| --- |
| language: en |
| tags: [robotics, 6-axis-arm, visual-policy, pytorch, imitation-learning, action-chunking] |
| license: apache-2.0 |
| --- |
| |
| # 6Net 2.0 — 6-Axis Visual Robot Policy (~228M) |
|
|
| Custom transformer policy for visual 6-DoF robot arm control, v2. |
| Optimised for **LDX-218**, **LD-1501MG**, and **LFD-01M** hardware. |
|
|
| | Component | Detail | Params | |
| |---|---|---| |
| | Visual Encoder | ResNet-50 (ImageNet V2, shared) | ~23.5M | |
| | Visual Projection (overhead) | Linear(2048→1024) | ~2.1M | |
| | Visual Projection (wrist) | Linear(2048→1024) | ~2.1M | |
| | State Encoder | MLP(6→256→1024) | ~0.3M | |
| | Transformer | 16L · d=1024 · 16h · ffn=4096 | ~201.6M | |
| | Action Head | MLP(1024→512→K×6) | ~0.5M | |
| | **Total** | | **~228M** | |
|
|
| **Dataset:** `lerobot/pusht_image` · **Steps:** 910 · **Eff. batch:** 32 |
| **Hardware profile:** `LDX-218` |
|
|
| ## Key improvements over v1 |
| - **2× parameter count** via ResNet-50 backbone + wider/deeper transformer |
| - **Dual-camera** overhead + wrist tokens |
| - **Action chunking** (K=10): predicts 10 future steps; returns step 0 at inference |
| - **Hardware profiles**: joint limits, max velocity, and gravity-comp for LDX-218 / LD-1501MG / LFD-01M |
| - **Streaming fallback**: tries streaming download before falling back to synthetic data |
|
|
| ## Inference |
|
|
| ```python |
| import torch |
| from train_6net_v2 import SixNetV2, Config, HARDWARE_PROFILES |
| import torchvision.transforms as T |
| from PIL import Image |
| |
| cfg = Config(hardware="LDX-218") |
| model = SixNetV2(cfg) |
| ckpt = torch.load("6net_v2_final.pt", map_location="cpu") |
| model.load_state_dict(ckpt["model_state"]) |
| model.eval() |
| |
| tf = T.Compose([T.Resize((224,224)), T.ToTensor(), |
| T.Normalize([.485,.456,.406],[.229,.224,.225])]) |
| img = tf(Image.open("overhead.jpg")).unsqueeze(0) |
| wrist = tf(Image.open("wrist.jpg")).unsqueeze(0) |
| jts = torch.zeros(1, 6) # current joint angles (rad) |
| |
| action = model.predict(img, jts, wrist=wrist, hw=HARDWARE_PROFILES["LDX-218"]) |
| # → tensor of shape (1, 6), clamped to LDX-218 joint limits |
| ``` |
|
|