""" ONNX runtime wrappers. Drop-in replacement for the PyTorch SigLIP + DINOv2 models inside AIModelManager._embed_crops_batch. Import pattern in ai_manager.py: from src.services.onnx_models import ONNXVisionStack if USE_ONNX_VISION: self.vision_stack = ONNXVisionStack(ONNX_MODELS_DIR, ONNX_USE_INT8) # use self.vision_stack.encode(crops) instead of torch models """ import os import numpy as np from PIL import Image import onnxruntime as ort # SigLIP normalization (ImageNet-style mean/std for siglip-base-patch16-224) _SIGLIP_MEAN = np.array([0.5, 0.5, 0.5], dtype=np.float32) _SIGLIP_STD = np.array([0.5, 0.5, 0.5], dtype=np.float32) # DINOv2 uses ImageNet stats _DINO_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32) _DINO_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32) def _preprocess_batch( pil_images: list[Image.Image], size: int, mean: np.ndarray, std: np.ndarray ) -> np.ndarray: """Resize + normalize a batch of PIL images to (B, 3, size, size) fp32.""" arrs = [] for im in pil_images: if im.mode != "RGB": im = im.convert("RGB") im = im.resize((size, size), Image.BILINEAR) a = np.asarray(im, dtype=np.float32) / 255.0 a = (a - mean) / std a = a.transpose(2, 0, 1) # HWC -> CHW arrs.append(a) return np.stack(arrs, axis=0) def _l2_normalize(x: np.ndarray, axis: int = 1) -> np.ndarray: n = np.linalg.norm(x, axis=axis, keepdims=True) n = np.where(n == 0, 1.0, n) return x / n class ONNXVisionStack: """SigLIP + DINOv2 fused embeddings via ONNX Runtime (CPU).""" def __init__(self, models_dir: str, use_int8: bool = True): siglip_name = "siglip_vision_int8.onnx" if use_int8 else "siglip_vision.onnx" dino_name = "dinov2_int8.onnx" if use_int8 else "dinov2.onnx" siglip_path = os.path.join(models_dir, siglip_name) dino_path = os.path.join(models_dir, dino_name) if not os.path.exists(siglip_path): raise FileNotFoundError( f"ONNX model not found: {siglip_path}. " "Run scripts/convert_to_onnx.py and upload outputs to the Space." ) if not os.path.exists(dino_path): raise FileNotFoundError(f"ONNX model not found: {dino_path}") sess_opts = ort.SessionOptions() sess_opts.intra_op_num_threads = int(os.getenv("OMP_NUM_THREADS", "2")) sess_opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL self.siglip = ort.InferenceSession( siglip_path, sess_options=sess_opts, providers=["CPUExecutionProvider"] ) self.dino = ort.InferenceSession( dino_path, sess_options=sess_opts, providers=["CPUExecutionProvider"] ) # Warmup — first call is always slow due to kernel compilation dummy = np.zeros((1, 3, 224, 224), dtype=np.float32) self.siglip.run(None, {"pixel_values": dummy}) self.dino.run(None, {"pixel_values": dummy}) def encode(self, pil_crops: list[Image.Image]) -> list[np.ndarray]: """Returns list of 1536-d L2-normalized fused vectors (same shape as old code).""" if not pil_crops: return [] sig_batch = _preprocess_batch(pil_crops, 224, _SIGLIP_MEAN, _SIGLIP_STD) dino_batch = _preprocess_batch(pil_crops, 224, _DINO_MEAN, _DINO_STD) sig_out = self.siglip.run(None, {"pixel_values": sig_batch})[0] # (B, 768) dino_out = self.dino.run(None, {"pixel_values": dino_batch})[0] # (B, 768) sig_n = _l2_normalize(sig_out) dino_n = _l2_normalize(dino_out) fused = np.concatenate([sig_n, dino_n], axis=1) # (B, 1536) fused = _l2_normalize(fused) return [fused[i] for i in range(fused.shape[0])]