File size: 5,489 Bytes
a63c61f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import logging
import os
from typing import Dict, Any, List

from src.core.config import settings
from src.core.ports.embedder_port import EmbedderPort

logger = logging.getLogger(__name__)

if os.name == 'nt':
    os.environ["HF_HUB_DISABLE_SYMLINKS"] = "1"

try:
    import transformers.utils.import_utils
    if not hasattr(transformers.utils.import_utils, 'is_torch_fx_available'):
        transformers.utils.import_utils.is_torch_fx_available = lambda: False
    
    from FlagEmbedding import BGEM3FlagModel
    HAS_FLAG_EMBEDDING = True
except ImportError as e:
    HAS_FLAG_EMBEDDING = False
    logger.warning(f"FlagEmbedding not installed: {e}. Using dummy embeddings.")

class BgeEmbedderAdapter(EmbedderPort):
    def __init__(self):
        self.model = None
        self.model_name = settings.EMBEDDING_MODEL
        
    def _load_model(self):
        if self.model is None:
            if not HAS_FLAG_EMBEDDING:
                logger.warning("FlagEmbedding not installed. Using dummy embeddings.")
                return
                
            logger.info(f"Loading embedding model: {self.model_name}")
            try:
                self.model = BGEM3FlagModel(self.model_name, use_fp16=True) 
                logger.info(f"Successfully loaded {self.model_name} (Hybrid Mode)")
            except Exception as e:
                logger.error(f"Failed to load embedding model: {e}", exc_info=True)
                raise e

    def encode_query(self, text: str) -> Dict[str, Any]:
        """Encodes a query string into dense and sparse vectors."""
        if self.model is None:
            self._load_model()
            
        if not HAS_FLAG_EMBEDDING or self.model is None:
            return {
                "dense": [0.1] * settings.VECTOR_SIZE,
                "sparse": None
            }
        
        embeddings = self.model.encode(
            sentences=[text],
            batch_size=1,
            max_length=512,
            return_dense=True,
            return_sparse=True,
            return_colbert_vecs=False
        )
        
        dense_vec = embeddings['dense_vecs'][0].tolist()
        lexical_dict = embeddings['lexical_weights'][0]
        sparse_vec = {
            "indices": [int(k) for k in lexical_dict.keys()],
            "values": [float(v) for v in lexical_dict.values()] 
        }
        
        return {
            "dense": dense_vec,
            "sparse": sparse_vec
        }

    def encode_sparse_only(self, text: str) -> Dict[str, Any]:
        """
        Encodes only the sparse (BM25/lexical) vector for a single query.
        Skips dense computation β€” ~2x faster than encode_query.
        Used for per-language sparse queries when the dense vector is
        already available from the English query.
        """
        if self.model is None:
            self._load_model()

        if not HAS_FLAG_EMBEDDING or self.model is None:
            return {"sparse": None}

        embeddings = self.model.encode(
            sentences=[text],
            batch_size=1,
            max_length=512,
            return_dense=False,       # skip dense β€” saves ~60% compute
            return_sparse=True,
            return_colbert_vecs=False
        )

        lexical_dict = embeddings['lexical_weights'][0]
        sparse_vec = {
            "indices": [int(k) for k in lexical_dict.keys()],
            "values": [float(v) for v in lexical_dict.values()]
        }
        return {"sparse": sparse_vec}

    def encode_sparse_batch(self, texts: List[str]) -> List[Dict[str, Any]]:
        """
        Encode multiple texts into sparse vectors in a SINGLE model forward pass.

        Why this matters:
        BGE-M3 holds the Python GIL during inference β€” ThreadPoolExecutor gives
        zero benefit for CPU-bound model calls. Calling encode_sparse_only() 6
        times in a thread pool still runs sequentially. This method batches all
        6 language queries into one model.encode() call, which is ~5x faster
        than 6 sequential calls because:
          - One tokenization pass for all texts
          - One forward pass through the transformer
          - GPU/CPU utilisation is much higher with batch_size=6 vs batch_size=1

        Returns a list of sparse dicts in the same order as `texts`.
        Falls back to empty sparse vectors on failure.
        """
        if not texts:
            return []

        if self.model is None:
            self._load_model()

        if not HAS_FLAG_EMBEDDING or self.model is None:
            return [{"sparse": None} for _ in texts]

        try:
            embeddings = self.model.encode(
                sentences=texts,
                batch_size=len(texts),   # all in one shot
                max_length=512,
                return_dense=False,      # skip dense β€” not needed here
                return_sparse=True,
                return_colbert_vecs=False
            )

            results = []
            for lexical_dict in embeddings['lexical_weights']:
                results.append({
                    "sparse": {
                        "indices": [int(k) for k in lexical_dict.keys()],
                        "values":  [float(v) for v in lexical_dict.values()],
                    }
                })
            return results

        except Exception as e:
            logger.error(f"encode_sparse_batch failed: {e} β€” returning empty sparse vectors")
            return [{"sparse": None} for _ in texts]