openvla-micro / fast_load.py
theguy21's picture
Add fast model loader (cached distill, ~21s load)
1d2bcdf verified
Raw
History Blame Contribute Delete
1.94 kB
"""
Fast(er) model loader for OpenVLA-Micro on CPU.
Uses the cached distill checkpoint directly (~21s load).
"""
import os, time
from pathlib import Path
import torch
def _find_cached_distill():
"""Find the cached distill checkpoint in HF hub cache."""
import glob
pattern = os.path.expanduser("~/.cache/huggingface/hub/models--theguy21--openvla-micro/snapshots/*/openvla-micro-distill.pt")
matches = sorted(glob.glob(pattern))
return matches[-1] if matches else None
def load_openvla_micro(device: str = "cpu"):
from modeling_openvla_micro import DinoSigLIPEncoder, CombinedProjector, ShimMLP, OpenVLAMicro
from transformers import AutoModelForCausalLM, AutoTokenizer
ckpt_path = _find_cached_distill()
if ckpt_path is None:
print("Distill checkpoint not cached. Downloading from HF...", flush=True)
from huggingface_hub import hf_hub_download
ckpt_path = hf_hub_download("theguy21/openvla-micro", "openvla-micro-distill.pt")
t0 = time.time()
print(f"Loading checkpoint...", flush=True)
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
msd = ckpt["model"]
print(f"Building vision encoder...", flush=True)
ve = DinoSigLIPEncoder().eval()
ve.load_state_dict(msd["vision_backbone"])
proj = CombinedProjector(ShimMLP(384), ShimMLP(768), torch.nn.Linear(8704, 896), torch.nn.Linear(896, 896))
proj.load_state_dict(msd["projector"])
print(f"Loading Qwen...", flush=True)
llm = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B", dtype=torch.float32)
llm_sd = {k.replace("llm.", "", 1): v for k, v in msd["llm_backbone"].items()}
llm.load_state_dict(llm_sd, strict=True)
tok = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
model = OpenVLAMicro(ve, proj, llm, tok, norm_stats=ckpt.get("norm_stats")).eval()
print(f"Loaded in {time.time()-t0:.0f}s", flush=True)
return model