Reinforcement Learning
Transformers
English
robotics
vla
vision-language-action
openvla
omnivla
robot
qwen
dinov2
siglip
Instructions to use theguy21/openvla-micro with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use theguy21/openvla-micro with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("theguy21/openvla-micro", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """ | |
| Train the hidden state shim (896β4096) for OpenVLA-Micro. | |
| The shim maps Qwen2.5 0.5B's 896-dim hidden states to match a teacher | |
| LLM's 4096-dim space (e.g., Llama-2, Llama-3). This lets the small model | |
| drive OmniVLA's pretrained action head with near-zero accuracy loss. | |
| Workflow: | |
| 1. Cache your teacher's hidden states on your dataset | |
| 2. Run this script to train the shim | |
| 3. Bake the shim into the checkpoint with bake_shim.py | |
| Usage: | |
| python train_shim.py --cache-dir ./my_cache --base-model theguy21/openvla-micro | |
| For the full training pipeline used in openvla-micro-distill, see: | |
| https://huggingface.co/theguy21/openvla-micro | |
| """ | |
| import argparse, json, os | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset, DataLoader | |
| from PIL import Image | |
| from tqdm import tqdm | |
| from modeling_openvla_micro import DinoSigLIPEncoder, CombinedProjector, ShimMLP | |
| from model_wrapper import IMAGENET_MEAN as IM4D, IMAGENET_STD as IS4D, SIGLIP_MEAN, SIGLIP_STD | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| IMAGENET_MEAN = IM4D.view(3, 1, 1) | |
| IMAGENET_STD = IS4D.view(3, 1, 1) | |
| NUM_ACTION_TOKENS = 32 # OmniVLA uses 8 chunks Γ 4 DoF | |
| NUM_VIS = 452 # 256 dino patches + 196 siglip patches | |
| def to_siglip(pv): | |
| return (pv * IMAGENET_STD.to(pv.device) + IMAGENET_MEAN.to(pv.device) | |
| - SIGLIP_MEAN.to(pv.device)) / SIGLIP_STD.to(pv.device) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Dataset β ADAPT THE IMAGE/INSTRUCTION LOGIC TO YOUR FORMAT | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class DistillDataset(Dataset): | |
| """ | |
| Each episode_*.pt is expected to contain: | |
| episode_id: str | |
| num_steps: int | |
| hidden_states: Tensor[T, 32, teacher_dim] | |
| (optional) instructions: list[str] of length T | |
| Image paths are constructed as {data_dir}/{episode_id}/img/step_{t:04d}.png | |
| Override _load_image / _get_instruction for custom formats. | |
| """ | |
| def __init__(self, cache_dir, data_dir, split="train", val_ratio=0.1): | |
| self.data_dir = Path(data_dir) | |
| cache_files = sorted(Path(cache_dir).glob("episode_*.pt")) | |
| n = len(cache_files) | |
| split_idx = int(n * (1 - val_ratio)) | |
| files = cache_files[:split_idx] if split == "train" else cache_files[split_idx:] | |
| self.index = [] | |
| for cf in files: | |
| d = torch.load(cf, weights_only=True) | |
| for t in range(d["num_steps"]): | |
| self.index.append((cf, t)) | |
| self._cache = {} | |
| self._instr_cache = {} | |
| print(f" [{split}] {len(self.index)} steps from {len(files)} episodes", flush=True) | |
| def __len__(self): | |
| return len(self.index) | |
| def __getitem__(self, idx): | |
| cf_path, t = self.index[idx] | |
| cf_str = str(cf_path) | |
| if cf_str not in self._cache: | |
| self._cache[cf_str] = torch.load(cf_path, weights_only=True) | |
| ep = self._cache[cf_str] | |
| ep_id = ep["episode_id"] | |
| # Image | |
| from torchvision.transforms.functional import resize as tv_resize | |
| img = tv_resize(Image.open(self.data_dir / ep_id / "img" / f"step_{t:04d}.png").convert("RGB"), 224) | |
| img = torch.tensor(np.array(img, dtype=np.float32) / 255.0).permute(2, 0, 1) | |
| img = (img - IMAGENET_MEAN) / IMAGENET_STD | |
| # Instruction | |
| if "instructions" in ep: | |
| instr = ep["instructions"][t] | |
| if isinstance(instr, list): | |
| instr = instr[0] | |
| else: | |
| instr = "move forward" | |
| return {"cur_img": img, "hs_target": ep["hidden_states"][t].float(), "instruction": str(instr).strip()} | |
| def find_action_offset(tokenizer, action_token_ids): | |
| """Determine where action tokens start in the chat template.""" | |
| dummy = tokenizer.apply_chat_template( | |
| [{"role": "system", "content": "You are a helpful assistant."}, | |
| {"role": "user", "content": "test"}, | |
| {"role": "assistant", "content": " ".join([f"<ACTION_{i}>" for i in range(NUM_ACTION_TOKENS)])}], | |
| tokenize=True, add_generation_prompt=False, return_dict=True, return_tensors="pt", | |
| ) | |
| ids = dummy["input_ids"].squeeze(0) | |
| pos = torch.where((ids >= action_token_ids[0]) & (ids <= action_token_ids[-1]))[0] | |
| return pos[0].item() | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--cache-dir", type=str, required=True) | |
| parser.add_argument("--data-dir", type=str, required=True, | |
| help="Dataset root with {episode_id}/img/step_*.png") | |
| parser.add_argument("--base-model", type=str, default="theguy21/openvla-micro") | |
| parser.add_argument("--teacher-dim", type=int, default=4096) | |
| parser.add_argument("--max-steps", type=int, default=10000) | |
| parser.add_argument("--batch-size", type=int, default=8) | |
| parser.add_argument("--lr", type=float, default=5e-5) | |
| parser.add_argument("--grad-accum", type=int, default=4) | |
| parser.add_argument("--val-every", type=int, default=500) | |
| parser.add_argument("--save-every", type=int, default=5000) | |
| parser.add_argument("--resume", type=str, default=None) | |
| parser.add_argument("--run-name", type=str, default="shim_run") | |
| parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") | |
| args = parser.parse_args() | |
| device = torch.device(args.device) | |
| dtype = torch.bfloat16 | |
| print(f"Device: {device}") | |
| run_dir = Path(args.run_name) | |
| run_dir.mkdir(exist_ok=True) | |
| # ββ 1. Load base model ββ | |
| print("\n[1] Loading base model...") | |
| ckpt = torch.load(os.path.expanduser(args.base_model), map_location="cpu", weights_only=False) | |
| msd = ckpt["model"] | |
| ve = DinoSigLIPEncoder().eval() | |
| ve.load_state_dict(msd["vision_backbone"]) | |
| ve.to(device, dtype=dtype) | |
| for p in ve.parameters(): p.requires_grad_(False) | |
| projector = CombinedProjector(ShimMLP(384), ShimMLP(768), nn.Linear(8704, 896), nn.Linear(896, 896)) | |
| projector.load_state_dict(msd["projector"]) | |
| projector.to(device, dtype=dtype).eval() | |
| for p in projector.parameters(): p.requires_grad_(False) | |
| llm = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B", torch_dtype=dtype) | |
| llm_sd = {k.replace("llm.", "", 1): v for k, v in msd["llm_backbone"].items()} | |
| llm.load_state_dict(llm_sd) | |
| llm.to(device, dtype=dtype).eval() | |
| for p in llm.parameters(): p.requires_grad_(False) | |
| tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", use_fast=True) | |
| tokenizer.add_tokens([f"<ACTION_{i}>" for i in range(NUM_ACTION_TOKENS)]) | |
| action_token_ids = tokenizer.convert_tokens_to_ids([f"<ACTION_{i}>" for i in range(NUM_ACTION_TOKENS)]) | |
| action_offset = find_action_offset(tokenizer, action_token_ids) | |
| print(f" Action tokens at position {action_offset}") | |
| # ββ 2. Shim ββ | |
| print("\n[2] Building shim...") | |
| shim = nn.Sequential(nn.Linear(896, 2048), nn.GELU(), nn.Linear(2048, args.teacher_dim)) | |
| if args.resume: | |
| shim.load_state_dict(torch.load(args.resume, map_location="cpu")) | |
| print(f" Resumed from {args.resume}") | |
| shim.to(device, dtype=dtype).train() | |
| # ββ 3. Data ββ | |
| print("\n[3] Loading data...") | |
| train_ds = DistillDataset(args.cache_dir, args.data_dir, split="train") | |
| val_ds = DistillDataset(args.cache_dir, args.data_dir, split="val") | |
| def collate(batch): | |
| from torchvision.transforms.functional import resize as tv_resize | |
| texts, imgs, hs = [], [], [] | |
| for b in batch: | |
| texts.append(b["instruction"]) | |
| imgs.append(b["cur_img"]) | |
| hs.append(b["hs_target"]) | |
| cur = torch.stack(imgs) | |
| hs_target = torch.stack(hs) | |
| chat = [[{"role": "system", "content": "You are a helpful assistant."}, | |
| {"role": "user", "content": f"What action should the robot take to {t.lower()}?"}, | |
| {"role": "assistant", "content": " ".join([f"<ACTION_{i}>" for i in range(NUM_ACTION_TOKENS)])}] | |
| for t in texts] | |
| tok = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=False, | |
| return_dict=True, return_tensors="pt", padding=True) | |
| return {"cur_img": cur, "input_ids": tok["input_ids"], "attention_mask": tok["attention_mask"], | |
| "hs_target": hs_target} | |
| train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=collate, num_workers=0) | |
| val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, collate_fn=collate, num_workers=0) | |
| # ββ 4. Optimizer ββ | |
| opt = torch.optim.AdamW(shim.parameters(), lr=args.lr, weight_decay=0.01) | |
| sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=args.max_steps) | |
| # ββ 5. Training ββ | |
| print(f"\n[4] Training...") | |
| dino, siglip = ve.dino_featurizer, ve.siglip_featurizer | |
| def encode_image(cur): | |
| with torch.no_grad(): | |
| df = dino(cur) | |
| if isinstance(df, (list, tuple)): df = df[0] | |
| df = df[:, 1:] | |
| sf = siglip(to_siglip(cur)) | |
| if isinstance(sf, (list, tuple)): sf = sf[0] | |
| sf = sf[:, 1:] | |
| B = cur.shape[0]; D = 1152 | |
| def pad(f, ed): | |
| p = torch.zeros(B, f.shape[1], D, device=device, dtype=dtype) | |
| p[..., :ed] = f[..., :ed]; return p | |
| return projector(torch.cat([pad(df, 384), pad(sf, 768)], dim=1)) | |
| best_loss = float("inf") | |
| global_step = 0 | |
| train_iter = iter(train_loader) | |
| pbar = tqdm(total=args.max_steps, desc="Train") | |
| while global_step < args.max_steps: | |
| shim.train() | |
| opt.zero_grad() | |
| accum_loss = 0.0 | |
| for _ in range(args.grad_accum): | |
| try: | |
| batch = next(train_iter) | |
| except StopIteration: | |
| train_iter = iter(train_loader) | |
| batch = next(train_iter) | |
| cur_img = batch["cur_img"].to(device, dtype=dtype) | |
| inp = batch["input_ids"].to(device) | |
| am = batch["attention_mask"].to(device) | |
| hs_target = batch["hs_target"].to(device, dtype=dtype) | |
| B = cur_img.shape[0] | |
| vis = encode_image(cur_img) | |
| embed = llm.get_input_embeddings()(inp) | |
| mm = torch.cat([embed[:, :1, :], vis, embed[:, 1:, :]], dim=1) | |
| mm_attn = torch.cat([am[:, :1], torch.ones(B, NUM_VIS, dtype=am.dtype, device=device), am[:, 1:]], dim=1) | |
| act_start = 1 + NUM_VIS + action_offset - 1 | |
| mask = torch.zeros(B, mm.shape[1], dtype=torch.bool, device=device) | |
| for i in range(B): | |
| end = act_start + NUM_ACTION_TOKENS | |
| if end <= mm.shape[1]: | |
| mask[i, act_start:end] = True | |
| mm = mm * ~mask.unsqueeze(-1) | |
| with torch.autocast(device_type=device.type, dtype=dtype): | |
| out = llm(inputs_embeds=mm, attention_mask=mm_attn, labels=None, output_hidden_states=True, return_dict=True) | |
| hs_all = out.hidden_states[-1] | |
| hs_act = torch.stack([hs_all[i, mask[i]] for i in range(B)], dim=0) | |
| hs_shimmed = shim(hs_act) | |
| loss = F.mse_loss(hs_shimmed, hs_target) | |
| (loss / args.grad_accum).backward() | |
| accum_loss += loss.item() | |
| torch.nn.utils.clip_grad_norm_(shim.parameters(), 1.0) | |
| opt.step() | |
| sched.step() | |
| global_step += 1 | |
| if global_step % 100 == 0: | |
| with torch.no_grad(): | |
| cos = F.cosine_similarity(hs_shimmed.float().reshape(-1, args.teacher_dim), | |
| hs_target.float().reshape(-1, args.teacher_dim), dim=-1).mean().item() | |
| pbar.set_postfix({"loss": f"{accum_loss/args.grad_accum:.5f}", "cos": f"{cos:.4f}"}) | |
| pbar.update(1) | |
| # Validation | |
| if global_step % args.val_every == 0: | |
| shim.eval() | |
| v_loss, v_cos, nv = 0.0, 0.0, 0 | |
| with torch.no_grad(): | |
| for vb in val_loader: | |
| ci = vb["cur_img"].to(device, dtype=dtype) | |
| ip = vb["input_ids"].to(device) | |
| am = vb["attention_mask"].to(device) | |
| ht = vb["hs_target"].to(device, dtype=dtype) | |
| Bv = ci.shape[0] | |
| vi = encode_image(ci) | |
| em = llm.get_input_embeddings()(ip) | |
| mm = torch.cat([em[:, :1, :], vi, em[:, 1:, :]], dim=1) | |
| ma = torch.cat([am[:, :1], torch.ones(Bv, NUM_VIS, dtype=am.dtype, device=device), am[:, 1:]], dim=1) | |
| mk = torch.zeros(Bv, mm.shape[1], dtype=torch.bool, device=device) | |
| for i in range(Bv): | |
| e = 1 + NUM_VIS + action_offset - 1 + NUM_ACTION_TOKENS | |
| if e <= mm.shape[1]: | |
| mk[i, 1 + NUM_VIS + action_offset - 1:e] = True | |
| mm = mm * ~mk.unsqueeze(-1) | |
| o = llm(inputs_embeds=mm, attention_mask=ma, labels=None, output_hidden_states=True, return_dict=True) | |
| ha = torch.stack([o.hidden_states[-1][i, mk[i]] for i in range(Bv)], dim=0) | |
| hs = shim(ha) | |
| v_loss += F.mse_loss(hs, ht).item() | |
| v_cos += F.cosine_similarity(hs.float().reshape(-1, args.teacher_dim), | |
| ht.float().reshape(-1, args.teacher_dim), dim=-1).mean().item() | |
| nv += 1 | |
| v_loss /= nv; v_cos /= nv | |
| print(f"\nβββ Val @ {global_step}: loss={v_loss:.5f} cos={v_cos:.4f} βββ", flush=True) | |
| if v_loss < best_loss: | |
| best_loss = v_loss | |
| torch.save(shim.state_dict(), run_dir / "shim_best.pt") | |
| print(f" β Saved best (loss={v_loss:.5f})") | |
| if global_step % args.save_every == 0: | |
| d = run_dir / f"step_{global_step}"; d.mkdir(exist_ok=True) | |
| torch.save(shim.state_dict(), d / "shim.pt") | |
| pbar.close() | |
| print(f"\nDone! Best val loss: {best_loss:.5f}") | |
| if __name__ == "__main__": | |
| main() | |