File size: 4,130 Bytes
057446f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""
Custom handler for BGE dense retrieval on HuggingFace Inference Endpoints.

Returns dense embeddings for queries and passages.

Key difference from ANCE: BGE requires an instruction prefix on queries
for retrieval tasks. Passages are encoded without any prefix.
"""

from typing import Any, Dict, List, Union

import torch
from transformers import AutoModel, AutoTokenizer


# BGE query instruction prefix
QUERY_INSTRUCTION = "Represent this sentence for searching relevant passages: "


class EndpointHandler:
    """Handler for BGE embedding generation."""

    def __init__(self, path: str = ""):
        """Initialize the model and tokenizer."""
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.model = AutoModel.from_pretrained(path)
        self.model.eval()

        # Move to GPU if available
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = self.model.to(self.device)

        print(f"BGE loaded on {self.device}")

    def __call__(self, data: Dict[str, Any]) -> Union[List[List[float]], Dict[str, Any]]:
        """
        Process inference requests.

        Accepts:
        - {"inputs": "text"} - single text (encoded as passage, no prefix)
        - {"inputs": ["text1", "text2", ...]} - batch of texts (encoded as passages)
        - {"inputs": "text", "is_query": true} - single text with query prefix
        - {"inputs": [...], "is_query": true} - batch with query prefix
        - {"query": "...", "passages": ["...", ...]} - query + passages (returns similarity scores)

        Returns:
        - {"embeddings": [...]} - list of embeddings
        - Or {"scores": [...]} if query + passages provided
        """
        inputs = data.get("inputs", None)
        is_query = data.get("is_query", False)
        query = data.get("query", None)
        passages = data.get("passages", None)

        # Mode 1: Query + Passages -> return similarity scores
        if query is not None and passages is not None:
            # Encode query with instruction prefix
            query_with_prefix = f"{QUERY_INSTRUCTION}{query}"
            query_emb = self._encode([query_with_prefix])[0]

            # Encode passages without prefix
            passage_embs = self._encode(passages)

            # Compute cosine similarities
            scores = []
            for p_emb in passage_embs:
                score = self._cosine_similarity(query_emb, p_emb)
                scores.append(score)

            return {"scores": scores}

        # Mode 2: Just inputs -> return embeddings
        if inputs is None:
            return {"error": "No inputs provided. Use 'inputs' or 'query'+'passages'."}

        if isinstance(inputs, str):
            inputs = [inputs]

        # Apply query instruction prefix if flagged as queries
        if is_query:
            inputs = [f"{QUERY_INSTRUCTION}{text}" for text in inputs]

        embeddings = self._encode(inputs)
        return {"embeddings": embeddings}

    def _encode(self, texts: List[str], max_length: int = 512) -> List[List[float]]:
        """Encode texts into embeddings."""
        # Tokenize
        encoded = self.tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors="pt"
        ).to(self.device)

        # Get embeddings
        with torch.no_grad():
            outputs = self.model(**encoded)
            # Use CLS token embedding
            embeddings = outputs.last_hidden_state[:, 0, :]
            # Normalize embeddings
            embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)

        return embeddings.cpu().tolist()

    def _cosine_similarity(self, emb1: List[float], emb2: List[float]) -> float:
        """Compute cosine similarity between two embeddings."""
        import math
        dot = sum(a * b for a, b in zip(emb1, emb2))
        norm1 = math.sqrt(sum(a * a for a in emb1))
        norm2 = math.sqrt(sum(b * b for b in emb2))
        if norm1 == 0 or norm2 == 0:
            return 0.0
        return dot / (norm1 * norm2)