pepijn223 HF Staff commited on
Commit
83fd2de
Β·
unverified Β·
1 Parent(s): cb1fbc7

add more action insight tools

Browse files
src/app/[org]/[dataset]/[episode]/fetch-data.ts CHANGED
@@ -1356,6 +1356,41 @@ export type AggAutocorrelation = {
1356
  shortKeys: string[];
1357
  };
1358
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1359
  export type CrossEpisodeVarianceData = {
1360
  actionNames: string[];
1361
  timeBins: number[];
@@ -1364,6 +1399,10 @@ export type CrossEpisodeVarianceData = {
1364
  lowMovementEpisodes: LowMovementEpisode[];
1365
  aggVelocity: AggVelocityStat[];
1366
  aggAutocorrelation: AggAutocorrelation | null;
 
 
 
 
1367
  };
1368
 
1369
  export async function loadCrossEpisodeActionVariance(
@@ -1392,6 +1431,12 @@ export async function loadCrossEpisodeActionVariance(
1392
  ? (names as string[]).map(n => `${actionKey}${SERIES_NAME_DELIMITER}${n}`)
1393
  : Array.from({ length: actionDim }, (_, i) => `${actionKey}${SERIES_NAME_DELIMITER}${i}`);
1394
 
 
 
 
 
 
 
1395
  // Collect episode metadata
1396
  type EpMeta = { index: number; chunkIdx: number; fileIdx: number; from: number; to: number };
1397
  const allEps: EpMeta[] = [];
@@ -1436,8 +1481,9 @@ export async function loadCrossEpisodeActionVariance(
1436
  allEps[Math.round((i * (allEps.length - 1)) / (maxEpisodes - 1))]
1437
  );
1438
 
1439
- // Load action data per episode, tracking episode index alongside
1440
  const episodeActions: { index: number; actions: number[][] }[] = [];
 
1441
 
1442
  if (version === "v3.0") {
1443
  const byFile = new Map<string, EpMeta[]>();
@@ -1459,11 +1505,19 @@ export async function loadCrossEpisodeActionVariance(
1459
  const localFrom = Math.max(0, ep.from - fileStart);
1460
  const localTo = Math.min(rows.length, ep.to - fileStart);
1461
  const actions: number[][] = [];
 
1462
  for (let r = localFrom; r < localTo; r++) {
1463
  const raw = rows[r]?.[actionKey];
1464
  if (Array.isArray(raw)) actions.push(raw.map(Number));
 
 
 
 
 
 
 
 
1465
  }
1466
- if (actions.length > 0) episodeActions.push({ index: ep.index, actions });
1467
  }
1468
  } catch { /* skip file */ }
1469
  }
@@ -1479,6 +1533,7 @@ export async function loadCrossEpisodeActionVariance(
1479
  const buf = await fetchParquetFile(buildVersionedUrl(repoId, version, dataPath));
1480
  const rows = await readParquetAsObjects(buf, []);
1481
  const actions: number[][] = [];
 
1482
  for (const row of rows) {
1483
  const raw = row[actionKey];
1484
  if (Array.isArray(raw)) {
@@ -1491,8 +1546,15 @@ export async function loadCrossEpisodeActionVariance(
1491
  }
1492
  actions.push(vec);
1493
  }
 
 
 
 
 
 
 
 
1494
  }
1495
- if (actions.length > 0) episodeActions.push({ index: ep.index, actions });
1496
  } catch { /* skip */ }
1497
  }
1498
  }
@@ -1503,10 +1565,12 @@ export async function loadCrossEpisodeActionVariance(
1503
  }
1504
  console.log(`[cross-ep] Loaded action data for ${episodeActions.length}/${sampled.length} episodes`);
1505
 
1506
- // Resample each episode to numTimeBins and compute variance
1507
  const timeBins = Array.from({ length: numTimeBins }, (_, i) => i / (numTimeBins - 1));
1508
  const sums = Array.from({ length: numTimeBins }, () => new Float64Array(actionDim));
1509
  const sumsSq = Array.from({ length: numTimeBins }, () => new Float64Array(actionDim));
 
 
1510
  const counts = new Uint32Array(numTimeBins);
1511
 
1512
  for (const { actions: epActions } of episodeActions) {
@@ -1516,8 +1580,11 @@ export async function loadCrossEpisodeActionVariance(
1516
  const row = epActions[srcIdx];
1517
  for (let d = 0; d < actionDim; d++) {
1518
  const v = row[d] ?? 0;
 
1519
  sums[b][d] += v;
1520
- sumsSq[b][d] += v * v;
 
 
1521
  }
1522
  counts[b]++;
1523
  }
@@ -1625,7 +1692,283 @@ export async function loadCrossEpisodeActionVariance(
1625
  return { chartData, suggestedChunk, shortKeys };
1626
  })();
1627
 
1628
- return { actionNames, timeBins, variance, numEpisodes: episodeActions.length, lowMovementEpisodes, aggVelocity, aggAutocorrelation };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1629
  }
1630
 
1631
  // Load only flatChartData for a specific episode (used by URDF viewer episode switching)
 
1356
  shortKeys: string[];
1357
  };
1358
 
1359
+ export type SpeedDistEntry = {
1360
+ episodeIndex: number;
1361
+ speed: number;
1362
+ };
1363
+
1364
+ export type ClusterEntry = {
1365
+ episodeIndex: number;
1366
+ x: number;
1367
+ y: number;
1368
+ z: number;
1369
+ cluster: number;
1370
+ distFromCenter: number;
1371
+ isOutlier: boolean;
1372
+ };
1373
+
1374
+ export type TrajectoryClustering = {
1375
+ entries: ClusterEntry[];
1376
+ numClusters: number;
1377
+ clusterSizes: number[];
1378
+ outlierCount: number;
1379
+ };
1380
+
1381
+ export type AggAlignment = {
1382
+ ccData: { lag: number; max: number; mean: number; min: number }[];
1383
+ meanPeakLag: number;
1384
+ meanPeakCorr: number;
1385
+ maxPeakLag: number;
1386
+ maxPeakCorr: number;
1387
+ minPeakLag: number;
1388
+ minPeakCorr: number;
1389
+ lagRangeMin: number;
1390
+ lagRangeMax: number;
1391
+ numPairs: number;
1392
+ };
1393
+
1394
  export type CrossEpisodeVarianceData = {
1395
  actionNames: string[];
1396
  timeBins: number[];
 
1399
  lowMovementEpisodes: LowMovementEpisode[];
1400
  aggVelocity: AggVelocityStat[];
1401
  aggAutocorrelation: AggAutocorrelation | null;
1402
+ speedDistribution: SpeedDistEntry[];
1403
+ multimodality: number[][] | null;
1404
+ trajectoryClustering: TrajectoryClustering | null;
1405
+ aggAlignment: AggAlignment | null;
1406
  };
1407
 
1408
  export async function loadCrossEpisodeActionVariance(
 
1431
  ? (names as string[]).map(n => `${actionKey}${SERIES_NAME_DELIMITER}${n}`)
1432
  : Array.from({ length: actionDim }, (_, i) => `${actionKey}${SERIES_NAME_DELIMITER}${i}`);
1433
 
1434
+ // State feature for alignment computation
1435
+ const stateEntry = Object.entries(info.features)
1436
+ .find(([key, f]) => key === "observation.state" && f.shape.length === 1);
1437
+ const stateKey = stateEntry?.[0] ?? null;
1438
+ const stateDim = stateEntry?.[1].shape[0] ?? 0;
1439
+
1440
  // Collect episode metadata
1441
  type EpMeta = { index: number; chunkIdx: number; fileIdx: number; from: number; to: number };
1442
  const allEps: EpMeta[] = [];
 
1481
  allEps[Math.round((i * (allEps.length - 1)) / (maxEpisodes - 1))]
1482
  );
1483
 
1484
+ // Load action (and state) data per episode
1485
  const episodeActions: { index: number; actions: number[][] }[] = [];
1486
+ const episodeStates: (number[][] | null)[] = [];
1487
 
1488
  if (version === "v3.0") {
1489
  const byFile = new Map<string, EpMeta[]>();
 
1505
  const localFrom = Math.max(0, ep.from - fileStart);
1506
  const localTo = Math.min(rows.length, ep.to - fileStart);
1507
  const actions: number[][] = [];
1508
+ const states: number[][] = [];
1509
  for (let r = localFrom; r < localTo; r++) {
1510
  const raw = rows[r]?.[actionKey];
1511
  if (Array.isArray(raw)) actions.push(raw.map(Number));
1512
+ if (stateKey) {
1513
+ const sRaw = rows[r]?.[stateKey];
1514
+ if (Array.isArray(sRaw)) states.push(sRaw.map(Number));
1515
+ }
1516
+ }
1517
+ if (actions.length > 0) {
1518
+ episodeActions.push({ index: ep.index, actions });
1519
+ episodeStates.push(stateKey && states.length === actions.length ? states : null);
1520
  }
 
1521
  }
1522
  } catch { /* skip file */ }
1523
  }
 
1533
  const buf = await fetchParquetFile(buildVersionedUrl(repoId, version, dataPath));
1534
  const rows = await readParquetAsObjects(buf, []);
1535
  const actions: number[][] = [];
1536
+ const states: number[][] = [];
1537
  for (const row of rows) {
1538
  const raw = row[actionKey];
1539
  if (Array.isArray(raw)) {
 
1546
  }
1547
  actions.push(vec);
1548
  }
1549
+ if (stateKey) {
1550
+ const sRaw = row[stateKey];
1551
+ if (Array.isArray(sRaw)) states.push(sRaw.map(Number));
1552
+ }
1553
+ }
1554
+ if (actions.length > 0) {
1555
+ episodeActions.push({ index: ep.index, actions });
1556
+ episodeStates.push(stateKey && states.length === actions.length ? states : null);
1557
  }
 
1558
  } catch { /* skip */ }
1559
  }
1560
  }
 
1565
  }
1566
  console.log(`[cross-ep] Loaded action data for ${episodeActions.length}/${sampled.length} episodes`);
1567
 
1568
+ // Resample each episode to numTimeBins and compute variance + higher moments for multimodality
1569
  const timeBins = Array.from({ length: numTimeBins }, (_, i) => i / (numTimeBins - 1));
1570
  const sums = Array.from({ length: numTimeBins }, () => new Float64Array(actionDim));
1571
  const sumsSq = Array.from({ length: numTimeBins }, () => new Float64Array(actionDim));
1572
+ const sumsCube = Array.from({ length: numTimeBins }, () => new Float64Array(actionDim));
1573
+ const sumsFourth = Array.from({ length: numTimeBins }, () => new Float64Array(actionDim));
1574
  const counts = new Uint32Array(numTimeBins);
1575
 
1576
  for (const { actions: epActions } of episodeActions) {
 
1580
  const row = epActions[srcIdx];
1581
  for (let d = 0; d < actionDim; d++) {
1582
  const v = row[d] ?? 0;
1583
+ const v2 = v * v;
1584
  sums[b][d] += v;
1585
+ sumsSq[b][d] += v2;
1586
+ sumsCube[b][d] += v2 * v;
1587
+ sumsFourth[b][d] += v2 * v2;
1588
  }
1589
  counts[b]++;
1590
  }
 
1692
  return { chartData, suggestedChunk, shortKeys };
1693
  })();
1694
 
1695
+ // Speed distribution: all episode movement scores (not just lowest 10)
1696
+ const speedDistribution: SpeedDistEntry[] = movementScores.map(s => ({
1697
+ episodeIndex: s.episodeIndex,
1698
+ speed: s.totalMovement,
1699
+ }));
1700
+
1701
+ // Multimodality: bimodality coefficient per time-bin per dimension
1702
+ // BC = (skewnessΒ² + 1) / kurtosis β€” values above 5/9 suggest bimodality
1703
+ const multimodality: number[][] | null = (() => {
1704
+ const result: number[][] = [];
1705
+ for (let b = 0; b < numTimeBins; b++) {
1706
+ const row: number[] = [];
1707
+ const n = counts[b];
1708
+ for (let d = 0; d < actionDim; d++) {
1709
+ if (n < 4) { row.push(0); continue; }
1710
+ const mean = sums[b][d] / n;
1711
+ const m2 = sumsSq[b][d] / n - mean * mean;
1712
+ if (m2 < 1e-12) { row.push(0); continue; }
1713
+ const m3 = sumsCube[b][d] / n - 3 * mean * sumsSq[b][d] / n + 2 * mean * mean * mean;
1714
+ const m4 = sumsFourth[b][d] / n - 4 * mean * sumsCube[b][d] / n
1715
+ + 6 * mean * mean * sumsSq[b][d] / n - 3 * mean * mean * mean * mean;
1716
+ const skew = m3 / Math.pow(m2, 1.5);
1717
+ const kurt = m4 / (m2 * m2);
1718
+ row.push(kurt > 0 ? (skew * skew + 1) / kurt : 0);
1719
+ }
1720
+ result.push(row);
1721
+ }
1722
+ return result.length > 0 ? result : null;
1723
+ })();
1724
+
1725
+ // Trajectory clustering: time-normalize, PCA 2D, k-means, outlier detection
1726
+ const trajectoryClustering: TrajectoryClustering | null = (() => {
1727
+ if (episodeActions.length < 5) return null;
1728
+ const clusterBins = 30;
1729
+ const N = episodeActions.length;
1730
+ const D = clusterBins * actionDim;
1731
+
1732
+ // Time-normalize each episode to a flat vector, standardize per action dim
1733
+ const raw: number[][] = episodeActions.map(({ actions: ep }) => {
1734
+ const vec: number[] = [];
1735
+ for (let b = 0; b < clusterBins; b++) {
1736
+ const srcIdx = Math.min(Math.round(b / (clusterBins - 1) * (ep.length - 1)), ep.length - 1);
1737
+ for (let d = 0; d < actionDim; d++) vec.push(ep[srcIdx][d] ?? 0);
1738
+ }
1739
+ return vec;
1740
+ });
1741
+
1742
+ const dimMean = new Float64Array(D), dimStd = new Float64Array(D);
1743
+ for (const v of raw) for (let d = 0; d < D; d++) dimMean[d] += v[d];
1744
+ for (let d = 0; d < D; d++) dimMean[d] /= N;
1745
+ for (const v of raw) for (let d = 0; d < D; d++) dimStd[d] += (v[d] - dimMean[d]) ** 2;
1746
+ for (let d = 0; d < D; d++) dimStd[d] = Math.sqrt(dimStd[d] / N) || 1;
1747
+ const data = raw.map(v => v.map((x, d) => (x - dimMean[d]) / dimStd[d]));
1748
+
1749
+ // PCA via power iteration on centered data (already centered by standardization)
1750
+ const powerIter = (rows: number[][], init: number[]): number[] => {
1751
+ let v = [...init];
1752
+ let norm = Math.sqrt(v.reduce((s, x) => s + x * x, 0)) || 1;
1753
+ v = v.map(x => x / norm);
1754
+ for (let iter = 0; iter < 50; iter++) {
1755
+ const w = new Float64Array(D);
1756
+ for (const row of rows) {
1757
+ const dot = row.reduce((s, x, d) => s + x * v[d], 0);
1758
+ for (let d = 0; d < D; d++) w[d] += dot * row[d];
1759
+ }
1760
+ norm = Math.sqrt(w.reduce((s, x) => s + x * x, 0)) || 1;
1761
+ v = Array.from(w, x => x / norm);
1762
+ }
1763
+ return v;
1764
+ };
1765
+
1766
+ const ev1 = powerIter(data, data[0]);
1767
+ const proj1 = data.map(row => row.reduce((s, x, d) => s + x * ev1[d], 0));
1768
+ const deflated1 = data.map((row, i) => row.map((x, d) => x - proj1[i] * ev1[d]));
1769
+ const ev2 = powerIter(deflated1, deflated1[1] ?? deflated1[0]);
1770
+ const proj2 = data.map(row => row.reduce((s, x, d) => s + x * ev2[d], 0));
1771
+ const deflated2 = deflated1.map((row, i) => row.map((x, d) => x - proj2[i] * ev2[d]));
1772
+ const ev3 = powerIter(deflated2, deflated2[2] ?? deflated2[0]);
1773
+ const proj3 = data.map(row => row.reduce((s, x, d) => s + x * ev3[d], 0));
1774
+
1775
+ // Precompute pairwise distances for silhouette
1776
+ const dist = new Float64Array(N * N);
1777
+ for (let i = 0; i < N; i++) for (let j = i + 1; j < N; j++) {
1778
+ let d2 = 0;
1779
+ for (let d = 0; d < D; d++) d2 += (data[i][d] - data[j][d]) ** 2;
1780
+ const d = Math.sqrt(d2);
1781
+ dist[i * N + j] = d;
1782
+ dist[j * N + i] = d;
1783
+ }
1784
+
1785
+ // K-means
1786
+ const runKmeans = (k: number): { labels: number[]; centroids: number[][] } => {
1787
+ const centroids = Array.from({ length: k }, (_, c) => [...data[Math.floor(c * N / k)]]);
1788
+ const labels = new Int32Array(N);
1789
+ for (let iter = 0; iter < 30; iter++) {
1790
+ for (let i = 0; i < N; i++) {
1791
+ let minD = Infinity;
1792
+ for (let c = 0; c < k; c++) {
1793
+ let d2 = 0;
1794
+ for (let d = 0; d < D; d++) d2 += (data[i][d] - centroids[c][d]) ** 2;
1795
+ if (d2 < minD) { minD = d2; labels[i] = c; }
1796
+ }
1797
+ }
1798
+ const sums = Array.from({ length: k }, () => new Float64Array(D));
1799
+ const counts = new Uint32Array(k);
1800
+ for (let i = 0; i < N; i++) {
1801
+ counts[labels[i]]++;
1802
+ for (let d = 0; d < D; d++) sums[labels[i]][d] += data[i][d];
1803
+ }
1804
+ for (let c = 0; c < k; c++) {
1805
+ if (counts[c] === 0) continue;
1806
+ for (let d = 0; d < D; d++) centroids[c][d] = sums[c][d] / counts[c];
1807
+ }
1808
+ }
1809
+ return { labels: Array.from(labels), centroids };
1810
+ };
1811
+
1812
+ // Silhouette score using precomputed distances
1813
+ const silhouette = (labels: number[], k: number): number => {
1814
+ let total = 0;
1815
+ for (let i = 0; i < N; i++) {
1816
+ let aSum = 0, aCount = 0;
1817
+ for (let j = 0; j < N; j++) {
1818
+ if (j === i || labels[j] !== labels[i]) continue;
1819
+ aSum += dist[i * N + j]; aCount++;
1820
+ }
1821
+ const a = aCount > 0 ? aSum / aCount : 0;
1822
+ let b = Infinity;
1823
+ for (let c = 0; c < k; c++) {
1824
+ if (c === labels[i]) continue;
1825
+ let bSum = 0, bCount = 0;
1826
+ for (let j = 0; j < N; j++) {
1827
+ if (labels[j] !== c) continue;
1828
+ bSum += dist[i * N + j]; bCount++;
1829
+ }
1830
+ if (bCount > 0) b = Math.min(b, bSum / bCount);
1831
+ }
1832
+ const denom = Math.max(a, b);
1833
+ total += denom > 0 ? (b - a) / denom : 0;
1834
+ }
1835
+ return total / N;
1836
+ };
1837
+
1838
+ // Try k = 2..maxK, pick best silhouette
1839
+ const maxK = Math.min(5, Math.floor(N / 3));
1840
+ let bestK = 2, bestSil = -Infinity, bestLabels: number[] = [], bestCentroids: number[][] = [];
1841
+ for (let k = 2; k <= maxK; k++) {
1842
+ const { labels, centroids } = runKmeans(k);
1843
+ const sil = silhouette(labels, k);
1844
+ if (sil > bestSil) { bestSil = sil; bestK = k; bestLabels = labels; bestCentroids = centroids; }
1845
+ }
1846
+
1847
+ // Compute distance from cluster center and detect outliers
1848
+ const centerDists = bestLabels.map((c, i) => {
1849
+ let d2 = 0;
1850
+ for (let d = 0; d < D; d++) d2 += (data[i][d] - bestCentroids[c][d]) ** 2;
1851
+ return Math.sqrt(d2);
1852
+ });
1853
+ const clusterStats = Array.from({ length: bestK }, (_, c) => {
1854
+ const ds = centerDists.filter((_, i) => bestLabels[i] === c);
1855
+ const m = ds.reduce((a, b) => a + b, 0) / (ds.length || 1);
1856
+ const s = Math.sqrt(ds.reduce((a, v) => a + (v - m) ** 2, 0) / (ds.length || 1));
1857
+ return { mean: m, std: s };
1858
+ });
1859
+
1860
+ const entries: ClusterEntry[] = episodeActions.map(({ index }, i) => {
1861
+ const cs = clusterStats[bestLabels[i]];
1862
+ const isOutlier = centerDists[i] > cs.mean + 2 * cs.std;
1863
+ return {
1864
+ episodeIndex: index, x: proj1[i], y: proj2[i], z: proj3[i],
1865
+ cluster: bestLabels[i], distFromCenter: Math.round(centerDists[i] * 1000) / 1000,
1866
+ isOutlier,
1867
+ };
1868
+ });
1869
+
1870
+ const clusterSizes = Array.from({ length: bestK }, (_, c) => bestLabels.filter(l => l === c).length);
1871
+ const outlierCount = entries.filter(e => e.isOutlier).length;
1872
+
1873
+ return { entries, numClusters: bestK, clusterSizes, outlierCount };
1874
+ })();
1875
+
1876
+ // Aggregated state-action alignment across episodes
1877
+ const aggAlignment: AggAlignment | null = (() => {
1878
+ if (!stateKey || stateDim === 0) return null;
1879
+
1880
+ let sNms: unknown = stateEntry![1].names;
1881
+ while (typeof sNms === "object" && sNms !== null && !Array.isArray(sNms)) sNms = Object.values(sNms)[0];
1882
+ const stateNames = Array.isArray(sNms)
1883
+ ? (sNms as string[])
1884
+ : Array.from({ length: stateDim }, (_, i) => `${i}`);
1885
+ const actionSuffixes = actionNames.map(n => { const p = n.split(SERIES_NAME_DELIMITER); return p[p.length - 1]; });
1886
+
1887
+ // Match pairs by suffix, fall back to index
1888
+ const pairs: [number, number][] = [];
1889
+ for (let ai = 0; ai < actionDim; ai++) {
1890
+ const si = stateNames.findIndex(s => s === actionSuffixes[ai]);
1891
+ if (si >= 0) pairs.push([ai, si]);
1892
+ }
1893
+ if (pairs.length === 0) {
1894
+ const count = Math.min(actionDim, stateDim);
1895
+ for (let i = 0; i < count; i++) pairs.push([i, i]);
1896
+ }
1897
+ if (pairs.length === 0) return null;
1898
+
1899
+ const maxLag = 30;
1900
+ const numLags = 2 * maxLag + 1;
1901
+ const corrSums = pairs.map(() => new Float64Array(numLags));
1902
+ const corrCounts = pairs.map(() => new Uint32Array(numLags));
1903
+
1904
+ for (let ei = 0; ei < episodeActions.length; ei++) {
1905
+ const states = episodeStates[ei];
1906
+ if (!states) continue;
1907
+ const { actions } = episodeActions[ei];
1908
+ const n = Math.min(actions.length, states.length);
1909
+ if (n < 10) continue;
1910
+
1911
+ for (let pi = 0; pi < pairs.length; pi++) {
1912
+ const [ai, si] = pairs[pi];
1913
+ const aVals = actions.slice(0, n).map(r => r[ai] ?? 0);
1914
+ const sDeltas = Array.from({ length: n - 1 }, (_, t) => (states[t + 1][si] ?? 0) - (states[t][si] ?? 0));
1915
+ const effN = Math.min(aVals.length, sDeltas.length);
1916
+ const aM = aVals.slice(0, effN).reduce((a, b) => a + b, 0) / effN;
1917
+ const sM = sDeltas.slice(0, effN).reduce((a, b) => a + b, 0) / effN;
1918
+
1919
+ for (let li = 0; li < numLags; li++) {
1920
+ const lag = -maxLag + li;
1921
+ let sum = 0, aV = 0, sV = 0;
1922
+ for (let t = 0; t < effN; t++) {
1923
+ const sIdx = t + lag;
1924
+ if (sIdx < 0 || sIdx >= sDeltas.length) continue;
1925
+ const a = aVals[t] - aM, s = sDeltas[sIdx] - sM;
1926
+ sum += a * s; aV += a * a; sV += s * s;
1927
+ }
1928
+ const d = Math.sqrt(aV * sV);
1929
+ if (d > 0) { corrSums[pi][li] += sum / d; corrCounts[pi][li]++; }
1930
+ }
1931
+ }
1932
+ }
1933
+
1934
+ const avgCorrs = pairs.map((_, pi) =>
1935
+ Array.from({ length: numLags }, (_, li) =>
1936
+ corrCounts[pi][li] > 0 ? corrSums[pi][li] / corrCounts[pi][li] : 0
1937
+ )
1938
+ );
1939
+
1940
+ const ccData = Array.from({ length: numLags }, (_, li) => {
1941
+ const lag = -maxLag + li;
1942
+ const vals = avgCorrs.map(pc => pc[li]);
1943
+ return { lag, max: Math.max(...vals), mean: vals.reduce((a, b) => a + b, 0) / vals.length, min: Math.min(...vals) };
1944
+ });
1945
+
1946
+ let meanPeakLag = 0, meanPeakCorr = -Infinity;
1947
+ let maxPeakLag = 0, maxPeakCorr = -Infinity;
1948
+ let minPeakLag = 0, minPeakCorr = -Infinity;
1949
+ for (const row of ccData) {
1950
+ if (row.max > maxPeakCorr) { maxPeakCorr = row.max; maxPeakLag = row.lag; }
1951
+ if (row.mean > meanPeakCorr) { meanPeakCorr = row.mean; meanPeakLag = row.lag; }
1952
+ if (row.min > minPeakCorr) { minPeakCorr = row.min; minPeakLag = row.lag; }
1953
+ }
1954
+
1955
+ const perPairPeakLags = avgCorrs.map(pc => {
1956
+ let best = -Infinity, bestLag = 0;
1957
+ for (let li = 0; li < pc.length; li++) { if (pc[li] > best) { best = pc[li]; bestLag = -maxLag + li; } }
1958
+ return bestLag;
1959
+ });
1960
+
1961
+ return {
1962
+ ccData, meanPeakLag, meanPeakCorr, maxPeakLag, maxPeakCorr, minPeakLag, minPeakCorr,
1963
+ lagRangeMin: Math.min(...perPairPeakLags), lagRangeMax: Math.max(...perPairPeakLags), numPairs: pairs.length,
1964
+ };
1965
+ })();
1966
+
1967
+ return {
1968
+ actionNames, timeBins, variance, numEpisodes: episodeActions.length,
1969
+ lowMovementEpisodes, aggVelocity, aggAutocorrelation,
1970
+ speedDistribution, multimodality, trajectoryClustering, aggAlignment,
1971
+ };
1972
  }
1973
 
1974
  // Load only flatChartData for a specific episode (used by URDF viewer episode switching)
src/components/action-insights-panel.tsx CHANGED
@@ -1,6 +1,6 @@
1
  "use client";
2
 
3
- import React, { useMemo } from "react";
4
  import {
5
  LineChart,
6
  Line,
@@ -10,7 +10,7 @@ import {
10
  ResponsiveContainer,
11
  Tooltip,
12
  } from "recharts";
13
- import type { CrossEpisodeVarianceData, LowMovementEpisode, AggVelocityStat, AggAutocorrelation } from "@/app/[org]/[dataset]/[episode]/fetch-data";
14
 
15
  const DELIMITER = " | ";
16
  const COLORS = [
@@ -30,6 +30,12 @@ function getActionKeys(row: Record<string, number>): string[] {
30
  .sort();
31
  }
32
 
 
 
 
 
 
 
33
  // ─── Autocorrelation ─────────────────────────────────────────────
34
 
35
  function computeAutocorrelation(values: number[], maxLag: number): number[] {
@@ -507,6 +513,585 @@ function LowMovementSection({ episodes }: { episodes: LowMovementEpisode[] }) {
507
  );
508
  }
509
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
510
  // ─── Main Panel ──────────────────────────────────────────────────
511
 
512
  interface ActionInsightsPanelProps {
@@ -522,18 +1107,45 @@ const ActionInsightsPanel: React.FC<ActionInsightsPanelProps> = ({
522
  crossEpisodeData,
523
  crossEpisodeLoading,
524
  }) => {
 
 
 
525
  return (
526
  <div className="max-w-5xl mx-auto py-6 space-y-8">
527
- <div>
528
- <h2 className="text-xl font-bold text-slate-100">Action Insights</h2>
529
- <p className="text-sm text-slate-400 mt-1">
530
- Data-driven analysis to guide action chunking, data quality assessment, and training configuration.
531
- </p>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
532
  </div>
533
 
534
- <AutocorrelationSection data={flatChartData} fps={fps} agg={crossEpisodeData?.aggAutocorrelation} numEpisodes={crossEpisodeData?.numEpisodes} />
535
- <ActionVelocitySection data={flatChartData} agg={crossEpisodeData?.aggVelocity} numEpisodes={crossEpisodeData?.numEpisodes} />
 
 
 
 
 
536
  <VarianceHeatmap data={crossEpisodeData} loading={crossEpisodeLoading} />
 
 
 
 
537
  {crossEpisodeData?.lowMovementEpisodes && (
538
  <LowMovementSection episodes={crossEpisodeData.lowMovementEpisodes} />
539
  )}
 
1
  "use client";
2
 
3
+ import React, { useMemo, useState } from "react";
4
  import {
5
  LineChart,
6
  Line,
 
10
  ResponsiveContainer,
11
  Tooltip,
12
  } from "recharts";
13
+ import type { CrossEpisodeVarianceData, LowMovementEpisode, AggVelocityStat, AggAutocorrelation, SpeedDistEntry, TrajectoryClustering, AggAlignment } from "@/app/[org]/[dataset]/[episode]/fetch-data";
14
 
15
  const DELIMITER = " | ";
16
  const COLORS = [
 
30
  .sort();
31
  }
32
 
33
+ function getStateKeys(row: Record<string, number>): string[] {
34
+ return Object.keys(row)
35
+ .filter(k => k.includes("state") && k !== "timestamp" && !k.startsWith("action"))
36
+ .sort();
37
+ }
38
+
39
  // ─── Autocorrelation ─────────────────────────────────────────────
40
 
41
  function computeAutocorrelation(values: number[], maxLag: number): number[] {
 
513
  );
514
  }
515
 
516
+ // ─── Demonstrator Speed Variance ────────────────────────────────
517
+
518
+ function SpeedVarianceSection({ distribution, numEpisodes }: { distribution: SpeedDistEntry[]; numEpisodes: number }) {
519
+ const { speeds, mean, std, cv, median, bins, lo, binW, maxBin, verdict } = useMemo(() => {
520
+ const sp = distribution.map(d => d.speed).sort((a, b) => a - b);
521
+ const m = sp.reduce((a, b) => a + b, 0) / sp.length;
522
+ const s = Math.sqrt(sp.reduce((a, v) => a + (v - m) ** 2, 0) / sp.length);
523
+ const c = m > 0 ? s / m : 0;
524
+ const med = sp[Math.floor(sp.length / 2)];
525
+
526
+ const binCount = Math.min(30, Math.ceil(Math.sqrt(sp.length)));
527
+ const lo = sp[0], hi = sp[sp.length - 1];
528
+ const bw = (hi - lo || 1) / binCount;
529
+ const b = new Array(binCount).fill(0);
530
+ for (const v of sp) { let i = Math.floor((v - lo) / bw); if (i >= binCount) i = binCount - 1; b[i]++; }
531
+
532
+ let v: { label: string; color: string; tip: string };
533
+ if (c < 0.2) v = { label: "Consistent", color: "text-green-400", tip: "Demonstrators execute at similar speeds β€” no velocity normalization needed." };
534
+ else if (c < 0.4) v = { label: "Moderate variance", color: "text-yellow-400", tip: "Some speed variation across demonstrators. Consider velocity normalization for best results." };
535
+ else v = { label: "High variance", color: "text-red-400", tip: "Large speed differences between demonstrations. Velocity normalization before training is strongly recommended." };
536
+
537
+ return { speeds: sp, mean: m, std: s, cv: c, median: med, bins: b, lo, binW: bw, maxBin: Math.max(...b), verdict: v };
538
+ }, [distribution]);
539
+
540
+ if (speeds.length < 3) return null;
541
+
542
+ const barH = 100;
543
+ const barW = Math.max(8, Math.floor(500 / bins.length));
544
+
545
+ return (
546
+ <div className="bg-slate-800/60 rounded-lg p-5 border border-slate-700 space-y-4">
547
+ <div>
548
+ <h3 className="text-sm font-semibold text-slate-200">
549
+ Demonstrator Speed Variance
550
+ <span className="text-xs text-slate-500 ml-2 font-normal">({numEpisodes} episodes)</span>
551
+ </h3>
552
+ <p className="text-xs text-slate-400 mt-1">
553
+ Distribution of average execution speed (mean β€–Ξ”a<sub>t</sub>β€– per frame) across all episodes.
554
+ Different human demonstrators often execute at <span className="text-orange-400">different speeds</span>, creating
555
+ artificial multimodality in the action distribution that confuses the policy. A coefficient of variation (CV) above 0.3
556
+ strongly suggests normalizing trajectory speed before training.
557
+ <br />
558
+ <span className="text-slate-500">
559
+ Based on &quot;Is Diversity All You Need&quot; (AGI-Bot, 2025) which shows velocity normalization dramatically improves
560
+ fine-tuning success rate. Also relates to ACT (Zhao et al., 2023) and Pi0.5 (Physical Intelligence, 2025).
561
+ </span>
562
+ </p>
563
+ </div>
564
+
565
+ <div className="flex gap-4">
566
+ <div className="flex-1 overflow-x-auto">
567
+ <svg width={bins.length * barW} height={barH + 24} className="block">
568
+ {bins.map((count: number, i: number) => {
569
+ const h = maxBin > 0 ? (count / maxBin) * barH : 0;
570
+ const speed = lo + (i + 0.5) * binW;
571
+ const ratio = median > 0 ? speed / median : 1;
572
+ const dev = Math.abs(ratio - 1);
573
+ const color = dev < 0.2 ? "#22c55e" : dev < 0.5 ? "#eab308" : "#ef4444";
574
+ return (
575
+ <rect key={i} x={i * barW} y={barH - h} width={barW - 1} height={Math.max(1, h)} fill={color} opacity={0.7} rx={1}>
576
+ <title>{`Speed ${(lo + i * binW).toFixed(3)}–${(lo + (i + 1) * binW).toFixed(3)}: ${count} ep (${ratio.toFixed(2)}Γ— median)`}</title>
577
+ </rect>
578
+ );
579
+ })}
580
+ {[0, 0.25, 0.5, 0.75, 1].map(frac => {
581
+ const idx = Math.round(frac * (bins.length - 1));
582
+ return (
583
+ <text key={frac} x={idx * barW + barW / 2} y={barH + 14} textAnchor="middle" className="fill-slate-400" fontSize={9}>
584
+ {(lo + idx * binW).toFixed(2)}
585
+ </text>
586
+ );
587
+ })}
588
+ </svg>
589
+ </div>
590
+ <div className="flex flex-col gap-2 text-xs shrink-0 min-w-[120px]">
591
+ <div><span className="text-slate-500">Mean</span> <span className="text-slate-200 tabular-nums ml-1">{mean.toFixed(4)}</span></div>
592
+ <div><span className="text-slate-500">Median</span> <span className="text-slate-200 tabular-nums ml-1">{median.toFixed(4)}</span></div>
593
+ <div><span className="text-slate-500">Std</span> <span className="text-slate-200 tabular-nums ml-1">{std.toFixed(4)}</span></div>
594
+ <div>
595
+ <span className="text-slate-500">CV</span>
596
+ <span className={`tabular-nums ml-1 font-bold ${verdict.color}`}>{cv.toFixed(3)}</span>
597
+ </div>
598
+ </div>
599
+ </div>
600
+
601
+ <div className="bg-slate-900/60 rounded-md px-4 py-3 border border-slate-700/60 space-y-1.5">
602
+ <p className="text-sm font-medium text-slate-200">
603
+ Verdict: <span className={verdict.color}>{verdict.label}</span>
604
+ </p>
605
+ <p className="text-xs text-slate-400">{verdict.tip}</p>
606
+ </div>
607
+ </div>
608
+ );
609
+ }
610
+
611
+ // ─── State–Action Temporal Alignment ────────────────────────────
612
+
613
+ function StateActionAlignmentSection({ data, fps, agg, numEpisodes }: { data: Record<string, number>[]; fps: number; agg?: AggAlignment | null; numEpisodes?: number }) {
614
+ const result = useMemo(() => {
615
+ if (agg) return { ...agg, fromAgg: true };
616
+ if (data.length < 10) return null;
617
+ const actionKeys = getActionKeys(data[0]);
618
+ const stateKeys = getStateKeys(data[0]);
619
+ if (actionKeys.length === 0 || stateKeys.length === 0) return null;
620
+ const maxLag = Math.min(Math.floor(data.length / 4), 30);
621
+ if (maxLag < 2) return null;
622
+
623
+ // Match action↔state by suffix, fall back to index matching
624
+ const pairs: [string, string][] = [];
625
+ for (const aKey of actionKeys) {
626
+ const match = stateKeys.find(sKey => shortName(sKey) === shortName(aKey));
627
+ if (match) pairs.push([aKey, match]);
628
+ }
629
+ if (pairs.length === 0) {
630
+ const count = Math.min(actionKeys.length, stateKeys.length);
631
+ for (let i = 0; i < count; i++) pairs.push([actionKeys[i], stateKeys[i]]);
632
+ }
633
+ if (pairs.length === 0) return null;
634
+
635
+ // Per-pair cross-correlation
636
+ const pairCorrs: number[][] = [];
637
+ for (const [aKey, sKey] of pairs) {
638
+ const aVals = data.map(row => row[aKey] ?? 0);
639
+ const sDeltas = data.slice(1).map((row, i) => (row[sKey] ?? 0) - (data[i][sKey] ?? 0));
640
+ const n = Math.min(aVals.length, sDeltas.length);
641
+ const aM = aVals.slice(0, n).reduce((a, b) => a + b, 0) / n;
642
+ const sM = sDeltas.slice(0, n).reduce((a, b) => a + b, 0) / n;
643
+
644
+ const corrs: number[] = [];
645
+ for (let lag = -maxLag; lag <= maxLag; lag++) {
646
+ let sum = 0, aV = 0, sV = 0;
647
+ for (let t = 0; t < n; t++) {
648
+ const sIdx = t + lag;
649
+ if (sIdx < 0 || sIdx >= sDeltas.length) continue;
650
+ const a = aVals[t] - aM, s = sDeltas[sIdx] - sM;
651
+ sum += a * s; aV += a * a; sV += s * s;
652
+ }
653
+ const d = Math.sqrt(aV * sV);
654
+ corrs.push(d > 0 ? sum / d : 0);
655
+ }
656
+ pairCorrs.push(corrs);
657
+ }
658
+
659
+ // Aggregate min/mean/max per lag
660
+ const ccData = Array.from({ length: 2 * maxLag + 1 }, (_, li) => {
661
+ const lag = -maxLag + li;
662
+ const vals = pairCorrs.map(pc => pc[li]);
663
+ return {
664
+ lag, time: lag / fps,
665
+ max: Math.max(...vals),
666
+ mean: vals.reduce((a, b) => a + b, 0) / vals.length,
667
+ min: Math.min(...vals),
668
+ };
669
+ });
670
+
671
+ // Peaks of the envelope curves
672
+ let meanPeakLag = 0, meanPeakCorr = -Infinity;
673
+ let maxPeakLag = 0, maxPeakCorr = -Infinity;
674
+ let minPeakLag = 0, minPeakCorr = -Infinity;
675
+ for (const row of ccData) {
676
+ if (row.max > maxPeakCorr) { maxPeakCorr = row.max; maxPeakLag = row.lag; }
677
+ if (row.mean > meanPeakCorr) { meanPeakCorr = row.mean; meanPeakLag = row.lag; }
678
+ if (row.min > minPeakCorr) { minPeakCorr = row.min; minPeakLag = row.lag; }
679
+ }
680
+
681
+ // Per-pair individual peak lags (for showing the true range across dimensions)
682
+ const perPairPeakLags = pairCorrs.map(pc => {
683
+ let best = -Infinity, bestLag = 0;
684
+ for (let li = 0; li < pc.length; li++) {
685
+ if (pc[li] > best) { best = pc[li]; bestLag = -maxLag + li; }
686
+ }
687
+ return bestLag;
688
+ });
689
+ const lagRangeMin = Math.min(...perPairPeakLags);
690
+ const lagRangeMax = Math.max(...perPairPeakLags);
691
+
692
+ return { ccData, meanPeakLag, meanPeakCorr, maxPeakLag, maxPeakCorr, minPeakLag, minPeakCorr, lagRangeMin, lagRangeMax, numPairs: pairs.length, fromAgg: false };
693
+ }, [data, fps, agg]);
694
+
695
+ if (!result) return null;
696
+ const { ccData, meanPeakLag, meanPeakCorr, maxPeakLag, maxPeakCorr, minPeakLag, minPeakCorr, lagRangeMin, lagRangeMax, numPairs, fromAgg } = result;
697
+ const scopeLabel = fromAgg ? `${numEpisodes} episodes sampled` : "current episode";
698
+
699
+ return (
700
+ <div className="bg-slate-800/60 rounded-lg p-5 border border-slate-700 space-y-4">
701
+ <div>
702
+ <h3 className="text-sm font-semibold text-slate-200">
703
+ State–Action Temporal Alignment
704
+ <span className="text-xs text-slate-500 ml-2 font-normal">({scopeLabel}, {numPairs} matched pair{numPairs !== 1 ? "s" : ""})</span>
705
+ </h3>
706
+ <p className="text-xs text-slate-400 mt-1">
707
+ Per-dimension cross-correlation between action<sub>d</sub>(t) and Ξ”state<sub>d</sub>(t+lag), aggregated as
708
+ <span className="text-orange-400"> max</span>, <span className="text-slate-200">mean</span>, and
709
+ <span className="text-blue-400"> min</span> across all matched action–state pairs.
710
+ The <span className="text-orange-400">peak lag</span> reveals the effective control delay β€” the time between
711
+ when an action is commanded and when the corresponding state changes.
712
+ <br />
713
+ <span className="text-slate-500">
714
+ Central to ACT (Zhao et al., 2023 β€” action chunking compensates for delay),
715
+ Real-Time Chunking (RTC, 2024), and Training-Time RTC (Biza et al., 2025) β€” all address
716
+ the timing mismatch between commanded actions and observed state changes.
717
+ </span>
718
+ </p>
719
+ </div>
720
+
721
+ {meanPeakLag !== 0 && (
722
+ <div className="flex items-center gap-3 bg-orange-500/10 border border-orange-500/30 rounded-md px-4 py-2.5">
723
+ <span className="text-orange-400 font-bold text-lg tabular-nums">{meanPeakLag}</span>
724
+ <div>
725
+ <p className="text-sm text-orange-300 font-medium">
726
+ Mean control delay: {meanPeakLag} step{Math.abs(meanPeakLag) !== 1 ? "s" : ""} ({(meanPeakLag / fps).toFixed(3)}s)
727
+ </p>
728
+ <p className="text-xs text-slate-400">
729
+ {meanPeakLag > 0
730
+ ? `State changes lag behind actions by ~${meanPeakLag} frames on average. Consider aligning action[t] with state[t+${meanPeakLag}].`
731
+ : `Actions lag behind state changes by ~${-meanPeakLag} frames on average (predictive actions).`}
732
+ {lagRangeMin !== lagRangeMax && ` Individual dimension peaks range from ${lagRangeMin} to ${lagRangeMax} steps.`}
733
+ </p>
734
+ </div>
735
+ </div>
736
+ )}
737
+
738
+ <div className="h-56">
739
+ <ResponsiveContainer width="100%" height="100%">
740
+ <LineChart data={ccData} margin={{ top: 8, right: 16, left: 0, bottom: 16 }}>
741
+ <CartesianGrid strokeDasharray="3 3" stroke="#334155" />
742
+ <XAxis dataKey="lag" stroke="#94a3b8"
743
+ label={{ value: "Lag (steps)", position: "insideBottom", offset: -8, fill: "#94a3b8", fontSize: 11 }} />
744
+ <YAxis stroke="#94a3b8" domain={[-0.5, 1]} />
745
+ <Tooltip
746
+ contentStyle={{ background: "#1e293b", border: "1px solid #475569", borderRadius: 6 }}
747
+ labelFormatter={(v) => `Lag ${v} (${(Number(v) / fps).toFixed(3)}s)`}
748
+ formatter={(v: number) => v.toFixed(3)}
749
+ />
750
+ <Line dataKey="max" stroke="#f97316" dot={false} strokeWidth={2} isAnimationActive={false} name="max" />
751
+ <Line dataKey="mean" stroke="#94a3b8" dot={false} strokeWidth={2} isAnimationActive={false} name="mean" />
752
+ <Line dataKey="min" stroke="#3b82f6" dot={false} strokeWidth={2} isAnimationActive={false} name="min" />
753
+ <Line dataKey={() => 0} stroke="#64748b" strokeDasharray="6 4" dot={false} name="zero" legendType="none" isAnimationActive={false} />
754
+ </LineChart>
755
+ </ResponsiveContainer>
756
+ </div>
757
+
758
+ <div className="flex flex-wrap gap-x-4 gap-y-1 px-1">
759
+ <div className="flex items-center gap-1.5">
760
+ <span className="w-3 h-[3px] rounded-full shrink-0 bg-orange-500" />
761
+ <span className="text-[11px] text-slate-400">max (peak: lag {maxPeakLag}, r={maxPeakCorr.toFixed(3)})</span>
762
+ </div>
763
+ <div className="flex items-center gap-1.5">
764
+ <span className="w-3 h-[3px] rounded-full shrink-0 bg-slate-400" />
765
+ <span className="text-[11px] text-slate-400">mean (peak: lag {meanPeakLag}, r={meanPeakCorr.toFixed(3)})</span>
766
+ </div>
767
+ <div className="flex items-center gap-1.5">
768
+ <span className="w-3 h-[3px] rounded-full shrink-0 bg-blue-500" />
769
+ <span className="text-[11px] text-slate-400">min (peak: lag {minPeakLag}, r={minPeakCorr.toFixed(3)})</span>
770
+ </div>
771
+ </div>
772
+
773
+ {meanPeakLag === 0 && (
774
+ <p className="text-xs text-green-400">
775
+ Mean peak correlation at lag 0 (r={meanPeakCorr.toFixed(3)}) β€” actions and state changes are well-aligned in this episode.
776
+ </p>
777
+ )}
778
+ </div>
779
+ );
780
+ }
781
+
782
+ // ─── Multimodality Detection ────────────────────────────────────
783
+
784
+ const BC_THRESHOLD = 5 / 9;
785
+
786
+ function MultimodalitySection({ data }: { data: CrossEpisodeVarianceData }) {
787
+ const { actionNames, timeBins, multimodality, numEpisodes } = data;
788
+ if (!multimodality || multimodality.length === 0) return null;
789
+
790
+ const numDims = actionNames.length;
791
+ const numBins = timeBins.length;
792
+
793
+ const { bimodalPct, verdict } = useMemo(() => {
794
+ let bimodal = 0, total = 0;
795
+ for (const row of multimodality!) for (const v of row) { total++; if (v > BC_THRESHOLD) bimodal++; }
796
+ const pct = total > 0 ? (bimodal / total * 100) : 0;
797
+
798
+ let v: { label: string; color: string };
799
+ if (pct < 10) v = { label: "Mostly Unimodal", color: "text-green-400" };
800
+ else if (pct < 30) v = { label: "Some Multimodality", color: "text-yellow-400" };
801
+ else v = { label: "Significantly Multimodal", color: "text-red-400" };
802
+
803
+ return { bimodalPct: pct, verdict: v };
804
+ }, [multimodality]);
805
+
806
+ const cellW = Math.max(6, Math.min(14, Math.floor(560 / numBins)));
807
+ const cellH = Math.max(20, Math.min(36, Math.floor(300 / numDims)));
808
+ const labelW = 100;
809
+ const svgW = labelW + numBins * cellW + 60;
810
+ const svgH = numDims * cellH + 40;
811
+
812
+ function bcColor(bc: number): string {
813
+ if (bc < 0.4) {
814
+ const t = bc / 0.4;
815
+ return `rgb(${Math.round(34 + t * 200)}, ${Math.round(197 - t * 50)}, ${Math.round(94 - t * 50)})`;
816
+ }
817
+ const t = Math.min(1, (bc - 0.4) / 0.4);
818
+ return `rgb(${Math.round(234 + t * 5)}, ${Math.round(147 - t * 79)}, ${Math.round(44 + t * 24)})`;
819
+ }
820
+
821
+ return (
822
+ <div className="bg-slate-800/60 rounded-lg p-5 border border-slate-700 space-y-4">
823
+ <div>
824
+ <h3 className="text-sm font-semibold text-slate-200">
825
+ Multimodality Detection
826
+ <span className="text-xs text-slate-500 ml-2 font-normal">({numEpisodes} episodes sampled)</span>
827
+ </h3>
828
+ <p className="text-xs text-slate-400 mt-1">
829
+ Bimodality coefficient (BC) per action dimension over episode progress.
830
+ BC values above <span className="text-red-400">5/9 β‰ˆ 0.556</span> suggest the action distribution at that point is bimodal β€”
831
+ meaning demonstrators use <span className="text-red-400">multiple distinct strategies</span>. This directly answers:
832
+ &quot;Do I need a generative policy (diffusion, flow-matching) or would MSE regression work?&quot;
833
+ <br />
834
+ <span className="text-slate-500">
835
+ Grounded in Diffusion Policy (Chi et al., 2023 β€” diffusion handles multimodality natively),
836
+ ACT (Zhao et al., 2023 β€” CVAE captures multiple modes). Extends the cross-episode variance heatmap above
837
+ by distinguishing true multimodality from mere noise.
838
+ </span>
839
+ </p>
840
+ </div>
841
+
842
+ <div className="overflow-x-auto">
843
+ <svg width={svgW} height={svgH} className="block">
844
+ {multimodality.map((row, bi) => row.map((bc, di) => (
845
+ <rect key={`${bi}-${di}`} x={labelW + bi * cellW} y={di * cellH} width={cellW} height={cellH}
846
+ fill={bcColor(bc)} stroke="#1e293b" strokeWidth={0.5}>
847
+ <title>{`${shortName(actionNames[di])} @ ${(timeBins[bi] * 100).toFixed(0)}%: BC=${bc.toFixed(3)} ${bc > BC_THRESHOLD ? "(bimodal)" : "(unimodal)"}`}</title>
848
+ </rect>
849
+ )))}
850
+ {actionNames.map((name, di) => (
851
+ <text key={di} x={labelW - 4} y={di * cellH + cellH / 2} textAnchor="end" dominantBaseline="central"
852
+ className="fill-slate-400" fontSize={Math.min(11, cellH - 4)}>{shortName(name)}</text>
853
+ ))}
854
+ {[0, 0.25, 0.5, 0.75, 1].map(frac => {
855
+ const binIdx = Math.round(frac * (numBins - 1));
856
+ return (
857
+ <text key={frac} x={labelW + binIdx * cellW + cellW / 2} y={numDims * cellH + 14}
858
+ textAnchor="middle" className="fill-slate-400" fontSize={9}>{(frac * 100).toFixed(0)}%</text>
859
+ );
860
+ })}
861
+ <text x={labelW + (numBins * cellW) / 2} y={numDims * cellH + 30}
862
+ textAnchor="middle" className="fill-slate-500" fontSize={10}>Episode progress</text>
863
+ {Array.from({ length: 10 }, (_, i) => {
864
+ const t = i / 9;
865
+ const barX = labelW + numBins * cellW + 16;
866
+ const barCellH = (numDims * cellH) / 10;
867
+ return <rect key={i} x={barX} y={(9 - i) * barCellH} width={12} height={barCellH} fill={bcColor(t)} />;
868
+ })}
869
+ <text x={labelW + numBins * cellW + 34} y={10} className="fill-slate-500" fontSize={8}
870
+ dominantBaseline="central">bimodal</text>
871
+ <text x={labelW + numBins * cellW + 34} y={numDims * cellH - 4} className="fill-slate-500" fontSize={8}
872
+ dominantBaseline="central">unimodal</text>
873
+ </svg>
874
+ </div>
875
+
876
+ <div className="bg-slate-900/60 rounded-md px-4 py-3 border border-slate-700/60 space-y-1.5">
877
+ <p className="text-sm font-medium text-slate-200">
878
+ Assessment: <span className={verdict.color}>{verdict.label}</span>
879
+ <span className="text-xs text-slate-500 ml-2">{bimodalPct.toFixed(1)}% of regions above threshold</span>
880
+ </p>
881
+ <p className="text-xs text-slate-400">
882
+ {bimodalPct < 10
883
+ ? "Action distributions are mostly unimodal β€” MSE regression or simple flow-matching should work well."
884
+ : bimodalPct < 30
885
+ ? "Moderate multimodality detected. A generative policy (diffusion/flow-matching) will likely outperform MSE regression in the highlighted regions."
886
+ : "Significant multimodality across the trajectory. A generative policy (diffusion or flow-matching action head) is strongly recommended over MSE regression."}
887
+ </p>
888
+ </div>
889
+ </div>
890
+ );
891
+ }
892
+
893
+ // ─── Trajectory Clustering & Outlier Detection ──────────────────
894
+
895
+ const CLUSTER_COLORS = ["#f97316", "#3b82f6", "#22c55e", "#a855f7", "#eab308"];
896
+
897
+ function TrajectoryClusteringSection({ data, numEpisodes }: { data: TrajectoryClustering; numEpisodes: number }) {
898
+ const { entries, numClusters, clusterSizes, outlierCount } = data;
899
+
900
+ const [showAll, setShowAll] = useState(false);
901
+ const [rotX, setRotX] = useState(-0.4);
902
+ const [rotY, setRotY] = useState(0.6);
903
+ const dragRef = React.useRef<{ x: number; y: number; rx: number; ry: number } | null>(null);
904
+
905
+ const sorted = useMemo(() =>
906
+ [...entries].sort((a, b) => b.distFromCenter - a.distFromCenter),
907
+ [entries]);
908
+
909
+ const plotW = 500, plotH = 400;
910
+ const cx = plotW / 2, cy = plotH / 2;
911
+ const scale = Math.min(plotW, plotH) * 0.35;
912
+
913
+ // Normalize xyz to [-1, 1]
914
+ const bounds = useMemo(() => {
915
+ let xMin = Infinity, xMax = -Infinity, yMin = Infinity, yMax = -Infinity, zMin = Infinity, zMax = -Infinity;
916
+ for (const e of entries) {
917
+ if (e.x < xMin) xMin = e.x; if (e.x > xMax) xMax = e.x;
918
+ if (e.y < yMin) yMin = e.y; if (e.y > yMax) yMax = e.y;
919
+ if (e.z < zMin) zMin = e.z; if (e.z > zMax) zMax = e.z;
920
+ }
921
+ const r = Math.max(xMax - xMin, yMax - yMin, zMax - zMin) / 2 || 1;
922
+ return { mx: (xMin + xMax) / 2, my: (yMin + yMax) / 2, mz: (zMin + zMax) / 2, r };
923
+ }, [entries]);
924
+
925
+ // Project 3D β†’ 2D with rotation
926
+ const projected = useMemo(() => {
927
+ const cosX = Math.cos(rotX), sinX = Math.sin(rotX);
928
+ const cosY = Math.cos(rotY), sinY = Math.sin(rotY);
929
+ return entries.map(e => {
930
+ const nx = (e.x - bounds.mx) / bounds.r;
931
+ const ny = (e.y - bounds.my) / bounds.r;
932
+ const nz = (e.z - bounds.mz) / bounds.r;
933
+ // Rotate around Y then X
934
+ const x1 = nx * cosY + nz * sinY;
935
+ const z1 = -nx * sinY + nz * cosY;
936
+ const y1 = ny * cosX - z1 * sinX;
937
+ const z2 = ny * sinX + z1 * cosX;
938
+ return { sx: cx + x1 * scale, sy: cy - y1 * scale, depth: z2, entry: e };
939
+ });
940
+ }, [entries, rotX, rotY, bounds, cx, cy, scale]);
941
+
942
+ // Sort by depth (back to front) for correct overlap
943
+ const sortedByDepth = useMemo(() => [...projected].sort((a, b) => a.depth - b.depth), [projected]);
944
+
945
+ const onMouseDown = (ev: React.MouseEvent) => {
946
+ dragRef.current = { x: ev.clientX, y: ev.clientY, rx: rotX, ry: rotY };
947
+ };
948
+ const onMouseMove = (ev: React.MouseEvent) => {
949
+ if (!dragRef.current) return;
950
+ setRotY(dragRef.current.ry + (ev.clientX - dragRef.current.x) * 0.005);
951
+ setRotX(dragRef.current.rx + (ev.clientY - dragRef.current.y) * 0.005);
952
+ };
953
+ const onMouseUp = () => { dragRef.current = null; };
954
+
955
+ // Axis lines
956
+ const axisLines = useMemo(() => {
957
+ const cosX = Math.cos(rotX), sinX = Math.sin(rotX);
958
+ const cosY = Math.cos(rotY), sinY = Math.sin(rotY);
959
+ const project = (x: number, y: number, z: number) => {
960
+ const x1 = x * cosY + z * sinY;
961
+ const z1 = -x * sinY + z * cosY;
962
+ const y1 = y * cosX - z1 * sinX;
963
+ return { px: cx + x1 * scale * 0.9, py: cy - y1 * scale * 0.9 };
964
+ };
965
+ const len = 1.1;
966
+ return [
967
+ { ...project(len, 0, 0), label: "PC1", color: "#64748b" },
968
+ { ...project(0, len, 0), label: "PC2", color: "#64748b" },
969
+ { ...project(0, 0, len), label: "PC3", color: "#64748b" },
970
+ ].map(a => ({ ...a, ox: cx, oy: cy }));
971
+ }, [rotX, rotY, cx, cy, scale]);
972
+
973
+ const imbalance = useMemo(() => {
974
+ const max = Math.max(...clusterSizes), min = Math.min(...clusterSizes);
975
+ return max > 0 ? (max - min) / max : 0;
976
+ }, [clusterSizes]);
977
+
978
+ return (
979
+ <div className="bg-slate-800/60 rounded-lg p-5 border border-slate-700 space-y-4">
980
+ <div>
981
+ <h3 className="text-sm font-semibold text-slate-200">
982
+ Trajectory Clustering & Outlier Detection
983
+ <span className="text-xs text-slate-500 ml-2 font-normal">({numEpisodes} episodes sampled)</span>
984
+ </h3>
985
+ <p className="text-xs text-slate-400 mt-1">
986
+ Episodes clustered by trajectory similarity: each episode&apos;s action trajectory is time-normalized, standardized,
987
+ and projected to 3D via PCA. K-means clustering (k selected by silhouette score) groups similar demonstrations.
988
+ <span className="text-red-400"> Outlier episodes</span> ({">"} 2Οƒ from cluster center) may indicate recording errors,
989
+ failed demonstrations, or fundamentally different strategies worth reviewing.
990
+ <span className="text-yellow-400"> Imbalanced clusters</span> suggest multimodal demonstrations.
991
+ Drag to rotate.
992
+ <br />
993
+ <span className="text-slate-500">
994
+ Grounded in FAST-UMI-100K (Zhao et al., 2025 β€” automatic quality tools at scale),
995
+ &quot;Curating Demonstrations using Online Experience&quot; (Burns et al., 2025),
996
+ GVL (Mazzaglia et al., 2024), and SARM (Li et al., 2025).
997
+ </span>
998
+ </p>
999
+ </div>
1000
+
1001
+ <div className="flex gap-4 flex-wrap">
1002
+ <div className="flex-1 min-w-[340px]">
1003
+ <svg
1004
+ width={plotW} height={plotH}
1005
+ className="block bg-slate-900/50 rounded cursor-grab active:cursor-grabbing select-none"
1006
+ onMouseDown={onMouseDown} onMouseMove={onMouseMove} onMouseUp={onMouseUp} onMouseLeave={onMouseUp}
1007
+ >
1008
+ {/* Axis lines */}
1009
+ {axisLines.map(a => (
1010
+ <React.Fragment key={a.label}>
1011
+ <line x1={a.ox} y1={a.oy} x2={a.px} y2={a.py} stroke={a.color} strokeWidth={0.5} strokeDasharray="4 3" opacity={0.4} />
1012
+ <text x={a.px} y={a.py} className="fill-slate-600" fontSize={9} textAnchor="middle" dominantBaseline="central">{a.label}</text>
1013
+ </React.Fragment>
1014
+ ))}
1015
+ {/* Points sorted back→front */}
1016
+ {sortedByDepth.map(({ sx, sy, depth, entry: e }, i) => {
1017
+ const color = CLUSTER_COLORS[e.cluster % CLUSTER_COLORS.length];
1018
+ const depthFade = 0.3 + 0.7 * ((depth + 1) / 2);
1019
+ const r = e.isOutlier ? 5 : 2.5 + depthFade * 2;
1020
+ return (
1021
+ <circle key={i} cx={sx} cy={sy} r={r}
1022
+ fill={e.isOutlier ? "transparent" : color}
1023
+ stroke={e.isOutlier ? "#ef4444" : color}
1024
+ strokeWidth={e.isOutlier ? 2 : 0}
1025
+ opacity={e.isOutlier ? 1 : depthFade * 0.8}>
1026
+ <title>{`ep ${e.episodeIndex} β€” cluster ${e.cluster}${e.isOutlier ? " (outlier)" : ""}, dist=${e.distFromCenter.toFixed(2)}`}</title>
1027
+ </circle>
1028
+ );
1029
+ })}
1030
+ </svg>
1031
+ </div>
1032
+
1033
+ <div className="flex flex-col gap-3 text-xs shrink-0 min-w-[160px]">
1034
+ <div>
1035
+ <p className="text-slate-500 mb-1">Clusters: {numClusters}</p>
1036
+ {clusterSizes.map((size, c) => (
1037
+ <div key={c} className="flex items-center gap-2 py-0.5">
1038
+ <span className="w-2.5 h-2.5 rounded-full shrink-0" style={{ background: CLUSTER_COLORS[c % CLUSTER_COLORS.length] }} />
1039
+ <span className="text-slate-300">Cluster {c}</span>
1040
+ <span className="text-slate-500 tabular-nums ml-auto">{size} ep</span>
1041
+ </div>
1042
+ ))}
1043
+ </div>
1044
+ {outlierCount > 0 && (
1045
+ <div className="flex items-center gap-2">
1046
+ <span className="w-2.5 h-2.5 rounded-full shrink-0 border-2 border-red-500" />
1047
+ <span className="text-red-400">{outlierCount} outlier{outlierCount !== 1 ? "s" : ""}</span>
1048
+ </div>
1049
+ )}
1050
+ {imbalance > 0.5 && (
1051
+ <p className="text-yellow-400 text-[11px]">
1052
+ Clusters are imbalanced ({(imbalance * 100).toFixed(0)}% size ratio) β€” the dataset may contain multiple distinct strategies.
1053
+ </p>
1054
+ )}
1055
+ </div>
1056
+ </div>
1057
+
1058
+ <div className="bg-slate-900/60 rounded-md px-4 py-3 border border-slate-700/60 space-y-2">
1059
+ <div className="flex items-center justify-between">
1060
+ <p className="text-sm font-medium text-slate-200">
1061
+ {showAll ? "All Episodes" : "Most Anomalous Episodes"} <span className="text-xs text-slate-500 font-normal">sorted by distance from cluster center</span>
1062
+ </p>
1063
+ <button onClick={() => setShowAll(v => !v)} className="text-xs text-slate-400 hover:text-slate-200 transition-colors">
1064
+ {showAll ? "Show top 15" : `Show all ${entries.length}`}
1065
+ </button>
1066
+ </div>
1067
+ <div className="max-h-48 overflow-y-auto">
1068
+ <table className="w-full text-xs">
1069
+ <thead>
1070
+ <tr className="text-slate-500 border-b border-slate-700">
1071
+ <th className="text-left py-1 pr-3">Episode</th>
1072
+ <th className="text-left py-1 pr-3">Cluster</th>
1073
+ <th className="text-right py-1">Distance</th>
1074
+ </tr>
1075
+ </thead>
1076
+ <tbody>
1077
+ {(showAll ? sorted : sorted.slice(0, 15)).map(e => (
1078
+ <tr key={e.episodeIndex} className={`border-b border-slate-800/40 ${e.isOutlier ? "text-red-400" : "text-slate-300"}`}>
1079
+ <td className="py-1 pr-3">ep {e.episodeIndex}{e.isOutlier ? " ⚠" : ""}</td>
1080
+ <td className="py-1 pr-3">
1081
+ <span className="inline-block w-2 h-2 rounded-full mr-1.5" style={{ background: CLUSTER_COLORS[e.cluster % CLUSTER_COLORS.length] }} />
1082
+ {e.cluster}
1083
+ </td>
1084
+ <td className="py-1 text-right tabular-nums">{e.distFromCenter.toFixed(2)}</td>
1085
+ </tr>
1086
+ ))}
1087
+ </tbody>
1088
+ </table>
1089
+ </div>
1090
+ </div>
1091
+ </div>
1092
+ );
1093
+ }
1094
+
1095
  // ─── Main Panel ──────────────────────────────────────────────────
1096
 
1097
  interface ActionInsightsPanelProps {
 
1107
  crossEpisodeData,
1108
  crossEpisodeLoading,
1109
  }) => {
1110
+ const [mode, setMode] = useState<"episode" | "dataset">("dataset");
1111
+ const showAgg = mode === "dataset" && !!crossEpisodeData;
1112
+
1113
  return (
1114
  <div className="max-w-5xl mx-auto py-6 space-y-8">
1115
+ <div className="flex items-center justify-between flex-wrap gap-4">
1116
+ <div>
1117
+ <h2 className="text-xl font-bold text-slate-100">Action Insights</h2>
1118
+ <p className="text-sm text-slate-400 mt-1">
1119
+ Data-driven analysis to guide action chunking, data quality assessment, and training configuration.
1120
+ </p>
1121
+ </div>
1122
+ <div className="flex items-center gap-3 shrink-0">
1123
+ <span className={`text-sm ${mode === "episode" ? "text-slate-100 font-medium" : "text-slate-500"}`}>Current Episode</span>
1124
+ <button
1125
+ onClick={() => setMode(m => m === "episode" ? "dataset" : "episode")}
1126
+ className={`relative inline-flex items-center w-9 h-5 rounded-full transition-colors shrink-0 ${mode === "dataset" ? "bg-orange-500" : "bg-slate-600"}`}
1127
+ aria-label="Toggle episode/dataset scope"
1128
+ >
1129
+ <span className={`inline-block w-3.5 h-3.5 bg-white rounded-full transition-transform ${mode === "dataset" ? "translate-x-[18px]" : "translate-x-[3px]"}`} />
1130
+ </button>
1131
+ <span className={`text-sm ${mode === "dataset" ? "text-slate-100 font-medium" : "text-slate-500"}`}>
1132
+ All Episodes{crossEpisodeData ? ` (${crossEpisodeData.numEpisodes})` : ""}
1133
+ </span>
1134
+ </div>
1135
  </div>
1136
 
1137
+ <AutocorrelationSection data={flatChartData} fps={fps} agg={showAgg ? crossEpisodeData?.aggAutocorrelation : null} numEpisodes={crossEpisodeData?.numEpisodes} />
1138
+ <ActionVelocitySection data={flatChartData} agg={showAgg ? crossEpisodeData?.aggVelocity : undefined} numEpisodes={crossEpisodeData?.numEpisodes} />
1139
+
1140
+ {crossEpisodeData?.speedDistribution && crossEpisodeData.speedDistribution.length > 2 && (
1141
+ <SpeedVarianceSection distribution={crossEpisodeData.speedDistribution} numEpisodes={crossEpisodeData.numEpisodes} />
1142
+ )}
1143
+ <StateActionAlignmentSection data={flatChartData} fps={fps} agg={showAgg ? crossEpisodeData?.aggAlignment : null} numEpisodes={crossEpisodeData?.numEpisodes} />
1144
  <VarianceHeatmap data={crossEpisodeData} loading={crossEpisodeLoading} />
1145
+ {crossEpisodeData && <MultimodalitySection data={crossEpisodeData} />}
1146
+ {crossEpisodeData?.trajectoryClustering && (
1147
+ <TrajectoryClusteringSection data={crossEpisodeData.trajectoryClustering} numEpisodes={crossEpisodeData.numEpisodes} />
1148
+ )}
1149
  {crossEpisodeData?.lowMovementEpisodes && (
1150
  <LowMovementSection episodes={crossEpisodeData.lowMovementEpisodes} />
1151
  )}