import torch from typing import List, Dict, Optional from PIL import Image def l2_normalize(x: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: """L2 normalization.""" return x / (torch.norm(x, dim=-1, p=2, keepdim=True) + eps) def search_embeddings( query: str, query_image: Optional[Image.Image], model, embeddings: torch.Tensor, documents: List[str], image_paths: List[str], top_k: int = 20, modality: str = "image" ) -> List[Dict]: """ Search embeddings using query. Returns: List of dicts with 'image_path', 'text', 'rank', 'score' """ device = next(model.parameters()).device # Encode query with torch.inference_mode(): if query_image and query: query_emb = model.encode_documents(images=[query_image], texts=[query]) elif query: query_emb = model.encode_queries([query]) else: query_emb = model.encode_documents(images=[query_image]) # Compute similarity cos_sim = l2_normalize(query_emb.to(device)) @ l2_normalize(embeddings.to(device)).T cos_sim_flat = cos_sim.flatten() sorted_indices = torch.argsort(cos_sim_flat, descending=True) # Format results results = [] for rank, idx in enumerate(sorted_indices[:top_k], 1): doc_idx = idx.item() score = cos_sim_flat[doc_idx].item() result = { "rank": rank, "score": score, "image_path": image_paths[doc_idx] if modality != "text" else None, "text": documents[doc_idx] if modality != "image" else None } results.append(result) return results def rerank_results( query: str, results: List[Dict], rerank_model, rerank_processor, device: str, top_k: int = 10 ) -> List[Dict]: """ Rerank top results using cross-encoder. Returns: Reranked list of results """ from transformers.image_utils import load_image # Prepare examples examples = [] for result in results[:top_k]: img = load_image(result["image_path"]) examples.append({ "question": query, "doc_text": result.get("text", ""), "doc_image": img }) # Process batch batch = rerank_processor.process_queries_documents_crossencoder(examples) batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # Get logits with torch.no_grad(): outputs = rerank_model(**batch, return_dict=True) logits = outputs.logits.squeeze(-1) rerank_indices = torch.argsort(logits, descending=True) # Reorder results reranked = [] for new_rank, idx in enumerate(rerank_indices, 1): old_result = results[idx.item()] old_result["rank"] = new_rank old_result["score"] = logits[idx].item() reranked.append(old_result) return reranked