File size: 4,731 Bytes
0366d65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""Quantization-penalty test: eval the SAME fine-tuned weights at several precisions.

Our Q4_K_M fine-tunes underperform base, and v1's misses were the "206"/`2062`
dropped-year-digit bug — a classic low-bit quantization artifact. This serves an
on-volume fp16 GGUF (from a past run) at f16 / Q8_0 / Q4_K_M and scores each on the
same 28-example eval, so precision is the only variable. If higher precision restores
schema validity + recall and kills the digit bug, quantization is the culprit.

  PYTHONUTF8=1 python -m modal run training/modal_quant_eval.py
  PYTHONUTF8=1 python -m modal run training/modal_quant_eval.py \
      --f16-path /outputs/gemma-cal-staging-f16.gguf --quants f16,Q8_0
"""
from __future__ import annotations

import json
import os
import re
import shutil
import subprocess
import time
import urllib.request
from pathlib import Path

import modal

REPO_ROOT = Path(__file__).resolve().parent.parent

# full-cuda image has BOTH llama-quantize and llama-server; clear its ENTRYPOINT so
# Modal can run python.
image = (
    modal.Image.from_registry("ghcr.io/ggml-org/llama.cpp:full-cuda", add_python="3.11")
    .entrypoint([])
    .pip_install("requests", "pydantic>=2", "huggingface_hub")
    .add_local_dir(
        str(REPO_ROOT), "/root/repo",
        ignore=[".git", "**/__pycache__", "**/*.gguf", "training/outputs",
                "training/data/.smcalflow_cache"],
    )
)
app = modal.App("imessage-cal-quant-eval", image=image)
outputs = modal.Volume.from_name("imessage-cal-outputs")


@app.function(gpu="A100-80GB", timeout=2 * 60 * 60, volumes={"/outputs": outputs})
def quant_eval(f16_path: str = "/outputs/gemma-cal-f16.gguf",
               quants: str = "f16,Q8_0,Q4_K_M") -> dict:
    workspace = "/root/repo"
    quantize = shutil.which("llama-quantize") or "/app/llama-quantize"
    server = shutil.which("llama-server") or "/app/llama-server"
    env = {**os.environ}
    env["LD_LIBRARY_PATH"] = f"{os.path.dirname(server)}:/app:" + env.get("LD_LIBRARY_PATH", "")

    if not os.path.exists(f16_path):
        raise FileNotFoundError(f"{f16_path} not on the volume; `modal volume ls imessage-cal-outputs`")

    results: dict[str, dict] = {}
    for q in [x.strip() for x in quants.split(",") if x.strip()]:
        if q == "f16":
            gguf = f16_path
            tmp = None
        else:
            tmp = gguf = f"/outputs/_quanttest-{q}.gguf"
            print(f"\n[quant] {quantize} {f16_path} -> {gguf} ({q})", flush=True)
            subprocess.run([quantize, f16_path, gguf, q], env=env, check=True)

        proc = subprocess.Popen(
            [server, "-m", gguf, "--host", "127.0.0.1", "--port", "8080",
             "-ngl", "999", "-c", "8192", "--jinja"], env=env,
        )
        ready = False
        for i in range(900):
            if proc.poll() is not None:
                print(f"[quant] {q}: server exited early", flush=True)
                break
            try:
                with urllib.request.urlopen("http://127.0.0.1:8080/health", timeout=5) as r:
                    if r.status == 200:
                        ready = True
                        print(f"[quant] {q}: ready after ~{i * 2}s", flush=True)
                        break
            except Exception:  # noqa: BLE001
                time.sleep(2)

        if ready:
            env2 = {**env, "INFERENCE_BASE_URL": "http://127.0.0.1:8080/v1",
                    "MODEL_LABEL": f"{Path(f16_path).name}:{q}"}
            r = subprocess.run(["python3", "training/eval.py"], cwd=workspace, env=env2,
                               capture_output=True, text=True)
            print(f"\n========== PRECISION {q} ==========", flush=True)
            print(r.stdout, flush=True)
            if r.stderr:
                print("STDERR:", r.stderr[-2000:], flush=True)
            m = re.search(r"RESULTS_JSON:\s*(\{.*\})", r.stdout)
            if m:
                results[q] = json.loads(m.group(1))
        proc.terminate()
        time.sleep(3)
        if tmp and os.path.exists(tmp):
            os.remove(tmp)  # keep the volume clean (scratch files)

    print("\n==================== QUANT SWEEP SUMMARY ====================", flush=True)
    print(f"  (weights: {Path(f16_path).name})", flush=True)
    for q, res in results.items():
        print(f"  {q:8s} validity={res.get('schema_validity')}  "
              f"f1={res.get('event_f1')}  recall={res.get('event_recall_start_exact')}", flush=True)
    print("  base ref: validity=1.0 f1=0.977 recall=0.955", flush=True)
    return results


@app.local_entrypoint()
def main(f16_path: str = "/outputs/gemma-cal-f16.gguf", quants: str = "f16,Q8_0,Q4_K_M"):
    print(quant_eval.remote(f16_path=f16_path, quants=quants))