""" Evaluate quantized ONNX models on ImageNet-1k validation set. Uses ONNX Runtime for inference. Loads the cached ImageNet dataset directly from arrow shard files. """ import argparse import os import time import io import tempfile import numpy as np import onnx import onnxruntime as ort import torch from torch.utils.data import Dataset, DataLoader from PIL import Image from datasets import load_dataset import pyarrow.ipc as ipc from transformers import ViTForImageClassification, ViTImageProcessor from sklearn.metrics import average_precision_score, precision_recall_fscore_support # --------------------------------------------------------------------------- # Dataset that reads directly from arrow shards # --------------------------------------------------------------------------- class ArrowImageNetDataset(Dataset): """Load ImageNet validation data from cached arrow shard files.""" def __init__(self, arrow_dir, transform=None): self.transform = transform self.shards = [] self.offsets = [0] # Load all valid arrow shards shard_files = sorted( f for f in os.listdir(arrow_dir) if f.startswith("imagenet-1k_validation-validation-") and f.endswith(".arrow") ) for fname in shard_files: path = os.path.join(arrow_dir, fname) try: with open(path, "rb") as f: reader = ipc.RecordBatchStreamReader(f) table = reader.read_all() self.shards.append(table) self.offsets.append(self.offsets[-1] + len(table)) print(f" Loaded shard {fname}: {len(table)} rows") except Exception as e: print(f" SKIP shard {fname}: {e}") self.total = self.offsets[-1] print(f" Total images: {self.total}") def __len__(self): return self.total def __getitem__(self, idx): # Binary search for the correct shard lo, hi = 0, len(self.shards) - 1 while lo < hi: mid = (lo + hi) // 2 if self.offsets[mid + 1] <= idx: lo = mid + 1 else: hi = mid shard_idx = lo local_idx = idx - self.offsets[shard_idx] table = self.shards[shard_idx] img_bytes = table.column("image")[local_idx].as_py() if isinstance(img_bytes, dict): img_bytes = img_bytes.get("bytes", img_bytes.get("path", b"")) if isinstance(img_bytes, bytes): img = Image.open(io.BytesIO(img_bytes)).convert("RGB") else: img = Image.new("RGB", (224, 224)) label = table.column("label")[local_idx].as_py() if self.transform: img = self.transform(img) return img, label # --------------------------------------------------------------------------- # Metrics (same as model_eval_test.py) # --------------------------------------------------------------------------- def compute_metrics(logits, labels, num_classes): probs = torch.softmax(torch.from_numpy(logits), dim=1).numpy() preds = probs.argmax(axis=1) N = len(labels) top1 = (preds == labels).sum() / N topk_vals = np.argsort(probs, axis=1)[:, ::-1] top5 = sum(labels[i] in topk_vals[i, :5] for i in range(N)) / N one_hot = np.zeros((N, num_classes), dtype=np.int32) one_hot[np.arange(N), labels] = 1 aps = [] for c in range(num_classes): if one_hot[:, c].sum() == 0: continue try: ap = average_precision_score(one_hot[:, c], probs[:, c]) except ValueError: ap = 0.0 aps.append(ap) mAP = np.mean(aps) if aps else 0.0 prec_mac, rec_mac, f1_mac, _ = precision_recall_fscore_support( labels, preds, average="macro", zero_division=0 ) prec_wt, rec_wt, f1_wt, _ = precision_recall_fscore_support( labels, preds, average="weighted", zero_division=0 ) return { "top1": top1, "top5": top5, "mAP": mAP, "precision_macro": prec_mac, "recall_macro": rec_mac, "f1_macro": f1_mac, "precision_weighted": prec_wt, "recall_weighted": rec_wt, "f1_weighted": f1_wt, } # --------------------------------------------------------------------------- # Model repair — fix known issues in quantized ONNX models # --------------------------------------------------------------------------- def repair_onnx_model(onnx_path): """Patch known bugs in quantized ONNX models before ORT inference. Currently fixes: - INT4 classifier DequantizeLinear axis: ModelOpt sets axis=0 but the scale shape [1000, 8] is correct for axis=1 with block_size=128. ORT validates ceil(Di/block_size) on the declared axis and rejects it. Returns the path to use for inference (original or a temp file with fixes). """ if "int4/" not in onnx_path: return onnx_path model = onnx.load(onnx_path) fixed = False for node in model.graph.node: if node.op_type != "DequantizeLinear": continue if "classifier.weight" not in node.name: continue # Read attributes axis = None block_size = None for attr in node.attribute: if attr.name == "axis": axis = attr.i elif attr.name == "block_size": block_size = attr.i if axis != 0 or block_size is None: continue # Check whether scale shape matches axis=1 instead of axis=0 weight_name = node.input[0] scale_name = node.input[1] weight_shape = scale_shape = None for init in model.graph.initializer: if init.name == weight_name: weight_shape = list(init.dims) if init.name == scale_name: scale_shape = list(init.dims) if weight_shape is None or scale_shape is None: continue # Expected scale shape if axis were 1: [D0, ceil(D1/block_size)] expected_axis1 = list(weight_shape) expected_axis1[1] = (expected_axis1[1] + block_size - 1) // block_size if scale_shape == expected_axis1: for attr in node.attribute: if attr.name == "axis": attr.i = 1 fixed = True print(f" [repair] Fixed {node.name}: axis 0 -> 1 " f"(scale {scale_shape} matches axis=1, block_size={block_size})") if not fixed: return onnx_path # Save repaired model to a temp file (persists for session lifetime) fd, tmp_path = tempfile.mkstemp(suffix=".onnx", prefix="repaired_") os.close(fd) onnx.save(model, tmp_path) print(f" [repair] Saved repaired model to {tmp_path}") return tmp_path # --------------------------------------------------------------------------- # ONNX model evaluation # --------------------------------------------------------------------------- @torch.no_grad() def evaluate_onnx(onnx_path, loader, num_classes, print_every=500): """Evaluate an ONNX model using ONNX Runtime with batch=1 (models have static shapes).""" # Repair known model bugs before loading repaired_path = repair_onnx_model(onnx_path) providers = [] if "CUDAExecutionProvider" in ort.get_available_providers(): providers.append(("CUDAExecutionProvider", {"device_id": 0})) providers.append("CPUExecutionProvider") session_options = ort.SessionOptions() session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL # Disable memory pattern/reuse optimizations that conflict with dim_param # (unknown) dimensions in quantized models. Without this, ORT pre-allocates # buffers based on incorrectly resolved shapes, causing runtime crashes # on FP16 (Add node) and FP8 (MatMul in decomposed SDPA). session_options.enable_mem_pattern = False session_options.enable_mem_reuse = False session = ort.InferenceSession(repaired_path, sess_options=session_options, providers=providers) input_name = session.get_inputs()[0].name all_logits = [] all_labels = [] total = 0 total_inference_time = 0.0 # strict inference-only timing start = time.time() # process timer (includes data prep, inference, output, metrics) for batch_idx, (imgs, labels) in enumerate(loader): # Run one image at a time (model has static batch=1 in internal Reshape nodes) for i in range(imgs.size(0)): single_img = imgs[i:i+1].numpy() # shape (1, C, H, W) t0 = time.perf_counter() outputs = session.run(None, {input_name: single_img}) t1 = time.perf_counter() total_inference_time += (t1 - t0) all_logits.append(outputs[0]) all_labels.append(np.array([labels[i].item()] if torch.is_tensor(labels[i]) else [labels[i]])) total += imgs.size(0) if print_every and (batch_idx + 1) % print_every == 0: elapsed = time.time() - start speed = total / elapsed print(f" [{total:>6d} images] {speed:.1f} img/s") all_logits = np.concatenate(all_logits, axis=0) all_labels = np.concatenate(all_labels, axis=0) metrics = compute_metrics(all_logits, all_labels, num_classes) metrics["total_images"] = total elapsed = time.time() - start metrics["elapsed"] = elapsed metrics["avg_process_ms"] = elapsed / total * 1000 if total > 0 else 0.0 metrics["avg_inference_ms"] = total_inference_time / total * 1000 if total > 0 else 0.0 return metrics # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): parser = argparse.ArgumentParser(description="Evaluate quantized ONNX models") parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--num_workers", type=int, default=4) parser.add_argument("--subset", type=int, default=0, help="Evaluate on first N images (0=all)") ALL_MODES = ["fp32", "fp16", "int8", "fp8", "int4"] parser.add_argument( "--mode", type=str, nargs="*", default=ALL_MODES, choices=ALL_MODES, help=f"Quantization mode(s) to evaluate (default: all). Choices: {ALL_MODES}", ) args = parser.parse_args() # ------------------------------------------------------------------ # Load model & processor # ------------------------------------------------------------------ model_name = "google/vit-large-patch16-224" print(f"Loading {model_name} ...") processor = ViTImageProcessor.from_pretrained(model_name) input_size = (1, 3, 224, 224) # (B, C, H, W) print(f" Input size: {input_size}") # Build a transform callable from the HF processor for use in DataLoader def transform(img): inputs = processor(images=img, return_tensors="pt") return inputs["pixel_values"].squeeze(0) # (C, H, W) num_classes = 1000 # Load dataset from cached arrow shards arrow_dir = os.path.expanduser( "~/.cache/huggingface/datasets/Tsomaros___imagenet-1k_validation/" "default/0.0.0/55405c49dece42420e68ddd5f80174f19b29ebaf/" ) print(f"Loading dataset from arrow shards: {arrow_dir}") dataset = ArrowImageNetDataset(arrow_dir, transform=transform) if args.subset > 0: from torch.utils.data import Subset dataset = Subset(dataset, range(min(args.subset, len(dataset)))) print(f" Using subset: {args.subset} images") loader = DataLoader( dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True, ) # Define models to evaluate models = { "FP32 (baseline)": "vit_large_patch16_224_fp32.onnx", "FP16": "fp16/vit_large_patch16_224_fp16.onnx", "INT8 entropy": "int8/vit_large_patch16_224_int8_entropy.onnx", "INT8 max": "int8/vit_large_patch16_224_int8_max.onnx", "FP8 entropy": "fp8/vit_large_patch16_224_fp8_entropy.onnx", "FP8 max": "fp8/vit_large_patch16_224_fp8_max.onnx", "INT4 awq_clip": "int4/vit_large_patch16_224_int4_awq_clip.onnx", "INT4 awq_lite (asym)": "int4/vit_large_patch16_224_int4_awq_lite_asym.onnx", "INT4 awq_lite (sym)": "int4/vit_large_patch16_224_int4_awq_lite.onnx", "INT4 awq_full": "int4/vit_large_patch16_224_int4_awq_full.onnx", "INT4 rtn_dq": "int4/vit_large_patch16_224_int4_rtn_dq.onnx", } # Filter by --mode selection mode_prefix = {"fp32": "FP32", "fp16": "FP16", "int8": "INT8", "fp8": "FP8", "int4": "INT4"} selected_prefixes = {mode_prefix[m] for m in args.mode} models = {k: v for k, v in models.items() if any(k.startswith(p) for p in selected_prefixes)} print(f"Evaluating modes: {args.mode}") # Filter to only existing files existing_models = {} for name, path in models.items(): if os.path.exists(path): existing_models[name] = path else: print(f" SKIP: {name} — file not found: {path}") results = {} for name, onnx_path in existing_models.items(): print(f"\n{'='*60}") print(f"Evaluating: {name}") print(f" Model: {onnx_path}") print(f"{'='*60}") try: metrics = evaluate_onnx(onnx_path, loader, num_classes) results[name] = metrics print(f"\n Top-1 Accuracy: {metrics['top1']*100:.3f}%") print(f" Top-5 Accuracy: {metrics['top5']*100:.3f}%") print(f" mAP: {metrics['mAP']:.4f}") print(f" F1 (macro): {metrics['f1_macro']:.4f}") print(f" F1 (weighted): {metrics['f1_weighted']:.4f}") print(f" Time: {metrics['elapsed']:.1f}s") print(f" Avg Process: {metrics['avg_process_ms']:.2f}ms/img") print(f" Avg Inference: {metrics['avg_inference_ms']:.2f}ms/img") except Exception as e: print(f" FAILED: {e}") import traceback traceback.print_exc() results[name] = {"error": str(e)} # Print comparison table print(f"\n\n{'='*100}") print("Evaluation Comparison Table") print(f"{'='*100}") print(f" {'Model':<25s} {'Images':>7s} {'Top-1%':>8s} {'Top-5%':>8s} {'mAP':>8s} {'F1_mac':>8s} {'F1_wt':>8s} {'Proc(ms)':>9s} {'Inf(ms)':>8s} {'Time':>8s}") print(f" {'-'*25} {'-'*7} {'-'*8} {'-'*8} {'-'*8} {'-'*8} {'-'*8} {'-'*9} {'-'*8} {'-'*8}") for name, m in results.items(): if "error" in m: print(f" {name:<25s} FAILED: {m['error']}") else: print( f" {name:<25s} " f"{m['total_images']:>7d} " f"{m['top1']*100:>8.3f} " f"{m['top5']*100:>8.3f} " f"{m['mAP']:>8.4f} " f"{m['f1_macro']:>8.4f} " f"{m['f1_weighted']:>8.4f} " f"{m['avg_process_ms']:>9.2f} " f"{m['avg_inference_ms']:>8.2f} " f"{m['elapsed']:>7.1f}s" ) print(f"\n Reference (timm model card): Top-1: 82.346% | Top-5: 96.394%") print(f"{'='*90}") # Find best INT8 model and copy as the canonical output int8_results = {k: v for k, v in results.items() if k.startswith("INT8") and "error" not in v} if int8_results: best_int8 = max(int8_results, key=lambda k: int8_results[k]["top1"]) best_path = existing_models[best_int8] print(f"\n Best INT8 model: {best_int8} ({best_path})") print(f" Top-1: {int8_results[best_int8]['top1']*100:.3f}%") print(f" Top-5: {int8_results[best_int8]['top5']*100:.3f}%") if __name__ == "__main__": main()