import functools import os import shutil import sys import torch from pathlib import Path from typing import Optional from urllib.request import Request, urlopen from urllib.error import HTTPError, URLError def variant_cache_dir(): hf_hub_cache = os.environ.get("HF_HUB_CACHE") if hf_hub_cache is not None: return Path(hf_hub_cache) / "md_variants" hf_home = os.environ.get("HF_HOME") if hf_home is not None: return Path(hf_home) / "hub" / "md_variants" return Path("~/.cache/huggingface/hub").expanduser() / "md_variants" def cached_variant_path(variant_id: str): cache_dir = variant_cache_dir() / variant_id os.makedirs(cache_dir, exist_ok=True) dest = cache_dir / "final.pt" if dest.exists(): return dest # If variant_id is a local path or a file, prefer it directly. try: p = Path(variant_id).expanduser() if p.exists(): # If a directory was passed, look for final.pt inside it. if p.is_dir(): candidate = p / "final.pt" if candidate.exists(): return candidate else: return p except Exception: # ignore and try remote fetch pass md_endpoint = os.getenv("MOONDREAM_ENDPOINT", "https://api.moondream.ai") headers = {"User-Agent": "moondream-torch"} api_key = os.getenv("MOONDREAM_API_KEY") if api_key is not None: headers["X-Moondream-Auth"] = api_key req = Request(f"{md_endpoint}/v1/variants/{variant_id}/download", headers=headers) try: with urlopen(req) as r, open(dest, "wb") as f: shutil.copyfileobj(r, f) return dest except HTTPError as e: print( f"[moondream.lora] Variant '{variant_id}' not found on server (HTTP {e.code}). Continue without LoRA.", file=sys.stderr, ) return None except URLError as e: print( f"[moondream.lora] Could not reach endpoint for variant '{variant_id}': {e}. Continue without LoRA.", file=sys.stderr, ) return None except Exception as e: print( f"[moondream.lora] Unexpected error downloading variant '{variant_id}': {e}. Continue without LoRA.", file=sys.stderr, ) return None def nest(flat): tree = {} for k, v in flat.items(): parts = k.split(".") d = tree for p in parts[:-1]: d = d.setdefault(p, {}) d[parts[-1]] = v return tree @functools.lru_cache(maxsize=5) def variant_state_dict(variant_id: Optional[str] = None, device: str = "cpu"): if variant_id is None: return None path = cached_variant_path(variant_id) if path is None: return None state_dict = torch.load(path, map_location=device, weights_only=True) # TODO: Move these into the training code that saves checkpoints... rename_rules = [ ("text_model.transformer.h", "text.blocks"), (".mixer", ".attn"), (".out_proj", ".proj"), (".Wqkv", ".qkv"), (".parametrizations.weight.0", ""), ] new_state_dict = {} for key, tensor in state_dict.items(): new_key = key for old, new in rename_rules: if old in new_key: new_key = new_key.replace(old, new) new_state_dict[new_key] = tensor return nest(new_state_dict)