OffGridSchedula / training /modal_quant_eval.py
ParetoOptimal's picture
Initial Commit
0366d65
Raw
History Blame Contribute Delete
4.73 kB
"""Quantization-penalty test: eval the SAME fine-tuned weights at several precisions.
Our Q4_K_M fine-tunes underperform base, and v1's misses were the "206"/`2062`
dropped-year-digit bug — a classic low-bit quantization artifact. This serves an
on-volume fp16 GGUF (from a past run) at f16 / Q8_0 / Q4_K_M and scores each on the
same 28-example eval, so precision is the only variable. If higher precision restores
schema validity + recall and kills the digit bug, quantization is the culprit.
PYTHONUTF8=1 python -m modal run training/modal_quant_eval.py
PYTHONUTF8=1 python -m modal run training/modal_quant_eval.py \
--f16-path /outputs/gemma-cal-staging-f16.gguf --quants f16,Q8_0
"""
from __future__ import annotations
import json
import os
import re
import shutil
import subprocess
import time
import urllib.request
from pathlib import Path
import modal
REPO_ROOT = Path(__file__).resolve().parent.parent
# full-cuda image has BOTH llama-quantize and llama-server; clear its ENTRYPOINT so
# Modal can run python.
image = (
modal.Image.from_registry("ghcr.io/ggml-org/llama.cpp:full-cuda", add_python="3.11")
.entrypoint([])
.pip_install("requests", "pydantic>=2", "huggingface_hub")
.add_local_dir(
str(REPO_ROOT), "/root/repo",
ignore=[".git", "**/__pycache__", "**/*.gguf", "training/outputs",
"training/data/.smcalflow_cache"],
)
)
app = modal.App("imessage-cal-quant-eval", image=image)
outputs = modal.Volume.from_name("imessage-cal-outputs")
@app.function(gpu="A100-80GB", timeout=2 * 60 * 60, volumes={"/outputs": outputs})
def quant_eval(f16_path: str = "/outputs/gemma-cal-f16.gguf",
quants: str = "f16,Q8_0,Q4_K_M") -> dict:
workspace = "/root/repo"
quantize = shutil.which("llama-quantize") or "/app/llama-quantize"
server = shutil.which("llama-server") or "/app/llama-server"
env = {**os.environ}
env["LD_LIBRARY_PATH"] = f"{os.path.dirname(server)}:/app:" + env.get("LD_LIBRARY_PATH", "")
if not os.path.exists(f16_path):
raise FileNotFoundError(f"{f16_path} not on the volume; `modal volume ls imessage-cal-outputs`")
results: dict[str, dict] = {}
for q in [x.strip() for x in quants.split(",") if x.strip()]:
if q == "f16":
gguf = f16_path
tmp = None
else:
tmp = gguf = f"/outputs/_quanttest-{q}.gguf"
print(f"\n[quant] {quantize} {f16_path} -> {gguf} ({q})", flush=True)
subprocess.run([quantize, f16_path, gguf, q], env=env, check=True)
proc = subprocess.Popen(
[server, "-m", gguf, "--host", "127.0.0.1", "--port", "8080",
"-ngl", "999", "-c", "8192", "--jinja"], env=env,
)
ready = False
for i in range(900):
if proc.poll() is not None:
print(f"[quant] {q}: server exited early", flush=True)
break
try:
with urllib.request.urlopen("http://127.0.0.1:8080/health", timeout=5) as r:
if r.status == 200:
ready = True
print(f"[quant] {q}: ready after ~{i * 2}s", flush=True)
break
except Exception: # noqa: BLE001
time.sleep(2)
if ready:
env2 = {**env, "INFERENCE_BASE_URL": "http://127.0.0.1:8080/v1",
"MODEL_LABEL": f"{Path(f16_path).name}:{q}"}
r = subprocess.run(["python3", "training/eval.py"], cwd=workspace, env=env2,
capture_output=True, text=True)
print(f"\n========== PRECISION {q} ==========", flush=True)
print(r.stdout, flush=True)
if r.stderr:
print("STDERR:", r.stderr[-2000:], flush=True)
m = re.search(r"RESULTS_JSON:\s*(\{.*\})", r.stdout)
if m:
results[q] = json.loads(m.group(1))
proc.terminate()
time.sleep(3)
if tmp and os.path.exists(tmp):
os.remove(tmp) # keep the volume clean (scratch files)
print("\n==================== QUANT SWEEP SUMMARY ====================", flush=True)
print(f" (weights: {Path(f16_path).name})", flush=True)
for q, res in results.items():
print(f" {q:8s} validity={res.get('schema_validity')} "
f"f1={res.get('event_f1')} recall={res.get('event_recall_start_exact')}", flush=True)
print(" base ref: validity=1.0 f1=0.977 recall=0.955", flush=True)
return results
@app.local_entrypoint()
def main(f16_path: str = "/outputs/gemma-cal-f16.gguf", quants: str = "f16,Q8_0,Q4_K_M"):
print(quant_eval.remote(f16_path=f16_path, quants=quants))