Spaces:
Sleeping
Sleeping
| 'use client'; | |
| import { useState, useEffect, useRef, useCallback } from 'react'; | |
| import { Activity, TrendingUp, AlertCircle, CheckCircle2, Circle } from 'lucide-react'; | |
| export interface TrainingEntry { | |
| step: number; | |
| timestamp: number; | |
| task_name: string; | |
| reward: number; | |
| shape_similarity: number; | |
| is_valid: boolean; | |
| error: string | null; | |
| fold_data: any | null; | |
| final_positions: number[][]; | |
| target_positions: number[][]; | |
| } | |
| interface TrainingStats { | |
| total_steps: number; | |
| best_reward: number; | |
| best_similarity: number; | |
| } | |
| interface FeedResponse { | |
| entries: TrainingEntry[]; | |
| stats: TrainingStats; | |
| } | |
| interface TrainingDashboardProps { | |
| onEntrySelect?: (entry: TrainingEntry) => void; | |
| } | |
| export function TrainingDashboard({ onEntrySelect }: TrainingDashboardProps) { | |
| const [entries, setEntries] = useState<TrainingEntry[]>([]); | |
| const [stats, setStats] = useState<TrainingStats>({ total_steps: 0, best_reward: -999, best_similarity: 0 }); | |
| const [connected, setConnected] = useState(false); | |
| const [error, setError] = useState<string | null>(null); | |
| const [selectedStep, setSelectedStep] = useState<number | null>(null); | |
| const lastStep = useRef(0); | |
| const feedRef = useRef<HTMLDivElement>(null); | |
| // Auto-select newest entry | |
| const autoSelect = useRef(true); | |
| useEffect(() => { | |
| let active = true; | |
| async function poll() { | |
| while (active) { | |
| try { | |
| const res = await fetch(`/api/env/training/feed?since=${lastStep.current}`); | |
| if (!res.ok) throw new Error(`HTTP ${res.status}`); | |
| const data: FeedResponse = await res.json(); | |
| if (data.entries.length > 0) { | |
| setEntries(prev => { | |
| const combined = [...prev, ...data.entries]; | |
| return combined.slice(-50); | |
| }); | |
| lastStep.current = data.entries[data.entries.length - 1].step; | |
| // Auto-select latest | |
| if (autoSelect.current) { | |
| const latest = data.entries[data.entries.length - 1]; | |
| setSelectedStep(latest.step); | |
| onEntrySelect?.(latest); | |
| } | |
| } | |
| setStats(data.stats); | |
| setConnected(true); | |
| setError(null); | |
| } catch (e: any) { | |
| setConnected(false); | |
| setError(e.message); | |
| } | |
| await new Promise(r => setTimeout(r, 2000)); | |
| } | |
| } | |
| poll(); | |
| return () => { active = false; }; | |
| }, [onEntrySelect]); | |
| // Auto-scroll feed | |
| useEffect(() => { | |
| if (feedRef.current && autoSelect.current) { | |
| feedRef.current.scrollTop = feedRef.current.scrollHeight; | |
| } | |
| }, [entries]); | |
| const handleEntryClick = useCallback((entry: TrainingEntry) => { | |
| autoSelect.current = false; | |
| setSelectedStep(entry.step); | |
| onEntrySelect?.(entry); | |
| }, [onEntrySelect]); | |
| const rewardColor = (r: number) => { | |
| if (r >= 15) return 'text-green-400'; | |
| if (r >= 5) return 'text-yellow-400'; | |
| if (r >= 0) return 'text-orange-400'; | |
| return 'text-red-400'; | |
| }; | |
| const rewardBg = (r: number) => { | |
| if (r >= 15) return 'bg-green-500'; | |
| if (r >= 5) return 'bg-yellow-500'; | |
| if (r >= 0) return 'bg-orange-500'; | |
| return 'bg-red-500'; | |
| }; | |
| const simBar = (sim: number) => { | |
| const pct = Math.min(sim * 100, 100); | |
| const color = pct > 70 ? 'bg-green-500' : pct > 40 ? 'bg-yellow-500' : 'bg-red-500'; | |
| return ( | |
| <div className="w-full h-1.5 bg-zinc-700 rounded-full overflow-hidden"> | |
| <div className={`h-full ${color} rounded-full transition-all`} style={{ width: `${pct}%` }} /> | |
| </div> | |
| ); | |
| }; | |
| // Reward chart data | |
| const recentRewards = entries.slice(-30); | |
| const maxR = Math.max(...recentRewards.map(e => e.reward), 1); | |
| const minR = Math.min(...recentRewards.map(e => e.reward), 0); | |
| const range = maxR - minR || 1; | |
| // Task breakdown | |
| const taskCounts: Record<string, { count: number; avgReward: number; totalReward: number }> = {}; | |
| for (const e of entries) { | |
| if (!taskCounts[e.task_name]) taskCounts[e.task_name] = { count: 0, avgReward: 0, totalReward: 0 }; | |
| taskCounts[e.task_name].count++; | |
| taskCounts[e.task_name].totalReward += e.reward; | |
| taskCounts[e.task_name].avgReward = taskCounts[e.task_name].totalReward / taskCounts[e.task_name].count; | |
| } | |
| return ( | |
| <div className="flex flex-col gap-3 h-full"> | |
| {/* Connection status */} | |
| <div className="flex items-center gap-2 text-xs"> | |
| <div className={`w-2 h-2 rounded-full ${connected ? 'bg-green-500 animate-pulse' : 'bg-zinc-600'}`} /> | |
| <span className="text-zinc-400"> | |
| {connected ? 'Live' : 'Connecting...'} | |
| </span> | |
| {error && <span className="text-red-400/80 ml-auto truncate max-w-[140px]">{error}</span>} | |
| </div> | |
| {/* Stats row */} | |
| <div className="grid grid-cols-3 gap-2"> | |
| <div className="bg-zinc-800/80 rounded-lg p-2.5 border border-zinc-700/60"> | |
| <div className="text-[10px] uppercase text-zinc-500 tracking-wider">Steps</div> | |
| <div className="text-xl font-mono font-bold text-zinc-100">{stats.total_steps}</div> | |
| </div> | |
| <div className="bg-zinc-800/80 rounded-lg p-2.5 border border-zinc-700/60"> | |
| <div className="text-[10px] uppercase text-zinc-500 tracking-wider">Best Reward</div> | |
| <div className={`text-xl font-mono font-bold ${rewardColor(stats.best_reward)}`}> | |
| {stats.best_reward > -999 ? stats.best_reward.toFixed(1) : '--'} | |
| </div> | |
| </div> | |
| <div className="bg-zinc-800/80 rounded-lg p-2.5 border border-zinc-700/60"> | |
| <div className="text-[10px] uppercase text-zinc-500 tracking-wider">Best Sim</div> | |
| <div className="text-xl font-mono font-bold text-indigo-400"> | |
| {stats.best_similarity > 0 ? (stats.best_similarity * 100).toFixed(0) + '%' : '--'} | |
| </div> | |
| </div> | |
| </div> | |
| {/* Reward trend chart */} | |
| {recentRewards.length > 1 && ( | |
| <div className="bg-zinc-800/80 rounded-lg p-3 border border-zinc-700/60"> | |
| <div className="text-[10px] uppercase text-zinc-500 tracking-wider mb-2 flex items-center gap-1"> | |
| <TrendingUp size={10} /> Reward History | |
| </div> | |
| <div className="flex items-end gap-[2px] h-20"> | |
| {recentRewards.map((e, i) => { | |
| const h = ((e.reward - minR) / range) * 100; | |
| const isSelected = e.step === selectedStep; | |
| return ( | |
| <div | |
| key={e.step} | |
| className={`flex-1 rounded-t cursor-pointer transition-all ${isSelected ? 'ring-1 ring-white' : ''}`} | |
| style={{ | |
| height: `${Math.max(h, 3)}%`, | |
| backgroundColor: e.reward >= 15 ? '#22c55e' : e.reward >= 5 ? '#eab308' : e.reward >= 0 ? '#f97316' : '#ef4444', | |
| opacity: isSelected ? 1 : 0.4 + (i / recentRewards.length) * 0.5, | |
| }} | |
| title={`#${e.step} ${e.task_name}: ${e.reward.toFixed(2)}`} | |
| onClick={() => handleEntryClick(e)} | |
| /> | |
| ); | |
| })} | |
| </div> | |
| <div className="flex justify-between text-[9px] text-zinc-600 mt-1 font-mono"> | |
| <span>{minR.toFixed(1)}</span> | |
| <span>{maxR.toFixed(1)}</span> | |
| </div> | |
| </div> | |
| )} | |
| {/* Task breakdown */} | |
| {Object.keys(taskCounts).length > 1 && ( | |
| <div className="bg-zinc-800/80 rounded-lg p-3 border border-zinc-700/60"> | |
| <div className="text-[10px] uppercase text-zinc-500 tracking-wider mb-2">Tasks</div> | |
| <div className="space-y-1"> | |
| {Object.entries(taskCounts).map(([name, data]) => ( | |
| <div key={name} className="flex items-center gap-2 text-[11px]"> | |
| <span className="text-zinc-400 w-24 truncate">{name}</span> | |
| <div className="flex-1 h-1 bg-zinc-700 rounded-full overflow-hidden"> | |
| <div | |
| className={`h-full rounded-full ${rewardBg(data.avgReward)}`} | |
| style={{ width: `${Math.max((data.avgReward + 5) / 25 * 100, 2)}%`, opacity: 0.7 }} | |
| /> | |
| </div> | |
| <span className={`font-mono w-10 text-right ${rewardColor(data.avgReward)}`}> | |
| {data.avgReward.toFixed(1)} | |
| </span> | |
| <span className="text-zinc-600 font-mono w-6 text-right">{data.count}</span> | |
| </div> | |
| ))} | |
| </div> | |
| </div> | |
| )} | |
| {/* Activity feed */} | |
| <div className="flex-1 min-h-0 flex flex-col"> | |
| <div className="text-[10px] uppercase text-zinc-500 tracking-wider mb-2 flex items-center justify-between"> | |
| <span className="flex items-center gap-1"><Activity size={10} /> Activity</span> | |
| {!autoSelect.current && entries.length > 0 && ( | |
| <button | |
| className="text-indigo-400 hover:text-indigo-300 normal-case tracking-normal" | |
| onClick={() => { | |
| autoSelect.current = true; | |
| const latest = entries[entries.length - 1]; | |
| setSelectedStep(latest.step); | |
| onEntrySelect?.(latest); | |
| }} | |
| > | |
| Follow latest | |
| </button> | |
| )} | |
| </div> | |
| <div ref={feedRef} className="overflow-y-auto flex-1 space-y-1 pr-1"> | |
| {entries.length === 0 ? ( | |
| <div className="text-xs text-zinc-600 text-center py-12"> | |
| <div className="text-zinc-500 mb-1">No training activity yet</div> | |
| <div>Start a training run in the Colab notebook</div> | |
| </div> | |
| ) : ( | |
| entries.map(e => ( | |
| <div | |
| key={e.step} | |
| onClick={() => handleEntryClick(e)} | |
| className={`rounded-lg px-2.5 py-1.5 text-xs cursor-pointer transition-all border ${ | |
| e.step === selectedStep | |
| ? 'bg-zinc-700/80 border-indigo-500/50' | |
| : 'bg-zinc-800/40 border-zinc-700/30 hover:bg-zinc-800/70' | |
| }`} | |
| > | |
| <div className="flex items-center justify-between"> | |
| <div className="flex items-center gap-1.5"> | |
| <span className="font-mono text-zinc-500 text-[10px]">#{e.step}</span> | |
| <span className="text-zinc-300 text-[10px] truncate max-w-[80px]">{e.task_name}</span> | |
| </div> | |
| <div className="flex items-center gap-1.5"> | |
| {e.error ? ( | |
| <Circle size={6} className="fill-red-400 text-red-400" /> | |
| ) : e.is_valid ? ( | |
| <Circle size={6} className="fill-green-400 text-green-400" /> | |
| ) : ( | |
| <Circle size={6} className="fill-yellow-400 text-yellow-400" /> | |
| )} | |
| <span className={`font-mono font-semibold ${rewardColor(e.reward)}`}> | |
| {e.reward.toFixed(1)} | |
| </span> | |
| </div> | |
| </div> | |
| {e.step === selectedStep && ( | |
| <div className="mt-1.5 space-y-1"> | |
| <div className="flex items-center gap-2"> | |
| <span className="text-zinc-500 text-[10px] w-8">sim</span> | |
| {simBar(e.shape_similarity)} | |
| <span className="text-zinc-400 font-mono text-[10px] w-8 text-right"> | |
| {(e.shape_similarity * 100).toFixed(0)}% | |
| </span> | |
| </div> | |
| {e.error && ( | |
| <div className="text-red-400/70 text-[10px] truncate">{e.error}</div> | |
| )} | |
| </div> | |
| )} | |
| </div> | |
| )) | |
| )} | |
| </div> | |
| </div> | |
| </div> | |
| ); | |
| } | |