Spaces:
Sleeping
Sleeping
fix: align rl-state API shape with frontend, add error boundary
Browse files- backend/api/demo.py +30 -7
- frontend/src/components/PerformanceGraph.tsx +5 -1
- frontend/src/main.tsx +39 -2
backend/api/demo.py
CHANGED
|
@@ -438,17 +438,40 @@ async def run_benchmark(req: BenchmarkRequest):
|
|
| 438 |
|
| 439 |
@router.get("/rl-state")
|
| 440 |
async def get_rl_state():
|
|
|
|
| 441 |
state = get_bandit_state()
|
|
|
|
|
|
|
| 442 |
action_names = [REPAIR_ACTION_NAMES[RepairAction(i)] for i in range(8)]
|
| 443 |
-
|
| 444 |
-
|
|
|
|
|
|
|
| 445 |
for i, name in enumerate(action_names)
|
| 446 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 447 |
return {
|
| 448 |
-
"
|
| 449 |
-
"
|
| 450 |
-
"
|
| 451 |
-
"
|
|
|
|
|
|
|
| 452 |
}
|
| 453 |
|
| 454 |
|
|
|
|
| 438 |
|
| 439 |
@router.get("/rl-state")
|
| 440 |
async def get_rl_state():
|
| 441 |
+
from rl.experience import get_metrics
|
| 442 |
state = get_bandit_state()
|
| 443 |
+
metrics = get_metrics()
|
| 444 |
+
|
| 445 |
action_names = [REPAIR_ACTION_NAMES[RepairAction(i)] for i in range(8)]
|
| 446 |
+
|
| 447 |
+
# Build actionDistribution as array [{action, count}] expected by frontend
|
| 448 |
+
action_distribution = [
|
| 449 |
+
{"action": name, "count": state["action_counts"][i]}
|
| 450 |
for i, name in enumerate(action_names)
|
| 451 |
+
]
|
| 452 |
+
|
| 453 |
+
# Build episodes array [{episode, totalReward, successRate}] from reward_history
|
| 454 |
+
reward_history: list[float] = metrics.reward_history or []
|
| 455 |
+
total_eps = max(metrics.total_episodes, len(reward_history))
|
| 456 |
+
episodes = [
|
| 457 |
+
{
|
| 458 |
+
"episode": i + 1,
|
| 459 |
+
"totalReward": round(r, 3),
|
| 460 |
+
"successRate": round(metrics.success_rate, 3),
|
| 461 |
+
}
|
| 462 |
+
for i, r in enumerate(reward_history)
|
| 463 |
+
]
|
| 464 |
+
|
| 465 |
+
from gepa.optimizer import get_gepa
|
| 466 |
+
gepa = get_gepa()
|
| 467 |
+
|
| 468 |
return {
|
| 469 |
+
"totalEpisodes": total_eps,
|
| 470 |
+
"successRate": round(metrics.success_rate, 3),
|
| 471 |
+
"currentAlpha": round(state["alpha"], 4),
|
| 472 |
+
"episodes": episodes,
|
| 473 |
+
"actionDistribution": action_distribution,
|
| 474 |
+
"currentGeneration": gepa.current_generation,
|
| 475 |
}
|
| 476 |
|
| 477 |
|
frontend/src/components/PerformanceGraph.tsx
CHANGED
|
@@ -66,7 +66,11 @@ export function PerformanceGraph() {
|
|
| 66 |
)
|
| 67 |
}
|
| 68 |
|
| 69 |
-
const
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
return (
|
| 72 |
<div className="flex flex-col gap-3">
|
|
|
|
| 66 |
)
|
| 67 |
}
|
| 68 |
|
| 69 |
+
const totalEpisodes: number = rlState.totalEpisodes ?? 0
|
| 70 |
+
const successRate: number = rlState.successRate ?? 0
|
| 71 |
+
const currentAlpha: number = rlState.currentAlpha ?? 0
|
| 72 |
+
const episodes: { episode: number; totalReward: number; successRate: number }[] = Array.isArray(rlState.episodes) ? rlState.episodes : []
|
| 73 |
+
const actionDistribution: { action: string; count: number }[] = Array.isArray(rlState.actionDistribution) ? rlState.actionDistribution : []
|
| 74 |
|
| 75 |
return (
|
| 76 |
<div className="flex flex-col gap-3">
|
frontend/src/main.tsx
CHANGED
|
@@ -12,8 +12,45 @@ try {
|
|
| 12 |
document.documentElement.setAttribute('data-theme', 'dark')
|
| 13 |
}
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
ReactDOM.createRoot(document.getElementById('root')!).render(
|
| 16 |
-
<
|
| 17 |
<App />
|
| 18 |
-
</
|
| 19 |
)
|
|
|
|
| 12 |
document.documentElement.setAttribute('data-theme', 'dark')
|
| 13 |
}
|
| 14 |
|
| 15 |
+
class ErrorBoundary extends React.Component<
|
| 16 |
+
{ children: React.ReactNode },
|
| 17 |
+
{ error: Error | null }
|
| 18 |
+
> {
|
| 19 |
+
constructor(props: { children: React.ReactNode }) {
|
| 20 |
+
super(props)
|
| 21 |
+
this.state = { error: null }
|
| 22 |
+
}
|
| 23 |
+
static getDerivedStateFromError(error: Error) {
|
| 24 |
+
return { error }
|
| 25 |
+
}
|
| 26 |
+
render() {
|
| 27 |
+
if (this.state.error) {
|
| 28 |
+
return (
|
| 29 |
+
<div style={{
|
| 30 |
+
height: '100vh', display: 'flex', alignItems: 'center', justifyContent: 'center',
|
| 31 |
+
background: '#08080d', color: '#fff', flexDirection: 'column', gap: 16, padding: 32,
|
| 32 |
+
fontFamily: 'monospace'
|
| 33 |
+
}}>
|
| 34 |
+
<div style={{ fontSize: 14, color: '#ef4444', marginBottom: 8 }}>Runtime error</div>
|
| 35 |
+
<pre style={{ fontSize: 11, color: '#9ca3af', maxWidth: 600, overflow: 'auto' }}>
|
| 36 |
+
{this.state.error.message}
|
| 37 |
+
</pre>
|
| 38 |
+
<button
|
| 39 |
+
onClick={() => window.location.reload()}
|
| 40 |
+
style={{ marginTop: 8, padding: '8px 16px', background: '#8b5cf6', border: 'none',
|
| 41 |
+
color: '#fff', borderRadius: 8, cursor: 'pointer', fontSize: 12 }}
|
| 42 |
+
>
|
| 43 |
+
Reload
|
| 44 |
+
</button>
|
| 45 |
+
</div>
|
| 46 |
+
)
|
| 47 |
+
}
|
| 48 |
+
return this.props.children
|
| 49 |
+
}
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
ReactDOM.createRoot(document.getElementById('root')!).render(
|
| 53 |
+
<ErrorBoundary>
|
| 54 |
<App />
|
| 55 |
+
</ErrorBoundary>
|
| 56 |
)
|