Spaces:
Runtime error
Runtime error
File size: 4,499 Bytes
982b011 5147546 982b011 5147546 982b011 5147546 982b011 5147546 982b011 5147546 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import torchvision.models.feature_extraction
import torchvision
import os
import torch
import onnx
import onnxruntime
import numpy as np
from .config_extractor import MODEL_CONFIG
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
class FeatureExtractor:
"""Class for extracting features from images using a pre-trained model"""
def __init__(self, base_model, onnx_path=None):
# set the base model
self.base_model = base_model
# get the number of features
self.feat_dims = MODEL_CONFIG[base_model]["feat_dims"]
# get the feature layer name
self.feat_layer = MODEL_CONFIG[base_model]["feat_layer"]
# Set default ONNX path if not provided
if onnx_path is None:
onnx_path = f"model/{base_model}_feature_extractor.onnx"
self.onnx_path = onnx_path
self.onnx_session = None
# Initialize transforms (needed for both ONNX and PyTorch)
_, self.transforms = self.init_model(base_model)
# Check if ONNX model exists
if os.path.exists(onnx_path):
print(f"Loading existing ONNX model from {onnx_path}")
self.onnx_session = onnxruntime.InferenceSession(onnx_path)
else:
print(
f"ONNX model not found at {onnx_path}. Initializing PyTorch model and converting to ONNX..."
)
# Initialize PyTorch model
self.model, _ = self.init_model(base_model)
self.model.eval()
self.device = torch.device("cpu")
self.model.to(self.device)
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(onnx_path), exist_ok=True)
# Convert to ONNX
self.convert_to_onnx(onnx_path)
# Load the newly created ONNX model
self.onnx_session = onnxruntime.InferenceSession(onnx_path)
print(f"Successfully created and loaded ONNX model from {onnx_path}")
def init_model(self, base_model):
"""Initialize the model for feature extraction
Args:
base_model: str, the name of the base model
Returns:
model: torch.nn.Module, the feature extraction model
transforms: torchvision.transforms.Compose, the image transformations
"""
if base_model not in MODEL_CONFIG:
raise ValueError(f"Invalid base model: {base_model}")
# get the model and weights
weights = MODEL_CONFIG[base_model]["weights"]
model = torchvision.models.feature_extraction.create_feature_extractor(
MODEL_CONFIG[base_model]["model"](weights=weights),
[MODEL_CONFIG[base_model]["feat_layer"]],
)
# get the image transformations
transforms = weights.transforms()
return model, transforms
def extract_features(self, img):
"""Extract features from an image
Args:
img: PIL.Image, the input image
Returns:
output: torch.Tensor, the extracted features
"""
# apply transformations
x = self.transforms(img)
# add batch dimension
x = x.unsqueeze(0)
# Convert to numpy for ONNX Runtime
x_numpy = x.numpy()
# Run inference with ONNX Runtime
print("Running inference with ONNX Runtime")
output = self.onnx_session.run(
None,
{'input': x_numpy}
)[0]
# Convert back to torch tensor
output = torch.from_numpy(output)
return output
def convert_to_onnx(self, save_path):
"""Convert the model to ONNX format and save it
Args:
save_path: str, the path to save the ONNX model
Returns:
None
"""
# Create a dummy input tensor
dummy_input = torch.randn(1, 3, 224, 224, device=self.device)
# Export the model
torch.onnx.export(
self.model,
dummy_input,
save_path,
export_params=True,
opset_version=14,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
}
)
# Verify the exported model
onnx_model = onnx.load(save_path)
onnx.checker.check_model(onnx_model)
print(f"ONNX model saved to {save_path}")
|