Spaces:
Build error
Build error
| import os | |
| import onnxruntime | |
| import numpy as np | |
| import PIL | |
| os.environ["KMP_DUPLICATE_LIB_OK"] = "True" | |
| def transforms(image): | |
| """ | |
| Preprocess a PIL image for torchvision-style ImageNet model: | |
| - Resize so that shorter side = 256 | |
| - Center crop to 224x224 | |
| - Convert to numpy array and normalize to [0, 1] | |
| - Normalize using ImageNet mean/std | |
| - Convert to CHW (C, H, W) format | |
| """ | |
| # Step 1: Resize to 256 (shorter side) | |
| w, h = image.size | |
| if h < w: | |
| new_h, new_w = 256, int(w * 256 / h) | |
| else: | |
| new_h, new_w = int(h * 256 / w), 256 | |
| image = image.resize((new_w, new_h), PIL.Image.BILINEAR) | |
| # Step 2: Center crop to 224x224 | |
| left = (image.width - 224) // 2 | |
| top = (image.height - 224) // 2 | |
| image = image.crop((left, top, left + 224, top + 224)) | |
| # Step 3: Convert to NumPy and normalize to [0, 1] | |
| image_np = np.array(image).astype(np.float32) / 255.0 | |
| # Step 4: Normalize using ImageNet mean/std | |
| mean = np.array([0.485, 0.456, 0.406]) | |
| std = np.array([0.229, 0.224, 0.225]) | |
| image_np = (image_np - mean) / std | |
| # Step 5: Convert HWC to CHW | |
| image_np = np.transpose(image_np, (2, 0, 1)) # (C, H, W) | |
| return image_np | |
| class FeatureExtractor: | |
| """Class for extracting features from images using a pre-trained model""" | |
| def __init__(self, base_model, onnx_path=None): | |
| if onnx_path is None: | |
| onnx_path = f"model/{base_model}_feature_extractor.onnx" | |
| self.onnx_path = onnx_path | |
| self.onnx_session = None | |
| # Check if ONNX model exists | |
| if os.path.exists(onnx_path): | |
| print(f"Loading existing ONNX model from {onnx_path}") | |
| # Get model size in MB | |
| model_size = os.path.getsize(onnx_path) / (1024 * 1024) | |
| print(f"Model size: {model_size:.2f} MB") | |
| self.onnx_session = onnxruntime.InferenceSession(onnx_path) | |
| else: | |
| print( | |
| f"ONNX model not found at {onnx_path}. Initializing PyTorch model and converting to ONNX..." | |
| ) | |
| def extract_features(self, img): | |
| """Extract features from an image | |
| Args: | |
| img: PIL.Image, the input image | |
| Returns: | |
| output: torch.Tensor, the extracted features | |
| """ | |
| # apply transformations | |
| x = transforms(img) | |
| # add batch dimension using numpy | |
| x = np.expand_dims(x, axis=0) | |
| # Convert to numpy for ONNX Runtime | |
| x_numpy = x.astype(np.float16) | |
| # Get model size in MB | |
| model_size = os.path.getsize(self.onnx_path) / (1024 * 1024) | |
| print(f"Running inference with ONNX model (size: {model_size:.2f} MB)") | |
| output = self.onnx_session.run( | |
| None, | |
| {'input': x_numpy} | |
| )[0] | |
| return output | |