pepijn223 HF Staff commited on
Commit
39baaef
·
unverified ·
1 Parent(s): 6346799

Make episode visualizer nicer (colors, formatting) add detection for still epsiodes and fix urdf visualizer

Browse files
next.config.ts CHANGED
@@ -2,13 +2,13 @@ import type { NextConfig } from "next";
2
  import packageJson from './package.json';
3
 
4
  const nextConfig: NextConfig = {
5
-
6
  typescript: {
7
  ignoreBuildErrors: true,
8
  },
9
  eslint: {
10
  ignoreDuringBuilds: true,
11
  },
 
12
  generateBuildId: () => packageJson.version,
13
  };
14
 
 
2
  import packageJson from './package.json';
3
 
4
  const nextConfig: NextConfig = {
 
5
  typescript: {
6
  ignoreBuildErrors: true,
7
  },
8
  eslint: {
9
  ignoreDuringBuilds: true,
10
  },
11
+ transpilePackages: ["three"],
12
  generateBuildId: () => packageJson.version,
13
  };
14
 
package-lock.json CHANGED
@@ -17,7 +17,8 @@
17
  "react-dom": "^19.0.0",
18
  "react-icons": "^5.5.0",
19
  "recharts": "^2.15.3",
20
- "three": "^0.182.0"
 
21
  },
22
  "devDependencies": {
23
  "@eslint/eslintrc": "^3",
@@ -6849,6 +6850,15 @@
6849
  "@unrs/resolver-binding-win32-x64-msvc": "1.11.1"
6850
  }
6851
  },
 
 
 
 
 
 
 
 
 
6852
  "node_modules/uri-js": {
6853
  "version": "4.4.1",
6854
  "resolved": "https://registry.npmjs.org/uri-js/-/uri-js-4.4.1.tgz",
 
17
  "react-dom": "^19.0.0",
18
  "react-icons": "^5.5.0",
19
  "recharts": "^2.15.3",
20
+ "three": "^0.182.0",
21
+ "urdf-loader": "^0.12.6"
22
  },
23
  "devDependencies": {
24
  "@eslint/eslintrc": "^3",
 
6850
  "@unrs/resolver-binding-win32-x64-msvc": "1.11.1"
6851
  }
6852
  },
6853
+ "node_modules/urdf-loader": {
6854
+ "version": "0.12.6",
6855
+ "resolved": "https://registry.npmjs.org/urdf-loader/-/urdf-loader-0.12.6.tgz",
6856
+ "integrity": "sha512-EwpgOCPe6Tep2+MXoo/r13keHaKQXMcM+4s9+jX0NRxNS/QSNuP5JPdk5AIgWEoEB43AkEj9Vk+Nr53NkXgSbA==",
6857
+ "license": "Apache-2.0",
6858
+ "peerDependencies": {
6859
+ "three": ">=0.152.0"
6860
+ }
6861
+ },
6862
  "node_modules/uri-js": {
6863
  "version": "4.4.1",
6864
  "resolved": "https://registry.npmjs.org/uri-js/-/uri-js-4.4.1.tgz",
package.json CHANGED
@@ -19,7 +19,8 @@
19
  "react-dom": "^19.0.0",
20
  "react-icons": "^5.5.0",
21
  "recharts": "^2.15.3",
22
- "three": "^0.182.0"
 
23
  },
24
  "devDependencies": {
25
  "@eslint/eslintrc": "^3",
 
19
  "react-dom": "^19.0.0",
20
  "react-icons": "^5.5.0",
21
  "recharts": "^2.15.3",
22
+ "three": "^0.182.0",
23
+ "urdf-loader": "^0.12.6"
24
  },
25
  "devDependencies": {
26
  "@eslint/eslintrc": "^3",
src/app/[org]/[dataset]/[episode]/actions.ts CHANGED
@@ -5,8 +5,10 @@ import type { DatasetMetadata } from "@/utils/parquetUtils";
5
  import {
6
  loadAllEpisodeLengthsV3,
7
  loadAllEpisodeFrameInfo,
 
8
  type EpisodeLengthStats,
9
  type EpisodeFramesData,
 
10
  } from "./fetch-data";
11
 
12
  export async function fetchEpisodeLengthStats(
@@ -28,3 +30,12 @@ export async function fetchEpisodeFrames(
28
  return loadAllEpisodeFrameInfo(repoId, version, info as unknown as DatasetMetadata);
29
  }
30
 
 
 
 
 
 
 
 
 
 
 
5
  import {
6
  loadAllEpisodeLengthsV3,
7
  loadAllEpisodeFrameInfo,
8
+ loadCrossEpisodeActionVariance,
9
  type EpisodeLengthStats,
10
  type EpisodeFramesData,
11
+ type CrossEpisodeVarianceData,
12
  } from "./fetch-data";
13
 
14
  export async function fetchEpisodeLengthStats(
 
30
  return loadAllEpisodeFrameInfo(repoId, version, info as unknown as DatasetMetadata);
31
  }
32
 
33
+ export async function fetchCrossEpisodeVariance(
34
+ org: string,
35
+ dataset: string,
36
+ ): Promise<CrossEpisodeVarianceData | null> {
37
+ const repoId = `${org}/${dataset}`;
38
+ const { version, info } = await getDatasetVersionAndInfo(repoId);
39
+ return loadCrossEpisodeActionVariance(repoId, version, info as unknown as DatasetMetadata, info.fps);
40
+ }
41
+
src/app/[org]/[dataset]/[episode]/episode-viewer.tsx CHANGED
@@ -19,12 +19,14 @@ import {
19
  type ColumnMinMax,
20
  type EpisodeLengthStats,
21
  type EpisodeFramesData,
 
22
  } from "./fetch-data";
23
- import { fetchEpisodeLengthStats, fetchEpisodeFrames } from "./actions";
24
 
25
  const URDFViewer = lazy(() => import("@/components/urdf-viewer"));
 
26
 
27
- type ActiveTab = "episodes" | "statistics" | "frames" | "urdf";
28
 
29
  export default function EpisodeViewer({
30
  data,
@@ -87,6 +89,9 @@ function EpisodeViewerInner({ data, org, dataset }: { data: EpisodeData; org?: s
87
  const [episodeFramesData, setEpisodeFramesData] = useState<EpisodeFramesData | null>(null);
88
  const [framesLoading, setFramesLoading] = useState(false);
89
  const framesLoadedRef = useRef(false);
 
 
 
90
 
91
  const loadStats = () => {
92
  if (statsLoadedRef.current) return;
@@ -113,10 +118,21 @@ function EpisodeViewerInner({ data, org, dataset }: { data: EpisodeData; org?: s
113
  .finally(() => setFramesLoading(false));
114
  };
115
 
 
 
 
 
 
 
 
 
 
 
116
  const handleTabChange = (tab: ActiveTab) => {
117
  setActiveTab(tab);
118
  if (tab === "statistics") loadStats();
119
  if (tab === "frames") loadFrames();
 
120
  };
121
 
122
  // Use context for time sync
@@ -288,6 +304,19 @@ function EpisodeViewerInner({ data, org, dataset }: { data: EpisodeData; org?: s
288
  <span className="absolute bottom-0 left-0 right-0 h-0.5 bg-orange-500" />
289
  )}
290
  </button>
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  {isSO101Robot(datasetInfo.robot_type) && (
292
  <button
293
  className={`px-6 py-2.5 text-sm font-medium transition-colors relative ${
@@ -405,6 +434,17 @@ function EpisodeViewerInner({ data, org, dataset }: { data: EpisodeData; org?: s
405
  <OverviewPanel data={episodeFramesData} loading={framesLoading} />
406
  )}
407
 
 
 
 
 
 
 
 
 
 
 
 
408
  {activeTab === "urdf" && (
409
  <Suspense fallback={<Loading />}>
410
  <URDFViewer data={data} />
 
19
  type ColumnMinMax,
20
  type EpisodeLengthStats,
21
  type EpisodeFramesData,
22
+ type CrossEpisodeVarianceData,
23
  } from "./fetch-data";
24
+ import { fetchEpisodeLengthStats, fetchEpisodeFrames, fetchCrossEpisodeVariance } from "./actions";
25
 
26
  const URDFViewer = lazy(() => import("@/components/urdf-viewer"));
27
+ const ActionInsightsPanel = lazy(() => import("@/components/action-insights-panel"));
28
 
29
+ type ActiveTab = "episodes" | "statistics" | "frames" | "insights" | "urdf";
30
 
31
  export default function EpisodeViewer({
32
  data,
 
89
  const [episodeFramesData, setEpisodeFramesData] = useState<EpisodeFramesData | null>(null);
90
  const [framesLoading, setFramesLoading] = useState(false);
91
  const framesLoadedRef = useRef(false);
92
+ const [crossEpData, setCrossEpData] = useState<CrossEpisodeVarianceData | null>(null);
93
+ const [insightsLoading, setInsightsLoading] = useState(false);
94
+ const insightsLoadedRef = useRef(false);
95
 
96
  const loadStats = () => {
97
  if (statsLoadedRef.current) return;
 
118
  .finally(() => setFramesLoading(false));
119
  };
120
 
121
+ const loadInsights = () => {
122
+ if (insightsLoadedRef.current || !org || !dataset) return;
123
+ insightsLoadedRef.current = true;
124
+ setInsightsLoading(true);
125
+ fetchCrossEpisodeVariance(org, dataset)
126
+ .then(setCrossEpData)
127
+ .catch((err) => console.error("[cross-ep] Failed:", err))
128
+ .finally(() => setInsightsLoading(false));
129
+ };
130
+
131
  const handleTabChange = (tab: ActiveTab) => {
132
  setActiveTab(tab);
133
  if (tab === "statistics") loadStats();
134
  if (tab === "frames") loadFrames();
135
+ if (tab === "insights") loadInsights();
136
  };
137
 
138
  // Use context for time sync
 
304
  <span className="absolute bottom-0 left-0 right-0 h-0.5 bg-orange-500" />
305
  )}
306
  </button>
307
+ <button
308
+ className={`px-6 py-2.5 text-sm font-medium transition-colors relative ${
309
+ activeTab === "insights"
310
+ ? "text-orange-400"
311
+ : "text-slate-400 hover:text-slate-200"
312
+ }`}
313
+ onClick={() => handleTabChange("insights")}
314
+ >
315
+ Action Insights
316
+ {activeTab === "insights" && (
317
+ <span className="absolute bottom-0 left-0 right-0 h-0.5 bg-orange-500" />
318
+ )}
319
+ </button>
320
  {isSO101Robot(datasetInfo.robot_type) && (
321
  <button
322
  className={`px-6 py-2.5 text-sm font-medium transition-colors relative ${
 
434
  <OverviewPanel data={episodeFramesData} loading={framesLoading} />
435
  )}
436
 
437
+ {activeTab === "insights" && (
438
+ <Suspense fallback={<Loading />}>
439
+ <ActionInsightsPanel
440
+ flatChartData={data.flatChartData}
441
+ fps={datasetInfo.fps}
442
+ crossEpisodeData={crossEpData}
443
+ crossEpisodeLoading={insightsLoading}
444
+ />
445
+ </Suspense>
446
+ )}
447
+
448
  {activeTab === "urdf" && (
449
  <Suspense fallback={<Loading />}>
450
  <URDFViewer data={data} />
src/app/[org]/[dataset]/[episode]/fetch-data.ts CHANGED
@@ -1337,6 +1337,309 @@ export async function loadAllEpisodeFrameInfo(
1337
  return { cameras, framesByCamera };
1338
  }
1339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1340
  // Safe wrapper for UI error display
1341
  export async function getEpisodeDataSafe(
1342
  org: string,
 
1337
  return { cameras, framesByCamera };
1338
  }
1339
 
1340
+ // ─── Cross-episode action variance ──────────────────────────────
1341
+
1342
+ export type LowMovementEpisode = { episodeIndex: number; totalMovement: number };
1343
+
1344
+ export type AggVelocityStat = {
1345
+ name: string;
1346
+ std: number;
1347
+ maxAbs: number;
1348
+ bins: number[];
1349
+ lo: number;
1350
+ hi: number;
1351
+ };
1352
+
1353
+ export type AggAutocorrelation = {
1354
+ chartData: Record<string, number>[];
1355
+ suggestedChunk: number | null;
1356
+ shortKeys: string[];
1357
+ };
1358
+
1359
+ export type CrossEpisodeVarianceData = {
1360
+ actionNames: string[];
1361
+ timeBins: number[];
1362
+ variance: number[][];
1363
+ numEpisodes: number;
1364
+ lowMovementEpisodes: LowMovementEpisode[];
1365
+ aggVelocity: AggVelocityStat[];
1366
+ aggAutocorrelation: AggAutocorrelation | null;
1367
+ };
1368
+
1369
+ export async function loadCrossEpisodeActionVariance(
1370
+ repoId: string,
1371
+ version: string,
1372
+ info: DatasetMetadata,
1373
+ fps: number,
1374
+ maxEpisodes = 500,
1375
+ numTimeBins = 50,
1376
+ ): Promise<CrossEpisodeVarianceData | null> {
1377
+ const actionEntry = Object.entries(info.features)
1378
+ .find(([key, f]) => key === "action" && f.shape.length === 1);
1379
+ if (!actionEntry) {
1380
+ console.warn("[cross-ep] No action feature found. Available features:", Object.entries(info.features).map(([k, f]) => `${k}(${f.dtype}, shape=${JSON.stringify(f.shape)})`).join(", "));
1381
+ return null;
1382
+ }
1383
+
1384
+ const [actionKey, actionMeta] = actionEntry;
1385
+ const actionDim = actionMeta.shape[0];
1386
+
1387
+ let names: unknown = actionMeta.names;
1388
+ while (typeof names === "object" && names !== null && !Array.isArray(names)) {
1389
+ names = Object.values(names)[0];
1390
+ }
1391
+ const actionNames = Array.isArray(names)
1392
+ ? (names as string[]).map(n => `${actionKey}${SERIES_NAME_DELIMITER}${n}`)
1393
+ : Array.from({ length: actionDim }, (_, i) => `${actionKey}${SERIES_NAME_DELIMITER}${i}`);
1394
+
1395
+ // Collect episode metadata
1396
+ type EpMeta = { index: number; chunkIdx: number; fileIdx: number; from: number; to: number };
1397
+ const allEps: EpMeta[] = [];
1398
+
1399
+ if (version === "v3.0") {
1400
+ let fileIndex = 0;
1401
+ while (true) {
1402
+ const path = `meta/episodes/chunk-000/file-${fileIndex.toString().padStart(3, "0")}.parquet`;
1403
+ try {
1404
+ const buf = await fetchParquetFile(buildVersionedUrl(repoId, version, path));
1405
+ const rows = await readParquetAsObjects(buf, []);
1406
+ if (rows.length === 0 && fileIndex > 0) break;
1407
+ for (const row of rows) {
1408
+ const parsed = parseEpisodeRowSimple(row);
1409
+ allEps.push({
1410
+ index: parsed.episode_index,
1411
+ chunkIdx: parsed.data_chunk_index,
1412
+ fileIdx: parsed.data_file_index,
1413
+ from: parsed.dataset_from_index,
1414
+ to: parsed.dataset_to_index,
1415
+ });
1416
+ }
1417
+ fileIndex++;
1418
+ } catch { break; }
1419
+ }
1420
+ } else {
1421
+ for (let i = 0; i < info.total_episodes; i++) {
1422
+ allEps.push({ index: i, chunkIdx: 0, fileIdx: 0, from: 0, to: 0 });
1423
+ }
1424
+ }
1425
+
1426
+ if (allEps.length < 2) {
1427
+ console.warn(`[cross-ep] Only ${allEps.length} episode(s) found in metadata, need ≥2`);
1428
+ return null;
1429
+ }
1430
+ console.log(`[cross-ep] Found ${allEps.length} episodes in metadata, sampling up to ${maxEpisodes}`);
1431
+
1432
+ // Sample episodes evenly
1433
+ const sampled = allEps.length <= maxEpisodes
1434
+ ? allEps
1435
+ : Array.from({ length: maxEpisodes }, (_, i) =>
1436
+ allEps[Math.round((i * (allEps.length - 1)) / (maxEpisodes - 1))]
1437
+ );
1438
+
1439
+ // Load action data per episode, tracking episode index alongside
1440
+ const episodeActions: { index: number; actions: number[][] }[] = [];
1441
+
1442
+ if (version === "v3.0") {
1443
+ const byFile = new Map<string, EpMeta[]>();
1444
+ for (const ep of sampled) {
1445
+ const key = `${ep.chunkIdx}-${ep.fileIdx}`;
1446
+ if (!byFile.has(key)) byFile.set(key, []);
1447
+ byFile.get(key)!.push(ep);
1448
+ }
1449
+
1450
+ for (const [, eps] of byFile) {
1451
+ const ep0 = eps[0];
1452
+ const dataPath = `data/chunk-${ep0.chunkIdx.toString().padStart(3, "0")}/file-${ep0.fileIdx.toString().padStart(3, "0")}.parquet`;
1453
+ try {
1454
+ const buf = await fetchParquetFile(buildVersionedUrl(repoId, version, dataPath));
1455
+ const rows = await readParquetAsObjects(buf, []);
1456
+ const fileStart = rows.length > 0 && rows[0].index !== undefined ? Number(rows[0].index) : 0;
1457
+
1458
+ for (const ep of eps) {
1459
+ const localFrom = Math.max(0, ep.from - fileStart);
1460
+ const localTo = Math.min(rows.length, ep.to - fileStart);
1461
+ const actions: number[][] = [];
1462
+ for (let r = localFrom; r < localTo; r++) {
1463
+ const raw = rows[r]?.[actionKey];
1464
+ if (Array.isArray(raw)) actions.push(raw.map(Number));
1465
+ }
1466
+ if (actions.length > 0) episodeActions.push({ index: ep.index, actions });
1467
+ }
1468
+ } catch { /* skip file */ }
1469
+ }
1470
+ } else {
1471
+ const chunkSize = info.chunks_size || 1000;
1472
+ for (const ep of sampled) {
1473
+ const chunk = Math.floor(ep.index / chunkSize);
1474
+ const dataPath = formatStringWithVars(info.data_path, {
1475
+ episode_chunk: chunk.toString().padStart(3, "0"),
1476
+ episode_index: ep.index.toString().padStart(6, "0"),
1477
+ });
1478
+ try {
1479
+ const buf = await fetchParquetFile(buildVersionedUrl(repoId, version, dataPath));
1480
+ const rows = await readParquetAsObjects(buf, []);
1481
+ const actions: number[][] = [];
1482
+ for (const row of rows) {
1483
+ const raw = row[actionKey];
1484
+ if (Array.isArray(raw)) {
1485
+ actions.push(raw.map(Number));
1486
+ } else {
1487
+ const vec: number[] = [];
1488
+ for (let d = 0; d < actionDim; d++) {
1489
+ const v = row[`${actionKey}.${d}`] ?? row[d];
1490
+ vec.push(typeof v === "number" ? v : Number(v) || 0);
1491
+ }
1492
+ actions.push(vec);
1493
+ }
1494
+ }
1495
+ if (actions.length > 0) episodeActions.push({ index: ep.index, actions });
1496
+ } catch { /* skip */ }
1497
+ }
1498
+ }
1499
+
1500
+ if (episodeActions.length < 2) {
1501
+ console.warn(`[cross-ep] Only ${episodeActions.length} episode(s) had loadable action data out of ${sampled.length} sampled`);
1502
+ return null;
1503
+ }
1504
+ console.log(`[cross-ep] Loaded action data for ${episodeActions.length}/${sampled.length} episodes`);
1505
+
1506
+ // Resample each episode to numTimeBins and compute variance
1507
+ const timeBins = Array.from({ length: numTimeBins }, (_, i) => i / (numTimeBins - 1));
1508
+ const sums = Array.from({ length: numTimeBins }, () => new Float64Array(actionDim));
1509
+ const sumsSq = Array.from({ length: numTimeBins }, () => new Float64Array(actionDim));
1510
+ const counts = new Uint32Array(numTimeBins);
1511
+
1512
+ for (const { actions: epActions } of episodeActions) {
1513
+ const T = epActions.length;
1514
+ for (let b = 0; b < numTimeBins; b++) {
1515
+ const srcIdx = Math.min(Math.round(timeBins[b] * (T - 1)), T - 1);
1516
+ const row = epActions[srcIdx];
1517
+ for (let d = 0; d < actionDim; d++) {
1518
+ const v = row[d] ?? 0;
1519
+ sums[b][d] += v;
1520
+ sumsSq[b][d] += v * v;
1521
+ }
1522
+ counts[b]++;
1523
+ }
1524
+ }
1525
+
1526
+ const variance: number[][] = [];
1527
+ for (let b = 0; b < numTimeBins; b++) {
1528
+ const row: number[] = [];
1529
+ const n = counts[b];
1530
+ for (let d = 0; d < actionDim; d++) {
1531
+ if (n < 2) { row.push(0); continue; }
1532
+ const mean = sums[b][d] / n;
1533
+ row.push(sumsSq[b][d] / n - mean * mean);
1534
+ }
1535
+ variance.push(row);
1536
+ }
1537
+
1538
+ // Per-episode average movement per frame: mean L2 norm of frame-to-frame action deltas
1539
+ const movementScores: LowMovementEpisode[] = episodeActions.map(({ index, actions: ep }) => {
1540
+ if (ep.length < 2) return { episodeIndex: index, totalMovement: 0 };
1541
+ let total = 0;
1542
+ for (let t = 1; t < ep.length; t++) {
1543
+ let sumSq = 0;
1544
+ for (let d = 0; d < actionDim; d++) {
1545
+ const delta = (ep[t][d] ?? 0) - (ep[t - 1][d] ?? 0);
1546
+ sumSq += delta * delta;
1547
+ }
1548
+ total += Math.sqrt(sumSq);
1549
+ }
1550
+ const avgPerFrame = total / (ep.length - 1);
1551
+ return { episodeIndex: index, totalMovement: Math.round(avgPerFrame * 10000) / 10000 };
1552
+ });
1553
+
1554
+ movementScores.sort((a, b) => a.totalMovement - b.totalMovement);
1555
+ const lowMovementEpisodes = movementScores.slice(0, 10);
1556
+
1557
+ // Aggregated velocity stats: pool deltas from all episodes
1558
+ const shortName = (k: string) => { const p = k.split(SERIES_NAME_DELIMITER); return p.length > 1 ? p[p.length - 1] : k; };
1559
+
1560
+ const aggVelocity: AggVelocityStat[] = (() => {
1561
+ const binCount = 30;
1562
+ return Array.from({ length: actionDim }, (_, d) => {
1563
+ const deltas: number[] = [];
1564
+ for (const { actions: ep } of episodeActions) {
1565
+ for (let t = 1; t < ep.length; t++) {
1566
+ deltas.push((ep[t][d] ?? 0) - (ep[t - 1][d] ?? 0));
1567
+ }
1568
+ }
1569
+ if (deltas.length === 0) return { name: shortName(actionNames[d]), std: 0, maxAbs: 0, bins: [], lo: 0, hi: 0 };
1570
+ let sum = 0, maxAbs = 0, lo = Infinity, hi = -Infinity;
1571
+ for (const v of deltas) { sum += v; const a = Math.abs(v); if (a > maxAbs) maxAbs = a; if (v < lo) lo = v; if (v > hi) hi = v; }
1572
+ const mean = sum / deltas.length;
1573
+ let varSum = 0; for (const v of deltas) varSum += (v - mean) ** 2;
1574
+ const std = Math.sqrt(varSum / deltas.length);
1575
+ const range = hi - lo || 1;
1576
+ const binW = range / binCount;
1577
+ const bins = new Array(binCount).fill(0);
1578
+ for (const v of deltas) { let b = Math.floor((v - lo) / binW); if (b >= binCount) b = binCount - 1; bins[b]++; }
1579
+ return { name: shortName(actionNames[d]), std, maxAbs, bins, lo, hi };
1580
+ });
1581
+ })();
1582
+
1583
+ // Aggregated autocorrelation: average per-episode ACFs
1584
+ const aggAutocorrelation: AggAutocorrelation | null = (() => {
1585
+ const maxLag = Math.min(100, Math.floor(
1586
+ episodeActions.reduce((min, e) => Math.min(min, e.actions.length), Infinity) / 2
1587
+ ));
1588
+ if (maxLag < 2) return null;
1589
+
1590
+ const avgAcf: number[][] = Array.from({ length: actionDim }, () => new Array(maxLag).fill(0));
1591
+ let epCount = 0;
1592
+
1593
+ for (const { actions: ep } of episodeActions) {
1594
+ if (ep.length < maxLag * 2) continue;
1595
+ epCount++;
1596
+ for (let d = 0; d < actionDim; d++) {
1597
+ const vals = ep.map(row => row[d] ?? 0);
1598
+ const n = vals.length;
1599
+ const m = vals.reduce((a, b) => a + b, 0) / n;
1600
+ const centered = vals.map(v => v - m);
1601
+ const vari = centered.reduce((a, v) => a + v * v, 0);
1602
+ if (vari === 0) continue;
1603
+ for (let lag = 1; lag <= maxLag; lag++) {
1604
+ let s = 0;
1605
+ for (let t = 0; t < n - lag; t++) s += centered[t] * centered[t + lag];
1606
+ avgAcf[d][lag - 1] += s / vari;
1607
+ }
1608
+ }
1609
+ }
1610
+
1611
+ if (epCount === 0) return null;
1612
+ for (let d = 0; d < actionDim; d++) for (let l = 0; l < maxLag; l++) avgAcf[d][l] /= epCount;
1613
+
1614
+ const shortKeys = actionNames.map(shortName);
1615
+ const chartData = Array.from({ length: maxLag }, (_, lag) => {
1616
+ const row: Record<string, number> = { lag: lag + 1, time: (lag + 1) / fps };
1617
+ shortKeys.forEach((k, d) => { row[k] = avgAcf[d][lag]; });
1618
+ return row;
1619
+ });
1620
+
1621
+ // Suggested chunk: median lag where ACF drops below 0.5
1622
+ const lags = avgAcf.map(acf => { const i = acf.findIndex(v => v < 0.5); return i >= 0 ? i + 1 : null; }).filter(Boolean) as number[];
1623
+ const suggestedChunk = lags.length > 0 ? lags.sort((a, b) => a - b)[Math.floor(lags.length / 2)] : null;
1624
+
1625
+ return { chartData, suggestedChunk, shortKeys };
1626
+ })();
1627
+
1628
+ return { actionNames, timeBins, variance, numEpisodes: episodeActions.length, lowMovementEpisodes, aggVelocity, aggAutocorrelation };
1629
+ }
1630
+
1631
+ // Load only flatChartData for a specific episode (used by URDF viewer episode switching)
1632
+ export async function loadEpisodeFlatChartData(
1633
+ repoId: string,
1634
+ version: string,
1635
+ info: DatasetMetadata,
1636
+ episodeId: number,
1637
+ ): Promise<Record<string, number>[]> {
1638
+ const episodeMetadata = await loadEpisodeMetadataV3Simple(repoId, version, episodeId);
1639
+ const { flatChartData } = await loadEpisodeDataV3(repoId, version, info, episodeMetadata);
1640
+ return flatChartData;
1641
+ }
1642
+
1643
  // Safe wrapper for UI error display
1644
  export async function getEpisodeDataSafe(
1645
  org: string,
src/app/page.tsx CHANGED
@@ -130,14 +130,6 @@ function HomeInner() {
130
  <h1 className="text-4xl md:text-5xl font-bold mb-6 drop-shadow-lg">
131
  LeRobot Dataset Visualizer
132
  </h1>
133
- <a
134
- href="https://x.com/RemiCadene/status/1825455895561859185"
135
- target="_blank"
136
- rel="noopener noreferrer"
137
- className="text-sky-400 font-medium text-lg underline mb-8 inline-block hover:text-sky-300 transition-colors"
138
- >
139
- create & train your own robots
140
- </a>
141
  <form onSubmit={handleGo} className="flex gap-2 justify-center mt-6">
142
  <input
143
  ref={inputRef}
 
130
  <h1 className="text-4xl md:text-5xl font-bold mb-6 drop-shadow-lg">
131
  LeRobot Dataset Visualizer
132
  </h1>
 
 
 
 
 
 
 
 
133
  <form onSubmit={handleGo} className="flex gap-2 justify-center mt-6">
134
  <input
135
  ref={inputRef}
src/components/action-insights-panel.tsx ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "use client";
2
+
3
+ import React, { useMemo } from "react";
4
+ import {
5
+ LineChart,
6
+ Line,
7
+ XAxis,
8
+ YAxis,
9
+ CartesianGrid,
10
+ ResponsiveContainer,
11
+ Tooltip,
12
+ } from "recharts";
13
+ import type { CrossEpisodeVarianceData, LowMovementEpisode, AggVelocityStat, AggAutocorrelation } from "@/app/[org]/[dataset]/[episode]/fetch-data";
14
+
15
+ const DELIMITER = " | ";
16
+ const COLORS = [
17
+ "#f97316", "#3b82f6", "#22c55e", "#ef4444", "#a855f7",
18
+ "#eab308", "#06b6d4", "#ec4899", "#14b8a6", "#f59e0b",
19
+ "#6366f1", "#84cc16",
20
+ ];
21
+
22
+ function shortName(key: string): string {
23
+ const parts = key.split(DELIMITER);
24
+ return parts.length > 1 ? parts[parts.length - 1] : key;
25
+ }
26
+
27
+ function getActionKeys(row: Record<string, number>): string[] {
28
+ return Object.keys(row)
29
+ .filter(k => k.startsWith("action") && k !== "timestamp")
30
+ .sort();
31
+ }
32
+
33
+ // ─── Autocorrelation ─────────────────────────────────────────────
34
+
35
+ function computeAutocorrelation(values: number[], maxLag: number): number[] {
36
+ const n = values.length;
37
+ const mean = values.reduce((a, b) => a + b, 0) / n;
38
+ const centered = values.map(v => v - mean);
39
+ const variance = centered.reduce((a, v) => a + v * v, 0);
40
+ if (variance === 0) return Array(maxLag).fill(0);
41
+
42
+ const result: number[] = [];
43
+ for (let lag = 1; lag <= maxLag; lag++) {
44
+ let sum = 0;
45
+ for (let t = 0; t < n - lag; t++) sum += centered[t] * centered[t + lag];
46
+ result.push(sum / variance);
47
+ }
48
+ return result;
49
+ }
50
+
51
+ function findDecorrelationLag(acf: number[], threshold = 0.5): number | null {
52
+ const idx = acf.findIndex(v => v < threshold);
53
+ return idx >= 0 ? idx + 1 : null;
54
+ }
55
+
56
+ function AutocorrelationSection({ data, fps, agg, numEpisodes }: { data: Record<string, number>[]; fps: number; agg?: AggAutocorrelation | null; numEpisodes?: number }) {
57
+ const actionKeys = useMemo(() => (data.length > 0 ? getActionKeys(data[0]) : []), [data]);
58
+ const maxLag = useMemo(() => Math.min(Math.floor(data.length / 2), 100), [data]);
59
+
60
+ const fallback = useMemo(() => {
61
+ if (agg) return null;
62
+ if (actionKeys.length === 0 || maxLag < 2) return { chartData: [], suggestedChunk: null, shortKeys: [] as string[] };
63
+
64
+ const acfs = actionKeys.map(key => {
65
+ const values = data.map(row => row[key] ?? 0);
66
+ return computeAutocorrelation(values, maxLag);
67
+ });
68
+
69
+ const rows = Array.from({ length: maxLag }, (_, lag) => {
70
+ const row: Record<string, number> = { lag: lag + 1, time: (lag + 1) / fps };
71
+ actionKeys.forEach((key, ki) => { row[shortName(key)] = acfs[ki][lag]; });
72
+ return row;
73
+ });
74
+
75
+ const lags = acfs.map(acf => findDecorrelationLag(acf, 0.5)).filter(Boolean) as number[];
76
+ const suggested = lags.length > 0 ? lags.sort((a, b) => a - b)[Math.floor(lags.length / 2)] : null;
77
+
78
+ return { chartData: rows, suggestedChunk: suggested, shortKeys: actionKeys.map(shortName) };
79
+ }, [data, actionKeys, maxLag, fps, agg]);
80
+
81
+ const { chartData, suggestedChunk, shortKeys } = agg ?? fallback ?? { chartData: [], suggestedChunk: null, shortKeys: [] };
82
+ const numEpisodesLabel = agg ? ` (${numEpisodes} episodes sampled)` : " (current episode)";
83
+
84
+ if (shortKeys.length === 0) return <p className="text-slate-500 italic">No action columns found.</p>;
85
+
86
+ return (
87
+ <div className="bg-slate-800/60 rounded-lg p-5 border border-slate-700 space-y-4">
88
+ <div>
89
+ <h3 className="text-sm font-semibold text-slate-200">Action Autocorrelation<span className="text-xs text-slate-500 ml-2 font-normal">{numEpisodesLabel}</span></h3>
90
+ <p className="text-xs text-slate-400 mt-1">
91
+ Shows how correlated each action dimension is with itself over increasing time lags.
92
+ Where autocorrelation drops below 0.5 suggests a <span className="text-orange-400 font-medium">natural action chunk boundary</span> — actions
93
+ beyond this lag are essentially independent, so executing them open-loop offers diminishing returns.
94
+ <br />
95
+ <span className="text-slate-500">
96
+ Grounded in the theoretical result that chunk length should scale logarithmically with system stability constants
97
+ (Zhang et al., 2025 — arXiv:2507.09061, Theorem 1).
98
+ </span>
99
+ </p>
100
+ </div>
101
+
102
+ {suggestedChunk && (
103
+ <div className="flex items-center gap-3 bg-orange-500/10 border border-orange-500/30 rounded-md px-4 py-2.5">
104
+ <span className="text-orange-400 font-bold text-lg tabular-nums">{suggestedChunk}</span>
105
+ <div>
106
+ <p className="text-sm text-orange-300 font-medium">
107
+ Suggested chunk length: {suggestedChunk} steps ({(suggestedChunk / fps).toFixed(2)}s)
108
+ </p>
109
+ <p className="text-xs text-slate-400">Median lag where autocorrelation drops below 0.5 across action dimensions</p>
110
+ </div>
111
+ </div>
112
+ )}
113
+
114
+ <div className="h-64">
115
+ <ResponsiveContainer width="100%" height="100%">
116
+ <LineChart data={chartData} margin={{ top: 8, right: 16, left: 0, bottom: 16 }}>
117
+ <CartesianGrid strokeDasharray="3 3" stroke="#334155" />
118
+ <XAxis
119
+ dataKey="lag"
120
+ stroke="#94a3b8"
121
+ label={{ value: "Lag (steps)", position: "insideBottom", offset: -8, fill: "#94a3b8", fontSize: 11 }}
122
+ />
123
+ <YAxis stroke="#94a3b8" domain={[-0.2, 1]} />
124
+ <Tooltip
125
+ contentStyle={{ background: "#1e293b", border: "1px solid #475569", borderRadius: 6 }}
126
+ labelFormatter={(v) => `Lag ${v} (${(Number(v) / fps).toFixed(2)}s)`}
127
+ formatter={(v: number) => v.toFixed(3)}
128
+ />
129
+ <Line
130
+ dataKey={() => 0.5}
131
+ stroke="#64748b"
132
+ strokeDasharray="6 4"
133
+ dot={false}
134
+ name="0.5 threshold"
135
+ legendType="none"
136
+ isAnimationActive={false}
137
+ />
138
+ {shortKeys.map((name, i) => (
139
+ <Line
140
+ key={name}
141
+ dataKey={name}
142
+ stroke={COLORS[i % COLORS.length]}
143
+ dot={false}
144
+ strokeWidth={1.5}
145
+ legendType="none"
146
+ isAnimationActive={false}
147
+ />
148
+ ))}
149
+ </LineChart>
150
+ </ResponsiveContainer>
151
+ </div>
152
+
153
+ {/* Custom legend */}
154
+ <div className="flex flex-wrap gap-x-4 gap-y-1 px-1">
155
+ {shortKeys.map((name, i) => (
156
+ <div key={name} className="flex items-center gap-1.5">
157
+ <span className="w-3 h-[3px] rounded-full shrink-0" style={{ background: COLORS[i % COLORS.length] }} />
158
+ <span className="text-[11px] text-slate-400">{name}</span>
159
+ </div>
160
+ ))}
161
+ </div>
162
+ </div>
163
+ );
164
+ }
165
+
166
+ // ─── Action Velocity ─────────────────────────────────────────────
167
+
168
+ function ActionVelocitySection({ data, agg, numEpisodes }: { data: Record<string, number>[]; agg?: AggVelocityStat[]; numEpisodes?: number }) {
169
+ const actionKeys = useMemo(() => (data.length > 0 ? getActionKeys(data[0]) : []), [data]);
170
+
171
+ const fallbackStats = useMemo(() => {
172
+ if (agg && agg.length > 0) return null;
173
+ if (actionKeys.length === 0 || data.length < 2) return [];
174
+
175
+ return actionKeys.map(key => {
176
+ const values = data.map(row => row[key] ?? 0);
177
+ const deltas = values.slice(1).map((v, i) => v - values[i]);
178
+ const mean = deltas.reduce((a, b) => a + b, 0) / deltas.length;
179
+ const std = Math.sqrt(deltas.reduce((a, d) => a + (d - mean) ** 2, 0) / deltas.length);
180
+ const maxAbs = Math.max(...deltas.map(Math.abs));
181
+ const binCount = 30;
182
+ const lo = Math.min(...deltas);
183
+ const hi = Math.max(...deltas);
184
+ const range = hi - lo || 1;
185
+ const binW = range / binCount;
186
+ const bins: number[] = new Array(binCount).fill(0);
187
+ for (const d of deltas) { let b = Math.floor((d - lo) / binW); if (b >= binCount) b = binCount - 1; bins[b]++; }
188
+ return { name: shortName(key), std, maxAbs, bins, lo, hi };
189
+ });
190
+ }, [data, actionKeys, agg]);
191
+
192
+ const stats = (agg && agg.length > 0) ? agg : fallbackStats ?? [];
193
+ const isAgg = agg && agg.length > 0;
194
+
195
+ if (stats.length === 0) return <p className="text-slate-500 italic">No action data for velocity analysis.</p>;
196
+
197
+ const maxBinCount = Math.max(...stats.flatMap(s => s.bins));
198
+ const maxStd = Math.max(...stats.map(s => s.std));
199
+
200
+ const insight = useMemo(() => {
201
+ const smooth = stats.filter(s => s.std / maxStd < 0.4);
202
+ const moderate = stats.filter(s => s.std / maxStd >= 0.4 && s.std / maxStd < 0.7);
203
+ const jerky = stats.filter(s => s.std / maxStd >= 0.7);
204
+ const isGripper = (n: string) => /grip/i.test(n);
205
+ const jerkyNonGripper = jerky.filter(s => !isGripper(s.name));
206
+ const jerkyGripper = jerky.filter(s => isGripper(s.name));
207
+ const smoothRatio = smooth.length / stats.length;
208
+
209
+ let verdict: { label: string; color: string };
210
+ if (smoothRatio >= 0.6 && jerkyNonGripper.length === 0)
211
+ verdict = { label: "Smooth", color: "text-green-400" };
212
+ else if (jerkyNonGripper.length <= 2 && smoothRatio >= 0.3)
213
+ verdict = { label: "Moderate", color: "text-yellow-400" };
214
+ else
215
+ verdict = { label: "Jerky", color: "text-red-400" };
216
+
217
+ const lines: string[] = [];
218
+ if (smooth.length > 0)
219
+ lines.push(`${smooth.length} smooth (${smooth.map(s => s.name).join(", ")})`);
220
+ if (moderate.length > 0)
221
+ lines.push(`${moderate.length} moderate (${moderate.map(s => s.name).join(", ")})`);
222
+ if (jerkyNonGripper.length > 0)
223
+ lines.push(`${jerkyNonGripper.length} jerky (${jerkyNonGripper.map(s => s.name).join(", ")})`);
224
+ if (jerkyGripper.length > 0)
225
+ lines.push(`${jerkyGripper.length} gripper${jerkyGripper.length > 1 ? "s" : ""} jerky — expected for binary open/close`);
226
+
227
+ let tip: string;
228
+ if (verdict.label === "Smooth")
229
+ tip = "Actions are consistent — longer action chunks should work well.";
230
+ else if (verdict.label === "Moderate")
231
+ tip = "Some dimensions show abrupt changes. Consider moderate chunk sizes.";
232
+ else
233
+ tip = "Many dimensions are jerky. Use shorter action chunks and consider filtering outlier episodes.";
234
+
235
+ return { verdict, lines, tip };
236
+ }, [stats, maxStd]);
237
+
238
+ return (
239
+ <div className="bg-slate-800/60 rounded-lg p-5 border border-slate-700 space-y-4">
240
+ <div>
241
+ <h3 className="text-sm font-semibold text-slate-200">Action Velocity (Δa) — Smoothness Proxy<span className="text-xs text-slate-500 ml-2 font-normal">{isAgg ? `(${numEpisodes} episodes sampled)` : "(current episode)"}</span></h3>
242
+ <p className="text-xs text-slate-400 mt-1">
243
+ Shows the distribution of frame-to-frame action changes (Δa = a<sub>t+1</sub> − a<sub>t</sub>) for each dimension.
244
+ A <span className="text-green-400">tight distribution around zero</span> means smooth, predictable control — the system
245
+ is likely stable and benefits from longer action chunks.
246
+ <span className="text-red-400"> Fat tails or high std</span> indicate jerky demonstrations, suggesting shorter chunks
247
+ and potentially beneficial noise injection.
248
+ <br />
249
+ <span className="text-slate-500">
250
+ Relates to the Lipschitz constant L<sub>π</sub> and smoothness C<sub>π</sub> in Zhang et al. (2025), which govern
251
+ compounding error bounds (Assumptions 3.1, 4.1).
252
+ </span>
253
+ </p>
254
+ </div>
255
+
256
+ {/* Per-dimension mini histograms + stats */}
257
+ <div className="grid gap-2" style={{ gridTemplateColumns: "repeat(auto-fill, minmax(180px, 1fr))" }}>
258
+ {stats.map((s, si) => {
259
+ const barH = 28;
260
+ return (
261
+ <div key={s.name} className="bg-slate-900/50 rounded-md px-2.5 py-2 space-y-1">
262
+ <p className="text-[11px] font-medium text-slate-200 truncate" title={s.name}>{s.name}</p>
263
+ <div className="flex gap-2 text-[9px] text-slate-400 tabular-nums">
264
+ <span>σ={s.std.toFixed(4)}</span>
265
+ <span>|Δ|<sub>max</sub>={s.maxAbs.toFixed(4)}</span>
266
+ </div>
267
+ <svg width="100%" viewBox={`0 0 ${s.bins.length} ${barH}`} preserveAspectRatio="none" className="h-7 rounded" aria-label={`Δa distribution for ${s.name}`}>
268
+ {[...s.bins].map((count, bi) => {
269
+ const h = maxBinCount > 0 ? (count / maxBinCount) * barH : 0;
270
+ return <rect key={bi} x={bi} y={barH - h} width={0.85} height={h} fill={COLORS[si % COLORS.length]} opacity={0.7} />;
271
+ })}
272
+ </svg>
273
+ <div className="h-1 w-full bg-slate-700 rounded-full overflow-hidden">
274
+ <div
275
+ className="h-full rounded-full"
276
+ style={{
277
+ width: `${Math.min(100, (s.std / maxStd) * 100)}%`,
278
+ background: s.std / maxStd < 0.4 ? "#22c55e" : s.std / maxStd < 0.7 ? "#eab308" : "#ef4444",
279
+ }}
280
+ />
281
+ </div>
282
+ </div>
283
+ );
284
+ })}
285
+ </div>
286
+
287
+ <div className="bg-slate-900/60 rounded-md px-4 py-3 border border-slate-700/60 space-y-1.5">
288
+ <p className="text-sm font-medium text-slate-200">
289
+ Overall: <span className={insight.verdict.color}>{insight.verdict.label}</span>
290
+ </p>
291
+ <ul className="text-xs text-slate-400 space-y-0.5 list-disc list-inside">
292
+ {insight.lines.map((l, i) => <li key={i}>{l}</li>)}
293
+ </ul>
294
+ <p className="text-xs text-slate-500 pt-1">{insight.tip}</p>
295
+ </div>
296
+ </div>
297
+ );
298
+ }
299
+
300
+ // ─── Cross-Episode Variance Heatmap ──────────────────────────────
301
+
302
+ function VarianceHeatmap({ data, loading }: { data: CrossEpisodeVarianceData | null; loading: boolean }) {
303
+ if (loading) {
304
+ return (
305
+ <div className="bg-slate-800/60 rounded-lg p-5 border border-slate-700">
306
+ <h3 className="text-sm font-semibold text-slate-200 mb-2">Cross-Episode Action Variance</h3>
307
+ <div className="flex items-center gap-2 text-slate-400 text-sm py-8 justify-center">
308
+ <svg className="animate-spin h-4 w-4" viewBox="0 0 24 24" fill="none">
309
+ <circle className="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" strokeWidth="4" />
310
+ <path className="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4z" />
311
+ </svg>
312
+ Loading cross-episode data (sampled up to 500 episodes)…
313
+ </div>
314
+ </div>
315
+ );
316
+ }
317
+
318
+ if (!data) {
319
+ return (
320
+ <div className="bg-slate-800/60 rounded-lg p-5 border border-slate-700">
321
+ <h3 className="text-sm font-semibold text-slate-200 mb-2">Cross-Episode Action Variance</h3>
322
+ <p className="text-slate-500 italic text-sm">Not enough episodes or no action data to compute variance.</p>
323
+ </div>
324
+ );
325
+ }
326
+
327
+ const { actionNames, timeBins, variance, numEpisodes } = data;
328
+ const numDims = actionNames.length;
329
+ const numBins = timeBins.length;
330
+
331
+ // Find global max variance for color scale
332
+ const maxVar = Math.max(...variance.flat(), 1e-10);
333
+
334
+ // Heatmap dimensions
335
+ const cellW = Math.max(6, Math.min(14, Math.floor(560 / numBins)));
336
+ const cellH = Math.max(20, Math.min(36, Math.floor(300 / numDims)));
337
+ const labelW = 100;
338
+ const svgW = labelW + numBins * cellW + 60; // 60 for color bar
339
+ const svgH = numDims * cellH + 40; // 40 for x-axis label
340
+
341
+ function varColor(v: number): string {
342
+ const t = Math.sqrt(v / maxVar); // sqrt for better visual spread
343
+ // Dark blue → teal → orange
344
+ const r = Math.round(t * 249);
345
+ const g = Math.round(t < 0.5 ? 80 + t * 200 : 180 - (t - 0.5) * 200);
346
+ const b = Math.round((1 - t) * 200 + 30);
347
+ return `rgb(${r},${g},${b})`;
348
+ }
349
+
350
+ return (
351
+ <div className="bg-slate-800/60 rounded-lg p-5 border border-slate-700 space-y-4">
352
+ <div>
353
+ <h3 className="text-sm font-semibold text-slate-200">
354
+ Cross-Episode Action Variance
355
+ <span className="text-xs text-slate-500 ml-2 font-normal">({numEpisodes} episodes sampled)</span>
356
+ </h3>
357
+ <p className="text-xs text-slate-400 mt-1">
358
+ Shows how much each action dimension varies across episodes at each point in time (normalized 0–100%).
359
+ <span className="text-orange-400"> High-variance regions</span> indicate multi-modal or inconsistent demonstrations —
360
+ generative policies (diffusion, flow-matching) and action chunking help here by modeling multiple modes.
361
+ <span className="text-blue-400"> Low-variance regions</span> indicate consistent behavior across demonstrations.
362
+ <br />
363
+ <span className="text-slate-500">
364
+ Relates to the &quot;coverage&quot; discussion in Zhang et al. (2025) — regions with low variance may lack the
365
+ exploratory coverage needed to prevent compounding errors (Section 4).
366
+ </span>
367
+ </p>
368
+ </div>
369
+
370
+ <div className="overflow-x-auto">
371
+ <svg width={svgW} height={svgH} className="block">
372
+ {/* Heatmap cells */}
373
+ {variance.map((row, bi) =>
374
+ row.map((v, di) => (
375
+ <rect
376
+ key={`${bi}-${di}`}
377
+ x={labelW + bi * cellW}
378
+ y={di * cellH}
379
+ width={cellW}
380
+ height={cellH}
381
+ fill={varColor(v)}
382
+ stroke="#1e293b"
383
+ strokeWidth={0.5}
384
+ >
385
+ <title>{`${shortName(actionNames[di])} @ ${(timeBins[bi] * 100).toFixed(0)}%: var=${v.toFixed(5)}`}</title>
386
+ </rect>
387
+ ))
388
+ )}
389
+
390
+ {/* Y-axis: action names */}
391
+ {actionNames.map((name, di) => (
392
+ <text
393
+ key={di}
394
+ x={labelW - 4}
395
+ y={di * cellH + cellH / 2}
396
+ textAnchor="end"
397
+ dominantBaseline="central"
398
+ className="fill-slate-400"
399
+ fontSize={Math.min(11, cellH - 4)}
400
+ >
401
+ {shortName(name)}
402
+ </text>
403
+ ))}
404
+
405
+ {/* X-axis labels */}
406
+ {[0, 0.25, 0.5, 0.75, 1].map(frac => {
407
+ const binIdx = Math.round(frac * (numBins - 1));
408
+ return (
409
+ <text
410
+ key={frac}
411
+ x={labelW + binIdx * cellW + cellW / 2}
412
+ y={numDims * cellH + 14}
413
+ textAnchor="middle"
414
+ className="fill-slate-400"
415
+ fontSize={9}
416
+ >
417
+ {(frac * 100).toFixed(0)}%
418
+ </text>
419
+ );
420
+ })}
421
+ <text
422
+ x={labelW + (numBins * cellW) / 2}
423
+ y={numDims * cellH + 30}
424
+ textAnchor="middle"
425
+ className="fill-slate-500"
426
+ fontSize={10}
427
+ >
428
+ Episode progress
429
+ </text>
430
+
431
+ {/* Color bar */}
432
+ {Array.from({ length: 10 }, (_, i) => {
433
+ const t = i / 9;
434
+ const barX = labelW + numBins * cellW + 16;
435
+ const barH = (numDims * cellH) / 10;
436
+ return (
437
+ <rect
438
+ key={i}
439
+ x={barX}
440
+ y={(9 - i) * barH}
441
+ width={12}
442
+ height={barH}
443
+ fill={varColor(t * maxVar)}
444
+ />
445
+ );
446
+ })}
447
+ <text
448
+ x={labelW + numBins * cellW + 34}
449
+ y={10}
450
+ className="fill-slate-500"
451
+ fontSize={8}
452
+ dominantBaseline="central"
453
+ >
454
+ high
455
+ </text>
456
+ <text
457
+ x={labelW + numBins * cellW + 34}
458
+ y={numDims * cellH - 4}
459
+ className="fill-slate-500"
460
+ fontSize={8}
461
+ dominantBaseline="central"
462
+ >
463
+ low
464
+ </text>
465
+ </svg>
466
+ </div>
467
+ </div>
468
+ );
469
+ }
470
+
471
+ // ─── Low-Movement Episodes ──────────────────────────────────────
472
+
473
+ function LowMovementSection({ episodes }: { episodes: LowMovementEpisode[] }) {
474
+ if (episodes.length === 0) return null;
475
+
476
+ const maxMovement = Math.max(...episodes.map(e => e.totalMovement), 1e-10);
477
+
478
+ return (
479
+ <div className="bg-slate-800/60 rounded-lg p-5 border border-slate-700 space-y-3">
480
+ <div>
481
+ <h3 className="text-sm font-semibold text-slate-200">Lowest-Movement Episodes</h3>
482
+ <p className="text-xs text-slate-400 mt-1">
483
+ Episodes with the lowest average action change per frame (mean ‖Δa<sub>t</sub>‖). Very low values may indicate the robot
484
+ was <span className="text-yellow-400">standing still</span> or the episode was recorded incorrectly.
485
+ </p>
486
+ </div>
487
+ <div className="grid gap-2" style={{ gridTemplateColumns: "repeat(auto-fill, minmax(220px, 1fr))" }}>
488
+ {episodes.map(ep => (
489
+ <div key={ep.episodeIndex} className="bg-slate-900/50 rounded-md px-3 py-2 flex items-center gap-3">
490
+ <span className="text-xs text-slate-300 font-medium shrink-0">ep {ep.episodeIndex}</span>
491
+ <div className="flex-1 min-w-0">
492
+ <div className="h-1.5 bg-slate-700 rounded-full overflow-hidden">
493
+ <div
494
+ className="h-full rounded-full"
495
+ style={{
496
+ width: `${Math.max(2, (ep.totalMovement / maxMovement) * 100)}%`,
497
+ background: ep.totalMovement / maxMovement < 0.15 ? "#ef4444" : ep.totalMovement / maxMovement < 0.4 ? "#eab308" : "#22c55e",
498
+ }}
499
+ />
500
+ </div>
501
+ </div>
502
+ <span className="text-[10px] text-slate-500 tabular-nums shrink-0">{ep.totalMovement.toFixed(2)}</span>
503
+ </div>
504
+ ))}
505
+ </div>
506
+ </div>
507
+ );
508
+ }
509
+
510
+ // ─── Main Panel ──────────────────────────────────────────────────
511
+
512
+ interface ActionInsightsPanelProps {
513
+ flatChartData: Record<string, number>[];
514
+ fps: number;
515
+ crossEpisodeData: CrossEpisodeVarianceData | null;
516
+ crossEpisodeLoading: boolean;
517
+ }
518
+
519
+ const ActionInsightsPanel: React.FC<ActionInsightsPanelProps> = ({
520
+ flatChartData,
521
+ fps,
522
+ crossEpisodeData,
523
+ crossEpisodeLoading,
524
+ }) => {
525
+ return (
526
+ <div className="max-w-5xl mx-auto py-6 space-y-8">
527
+ <div>
528
+ <h2 className="text-xl font-bold text-slate-100">Action Insights</h2>
529
+ <p className="text-sm text-slate-400 mt-1">
530
+ Data-driven analysis to guide action chunking, data quality assessment, and training configuration.
531
+ </p>
532
+ </div>
533
+
534
+ <AutocorrelationSection data={flatChartData} fps={fps} agg={crossEpisodeData?.aggAutocorrelation} numEpisodes={crossEpisodeData?.numEpisodes} />
535
+ <ActionVelocitySection data={flatChartData} agg={crossEpisodeData?.aggVelocity} numEpisodes={crossEpisodeData?.numEpisodes} />
536
+ <VarianceHeatmap data={crossEpisodeData} loading={crossEpisodeLoading} />
537
+ {crossEpisodeData?.lowMovementEpisodes && (
538
+ <LowMovementSection episodes={crossEpisodeData.lowMovementEpisodes} />
539
+ )}
540
+ </div>
541
+ );
542
+ };
543
+
544
+ export default ActionInsightsPanel;
545
+
src/components/data-recharts.tsx CHANGED
@@ -10,6 +10,7 @@ import {
10
  CartesianGrid,
11
  ResponsiveContainer,
12
  Tooltip,
 
13
  } from "recharts";
14
 
15
  type ChartRow = Record<string, number | Record<string, number>>;
@@ -21,9 +22,14 @@ type DataGraphProps = {
21
 
22
  import React, { useMemo } from "react";
23
 
24
- // Use the same delimiter as the data processing
25
  const SERIES_NAME_DELIMITER = " | ";
26
 
 
 
 
 
 
 
27
  export const DataRecharts = React.memo(
28
  ({ data, onChartsReady }: DataGraphProps) => {
29
  // Shared hoveredTime for all graphs
@@ -112,11 +118,10 @@ const SingleDataGraph = React.memo(
112
  }
113
  });
114
 
115
- // Assign a color per group (and for singles)
116
  const allGroups = [...Object.keys(groups), ...singles];
117
  const groupColorMap: Record<string, string> = {};
118
  allGroups.forEach((group, idx) => {
119
- groupColorMap[group] = `hsl(${idx * (360 / allGroups.length)}, 100%, 50%)`;
120
  });
121
 
122
  // Find the closest data point to the current time for highlighting
@@ -160,11 +165,10 @@ const SingleDataGraph = React.memo(
160
  }
161
  });
162
 
163
- // Assign a color per group (and for singles)
164
  const allGroups = [...Object.keys(groups), ...singles];
165
  const groupColorMap: Record<string, string> = {};
166
  allGroups.forEach((group, idx) => {
167
- groupColorMap[group] = `hsl(${idx * (360 / allGroups.length)}, 100%, 50%)`;
168
  });
169
 
170
  const isGroupChecked = (group: string) => groups[group].every(k => visibleKeys.includes(k));
@@ -187,58 +191,59 @@ const SingleDataGraph = React.memo(
187
  };
188
 
189
  return (
190
- <div className="grid grid-cols-[repeat(auto-fit,250px)] gap-4 mx-8">
191
- {/* Grouped keys */}
192
  {Object.entries(groups).map(([group, children]) => {
193
  const color = groupColorMap[group];
194
  return (
195
- <div key={group} className="mb-2">
196
- <label className="flex gap-2 cursor-pointer select-none font-semibold">
197
  <input
198
  type="checkbox"
199
  checked={isGroupChecked(group)}
200
  ref={el => { if (el) el.indeterminate = isGroupIndeterminate(group); }}
201
  onChange={() => handleGroupCheckboxChange(group)}
202
- className="size-3.5 mt-1"
203
  style={{ accentColor: color }}
204
  />
205
- <span className="text-sm w-40 text-white">{group}</span>
206
  </label>
207
- <div className="pl-7 flex flex-col gap-1 mt-1">
208
- {children.map((key) => (
209
- <label key={key} className="flex gap-2 cursor-pointer select-none">
210
- <input
211
- type="checkbox"
212
- checked={visibleKeys.includes(key)}
213
- onChange={() => handleCheckboxChange(key)}
214
- className="size-3.5 mt-1"
215
- style={{ accentColor: color }}
216
- />
217
- <span className={`text-xs break-all w-36 ${visibleKeys.includes(key) ? "text-white" : "text-gray-400"}`}>{key.slice(group.length + 1)}</span>
218
- <span className={`text-xs font-mono ml-auto ${visibleKeys.includes(key) ? "text-orange-300" : "text-gray-500"}`}>
219
- {typeof currentData[key] === "number" ? currentData[key].toFixed(2) : "--"}
220
- </span>
221
- </label>
222
- ))}
 
 
 
223
  </div>
224
  </div>
225
  );
226
  })}
227
- {/* Singles (non-grouped) */}
228
  {singles.map((key) => {
229
  const color = groupColorMap[key];
230
  return (
231
- <label key={key} className="flex gap-2 cursor-pointer select-none">
232
  <input
233
  type="checkbox"
234
  checked={visibleKeys.includes(key)}
235
  onChange={() => handleCheckboxChange(key)}
236
- className="size-3.5 mt-1"
237
  style={{ accentColor: color }}
238
  />
239
- <span className={`text-sm break-all w-40 ${visibleKeys.includes(key) ? "text-white" : "text-gray-400"}`}>{key}</span>
240
- <span className={`text-sm font-mono ml-auto ${visibleKeys.includes(key) ? "text-orange-300" : "text-gray-500"}`}>
241
- {typeof currentData[key] === "number" ? currentData[key].toFixed(2) : "--"}
242
  </span>
243
  </label>
244
  );
@@ -247,14 +252,30 @@ const SingleDataGraph = React.memo(
247
  );
248
  };
249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  return (
251
- <div className="w-full">
252
- <div className="w-full h-80" onMouseLeave={handleMouseLeave}>
 
 
 
253
  <ResponsiveContainer width="100%" height="100%">
254
  <LineChart
255
  data={chartData}
256
  syncId="episode-sync"
257
- margin={{ top: 24, right: 16, left: 0, bottom: 16 }}
258
  onClick={handleClick}
259
  onMouseMove={(state) => {
260
  const payload = state?.activePayload?.[0]?.payload as { timestamp?: number } | undefined;
@@ -262,33 +283,24 @@ const SingleDataGraph = React.memo(
262
  }}
263
  onMouseLeave={handleMouseLeave}
264
  >
265
- <CartesianGrid strokeDasharray="3 3" stroke="#444" />
266
  <XAxis
267
  dataKey="timestamp"
268
- label={{
269
- value: "time",
270
- position: "insideBottomLeft",
271
- fill: "#cbd5e1",
272
- }}
273
  domain={[
274
  chartData.at(0)?.timestamp ?? 0,
275
  chartData.at(-1)?.timestamp ?? 0,
276
  ]}
277
- ticks={useMemo(
278
- () =>
279
- Array.from(
280
- new Set(chartData.map((d) => Math.ceil(d.timestamp))),
281
- ),
282
- [chartData],
283
- )}
284
- stroke="#cbd5e1"
285
- minTickGap={20} // Increased for fewer ticks
286
  allowDataOverflow={true}
287
  />
288
  <YAxis
289
  domain={["auto", "auto"]}
290
- stroke="#cbd5e1"
291
- interval={0}
 
292
  allowDataOverflow={true}
293
  />
294
 
@@ -301,9 +313,14 @@ const SingleDataGraph = React.memo(
301
  }
302
  />
303
 
304
- {/* Render lines for visible dataKeys only */}
 
 
 
 
 
 
305
  {dataKeys.map((key) => {
306
- // Use group color for all keys in a group
307
  const group = key.includes(SERIES_NAME_DELIMITER) ? key.split(SERIES_NAME_DELIMITER)[0] : key;
308
  const color = groupColorMap[group];
309
  let strokeDasharray: string | undefined = undefined;
 
10
  CartesianGrid,
11
  ResponsiveContainer,
12
  Tooltip,
13
+ ReferenceLine,
14
  } from "recharts";
15
 
16
  type ChartRow = Record<string, number | Record<string, number>>;
 
22
 
23
  import React, { useMemo } from "react";
24
 
 
25
  const SERIES_NAME_DELIMITER = " | ";
26
 
27
+ const CHART_COLORS = [
28
+ "#f97316", "#3b82f6", "#22c55e", "#ef4444", "#a855f7",
29
+ "#eab308", "#06b6d4", "#ec4899", "#14b8a6", "#f59e0b",
30
+ "#6366f1", "#84cc16",
31
+ ];
32
+
33
  export const DataRecharts = React.memo(
34
  ({ data, onChartsReady }: DataGraphProps) => {
35
  // Shared hoveredTime for all graphs
 
118
  }
119
  });
120
 
 
121
  const allGroups = [...Object.keys(groups), ...singles];
122
  const groupColorMap: Record<string, string> = {};
123
  allGroups.forEach((group, idx) => {
124
+ groupColorMap[group] = CHART_COLORS[idx % CHART_COLORS.length];
125
  });
126
 
127
  // Find the closest data point to the current time for highlighting
 
165
  }
166
  });
167
 
 
168
  const allGroups = [...Object.keys(groups), ...singles];
169
  const groupColorMap: Record<string, string> = {};
170
  allGroups.forEach((group, idx) => {
171
+ groupColorMap[group] = CHART_COLORS[idx % CHART_COLORS.length];
172
  });
173
 
174
  const isGroupChecked = (group: string) => groups[group].every(k => visibleKeys.includes(k));
 
191
  };
192
 
193
  return (
194
+ <div className="flex flex-wrap gap-x-5 gap-y-2 px-1 pt-2">
 
195
  {Object.entries(groups).map(([group, children]) => {
196
  const color = groupColorMap[group];
197
  return (
198
+ <div key={group}>
199
+ <label className="flex items-center gap-1.5 cursor-pointer select-none">
200
  <input
201
  type="checkbox"
202
  checked={isGroupChecked(group)}
203
  ref={el => { if (el) el.indeterminate = isGroupIndeterminate(group); }}
204
  onChange={() => handleGroupCheckboxChange(group)}
205
+ className="size-3"
206
  style={{ accentColor: color }}
207
  />
208
+ <span className="text-[11px] font-semibold text-slate-200">{group}</span>
209
  </label>
210
+ <div className="pl-5 flex flex-col gap-0.5 mt-0.5">
211
+ {children.map((key) => {
212
+ const label = key.split(SERIES_NAME_DELIMITER).pop() ?? key;
213
+ return (
214
+ <label key={key} className="flex items-center gap-1.5 cursor-pointer select-none">
215
+ <input
216
+ type="checkbox"
217
+ checked={visibleKeys.includes(key)}
218
+ onChange={() => handleCheckboxChange(key)}
219
+ className="size-2.5"
220
+ style={{ accentColor: color }}
221
+ />
222
+ <span className={`text-[10px] ${visibleKeys.includes(key) ? "text-slate-300" : "text-slate-500"}`}>{label}</span>
223
+ <span className={`text-[10px] font-mono tabular-nums ml-1 ${visibleKeys.includes(key) ? "text-orange-300/80" : "text-slate-600"}`}>
224
+ {typeof currentData[key] === "number" ? currentData[key].toFixed(2) : "–"}
225
+ </span>
226
+ </label>
227
+ );
228
+ })}
229
  </div>
230
  </div>
231
  );
232
  })}
 
233
  {singles.map((key) => {
234
  const color = groupColorMap[key];
235
  return (
236
+ <label key={key} className="flex items-center gap-1.5 cursor-pointer select-none">
237
  <input
238
  type="checkbox"
239
  checked={visibleKeys.includes(key)}
240
  onChange={() => handleCheckboxChange(key)}
241
+ className="size-3"
242
  style={{ accentColor: color }}
243
  />
244
+ <span className={`text-[11px] ${visibleKeys.includes(key) ? "text-slate-200" : "text-slate-500"}`}>{key}</span>
245
+ <span className={`text-[10px] font-mono tabular-nums ml-1 ${visibleKeys.includes(key) ? "text-orange-300/80" : "text-slate-600"}`}>
246
+ {typeof currentData[key] === "number" ? currentData[key].toFixed(2) : ""}
247
  </span>
248
  </label>
249
  );
 
252
  );
253
  };
254
 
255
+ // Derive chart title from the grouped feature names
256
+ const chartTitle = useMemo(() => {
257
+ const featureNames = Object.keys(groups);
258
+ if (featureNames.length > 0) {
259
+ const suffixes = featureNames.map(g => {
260
+ const parts = g.split(SERIES_NAME_DELIMITER);
261
+ return parts[parts.length - 1];
262
+ });
263
+ return suffixes.join(", ");
264
+ }
265
+ return singles.join(", ");
266
+ }, [groups, singles]);
267
+
268
  return (
269
+ <div className="w-full bg-slate-800/40 rounded-lg border border-slate-700/50 p-3">
270
+ {chartTitle && (
271
+ <p className="text-xs font-medium text-slate-300 mb-1 px-1 truncate" title={chartTitle}>{chartTitle}</p>
272
+ )}
273
+ <div className="w-full h-72" onMouseLeave={handleMouseLeave}>
274
  <ResponsiveContainer width="100%" height="100%">
275
  <LineChart
276
  data={chartData}
277
  syncId="episode-sync"
278
+ margin={{ top: 12, right: 12, left: -8, bottom: 8 }}
279
  onClick={handleClick}
280
  onMouseMove={(state) => {
281
  const payload = state?.activePayload?.[0]?.payload as { timestamp?: number } | undefined;
 
283
  }}
284
  onMouseLeave={handleMouseLeave}
285
  >
286
+ <CartesianGrid strokeDasharray="3 3" stroke="#334155" strokeOpacity={0.6} />
287
  <XAxis
288
  dataKey="timestamp"
 
 
 
 
 
289
  domain={[
290
  chartData.at(0)?.timestamp ?? 0,
291
  chartData.at(-1)?.timestamp ?? 0,
292
  ]}
293
+ tickFormatter={(v: number) => `${v.toFixed(1)}s`}
294
+ stroke="#64748b"
295
+ tick={{ fontSize: 10, fill: "#94a3b8" }}
296
+ minTickGap={30}
 
 
 
 
 
297
  allowDataOverflow={true}
298
  />
299
  <YAxis
300
  domain={["auto", "auto"]}
301
+ stroke="#64748b"
302
+ tick={{ fontSize: 10, fill: "#94a3b8" }}
303
+ width={45}
304
  allowDataOverflow={true}
305
  />
306
 
 
313
  }
314
  />
315
 
316
+ <ReferenceLine
317
+ x={currentTime}
318
+ stroke="#f97316"
319
+ strokeWidth={1.5}
320
+ strokeOpacity={0.7}
321
+ />
322
+
323
  {dataKeys.map((key) => {
 
324
  const group = key.includes(SERIES_NAME_DELIMITER) ? key.split(SERIES_NAME_DELIMITER)[0] : key;
325
  const color = groupColorMap[group];
326
  let strokeDasharray: string | undefined = undefined;
src/components/urdf-viewer.tsx CHANGED
@@ -1,141 +1,24 @@
1
  "use client";
2
 
3
- import React, { useState, useEffect, useRef, useMemo, useCallback, Suspense } from "react";
4
- import { Canvas, useFrame, useLoader } from "@react-three/fiber";
5
- import { OrbitControls, Grid, Environment } from "@react-three/drei";
6
  import * as THREE from "three";
7
- import { STLLoader } from "three/addons/loaders/STLLoader.js";
8
- import {
9
- SO101_JOINTS,
10
- SO101_LINKS,
11
- MATERIAL_COLORS,
12
- autoMatchJoints,
13
- type JointDef,
14
- type MeshDef,
15
- } from "@/lib/so101-robot";
16
  import type { EpisodeData } from "@/app/[org]/[dataset]/[episode]/fetch-data";
17
 
18
  const SERIES_DELIM = " | ";
 
19
 
20
- // ─── STL Mesh component ───
21
- function STLMesh({ mesh }: { mesh: MeshDef }) {
22
- const geometry = useLoader(STLLoader, mesh.file);
23
- const color = MATERIAL_COLORS[mesh.material];
24
- return (
25
- <mesh
26
- geometry={geometry}
27
- position={mesh.origin.xyz}
28
- rotation={new THREE.Euler(...mesh.origin.rpy, "XYZ")}
29
- >
30
- <meshStandardMaterial
31
- color={color}
32
- metalness={mesh.material === "motor" ? 0.7 : 0.1}
33
- roughness={mesh.material === "motor" ? 0.3 : 0.6}
34
- />
35
- </mesh>
36
- );
37
- }
38
-
39
- // ─── Link visual: renders all meshes for a link ───
40
- function LinkVisual({ linkIndex }: { linkIndex: number }) {
41
- const link = SO101_LINKS[linkIndex];
42
- if (!link) return null;
43
- return (
44
- <>
45
- {link.meshes.map((mesh, i) => (
46
- <STLMesh key={i} mesh={mesh} />
47
- ))}
48
- </>
49
- );
50
  }
51
 
52
- // ─── Joint group: applies origin transform + joint rotation ───
53
- function JointGroup({
54
- joint,
55
- angle,
56
- linkIndex,
57
- children,
58
- }: {
59
- joint: JointDef;
60
- angle: number;
61
- linkIndex: number;
62
- children?: React.ReactNode;
63
- }) {
64
- const rotRef = useRef<THREE.Group>(null);
65
-
66
- useEffect(() => {
67
- if (rotRef.current) {
68
- rotRef.current.quaternion.setFromAxisAngle(new THREE.Vector3(...joint.axis), angle);
69
- }
70
- }, [angle, joint.axis]);
71
-
72
- return (
73
- <group position={joint.origin.xyz} rotation={new THREE.Euler(...joint.origin.rpy, "XYZ")}>
74
- <group ref={rotRef}>
75
- <LinkVisual linkIndex={linkIndex} />
76
- {children}
77
- </group>
78
- </group>
79
- );
80
- }
81
-
82
- // ─── Full robot arm ───
83
- function RobotArm({ angles }: { angles: Record<string, number> }) {
84
- return (
85
- <group>
86
- {/* Base link (no parent joint) */}
87
- <LinkVisual linkIndex={0} />
88
-
89
- {/* shoulder_pan → shoulder_link (1) */}
90
- <JointGroup joint={SO101_JOINTS[0]} angle={angles.shoulder_pan ?? 0} linkIndex={1}>
91
- {/* shoulder_lift → upper_arm_link (2) */}
92
- <JointGroup joint={SO101_JOINTS[1]} angle={angles.shoulder_lift ?? 0} linkIndex={2}>
93
- {/* elbow_flex → lower_arm_link (3) */}
94
- <JointGroup joint={SO101_JOINTS[2]} angle={angles.elbow_flex ?? 0} linkIndex={3}>
95
- {/* wrist_flex → wrist_link (4) */}
96
- <JointGroup joint={SO101_JOINTS[3]} angle={angles.wrist_flex ?? 0} linkIndex={4}>
97
- {/* wrist_roll → gripper_link (5) */}
98
- <JointGroup joint={SO101_JOINTS[4]} angle={angles.wrist_roll ?? 0} linkIndex={5}>
99
- {/* gripper → moving_jaw (6) */}
100
- <JointGroup joint={SO101_JOINTS[5]} angle={angles.gripper ?? 0} linkIndex={6} />
101
- </JointGroup>
102
- </JointGroup>
103
- </JointGroup>
104
- </JointGroup>
105
- </JointGroup>
106
- </group>
107
- );
108
- }
109
-
110
- // ─── Playback driver (advances frame inside Canvas render loop) ───
111
- function PlaybackDriver({
112
- playing,
113
- fps,
114
- totalFrames,
115
- frameRef,
116
- }: {
117
- playing: boolean;
118
- fps: number;
119
- totalFrames: number;
120
- frameRef: React.MutableRefObject<number>;
121
- }) {
122
- const elapsed = useRef(0);
123
- useFrame((_, delta) => {
124
- if (!playing) {
125
- elapsed.current = 0;
126
- return;
127
- }
128
- elapsed.current += delta;
129
- const frameDelta = Math.floor(elapsed.current * fps);
130
- if (frameDelta > 0) {
131
- elapsed.current -= frameDelta / fps;
132
- frameRef.current = (frameRef.current + frameDelta) % totalFrames;
133
- }
134
- });
135
- return null;
136
- }
137
-
138
- // ─── Detect raw servo values (0-4096) vs radians ───
139
  function detectAndConvert(values: number[]): number[] {
140
  if (values.length === 0) return values;
141
  const max = Math.max(...values.map(Math.abs));
@@ -143,7 +26,7 @@ function detectAndConvert(values: number[]): number[] {
143
  return values;
144
  }
145
 
146
- // ─── Group columns by feature prefix ───
147
  function groupColumnsByPrefix(keys: string[]): Record<string, string[]> {
148
  const groups: Record<string, string[]> = {};
149
  for (const key of keys) {
@@ -156,6 +39,144 @@ function groupColumnsByPrefix(keys: string[]): Record<string, string[]> {
156
  return groups;
157
  }
158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  // ═══════════════════════════════════════
160
  // ─── Main URDF Viewer ───
161
  // ═══════════════════════════════════════
@@ -163,7 +184,13 @@ export default function URDFViewer({ data }: { data: EpisodeData }) {
163
  const { flatChartData, datasetInfo } = data;
164
  const totalFrames = flatChartData.length;
165
  const fps = datasetInfo.fps || 30;
 
166
 
 
 
 
 
 
167
  const columnGroups = useMemo(() => {
168
  if (totalFrames === 0) return {};
169
  return groupColumnsByPrefix(Object.keys(flatChartData[0]));
@@ -183,47 +210,46 @@ export default function URDFViewer({ data }: { data: EpisodeData }) {
183
  useEffect(() => setSelectedGroup(defaultGroup), [defaultGroup]);
184
 
185
  const selectedColumns = columnGroups[selectedGroup] ?? [];
186
- const autoMapping = useMemo(() => autoMatchJoints(selectedColumns), [selectedColumns]);
 
 
 
 
 
187
  const [mapping, setMapping] = useState<Record<string, string>>(autoMapping);
188
  useEffect(() => setMapping(autoMapping), [autoMapping]);
189
 
 
190
  const [frame, setFrame] = useState(0);
191
  const [playing, setPlaying] = useState(false);
192
  const frameRef = useRef(0);
193
 
194
- useEffect(() => {
195
- if (!playing) return;
196
- const interval = setInterval(() => setFrame(frameRef.current), 33);
197
- return () => clearInterval(interval);
198
- }, [playing]);
199
-
200
  const handleFrameChange = useCallback((e: React.ChangeEvent<HTMLInputElement>) => {
201
  const f = parseInt(e.target.value);
202
  setFrame(f);
203
  frameRef.current = f;
204
  }, []);
205
 
206
- const jointAngles = useMemo(() => {
207
- if (totalFrames === 0) return {};
 
208
  const row = flatChartData[Math.min(frame, totalFrames - 1)];
209
  const rawValues: number[] = [];
210
- const jointNames: string[] = [];
211
 
212
- for (const joint of SO101_JOINTS) {
213
- const col = mapping[joint.name];
214
  if (col && typeof row[col] === "number") {
215
  rawValues.push(row[col]);
216
- jointNames.push(joint.name);
217
  }
218
  }
219
 
220
  const converted = detectAndConvert(rawValues);
221
- const angles: Record<string, number> = {};
222
- jointNames.forEach((name, i) => {
223
- angles[name] = converted[i];
224
- });
225
- return angles;
226
- }, [flatChartData, frame, mapping, totalFrames]);
227
 
228
  const currentTime = totalFrames > 0 ? (frame / fps).toFixed(2) : "0.00";
229
  const totalTime = (totalFrames / fps).toFixed(2);
@@ -236,66 +262,46 @@ export default function URDFViewer({ data }: { data: EpisodeData }) {
236
  <div className="flex-1 flex flex-col overflow-hidden">
237
  {/* 3D Viewport */}
238
  <div className="flex-1 min-h-0 bg-slate-950 rounded-lg overflow-hidden border border-slate-700">
239
- <Canvas camera={{ position: [0.35, 0.25, 0.3], fov: 45, near: 0.001, far: 10 }}>
240
  <ambientLight intensity={0.5} />
241
- <directionalLight position={[3, 5, 4]} intensity={1.2} castShadow />
242
  <directionalLight position={[-2, 3, -2]} intensity={0.4} />
243
  <hemisphereLight args={["#b1e1ff", "#444444", 0.4]} />
244
- <Suspense fallback={null}>
245
- <RobotArm angles={jointAngles} />
246
- </Suspense>
247
  <Grid
248
- args={[1, 1]}
249
- cellSize={0.02}
250
  cellThickness={0.5}
251
  cellColor="#334155"
252
- sectionSize={0.1}
253
  sectionThickness={1}
254
  sectionColor="#475569"
255
- fadeDistance={1}
256
  position={[0, 0, 0]}
257
  />
258
- <OrbitControls target={[0, 0.1, 0]} />
259
- <PlaybackDriver playing={playing} fps={fps} totalFrames={totalFrames} frameRef={frameRef} />
260
  </Canvas>
261
  </div>
262
 
263
- {/* Controls Panel */}
264
  <div className="bg-slate-800/90 border-t border-slate-700 p-3 space-y-3 shrink-0">
265
- {/* Playback bar */}
266
  <div className="flex items-center gap-3">
267
  <button
268
- onClick={() => {
269
- setPlaying(!playing);
270
- if (!playing) frameRef.current = frame;
271
- }}
272
  className="w-8 h-8 flex items-center justify-center rounded bg-orange-600 hover:bg-orange-500 text-white transition-colors shrink-0"
273
  >
274
  {playing ? (
275
- <svg width="12" height="14" viewBox="0 0 12 14">
276
- <rect x="1" y="1" width="3" height="12" fill="white" />
277
- <rect x="8" y="1" width="3" height="12" fill="white" />
278
- </svg>
279
  ) : (
280
- <svg width="12" height="14" viewBox="0 0 12 14">
281
- <polygon points="2,1 11,7 2,13" fill="white" />
282
- </svg>
283
  )}
284
  </button>
285
- <input
286
- type="range"
287
- min={0}
288
- max={Math.max(totalFrames - 1, 0)}
289
- value={frame}
290
- onChange={handleFrameChange}
291
- className="flex-1 h-1.5 accent-orange-500 cursor-pointer"
292
- />
293
- <span className="text-xs text-slate-400 tabular-nums w-28 text-right shrink-0">
294
- {currentTime}s / {totalTime}s
295
- </span>
296
- <span className="text-xs text-slate-500 tabular-nums w-20 text-right shrink-0">
297
- F {frame}/{totalFrames - 1}
298
- </span>
299
  </div>
300
 
301
  {/* Data source + joint mapping */}
@@ -304,17 +310,10 @@ export default function URDFViewer({ data }: { data: EpisodeData }) {
304
  <label className="text-xs text-slate-400">Data source</label>
305
  <div className="flex gap-1 flex-wrap">
306
  {groupNames.map((name) => (
307
- <button
308
- key={name}
309
- onClick={() => setSelectedGroup(name)}
310
  className={`px-2 py-1 text-xs rounded transition-colors ${
311
- selectedGroup === name
312
- ? "bg-orange-600 text-white"
313
- : "bg-slate-700 text-slate-300 hover:bg-slate-600"
314
- }`}
315
- >
316
- {name}
317
- </button>
318
  ))}
319
  </div>
320
  </div>
@@ -330,29 +329,23 @@ export default function URDFViewer({ data }: { data: EpisodeData }) {
330
  </tr>
331
  </thead>
332
  <tbody>
333
- {SO101_JOINTS.map((joint) => (
334
- <tr key={joint.name} className="border-t border-slate-700/50">
335
- <td className="px-1 py-0.5 text-slate-300 font-mono">{joint.name}</td>
336
  <td className="px-1 text-slate-600">→</td>
337
  <td className="px-1 py-0.5">
338
- <select
339
- value={mapping[joint.name] ?? ""}
340
- onChange={(e) => setMapping((m) => ({ ...m, [joint.name]: e.target.value }))}
341
- className="bg-slate-900 text-slate-200 text-xs rounded px-1 py-0.5 border border-slate-600 w-full max-w-[200px]"
342
- >
343
  <option value="">-- unmapped --</option>
344
  {selectedColumns.map((col) => {
345
  const label = col.split(SERIES_DELIM).pop() ?? col;
346
- return (
347
- <option key={col} value={col}>
348
- {label}
349
- </option>
350
- );
351
  })}
352
  </select>
353
  </td>
354
  <td className="px-1 py-0.5 text-right tabular-nums text-slate-400 font-mono">
355
- {jointAngles[joint.name] !== undefined ? jointAngles[joint.name].toFixed(3) : "—"}
356
  </td>
357
  </tr>
358
  ))}
 
1
  "use client";
2
 
3
+ import React, { useState, useEffect, useRef, useMemo, useCallback } from "react";
4
+ import { Canvas, useThree, useFrame } from "@react-three/fiber";
5
+ import { OrbitControls, Grid, Html } from "@react-three/drei";
6
  import * as THREE from "three";
7
+ import URDFLoader from "urdf-loader";
8
+ import type { URDFRobot } from "urdf-loader";
9
+ import { STLLoader } from "three/examples/jsm/loaders/STLLoader.js";
 
 
 
 
 
 
10
  import type { EpisodeData } from "@/app/[org]/[dataset]/[episode]/fetch-data";
11
 
12
  const SERIES_DELIM = " | ";
13
+ const SCALE = 10;
14
 
15
+ function getUrdfUrl(robotType: string | null): string {
16
+ const lower = (robotType ?? "").toLowerCase();
17
+ if (lower.includes("so100") && !lower.includes("so101")) return "/urdf/so101/so100.urdf";
18
+ return "/urdf/so101/so101_new_calib.urdf";
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  }
20
 
21
+ // Detect raw servo values (0-4096) vs radians
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  function detectAndConvert(values: number[]): number[] {
23
  if (values.length === 0) return values;
24
  const max = Math.max(...values.map(Math.abs));
 
26
  return values;
27
  }
28
 
29
+ // Group flat chart columns by feature prefix
30
  function groupColumnsByPrefix(keys: string[]): Record<string, string[]> {
31
  const groups: Record<string, string[]> = {};
32
  for (const key of keys) {
 
39
  return groups;
40
  }
41
 
42
+ // Auto-match dataset columns to URDF joint names
43
+ function autoMatchJoints(urdfJointNames: string[], columnKeys: string[]): Record<string, string> {
44
+ const mapping: Record<string, string> = {};
45
+ for (const jointName of urdfJointNames) {
46
+ const lower = jointName.toLowerCase();
47
+ const exactMatch = columnKeys.find((k) => {
48
+ const suffix = (k.split(SERIES_DELIM).pop()?.trim() ?? k).toLowerCase();
49
+ return suffix === lower;
50
+ });
51
+ if (exactMatch) { mapping[jointName] = exactMatch; continue; }
52
+ const fuzzy = columnKeys.find((k) => k.toLowerCase().includes(lower));
53
+ if (fuzzy) mapping[jointName] = fuzzy;
54
+ }
55
+ return mapping;
56
+ }
57
+
58
+ // ─── Robot scene (imperative, inside Canvas) ───
59
+ function RobotScene({
60
+ urdfUrl,
61
+ jointValues,
62
+ onJointsLoaded,
63
+ }: {
64
+ urdfUrl: string;
65
+ jointValues: Record<string, number>;
66
+ onJointsLoaded: (names: string[]) => void;
67
+ }) {
68
+ const { scene } = useThree();
69
+ const robotRef = useRef<URDFRobot | null>(null);
70
+ const [loading, setLoading] = useState(true);
71
+ const [error, setError] = useState<string | null>(null);
72
+
73
+ useEffect(() => {
74
+ setLoading(true);
75
+ setError(null);
76
+
77
+ const manager = new THREE.LoadingManager();
78
+ const loader = new URDFLoader(manager);
79
+
80
+ loader.loadMeshCb = (url, mgr, onLoad) => {
81
+ const stlLoader = new STLLoader(mgr);
82
+ stlLoader.load(
83
+ url,
84
+ (geometry) => {
85
+ const isMotor = url.includes("sts3215");
86
+ const material = new THREE.MeshStandardMaterial({
87
+ color: isMotor ? "#1a1a1a" : "#FFD700",
88
+ metalness: isMotor ? 0.7 : 0.1,
89
+ roughness: isMotor ? 0.3 : 0.6,
90
+ });
91
+ onLoad(new THREE.Mesh(geometry, material));
92
+ },
93
+ undefined,
94
+ (err) => onLoad(new THREE.Object3D(), err as Error),
95
+ );
96
+ };
97
+
98
+ loader.load(
99
+ urdfUrl,
100
+ (robot) => {
101
+ robotRef.current = robot;
102
+ robot.rotateOnAxis(new THREE.Vector3(1, 0, 0), -Math.PI / 2);
103
+ robot.traverse((c) => { c.castShadow = true; });
104
+ robot.updateMatrixWorld(true);
105
+ robot.scale.set(SCALE, SCALE, SCALE);
106
+ scene.add(robot);
107
+
108
+ const revolute = Object.values(robot.joints)
109
+ .filter((j) => j.jointType === "revolute" || j.jointType === "continuous")
110
+ .map((j) => j.name);
111
+ onJointsLoaded(revolute);
112
+ setLoading(false);
113
+ },
114
+ undefined,
115
+ (err) => {
116
+ console.error("Error loading URDF:", err);
117
+ setError(String(err));
118
+ setLoading(false);
119
+ },
120
+ );
121
+
122
+ return () => {
123
+ if (robotRef.current) {
124
+ scene.remove(robotRef.current);
125
+ robotRef.current = null;
126
+ }
127
+ };
128
+ }, [urdfUrl, scene, onJointsLoaded]);
129
+
130
+ useFrame(() => {
131
+ if (!robotRef.current) return;
132
+ for (const [name, value] of Object.entries(jointValues)) {
133
+ robotRef.current.setJointValue(name, value);
134
+ }
135
+ });
136
+
137
+ if (loading) return <Html center><span className="text-white text-lg">Loading robot…</span></Html>;
138
+ if (error) return <Html center><span className="text-red-400">Failed to load URDF</span></Html>;
139
+ return null;
140
+ }
141
+
142
+ // ─── Playback ticker (inside Canvas) ───
143
+ function PlaybackDriver({
144
+ playing, fps, totalFrames, frameRef, setFrame,
145
+ }: {
146
+ playing: boolean; fps: number; totalFrames: number;
147
+ frameRef: React.MutableRefObject<number>;
148
+ setFrame: React.Dispatch<React.SetStateAction<number>>;
149
+ }) {
150
+ const elapsed = useRef(0);
151
+ const last = useRef(0);
152
+
153
+ useEffect(() => {
154
+ if (!playing) return;
155
+ let raf: number;
156
+ const tick = () => {
157
+ raf = requestAnimationFrame(tick);
158
+ const now = performance.now();
159
+ const dt = (now - last.current) / 1000;
160
+ last.current = now;
161
+ if (dt > 0 && dt < 0.5) {
162
+ elapsed.current += dt;
163
+ const frameDelta = Math.floor(elapsed.current * fps);
164
+ if (frameDelta > 0) {
165
+ elapsed.current -= frameDelta / fps;
166
+ frameRef.current = (frameRef.current + frameDelta) % totalFrames;
167
+ setFrame(frameRef.current);
168
+ }
169
+ }
170
+ };
171
+ last.current = performance.now();
172
+ elapsed.current = 0;
173
+ raf = requestAnimationFrame(tick);
174
+ return () => cancelAnimationFrame(raf);
175
+ }, [playing, fps, totalFrames, frameRef, setFrame]);
176
+
177
+ return null;
178
+ }
179
+
180
  // ═══════════════════════════════════════
181
  // ─── Main URDF Viewer ───
182
  // ═══════════════════════════════════════
 
184
  const { flatChartData, datasetInfo } = data;
185
  const totalFrames = flatChartData.length;
186
  const fps = datasetInfo.fps || 30;
187
+ const urdfUrl = useMemo(() => getUrdfUrl(datasetInfo.robot_type), [datasetInfo.robot_type]);
188
 
189
+ // URDF joint names (set after robot loads)
190
+ const [urdfJointNames, setUrdfJointNames] = useState<string[]>([]);
191
+ const onJointsLoaded = useCallback((names: string[]) => setUrdfJointNames(names), []);
192
+
193
+ // Feature group selection
194
  const columnGroups = useMemo(() => {
195
  if (totalFrames === 0) return {};
196
  return groupColumnsByPrefix(Object.keys(flatChartData[0]));
 
210
  useEffect(() => setSelectedGroup(defaultGroup), [defaultGroup]);
211
 
212
  const selectedColumns = columnGroups[selectedGroup] ?? [];
213
+
214
+ // Joint mapping (re-compute when URDF joints or selected columns change)
215
+ const autoMapping = useMemo(
216
+ () => autoMatchJoints(urdfJointNames, selectedColumns),
217
+ [urdfJointNames, selectedColumns],
218
+ );
219
  const [mapping, setMapping] = useState<Record<string, string>>(autoMapping);
220
  useEffect(() => setMapping(autoMapping), [autoMapping]);
221
 
222
+ // Playback
223
  const [frame, setFrame] = useState(0);
224
  const [playing, setPlaying] = useState(false);
225
  const frameRef = useRef(0);
226
 
 
 
 
 
 
 
227
  const handleFrameChange = useCallback((e: React.ChangeEvent<HTMLInputElement>) => {
228
  const f = parseInt(e.target.value);
229
  setFrame(f);
230
  frameRef.current = f;
231
  }, []);
232
 
233
+ // Compute joint values for current frame
234
+ const jointValues = useMemo(() => {
235
+ if (totalFrames === 0 || urdfJointNames.length === 0) return {};
236
  const row = flatChartData[Math.min(frame, totalFrames - 1)];
237
  const rawValues: number[] = [];
238
+ const names: string[] = [];
239
 
240
+ for (const jn of urdfJointNames) {
241
+ const col = mapping[jn];
242
  if (col && typeof row[col] === "number") {
243
  rawValues.push(row[col]);
244
+ names.push(jn);
245
  }
246
  }
247
 
248
  const converted = detectAndConvert(rawValues);
249
+ const values: Record<string, number> = {};
250
+ names.forEach((n, i) => { values[n] = converted[i]; });
251
+ return values;
252
+ }, [flatChartData, frame, mapping, totalFrames, urdfJointNames]);
 
 
253
 
254
  const currentTime = totalFrames > 0 ? (frame / fps).toFixed(2) : "0.00";
255
  const totalTime = (totalFrames / fps).toFixed(2);
 
262
  <div className="flex-1 flex flex-col overflow-hidden">
263
  {/* 3D Viewport */}
264
  <div className="flex-1 min-h-0 bg-slate-950 rounded-lg overflow-hidden border border-slate-700">
265
+ <Canvas camera={{ position: [0.3 * SCALE, 0.25 * SCALE, 0.3 * SCALE], fov: 45, near: 0.01, far: 100 }}>
266
  <ambientLight intensity={0.5} />
267
+ <directionalLight position={[3, 5, 4]} intensity={1.2} />
268
  <directionalLight position={[-2, 3, -2]} intensity={0.4} />
269
  <hemisphereLight args={["#b1e1ff", "#444444", 0.4]} />
270
+ <RobotScene urdfUrl={urdfUrl} jointValues={jointValues} onJointsLoaded={onJointsLoaded} />
 
 
271
  <Grid
272
+ args={[10, 10]}
273
+ cellSize={0.2}
274
  cellThickness={0.5}
275
  cellColor="#334155"
276
+ sectionSize={1}
277
  sectionThickness={1}
278
  sectionColor="#475569"
279
+ fadeDistance={10}
280
  position={[0, 0, 0]}
281
  />
282
+ <OrbitControls target={[0, 0.8, 0]} />
283
+ <PlaybackDriver playing={playing} fps={fps} totalFrames={totalFrames} frameRef={frameRef} setFrame={setFrame} />
284
  </Canvas>
285
  </div>
286
 
287
+ {/* Controls */}
288
  <div className="bg-slate-800/90 border-t border-slate-700 p-3 space-y-3 shrink-0">
289
+ {/* Timeline */}
290
  <div className="flex items-center gap-3">
291
  <button
292
+ onClick={() => { setPlaying(!playing); if (!playing) frameRef.current = frame; }}
 
 
 
293
  className="w-8 h-8 flex items-center justify-center rounded bg-orange-600 hover:bg-orange-500 text-white transition-colors shrink-0"
294
  >
295
  {playing ? (
296
+ <svg width="12" height="14" viewBox="0 0 12 14"><rect x="1" y="1" width="3" height="12" fill="white" /><rect x="8" y="1" width="3" height="12" fill="white" /></svg>
 
 
 
297
  ) : (
298
+ <svg width="12" height="14" viewBox="0 0 12 14"><polygon points="2,1 11,7 2,13" fill="white" /></svg>
 
 
299
  )}
300
  </button>
301
+ <input type="range" min={0} max={Math.max(totalFrames - 1, 0)} value={frame}
302
+ onChange={handleFrameChange} className="flex-1 h-1.5 accent-orange-500 cursor-pointer" />
303
+ <span className="text-xs text-slate-400 tabular-nums w-28 text-right shrink-0">{currentTime}s / {totalTime}s</span>
304
+ <span className="text-xs text-slate-500 tabular-nums w-20 text-right shrink-0">F {frame}/{totalFrames - 1}</span>
 
 
 
 
 
 
 
 
 
 
305
  </div>
306
 
307
  {/* Data source + joint mapping */}
 
310
  <label className="text-xs text-slate-400">Data source</label>
311
  <div className="flex gap-1 flex-wrap">
312
  {groupNames.map((name) => (
313
+ <button key={name} onClick={() => setSelectedGroup(name)}
 
 
314
  className={`px-2 py-1 text-xs rounded transition-colors ${
315
+ selectedGroup === name ? "bg-orange-600 text-white" : "bg-slate-700 text-slate-300 hover:bg-slate-600"
316
+ }`}>{name}</button>
 
 
 
 
 
317
  ))}
318
  </div>
319
  </div>
 
329
  </tr>
330
  </thead>
331
  <tbody>
332
+ {urdfJointNames.map((jointName) => (
333
+ <tr key={jointName} className="border-t border-slate-700/50">
334
+ <td className="px-1 py-0.5 text-slate-300 font-mono">{jointName}</td>
335
  <td className="px-1 text-slate-600">→</td>
336
  <td className="px-1 py-0.5">
337
+ <select value={mapping[jointName] ?? ""}
338
+ onChange={(e) => setMapping((m) => ({ ...m, [jointName]: e.target.value }))}
339
+ className="bg-slate-900 text-slate-200 text-xs rounded px-1 py-0.5 border border-slate-600 w-full max-w-[200px]">
 
 
340
  <option value="">-- unmapped --</option>
341
  {selectedColumns.map((col) => {
342
  const label = col.split(SERIES_DELIM).pop() ?? col;
343
+ return <option key={col} value={col}>{label}</option>;
 
 
 
 
344
  })}
345
  </select>
346
  </td>
347
  <td className="px-1 py-0.5 text-right tabular-nums text-slate-400 font-mono">
348
+ {jointValues[jointName] !== undefined ? jointValues[jointName].toFixed(3) : "—"}
349
  </td>
350
  </tr>
351
  ))}
src/lib/so101-robot.ts CHANGED
@@ -1,154 +1,5 @@
1
- export type JointDef = {
2
- name: string;
3
- origin: { xyz: [number, number, number]; rpy: [number, number, number] };
4
- axis: [number, number, number];
5
- limits: [number, number];
6
- };
7
-
8
- export type MeshDef = {
9
- file: string;
10
- origin: { xyz: [number, number, number]; rpy: [number, number, number] };
11
- material: "3d_printed" | "motor";
12
- };
13
-
14
- export type LinkDef = {
15
- name: string;
16
- meshes: MeshDef[];
17
- };
18
-
19
- const ASSET_BASE = "/urdf/so101/assets";
20
- const P = Math.PI;
21
-
22
- // ─── Visual meshes per link (from URDF) ───
23
- export const SO101_LINKS: LinkDef[] = [
24
- {
25
- name: "base_link",
26
- meshes: [
27
- { file: `${ASSET_BASE}/base_motor_holder_so101_v1.stl`, origin: { xyz: [-0.00636471, -9.94414e-05, -0.0024], rpy: [P / 2, 0, P / 2] }, material: "3d_printed" },
28
- { file: `${ASSET_BASE}/base_so101_v2.stl`, origin: { xyz: [-0.00636471, 0, -0.0024], rpy: [P / 2, 0, P / 2] }, material: "3d_printed" },
29
- { file: `${ASSET_BASE}/sts3215_03a_v1.stl`, origin: { xyz: [0.0263353, 0, 0.0437], rpy: [0, 0, 0] }, material: "motor" },
30
- { file: `${ASSET_BASE}/waveshare_mounting_plate_so101_v2.stl`, origin: { xyz: [-0.0309827, -0.000199441, 0.0474], rpy: [P / 2, 0, P / 2] }, material: "3d_printed" },
31
- ],
32
- },
33
- {
34
- name: "shoulder_link",
35
- meshes: [
36
- { file: `${ASSET_BASE}/sts3215_03a_v1.stl`, origin: { xyz: [-0.0303992, 0.000422241, -0.0417], rpy: [P / 2, P / 2, 0] }, material: "motor" },
37
- { file: `${ASSET_BASE}/motor_holder_so101_base_v1.stl`, origin: { xyz: [-0.0675992, -0.000177759, 0.0158499], rpy: [P / 2, -P / 2, 0] }, material: "3d_printed" },
38
- { file: `${ASSET_BASE}/rotation_pitch_so101_v1.stl`, origin: { xyz: [0.0122008, 2.22413e-05, 0.0464], rpy: [-P / 2, 0, 0] }, material: "3d_printed" },
39
- ],
40
- },
41
- {
42
- name: "upper_arm_link",
43
- meshes: [
44
- { file: `${ASSET_BASE}/sts3215_03a_v1.stl`, origin: { xyz: [-0.11257, -0.0155, 0.0187], rpy: [-P, 0, -P / 2] }, material: "motor" },
45
- { file: `${ASSET_BASE}/upper_arm_so101_v1.stl`, origin: { xyz: [-0.065085, 0.012, 0.0182], rpy: [P, 0, 0] }, material: "3d_printed" },
46
- ],
47
- },
48
- {
49
- name: "lower_arm_link",
50
- meshes: [
51
- { file: `${ASSET_BASE}/under_arm_so101_v1.stl`, origin: { xyz: [-0.0648499, -0.032, 0.0182], rpy: [P, 0, 0] }, material: "3d_printed" },
52
- { file: `${ASSET_BASE}/motor_holder_so101_wrist_v1.stl`, origin: { xyz: [-0.0648499, -0.032, 0.018], rpy: [-P, 0, 0] }, material: "3d_printed" },
53
- { file: `${ASSET_BASE}/sts3215_03a_v1.stl`, origin: { xyz: [-0.1224, 0.0052, 0.0187], rpy: [-P, 0, -P] }, material: "motor" },
54
- ],
55
- },
56
- {
57
- name: "wrist_link",
58
- meshes: [
59
- { file: `${ASSET_BASE}/sts3215_03a_no_horn_v1.stl`, origin: { xyz: [0, -0.0424, 0.0306], rpy: [P / 2, P / 2, 0] }, material: "motor" },
60
- { file: `${ASSET_BASE}/wrist_roll_pitch_so101_v2.stl`, origin: { xyz: [0, -0.028, 0.0181], rpy: [-P / 2, -P / 2, 0] }, material: "3d_printed" },
61
- ],
62
- },
63
- {
64
- name: "gripper_link",
65
- meshes: [
66
- { file: `${ASSET_BASE}/sts3215_03a_v1.stl`, origin: { xyz: [0.0077, 0.0001, -0.0234], rpy: [-P / 2, 0, 0] }, material: "motor" },
67
- { file: `${ASSET_BASE}/wrist_roll_follower_so101_v1.stl`, origin: { xyz: [0, -0.000218214, 0.000949706], rpy: [-P, 0, 0] }, material: "3d_printed" },
68
- ],
69
- },
70
- {
71
- name: "moving_jaw_link",
72
- meshes: [
73
- { file: `${ASSET_BASE}/moving_jaw_so101_v1.stl`, origin: { xyz: [0, 0, 0.0189], rpy: [0, 0, 0] }, material: "3d_printed" },
74
- ],
75
- },
76
- ];
77
-
78
- // Kinematic chain: each joint connects a parent link to a child link
79
- // Index in SO101_LINKS: base=0, shoulder=1, upper_arm=2, lower_arm=3, wrist=4, gripper=5, jaw=6
80
- export const SO101_JOINTS: JointDef[] = [
81
- {
82
- name: "shoulder_pan",
83
- origin: { xyz: [0.0388353, -8.97657e-09, 0.0624], rpy: [P, 4.18253e-17, -P] },
84
- axis: [0, 0, 1],
85
- limits: [-1.91986, 1.91986],
86
- },
87
- {
88
- name: "shoulder_lift",
89
- origin: { xyz: [-0.0303992, -0.0182778, -0.0542], rpy: [-P / 2, -P / 2, 0] },
90
- axis: [0, 0, 1],
91
- limits: [-1.74533, 1.74533],
92
- },
93
- {
94
- name: "elbow_flex",
95
- origin: { xyz: [-0.11257, -0.028, 1.73763e-16], rpy: [0, 0, P / 2] },
96
- axis: [0, 0, 1],
97
- limits: [-1.69, 1.69],
98
- },
99
- {
100
- name: "wrist_flex",
101
- origin: { xyz: [-0.1349, 0.0052, 0], rpy: [0, 0, -P / 2] },
102
- axis: [0, 0, 1],
103
- limits: [-1.65806, 1.65806],
104
- },
105
- {
106
- name: "wrist_roll",
107
- origin: { xyz: [0, -0.0611, 0.0181], rpy: [P / 2, 0.0486795, P] },
108
- axis: [0, 0, 1],
109
- limits: [-2.74385, 2.84121],
110
- },
111
- {
112
- name: "gripper",
113
- origin: { xyz: [0.0202, 0.0188, -0.0234], rpy: [P / 2, 0, 0] },
114
- axis: [0, 0, 1],
115
- limits: [-0.174533, 1.74533],
116
- },
117
- ];
118
-
119
- export const MATERIAL_COLORS = {
120
- "3d_printed": "#FFD700",
121
- motor: "#1a1a1a",
122
- } as const;
123
-
124
  export function isSO101Robot(robotType: string | null): boolean {
125
  if (!robotType) return false;
126
  const lower = robotType.toLowerCase();
127
  return lower.includes("so100") || lower.includes("so101") || lower === "so_follower";
128
  }
129
-
130
- // Collect all unique STL file paths for preloading
131
- export function getAllSTLPaths(): string[] {
132
- const paths = new Set<string>();
133
- for (const link of SO101_LINKS) {
134
- for (const mesh of link.meshes) {
135
- paths.add(mesh.file);
136
- }
137
- }
138
- return [...paths];
139
- }
140
-
141
- // Auto-match dataset columns to URDF joint names
142
- export function autoMatchJoints(columnKeys: string[]): Record<string, string> {
143
- const mapping: Record<string, string> = {};
144
- for (const joint of SO101_JOINTS) {
145
- const exactMatch = columnKeys.find((k) => {
146
- const suffix = k.split(" | ").pop()?.trim() ?? k;
147
- return suffix === joint.name;
148
- });
149
- if (exactMatch) { mapping[joint.name] = exactMatch; continue; }
150
- const fuzzy = columnKeys.find((k) => k.toLowerCase().includes(joint.name));
151
- if (fuzzy) mapping[joint.name] = fuzzy;
152
- }
153
- return mapping;
154
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  export function isSO101Robot(robotType: string | null): boolean {
2
  if (!robotType) return false;
3
  const lower = robotType.toLowerCase();
4
  return lower.includes("so100") || lower.includes("so101") || lower === "so_follower";
5
  }