import argparse import json import os import shutil import time from typing import Dict, List, Tuple import numpy as np import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer from config import CFG from data_loader import load_test_only from transformer_model import _checkpoint_to_dir def _dir_size_bytes(path: str) -> int: total = 0 for root, _, files in os.walk(path): for f in files: fp = os.path.join(root, f) try: total += os.path.getsize(fp) except OSError: pass return total def _mb(n_bytes: int) -> float: return float(n_bytes) / (1024.0 * 1024.0) def _copy_tokenizer_files(src_dir: str, dst_dir: str) -> List[str]: os.makedirs(dst_dir, exist_ok=True) whitelist = { "tokenizer.json", "tokenizer_config.json", "special_tokens_map.json", "vocab.txt", "merges.txt", "added_tokens.json", "sentencepiece.bpe.model", "spiece.model", "config.json", } copied: List[str] = [] for name in os.listdir(src_dir): src = os.path.join(src_dir, name) dst = os.path.join(dst_dir, name) if not os.path.isfile(src): continue if name in whitelist or name.startswith("tokenizer"): shutil.copy2(src, dst) copied.append(name) return copied def _load_fp32_model(fp32_dir: str): model = AutoModelForSequenceClassification.from_pretrained(fp32_dir) tokenizer = AutoTokenizer.from_pretrained(fp32_dir) model.eval() model.to("cpu") return model, tokenizer def _quantize_dynamic_int8(model_fp32: torch.nn.Module) -> torch.nn.Module: # Apple Silicon (ARM) requires qnnpack; x86 defaults to fbgemm which is unavailable on MPS. torch.backends.quantized.engine = "qnnpack" model_int8 = torch.quantization.quantize_dynamic( model_fp32, {torch.nn.Linear}, dtype=torch.qint8, ) model_int8.eval() return model_int8 def _batched(iterable: List[str], batch_size: int): for i in range(0, len(iterable), batch_size): yield iterable[i : i + batch_size] def _predict( model: torch.nn.Module, tokenizer, texts: List[str], batch_size: int, ) -> np.ndarray: preds: List[int] = [] with torch.inference_mode(): for batch in _batched(texts, batch_size): enc = tokenizer( batch, truncation=True, max_length=CFG.max_length, padding=True, return_tensors="pt", ) enc = {k: v.to("cpu") for k, v in enc.items()} logits = model(**enc).logits batch_preds = torch.argmax(logits, dim=-1).cpu().numpy().tolist() preds.extend(batch_preds) return np.asarray(preds, dtype=np.int64) def _accuracy( model: torch.nn.Module, tokenizer, X_test: List[str], y_test: List[int], batch_size: int = 32, ) -> float: y_pred = _predict(model, tokenizer, X_test, batch_size=batch_size) y_true = np.asarray(y_test, dtype=np.int64) return float((y_pred == y_true).mean()) def _benchmark_latency_ms( model: torch.nn.Module, tokenizer, sample_texts: List[str], batch_size: int, runs: int = 50, warmup: int = 5, ) -> float: per_text_ms: List[float] = [] for i in range(runs): t0 = time.perf_counter() _predict(model, tokenizer, sample_texts, batch_size=batch_size) dt = time.perf_counter() - t0 if i >= warmup: per_text_ms.append((dt / len(sample_texts)) * 1000.0) return float(np.median(per_text_ms)) def _save_quantized_model( model_int8: torch.nn.Module, fp32_dir: str, int8_dir: str, checkpoint_dir_name: str, ) -> Dict: os.makedirs(int8_dir, exist_ok=True) model_path = os.path.join(int8_dir, "model_int8.pt") torch.save(model_int8, model_path) _copy_tokenizer_files(fp32_dir, int8_dir) original_size_mb = _mb(_dir_size_bytes(fp32_dir)) quantized_size_mb = _mb(_dir_size_bytes(int8_dir)) compression_ratio = ( float(original_size_mb) / float(quantized_size_mb) if quantized_size_mb > 0 else 0.0 ) info = { "original_model": checkpoint_dir_name, "quantization_type": "dynamic_int8", "original_size_mb": round(original_size_mb, 2), "quantized_size_mb": round(quantized_size_mb, 2), "compression_ratio": round(compression_ratio, 3), } info_path = os.path.join(int8_dir, "quantization_info.json") with open(info_path, "w", encoding="utf-8") as f: json.dump(info, f, indent=2) return {"model_path": model_path, "info_path": info_path, "info": info} def _print_table( fp32_size_mb: float, int8_size_mb: float, fp32_single_ms: float, int8_single_ms: float, fp32_batch16_ms: float, int8_batch16_ms: float, fp32_acc: float, int8_acc: float, ) -> None: size_change_pct = 100.0 * (1.0 - (int8_size_mb / fp32_size_mb)) if fp32_size_mb > 0 else 0.0 single_speedup = (fp32_single_ms / int8_single_ms) if int8_single_ms > 0 else 0.0 batch_speedup = (fp32_batch16_ms / int8_batch16_ms) if int8_batch16_ms > 0 else 0.0 acc_delta_pp = (int8_acc - fp32_acc) * 100.0 def line(a: str, b: str, c: str, d: str) -> str: return f"│ {a:<15} │ {b:<10} │ {c:<11} │ {d:<17} │" print("┌─────────────────┬────────────┬─────────────┬───────────────────┐") print(line("Metric", "FP32 Model", "INT8 Model", "Change")) print("├─────────────────┼────────────┼─────────────┼───────────────────┤") print( line( "Model size", f"{fp32_size_mb:.1f} MB", f"{int8_size_mb:.1f} MB", f"-{size_change_pct:.1f}% smaller", ) ) print( line( "Single-text ms", f"{fp32_single_ms:.2f} ms", f"{int8_single_ms:.2f} ms", f"{single_speedup:.2f}x faster", ) ) print( line( "Batch-16 ms", f"{fp32_batch16_ms:.2f} ms", f"{int8_batch16_ms:.2f} ms", f"{batch_speedup:.2f}x faster", ) ) print( line( "Test accuracy", f"{fp32_acc * 100:.2f}%", f"{int8_acc * 100:.2f}%", f"{acc_delta_pp:+.2f} pp", ) ) print("└─────────────────┴────────────┴─────────────┴───────────────────┘") def main() -> None: parser = argparse.ArgumentParser(description="Dynamic INT8 quantization for transformer inference on CPU.") parser.add_argument("--model", type=str, default="distilbert-base-uncased") parser.add_argument("--benchmark-only", action="store_true") args = parser.parse_args() dir_name = _checkpoint_to_dir(args.model) fp32_dir = os.path.join(CFG.models_dir, dir_name) int8_dir = os.path.join(CFG.models_dir, f"{dir_name}_int8") if not os.path.isdir(fp32_dir): raise FileNotFoundError( f"FP32 model directory not found: {fp32_dir}\n" f"Expected a fine-tuned model saved via save_pretrained() under saved_models/." ) print(f"[Quantize] Loading FP32 model from: {fp32_dir}") model_fp32, tokenizer_fp32 = _load_fp32_model(fp32_dir) print("[Quantize] Applying dynamic INT8 quantization (Linear layers)...") model_int8 = _quantize_dynamic_int8(model_fp32) if not args.benchmark_only: saved = _save_quantized_model(model_int8, fp32_dir, int8_dir, checkpoint_dir_name=dir_name) print(f"[Quantize] Saved INT8 model -> {saved['model_path']}") print(f"[Quantize] Saved metadata -> {saved['info_path']}") else: os.makedirs(int8_dir, exist_ok=True) X_test, y_test = load_test_only() rng = np.random.default_rng(CFG.seed) sample_idx = rng.choice(len(X_test), size=min(100, len(X_test)), replace=False).tolist() sample_texts = [X_test[i] for i in sample_idx] print("[Benchmark] Measuring latency (median ms per text)...") fp32_single = _benchmark_latency_ms(model_fp32, tokenizer_fp32, sample_texts, batch_size=1) int8_single = _benchmark_latency_ms(model_int8, tokenizer_fp32, sample_texts, batch_size=1) fp32_b16 = _benchmark_latency_ms(model_fp32, tokenizer_fp32, sample_texts, batch_size=16) int8_b16 = _benchmark_latency_ms(model_int8, tokenizer_fp32, sample_texts, batch_size=16) print("[Eval] Computing test accuracy on 7,600 examples...") fp32_acc = _accuracy(model_fp32, tokenizer_fp32, X_test, y_test, batch_size=32) int8_acc = _accuracy(model_int8, tokenizer_fp32, X_test, y_test, batch_size=32) fp32_size_mb = _mb(_dir_size_bytes(fp32_dir)) int8_size_mb = _mb(_dir_size_bytes(int8_dir)) _print_table( fp32_size_mb=fp32_size_mb, int8_size_mb=int8_size_mb, fp32_single_ms=fp32_single, int8_single_ms=int8_single, fp32_batch16_ms=fp32_b16, int8_batch16_ms=int8_b16, fp32_acc=fp32_acc, int8_acc=int8_acc, ) if __name__ == "__main__": main()