""" Train the hidden state shim (896→4096) for OpenVLA-Micro. The shim maps Qwen2.5 0.5B's 896-dim hidden states to match a teacher LLM's 4096-dim space (e.g., Llama-2, Llama-3). This lets the small model drive OmniVLA's pretrained action head with near-zero accuracy loss. Workflow: 1. Cache your teacher's hidden states on your dataset 2. Run this script to train the shim 3. Bake the shim into the checkpoint with bake_shim.py Usage: python train_shim.py --cache-dir ./my_cache --base-model theguy21/openvla-micro For the full training pipeline used in openvla-micro-distill, see: https://huggingface.co/theguy21/openvla-micro """ import argparse, json, os from pathlib import Path import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from PIL import Image from tqdm import tqdm from modeling_openvla_micro import DinoSigLIPEncoder, CombinedProjector, ShimMLP from model_wrapper import IMAGENET_MEAN as IM4D, IMAGENET_STD as IS4D, SIGLIP_MEAN, SIGLIP_STD from transformers import AutoModelForCausalLM, AutoTokenizer IMAGENET_MEAN = IM4D.view(3, 1, 1) IMAGENET_STD = IS4D.view(3, 1, 1) NUM_ACTION_TOKENS = 32 # OmniVLA uses 8 chunks × 4 DoF NUM_VIS = 452 # 256 dino patches + 196 siglip patches def to_siglip(pv): return (pv * IMAGENET_STD.to(pv.device) + IMAGENET_MEAN.to(pv.device) - SIGLIP_MEAN.to(pv.device)) / SIGLIP_STD.to(pv.device) # ───────────────────────────────────────────────────────────── # Dataset — ADAPT THE IMAGE/INSTRUCTION LOGIC TO YOUR FORMAT # ───────────────────────────────────────────────────────────── class DistillDataset(Dataset): """ Each episode_*.pt is expected to contain: episode_id: str num_steps: int hidden_states: Tensor[T, 32, teacher_dim] (optional) instructions: list[str] of length T Image paths are constructed as {data_dir}/{episode_id}/img/step_{t:04d}.png Override _load_image / _get_instruction for custom formats. """ def __init__(self, cache_dir, data_dir, split="train", val_ratio=0.1): self.data_dir = Path(data_dir) cache_files = sorted(Path(cache_dir).glob("episode_*.pt")) n = len(cache_files) split_idx = int(n * (1 - val_ratio)) files = cache_files[:split_idx] if split == "train" else cache_files[split_idx:] self.index = [] for cf in files: d = torch.load(cf, weights_only=True) for t in range(d["num_steps"]): self.index.append((cf, t)) self._cache = {} self._instr_cache = {} print(f" [{split}] {len(self.index)} steps from {len(files)} episodes", flush=True) def __len__(self): return len(self.index) def __getitem__(self, idx): cf_path, t = self.index[idx] cf_str = str(cf_path) if cf_str not in self._cache: self._cache[cf_str] = torch.load(cf_path, weights_only=True) ep = self._cache[cf_str] ep_id = ep["episode_id"] # Image from torchvision.transforms.functional import resize as tv_resize img = tv_resize(Image.open(self.data_dir / ep_id / "img" / f"step_{t:04d}.png").convert("RGB"), 224) img = torch.tensor(np.array(img, dtype=np.float32) / 255.0).permute(2, 0, 1) img = (img - IMAGENET_MEAN) / IMAGENET_STD # Instruction if "instructions" in ep: instr = ep["instructions"][t] if isinstance(instr, list): instr = instr[0] else: instr = "move forward" return {"cur_img": img, "hs_target": ep["hidden_states"][t].float(), "instruction": str(instr).strip()} def find_action_offset(tokenizer, action_token_ids): """Determine where action tokens start in the chat template.""" dummy = tokenizer.apply_chat_template( [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "test"}, {"role": "assistant", "content": " ".join([f"" for i in range(NUM_ACTION_TOKENS)])}], tokenize=True, add_generation_prompt=False, return_dict=True, return_tensors="pt", ) ids = dummy["input_ids"].squeeze(0) pos = torch.where((ids >= action_token_ids[0]) & (ids <= action_token_ids[-1]))[0] return pos[0].item() def main(): parser = argparse.ArgumentParser() parser.add_argument("--cache-dir", type=str, required=True) parser.add_argument("--data-dir", type=str, required=True, help="Dataset root with {episode_id}/img/step_*.png") parser.add_argument("--base-model", type=str, default="theguy21/openvla-micro") parser.add_argument("--teacher-dim", type=int, default=4096) parser.add_argument("--max-steps", type=int, default=10000) parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--lr", type=float, default=5e-5) parser.add_argument("--grad-accum", type=int, default=4) parser.add_argument("--val-every", type=int, default=500) parser.add_argument("--save-every", type=int, default=5000) parser.add_argument("--resume", type=str, default=None) parser.add_argument("--run-name", type=str, default="shim_run") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") args = parser.parse_args() device = torch.device(args.device) dtype = torch.bfloat16 print(f"Device: {device}") run_dir = Path(args.run_name) run_dir.mkdir(exist_ok=True) # ── 1. Load base model ── print("\n[1] Loading base model...") ckpt = torch.load(os.path.expanduser(args.base_model), map_location="cpu", weights_only=False) msd = ckpt["model"] ve = DinoSigLIPEncoder().eval() ve.load_state_dict(msd["vision_backbone"]) ve.to(device, dtype=dtype) for p in ve.parameters(): p.requires_grad_(False) projector = CombinedProjector(ShimMLP(384), ShimMLP(768), nn.Linear(8704, 896), nn.Linear(896, 896)) projector.load_state_dict(msd["projector"]) projector.to(device, dtype=dtype).eval() for p in projector.parameters(): p.requires_grad_(False) llm = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B", torch_dtype=dtype) llm_sd = {k.replace("llm.", "", 1): v for k, v in msd["llm_backbone"].items()} llm.load_state_dict(llm_sd) llm.to(device, dtype=dtype).eval() for p in llm.parameters(): p.requires_grad_(False) tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", use_fast=True) tokenizer.add_tokens([f"" for i in range(NUM_ACTION_TOKENS)]) action_token_ids = tokenizer.convert_tokens_to_ids([f"" for i in range(NUM_ACTION_TOKENS)]) action_offset = find_action_offset(tokenizer, action_token_ids) print(f" Action tokens at position {action_offset}") # ── 2. Shim ── print("\n[2] Building shim...") shim = nn.Sequential(nn.Linear(896, 2048), nn.GELU(), nn.Linear(2048, args.teacher_dim)) if args.resume: shim.load_state_dict(torch.load(args.resume, map_location="cpu")) print(f" Resumed from {args.resume}") shim.to(device, dtype=dtype).train() # ── 3. Data ── print("\n[3] Loading data...") train_ds = DistillDataset(args.cache_dir, args.data_dir, split="train") val_ds = DistillDataset(args.cache_dir, args.data_dir, split="val") def collate(batch): from torchvision.transforms.functional import resize as tv_resize texts, imgs, hs = [], [], [] for b in batch: texts.append(b["instruction"]) imgs.append(b["cur_img"]) hs.append(b["hs_target"]) cur = torch.stack(imgs) hs_target = torch.stack(hs) chat = [[{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": f"What action should the robot take to {t.lower()}?"}, {"role": "assistant", "content": " ".join([f"" for i in range(NUM_ACTION_TOKENS)])}] for t in texts] tok = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=False, return_dict=True, return_tensors="pt", padding=True) return {"cur_img": cur, "input_ids": tok["input_ids"], "attention_mask": tok["attention_mask"], "hs_target": hs_target} train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=collate, num_workers=0) val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, collate_fn=collate, num_workers=0) # ── 4. Optimizer ── opt = torch.optim.AdamW(shim.parameters(), lr=args.lr, weight_decay=0.01) sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=args.max_steps) # ── 5. Training ── print(f"\n[4] Training...") dino, siglip = ve.dino_featurizer, ve.siglip_featurizer def encode_image(cur): with torch.no_grad(): df = dino(cur) if isinstance(df, (list, tuple)): df = df[0] df = df[:, 1:] sf = siglip(to_siglip(cur)) if isinstance(sf, (list, tuple)): sf = sf[0] sf = sf[:, 1:] B = cur.shape[0]; D = 1152 def pad(f, ed): p = torch.zeros(B, f.shape[1], D, device=device, dtype=dtype) p[..., :ed] = f[..., :ed]; return p return projector(torch.cat([pad(df, 384), pad(sf, 768)], dim=1)) best_loss = float("inf") global_step = 0 train_iter = iter(train_loader) pbar = tqdm(total=args.max_steps, desc="Train") while global_step < args.max_steps: shim.train() opt.zero_grad() accum_loss = 0.0 for _ in range(args.grad_accum): try: batch = next(train_iter) except StopIteration: train_iter = iter(train_loader) batch = next(train_iter) cur_img = batch["cur_img"].to(device, dtype=dtype) inp = batch["input_ids"].to(device) am = batch["attention_mask"].to(device) hs_target = batch["hs_target"].to(device, dtype=dtype) B = cur_img.shape[0] vis = encode_image(cur_img) embed = llm.get_input_embeddings()(inp) mm = torch.cat([embed[:, :1, :], vis, embed[:, 1:, :]], dim=1) mm_attn = torch.cat([am[:, :1], torch.ones(B, NUM_VIS, dtype=am.dtype, device=device), am[:, 1:]], dim=1) act_start = 1 + NUM_VIS + action_offset - 1 mask = torch.zeros(B, mm.shape[1], dtype=torch.bool, device=device) for i in range(B): end = act_start + NUM_ACTION_TOKENS if end <= mm.shape[1]: mask[i, act_start:end] = True mm = mm * ~mask.unsqueeze(-1) with torch.autocast(device_type=device.type, dtype=dtype): out = llm(inputs_embeds=mm, attention_mask=mm_attn, labels=None, output_hidden_states=True, return_dict=True) hs_all = out.hidden_states[-1] hs_act = torch.stack([hs_all[i, mask[i]] for i in range(B)], dim=0) hs_shimmed = shim(hs_act) loss = F.mse_loss(hs_shimmed, hs_target) (loss / args.grad_accum).backward() accum_loss += loss.item() torch.nn.utils.clip_grad_norm_(shim.parameters(), 1.0) opt.step() sched.step() global_step += 1 if global_step % 100 == 0: with torch.no_grad(): cos = F.cosine_similarity(hs_shimmed.float().reshape(-1, args.teacher_dim), hs_target.float().reshape(-1, args.teacher_dim), dim=-1).mean().item() pbar.set_postfix({"loss": f"{accum_loss/args.grad_accum:.5f}", "cos": f"{cos:.4f}"}) pbar.update(1) # Validation if global_step % args.val_every == 0: shim.eval() v_loss, v_cos, nv = 0.0, 0.0, 0 with torch.no_grad(): for vb in val_loader: ci = vb["cur_img"].to(device, dtype=dtype) ip = vb["input_ids"].to(device) am = vb["attention_mask"].to(device) ht = vb["hs_target"].to(device, dtype=dtype) Bv = ci.shape[0] vi = encode_image(ci) em = llm.get_input_embeddings()(ip) mm = torch.cat([em[:, :1, :], vi, em[:, 1:, :]], dim=1) ma = torch.cat([am[:, :1], torch.ones(Bv, NUM_VIS, dtype=am.dtype, device=device), am[:, 1:]], dim=1) mk = torch.zeros(Bv, mm.shape[1], dtype=torch.bool, device=device) for i in range(Bv): e = 1 + NUM_VIS + action_offset - 1 + NUM_ACTION_TOKENS if e <= mm.shape[1]: mk[i, 1 + NUM_VIS + action_offset - 1:e] = True mm = mm * ~mk.unsqueeze(-1) o = llm(inputs_embeds=mm, attention_mask=ma, labels=None, output_hidden_states=True, return_dict=True) ha = torch.stack([o.hidden_states[-1][i, mk[i]] for i in range(Bv)], dim=0) hs = shim(ha) v_loss += F.mse_loss(hs, ht).item() v_cos += F.cosine_similarity(hs.float().reshape(-1, args.teacher_dim), ht.float().reshape(-1, args.teacher_dim), dim=-1).mean().item() nv += 1 v_loss /= nv; v_cos /= nv print(f"\n─── Val @ {global_step}: loss={v_loss:.5f} cos={v_cos:.4f} ───", flush=True) if v_loss < best_loss: best_loss = v_loss torch.save(shim.state_dict(), run_dir / "shim_best.pt") print(f" → Saved best (loss={v_loss:.5f})") if global_step % args.save_every == 0: d = run_dir / f"step_{global_step}"; d.mkdir(exist_ok=True) torch.save(shim.state_dict(), d / "shim.pt") pbar.close() print(f"\nDone! Best val loss: {best_loss:.5f}") if __name__ == "__main__": main()