Spaces:
Runtime error
Runtime error
| import torch | |
| from typing import List, Dict, Optional | |
| from PIL import Image | |
| def l2_normalize(x: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: | |
| """L2 normalization.""" | |
| return x / (torch.norm(x, dim=-1, p=2, keepdim=True) + eps) | |
| def search_embeddings( | |
| query: str, | |
| query_image: Optional[Image.Image], | |
| model, | |
| embeddings: torch.Tensor, | |
| documents: List[str], | |
| image_paths: List[str], | |
| top_k: int = 20, | |
| modality: str = "image" | |
| ) -> List[Dict]: | |
| """ | |
| Search embeddings using query. | |
| Returns: | |
| List of dicts with 'image_path', 'text', 'rank', 'score' | |
| """ | |
| device = next(model.parameters()).device | |
| # Encode query | |
| with torch.inference_mode(): | |
| if query_image and query: | |
| query_emb = model.encode_documents(images=[query_image], texts=[query]) | |
| elif query: | |
| query_emb = model.encode_queries([query]) | |
| else: | |
| query_emb = model.encode_documents(images=[query_image]) | |
| # Compute similarity | |
| cos_sim = l2_normalize(query_emb.to(device)) @ l2_normalize(embeddings.to(device)).T | |
| cos_sim_flat = cos_sim.flatten() | |
| sorted_indices = torch.argsort(cos_sim_flat, descending=True) | |
| # Format results | |
| results = [] | |
| for rank, idx in enumerate(sorted_indices[:top_k], 1): | |
| doc_idx = idx.item() | |
| score = cos_sim_flat[doc_idx].item() | |
| result = { | |
| "rank": rank, | |
| "score": score, | |
| "image_path": image_paths[doc_idx] if modality != "text" else None, | |
| "text": documents[doc_idx] if modality != "image" else None | |
| } | |
| results.append(result) | |
| return results | |
| def rerank_results( | |
| query: str, | |
| results: List[Dict], | |
| rerank_model, | |
| rerank_processor, | |
| device: str, | |
| top_k: int = 10 | |
| ) -> List[Dict]: | |
| """ | |
| Rerank top results using cross-encoder. | |
| Returns: | |
| Reranked list of results | |
| """ | |
| from transformers.image_utils import load_image | |
| # Prepare examples | |
| examples = [] | |
| for result in results[:top_k]: | |
| img = load_image(result["image_path"]) | |
| examples.append({ | |
| "question": query, | |
| "doc_text": result.get("text", ""), | |
| "doc_image": img | |
| }) | |
| # Process batch | |
| batch = rerank_processor.process_queries_documents_crossencoder(examples) | |
| batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} | |
| # Get logits | |
| with torch.no_grad(): | |
| outputs = rerank_model(**batch, return_dict=True) | |
| logits = outputs.logits.squeeze(-1) | |
| rerank_indices = torch.argsort(logits, descending=True) | |
| # Reorder results | |
| reranked = [] | |
| for new_rank, idx in enumerate(rerank_indices, 1): | |
| old_result = results[idx.item()] | |
| old_result["rank"] = new_rank | |
| old_result["score"] = logits[idx].item() | |
| reranked.append(old_result) | |
| return reranked |