ABAO77 commited on
Commit
5c8a6b6
·
verified ·
1 Parent(s): 4768f1e

Upload 11 files

Browse files
app.py CHANGED
@@ -1,39 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import faiss
3
- import base64
4
- from PIL import Image
5
- from fastapi import FastAPI, HTTPException
 
6
  from fastapi.responses import JSONResponse
7
- from io import BytesIO
8
- from src.modules.feature_extractor import FeatureExtractor
9
  from fastapi.middleware.cors import CORSMiddleware
10
- from pydantic import BaseModel, Field
11
- import json
12
- from dotenv import load_dotenv
13
- import faulthandler
14
- import torch
 
 
15
 
16
  # Enable fault handler to debug segmentation faults
17
  faulthandler.enable()
18
- load_dotenv(override=True)
19
 
20
  # Force CPU mode to avoid segmentation faults with ONNX/PyTorch
21
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
22
  torch.set_num_threads(1)
23
 
24
- encoded_env = os.getenv("ENCODED_ENV")
25
- if encoded_env:
26
- # Decode the base64 string
27
- decoded_env = base64.b64decode(encoded_env).decode()
28
-
29
- # Load it as a dictionary
30
- env_data = json.loads(decoded_env)
31
-
32
- # Set environment variables
33
- for key, value in env_data.items():
34
- os.environ[key] = value
35
 
36
 
 
37
  app = FastAPI(docs_url="/")
38
  origins = ["*"]
39
 
@@ -45,9 +50,11 @@ app.add_middleware(
45
  allow_headers=["*"],
46
  )
47
 
48
- # Initialize paths
49
  index_path = "./model/db_vit_b_16.index"
50
  onnx_path = "./model/vit_b_16_feature_extractor.onnx"
 
 
51
 
52
  # Check if index file exists
53
  if not os.path.exists(index_path):
@@ -63,82 +70,10 @@ try:
63
  except Exception as e:
64
  raise RuntimeError(f"Error initializing models: {str(e)}")
65
 
66
-
67
- def base64_to_image(base64_str: str) -> Image.Image:
68
- try:
69
- image_data = base64.b64decode(base64_str)
70
- image = Image.open(BytesIO(image_data)).convert("RGB")
71
- return image
72
- except Exception as e:
73
- raise HTTPException(status_code=400, detail="Invalid Base64 image")
74
-
75
-
76
- def image_to_base64(image: Image.Image) -> str:
77
- buffered = BytesIO()
78
- image.save(buffered, format="JPEG")
79
- return base64.b64encode(buffered.getvalue()).decode("utf-8")
80
-
81
-
82
- def extract_zip_file(zip_file_path, destination_folder):
83
- """
84
- Extract a zip file to a destination folder
85
- If destination folder already exists, extraction is skipped
86
-
87
- Args:
88
- zip_file_path: str, path to the zip file
89
- destination_folder: str, path to the destination folder
90
-
91
- Returns:
92
- None
93
- """
94
- import zipfile
95
- import os
96
-
97
- # Check if destination folder already exists
98
- if os.path.exists(destination_folder):
99
- print(f"Destination folder {destination_folder} already exists. Skipping extraction.")
100
- return
101
-
102
- # Check if zip file exists
103
- if not os.path.exists(zip_file_path):
104
- raise FileNotFoundError(f"Zip file not found: {zip_file_path}")
105
-
106
- # Create destination folder
107
- os.makedirs(destination_folder, exist_ok=True)
108
-
109
- # Extract the zip file
110
- with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
111
- for member in zip_ref.infolist():
112
- # Handle non-ASCII filenames
113
- filename = member.filename.encode('cp437').decode('utf-8')
114
- extracted_path = os.path.join(destination_folder, filename)
115
-
116
- # Create directories if needed
117
- os.makedirs(os.path.dirname(extracted_path), exist_ok=True)
118
-
119
- # Extract file
120
- if not filename.endswith('/'): # Skip directories
121
- with zip_ref.open(member) as source, open(extracted_path, 'wb') as target:
122
- target.write(source.read())
123
-
124
- print(f"Successfully extracted {zip_file_path} to {destination_folder}")
125
-
126
-
127
- zip_file = "./images_2.zip"
128
- extract_path = "./data"
129
- # Fix function name to match the defined function
130
  extract_zip_file(zip_file, extract_path)
131
 
132
 
133
- def is_image_file(filename):
134
- valid_extensions = (".png", ".jpg", ".jpeg", ".bmp", ".gif", ".tiff", ".webp")
135
- return filename.lower().endswith(valid_extensions)
136
-
137
-
138
- class ImageSearchBody(BaseModel):
139
- base64_image: str = Field(..., title="Base64 Image String")
140
-
141
-
142
  @app.post("/search-image/")
143
  def search_image(body: ImageSearchBody):
144
  try:
@@ -162,7 +97,7 @@ def search_image(body: ImageSearchBody):
162
  matched_image = Image.open(matched_image_path)
163
  matched_image_base64 = image_to_base64(matched_image)
164
 
165
- # Post-process image name: remove underscores, numbers, and file extension
166
  image_name_post_process = image_name.replace("_", " ") # Replace underscores with spaces
167
  image_name_post_process = ''.join([c for c in image_name_post_process if not c.isdigit()]) # Remove numbers
168
  image_name_post_process = image_name_post_process.rsplit('.', 1)[0] # Remove file extension
@@ -179,38 +114,26 @@ def search_image(body: ImageSearchBody):
179
  except Exception as e:
180
  print(f"Error in search_image: {str(e)}")
181
  return JSONResponse(
182
- content={"error": f"Error processing image: {str(e)}"}, status_code=500
 
183
  )
184
 
185
 
186
- from src.firebase.firebase_provider import process_images
187
-
188
-
189
- class Body(BaseModel):
190
- base64_image: list[str] = Field(..., title="Base64 Image String")
191
- model_config = {
192
- "json_schema_extra": {
193
- "examples": [
194
- {
195
- "base64_image": [
196
- "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAYAAACNiR0NAAABdUlEQVR42mNk",
197
- ]
198
- }
199
- ]
200
- }
201
- }
202
-
203
-
204
  @app.post("/upload_image")
205
- async def upload_image(body: Body):
206
  try:
207
  public_url = await process_images(body.base64_image)
208
- return JSONResponse(content={"public_url": public_url}, status_code=200)
 
 
 
209
  except Exception as e:
210
- return JSONResponse(content={"error": str(e)}, status_code=500)
 
 
 
211
 
212
 
213
  if __name__ == "__main__":
214
  import uvicorn
215
-
216
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ from dotenv import load_dotenv
2
+ import base64
3
+ import json
4
+ import os
5
+ load_dotenv(override=True)
6
+ encoded_env = os.getenv("ENCODED_ENV")
7
+ print(f"Encoded environment: {encoded_env}")
8
+ if encoded_env:
9
+ decoded_env = base64.b64decode(encoded_env).decode()
10
+ env_data = json.loads(decoded_env)
11
+ for key, value in env_data.items():
12
+ os.environ[key] = value
13
+ print(f"Environment variable {key} set to {value}")
14
+
15
  import os
16
  import faiss
17
+ import torch
18
+ import faulthandler
19
+ import json
20
+ from fastapi import FastAPI
21
  from fastapi.responses import JSONResponse
 
 
22
  from fastapi.middleware.cors import CORSMiddleware
23
+ from PIL import Image
24
+
25
+ from src.modules.feature_extractor import FeatureExtractor
26
+ from src.firebase.firebase_provider import process_images
27
+ from src.utils.image_utils import base64_to_image, image_to_base64, is_image_file
28
+ from src.utils.file_utils import extract_zip_file
29
+ from src.models.schemas import ImageSearchBody, ImageUploadBody
30
 
31
  # Enable fault handler to debug segmentation faults
32
  faulthandler.enable()
 
33
 
34
  # Force CPU mode to avoid segmentation faults with ONNX/PyTorch
35
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
36
  torch.set_num_threads(1)
37
 
38
+ # Load environment variables from base64 encoded string
 
 
 
 
 
 
 
 
 
 
39
 
40
 
41
+ # Initialize FastAPI app
42
  app = FastAPI(docs_url="/")
43
  origins = ["*"]
44
 
 
50
  allow_headers=["*"],
51
  )
52
 
53
+ # Initialize paths and extract data
54
  index_path = "./model/db_vit_b_16.index"
55
  onnx_path = "./model/vit_b_16_feature_extractor.onnx"
56
+ zip_file = "./images_2.zip"
57
+ extract_path = "./data"
58
 
59
  # Check if index file exists
60
  if not os.path.exists(index_path):
 
70
  except Exception as e:
71
  raise RuntimeError(f"Error initializing models: {str(e)}")
72
 
73
+ # Extract zip file if needed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  extract_zip_file(zip_file, extract_path)
75
 
76
 
 
 
 
 
 
 
 
 
 
77
  @app.post("/search-image/")
78
  def search_image(body: ImageSearchBody):
79
  try:
 
97
  matched_image = Image.open(matched_image_path)
98
  matched_image_base64 = image_to_base64(matched_image)
99
 
100
+ # Post-process image name
101
  image_name_post_process = image_name.replace("_", " ") # Replace underscores with spaces
102
  image_name_post_process = ''.join([c for c in image_name_post_process if not c.isdigit()]) # Remove numbers
103
  image_name_post_process = image_name_post_process.rsplit('.', 1)[0] # Remove file extension
 
114
  except Exception as e:
115
  print(f"Error in search_image: {str(e)}")
116
  return JSONResponse(
117
+ content={"error": f"Error processing image: {str(e)}"},
118
+ status_code=500
119
  )
120
 
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  @app.post("/upload_image")
123
+ async def upload_image(body: ImageUploadBody):
124
  try:
125
  public_url = await process_images(body.base64_image)
126
+ return JSONResponse(
127
+ content={"public_url": public_url},
128
+ status_code=200
129
+ )
130
  except Exception as e:
131
+ return JSONResponse(
132
+ content={"error": str(e)},
133
+ status_code=500
134
+ )
135
 
136
 
137
  if __name__ == "__main__":
138
  import uvicorn
 
139
  uvicorn.run(app, host="0.0.0.0", port=8000)
src/firebase/firebase_provider.py CHANGED
@@ -70,6 +70,26 @@ def delete_file_from_storage(file_name):
70
  return False
71
 
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  def list_all_files_in_storage():
74
  """
75
  View all files in Firebase Storage
 
70
  return False
71
 
72
 
73
+ def delete_file_by_url(public_url):
74
+ """
75
+ Delete a file from Firebase Storage using its public URL
76
+ param:
77
+ public_url: str - The public URL of the file to be deleted
78
+ return:
79
+ bool - True if the file is deleted successfully, False if the file is not found
80
+ """
81
+ try:
82
+ # Extract the file name from the public URL
83
+ # URL format is typically: https://storage.googleapis.com/BUCKET_NAME/FILE_NAME
84
+ file_name = public_url.split('/')[-1]
85
+
86
+ # Delete the file using the extracted name
87
+ return delete_file_from_storage(file_name)
88
+ except Exception as e:
89
+ print(f"Error deleting file by URL: {e}")
90
+ return False
91
+
92
+
93
  def list_all_files_in_storage():
94
  """
95
  View all files in Firebase Storage
src/models/schemas.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import List
3
+
4
+
5
+ class ImageSearchBody(BaseModel):
6
+ """
7
+ Schema for image search request body.
8
+ """
9
+ base64_image: str = Field(
10
+ ...,
11
+ title="Base64 Image String",
12
+ description="Base64 encoded image string to search for similar images"
13
+ )
14
+
15
+
16
+ class ImageUploadBody(BaseModel):
17
+ """
18
+ Schema for image upload request body.
19
+ """
20
+ base64_image: List[str] = Field(
21
+ ...,
22
+ title="Base64 Image String",
23
+ description="List of base64 encoded image strings to upload"
24
+ )
25
+
26
+ model_config = {
27
+ "json_schema_extra": {
28
+ "examples": [
29
+ {
30
+ "base64_image": [
31
+ "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAYAAACNiR0NAAABdUlEQVR42mNk",
32
+ ]
33
+ }
34
+ ]
35
+ }
36
+ }
src/utils/file_utils.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import zipfile
3
+ from typing import Optional
4
+
5
+
6
+ def extract_zip_file(zip_file_path: str, destination_folder: str) -> None:
7
+ """
8
+ Extract a zip file to a destination folder.
9
+ If destination folder already exists, extraction is skipped.
10
+
11
+ Args:
12
+ zip_file_path (str): Path to the zip file
13
+ destination_folder (str): Path to the destination folder
14
+
15
+ Raises:
16
+ FileNotFoundError: If the zip file doesn't exist
17
+ """
18
+ # Check if destination folder already exists
19
+ if os.path.exists(destination_folder):
20
+ print(f"Destination folder {destination_folder} already exists. Skipping extraction.")
21
+ return
22
+
23
+ # Check if zip file exists
24
+ if not os.path.exists(zip_file_path):
25
+ raise FileNotFoundError(f"Zip file not found: {zip_file_path}")
26
+
27
+ # Create destination folder
28
+ os.makedirs(destination_folder, exist_ok=True)
29
+
30
+ # Extract the zip file
31
+ with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
32
+ for member in zip_ref.infolist():
33
+ # Handle non-ASCII filenames
34
+ filename = member.filename.encode('cp437').decode('utf-8')
35
+ extracted_path = os.path.join(destination_folder, filename)
36
+
37
+ # Create directories if needed
38
+ os.makedirs(os.path.dirname(extracted_path), exist_ok=True)
39
+
40
+ # Extract file
41
+ if not filename.endswith('/'): # Skip directories
42
+ with zip_ref.open(member) as source, open(extracted_path, 'wb') as target:
43
+ target.write(source.read())
44
+
45
+ print(f"Successfully extracted {zip_file_path} to {destination_folder}")
46
+
47
+
48
+ def ensure_directory_exists(directory_path: str) -> None:
49
+ """
50
+ Ensure that a directory exists, create it if it doesn't.
51
+
52
+ Args:
53
+ directory_path (str): Path to the directory
54
+ """
55
+ os.makedirs(directory_path, exist_ok=True)
56
+
57
+
58
+ def get_file_extension(filename: str) -> Optional[str]:
59
+ """
60
+ Get the extension of a file.
61
+
62
+ Args:
63
+ filename (str): Name of the file
64
+
65
+ Returns:
66
+ Optional[str]: The file extension without the dot, or None if no extension
67
+ """
68
+ split = os.path.splitext(filename)
69
+ if len(split) > 1:
70
+ return split[1][1:] # Remove the dot
71
+ return None
src/utils/image_utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ from PIL import Image
3
+ from io import BytesIO
4
+ from fastapi import HTTPException
5
+
6
+
7
+ def base64_to_image(base64_str: str) -> Image.Image:
8
+ """
9
+ Convert a base64 string to a PIL Image.
10
+
11
+ Args:
12
+ base64_str (str): The base64 encoded image string
13
+
14
+ Returns:
15
+ Image.Image: The decoded PIL Image
16
+
17
+ Raises:
18
+ HTTPException: If the base64 string is invalid
19
+ """
20
+ try:
21
+ # Handle frontend format: data:image/jpeg;base64,{base64_data}
22
+ if "," in base64_str:
23
+ base64_str = base64_str.split(",", 1)[1]
24
+
25
+ image_data = base64.b64decode(base64_str)
26
+ image = Image.open(BytesIO(image_data)).convert("RGB")
27
+ return image
28
+ except Exception as e:
29
+ print(f"Base64 decoding error: {str(e)}")
30
+ raise HTTPException(status_code=400, detail=f"Invalid Base64 image: {str(e)}")
31
+
32
+
33
+ def image_to_base64(image: Image.Image) -> str:
34
+ """
35
+ Convert a PIL Image to a base64 string.
36
+
37
+ Args:
38
+ image (Image.Image): The PIL Image to convert
39
+
40
+ Returns:
41
+ str: The base64 encoded image string
42
+ """
43
+ buffered = BytesIO()
44
+ image.save(buffered, format="JPEG")
45
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
46
+
47
+
48
+ def is_image_file(filename: str) -> bool:
49
+ """
50
+ Check if a filename has a valid image extension.
51
+
52
+ Args:
53
+ filename (str): The filename to check
54
+
55
+ Returns:
56
+ bool: True if the file has a valid image extension
57
+ """
58
+ valid_extensions = (".png", ".jpg", ".jpeg", ".bmp", ".gif", ".tiff", ".webp")
59
+ return filename.lower().endswith(valid_extensions)