File size: 15,266 Bytes
a795a71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
import os

import torch
import numpy as np
#from sklearn.metrics.pairwise import cosine_similarity
from transformers import (
    AutoTokenizer,
    AutoModel,
)


class BGERetriever:
    def __init__(self, model_name=None, device=None, sentence_pooling_method="cls"):
        """
        Initializes the BGE retriever using the multilingual BGE-m3 base model.
        """
        # Use local model
        if model_name is None:
            model_suffix = "bge-m3"
            model_suffix = "test_encoder_only_m3_bge-m3_sd"
            model_suffix = "test_encoder_only_base_bge-large-en-v1.5_sd"
            model_suffix = "test_encoder_only_base_bge_m3_new"
            model_suffix = "test_encoder_only_base_bge_m3_new1"
            local_model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models', model_suffix)
            if os.path.isdir(local_model_path):
                model_name = local_model_path
                print(f"Using local BGE model from: {model_name}")

        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")

        self.return_dense: bool = True
        self.return_sparse: bool = False
        self.return_colbert_vecs: bool = False
        self.return_sparse_embedding: bool = False

        print(f"Loading BGE multilingual model on device: {self.device}")
        self.sentence_pooling_method = sentence_pooling_method
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name, torch_dtype=torch.float16, device_map=self.device)
        self.vocab_size = self.model.config.vocab_size
        self.temperature = 1.0
        self.model.eval()

        self.corpus_ids = []
        self.corpus_embeddings = None

    def _dense_embedding(self, last_hidden_state, attention_mask):
        """Use the pooling method to get the dense embedding.

        Args:
            last_hidden_state (torch.Tensor): The model output's last hidden state.
            attention_mask (torch.Tensor): Mask out padding tokens during pooling.

        Raises:
            NotImplementedError: Specified pooling method not implemented.

        Returns:
            torch.Tensor: The dense embeddings.
        """
        if self.sentence_pooling_method == "cls":
            return last_hidden_state[:, 0]
        elif self.sentence_pooling_method == "mean":
            s = torch.sum(
                last_hidden_state * attention_mask.unsqueeze(-1).float(), dim=1
            )
            d = attention_mask.sum(dim=1, keepdim=True).float()
            return s / d
        elif self.sentence_pooling_method == "last_token":
            left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0]
            if left_padding:
                return last_hidden_state[:, -1]
            else:
                sequence_lengths = attention_mask.sum(dim=1) - 1
                batch_size = last_hidden_state.shape[0]
                return last_hidden_state[
                    torch.arange(batch_size, device=last_hidden_state.device),
                    sequence_lengths,
                ]
        else:
            raise NotImplementedError(f"pooling method {self.sentence_pooling_method} not implemented")

    def _compute_similarity(self, q_reps, p_reps):
        """Computes the similarity between query and passage representations using inner product.

        Args:
            q_reps (torch.Tensor): Query representations.
            p_reps (torch.Tensor): Passage representations.

        Returns:
            torch.Tensor: The computed similarity matrix.
        """
        if len(p_reps.size()) == 2:
            return torch.matmul(q_reps, p_reps.transpose(0, 1))
        return torch.matmul(q_reps, p_reps.transpose(-2, -1))

    def compute_dense_score(self, q_reps, p_reps):
        """Compute the dense score.

        Args:
            q_reps (torch.Tensor): Query representations.
            p_reps (torch.Tensor): Passage representations.

        Returns:
            torch.Tensor: The computed dense scores, adjusted by temperature.
        """
        cos_scores = q_reps @ p_reps.T
        return cos_scores
        scores = self._compute_similarity(q_reps, p_reps) / self.temperature
        scores = scores.view(q_reps.size(0), -1)
        return scores

    @torch.inference_mode()
    def embed_texts(
        self,
        texts,
        is_query=False,
        batch_size=64,
    ):
        """
        Generates embeddings for texts using BGE model with proper prefixes.
        BGE requires specific prefixes for queries vs passages.
        """

        prefixed_texts = [text.strip() for text in texts]

        all_dense_embeddings = []
        total_batches = (len(prefixed_texts) + batch_size - 1) // batch_size

        for i in range(0, len(prefixed_texts), batch_size):
            batch_num = i // batch_size + 1
            if not is_query and batch_num % 50 == 0:
                print(f"Processing batch {batch_num}/{total_batches} ({(batch_num/total_batches)*100:.1f}%)")
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

            batch_texts = prefixed_texts[i:i + batch_size]

            encoded = self.tokenizer(
                batch_texts,
                padding=True,
                truncation=True,
                max_length=512,
                return_tensors='pt',
            ).to(self.device)

            model_output = self.model(**encoded)

            last_hidden_state = model_output.last_hidden_state
            dense_vecs = self._dense_embedding(last_hidden_state, encoded['attention_mask'])
            dense_vecs = torch.nn.functional.normalize(dense_vecs, p=2, dim=1)
            all_dense_embeddings.append(dense_vecs.cpu())

        all_dense_embeddings = torch.cat(all_dense_embeddings, dim=0)

        return all_dense_embeddings


class BGEReranker:
    def __init__(self, model_name=None, device=None):
        """
        Initializes the BGE reranker for fine-grained relevance scoring.
        """
        # Use local model
        if model_name is None:
            model_suffix = 'bge-reranker-v2-m3'
            model_suffix = 'test_encoder_only_base_bge_reranker_v2_m3'
            model_suffix = "test_encoder_only_base_bge_reranker_v2_m3_new"
            model_suffix = "test_encoder_only_base_bge_reranker_v2_m3_new1"
            local_model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models', model_suffix)
            if os.path.isdir(local_model_path):
                model_name = local_model_path
                print(f"Using local BGE model from: {model_name}")

        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")

        print(f"Loading BGE reranker on device: {self.device}")
        # BGE reranker is actually a special model type
        from transformers import AutoModelForSequenceClassification
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSequenceClassification.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            trust_remote_code=True,
            device_map=self.device,
        )
        self.model.eval()

    @torch.inference_mode()
    def rerank(self, query_text, passages, passage_ids, top_k=20, batch_size=32):
        """
        Rerank the passages using BGE reranker - CORRECTED VERSION.
        """
        if not passages:
            return []

        pairs = list(zip(passage_ids, passages))
        pairs.sort(key=lambda x: len(x[1]))
        passage_ids, passages = zip(*pairs)

        scores = []
        for i in range(0, len(passages), batch_size):
            batch_passages = passages[i:i + batch_size]

            try:
                # BGE reranker expects SEPARATE query and passage inputs
                # NOT concatenated strings
                batch_queries = [query_text] * len(batch_passages)

                # Tokenize query-passage pairs properly
                inputs = self.tokenizer(
                    batch_queries,
                    batch_passages,
                    padding=True,
                    truncation=True,
                    max_length=512,
                    return_tensors='pt'
                ).to(self.device)

                # Get relevance scores from sequence classification model
                outputs = self.model(**inputs)

                # BGE reranker outputs logits for relevance classification
                logits = outputs.logits

                # Handle different output shapes
                if len(logits.shape) == 1:
                    # Single score per pair
                    batch_scores = logits.cpu().numpy()
                elif logits.shape[1] == 1:
                    # Single column output
                    batch_scores = logits.squeeze(-1).cpu().numpy()
                else:
                    # Binary classification - take positive class (index 1)
                    batch_scores = logits[:, 1].cpu().numpy()

                scores.extend(batch_scores.tolist())

            except Exception as e:
                print(f"Error in reranking batch {i//batch_size + 1}: {e}")
                # Fallback: Use neutral scores for this batch
                fallback_scores = [0.5] * len(batch_passages)
                scores.extend(fallback_scores)

        # Combine results and sort by reranking score
        results = list(zip(passage_ids, scores))
        results.sort(key=lambda x: x[1], reverse=True)

        return results[:top_k]


# Global instances
retriever = None
reranker = None
corpus_texts = {}  # Store original passage texts for reranking


def preprocess(corpus_dict):
    """
    Preprocessing function using BGE multilingual model + BGE reranker.

    Input: corpus_dict - dict mapping document IDs to document objects with 'passage'/'text' field
    Output: dict containing initialized models, embeddings, and corpus data

    Note: Uses global variables (retriever, reranker, corpus_texts) for efficiency,
    but also returns all required data via preprocessed_data for function interface.
    """
    global retriever, reranker, corpus_texts
    print("=" * 60)
    print("PREPROCESSING: Initializing BGE Reranker Pipeline...")
    print("=" * 60)

    # Set GPU memory optimization
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

    # Initialize BGE retriever
    print("Loading BGE retriever...")
    retriever = BGERetriever()

    # Initialize BGE reranker
    print("Loading BGE reranker...")
    reranker = BGEReranker()

    print(f"Preparing corpus with {len(corpus_dict)} documents...")

    # Store corpus IDs, passages, and original texts
    #retriever.corpus_ids = list(corpus_dict.keys())
    corpus_ids = list(corpus_dict.keys())
    passages = [doc.get('passage', doc.get('text', '')) for doc in corpus_dict.values()]
    retriever.corpus_ids, passages = zip(*sorted(zip(corpus_ids, passages), key=lambda x: len(x[1])))

    # Store original texts for reranking
    corpus_texts = {doc_id: passages[i] for i, doc_id in enumerate(retriever.corpus_ids)}

    # Compute embeddings with conservative batch size for retrieval
    print("Computing BGE embeddings...")
    retriever.corpus_embeddings = retriever.embed_texts(passages, is_query=False, batch_size=64)

    print("✓ Corpus preprocessing complete!")
    print(f"✓ Generated embeddings for {len(retriever.corpus_ids)} documents")

    print(f"✓ Embedding matrix shape: {retriever.corpus_embeddings.shape}")

    return {
        'retriever': retriever,
        'reranker': reranker,
        'corpus_ids': retriever.corpus_ids,
        'corpus_embeddings': retriever.corpus_embeddings,
        'corpus_texts': corpus_texts,
        'num_documents': len(corpus_dict)
    }


def predict(query, preprocessed_data):
    """
    Two-stage prediction: BGE retrieval + BGE reranking.

    Input:
    - query: dict with 'query' field containing query text
    - preprocessed_data: dict from preprocess() containing models and corpus data

    Output: list of dicts with 'paragraph_uuid' and 'score' fields, ranked by relevance

    Note: Uses global variables for efficiency but can also extract required data
    from preprocessed_data parameter for proper function interface.
    """
    global retriever, reranker, corpus_texts

    # Extract query text
    query_text = query.get('query', '')
    if not query_text:
        return []

    # Use global instances or get from preprocessed_data
    if retriever is None:
        retriever = preprocessed_data.get('retriever')
        reranker = preprocessed_data.get('reranker')
        corpus_texts = preprocessed_data.get('corpus_texts', {})

        if retriever is None or reranker is None:
            print("Error: Missing retriever or reranker in preprocessed data")
            return []

    try:
        #raise
        # STAGE 1: BGE Retrieval (get top 100 candidates)
        print("Stage 1: BGE retrieval...")
        query_embedding = retriever.embed_texts([query_text], is_query=True, batch_size=1)

        # Compute cosine similarity with precomputed corpus embeddings
        #e5_scores = cosine_similarity(query_embedding, retriever.corpus_embeddings)[0]
        dense_scores = retriever.compute_dense_score(query_embedding, retriever.corpus_embeddings)
        e5_scores = dense_scores.squeeze(0).numpy()

        # Get top 100 candidates for reranking
        top_100_indices = np.argsort(e5_scores)[::-1][:100]

        # Get passages and IDs for reranking
        candidate_ids = [retriever.corpus_ids[idx] for idx in top_100_indices]
        candidate_passages = [corpus_texts.get(doc_id, '') for doc_id in candidate_ids]

        # STAGE 2: BGE Reranking (rerank top 100 -> top 20)
        print("Stage 2: BGE reranking...")
        reranked_results = reranker.rerank(
            query_text,
            candidate_passages,
            candidate_ids,
            top_k=20,
            batch_size=16,
        )

        # Build final results with ACTUAL reranking scores
        results = []
        for rank, (passage_id, rerank_score) in enumerate(reranked_results):
            results.append({
                'paragraph_uuid': passage_id,
                'score': float(rerank_score)  # Use actual BGE reranker score!
            })

        print(f"✓ Returned {len(results)} results with reranker scores")
        return results

    except Exception as e:
        print(f"Error in prediction: {e}")
        # Fallback to BGE-only retrieval with BGE scores
        query_embedding = retriever.embed_texts([query_text], is_query=True, batch_size=1)
        #e5_scores = cosine_similarity(query_embedding, retriever.corpus_embeddings)[0]
        dense_scores = retriever.compute_dense_score(query_embedding, retriever.corpus_embeddings)

        e5_scores = dense_scores.squeeze(0).numpy()

        top_indices = np.argsort(e5_scores)[::-1][:20]

        results = []
        for idx in top_indices:
            results.append({
                'paragraph_uuid': retriever.corpus_ids[idx],
                'score': float(e5_scores[idx])  # Use actual BGE cosine similarity score
            })

        return results