Spaces:
Sleeping
Sleeping
Upload 7 files
Browse files- src/config/settings.py +18 -0
- src/firebase/firebase_config.py +36 -0
- src/firebase/firebase_provider.py +175 -0
- src/modules/config_extractor.py +66 -0
- src/modules/feature_extractor.py +139 -0
- src/search_query.py +154 -0
- src/utils/helper.py +19 -0
src/config/settings.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
WORK_DIR = "./"
|
| 2 |
+
DATA_DIR = f"{WORK_DIR}/data"
|
| 3 |
+
IMAGES_DIR = f"{DATA_DIR}/images"
|
| 4 |
+
RESULTS_DIR = f"{WORK_DIR}/results"
|
| 5 |
+
|
| 6 |
+
# supported feature extractor models
|
| 7 |
+
FEATURE_EXTRACTOR_MODELS = [
|
| 8 |
+
"resnet18",
|
| 9 |
+
"resnet34",
|
| 10 |
+
"resnet50",
|
| 11 |
+
"resnet101",
|
| 12 |
+
"resnet152",
|
| 13 |
+
"vit_b_16",
|
| 14 |
+
"vit_b_32",
|
| 15 |
+
"vit_l_16",
|
| 16 |
+
"vit_l_32",
|
| 17 |
+
"vit_h_14",
|
| 18 |
+
]
|
src/firebase/firebase_config.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import firebase_admin
|
| 2 |
+
from firebase_admin import credentials
|
| 3 |
+
from firebase_admin import storage
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
|
| 8 |
+
# Load environment variables
|
| 9 |
+
load_dotenv()
|
| 10 |
+
firebase_url_storageBucket = os.getenv("URL_STORAGEBUCKET")
|
| 11 |
+
|
| 12 |
+
# Get credentials from environment variables
|
| 13 |
+
credential_firebase = {
|
| 14 |
+
"type": os.getenv("TYPE"),
|
| 15 |
+
"project_id": os.getenv("PROJECT_ID"),
|
| 16 |
+
"private_key_id": os.getenv("PRIVATE_KEY_ID"),
|
| 17 |
+
"private_key": os.getenv("PRIVATE_KEY"),
|
| 18 |
+
"client_email": os.getenv("CLIENT_EMAIL"),
|
| 19 |
+
"client_id": os.getenv("CLIENT_ID"),
|
| 20 |
+
"auth_uri": os.getenv("AUTH_URI"),
|
| 21 |
+
"token_uri": os.getenv("TOKEN_URI"),
|
| 22 |
+
"auth_provider_x509_cert_url": os.getenv("AUTH_PROVIDER_X509_CERT_URL"),
|
| 23 |
+
"client_x509_cert_url": os.getenv("CLIENT_X509_CERT_URL"),
|
| 24 |
+
"universe_domain": os.getenv("UNIVERSE_DOMAIN"),
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# Check if the app is not initialized yet
|
| 29 |
+
if not firebase_admin._apps:
|
| 30 |
+
# Initialize the app with the credentials
|
| 31 |
+
cred = credentials.Certificate(credential_firebase)
|
| 32 |
+
firebase_admin.initialize_app(cred, {"storageBucket": firebase_url_storageBucket})
|
| 33 |
+
|
| 34 |
+
# Initialize Firestore
|
| 35 |
+
firebase_bucket = storage.bucket(app=firebase_admin.get_app())
|
| 36 |
+
print("Storage connected")
|
src/firebase/firebase_provider.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .firebase_config import firebase_bucket
|
| 2 |
+
import base64
|
| 3 |
+
import os
|
| 4 |
+
import tempfile
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import io
|
| 7 |
+
import asyncio
|
| 8 |
+
from typing import List, Optional
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
import pytz
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
import asyncio
|
| 14 |
+
import functools
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def upload_file_to_storage_sync(file_path, file_name):
|
| 18 |
+
"""
|
| 19 |
+
Synchronous function to upload a file to Firebase Storage.
|
| 20 |
+
|
| 21 |
+
param:
|
| 22 |
+
file_path: str - The path of the file on the local machine to be uploaded.
|
| 23 |
+
file_name: str - The name of the file in Firebase Storage.
|
| 24 |
+
|
| 25 |
+
return:
|
| 26 |
+
str - The public URL of the uploaded file.
|
| 27 |
+
"""
|
| 28 |
+
blob = firebase_bucket.blob(file_name)
|
| 29 |
+
blob.upload_from_filename(file_path)
|
| 30 |
+
blob.make_public()
|
| 31 |
+
|
| 32 |
+
return blob.public_url
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
async def upload_file_to_storage(file_path: str, file_name: str) -> str:
|
| 36 |
+
"""
|
| 37 |
+
Asynchronous wrapper to upload a file to Firebase Storage using a thread pool.
|
| 38 |
+
|
| 39 |
+
param:
|
| 40 |
+
file_path: str - The path of the file on the local machine to be uploaded.
|
| 41 |
+
file_name: str - The name of the file in Firebase Storage.
|
| 42 |
+
|
| 43 |
+
return:
|
| 44 |
+
str - The public URL of the uploaded file.
|
| 45 |
+
"""
|
| 46 |
+
loop = asyncio.get_event_loop()
|
| 47 |
+
|
| 48 |
+
# Run the synchronous `upload_file_to_storage_sync` in a thread pool.
|
| 49 |
+
public_url = await loop.run_in_executor(
|
| 50 |
+
None, functools.partial(upload_file_to_storage_sync, file_path, file_name)
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
return public_url
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def delete_file_from_storage(file_name):
|
| 57 |
+
"""
|
| 58 |
+
Delete a file from Firebase Storage
|
| 59 |
+
param:
|
| 60 |
+
file_name: str - The name of the file to be deleted
|
| 61 |
+
return:
|
| 62 |
+
bool - True if the file is deleted successfully, False if the file is not found
|
| 63 |
+
"""
|
| 64 |
+
try:
|
| 65 |
+
blob = firebase_bucket.blob(file_name)
|
| 66 |
+
blob.delete()
|
| 67 |
+
return True
|
| 68 |
+
except Exception as e:
|
| 69 |
+
print("Error:", e)
|
| 70 |
+
return False
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def list_all_files_in_storage():
|
| 74 |
+
"""
|
| 75 |
+
View all files in Firebase Storage
|
| 76 |
+
return:
|
| 77 |
+
dict - Dictionary with keys are names and values are url of all files in Firebase Storage
|
| 78 |
+
"""
|
| 79 |
+
blobs = firebase_bucket.list_blobs()
|
| 80 |
+
blob_dict = {blob.name: blob.public_url for blob in blobs}
|
| 81 |
+
return blob_dict
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def download_file_from_storage(file_name, destination_path):
|
| 85 |
+
"""
|
| 86 |
+
Download a file from Firebase Storage
|
| 87 |
+
param:
|
| 88 |
+
file_name: str - The name of the file to be downloaded
|
| 89 |
+
destination_path: str - The path to save the downloaded file
|
| 90 |
+
return:
|
| 91 |
+
bool - True if the file is downloaded successfully, False if the file is not found
|
| 92 |
+
"""
|
| 93 |
+
try:
|
| 94 |
+
blob = firebase_bucket.blob(file_name)
|
| 95 |
+
blob.download_to_filename(destination_path)
|
| 96 |
+
print("da tai xun thanh cong")
|
| 97 |
+
return True
|
| 98 |
+
except Exception as e:
|
| 99 |
+
print("Error:", e)
|
| 100 |
+
return False
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
async def upload_base64_image_to_storage(
|
| 104 |
+
base64_image: str, file_name: str
|
| 105 |
+
) -> Optional[str]:
|
| 106 |
+
"""
|
| 107 |
+
Upload a base64 image to Firebase Storage asynchronously.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
base64_image: str - The base64 encoded image
|
| 111 |
+
file_name: str - The name of the file to be uploaded
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
Optional[str] - The public URL of the uploaded file or None if failed
|
| 115 |
+
"""
|
| 116 |
+
try:
|
| 117 |
+
# Run CPU-intensive operations in thread pool
|
| 118 |
+
loop = asyncio.get_event_loop()
|
| 119 |
+
|
| 120 |
+
# Decode base64 in thread pool
|
| 121 |
+
image_data = await loop.run_in_executor(
|
| 122 |
+
None, lambda: base64.b64decode(base64_image)
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# Open and process image in thread pool
|
| 126 |
+
image = await loop.run_in_executor(
|
| 127 |
+
None, lambda: Image.open(io.BytesIO(image_data))
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# Create unique temp file path
|
| 131 |
+
temp_file_path = os.path.join(
|
| 132 |
+
tempfile.gettempdir(), f"{file_name}_{datetime.now().timestamp()}.jpg"
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# Save image in thread pool
|
| 136 |
+
await loop.run_in_executor(
|
| 137 |
+
None, lambda: image.save(temp_file_path, format="JPEG")
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
try:
|
| 141 |
+
# Upload to Firebase
|
| 142 |
+
public_url = await upload_file_to_storage(
|
| 143 |
+
temp_file_path, f"{file_name}.jpg"
|
| 144 |
+
)
|
| 145 |
+
return public_url
|
| 146 |
+
finally:
|
| 147 |
+
# Clean up temp file in thread pool
|
| 148 |
+
await loop.run_in_executor(None, os.remove, temp_file_path)
|
| 149 |
+
|
| 150 |
+
except Exception as e:
|
| 151 |
+
print(f"Error processing image {file_name}: {str(e)}")
|
| 152 |
+
return None
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
async def process_images(base64_images: List[str]) -> List[Optional[str]]:
|
| 156 |
+
"""
|
| 157 |
+
Process multiple base64 images concurrently.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
base64_images: List[str] - List of base64 encoded images
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
List[Optional[str]] - List of public URLs or None for failed uploads
|
| 164 |
+
"""
|
| 165 |
+
tasks = []
|
| 166 |
+
for idx, base64_image in enumerate(base64_images):
|
| 167 |
+
timestamp = (
|
| 168 |
+
datetime.now(pytz.timezone("Asia/Ho_Chi_Minh"))
|
| 169 |
+
.replace(tzinfo=None)
|
| 170 |
+
.strftime("%Y-%m-%d_%H-%M-%S")
|
| 171 |
+
)
|
| 172 |
+
file_name = f"image_{timestamp}_{idx}"
|
| 173 |
+
tasks.append(upload_base64_image_to_storage(base64_image, file_name))
|
| 174 |
+
|
| 175 |
+
return await asyncio.gather(*tasks, return_exceptions=True)
|
src/modules/config_extractor.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torchvision
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
# Config for the models that are supported by the extractor
|
| 5 |
+
MODEL_CONFIG = {
|
| 6 |
+
"resnet18": {
|
| 7 |
+
"weights": torchvision.models.ResNet18_Weights.DEFAULT,
|
| 8 |
+
"model": torchvision.models.resnet18,
|
| 9 |
+
"feat_layer": "flatten",
|
| 10 |
+
"feat_dims": 512,
|
| 11 |
+
},
|
| 12 |
+
"resnet34": {
|
| 13 |
+
"weights": torchvision.models.ResNet34_Weights.DEFAULT,
|
| 14 |
+
"model": torchvision.models.resnet34,
|
| 15 |
+
"feat_layer": "flatten",
|
| 16 |
+
"feat_dims": 512,
|
| 17 |
+
},
|
| 18 |
+
"resnet50": {
|
| 19 |
+
"weights": torchvision.models.ResNet50_Weights.DEFAULT,
|
| 20 |
+
"model": torchvision.models.resnet50,
|
| 21 |
+
"feat_layer": "flatten",
|
| 22 |
+
"feat_dims": 2048,
|
| 23 |
+
},
|
| 24 |
+
"resnet101": {
|
| 25 |
+
"weights": torchvision.models.ResNet101_Weights.DEFAULT,
|
| 26 |
+
"model": torchvision.models.resnet101,
|
| 27 |
+
"feat_layer": "flatten",
|
| 28 |
+
"feat_dims": 2048,
|
| 29 |
+
},
|
| 30 |
+
"resnet152": {
|
| 31 |
+
"weights": torchvision.models.ResNet152_Weights.DEFAULT,
|
| 32 |
+
"model": torchvision.models.resnet152,
|
| 33 |
+
"feat_layer": "flatten",
|
| 34 |
+
"feat_dims": 2048,
|
| 35 |
+
},
|
| 36 |
+
"vit_b_16": {
|
| 37 |
+
"weights": torchvision.models.ViT_B_16_Weights.DEFAULT,
|
| 38 |
+
"model": torchvision.models.vit_b_16,
|
| 39 |
+
"feat_layer": "getitem_5",
|
| 40 |
+
"feat_dims": 768,
|
| 41 |
+
},
|
| 42 |
+
"vit_b_32": {
|
| 43 |
+
"weights": torchvision.models.ViT_B_32_Weights.DEFAULT,
|
| 44 |
+
"model": torchvision.models.vit_b_32,
|
| 45 |
+
"feat_layer": "getitem_5",
|
| 46 |
+
"feat_dims": 768,
|
| 47 |
+
},
|
| 48 |
+
"vit_l_16": {
|
| 49 |
+
"weights": torchvision.models.ViT_L_16_Weights.DEFAULT,
|
| 50 |
+
"model": torchvision.models.vit_l_16,
|
| 51 |
+
"feat_layer": "getitem_5",
|
| 52 |
+
"feat_dims": 1024,
|
| 53 |
+
},
|
| 54 |
+
"vit_l_32": {
|
| 55 |
+
"weights": torchvision.models.ViT_L_32_Weights.DEFAULT,
|
| 56 |
+
"model": torchvision.models.vit_l_32,
|
| 57 |
+
"feat_layer": "getitem_5",
|
| 58 |
+
"feat_dims": 1024,
|
| 59 |
+
},
|
| 60 |
+
"vit_h_14": {
|
| 61 |
+
"weights": torchvision.models.ViT_H_14_Weights.DEFAULT,
|
| 62 |
+
"model": torchvision.models.vit_h_14,
|
| 63 |
+
"feat_layer": "getitem_5",
|
| 64 |
+
"feat_dims": 1280,
|
| 65 |
+
},
|
| 66 |
+
}
|
src/modules/feature_extractor.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torchvision.models.feature_extraction
|
| 2 |
+
import torchvision
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
import onnx
|
| 6 |
+
import onnxruntime
|
| 7 |
+
|
| 8 |
+
from src.modules.config_extractor import MODEL_CONFIG
|
| 9 |
+
|
| 10 |
+
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class FeatureExtractor:
|
| 14 |
+
"""Class for extracting features from images using a pre-trained model"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, base_model, onnx_path=None):
|
| 17 |
+
# set the base model
|
| 18 |
+
self.base_model = base_model
|
| 19 |
+
# get the number of features
|
| 20 |
+
self.feat_dims = MODEL_CONFIG[base_model]["feat_dims"]
|
| 21 |
+
# get the feature layer name
|
| 22 |
+
self.feat_layer = MODEL_CONFIG[base_model]["feat_layer"]
|
| 23 |
+
|
| 24 |
+
# Set default ONNX path if not provided
|
| 25 |
+
if onnx_path is None:
|
| 26 |
+
onnx_path = f"model/{base_model}_feature_extractor.onnx"
|
| 27 |
+
|
| 28 |
+
self.onnx_path = onnx_path
|
| 29 |
+
self.onnx_session = None
|
| 30 |
+
|
| 31 |
+
# Initialize transforms (needed for both ONNX and PyTorch)
|
| 32 |
+
_, self.transforms = self.init_model(base_model)
|
| 33 |
+
|
| 34 |
+
# Check if ONNX model exists
|
| 35 |
+
if os.path.exists(onnx_path):
|
| 36 |
+
print(f"Loading existing ONNX model from {onnx_path}")
|
| 37 |
+
self.onnx_session = onnxruntime.InferenceSession(onnx_path)
|
| 38 |
+
else:
|
| 39 |
+
print(
|
| 40 |
+
f"ONNX model not found at {onnx_path}. Initializing PyTorch model and converting to ONNX..."
|
| 41 |
+
)
|
| 42 |
+
# Initialize PyTorch model
|
| 43 |
+
self.model, _ = self.init_model(base_model)
|
| 44 |
+
self.model.eval()
|
| 45 |
+
self.device = torch.device("cpu")
|
| 46 |
+
self.model.to(self.device)
|
| 47 |
+
|
| 48 |
+
# Create directory if it doesn't exist
|
| 49 |
+
os.makedirs(os.path.dirname(onnx_path), exist_ok=True)
|
| 50 |
+
|
| 51 |
+
# Convert to ONNX
|
| 52 |
+
self.convert_to_onnx(onnx_path)
|
| 53 |
+
|
| 54 |
+
# Load the newly created ONNX model
|
| 55 |
+
self.onnx_session = onnxruntime.InferenceSession(onnx_path)
|
| 56 |
+
print(f"Successfully created and loaded ONNX model from {onnx_path}")
|
| 57 |
+
|
| 58 |
+
def init_model(self, base_model):
|
| 59 |
+
"""Initialize the model for feature extraction
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
base_model: str, the name of the base model
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
model: torch.nn.Module, the feature extraction model
|
| 66 |
+
transforms: torchvision.transforms.Compose, the image transformations
|
| 67 |
+
"""
|
| 68 |
+
if base_model not in MODEL_CONFIG:
|
| 69 |
+
raise ValueError(f"Invalid base model: {base_model}")
|
| 70 |
+
|
| 71 |
+
# get the model and weights
|
| 72 |
+
weights = MODEL_CONFIG[base_model]["weights"]
|
| 73 |
+
model = torchvision.models.feature_extraction.create_feature_extractor(
|
| 74 |
+
MODEL_CONFIG[base_model]["model"](weights=weights),
|
| 75 |
+
[MODEL_CONFIG[base_model]["feat_layer"]],
|
| 76 |
+
)
|
| 77 |
+
# get the image transformations
|
| 78 |
+
transforms = weights.transforms()
|
| 79 |
+
return model, transforms
|
| 80 |
+
|
| 81 |
+
def extract_features(self, img):
|
| 82 |
+
"""Extract features from an image
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
img: PIL.Image, the input image
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
output: torch.Tensor, the extracted features
|
| 89 |
+
"""
|
| 90 |
+
# apply transformations
|
| 91 |
+
x = self.transforms(img)
|
| 92 |
+
# add batch dimension
|
| 93 |
+
x = x.unsqueeze(0)
|
| 94 |
+
|
| 95 |
+
# Convert to numpy for ONNX Runtime
|
| 96 |
+
x_numpy = x.numpy()
|
| 97 |
+
# Run inference with ONNX Runtime
|
| 98 |
+
print("Running inference with ONNX Runtime")
|
| 99 |
+
output = self.onnx_session.run(
|
| 100 |
+
None,
|
| 101 |
+
{'input': x_numpy}
|
| 102 |
+
)[0]
|
| 103 |
+
# Convert back to torch tensor
|
| 104 |
+
output = torch.from_numpy(output)
|
| 105 |
+
|
| 106 |
+
return output
|
| 107 |
+
|
| 108 |
+
def convert_to_onnx(self, save_path):
|
| 109 |
+
"""Convert the model to ONNX format and save it
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
save_path: str, the path to save the ONNX model
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
None
|
| 116 |
+
"""
|
| 117 |
+
# Create a dummy input tensor
|
| 118 |
+
dummy_input = torch.randn(1, 3, 224, 224, device=self.device)
|
| 119 |
+
|
| 120 |
+
# Export the model
|
| 121 |
+
torch.onnx.export(
|
| 122 |
+
self.model,
|
| 123 |
+
dummy_input,
|
| 124 |
+
save_path,
|
| 125 |
+
export_params=True,
|
| 126 |
+
opset_version=14,
|
| 127 |
+
do_constant_folding=True,
|
| 128 |
+
input_names=['input'],
|
| 129 |
+
output_names=['output'],
|
| 130 |
+
dynamic_axes={
|
| 131 |
+
'input': {0: 'batch_size'},
|
| 132 |
+
'output': {0: 'batch_size'}
|
| 133 |
+
}
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# Verify the exported model
|
| 137 |
+
onnx_model = onnx.load(save_path)
|
| 138 |
+
onnx.checker.check_model(onnx_model)
|
| 139 |
+
print(f"ONNX model saved to {save_path}")
|
src/search_query.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Description:
|
| 2 |
+
# This script is used to query the index for similar images to a set of random images.
|
| 3 |
+
# The script uses the FeatureExtractor class to extract the features from the images and the Faiss index to search for similar images.
|
| 4 |
+
#
|
| 5 |
+
# Usage:
|
| 6 |
+
#
|
| 7 |
+
# To use this script, you can run the following commands: (You MUST define a feat_extractor since indexings are different for each model)
|
| 8 |
+
# python3 search_query.py --feat_extractor resnet50
|
| 9 |
+
# python3 search_query.py --feat_extractor resnet101
|
| 10 |
+
# python3 search_query.py --feat_extractor resnet50 --n 5
|
| 11 |
+
# python3 search_query.py --feat_extractor resnet50 --k 20
|
| 12 |
+
# python3 search_query.py --feat_extractor resnet50 --n 10 --k 12
|
| 13 |
+
#
|
| 14 |
+
import matplotlib.pyplot as plt
|
| 15 |
+
import numpy as np
|
| 16 |
+
import argparse
|
| 17 |
+
import torch
|
| 18 |
+
import faiss
|
| 19 |
+
import PIL
|
| 20 |
+
import os
|
| 21 |
+
|
| 22 |
+
from src.modules.feature_extractor import FeatureExtractor
|
| 23 |
+
from src.config.settings import FEATURE_EXTRACTOR_MODELS
|
| 24 |
+
from src.config.settings import DATA_DIR, IMAGES_DIR, RESULTS_DIR
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def select_random_images(n, image_list):
|
| 28 |
+
"""Select n random images from the image list.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
n (int): The number of images to select.
|
| 32 |
+
image_list (list[str]): The list of image file names.
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
list[PIL.Image]: The list of selected images.
|
| 36 |
+
"""
|
| 37 |
+
selected_indices = np.random.randint(len(image_list), size=n)
|
| 38 |
+
img_filenames = [image_list[i] for i in selected_indices]
|
| 39 |
+
images = [
|
| 40 |
+
PIL.Image.open(os.path.join(IMAGES_DIR, img_filename))
|
| 41 |
+
for img_filename in img_filenames
|
| 42 |
+
]
|
| 43 |
+
return images
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def plot_query_results(query_img, similar_imgs, distances, out_filepath):
|
| 47 |
+
"""Plot the query image and the similar images side by side. Save the plot to the specified file path.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
query_img (PIL.Image): The query image.
|
| 51 |
+
similar_imgs (list[PIL.Image]): The list of similar images.
|
| 52 |
+
distances (list[float]): The list of distances of the similar images.
|
| 53 |
+
out_filepath (str): The file path to save the plot.
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
None
|
| 57 |
+
"""
|
| 58 |
+
# initialize the figure
|
| 59 |
+
fig, axes = plt.subplots(3, args.k // 2, figsize=(20, 10))
|
| 60 |
+
# plot the query image
|
| 61 |
+
axes[0, 0].imshow(query_img)
|
| 62 |
+
axes[0, 0].set_title("Query Image")
|
| 63 |
+
axes[0, 0].axis("off")
|
| 64 |
+
# do not draw the remaining pots in the first row
|
| 65 |
+
for i in range(1, args.k // 2):
|
| 66 |
+
axes[0, i].axis("off")
|
| 67 |
+
# plot the similar images
|
| 68 |
+
for i, (img, dist) in enumerate(zip(similar_imgs, distances)):
|
| 69 |
+
axes[i // (args.k // 2) + 1, i % (args.k // 2)].imshow(img)
|
| 70 |
+
axes[i // (args.k // 2) + 1, i % (args.k // 2)].set_title(f"{dist:.4f}")
|
| 71 |
+
axes[i // (args.k // 2) + 1, i % (args.k // 2)].axis("off")
|
| 72 |
+
# remove the remaining axes
|
| 73 |
+
plt.tight_layout()
|
| 74 |
+
# save the plot
|
| 75 |
+
plt.savefig(out_filepath, bbox_inches="tight", dpi=200)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def main(args=None):
|
| 79 |
+
|
| 80 |
+
# set the random seed for reproducibility
|
| 81 |
+
np.random.seed(args.seed)
|
| 82 |
+
|
| 83 |
+
# load the vector database index
|
| 84 |
+
index_filepath = os.path.join(DATA_DIR, f"db_{args.feat_extractor}.index")
|
| 85 |
+
index = faiss.read_index(index_filepath)
|
| 86 |
+
|
| 87 |
+
# initialize the feature extractor with the base model specified in the arguments
|
| 88 |
+
feature_extractor = FeatureExtractor(base_model=args.feat_extractor)
|
| 89 |
+
|
| 90 |
+
# get the list of images in sorted order since the index is built in the same order
|
| 91 |
+
image_list = sorted(os.listdir(IMAGES_DIR))
|
| 92 |
+
# select n random images
|
| 93 |
+
query_images = select_random_images(args.n, image_list)
|
| 94 |
+
|
| 95 |
+
with torch.no_grad():
|
| 96 |
+
# iterate over the selected/query images
|
| 97 |
+
for query_idx, img in enumerate(query_images, start=1):
|
| 98 |
+
# output now has the features corresponding to input x
|
| 99 |
+
output = feature_extractor.extract_ưfeatures(img)
|
| 100 |
+
# keep only batch dimension
|
| 101 |
+
output = output.view(output.size(0), -1)
|
| 102 |
+
# normalize
|
| 103 |
+
output = output / output.norm(p=2, dim=1, keepdim=True)
|
| 104 |
+
# search for similar images
|
| 105 |
+
D, I = index.search(output.cpu().numpy(), args.k)
|
| 106 |
+
|
| 107 |
+
# get the similar images
|
| 108 |
+
similar_images = [
|
| 109 |
+
PIL.Image.open(os.path.join(IMAGES_DIR, image_list[index]))
|
| 110 |
+
for index in I[0]
|
| 111 |
+
]
|
| 112 |
+
# plot the query results and save the plot
|
| 113 |
+
query_results_folderpath = f"{RESULTS_DIR}/results_{args.feat_extractor}"
|
| 114 |
+
os.makedirs(query_results_folderpath, exist_ok=True)
|
| 115 |
+
query_results_filepath = (
|
| 116 |
+
f"{query_results_folderpath}/query_{query_idx:03}.jpg"
|
| 117 |
+
)
|
| 118 |
+
plot_query_results(
|
| 119 |
+
img, similar_images, D[0], out_filepath=query_results_filepath
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
if __name__ == "__main__":
|
| 124 |
+
# parse arguments
|
| 125 |
+
args = argparse.ArgumentParser()
|
| 126 |
+
args.add_argument(
|
| 127 |
+
"--feat_extractor",
|
| 128 |
+
type=str,
|
| 129 |
+
choices=FEATURE_EXTRACTOR_MODELS,
|
| 130 |
+
required=True,
|
| 131 |
+
)
|
| 132 |
+
args.add_argument(
|
| 133 |
+
"--n",
|
| 134 |
+
type=int,
|
| 135 |
+
default=10,
|
| 136 |
+
help="Number of random images to select",
|
| 137 |
+
)
|
| 138 |
+
args.add_argument(
|
| 139 |
+
"--k",
|
| 140 |
+
type=int,
|
| 141 |
+
default=12,
|
| 142 |
+
help="Number of similar images to retrieve",
|
| 143 |
+
)
|
| 144 |
+
args.add_argument(
|
| 145 |
+
"--seed",
|
| 146 |
+
type=int,
|
| 147 |
+
default=777,
|
| 148 |
+
help="Random seed for reproducibility",
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
args = args.parse_args()
|
| 152 |
+
|
| 153 |
+
# run the main function
|
| 154 |
+
main(args)
|
src/utils/helper.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
from PIL import Image
|
| 3 |
+
from fastapi import HTTPException
|
| 4 |
+
from io import BytesIO
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def base64_to_image(base64_str: str) -> Image.Image:
|
| 8 |
+
try:
|
| 9 |
+
image_data = base64.b64decode(base64_str)
|
| 10 |
+
image = Image.open(BytesIO(image_data)).convert("RGB")
|
| 11 |
+
return image
|
| 12 |
+
except Exception as e:
|
| 13 |
+
raise HTTPException(status_code=400, detail="Invalid Base64 image")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def image_to_base64(image: Image.Image) -> str:
|
| 17 |
+
buffered = BytesIO()
|
| 18 |
+
image.save(buffered, format="JPEG")
|
| 19 |
+
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|