openvla-micro / train_shim.py
theguy21's picture
Add CPU inference script, update README with model details and perf stats
bd89217 verified
Raw
History Blame Contribute Delete
14.5 kB
"""
Train the hidden state shim (896β†’4096) for OpenVLA-Micro.
The shim maps Qwen2.5 0.5B's 896-dim hidden states to match a teacher
LLM's 4096-dim space (e.g., Llama-2, Llama-3). This lets the small model
drive OmniVLA's pretrained action head with near-zero accuracy loss.
Workflow:
1. Cache your teacher's hidden states on your dataset
2. Run this script to train the shim
3. Bake the shim into the checkpoint with bake_shim.py
Usage:
python train_shim.py --cache-dir ./my_cache --base-model theguy21/openvla-micro
For the full training pipeline used in openvla-micro-distill, see:
https://huggingface.co/theguy21/openvla-micro
"""
import argparse, json, os
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm
from modeling_openvla_micro import DinoSigLIPEncoder, CombinedProjector, ShimMLP
from model_wrapper import IMAGENET_MEAN as IM4D, IMAGENET_STD as IS4D, SIGLIP_MEAN, SIGLIP_STD
from transformers import AutoModelForCausalLM, AutoTokenizer
IMAGENET_MEAN = IM4D.view(3, 1, 1)
IMAGENET_STD = IS4D.view(3, 1, 1)
NUM_ACTION_TOKENS = 32 # OmniVLA uses 8 chunks Γ— 4 DoF
NUM_VIS = 452 # 256 dino patches + 196 siglip patches
def to_siglip(pv):
return (pv * IMAGENET_STD.to(pv.device) + IMAGENET_MEAN.to(pv.device)
- SIGLIP_MEAN.to(pv.device)) / SIGLIP_STD.to(pv.device)
# ─────────────────────────────────────────────────────────────
# Dataset β€” ADAPT THE IMAGE/INSTRUCTION LOGIC TO YOUR FORMAT
# ─────────────────────────────────────────────────────────────
class DistillDataset(Dataset):
"""
Each episode_*.pt is expected to contain:
episode_id: str
num_steps: int
hidden_states: Tensor[T, 32, teacher_dim]
(optional) instructions: list[str] of length T
Image paths are constructed as {data_dir}/{episode_id}/img/step_{t:04d}.png
Override _load_image / _get_instruction for custom formats.
"""
def __init__(self, cache_dir, data_dir, split="train", val_ratio=0.1):
self.data_dir = Path(data_dir)
cache_files = sorted(Path(cache_dir).glob("episode_*.pt"))
n = len(cache_files)
split_idx = int(n * (1 - val_ratio))
files = cache_files[:split_idx] if split == "train" else cache_files[split_idx:]
self.index = []
for cf in files:
d = torch.load(cf, weights_only=True)
for t in range(d["num_steps"]):
self.index.append((cf, t))
self._cache = {}
self._instr_cache = {}
print(f" [{split}] {len(self.index)} steps from {len(files)} episodes", flush=True)
def __len__(self):
return len(self.index)
def __getitem__(self, idx):
cf_path, t = self.index[idx]
cf_str = str(cf_path)
if cf_str not in self._cache:
self._cache[cf_str] = torch.load(cf_path, weights_only=True)
ep = self._cache[cf_str]
ep_id = ep["episode_id"]
# Image
from torchvision.transforms.functional import resize as tv_resize
img = tv_resize(Image.open(self.data_dir / ep_id / "img" / f"step_{t:04d}.png").convert("RGB"), 224)
img = torch.tensor(np.array(img, dtype=np.float32) / 255.0).permute(2, 0, 1)
img = (img - IMAGENET_MEAN) / IMAGENET_STD
# Instruction
if "instructions" in ep:
instr = ep["instructions"][t]
if isinstance(instr, list):
instr = instr[0]
else:
instr = "move forward"
return {"cur_img": img, "hs_target": ep["hidden_states"][t].float(), "instruction": str(instr).strip()}
def find_action_offset(tokenizer, action_token_ids):
"""Determine where action tokens start in the chat template."""
dummy = tokenizer.apply_chat_template(
[{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "test"},
{"role": "assistant", "content": " ".join([f"<ACTION_{i}>" for i in range(NUM_ACTION_TOKENS)])}],
tokenize=True, add_generation_prompt=False, return_dict=True, return_tensors="pt",
)
ids = dummy["input_ids"].squeeze(0)
pos = torch.where((ids >= action_token_ids[0]) & (ids <= action_token_ids[-1]))[0]
return pos[0].item()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--cache-dir", type=str, required=True)
parser.add_argument("--data-dir", type=str, required=True,
help="Dataset root with {episode_id}/img/step_*.png")
parser.add_argument("--base-model", type=str, default="theguy21/openvla-micro")
parser.add_argument("--teacher-dim", type=int, default=4096)
parser.add_argument("--max-steps", type=int, default=10000)
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--lr", type=float, default=5e-5)
parser.add_argument("--grad-accum", type=int, default=4)
parser.add_argument("--val-every", type=int, default=500)
parser.add_argument("--save-every", type=int, default=5000)
parser.add_argument("--resume", type=str, default=None)
parser.add_argument("--run-name", type=str, default="shim_run")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
args = parser.parse_args()
device = torch.device(args.device)
dtype = torch.bfloat16
print(f"Device: {device}")
run_dir = Path(args.run_name)
run_dir.mkdir(exist_ok=True)
# ── 1. Load base model ──
print("\n[1] Loading base model...")
ckpt = torch.load(os.path.expanduser(args.base_model), map_location="cpu", weights_only=False)
msd = ckpt["model"]
ve = DinoSigLIPEncoder().eval()
ve.load_state_dict(msd["vision_backbone"])
ve.to(device, dtype=dtype)
for p in ve.parameters(): p.requires_grad_(False)
projector = CombinedProjector(ShimMLP(384), ShimMLP(768), nn.Linear(8704, 896), nn.Linear(896, 896))
projector.load_state_dict(msd["projector"])
projector.to(device, dtype=dtype).eval()
for p in projector.parameters(): p.requires_grad_(False)
llm = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B", torch_dtype=dtype)
llm_sd = {k.replace("llm.", "", 1): v for k, v in msd["llm_backbone"].items()}
llm.load_state_dict(llm_sd)
llm.to(device, dtype=dtype).eval()
for p in llm.parameters(): p.requires_grad_(False)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", use_fast=True)
tokenizer.add_tokens([f"<ACTION_{i}>" for i in range(NUM_ACTION_TOKENS)])
action_token_ids = tokenizer.convert_tokens_to_ids([f"<ACTION_{i}>" for i in range(NUM_ACTION_TOKENS)])
action_offset = find_action_offset(tokenizer, action_token_ids)
print(f" Action tokens at position {action_offset}")
# ── 2. Shim ──
print("\n[2] Building shim...")
shim = nn.Sequential(nn.Linear(896, 2048), nn.GELU(), nn.Linear(2048, args.teacher_dim))
if args.resume:
shim.load_state_dict(torch.load(args.resume, map_location="cpu"))
print(f" Resumed from {args.resume}")
shim.to(device, dtype=dtype).train()
# ── 3. Data ──
print("\n[3] Loading data...")
train_ds = DistillDataset(args.cache_dir, args.data_dir, split="train")
val_ds = DistillDataset(args.cache_dir, args.data_dir, split="val")
def collate(batch):
from torchvision.transforms.functional import resize as tv_resize
texts, imgs, hs = [], [], []
for b in batch:
texts.append(b["instruction"])
imgs.append(b["cur_img"])
hs.append(b["hs_target"])
cur = torch.stack(imgs)
hs_target = torch.stack(hs)
chat = [[{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": f"What action should the robot take to {t.lower()}?"},
{"role": "assistant", "content": " ".join([f"<ACTION_{i}>" for i in range(NUM_ACTION_TOKENS)])}]
for t in texts]
tok = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=False,
return_dict=True, return_tensors="pt", padding=True)
return {"cur_img": cur, "input_ids": tok["input_ids"], "attention_mask": tok["attention_mask"],
"hs_target": hs_target}
train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=collate, num_workers=0)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, collate_fn=collate, num_workers=0)
# ── 4. Optimizer ──
opt = torch.optim.AdamW(shim.parameters(), lr=args.lr, weight_decay=0.01)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=args.max_steps)
# ── 5. Training ──
print(f"\n[4] Training...")
dino, siglip = ve.dino_featurizer, ve.siglip_featurizer
def encode_image(cur):
with torch.no_grad():
df = dino(cur)
if isinstance(df, (list, tuple)): df = df[0]
df = df[:, 1:]
sf = siglip(to_siglip(cur))
if isinstance(sf, (list, tuple)): sf = sf[0]
sf = sf[:, 1:]
B = cur.shape[0]; D = 1152
def pad(f, ed):
p = torch.zeros(B, f.shape[1], D, device=device, dtype=dtype)
p[..., :ed] = f[..., :ed]; return p
return projector(torch.cat([pad(df, 384), pad(sf, 768)], dim=1))
best_loss = float("inf")
global_step = 0
train_iter = iter(train_loader)
pbar = tqdm(total=args.max_steps, desc="Train")
while global_step < args.max_steps:
shim.train()
opt.zero_grad()
accum_loss = 0.0
for _ in range(args.grad_accum):
try:
batch = next(train_iter)
except StopIteration:
train_iter = iter(train_loader)
batch = next(train_iter)
cur_img = batch["cur_img"].to(device, dtype=dtype)
inp = batch["input_ids"].to(device)
am = batch["attention_mask"].to(device)
hs_target = batch["hs_target"].to(device, dtype=dtype)
B = cur_img.shape[0]
vis = encode_image(cur_img)
embed = llm.get_input_embeddings()(inp)
mm = torch.cat([embed[:, :1, :], vis, embed[:, 1:, :]], dim=1)
mm_attn = torch.cat([am[:, :1], torch.ones(B, NUM_VIS, dtype=am.dtype, device=device), am[:, 1:]], dim=1)
act_start = 1 + NUM_VIS + action_offset - 1
mask = torch.zeros(B, mm.shape[1], dtype=torch.bool, device=device)
for i in range(B):
end = act_start + NUM_ACTION_TOKENS
if end <= mm.shape[1]:
mask[i, act_start:end] = True
mm = mm * ~mask.unsqueeze(-1)
with torch.autocast(device_type=device.type, dtype=dtype):
out = llm(inputs_embeds=mm, attention_mask=mm_attn, labels=None, output_hidden_states=True, return_dict=True)
hs_all = out.hidden_states[-1]
hs_act = torch.stack([hs_all[i, mask[i]] for i in range(B)], dim=0)
hs_shimmed = shim(hs_act)
loss = F.mse_loss(hs_shimmed, hs_target)
(loss / args.grad_accum).backward()
accum_loss += loss.item()
torch.nn.utils.clip_grad_norm_(shim.parameters(), 1.0)
opt.step()
sched.step()
global_step += 1
if global_step % 100 == 0:
with torch.no_grad():
cos = F.cosine_similarity(hs_shimmed.float().reshape(-1, args.teacher_dim),
hs_target.float().reshape(-1, args.teacher_dim), dim=-1).mean().item()
pbar.set_postfix({"loss": f"{accum_loss/args.grad_accum:.5f}", "cos": f"{cos:.4f}"})
pbar.update(1)
# Validation
if global_step % args.val_every == 0:
shim.eval()
v_loss, v_cos, nv = 0.0, 0.0, 0
with torch.no_grad():
for vb in val_loader:
ci = vb["cur_img"].to(device, dtype=dtype)
ip = vb["input_ids"].to(device)
am = vb["attention_mask"].to(device)
ht = vb["hs_target"].to(device, dtype=dtype)
Bv = ci.shape[0]
vi = encode_image(ci)
em = llm.get_input_embeddings()(ip)
mm = torch.cat([em[:, :1, :], vi, em[:, 1:, :]], dim=1)
ma = torch.cat([am[:, :1], torch.ones(Bv, NUM_VIS, dtype=am.dtype, device=device), am[:, 1:]], dim=1)
mk = torch.zeros(Bv, mm.shape[1], dtype=torch.bool, device=device)
for i in range(Bv):
e = 1 + NUM_VIS + action_offset - 1 + NUM_ACTION_TOKENS
if e <= mm.shape[1]:
mk[i, 1 + NUM_VIS + action_offset - 1:e] = True
mm = mm * ~mk.unsqueeze(-1)
o = llm(inputs_embeds=mm, attention_mask=ma, labels=None, output_hidden_states=True, return_dict=True)
ha = torch.stack([o.hidden_states[-1][i, mk[i]] for i in range(Bv)], dim=0)
hs = shim(ha)
v_loss += F.mse_loss(hs, ht).item()
v_cos += F.cosine_similarity(hs.float().reshape(-1, args.teacher_dim),
ht.float().reshape(-1, args.teacher_dim), dim=-1).mean().item()
nv += 1
v_loss /= nv; v_cos /= nv
print(f"\n─── Val @ {global_step}: loss={v_loss:.5f} cos={v_cos:.4f} ───", flush=True)
if v_loss < best_loss:
best_loss = v_loss
torch.save(shim.state_dict(), run_dir / "shim_best.pt")
print(f" β†’ Saved best (loss={v_loss:.5f})")
if global_step % args.save_every == 0:
d = run_dir / f"step_{global_step}"; d.mkdir(exist_ok=True)
torch.save(shim.state_dict(), d / "shim.pt")
pbar.close()
print(f"\nDone! Best val loss: {best_loss:.5f}")
if __name__ == "__main__":
main()