ABAO77 commited on
Commit
322d2b7
·
verified ·
1 Parent(s): 9ef7463

Upload 7 files

Browse files
src/config/settings.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ WORK_DIR = "./"
2
+ DATA_DIR = f"{WORK_DIR}/data"
3
+ IMAGES_DIR = f"{DATA_DIR}/images"
4
+ RESULTS_DIR = f"{WORK_DIR}/results"
5
+
6
+ # supported feature extractor models
7
+ FEATURE_EXTRACTOR_MODELS = [
8
+ "resnet18",
9
+ "resnet34",
10
+ "resnet50",
11
+ "resnet101",
12
+ "resnet152",
13
+ "vit_b_16",
14
+ "vit_b_32",
15
+ "vit_l_16",
16
+ "vit_l_32",
17
+ "vit_h_14",
18
+ ]
src/firebase/firebase_config.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import firebase_admin
2
+ from firebase_admin import credentials
3
+ from firebase_admin import storage
4
+
5
+ import os
6
+ from dotenv import load_dotenv
7
+
8
+ # Load environment variables
9
+ load_dotenv()
10
+ firebase_url_storageBucket = os.getenv("URL_STORAGEBUCKET")
11
+
12
+ # Get credentials from environment variables
13
+ credential_firebase = {
14
+ "type": os.getenv("TYPE"),
15
+ "project_id": os.getenv("PROJECT_ID"),
16
+ "private_key_id": os.getenv("PRIVATE_KEY_ID"),
17
+ "private_key": os.getenv("PRIVATE_KEY"),
18
+ "client_email": os.getenv("CLIENT_EMAIL"),
19
+ "client_id": os.getenv("CLIENT_ID"),
20
+ "auth_uri": os.getenv("AUTH_URI"),
21
+ "token_uri": os.getenv("TOKEN_URI"),
22
+ "auth_provider_x509_cert_url": os.getenv("AUTH_PROVIDER_X509_CERT_URL"),
23
+ "client_x509_cert_url": os.getenv("CLIENT_X509_CERT_URL"),
24
+ "universe_domain": os.getenv("UNIVERSE_DOMAIN"),
25
+ }
26
+
27
+
28
+ # Check if the app is not initialized yet
29
+ if not firebase_admin._apps:
30
+ # Initialize the app with the credentials
31
+ cred = credentials.Certificate(credential_firebase)
32
+ firebase_admin.initialize_app(cred, {"storageBucket": firebase_url_storageBucket})
33
+
34
+ # Initialize Firestore
35
+ firebase_bucket = storage.bucket(app=firebase_admin.get_app())
36
+ print("Storage connected")
src/firebase/firebase_provider.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .firebase_config import firebase_bucket
2
+ import base64
3
+ import os
4
+ import tempfile
5
+ from PIL import Image
6
+ import io
7
+ import asyncio
8
+ from typing import List, Optional
9
+ from datetime import datetime
10
+ import pytz
11
+
12
+
13
+ import asyncio
14
+ import functools
15
+
16
+
17
+ def upload_file_to_storage_sync(file_path, file_name):
18
+ """
19
+ Synchronous function to upload a file to Firebase Storage.
20
+
21
+ param:
22
+ file_path: str - The path of the file on the local machine to be uploaded.
23
+ file_name: str - The name of the file in Firebase Storage.
24
+
25
+ return:
26
+ str - The public URL of the uploaded file.
27
+ """
28
+ blob = firebase_bucket.blob(file_name)
29
+ blob.upload_from_filename(file_path)
30
+ blob.make_public()
31
+
32
+ return blob.public_url
33
+
34
+
35
+ async def upload_file_to_storage(file_path: str, file_name: str) -> str:
36
+ """
37
+ Asynchronous wrapper to upload a file to Firebase Storage using a thread pool.
38
+
39
+ param:
40
+ file_path: str - The path of the file on the local machine to be uploaded.
41
+ file_name: str - The name of the file in Firebase Storage.
42
+
43
+ return:
44
+ str - The public URL of the uploaded file.
45
+ """
46
+ loop = asyncio.get_event_loop()
47
+
48
+ # Run the synchronous `upload_file_to_storage_sync` in a thread pool.
49
+ public_url = await loop.run_in_executor(
50
+ None, functools.partial(upload_file_to_storage_sync, file_path, file_name)
51
+ )
52
+
53
+ return public_url
54
+
55
+
56
+ def delete_file_from_storage(file_name):
57
+ """
58
+ Delete a file from Firebase Storage
59
+ param:
60
+ file_name: str - The name of the file to be deleted
61
+ return:
62
+ bool - True if the file is deleted successfully, False if the file is not found
63
+ """
64
+ try:
65
+ blob = firebase_bucket.blob(file_name)
66
+ blob.delete()
67
+ return True
68
+ except Exception as e:
69
+ print("Error:", e)
70
+ return False
71
+
72
+
73
+ def list_all_files_in_storage():
74
+ """
75
+ View all files in Firebase Storage
76
+ return:
77
+ dict - Dictionary with keys are names and values are url of all files in Firebase Storage
78
+ """
79
+ blobs = firebase_bucket.list_blobs()
80
+ blob_dict = {blob.name: blob.public_url for blob in blobs}
81
+ return blob_dict
82
+
83
+
84
+ def download_file_from_storage(file_name, destination_path):
85
+ """
86
+ Download a file from Firebase Storage
87
+ param:
88
+ file_name: str - The name of the file to be downloaded
89
+ destination_path: str - The path to save the downloaded file
90
+ return:
91
+ bool - True if the file is downloaded successfully, False if the file is not found
92
+ """
93
+ try:
94
+ blob = firebase_bucket.blob(file_name)
95
+ blob.download_to_filename(destination_path)
96
+ print("da tai xun thanh cong")
97
+ return True
98
+ except Exception as e:
99
+ print("Error:", e)
100
+ return False
101
+
102
+
103
+ async def upload_base64_image_to_storage(
104
+ base64_image: str, file_name: str
105
+ ) -> Optional[str]:
106
+ """
107
+ Upload a base64 image to Firebase Storage asynchronously.
108
+
109
+ Args:
110
+ base64_image: str - The base64 encoded image
111
+ file_name: str - The name of the file to be uploaded
112
+
113
+ Returns:
114
+ Optional[str] - The public URL of the uploaded file or None if failed
115
+ """
116
+ try:
117
+ # Run CPU-intensive operations in thread pool
118
+ loop = asyncio.get_event_loop()
119
+
120
+ # Decode base64 in thread pool
121
+ image_data = await loop.run_in_executor(
122
+ None, lambda: base64.b64decode(base64_image)
123
+ )
124
+
125
+ # Open and process image in thread pool
126
+ image = await loop.run_in_executor(
127
+ None, lambda: Image.open(io.BytesIO(image_data))
128
+ )
129
+
130
+ # Create unique temp file path
131
+ temp_file_path = os.path.join(
132
+ tempfile.gettempdir(), f"{file_name}_{datetime.now().timestamp()}.jpg"
133
+ )
134
+
135
+ # Save image in thread pool
136
+ await loop.run_in_executor(
137
+ None, lambda: image.save(temp_file_path, format="JPEG")
138
+ )
139
+
140
+ try:
141
+ # Upload to Firebase
142
+ public_url = await upload_file_to_storage(
143
+ temp_file_path, f"{file_name}.jpg"
144
+ )
145
+ return public_url
146
+ finally:
147
+ # Clean up temp file in thread pool
148
+ await loop.run_in_executor(None, os.remove, temp_file_path)
149
+
150
+ except Exception as e:
151
+ print(f"Error processing image {file_name}: {str(e)}")
152
+ return None
153
+
154
+
155
+ async def process_images(base64_images: List[str]) -> List[Optional[str]]:
156
+ """
157
+ Process multiple base64 images concurrently.
158
+
159
+ Args:
160
+ base64_images: List[str] - List of base64 encoded images
161
+
162
+ Returns:
163
+ List[Optional[str]] - List of public URLs or None for failed uploads
164
+ """
165
+ tasks = []
166
+ for idx, base64_image in enumerate(base64_images):
167
+ timestamp = (
168
+ datetime.now(pytz.timezone("Asia/Ho_Chi_Minh"))
169
+ .replace(tzinfo=None)
170
+ .strftime("%Y-%m-%d_%H-%M-%S")
171
+ )
172
+ file_name = f"image_{timestamp}_{idx}"
173
+ tasks.append(upload_base64_image_to_storage(base64_image, file_name))
174
+
175
+ return await asyncio.gather(*tasks, return_exceptions=True)
src/modules/config_extractor.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision
2
+
3
+
4
+ # Config for the models that are supported by the extractor
5
+ MODEL_CONFIG = {
6
+ "resnet18": {
7
+ "weights": torchvision.models.ResNet18_Weights.DEFAULT,
8
+ "model": torchvision.models.resnet18,
9
+ "feat_layer": "flatten",
10
+ "feat_dims": 512,
11
+ },
12
+ "resnet34": {
13
+ "weights": torchvision.models.ResNet34_Weights.DEFAULT,
14
+ "model": torchvision.models.resnet34,
15
+ "feat_layer": "flatten",
16
+ "feat_dims": 512,
17
+ },
18
+ "resnet50": {
19
+ "weights": torchvision.models.ResNet50_Weights.DEFAULT,
20
+ "model": torchvision.models.resnet50,
21
+ "feat_layer": "flatten",
22
+ "feat_dims": 2048,
23
+ },
24
+ "resnet101": {
25
+ "weights": torchvision.models.ResNet101_Weights.DEFAULT,
26
+ "model": torchvision.models.resnet101,
27
+ "feat_layer": "flatten",
28
+ "feat_dims": 2048,
29
+ },
30
+ "resnet152": {
31
+ "weights": torchvision.models.ResNet152_Weights.DEFAULT,
32
+ "model": torchvision.models.resnet152,
33
+ "feat_layer": "flatten",
34
+ "feat_dims": 2048,
35
+ },
36
+ "vit_b_16": {
37
+ "weights": torchvision.models.ViT_B_16_Weights.DEFAULT,
38
+ "model": torchvision.models.vit_b_16,
39
+ "feat_layer": "getitem_5",
40
+ "feat_dims": 768,
41
+ },
42
+ "vit_b_32": {
43
+ "weights": torchvision.models.ViT_B_32_Weights.DEFAULT,
44
+ "model": torchvision.models.vit_b_32,
45
+ "feat_layer": "getitem_5",
46
+ "feat_dims": 768,
47
+ },
48
+ "vit_l_16": {
49
+ "weights": torchvision.models.ViT_L_16_Weights.DEFAULT,
50
+ "model": torchvision.models.vit_l_16,
51
+ "feat_layer": "getitem_5",
52
+ "feat_dims": 1024,
53
+ },
54
+ "vit_l_32": {
55
+ "weights": torchvision.models.ViT_L_32_Weights.DEFAULT,
56
+ "model": torchvision.models.vit_l_32,
57
+ "feat_layer": "getitem_5",
58
+ "feat_dims": 1024,
59
+ },
60
+ "vit_h_14": {
61
+ "weights": torchvision.models.ViT_H_14_Weights.DEFAULT,
62
+ "model": torchvision.models.vit_h_14,
63
+ "feat_layer": "getitem_5",
64
+ "feat_dims": 1280,
65
+ },
66
+ }
src/modules/feature_extractor.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision.models.feature_extraction
2
+ import torchvision
3
+ import os
4
+ import torch
5
+ import onnx
6
+ import onnxruntime
7
+
8
+ from src.modules.config_extractor import MODEL_CONFIG
9
+
10
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
11
+
12
+
13
+ class FeatureExtractor:
14
+ """Class for extracting features from images using a pre-trained model"""
15
+
16
+ def __init__(self, base_model, onnx_path=None):
17
+ # set the base model
18
+ self.base_model = base_model
19
+ # get the number of features
20
+ self.feat_dims = MODEL_CONFIG[base_model]["feat_dims"]
21
+ # get the feature layer name
22
+ self.feat_layer = MODEL_CONFIG[base_model]["feat_layer"]
23
+
24
+ # Set default ONNX path if not provided
25
+ if onnx_path is None:
26
+ onnx_path = f"model/{base_model}_feature_extractor.onnx"
27
+
28
+ self.onnx_path = onnx_path
29
+ self.onnx_session = None
30
+
31
+ # Initialize transforms (needed for both ONNX and PyTorch)
32
+ _, self.transforms = self.init_model(base_model)
33
+
34
+ # Check if ONNX model exists
35
+ if os.path.exists(onnx_path):
36
+ print(f"Loading existing ONNX model from {onnx_path}")
37
+ self.onnx_session = onnxruntime.InferenceSession(onnx_path)
38
+ else:
39
+ print(
40
+ f"ONNX model not found at {onnx_path}. Initializing PyTorch model and converting to ONNX..."
41
+ )
42
+ # Initialize PyTorch model
43
+ self.model, _ = self.init_model(base_model)
44
+ self.model.eval()
45
+ self.device = torch.device("cpu")
46
+ self.model.to(self.device)
47
+
48
+ # Create directory if it doesn't exist
49
+ os.makedirs(os.path.dirname(onnx_path), exist_ok=True)
50
+
51
+ # Convert to ONNX
52
+ self.convert_to_onnx(onnx_path)
53
+
54
+ # Load the newly created ONNX model
55
+ self.onnx_session = onnxruntime.InferenceSession(onnx_path)
56
+ print(f"Successfully created and loaded ONNX model from {onnx_path}")
57
+
58
+ def init_model(self, base_model):
59
+ """Initialize the model for feature extraction
60
+
61
+ Args:
62
+ base_model: str, the name of the base model
63
+
64
+ Returns:
65
+ model: torch.nn.Module, the feature extraction model
66
+ transforms: torchvision.transforms.Compose, the image transformations
67
+ """
68
+ if base_model not in MODEL_CONFIG:
69
+ raise ValueError(f"Invalid base model: {base_model}")
70
+
71
+ # get the model and weights
72
+ weights = MODEL_CONFIG[base_model]["weights"]
73
+ model = torchvision.models.feature_extraction.create_feature_extractor(
74
+ MODEL_CONFIG[base_model]["model"](weights=weights),
75
+ [MODEL_CONFIG[base_model]["feat_layer"]],
76
+ )
77
+ # get the image transformations
78
+ transforms = weights.transforms()
79
+ return model, transforms
80
+
81
+ def extract_features(self, img):
82
+ """Extract features from an image
83
+
84
+ Args:
85
+ img: PIL.Image, the input image
86
+
87
+ Returns:
88
+ output: torch.Tensor, the extracted features
89
+ """
90
+ # apply transformations
91
+ x = self.transforms(img)
92
+ # add batch dimension
93
+ x = x.unsqueeze(0)
94
+
95
+ # Convert to numpy for ONNX Runtime
96
+ x_numpy = x.numpy()
97
+ # Run inference with ONNX Runtime
98
+ print("Running inference with ONNX Runtime")
99
+ output = self.onnx_session.run(
100
+ None,
101
+ {'input': x_numpy}
102
+ )[0]
103
+ # Convert back to torch tensor
104
+ output = torch.from_numpy(output)
105
+
106
+ return output
107
+
108
+ def convert_to_onnx(self, save_path):
109
+ """Convert the model to ONNX format and save it
110
+
111
+ Args:
112
+ save_path: str, the path to save the ONNX model
113
+
114
+ Returns:
115
+ None
116
+ """
117
+ # Create a dummy input tensor
118
+ dummy_input = torch.randn(1, 3, 224, 224, device=self.device)
119
+
120
+ # Export the model
121
+ torch.onnx.export(
122
+ self.model,
123
+ dummy_input,
124
+ save_path,
125
+ export_params=True,
126
+ opset_version=14,
127
+ do_constant_folding=True,
128
+ input_names=['input'],
129
+ output_names=['output'],
130
+ dynamic_axes={
131
+ 'input': {0: 'batch_size'},
132
+ 'output': {0: 'batch_size'}
133
+ }
134
+ )
135
+
136
+ # Verify the exported model
137
+ onnx_model = onnx.load(save_path)
138
+ onnx.checker.check_model(onnx_model)
139
+ print(f"ONNX model saved to {save_path}")
src/search_query.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Description:
2
+ # This script is used to query the index for similar images to a set of random images.
3
+ # The script uses the FeatureExtractor class to extract the features from the images and the Faiss index to search for similar images.
4
+ #
5
+ # Usage:
6
+ #
7
+ # To use this script, you can run the following commands: (You MUST define a feat_extractor since indexings are different for each model)
8
+ # python3 search_query.py --feat_extractor resnet50
9
+ # python3 search_query.py --feat_extractor resnet101
10
+ # python3 search_query.py --feat_extractor resnet50 --n 5
11
+ # python3 search_query.py --feat_extractor resnet50 --k 20
12
+ # python3 search_query.py --feat_extractor resnet50 --n 10 --k 12
13
+ #
14
+ import matplotlib.pyplot as plt
15
+ import numpy as np
16
+ import argparse
17
+ import torch
18
+ import faiss
19
+ import PIL
20
+ import os
21
+
22
+ from src.modules.feature_extractor import FeatureExtractor
23
+ from src.config.settings import FEATURE_EXTRACTOR_MODELS
24
+ from src.config.settings import DATA_DIR, IMAGES_DIR, RESULTS_DIR
25
+
26
+
27
+ def select_random_images(n, image_list):
28
+ """Select n random images from the image list.
29
+
30
+ Args:
31
+ n (int): The number of images to select.
32
+ image_list (list[str]): The list of image file names.
33
+
34
+ Returns:
35
+ list[PIL.Image]: The list of selected images.
36
+ """
37
+ selected_indices = np.random.randint(len(image_list), size=n)
38
+ img_filenames = [image_list[i] for i in selected_indices]
39
+ images = [
40
+ PIL.Image.open(os.path.join(IMAGES_DIR, img_filename))
41
+ for img_filename in img_filenames
42
+ ]
43
+ return images
44
+
45
+
46
+ def plot_query_results(query_img, similar_imgs, distances, out_filepath):
47
+ """Plot the query image and the similar images side by side. Save the plot to the specified file path.
48
+
49
+ Args:
50
+ query_img (PIL.Image): The query image.
51
+ similar_imgs (list[PIL.Image]): The list of similar images.
52
+ distances (list[float]): The list of distances of the similar images.
53
+ out_filepath (str): The file path to save the plot.
54
+
55
+ Returns:
56
+ None
57
+ """
58
+ # initialize the figure
59
+ fig, axes = plt.subplots(3, args.k // 2, figsize=(20, 10))
60
+ # plot the query image
61
+ axes[0, 0].imshow(query_img)
62
+ axes[0, 0].set_title("Query Image")
63
+ axes[0, 0].axis("off")
64
+ # do not draw the remaining pots in the first row
65
+ for i in range(1, args.k // 2):
66
+ axes[0, i].axis("off")
67
+ # plot the similar images
68
+ for i, (img, dist) in enumerate(zip(similar_imgs, distances)):
69
+ axes[i // (args.k // 2) + 1, i % (args.k // 2)].imshow(img)
70
+ axes[i // (args.k // 2) + 1, i % (args.k // 2)].set_title(f"{dist:.4f}")
71
+ axes[i // (args.k // 2) + 1, i % (args.k // 2)].axis("off")
72
+ # remove the remaining axes
73
+ plt.tight_layout()
74
+ # save the plot
75
+ plt.savefig(out_filepath, bbox_inches="tight", dpi=200)
76
+
77
+
78
+ def main(args=None):
79
+
80
+ # set the random seed for reproducibility
81
+ np.random.seed(args.seed)
82
+
83
+ # load the vector database index
84
+ index_filepath = os.path.join(DATA_DIR, f"db_{args.feat_extractor}.index")
85
+ index = faiss.read_index(index_filepath)
86
+
87
+ # initialize the feature extractor with the base model specified in the arguments
88
+ feature_extractor = FeatureExtractor(base_model=args.feat_extractor)
89
+
90
+ # get the list of images in sorted order since the index is built in the same order
91
+ image_list = sorted(os.listdir(IMAGES_DIR))
92
+ # select n random images
93
+ query_images = select_random_images(args.n, image_list)
94
+
95
+ with torch.no_grad():
96
+ # iterate over the selected/query images
97
+ for query_idx, img in enumerate(query_images, start=1):
98
+ # output now has the features corresponding to input x
99
+ output = feature_extractor.extract_ưfeatures(img)
100
+ # keep only batch dimension
101
+ output = output.view(output.size(0), -1)
102
+ # normalize
103
+ output = output / output.norm(p=2, dim=1, keepdim=True)
104
+ # search for similar images
105
+ D, I = index.search(output.cpu().numpy(), args.k)
106
+
107
+ # get the similar images
108
+ similar_images = [
109
+ PIL.Image.open(os.path.join(IMAGES_DIR, image_list[index]))
110
+ for index in I[0]
111
+ ]
112
+ # plot the query results and save the plot
113
+ query_results_folderpath = f"{RESULTS_DIR}/results_{args.feat_extractor}"
114
+ os.makedirs(query_results_folderpath, exist_ok=True)
115
+ query_results_filepath = (
116
+ f"{query_results_folderpath}/query_{query_idx:03}.jpg"
117
+ )
118
+ plot_query_results(
119
+ img, similar_images, D[0], out_filepath=query_results_filepath
120
+ )
121
+
122
+
123
+ if __name__ == "__main__":
124
+ # parse arguments
125
+ args = argparse.ArgumentParser()
126
+ args.add_argument(
127
+ "--feat_extractor",
128
+ type=str,
129
+ choices=FEATURE_EXTRACTOR_MODELS,
130
+ required=True,
131
+ )
132
+ args.add_argument(
133
+ "--n",
134
+ type=int,
135
+ default=10,
136
+ help="Number of random images to select",
137
+ )
138
+ args.add_argument(
139
+ "--k",
140
+ type=int,
141
+ default=12,
142
+ help="Number of similar images to retrieve",
143
+ )
144
+ args.add_argument(
145
+ "--seed",
146
+ type=int,
147
+ default=777,
148
+ help="Random seed for reproducibility",
149
+ )
150
+
151
+ args = args.parse_args()
152
+
153
+ # run the main function
154
+ main(args)
src/utils/helper.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ from PIL import Image
3
+ from fastapi import HTTPException
4
+ from io import BytesIO
5
+
6
+
7
+ def base64_to_image(base64_str: str) -> Image.Image:
8
+ try:
9
+ image_data = base64.b64decode(base64_str)
10
+ image = Image.open(BytesIO(image_data)).convert("RGB")
11
+ return image
12
+ except Exception as e:
13
+ raise HTTPException(status_code=400, detail="Invalid Base64 image")
14
+
15
+
16
+ def image_to_base64(image: Image.Image) -> str:
17
+ buffered = BytesIO()
18
+ image.save(buffered, format="JPEG")
19
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")