visualize_dataset / src /components /urdf-viewer.tsx
mishig's picture
mishig HF Staff
Sync from GitHub via hub-sync
5e6e3e2 verified
raw
history blame
31.3 kB
"use client";
import React, {
useState,
useEffect,
useRef,
useMemo,
useCallback,
} from "react";
import { Canvas, useThree, useFrame } from "@react-three/fiber";
import { OrbitControls, Grid, Html } from "@react-three/drei";
import * as THREE from "three";
import URDFLoader from "urdf-loader";
import type { URDFRobot } from "urdf-loader";
import { STLLoader } from "three/examples/jsm/loaders/STLLoader.js";
import { ColladaLoader } from "three/examples/jsm/loaders/ColladaLoader.js";
import { Line2 } from "three/examples/jsm/lines/Line2.js";
import { LineMaterial } from "three/examples/jsm/lines/LineMaterial.js";
import { LineGeometry } from "three/examples/jsm/lines/LineGeometry.js";
import type { EpisodeData } from "@/app/[org]/[dataset]/[episode]/fetch-data";
import { loadEpisodeFlatChartData } from "@/app/[org]/[dataset]/[episode]/fetch-data";
import UrdfPlaybackBar from "@/components/urdf-playback-bar";
import { CHART_CONFIG } from "@/utils/constants";
import { getDatasetVersionAndInfo } from "@/utils/versionUtils";
import type { DatasetMetadata } from "@/utils/parquetUtils";
const SERIES_DELIM = CHART_CONFIG.SERIES_NAME_DELIMITER;
const DEG2RAD = Math.PI / 180;
// Module-level geometry cache — survives component remounts (tab switches,
// episode navigations). Avoids re-fetching and re-parsing STL files.
const stlGeometryCache = new Map<string, THREE.BufferGeometry>();
// In-flight promise cache — prevents duplicate simultaneous fetches
const stlGeometryLoading = new Map<string, Promise<THREE.BufferGeometry>>();
function getRobotConfig(robotType: string | null) {
const lower = (robotType ?? "").toLowerCase();
if (lower.includes("g1") || lower.includes("unitree")) {
return { urdfUrl: "/urdf/g1/g1_body29_hand14.urdf", scale: 1 };
}
if (lower.includes("openarm")) {
return { urdfUrl: "/urdf/openarm/openarm_bimanual.urdf", scale: 3 };
}
if (lower.includes("so100") && !lower.includes("so101")) {
return { urdfUrl: "/urdf/so101/so100.urdf", scale: 10 };
}
return { urdfUrl: "/urdf/so101/so101_new_calib.urdf", scale: 10 };
}
// Detect unit: servo ticks (0-4096), degrees (>6.28), or radians
function detectAndConvert(values: number[]): number[] {
if (values.length === 0) return values;
const max = Math.max(...values.map(Math.abs));
if (max > 360) return values.map((v) => ((v - 2048) / 2048) * Math.PI); // servo ticks
if (max > 6.3) return values.map((v) => v * DEG2RAD); // degrees
return values; // already radians
}
function groupColumnsByPrefix(keys: string[]): Record<string, string[]> {
const groups: Record<string, string[]> = {};
for (const key of keys) {
if (key === "timestamp") continue;
const parts = key.split(SERIES_DELIM);
const prefix = parts.length > 1 ? parts[0].trim() : "other";
if (!groups[prefix]) groups[prefix] = [];
groups[prefix].push(key);
}
return groups;
}
// Unitree G1 SDK column suffix → URDF joint name
const G1_SDK_TO_URDF: Record<string, string> = {
"klefthippitch.q": "left_hip_pitch_joint",
"klefthiproll.q": "left_hip_roll_joint",
"klefthipyaw.q": "left_hip_yaw_joint",
"kleftknee.q": "left_knee_joint",
"kleftanklepitch.q": "left_ankle_pitch_joint",
"kleftankleroll.q": "left_ankle_roll_joint",
"krighthippitch.q": "right_hip_pitch_joint",
"krighthiproll.q": "right_hip_roll_joint",
"krighthipyaw.q": "right_hip_yaw_joint",
"krightknee.q": "right_knee_joint",
"krightanklepitch.q": "right_ankle_pitch_joint",
"krightankleroll.q": "right_ankle_roll_joint",
"kwaistyaw.q": "waist_yaw_joint",
"kwaistroll.q": "waist_roll_joint",
"kwaistpitch.q": "waist_pitch_joint",
"kleftshoulderpitch.q": "left_shoulder_pitch_joint",
"kleftshoulderroll.q": "left_shoulder_roll_joint",
"kleftshoulderyaw.q": "left_shoulder_yaw_joint",
"kleftelbow.q": "left_elbow_joint",
"kleftwristroll.q": "left_wrist_roll_joint",
"kleftwristpitch.q": "left_wrist_pitch_joint",
"kleftwristyaw.q": "left_wrist_yaw_joint",
"krightshoulderpitch.q": "right_shoulder_pitch_joint",
"krightshoulderroll.q": "right_shoulder_roll_joint",
"krightshoulderyaw.q": "right_shoulder_yaw_joint",
"krightelbow.q": "right_elbow_joint",
"krightwristroll.q": "right_wrist_roll_joint",
"krightwristpitch.q": "right_wrist_pitch_joint",
"krightwristyaw.q": "right_wrist_yaw_joint",
};
function autoMatchJoints(
urdfJointNames: string[],
columnKeys: string[],
): Record<string, string> {
const mapping: Record<string, string> = {};
const suffixes = columnKeys.map((k) =>
(k.split(SERIES_DELIM).pop()?.trim() ?? k).toLowerCase(),
);
// Build reverse lookup: URDF joint name → column key (for G1 SDK-style columns)
const g1Reverse = new Map<string, string>();
for (let i = 0; i < suffixes.length; i++) {
const urdfName = G1_SDK_TO_URDF[suffixes[i]];
if (urdfName) g1Reverse.set(urdfName, columnKeys[i]);
}
for (const jointName of urdfJointNames) {
const lower = jointName.toLowerCase();
// Exact match on column suffix
const exactIdx = suffixes.findIndex((s) => s === lower);
if (exactIdx >= 0) {
mapping[jointName] = columnKeys[exactIdx];
continue;
}
// G1 / Unitree SDK name match
const g1Col = g1Reverse.get(lower);
if (g1Col) {
mapping[jointName] = g1Col;
continue;
}
// OpenArm: openarm_(left|right)_joint(\d+) → (left|right)_joint_(\d+)
const armMatch = lower.match(/^openarm_(left|right)_joint(\d+)$/);
if (armMatch) {
const pattern = `${armMatch[1]}_joint_${armMatch[2]}`;
const idx = suffixes.findIndex((s) => s.includes(pattern));
if (idx >= 0) {
mapping[jointName] = columnKeys[idx];
continue;
}
}
// OpenArm: openarm_(left|right)_finger_joint1 → (left|right)_gripper
const fingerMatch = lower.match(/^openarm_(left|right)_finger_joint1$/);
if (fingerMatch) {
const pattern = `${fingerMatch[1]}_gripper`;
const idx = suffixes.findIndex((s) => s.includes(pattern));
if (idx >= 0) {
mapping[jointName] = columnKeys[idx];
continue;
}
}
// finger_joint2 is a mimic joint — skip
if (lower.includes("finger_joint2")) continue;
// Generic fuzzy fallback
const fuzzy = columnKeys.find((k) => k.toLowerCase().includes(lower));
if (fuzzy) mapping[jointName] = fuzzy;
}
return mapping;
}
const SINGLE_ARM_TIP_NAMES = [
"gripper_frame_link",
"gripperframe",
"gripper_link",
"gripper",
];
const DUAL_ARM_TIP_NAMES = ["openarm_left_hand_tcp", "openarm_right_hand_tcp"];
const G1_TIP_NAMES = ["left_hand_palm_link", "right_hand_palm_link"];
const TRAIL_DURATION = 1.0;
const TRAIL_COLORS = [new THREE.Color("#ff6600"), new THREE.Color("#00aaff")];
const MAX_TRAIL_POINTS = 300;
// ─── Robot scene (imperative, inside Canvas) ───
function RobotScene({
urdfUrl,
jointValues,
onJointsLoaded,
trailEnabled,
trailResetKey,
scale,
}: {
urdfUrl: string;
jointValues: Record<string, number>;
onJointsLoaded: (names: string[]) => void;
trailEnabled: boolean;
trailResetKey: number;
scale: number;
}) {
const { scene, size } = useThree();
const robotRef = useRef<URDFRobot | null>(null);
const tipLinksRef = useRef<THREE.Object3D[]>([]);
const [loading, setLoading] = useState(true);
const [error, setError] = useState<string | null>(null);
type TrailState = {
positions: Float32Array;
colors: Float32Array;
times: number[];
count: number;
};
const trailsRef = useRef<TrailState[]>([]);
const linesRef = useRef<Line2[]>([]);
const trailMatsRef = useRef<LineMaterial[]>([]);
const trailCountRef = useRef(0);
// Reset trails when episode changes
useEffect(() => {
for (const t of trailsRef.current) {
t.count = 0;
t.times = [];
}
for (const l of linesRef.current) l.visible = false;
}, [trailResetKey]);
// Create/destroy trail Line2 objects when tip count changes
const ensureTrails = useCallback(
(count: number) => {
if (trailCountRef.current === count) return;
// Remove old
for (const l of linesRef.current) {
scene.remove(l);
l.geometry.dispose();
}
for (const m of trailMatsRef.current) m.dispose();
// Create new
const trails: TrailState[] = [];
const lines: Line2[] = [];
const mats: LineMaterial[] = [];
for (let i = 0; i < count; i++) {
trails.push({
positions: new Float32Array(MAX_TRAIL_POINTS * 3),
colors: new Float32Array(MAX_TRAIL_POINTS * 3),
times: [],
count: 0,
});
const mat = new LineMaterial({
color: 0xffffff,
linewidth: 4,
vertexColors: true,
transparent: true,
worldUnits: false,
});
mat.resolution.set(window.innerWidth, window.innerHeight);
mats.push(mat);
const line = new Line2(new LineGeometry(), mat);
line.frustumCulled = false;
line.visible = false;
lines.push(line);
scene.add(line);
}
trailsRef.current = trails;
linesRef.current = lines;
trailMatsRef.current = mats;
trailCountRef.current = count;
},
[scene],
);
useEffect(() => {
setLoading(true);
setError(null);
const isOpenArm = urdfUrl.includes("openarm");
const isG1 = urdfUrl.includes("g1");
const manager = new THREE.LoadingManager();
const loader = new URDFLoader(manager);
loader.loadMeshCb = (url, mgr, onLoad) => {
// DAE (Collada) files — load with embedded materials
if (url.endsWith(".dae")) {
const colladaLoader = new ColladaLoader(mgr);
colladaLoader.load(
url,
(collada) => {
if (isOpenArm) {
collada.scene.traverse((child) => {
if (child instanceof THREE.Mesh && child.material) {
const mat = child.material as THREE.MeshStandardMaterial;
if (mat.side !== undefined) mat.side = THREE.DoubleSide;
if (mat.color) {
const hsl = { h: 0, s: 0, l: 0 };
mat.color.getHSL(hsl);
if (hsl.l > 0.7) mat.color.setHSL(hsl.h, hsl.s, 0.55);
}
}
});
}
onLoad(collada.scene);
},
undefined,
(err) => onLoad(new THREE.Object3D(), err as Error),
);
return;
}
// STL files — apply custom materials, with module-level geometry cache
const makeMesh = (geometry: THREE.BufferGeometry) => {
let color = "#FFD700";
let metalness = 0.1;
let roughness = 0.6;
if (isG1) {
const lower = url.toLowerCase();
const isWhitePart =
lower.includes("contour") ||
lower.includes("roll_link") ||
lower.includes("logo") ||
lower.includes("rubber") ||
lower.includes("constraint") ||
lower.includes("support");
color = isWhitePart ? "#c0c0c0" : "#2a2a2a";
metalness = 0.3;
roughness = 0.5;
} else if (url.includes("sts3215")) {
color = "#1a1a1a";
metalness = 0.7;
roughness = 0.3;
} else if (isOpenArm) {
color = url.includes("body_link0") ? "#3a3a4a" : "#f5f5f5";
metalness = 0.15;
roughness = 0.6;
}
return new THREE.Mesh(
geometry,
new THREE.MeshStandardMaterial({
color,
metalness,
roughness,
side: isOpenArm ? THREE.DoubleSide : THREE.FrontSide,
}),
);
};
const cached = stlGeometryCache.get(url);
if (cached) {
onLoad(makeMesh(cached));
return;
}
// Deduplicate in-flight requests for the same URL
let loading = stlGeometryLoading.get(url);
if (!loading) {
loading = new Promise<THREE.BufferGeometry>((resolve, reject) => {
new STLLoader(mgr).load(url, resolve, undefined, reject);
}).then((geometry) => {
stlGeometryCache.set(url, geometry);
stlGeometryLoading.delete(url);
return geometry;
});
stlGeometryLoading.set(url, loading);
}
loading
.then((geometry) => onLoad(makeMesh(geometry)))
.catch((err) => onLoad(new THREE.Object3D(), err as Error));
};
loader.load(
urdfUrl,
(robot) => {
robotRef.current = robot;
robot.rotateOnAxis(new THREE.Vector3(1, 0, 0), -Math.PI / 2);
robot.traverse((c) => {
c.castShadow = true;
});
robot.updateMatrixWorld(true);
robot.scale.set(scale, scale, scale);
scene.add(robot);
const tipNames = isG1
? G1_TIP_NAMES
: isOpenArm
? DUAL_ARM_TIP_NAMES
: SINGLE_ARM_TIP_NAMES;
const tips: THREE.Object3D[] = [];
for (const name of tipNames) {
if (robot.frames[name]) tips.push(robot.frames[name]);
if (!isOpenArm && !isG1 && tips.length === 1) break;
}
tipLinksRef.current = tips;
ensureTrails(tips.length);
const movable = Object.values(robot.joints)
.filter(
(j) =>
j.jointType === "revolute" ||
j.jointType === "continuous" ||
j.jointType === "prismatic",
)
.map((j) => j.name);
onJointsLoaded(movable);
setLoading(false);
},
undefined,
(err) => {
console.error("Error loading URDF:", err);
setError(String(err));
setLoading(false);
},
);
return () => {
if (robotRef.current) {
scene.remove(robotRef.current);
robotRef.current = null;
}
tipLinksRef.current = [];
};
}, [urdfUrl, scale, scene, onJointsLoaded, ensureTrails]);
const tipWorldPos = useMemo(() => new THREE.Vector3(), []);
useFrame(() => {
const robot = robotRef.current;
if (!robot) return;
for (const [name, value] of Object.entries(jointValues)) {
robot.setJointValue(name, value);
}
robot.updateMatrixWorld(true);
const tips = tipLinksRef.current;
if (!trailEnabled || tips.length === 0) {
for (const l of linesRef.current) l.visible = false;
return;
}
const now = performance.now() / 1000;
for (let ti = 0; ti < tips.length; ti++) {
const tip = tips[ti];
const trail = trailsRef.current[ti];
const line = linesRef.current[ti];
const mat = trailMatsRef.current[ti];
if (!trail || !line || !mat) continue;
mat.resolution.set(size.width, size.height);
tip.getWorldPosition(tipWorldPos);
const trailColor = TRAIL_COLORS[ti % TRAIL_COLORS.length];
if (trail.count < MAX_TRAIL_POINTS) {
trail.count++;
} else {
trail.positions.copyWithin(0, 3);
trail.colors.copyWithin(0, 3);
trail.times.shift();
}
const idx = trail.count - 1;
trail.positions[idx * 3] = tipWorldPos.x;
trail.positions[idx * 3 + 1] = tipWorldPos.y;
trail.positions[idx * 3 + 2] = tipWorldPos.z;
trail.times.push(now);
for (let i = 0; i < trail.count; i++) {
const t = Math.max(0, 1 - (now - trail.times[i]) / TRAIL_DURATION);
trail.colors[i * 3] = trailColor.r * t;
trail.colors[i * 3 + 1] = trailColor.g * t;
trail.colors[i * 3 + 2] = trailColor.b * t;
}
if (trail.count < 2) {
line.visible = false;
continue;
}
const geo = new LineGeometry();
geo.setPositions(
Array.from(trail.positions.subarray(0, trail.count * 3)),
);
geo.setColors(Array.from(trail.colors.subarray(0, trail.count * 3)));
line.geometry.dispose();
line.geometry = geo;
line.computeLineDistances();
line.visible = true;
}
});
if (loading)
return (
<Html center>
<span className="text-white text-lg">Loading robot…</span>
</Html>
);
if (error)
return (
<Html center>
<span className="text-red-400">Failed to load URDF</span>
</Html>
);
return null;
}
// ─── Playback ticker ───
function PlaybackDriver({
playing,
fps,
totalFrames,
frameRef,
setFrame,
}: {
playing: boolean;
fps: number;
totalFrames: number;
frameRef: React.MutableRefObject<number>;
setFrame: React.Dispatch<React.SetStateAction<number>>;
}) {
const elapsed = useRef(0);
const last = useRef(0);
useEffect(() => {
if (!playing) return;
let raf: number;
const tick = () => {
raf = requestAnimationFrame(tick);
const now = performance.now();
const dt = (now - last.current) / 1000;
last.current = now;
if (dt > 0 && dt < 0.5) {
elapsed.current += dt;
const fd = Math.floor(elapsed.current * fps);
if (fd > 0) {
elapsed.current -= fd / fps;
frameRef.current = (frameRef.current + fd) % totalFrames;
setFrame(frameRef.current);
}
}
};
last.current = performance.now();
elapsed.current = 0;
raf = requestAnimationFrame(tick);
return () => cancelAnimationFrame(raf);
}, [playing, fps, totalFrames, frameRef, setFrame]);
return null;
}
// ═══════════════════════════════════════
// ─── Main URDF Viewer ───
// ═══════════════════════════════════════
export default function URDFViewer({
data,
org,
dataset,
episodeChangerRef,
playToggleRef,
}: {
data: EpisodeData;
org?: string;
dataset?: string;
episodeChangerRef?: React.RefObject<((ep: number) => void) | undefined>;
playToggleRef?: React.RefObject<(() => void) | undefined>;
}) {
const { datasetInfo } = data;
const fps = datasetInfo.fps || 30;
const robotConfig = useMemo(
() => getRobotConfig(datasetInfo.robot_type),
[datasetInfo.robot_type],
);
const { urdfUrl, scale } = robotConfig;
const isG1 = urdfUrl.includes("g1");
const repoId = org && dataset ? `${org}/${dataset}` : null;
const datasetInfoRef = useRef<{
version: string;
info: DatasetMetadata;
} | null>(null);
const ensureDatasetInfo = useCallback(async () => {
if (!repoId) return null;
if (datasetInfoRef.current) return datasetInfoRef.current;
const { version, info } = await getDatasetVersionAndInfo(repoId);
const payload = { version, info: info as unknown as DatasetMetadata };
datasetInfoRef.current = payload;
return payload;
}, [repoId]);
// Episode selection & chart data
const [selectedEpisode, setSelectedEpisode] = useState(data.episodeId);
const [chartData, setChartData] = useState(data.flatChartData);
const [episodeLoading, setEpisodeLoading] = useState(false);
const chartDataCache = useRef<Record<number, Record<string, number>[]>>({
[data.episodeId]: data.flatChartData,
});
const handleEpisodeChange = useCallback(
(epId: number) => {
setSelectedEpisode(epId);
setFrame(0);
frameRef.current = 0;
setPlaying(false);
if (chartDataCache.current[epId]) {
setChartData(chartDataCache.current[epId]);
return;
}
if (!repoId) return;
setEpisodeLoading(true);
ensureDatasetInfo()
.then((payload) => {
if (!payload) return null;
return loadEpisodeFlatChartData(
repoId,
payload.version,
payload.info,
epId,
);
})
.then((result) => {
if (!result) return;
chartDataCache.current[epId] = result;
setChartData(result);
})
.catch((err) => console.error("Failed to load episode:", err))
.finally(() => setEpisodeLoading(false));
},
[ensureDatasetInfo, repoId],
);
useEffect(() => {
if (episodeChangerRef) episodeChangerRef.current = handleEpisodeChange;
}, [episodeChangerRef, handleEpisodeChange]);
const totalFrames = chartData.length;
// URDF joint names
const [urdfJointNames, setUrdfJointNames] = useState<string[]>([]);
const onJointsLoaded = useCallback(
(names: string[]) => setUrdfJointNames(names),
[],
);
// Feature groups
const columnGroups = useMemo(() => {
if (totalFrames === 0) return {};
return groupColumnsByPrefix(Object.keys(chartData[0]));
}, [chartData, totalFrames]);
const groupNames = useMemo(() => Object.keys(columnGroups), [columnGroups]);
const defaultGroup = useMemo(
() =>
groupNames.find((g) => g.toLowerCase().includes("state")) ??
groupNames.find((g) => g.toLowerCase().includes("action")) ??
groupNames[0] ??
"",
[groupNames],
);
const [selectedGroup, setSelectedGroup] = useState(defaultGroup);
useEffect(() => setSelectedGroup(defaultGroup), [defaultGroup]);
const selectedColumns = useMemo(
() => columnGroups[selectedGroup] ?? [],
[columnGroups, selectedGroup],
);
// Joint mapping
const autoMapping = useMemo(
() => autoMatchJoints(urdfJointNames, selectedColumns),
[urdfJointNames, selectedColumns],
);
const [mapping, setMapping] = useState<Record<string, string>>(autoMapping);
useEffect(() => setMapping(autoMapping), [autoMapping]);
// Trail
const [trailEnabled, setTrailEnabled] = useState(true);
const [showMapping, setShowMapping] = useState(false);
// Playback
const [frame, setFrame] = useState(0);
const [playing, setPlaying] = useState(false);
const frameRef = useRef(0);
const handleFrameChange = useCallback(
(e: React.ChangeEvent<HTMLInputElement>) => {
const f = parseInt(e.target.value);
setFrame(f);
frameRef.current = f;
},
[],
);
const handlePlayPause = useCallback(() => {
setPlaying((prev) => {
if (!prev) frameRef.current = frame;
return !prev;
});
}, [frame]);
useEffect(() => {
if (playToggleRef) playToggleRef.current = handlePlayPause;
}, [playToggleRef, handlePlayPause]);
// Filter out mimic joints (finger_joint2) from the UI list
const displayJointNames = useMemo(
() =>
urdfJointNames.filter((n) => !n.toLowerCase().includes("finger_joint2")),
[urdfJointNames],
);
// Auto-detect gripper column range for linear mapping to 0-0.044m
const gripperRanges = useMemo(() => {
const ranges: Record<string, { min: number; max: number }> = {};
for (const jn of urdfJointNames) {
if (!jn.toLowerCase().includes("finger_joint1")) continue;
const col = mapping[jn];
if (!col) continue;
let min = Infinity,
max = -Infinity;
for (const row of chartData) {
const v = row[col];
if (typeof v === "number") {
if (v < min) min = v;
if (v > max) max = v;
}
}
if (min < max) ranges[jn] = { min, max };
}
return ranges;
}, [chartData, mapping, urdfJointNames]);
// Compute joint values for current frame
const jointValues = useMemo(() => {
if (totalFrames === 0 || urdfJointNames.length === 0) return {};
const row = chartData[Math.min(frame, totalFrames - 1)];
const revoluteValues: number[] = [];
const revoluteNames: string[] = [];
const values: Record<string, number> = {};
for (const jn of urdfJointNames) {
if (jn.toLowerCase().includes("finger_joint2")) continue;
const col = mapping[jn];
if (!col || typeof row[col] !== "number") continue;
const raw = row[col];
if (jn.toLowerCase().includes("finger_joint1")) {
// Map gripper range → 0-0.044m using auto-detected min/max
const range = gripperRanges[jn];
if (range) {
const t = (raw - range.min) / (range.max - range.min);
values[jn] = t * 0.044;
} else {
values[jn] = (raw / 100) * 0.044; // fallback: assume 0-100
}
} else {
revoluteValues.push(raw);
revoluteNames.push(jn);
}
}
const converted = detectAndConvert(revoluteValues);
revoluteNames.forEach((n, i) => {
values[n] = converted[i];
});
// Copy finger_joint1 → finger_joint2 (mimic joints)
for (const jn of urdfJointNames) {
if (jn.toLowerCase().includes("finger_joint2")) {
const j1 = jn.replace(/finger_joint2/, "finger_joint1");
if (values[j1] !== undefined) values[jn] = values[j1];
}
}
return values;
}, [chartData, frame, gripperRanges, mapping, totalFrames, urdfJointNames]);
if (data.flatChartData.length === 0) {
return (
<div className="text-slate-400 p-8 text-center">
No trajectory data available.
</div>
);
}
return (
<div className="flex-1 flex flex-col overflow-hidden">
{/* 3D Viewport */}
<div className="flex-1 min-h-0 bg-slate-950 rounded-lg overflow-hidden border border-slate-700 relative">
{episodeLoading && (
<div className="absolute inset-0 z-10 flex items-center justify-center bg-slate-950/70">
<span className="text-white text-lg animate-pulse">
Loading episode {selectedEpisode}…
</span>
</div>
)}
<Canvas
camera={{
position: isG1
? [1.5, 1.0, 1.5]
: [0.3 * scale, 0.25 * scale, 0.3 * scale],
fov: 45,
near: 0.01,
far: 100,
}}
>
<ambientLight intensity={0.7} />
<directionalLight position={[3, 5, 4]} intensity={1.5} />
<directionalLight position={[-2, 3, -2]} intensity={0.6} />
<hemisphereLight args={["#b1e1ff", "#666666", 0.5]} />
<RobotScene
urdfUrl={urdfUrl}
jointValues={jointValues}
onJointsLoaded={onJointsLoaded}
trailEnabled={trailEnabled}
trailResetKey={selectedEpisode}
scale={scale}
/>
<Grid
args={[10, 10]}
cellSize={isG1 ? 0.5 : 0.2}
cellThickness={0.5}
cellColor="#334155"
sectionSize={isG1 ? 2 : 1}
sectionThickness={1}
sectionColor="#475569"
fadeDistance={isG1 ? 20 : 10}
position={[0, 0, 0]}
/>
<OrbitControls target={isG1 ? [0, 0.5, 0] : [0, 0.8, 0]} />
<PlaybackDriver
playing={playing}
fps={fps}
totalFrames={totalFrames}
frameRef={frameRef}
setFrame={setFrame}
/>
</Canvas>
</div>
{/* Controls */}
<div className="bg-slate-800/90 border-t border-slate-700 p-3 space-y-3 shrink-0">
<UrdfPlaybackBar
frame={frame}
totalFrames={totalFrames}
fps={fps}
playing={playing}
onPlayPause={handlePlayPause}
trailEnabled={trailEnabled}
onTrailToggle={() => setTrailEnabled((v) => !v)}
onFrameChange={handleFrameChange}
/>
{/* Collapsible joint mapping */}
<button
onClick={() => setShowMapping((v) => !v)}
className="flex items-center gap-1.5 text-xs text-slate-400 hover:text-slate-200 transition-colors"
>
<span
className={`transition-transform ${showMapping ? "rotate-90" : ""}`}
>
</span>
Joint Mapping
<span className="text-slate-600">
({Object.keys(mapping).filter((k) => mapping[k]).length}/
{displayJointNames.length} mapped)
</span>
</button>
{showMapping && (
<div className="flex gap-4 items-start">
<div className="space-y-1 shrink-0">
<label className="text-xs text-slate-400">Data source</label>
<div className="flex gap-1 flex-wrap">
{groupNames.map((name) => (
<button
key={name}
onClick={() => setSelectedGroup(name)}
className={`px-2 py-1 text-xs rounded transition-colors ${
selectedGroup === name
? "bg-orange-600 text-white"
: "bg-slate-700 text-slate-300 hover:bg-slate-600"
}`}
>
{name}
</button>
))}
</div>
</div>
<div className="flex-1 overflow-x-auto max-h-48 overflow-y-auto">
<table className="w-full text-xs">
<thead className="sticky top-0 bg-slate-800">
<tr className="text-slate-500">
<th className="text-left font-normal px-1">URDF Joint</th>
<th className="text-left font-normal px-1"></th>
<th className="text-left font-normal px-1">
Dataset Column
</th>
<th className="text-right font-normal px-1">Value</th>
</tr>
</thead>
<tbody>
{displayJointNames.map((jointName) => (
<tr
key={jointName}
className="border-t border-slate-700/50"
>
<td className="px-1 py-0.5 text-slate-300 font-mono">
{jointName}
</td>
<td className="px-1 text-slate-600"></td>
<td className="px-1 py-0.5">
<select
value={mapping[jointName] ?? ""}
onChange={(e) =>
setMapping((m) => ({
...m,
[jointName]: e.target.value,
}))
}
className="bg-slate-900 text-slate-200 text-xs rounded px-1 py-0.5 border border-slate-600 w-full max-w-[200px]"
>
<option value="">-- unmapped --</option>
{selectedColumns.map((col) => {
const label = col.split(SERIES_DELIM).pop() ?? col;
return (
<option key={col} value={col}>
{label}
</option>
);
})}
</select>
</td>
<td className="px-1 py-0.5 text-right tabular-nums text-slate-400 font-mono">
{jointValues[jointName] !== undefined
? jointValues[jointName].toFixed(3)
: "—"}
</td>
</tr>
))}
</tbody>
</table>
</div>
</div>
)}
</div>
</div>
);
}