File size: 13,942 Bytes
6252f54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
"""AI Learning Engine tab — enrichment heatmap + reward curve + training control."""

import streamlit as st
import plotly.graph_objects as go
import plotly.express as px
import pandas as pd

from frontend.utils.api_client import get_training_metrics, get_training_coverage, trigger_training
from frontend.utils.terminology import domain_label

# ── Tooltip helper ────────────────────────────────────────────────────────────
_TIP = 'cursor:help;border-bottom:1px dotted #9ca3af;text-decoration:none'

def _t(term: str, defn: str) -> str:
    """Wrap a term in an HTML abbr tag — browser shows defn on hover."""
    safe = defn.replace('"', '"').replace("'", "'")
    return f'<abbr title="{safe}" style="{_TIP}">{term}</abbr>'


def render_training_tab():
    st.subheader("AI Learning Engine")
    st.markdown(
        "The "
        + _t("AI Prioritisation Engine",
             "An autonomous reinforcement learning system that trains on every "
             "strategic domain in the knowledge graph. It learns to rank capabilities "
             "by value, feasibility, and governance alignment — producing better "
             "roadmap recommendations with every session.")
        + " trains on every strategic domain in the knowledge graph. "
        "Each learning session is recorded as a node in Neo4j — the graph continuously "
        "improves itself, grounding capability prioritisation in sector-specific "
        + _t("governance frameworks",
             "Industry-recognised standards (e.g. TOGAF, ISO 27001, COBIT) that define "
             "how an enterprise should structure, secure, and operate its architecture.")
        + " and "
        + _t("innovation drivers",
             "Emerging technology trends (e.g. Generative AI, Cloud-Native, ESG) that "
             "signal where investment should be directed to stay competitive.")
        + ".",
        unsafe_allow_html=True,
    )

    coverage = get_training_coverage()
    metrics = get_training_metrics()

    # ── Section A: Summary metrics ──────────────────────────────────────────
    total = len(coverage)
    trained = sum(1 for d in coverage if d.get("drl_trained"))
    std_enriched = sum(1 for d in coverage if d.get("standard_enriched"))
    trend_enriched = sum(1 for d in coverage if d.get("trend_enriched"))
    rewards = [m["final_reward"] for m in metrics if m.get("final_reward") is not None]
    avg_reward = round(sum(rewards) / len(rewards), 4) if rewards else None

    c1, c2, c3, c4 = st.columns(4)
    c1.metric(
        "Strategic Domains Trained",
        f"{trained} / {total}",
        help=(
            "Number of enterprise domains where the AI policy has completed at least one "
            "training run. Trained domains produce significantly better capability "
            "prioritisation — the AI learns the value-vs-complexity trade-off specific "
            "to that domain's real capability data."
        ),
    )
    c2.metric(
        "Governance Frameworks Enriched",
        f"{std_enriched} / {total}",
        help=(
            "Domains linked to an industry standard (e.g. TOGAF, ISO 27001, COBIT) with "
            "detailed compliance requirements loaded. These ground roadmap outputs in "
            "recognised governance structures rather than generic best-guess advice."
        ),
    )
    c3.metric(
        "Innovation Drivers Enriched",
        f"{trend_enriched} / {total}",
        help=(
            "Domains linked to a technology trend (e.g. Generative AI, ESG, Cloud-Native) "
            "with measurable business impact data. Used to weight forward-looking "
            "capabilities higher when generating strategic roadmaps."
        ),
    )
    c4.metric(
        "Avg Prioritisation Reward",
        f"{avg_reward:.4f}" if avg_reward is not None else "—",
        help=(
            "Mean final reward across all AI training sessions. Range: −1.0 to +1.0. "
            "Higher values mean the policy reliably ranks high-value, feasible capabilities "
            "above complex, high-risk ones. A score above 0.5 is considered well-converged."
        ),
    )

    st.divider()

    # ── Section B: Enrichment heatmap ───────────────────────────────────────
    if coverage:
        st.markdown("#### Knowledge Graph Enrichment Coverage")
        st.markdown(
            "Each row is a strategic domain. Columns show three dimensions of "
            "knowledge enrichment: "
            + _t("Governance Framework",
                 "Whether the domain is linked to an industry standard with full "
                 "compliance requirements (e.g. TOGAF for architecture, ISO 27001 "
                 "for security, HL7 FHIR for healthcare).")
            + ", "
            + _t("Innovation Driver",
                 "Whether the domain is linked to a technology trend with documented "
                 "business impact (e.g. Generative AI, ESG Reporting, Blockchain).")
            + ", and "
            + _t("AI Trained",
                 "Whether the reinforcement learning policy has run at least one "
                 "training session on this domain's capability data. "
                 "Green = fully enriched · Red = not yet enriched.")
            + ".",
            unsafe_allow_html=True,
        )

        domains = [domain_label(d["domain"]) for d in coverage]
        cols = ["Governance Framework", "Innovation Driver", "AI Trained"]
        z = [
            [
                int(d.get("standard_enriched", False)),
                int(d.get("trend_enriched", False)),
                int(d.get("drl_trained", False)),
            ]
            for d in coverage
        ]

        fig_heat = go.Figure(go.Heatmap(
            z=z,
            x=cols,
            y=domains,
            colorscale=[[0, "#e74c3c"], [0.5, "#f39c12"], [1, "#27ae60"]],
            zmin=0, zmax=1,
            showscale=False,
            xgap=2, ygap=1,
        ))
        fig_heat.update_layout(
            height=max(400, len(domains) * 14),
            margin=dict(l=200, r=20, t=20, b=20),
            yaxis=dict(tickfont=dict(size=10)),
        )
        st.plotly_chart(fig_heat, width='stretch')
    else:
        st.info("No coverage data yet. Run training to populate.")

    st.divider()

    # ── Section C: Reward progression ───────────────────────────────────────
    if metrics:
        st.markdown("#### AI Prioritisation Reward Progression")
        st.markdown(
            "Each point is one domain learning session. The "
            + _t("final reward",
                 "A scalar score computed at the end of each training episode. "
                 "It combines: capability complexity alignment (did the AI deprioritise "
                 "high-risk items?), budget feasibility (do selected capabilities fit "
                 "the effort envelope?), and domain coverage breadth. Range: −1 to +1.")
            + " measures how well the AI ordered that domain's capabilities. "
            "The dashed line is a "
            + _t("5-run rolling average",
                 "The mean reward across the last 5 training sessions. Smooths out "
                 "episode-level noise to reveal the overall learning trend. A rising "
                 "trend means the policy is improving.")
            + " showing the learning trend.",
            unsafe_allow_html=True,
        )

        df = pd.DataFrame(metrics)
        df = df.dropna(subset=["final_reward"])
        if "domain_name" in df.columns:
            df["domain_name"] = df["domain_name"].apply(domain_label)

        if not df.empty:
            if "ts" in df.columns:
                df = df.sort_values("ts")

            fig_line = px.scatter(
                df,
                x=df.index,
                y="final_reward",
                color="sector",
                hover_data=["domain_name", "episodes", "device", "ts"],
                labels={"index": "Training Run #", "final_reward": "Final Reward"},
                height=350,
            )
            fig_line.add_trace(go.Scatter(
                x=df.index,
                y=df["final_reward"].rolling(5, min_periods=1).mean(),
                mode="lines",
                name="5-run avg",
                line=dict(color="white", width=2, dash="dash"),
            ))
            fig_line.update_layout(margin=dict(t=20, b=20))
            st.plotly_chart(fig_line, width='stretch')

        if "sector" in df.columns and not df.empty:
            sector_avg = df.groupby("sector")["final_reward"].mean().reset_index()
            sector_avg.columns = ["Sector", "Avg Reward"]
            fig_bar = px.bar(
                sector_avg.sort_values("Avg Reward", ascending=True),
                x="Avg Reward", y="Sector", orientation="h",
                height=300,
                color="Avg Reward",
                color_continuous_scale="RdYlGn",
                title="Average Prioritisation Reward by Sector",
            )
            fig_bar.update_layout(margin=dict(t=40, b=20), coloraxis_showscale=False)
            st.plotly_chart(fig_bar, width='stretch')
    else:
        st.info("No training runs recorded yet. Use the controls below to start training.")

    st.divider()

    # ── Section D: Training control ─────────────────────────────────────────
    st.markdown("#### Run AI Learning Session")
    st.markdown(
        "Uses "
        + _t("REINFORCE",
             "A policy gradient algorithm (Williams, 1992). The AI tries different "
             "capability orderings, observes the reward signal, then adjusts its "
             "policy to make high-reward orderings more likely in future. "
             "No environment model required — learns purely from trial-and-error.")
        + " policy gradient training on the "
        + _t("AMD MI300X via ROCm",
             "AMD's Instinct MI300X is a GPU accelerator optimised for AI workloads. "
             "ROCm is AMD's open-source GPU compute platform — analogous to NVIDIA CUDA "
             "but for AMD hardware. The DRL policy trains in seconds on MI300X vs minutes on CPU.")
        + ". Each strategic domain's capabilities form the learning state — the AI engine "
        "learns optimal prioritisation grounded in real governance frameworks, "
        + _t("effort estimates",
             "typical_duration_weeks: the median number of delivery weeks for a capability "
             "based on industry benchmarks stored in the knowledge graph.")
        + ", and risk data.",
        unsafe_allow_html=True,
    )

    ctrl1, ctrl2, ctrl3 = st.columns([2, 2, 3])
    with ctrl1:
        episodes = st.number_input(
            "Learning episodes per domain",
            min_value=10, max_value=500, value=50, step=10,
            help=(
                "One episode = the AI runs through all capabilities in a domain once, "
                "receives a reward, and updates its policy. More episodes → more "
                "refined prioritisation, but takes longer. 50 is a good starting point; "
                "200+ is recommended for production-quality rankings."
            ),
        )
    with ctrl2:
        domain_filter = st.text_input(
            "Specific domain (blank = all 44)",
            placeholder="e.g. Healthcare Provider",
            help="Leave blank to train all 44 domains sequentially. Enter a domain name to train just that one.",
        )
    with ctrl3:
        st.markdown("<br>", unsafe_allow_html=True)
        if st.button("Run AI Learning Session", type="primary", width='stretch'):
            with st.spinner("Submitting training job…"):
                resp = trigger_training(
                    episodes_per_domain=int(episodes),
                    domain=domain_filter.strip() or None,
                )
            if resp.get("status") == "started":
                st.success(
                    f"Training started (run_id: `{resp.get('run_id')}`). "
                    "Metrics appear in the graph as each domain completes. "
                    "Refresh this tab in ~60 seconds."
                )
            else:
                st.error(f"Failed to start training: {resp.get('message', resp)}")

    # Latest runs table
    if metrics:
        st.markdown("##### Latest AI Learning Sessions")
        st.markdown(
            "_"
            + _t("Prioritisation Reward",
                 "Final reward score at the end of training (range −1 to +1). "
                 "Combines complexity alignment, budget feasibility, and coverage breadth. "
                 "Higher = better capability ordering.")
            + " · "
            + _t("Avg (last 10)",
                 "Rolling mean of the final reward across the last 10 episodes of "
                 "the same training run. A value close to the Final Reward means the "
                 "policy has stabilised; a lower value means it was still improving.")
            + "_",
            unsafe_allow_html=True,
        )
        df_show = pd.DataFrame(metrics[:20])[
            ["domain_name", "sector", "episodes", "final_reward", "avg_reward_last10", "device", "ts"]
        ].copy()
        df_show["domain_name"] = df_show["domain_name"].apply(domain_label)
        df_show.columns = [
            "Strategic Domain", "Sector", "Episodes",
            "Prioritisation Reward", "Avg (last 10)", "Device", "Timestamp",
        ]
        st.dataframe(df_show, width='stretch', hide_index=True)