sissississi Claude Opus 4.6 commited on
Commit
d662461
·
1 Parent(s): c0cedb4

Redesign frontend as training dashboard + add live activity feed

Browse files

- Replace manual origami editor with training-focused UI
- TrainingDashboard polls /training/feed for live step data
- 3D canvas shows selected training entry's fold attempt
- Backend: add /training/feed and /training/log endpoints
- Notebook: log each env.step() to dashboard via /training/log
- Add API proxy rewrite and turbopack config

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

app/page.tsx CHANGED
@@ -1,118 +1,96 @@
1
  'use client';
2
 
3
- import { useState, useRef } from 'react';
4
  import { Canvas } from '@react-three/fiber';
5
- import { OrbitControls, Environment, Grid } from '@react-three/drei';
6
  import { patterns, Pattern } from '@/lib/patterns';
7
  import { OrigamiMesh } from '@/components/OrigamiMesh';
8
- import { Github, RefreshCw, Upload } from 'lucide-react';
9
  import { parseFoldFile } from '@/lib/foldParser';
10
- import { LLMPrompt } from '@/components/LLMPrompt';
11
 
12
  export default function Optigami() {
13
- const [customPatterns, setCustomPatterns] = useState<Pattern[]>([]);
14
- const [selectedPatternId, setSelectedPatternId] = useState(patterns[0].id);
15
- const [foldPercent, setFoldPercent] = useState(0);
16
- const [key, setKey] = useState(0); // Used to force reset the simulation
17
- const fileInputRef = useRef<HTMLInputElement>(null);
18
 
19
- const allPatterns = [...patterns, ...customPatterns];
20
- const pattern = allPatterns.find(p => p.id === selectedPatternId) || allPatterns[0];
21
-
22
- const handleReset = () => {
23
- setFoldPercent(0);
24
- setKey(k => k + 1);
25
- };
26
-
27
- const handleFileUpload = (e: React.ChangeEvent<HTMLInputElement>) => {
28
- const file = e.target.files?.[0];
29
- if (!file) return;
30
-
31
- const reader = new FileReader();
32
- reader.onload = (event) => {
33
- const content = event.target?.result as string;
34
- const parsedPattern = parseFoldFile(content, file.name);
35
- if (parsedPattern) {
36
- setCustomPatterns(prev => [...prev, parsedPattern]);
37
- setSelectedPatternId(parsedPattern.id);
38
- setFoldPercent(0);
39
- setKey(k => k + 1);
40
- } else {
41
- alert("Failed to parse .fold file. Please ensure it's a valid FOLD format with vertices_coords and faces_vertices.");
42
- }
43
- };
44
- reader.readAsText(file);
45
-
46
- if (fileInputRef.current) {
47
- fileInputRef.current.value = '';
48
  }
49
- };
 
 
 
50
 
51
- const handlePatternGenerated = (newPattern: Pattern) => {
52
- setCustomPatterns(prev => [...prev, newPattern]);
53
- setSelectedPatternId(newPattern.id);
54
- setFoldPercent(0);
55
  setKey(k => k + 1);
56
- };
57
 
58
  return (
59
  <div className="flex h-screen w-full bg-zinc-950 text-zinc-100 font-sans overflow-hidden">
60
- {/* Left Sidebar */}
61
- <div className="w-80 flex-shrink-0 border-r border-zinc-800 bg-zinc-900 flex flex-col">
62
- <div className="p-6 border-b border-zinc-800">
63
- <h1 className="text-xl font-semibold tracking-tight mb-2">Optigami</h1>
64
- <p className="text-sm text-zinc-400">
65
- Optigami
66
- </p>
 
 
 
 
 
 
 
 
 
 
67
  </div>
68
 
69
- <div className="p-6 flex-1 overflow-y-auto flex flex-col gap-8">
70
- {/* Controls */}
71
- <div className="space-y-4">
 
 
 
 
 
 
 
72
  <div>
73
- <label className="block text-xs font-medium text-zinc-400 uppercase tracking-wider mb-2">
74
- Crease Pattern
75
- </label>
76
- <div className="flex gap-2">
77
- <select
78
- className="flex-1 bg-zinc-800 border border-zinc-700 rounded-lg px-3 py-2 text-sm focus:outline-none focus:ring-2 focus:ring-indigo-500"
79
- value={selectedPatternId}
80
- onChange={(e) => {
81
- setSelectedPatternId(e.target.value);
82
- setFoldPercent(0);
83
- setKey(k => k + 1);
84
- }}
85
- >
86
- {allPatterns.map(p => (
87
- <option key={p.id} value={p.id}>{p.name}</option>
88
- ))}
89
- </select>
90
- <button
91
- onClick={() => fileInputRef.current?.click()}
92
- className="bg-zinc-800 hover:bg-zinc-700 p-2 rounded-lg border border-zinc-700 transition-colors flex-shrink-0 text-zinc-300"
93
- title="Upload .fold file"
94
- >
95
- <Upload size={18} />
96
- </button>
97
- <input
98
- type="file"
99
- accept=".fold"
100
- className="hidden"
101
- ref={fileInputRef}
102
- onChange={handleFileUpload}
103
- />
104
- </div>
105
  </div>
106
-
107
  <div>
108
- <div className="flex justify-between items-center mb-2">
109
- <label className="block text-xs font-medium text-zinc-400 uppercase tracking-wider">
110
- Fold Angle
111
- </label>
112
- <span className="text-xs text-zinc-500 font-mono">
113
- {Math.round(foldPercent * 100)}%
114
- </span>
115
- </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  <input
117
  type="range"
118
  min="0"
@@ -120,37 +98,57 @@ export default function Optigami() {
120
  step="0.01"
121
  value={foldPercent}
122
  onChange={(e) => setFoldPercent(parseFloat(e.target.value))}
123
- className="w-full accent-indigo-500"
124
  />
 
125
  </div>
126
-
127
- <button
128
- onClick={handleReset}
129
- className="w-full flex items-center justify-center gap-2 bg-zinc-800 hover:bg-zinc-700 text-sm py-2 rounded-lg transition-colors border border-zinc-700"
130
- >
131
- <RefreshCw size={14} />
132
- Reset Simulation
133
- </button>
134
-
135
- <LLMPrompt onPatternGenerated={handlePatternGenerated} />
136
  </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
- {/* 2D View */}
139
- <div className="mt-auto">
140
- <label className="block text-xs font-medium text-zinc-400 uppercase tracking-wider mb-2">
141
- 2D Crease Pattern
142
- </label>
143
- <div className="bg-zinc-800 rounded-lg p-4 aspect-square flex items-center justify-center border border-zinc-700 relative">
 
 
 
 
 
 
 
 
144
  <svg viewBox="-1.2 -1.2 2.4 2.4" className="w-full h-full">
145
  <g transform="scale(1, -1)">
146
- {/* Draw faces */}
147
- {pattern.faces.map((face, i) => {
148
- const v1 = pattern.vertices[face[0]];
149
- const v2 = pattern.vertices[face[1]];
150
- const v3 = pattern.vertices[face[2]];
151
  return (
152
  <polygon
153
- key={`face-${i}`}
154
  points={`${v1[0]},${v1[1]} ${v2[0]},${v2[1]} ${v3[0]},${v3[1]}`}
155
  fill="#3f3f46"
156
  stroke="#52525b"
@@ -158,18 +156,15 @@ export default function Optigami() {
158
  />
159
  );
160
  })}
161
- {/* Draw creases */}
162
- {pattern.creases.map((crease, i) => {
163
- const v1 = pattern.vertices[crease.edge[0]];
164
- const v2 = pattern.vertices[crease.edge[1]];
165
  const color = crease.type === 'mountain' ? '#ef4444' : '#3b82f6';
166
  return (
167
  <line
168
- key={`crease-${i}`}
169
- x1={v1[0]}
170
- y1={v1[1]}
171
- x2={v2[0]}
172
- y2={v2[1]}
173
  stroke={color}
174
  strokeWidth="0.03"
175
  strokeLinecap="round"
@@ -178,40 +173,10 @@ export default function Optigami() {
178
  })}
179
  </g>
180
  </svg>
181
- <div className="absolute bottom-2 left-2 flex gap-3 text-[10px] uppercase font-mono text-zinc-500">
182
- <div className="flex items-center gap-1">
183
- <div className="w-2 h-0.5 bg-red-500"></div> Mountain
184
- </div>
185
- <div className="flex items-center gap-1">
186
- <div className="w-2 h-0.5 bg-blue-500"></div> Valley
187
- </div>
188
- </div>
189
  </div>
190
- </div>
191
  </div>
192
  </div>
193
-
194
- {/* 3D Canvas */}
195
- <div className="flex-1 relative bg-zinc-950">
196
- <Canvas camera={{ position: [0, 0, 3], fov: 45 }}>
197
- <ambientLight intensity={0.5} />
198
- <directionalLight position={[5, 5, 5]} intensity={1} castShadow />
199
- <directionalLight position={[-5, -5, -5]} intensity={0.2} />
200
-
201
- <group key={key}>
202
- <OrigamiMesh pattern={pattern} foldPercent={foldPercent} />
203
- </group>
204
-
205
- <OrbitControls makeDefault />
206
- <Grid
207
- infiniteGrid
208
- fadeDistance={10}
209
- sectionColor="#333"
210
- cellColor="#222"
211
- position={[0, 0, -0.01]}
212
- />
213
- </Canvas>
214
- </div>
215
  </div>
216
  );
217
  }
 
1
  'use client';
2
 
3
+ import { useState, useCallback, useMemo } from 'react';
4
  import { Canvas } from '@react-three/fiber';
5
+ import { OrbitControls, Grid } from '@react-three/drei';
6
  import { patterns, Pattern } from '@/lib/patterns';
7
  import { OrigamiMesh } from '@/components/OrigamiMesh';
8
+ import { TrainingDashboard, TrainingEntry } from '@/components/TrainingDashboard';
9
  import { parseFoldFile } from '@/lib/foldParser';
 
10
 
11
  export default function Optigami() {
12
+ const [selectedEntry, setSelectedEntry] = useState<TrainingEntry | null>(null);
13
+ const [foldPercent, setFoldPercent] = useState(1);
14
+ const [key, setKey] = useState(0);
 
 
15
 
16
+ // Convert training entry's fold_data into a Pattern for the 3D viewer
17
+ const activePattern: Pattern | null = useMemo(() => {
18
+ if (!selectedEntry?.fold_data) return null;
19
+ try {
20
+ const parsed = parseFoldFile(JSON.stringify(selectedEntry.fold_data), selectedEntry.task_name);
21
+ return parsed;
22
+ } catch {
23
+ return null;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  }
25
+ }, [selectedEntry]);
26
+
27
+ // Fallback to a default pattern when no training entry is selected
28
+ const displayPattern = activePattern || patterns[0];
29
 
30
+ const handleEntrySelect = useCallback((entry: TrainingEntry) => {
31
+ setSelectedEntry(entry);
32
+ setFoldPercent(1);
 
33
  setKey(k => k + 1);
34
+ }, []);
35
 
36
  return (
37
  <div className="flex h-screen w-full bg-zinc-950 text-zinc-100 font-sans overflow-hidden">
38
+ {/* Left panel — Training Dashboard */}
39
+ <div className="w-96 flex-shrink-0 border-r border-zinc-800 bg-zinc-900 flex flex-col">
40
+ <div className="px-5 py-4 border-b border-zinc-800 flex items-center justify-between">
41
+ <div>
42
+ <h1 className="text-lg font-semibold tracking-tight">Optigami</h1>
43
+ <p className="text-[11px] text-zinc-500">RL Training Environment</p>
44
+ </div>
45
+ <div className="flex items-center gap-2">
46
+ <a
47
+ href="https://huggingface.co/spaces/openenv-community/optigami_"
48
+ target="_blank"
49
+ rel="noopener noreferrer"
50
+ className="text-[10px] text-zinc-500 hover:text-zinc-300 bg-zinc-800 px-2 py-1 rounded border border-zinc-700"
51
+ >
52
+ OpenEnv 0.2.1
53
+ </a>
54
+ </div>
55
  </div>
56
 
57
+ <div className="flex-1 overflow-y-auto p-4">
58
+ <TrainingDashboard onEntrySelect={handleEntrySelect} />
59
+ </div>
60
+ </div>
61
+
62
+ {/* Right side — 3D viewer + detail */}
63
+ <div className="flex-1 flex flex-col">
64
+ {/* Top bar with context about selected entry */}
65
+ {selectedEntry && (
66
+ <div className="flex-shrink-0 border-b border-zinc-800 bg-zinc-900/50 px-6 py-3 flex items-center gap-6 text-xs">
67
  <div>
68
+ <span className="text-zinc-500">Step</span>{' '}
69
+ <span className="font-mono text-zinc-200">#{selectedEntry.step}</span>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  </div>
 
71
  <div>
72
+ <span className="text-zinc-500">Task</span>{' '}
73
+ <span className="text-zinc-200">{selectedEntry.task_name}</span>
74
+ </div>
75
+ <div>
76
+ <span className="text-zinc-500">Reward</span>{' '}
77
+ <span className={`font-mono font-semibold ${
78
+ selectedEntry.reward >= 15 ? 'text-green-400' :
79
+ selectedEntry.reward >= 5 ? 'text-yellow-400' :
80
+ selectedEntry.reward >= 0 ? 'text-orange-400' : 'text-red-400'
81
+ }`}>{selectedEntry.reward.toFixed(2)}</span>
82
+ </div>
83
+ <div>
84
+ <span className="text-zinc-500">Similarity</span>{' '}
85
+ <span className="font-mono text-indigo-400">
86
+ {(selectedEntry.shape_similarity * 100).toFixed(1)}%
87
+ </span>
88
+ </div>
89
+ {selectedEntry.error && (
90
+ <div className="text-red-400/80 truncate flex-1">{selectedEntry.error}</div>
91
+ )}
92
+ <div className="ml-auto flex items-center gap-2">
93
+ <span className="text-zinc-500">Fold</span>
94
  <input
95
  type="range"
96
  min="0"
 
98
  step="0.01"
99
  value={foldPercent}
100
  onChange={(e) => setFoldPercent(parseFloat(e.target.value))}
101
+ className="w-24 accent-indigo-500"
102
  />
103
+ <span className="font-mono text-zinc-400 w-8">{Math.round(foldPercent * 100)}%</span>
104
  </div>
 
 
 
 
 
 
 
 
 
 
105
  </div>
106
+ )}
107
+
108
+ {/* 3D Canvas */}
109
+ <div className="flex-1 relative">
110
+ {!selectedEntry && (
111
+ <div className="absolute inset-0 flex items-center justify-center z-10 pointer-events-none">
112
+ <div className="text-center">
113
+ <div className="text-zinc-500 text-sm mb-1">Waiting for training data</div>
114
+ <div className="text-zinc-600 text-xs">
115
+ Start a GRPO training run in the Colab notebook
116
+ </div>
117
+ </div>
118
+ </div>
119
+ )}
120
+ <Canvas camera={{ position: [0, 0, 3], fov: 45 }}>
121
+ <ambientLight intensity={0.5} />
122
+ <directionalLight position={[5, 5, 5]} intensity={1} castShadow />
123
+ <directionalLight position={[-5, -5, -5]} intensity={0.2} />
124
+
125
+ <group key={key}>
126
+ <OrigamiMesh pattern={displayPattern} foldPercent={foldPercent} />
127
+ </group>
128
 
129
+ <OrbitControls makeDefault />
130
+ <Grid
131
+ infiniteGrid
132
+ fadeDistance={10}
133
+ sectionColor="#333"
134
+ cellColor="#222"
135
+ position={[0, 0, -0.01]}
136
+ />
137
+ </Canvas>
138
+
139
+ {/* 2D crease pattern overlay */}
140
+ {activePattern && (
141
+ <div className="absolute bottom-4 left-4 w-40 h-40 bg-zinc-900/90 rounded-lg border border-zinc-700/50 p-2 backdrop-blur-sm">
142
+ <div className="text-[9px] uppercase text-zinc-500 tracking-wider mb-1">Crease Pattern</div>
143
  <svg viewBox="-1.2 -1.2 2.4 2.4" className="w-full h-full">
144
  <g transform="scale(1, -1)">
145
+ {activePattern.faces.map((face, i) => {
146
+ const v1 = activePattern.vertices[face[0]];
147
+ const v2 = activePattern.vertices[face[1]];
148
+ const v3 = activePattern.vertices[face[2]];
 
149
  return (
150
  <polygon
151
+ key={`f-${i}`}
152
  points={`${v1[0]},${v1[1]} ${v2[0]},${v2[1]} ${v3[0]},${v3[1]}`}
153
  fill="#3f3f46"
154
  stroke="#52525b"
 
156
  />
157
  );
158
  })}
159
+ {activePattern.creases.map((crease, i) => {
160
+ const v1 = activePattern.vertices[crease.edge[0]];
161
+ const v2 = activePattern.vertices[crease.edge[1]];
 
162
  const color = crease.type === 'mountain' ? '#ef4444' : '#3b82f6';
163
  return (
164
  <line
165
+ key={`c-${i}`}
166
+ x1={v1[0]} y1={v1[1]}
167
+ x2={v2[0]} y2={v2[1]}
 
 
168
  stroke={color}
169
  strokeWidth="0.03"
170
  strokeLinecap="round"
 
173
  })}
174
  </g>
175
  </svg>
 
 
 
 
 
 
 
 
176
  </div>
177
+ )}
178
  </div>
179
  </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  </div>
181
  );
182
  }
components/TrainingDashboard.tsx ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 'use client';
2
+
3
+ import { useState, useEffect, useRef, useCallback } from 'react';
4
+ import { Activity, TrendingUp, AlertCircle, CheckCircle2, Circle } from 'lucide-react';
5
+
6
+ export interface TrainingEntry {
7
+ step: number;
8
+ timestamp: number;
9
+ task_name: string;
10
+ reward: number;
11
+ shape_similarity: number;
12
+ is_valid: boolean;
13
+ error: string | null;
14
+ fold_data: any | null;
15
+ final_positions: number[][];
16
+ target_positions: number[][];
17
+ }
18
+
19
+ interface TrainingStats {
20
+ total_steps: number;
21
+ best_reward: number;
22
+ best_similarity: number;
23
+ }
24
+
25
+ interface FeedResponse {
26
+ entries: TrainingEntry[];
27
+ stats: TrainingStats;
28
+ }
29
+
30
+ interface TrainingDashboardProps {
31
+ onEntrySelect?: (entry: TrainingEntry) => void;
32
+ }
33
+
34
+ export function TrainingDashboard({ onEntrySelect }: TrainingDashboardProps) {
35
+ const [entries, setEntries] = useState<TrainingEntry[]>([]);
36
+ const [stats, setStats] = useState<TrainingStats>({ total_steps: 0, best_reward: -999, best_similarity: 0 });
37
+ const [connected, setConnected] = useState(false);
38
+ const [error, setError] = useState<string | null>(null);
39
+ const [selectedStep, setSelectedStep] = useState<number | null>(null);
40
+ const lastStep = useRef(0);
41
+ const feedRef = useRef<HTMLDivElement>(null);
42
+
43
+ // Auto-select newest entry
44
+ const autoSelect = useRef(true);
45
+
46
+ useEffect(() => {
47
+ let active = true;
48
+
49
+ async function poll() {
50
+ while (active) {
51
+ try {
52
+ const res = await fetch(`/api/env/training/feed?since=${lastStep.current}`);
53
+ if (!res.ok) throw new Error(`HTTP ${res.status}`);
54
+ const data: FeedResponse = await res.json();
55
+
56
+ if (data.entries.length > 0) {
57
+ setEntries(prev => {
58
+ const combined = [...prev, ...data.entries];
59
+ return combined.slice(-50);
60
+ });
61
+ lastStep.current = data.entries[data.entries.length - 1].step;
62
+
63
+ // Auto-select latest
64
+ if (autoSelect.current) {
65
+ const latest = data.entries[data.entries.length - 1];
66
+ setSelectedStep(latest.step);
67
+ onEntrySelect?.(latest);
68
+ }
69
+ }
70
+ setStats(data.stats);
71
+ setConnected(true);
72
+ setError(null);
73
+ } catch (e: any) {
74
+ setConnected(false);
75
+ setError(e.message);
76
+ }
77
+ await new Promise(r => setTimeout(r, 2000));
78
+ }
79
+ }
80
+
81
+ poll();
82
+ return () => { active = false; };
83
+ }, [onEntrySelect]);
84
+
85
+ // Auto-scroll feed
86
+ useEffect(() => {
87
+ if (feedRef.current && autoSelect.current) {
88
+ feedRef.current.scrollTop = feedRef.current.scrollHeight;
89
+ }
90
+ }, [entries]);
91
+
92
+ const handleEntryClick = useCallback((entry: TrainingEntry) => {
93
+ autoSelect.current = false;
94
+ setSelectedStep(entry.step);
95
+ onEntrySelect?.(entry);
96
+ }, [onEntrySelect]);
97
+
98
+ const rewardColor = (r: number) => {
99
+ if (r >= 15) return 'text-green-400';
100
+ if (r >= 5) return 'text-yellow-400';
101
+ if (r >= 0) return 'text-orange-400';
102
+ return 'text-red-400';
103
+ };
104
+
105
+ const rewardBg = (r: number) => {
106
+ if (r >= 15) return 'bg-green-500';
107
+ if (r >= 5) return 'bg-yellow-500';
108
+ if (r >= 0) return 'bg-orange-500';
109
+ return 'bg-red-500';
110
+ };
111
+
112
+ const simBar = (sim: number) => {
113
+ const pct = Math.min(sim * 100, 100);
114
+ const color = pct > 70 ? 'bg-green-500' : pct > 40 ? 'bg-yellow-500' : 'bg-red-500';
115
+ return (
116
+ <div className="w-full h-1.5 bg-zinc-700 rounded-full overflow-hidden">
117
+ <div className={`h-full ${color} rounded-full transition-all`} style={{ width: `${pct}%` }} />
118
+ </div>
119
+ );
120
+ };
121
+
122
+ // Reward chart data
123
+ const recentRewards = entries.slice(-30);
124
+ const maxR = Math.max(...recentRewards.map(e => e.reward), 1);
125
+ const minR = Math.min(...recentRewards.map(e => e.reward), 0);
126
+ const range = maxR - minR || 1;
127
+
128
+ // Task breakdown
129
+ const taskCounts: Record<string, { count: number; avgReward: number; totalReward: number }> = {};
130
+ for (const e of entries) {
131
+ if (!taskCounts[e.task_name]) taskCounts[e.task_name] = { count: 0, avgReward: 0, totalReward: 0 };
132
+ taskCounts[e.task_name].count++;
133
+ taskCounts[e.task_name].totalReward += e.reward;
134
+ taskCounts[e.task_name].avgReward = taskCounts[e.task_name].totalReward / taskCounts[e.task_name].count;
135
+ }
136
+
137
+ return (
138
+ <div className="flex flex-col gap-3 h-full">
139
+ {/* Connection status */}
140
+ <div className="flex items-center gap-2 text-xs">
141
+ <div className={`w-2 h-2 rounded-full ${connected ? 'bg-green-500 animate-pulse' : 'bg-zinc-600'}`} />
142
+ <span className="text-zinc-400">
143
+ {connected ? 'Live' : 'Connecting...'}
144
+ </span>
145
+ {error && <span className="text-red-400/80 ml-auto truncate max-w-[140px]">{error}</span>}
146
+ </div>
147
+
148
+ {/* Stats row */}
149
+ <div className="grid grid-cols-3 gap-2">
150
+ <div className="bg-zinc-800/80 rounded-lg p-2.5 border border-zinc-700/60">
151
+ <div className="text-[10px] uppercase text-zinc-500 tracking-wider">Steps</div>
152
+ <div className="text-xl font-mono font-bold text-zinc-100">{stats.total_steps}</div>
153
+ </div>
154
+ <div className="bg-zinc-800/80 rounded-lg p-2.5 border border-zinc-700/60">
155
+ <div className="text-[10px] uppercase text-zinc-500 tracking-wider">Best Reward</div>
156
+ <div className={`text-xl font-mono font-bold ${rewardColor(stats.best_reward)}`}>
157
+ {stats.best_reward > -999 ? stats.best_reward.toFixed(1) : '--'}
158
+ </div>
159
+ </div>
160
+ <div className="bg-zinc-800/80 rounded-lg p-2.5 border border-zinc-700/60">
161
+ <div className="text-[10px] uppercase text-zinc-500 tracking-wider">Best Sim</div>
162
+ <div className="text-xl font-mono font-bold text-indigo-400">
163
+ {stats.best_similarity > 0 ? (stats.best_similarity * 100).toFixed(0) + '%' : '--'}
164
+ </div>
165
+ </div>
166
+ </div>
167
+
168
+ {/* Reward trend chart */}
169
+ {recentRewards.length > 1 && (
170
+ <div className="bg-zinc-800/80 rounded-lg p-3 border border-zinc-700/60">
171
+ <div className="text-[10px] uppercase text-zinc-500 tracking-wider mb-2 flex items-center gap-1">
172
+ <TrendingUp size={10} /> Reward History
173
+ </div>
174
+ <div className="flex items-end gap-[2px] h-20">
175
+ {recentRewards.map((e, i) => {
176
+ const h = ((e.reward - minR) / range) * 100;
177
+ const isSelected = e.step === selectedStep;
178
+ return (
179
+ <div
180
+ key={e.step}
181
+ className={`flex-1 rounded-t cursor-pointer transition-all ${isSelected ? 'ring-1 ring-white' : ''}`}
182
+ style={{
183
+ height: `${Math.max(h, 3)}%`,
184
+ backgroundColor: e.reward >= 15 ? '#22c55e' : e.reward >= 5 ? '#eab308' : e.reward >= 0 ? '#f97316' : '#ef4444',
185
+ opacity: isSelected ? 1 : 0.4 + (i / recentRewards.length) * 0.5,
186
+ }}
187
+ title={`#${e.step} ${e.task_name}: ${e.reward.toFixed(2)}`}
188
+ onClick={() => handleEntryClick(e)}
189
+ />
190
+ );
191
+ })}
192
+ </div>
193
+ <div className="flex justify-between text-[9px] text-zinc-600 mt-1 font-mono">
194
+ <span>{minR.toFixed(1)}</span>
195
+ <span>{maxR.toFixed(1)}</span>
196
+ </div>
197
+ </div>
198
+ )}
199
+
200
+ {/* Task breakdown */}
201
+ {Object.keys(taskCounts).length > 1 && (
202
+ <div className="bg-zinc-800/80 rounded-lg p-3 border border-zinc-700/60">
203
+ <div className="text-[10px] uppercase text-zinc-500 tracking-wider mb-2">Tasks</div>
204
+ <div className="space-y-1">
205
+ {Object.entries(taskCounts).map(([name, data]) => (
206
+ <div key={name} className="flex items-center gap-2 text-[11px]">
207
+ <span className="text-zinc-400 w-24 truncate">{name}</span>
208
+ <div className="flex-1 h-1 bg-zinc-700 rounded-full overflow-hidden">
209
+ <div
210
+ className={`h-full rounded-full ${rewardBg(data.avgReward)}`}
211
+ style={{ width: `${Math.max((data.avgReward + 5) / 25 * 100, 2)}%`, opacity: 0.7 }}
212
+ />
213
+ </div>
214
+ <span className={`font-mono w-10 text-right ${rewardColor(data.avgReward)}`}>
215
+ {data.avgReward.toFixed(1)}
216
+ </span>
217
+ <span className="text-zinc-600 font-mono w-6 text-right">{data.count}</span>
218
+ </div>
219
+ ))}
220
+ </div>
221
+ </div>
222
+ )}
223
+
224
+ {/* Activity feed */}
225
+ <div className="flex-1 min-h-0 flex flex-col">
226
+ <div className="text-[10px] uppercase text-zinc-500 tracking-wider mb-2 flex items-center justify-between">
227
+ <span className="flex items-center gap-1"><Activity size={10} /> Activity</span>
228
+ {!autoSelect.current && entries.length > 0 && (
229
+ <button
230
+ className="text-indigo-400 hover:text-indigo-300 normal-case tracking-normal"
231
+ onClick={() => {
232
+ autoSelect.current = true;
233
+ const latest = entries[entries.length - 1];
234
+ setSelectedStep(latest.step);
235
+ onEntrySelect?.(latest);
236
+ }}
237
+ >
238
+ Follow latest
239
+ </button>
240
+ )}
241
+ </div>
242
+ <div ref={feedRef} className="overflow-y-auto flex-1 space-y-1 pr-1">
243
+ {entries.length === 0 ? (
244
+ <div className="text-xs text-zinc-600 text-center py-12">
245
+ <div className="text-zinc-500 mb-1">No training activity yet</div>
246
+ <div>Start a training run in the Colab notebook</div>
247
+ </div>
248
+ ) : (
249
+ entries.map(e => (
250
+ <div
251
+ key={e.step}
252
+ onClick={() => handleEntryClick(e)}
253
+ className={`rounded-lg px-2.5 py-1.5 text-xs cursor-pointer transition-all border ${
254
+ e.step === selectedStep
255
+ ? 'bg-zinc-700/80 border-indigo-500/50'
256
+ : 'bg-zinc-800/40 border-zinc-700/30 hover:bg-zinc-800/70'
257
+ }`}
258
+ >
259
+ <div className="flex items-center justify-between">
260
+ <div className="flex items-center gap-1.5">
261
+ <span className="font-mono text-zinc-500 text-[10px]">#{e.step}</span>
262
+ <span className="text-zinc-300 text-[10px] truncate max-w-[80px]">{e.task_name}</span>
263
+ </div>
264
+ <div className="flex items-center gap-1.5">
265
+ {e.error ? (
266
+ <Circle size={6} className="fill-red-400 text-red-400" />
267
+ ) : e.is_valid ? (
268
+ <Circle size={6} className="fill-green-400 text-green-400" />
269
+ ) : (
270
+ <Circle size={6} className="fill-yellow-400 text-yellow-400" />
271
+ )}
272
+ <span className={`font-mono font-semibold ${rewardColor(e.reward)}`}>
273
+ {e.reward.toFixed(1)}
274
+ </span>
275
+ </div>
276
+ </div>
277
+ {e.step === selectedStep && (
278
+ <div className="mt-1.5 space-y-1">
279
+ <div className="flex items-center gap-2">
280
+ <span className="text-zinc-500 text-[10px] w-8">sim</span>
281
+ {simBar(e.shape_similarity)}
282
+ <span className="text-zinc-400 font-mono text-[10px] w-8 text-right">
283
+ {(e.shape_similarity * 100).toFixed(0)}%
284
+ </span>
285
+ </div>
286
+ {e.error && (
287
+ <div className="text-red-400/70 text-[10px] truncate">{e.error}</div>
288
+ )}
289
+ </div>
290
+ )}
291
+ </div>
292
+ ))
293
+ )}
294
+ </div>
295
+ </div>
296
+ </div>
297
+ );
298
+ }
next.config.ts CHANGED
@@ -29,6 +29,7 @@ const nextConfig: NextConfig = {
29
  ];
30
  },
31
  transpilePackages: ['motion'],
 
32
  webpack: (config, {dev}) => {
33
  // HMR is disabled in AI Studio via DISABLE_HMR env var.
34
  // Do not modify—file watching is disabled to prevent flickering during agent edits.
 
29
  ];
30
  },
31
  transpilePackages: ['motion'],
32
+ turbopack: {},
33
  webpack: (config, {dev}) => {
34
  // HMR is disabled in AI Studio via DISABLE_HMR env var.
35
  // Do not modify—file watching is disabled to prevent flickering during agent edits.
origami_server/app.py CHANGED
@@ -1,9 +1,12 @@
1
  """FastAPI entry point — OpenEnv create_app() + custom endpoints."""
2
 
3
  import os
 
 
4
  from pathlib import Path
5
 
6
  from fastapi import HTTPException
 
7
  from fastapi.responses import HTMLResponse
8
 
9
  from openenv.core.env_server.http_server import create_app
@@ -19,6 +22,59 @@ app = create_app(
19
  env_name="origami_env",
20
  )
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  @app.get("/tasks")
24
  def get_tasks():
 
1
  """FastAPI entry point — OpenEnv create_app() + custom endpoints."""
2
 
3
  import os
4
+ import time
5
+ from collections import deque
6
  from pathlib import Path
7
 
8
  from fastapi import HTTPException
9
+ from fastapi.middleware.cors import CORSMiddleware
10
  from fastapi.responses import HTMLResponse
11
 
12
  from openenv.core.env_server.http_server import create_app
 
22
  env_name="origami_env",
23
  )
24
 
25
+ # Allow CORS for frontend polling
26
+ app.add_middleware(
27
+ CORSMiddleware,
28
+ allow_origins=["*"],
29
+ allow_methods=["*"],
30
+ allow_headers=["*"],
31
+ )
32
+
33
+ # ── Training Activity Feed ───────────────────────────────────────────────────
34
+ # Ring buffer of recent training steps — the frontend polls this to visualize
35
+ # what's happening during GRPO training.
36
+
37
+ ACTIVITY_FEED: deque = deque(maxlen=50) # Last 50 steps
38
+ TRAINING_STATS: dict = {"total_steps": 0, "best_reward": -999, "best_similarity": 0}
39
+
40
+
41
+ @app.get("/training/feed")
42
+ def get_training_feed(since: int = 0):
43
+ """Get recent training activity. Pass `since=<step>` to get only new entries."""
44
+ entries = [e for e in ACTIVITY_FEED if e["step"] > since]
45
+ return {"entries": entries, "stats": TRAINING_STATS}
46
+
47
+
48
+ @app.post("/training/log")
49
+ def log_training_step(data: dict):
50
+ """Log a training step from the notebook. Called after each env.step()."""
51
+ step = TRAINING_STATS["total_steps"] + 1
52
+ TRAINING_STATS["total_steps"] = step
53
+
54
+ entry = {
55
+ "step": step,
56
+ "timestamp": time.time(),
57
+ "task_name": data.get("task_name", ""),
58
+ "reward": data.get("reward", 0),
59
+ "shape_similarity": data.get("shape_similarity", 0),
60
+ "is_valid": data.get("is_valid", False),
61
+ "error": data.get("error", None),
62
+ "fold_data": data.get("fold_data", None),
63
+ "final_positions": data.get("final_positions", []),
64
+ "target_positions": data.get("target_positions", []),
65
+ }
66
+
67
+ ACTIVITY_FEED.append(entry)
68
+
69
+ reward = entry["reward"]
70
+ sim = entry["shape_similarity"]
71
+ if reward > TRAINING_STATS["best_reward"]:
72
+ TRAINING_STATS["best_reward"] = reward
73
+ if sim > TRAINING_STATS["best_similarity"]:
74
+ TRAINING_STATS["best_similarity"] = sim
75
+
76
+ return {"step": step}
77
+
78
 
79
  @app.get("/tasks")
80
  def get_tasks():
training/train_grpo.ipynb CHANGED
@@ -181,108 +181,7 @@
181
  "execution_count": null,
182
  "metadata": {},
183
  "outputs": [],
184
- "source": [
185
- "PRINTER = 0\n",
186
- "\n",
187
- "def extract_fold_json(response):\n",
188
- " \"\"\"Extract FOLD JSON from LLM response text.\"\"\"\n",
189
- " m = re.search(r\"```(?:json)?\\s*(\\{.*?\\})\\s*```\", response, re.DOTALL)\n",
190
- " if m:\n",
191
- " try: return json.loads(m.group(1))\n",
192
- " except: pass\n",
193
- " m = re.search(r'\\{[^{}]*\"vertices_coords\"[^{}]*\\}', response, re.DOTALL)\n",
194
- " if m:\n",
195
- " try: return json.loads(m.group(0))\n",
196
- " except: pass\n",
197
- " try:\n",
198
- " d = json.loads(response.strip())\n",
199
- " if isinstance(d, dict) and \"vertices_coords\" in d: return d\n",
200
- " except: pass\n",
201
- " return None\n",
202
- "\n",
203
- "\n",
204
- "def valid_fold_reward(completions, **kwargs):\n",
205
- " \"\"\"Reward 1 (local): +1.0 valid FOLD structure, -0.5 bad structure, -2.0 unparseable.\"\"\"\n",
206
- " REQUIRED = {\"vertices_coords\", \"edges_vertices\", \"edges_assignment\"}\n",
207
- " scores = []\n",
208
- " for c in completions:\n",
209
- " fold = extract_fold_json(c[0][\"content\"])\n",
210
- " if fold is None:\n",
211
- " scores.append(-2.0)\n",
212
- " continue\n",
213
- " # Basic structural checks\n",
214
- " if not REQUIRED.issubset(fold.keys()):\n",
215
- " scores.append(-0.5); continue\n",
216
- " verts = fold[\"vertices_coords\"]\n",
217
- " edges = fold[\"edges_vertices\"]\n",
218
- " asgn = fold[\"edges_assignment\"]\n",
219
- " if len(verts) < 3 or len(edges) < 3 or len(edges) != len(asgn):\n",
220
- " scores.append(-0.5); continue\n",
221
- " if not any(a in (\"M\",\"V\") for a in asgn):\n",
222
- " scores.append(-0.5); continue\n",
223
- " if not any(a == \"B\" for a in asgn):\n",
224
- " scores.append(-0.5); continue\n",
225
- " scores.append(1.0)\n",
226
- " return scores\n",
227
- "\n",
228
- "\n",
229
- "def openenv_reward(completions, task_name, **kwargs):\n",
230
- " \"\"\"Reward 2 (OpenEnv API): Submit fold to environment, get simulation reward.\n",
231
- "\n",
232
- " Calls POST /reset and POST /step on the HF Space OpenEnv environment.\n",
233
- " The environment runs the fold simulation and computes shape similarity.\n",
234
- " \"\"\"\n",
235
- " global PRINTER\n",
236
- " # task_name comes as a list from the dataset\n",
237
- " tn = task_name[0] if isinstance(task_name, list) else task_name\n",
238
- "\n",
239
- " scores = []\n",
240
- " for c in completions:\n",
241
- " resp = c[0][\"content\"]\n",
242
- "\n",
243
- " # Periodic logging\n",
244
- " if PRINTER % 10 == 0:\n",
245
- " print(f\"\\n--- [{tn}] Sample {PRINTER} ---\\n{resp[:300]}\")\n",
246
- " PRINTER += 1\n",
247
- "\n",
248
- " # Parse the FOLD JSON from the LLM response\n",
249
- " fold = extract_fold_json(resp)\n",
250
- " if fold is None:\n",
251
- " scores.append(-2.0)\n",
252
- " continue\n",
253
- "\n",
254
- " try:\n",
255
- " # Reset environment for this task\n",
256
- " env.reset(task_name=tn)\n",
257
- "\n",
258
- " # Submit the fold to OpenEnv — environment simulates and scores it\n",
259
- " result = env.step(fold)\n",
260
- "\n",
261
- " # Get reward from the environment\n",
262
- " reward = result.get(\"reward\", None)\n",
263
- " if reward is not None:\n",
264
- " scores.append(float(reward))\n",
265
- " else:\n",
266
- " # Fallback: extract from observation\n",
267
- " obs = result.get(\"observation\", {})\n",
268
- " if obs.get(\"error\"):\n",
269
- " scores.append(-2.0)\n",
270
- " else:\n",
271
- " sim = obs.get(\"shape_similarity\", 0.0)\n",
272
- " scores.append(float(sim) * 20.0)\n",
273
- "\n",
274
- " except requests.exceptions.RequestException as e:\n",
275
- " print(f\"OpenEnv API error: {e}\")\n",
276
- " scores.append(-1.0)\n",
277
- " except Exception as e:\n",
278
- " print(f\"Reward error: {e}\")\n",
279
- " scores.append(-1.0)\n",
280
- "\n",
281
- " return scores\n",
282
- "\n",
283
- "\n",
284
- "print(\"Reward functions ready (valid_fold=local, openenv_reward=API).\")"
285
- ]
286
  },
287
  {
288
  "cell_type": "markdown",
@@ -533,9 +432,18 @@
533
  ],
534
  "metadata": {
535
  "accelerator": "GPU",
536
- "colab": { "gpuType": "T4", "provenance": [] },
537
- "kernelspec": { "display_name": "Python 3", "name": "python3" },
538
- "language_info": { "name": "python", "version": "3.11.0" }
 
 
 
 
 
 
 
 
 
539
  },
540
  "nbformat": 4,
541
  "nbformat_minor": 0
 
181
  "execution_count": null,
182
  "metadata": {},
183
  "outputs": [],
184
+ "source": "PRINTER = 0\n\ndef extract_fold_json(response):\n \"\"\"Extract FOLD JSON from LLM response text.\"\"\"\n m = re.search(r\"```(?:json)?\\s*(\\{.*?\\})\\s*```\", response, re.DOTALL)\n if m:\n try: return json.loads(m.group(1))\n except: pass\n m = re.search(r'\\{[^{}]*\"vertices_coords\"[^{}]*\\}', response, re.DOTALL)\n if m:\n try: return json.loads(m.group(0))\n except: pass\n try:\n d = json.loads(response.strip())\n if isinstance(d, dict) and \"vertices_coords\" in d: return d\n except: pass\n return None\n\n\ndef log_to_dashboard(task_name, reward, shape_similarity, is_valid, error=None, fold_data=None, final_positions=None, target_positions=None):\n \"\"\"Send training step data to the frontend dashboard via /training/log.\"\"\"\n try:\n requests.post(\n f\"{OPENENV_URL.replace('/api/env', '')}/training/log\",\n json={\n \"task_name\": task_name,\n \"reward\": reward,\n \"shape_similarity\": shape_similarity,\n \"is_valid\": is_valid,\n \"error\": error,\n \"fold_data\": fold_data,\n \"final_positions\": final_positions or [],\n \"target_positions\": target_positions or [],\n },\n timeout=5,\n )\n except:\n pass # Don't let dashboard logging break training\n\n\ndef valid_fold_reward(completions, **kwargs):\n \"\"\"Reward 1 (local): +1.0 valid FOLD structure, -0.5 bad structure, -2.0 unparseable.\"\"\"\n REQUIRED = {\"vertices_coords\", \"edges_vertices\", \"edges_assignment\"}\n scores = []\n for c in completions:\n fold = extract_fold_json(c[0][\"content\"])\n if fold is None:\n scores.append(-2.0)\n continue\n # Basic structural checks\n if not REQUIRED.issubset(fold.keys()):\n scores.append(-0.5); continue\n verts = fold[\"vertices_coords\"]\n edges = fold[\"edges_vertices\"]\n asgn = fold[\"edges_assignment\"]\n if len(verts) < 3 or len(edges) < 3 or len(edges) != len(asgn):\n scores.append(-0.5); continue\n if not any(a in (\"M\",\"V\") for a in asgn):\n scores.append(-0.5); continue\n if not any(a == \"B\" for a in asgn):\n scores.append(-0.5); continue\n scores.append(1.0)\n return scores\n\n\ndef openenv_reward(completions, task_name, **kwargs):\n \"\"\"Reward 2 (OpenEnv API): Submit fold to environment, get simulation reward.\n\n Calls POST /reset and POST /step on the HF Space OpenEnv environment.\n The environment runs the fold simulation and computes shape similarity.\n Also logs each step to the frontend training dashboard.\n \"\"\"\n global PRINTER\n # task_name comes as a list from the dataset\n tn = task_name[0] if isinstance(task_name, list) else task_name\n\n scores = []\n for c in completions:\n resp = c[0][\"content\"]\n\n # Periodic logging\n if PRINTER % 10 == 0:\n print(f\"\\n--- [{tn}] Sample {PRINTER} ---\\n{resp[:300]}\")\n PRINTER += 1\n\n # Parse the FOLD JSON from the LLM response\n fold = extract_fold_json(resp)\n if fold is None:\n scores.append(-2.0)\n log_to_dashboard(tn, -2.0, 0.0, False, error=\"No JSON parsed\")\n continue\n\n try:\n # Reset environment for this task\n env.reset(task_name=tn)\n\n # Submit the fold to OpenEnv — environment simulates and scores it\n result = env.step(fold)\n\n # Get reward from the environment\n reward = result.get(\"reward\", None)\n obs = result.get(\"observation\", {})\n sim = obs.get(\"shape_similarity\", 0.0)\n is_valid = not bool(obs.get(\"error\"))\n\n if reward is not None:\n scores.append(float(reward))\n else:\n if obs.get(\"error\"):\n scores.append(-2.0)\n else:\n reward = float(sim) * 20.0\n scores.append(reward)\n\n # Log to frontend dashboard\n log_to_dashboard(\n task_name=tn,\n reward=float(reward) if reward is not None else scores[-1],\n shape_similarity=float(sim),\n is_valid=is_valid,\n error=obs.get(\"error\"),\n fold_data=fold,\n final_positions=obs.get(\"final_positions\", []),\n target_positions=obs.get(\"target_positions\", []),\n )\n\n except requests.exceptions.RequestException as e:\n print(f\"OpenEnv API error: {e}\")\n scores.append(-1.0)\n log_to_dashboard(tn, -1.0, 0.0, False, error=str(e))\n except Exception as e:\n print(f\"Reward error: {e}\")\n scores.append(-1.0)\n log_to_dashboard(tn, -1.0, 0.0, False, error=str(e))\n\n return scores\n\n\nprint(\"Reward functions ready (valid_fold=local, openenv_reward=API + dashboard logging).\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  },
186
  {
187
  "cell_type": "markdown",
 
432
  ],
433
  "metadata": {
434
  "accelerator": "GPU",
435
+ "colab": {
436
+ "gpuType": "T4",
437
+ "provenance": []
438
+ },
439
+ "kernelspec": {
440
+ "display_name": "Python 3",
441
+ "name": "python3"
442
+ },
443
+ "language_info": {
444
+ "name": "python",
445
+ "version": "3.11.0"
446
+ }
447
  },
448
  "nbformat": 4,
449
  "nbformat_minor": 0