modules_play / vit_large_patch16_224_eval_quantized.py
richard.lin
update: log output.
65012cb
Raw
History Blame Contribute Delete
16 kB
"""
Evaluate quantized ONNX models on ImageNet-1k validation set.
Uses ONNX Runtime for inference. Loads the cached ImageNet dataset
directly from arrow shard files.
"""
import argparse
import os
import time
import io
import tempfile
import numpy as np
import onnx
import onnxruntime as ort
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from datasets import load_dataset
import pyarrow.ipc as ipc
from transformers import ViTForImageClassification, ViTImageProcessor
from sklearn.metrics import average_precision_score, precision_recall_fscore_support
# ---------------------------------------------------------------------------
# Dataset that reads directly from arrow shards
# ---------------------------------------------------------------------------
class ArrowImageNetDataset(Dataset):
"""Load ImageNet validation data from cached arrow shard files."""
def __init__(self, arrow_dir, transform=None):
self.transform = transform
self.shards = []
self.offsets = [0]
# Load all valid arrow shards
shard_files = sorted(
f for f in os.listdir(arrow_dir)
if f.startswith("imagenet-1k_validation-validation-") and f.endswith(".arrow")
)
for fname in shard_files:
path = os.path.join(arrow_dir, fname)
try:
with open(path, "rb") as f:
reader = ipc.RecordBatchStreamReader(f)
table = reader.read_all()
self.shards.append(table)
self.offsets.append(self.offsets[-1] + len(table))
print(f" Loaded shard {fname}: {len(table)} rows")
except Exception as e:
print(f" SKIP shard {fname}: {e}")
self.total = self.offsets[-1]
print(f" Total images: {self.total}")
def __len__(self):
return self.total
def __getitem__(self, idx):
# Binary search for the correct shard
lo, hi = 0, len(self.shards) - 1
while lo < hi:
mid = (lo + hi) // 2
if self.offsets[mid + 1] <= idx:
lo = mid + 1
else:
hi = mid
shard_idx = lo
local_idx = idx - self.offsets[shard_idx]
table = self.shards[shard_idx]
img_bytes = table.column("image")[local_idx].as_py()
if isinstance(img_bytes, dict):
img_bytes = img_bytes.get("bytes", img_bytes.get("path", b""))
if isinstance(img_bytes, bytes):
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
else:
img = Image.new("RGB", (224, 224))
label = table.column("label")[local_idx].as_py()
if self.transform:
img = self.transform(img)
return img, label
# ---------------------------------------------------------------------------
# Metrics (same as model_eval_test.py)
# ---------------------------------------------------------------------------
def compute_metrics(logits, labels, num_classes):
probs = torch.softmax(torch.from_numpy(logits), dim=1).numpy()
preds = probs.argmax(axis=1)
N = len(labels)
top1 = (preds == labels).sum() / N
topk_vals = np.argsort(probs, axis=1)[:, ::-1]
top5 = sum(labels[i] in topk_vals[i, :5] for i in range(N)) / N
one_hot = np.zeros((N, num_classes), dtype=np.int32)
one_hot[np.arange(N), labels] = 1
aps = []
for c in range(num_classes):
if one_hot[:, c].sum() == 0:
continue
try:
ap = average_precision_score(one_hot[:, c], probs[:, c])
except ValueError:
ap = 0.0
aps.append(ap)
mAP = np.mean(aps) if aps else 0.0
prec_mac, rec_mac, f1_mac, _ = precision_recall_fscore_support(
labels, preds, average="macro", zero_division=0
)
prec_wt, rec_wt, f1_wt, _ = precision_recall_fscore_support(
labels, preds, average="weighted", zero_division=0
)
return {
"top1": top1,
"top5": top5,
"mAP": mAP,
"precision_macro": prec_mac,
"recall_macro": rec_mac,
"f1_macro": f1_mac,
"precision_weighted": prec_wt,
"recall_weighted": rec_wt,
"f1_weighted": f1_wt,
}
# ---------------------------------------------------------------------------
# Model repair — fix known issues in quantized ONNX models
# ---------------------------------------------------------------------------
def repair_onnx_model(onnx_path):
"""Patch known bugs in quantized ONNX models before ORT inference.
Currently fixes:
- INT4 classifier DequantizeLinear axis: ModelOpt sets axis=0 but the
scale shape [1000, 8] is correct for axis=1 with block_size=128.
ORT validates ceil(Di/block_size) on the declared axis and rejects it.
Returns the path to use for inference (original or a temp file with fixes).
"""
if "int4/" not in onnx_path:
return onnx_path
model = onnx.load(onnx_path)
fixed = False
for node in model.graph.node:
if node.op_type != "DequantizeLinear":
continue
if "classifier.weight" not in node.name:
continue
# Read attributes
axis = None
block_size = None
for attr in node.attribute:
if attr.name == "axis":
axis = attr.i
elif attr.name == "block_size":
block_size = attr.i
if axis != 0 or block_size is None:
continue
# Check whether scale shape matches axis=1 instead of axis=0
weight_name = node.input[0]
scale_name = node.input[1]
weight_shape = scale_shape = None
for init in model.graph.initializer:
if init.name == weight_name:
weight_shape = list(init.dims)
if init.name == scale_name:
scale_shape = list(init.dims)
if weight_shape is None or scale_shape is None:
continue
# Expected scale shape if axis were 1: [D0, ceil(D1/block_size)]
expected_axis1 = list(weight_shape)
expected_axis1[1] = (expected_axis1[1] + block_size - 1) // block_size
if scale_shape == expected_axis1:
for attr in node.attribute:
if attr.name == "axis":
attr.i = 1
fixed = True
print(f" [repair] Fixed {node.name}: axis 0 -> 1 "
f"(scale {scale_shape} matches axis=1, block_size={block_size})")
if not fixed:
return onnx_path
# Save repaired model to a temp file (persists for session lifetime)
fd, tmp_path = tempfile.mkstemp(suffix=".onnx", prefix="repaired_")
os.close(fd)
onnx.save(model, tmp_path)
print(f" [repair] Saved repaired model to {tmp_path}")
return tmp_path
# ---------------------------------------------------------------------------
# ONNX model evaluation
# ---------------------------------------------------------------------------
@torch.no_grad()
def evaluate_onnx(onnx_path, loader, num_classes, print_every=500):
"""Evaluate an ONNX model using ONNX Runtime with batch=1 (models have static shapes)."""
# Repair known model bugs before loading
repaired_path = repair_onnx_model(onnx_path)
providers = []
if "CUDAExecutionProvider" in ort.get_available_providers():
providers.append(("CUDAExecutionProvider", {"device_id": 0}))
providers.append("CPUExecutionProvider")
session_options = ort.SessionOptions()
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
# Disable memory pattern/reuse optimizations that conflict with dim_param
# (unknown) dimensions in quantized models. Without this, ORT pre-allocates
# buffers based on incorrectly resolved shapes, causing runtime crashes
# on FP16 (Add node) and FP8 (MatMul in decomposed SDPA).
session_options.enable_mem_pattern = False
session_options.enable_mem_reuse = False
session = ort.InferenceSession(repaired_path, sess_options=session_options, providers=providers)
input_name = session.get_inputs()[0].name
all_logits = []
all_labels = []
total = 0
total_inference_time = 0.0 # strict inference-only timing
start = time.time() # process timer (includes data prep, inference, output, metrics)
for batch_idx, (imgs, labels) in enumerate(loader):
# Run one image at a time (model has static batch=1 in internal Reshape nodes)
for i in range(imgs.size(0)):
single_img = imgs[i:i+1].numpy() # shape (1, C, H, W)
t0 = time.perf_counter()
outputs = session.run(None, {input_name: single_img})
t1 = time.perf_counter()
total_inference_time += (t1 - t0)
all_logits.append(outputs[0])
all_labels.append(np.array([labels[i].item()] if torch.is_tensor(labels[i]) else [labels[i]]))
total += imgs.size(0)
if print_every and (batch_idx + 1) % print_every == 0:
elapsed = time.time() - start
speed = total / elapsed
print(f" [{total:>6d} images] {speed:.1f} img/s")
all_logits = np.concatenate(all_logits, axis=0)
all_labels = np.concatenate(all_labels, axis=0)
metrics = compute_metrics(all_logits, all_labels, num_classes)
metrics["total_images"] = total
elapsed = time.time() - start
metrics["elapsed"] = elapsed
metrics["avg_process_ms"] = elapsed / total * 1000 if total > 0 else 0.0
metrics["avg_inference_ms"] = total_inference_time / total * 1000 if total > 0 else 0.0
return metrics
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="Evaluate quantized ONNX models")
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--num_workers", type=int, default=4)
parser.add_argument("--subset", type=int, default=0, help="Evaluate on first N images (0=all)")
ALL_MODES = ["fp32", "fp16", "int8", "fp8", "int4"]
parser.add_argument(
"--mode", type=str, nargs="*", default=ALL_MODES, choices=ALL_MODES,
help=f"Quantization mode(s) to evaluate (default: all). Choices: {ALL_MODES}",
)
args = parser.parse_args()
# ------------------------------------------------------------------
# Load model & processor
# ------------------------------------------------------------------
model_name = "google/vit-large-patch16-224"
print(f"Loading {model_name} ...")
processor = ViTImageProcessor.from_pretrained(model_name)
input_size = (1, 3, 224, 224) # (B, C, H, W)
print(f" Input size: {input_size}")
# Build a transform callable from the HF processor for use in DataLoader
def transform(img):
inputs = processor(images=img, return_tensors="pt")
return inputs["pixel_values"].squeeze(0) # (C, H, W)
num_classes = 1000
# Load dataset from cached arrow shards
arrow_dir = os.path.expanduser(
"~/.cache/huggingface/datasets/Tsomaros___imagenet-1k_validation/"
"default/0.0.0/55405c49dece42420e68ddd5f80174f19b29ebaf/"
)
print(f"Loading dataset from arrow shards: {arrow_dir}")
dataset = ArrowImageNetDataset(arrow_dir, transform=transform)
if args.subset > 0:
from torch.utils.data import Subset
dataset = Subset(dataset, range(min(args.subset, len(dataset))))
print(f" Using subset: {args.subset} images")
loader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
pin_memory=True,
)
# Define models to evaluate
models = {
"FP32 (baseline)": "vit_large_patch16_224_fp32.onnx",
"FP16": "fp16/vit_large_patch16_224_fp16.onnx",
"INT8 entropy": "int8/vit_large_patch16_224_int8_entropy.onnx",
"INT8 max": "int8/vit_large_patch16_224_int8_max.onnx",
"FP8 entropy": "fp8/vit_large_patch16_224_fp8_entropy.onnx",
"FP8 max": "fp8/vit_large_patch16_224_fp8_max.onnx",
"INT4 awq_clip": "int4/vit_large_patch16_224_int4_awq_clip.onnx",
"INT4 awq_lite (asym)": "int4/vit_large_patch16_224_int4_awq_lite_asym.onnx",
"INT4 awq_lite (sym)": "int4/vit_large_patch16_224_int4_awq_lite.onnx",
"INT4 awq_full": "int4/vit_large_patch16_224_int4_awq_full.onnx",
"INT4 rtn_dq": "int4/vit_large_patch16_224_int4_rtn_dq.onnx",
}
# Filter by --mode selection
mode_prefix = {"fp32": "FP32", "fp16": "FP16", "int8": "INT8", "fp8": "FP8", "int4": "INT4"}
selected_prefixes = {mode_prefix[m] for m in args.mode}
models = {k: v for k, v in models.items() if any(k.startswith(p) for p in selected_prefixes)}
print(f"Evaluating modes: {args.mode}")
# Filter to only existing files
existing_models = {}
for name, path in models.items():
if os.path.exists(path):
existing_models[name] = path
else:
print(f" SKIP: {name} — file not found: {path}")
results = {}
for name, onnx_path in existing_models.items():
print(f"\n{'='*60}")
print(f"Evaluating: {name}")
print(f" Model: {onnx_path}")
print(f"{'='*60}")
try:
metrics = evaluate_onnx(onnx_path, loader, num_classes)
results[name] = metrics
print(f"\n Top-1 Accuracy: {metrics['top1']*100:.3f}%")
print(f" Top-5 Accuracy: {metrics['top5']*100:.3f}%")
print(f" mAP: {metrics['mAP']:.4f}")
print(f" F1 (macro): {metrics['f1_macro']:.4f}")
print(f" F1 (weighted): {metrics['f1_weighted']:.4f}")
print(f" Time: {metrics['elapsed']:.1f}s")
print(f" Avg Process: {metrics['avg_process_ms']:.2f}ms/img")
print(f" Avg Inference: {metrics['avg_inference_ms']:.2f}ms/img")
except Exception as e:
print(f" FAILED: {e}")
import traceback
traceback.print_exc()
results[name] = {"error": str(e)}
# Print comparison table
print(f"\n\n{'='*100}")
print("Evaluation Comparison Table")
print(f"{'='*100}")
print(f" {'Model':<25s} {'Images':>7s} {'Top-1%':>8s} {'Top-5%':>8s} {'mAP':>8s} {'F1_mac':>8s} {'F1_wt':>8s} {'Proc(ms)':>9s} {'Inf(ms)':>8s} {'Time':>8s}")
print(f" {'-'*25} {'-'*7} {'-'*8} {'-'*8} {'-'*8} {'-'*8} {'-'*8} {'-'*9} {'-'*8} {'-'*8}")
for name, m in results.items():
if "error" in m:
print(f" {name:<25s} FAILED: {m['error']}")
else:
print(
f" {name:<25s} "
f"{m['total_images']:>7d} "
f"{m['top1']*100:>8.3f} "
f"{m['top5']*100:>8.3f} "
f"{m['mAP']:>8.4f} "
f"{m['f1_macro']:>8.4f} "
f"{m['f1_weighted']:>8.4f} "
f"{m['avg_process_ms']:>9.2f} "
f"{m['avg_inference_ms']:>8.2f} "
f"{m['elapsed']:>7.1f}s"
)
print(f"\n Reference (timm model card): Top-1: 82.346% | Top-5: 96.394%")
print(f"{'='*90}")
# Find best INT8 model and copy as the canonical output
int8_results = {k: v for k, v in results.items() if k.startswith("INT8") and "error" not in v}
if int8_results:
best_int8 = max(int8_results, key=lambda k: int8_results[k]["top1"])
best_path = existing_models[best_int8]
print(f"\n Best INT8 model: {best_int8} ({best_path})")
print(f" Top-1: {int8_results[best_int8]['top1']*100:.3f}%")
print(f" Top-5: {int8_results[best_int8]['top5']*100:.3f}%")
if __name__ == "__main__":
main()