moondream2-e621 / lora.py
Bocchi
upload models
6a47858
import functools
import os
import shutil
import sys
import torch
from pathlib import Path
from typing import Optional
from urllib.request import Request, urlopen
from urllib.error import HTTPError, URLError
def variant_cache_dir():
hf_hub_cache = os.environ.get("HF_HUB_CACHE")
if hf_hub_cache is not None:
return Path(hf_hub_cache) / "md_variants"
hf_home = os.environ.get("HF_HOME")
if hf_home is not None:
return Path(hf_home) / "hub" / "md_variants"
return Path("~/.cache/huggingface/hub").expanduser() / "md_variants"
def cached_variant_path(variant_id: str):
cache_dir = variant_cache_dir() / variant_id
os.makedirs(cache_dir, exist_ok=True)
dest = cache_dir / "final.pt"
if dest.exists():
return dest
# If variant_id is a local path or a file, prefer it directly.
try:
p = Path(variant_id).expanduser()
if p.exists():
# If a directory was passed, look for final.pt inside it.
if p.is_dir():
candidate = p / "final.pt"
if candidate.exists():
return candidate
else:
return p
except Exception:
# ignore and try remote fetch
pass
md_endpoint = os.getenv("MOONDREAM_ENDPOINT", "https://api.moondream.ai")
headers = {"User-Agent": "moondream-torch"}
api_key = os.getenv("MOONDREAM_API_KEY")
if api_key is not None:
headers["X-Moondream-Auth"] = api_key
req = Request(f"{md_endpoint}/v1/variants/{variant_id}/download", headers=headers)
try:
with urlopen(req) as r, open(dest, "wb") as f:
shutil.copyfileobj(r, f)
return dest
except HTTPError as e:
print(
f"[moondream.lora] Variant '{variant_id}' not found on server (HTTP {e.code}). Continue without LoRA.",
file=sys.stderr,
)
return None
except URLError as e:
print(
f"[moondream.lora] Could not reach endpoint for variant '{variant_id}': {e}. Continue without LoRA.",
file=sys.stderr,
)
return None
except Exception as e:
print(
f"[moondream.lora] Unexpected error downloading variant '{variant_id}': {e}. Continue without LoRA.",
file=sys.stderr,
)
return None
def nest(flat):
tree = {}
for k, v in flat.items():
parts = k.split(".")
d = tree
for p in parts[:-1]:
d = d.setdefault(p, {})
d[parts[-1]] = v
return tree
@functools.lru_cache(maxsize=5)
def variant_state_dict(variant_id: Optional[str] = None, device: str = "cpu"):
if variant_id is None:
return None
path = cached_variant_path(variant_id)
if path is None:
return None
state_dict = torch.load(path, map_location=device, weights_only=True)
# TODO: Move these into the training code that saves checkpoints...
rename_rules = [
("text_model.transformer.h", "text.blocks"),
(".mixer", ".attn"),
(".out_proj", ".proj"),
(".Wqkv", ".qkv"),
(".parametrizations.weight.0", ""),
]
new_state_dict = {}
for key, tensor in state_dict.items():
new_key = key
for old, new in rename_rules:
if old in new_key:
new_key = new_key.replace(old, new)
new_state_dict[new_key] = tensor
return nest(new_state_dict)