import torchvision.models.feature_extraction import torchvision import os import torch import onnx import onnxruntime import numpy as np from .config_extractor import MODEL_CONFIG os.environ["KMP_DUPLICATE_LIB_OK"] = "True" class FeatureExtractor: """Class for extracting features from images using a pre-trained model""" def __init__(self, base_model, onnx_path=None): # set the base model self.base_model = base_model # get the number of features self.feat_dims = MODEL_CONFIG[base_model]["feat_dims"] # get the feature layer name self.feat_layer = MODEL_CONFIG[base_model]["feat_layer"] # Set default ONNX path if not provided if onnx_path is None: onnx_path = f"model/{base_model}_feature_extractor.onnx" self.onnx_path = onnx_path self.onnx_session = None # Initialize transforms (needed for both ONNX and PyTorch) _, self.transforms = self.init_model(base_model) # Check if ONNX model exists if os.path.exists(onnx_path): print(f"Loading existing ONNX model from {onnx_path}") self.onnx_session = onnxruntime.InferenceSession(onnx_path) else: print( f"ONNX model not found at {onnx_path}. Initializing PyTorch model and converting to ONNX..." ) # Initialize PyTorch model self.model, _ = self.init_model(base_model) self.model.eval() self.device = torch.device("cpu") self.model.to(self.device) # Create directory if it doesn't exist os.makedirs(os.path.dirname(onnx_path), exist_ok=True) # Convert to ONNX self.convert_to_onnx(onnx_path) # Load the newly created ONNX model self.onnx_session = onnxruntime.InferenceSession(onnx_path) print(f"Successfully created and loaded ONNX model from {onnx_path}") def init_model(self, base_model): """Initialize the model for feature extraction Args: base_model: str, the name of the base model Returns: model: torch.nn.Module, the feature extraction model transforms: torchvision.transforms.Compose, the image transformations """ if base_model not in MODEL_CONFIG: raise ValueError(f"Invalid base model: {base_model}") # get the model and weights weights = MODEL_CONFIG[base_model]["weights"] model = torchvision.models.feature_extraction.create_feature_extractor( MODEL_CONFIG[base_model]["model"](weights=weights), [MODEL_CONFIG[base_model]["feat_layer"]], ) # get the image transformations transforms = weights.transforms() return model, transforms def extract_features(self, img): """Extract features from an image Args: img: PIL.Image, the input image Returns: output: torch.Tensor, the extracted features """ # apply transformations x = self.transforms(img) # add batch dimension x = x.unsqueeze(0) # Convert to numpy for ONNX Runtime x_numpy = x.numpy() # Run inference with ONNX Runtime print("Running inference with ONNX Runtime") output = self.onnx_session.run( None, {'input': x_numpy} )[0] # Convert back to torch tensor output = torch.from_numpy(output) return output def convert_to_onnx(self, save_path): """Convert the model to ONNX format and save it Args: save_path: str, the path to save the ONNX model Returns: None """ # Create a dummy input tensor dummy_input = torch.randn(1, 3, 224, 224, device=self.device) # Export the model torch.onnx.export( self.model, dummy_input, save_path, export_params=True, opset_version=14, do_constant_folding=True, input_names=['input'], output_names=['output'], dynamic_axes={ 'input': {0: 'batch_size'}, 'output': {0: 'batch_size'} } ) # Verify the exported model onnx_model = onnx.load(save_path) onnx.checker.check_model(onnx_model) print(f"ONNX model saved to {save_path}")