File size: 3,453 Bytes
6a47858
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import functools
import os
import shutil
import sys
import torch

from pathlib import Path
from typing import Optional
from urllib.request import Request, urlopen
from urllib.error import HTTPError, URLError


def variant_cache_dir():
    hf_hub_cache = os.environ.get("HF_HUB_CACHE")
    if hf_hub_cache is not None:
        return Path(hf_hub_cache) / "md_variants"

    hf_home = os.environ.get("HF_HOME")
    if hf_home is not None:
        return Path(hf_home) / "hub" / "md_variants"

    return Path("~/.cache/huggingface/hub").expanduser() / "md_variants"


def cached_variant_path(variant_id: str):
    cache_dir = variant_cache_dir() / variant_id
    os.makedirs(cache_dir, exist_ok=True)
    dest = cache_dir / "final.pt"
    if dest.exists():
        return dest
    # If variant_id is a local path or a file, prefer it directly.
    try:
        p = Path(variant_id).expanduser()
        if p.exists():
            # If a directory was passed, look for final.pt inside it.
            if p.is_dir():
                candidate = p / "final.pt"
                if candidate.exists():
                    return candidate
            else:
                return p
    except Exception:
        # ignore and try remote fetch
        pass

    md_endpoint = os.getenv("MOONDREAM_ENDPOINT", "https://api.moondream.ai")

    headers = {"User-Agent": "moondream-torch"}
    api_key = os.getenv("MOONDREAM_API_KEY")
    if api_key is not None:
        headers["X-Moondream-Auth"] = api_key

    req = Request(f"{md_endpoint}/v1/variants/{variant_id}/download", headers=headers)
    try:
        with urlopen(req) as r, open(dest, "wb") as f:
            shutil.copyfileobj(r, f)
        return dest
    except HTTPError as e:
        print(
            f"[moondream.lora] Variant '{variant_id}' not found on server (HTTP {e.code}). Continue without LoRA.",
            file=sys.stderr,
        )
        return None
    except URLError as e:
        print(
            f"[moondream.lora] Could not reach endpoint for variant '{variant_id}': {e}. Continue without LoRA.",
            file=sys.stderr,
        )
        return None
    except Exception as e:
        print(
            f"[moondream.lora] Unexpected error downloading variant '{variant_id}': {e}. Continue without LoRA.",
            file=sys.stderr,
        )
        return None


def nest(flat):
    tree = {}
    for k, v in flat.items():
        parts = k.split(".")
        d = tree
        for p in parts[:-1]:
            d = d.setdefault(p, {})
        d[parts[-1]] = v
    return tree


@functools.lru_cache(maxsize=5)
def variant_state_dict(variant_id: Optional[str] = None, device: str = "cpu"):
    if variant_id is None:
        return None

    path = cached_variant_path(variant_id)
    if path is None:
        return None

    state_dict = torch.load(path, map_location=device, weights_only=True)

    # TODO: Move these into the training code that saves checkpoints...
    rename_rules = [
        ("text_model.transformer.h", "text.blocks"),
        (".mixer", ".attn"),
        (".out_proj", ".proj"),
        (".Wqkv", ".qkv"),
        (".parametrizations.weight.0", ""),
    ]
    new_state_dict = {}
    for key, tensor in state_dict.items():
        new_key = key
        for old, new in rename_rules:
            if old in new_key:
                new_key = new_key.replace(old, new)
        new_state_dict[new_key] = tensor

    return nest(new_state_dict)