onurcopur commited on
Commit
8880ccb
·
1 Parent(s): 917a5d3

unfreeze caption generator

Browse files
Files changed (1) hide show
  1. main.py +89 -50
main.py CHANGED
@@ -12,14 +12,15 @@ import requests
12
  import torch
13
  import torch.nn.functional as F
14
  from dotenv import load_dotenv
15
- from fastapi import FastAPI, File, HTTPException, UploadFile, Query
16
  from fastapi.middleware.cors import CORSMiddleware
17
  from huggingface_hub import InferenceClient
18
  from PIL import Image
 
 
 
19
  from search_engines import SearchEngineManager
20
  from utils import SearchCache, URLValidator
21
- from embeddings import EmbeddingModelFactory, EmbeddingModel, get_default_model_configs
22
- from patch_attention import PatchAttentionAnalyzer
23
 
24
  # Load environment variables from .env file
25
  load_dotenv()
@@ -95,26 +96,25 @@ class TattooSearchEngine:
95
  image_b64 = base64.b64encode(img_buffer.getvalue()).decode()
96
  image_url = f"data:image/jpeg;base64,{image_b64}"
97
 
98
- # completion = self.client.chat.completions.create(
99
- # model=self.vlm_model,
100
- # messages=[
101
- # {
102
- # "role": "user",
103
- # "content": [
104
- # {
105
- # "type": "text",
106
- # "text": "Generate a one search engine query to find the most similar tattoos to this image. Response in json format",
107
- # },
108
- # {
109
- # "type": "image_url",
110
- # "image_url": {"url": image_url},
111
- # },
112
- # ],
113
- # }
114
- # ],
115
- # )
116
- caption = '<|begin_of_box|>{"search_query": "hand tattoo geometric human figure abstract blackwork"}<|end_of_box|>'
117
- # caption = completion.choices[0].message.content
118
  if caption:
119
  match = re.search(r"\{.*\}", caption)
120
  if match:
@@ -256,8 +256,11 @@ class TattooSearchEngine:
256
  return None
257
 
258
  def download_and_process_image(
259
- self, url: str, query_features: torch.Tensor, query_image: Image.Image = None,
260
- include_patch_attention: bool = False
 
 
 
261
  ) -> Dict[str, Any]:
262
  """Download and compute similarity for a single image"""
263
  candidate_image = self.download_image(url)
@@ -266,7 +269,9 @@ class TattooSearchEngine:
266
 
267
  try:
268
  candidate_features = self.embedding_model.encode_image(candidate_image)
269
- similarity = self.embedding_model.compute_similarity(query_features, candidate_features)
 
 
270
 
271
  result = {"score": float(similarity), "url": url}
272
 
@@ -274,12 +279,16 @@ class TattooSearchEngine:
274
  if include_patch_attention and query_image is not None:
275
  try:
276
  analyzer = PatchAttentionAnalyzer(self.embedding_model)
277
- patch_data = analyzer.compute_patch_similarities(query_image, candidate_image)
 
 
278
  result["patch_attention"] = {
279
  "overall_similarity": patch_data["overall_similarity"],
280
  "query_grid_size": patch_data["query_grid_size"],
281
  "candidate_grid_size": patch_data["candidate_grid_size"],
282
- "attention_summary": analyzer.get_similarity_summary(patch_data)
 
 
283
  }
284
  except Exception as e:
285
  logger.warning(f"Failed to compute patch attention for {url}: {e}")
@@ -292,7 +301,10 @@ class TattooSearchEngine:
292
  return None
293
 
294
  def compute_similarity(
295
- self, query_image: Image.Image, candidate_urls: List[str], include_patch_attention: bool = False
 
 
 
296
  ) -> List[Dict[str, Any]]:
297
  # Encode query image using the selected embedding model
298
  query_features = self.embedding_model.encode_image(query_image)
@@ -306,7 +318,11 @@ class TattooSearchEngine:
306
  # Submit all download tasks
307
  future_to_url = {
308
  executor.submit(
309
- self.download_and_process_image, url, query_features, query_image, include_patch_attention
 
 
 
 
310
  ): url
311
  for url in candidate_urls
312
  }
@@ -343,10 +359,14 @@ class TattooSearchEngine:
343
  # Global variable to store search engine instance
344
  search_engine = None
345
 
 
346
  def get_search_engine(embedding_model: str = "clip") -> TattooSearchEngine:
347
  """Get or create search engine instance with specified embedding model."""
348
  global search_engine
349
- if search_engine is None or search_engine.embedding_model.get_model_name().lower() != embedding_model:
 
 
 
350
  search_engine = TattooSearchEngine(embedding_model)
351
  return search_engine
352
 
@@ -354,8 +374,12 @@ def get_search_engine(embedding_model: str = "clip") -> TattooSearchEngine:
354
  @app.post("/search")
355
  async def search_tattoos(
356
  file: UploadFile = File(...),
357
- embedding_model: str = Query(default="clip", description="Embedding model to use (clip, dinov2, siglip)"),
358
- include_patch_attention: bool = Query(default=False, description="Include patch-level attention analysis")
 
 
 
 
359
  ):
360
  if not file.content_type.startswith("image/"):
361
  raise HTTPException(status_code=400, detail="File must be an image")
@@ -366,7 +390,7 @@ async def search_tattoos(
366
  if embedding_model not in available_models:
367
  raise HTTPException(
368
  status_code=400,
369
- detail=f"Invalid embedding model. Available: {available_models}"
370
  )
371
 
372
  # Get search engine with specified embedding model
@@ -386,17 +410,23 @@ async def search_tattoos(
386
  candidate_urls = engine.search_images(caption, max_results=100)
387
 
388
  if not candidate_urls:
389
- return {"caption": caption, "results": [], "embedding_model": engine.embedding_model.get_model_name()}
 
 
 
 
390
 
391
  # Compute similarities and rank
392
  logger.info("Computing similarities...")
393
- results = engine.compute_similarity(query_image, candidate_urls, include_patch_attention)
 
 
394
 
395
  return {
396
  "caption": caption,
397
  "results": results,
398
  "embedding_model": engine.embedding_model.get_model_name(),
399
- "patch_attention_enabled": include_patch_attention
400
  }
401
 
402
  except Exception as e:
@@ -407,9 +437,15 @@ async def search_tattoos(
407
  @app.post("/analyze-attention")
408
  async def analyze_patch_attention(
409
  query_file: UploadFile = File(...),
410
- candidate_url: str = Query(..., description="URL of the candidate image to compare"),
411
- embedding_model: str = Query(default="clip", description="Embedding model to use (clip, dinov2, siglip)"),
412
- include_visualizations: bool = Query(default=True, description="Include attention visualizations")
 
 
 
 
 
 
413
  ):
414
  """Analyze patch-level attention between query image and a specific candidate image."""
415
  if not query_file.content_type.startswith("image/"):
@@ -421,7 +457,7 @@ async def analyze_patch_attention(
421
  if embedding_model not in available_models:
422
  raise HTTPException(
423
  status_code=400,
424
- detail=f"Invalid embedding model. Available: {available_models}"
425
  )
426
 
427
  # Get search engine with specified embedding model
@@ -434,11 +470,15 @@ async def analyze_patch_attention(
434
  # Download candidate image
435
  candidate_image = engine.download_image(candidate_url)
436
  if candidate_image is None:
437
- raise HTTPException(status_code=400, detail="Failed to download candidate image")
 
 
438
 
439
  # Analyze patch attention
440
  analyzer = PatchAttentionAnalyzer(engine.embedding_model)
441
- similarity_data = analyzer.compute_patch_similarities(query_image, candidate_image)
 
 
442
 
443
  result = {
444
  "query_image_size": query_image.size,
@@ -446,8 +486,10 @@ async def analyze_patch_attention(
446
  "candidate_url": candidate_url,
447
  "embedding_model": engine.embedding_model.get_model_name(),
448
  "similarity_analysis": analyzer.get_similarity_summary(similarity_data),
449
- "attention_matrix_shape": similarity_data['attention_matrix'].shape,
450
- "top_correspondences": similarity_data['top_correspondences'][:10] # Top 10
 
 
451
  }
452
 
453
  # Add visualizations if requested
@@ -462,7 +504,7 @@ async def analyze_patch_attention(
462
 
463
  result["visualizations"] = {
464
  "attention_heatmap": f"data:image/png;base64,{attention_heatmap}",
465
- "top_correspondences": f"data:image/png;base64,{top_correspondences_viz}"
466
  }
467
  except Exception as e:
468
  logger.warning(f"Failed to generate visualizations: {e}")
@@ -480,10 +522,7 @@ async def get_available_models():
480
  """Get list of available embedding models and their configurations."""
481
  models = EmbeddingModelFactory.get_available_models()
482
  configs = get_default_model_configs()
483
- return {
484
- "available_models": models,
485
- "model_configs": configs
486
- }
487
 
488
 
489
  @app.get("/health")
 
12
  import torch
13
  import torch.nn.functional as F
14
  from dotenv import load_dotenv
15
+ from fastapi import FastAPI, File, HTTPException, Query, UploadFile
16
  from fastapi.middleware.cors import CORSMiddleware
17
  from huggingface_hub import InferenceClient
18
  from PIL import Image
19
+
20
+ from embeddings import EmbeddingModel, EmbeddingModelFactory, get_default_model_configs
21
+ from patch_attention import PatchAttentionAnalyzer
22
  from search_engines import SearchEngineManager
23
  from utils import SearchCache, URLValidator
 
 
24
 
25
  # Load environment variables from .env file
26
  load_dotenv()
 
96
  image_b64 = base64.b64encode(img_buffer.getvalue()).decode()
97
  image_url = f"data:image/jpeg;base64,{image_b64}"
98
 
99
+ completion = self.client.chat.completions.create(
100
+ model=self.vlm_model,
101
+ messages=[
102
+ {
103
+ "role": "user",
104
+ "content": [
105
+ {
106
+ "type": "text",
107
+ "text": "Generate a one search engine query to find the most similar tattoos to this image. Response in json format",
108
+ },
109
+ {
110
+ "type": "image_url",
111
+ "image_url": {"url": image_url},
112
+ },
113
+ ],
114
+ }
115
+ ],
116
+ )
117
+ caption = completion.choices[0].message.content
 
118
  if caption:
119
  match = re.search(r"\{.*\}", caption)
120
  if match:
 
256
  return None
257
 
258
  def download_and_process_image(
259
+ self,
260
+ url: str,
261
+ query_features: torch.Tensor,
262
+ query_image: Image.Image = None,
263
+ include_patch_attention: bool = False,
264
  ) -> Dict[str, Any]:
265
  """Download and compute similarity for a single image"""
266
  candidate_image = self.download_image(url)
 
269
 
270
  try:
271
  candidate_features = self.embedding_model.encode_image(candidate_image)
272
+ similarity = self.embedding_model.compute_similarity(
273
+ query_features, candidate_features
274
+ )
275
 
276
  result = {"score": float(similarity), "url": url}
277
 
 
279
  if include_patch_attention and query_image is not None:
280
  try:
281
  analyzer = PatchAttentionAnalyzer(self.embedding_model)
282
+ patch_data = analyzer.compute_patch_similarities(
283
+ query_image, candidate_image
284
+ )
285
  result["patch_attention"] = {
286
  "overall_similarity": patch_data["overall_similarity"],
287
  "query_grid_size": patch_data["query_grid_size"],
288
  "candidate_grid_size": patch_data["candidate_grid_size"],
289
+ "attention_summary": analyzer.get_similarity_summary(
290
+ patch_data
291
+ ),
292
  }
293
  except Exception as e:
294
  logger.warning(f"Failed to compute patch attention for {url}: {e}")
 
301
  return None
302
 
303
  def compute_similarity(
304
+ self,
305
+ query_image: Image.Image,
306
+ candidate_urls: List[str],
307
+ include_patch_attention: bool = False,
308
  ) -> List[Dict[str, Any]]:
309
  # Encode query image using the selected embedding model
310
  query_features = self.embedding_model.encode_image(query_image)
 
318
  # Submit all download tasks
319
  future_to_url = {
320
  executor.submit(
321
+ self.download_and_process_image,
322
+ url,
323
+ query_features,
324
+ query_image,
325
+ include_patch_attention,
326
  ): url
327
  for url in candidate_urls
328
  }
 
359
  # Global variable to store search engine instance
360
  search_engine = None
361
 
362
+
363
  def get_search_engine(embedding_model: str = "clip") -> TattooSearchEngine:
364
  """Get or create search engine instance with specified embedding model."""
365
  global search_engine
366
+ if (
367
+ search_engine is None
368
+ or search_engine.embedding_model.get_model_name().lower() != embedding_model
369
+ ):
370
  search_engine = TattooSearchEngine(embedding_model)
371
  return search_engine
372
 
 
374
  @app.post("/search")
375
  async def search_tattoos(
376
  file: UploadFile = File(...),
377
+ embedding_model: str = Query(
378
+ default="clip", description="Embedding model to use (clip, dinov2, siglip)"
379
+ ),
380
+ include_patch_attention: bool = Query(
381
+ default=False, description="Include patch-level attention analysis"
382
+ ),
383
  ):
384
  if not file.content_type.startswith("image/"):
385
  raise HTTPException(status_code=400, detail="File must be an image")
 
390
  if embedding_model not in available_models:
391
  raise HTTPException(
392
  status_code=400,
393
+ detail=f"Invalid embedding model. Available: {available_models}",
394
  )
395
 
396
  # Get search engine with specified embedding model
 
410
  candidate_urls = engine.search_images(caption, max_results=100)
411
 
412
  if not candidate_urls:
413
+ return {
414
+ "caption": caption,
415
+ "results": [],
416
+ "embedding_model": engine.embedding_model.get_model_name(),
417
+ }
418
 
419
  # Compute similarities and rank
420
  logger.info("Computing similarities...")
421
+ results = engine.compute_similarity(
422
+ query_image, candidate_urls, include_patch_attention
423
+ )
424
 
425
  return {
426
  "caption": caption,
427
  "results": results,
428
  "embedding_model": engine.embedding_model.get_model_name(),
429
+ "patch_attention_enabled": include_patch_attention,
430
  }
431
 
432
  except Exception as e:
 
437
  @app.post("/analyze-attention")
438
  async def analyze_patch_attention(
439
  query_file: UploadFile = File(...),
440
+ candidate_url: str = Query(
441
+ ..., description="URL of the candidate image to compare"
442
+ ),
443
+ embedding_model: str = Query(
444
+ default="clip", description="Embedding model to use (clip, dinov2, siglip)"
445
+ ),
446
+ include_visualizations: bool = Query(
447
+ default=True, description="Include attention visualizations"
448
+ ),
449
  ):
450
  """Analyze patch-level attention between query image and a specific candidate image."""
451
  if not query_file.content_type.startswith("image/"):
 
457
  if embedding_model not in available_models:
458
  raise HTTPException(
459
  status_code=400,
460
+ detail=f"Invalid embedding model. Available: {available_models}",
461
  )
462
 
463
  # Get search engine with specified embedding model
 
470
  # Download candidate image
471
  candidate_image = engine.download_image(candidate_url)
472
  if candidate_image is None:
473
+ raise HTTPException(
474
+ status_code=400, detail="Failed to download candidate image"
475
+ )
476
 
477
  # Analyze patch attention
478
  analyzer = PatchAttentionAnalyzer(engine.embedding_model)
479
+ similarity_data = analyzer.compute_patch_similarities(
480
+ query_image, candidate_image
481
+ )
482
 
483
  result = {
484
  "query_image_size": query_image.size,
 
486
  "candidate_url": candidate_url,
487
  "embedding_model": engine.embedding_model.get_model_name(),
488
  "similarity_analysis": analyzer.get_similarity_summary(similarity_data),
489
+ "attention_matrix_shape": similarity_data["attention_matrix"].shape,
490
+ "top_correspondences": similarity_data["top_correspondences"][
491
+ :10
492
+ ], # Top 10
493
  }
494
 
495
  # Add visualizations if requested
 
504
 
505
  result["visualizations"] = {
506
  "attention_heatmap": f"data:image/png;base64,{attention_heatmap}",
507
+ "top_correspondences": f"data:image/png;base64,{top_correspondences_viz}",
508
  }
509
  except Exception as e:
510
  logger.warning(f"Failed to generate visualizations: {e}")
 
522
  """Get list of available embedding models and their configurations."""
523
  models = EmbeddingModelFactory.get_available_models()
524
  configs = get_default_model_configs()
525
+ return {"available_models": models, "model_configs": configs}
 
 
 
526
 
527
 
528
  @app.get("/health")