TheQuantEd's picture
deploy: AMD EA Strategy Optimizer — Neo4j + FastAPI + Streamlit
6252f54
"""AI Learning Engine tab — enrichment heatmap + reward curve + training control."""
import streamlit as st
import plotly.graph_objects as go
import plotly.express as px
import pandas as pd
from frontend.utils.api_client import get_training_metrics, get_training_coverage, trigger_training
from frontend.utils.terminology import domain_label
# ── Tooltip helper ────────────────────────────────────────────────────────────
_TIP = 'cursor:help;border-bottom:1px dotted #9ca3af;text-decoration:none'
def _t(term: str, defn: str) -> str:
"""Wrap a term in an HTML abbr tag — browser shows defn on hover."""
safe = defn.replace('"', '"').replace("'", "'")
return f'<abbr title="{safe}" style="{_TIP}">{term}</abbr>'
def render_training_tab():
st.subheader("AI Learning Engine")
st.markdown(
"The "
+ _t("AI Prioritisation Engine",
"An autonomous reinforcement learning system that trains on every "
"strategic domain in the knowledge graph. It learns to rank capabilities "
"by value, feasibility, and governance alignment — producing better "
"roadmap recommendations with every session.")
+ " trains on every strategic domain in the knowledge graph. "
"Each learning session is recorded as a node in Neo4j — the graph continuously "
"improves itself, grounding capability prioritisation in sector-specific "
+ _t("governance frameworks",
"Industry-recognised standards (e.g. TOGAF, ISO 27001, COBIT) that define "
"how an enterprise should structure, secure, and operate its architecture.")
+ " and "
+ _t("innovation drivers",
"Emerging technology trends (e.g. Generative AI, Cloud-Native, ESG) that "
"signal where investment should be directed to stay competitive.")
+ ".",
unsafe_allow_html=True,
)
coverage = get_training_coverage()
metrics = get_training_metrics()
# ── Section A: Summary metrics ──────────────────────────────────────────
total = len(coverage)
trained = sum(1 for d in coverage if d.get("drl_trained"))
std_enriched = sum(1 for d in coverage if d.get("standard_enriched"))
trend_enriched = sum(1 for d in coverage if d.get("trend_enriched"))
rewards = [m["final_reward"] for m in metrics if m.get("final_reward") is not None]
avg_reward = round(sum(rewards) / len(rewards), 4) if rewards else None
c1, c2, c3, c4 = st.columns(4)
c1.metric(
"Strategic Domains Trained",
f"{trained} / {total}",
help=(
"Number of enterprise domains where the AI policy has completed at least one "
"training run. Trained domains produce significantly better capability "
"prioritisation — the AI learns the value-vs-complexity trade-off specific "
"to that domain's real capability data."
),
)
c2.metric(
"Governance Frameworks Enriched",
f"{std_enriched} / {total}",
help=(
"Domains linked to an industry standard (e.g. TOGAF, ISO 27001, COBIT) with "
"detailed compliance requirements loaded. These ground roadmap outputs in "
"recognised governance structures rather than generic best-guess advice."
),
)
c3.metric(
"Innovation Drivers Enriched",
f"{trend_enriched} / {total}",
help=(
"Domains linked to a technology trend (e.g. Generative AI, ESG, Cloud-Native) "
"with measurable business impact data. Used to weight forward-looking "
"capabilities higher when generating strategic roadmaps."
),
)
c4.metric(
"Avg Prioritisation Reward",
f"{avg_reward:.4f}" if avg_reward is not None else "—",
help=(
"Mean final reward across all AI training sessions. Range: −1.0 to +1.0. "
"Higher values mean the policy reliably ranks high-value, feasible capabilities "
"above complex, high-risk ones. A score above 0.5 is considered well-converged."
),
)
st.divider()
# ── Section B: Enrichment heatmap ───────────────────────────────────────
if coverage:
st.markdown("#### Knowledge Graph Enrichment Coverage")
st.markdown(
"Each row is a strategic domain. Columns show three dimensions of "
"knowledge enrichment: "
+ _t("Governance Framework",
"Whether the domain is linked to an industry standard with full "
"compliance requirements (e.g. TOGAF for architecture, ISO 27001 "
"for security, HL7 FHIR for healthcare).")
+ ", "
+ _t("Innovation Driver",
"Whether the domain is linked to a technology trend with documented "
"business impact (e.g. Generative AI, ESG Reporting, Blockchain).")
+ ", and "
+ _t("AI Trained",
"Whether the reinforcement learning policy has run at least one "
"training session on this domain's capability data. "
"Green = fully enriched · Red = not yet enriched.")
+ ".",
unsafe_allow_html=True,
)
domains = [domain_label(d["domain"]) for d in coverage]
cols = ["Governance Framework", "Innovation Driver", "AI Trained"]
z = [
[
int(d.get("standard_enriched", False)),
int(d.get("trend_enriched", False)),
int(d.get("drl_trained", False)),
]
for d in coverage
]
fig_heat = go.Figure(go.Heatmap(
z=z,
x=cols,
y=domains,
colorscale=[[0, "#e74c3c"], [0.5, "#f39c12"], [1, "#27ae60"]],
zmin=0, zmax=1,
showscale=False,
xgap=2, ygap=1,
))
fig_heat.update_layout(
height=max(400, len(domains) * 14),
margin=dict(l=200, r=20, t=20, b=20),
yaxis=dict(tickfont=dict(size=10)),
)
st.plotly_chart(fig_heat, width='stretch')
else:
st.info("No coverage data yet. Run training to populate.")
st.divider()
# ── Section C: Reward progression ───────────────────────────────────────
if metrics:
st.markdown("#### AI Prioritisation Reward Progression")
st.markdown(
"Each point is one domain learning session. The "
+ _t("final reward",
"A scalar score computed at the end of each training episode. "
"It combines: capability complexity alignment (did the AI deprioritise "
"high-risk items?), budget feasibility (do selected capabilities fit "
"the effort envelope?), and domain coverage breadth. Range: −1 to +1.")
+ " measures how well the AI ordered that domain's capabilities. "
"The dashed line is a "
+ _t("5-run rolling average",
"The mean reward across the last 5 training sessions. Smooths out "
"episode-level noise to reveal the overall learning trend. A rising "
"trend means the policy is improving.")
+ " showing the learning trend.",
unsafe_allow_html=True,
)
df = pd.DataFrame(metrics)
df = df.dropna(subset=["final_reward"])
if "domain_name" in df.columns:
df["domain_name"] = df["domain_name"].apply(domain_label)
if not df.empty:
if "ts" in df.columns:
df = df.sort_values("ts")
fig_line = px.scatter(
df,
x=df.index,
y="final_reward",
color="sector",
hover_data=["domain_name", "episodes", "device", "ts"],
labels={"index": "Training Run #", "final_reward": "Final Reward"},
height=350,
)
fig_line.add_trace(go.Scatter(
x=df.index,
y=df["final_reward"].rolling(5, min_periods=1).mean(),
mode="lines",
name="5-run avg",
line=dict(color="white", width=2, dash="dash"),
))
fig_line.update_layout(margin=dict(t=20, b=20))
st.plotly_chart(fig_line, width='stretch')
if "sector" in df.columns and not df.empty:
sector_avg = df.groupby("sector")["final_reward"].mean().reset_index()
sector_avg.columns = ["Sector", "Avg Reward"]
fig_bar = px.bar(
sector_avg.sort_values("Avg Reward", ascending=True),
x="Avg Reward", y="Sector", orientation="h",
height=300,
color="Avg Reward",
color_continuous_scale="RdYlGn",
title="Average Prioritisation Reward by Sector",
)
fig_bar.update_layout(margin=dict(t=40, b=20), coloraxis_showscale=False)
st.plotly_chart(fig_bar, width='stretch')
else:
st.info("No training runs recorded yet. Use the controls below to start training.")
st.divider()
# ── Section D: Training control ─────────────────────────────────────────
st.markdown("#### Run AI Learning Session")
st.markdown(
"Uses "
+ _t("REINFORCE",
"A policy gradient algorithm (Williams, 1992). The AI tries different "
"capability orderings, observes the reward signal, then adjusts its "
"policy to make high-reward orderings more likely in future. "
"No environment model required — learns purely from trial-and-error.")
+ " policy gradient training on the "
+ _t("AMD MI300X via ROCm",
"AMD's Instinct MI300X is a GPU accelerator optimised for AI workloads. "
"ROCm is AMD's open-source GPU compute platform — analogous to NVIDIA CUDA "
"but for AMD hardware. The DRL policy trains in seconds on MI300X vs minutes on CPU.")
+ ". Each strategic domain's capabilities form the learning state — the AI engine "
"learns optimal prioritisation grounded in real governance frameworks, "
+ _t("effort estimates",
"typical_duration_weeks: the median number of delivery weeks for a capability "
"based on industry benchmarks stored in the knowledge graph.")
+ ", and risk data.",
unsafe_allow_html=True,
)
ctrl1, ctrl2, ctrl3 = st.columns([2, 2, 3])
with ctrl1:
episodes = st.number_input(
"Learning episodes per domain",
min_value=10, max_value=500, value=50, step=10,
help=(
"One episode = the AI runs through all capabilities in a domain once, "
"receives a reward, and updates its policy. More episodes → more "
"refined prioritisation, but takes longer. 50 is a good starting point; "
"200+ is recommended for production-quality rankings."
),
)
with ctrl2:
domain_filter = st.text_input(
"Specific domain (blank = all 44)",
placeholder="e.g. Healthcare Provider",
help="Leave blank to train all 44 domains sequentially. Enter a domain name to train just that one.",
)
with ctrl3:
st.markdown("<br>", unsafe_allow_html=True)
if st.button("Run AI Learning Session", type="primary", width='stretch'):
with st.spinner("Submitting training job…"):
resp = trigger_training(
episodes_per_domain=int(episodes),
domain=domain_filter.strip() or None,
)
if resp.get("status") == "started":
st.success(
f"Training started (run_id: `{resp.get('run_id')}`). "
"Metrics appear in the graph as each domain completes. "
"Refresh this tab in ~60 seconds."
)
else:
st.error(f"Failed to start training: {resp.get('message', resp)}")
# Latest runs table
if metrics:
st.markdown("##### Latest AI Learning Sessions")
st.markdown(
"_"
+ _t("Prioritisation Reward",
"Final reward score at the end of training (range −1 to +1). "
"Combines complexity alignment, budget feasibility, and coverage breadth. "
"Higher = better capability ordering.")
+ " · "
+ _t("Avg (last 10)",
"Rolling mean of the final reward across the last 10 episodes of "
"the same training run. A value close to the Final Reward means the "
"policy has stabilised; a lower value means it was still improving.")
+ "_",
unsafe_allow_html=True,
)
df_show = pd.DataFrame(metrics[:20])[
["domain_name", "sector", "episodes", "final_reward", "avg_reward_last10", "device", "ts"]
].copy()
df_show["domain_name"] = df_show["domain_name"].apply(domain_label)
df_show.columns = [
"Strategic Domain", "Sector", "Episodes",
"Prioritisation Reward", "Avg (last 10)", "Device", "Timestamp",
]
st.dataframe(df_show, width='stretch', hide_index=True)