|
|
import functools
|
|
|
import os
|
|
|
import shutil
|
|
|
import torch
|
|
|
|
|
|
from pathlib import Path
|
|
|
from urllib.request import Request, urlopen
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
def variant_cache_dir():
|
|
|
hf_hub_cache = os.environ.get("HF_HUB_CACHE")
|
|
|
if hf_hub_cache is not None:
|
|
|
return Path(hf_hub_cache) / "md_variants"
|
|
|
|
|
|
hf_home = os.environ.get("HF_HOME")
|
|
|
if hf_home is not None:
|
|
|
return Path(hf_home) / "hub" / "md_variants"
|
|
|
|
|
|
return Path("~/.cache/huggingface/hub").expanduser() / "md_variants"
|
|
|
|
|
|
|
|
|
def cached_variant_path(variant_id: str):
|
|
|
variant, *rest = variant_id.split("/", 1)
|
|
|
step = rest[0] if rest else "final"
|
|
|
|
|
|
cache_dir = variant_cache_dir() / variant
|
|
|
os.makedirs(cache_dir, exist_ok=True)
|
|
|
dest = cache_dir / f"{step}.pt"
|
|
|
if dest.exists():
|
|
|
return dest
|
|
|
|
|
|
md_endpoint = os.getenv("MOONDREAM_ENDPOINT", "https://api.moondream.ai")
|
|
|
|
|
|
headers = {"User-Agent": "moondream-torch"}
|
|
|
api_key = os.getenv("MOONDREAM_API_KEY")
|
|
|
if api_key is not None:
|
|
|
headers["X-Moondream-Auth"] = api_key
|
|
|
|
|
|
req = Request(f"{md_endpoint}/v1/variants/{variant_id}/download", headers=headers)
|
|
|
with urlopen(req) as r, open(dest, "wb") as f:
|
|
|
shutil.copyfileobj(r, f)
|
|
|
return dest
|
|
|
|
|
|
|
|
|
def nest(flat):
|
|
|
tree = {}
|
|
|
for k, v in flat.items():
|
|
|
parts = k.split(".")
|
|
|
d = tree
|
|
|
for p in parts[:-1]:
|
|
|
d = d.setdefault(p, {})
|
|
|
d[parts[-1]] = v
|
|
|
return tree
|
|
|
|
|
|
|
|
|
@functools.lru_cache(maxsize=5)
|
|
|
def variant_state_dict(variant_id: Optional[str] = None, device: str = "cpu"):
|
|
|
if variant_id is None:
|
|
|
return None
|
|
|
|
|
|
state_dict = torch.load(
|
|
|
cached_variant_path(variant_id), map_location=device, weights_only=True
|
|
|
)
|
|
|
|
|
|
|
|
|
rename_rules = [
|
|
|
("text_model.transformer.h", "text.blocks"),
|
|
|
(".mixer", ".attn"),
|
|
|
(".out_proj", ".proj"),
|
|
|
(".Wqkv", ".qkv"),
|
|
|
(".parametrizations.weight.0", ""),
|
|
|
]
|
|
|
new_state_dict = {}
|
|
|
for key, tensor in state_dict.items():
|
|
|
new_key = key
|
|
|
for old, new in rename_rules:
|
|
|
if old in new_key:
|
|
|
new_key = new_key.replace(old, new)
|
|
|
new_state_dict[new_key] = tensor
|
|
|
|
|
|
return nest(new_state_dict)
|
|
|
|