Spaces:
Sleeping
Sleeping
unfreeze caption generator
Browse files
main.py
CHANGED
|
@@ -12,14 +12,15 @@ import requests
|
|
| 12 |
import torch
|
| 13 |
import torch.nn.functional as F
|
| 14 |
from dotenv import load_dotenv
|
| 15 |
-
from fastapi import FastAPI, File, HTTPException,
|
| 16 |
from fastapi.middleware.cors import CORSMiddleware
|
| 17 |
from huggingface_hub import InferenceClient
|
| 18 |
from PIL import Image
|
|
|
|
|
|
|
|
|
|
| 19 |
from search_engines import SearchEngineManager
|
| 20 |
from utils import SearchCache, URLValidator
|
| 21 |
-
from embeddings import EmbeddingModelFactory, EmbeddingModel, get_default_model_configs
|
| 22 |
-
from patch_attention import PatchAttentionAnalyzer
|
| 23 |
|
| 24 |
# Load environment variables from .env file
|
| 25 |
load_dotenv()
|
|
@@ -95,26 +96,25 @@ class TattooSearchEngine:
|
|
| 95 |
image_b64 = base64.b64encode(img_buffer.getvalue()).decode()
|
| 96 |
image_url = f"data:image/jpeg;base64,{image_b64}"
|
| 97 |
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
caption =
|
| 117 |
-
# caption = completion.choices[0].message.content
|
| 118 |
if caption:
|
| 119 |
match = re.search(r"\{.*\}", caption)
|
| 120 |
if match:
|
|
@@ -256,8 +256,11 @@ class TattooSearchEngine:
|
|
| 256 |
return None
|
| 257 |
|
| 258 |
def download_and_process_image(
|
| 259 |
-
self,
|
| 260 |
-
|
|
|
|
|
|
|
|
|
|
| 261 |
) -> Dict[str, Any]:
|
| 262 |
"""Download and compute similarity for a single image"""
|
| 263 |
candidate_image = self.download_image(url)
|
|
@@ -266,7 +269,9 @@ class TattooSearchEngine:
|
|
| 266 |
|
| 267 |
try:
|
| 268 |
candidate_features = self.embedding_model.encode_image(candidate_image)
|
| 269 |
-
similarity = self.embedding_model.compute_similarity(
|
|
|
|
|
|
|
| 270 |
|
| 271 |
result = {"score": float(similarity), "url": url}
|
| 272 |
|
|
@@ -274,12 +279,16 @@ class TattooSearchEngine:
|
|
| 274 |
if include_patch_attention and query_image is not None:
|
| 275 |
try:
|
| 276 |
analyzer = PatchAttentionAnalyzer(self.embedding_model)
|
| 277 |
-
patch_data = analyzer.compute_patch_similarities(
|
|
|
|
|
|
|
| 278 |
result["patch_attention"] = {
|
| 279 |
"overall_similarity": patch_data["overall_similarity"],
|
| 280 |
"query_grid_size": patch_data["query_grid_size"],
|
| 281 |
"candidate_grid_size": patch_data["candidate_grid_size"],
|
| 282 |
-
"attention_summary": analyzer.get_similarity_summary(
|
|
|
|
|
|
|
| 283 |
}
|
| 284 |
except Exception as e:
|
| 285 |
logger.warning(f"Failed to compute patch attention for {url}: {e}")
|
|
@@ -292,7 +301,10 @@ class TattooSearchEngine:
|
|
| 292 |
return None
|
| 293 |
|
| 294 |
def compute_similarity(
|
| 295 |
-
self,
|
|
|
|
|
|
|
|
|
|
| 296 |
) -> List[Dict[str, Any]]:
|
| 297 |
# Encode query image using the selected embedding model
|
| 298 |
query_features = self.embedding_model.encode_image(query_image)
|
|
@@ -306,7 +318,11 @@ class TattooSearchEngine:
|
|
| 306 |
# Submit all download tasks
|
| 307 |
future_to_url = {
|
| 308 |
executor.submit(
|
| 309 |
-
self.download_and_process_image,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
): url
|
| 311 |
for url in candidate_urls
|
| 312 |
}
|
|
@@ -343,10 +359,14 @@ class TattooSearchEngine:
|
|
| 343 |
# Global variable to store search engine instance
|
| 344 |
search_engine = None
|
| 345 |
|
|
|
|
| 346 |
def get_search_engine(embedding_model: str = "clip") -> TattooSearchEngine:
|
| 347 |
"""Get or create search engine instance with specified embedding model."""
|
| 348 |
global search_engine
|
| 349 |
-
if
|
|
|
|
|
|
|
|
|
|
| 350 |
search_engine = TattooSearchEngine(embedding_model)
|
| 351 |
return search_engine
|
| 352 |
|
|
@@ -354,8 +374,12 @@ def get_search_engine(embedding_model: str = "clip") -> TattooSearchEngine:
|
|
| 354 |
@app.post("/search")
|
| 355 |
async def search_tattoos(
|
| 356 |
file: UploadFile = File(...),
|
| 357 |
-
embedding_model: str = Query(
|
| 358 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
):
|
| 360 |
if not file.content_type.startswith("image/"):
|
| 361 |
raise HTTPException(status_code=400, detail="File must be an image")
|
|
@@ -366,7 +390,7 @@ async def search_tattoos(
|
|
| 366 |
if embedding_model not in available_models:
|
| 367 |
raise HTTPException(
|
| 368 |
status_code=400,
|
| 369 |
-
detail=f"Invalid embedding model. Available: {available_models}"
|
| 370 |
)
|
| 371 |
|
| 372 |
# Get search engine with specified embedding model
|
|
@@ -386,17 +410,23 @@ async def search_tattoos(
|
|
| 386 |
candidate_urls = engine.search_images(caption, max_results=100)
|
| 387 |
|
| 388 |
if not candidate_urls:
|
| 389 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
|
| 391 |
# Compute similarities and rank
|
| 392 |
logger.info("Computing similarities...")
|
| 393 |
-
results = engine.compute_similarity(
|
|
|
|
|
|
|
| 394 |
|
| 395 |
return {
|
| 396 |
"caption": caption,
|
| 397 |
"results": results,
|
| 398 |
"embedding_model": engine.embedding_model.get_model_name(),
|
| 399 |
-
"patch_attention_enabled": include_patch_attention
|
| 400 |
}
|
| 401 |
|
| 402 |
except Exception as e:
|
|
@@ -407,9 +437,15 @@ async def search_tattoos(
|
|
| 407 |
@app.post("/analyze-attention")
|
| 408 |
async def analyze_patch_attention(
|
| 409 |
query_file: UploadFile = File(...),
|
| 410 |
-
candidate_url: str = Query(
|
| 411 |
-
|
| 412 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
):
|
| 414 |
"""Analyze patch-level attention between query image and a specific candidate image."""
|
| 415 |
if not query_file.content_type.startswith("image/"):
|
|
@@ -421,7 +457,7 @@ async def analyze_patch_attention(
|
|
| 421 |
if embedding_model not in available_models:
|
| 422 |
raise HTTPException(
|
| 423 |
status_code=400,
|
| 424 |
-
detail=f"Invalid embedding model. Available: {available_models}"
|
| 425 |
)
|
| 426 |
|
| 427 |
# Get search engine with specified embedding model
|
|
@@ -434,11 +470,15 @@ async def analyze_patch_attention(
|
|
| 434 |
# Download candidate image
|
| 435 |
candidate_image = engine.download_image(candidate_url)
|
| 436 |
if candidate_image is None:
|
| 437 |
-
raise HTTPException(
|
|
|
|
|
|
|
| 438 |
|
| 439 |
# Analyze patch attention
|
| 440 |
analyzer = PatchAttentionAnalyzer(engine.embedding_model)
|
| 441 |
-
similarity_data = analyzer.compute_patch_similarities(
|
|
|
|
|
|
|
| 442 |
|
| 443 |
result = {
|
| 444 |
"query_image_size": query_image.size,
|
|
@@ -446,8 +486,10 @@ async def analyze_patch_attention(
|
|
| 446 |
"candidate_url": candidate_url,
|
| 447 |
"embedding_model": engine.embedding_model.get_model_name(),
|
| 448 |
"similarity_analysis": analyzer.get_similarity_summary(similarity_data),
|
| 449 |
-
"attention_matrix_shape": similarity_data[
|
| 450 |
-
"top_correspondences": similarity_data[
|
|
|
|
|
|
|
| 451 |
}
|
| 452 |
|
| 453 |
# Add visualizations if requested
|
|
@@ -462,7 +504,7 @@ async def analyze_patch_attention(
|
|
| 462 |
|
| 463 |
result["visualizations"] = {
|
| 464 |
"attention_heatmap": f"data:image/png;base64,{attention_heatmap}",
|
| 465 |
-
"top_correspondences": f"data:image/png;base64,{top_correspondences_viz}"
|
| 466 |
}
|
| 467 |
except Exception as e:
|
| 468 |
logger.warning(f"Failed to generate visualizations: {e}")
|
|
@@ -480,10 +522,7 @@ async def get_available_models():
|
|
| 480 |
"""Get list of available embedding models and their configurations."""
|
| 481 |
models = EmbeddingModelFactory.get_available_models()
|
| 482 |
configs = get_default_model_configs()
|
| 483 |
-
return {
|
| 484 |
-
"available_models": models,
|
| 485 |
-
"model_configs": configs
|
| 486 |
-
}
|
| 487 |
|
| 488 |
|
| 489 |
@app.get("/health")
|
|
|
|
| 12 |
import torch
|
| 13 |
import torch.nn.functional as F
|
| 14 |
from dotenv import load_dotenv
|
| 15 |
+
from fastapi import FastAPI, File, HTTPException, Query, UploadFile
|
| 16 |
from fastapi.middleware.cors import CORSMiddleware
|
| 17 |
from huggingface_hub import InferenceClient
|
| 18 |
from PIL import Image
|
| 19 |
+
|
| 20 |
+
from embeddings import EmbeddingModel, EmbeddingModelFactory, get_default_model_configs
|
| 21 |
+
from patch_attention import PatchAttentionAnalyzer
|
| 22 |
from search_engines import SearchEngineManager
|
| 23 |
from utils import SearchCache, URLValidator
|
|
|
|
|
|
|
| 24 |
|
| 25 |
# Load environment variables from .env file
|
| 26 |
load_dotenv()
|
|
|
|
| 96 |
image_b64 = base64.b64encode(img_buffer.getvalue()).decode()
|
| 97 |
image_url = f"data:image/jpeg;base64,{image_b64}"
|
| 98 |
|
| 99 |
+
completion = self.client.chat.completions.create(
|
| 100 |
+
model=self.vlm_model,
|
| 101 |
+
messages=[
|
| 102 |
+
{
|
| 103 |
+
"role": "user",
|
| 104 |
+
"content": [
|
| 105 |
+
{
|
| 106 |
+
"type": "text",
|
| 107 |
+
"text": "Generate a one search engine query to find the most similar tattoos to this image. Response in json format",
|
| 108 |
+
},
|
| 109 |
+
{
|
| 110 |
+
"type": "image_url",
|
| 111 |
+
"image_url": {"url": image_url},
|
| 112 |
+
},
|
| 113 |
+
],
|
| 114 |
+
}
|
| 115 |
+
],
|
| 116 |
+
)
|
| 117 |
+
caption = completion.choices[0].message.content
|
|
|
|
| 118 |
if caption:
|
| 119 |
match = re.search(r"\{.*\}", caption)
|
| 120 |
if match:
|
|
|
|
| 256 |
return None
|
| 257 |
|
| 258 |
def download_and_process_image(
|
| 259 |
+
self,
|
| 260 |
+
url: str,
|
| 261 |
+
query_features: torch.Tensor,
|
| 262 |
+
query_image: Image.Image = None,
|
| 263 |
+
include_patch_attention: bool = False,
|
| 264 |
) -> Dict[str, Any]:
|
| 265 |
"""Download and compute similarity for a single image"""
|
| 266 |
candidate_image = self.download_image(url)
|
|
|
|
| 269 |
|
| 270 |
try:
|
| 271 |
candidate_features = self.embedding_model.encode_image(candidate_image)
|
| 272 |
+
similarity = self.embedding_model.compute_similarity(
|
| 273 |
+
query_features, candidate_features
|
| 274 |
+
)
|
| 275 |
|
| 276 |
result = {"score": float(similarity), "url": url}
|
| 277 |
|
|
|
|
| 279 |
if include_patch_attention and query_image is not None:
|
| 280 |
try:
|
| 281 |
analyzer = PatchAttentionAnalyzer(self.embedding_model)
|
| 282 |
+
patch_data = analyzer.compute_patch_similarities(
|
| 283 |
+
query_image, candidate_image
|
| 284 |
+
)
|
| 285 |
result["patch_attention"] = {
|
| 286 |
"overall_similarity": patch_data["overall_similarity"],
|
| 287 |
"query_grid_size": patch_data["query_grid_size"],
|
| 288 |
"candidate_grid_size": patch_data["candidate_grid_size"],
|
| 289 |
+
"attention_summary": analyzer.get_similarity_summary(
|
| 290 |
+
patch_data
|
| 291 |
+
),
|
| 292 |
}
|
| 293 |
except Exception as e:
|
| 294 |
logger.warning(f"Failed to compute patch attention for {url}: {e}")
|
|
|
|
| 301 |
return None
|
| 302 |
|
| 303 |
def compute_similarity(
|
| 304 |
+
self,
|
| 305 |
+
query_image: Image.Image,
|
| 306 |
+
candidate_urls: List[str],
|
| 307 |
+
include_patch_attention: bool = False,
|
| 308 |
) -> List[Dict[str, Any]]:
|
| 309 |
# Encode query image using the selected embedding model
|
| 310 |
query_features = self.embedding_model.encode_image(query_image)
|
|
|
|
| 318 |
# Submit all download tasks
|
| 319 |
future_to_url = {
|
| 320 |
executor.submit(
|
| 321 |
+
self.download_and_process_image,
|
| 322 |
+
url,
|
| 323 |
+
query_features,
|
| 324 |
+
query_image,
|
| 325 |
+
include_patch_attention,
|
| 326 |
): url
|
| 327 |
for url in candidate_urls
|
| 328 |
}
|
|
|
|
| 359 |
# Global variable to store search engine instance
|
| 360 |
search_engine = None
|
| 361 |
|
| 362 |
+
|
| 363 |
def get_search_engine(embedding_model: str = "clip") -> TattooSearchEngine:
|
| 364 |
"""Get or create search engine instance with specified embedding model."""
|
| 365 |
global search_engine
|
| 366 |
+
if (
|
| 367 |
+
search_engine is None
|
| 368 |
+
or search_engine.embedding_model.get_model_name().lower() != embedding_model
|
| 369 |
+
):
|
| 370 |
search_engine = TattooSearchEngine(embedding_model)
|
| 371 |
return search_engine
|
| 372 |
|
|
|
|
| 374 |
@app.post("/search")
|
| 375 |
async def search_tattoos(
|
| 376 |
file: UploadFile = File(...),
|
| 377 |
+
embedding_model: str = Query(
|
| 378 |
+
default="clip", description="Embedding model to use (clip, dinov2, siglip)"
|
| 379 |
+
),
|
| 380 |
+
include_patch_attention: bool = Query(
|
| 381 |
+
default=False, description="Include patch-level attention analysis"
|
| 382 |
+
),
|
| 383 |
):
|
| 384 |
if not file.content_type.startswith("image/"):
|
| 385 |
raise HTTPException(status_code=400, detail="File must be an image")
|
|
|
|
| 390 |
if embedding_model not in available_models:
|
| 391 |
raise HTTPException(
|
| 392 |
status_code=400,
|
| 393 |
+
detail=f"Invalid embedding model. Available: {available_models}",
|
| 394 |
)
|
| 395 |
|
| 396 |
# Get search engine with specified embedding model
|
|
|
|
| 410 |
candidate_urls = engine.search_images(caption, max_results=100)
|
| 411 |
|
| 412 |
if not candidate_urls:
|
| 413 |
+
return {
|
| 414 |
+
"caption": caption,
|
| 415 |
+
"results": [],
|
| 416 |
+
"embedding_model": engine.embedding_model.get_model_name(),
|
| 417 |
+
}
|
| 418 |
|
| 419 |
# Compute similarities and rank
|
| 420 |
logger.info("Computing similarities...")
|
| 421 |
+
results = engine.compute_similarity(
|
| 422 |
+
query_image, candidate_urls, include_patch_attention
|
| 423 |
+
)
|
| 424 |
|
| 425 |
return {
|
| 426 |
"caption": caption,
|
| 427 |
"results": results,
|
| 428 |
"embedding_model": engine.embedding_model.get_model_name(),
|
| 429 |
+
"patch_attention_enabled": include_patch_attention,
|
| 430 |
}
|
| 431 |
|
| 432 |
except Exception as e:
|
|
|
|
| 437 |
@app.post("/analyze-attention")
|
| 438 |
async def analyze_patch_attention(
|
| 439 |
query_file: UploadFile = File(...),
|
| 440 |
+
candidate_url: str = Query(
|
| 441 |
+
..., description="URL of the candidate image to compare"
|
| 442 |
+
),
|
| 443 |
+
embedding_model: str = Query(
|
| 444 |
+
default="clip", description="Embedding model to use (clip, dinov2, siglip)"
|
| 445 |
+
),
|
| 446 |
+
include_visualizations: bool = Query(
|
| 447 |
+
default=True, description="Include attention visualizations"
|
| 448 |
+
),
|
| 449 |
):
|
| 450 |
"""Analyze patch-level attention between query image and a specific candidate image."""
|
| 451 |
if not query_file.content_type.startswith("image/"):
|
|
|
|
| 457 |
if embedding_model not in available_models:
|
| 458 |
raise HTTPException(
|
| 459 |
status_code=400,
|
| 460 |
+
detail=f"Invalid embedding model. Available: {available_models}",
|
| 461 |
)
|
| 462 |
|
| 463 |
# Get search engine with specified embedding model
|
|
|
|
| 470 |
# Download candidate image
|
| 471 |
candidate_image = engine.download_image(candidate_url)
|
| 472 |
if candidate_image is None:
|
| 473 |
+
raise HTTPException(
|
| 474 |
+
status_code=400, detail="Failed to download candidate image"
|
| 475 |
+
)
|
| 476 |
|
| 477 |
# Analyze patch attention
|
| 478 |
analyzer = PatchAttentionAnalyzer(engine.embedding_model)
|
| 479 |
+
similarity_data = analyzer.compute_patch_similarities(
|
| 480 |
+
query_image, candidate_image
|
| 481 |
+
)
|
| 482 |
|
| 483 |
result = {
|
| 484 |
"query_image_size": query_image.size,
|
|
|
|
| 486 |
"candidate_url": candidate_url,
|
| 487 |
"embedding_model": engine.embedding_model.get_model_name(),
|
| 488 |
"similarity_analysis": analyzer.get_similarity_summary(similarity_data),
|
| 489 |
+
"attention_matrix_shape": similarity_data["attention_matrix"].shape,
|
| 490 |
+
"top_correspondences": similarity_data["top_correspondences"][
|
| 491 |
+
:10
|
| 492 |
+
], # Top 10
|
| 493 |
}
|
| 494 |
|
| 495 |
# Add visualizations if requested
|
|
|
|
| 504 |
|
| 505 |
result["visualizations"] = {
|
| 506 |
"attention_heatmap": f"data:image/png;base64,{attention_heatmap}",
|
| 507 |
+
"top_correspondences": f"data:image/png;base64,{top_correspondences_viz}",
|
| 508 |
}
|
| 509 |
except Exception as e:
|
| 510 |
logger.warning(f"Failed to generate visualizations: {e}")
|
|
|
|
| 522 |
"""Get list of available embedding models and their configurations."""
|
| 523 |
models = EmbeddingModelFactory.get_available_models()
|
| 524 |
configs = get_default_model_configs()
|
| 525 |
+
return {"available_models": models, "model_configs": configs}
|
|
|
|
|
|
|
|
|
|
| 526 |
|
| 527 |
|
| 528 |
@app.get("/health")
|