File size: 12,255 Bytes
f128c67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
"""
runtime.py β€” manages the local LLM-serving subprocess + HuggingFace LoRA cache.

We use **LLaMA-Factory's `api` command** as the serving engine, not vLLM.
LLaMA-Factory wraps transformers + bitsandbytes and streams shards from disk
during quantization, so it can fit Qwen3-32B BF16 β†’ 4-bit on a single 24 GB
GPU. vLLM's bitsandbytes path tries to load the full 64 GB BF16 first and
OOMs on consumer hardware.

Contracts:
  * `ensure_lora_cached()` downloads ONLY the LoRA adapter (~1 GB).
    The 64 GB BF16 base model is the user's responsibility.
  * `resolve_base_model()` returns abs path or hard-fails with help text.
  * `start_server(base_path, lora_path)` spawns `llamafactory-cli api`.
  * `wait_for_server()` polls /v1/models until ready.
"""
from __future__ import annotations

import os
import subprocess
import sys
import tempfile
import time
from pathlib import Path
from typing import Optional

from . import (
    DEFAULT_LORA_REPO,
    DEFAULT_LORA_SUBFOLDER,
    DEFAULT_QUANTIZATION,
    LOCAL_LLM_PORT,
)


# ──────────────────────────── GPU CHECK ────────────────────────────
def check_gpu(min_vram_gb: int = 22) -> None:
    """Hard-fail if no NVIDIA GPU with enough VRAM."""
    try:
        import torch
    except ImportError:
        sys.exit(
            "[statlens] FATAL β€” torch is not importable. Reinstall with `pip install -U statlens`."
        )

    if not torch.cuda.is_available():
        sys.exit(
            "[statlens] FATAL β€” no NVIDIA GPU detected. statLens requires a CUDA GPU\n"
            "with at least 22 GB VRAM (e.g. RTX 3090, 4090, A40, A100).\n"
            "Mac / CPU-only / AMD ROCm are not supported."
        )

    n = torch.cuda.device_count()
    devices = []
    for i in range(n):
        p = torch.cuda.get_device_properties(i)
        devices.append((p.name, p.total_memory / 1024**3))

    biggest = max(d[1] for d in devices)
    if biggest < min_vram_gb:
        names = ", ".join(f"{n}({mem:.0f}GB)" for n, mem in devices)
        sys.exit(
            f"[statlens] FATAL β€” biggest GPU has only {biggest:.0f} GB VRAM ({names}).\n"
            f"Need at least {min_vram_gb} GB to load Qwen3-32B (4-bit) + LoRA + KV cache."
        )

    print(f"[statlens] GPU OK β€” {n} device(s): "
          + ", ".join(f"{n} ({m:.0f}GB)" for n, m in devices))


# ──────────────────────────── BASE MODEL CHECK ────────────────────────────
# Common places people put a downloaded HF model. We search these in order
# when the user didn't pass --base-model.
_CANDIDATE_BASE_PATHS = (
    "~/models/qwen3-32b",
    "/root/autodl-tmp/models/qwen3-32b",     # AutoDL persistent disk
    "/workspace/models/qwen3-32b",            # RunPod / Lambda common path
    "/data/models/qwen3-32b",
    "/mnt/models/qwen3-32b",
)


def _looks_like_hf_model(p: Path) -> bool:
    return p.is_dir() and (p / "config.json").exists()


def _search_hf_cache_for_qwen3_32b() -> Optional[Path]:
    """If the user ran `huggingface-cli download Qwen/Qwen3-32B` without
    --local-dir, the snapshot lives at
        ~/.cache/huggingface/hub/models--Qwen--Qwen3-32B/snapshots/<sha>/
    Return that path if found, else None.
    """
    try:
        from huggingface_hub.constants import HF_HUB_CACHE
    except ImportError:
        return None
    repo_dir = Path(HF_HUB_CACHE) / "models--Qwen--Qwen3-32B" / "snapshots"
    if not repo_dir.exists():
        return None
    for snap in repo_dir.iterdir():
        if _looks_like_hf_model(snap):
            return snap
    return None


def resolve_base_model(base_model: Optional[str]) -> str:
    """Resolve to an absolute path of the BF16 base model directory.

    Resolution order:
      1. CLI arg `base_model`           (highest priority)
      2. env var STATLENS_BASE_MODEL
      3. auto-search common paths       (~/models/qwen3-32b, /root/autodl-tmp/..., etc.)
      4. auto-search the HF Hub cache   (in case user did `huggingface-cli download`)
      5. hard-fail with clear instructions
    """
    explicit = base_model or os.environ.get("STATLENS_BASE_MODEL")

    if explicit:
        p = Path(explicit).expanduser().resolve()
        if not p.exists():
            # Try auto-discovery before giving up β€” maybe they passed the wrong
            # path but the model IS somewhere obvious.
            auto = _auto_discover()
            if auto:
                sys.exit(
                    f"[statlens] FATAL β€” --base-model path not found: {p}\n"
                    f"           but I found a Qwen3-32B at: {auto}\n"
                    f"           Re-run with --base-model {auto}\n"
                    f"           (or just omit --base-model; statLens will auto-detect.)"
                )
            sys.exit(f"[statlens] FATAL β€” --base-model path not found: {p}")
        if not _looks_like_hf_model(p):
            sys.exit(
                f"[statlens] FATAL β€” {p} does not look like a HF model directory "
                "(missing config.json)."
            )
        return str(p)

    # No explicit path given β€” auto-discover.
    auto = _auto_discover()
    if auto:
        print(f"[statlens] auto-detected base model: {auto}")
        return str(auto)

    sys.exit(
        "[statlens] FATAL β€” no base model found.\n\n"
        "statLens does not auto-download the 64 GB BF16 base. Get it once:\n\n"
        "    # mainland China:\n"
        "    HF_ENDPOINT=https://hf-mirror.com \\\n"
        "        huggingface-cli download Qwen/Qwen3-32B --local-dir ~/models/qwen3-32b\n\n"
        "    # elsewhere:\n"
        "    huggingface-cli download Qwen/Qwen3-32B --local-dir ~/models/qwen3-32b\n\n"
        "Then either:\n"
        "  Β· run statLens with no --base-model flag (auto-detected from common paths)\n"
        "  Β· pass --base-model <path>\n"
        "  Β· set the env var STATLENS_BASE_MODEL=<path>\n\n"
        f"Searched (in order): {list(_CANDIDATE_BASE_PATHS)}\n"
        "                     + the HuggingFace Hub cache."
    )


def _auto_discover() -> Optional[Path]:
    """Walk the candidate list and return the first valid HF model dir."""
    for cand in _CANDIDATE_BASE_PATHS:
        p = Path(cand).expanduser().resolve()
        if _looks_like_hf_model(p):
            return p
    return _search_hf_cache_for_qwen3_32b()


# ──────────────────────────── LoRA CACHE ────────────────────────────
def ensure_lora_cached(
    lora_path_override: Optional[str] = None,
    lora_repo: str = DEFAULT_LORA_REPO,
    lora_subfolder: str = DEFAULT_LORA_SUBFOLDER,
) -> str:
    if lora_path_override:
        p = Path(lora_path_override).expanduser().resolve()
        if not (p / "adapter_model.safetensors").exists():
            sys.exit(
                f"[statlens] FATAL β€” --lora-path {p} has no adapter_model.safetensors"
            )
        return str(p)

    from huggingface_hub import snapshot_download
    print(f"[statlens] checking LoRA {lora_repo} ...")
    lora_root = snapshot_download(lora_repo)
    lora_path = str(Path(lora_root) / lora_subfolder)
    if not Path(lora_path, "adapter_model.safetensors").exists():
        sys.exit(
            f"[statlens] FATAL β€” LoRA adapter not found at {lora_path}.\n"
            f"Repo {lora_repo} may have changed layout."
        )
    return lora_path


def cache_dir_for(repo: str) -> Path:
    from huggingface_hub.constants import HF_HUB_CACHE
    safe = repo.replace("/", "--")
    return Path(HF_HUB_CACHE) / f"models--{safe}"


# ──────────────────────────── LLaMA-Factory SUBPROCESS ────────────────────────────
def _build_yaml(base_path: str, lora_path: str, quantization: str) -> Path:
    """Materialise a LLaMA-Factory inference YAML in tmp; return its path."""
    import yaml

    cfg = {
        "model_name_or_path": base_path,
        "adapter_name_or_path": lora_path,
        "template": "qwen",
        "finetuning_type": "lora",
        "trust_remote_code": True,
        "infer_backend": "huggingface",
        "infer_dtype": "bfloat16",
        "flash_attn": "sdpa",
    }
    if quantization == "bitsandbytes":
        cfg["quantization_bit"] = 4
        cfg["quantization_method"] = "bnb"
    elif quantization == "none":
        pass
    else:
        # gptq / awq go straight through; LLaMA-Factory will reject if not supported
        cfg["quantization_method"] = quantization

    yaml_path = Path(tempfile.gettempdir()) / "statlens_lf_api.yaml"
    yaml_path.write_text(yaml.safe_dump(cfg, sort_keys=False))
    return yaml_path


def _llm_log_path() -> Path:
    p = Path.home() / ".cache" / "statlens" / "llm.log"
    p.parent.mkdir(parents=True, exist_ok=True)
    return p


def start_server(
    base_path: str,
    lora_path: str,
    port: int = LOCAL_LLM_PORT,
    quantization: str = DEFAULT_QUANTIZATION,
    log_path: Optional[Path] = None,
) -> subprocess.Popen:
    """Spawn `llamafactory-cli api <yaml>` as a child.

    LLaMA-Factory's verbose output goes to a log file (default ~/.cache/statlens/llm.log)
    rather than the user's terminal. The CLI will surface a clean status line.

    Note: this is LLaMA-Factory's API server, not vLLM. Tensor parallelism /
    max-model-len would need to be plumbed via the LF YAML config β€” they are
    not currently exposed.
    """
    yaml_path = _build_yaml(base_path, lora_path, quantization)
    log_path = log_path or _llm_log_path()

    env = {
        **os.environ,
        "API_HOST": "127.0.0.1",
        "API_PORT": str(port),
    }

    print(f"[statlens] starting LLM backend (quantization={quantization})")
    print(f"[statlens]   base : {base_path}")
    print(f"[statlens]   lora : {lora_path}")
    print(f"[statlens]   log  : {log_path}    (tail this for full LLaMA-Factory output)")

    cmd = [sys.executable, "-m", "llamafactory.cli", "api", str(yaml_path)]
    log_file = open(log_path, "wb")
    proc = subprocess.Popen(
        cmd,
        stdout=log_file,
        stderr=subprocess.STDOUT,
        env=env,
    )
    return proc


def wait_for_server(
    port: int = LOCAL_LLM_PORT,
    timeout: float = 600.0,
    proc: Optional[subprocess.Popen] = None,
) -> None:
    """Poll /v1/models until the server is up, printing a single self-overwriting line."""
    import httpx
    url = f"http://127.0.0.1:{port}/v1/models"
    t0 = time.time()
    spinner = "|/-\\"
    i = 0
    while time.time() - t0 < timeout:
        # If subprocess crashed, abort early with a useful message.
        if proc is not None and proc.poll() is not None:
            sys.stdout.write("\r" + " " * 80 + "\r")
            raise RuntimeError(
                f"LLM backend process exited early with code {proc.returncode}. "
                f"Check the log at {_llm_log_path()}"
            )
        try:
            r = httpx.get(url, timeout=2.0)
            if r.status_code == 200:
                # finish the progress line
                sys.stdout.write("\r" + " " * 80 + "\r")
                sys.stdout.flush()
                print(f"[statlens] LLM ready after {time.time()-t0:.0f}s")
                return
        except Exception:
            pass
        # animated progress
        sys.stdout.write(
            f"\r[statlens]   loading model {spinner[i % 4]} "
            f"({time.time()-t0:.0f}s elapsed) "
        )
        sys.stdout.flush()
        i += 1
        time.sleep(2)
    sys.stdout.write("\r" + " " * 80 + "\r")
    raise TimeoutError(
        f"LLM server did not become ready within {timeout:.0f}s on port {port}"
    )


def stop_server(proc: Optional[subprocess.Popen]) -> None:
    if proc is None or proc.poll() is not None:
        return
    print("[statlens] stopping LLM server ...")
    proc.terminate()
    try:
        proc.wait(timeout=10)
    except subprocess.TimeoutExpired:
        proc.kill()
        proc.wait(timeout=5)