Spaces:
Running
Running
File size: 13,942 Bytes
6252f54 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 | """AI Learning Engine tab — enrichment heatmap + reward curve + training control."""
import streamlit as st
import plotly.graph_objects as go
import plotly.express as px
import pandas as pd
from frontend.utils.api_client import get_training_metrics, get_training_coverage, trigger_training
from frontend.utils.terminology import domain_label
# ── Tooltip helper ────────────────────────────────────────────────────────────
_TIP = 'cursor:help;border-bottom:1px dotted #9ca3af;text-decoration:none'
def _t(term: str, defn: str) -> str:
"""Wrap a term in an HTML abbr tag — browser shows defn on hover."""
safe = defn.replace('"', '"').replace("'", "'")
return f'<abbr title="{safe}" style="{_TIP}">{term}</abbr>'
def render_training_tab():
st.subheader("AI Learning Engine")
st.markdown(
"The "
+ _t("AI Prioritisation Engine",
"An autonomous reinforcement learning system that trains on every "
"strategic domain in the knowledge graph. It learns to rank capabilities "
"by value, feasibility, and governance alignment — producing better "
"roadmap recommendations with every session.")
+ " trains on every strategic domain in the knowledge graph. "
"Each learning session is recorded as a node in Neo4j — the graph continuously "
"improves itself, grounding capability prioritisation in sector-specific "
+ _t("governance frameworks",
"Industry-recognised standards (e.g. TOGAF, ISO 27001, COBIT) that define "
"how an enterprise should structure, secure, and operate its architecture.")
+ " and "
+ _t("innovation drivers",
"Emerging technology trends (e.g. Generative AI, Cloud-Native, ESG) that "
"signal where investment should be directed to stay competitive.")
+ ".",
unsafe_allow_html=True,
)
coverage = get_training_coverage()
metrics = get_training_metrics()
# ── Section A: Summary metrics ──────────────────────────────────────────
total = len(coverage)
trained = sum(1 for d in coverage if d.get("drl_trained"))
std_enriched = sum(1 for d in coverage if d.get("standard_enriched"))
trend_enriched = sum(1 for d in coverage if d.get("trend_enriched"))
rewards = [m["final_reward"] for m in metrics if m.get("final_reward") is not None]
avg_reward = round(sum(rewards) / len(rewards), 4) if rewards else None
c1, c2, c3, c4 = st.columns(4)
c1.metric(
"Strategic Domains Trained",
f"{trained} / {total}",
help=(
"Number of enterprise domains where the AI policy has completed at least one "
"training run. Trained domains produce significantly better capability "
"prioritisation — the AI learns the value-vs-complexity trade-off specific "
"to that domain's real capability data."
),
)
c2.metric(
"Governance Frameworks Enriched",
f"{std_enriched} / {total}",
help=(
"Domains linked to an industry standard (e.g. TOGAF, ISO 27001, COBIT) with "
"detailed compliance requirements loaded. These ground roadmap outputs in "
"recognised governance structures rather than generic best-guess advice."
),
)
c3.metric(
"Innovation Drivers Enriched",
f"{trend_enriched} / {total}",
help=(
"Domains linked to a technology trend (e.g. Generative AI, ESG, Cloud-Native) "
"with measurable business impact data. Used to weight forward-looking "
"capabilities higher when generating strategic roadmaps."
),
)
c4.metric(
"Avg Prioritisation Reward",
f"{avg_reward:.4f}" if avg_reward is not None else "—",
help=(
"Mean final reward across all AI training sessions. Range: −1.0 to +1.0. "
"Higher values mean the policy reliably ranks high-value, feasible capabilities "
"above complex, high-risk ones. A score above 0.5 is considered well-converged."
),
)
st.divider()
# ── Section B: Enrichment heatmap ───────────────────────────────────────
if coverage:
st.markdown("#### Knowledge Graph Enrichment Coverage")
st.markdown(
"Each row is a strategic domain. Columns show three dimensions of "
"knowledge enrichment: "
+ _t("Governance Framework",
"Whether the domain is linked to an industry standard with full "
"compliance requirements (e.g. TOGAF for architecture, ISO 27001 "
"for security, HL7 FHIR for healthcare).")
+ ", "
+ _t("Innovation Driver",
"Whether the domain is linked to a technology trend with documented "
"business impact (e.g. Generative AI, ESG Reporting, Blockchain).")
+ ", and "
+ _t("AI Trained",
"Whether the reinforcement learning policy has run at least one "
"training session on this domain's capability data. "
"Green = fully enriched · Red = not yet enriched.")
+ ".",
unsafe_allow_html=True,
)
domains = [domain_label(d["domain"]) for d in coverage]
cols = ["Governance Framework", "Innovation Driver", "AI Trained"]
z = [
[
int(d.get("standard_enriched", False)),
int(d.get("trend_enriched", False)),
int(d.get("drl_trained", False)),
]
for d in coverage
]
fig_heat = go.Figure(go.Heatmap(
z=z,
x=cols,
y=domains,
colorscale=[[0, "#e74c3c"], [0.5, "#f39c12"], [1, "#27ae60"]],
zmin=0, zmax=1,
showscale=False,
xgap=2, ygap=1,
))
fig_heat.update_layout(
height=max(400, len(domains) * 14),
margin=dict(l=200, r=20, t=20, b=20),
yaxis=dict(tickfont=dict(size=10)),
)
st.plotly_chart(fig_heat, width='stretch')
else:
st.info("No coverage data yet. Run training to populate.")
st.divider()
# ── Section C: Reward progression ───────────────────────────────────────
if metrics:
st.markdown("#### AI Prioritisation Reward Progression")
st.markdown(
"Each point is one domain learning session. The "
+ _t("final reward",
"A scalar score computed at the end of each training episode. "
"It combines: capability complexity alignment (did the AI deprioritise "
"high-risk items?), budget feasibility (do selected capabilities fit "
"the effort envelope?), and domain coverage breadth. Range: −1 to +1.")
+ " measures how well the AI ordered that domain's capabilities. "
"The dashed line is a "
+ _t("5-run rolling average",
"The mean reward across the last 5 training sessions. Smooths out "
"episode-level noise to reveal the overall learning trend. A rising "
"trend means the policy is improving.")
+ " showing the learning trend.",
unsafe_allow_html=True,
)
df = pd.DataFrame(metrics)
df = df.dropna(subset=["final_reward"])
if "domain_name" in df.columns:
df["domain_name"] = df["domain_name"].apply(domain_label)
if not df.empty:
if "ts" in df.columns:
df = df.sort_values("ts")
fig_line = px.scatter(
df,
x=df.index,
y="final_reward",
color="sector",
hover_data=["domain_name", "episodes", "device", "ts"],
labels={"index": "Training Run #", "final_reward": "Final Reward"},
height=350,
)
fig_line.add_trace(go.Scatter(
x=df.index,
y=df["final_reward"].rolling(5, min_periods=1).mean(),
mode="lines",
name="5-run avg",
line=dict(color="white", width=2, dash="dash"),
))
fig_line.update_layout(margin=dict(t=20, b=20))
st.plotly_chart(fig_line, width='stretch')
if "sector" in df.columns and not df.empty:
sector_avg = df.groupby("sector")["final_reward"].mean().reset_index()
sector_avg.columns = ["Sector", "Avg Reward"]
fig_bar = px.bar(
sector_avg.sort_values("Avg Reward", ascending=True),
x="Avg Reward", y="Sector", orientation="h",
height=300,
color="Avg Reward",
color_continuous_scale="RdYlGn",
title="Average Prioritisation Reward by Sector",
)
fig_bar.update_layout(margin=dict(t=40, b=20), coloraxis_showscale=False)
st.plotly_chart(fig_bar, width='stretch')
else:
st.info("No training runs recorded yet. Use the controls below to start training.")
st.divider()
# ── Section D: Training control ─────────────────────────────────────────
st.markdown("#### Run AI Learning Session")
st.markdown(
"Uses "
+ _t("REINFORCE",
"A policy gradient algorithm (Williams, 1992). The AI tries different "
"capability orderings, observes the reward signal, then adjusts its "
"policy to make high-reward orderings more likely in future. "
"No environment model required — learns purely from trial-and-error.")
+ " policy gradient training on the "
+ _t("AMD MI300X via ROCm",
"AMD's Instinct MI300X is a GPU accelerator optimised for AI workloads. "
"ROCm is AMD's open-source GPU compute platform — analogous to NVIDIA CUDA "
"but for AMD hardware. The DRL policy trains in seconds on MI300X vs minutes on CPU.")
+ ". Each strategic domain's capabilities form the learning state — the AI engine "
"learns optimal prioritisation grounded in real governance frameworks, "
+ _t("effort estimates",
"typical_duration_weeks: the median number of delivery weeks for a capability "
"based on industry benchmarks stored in the knowledge graph.")
+ ", and risk data.",
unsafe_allow_html=True,
)
ctrl1, ctrl2, ctrl3 = st.columns([2, 2, 3])
with ctrl1:
episodes = st.number_input(
"Learning episodes per domain",
min_value=10, max_value=500, value=50, step=10,
help=(
"One episode = the AI runs through all capabilities in a domain once, "
"receives a reward, and updates its policy. More episodes → more "
"refined prioritisation, but takes longer. 50 is a good starting point; "
"200+ is recommended for production-quality rankings."
),
)
with ctrl2:
domain_filter = st.text_input(
"Specific domain (blank = all 44)",
placeholder="e.g. Healthcare Provider",
help="Leave blank to train all 44 domains sequentially. Enter a domain name to train just that one.",
)
with ctrl3:
st.markdown("<br>", unsafe_allow_html=True)
if st.button("Run AI Learning Session", type="primary", width='stretch'):
with st.spinner("Submitting training job…"):
resp = trigger_training(
episodes_per_domain=int(episodes),
domain=domain_filter.strip() or None,
)
if resp.get("status") == "started":
st.success(
f"Training started (run_id: `{resp.get('run_id')}`). "
"Metrics appear in the graph as each domain completes. "
"Refresh this tab in ~60 seconds."
)
else:
st.error(f"Failed to start training: {resp.get('message', resp)}")
# Latest runs table
if metrics:
st.markdown("##### Latest AI Learning Sessions")
st.markdown(
"_"
+ _t("Prioritisation Reward",
"Final reward score at the end of training (range −1 to +1). "
"Combines complexity alignment, budget feasibility, and coverage breadth. "
"Higher = better capability ordering.")
+ " · "
+ _t("Avg (last 10)",
"Rolling mean of the final reward across the last 10 episodes of "
"the same training run. A value close to the Final Reward means the "
"policy has stabilised; a lower value means it was still improving.")
+ "_",
unsafe_allow_html=True,
)
df_show = pd.DataFrame(metrics[:20])[
["domain_name", "sector", "episodes", "final_reward", "avg_reward_last10", "device", "ts"]
].copy()
df_show["domain_name"] = df_show["domain_name"].apply(domain_label)
df_show.columns = [
"Strategic Domain", "Sector", "Episodes",
"Prioritisation Reward", "Avg (last 10)", "Device", "Timestamp",
]
st.dataframe(df_show, width='stretch', hide_index=True)
|