Spaces:
Running
Running
Commit ·
d662461
1
Parent(s): c0cedb4
Redesign frontend as training dashboard + add live activity feed
Browse files- Replace manual origami editor with training-focused UI
- TrainingDashboard polls /training/feed for live step data
- 3D canvas shows selected training entry's fold attempt
- Backend: add /training/feed and /training/log endpoints
- Notebook: log each env.step() to dashboard via /training/log
- Add API proxy rewrite and turbopack config
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- app/page.tsx +123 -158
- components/TrainingDashboard.tsx +298 -0
- next.config.ts +1 -0
- origami_server/app.py +56 -0
- training/train_grpo.ipynb +13 -105
app/page.tsx
CHANGED
|
@@ -1,118 +1,96 @@
|
|
| 1 |
'use client';
|
| 2 |
|
| 3 |
-
import { useState,
|
| 4 |
import { Canvas } from '@react-three/fiber';
|
| 5 |
-
import { OrbitControls,
|
| 6 |
import { patterns, Pattern } from '@/lib/patterns';
|
| 7 |
import { OrigamiMesh } from '@/components/OrigamiMesh';
|
| 8 |
-
import {
|
| 9 |
import { parseFoldFile } from '@/lib/foldParser';
|
| 10 |
-
import { LLMPrompt } from '@/components/LLMPrompt';
|
| 11 |
|
| 12 |
export default function Optigami() {
|
| 13 |
-
const [
|
| 14 |
-
const [
|
| 15 |
-
const [
|
| 16 |
-
const [key, setKey] = useState(0); // Used to force reset the simulation
|
| 17 |
-
const fileInputRef = useRef<HTMLInputElement>(null);
|
| 18 |
|
| 19 |
-
|
| 20 |
-
const
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
const handleFileUpload = (e: React.ChangeEvent<HTMLInputElement>) => {
|
| 28 |
-
const file = e.target.files?.[0];
|
| 29 |
-
if (!file) return;
|
| 30 |
-
|
| 31 |
-
const reader = new FileReader();
|
| 32 |
-
reader.onload = (event) => {
|
| 33 |
-
const content = event.target?.result as string;
|
| 34 |
-
const parsedPattern = parseFoldFile(content, file.name);
|
| 35 |
-
if (parsedPattern) {
|
| 36 |
-
setCustomPatterns(prev => [...prev, parsedPattern]);
|
| 37 |
-
setSelectedPatternId(parsedPattern.id);
|
| 38 |
-
setFoldPercent(0);
|
| 39 |
-
setKey(k => k + 1);
|
| 40 |
-
} else {
|
| 41 |
-
alert("Failed to parse .fold file. Please ensure it's a valid FOLD format with vertices_coords and faces_vertices.");
|
| 42 |
-
}
|
| 43 |
-
};
|
| 44 |
-
reader.readAsText(file);
|
| 45 |
-
|
| 46 |
-
if (fileInputRef.current) {
|
| 47 |
-
fileInputRef.current.value = '';
|
| 48 |
}
|
| 49 |
-
};
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
-
const
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
setFoldPercent(0);
|
| 55 |
setKey(k => k + 1);
|
| 56 |
-
};
|
| 57 |
|
| 58 |
return (
|
| 59 |
<div className="flex h-screen w-full bg-zinc-950 text-zinc-100 font-sans overflow-hidden">
|
| 60 |
-
{/* Left
|
| 61 |
-
<div className="w-
|
| 62 |
-
<div className="
|
| 63 |
-
<
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
</
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
</div>
|
| 68 |
|
| 69 |
-
<div className="
|
| 70 |
-
{
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
<div>
|
| 73 |
-
<
|
| 74 |
-
|
| 75 |
-
</label>
|
| 76 |
-
<div className="flex gap-2">
|
| 77 |
-
<select
|
| 78 |
-
className="flex-1 bg-zinc-800 border border-zinc-700 rounded-lg px-3 py-2 text-sm focus:outline-none focus:ring-2 focus:ring-indigo-500"
|
| 79 |
-
value={selectedPatternId}
|
| 80 |
-
onChange={(e) => {
|
| 81 |
-
setSelectedPatternId(e.target.value);
|
| 82 |
-
setFoldPercent(0);
|
| 83 |
-
setKey(k => k + 1);
|
| 84 |
-
}}
|
| 85 |
-
>
|
| 86 |
-
{allPatterns.map(p => (
|
| 87 |
-
<option key={p.id} value={p.id}>{p.name}</option>
|
| 88 |
-
))}
|
| 89 |
-
</select>
|
| 90 |
-
<button
|
| 91 |
-
onClick={() => fileInputRef.current?.click()}
|
| 92 |
-
className="bg-zinc-800 hover:bg-zinc-700 p-2 rounded-lg border border-zinc-700 transition-colors flex-shrink-0 text-zinc-300"
|
| 93 |
-
title="Upload .fold file"
|
| 94 |
-
>
|
| 95 |
-
<Upload size={18} />
|
| 96 |
-
</button>
|
| 97 |
-
<input
|
| 98 |
-
type="file"
|
| 99 |
-
accept=".fold"
|
| 100 |
-
className="hidden"
|
| 101 |
-
ref={fileInputRef}
|
| 102 |
-
onChange={handleFileUpload}
|
| 103 |
-
/>
|
| 104 |
-
</div>
|
| 105 |
</div>
|
| 106 |
-
|
| 107 |
<div>
|
| 108 |
-
<
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
<input
|
| 117 |
type="range"
|
| 118 |
min="0"
|
|
@@ -120,37 +98,57 @@ export default function Optigami() {
|
|
| 120 |
step="0.01"
|
| 121 |
value={foldPercent}
|
| 122 |
onChange={(e) => setFoldPercent(parseFloat(e.target.value))}
|
| 123 |
-
className="w-
|
| 124 |
/>
|
|
|
|
| 125 |
</div>
|
| 126 |
-
|
| 127 |
-
<button
|
| 128 |
-
onClick={handleReset}
|
| 129 |
-
className="w-full flex items-center justify-center gap-2 bg-zinc-800 hover:bg-zinc-700 text-sm py-2 rounded-lg transition-colors border border-zinc-700"
|
| 130 |
-
>
|
| 131 |
-
<RefreshCw size={14} />
|
| 132 |
-
Reset Simulation
|
| 133 |
-
</button>
|
| 134 |
-
|
| 135 |
-
<LLMPrompt onPatternGenerated={handlePatternGenerated} />
|
| 136 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
<svg viewBox="-1.2 -1.2 2.4 2.4" className="w-full h-full">
|
| 145 |
<g transform="scale(1, -1)">
|
| 146 |
-
{
|
| 147 |
-
|
| 148 |
-
const
|
| 149 |
-
const
|
| 150 |
-
const v3 = pattern.vertices[face[2]];
|
| 151 |
return (
|
| 152 |
<polygon
|
| 153 |
-
key={`
|
| 154 |
points={`${v1[0]},${v1[1]} ${v2[0]},${v2[1]} ${v3[0]},${v3[1]}`}
|
| 155 |
fill="#3f3f46"
|
| 156 |
stroke="#52525b"
|
|
@@ -158,18 +156,15 @@ export default function Optigami() {
|
|
| 158 |
/>
|
| 159 |
);
|
| 160 |
})}
|
| 161 |
-
{
|
| 162 |
-
|
| 163 |
-
const
|
| 164 |
-
const v2 = pattern.vertices[crease.edge[1]];
|
| 165 |
const color = crease.type === 'mountain' ? '#ef4444' : '#3b82f6';
|
| 166 |
return (
|
| 167 |
<line
|
| 168 |
-
key={`
|
| 169 |
-
x1={v1[0]}
|
| 170 |
-
|
| 171 |
-
x2={v2[0]}
|
| 172 |
-
y2={v2[1]}
|
| 173 |
stroke={color}
|
| 174 |
strokeWidth="0.03"
|
| 175 |
strokeLinecap="round"
|
|
@@ -178,40 +173,10 @@ export default function Optigami() {
|
|
| 178 |
})}
|
| 179 |
</g>
|
| 180 |
</svg>
|
| 181 |
-
<div className="absolute bottom-2 left-2 flex gap-3 text-[10px] uppercase font-mono text-zinc-500">
|
| 182 |
-
<div className="flex items-center gap-1">
|
| 183 |
-
<div className="w-2 h-0.5 bg-red-500"></div> Mountain
|
| 184 |
-
</div>
|
| 185 |
-
<div className="flex items-center gap-1">
|
| 186 |
-
<div className="w-2 h-0.5 bg-blue-500"></div> Valley
|
| 187 |
-
</div>
|
| 188 |
-
</div>
|
| 189 |
</div>
|
| 190 |
-
|
| 191 |
</div>
|
| 192 |
</div>
|
| 193 |
-
|
| 194 |
-
{/* 3D Canvas */}
|
| 195 |
-
<div className="flex-1 relative bg-zinc-950">
|
| 196 |
-
<Canvas camera={{ position: [0, 0, 3], fov: 45 }}>
|
| 197 |
-
<ambientLight intensity={0.5} />
|
| 198 |
-
<directionalLight position={[5, 5, 5]} intensity={1} castShadow />
|
| 199 |
-
<directionalLight position={[-5, -5, -5]} intensity={0.2} />
|
| 200 |
-
|
| 201 |
-
<group key={key}>
|
| 202 |
-
<OrigamiMesh pattern={pattern} foldPercent={foldPercent} />
|
| 203 |
-
</group>
|
| 204 |
-
|
| 205 |
-
<OrbitControls makeDefault />
|
| 206 |
-
<Grid
|
| 207 |
-
infiniteGrid
|
| 208 |
-
fadeDistance={10}
|
| 209 |
-
sectionColor="#333"
|
| 210 |
-
cellColor="#222"
|
| 211 |
-
position={[0, 0, -0.01]}
|
| 212 |
-
/>
|
| 213 |
-
</Canvas>
|
| 214 |
-
</div>
|
| 215 |
</div>
|
| 216 |
);
|
| 217 |
}
|
|
|
|
| 1 |
'use client';
|
| 2 |
|
| 3 |
+
import { useState, useCallback, useMemo } from 'react';
|
| 4 |
import { Canvas } from '@react-three/fiber';
|
| 5 |
+
import { OrbitControls, Grid } from '@react-three/drei';
|
| 6 |
import { patterns, Pattern } from '@/lib/patterns';
|
| 7 |
import { OrigamiMesh } from '@/components/OrigamiMesh';
|
| 8 |
+
import { TrainingDashboard, TrainingEntry } from '@/components/TrainingDashboard';
|
| 9 |
import { parseFoldFile } from '@/lib/foldParser';
|
|
|
|
| 10 |
|
| 11 |
export default function Optigami() {
|
| 12 |
+
const [selectedEntry, setSelectedEntry] = useState<TrainingEntry | null>(null);
|
| 13 |
+
const [foldPercent, setFoldPercent] = useState(1);
|
| 14 |
+
const [key, setKey] = useState(0);
|
|
|
|
|
|
|
| 15 |
|
| 16 |
+
// Convert training entry's fold_data into a Pattern for the 3D viewer
|
| 17 |
+
const activePattern: Pattern | null = useMemo(() => {
|
| 18 |
+
if (!selectedEntry?.fold_data) return null;
|
| 19 |
+
try {
|
| 20 |
+
const parsed = parseFoldFile(JSON.stringify(selectedEntry.fold_data), selectedEntry.task_name);
|
| 21 |
+
return parsed;
|
| 22 |
+
} catch {
|
| 23 |
+
return null;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
}
|
| 25 |
+
}, [selectedEntry]);
|
| 26 |
+
|
| 27 |
+
// Fallback to a default pattern when no training entry is selected
|
| 28 |
+
const displayPattern = activePattern || patterns[0];
|
| 29 |
|
| 30 |
+
const handleEntrySelect = useCallback((entry: TrainingEntry) => {
|
| 31 |
+
setSelectedEntry(entry);
|
| 32 |
+
setFoldPercent(1);
|
|
|
|
| 33 |
setKey(k => k + 1);
|
| 34 |
+
}, []);
|
| 35 |
|
| 36 |
return (
|
| 37 |
<div className="flex h-screen w-full bg-zinc-950 text-zinc-100 font-sans overflow-hidden">
|
| 38 |
+
{/* Left panel — Training Dashboard */}
|
| 39 |
+
<div className="w-96 flex-shrink-0 border-r border-zinc-800 bg-zinc-900 flex flex-col">
|
| 40 |
+
<div className="px-5 py-4 border-b border-zinc-800 flex items-center justify-between">
|
| 41 |
+
<div>
|
| 42 |
+
<h1 className="text-lg font-semibold tracking-tight">Optigami</h1>
|
| 43 |
+
<p className="text-[11px] text-zinc-500">RL Training Environment</p>
|
| 44 |
+
</div>
|
| 45 |
+
<div className="flex items-center gap-2">
|
| 46 |
+
<a
|
| 47 |
+
href="https://huggingface.co/spaces/openenv-community/optigami_"
|
| 48 |
+
target="_blank"
|
| 49 |
+
rel="noopener noreferrer"
|
| 50 |
+
className="text-[10px] text-zinc-500 hover:text-zinc-300 bg-zinc-800 px-2 py-1 rounded border border-zinc-700"
|
| 51 |
+
>
|
| 52 |
+
OpenEnv 0.2.1
|
| 53 |
+
</a>
|
| 54 |
+
</div>
|
| 55 |
</div>
|
| 56 |
|
| 57 |
+
<div className="flex-1 overflow-y-auto p-4">
|
| 58 |
+
<TrainingDashboard onEntrySelect={handleEntrySelect} />
|
| 59 |
+
</div>
|
| 60 |
+
</div>
|
| 61 |
+
|
| 62 |
+
{/* Right side — 3D viewer + detail */}
|
| 63 |
+
<div className="flex-1 flex flex-col">
|
| 64 |
+
{/* Top bar with context about selected entry */}
|
| 65 |
+
{selectedEntry && (
|
| 66 |
+
<div className="flex-shrink-0 border-b border-zinc-800 bg-zinc-900/50 px-6 py-3 flex items-center gap-6 text-xs">
|
| 67 |
<div>
|
| 68 |
+
<span className="text-zinc-500">Step</span>{' '}
|
| 69 |
+
<span className="font-mono text-zinc-200">#{selectedEntry.step}</span>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
</div>
|
|
|
|
| 71 |
<div>
|
| 72 |
+
<span className="text-zinc-500">Task</span>{' '}
|
| 73 |
+
<span className="text-zinc-200">{selectedEntry.task_name}</span>
|
| 74 |
+
</div>
|
| 75 |
+
<div>
|
| 76 |
+
<span className="text-zinc-500">Reward</span>{' '}
|
| 77 |
+
<span className={`font-mono font-semibold ${
|
| 78 |
+
selectedEntry.reward >= 15 ? 'text-green-400' :
|
| 79 |
+
selectedEntry.reward >= 5 ? 'text-yellow-400' :
|
| 80 |
+
selectedEntry.reward >= 0 ? 'text-orange-400' : 'text-red-400'
|
| 81 |
+
}`}>{selectedEntry.reward.toFixed(2)}</span>
|
| 82 |
+
</div>
|
| 83 |
+
<div>
|
| 84 |
+
<span className="text-zinc-500">Similarity</span>{' '}
|
| 85 |
+
<span className="font-mono text-indigo-400">
|
| 86 |
+
{(selectedEntry.shape_similarity * 100).toFixed(1)}%
|
| 87 |
+
</span>
|
| 88 |
+
</div>
|
| 89 |
+
{selectedEntry.error && (
|
| 90 |
+
<div className="text-red-400/80 truncate flex-1">{selectedEntry.error}</div>
|
| 91 |
+
)}
|
| 92 |
+
<div className="ml-auto flex items-center gap-2">
|
| 93 |
+
<span className="text-zinc-500">Fold</span>
|
| 94 |
<input
|
| 95 |
type="range"
|
| 96 |
min="0"
|
|
|
|
| 98 |
step="0.01"
|
| 99 |
value={foldPercent}
|
| 100 |
onChange={(e) => setFoldPercent(parseFloat(e.target.value))}
|
| 101 |
+
className="w-24 accent-indigo-500"
|
| 102 |
/>
|
| 103 |
+
<span className="font-mono text-zinc-400 w-8">{Math.round(foldPercent * 100)}%</span>
|
| 104 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
</div>
|
| 106 |
+
)}
|
| 107 |
+
|
| 108 |
+
{/* 3D Canvas */}
|
| 109 |
+
<div className="flex-1 relative">
|
| 110 |
+
{!selectedEntry && (
|
| 111 |
+
<div className="absolute inset-0 flex items-center justify-center z-10 pointer-events-none">
|
| 112 |
+
<div className="text-center">
|
| 113 |
+
<div className="text-zinc-500 text-sm mb-1">Waiting for training data</div>
|
| 114 |
+
<div className="text-zinc-600 text-xs">
|
| 115 |
+
Start a GRPO training run in the Colab notebook
|
| 116 |
+
</div>
|
| 117 |
+
</div>
|
| 118 |
+
</div>
|
| 119 |
+
)}
|
| 120 |
+
<Canvas camera={{ position: [0, 0, 3], fov: 45 }}>
|
| 121 |
+
<ambientLight intensity={0.5} />
|
| 122 |
+
<directionalLight position={[5, 5, 5]} intensity={1} castShadow />
|
| 123 |
+
<directionalLight position={[-5, -5, -5]} intensity={0.2} />
|
| 124 |
+
|
| 125 |
+
<group key={key}>
|
| 126 |
+
<OrigamiMesh pattern={displayPattern} foldPercent={foldPercent} />
|
| 127 |
+
</group>
|
| 128 |
|
| 129 |
+
<OrbitControls makeDefault />
|
| 130 |
+
<Grid
|
| 131 |
+
infiniteGrid
|
| 132 |
+
fadeDistance={10}
|
| 133 |
+
sectionColor="#333"
|
| 134 |
+
cellColor="#222"
|
| 135 |
+
position={[0, 0, -0.01]}
|
| 136 |
+
/>
|
| 137 |
+
</Canvas>
|
| 138 |
+
|
| 139 |
+
{/* 2D crease pattern overlay */}
|
| 140 |
+
{activePattern && (
|
| 141 |
+
<div className="absolute bottom-4 left-4 w-40 h-40 bg-zinc-900/90 rounded-lg border border-zinc-700/50 p-2 backdrop-blur-sm">
|
| 142 |
+
<div className="text-[9px] uppercase text-zinc-500 tracking-wider mb-1">Crease Pattern</div>
|
| 143 |
<svg viewBox="-1.2 -1.2 2.4 2.4" className="w-full h-full">
|
| 144 |
<g transform="scale(1, -1)">
|
| 145 |
+
{activePattern.faces.map((face, i) => {
|
| 146 |
+
const v1 = activePattern.vertices[face[0]];
|
| 147 |
+
const v2 = activePattern.vertices[face[1]];
|
| 148 |
+
const v3 = activePattern.vertices[face[2]];
|
|
|
|
| 149 |
return (
|
| 150 |
<polygon
|
| 151 |
+
key={`f-${i}`}
|
| 152 |
points={`${v1[0]},${v1[1]} ${v2[0]},${v2[1]} ${v3[0]},${v3[1]}`}
|
| 153 |
fill="#3f3f46"
|
| 154 |
stroke="#52525b"
|
|
|
|
| 156 |
/>
|
| 157 |
);
|
| 158 |
})}
|
| 159 |
+
{activePattern.creases.map((crease, i) => {
|
| 160 |
+
const v1 = activePattern.vertices[crease.edge[0]];
|
| 161 |
+
const v2 = activePattern.vertices[crease.edge[1]];
|
|
|
|
| 162 |
const color = crease.type === 'mountain' ? '#ef4444' : '#3b82f6';
|
| 163 |
return (
|
| 164 |
<line
|
| 165 |
+
key={`c-${i}`}
|
| 166 |
+
x1={v1[0]} y1={v1[1]}
|
| 167 |
+
x2={v2[0]} y2={v2[1]}
|
|
|
|
|
|
|
| 168 |
stroke={color}
|
| 169 |
strokeWidth="0.03"
|
| 170 |
strokeLinecap="round"
|
|
|
|
| 173 |
})}
|
| 174 |
</g>
|
| 175 |
</svg>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
</div>
|
| 177 |
+
)}
|
| 178 |
</div>
|
| 179 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
</div>
|
| 181 |
);
|
| 182 |
}
|
components/TrainingDashboard.tsx
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'use client';
|
| 2 |
+
|
| 3 |
+
import { useState, useEffect, useRef, useCallback } from 'react';
|
| 4 |
+
import { Activity, TrendingUp, AlertCircle, CheckCircle2, Circle } from 'lucide-react';
|
| 5 |
+
|
| 6 |
+
export interface TrainingEntry {
|
| 7 |
+
step: number;
|
| 8 |
+
timestamp: number;
|
| 9 |
+
task_name: string;
|
| 10 |
+
reward: number;
|
| 11 |
+
shape_similarity: number;
|
| 12 |
+
is_valid: boolean;
|
| 13 |
+
error: string | null;
|
| 14 |
+
fold_data: any | null;
|
| 15 |
+
final_positions: number[][];
|
| 16 |
+
target_positions: number[][];
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
interface TrainingStats {
|
| 20 |
+
total_steps: number;
|
| 21 |
+
best_reward: number;
|
| 22 |
+
best_similarity: number;
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
interface FeedResponse {
|
| 26 |
+
entries: TrainingEntry[];
|
| 27 |
+
stats: TrainingStats;
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
interface TrainingDashboardProps {
|
| 31 |
+
onEntrySelect?: (entry: TrainingEntry) => void;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
export function TrainingDashboard({ onEntrySelect }: TrainingDashboardProps) {
|
| 35 |
+
const [entries, setEntries] = useState<TrainingEntry[]>([]);
|
| 36 |
+
const [stats, setStats] = useState<TrainingStats>({ total_steps: 0, best_reward: -999, best_similarity: 0 });
|
| 37 |
+
const [connected, setConnected] = useState(false);
|
| 38 |
+
const [error, setError] = useState<string | null>(null);
|
| 39 |
+
const [selectedStep, setSelectedStep] = useState<number | null>(null);
|
| 40 |
+
const lastStep = useRef(0);
|
| 41 |
+
const feedRef = useRef<HTMLDivElement>(null);
|
| 42 |
+
|
| 43 |
+
// Auto-select newest entry
|
| 44 |
+
const autoSelect = useRef(true);
|
| 45 |
+
|
| 46 |
+
useEffect(() => {
|
| 47 |
+
let active = true;
|
| 48 |
+
|
| 49 |
+
async function poll() {
|
| 50 |
+
while (active) {
|
| 51 |
+
try {
|
| 52 |
+
const res = await fetch(`/api/env/training/feed?since=${lastStep.current}`);
|
| 53 |
+
if (!res.ok) throw new Error(`HTTP ${res.status}`);
|
| 54 |
+
const data: FeedResponse = await res.json();
|
| 55 |
+
|
| 56 |
+
if (data.entries.length > 0) {
|
| 57 |
+
setEntries(prev => {
|
| 58 |
+
const combined = [...prev, ...data.entries];
|
| 59 |
+
return combined.slice(-50);
|
| 60 |
+
});
|
| 61 |
+
lastStep.current = data.entries[data.entries.length - 1].step;
|
| 62 |
+
|
| 63 |
+
// Auto-select latest
|
| 64 |
+
if (autoSelect.current) {
|
| 65 |
+
const latest = data.entries[data.entries.length - 1];
|
| 66 |
+
setSelectedStep(latest.step);
|
| 67 |
+
onEntrySelect?.(latest);
|
| 68 |
+
}
|
| 69 |
+
}
|
| 70 |
+
setStats(data.stats);
|
| 71 |
+
setConnected(true);
|
| 72 |
+
setError(null);
|
| 73 |
+
} catch (e: any) {
|
| 74 |
+
setConnected(false);
|
| 75 |
+
setError(e.message);
|
| 76 |
+
}
|
| 77 |
+
await new Promise(r => setTimeout(r, 2000));
|
| 78 |
+
}
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
poll();
|
| 82 |
+
return () => { active = false; };
|
| 83 |
+
}, [onEntrySelect]);
|
| 84 |
+
|
| 85 |
+
// Auto-scroll feed
|
| 86 |
+
useEffect(() => {
|
| 87 |
+
if (feedRef.current && autoSelect.current) {
|
| 88 |
+
feedRef.current.scrollTop = feedRef.current.scrollHeight;
|
| 89 |
+
}
|
| 90 |
+
}, [entries]);
|
| 91 |
+
|
| 92 |
+
const handleEntryClick = useCallback((entry: TrainingEntry) => {
|
| 93 |
+
autoSelect.current = false;
|
| 94 |
+
setSelectedStep(entry.step);
|
| 95 |
+
onEntrySelect?.(entry);
|
| 96 |
+
}, [onEntrySelect]);
|
| 97 |
+
|
| 98 |
+
const rewardColor = (r: number) => {
|
| 99 |
+
if (r >= 15) return 'text-green-400';
|
| 100 |
+
if (r >= 5) return 'text-yellow-400';
|
| 101 |
+
if (r >= 0) return 'text-orange-400';
|
| 102 |
+
return 'text-red-400';
|
| 103 |
+
};
|
| 104 |
+
|
| 105 |
+
const rewardBg = (r: number) => {
|
| 106 |
+
if (r >= 15) return 'bg-green-500';
|
| 107 |
+
if (r >= 5) return 'bg-yellow-500';
|
| 108 |
+
if (r >= 0) return 'bg-orange-500';
|
| 109 |
+
return 'bg-red-500';
|
| 110 |
+
};
|
| 111 |
+
|
| 112 |
+
const simBar = (sim: number) => {
|
| 113 |
+
const pct = Math.min(sim * 100, 100);
|
| 114 |
+
const color = pct > 70 ? 'bg-green-500' : pct > 40 ? 'bg-yellow-500' : 'bg-red-500';
|
| 115 |
+
return (
|
| 116 |
+
<div className="w-full h-1.5 bg-zinc-700 rounded-full overflow-hidden">
|
| 117 |
+
<div className={`h-full ${color} rounded-full transition-all`} style={{ width: `${pct}%` }} />
|
| 118 |
+
</div>
|
| 119 |
+
);
|
| 120 |
+
};
|
| 121 |
+
|
| 122 |
+
// Reward chart data
|
| 123 |
+
const recentRewards = entries.slice(-30);
|
| 124 |
+
const maxR = Math.max(...recentRewards.map(e => e.reward), 1);
|
| 125 |
+
const minR = Math.min(...recentRewards.map(e => e.reward), 0);
|
| 126 |
+
const range = maxR - minR || 1;
|
| 127 |
+
|
| 128 |
+
// Task breakdown
|
| 129 |
+
const taskCounts: Record<string, { count: number; avgReward: number; totalReward: number }> = {};
|
| 130 |
+
for (const e of entries) {
|
| 131 |
+
if (!taskCounts[e.task_name]) taskCounts[e.task_name] = { count: 0, avgReward: 0, totalReward: 0 };
|
| 132 |
+
taskCounts[e.task_name].count++;
|
| 133 |
+
taskCounts[e.task_name].totalReward += e.reward;
|
| 134 |
+
taskCounts[e.task_name].avgReward = taskCounts[e.task_name].totalReward / taskCounts[e.task_name].count;
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
return (
|
| 138 |
+
<div className="flex flex-col gap-3 h-full">
|
| 139 |
+
{/* Connection status */}
|
| 140 |
+
<div className="flex items-center gap-2 text-xs">
|
| 141 |
+
<div className={`w-2 h-2 rounded-full ${connected ? 'bg-green-500 animate-pulse' : 'bg-zinc-600'}`} />
|
| 142 |
+
<span className="text-zinc-400">
|
| 143 |
+
{connected ? 'Live' : 'Connecting...'}
|
| 144 |
+
</span>
|
| 145 |
+
{error && <span className="text-red-400/80 ml-auto truncate max-w-[140px]">{error}</span>}
|
| 146 |
+
</div>
|
| 147 |
+
|
| 148 |
+
{/* Stats row */}
|
| 149 |
+
<div className="grid grid-cols-3 gap-2">
|
| 150 |
+
<div className="bg-zinc-800/80 rounded-lg p-2.5 border border-zinc-700/60">
|
| 151 |
+
<div className="text-[10px] uppercase text-zinc-500 tracking-wider">Steps</div>
|
| 152 |
+
<div className="text-xl font-mono font-bold text-zinc-100">{stats.total_steps}</div>
|
| 153 |
+
</div>
|
| 154 |
+
<div className="bg-zinc-800/80 rounded-lg p-2.5 border border-zinc-700/60">
|
| 155 |
+
<div className="text-[10px] uppercase text-zinc-500 tracking-wider">Best Reward</div>
|
| 156 |
+
<div className={`text-xl font-mono font-bold ${rewardColor(stats.best_reward)}`}>
|
| 157 |
+
{stats.best_reward > -999 ? stats.best_reward.toFixed(1) : '--'}
|
| 158 |
+
</div>
|
| 159 |
+
</div>
|
| 160 |
+
<div className="bg-zinc-800/80 rounded-lg p-2.5 border border-zinc-700/60">
|
| 161 |
+
<div className="text-[10px] uppercase text-zinc-500 tracking-wider">Best Sim</div>
|
| 162 |
+
<div className="text-xl font-mono font-bold text-indigo-400">
|
| 163 |
+
{stats.best_similarity > 0 ? (stats.best_similarity * 100).toFixed(0) + '%' : '--'}
|
| 164 |
+
</div>
|
| 165 |
+
</div>
|
| 166 |
+
</div>
|
| 167 |
+
|
| 168 |
+
{/* Reward trend chart */}
|
| 169 |
+
{recentRewards.length > 1 && (
|
| 170 |
+
<div className="bg-zinc-800/80 rounded-lg p-3 border border-zinc-700/60">
|
| 171 |
+
<div className="text-[10px] uppercase text-zinc-500 tracking-wider mb-2 flex items-center gap-1">
|
| 172 |
+
<TrendingUp size={10} /> Reward History
|
| 173 |
+
</div>
|
| 174 |
+
<div className="flex items-end gap-[2px] h-20">
|
| 175 |
+
{recentRewards.map((e, i) => {
|
| 176 |
+
const h = ((e.reward - minR) / range) * 100;
|
| 177 |
+
const isSelected = e.step === selectedStep;
|
| 178 |
+
return (
|
| 179 |
+
<div
|
| 180 |
+
key={e.step}
|
| 181 |
+
className={`flex-1 rounded-t cursor-pointer transition-all ${isSelected ? 'ring-1 ring-white' : ''}`}
|
| 182 |
+
style={{
|
| 183 |
+
height: `${Math.max(h, 3)}%`,
|
| 184 |
+
backgroundColor: e.reward >= 15 ? '#22c55e' : e.reward >= 5 ? '#eab308' : e.reward >= 0 ? '#f97316' : '#ef4444',
|
| 185 |
+
opacity: isSelected ? 1 : 0.4 + (i / recentRewards.length) * 0.5,
|
| 186 |
+
}}
|
| 187 |
+
title={`#${e.step} ${e.task_name}: ${e.reward.toFixed(2)}`}
|
| 188 |
+
onClick={() => handleEntryClick(e)}
|
| 189 |
+
/>
|
| 190 |
+
);
|
| 191 |
+
})}
|
| 192 |
+
</div>
|
| 193 |
+
<div className="flex justify-between text-[9px] text-zinc-600 mt-1 font-mono">
|
| 194 |
+
<span>{minR.toFixed(1)}</span>
|
| 195 |
+
<span>{maxR.toFixed(1)}</span>
|
| 196 |
+
</div>
|
| 197 |
+
</div>
|
| 198 |
+
)}
|
| 199 |
+
|
| 200 |
+
{/* Task breakdown */}
|
| 201 |
+
{Object.keys(taskCounts).length > 1 && (
|
| 202 |
+
<div className="bg-zinc-800/80 rounded-lg p-3 border border-zinc-700/60">
|
| 203 |
+
<div className="text-[10px] uppercase text-zinc-500 tracking-wider mb-2">Tasks</div>
|
| 204 |
+
<div className="space-y-1">
|
| 205 |
+
{Object.entries(taskCounts).map(([name, data]) => (
|
| 206 |
+
<div key={name} className="flex items-center gap-2 text-[11px]">
|
| 207 |
+
<span className="text-zinc-400 w-24 truncate">{name}</span>
|
| 208 |
+
<div className="flex-1 h-1 bg-zinc-700 rounded-full overflow-hidden">
|
| 209 |
+
<div
|
| 210 |
+
className={`h-full rounded-full ${rewardBg(data.avgReward)}`}
|
| 211 |
+
style={{ width: `${Math.max((data.avgReward + 5) / 25 * 100, 2)}%`, opacity: 0.7 }}
|
| 212 |
+
/>
|
| 213 |
+
</div>
|
| 214 |
+
<span className={`font-mono w-10 text-right ${rewardColor(data.avgReward)}`}>
|
| 215 |
+
{data.avgReward.toFixed(1)}
|
| 216 |
+
</span>
|
| 217 |
+
<span className="text-zinc-600 font-mono w-6 text-right">{data.count}</span>
|
| 218 |
+
</div>
|
| 219 |
+
))}
|
| 220 |
+
</div>
|
| 221 |
+
</div>
|
| 222 |
+
)}
|
| 223 |
+
|
| 224 |
+
{/* Activity feed */}
|
| 225 |
+
<div className="flex-1 min-h-0 flex flex-col">
|
| 226 |
+
<div className="text-[10px] uppercase text-zinc-500 tracking-wider mb-2 flex items-center justify-between">
|
| 227 |
+
<span className="flex items-center gap-1"><Activity size={10} /> Activity</span>
|
| 228 |
+
{!autoSelect.current && entries.length > 0 && (
|
| 229 |
+
<button
|
| 230 |
+
className="text-indigo-400 hover:text-indigo-300 normal-case tracking-normal"
|
| 231 |
+
onClick={() => {
|
| 232 |
+
autoSelect.current = true;
|
| 233 |
+
const latest = entries[entries.length - 1];
|
| 234 |
+
setSelectedStep(latest.step);
|
| 235 |
+
onEntrySelect?.(latest);
|
| 236 |
+
}}
|
| 237 |
+
>
|
| 238 |
+
Follow latest
|
| 239 |
+
</button>
|
| 240 |
+
)}
|
| 241 |
+
</div>
|
| 242 |
+
<div ref={feedRef} className="overflow-y-auto flex-1 space-y-1 pr-1">
|
| 243 |
+
{entries.length === 0 ? (
|
| 244 |
+
<div className="text-xs text-zinc-600 text-center py-12">
|
| 245 |
+
<div className="text-zinc-500 mb-1">No training activity yet</div>
|
| 246 |
+
<div>Start a training run in the Colab notebook</div>
|
| 247 |
+
</div>
|
| 248 |
+
) : (
|
| 249 |
+
entries.map(e => (
|
| 250 |
+
<div
|
| 251 |
+
key={e.step}
|
| 252 |
+
onClick={() => handleEntryClick(e)}
|
| 253 |
+
className={`rounded-lg px-2.5 py-1.5 text-xs cursor-pointer transition-all border ${
|
| 254 |
+
e.step === selectedStep
|
| 255 |
+
? 'bg-zinc-700/80 border-indigo-500/50'
|
| 256 |
+
: 'bg-zinc-800/40 border-zinc-700/30 hover:bg-zinc-800/70'
|
| 257 |
+
}`}
|
| 258 |
+
>
|
| 259 |
+
<div className="flex items-center justify-between">
|
| 260 |
+
<div className="flex items-center gap-1.5">
|
| 261 |
+
<span className="font-mono text-zinc-500 text-[10px]">#{e.step}</span>
|
| 262 |
+
<span className="text-zinc-300 text-[10px] truncate max-w-[80px]">{e.task_name}</span>
|
| 263 |
+
</div>
|
| 264 |
+
<div className="flex items-center gap-1.5">
|
| 265 |
+
{e.error ? (
|
| 266 |
+
<Circle size={6} className="fill-red-400 text-red-400" />
|
| 267 |
+
) : e.is_valid ? (
|
| 268 |
+
<Circle size={6} className="fill-green-400 text-green-400" />
|
| 269 |
+
) : (
|
| 270 |
+
<Circle size={6} className="fill-yellow-400 text-yellow-400" />
|
| 271 |
+
)}
|
| 272 |
+
<span className={`font-mono font-semibold ${rewardColor(e.reward)}`}>
|
| 273 |
+
{e.reward.toFixed(1)}
|
| 274 |
+
</span>
|
| 275 |
+
</div>
|
| 276 |
+
</div>
|
| 277 |
+
{e.step === selectedStep && (
|
| 278 |
+
<div className="mt-1.5 space-y-1">
|
| 279 |
+
<div className="flex items-center gap-2">
|
| 280 |
+
<span className="text-zinc-500 text-[10px] w-8">sim</span>
|
| 281 |
+
{simBar(e.shape_similarity)}
|
| 282 |
+
<span className="text-zinc-400 font-mono text-[10px] w-8 text-right">
|
| 283 |
+
{(e.shape_similarity * 100).toFixed(0)}%
|
| 284 |
+
</span>
|
| 285 |
+
</div>
|
| 286 |
+
{e.error && (
|
| 287 |
+
<div className="text-red-400/70 text-[10px] truncate">{e.error}</div>
|
| 288 |
+
)}
|
| 289 |
+
</div>
|
| 290 |
+
)}
|
| 291 |
+
</div>
|
| 292 |
+
))
|
| 293 |
+
)}
|
| 294 |
+
</div>
|
| 295 |
+
</div>
|
| 296 |
+
</div>
|
| 297 |
+
);
|
| 298 |
+
}
|
next.config.ts
CHANGED
|
@@ -29,6 +29,7 @@ const nextConfig: NextConfig = {
|
|
| 29 |
];
|
| 30 |
},
|
| 31 |
transpilePackages: ['motion'],
|
|
|
|
| 32 |
webpack: (config, {dev}) => {
|
| 33 |
// HMR is disabled in AI Studio via DISABLE_HMR env var.
|
| 34 |
// Do not modifyâfile watching is disabled to prevent flickering during agent edits.
|
|
|
|
| 29 |
];
|
| 30 |
},
|
| 31 |
transpilePackages: ['motion'],
|
| 32 |
+
turbopack: {},
|
| 33 |
webpack: (config, {dev}) => {
|
| 34 |
// HMR is disabled in AI Studio via DISABLE_HMR env var.
|
| 35 |
// Do not modifyâfile watching is disabled to prevent flickering during agent edits.
|
origami_server/app.py
CHANGED
|
@@ -1,9 +1,12 @@
|
|
| 1 |
"""FastAPI entry point — OpenEnv create_app() + custom endpoints."""
|
| 2 |
|
| 3 |
import os
|
|
|
|
|
|
|
| 4 |
from pathlib import Path
|
| 5 |
|
| 6 |
from fastapi import HTTPException
|
|
|
|
| 7 |
from fastapi.responses import HTMLResponse
|
| 8 |
|
| 9 |
from openenv.core.env_server.http_server import create_app
|
|
@@ -19,6 +22,59 @@ app = create_app(
|
|
| 19 |
env_name="origami_env",
|
| 20 |
)
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
@app.get("/tasks")
|
| 24 |
def get_tasks():
|
|
|
|
| 1 |
"""FastAPI entry point — OpenEnv create_app() + custom endpoints."""
|
| 2 |
|
| 3 |
import os
|
| 4 |
+
import time
|
| 5 |
+
from collections import deque
|
| 6 |
from pathlib import Path
|
| 7 |
|
| 8 |
from fastapi import HTTPException
|
| 9 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 10 |
from fastapi.responses import HTMLResponse
|
| 11 |
|
| 12 |
from openenv.core.env_server.http_server import create_app
|
|
|
|
| 22 |
env_name="origami_env",
|
| 23 |
)
|
| 24 |
|
| 25 |
+
# Allow CORS for frontend polling
|
| 26 |
+
app.add_middleware(
|
| 27 |
+
CORSMiddleware,
|
| 28 |
+
allow_origins=["*"],
|
| 29 |
+
allow_methods=["*"],
|
| 30 |
+
allow_headers=["*"],
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# ── Training Activity Feed ───────────────────────────────────────────────────
|
| 34 |
+
# Ring buffer of recent training steps — the frontend polls this to visualize
|
| 35 |
+
# what's happening during GRPO training.
|
| 36 |
+
|
| 37 |
+
ACTIVITY_FEED: deque = deque(maxlen=50) # Last 50 steps
|
| 38 |
+
TRAINING_STATS: dict = {"total_steps": 0, "best_reward": -999, "best_similarity": 0}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@app.get("/training/feed")
|
| 42 |
+
def get_training_feed(since: int = 0):
|
| 43 |
+
"""Get recent training activity. Pass `since=<step>` to get only new entries."""
|
| 44 |
+
entries = [e for e in ACTIVITY_FEED if e["step"] > since]
|
| 45 |
+
return {"entries": entries, "stats": TRAINING_STATS}
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@app.post("/training/log")
|
| 49 |
+
def log_training_step(data: dict):
|
| 50 |
+
"""Log a training step from the notebook. Called after each env.step()."""
|
| 51 |
+
step = TRAINING_STATS["total_steps"] + 1
|
| 52 |
+
TRAINING_STATS["total_steps"] = step
|
| 53 |
+
|
| 54 |
+
entry = {
|
| 55 |
+
"step": step,
|
| 56 |
+
"timestamp": time.time(),
|
| 57 |
+
"task_name": data.get("task_name", ""),
|
| 58 |
+
"reward": data.get("reward", 0),
|
| 59 |
+
"shape_similarity": data.get("shape_similarity", 0),
|
| 60 |
+
"is_valid": data.get("is_valid", False),
|
| 61 |
+
"error": data.get("error", None),
|
| 62 |
+
"fold_data": data.get("fold_data", None),
|
| 63 |
+
"final_positions": data.get("final_positions", []),
|
| 64 |
+
"target_positions": data.get("target_positions", []),
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
ACTIVITY_FEED.append(entry)
|
| 68 |
+
|
| 69 |
+
reward = entry["reward"]
|
| 70 |
+
sim = entry["shape_similarity"]
|
| 71 |
+
if reward > TRAINING_STATS["best_reward"]:
|
| 72 |
+
TRAINING_STATS["best_reward"] = reward
|
| 73 |
+
if sim > TRAINING_STATS["best_similarity"]:
|
| 74 |
+
TRAINING_STATS["best_similarity"] = sim
|
| 75 |
+
|
| 76 |
+
return {"step": step}
|
| 77 |
+
|
| 78 |
|
| 79 |
@app.get("/tasks")
|
| 80 |
def get_tasks():
|
training/train_grpo.ipynb
CHANGED
|
@@ -181,108 +181,7 @@
|
|
| 181 |
"execution_count": null,
|
| 182 |
"metadata": {},
|
| 183 |
"outputs": [],
|
| 184 |
-
"source": [
|
| 185 |
-
"PRINTER = 0\n",
|
| 186 |
-
"\n",
|
| 187 |
-
"def extract_fold_json(response):\n",
|
| 188 |
-
" \"\"\"Extract FOLD JSON from LLM response text.\"\"\"\n",
|
| 189 |
-
" m = re.search(r\"```(?:json)?\\s*(\\{.*?\\})\\s*```\", response, re.DOTALL)\n",
|
| 190 |
-
" if m:\n",
|
| 191 |
-
" try: return json.loads(m.group(1))\n",
|
| 192 |
-
" except: pass\n",
|
| 193 |
-
" m = re.search(r'\\{[^{}]*\"vertices_coords\"[^{}]*\\}', response, re.DOTALL)\n",
|
| 194 |
-
" if m:\n",
|
| 195 |
-
" try: return json.loads(m.group(0))\n",
|
| 196 |
-
" except: pass\n",
|
| 197 |
-
" try:\n",
|
| 198 |
-
" d = json.loads(response.strip())\n",
|
| 199 |
-
" if isinstance(d, dict) and \"vertices_coords\" in d: return d\n",
|
| 200 |
-
" except: pass\n",
|
| 201 |
-
" return None\n",
|
| 202 |
-
"\n",
|
| 203 |
-
"\n",
|
| 204 |
-
"def valid_fold_reward(completions, **kwargs):\n",
|
| 205 |
-
" \"\"\"Reward 1 (local): +1.0 valid FOLD structure, -0.5 bad structure, -2.0 unparseable.\"\"\"\n",
|
| 206 |
-
" REQUIRED = {\"vertices_coords\", \"edges_vertices\", \"edges_assignment\"}\n",
|
| 207 |
-
" scores = []\n",
|
| 208 |
-
" for c in completions:\n",
|
| 209 |
-
" fold = extract_fold_json(c[0][\"content\"])\n",
|
| 210 |
-
" if fold is None:\n",
|
| 211 |
-
" scores.append(-2.0)\n",
|
| 212 |
-
" continue\n",
|
| 213 |
-
" # Basic structural checks\n",
|
| 214 |
-
" if not REQUIRED.issubset(fold.keys()):\n",
|
| 215 |
-
" scores.append(-0.5); continue\n",
|
| 216 |
-
" verts = fold[\"vertices_coords\"]\n",
|
| 217 |
-
" edges = fold[\"edges_vertices\"]\n",
|
| 218 |
-
" asgn = fold[\"edges_assignment\"]\n",
|
| 219 |
-
" if len(verts) < 3 or len(edges) < 3 or len(edges) != len(asgn):\n",
|
| 220 |
-
" scores.append(-0.5); continue\n",
|
| 221 |
-
" if not any(a in (\"M\",\"V\") for a in asgn):\n",
|
| 222 |
-
" scores.append(-0.5); continue\n",
|
| 223 |
-
" if not any(a == \"B\" for a in asgn):\n",
|
| 224 |
-
" scores.append(-0.5); continue\n",
|
| 225 |
-
" scores.append(1.0)\n",
|
| 226 |
-
" return scores\n",
|
| 227 |
-
"\n",
|
| 228 |
-
"\n",
|
| 229 |
-
"def openenv_reward(completions, task_name, **kwargs):\n",
|
| 230 |
-
" \"\"\"Reward 2 (OpenEnv API): Submit fold to environment, get simulation reward.\n",
|
| 231 |
-
"\n",
|
| 232 |
-
" Calls POST /reset and POST /step on the HF Space OpenEnv environment.\n",
|
| 233 |
-
" The environment runs the fold simulation and computes shape similarity.\n",
|
| 234 |
-
" \"\"\"\n",
|
| 235 |
-
" global PRINTER\n",
|
| 236 |
-
" # task_name comes as a list from the dataset\n",
|
| 237 |
-
" tn = task_name[0] if isinstance(task_name, list) else task_name\n",
|
| 238 |
-
"\n",
|
| 239 |
-
" scores = []\n",
|
| 240 |
-
" for c in completions:\n",
|
| 241 |
-
" resp = c[0][\"content\"]\n",
|
| 242 |
-
"\n",
|
| 243 |
-
" # Periodic logging\n",
|
| 244 |
-
" if PRINTER % 10 == 0:\n",
|
| 245 |
-
" print(f\"\\n--- [{tn}] Sample {PRINTER} ---\\n{resp[:300]}\")\n",
|
| 246 |
-
" PRINTER += 1\n",
|
| 247 |
-
"\n",
|
| 248 |
-
" # Parse the FOLD JSON from the LLM response\n",
|
| 249 |
-
" fold = extract_fold_json(resp)\n",
|
| 250 |
-
" if fold is None:\n",
|
| 251 |
-
" scores.append(-2.0)\n",
|
| 252 |
-
" continue\n",
|
| 253 |
-
"\n",
|
| 254 |
-
" try:\n",
|
| 255 |
-
" # Reset environment for this task\n",
|
| 256 |
-
" env.reset(task_name=tn)\n",
|
| 257 |
-
"\n",
|
| 258 |
-
" # Submit the fold to OpenEnv — environment simulates and scores it\n",
|
| 259 |
-
" result = env.step(fold)\n",
|
| 260 |
-
"\n",
|
| 261 |
-
" # Get reward from the environment\n",
|
| 262 |
-
" reward = result.get(\"reward\", None)\n",
|
| 263 |
-
" if reward is not None:\n",
|
| 264 |
-
" scores.append(float(reward))\n",
|
| 265 |
-
" else:\n",
|
| 266 |
-
" # Fallback: extract from observation\n",
|
| 267 |
-
" obs = result.get(\"observation\", {})\n",
|
| 268 |
-
" if obs.get(\"error\"):\n",
|
| 269 |
-
" scores.append(-2.0)\n",
|
| 270 |
-
" else:\n",
|
| 271 |
-
" sim = obs.get(\"shape_similarity\", 0.0)\n",
|
| 272 |
-
" scores.append(float(sim) * 20.0)\n",
|
| 273 |
-
"\n",
|
| 274 |
-
" except requests.exceptions.RequestException as e:\n",
|
| 275 |
-
" print(f\"OpenEnv API error: {e}\")\n",
|
| 276 |
-
" scores.append(-1.0)\n",
|
| 277 |
-
" except Exception as e:\n",
|
| 278 |
-
" print(f\"Reward error: {e}\")\n",
|
| 279 |
-
" scores.append(-1.0)\n",
|
| 280 |
-
"\n",
|
| 281 |
-
" return scores\n",
|
| 282 |
-
"\n",
|
| 283 |
-
"\n",
|
| 284 |
-
"print(\"Reward functions ready (valid_fold=local, openenv_reward=API).\")"
|
| 285 |
-
]
|
| 286 |
},
|
| 287 |
{
|
| 288 |
"cell_type": "markdown",
|
|
@@ -533,9 +432,18 @@
|
|
| 533 |
],
|
| 534 |
"metadata": {
|
| 535 |
"accelerator": "GPU",
|
| 536 |
-
"colab": {
|
| 537 |
-
|
| 538 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 539 |
},
|
| 540 |
"nbformat": 4,
|
| 541 |
"nbformat_minor": 0
|
|
|
|
| 181 |
"execution_count": null,
|
| 182 |
"metadata": {},
|
| 183 |
"outputs": [],
|
| 184 |
+
"source": "PRINTER = 0\n\ndef extract_fold_json(response):\n \"\"\"Extract FOLD JSON from LLM response text.\"\"\"\n m = re.search(r\"```(?:json)?\\s*(\\{.*?\\})\\s*```\", response, re.DOTALL)\n if m:\n try: return json.loads(m.group(1))\n except: pass\n m = re.search(r'\\{[^{}]*\"vertices_coords\"[^{}]*\\}', response, re.DOTALL)\n if m:\n try: return json.loads(m.group(0))\n except: pass\n try:\n d = json.loads(response.strip())\n if isinstance(d, dict) and \"vertices_coords\" in d: return d\n except: pass\n return None\n\n\ndef log_to_dashboard(task_name, reward, shape_similarity, is_valid, error=None, fold_data=None, final_positions=None, target_positions=None):\n \"\"\"Send training step data to the frontend dashboard via /training/log.\"\"\"\n try:\n requests.post(\n f\"{OPENENV_URL.replace('/api/env', '')}/training/log\",\n json={\n \"task_name\": task_name,\n \"reward\": reward,\n \"shape_similarity\": shape_similarity,\n \"is_valid\": is_valid,\n \"error\": error,\n \"fold_data\": fold_data,\n \"final_positions\": final_positions or [],\n \"target_positions\": target_positions or [],\n },\n timeout=5,\n )\n except:\n pass # Don't let dashboard logging break training\n\n\ndef valid_fold_reward(completions, **kwargs):\n \"\"\"Reward 1 (local): +1.0 valid FOLD structure, -0.5 bad structure, -2.0 unparseable.\"\"\"\n REQUIRED = {\"vertices_coords\", \"edges_vertices\", \"edges_assignment\"}\n scores = []\n for c in completions:\n fold = extract_fold_json(c[0][\"content\"])\n if fold is None:\n scores.append(-2.0)\n continue\n # Basic structural checks\n if not REQUIRED.issubset(fold.keys()):\n scores.append(-0.5); continue\n verts = fold[\"vertices_coords\"]\n edges = fold[\"edges_vertices\"]\n asgn = fold[\"edges_assignment\"]\n if len(verts) < 3 or len(edges) < 3 or len(edges) != len(asgn):\n scores.append(-0.5); continue\n if not any(a in (\"M\",\"V\") for a in asgn):\n scores.append(-0.5); continue\n if not any(a == \"B\" for a in asgn):\n scores.append(-0.5); continue\n scores.append(1.0)\n return scores\n\n\ndef openenv_reward(completions, task_name, **kwargs):\n \"\"\"Reward 2 (OpenEnv API): Submit fold to environment, get simulation reward.\n\n Calls POST /reset and POST /step on the HF Space OpenEnv environment.\n The environment runs the fold simulation and computes shape similarity.\n Also logs each step to the frontend training dashboard.\n \"\"\"\n global PRINTER\n # task_name comes as a list from the dataset\n tn = task_name[0] if isinstance(task_name, list) else task_name\n\n scores = []\n for c in completions:\n resp = c[0][\"content\"]\n\n # Periodic logging\n if PRINTER % 10 == 0:\n print(f\"\\n--- [{tn}] Sample {PRINTER} ---\\n{resp[:300]}\")\n PRINTER += 1\n\n # Parse the FOLD JSON from the LLM response\n fold = extract_fold_json(resp)\n if fold is None:\n scores.append(-2.0)\n log_to_dashboard(tn, -2.0, 0.0, False, error=\"No JSON parsed\")\n continue\n\n try:\n # Reset environment for this task\n env.reset(task_name=tn)\n\n # Submit the fold to OpenEnv — environment simulates and scores it\n result = env.step(fold)\n\n # Get reward from the environment\n reward = result.get(\"reward\", None)\n obs = result.get(\"observation\", {})\n sim = obs.get(\"shape_similarity\", 0.0)\n is_valid = not bool(obs.get(\"error\"))\n\n if reward is not None:\n scores.append(float(reward))\n else:\n if obs.get(\"error\"):\n scores.append(-2.0)\n else:\n reward = float(sim) * 20.0\n scores.append(reward)\n\n # Log to frontend dashboard\n log_to_dashboard(\n task_name=tn,\n reward=float(reward) if reward is not None else scores[-1],\n shape_similarity=float(sim),\n is_valid=is_valid,\n error=obs.get(\"error\"),\n fold_data=fold,\n final_positions=obs.get(\"final_positions\", []),\n target_positions=obs.get(\"target_positions\", []),\n )\n\n except requests.exceptions.RequestException as e:\n print(f\"OpenEnv API error: {e}\")\n scores.append(-1.0)\n log_to_dashboard(tn, -1.0, 0.0, False, error=str(e))\n except Exception as e:\n print(f\"Reward error: {e}\")\n scores.append(-1.0)\n log_to_dashboard(tn, -1.0, 0.0, False, error=str(e))\n\n return scores\n\n\nprint(\"Reward functions ready (valid_fold=local, openenv_reward=API + dashboard logging).\")"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
},
|
| 186 |
{
|
| 187 |
"cell_type": "markdown",
|
|
|
|
| 432 |
],
|
| 433 |
"metadata": {
|
| 434 |
"accelerator": "GPU",
|
| 435 |
+
"colab": {
|
| 436 |
+
"gpuType": "T4",
|
| 437 |
+
"provenance": []
|
| 438 |
+
},
|
| 439 |
+
"kernelspec": {
|
| 440 |
+
"display_name": "Python 3",
|
| 441 |
+
"name": "python3"
|
| 442 |
+
},
|
| 443 |
+
"language_info": {
|
| 444 |
+
"name": "python",
|
| 445 |
+
"version": "3.11.0"
|
| 446 |
+
}
|
| 447 |
},
|
| 448 |
"nbformat": 4,
|
| 449 |
"nbformat_minor": 0
|