| --- |
| language: en |
| tags: |
| - robotics |
| - 6-axis-arm |
| - visual-policy |
| - pytorch |
| - imitation-learning |
| license: mit |
| --- |
| |
| # 6Net — 6-Axis Visual Robot Policy (~115M) |
|
|
| Custom transformer policy for visual 6-DoF robot arm control. Trained from scratch (no LoRA). |
|
|
| | Component | Detail | Params | |
| |---|---|---| |
| | Visual Encoder | ResNet-18 fine-tuned | ~11.7M | |
| | Visual Projection | Linear(512→768) | ~0.4M | |
| | State Encoder | MLP(6→256→768) | ~0.2M | |
| | Transformer | 14L · d=768 · 12h · ffn=3072 | ~99.1M | |
| | Action Head | MLP(768→256→6) | ~0.2M | |
| | **Total** | | **~111M** | |
|
|
| **Dataset:** `synthetic` · **Steps:** 455 · **Eff. batch:** 32 |
|
|
| ## Inference |
| ```python |
| import torch |
| from train_6net_local import SixNet, Config |
| import torchvision.transforms as T |
| from PIL import Image |
| |
| model = SixNet(Config()) |
| ckpt = torch.load("6net_final.pt", map_location="cpu") |
| model.load_state_dict(ckpt["model_state"]) |
| model.eval() |
| |
| tf = T.Compose([T.Resize((224,224)), T.ToTensor(), |
| T.Normalize([.485,.456,.406],[.229,.224,.225])]) |
| img = tf(Image.open("cam.jpg")).unsqueeze(0) # (1,3,224,224) |
| jts = torch.zeros(1, 6) # current joint angles (rad) |
| with torch.no_grad(): |
| action = model(img, jts) # (1,6) predicted targets |
| ``` |