Spaces:
Sleeping
Sleeping
| import torchvision.models.feature_extraction | |
| import torchvision | |
| import os | |
| import torch | |
| import onnx | |
| import onnxruntime | |
| from onnxconverter_common import float16 | |
| import numpy as np | |
| from src.modules.config_extractor import MODEL_CONFIG | |
| os.environ["KMP_DUPLICATE_LIB_OK"] = "True" | |
| class FeatureExtractor: | |
| """Class for extracting features from images using a pre-trained model""" | |
| def __init__(self, base_model, onnx_path=None): | |
| # set the base model | |
| self.base_model = base_model | |
| # get the number of features | |
| self.feat_dims = MODEL_CONFIG[base_model]["feat_dims"] | |
| # get the feature layer name | |
| self.feat_layer = MODEL_CONFIG[base_model]["feat_layer"] | |
| # Set default ONNX path if not provided | |
| if onnx_path is None: | |
| onnx_path = f"model/{base_model}_feature_extractor.onnx" | |
| self.onnx_path = onnx_path | |
| self.onnx_session = None | |
| # Initialize transforms (needed for both ONNX and PyTorch) | |
| _, self.transforms = self.init_model(base_model) | |
| # Check if ONNX model exists | |
| if os.path.exists(onnx_path): | |
| print(f"Loading existing ONNX model from {onnx_path}") | |
| # Get model size in MB | |
| model_size = os.path.getsize(onnx_path) / (1024 * 1024) | |
| print(f"Model size: {model_size:.2f} MB") | |
| self.onnx_session = onnxruntime.InferenceSession(onnx_path) | |
| else: | |
| print( | |
| f"ONNX model not found at {onnx_path}. Initializing PyTorch model and converting to ONNX..." | |
| ) | |
| # Initialize PyTorch model | |
| self.model, _ = self.init_model(base_model) | |
| self.model.eval() | |
| self.device = torch.device("cpu") | |
| self.model.to(self.device) | |
| # Create directory if it doesn't exist | |
| os.makedirs(os.path.dirname(onnx_path), exist_ok=True) | |
| # Convert to ONNX | |
| self.convert_to_onnx(onnx_path) | |
| # Load the newly created ONNX model | |
| self.onnx_session = onnxruntime.InferenceSession(onnx_path) | |
| # Get model size in MB | |
| model_size = os.path.getsize(onnx_path) / (1024 * 1024) | |
| print(f"Successfully created and loaded ONNX model from {onnx_path}") | |
| print(f"Model size: {model_size:.2f} MB") | |
| def init_model(self, base_model): | |
| """Initialize the model for feature extraction | |
| Args: | |
| base_model: str, the name of the base model | |
| Returns: | |
| model: torch.nn.Module, the feature extraction model | |
| transforms: torchvision.transforms.Compose, the image transformations | |
| """ | |
| if base_model not in MODEL_CONFIG: | |
| raise ValueError(f"Invalid base model: {base_model}") | |
| # get the model and weights | |
| weights = MODEL_CONFIG[base_model]["weights"] | |
| model = torchvision.models.feature_extraction.create_feature_extractor( | |
| MODEL_CONFIG[base_model]["model"](weights=weights), | |
| [MODEL_CONFIG[base_model]["feat_layer"]], | |
| ) | |
| # get the image transformations | |
| transforms = weights.transforms() | |
| return model, transforms | |
| def extract_features(self, img): | |
| """Extract features from an image | |
| Args: | |
| img: PIL.Image, the input image | |
| Returns: | |
| output: torch.Tensor, the extracted features | |
| """ | |
| # apply transformations | |
| x = self.transforms(img) | |
| # add batch dimension | |
| x = x.unsqueeze(0) | |
| # Convert to numpy for ONNX Runtime | |
| x_numpy = x.numpy().astype(np.float16) | |
| # Get model size in MB | |
| model_size = os.path.getsize(self.onnx_path) / (1024 * 1024) | |
| print(f"Running inference with ONNX model (size: {model_size:.2f} MB)") | |
| output = self.onnx_session.run( | |
| None, | |
| {'input': x_numpy} | |
| )[0] | |
| # Convert back to torch tensor | |
| output = torch.from_numpy(output) | |
| return output | |
| def convert_to_onnx(self, save_path): | |
| """Convert the model to ONNX format and save it | |
| Args: | |
| save_path: str, the path to save the ONNX model | |
| Returns: | |
| None | |
| """ | |
| # Create a dummy input tensor | |
| dummy_input = torch.randn(1, 3, 224, 224, device=self.device) | |
| # Export the model to ONNX (FP32 first) | |
| torch.onnx.export( | |
| self.model, | |
| dummy_input, | |
| save_path, | |
| export_params=True, | |
| opset_version=14, | |
| do_constant_folding=True, | |
| input_names=['input'], | |
| output_names=['output'], | |
| dynamic_axes={ | |
| 'input': {0: 'batch_size'}, | |
| 'output': {0: 'batch_size'} | |
| } | |
| ) | |
| # Load and verify the exported model | |
| print("Converting model to float16...") | |
| onnx_model = onnx.load(save_path) | |
| onnx.checker.check_model(onnx_model) | |
| # Convert to float16 | |
| model_fp16 = float16.convert_float_to_float16(onnx_model) | |
| # Save the float16 model | |
| onnx.save(model_fp16, save_path) | |
| print(f"Float16 ONNX model saved to {save_path}") | |