| """Quantization-penalty test: eval the SAME fine-tuned weights at several precisions. |
| |
| Our Q4_K_M fine-tunes underperform base, and v1's misses were the "206"/`2062` |
| dropped-year-digit bug — a classic low-bit quantization artifact. This serves an |
| on-volume fp16 GGUF (from a past run) at f16 / Q8_0 / Q4_K_M and scores each on the |
| same 28-example eval, so precision is the only variable. If higher precision restores |
| schema validity + recall and kills the digit bug, quantization is the culprit. |
| |
| PYTHONUTF8=1 python -m modal run training/modal_quant_eval.py |
| PYTHONUTF8=1 python -m modal run training/modal_quant_eval.py \ |
| --f16-path /outputs/gemma-cal-staging-f16.gguf --quants f16,Q8_0 |
| """ |
| from __future__ import annotations |
|
|
| import json |
| import os |
| import re |
| import shutil |
| import subprocess |
| import time |
| import urllib.request |
| from pathlib import Path |
|
|
| import modal |
|
|
| REPO_ROOT = Path(__file__).resolve().parent.parent |
|
|
| |
| |
| image = ( |
| modal.Image.from_registry("ghcr.io/ggml-org/llama.cpp:full-cuda", add_python="3.11") |
| .entrypoint([]) |
| .pip_install("requests", "pydantic>=2", "huggingface_hub") |
| .add_local_dir( |
| str(REPO_ROOT), "/root/repo", |
| ignore=[".git", "**/__pycache__", "**/*.gguf", "training/outputs", |
| "training/data/.smcalflow_cache"], |
| ) |
| ) |
| app = modal.App("imessage-cal-quant-eval", image=image) |
| outputs = modal.Volume.from_name("imessage-cal-outputs") |
|
|
|
|
| @app.function(gpu="A100-80GB", timeout=2 * 60 * 60, volumes={"/outputs": outputs}) |
| def quant_eval(f16_path: str = "/outputs/gemma-cal-f16.gguf", |
| quants: str = "f16,Q8_0,Q4_K_M") -> dict: |
| workspace = "/root/repo" |
| quantize = shutil.which("llama-quantize") or "/app/llama-quantize" |
| server = shutil.which("llama-server") or "/app/llama-server" |
| env = {**os.environ} |
| env["LD_LIBRARY_PATH"] = f"{os.path.dirname(server)}:/app:" + env.get("LD_LIBRARY_PATH", "") |
|
|
| if not os.path.exists(f16_path): |
| raise FileNotFoundError(f"{f16_path} not on the volume; `modal volume ls imessage-cal-outputs`") |
|
|
| results: dict[str, dict] = {} |
| for q in [x.strip() for x in quants.split(",") if x.strip()]: |
| if q == "f16": |
| gguf = f16_path |
| tmp = None |
| else: |
| tmp = gguf = f"/outputs/_quanttest-{q}.gguf" |
| print(f"\n[quant] {quantize} {f16_path} -> {gguf} ({q})", flush=True) |
| subprocess.run([quantize, f16_path, gguf, q], env=env, check=True) |
|
|
| proc = subprocess.Popen( |
| [server, "-m", gguf, "--host", "127.0.0.1", "--port", "8080", |
| "-ngl", "999", "-c", "8192", "--jinja"], env=env, |
| ) |
| ready = False |
| for i in range(900): |
| if proc.poll() is not None: |
| print(f"[quant] {q}: server exited early", flush=True) |
| break |
| try: |
| with urllib.request.urlopen("http://127.0.0.1:8080/health", timeout=5) as r: |
| if r.status == 200: |
| ready = True |
| print(f"[quant] {q}: ready after ~{i * 2}s", flush=True) |
| break |
| except Exception: |
| time.sleep(2) |
|
|
| if ready: |
| env2 = {**env, "INFERENCE_BASE_URL": "http://127.0.0.1:8080/v1", |
| "MODEL_LABEL": f"{Path(f16_path).name}:{q}"} |
| r = subprocess.run(["python3", "training/eval.py"], cwd=workspace, env=env2, |
| capture_output=True, text=True) |
| print(f"\n========== PRECISION {q} ==========", flush=True) |
| print(r.stdout, flush=True) |
| if r.stderr: |
| print("STDERR:", r.stderr[-2000:], flush=True) |
| m = re.search(r"RESULTS_JSON:\s*(\{.*\})", r.stdout) |
| if m: |
| results[q] = json.loads(m.group(1)) |
| proc.terminate() |
| time.sleep(3) |
| if tmp and os.path.exists(tmp): |
| os.remove(tmp) |
|
|
| print("\n==================== QUANT SWEEP SUMMARY ====================", flush=True) |
| print(f" (weights: {Path(f16_path).name})", flush=True) |
| for q, res in results.items(): |
| print(f" {q:8s} validity={res.get('schema_validity')} " |
| f"f1={res.get('event_f1')} recall={res.get('event_recall_start_exact')}", flush=True) |
| print(" base ref: validity=1.0 f1=0.977 recall=0.955", flush=True) |
| return results |
|
|
|
|
| @app.local_entrypoint() |
| def main(f16_path: str = "/outputs/gemma-cal-f16.gguf", quants: str = "f16,Q8_0,Q4_K_M"): |
| print(quant_eval.remote(f16_path=f16_path, quants=quants)) |
|
|