File size: 2,978 Bytes
3f8c153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
from typing import List, Dict, Optional
from PIL import Image

def l2_normalize(x: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    """L2 normalization."""
    return x / (torch.norm(x, dim=-1, p=2, keepdim=True) + eps)


def search_embeddings(
    query: str,
    query_image: Optional[Image.Image],
    model,
    embeddings: torch.Tensor,
    documents: List[str],
    image_paths: List[str],
    top_k: int = 20,
    modality: str = "image"
) -> List[Dict]:
    """
    Search embeddings using query.
    
    Returns:
        List of dicts with 'image_path', 'text', 'rank', 'score'
    """
    device = next(model.parameters()).device
    
    # Encode query
    with torch.inference_mode():
        if query_image and query:
            query_emb = model.encode_documents(images=[query_image], texts=[query])
        elif query:
            query_emb = model.encode_queries([query])
        else:
            query_emb = model.encode_documents(images=[query_image])
    
    # Compute similarity
    cos_sim = l2_normalize(query_emb.to(device)) @ l2_normalize(embeddings.to(device)).T
    cos_sim_flat = cos_sim.flatten()
    sorted_indices = torch.argsort(cos_sim_flat, descending=True)
    
    # Format results
    results = []
    for rank, idx in enumerate(sorted_indices[:top_k], 1):
        doc_idx = idx.item()
        score = cos_sim_flat[doc_idx].item()
        
        result = {
            "rank": rank,
            "score": score,
            "image_path": image_paths[doc_idx] if modality != "text" else None,
            "text": documents[doc_idx] if modality != "image" else None
        }
        results.append(result)
    
    return results


def rerank_results(
    query: str,
    results: List[Dict],
    rerank_model,
    rerank_processor,
    device: str,
    top_k: int = 10
) -> List[Dict]:
    """
    Rerank top results using cross-encoder.
    
    Returns:
        Reranked list of results
    """
    from transformers.image_utils import load_image
    
    # Prepare examples
    examples = []
    for result in results[:top_k]:
        img = load_image(result["image_path"])
        examples.append({
            "question": query,
            "doc_text": result.get("text", ""),
            "doc_image": img
        })
    
    # Process batch
    batch = rerank_processor.process_queries_documents_crossencoder(examples)
    batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
    
    # Get logits
    with torch.no_grad():
        outputs = rerank_model(**batch, return_dict=True)
    
    logits = outputs.logits.squeeze(-1)
    rerank_indices = torch.argsort(logits, descending=True)
    
    # Reorder results
    reranked = []
    for new_rank, idx in enumerate(rerank_indices, 1):
        old_result = results[idx.item()]
        old_result["rank"] = new_rank
        old_result["score"] = logits[idx].item()
        reranked.append(old_result)
    
    return reranked