ashutosh-koottu commited on
Commit
a1075ab
·
1 Parent(s): a54e757

Adds caching and parallel processing

Browse files
Files changed (1) hide show
  1. handler.py +99 -58
handler.py CHANGED
@@ -15,6 +15,7 @@ from PIL import Image
15
  from azure.storage.blob import BlobServiceClient, BlobClient, ContainerClient
16
  from config import get_config
17
  import time
 
18
 
19
  class EndpointHandler:
20
  def __init__(self, model_dir=None):
@@ -45,6 +46,12 @@ class EndpointHandler:
45
 
46
  # Get container client
47
  self.container_client = self.blob_service_client.get_container_client(self.container_name)
 
 
 
 
 
 
48
 
49
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
50
  try:
@@ -79,7 +86,15 @@ class EndpointHandler:
79
  raise ValueError("Invalid JSON structure.")
80
 
81
  def load_embeddings_from_azure(self):
82
- """Load existing embeddings from Azure Blob Storage if they exist, else return an empty list."""
 
 
 
 
 
 
 
 
83
  try:
84
  # Check if embeddings file exists in Azure - look in profile-media/embeddings/
85
  blob_name = f'profile-media/embeddings/embeddings_db.json'
@@ -94,9 +109,14 @@ class EndpointHandler:
94
  download_file.write(download_stream.readall())
95
 
96
  with open(temp_file_path, 'r') as f:
97
- return json.load(f)
 
 
 
98
  except Exception as e:
99
  print(f'Embeddings file not found in Azure, initializing a new one: {e}')
 
 
100
  return []
101
 
102
  def extract_and_save_embeddings(self):
@@ -225,66 +245,57 @@ class EndpointHandler:
225
  print(f"Debug: Starting similarity search with {len(query_images)} query images")
226
  print(f"Debug: Looking for gender: {gender}, top_n: {top_n}")
227
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  similarities = {}
 
 
229
  for i, image_input in enumerate(query_images):
230
- print(f"Debug: Processing query image {i+1}/{len(query_images)}: {image_input}")
 
 
 
 
 
 
231
  try:
232
- # Determine the type of image input
233
- if image_input.startswith('http'):
234
- # It's a URL
235
- img = self.load_image_from_url(image_input)
236
- elif image_input.startswith('data:image/'):
237
- # It's a base64-encoded image
238
- img = self.load_image_from_base64(image_input)
239
- else:
240
- # It's a local file path reference - convert to full Azure blob URL
241
- blob_url = f"https://koottuprod.blob.core.windows.net/koottu-media/{image_input}"
242
- img = self.load_image_from_url(blob_url)
243
-
244
- if img is None:
245
- print(f"Failed to load image: {image_input}")
246
- continue
247
-
248
- faces = self.app.get(img)
249
- if len(faces) == 0:
250
- print(f"Debug: No faces detected in query image {i+1}")
251
- continue
252
-
253
- query_embedding = faces[0].embedding
254
- print(f"Debug: Successfully extracted face embedding from query image {i+1}")
255
-
256
- # Load embeddings database from Azure
257
- embeddings_db = self.load_embeddings_from_azure()
258
- print(f"Debug: Total embeddings in database: {len(embeddings_db)}")
259
-
260
- # Filter to only include images from profile-media folder structure
261
- profile_media_db = [item for item in embeddings_db if 'image_url' in item and 'profile-media' in item['image_url']]
262
- print(f"Debug: Profile-media embeddings: {len(profile_media_db)}")
263
-
264
- # Filter by gender: if 'all', include all items with gender field; otherwise filter by specific gender
265
- if gender == 'all':
266
- filtered_db = [item for item in profile_media_db if 'gender' in item and 'embedding' in item]
267
- else:
268
- filtered_db = [item for item in profile_media_db if 'gender' in item and item['gender'] == gender and 'embedding' in item]
269
- print(f"Debug: Filtered by gender '{gender}': {len(filtered_db)}")
270
-
271
- if len(filtered_db) == 0:
272
- print(f"Debug: No embeddings found for gender '{gender}' in profile-media folder")
273
- print(f"Debug: Available genders in profile-media: {list(set([item.get('gender') for item in profile_media_db if 'gender' in item]))}")
274
- continue
275
-
276
- for item in filtered_db:
277
- similarity = 1 - cosine(query_embedding, np.array(item['embedding']))
278
- if item['image_url'] in similarities:
279
- similarities[item['image_url']].append(similarity)
280
- else:
281
- similarities[item['image_url']] = [similarity]
282
-
283
  except Exception as e:
284
- error_message = f"Error processing image input: {e}"
285
- print(error_message)
286
- # Return empty list instead of error dict
287
- return []
 
 
 
 
 
 
 
 
 
 
288
 
289
  # Aggregate similarities
290
  print(f"Debug: Total similarities found: {len(similarities)}")
@@ -293,6 +304,36 @@ class EndpointHandler:
293
  result = [url for _, url in aggregated_similarities[:top_n]]
294
  print(f"Debug: Returning {len(result)} recommendations")
295
  return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
 
297
  def find_similar_images_by_embedding(self, query_embedding: np.ndarray, gender: str = 'all', top_n: int = 10, excluded_images: List[str] = None) -> List[str]:
298
  """Find similar images based on a given embedding vector."""
 
15
  from azure.storage.blob import BlobServiceClient, BlobClient, ContainerClient
16
  from config import get_config
17
  import time
18
+ from concurrent.futures import ThreadPoolExecutor, as_completed
19
 
20
  class EndpointHandler:
21
  def __init__(self, model_dir=None):
 
46
 
47
  # Get container client
48
  self.container_client = self.blob_service_client.get_container_client(self.container_name)
49
+
50
+ # Initialize caching
51
+ self.embeddings_cache = None
52
+ self.cache_timestamp = 0
53
+ self.cache_ttl = 3600 # 1 hour in seconds
54
+ self.thread_pool = ThreadPoolExecutor(max_workers=4)
55
 
56
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
57
  try:
 
86
  raise ValueError("Invalid JSON structure.")
87
 
88
  def load_embeddings_from_azure(self):
89
+ """Load existing embeddings from Azure Blob Storage with caching."""
90
+ current_time = time.time()
91
+
92
+ # Return cached embeddings if still valid
93
+ if self.embeddings_cache is not None and (current_time - self.cache_timestamp) < self.cache_ttl:
94
+ print(f"Using cached embeddings (age: {int(current_time - self.cache_timestamp)}s)")
95
+ return self.embeddings_cache
96
+
97
+ print("Fetching embeddings from Azure...")
98
  try:
99
  # Check if embeddings file exists in Azure - look in profile-media/embeddings/
100
  blob_name = f'profile-media/embeddings/embeddings_db.json'
 
109
  download_file.write(download_stream.readall())
110
 
111
  with open(temp_file_path, 'r') as f:
112
+ self.embeddings_cache = json.load(f)
113
+ self.cache_timestamp = current_time
114
+ print(f"Loaded {len(self.embeddings_cache)} embeddings from Azure")
115
+ return self.embeddings_cache
116
  except Exception as e:
117
  print(f'Embeddings file not found in Azure, initializing a new one: {e}')
118
+ self.embeddings_cache = []
119
+ self.cache_timestamp = current_time
120
  return []
121
 
122
  def extract_and_save_embeddings(self):
 
245
  print(f"Debug: Starting similarity search with {len(query_images)} query images")
246
  print(f"Debug: Looking for gender: {gender}, top_n: {top_n}")
247
 
248
+ # Load embeddings database once (cached)
249
+ embeddings_db = self.load_embeddings_from_azure()
250
+ print(f"Debug: Total embeddings in database: {len(embeddings_db)}")
251
+
252
+ # Filter to only include images from profile-media folder structure
253
+ profile_media_db = [item for item in embeddings_db if 'image_url' in item and 'profile-media' in item['image_url']]
254
+ print(f"Debug: Profile-media embeddings: {len(profile_media_db)}")
255
+
256
+ # Filter by gender: if 'all', include all items with gender field; otherwise filter by specific gender
257
+ if gender == 'all':
258
+ filtered_db = [item for item in profile_media_db if 'gender' in item and 'embedding' in item]
259
+ else:
260
+ filtered_db = [item for item in profile_media_db if 'gender' in item and item['gender'] == gender and 'embedding' in item]
261
+ print(f"Debug: Filtered by gender '{gender}': {len(filtered_db)}")
262
+
263
+ if len(filtered_db) == 0:
264
+ print(f"Debug: No embeddings found for gender '{gender}' in profile-media folder")
265
+ return []
266
+
267
+ # Process query images in parallel
268
  similarities = {}
269
+ futures = {}
270
+
271
  for i, image_input in enumerate(query_images):
272
+ future = self.thread_pool.submit(self._extract_query_embedding, image_input, i)
273
+ futures[future] = i
274
+
275
+ # Collect results from parallel processing
276
+ query_embeddings = []
277
+ for future in as_completed(futures):
278
+ i = futures[future]
279
  try:
280
+ query_embedding = future.result()
281
+ if query_embedding is not None:
282
+ query_embeddings.append(query_embedding)
283
+ print(f"Debug: Successfully extracted face embedding from query image {i+1}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  except Exception as e:
285
+ print(f"Debug: Error processing query image {i+1}: {e}")
286
+
287
+ if not query_embeddings:
288
+ print("Debug: No valid query embeddings extracted")
289
+ return []
290
+
291
+ # Compute similarities for all query embeddings against filtered database
292
+ for query_embedding in query_embeddings:
293
+ for item in filtered_db:
294
+ similarity = 1 - cosine(query_embedding, np.array(item['embedding']))
295
+ if item['image_url'] in similarities:
296
+ similarities[item['image_url']].append(similarity)
297
+ else:
298
+ similarities[item['image_url']] = [similarity]
299
 
300
  # Aggregate similarities
301
  print(f"Debug: Total similarities found: {len(similarities)}")
 
304
  result = [url for _, url in aggregated_similarities[:top_n]]
305
  print(f"Debug: Returning {len(result)} recommendations")
306
  return result
307
+
308
+ def _extract_query_embedding(self, image_input: str, index: int) -> Any:
309
+ """Extract embedding from a single query image (for parallel processing)."""
310
+ try:
311
+ print(f"Debug: Processing query image {index+1}: {image_input}")
312
+ # Determine the type of image input
313
+ if image_input.startswith('http'):
314
+ # It's a URL
315
+ img = self.load_image_from_url(image_input)
316
+ elif image_input.startswith('data:image/'):
317
+ # It's a base64-encoded image
318
+ img = self.load_image_from_base64(image_input)
319
+ else:
320
+ # It's a local file path reference - convert to full Azure blob URL
321
+ blob_url = f"https://koottuprod.blob.core.windows.net/koottu-media/{image_input}"
322
+ img = self.load_image_from_url(blob_url)
323
+
324
+ if img is None:
325
+ print(f"Failed to load image: {image_input}")
326
+ return None
327
+
328
+ faces = self.app.get(img)
329
+ if len(faces) == 0:
330
+ print(f"Debug: No faces detected in query image {index+1}")
331
+ return None
332
+
333
+ return faces[0].embedding
334
+ except Exception as e:
335
+ print(f"Error extracting embedding from image {index+1}: {e}")
336
+ return None
337
 
338
  def find_similar_images_by_embedding(self, query_embedding: np.ndarray, gender: str = 'all', top_n: int = 10, excluded_images: List[str] = None) -> List[str]:
339
  """Find similar images based on a given embedding vector."""