Triventure-AI / src /modules /feature_extractor.py
ABAO77's picture
Upload 37 files
5ce8318 verified
import os
import onnxruntime
import numpy as np
import PIL
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
def transforms(image):
"""
Preprocess a PIL image for torchvision-style ImageNet model:
- Resize so that shorter side = 256
- Center crop to 224x224
- Convert to numpy array and normalize to [0, 1]
- Normalize using ImageNet mean/std
- Convert to CHW (C, H, W) format
"""
# Step 1: Resize to 256 (shorter side)
w, h = image.size
if h < w:
new_h, new_w = 256, int(w * 256 / h)
else:
new_h, new_w = int(h * 256 / w), 256
image = image.resize((new_w, new_h), PIL.Image.BILINEAR)
# Step 2: Center crop to 224x224
left = (image.width - 224) // 2
top = (image.height - 224) // 2
image = image.crop((left, top, left + 224, top + 224))
# Step 3: Convert to NumPy and normalize to [0, 1]
image_np = np.array(image).astype(np.float32) / 255.0
# Step 4: Normalize using ImageNet mean/std
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
image_np = (image_np - mean) / std
# Step 5: Convert HWC to CHW
image_np = np.transpose(image_np, (2, 0, 1)) # (C, H, W)
return image_np
class FeatureExtractor:
"""Class for extracting features from images using a pre-trained model"""
def __init__(self, base_model, onnx_path=None):
if onnx_path is None:
onnx_path = f"model/{base_model}_feature_extractor.onnx"
self.onnx_path = onnx_path
self.onnx_session = None
# Check if ONNX model exists
if os.path.exists(onnx_path):
print(f"Loading existing ONNX model from {onnx_path}")
# Get model size in MB
model_size = os.path.getsize(onnx_path) / (1024 * 1024)
print(f"Model size: {model_size:.2f} MB")
self.onnx_session = onnxruntime.InferenceSession(onnx_path)
else:
print(
f"ONNX model not found at {onnx_path}. Initializing PyTorch model and converting to ONNX..."
)
def extract_features(self, img):
"""Extract features from an image
Args:
img: PIL.Image, the input image
Returns:
output: torch.Tensor, the extracted features
"""
# apply transformations
x = transforms(img)
# add batch dimension using numpy
x = np.expand_dims(x, axis=0)
# Convert to numpy for ONNX Runtime
x_numpy = x.astype(np.float16)
# Get model size in MB
model_size = os.path.getsize(self.onnx_path) / (1024 * 1024)
print(f"Running inference with ONNX model (size: {model_size:.2f} MB)")
output = self.onnx_session.run(
None,
{'input': x_numpy}
)[0]
return output