Spaces:
Sleeping
Sleeping
| """Zero-shot inference test with LBST pre-trained checkpoint on ElRobot. | |
| Maps ElRobot 8 joints → SO-101 6 joints: | |
| ElRobot motor 1 (base) → SO-101 joint 0 | |
| ElRobot motor 2 (shoulder) → SO-101 joint 1 | |
| ElRobot motor 3 (elbow) → SO-101 joint 2 | |
| ElRobot motor 4 (wrist flex)→ SO-101 joint 3 | |
| ElRobot motor 7 (wrist roll)→ SO-101 joint 4 | |
| ElRobot motor 8 (gripper) → SO-101 joint 5 | |
| Motors 5,6 (extra wrist DOFs) are held at current position. | |
| Run from the smolvla_py directory: | |
| uv run python scripts/zero_shot_test.py | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import io | |
| import logging | |
| import sys | |
| import time | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| _HERE = Path(__file__).resolve() | |
| _REPO = _HERE.parents[4] | |
| sys.path.insert(0, str(_REPO / "software" / "station" / "shared")) | |
| sys.path.insert(0, str(_REPO)) | |
| from station_py import new_station_client | |
| from target.gen_python.protobuf.drivers.inferences import normvla | |
| from smolvla import SmolVLAPolicy | |
| from smolvla.normalize import normalize_state, unnormalize_action | |
| from smolvla.stats import load_stats | |
| QUEUE_ID = "inference/normvla" | |
| CHECKPOINT = _REPO / "checkpoints" / "lbst-pick-place" | |
| STATS_PATH = CHECKPOINT / "stats.safetensors" | |
| ELROBOT_TO_SO101 = [0, 1, 2, 3, 6, 7] | |
| IMAGE_KEY = "observation.images.front" | |
| async def fetch_frame(client, timeout: float = 5.0): | |
| qr = client.read_from_tail(QUEUE_ID, offset=b"\x00", limit=1, step=1, buf_size=1) | |
| entry = await asyncio.wait_for(qr.data.get(), timeout=timeout) | |
| if entry is None: | |
| raise RuntimeError(f"Queue closed: {qr.err}") | |
| return normvla.FrameReader(memoryview(bytes(entry.Data))) | |
| def frame_to_batch(frame, stats, device): | |
| joints = frame.get_joints() or [] | |
| images = frame.get_images() or [] | |
| if len(joints) < 8: | |
| raise RuntimeError(f"Expected 8 joints, got {len(joints)}") | |
| state_8 = [j.get_position_norm() for j in joints] | |
| state_6 = [state_8[i] for i in ELROBOT_TO_SO101] | |
| state = torch.tensor(state_6, dtype=torch.float32, device=device).unsqueeze(0) | |
| stats_6 = { | |
| "state_mean": stats["state_mean"][ELROBOT_TO_SO101], | |
| "state_std": stats["state_std"][ELROBOT_TO_SO101], | |
| "action_mean": stats["action_mean"][ELROBOT_TO_SO101], | |
| "action_std": stats["action_std"][ELROBOT_TO_SO101], | |
| } | |
| state_norm = normalize_state(state, stats_6) | |
| batch = {"observation.state": state_norm} | |
| if len(images) > 0: | |
| jpeg = bytes(images[0].get_jpeg()) | |
| with Image.open(io.BytesIO(jpeg)) as im: | |
| arr = np.asarray(im.convert("RGB"), dtype=np.uint8) | |
| img_tensor = torch.from_numpy(arr.copy()).permute(2, 0, 1).float().unsqueeze(0).to(device) / 255.0 | |
| batch[IMAGE_KEY] = img_tensor | |
| else: | |
| print("WARNING: No camera images in frame, using black image") | |
| batch[IMAGE_KEY] = torch.zeros(1, 3, 224, 224, device=device) | |
| ranges_8 = [(int(j.get_range_min()), int(j.get_range_max())) for j in joints] | |
| ranges_6 = [ranges_8[i] for i in ELROBOT_TO_SO101] | |
| return batch, stats_6, ranges_6, state_8, ranges_8 | |
| async def main(): | |
| logging.basicConfig(level=logging.WARNING) | |
| logger = logging.getLogger("zero-shot") | |
| device = torch.device("cpu") | |
| print(f"Loading stats from {STATS_PATH}") | |
| stats = {k: v.to(device) for k, v in load_stats(str(STATS_PATH)).items()} | |
| print(f"Loading checkpoint from {CHECKPOINT}") | |
| t0 = time.time() | |
| policy = SmolVLAPolicy.from_pretrained( | |
| str(CHECKPOINT), | |
| config_overrides={ | |
| "load_vlm_weights": False, | |
| "image_keys": [IMAGE_KEY], | |
| "state_dim": 6, | |
| "action_dim": 6, | |
| }, | |
| strict=False, | |
| ).to(device) | |
| policy.eval() | |
| print(f"Model loaded in {time.time()-t0:.1f}s ({sum(p.numel() for p in policy.parameters()):,} params)") | |
| print("Connecting to station...") | |
| client = await new_station_client("localhost", logger) | |
| print("Connected. Fetching frame...") | |
| frame = await fetch_frame(client) | |
| batch, stats_6, ranges_6, state_8, ranges_8 = frame_to_batch(frame, stats, device) | |
| print(f"\nCurrent ElRobot state (8 joints):") | |
| for i, s in enumerate(state_8): | |
| print(f" motor {i+1}: position_norm={s:.4f}") | |
| print(f"\nMapped to SO-101 (6 joints):") | |
| for i, idx in enumerate(ELROBOT_TO_SO101): | |
| print(f" SO-101 joint {i}: ElRobot motor {idx+1}, norm={state_8[idx]:.4f}") | |
| tokens, mask = policy.tokenize_task("pick up the block", device=device) | |
| batch["observation.language.tokens"] = tokens | |
| batch["observation.language.attention_mask"] = mask | |
| print(f"\nRunning inference (task: 'pick up the block')...") | |
| t0 = time.time() | |
| with torch.no_grad(): | |
| pred = policy.predict_action_chunk(batch) | |
| dt = time.time() - t0 | |
| print(f"Inference took {dt:.1f}s") | |
| pred_goal = unnormalize_action(pred[0], stats_6) | |
| next_goal = pred_goal[0].clamp(0.0, 1.0).numpy() | |
| print(f"\nPredicted actions (6 SO-101 joints):") | |
| print(f" {'Joint':>8} {'Current':>8} {'Predicted':>10} {'Delta':>8} {'Raw ticks':>10}") | |
| joint_names = ["base", "shoulder", "elbow", "wrist_flex", "wrist_roll", "gripper"] | |
| for i, (gn, (rmin, rmax)) in enumerate(zip(next_goal, ranges_6)): | |
| cur = state_8[ELROBOT_TO_SO101[i]] | |
| delta = float(gn) - cur | |
| raw = int(round(rmin + float(gn) * (rmax - rmin))) | |
| print(f" {joint_names[i]:>8} {cur:>8.4f} {float(gn):>10.4f} {delta:>+8.4f} {raw:>10}") | |
| print(f"\nZero-shot test complete. NOT sending commands to the robot.") | |
| print(f"Review the predictions above — if they look reasonable, we can proceed.") | |
| if __name__ == "__main__": | |
| asyncio.run(main()) | |