""" Fast(er) model loader for OpenVLA-Micro on CPU. Uses the cached distill checkpoint directly (~21s load). """ import os, time from pathlib import Path import torch def _find_cached_distill(): """Find the cached distill checkpoint in HF hub cache.""" import glob pattern = os.path.expanduser("~/.cache/huggingface/hub/models--theguy21--openvla-micro/snapshots/*/openvla-micro-distill.pt") matches = sorted(glob.glob(pattern)) return matches[-1] if matches else None def load_openvla_micro(device: str = "cpu"): from modeling_openvla_micro import DinoSigLIPEncoder, CombinedProjector, ShimMLP, OpenVLAMicro from transformers import AutoModelForCausalLM, AutoTokenizer ckpt_path = _find_cached_distill() if ckpt_path is None: print("Distill checkpoint not cached. Downloading from HF...", flush=True) from huggingface_hub import hf_hub_download ckpt_path = hf_hub_download("theguy21/openvla-micro", "openvla-micro-distill.pt") t0 = time.time() print(f"Loading checkpoint...", flush=True) ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) msd = ckpt["model"] print(f"Building vision encoder...", flush=True) ve = DinoSigLIPEncoder().eval() ve.load_state_dict(msd["vision_backbone"]) proj = CombinedProjector(ShimMLP(384), ShimMLP(768), torch.nn.Linear(8704, 896), torch.nn.Linear(896, 896)) proj.load_state_dict(msd["projector"]) print(f"Loading Qwen...", flush=True) llm = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B", dtype=torch.float32) llm_sd = {k.replace("llm.", "", 1): v for k, v in msd["llm_backbone"].items()} llm.load_state_dict(llm_sd, strict=True) tok = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B") model = OpenVLAMicro(ve, proj, llm, tok, norm_stats=ckpt.get("norm_stats")).eval() print(f"Loaded in {time.time()-t0:.0f}s", flush=True) return model