File size: 4,846 Bytes
09c4529
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""
RexReranker Inference Utilities.

This module provides helper functions for converting model logits to relevance scores.
The model outputs logits for 11 bins representing a distribution over [0, 1].
To get a relevance score, apply softmax and compute the expected value.

Example usage:
    from transformers import AutoModelForSequenceClassification, AutoTokenizer
    from utils import logits_to_relevance, logits_to_relevance_with_uncertainty
    import torch
    
    model = AutoModelForSequenceClassification.from_pretrained("path/to/model")
    tokenizer = AutoTokenizer.from_pretrained("path/to/model")
    
    inputs = tokenizer(
        "Query: best laptop",
        "Title: MacBook Pro\nDescription: Great laptop for developers",
        return_tensors="pt",
        truncation=True,
    )
    
    with torch.no_grad():
        outputs = model(**inputs)
        
        # Simple relevance score
        relevance = logits_to_relevance(outputs.logits)
        print(f"Relevance: {relevance.item():.3f}")
        
        # With uncertainty estimates
        result = logits_to_relevance_with_uncertainty(outputs.logits)
        print(f"Relevance: {result['relevance'].item():.3f}")
        print(f"Variance: {result['variance'].item():.4f}")
        print(f"Entropy: {result['entropy'].item():.3f}")
"""

import torch
from typing import Dict


# Configuration
NUM_BINS = 11
BIN_CENTERS = torch.linspace(0.0, 1.0, NUM_BINS)


def logits_to_relevance(logits: torch.Tensor) -> torch.Tensor:
    """
    Convert model logits to relevance scores.
    
    Args:
        logits: Model output logits [B, 11]
        
    Returns:
        relevance: Relevance scores [B] in range [0, 1]
    """
    probs = torch.softmax(logits, dim=-1)
    bin_centers = BIN_CENTERS.to(logits.device)
    return (probs * bin_centers.view(1, -1)).sum(dim=-1)


def logits_to_relevance_with_uncertainty(logits: torch.Tensor) -> Dict[str, torch.Tensor]:
    """
    Convert model logits to relevance scores with uncertainty estimates.
    
    Args:
        logits: Model output logits [B, 11]
        
    Returns:
        dict with:
            - relevance: [B] predicted relevance scores in [0, 1]
            - variance: [B] prediction variance (higher = more uncertain)
            - entropy: [B] distribution entropy (higher = more uncertain)
            - probs: [B, 11] full probability distribution over bins
    """
    probs = torch.softmax(logits, dim=-1)
    bin_centers = BIN_CENTERS.to(logits.device)
    
    relevance = (probs * bin_centers.view(1, -1)).sum(dim=-1)
    variance = (probs * (bin_centers.view(1, -1) - relevance.unsqueeze(-1)) ** 2).sum(dim=-1)
    entropy = -(probs * torch.log(probs.clamp(min=1e-9))).sum(dim=-1)
    
    return {
        "relevance": relevance,
        "variance": variance,
        "entropy": entropy,
        "probs": probs,
    }


def batch_rerank(
    model,
    tokenizer,
    query: str,
    documents: list,
    max_length: int = 2048,
    batch_size: int = 32,
    device: str = None,
) -> list:
    """
    Rerank a list of documents for a given query.
    
    Args:
        model: The RexReranker model
        tokenizer: The tokenizer
        query: The search query
        documents: List of dicts with 'title' and 'description' keys
        max_length: Maximum sequence length
        batch_size: Batch size for inference
        device: Device to use (default: auto-detect)
        
    Returns:
        List of dicts with original document info plus 'relevance', 'variance', 'entropy'
    """
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    
    model = model.to(device)
    model.eval()
    
    results = []
    
    for i in range(0, len(documents), batch_size):
        batch_docs = documents[i:i + batch_size]
        
        # Format inputs
        texts_a = [f"Query: {query}" for _ in batch_docs]
        texts_b = [f"Title: {doc.get('title', '')}\nDescription: {doc.get('description', '')}" for doc in batch_docs]
        
        inputs = tokenizer(
            texts_a,
            texts_b,
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors="pt",
        ).to(device)
        
        with torch.no_grad():
            outputs = model(**inputs)
            batch_results = logits_to_relevance_with_uncertainty(outputs.logits)
        
        for j, doc in enumerate(batch_docs):
            results.append({
                **doc,
                "relevance": batch_results["relevance"][j].item(),
                "variance": batch_results["variance"][j].item(),
                "entropy": batch_results["entropy"][j].item(),
            })
    
    # Sort by relevance (descending)
    results.sort(key=lambda x: x["relevance"], reverse=True)
    return results