Spaces:
Running
Running
| """ | |
| Evaluate quantized ONNX models on ImageNet-1k validation set. | |
| Uses ONNX Runtime for inference. Loads the cached ImageNet dataset | |
| directly from arrow shard files. | |
| """ | |
| import argparse | |
| import os | |
| import time | |
| import io | |
| import tempfile | |
| import numpy as np | |
| import onnx | |
| import onnxruntime as ort | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| from PIL import Image | |
| from datasets import load_dataset | |
| import pyarrow.ipc as ipc | |
| from transformers import ViTForImageClassification, ViTImageProcessor | |
| from sklearn.metrics import average_precision_score, precision_recall_fscore_support | |
| # --------------------------------------------------------------------------- | |
| # Dataset that reads directly from arrow shards | |
| # --------------------------------------------------------------------------- | |
| class ArrowImageNetDataset(Dataset): | |
| """Load ImageNet validation data from cached arrow shard files.""" | |
| def __init__(self, arrow_dir, transform=None): | |
| self.transform = transform | |
| self.shards = [] | |
| self.offsets = [0] | |
| # Load all valid arrow shards | |
| shard_files = sorted( | |
| f for f in os.listdir(arrow_dir) | |
| if f.startswith("imagenet-1k_validation-validation-") and f.endswith(".arrow") | |
| ) | |
| for fname in shard_files: | |
| path = os.path.join(arrow_dir, fname) | |
| try: | |
| with open(path, "rb") as f: | |
| reader = ipc.RecordBatchStreamReader(f) | |
| table = reader.read_all() | |
| self.shards.append(table) | |
| self.offsets.append(self.offsets[-1] + len(table)) | |
| print(f" Loaded shard {fname}: {len(table)} rows") | |
| except Exception as e: | |
| print(f" SKIP shard {fname}: {e}") | |
| self.total = self.offsets[-1] | |
| print(f" Total images: {self.total}") | |
| def __len__(self): | |
| return self.total | |
| def __getitem__(self, idx): | |
| # Binary search for the correct shard | |
| lo, hi = 0, len(self.shards) - 1 | |
| while lo < hi: | |
| mid = (lo + hi) // 2 | |
| if self.offsets[mid + 1] <= idx: | |
| lo = mid + 1 | |
| else: | |
| hi = mid | |
| shard_idx = lo | |
| local_idx = idx - self.offsets[shard_idx] | |
| table = self.shards[shard_idx] | |
| img_bytes = table.column("image")[local_idx].as_py() | |
| if isinstance(img_bytes, dict): | |
| img_bytes = img_bytes.get("bytes", img_bytes.get("path", b"")) | |
| if isinstance(img_bytes, bytes): | |
| img = Image.open(io.BytesIO(img_bytes)).convert("RGB") | |
| else: | |
| img = Image.new("RGB", (224, 224)) | |
| label = table.column("label")[local_idx].as_py() | |
| if self.transform: | |
| img = self.transform(img) | |
| return img, label | |
| # --------------------------------------------------------------------------- | |
| # Metrics (same as model_eval_test.py) | |
| # --------------------------------------------------------------------------- | |
| def compute_metrics(logits, labels, num_classes): | |
| probs = torch.softmax(torch.from_numpy(logits), dim=1).numpy() | |
| preds = probs.argmax(axis=1) | |
| N = len(labels) | |
| top1 = (preds == labels).sum() / N | |
| topk_vals = np.argsort(probs, axis=1)[:, ::-1] | |
| top5 = sum(labels[i] in topk_vals[i, :5] for i in range(N)) / N | |
| one_hot = np.zeros((N, num_classes), dtype=np.int32) | |
| one_hot[np.arange(N), labels] = 1 | |
| aps = [] | |
| for c in range(num_classes): | |
| if one_hot[:, c].sum() == 0: | |
| continue | |
| try: | |
| ap = average_precision_score(one_hot[:, c], probs[:, c]) | |
| except ValueError: | |
| ap = 0.0 | |
| aps.append(ap) | |
| mAP = np.mean(aps) if aps else 0.0 | |
| prec_mac, rec_mac, f1_mac, _ = precision_recall_fscore_support( | |
| labels, preds, average="macro", zero_division=0 | |
| ) | |
| prec_wt, rec_wt, f1_wt, _ = precision_recall_fscore_support( | |
| labels, preds, average="weighted", zero_division=0 | |
| ) | |
| return { | |
| "top1": top1, | |
| "top5": top5, | |
| "mAP": mAP, | |
| "precision_macro": prec_mac, | |
| "recall_macro": rec_mac, | |
| "f1_macro": f1_mac, | |
| "precision_weighted": prec_wt, | |
| "recall_weighted": rec_wt, | |
| "f1_weighted": f1_wt, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Model repair — fix known issues in quantized ONNX models | |
| # --------------------------------------------------------------------------- | |
| def repair_onnx_model(onnx_path): | |
| """Patch known bugs in quantized ONNX models before ORT inference. | |
| Currently fixes: | |
| - INT4 classifier DequantizeLinear axis: ModelOpt sets axis=0 but the | |
| scale shape [1000, 8] is correct for axis=1 with block_size=128. | |
| ORT validates ceil(Di/block_size) on the declared axis and rejects it. | |
| Returns the path to use for inference (original or a temp file with fixes). | |
| """ | |
| if "int4/" not in onnx_path: | |
| return onnx_path | |
| model = onnx.load(onnx_path) | |
| fixed = False | |
| for node in model.graph.node: | |
| if node.op_type != "DequantizeLinear": | |
| continue | |
| if "classifier.weight" not in node.name: | |
| continue | |
| # Read attributes | |
| axis = None | |
| block_size = None | |
| for attr in node.attribute: | |
| if attr.name == "axis": | |
| axis = attr.i | |
| elif attr.name == "block_size": | |
| block_size = attr.i | |
| if axis != 0 or block_size is None: | |
| continue | |
| # Check whether scale shape matches axis=1 instead of axis=0 | |
| weight_name = node.input[0] | |
| scale_name = node.input[1] | |
| weight_shape = scale_shape = None | |
| for init in model.graph.initializer: | |
| if init.name == weight_name: | |
| weight_shape = list(init.dims) | |
| if init.name == scale_name: | |
| scale_shape = list(init.dims) | |
| if weight_shape is None or scale_shape is None: | |
| continue | |
| # Expected scale shape if axis were 1: [D0, ceil(D1/block_size)] | |
| expected_axis1 = list(weight_shape) | |
| expected_axis1[1] = (expected_axis1[1] + block_size - 1) // block_size | |
| if scale_shape == expected_axis1: | |
| for attr in node.attribute: | |
| if attr.name == "axis": | |
| attr.i = 1 | |
| fixed = True | |
| print(f" [repair] Fixed {node.name}: axis 0 -> 1 " | |
| f"(scale {scale_shape} matches axis=1, block_size={block_size})") | |
| if not fixed: | |
| return onnx_path | |
| # Save repaired model to a temp file (persists for session lifetime) | |
| fd, tmp_path = tempfile.mkstemp(suffix=".onnx", prefix="repaired_") | |
| os.close(fd) | |
| onnx.save(model, tmp_path) | |
| print(f" [repair] Saved repaired model to {tmp_path}") | |
| return tmp_path | |
| # --------------------------------------------------------------------------- | |
| # ONNX model evaluation | |
| # --------------------------------------------------------------------------- | |
| def evaluate_onnx(onnx_path, loader, num_classes, print_every=500): | |
| """Evaluate an ONNX model using ONNX Runtime with batch=1 (models have static shapes).""" | |
| # Repair known model bugs before loading | |
| repaired_path = repair_onnx_model(onnx_path) | |
| providers = [] | |
| if "CUDAExecutionProvider" in ort.get_available_providers(): | |
| providers.append(("CUDAExecutionProvider", {"device_id": 0})) | |
| providers.append("CPUExecutionProvider") | |
| session_options = ort.SessionOptions() | |
| session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL | |
| # Disable memory pattern/reuse optimizations that conflict with dim_param | |
| # (unknown) dimensions in quantized models. Without this, ORT pre-allocates | |
| # buffers based on incorrectly resolved shapes, causing runtime crashes | |
| # on FP16 (Add node) and FP8 (MatMul in decomposed SDPA). | |
| session_options.enable_mem_pattern = False | |
| session_options.enable_mem_reuse = False | |
| session = ort.InferenceSession(repaired_path, sess_options=session_options, providers=providers) | |
| input_name = session.get_inputs()[0].name | |
| all_logits = [] | |
| all_labels = [] | |
| total = 0 | |
| total_inference_time = 0.0 # strict inference-only timing | |
| start = time.time() # process timer (includes data prep, inference, output, metrics) | |
| for batch_idx, (imgs, labels) in enumerate(loader): | |
| # Run one image at a time (model has static batch=1 in internal Reshape nodes) | |
| for i in range(imgs.size(0)): | |
| single_img = imgs[i:i+1].numpy() # shape (1, C, H, W) | |
| t0 = time.perf_counter() | |
| outputs = session.run(None, {input_name: single_img}) | |
| t1 = time.perf_counter() | |
| total_inference_time += (t1 - t0) | |
| all_logits.append(outputs[0]) | |
| all_labels.append(np.array([labels[i].item()] if torch.is_tensor(labels[i]) else [labels[i]])) | |
| total += imgs.size(0) | |
| if print_every and (batch_idx + 1) % print_every == 0: | |
| elapsed = time.time() - start | |
| speed = total / elapsed | |
| print(f" [{total:>6d} images] {speed:.1f} img/s") | |
| all_logits = np.concatenate(all_logits, axis=0) | |
| all_labels = np.concatenate(all_labels, axis=0) | |
| metrics = compute_metrics(all_logits, all_labels, num_classes) | |
| metrics["total_images"] = total | |
| elapsed = time.time() - start | |
| metrics["elapsed"] = elapsed | |
| metrics["avg_process_ms"] = elapsed / total * 1000 if total > 0 else 0.0 | |
| metrics["avg_inference_ms"] = total_inference_time / total * 1000 if total > 0 else 0.0 | |
| return metrics | |
| # --------------------------------------------------------------------------- | |
| # Main | |
| # --------------------------------------------------------------------------- | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Evaluate quantized ONNX models") | |
| parser.add_argument("--batch_size", type=int, default=32) | |
| parser.add_argument("--num_workers", type=int, default=4) | |
| parser.add_argument("--subset", type=int, default=0, help="Evaluate on first N images (0=all)") | |
| ALL_MODES = ["fp32", "fp16", "int8", "fp8", "int4"] | |
| parser.add_argument( | |
| "--mode", type=str, nargs="*", default=ALL_MODES, choices=ALL_MODES, | |
| help=f"Quantization mode(s) to evaluate (default: all). Choices: {ALL_MODES}", | |
| ) | |
| args = parser.parse_args() | |
| # ------------------------------------------------------------------ | |
| # Load model & processor | |
| # ------------------------------------------------------------------ | |
| model_name = "google/vit-large-patch16-224" | |
| print(f"Loading {model_name} ...") | |
| processor = ViTImageProcessor.from_pretrained(model_name) | |
| input_size = (1, 3, 224, 224) # (B, C, H, W) | |
| print(f" Input size: {input_size}") | |
| # Build a transform callable from the HF processor for use in DataLoader | |
| def transform(img): | |
| inputs = processor(images=img, return_tensors="pt") | |
| return inputs["pixel_values"].squeeze(0) # (C, H, W) | |
| num_classes = 1000 | |
| # Load dataset from cached arrow shards | |
| arrow_dir = os.path.expanduser( | |
| "~/.cache/huggingface/datasets/Tsomaros___imagenet-1k_validation/" | |
| "default/0.0.0/55405c49dece42420e68ddd5f80174f19b29ebaf/" | |
| ) | |
| print(f"Loading dataset from arrow shards: {arrow_dir}") | |
| dataset = ArrowImageNetDataset(arrow_dir, transform=transform) | |
| if args.subset > 0: | |
| from torch.utils.data import Subset | |
| dataset = Subset(dataset, range(min(args.subset, len(dataset)))) | |
| print(f" Using subset: {args.subset} images") | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=args.batch_size, | |
| shuffle=False, | |
| num_workers=args.num_workers, | |
| pin_memory=True, | |
| ) | |
| # Define models to evaluate | |
| models = { | |
| "FP32 (baseline)": "vit_large_patch16_224_fp32.onnx", | |
| "FP16": "fp16/vit_large_patch16_224_fp16.onnx", | |
| "INT8 entropy": "int8/vit_large_patch16_224_int8_entropy.onnx", | |
| "INT8 max": "int8/vit_large_patch16_224_int8_max.onnx", | |
| "FP8 entropy": "fp8/vit_large_patch16_224_fp8_entropy.onnx", | |
| "FP8 max": "fp8/vit_large_patch16_224_fp8_max.onnx", | |
| "INT4 awq_clip": "int4/vit_large_patch16_224_int4_awq_clip.onnx", | |
| "INT4 awq_lite (asym)": "int4/vit_large_patch16_224_int4_awq_lite_asym.onnx", | |
| "INT4 awq_lite (sym)": "int4/vit_large_patch16_224_int4_awq_lite.onnx", | |
| "INT4 awq_full": "int4/vit_large_patch16_224_int4_awq_full.onnx", | |
| "INT4 rtn_dq": "int4/vit_large_patch16_224_int4_rtn_dq.onnx", | |
| } | |
| # Filter by --mode selection | |
| mode_prefix = {"fp32": "FP32", "fp16": "FP16", "int8": "INT8", "fp8": "FP8", "int4": "INT4"} | |
| selected_prefixes = {mode_prefix[m] for m in args.mode} | |
| models = {k: v for k, v in models.items() if any(k.startswith(p) for p in selected_prefixes)} | |
| print(f"Evaluating modes: {args.mode}") | |
| # Filter to only existing files | |
| existing_models = {} | |
| for name, path in models.items(): | |
| if os.path.exists(path): | |
| existing_models[name] = path | |
| else: | |
| print(f" SKIP: {name} — file not found: {path}") | |
| results = {} | |
| for name, onnx_path in existing_models.items(): | |
| print(f"\n{'='*60}") | |
| print(f"Evaluating: {name}") | |
| print(f" Model: {onnx_path}") | |
| print(f"{'='*60}") | |
| try: | |
| metrics = evaluate_onnx(onnx_path, loader, num_classes) | |
| results[name] = metrics | |
| print(f"\n Top-1 Accuracy: {metrics['top1']*100:.3f}%") | |
| print(f" Top-5 Accuracy: {metrics['top5']*100:.3f}%") | |
| print(f" mAP: {metrics['mAP']:.4f}") | |
| print(f" F1 (macro): {metrics['f1_macro']:.4f}") | |
| print(f" F1 (weighted): {metrics['f1_weighted']:.4f}") | |
| print(f" Time: {metrics['elapsed']:.1f}s") | |
| print(f" Avg Process: {metrics['avg_process_ms']:.2f}ms/img") | |
| print(f" Avg Inference: {metrics['avg_inference_ms']:.2f}ms/img") | |
| except Exception as e: | |
| print(f" FAILED: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| results[name] = {"error": str(e)} | |
| # Print comparison table | |
| print(f"\n\n{'='*100}") | |
| print("Evaluation Comparison Table") | |
| print(f"{'='*100}") | |
| print(f" {'Model':<25s} {'Images':>7s} {'Top-1%':>8s} {'Top-5%':>8s} {'mAP':>8s} {'F1_mac':>8s} {'F1_wt':>8s} {'Proc(ms)':>9s} {'Inf(ms)':>8s} {'Time':>8s}") | |
| print(f" {'-'*25} {'-'*7} {'-'*8} {'-'*8} {'-'*8} {'-'*8} {'-'*8} {'-'*9} {'-'*8} {'-'*8}") | |
| for name, m in results.items(): | |
| if "error" in m: | |
| print(f" {name:<25s} FAILED: {m['error']}") | |
| else: | |
| print( | |
| f" {name:<25s} " | |
| f"{m['total_images']:>7d} " | |
| f"{m['top1']*100:>8.3f} " | |
| f"{m['top5']*100:>8.3f} " | |
| f"{m['mAP']:>8.4f} " | |
| f"{m['f1_macro']:>8.4f} " | |
| f"{m['f1_weighted']:>8.4f} " | |
| f"{m['avg_process_ms']:>9.2f} " | |
| f"{m['avg_inference_ms']:>8.2f} " | |
| f"{m['elapsed']:>7.1f}s" | |
| ) | |
| print(f"\n Reference (timm model card): Top-1: 82.346% | Top-5: 96.394%") | |
| print(f"{'='*90}") | |
| # Find best INT8 model and copy as the canonical output | |
| int8_results = {k: v for k, v in results.items() if k.startswith("INT8") and "error" not in v} | |
| if int8_results: | |
| best_int8 = max(int8_results, key=lambda k: int8_results[k]["top1"]) | |
| best_path = existing_models[best_int8] | |
| print(f"\n Best INT8 model: {best_int8} ({best_path})") | |
| print(f" Top-1: {int8_results[best_int8]['top1']*100:.3f}%") | |
| print(f" Top-5: {int8_results[best_int8]['top5']*100:.3f}%") | |
| if __name__ == "__main__": | |
| main() | |