File size: 3,439 Bytes
ed4f8d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""
Custom handler for ANCE dense retrieval on HuggingFace Inference Endpoints.

Returns dense embeddings for queries and passages.
"""

from typing import Any, Dict, List, Union

import torch
from transformers import AutoModel, AutoTokenizer


class EndpointHandler:
    """Handler for ANCE embedding generation."""
    
    def __init__(self, path: str = ""):
        """Initialize the model and tokenizer."""
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.model = AutoModel.from_pretrained(path)
        self.model.eval()
        
        # Move to GPU if available
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = self.model.to(self.device)
        
        print(f"ANCE loaded on {self.device}")
    
    def __call__(self, data: Dict[str, Any]) -> Union[List[List[float]], Dict[str, Any]]:
        """
        Process inference requests.
        
        Accepts:
        - {"inputs": "text"} - single text
        - {"inputs": ["text1", "text2", ...]} - batch of texts
        - {"query": "...", "passages": ["...", ...]} - query + passages (returns similarity scores)
        
        Returns:
        - List of embeddings (each embedding is a list of floats)
        - Or similarity scores if query + passages provided
        """
        inputs = data.get("inputs", None)
        query = data.get("query", None)
        passages = data.get("passages", None)
        
        # Mode 1: Query + Passages -> return similarity scores
        if query is not None and passages is not None:
            query_emb = self._encode([query])[0]
            passage_embs = self._encode(passages)
            
            # Compute cosine similarities
            scores = []
            for p_emb in passage_embs:
                score = self._cosine_similarity(query_emb, p_emb)
                scores.append(score)
            
            return {"scores": scores}
        
        # Mode 2: Just inputs -> return embeddings
        if inputs is None:
            return {"error": "No inputs provided. Use 'inputs' or 'query'+'passages'."}
        
        if isinstance(inputs, str):
            inputs = [inputs]
        
        embeddings = self._encode(inputs)
        return {"embeddings": embeddings}
    
    def _encode(self, texts: List[str], max_length: int = 512) -> List[List[float]]:
        """Encode texts into embeddings."""
        # Tokenize
        encoded = self.tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors="pt"
        ).to(self.device)
        
        # Get embeddings
        with torch.no_grad():
            outputs = self.model(**encoded)
            # Use CLS token embedding
            embeddings = outputs.last_hidden_state[:, 0, :]
            # Normalize embeddings
            embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
        
        return embeddings.cpu().tolist()
    
    def _cosine_similarity(self, emb1: List[float], emb2: List[float]) -> float:
        """Compute cosine similarity between two embeddings."""
        import math
        dot = sum(a * b for a, b in zip(emb1, emb2))
        norm1 = math.sqrt(sum(a * a for a in emb1))
        norm2 = math.sqrt(sum(b * b for b in emb2))
        if norm1 == 0 or norm2 == 0:
            return 0.0
        return dot / (norm1 * norm2)