File size: 1,936 Bytes
1d2bcdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
"""
Fast(er) model loader for OpenVLA-Micro on CPU.
Uses the cached distill checkpoint directly (~21s load).
"""
import os, time
from pathlib import Path

import torch


def _find_cached_distill():
    """Find the cached distill checkpoint in HF hub cache."""
    import glob
    pattern = os.path.expanduser("~/.cache/huggingface/hub/models--theguy21--openvla-micro/snapshots/*/openvla-micro-distill.pt")
    matches = sorted(glob.glob(pattern))
    return matches[-1] if matches else None


def load_openvla_micro(device: str = "cpu"):
    from modeling_openvla_micro import DinoSigLIPEncoder, CombinedProjector, ShimMLP, OpenVLAMicro
    from transformers import AutoModelForCausalLM, AutoTokenizer

    ckpt_path = _find_cached_distill()
    if ckpt_path is None:
        print("Distill checkpoint not cached. Downloading from HF...", flush=True)
        from huggingface_hub import hf_hub_download
        ckpt_path = hf_hub_download("theguy21/openvla-micro", "openvla-micro-distill.pt")

    t0 = time.time()
    print(f"Loading checkpoint...", flush=True)
    ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
    msd = ckpt["model"]

    print(f"Building vision encoder...", flush=True)
    ve = DinoSigLIPEncoder().eval()
    ve.load_state_dict(msd["vision_backbone"])
    proj = CombinedProjector(ShimMLP(384), ShimMLP(768), torch.nn.Linear(8704, 896), torch.nn.Linear(896, 896))
    proj.load_state_dict(msd["projector"])

    print(f"Loading Qwen...", flush=True)
    llm = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B", dtype=torch.float32)
    llm_sd = {k.replace("llm.", "", 1): v for k, v in msd["llm_backbone"].items()}
    llm.load_state_dict(llm_sd, strict=True)
    tok = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")

    model = OpenVLAMicro(ve, proj, llm, tok, norm_stats=ckpt.get("norm_stats")).eval()
    print(f"Loaded in {time.time()-t0:.0f}s", flush=True)
    return model