Spaces:
Sleeping
Sleeping
| """Quantization-penalty test: eval the SAME fine-tuned weights at several precisions. | |
| Our Q4_K_M fine-tunes underperform base, and v1's misses were the "206"/`2062` | |
| dropped-year-digit bug — a classic low-bit quantization artifact. This serves an | |
| on-volume fp16 GGUF (from a past run) at f16 / Q8_0 / Q4_K_M and scores each on the | |
| same 28-example eval, so precision is the only variable. If higher precision restores | |
| schema validity + recall and kills the digit bug, quantization is the culprit. | |
| PYTHONUTF8=1 python -m modal run training/modal_quant_eval.py | |
| PYTHONUTF8=1 python -m modal run training/modal_quant_eval.py \ | |
| --f16-path /outputs/gemma-cal-staging-f16.gguf --quants f16,Q8_0 | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import re | |
| import shutil | |
| import subprocess | |
| import time | |
| import urllib.request | |
| from pathlib import Path | |
| import modal | |
| REPO_ROOT = Path(__file__).resolve().parent.parent | |
| # full-cuda image has BOTH llama-quantize and llama-server; clear its ENTRYPOINT so | |
| # Modal can run python. | |
| image = ( | |
| modal.Image.from_registry("ghcr.io/ggml-org/llama.cpp:full-cuda", add_python="3.11") | |
| .entrypoint([]) | |
| .pip_install("requests", "pydantic>=2", "huggingface_hub") | |
| .add_local_dir( | |
| str(REPO_ROOT), "/root/repo", | |
| ignore=[".git", "**/__pycache__", "**/*.gguf", "training/outputs", | |
| "training/data/.smcalflow_cache"], | |
| ) | |
| ) | |
| app = modal.App("imessage-cal-quant-eval", image=image) | |
| outputs = modal.Volume.from_name("imessage-cal-outputs") | |
| def quant_eval(f16_path: str = "/outputs/gemma-cal-f16.gguf", | |
| quants: str = "f16,Q8_0,Q4_K_M") -> dict: | |
| workspace = "/root/repo" | |
| quantize = shutil.which("llama-quantize") or "/app/llama-quantize" | |
| server = shutil.which("llama-server") or "/app/llama-server" | |
| env = {**os.environ} | |
| env["LD_LIBRARY_PATH"] = f"{os.path.dirname(server)}:/app:" + env.get("LD_LIBRARY_PATH", "") | |
| if not os.path.exists(f16_path): | |
| raise FileNotFoundError(f"{f16_path} not on the volume; `modal volume ls imessage-cal-outputs`") | |
| results: dict[str, dict] = {} | |
| for q in [x.strip() for x in quants.split(",") if x.strip()]: | |
| if q == "f16": | |
| gguf = f16_path | |
| tmp = None | |
| else: | |
| tmp = gguf = f"/outputs/_quanttest-{q}.gguf" | |
| print(f"\n[quant] {quantize} {f16_path} -> {gguf} ({q})", flush=True) | |
| subprocess.run([quantize, f16_path, gguf, q], env=env, check=True) | |
| proc = subprocess.Popen( | |
| [server, "-m", gguf, "--host", "127.0.0.1", "--port", "8080", | |
| "-ngl", "999", "-c", "8192", "--jinja"], env=env, | |
| ) | |
| ready = False | |
| for i in range(900): | |
| if proc.poll() is not None: | |
| print(f"[quant] {q}: server exited early", flush=True) | |
| break | |
| try: | |
| with urllib.request.urlopen("http://127.0.0.1:8080/health", timeout=5) as r: | |
| if r.status == 200: | |
| ready = True | |
| print(f"[quant] {q}: ready after ~{i * 2}s", flush=True) | |
| break | |
| except Exception: # noqa: BLE001 | |
| time.sleep(2) | |
| if ready: | |
| env2 = {**env, "INFERENCE_BASE_URL": "http://127.0.0.1:8080/v1", | |
| "MODEL_LABEL": f"{Path(f16_path).name}:{q}"} | |
| r = subprocess.run(["python3", "training/eval.py"], cwd=workspace, env=env2, | |
| capture_output=True, text=True) | |
| print(f"\n========== PRECISION {q} ==========", flush=True) | |
| print(r.stdout, flush=True) | |
| if r.stderr: | |
| print("STDERR:", r.stderr[-2000:], flush=True) | |
| m = re.search(r"RESULTS_JSON:\s*(\{.*\})", r.stdout) | |
| if m: | |
| results[q] = json.loads(m.group(1)) | |
| proc.terminate() | |
| time.sleep(3) | |
| if tmp and os.path.exists(tmp): | |
| os.remove(tmp) # keep the volume clean (scratch files) | |
| print("\n==================== QUANT SWEEP SUMMARY ====================", flush=True) | |
| print(f" (weights: {Path(f16_path).name})", flush=True) | |
| for q, res in results.items(): | |
| print(f" {q:8s} validity={res.get('schema_validity')} " | |
| f"f1={res.get('event_f1')} recall={res.get('event_recall_start_exact')}", flush=True) | |
| print(" base ref: validity=1.0 f1=0.977 recall=0.955", flush=True) | |
| return results | |
| def main(f16_path: str = "/outputs/gemma-cal-f16.gguf", quants: str = "f16,Q8_0,Q4_K_M"): | |
| print(quant_eval.remote(f16_path=f16_path, quants=quants)) | |