File size: 2,829 Bytes
5ce8318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import onnxruntime
import numpy as np
import PIL
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
def transforms(image):
    """
    Preprocess a PIL image for torchvision-style ImageNet model:
    - Resize so that shorter side = 256
    - Center crop to 224x224
    - Convert to numpy array and normalize to [0, 1]
    - Normalize using ImageNet mean/std
    - Convert to CHW (C, H, W) format
    """
    # Step 1: Resize to 256 (shorter side)
    w, h = image.size
    if h < w:
        new_h, new_w = 256, int(w * 256 / h)
    else:
        new_h, new_w = int(h * 256 / w), 256
    image = image.resize((new_w, new_h), PIL.Image.BILINEAR)

    # Step 2: Center crop to 224x224
    left = (image.width - 224) // 2
    top = (image.height - 224) // 2
    image = image.crop((left, top, left + 224, top + 224))

    # Step 3: Convert to NumPy and normalize to [0, 1]
    image_np = np.array(image).astype(np.float32) / 255.0

    # Step 4: Normalize using ImageNet mean/std
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    image_np = (image_np - mean) / std

    # Step 5: Convert HWC to CHW
    image_np = np.transpose(image_np, (2, 0, 1))  # (C, H, W)

    return image_np

class FeatureExtractor:
    """Class for extracting features from images using a pre-trained model"""

    def __init__(self, base_model, onnx_path=None):
        if onnx_path is None:
            onnx_path = f"model/{base_model}_feature_extractor.onnx"

        self.onnx_path = onnx_path
        self.onnx_session = None
        # Check if ONNX model exists
        if os.path.exists(onnx_path):
            print(f"Loading existing ONNX model from {onnx_path}")
            # Get model size in MB
            model_size = os.path.getsize(onnx_path) / (1024 * 1024)
            print(f"Model size: {model_size:.2f} MB")
            self.onnx_session = onnxruntime.InferenceSession(onnx_path)
        else:
            print(
                f"ONNX model not found at {onnx_path}. Initializing PyTorch model and converting to ONNX..."
            )
    
    def extract_features(self, img):
        """Extract features from an image

        Args:
            img: PIL.Image, the input image

        Returns:
            output: torch.Tensor, the extracted features
        """
        # apply transformations
        x = transforms(img)
        # add batch dimension using numpy
        x = np.expand_dims(x, axis=0)

        # Convert to numpy for ONNX Runtime
        x_numpy = x.astype(np.float16)
        # Get model size in MB
        model_size = os.path.getsize(self.onnx_path) / (1024 * 1024)
        print(f"Running inference with ONNX model (size: {model_size:.2f} MB)")
        output = self.onnx_session.run(
            None,
            {'input': x_numpy}
        )[0]
        return output