ABAO77 commited on
Commit
0ec5620
·
verified ·
1 Parent(s): 5c8a6b6

Upload 12 files

Browse files
app.py CHANGED
@@ -1,41 +1,38 @@
1
- from dotenv import load_dotenv
2
  import base64
3
  import json
4
- import os
 
5
  load_dotenv(override=True)
6
- encoded_env = os.getenv("ENCODED_ENV")
7
- print(f"Encoded environment: {encoded_env}")
8
  if encoded_env:
9
  decoded_env = base64.b64decode(encoded_env).decode()
10
  env_data = json.loads(decoded_env)
11
  for key, value in env_data.items():
12
  os.environ[key] = value
13
- print(f"Environment variable {key} set to {value}")
14
-
15
- import os
16
- import faiss
17
  import torch
18
- import faulthandler
19
- import json
20
  from fastapi import FastAPI
21
  from fastapi.responses import JSONResponse
22
  from fastapi.middleware.cors import CORSMiddleware
 
 
 
23
  from PIL import Image
24
 
25
- from src.modules.feature_extractor import FeatureExtractor
26
- from src.firebase.firebase_provider import process_images
27
  from src.utils.image_utils import base64_to_image, image_to_base64, is_image_file
28
- from src.utils.file_utils import extract_zip_file
29
- from src.models.schemas import ImageSearchBody, ImageUploadBody
 
30
 
31
  # Enable fault handler to debug segmentation faults
32
  faulthandler.enable()
 
33
 
34
  # Force CPU mode to avoid segmentation faults with ONNX/PyTorch
35
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
36
  torch.set_num_threads(1)
37
 
38
- # Load environment variables from base64 encoded string
39
 
40
 
41
  # Initialize FastAPI app
@@ -50,30 +47,21 @@ app.add_middleware(
50
  allow_headers=["*"],
51
  )
52
 
53
- # Initialize paths and extract data
54
  index_path = "./model/db_vit_b_16.index"
55
  onnx_path = "./model/vit_b_16_feature_extractor.onnx"
 
 
 
56
  zip_file = "./images_2.zip"
57
  extract_path = "./data"
58
-
59
- # Check if index file exists
60
- if not os.path.exists(index_path):
61
- raise FileNotFoundError(f"Index file not found: {index_path}")
62
-
63
- try:
64
- # Load FAISS index
65
- index = faiss.read_index(index_path)
66
- print(f"Successfully loaded FAISS index from {index_path}")
67
- # Initialize feature extractor with ONNX support
68
- feature_extractor = FeatureExtractor(base_model="vit_b_16", onnx_path=onnx_path)
69
- print("Successfully initialized feature extractor with ONNX support")
70
- except Exception as e:
71
- raise RuntimeError(f"Error initializing models: {str(e)}")
72
-
73
- # Extract zip file if needed
74
  extract_zip_file(zip_file, extract_path)
75
 
76
 
 
 
 
 
77
  @app.post("/search-image/")
78
  def search_image(body: ImageSearchBody):
79
  try:
@@ -81,27 +69,31 @@ def search_image(body: ImageSearchBody):
81
  image = base64_to_image(body.base64_image)
82
 
83
  # Extract features using ONNX model
84
- output = feature_extractor.extract_features(image)
85
-
86
- # Prepare features for FAISS search
87
- output = output.view(output.size(0), -1)
88
- output = output / output.norm(p=2, dim=1, keepdim=True)
89
 
90
  # Search for similar images
91
- D, I = index.search(output.cpu().numpy(), 1)
92
 
93
  # Get the matched image
94
- image_list = sorted([f for f in os.listdir(extract_path + "/images") if is_image_file(f)])
 
 
95
  image_name = image_list[int(I[0][0])]
96
  matched_image_path = f"{extract_path}/images/{image_name}"
97
  matched_image = Image.open(matched_image_path)
98
  matched_image_base64 = image_to_base64(matched_image)
99
-
100
- # Post-process image name
101
- image_name_post_process = image_name.replace("_", " ") # Replace underscores with spaces
102
- image_name_post_process = ''.join([c for c in image_name_post_process if not c.isdigit()]) # Remove numbers
103
- image_name_post_process = image_name_post_process.rsplit('.', 1)[0] # Remove file extension
104
-
 
 
 
 
 
 
105
  return JSONResponse(
106
  content={
107
  "image_base64": matched_image_base64,
@@ -114,26 +106,35 @@ def search_image(body: ImageSearchBody):
114
  except Exception as e:
115
  print(f"Error in search_image: {str(e)}")
116
  return JSONResponse(
117
- content={"error": f"Error processing image: {str(e)}"},
118
- status_code=500
119
  )
120
 
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  @app.post("/upload_image")
123
- async def upload_image(body: ImageUploadBody):
124
  try:
125
  public_url = await process_images(body.base64_image)
126
- return JSONResponse(
127
- content={"public_url": public_url},
128
- status_code=200
129
- )
130
  except Exception as e:
131
- return JSONResponse(
132
- content={"error": str(e)},
133
- status_code=500
134
- )
135
 
136
 
137
  if __name__ == "__main__":
138
  import uvicorn
139
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
1
+ import os
2
  import base64
3
  import json
4
+ from dotenv import load_dotenv
5
+
6
  load_dotenv(override=True)
7
+ encoded_env = os.getenv("ENCODED_ENV_IMAGE")
 
8
  if encoded_env:
9
  decoded_env = base64.b64decode(encoded_env).decode()
10
  env_data = json.loads(decoded_env)
11
  for key, value in env_data.items():
12
  os.environ[key] = value
 
 
 
 
13
  import torch
 
 
14
  from fastapi import FastAPI
15
  from fastapi.responses import JSONResponse
16
  from fastapi.middleware.cors import CORSMiddleware
17
+ from pydantic import BaseModel, Field
18
+ from dotenv import load_dotenv
19
+ import faulthandler
20
  from PIL import Image
21
 
 
 
22
  from src.utils.image_utils import base64_to_image, image_to_base64, is_image_file
23
+ from src.utils.zip_utils import extract_zip_file
24
+ from src.utils.model_utils import init_models, search_similar_images
25
+ from src.firebase.firebase_provider import process_images
26
 
27
  # Enable fault handler to debug segmentation faults
28
  faulthandler.enable()
29
+ load_dotenv(override=True)
30
 
31
  # Force CPU mode to avoid segmentation faults with ONNX/PyTorch
32
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
33
  torch.set_num_threads(1)
34
 
35
+ # Load environment variables
36
 
37
 
38
  # Initialize FastAPI app
 
47
  allow_headers=["*"],
48
  )
49
 
50
+ # Initialize paths and models
51
  index_path = "./model/db_vit_b_16.index"
52
  onnx_path = "./model/vit_b_16_feature_extractor.onnx"
53
+ index, feature_extractor = init_models(index_path, onnx_path)
54
+
55
+ # Extract images if needed
56
  zip_file = "./images_2.zip"
57
  extract_path = "./data"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  extract_zip_file(zip_file, extract_path)
59
 
60
 
61
+ class ImageSearchBody(BaseModel):
62
+ base64_image: str = Field(..., title="Base64 Image String")
63
+
64
+
65
  @app.post("/search-image/")
66
  def search_image(body: ImageSearchBody):
67
  try:
 
69
  image = base64_to_image(body.base64_image)
70
 
71
  # Extract features using ONNX model
72
+ features = feature_extractor.extract_features(image)
 
 
 
 
73
 
74
  # Search for similar images
75
+ D, I = search_similar_images(index, features)
76
 
77
  # Get the matched image
78
+ image_list = sorted(
79
+ [f for f in os.listdir(extract_path + "/images") if is_image_file(f)]
80
+ )
81
  image_name = image_list[int(I[0][0])]
82
  matched_image_path = f"{extract_path}/images/{image_name}"
83
  matched_image = Image.open(matched_image_path)
84
  matched_image_base64 = image_to_base64(matched_image)
85
+
86
+ # Post-process image name: remove underscores, numbers, and file extension
87
+ image_name_post_process = image_name.replace(
88
+ "_", " "
89
+ ) # Replace underscores with spaces
90
+ image_name_post_process = "".join(
91
+ [c for c in image_name_post_process if not c.isdigit()]
92
+ ) # Remove numbers
93
+ image_name_post_process = image_name_post_process.rsplit(".", 1)[
94
+ 0
95
+ ] # Remove file extension
96
+
97
  return JSONResponse(
98
  content={
99
  "image_base64": matched_image_base64,
 
106
  except Exception as e:
107
  print(f"Error in search_image: {str(e)}")
108
  return JSONResponse(
109
+ content={"error": f"Error processing image: {str(e)}"}, status_code=500
 
110
  )
111
 
112
 
113
+ class Body(BaseModel):
114
+ base64_image: list[str] = Field(..., title="Base64 Image String")
115
+ model_config = {
116
+ "json_schema_extra": {
117
+ "examples": [
118
+ {
119
+ "base64_image": [
120
+ "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAYAAACNiR0NAAABdUlEQVR42mNk",
121
+ ]
122
+ }
123
+ ]
124
+ }
125
+ }
126
+
127
+
128
  @app.post("/upload_image")
129
+ async def upload_image(body: Body):
130
  try:
131
  public_url = await process_images(body.base64_image)
132
+ return JSONResponse(content={"public_url": public_url}, status_code=200)
 
 
 
133
  except Exception as e:
134
+ return JSONResponse(content={"error": str(e)}, status_code=500)
 
 
 
135
 
136
 
137
  if __name__ == "__main__":
138
  import uvicorn
139
+
140
+ uvicorn.run(app, host="0.0.0.0", port=8000)
src/firebase/firebase_provider.py CHANGED
@@ -8,6 +8,7 @@ import asyncio
8
  from typing import List, Optional
9
  from datetime import datetime
10
  import pytz
 
11
 
12
 
13
  import asyncio
@@ -36,20 +37,23 @@ async def upload_file_to_storage(file_path: str, file_name: str) -> str:
36
  """
37
  Asynchronous wrapper to upload a file to Firebase Storage using a thread pool.
38
 
39
- param:
40
  file_path: str - The path of the file on the local machine to be uploaded.
41
  file_name: str - The name of the file in Firebase Storage.
42
 
43
- return:
44
  str - The public URL of the uploaded file.
45
  """
46
  loop = asyncio.get_event_loop()
47
 
48
- # Run the synchronous `upload_file_to_storage_sync` in a thread pool.
49
- public_url = await loop.run_in_executor(
50
- None, functools.partial(upload_file_to_storage_sync, file_path, file_name)
51
- )
 
 
52
 
 
53
  return public_url
54
 
55
 
@@ -81,8 +85,8 @@ def delete_file_by_url(public_url):
81
  try:
82
  # Extract the file name from the public URL
83
  # URL format is typically: https://storage.googleapis.com/BUCKET_NAME/FILE_NAME
84
- file_name = public_url.split('/')[-1]
85
-
86
  # Delete the file using the extracted name
87
  return delete_file_from_storage(file_name)
88
  except Exception as e:
@@ -121,7 +125,7 @@ def download_file_from_storage(file_name, destination_path):
121
 
122
 
123
  async def upload_base64_image_to_storage(
124
- base64_image: str, file_name: str
125
  ) -> Optional[str]:
126
  """
127
  Upload a base64 image to Firebase Storage asynchronously.
@@ -129,46 +133,42 @@ async def upload_base64_image_to_storage(
129
  Args:
130
  base64_image: str - The base64 encoded image
131
  file_name: str - The name of the file to be uploaded
 
132
 
133
  Returns:
134
  Optional[str] - The public URL of the uploaded file or None if failed
135
  """
136
  try:
137
- # Run CPU-intensive operations in thread pool
138
- loop = asyncio.get_event_loop()
139
-
140
- # Decode base64 in thread pool
141
- image_data = await loop.run_in_executor(
142
- None, lambda: base64.b64decode(base64_image)
143
- )
144
-
145
- # Open and process image in thread pool
146
- image = await loop.run_in_executor(
147
- None, lambda: Image.open(io.BytesIO(image_data))
148
- )
149
 
150
- # Create unique temp file path
151
  temp_file_path = os.path.join(
152
- tempfile.gettempdir(), f"{file_name}_{datetime.now().timestamp()}.jpg"
 
153
  )
154
 
155
- # Save image in thread pool
156
- await loop.run_in_executor(
157
- None, lambda: image.save(temp_file_path, format="JPEG")
158
- )
159
 
160
  try:
161
  # Upload to Firebase
162
  public_url = await upload_file_to_storage(
163
- temp_file_path, f"{file_name}.jpg"
164
  )
165
  return public_url
166
  finally:
167
- # Clean up temp file in thread pool
168
- await loop.run_in_executor(None, os.remove, temp_file_path)
 
169
 
170
  except Exception as e:
171
  print(f"Error processing image {file_name}: {str(e)}")
 
 
 
 
 
172
  return None
173
 
174
 
@@ -190,6 +190,18 @@ async def process_images(base64_images: List[str]) -> List[Optional[str]]:
190
  .strftime("%Y-%m-%d_%H-%M-%S")
191
  )
192
  file_name = f"image_{timestamp}_{idx}"
193
- tasks.append(upload_base64_image_to_storage(base64_image, file_name))
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  return await asyncio.gather(*tasks, return_exceptions=True)
 
8
  from typing import List, Optional
9
  from datetime import datetime
10
  import pytz
11
+ from src.utils.image_utils import base64_to_image
12
 
13
 
14
  import asyncio
 
37
  """
38
  Asynchronous wrapper to upload a file to Firebase Storage using a thread pool.
39
 
40
+ Args:
41
  file_path: str - The path of the file on the local machine to be uploaded.
42
  file_name: str - The name of the file in Firebase Storage.
43
 
44
+ Returns:
45
  str - The public URL of the uploaded file.
46
  """
47
  loop = asyncio.get_event_loop()
48
 
49
+ # Run the synchronous upload in a thread pool
50
+ def upload_sync():
51
+ blob = firebase_bucket.blob(file_name)
52
+ blob.upload_from_filename(file_path)
53
+ blob.make_public()
54
+ return blob.public_url
55
 
56
+ public_url = await loop.run_in_executor(None, upload_sync)
57
  return public_url
58
 
59
 
 
85
  try:
86
  # Extract the file name from the public URL
87
  # URL format is typically: https://storage.googleapis.com/BUCKET_NAME/FILE_NAME
88
+ file_name = public_url.split("/")[-1]
89
+
90
  # Delete the file using the extracted name
91
  return delete_file_from_storage(file_name)
92
  except Exception as e:
 
125
 
126
 
127
  async def upload_base64_image_to_storage(
128
+ base64_image: str, file_name: str, format: str = "JPEG"
129
  ) -> Optional[str]:
130
  """
131
  Upload a base64 image to Firebase Storage asynchronously.
 
133
  Args:
134
  base64_image: str - The base64 encoded image
135
  file_name: str - The name of the file to be uploaded
136
+ format: str - The format to save the image in (JPEG, PNG, etc.)
137
 
138
  Returns:
139
  Optional[str] - The public URL of the uploaded file or None if failed
140
  """
141
  try:
142
+ # Convert base64 to PIL Image
143
+ image = base64_to_image(base64_image)
 
 
 
 
 
 
 
 
 
 
144
 
145
+ # Create unique temp file path with appropriate extension
146
  temp_file_path = os.path.join(
147
+ tempfile.gettempdir(),
148
+ f"{file_name}_{datetime.now().timestamp()}.{format.lower()}",
149
  )
150
 
151
+ # Save image in the specified format
152
+ image.save(temp_file_path, format=format)
 
 
153
 
154
  try:
155
  # Upload to Firebase
156
  public_url = await upload_file_to_storage(
157
+ temp_file_path, f"{file_name}.{format.lower()}"
158
  )
159
  return public_url
160
  finally:
161
+ # Clean up temp file
162
+ if os.path.exists(temp_file_path):
163
+ os.remove(temp_file_path)
164
 
165
  except Exception as e:
166
  print(f"Error processing image {file_name}: {str(e)}")
167
+ # If format is not JPEG, try again with JPEG
168
+ if format.upper() != "JPEG":
169
+ return await upload_base64_image_to_storage(
170
+ base64_image, file_name, format="JPEG"
171
+ )
172
  return None
173
 
174
 
 
190
  .strftime("%Y-%m-%d_%H-%M-%S")
191
  )
192
  file_name = f"image_{timestamp}_{idx}"
193
+
194
+ # Determine format from base64 header or default to JPEG
195
+ format = "JPEG"
196
+ if "data:image/" in base64_image:
197
+ mime_type = base64_image.split(";")[0].split("/")[1]
198
+ if mime_type == "png":
199
+ format = "PNG"
200
+ elif mime_type == "webp":
201
+ format = "WEBP"
202
+
203
+ tasks.append(
204
+ upload_base64_image_to_storage(base64_image, file_name, format=format)
205
+ )
206
 
207
  return await asyncio.gather(*tasks, return_exceptions=True)
src/utils/image_utils.py CHANGED
@@ -5,55 +5,105 @@ from fastapi import HTTPException
5
 
6
 
7
  def base64_to_image(base64_str: str) -> Image.Image:
8
- """
9
- Convert a base64 string to a PIL Image.
10
 
11
  Args:
12
- base64_str (str): The base64 encoded image string
13
 
14
  Returns:
15
- Image.Image: The decoded PIL Image
16
 
17
  Raises:
18
- HTTPException: If the base64 string is invalid
19
  """
20
  try:
21
- # Handle frontend format: data:image/jpeg;base64,{base64_data}
22
  if "," in base64_str:
23
  base64_str = base64_str.split(",", 1)[1]
24
 
25
  image_data = base64.b64decode(base64_str)
26
- image = Image.open(BytesIO(image_data)).convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
27
  return image
28
  except Exception as e:
29
  print(f"Base64 decoding error: {str(e)}")
30
  raise HTTPException(status_code=400, detail=f"Invalid Base64 image: {str(e)}")
31
 
32
 
33
- def image_to_base64(image: Image.Image) -> str:
34
- """
35
- Convert a PIL Image to a base64 string.
36
 
37
  Args:
38
- image (Image.Image): The PIL Image to convert
 
39
 
40
  Returns:
41
- str: The base64 encoded image string
42
  """
43
- buffered = BytesIO()
44
- image.save(buffered, format="JPEG")
45
- return base64.b64encode(buffered.getvalue()).decode("utf-8")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
 
48
  def is_image_file(filename: str) -> bool:
49
- """
50
- Check if a filename has a valid image extension.
51
 
52
  Args:
53
- filename (str): The filename to check
54
 
55
  Returns:
56
- bool: True if the file has a valid image extension
57
  """
58
  valid_extensions = (".png", ".jpg", ".jpeg", ".bmp", ".gif", ".tiff", ".webp")
59
- return filename.lower().endswith(valid_extensions)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
 
7
  def base64_to_image(base64_str: str) -> Image.Image:
8
+ """Convert base64 string to PIL Image.
 
9
 
10
  Args:
11
+ base64_str: Base64 encoded image string
12
 
13
  Returns:
14
+ PIL.Image: Decoded image
15
 
16
  Raises:
17
+ HTTPException: If base64 string is invalid
18
  """
19
  try:
20
+ # Handle frontend base64 format (data:image/jpeg;base64,{base64_data})
21
  if "," in base64_str:
22
  base64_str = base64_str.split(",", 1)[1]
23
 
24
  image_data = base64.b64decode(base64_str)
25
+ image = Image.open(BytesIO(image_data))
26
+
27
+ # Convert RGBA to RGB if necessary
28
+ if image.mode in ('RGBA', 'LA'):
29
+ background = Image.new('RGB', image.size, (255, 255, 255))
30
+ if image.mode == 'RGBA':
31
+ background.paste(image, mask=image.split()[3]) # 3 is the alpha channel
32
+ else:
33
+ background.paste(image, mask=image.split()[1]) # 1 is the alpha channel
34
+ image = background
35
+ elif image.mode != 'RGB':
36
+ image = image.convert('RGB')
37
+
38
  return image
39
  except Exception as e:
40
  print(f"Base64 decoding error: {str(e)}")
41
  raise HTTPException(status_code=400, detail=f"Invalid Base64 image: {str(e)}")
42
 
43
 
44
+ def image_to_base64(image: Image.Image, format: str = "JPEG") -> str:
45
+ """Convert PIL Image to base64 string.
 
46
 
47
  Args:
48
+ image: PIL Image object
49
+ format: Output format (JPEG, PNG, etc.)
50
 
51
  Returns:
52
+ str: Base64 encoded image string
53
  """
54
+ try:
55
+ # Convert RGBA to RGB if saving as JPEG
56
+ if format.upper() == "JPEG" and image.mode in ('RGBA', 'LA'):
57
+ background = Image.new('RGB', image.size, (255, 255, 255))
58
+ if image.mode == 'RGBA':
59
+ background.paste(image, mask=image.split()[3])
60
+ else:
61
+ background.paste(image, mask=image.split()[1])
62
+ image = background
63
+ elif format.upper() == "JPEG" and image.mode != 'RGB':
64
+ image = image.convert('RGB')
65
+
66
+ buffered = BytesIO()
67
+ image.save(buffered, format=format)
68
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
69
+ except Exception as e:
70
+ print(f"Error converting image to base64: {str(e)}")
71
+ # Try JPEG as fallback
72
+ if format.upper() != "JPEG":
73
+ return image_to_base64(image, format="JPEG")
74
+ raise
75
 
76
 
77
  def is_image_file(filename: str) -> bool:
78
+ """Check if a filename has a valid image extension.
 
79
 
80
  Args:
81
+ filename: Name of the file to check
82
 
83
  Returns:
84
+ bool: True if file has valid image extension
85
  """
86
  valid_extensions = (".png", ".jpg", ".jpeg", ".bmp", ".gif", ".tiff", ".webp")
87
+ return filename.lower().endswith(valid_extensions)
88
+
89
+
90
+ def get_image_format(filename: str) -> str:
91
+ """Get the format to use for saving an image based on its filename.
92
+
93
+ Args:
94
+ filename: Name of the file
95
+
96
+ Returns:
97
+ str: Format to use (JPEG, PNG, etc.)
98
+ """
99
+ ext = filename.lower().split('.')[-1]
100
+ if ext in ('jpg', 'jpeg'):
101
+ return 'JPEG'
102
+ elif ext == 'png':
103
+ return 'PNG'
104
+ elif ext == 'webp':
105
+ return 'WEBP'
106
+ elif ext == 'gif':
107
+ return 'GIF'
108
+ else:
109
+ return 'JPEG' # Default to JPEG
src/utils/model_utils.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import faiss
3
+ import torch
4
+ from src.modules.feature_extractor import FeatureExtractor
5
+
6
+
7
+ def init_models(index_path: str, onnx_path: str) -> tuple[faiss.IndexFlatIP, FeatureExtractor]:
8
+ """Initialize FAISS index and feature extractor.
9
+
10
+ Args:
11
+ index_path: Path to FAISS index file
12
+ onnx_path: Path to ONNX model file
13
+
14
+ Returns:
15
+ tuple: (FAISS index, Feature extractor)
16
+
17
+ Raises:
18
+ FileNotFoundError: If index file doesn't exist
19
+ RuntimeError: If model initialization fails
20
+ """
21
+ # Check if index file exists
22
+ if not os.path.exists(index_path):
23
+ raise FileNotFoundError(f"Index file not found: {index_path}")
24
+
25
+ try:
26
+ # Load FAISS index
27
+ index = faiss.read_index(index_path)
28
+ print(f"Successfully loaded FAISS index from {index_path}")
29
+
30
+ # Initialize feature extractor with ONNX support
31
+ feature_extractor = FeatureExtractor(base_model="vit_b_16", onnx_path=onnx_path)
32
+ print("Successfully initialized feature extractor with ONNX support")
33
+
34
+ return index, feature_extractor
35
+
36
+ except Exception as e:
37
+ raise RuntimeError(f"Error initializing models: {str(e)}")
38
+
39
+
40
+ def search_similar_images(
41
+ index: faiss.IndexFlatIP,
42
+ features: torch.Tensor,
43
+ k: int = 1
44
+ ) -> tuple[torch.Tensor, torch.Tensor]:
45
+ """Search for similar images using FAISS index.
46
+
47
+ Args:
48
+ index: FAISS index
49
+ features: Image features to search for
50
+ k: Number of similar images to return
51
+
52
+ Returns:
53
+ tuple: (Distances, Indices)
54
+ """
55
+ # Prepare features for FAISS search
56
+ features = features.view(features.size(0), -1)
57
+ features = features / features.norm(p=2, dim=1, keepdim=True)
58
+
59
+ # Search for similar images
60
+ D, I = index.search(features.cpu().numpy(), k)
61
+
62
+ return D, I
src/utils/zip_utils.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import zipfile
3
+
4
+
5
+ def extract_zip_file(zip_file_path: str, destination_folder: str) -> None:
6
+ """Extract a zip file to a destination folder.
7
+ If destination folder already exists, extraction is skipped.
8
+
9
+ Args:
10
+ zip_file_path: Path to the zip file
11
+ destination_folder: Path to the destination folder
12
+
13
+ Raises:
14
+ FileNotFoundError: If zip file doesn't exist
15
+ """
16
+ # Check if destination folder already exists
17
+ if os.path.exists(destination_folder):
18
+ print(f"Destination folder {destination_folder} already exists. Skipping extraction.")
19
+ return
20
+
21
+ # Check if zip file exists
22
+ if not os.path.exists(zip_file_path):
23
+ raise FileNotFoundError(f"Zip file not found: {zip_file_path}")
24
+
25
+ # Create destination folder
26
+ os.makedirs(destination_folder, exist_ok=True)
27
+
28
+ # Extract the zip file
29
+ with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
30
+ for member in zip_ref.infolist():
31
+ # Handle non-ASCII filenames
32
+ filename = member.filename.encode('cp437').decode('utf-8')
33
+ extracted_path = os.path.join(destination_folder, filename)
34
+
35
+ # Create directories if needed
36
+ os.makedirs(os.path.dirname(extracted_path), exist_ok=True)
37
+
38
+ # Extract file
39
+ if not filename.endswith('/'): # Skip directories
40
+ with zip_ref.open(member) as source, open(extracted_path, 'wb') as target:
41
+ target.write(source.read())
42
+
43
+ print(f"Successfully extracted {zip_file_path} to {destination_folder}")