theguy21 commited on
Commit
1d2bcdf
·
verified ·
1 Parent(s): 5da54d3

Add fast model loader (cached distill, ~21s load)

Browse files
Files changed (1) hide show
  1. fast_load.py +48 -0
fast_load.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fast(er) model loader for OpenVLA-Micro on CPU.
3
+ Uses the cached distill checkpoint directly (~21s load).
4
+ """
5
+ import os, time
6
+ from pathlib import Path
7
+
8
+ import torch
9
+
10
+
11
+ def _find_cached_distill():
12
+ """Find the cached distill checkpoint in HF hub cache."""
13
+ import glob
14
+ pattern = os.path.expanduser("~/.cache/huggingface/hub/models--theguy21--openvla-micro/snapshots/*/openvla-micro-distill.pt")
15
+ matches = sorted(glob.glob(pattern))
16
+ return matches[-1] if matches else None
17
+
18
+
19
+ def load_openvla_micro(device: str = "cpu"):
20
+ from modeling_openvla_micro import DinoSigLIPEncoder, CombinedProjector, ShimMLP, OpenVLAMicro
21
+ from transformers import AutoModelForCausalLM, AutoTokenizer
22
+
23
+ ckpt_path = _find_cached_distill()
24
+ if ckpt_path is None:
25
+ print("Distill checkpoint not cached. Downloading from HF...", flush=True)
26
+ from huggingface_hub import hf_hub_download
27
+ ckpt_path = hf_hub_download("theguy21/openvla-micro", "openvla-micro-distill.pt")
28
+
29
+ t0 = time.time()
30
+ print(f"Loading checkpoint...", flush=True)
31
+ ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
32
+ msd = ckpt["model"]
33
+
34
+ print(f"Building vision encoder...", flush=True)
35
+ ve = DinoSigLIPEncoder().eval()
36
+ ve.load_state_dict(msd["vision_backbone"])
37
+ proj = CombinedProjector(ShimMLP(384), ShimMLP(768), torch.nn.Linear(8704, 896), torch.nn.Linear(896, 896))
38
+ proj.load_state_dict(msd["projector"])
39
+
40
+ print(f"Loading Qwen...", flush=True)
41
+ llm = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B", dtype=torch.float32)
42
+ llm_sd = {k.replace("llm.", "", 1): v for k, v in msd["llm_backbone"].items()}
43
+ llm.load_state_dict(llm_sd, strict=True)
44
+ tok = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
45
+
46
+ model = OpenVLAMicro(ve, proj, llm, tok, norm_stats=ckpt.get("norm_stats")).eval()
47
+ print(f"Loaded in {time.time()-t0:.0f}s", flush=True)
48
+ return model