Spaces:
Running
Running
| import argparse | |
| import json | |
| import os | |
| import shutil | |
| import time | |
| from typing import Dict, List, Tuple | |
| import numpy as np | |
| import torch | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| from config import CFG | |
| from data_loader import load_test_only | |
| from transformer_model import _checkpoint_to_dir | |
| def _dir_size_bytes(path: str) -> int: | |
| total = 0 | |
| for root, _, files in os.walk(path): | |
| for f in files: | |
| fp = os.path.join(root, f) | |
| try: | |
| total += os.path.getsize(fp) | |
| except OSError: | |
| pass | |
| return total | |
| def _mb(n_bytes: int) -> float: | |
| return float(n_bytes) / (1024.0 * 1024.0) | |
| def _copy_tokenizer_files(src_dir: str, dst_dir: str) -> List[str]: | |
| os.makedirs(dst_dir, exist_ok=True) | |
| whitelist = { | |
| "tokenizer.json", | |
| "tokenizer_config.json", | |
| "special_tokens_map.json", | |
| "vocab.txt", | |
| "merges.txt", | |
| "added_tokens.json", | |
| "sentencepiece.bpe.model", | |
| "spiece.model", | |
| "config.json", | |
| } | |
| copied: List[str] = [] | |
| for name in os.listdir(src_dir): | |
| src = os.path.join(src_dir, name) | |
| dst = os.path.join(dst_dir, name) | |
| if not os.path.isfile(src): | |
| continue | |
| if name in whitelist or name.startswith("tokenizer"): | |
| shutil.copy2(src, dst) | |
| copied.append(name) | |
| return copied | |
| def _load_fp32_model(fp32_dir: str): | |
| model = AutoModelForSequenceClassification.from_pretrained(fp32_dir) | |
| tokenizer = AutoTokenizer.from_pretrained(fp32_dir) | |
| model.eval() | |
| model.to("cpu") | |
| return model, tokenizer | |
| def _quantize_dynamic_int8(model_fp32: torch.nn.Module) -> torch.nn.Module: | |
| # Apple Silicon (ARM) requires qnnpack; x86 defaults to fbgemm which is unavailable on MPS. | |
| torch.backends.quantized.engine = "qnnpack" | |
| model_int8 = torch.quantization.quantize_dynamic( | |
| model_fp32, | |
| {torch.nn.Linear}, | |
| dtype=torch.qint8, | |
| ) | |
| model_int8.eval() | |
| return model_int8 | |
| def _batched(iterable: List[str], batch_size: int): | |
| for i in range(0, len(iterable), batch_size): | |
| yield iterable[i : i + batch_size] | |
| def _predict( | |
| model: torch.nn.Module, | |
| tokenizer, | |
| texts: List[str], | |
| batch_size: int, | |
| ) -> np.ndarray: | |
| preds: List[int] = [] | |
| with torch.inference_mode(): | |
| for batch in _batched(texts, batch_size): | |
| enc = tokenizer( | |
| batch, | |
| truncation=True, | |
| max_length=CFG.max_length, | |
| padding=True, | |
| return_tensors="pt", | |
| ) | |
| enc = {k: v.to("cpu") for k, v in enc.items()} | |
| logits = model(**enc).logits | |
| batch_preds = torch.argmax(logits, dim=-1).cpu().numpy().tolist() | |
| preds.extend(batch_preds) | |
| return np.asarray(preds, dtype=np.int64) | |
| def _accuracy( | |
| model: torch.nn.Module, | |
| tokenizer, | |
| X_test: List[str], | |
| y_test: List[int], | |
| batch_size: int = 32, | |
| ) -> float: | |
| y_pred = _predict(model, tokenizer, X_test, batch_size=batch_size) | |
| y_true = np.asarray(y_test, dtype=np.int64) | |
| return float((y_pred == y_true).mean()) | |
| def _benchmark_latency_ms( | |
| model: torch.nn.Module, | |
| tokenizer, | |
| sample_texts: List[str], | |
| batch_size: int, | |
| runs: int = 50, | |
| warmup: int = 5, | |
| ) -> float: | |
| per_text_ms: List[float] = [] | |
| for i in range(runs): | |
| t0 = time.perf_counter() | |
| _predict(model, tokenizer, sample_texts, batch_size=batch_size) | |
| dt = time.perf_counter() - t0 | |
| if i >= warmup: | |
| per_text_ms.append((dt / len(sample_texts)) * 1000.0) | |
| return float(np.median(per_text_ms)) | |
| def _save_quantized_model( | |
| model_int8: torch.nn.Module, | |
| fp32_dir: str, | |
| int8_dir: str, | |
| checkpoint_dir_name: str, | |
| ) -> Dict: | |
| os.makedirs(int8_dir, exist_ok=True) | |
| model_path = os.path.join(int8_dir, "model_int8.pt") | |
| torch.save(model_int8, model_path) | |
| _copy_tokenizer_files(fp32_dir, int8_dir) | |
| original_size_mb = _mb(_dir_size_bytes(fp32_dir)) | |
| quantized_size_mb = _mb(_dir_size_bytes(int8_dir)) | |
| compression_ratio = ( | |
| float(original_size_mb) / float(quantized_size_mb) | |
| if quantized_size_mb > 0 | |
| else 0.0 | |
| ) | |
| info = { | |
| "original_model": checkpoint_dir_name, | |
| "quantization_type": "dynamic_int8", | |
| "original_size_mb": round(original_size_mb, 2), | |
| "quantized_size_mb": round(quantized_size_mb, 2), | |
| "compression_ratio": round(compression_ratio, 3), | |
| } | |
| info_path = os.path.join(int8_dir, "quantization_info.json") | |
| with open(info_path, "w", encoding="utf-8") as f: | |
| json.dump(info, f, indent=2) | |
| return {"model_path": model_path, "info_path": info_path, "info": info} | |
| def _print_table( | |
| fp32_size_mb: float, | |
| int8_size_mb: float, | |
| fp32_single_ms: float, | |
| int8_single_ms: float, | |
| fp32_batch16_ms: float, | |
| int8_batch16_ms: float, | |
| fp32_acc: float, | |
| int8_acc: float, | |
| ) -> None: | |
| size_change_pct = 100.0 * (1.0 - (int8_size_mb / fp32_size_mb)) if fp32_size_mb > 0 else 0.0 | |
| single_speedup = (fp32_single_ms / int8_single_ms) if int8_single_ms > 0 else 0.0 | |
| batch_speedup = (fp32_batch16_ms / int8_batch16_ms) if int8_batch16_ms > 0 else 0.0 | |
| acc_delta_pp = (int8_acc - fp32_acc) * 100.0 | |
| def line(a: str, b: str, c: str, d: str) -> str: | |
| return f"β {a:<15} β {b:<10} β {c:<11} β {d:<17} β" | |
| print("βββββββββββββββββββ¬βββββββββββββ¬ββββββββββββββ¬ββββββββββββββββββββ") | |
| print(line("Metric", "FP32 Model", "INT8 Model", "Change")) | |
| print("βββββββββββββββββββΌβββββββββββββΌββββββββββββββΌββββββββββββββββββββ€") | |
| print( | |
| line( | |
| "Model size", | |
| f"{fp32_size_mb:.1f} MB", | |
| f"{int8_size_mb:.1f} MB", | |
| f"-{size_change_pct:.1f}% smaller", | |
| ) | |
| ) | |
| print( | |
| line( | |
| "Single-text ms", | |
| f"{fp32_single_ms:.2f} ms", | |
| f"{int8_single_ms:.2f} ms", | |
| f"{single_speedup:.2f}x faster", | |
| ) | |
| ) | |
| print( | |
| line( | |
| "Batch-16 ms", | |
| f"{fp32_batch16_ms:.2f} ms", | |
| f"{int8_batch16_ms:.2f} ms", | |
| f"{batch_speedup:.2f}x faster", | |
| ) | |
| ) | |
| print( | |
| line( | |
| "Test accuracy", | |
| f"{fp32_acc * 100:.2f}%", | |
| f"{int8_acc * 100:.2f}%", | |
| f"{acc_delta_pp:+.2f} pp", | |
| ) | |
| ) | |
| print("βββββββββββββββββββ΄βββββββββββββ΄ββββββββββββββ΄ββββββββββββββββββββ") | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Dynamic INT8 quantization for transformer inference on CPU.") | |
| parser.add_argument("--model", type=str, default="distilbert-base-uncased") | |
| parser.add_argument("--benchmark-only", action="store_true") | |
| args = parser.parse_args() | |
| dir_name = _checkpoint_to_dir(args.model) | |
| fp32_dir = os.path.join(CFG.models_dir, dir_name) | |
| int8_dir = os.path.join(CFG.models_dir, f"{dir_name}_int8") | |
| if not os.path.isdir(fp32_dir): | |
| raise FileNotFoundError( | |
| f"FP32 model directory not found: {fp32_dir}\n" | |
| f"Expected a fine-tuned model saved via save_pretrained() under saved_models/." | |
| ) | |
| print(f"[Quantize] Loading FP32 model from: {fp32_dir}") | |
| model_fp32, tokenizer_fp32 = _load_fp32_model(fp32_dir) | |
| print("[Quantize] Applying dynamic INT8 quantization (Linear layers)...") | |
| model_int8 = _quantize_dynamic_int8(model_fp32) | |
| if not args.benchmark_only: | |
| saved = _save_quantized_model(model_int8, fp32_dir, int8_dir, checkpoint_dir_name=dir_name) | |
| print(f"[Quantize] Saved INT8 model -> {saved['model_path']}") | |
| print(f"[Quantize] Saved metadata -> {saved['info_path']}") | |
| else: | |
| os.makedirs(int8_dir, exist_ok=True) | |
| X_test, y_test = load_test_only() | |
| rng = np.random.default_rng(CFG.seed) | |
| sample_idx = rng.choice(len(X_test), size=min(100, len(X_test)), replace=False).tolist() | |
| sample_texts = [X_test[i] for i in sample_idx] | |
| print("[Benchmark] Measuring latency (median ms per text)...") | |
| fp32_single = _benchmark_latency_ms(model_fp32, tokenizer_fp32, sample_texts, batch_size=1) | |
| int8_single = _benchmark_latency_ms(model_int8, tokenizer_fp32, sample_texts, batch_size=1) | |
| fp32_b16 = _benchmark_latency_ms(model_fp32, tokenizer_fp32, sample_texts, batch_size=16) | |
| int8_b16 = _benchmark_latency_ms(model_int8, tokenizer_fp32, sample_texts, batch_size=16) | |
| print("[Eval] Computing test accuracy on 7,600 examples...") | |
| fp32_acc = _accuracy(model_fp32, tokenizer_fp32, X_test, y_test, batch_size=32) | |
| int8_acc = _accuracy(model_int8, tokenizer_fp32, X_test, y_test, batch_size=32) | |
| fp32_size_mb = _mb(_dir_size_bytes(fp32_dir)) | |
| int8_size_mb = _mb(_dir_size_bytes(int8_dir)) | |
| _print_table( | |
| fp32_size_mb=fp32_size_mb, | |
| int8_size_mb=int8_size_mb, | |
| fp32_single_ms=fp32_single, | |
| int8_single_ms=int8_single, | |
| fp32_batch16_ms=fp32_b16, | |
| int8_batch16_ms=int8_b16, | |
| fp32_acc=fp32_acc, | |
| int8_acc=int8_acc, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |