nexa-classify-api / quantize_model.py
Prototype6239's picture
Upload folder using huggingface_hub
a229747 verified
Raw
History Blame Contribute Delete
9.58 kB
import argparse
import json
import os
import shutil
import time
from typing import Dict, List, Tuple
import numpy as np
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from config import CFG
from data_loader import load_test_only
from transformer_model import _checkpoint_to_dir
def _dir_size_bytes(path: str) -> int:
total = 0
for root, _, files in os.walk(path):
for f in files:
fp = os.path.join(root, f)
try:
total += os.path.getsize(fp)
except OSError:
pass
return total
def _mb(n_bytes: int) -> float:
return float(n_bytes) / (1024.0 * 1024.0)
def _copy_tokenizer_files(src_dir: str, dst_dir: str) -> List[str]:
os.makedirs(dst_dir, exist_ok=True)
whitelist = {
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"vocab.txt",
"merges.txt",
"added_tokens.json",
"sentencepiece.bpe.model",
"spiece.model",
"config.json",
}
copied: List[str] = []
for name in os.listdir(src_dir):
src = os.path.join(src_dir, name)
dst = os.path.join(dst_dir, name)
if not os.path.isfile(src):
continue
if name in whitelist or name.startswith("tokenizer"):
shutil.copy2(src, dst)
copied.append(name)
return copied
def _load_fp32_model(fp32_dir: str):
model = AutoModelForSequenceClassification.from_pretrained(fp32_dir)
tokenizer = AutoTokenizer.from_pretrained(fp32_dir)
model.eval()
model.to("cpu")
return model, tokenizer
def _quantize_dynamic_int8(model_fp32: torch.nn.Module) -> torch.nn.Module:
# Apple Silicon (ARM) requires qnnpack; x86 defaults to fbgemm which is unavailable on MPS.
torch.backends.quantized.engine = "qnnpack"
model_int8 = torch.quantization.quantize_dynamic(
model_fp32,
{torch.nn.Linear},
dtype=torch.qint8,
)
model_int8.eval()
return model_int8
def _batched(iterable: List[str], batch_size: int):
for i in range(0, len(iterable), batch_size):
yield iterable[i : i + batch_size]
def _predict(
model: torch.nn.Module,
tokenizer,
texts: List[str],
batch_size: int,
) -> np.ndarray:
preds: List[int] = []
with torch.inference_mode():
for batch in _batched(texts, batch_size):
enc = tokenizer(
batch,
truncation=True,
max_length=CFG.max_length,
padding=True,
return_tensors="pt",
)
enc = {k: v.to("cpu") for k, v in enc.items()}
logits = model(**enc).logits
batch_preds = torch.argmax(logits, dim=-1).cpu().numpy().tolist()
preds.extend(batch_preds)
return np.asarray(preds, dtype=np.int64)
def _accuracy(
model: torch.nn.Module,
tokenizer,
X_test: List[str],
y_test: List[int],
batch_size: int = 32,
) -> float:
y_pred = _predict(model, tokenizer, X_test, batch_size=batch_size)
y_true = np.asarray(y_test, dtype=np.int64)
return float((y_pred == y_true).mean())
def _benchmark_latency_ms(
model: torch.nn.Module,
tokenizer,
sample_texts: List[str],
batch_size: int,
runs: int = 50,
warmup: int = 5,
) -> float:
per_text_ms: List[float] = []
for i in range(runs):
t0 = time.perf_counter()
_predict(model, tokenizer, sample_texts, batch_size=batch_size)
dt = time.perf_counter() - t0
if i >= warmup:
per_text_ms.append((dt / len(sample_texts)) * 1000.0)
return float(np.median(per_text_ms))
def _save_quantized_model(
model_int8: torch.nn.Module,
fp32_dir: str,
int8_dir: str,
checkpoint_dir_name: str,
) -> Dict:
os.makedirs(int8_dir, exist_ok=True)
model_path = os.path.join(int8_dir, "model_int8.pt")
torch.save(model_int8, model_path)
_copy_tokenizer_files(fp32_dir, int8_dir)
original_size_mb = _mb(_dir_size_bytes(fp32_dir))
quantized_size_mb = _mb(_dir_size_bytes(int8_dir))
compression_ratio = (
float(original_size_mb) / float(quantized_size_mb)
if quantized_size_mb > 0
else 0.0
)
info = {
"original_model": checkpoint_dir_name,
"quantization_type": "dynamic_int8",
"original_size_mb": round(original_size_mb, 2),
"quantized_size_mb": round(quantized_size_mb, 2),
"compression_ratio": round(compression_ratio, 3),
}
info_path = os.path.join(int8_dir, "quantization_info.json")
with open(info_path, "w", encoding="utf-8") as f:
json.dump(info, f, indent=2)
return {"model_path": model_path, "info_path": info_path, "info": info}
def _print_table(
fp32_size_mb: float,
int8_size_mb: float,
fp32_single_ms: float,
int8_single_ms: float,
fp32_batch16_ms: float,
int8_batch16_ms: float,
fp32_acc: float,
int8_acc: float,
) -> None:
size_change_pct = 100.0 * (1.0 - (int8_size_mb / fp32_size_mb)) if fp32_size_mb > 0 else 0.0
single_speedup = (fp32_single_ms / int8_single_ms) if int8_single_ms > 0 else 0.0
batch_speedup = (fp32_batch16_ms / int8_batch16_ms) if int8_batch16_ms > 0 else 0.0
acc_delta_pp = (int8_acc - fp32_acc) * 100.0
def line(a: str, b: str, c: str, d: str) -> str:
return f"β”‚ {a:<15} β”‚ {b:<10} β”‚ {c:<11} β”‚ {d:<17} β”‚"
print("β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”")
print(line("Metric", "FP32 Model", "INT8 Model", "Change"))
print("β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€")
print(
line(
"Model size",
f"{fp32_size_mb:.1f} MB",
f"{int8_size_mb:.1f} MB",
f"-{size_change_pct:.1f}% smaller",
)
)
print(
line(
"Single-text ms",
f"{fp32_single_ms:.2f} ms",
f"{int8_single_ms:.2f} ms",
f"{single_speedup:.2f}x faster",
)
)
print(
line(
"Batch-16 ms",
f"{fp32_batch16_ms:.2f} ms",
f"{int8_batch16_ms:.2f} ms",
f"{batch_speedup:.2f}x faster",
)
)
print(
line(
"Test accuracy",
f"{fp32_acc * 100:.2f}%",
f"{int8_acc * 100:.2f}%",
f"{acc_delta_pp:+.2f} pp",
)
)
print("β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜")
def main() -> None:
parser = argparse.ArgumentParser(description="Dynamic INT8 quantization for transformer inference on CPU.")
parser.add_argument("--model", type=str, default="distilbert-base-uncased")
parser.add_argument("--benchmark-only", action="store_true")
args = parser.parse_args()
dir_name = _checkpoint_to_dir(args.model)
fp32_dir = os.path.join(CFG.models_dir, dir_name)
int8_dir = os.path.join(CFG.models_dir, f"{dir_name}_int8")
if not os.path.isdir(fp32_dir):
raise FileNotFoundError(
f"FP32 model directory not found: {fp32_dir}\n"
f"Expected a fine-tuned model saved via save_pretrained() under saved_models/."
)
print(f"[Quantize] Loading FP32 model from: {fp32_dir}")
model_fp32, tokenizer_fp32 = _load_fp32_model(fp32_dir)
print("[Quantize] Applying dynamic INT8 quantization (Linear layers)...")
model_int8 = _quantize_dynamic_int8(model_fp32)
if not args.benchmark_only:
saved = _save_quantized_model(model_int8, fp32_dir, int8_dir, checkpoint_dir_name=dir_name)
print(f"[Quantize] Saved INT8 model -> {saved['model_path']}")
print(f"[Quantize] Saved metadata -> {saved['info_path']}")
else:
os.makedirs(int8_dir, exist_ok=True)
X_test, y_test = load_test_only()
rng = np.random.default_rng(CFG.seed)
sample_idx = rng.choice(len(X_test), size=min(100, len(X_test)), replace=False).tolist()
sample_texts = [X_test[i] for i in sample_idx]
print("[Benchmark] Measuring latency (median ms per text)...")
fp32_single = _benchmark_latency_ms(model_fp32, tokenizer_fp32, sample_texts, batch_size=1)
int8_single = _benchmark_latency_ms(model_int8, tokenizer_fp32, sample_texts, batch_size=1)
fp32_b16 = _benchmark_latency_ms(model_fp32, tokenizer_fp32, sample_texts, batch_size=16)
int8_b16 = _benchmark_latency_ms(model_int8, tokenizer_fp32, sample_texts, batch_size=16)
print("[Eval] Computing test accuracy on 7,600 examples...")
fp32_acc = _accuracy(model_fp32, tokenizer_fp32, X_test, y_test, batch_size=32)
int8_acc = _accuracy(model_int8, tokenizer_fp32, X_test, y_test, batch_size=32)
fp32_size_mb = _mb(_dir_size_bytes(fp32_dir))
int8_size_mb = _mb(_dir_size_bytes(int8_dir))
_print_table(
fp32_size_mb=fp32_size_mb,
int8_size_mb=int8_size_mb,
fp32_single_ms=fp32_single,
int8_single_ms=int8_single,
fp32_batch16_ms=fp32_b16,
int8_batch16_ms=int8_b16,
fp32_acc=fp32_acc,
int8_acc=int8_acc,
)
if __name__ == "__main__":
main()