TurkishCodeMan's picture
Upload folder using huggingface_hub
3f8c153 verified
import torch
from typing import List, Dict, Optional
from PIL import Image
def l2_normalize(x: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
"""L2 normalization."""
return x / (torch.norm(x, dim=-1, p=2, keepdim=True) + eps)
def search_embeddings(
query: str,
query_image: Optional[Image.Image],
model,
embeddings: torch.Tensor,
documents: List[str],
image_paths: List[str],
top_k: int = 20,
modality: str = "image"
) -> List[Dict]:
"""
Search embeddings using query.
Returns:
List of dicts with 'image_path', 'text', 'rank', 'score'
"""
device = next(model.parameters()).device
# Encode query
with torch.inference_mode():
if query_image and query:
query_emb = model.encode_documents(images=[query_image], texts=[query])
elif query:
query_emb = model.encode_queries([query])
else:
query_emb = model.encode_documents(images=[query_image])
# Compute similarity
cos_sim = l2_normalize(query_emb.to(device)) @ l2_normalize(embeddings.to(device)).T
cos_sim_flat = cos_sim.flatten()
sorted_indices = torch.argsort(cos_sim_flat, descending=True)
# Format results
results = []
for rank, idx in enumerate(sorted_indices[:top_k], 1):
doc_idx = idx.item()
score = cos_sim_flat[doc_idx].item()
result = {
"rank": rank,
"score": score,
"image_path": image_paths[doc_idx] if modality != "text" else None,
"text": documents[doc_idx] if modality != "image" else None
}
results.append(result)
return results
def rerank_results(
query: str,
results: List[Dict],
rerank_model,
rerank_processor,
device: str,
top_k: int = 10
) -> List[Dict]:
"""
Rerank top results using cross-encoder.
Returns:
Reranked list of results
"""
from transformers.image_utils import load_image
# Prepare examples
examples = []
for result in results[:top_k]:
img = load_image(result["image_path"])
examples.append({
"question": query,
"doc_text": result.get("text", ""),
"doc_image": img
})
# Process batch
batch = rerank_processor.process_queries_documents_crossencoder(examples)
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
# Get logits
with torch.no_grad():
outputs = rerank_model(**batch, return_dict=True)
logits = outputs.logits.squeeze(-1)
rerank_indices = torch.argsort(logits, descending=True)
# Reorder results
reranked = []
for new_rank, idx in enumerate(rerank_indices, 1):
old_result = results[idx.item()]
old_result["rank"] = new_rank
old_result["score"] = logits[idx].item()
reranked.append(old_result)
return reranked