File size: 3,453 Bytes
6a47858 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import functools
import os
import shutil
import sys
import torch
from pathlib import Path
from typing import Optional
from urllib.request import Request, urlopen
from urllib.error import HTTPError, URLError
def variant_cache_dir():
hf_hub_cache = os.environ.get("HF_HUB_CACHE")
if hf_hub_cache is not None:
return Path(hf_hub_cache) / "md_variants"
hf_home = os.environ.get("HF_HOME")
if hf_home is not None:
return Path(hf_home) / "hub" / "md_variants"
return Path("~/.cache/huggingface/hub").expanduser() / "md_variants"
def cached_variant_path(variant_id: str):
cache_dir = variant_cache_dir() / variant_id
os.makedirs(cache_dir, exist_ok=True)
dest = cache_dir / "final.pt"
if dest.exists():
return dest
# If variant_id is a local path or a file, prefer it directly.
try:
p = Path(variant_id).expanduser()
if p.exists():
# If a directory was passed, look for final.pt inside it.
if p.is_dir():
candidate = p / "final.pt"
if candidate.exists():
return candidate
else:
return p
except Exception:
# ignore and try remote fetch
pass
md_endpoint = os.getenv("MOONDREAM_ENDPOINT", "https://api.moondream.ai")
headers = {"User-Agent": "moondream-torch"}
api_key = os.getenv("MOONDREAM_API_KEY")
if api_key is not None:
headers["X-Moondream-Auth"] = api_key
req = Request(f"{md_endpoint}/v1/variants/{variant_id}/download", headers=headers)
try:
with urlopen(req) as r, open(dest, "wb") as f:
shutil.copyfileobj(r, f)
return dest
except HTTPError as e:
print(
f"[moondream.lora] Variant '{variant_id}' not found on server (HTTP {e.code}). Continue without LoRA.",
file=sys.stderr,
)
return None
except URLError as e:
print(
f"[moondream.lora] Could not reach endpoint for variant '{variant_id}': {e}. Continue without LoRA.",
file=sys.stderr,
)
return None
except Exception as e:
print(
f"[moondream.lora] Unexpected error downloading variant '{variant_id}': {e}. Continue without LoRA.",
file=sys.stderr,
)
return None
def nest(flat):
tree = {}
for k, v in flat.items():
parts = k.split(".")
d = tree
for p in parts[:-1]:
d = d.setdefault(p, {})
d[parts[-1]] = v
return tree
@functools.lru_cache(maxsize=5)
def variant_state_dict(variant_id: Optional[str] = None, device: str = "cpu"):
if variant_id is None:
return None
path = cached_variant_path(variant_id)
if path is None:
return None
state_dict = torch.load(path, map_location=device, weights_only=True)
# TODO: Move these into the training code that saves checkpoints...
rename_rules = [
("text_model.transformer.h", "text.blocks"),
(".mixer", ".attn"),
(".out_proj", ".proj"),
(".Wqkv", ".qkv"),
(".parametrizations.weight.0", ""),
]
new_state_dict = {}
for key, tensor in state_dict.items():
new_key = key
for old, new in rename_rules:
if old in new_key:
new_key = new_key.replace(old, new)
new_state_dict[new_key] = tensor
return nest(new_state_dict)
|