ABAO77 commited on
Commit
4dc9354
·
verified ·
1 Parent(s): 55ecbbd

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +30 -21
  2. feature_extractor.py +140 -0
  3. requirements.txt +3 -1
app.py CHANGED
@@ -39,19 +39,24 @@ app.add_middleware(
39
  allow_headers=["*"],
40
  )
41
 
 
42
  index_path = "./model/db_vit_b_16.index"
 
43
 
 
44
  if not os.path.exists(index_path):
45
  raise FileNotFoundError(f"Index file not found: {index_path}")
46
 
47
  try:
 
48
  index = faiss.read_index(index_path)
49
- except RuntimeError as e:
50
- raise RuntimeError(f"Error reading FAISS index: {e}")
51
- feature_extractor = FeatureExtractor(base_model="vit_b_16")
52
 
53
- if torch.backends.mps.is_built():
54
- torch.set_default_device("mps")
 
 
 
55
 
56
 
57
  def base64_to_image(base64_str: str) -> Image.Image:
@@ -69,12 +74,6 @@ def image_to_base64(image: Image.Image) -> str:
69
  return base64.b64encode(buffered.getvalue()).decode("utf-8")
70
 
71
 
72
- def image_to_base64(image: Image.Image) -> str:
73
- buffered = BytesIO()
74
- image.save(buffered, format="JPEG")
75
- return base64.b64encode(buffered.getvalue()).decode("utf-8")
76
-
77
-
78
  def unzip_folder(zip_file_path, extract_to_path):
79
  if not os.path.exists(zip_file_path):
80
  raise FileNotFoundError(f"Zip file not found: {zip_file_path}")
@@ -105,20 +104,26 @@ class ImageSearchBody(BaseModel):
105
  @app.post("/search-image/")
106
  async def search_image(body: ImageSearchBody):
107
  try:
 
108
  image = base64_to_image(body.base64_image)
109
- with torch.no_grad():
110
- output = feature_extractor.extract_features(image)
111
- output = output.view(output.size(0), -1)
112
- output = output / output.norm(p=2, dim=1, keepdim=True)
113
- D, I = index.search(output.cpu().numpy(), 1)
114
- print(D, I)
 
 
 
 
 
 
115
  image_list = sorted([f for f in os.listdir(extract_path) if is_image_file(f)])
116
- print(image_list)
117
  image_name = image_list[int(I[0][0])]
118
- matched_image_path = f"{extract_path}/{image_list[int(I[0][0])]}"
119
  matched_image = Image.open(matched_image_path)
120
  matched_image_base64 = image_to_base64(matched_image)
121
-
122
  return JSONResponse(
123
  content={
124
  "image_base64": matched_image_base64,
@@ -129,11 +134,15 @@ async def search_image(body: ImageSearchBody):
129
  )
130
 
131
  except Exception as e:
132
- return JSONResponse(content={"error": str(e)}, status_code=500)
 
 
 
133
 
134
 
135
  from src.firebase.firebase_provider import process_images
136
 
 
137
  class Body(BaseModel):
138
  base64_image: list[str] = Field(..., title="Base64 Image String")
139
  model_config = {
 
39
  allow_headers=["*"],
40
  )
41
 
42
+ # Initialize paths
43
  index_path = "./model/db_vit_b_16.index"
44
+ onnx_path = "./model/vit_b_16_feature_extractor.onnx"
45
 
46
+ # Check if index file exists
47
  if not os.path.exists(index_path):
48
  raise FileNotFoundError(f"Index file not found: {index_path}")
49
 
50
  try:
51
+ # Load FAISS index
52
  index = faiss.read_index(index_path)
53
+ print(f"Successfully loaded FAISS index from {index_path}")
 
 
54
 
55
+ # Initialize feature extractor with ONNX support
56
+ feature_extractor = FeatureExtractor(base_model="vit_b_16", onnx_path=onnx_path)
57
+ print("Successfully initialized feature extractor with ONNX support")
58
+ except Exception as e:
59
+ raise RuntimeError(f"Error initializing models: {str(e)}")
60
 
61
 
62
  def base64_to_image(base64_str: str) -> Image.Image:
 
74
  return base64.b64encode(buffered.getvalue()).decode("utf-8")
75
 
76
 
 
 
 
 
 
 
77
  def unzip_folder(zip_file_path, extract_to_path):
78
  if not os.path.exists(zip_file_path):
79
  raise FileNotFoundError(f"Zip file not found: {zip_file_path}")
 
104
  @app.post("/search-image/")
105
  async def search_image(body: ImageSearchBody):
106
  try:
107
+ # Convert base64 to image
108
  image = base64_to_image(body.base64_image)
109
+
110
+ # Extract features using ONNX model
111
+ output = feature_extractor.extract_features(image)
112
+
113
+ # Prepare features for FAISS search
114
+ output = output.view(output.size(0), -1)
115
+ output = output / output.norm(p=2, dim=1, keepdim=True)
116
+
117
+ # Search for similar images
118
+ D, I = index.search(output.cpu().numpy(), 1)
119
+
120
+ # Get the matched image
121
  image_list = sorted([f for f in os.listdir(extract_path) if is_image_file(f)])
 
122
  image_name = image_list[int(I[0][0])]
123
+ matched_image_path = f"{extract_path}/{image_name}"
124
  matched_image = Image.open(matched_image_path)
125
  matched_image_base64 = image_to_base64(matched_image)
126
+
127
  return JSONResponse(
128
  content={
129
  "image_base64": matched_image_base64,
 
134
  )
135
 
136
  except Exception as e:
137
+ print(f"Error in search_image: {str(e)}")
138
+ return JSONResponse(
139
+ content={"error": f"Error processing image: {str(e)}"}, status_code=500
140
+ )
141
 
142
 
143
  from src.firebase.firebase_provider import process_images
144
 
145
+
146
  class Body(BaseModel):
147
  base64_image: list[str] = Field(..., title="Base64 Image String")
148
  model_config = {
feature_extractor.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision.models.feature_extraction
2
+ import torchvision
3
+ import os
4
+ import torch
5
+ import onnx
6
+ import onnxruntime
7
+ import numpy as np
8
+
9
+ from .config_extractor import MODEL_CONFIG
10
+
11
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
12
+
13
+
14
+ class FeatureExtractor:
15
+ """Class for extracting features from images using a pre-trained model"""
16
+
17
+ def __init__(self, base_model, onnx_path=None):
18
+ # set the base model
19
+ self.base_model = base_model
20
+ # get the number of features
21
+ self.feat_dims = MODEL_CONFIG[base_model]["feat_dims"]
22
+ # get the feature layer name
23
+ self.feat_layer = MODEL_CONFIG[base_model]["feat_layer"]
24
+
25
+ # Set default ONNX path if not provided
26
+ if onnx_path is None:
27
+ onnx_path = f"model/{base_model}_feature_extractor.onnx"
28
+
29
+ self.onnx_path = onnx_path
30
+ self.onnx_session = None
31
+
32
+ # Initialize transforms (needed for both ONNX and PyTorch)
33
+ _, self.transforms = self.init_model(base_model)
34
+
35
+ # Check if ONNX model exists
36
+ if os.path.exists(onnx_path):
37
+ print(f"Loading existing ONNX model from {onnx_path}")
38
+ self.onnx_session = onnxruntime.InferenceSession(onnx_path)
39
+ else:
40
+ print(
41
+ f"ONNX model not found at {onnx_path}. Initializing PyTorch model and converting to ONNX..."
42
+ )
43
+ # Initialize PyTorch model
44
+ self.model, _ = self.init_model(base_model)
45
+ self.model.eval()
46
+ self.device = torch.device("cpu")
47
+ self.model.to(self.device)
48
+
49
+ # Create directory if it doesn't exist
50
+ os.makedirs(os.path.dirname(onnx_path), exist_ok=True)
51
+
52
+ # Convert to ONNX
53
+ self.convert_to_onnx(onnx_path)
54
+
55
+ # Load the newly created ONNX model
56
+ self.onnx_session = onnxruntime.InferenceSession(onnx_path)
57
+ print(f"Successfully created and loaded ONNX model from {onnx_path}")
58
+
59
+ def init_model(self, base_model):
60
+ """Initialize the model for feature extraction
61
+
62
+ Args:
63
+ base_model: str, the name of the base model
64
+
65
+ Returns:
66
+ model: torch.nn.Module, the feature extraction model
67
+ transforms: torchvision.transforms.Compose, the image transformations
68
+ """
69
+ if base_model not in MODEL_CONFIG:
70
+ raise ValueError(f"Invalid base model: {base_model}")
71
+
72
+ # get the model and weights
73
+ weights = MODEL_CONFIG[base_model]["weights"]
74
+ model = torchvision.models.feature_extraction.create_feature_extractor(
75
+ MODEL_CONFIG[base_model]["model"](weights=weights),
76
+ [MODEL_CONFIG[base_model]["feat_layer"]],
77
+ )
78
+ # get the image transformations
79
+ transforms = weights.transforms()
80
+ return model, transforms
81
+
82
+ def extract_features(self, img):
83
+ """Extract features from an image
84
+
85
+ Args:
86
+ img: PIL.Image, the input image
87
+
88
+ Returns:
89
+ output: torch.Tensor, the extracted features
90
+ """
91
+ # apply transformations
92
+ x = self.transforms(img)
93
+ # add batch dimension
94
+ x = x.unsqueeze(0)
95
+
96
+ # Convert to numpy for ONNX Runtime
97
+ x_numpy = x.numpy()
98
+ # Run inference with ONNX Runtime
99
+ print("Running inference with ONNX Runtime")
100
+ output = self.onnx_session.run(
101
+ None,
102
+ {'input': x_numpy}
103
+ )[0]
104
+ # Convert back to torch tensor
105
+ output = torch.from_numpy(output)
106
+
107
+ return output
108
+
109
+ def convert_to_onnx(self, save_path):
110
+ """Convert the model to ONNX format and save it
111
+
112
+ Args:
113
+ save_path: str, the path to save the ONNX model
114
+
115
+ Returns:
116
+ None
117
+ """
118
+ # Create a dummy input tensor
119
+ dummy_input = torch.randn(1, 3, 224, 224, device=self.device)
120
+
121
+ # Export the model
122
+ torch.onnx.export(
123
+ self.model,
124
+ dummy_input,
125
+ save_path,
126
+ export_params=True,
127
+ opset_version=14,
128
+ do_constant_folding=True,
129
+ input_names=['input'],
130
+ output_names=['output'],
131
+ dynamic_axes={
132
+ 'input': {0: 'batch_size'},
133
+ 'output': {0: 'batch_size'}
134
+ }
135
+ )
136
+
137
+ # Verify the exported model
138
+ onnx_model = onnx.load(save_path)
139
+ onnx.checker.check_model(onnx_model)
140
+ print(f"ONNX model saved to {save_path}")
requirements.txt CHANGED
@@ -10,4 +10,6 @@ python-multipart
10
  firebase-admin
11
  python-dotenv
12
  aiofiles
13
- pytz
 
 
 
10
  firebase-admin
11
  python-dotenv
12
  aiofiles
13
+ pytz
14
+ onnx
15
+ onnxruntime