pepijn223 HF Staff commited on
Commit
9333aea
Β·
unverified Β·
1 Parent(s): 39baaef

add trail in urdf viewer

Browse files
src/app/[org]/[dataset]/[episode]/actions.ts CHANGED
@@ -6,6 +6,7 @@ import {
6
  loadAllEpisodeLengthsV3,
7
  loadAllEpisodeFrameInfo,
8
  loadCrossEpisodeActionVariance,
 
9
  type EpisodeLengthStats,
10
  type EpisodeFramesData,
11
  type CrossEpisodeVarianceData,
@@ -39,3 +40,13 @@ export async function fetchCrossEpisodeVariance(
39
  return loadCrossEpisodeActionVariance(repoId, version, info as unknown as DatasetMetadata, info.fps);
40
  }
41
 
 
 
 
 
 
 
 
 
 
 
 
6
  loadAllEpisodeLengthsV3,
7
  loadAllEpisodeFrameInfo,
8
  loadCrossEpisodeActionVariance,
9
+ loadEpisodeFlatChartData,
10
  type EpisodeLengthStats,
11
  type EpisodeFramesData,
12
  type CrossEpisodeVarianceData,
 
40
  return loadCrossEpisodeActionVariance(repoId, version, info as unknown as DatasetMetadata, info.fps);
41
  }
42
 
43
+ export async function fetchEpisodeChartData(
44
+ org: string,
45
+ dataset: string,
46
+ episodeId: number,
47
+ ): Promise<Record<string, number>[]> {
48
+ const repoId = `${org}/${dataset}`;
49
+ const { version, info } = await getDatasetVersionAndInfo(repoId);
50
+ return loadEpisodeFlatChartData(repoId, version, info as unknown as DatasetMetadata, episodeId);
51
+ }
52
+
src/app/[org]/[dataset]/[episode]/episode-viewer.tsx CHANGED
@@ -447,7 +447,7 @@ function EpisodeViewerInner({ data, org, dataset }: { data: EpisodeData; org?: s
447
 
448
  {activeTab === "urdf" && (
449
  <Suspense fallback={<Loading />}>
450
- <URDFViewer data={data} />
451
  </Suspense>
452
  )}
453
  </div>
 
447
 
448
  {activeTab === "urdf" && (
449
  <Suspense fallback={<Loading />}>
450
+ <URDFViewer data={data} org={org} dataset={dataset} />
451
  </Suspense>
452
  )}
453
  </div>
src/components/urdf-viewer.tsx CHANGED
@@ -7,10 +7,15 @@ import * as THREE from "three";
7
  import URDFLoader from "urdf-loader";
8
  import type { URDFRobot } from "urdf-loader";
9
  import { STLLoader } from "three/examples/jsm/loaders/STLLoader.js";
 
 
 
10
  import type { EpisodeData } from "@/app/[org]/[dataset]/[episode]/fetch-data";
 
11
 
12
  const SERIES_DELIM = " | ";
13
  const SCALE = 10;
 
14
 
15
  function getUrdfUrl(robotType: string | null): string {
16
  const lower = (robotType ?? "").toLowerCase();
@@ -18,15 +23,15 @@ function getUrdfUrl(robotType: string | null): string {
18
  return "/urdf/so101/so101_new_calib.urdf";
19
  }
20
 
21
- // Detect raw servo values (0-4096) vs radians
22
  function detectAndConvert(values: number[]): number[] {
23
  if (values.length === 0) return values;
24
  const max = Math.max(...values.map(Math.abs));
25
- if (max > 10) return values.map((v) => ((v - 2048) / 2048) * Math.PI);
26
- return values;
 
27
  }
28
 
29
- // Group flat chart columns by feature prefix
30
  function groupColumnsByPrefix(keys: string[]): Record<string, string[]> {
31
  const groups: Record<string, string[]> = {};
32
  for (const key of keys) {
@@ -39,7 +44,6 @@ function groupColumnsByPrefix(keys: string[]): Record<string, string[]> {
39
  return groups;
40
  }
41
 
42
- // Auto-match dataset columns to URDF joint names
43
  function autoMatchJoints(urdfJointNames: string[], columnKeys: string[]): Record<string, string> {
44
  const mapping: Record<string, string> = {};
45
  for (const jointName of urdfJointNames) {
@@ -55,28 +59,72 @@ function autoMatchJoints(urdfJointNames: string[], columnKeys: string[]): Record
55
  return mapping;
56
  }
57
 
 
 
 
 
 
 
58
  // ─── Robot scene (imperative, inside Canvas) ───
59
  function RobotScene({
60
- urdfUrl,
61
- jointValues,
62
- onJointsLoaded,
63
  }: {
64
  urdfUrl: string;
65
  jointValues: Record<string, number>;
66
  onJointsLoaded: (names: string[]) => void;
 
 
67
  }) {
68
- const { scene } = useThree();
69
  const robotRef = useRef<URDFRobot | null>(null);
 
70
  const [loading, setLoading] = useState(true);
71
  const [error, setError] = useState<string | null>(null);
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  useEffect(() => {
74
  setLoading(true);
75
  setError(null);
76
-
77
  const manager = new THREE.LoadingManager();
78
  const loader = new URDFLoader(manager);
79
-
80
  loader.loadMeshCb = (url, mgr, onLoad) => {
81
  const stlLoader = new STLLoader(mgr);
82
  stlLoader.load(
@@ -94,7 +142,6 @@ function RobotScene({
94
  (err) => onLoad(new THREE.Object3D(), err as Error),
95
  );
96
  };
97
-
98
  loader.load(
99
  urdfUrl,
100
  (robot) => {
@@ -105,6 +152,11 @@ function RobotScene({
105
  robot.scale.set(SCALE, SCALE, SCALE);
106
  scene.add(robot);
107
 
 
 
 
 
 
108
  const revolute = Object.values(robot.joints)
109
  .filter((j) => j.jointType === "revolute" || j.jointType === "continuous")
110
  .map((j) => j.name);
@@ -112,26 +164,75 @@ function RobotScene({
112
  setLoading(false);
113
  },
114
  undefined,
115
- (err) => {
116
- console.error("Error loading URDF:", err);
117
- setError(String(err));
118
- setLoading(false);
119
- },
120
  );
121
-
122
  return () => {
123
- if (robotRef.current) {
124
- scene.remove(robotRef.current);
125
- robotRef.current = null;
126
- }
127
  };
128
  }, [urdfUrl, scene, onJointsLoaded]);
129
 
 
 
130
  useFrame(() => {
131
- if (!robotRef.current) return;
 
 
 
132
  for (const [name, value] of Object.entries(jointValues)) {
133
- robotRef.current.setJointValue(name, value);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  }
 
 
 
 
 
 
 
 
 
 
 
 
135
  });
136
 
137
  if (loading) return <Html center><span className="text-white text-lg">Loading robot…</span></Html>;
@@ -139,7 +240,7 @@ function RobotScene({
139
  return null;
140
  }
141
 
142
- // ─── Playback ticker (inside Canvas) ───
143
  function PlaybackDriver({
144
  playing, fps, totalFrames, frameRef, setFrame,
145
  }: {
@@ -149,7 +250,6 @@ function PlaybackDriver({
149
  }) {
150
  const elapsed = useRef(0);
151
  const last = useRef(0);
152
-
153
  useEffect(() => {
154
  if (!playing) return;
155
  let raf: number;
@@ -160,10 +260,10 @@ function PlaybackDriver({
160
  last.current = now;
161
  if (dt > 0 && dt < 0.5) {
162
  elapsed.current += dt;
163
- const frameDelta = Math.floor(elapsed.current * fps);
164
- if (frameDelta > 0) {
165
- elapsed.current -= frameDelta / fps;
166
- frameRef.current = (frameRef.current + frameDelta) % totalFrames;
167
  setFrame(frameRef.current);
168
  }
169
  }
@@ -173,31 +273,68 @@ function PlaybackDriver({
173
  raf = requestAnimationFrame(tick);
174
  return () => cancelAnimationFrame(raf);
175
  }, [playing, fps, totalFrames, frameRef, setFrame]);
176
-
177
  return null;
178
  }
179
 
180
  // ═══════════════════════════════════════
181
  // ─── Main URDF Viewer ───
182
  // ═══════════════════════════════════════
183
- export default function URDFViewer({ data }: { data: EpisodeData }) {
184
- const { flatChartData, datasetInfo } = data;
185
- const totalFrames = flatChartData.length;
 
 
 
 
 
 
 
186
  const fps = datasetInfo.fps || 30;
187
  const urdfUrl = useMemo(() => getUrdfUrl(datasetInfo.robot_type), [datasetInfo.robot_type]);
188
 
189
- // URDF joint names (set after robot loads)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  const [urdfJointNames, setUrdfJointNames] = useState<string[]>([]);
191
  const onJointsLoaded = useCallback((names: string[]) => setUrdfJointNames(names), []);
192
 
193
- // Feature group selection
194
  const columnGroups = useMemo(() => {
195
  if (totalFrames === 0) return {};
196
- return groupColumnsByPrefix(Object.keys(flatChartData[0]));
197
- }, [flatChartData, totalFrames]);
198
 
199
  const groupNames = useMemo(() => Object.keys(columnGroups), [columnGroups]);
200
-
201
  const defaultGroup = useMemo(
202
  () =>
203
  groupNames.find((g) => g.toLowerCase().includes("state")) ??
@@ -208,10 +345,9 @@ export default function URDFViewer({ data }: { data: EpisodeData }) {
208
 
209
  const [selectedGroup, setSelectedGroup] = useState(defaultGroup);
210
  useEffect(() => setSelectedGroup(defaultGroup), [defaultGroup]);
211
-
212
  const selectedColumns = columnGroups[selectedGroup] ?? [];
213
 
214
- // Joint mapping (re-compute when URDF joints or selected columns change)
215
  const autoMapping = useMemo(
216
  () => autoMatchJoints(urdfJointNames, selectedColumns),
217
  [urdfJointNames, selectedColumns],
@@ -219,6 +355,9 @@ export default function URDFViewer({ data }: { data: EpisodeData }) {
219
  const [mapping, setMapping] = useState<Record<string, string>>(autoMapping);
220
  useEffect(() => setMapping(autoMapping), [autoMapping]);
221
 
 
 
 
222
  // Playback
223
  const [frame, setFrame] = useState(0);
224
  const [playing, setPlaying] = useState(false);
@@ -233,7 +372,7 @@ export default function URDFViewer({ data }: { data: EpisodeData }) {
233
  // Compute joint values for current frame
234
  const jointValues = useMemo(() => {
235
  if (totalFrames === 0 || urdfJointNames.length === 0) return {};
236
- const row = flatChartData[Math.min(frame, totalFrames - 1)];
237
  const rawValues: number[] = [];
238
  const names: string[] = [];
239
 
@@ -249,35 +388,34 @@ export default function URDFViewer({ data }: { data: EpisodeData }) {
249
  const values: Record<string, number> = {};
250
  names.forEach((n, i) => { values[n] = converted[i]; });
251
  return values;
252
- }, [flatChartData, frame, mapping, totalFrames, urdfJointNames]);
253
 
254
  const currentTime = totalFrames > 0 ? (frame / fps).toFixed(2) : "0.00";
255
  const totalTime = (totalFrames / fps).toFixed(2);
256
 
257
- if (totalFrames === 0) {
258
- return <div className="text-slate-400 p-8 text-center">No trajectory data available for this episode.</div>;
259
  }
260
 
261
  return (
262
  <div className="flex-1 flex flex-col overflow-hidden">
263
  {/* 3D Viewport */}
264
- <div className="flex-1 min-h-0 bg-slate-950 rounded-lg overflow-hidden border border-slate-700">
 
 
 
 
 
265
  <Canvas camera={{ position: [0.3 * SCALE, 0.25 * SCALE, 0.3 * SCALE], fov: 45, near: 0.01, far: 100 }}>
266
  <ambientLight intensity={0.5} />
267
  <directionalLight position={[3, 5, 4]} intensity={1.2} />
268
  <directionalLight position={[-2, 3, -2]} intensity={0.4} />
269
  <hemisphereLight args={["#b1e1ff", "#444444", 0.4]} />
270
- <RobotScene urdfUrl={urdfUrl} jointValues={jointValues} onJointsLoaded={onJointsLoaded} />
271
  <Grid
272
- args={[10, 10]}
273
- cellSize={0.2}
274
- cellThickness={0.5}
275
- cellColor="#334155"
276
- sectionSize={1}
277
- sectionThickness={1}
278
- sectionColor="#475569"
279
- fadeDistance={10}
280
- position={[0, 0, 0]}
281
  />
282
  <OrbitControls target={[0, 0.8, 0]} />
283
  <PlaybackDriver playing={playing} fps={fps} totalFrames={totalFrames} frameRef={frameRef} setFrame={setFrame} />
@@ -286,8 +424,32 @@ export default function URDFViewer({ data }: { data: EpisodeData }) {
286
 
287
  {/* Controls */}
288
  <div className="bg-slate-800/90 border-t border-slate-700 p-3 space-y-3 shrink-0">
289
- {/* Timeline */}
290
  <div className="flex items-center gap-3">
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  <button
292
  onClick={() => { setPlaying(!playing); if (!playing) frameRef.current = frame; }}
293
  className="w-8 h-8 flex items-center justify-center rounded bg-orange-600 hover:bg-orange-500 text-white transition-colors shrink-0"
@@ -298,10 +460,21 @@ export default function URDFViewer({ data }: { data: EpisodeData }) {
298
  <svg width="12" height="14" viewBox="0 0 12 14"><polygon points="2,1 11,7 2,13" fill="white" /></svg>
299
  )}
300
  </button>
 
 
 
 
 
 
 
 
 
 
 
301
  <input type="range" min={0} max={Math.max(totalFrames - 1, 0)} value={frame}
302
  onChange={handleFrameChange} className="flex-1 h-1.5 accent-orange-500 cursor-pointer" />
303
  <span className="text-xs text-slate-400 tabular-nums w-28 text-right shrink-0">{currentTime}s / {totalTime}s</span>
304
- <span className="text-xs text-slate-500 tabular-nums w-20 text-right shrink-0">F {frame}/{totalFrames - 1}</span>
305
  </div>
306
 
307
  {/* Data source + joint mapping */}
 
7
  import URDFLoader from "urdf-loader";
8
  import type { URDFRobot } from "urdf-loader";
9
  import { STLLoader } from "three/examples/jsm/loaders/STLLoader.js";
10
+ import { Line2 } from "three/examples/jsm/lines/Line2.js";
11
+ import { LineMaterial } from "three/examples/jsm/lines/LineMaterial.js";
12
+ import { LineGeometry } from "three/examples/jsm/lines/LineGeometry.js";
13
  import type { EpisodeData } from "@/app/[org]/[dataset]/[episode]/fetch-data";
14
+ import { fetchEpisodeChartData } from "@/app/[org]/[dataset]/[episode]/actions";
15
 
16
  const SERIES_DELIM = " | ";
17
  const SCALE = 10;
18
+ const DEG2RAD = Math.PI / 180;
19
 
20
  function getUrdfUrl(robotType: string | null): string {
21
  const lower = (robotType ?? "").toLowerCase();
 
23
  return "/urdf/so101/so101_new_calib.urdf";
24
  }
25
 
26
+ // Detect unit: servo ticks (0-4096), degrees (>6.28), or radians
27
  function detectAndConvert(values: number[]): number[] {
28
  if (values.length === 0) return values;
29
  const max = Math.max(...values.map(Math.abs));
30
+ if (max > 360) return values.map((v) => ((v - 2048) / 2048) * Math.PI); // servo ticks
31
+ if (max > 6.3) return values.map((v) => v * DEG2RAD); // degrees
32
+ return values; // already radians
33
  }
34
 
 
35
  function groupColumnsByPrefix(keys: string[]): Record<string, string[]> {
36
  const groups: Record<string, string[]> = {};
37
  for (const key of keys) {
 
44
  return groups;
45
  }
46
 
 
47
  function autoMatchJoints(urdfJointNames: string[], columnKeys: string[]): Record<string, string> {
48
  const mapping: Record<string, string> = {};
49
  for (const jointName of urdfJointNames) {
 
59
  return mapping;
60
  }
61
 
62
+ // Tip link names to try (so101 then so100 naming)
63
+ const TIP_LINK_NAMES = ["gripper_frame_link", "gripperframe", "gripper_link", "gripper"];
64
+ const TRAIL_DURATION = 1.0; // seconds
65
+ const TRAIL_COLOR = new THREE.Color("#ff6600");
66
+ const MAX_TRAIL_POINTS = 300;
67
+
68
  // ─── Robot scene (imperative, inside Canvas) ───
69
  function RobotScene({
70
+ urdfUrl, jointValues, onJointsLoaded, trailEnabled, trailResetKey,
 
 
71
  }: {
72
  urdfUrl: string;
73
  jointValues: Record<string, number>;
74
  onJointsLoaded: (names: string[]) => void;
75
+ trailEnabled: boolean;
76
+ trailResetKey: number;
77
  }) {
78
+ const { scene, size } = useThree();
79
  const robotRef = useRef<URDFRobot | null>(null);
80
+ const tipLinkRef = useRef<THREE.Object3D | null>(null);
81
  const [loading, setLoading] = useState(true);
82
  const [error, setError] = useState<string | null>(null);
83
 
84
+ // Trail state
85
+ const trailRef = useRef<{ positions: Float32Array; colors: Float32Array; times: number[]; count: number }>({
86
+ positions: new Float32Array(MAX_TRAIL_POINTS * 3),
87
+ colors: new Float32Array(MAX_TRAIL_POINTS * 3), // RGB, no alpha
88
+ times: [],
89
+ count: 0,
90
+ });
91
+ const lineRef = useRef<Line2 | null>(null);
92
+ const trailMatRef = useRef<LineMaterial | null>(null);
93
+
94
+ // Reset trail when episode changes
95
+ useEffect(() => {
96
+ trailRef.current.count = 0;
97
+ trailRef.current.times = [];
98
+ if (lineRef.current) lineRef.current.visible = false;
99
+ }, [trailResetKey]);
100
+
101
+ // Create trail Line2 object
102
+ useEffect(() => {
103
+ const geometry = new LineGeometry();
104
+ const material = new LineMaterial({
105
+ color: 0xffffff,
106
+ linewidth: 4, // pixels
107
+ vertexColors: true,
108
+ transparent: true,
109
+ worldUnits: false,
110
+ });
111
+ material.resolution.set(window.innerWidth, window.innerHeight);
112
+ trailMatRef.current = material;
113
+
114
+ const line = new Line2(geometry, material);
115
+ line.frustumCulled = false;
116
+ line.visible = false;
117
+ lineRef.current = line;
118
+ scene.add(line);
119
+
120
+ return () => { scene.remove(line); geometry.dispose(); material.dispose(); };
121
+ }, [scene]);
122
+
123
  useEffect(() => {
124
  setLoading(true);
125
  setError(null);
 
126
  const manager = new THREE.LoadingManager();
127
  const loader = new URDFLoader(manager);
 
128
  loader.loadMeshCb = (url, mgr, onLoad) => {
129
  const stlLoader = new STLLoader(mgr);
130
  stlLoader.load(
 
142
  (err) => onLoad(new THREE.Object3D(), err as Error),
143
  );
144
  };
 
145
  loader.load(
146
  urdfUrl,
147
  (robot) => {
 
152
  robot.scale.set(SCALE, SCALE, SCALE);
153
  scene.add(robot);
154
 
155
+ // Find the tip link for the trail
156
+ for (const name of TIP_LINK_NAMES) {
157
+ if (robot.frames[name]) { tipLinkRef.current = robot.frames[name]; break; }
158
+ }
159
+
160
  const revolute = Object.values(robot.joints)
161
  .filter((j) => j.jointType === "revolute" || j.jointType === "continuous")
162
  .map((j) => j.name);
 
164
  setLoading(false);
165
  },
166
  undefined,
167
+ (err) => { console.error("Error loading URDF:", err); setError(String(err)); setLoading(false); },
 
 
 
 
168
  );
 
169
  return () => {
170
+ if (robotRef.current) { scene.remove(robotRef.current); robotRef.current = null; }
171
+ tipLinkRef.current = null;
 
 
172
  };
173
  }, [urdfUrl, scene, onJointsLoaded]);
174
 
175
+ const tipWorldPos = useMemo(() => new THREE.Vector3(), []);
176
+
177
  useFrame(() => {
178
+ const robot = robotRef.current;
179
+ if (!robot) return;
180
+
181
+ // Apply joint values
182
  for (const [name, value] of Object.entries(jointValues)) {
183
+ robot.setJointValue(name, value);
184
+ }
185
+ robot.updateMatrixWorld(true);
186
+
187
+ // Update trail
188
+ const line = lineRef.current;
189
+ const tip = tipLinkRef.current;
190
+ if (!line || !tip || !trailEnabled) {
191
+ if (line) line.visible = false;
192
+ return;
193
+ }
194
+
195
+ // Keep resolution in sync with viewport
196
+ if (trailMatRef.current) trailMatRef.current.resolution.set(size.width, size.height);
197
+
198
+ tip.getWorldPosition(tipWorldPos);
199
+ const now = performance.now() / 1000;
200
+ const trail = trailRef.current;
201
+
202
+ // Add new point
203
+ if (trail.count < MAX_TRAIL_POINTS) {
204
+ trail.count++;
205
+ } else {
206
+ trail.positions.copyWithin(0, 3);
207
+ trail.colors.copyWithin(0, 3);
208
+ trail.times.shift();
209
+ }
210
+ const idx = trail.count - 1;
211
+ trail.positions[idx * 3] = tipWorldPos.x;
212
+ trail.positions[idx * 3 + 1] = tipWorldPos.y;
213
+ trail.positions[idx * 3 + 2] = tipWorldPos.z;
214
+ trail.times.push(now);
215
+
216
+ // Update colors: fade from orange β†’ black based on age
217
+ for (let i = 0; i < trail.count; i++) {
218
+ const age = now - trail.times[i];
219
+ const t = Math.max(0, 1 - age / TRAIL_DURATION);
220
+ trail.colors[i * 3] = TRAIL_COLOR.r * t;
221
+ trail.colors[i * 3 + 1] = TRAIL_COLOR.g * t;
222
+ trail.colors[i * 3 + 2] = TRAIL_COLOR.b * t;
223
  }
224
+
225
+ // Need at least 2 points for Line2
226
+ if (trail.count < 2) { line.visible = false; return; }
227
+
228
+ // Rebuild geometry (Line2 requires this)
229
+ const geo = new LineGeometry();
230
+ geo.setPositions(Array.from(trail.positions.subarray(0, trail.count * 3)));
231
+ geo.setColors(Array.from(trail.colors.subarray(0, trail.count * 3)));
232
+ line.geometry.dispose();
233
+ line.geometry = geo;
234
+ line.computeLineDistances();
235
+ line.visible = true;
236
  });
237
 
238
  if (loading) return <Html center><span className="text-white text-lg">Loading robot…</span></Html>;
 
240
  return null;
241
  }
242
 
243
+ // ─── Playback ticker ───
244
  function PlaybackDriver({
245
  playing, fps, totalFrames, frameRef, setFrame,
246
  }: {
 
250
  }) {
251
  const elapsed = useRef(0);
252
  const last = useRef(0);
 
253
  useEffect(() => {
254
  if (!playing) return;
255
  let raf: number;
 
260
  last.current = now;
261
  if (dt > 0 && dt < 0.5) {
262
  elapsed.current += dt;
263
+ const fd = Math.floor(elapsed.current * fps);
264
+ if (fd > 0) {
265
+ elapsed.current -= fd / fps;
266
+ frameRef.current = (frameRef.current + fd) % totalFrames;
267
  setFrame(frameRef.current);
268
  }
269
  }
 
273
  raf = requestAnimationFrame(tick);
274
  return () => cancelAnimationFrame(raf);
275
  }, [playing, fps, totalFrames, frameRef, setFrame]);
 
276
  return null;
277
  }
278
 
279
  // ═══════════════════════════════════════
280
  // ─── Main URDF Viewer ───
281
  // ═══════════════════════════════════════
282
+ export default function URDFViewer({
283
+ data,
284
+ org,
285
+ dataset,
286
+ }: {
287
+ data: EpisodeData;
288
+ org?: string;
289
+ dataset?: string;
290
+ }) {
291
+ const { datasetInfo, episodes } = data;
292
  const fps = datasetInfo.fps || 30;
293
  const urdfUrl = useMemo(() => getUrdfUrl(datasetInfo.robot_type), [datasetInfo.robot_type]);
294
 
295
+ // Episode selection & chart data
296
+ const [selectedEpisode, setSelectedEpisode] = useState(data.episodeId);
297
+ const [chartData, setChartData] = useState(data.flatChartData);
298
+ const [episodeLoading, setEpisodeLoading] = useState(false);
299
+ const chartDataCache = useRef<Record<number, Record<string, number>[]>>({
300
+ [data.episodeId]: data.flatChartData,
301
+ });
302
+
303
+ const handleEpisodeChange = useCallback((epId: number) => {
304
+ setSelectedEpisode(epId);
305
+ setFrame(0);
306
+ frameRef.current = 0;
307
+ setPlaying(false);
308
+
309
+ if (chartDataCache.current[epId]) {
310
+ setChartData(chartDataCache.current[epId]);
311
+ return;
312
+ }
313
+
314
+ if (!org || !dataset) return;
315
+ setEpisodeLoading(true);
316
+ fetchEpisodeChartData(org, dataset, epId)
317
+ .then((result) => {
318
+ chartDataCache.current[epId] = result;
319
+ setChartData(result);
320
+ })
321
+ .catch((err) => console.error("Failed to load episode:", err))
322
+ .finally(() => setEpisodeLoading(false));
323
+ }, [org, dataset]);
324
+
325
+ const totalFrames = chartData.length;
326
+
327
+ // URDF joint names
328
  const [urdfJointNames, setUrdfJointNames] = useState<string[]>([]);
329
  const onJointsLoaded = useCallback((names: string[]) => setUrdfJointNames(names), []);
330
 
331
+ // Feature groups
332
  const columnGroups = useMemo(() => {
333
  if (totalFrames === 0) return {};
334
+ return groupColumnsByPrefix(Object.keys(chartData[0]));
335
+ }, [chartData, totalFrames]);
336
 
337
  const groupNames = useMemo(() => Object.keys(columnGroups), [columnGroups]);
 
338
  const defaultGroup = useMemo(
339
  () =>
340
  groupNames.find((g) => g.toLowerCase().includes("state")) ??
 
345
 
346
  const [selectedGroup, setSelectedGroup] = useState(defaultGroup);
347
  useEffect(() => setSelectedGroup(defaultGroup), [defaultGroup]);
 
348
  const selectedColumns = columnGroups[selectedGroup] ?? [];
349
 
350
+ // Joint mapping
351
  const autoMapping = useMemo(
352
  () => autoMatchJoints(urdfJointNames, selectedColumns),
353
  [urdfJointNames, selectedColumns],
 
355
  const [mapping, setMapping] = useState<Record<string, string>>(autoMapping);
356
  useEffect(() => setMapping(autoMapping), [autoMapping]);
357
 
358
+ // Trail
359
+ const [trailEnabled, setTrailEnabled] = useState(true);
360
+
361
  // Playback
362
  const [frame, setFrame] = useState(0);
363
  const [playing, setPlaying] = useState(false);
 
372
  // Compute joint values for current frame
373
  const jointValues = useMemo(() => {
374
  if (totalFrames === 0 || urdfJointNames.length === 0) return {};
375
+ const row = chartData[Math.min(frame, totalFrames - 1)];
376
  const rawValues: number[] = [];
377
  const names: string[] = [];
378
 
 
388
  const values: Record<string, number> = {};
389
  names.forEach((n, i) => { values[n] = converted[i]; });
390
  return values;
391
+ }, [chartData, frame, mapping, totalFrames, urdfJointNames]);
392
 
393
  const currentTime = totalFrames > 0 ? (frame / fps).toFixed(2) : "0.00";
394
  const totalTime = (totalFrames / fps).toFixed(2);
395
 
396
+ if (data.flatChartData.length === 0) {
397
+ return <div className="text-slate-400 p-8 text-center">No trajectory data available.</div>;
398
  }
399
 
400
  return (
401
  <div className="flex-1 flex flex-col overflow-hidden">
402
  {/* 3D Viewport */}
403
+ <div className="flex-1 min-h-0 bg-slate-950 rounded-lg overflow-hidden border border-slate-700 relative">
404
+ {episodeLoading && (
405
+ <div className="absolute inset-0 z-10 flex items-center justify-center bg-slate-950/70">
406
+ <span className="text-white text-lg animate-pulse">Loading episode {selectedEpisode}…</span>
407
+ </div>
408
+ )}
409
  <Canvas camera={{ position: [0.3 * SCALE, 0.25 * SCALE, 0.3 * SCALE], fov: 45, near: 0.01, far: 100 }}>
410
  <ambientLight intensity={0.5} />
411
  <directionalLight position={[3, 5, 4]} intensity={1.2} />
412
  <directionalLight position={[-2, 3, -2]} intensity={0.4} />
413
  <hemisphereLight args={["#b1e1ff", "#444444", 0.4]} />
414
+ <RobotScene urdfUrl={urdfUrl} jointValues={jointValues} onJointsLoaded={onJointsLoaded} trailEnabled={trailEnabled} trailResetKey={selectedEpisode} />
415
  <Grid
416
+ args={[10, 10]} cellSize={0.2} cellThickness={0.5} cellColor="#334155"
417
+ sectionSize={1} sectionThickness={1} sectionColor="#475569"
418
+ fadeDistance={10} position={[0, 0, 0]}
 
 
 
 
 
 
419
  />
420
  <OrbitControls target={[0, 0.8, 0]} />
421
  <PlaybackDriver playing={playing} fps={fps} totalFrames={totalFrames} frameRef={frameRef} setFrame={setFrame} />
 
424
 
425
  {/* Controls */}
426
  <div className="bg-slate-800/90 border-t border-slate-700 p-3 space-y-3 shrink-0">
427
+ {/* Episode selector + Timeline */}
428
  <div className="flex items-center gap-3">
429
+ {/* Episode selector */}
430
+ <div className="flex items-center gap-1.5 shrink-0">
431
+ <button
432
+ onClick={() => { if (selectedEpisode > episodes[0]) handleEpisodeChange(selectedEpisode - 1); }}
433
+ disabled={selectedEpisode <= episodes[0]}
434
+ className="w-6 h-6 flex items-center justify-center rounded bg-slate-700 hover:bg-slate-600 text-slate-300 disabled:opacity-30 disabled:cursor-not-allowed text-xs"
435
+ >β—€</button>
436
+ <select
437
+ value={selectedEpisode}
438
+ onChange={(e) => handleEpisodeChange(Number(e.target.value))}
439
+ className="bg-slate-900 text-slate-200 text-xs rounded px-1.5 py-1 border border-slate-600 w-28"
440
+ >
441
+ {episodes.map((ep) => (
442
+ <option key={ep} value={ep}>Episode {ep}</option>
443
+ ))}
444
+ </select>
445
+ <button
446
+ onClick={() => { if (selectedEpisode < episodes[episodes.length - 1]) handleEpisodeChange(selectedEpisode + 1); }}
447
+ disabled={selectedEpisode >= episodes[episodes.length - 1]}
448
+ className="w-6 h-6 flex items-center justify-center rounded bg-slate-700 hover:bg-slate-600 text-slate-300 disabled:opacity-30 disabled:cursor-not-allowed text-xs"
449
+ >β–Ά</button>
450
+ </div>
451
+
452
+ {/* Play/Pause */}
453
  <button
454
  onClick={() => { setPlaying(!playing); if (!playing) frameRef.current = frame; }}
455
  className="w-8 h-8 flex items-center justify-center rounded bg-orange-600 hover:bg-orange-500 text-white transition-colors shrink-0"
 
460
  <svg width="12" height="14" viewBox="0 0 12 14"><polygon points="2,1 11,7 2,13" fill="white" /></svg>
461
  )}
462
  </button>
463
+
464
+ {/* Trail toggle */}
465
+ <button
466
+ onClick={() => setTrailEnabled((v) => !v)}
467
+ className={`px-2 h-8 text-xs rounded transition-colors shrink-0 ${
468
+ trailEnabled ? "bg-orange-600/30 text-orange-400 border border-orange-500" : "bg-slate-700 text-slate-400 border border-slate-600"
469
+ }`}
470
+ title={trailEnabled ? "Hide trail" : "Show trail"}
471
+ >Trail</button>
472
+
473
+ {/* Scrubber */}
474
  <input type="range" min={0} max={Math.max(totalFrames - 1, 0)} value={frame}
475
  onChange={handleFrameChange} className="flex-1 h-1.5 accent-orange-500 cursor-pointer" />
476
  <span className="text-xs text-slate-400 tabular-nums w-28 text-right shrink-0">{currentTime}s / {totalTime}s</span>
477
+ <span className="text-xs text-slate-500 tabular-nums w-20 text-right shrink-0">F {frame}/{Math.max(totalFrames - 1, 0)}</span>
478
  </div>
479
 
480
  {/* Data source + joint mapping */}