File size: 4,097 Bytes
1904012
 
 
6a14fa9
1904012
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a14fa9
 
1904012
6a14fa9
1904012
 
 
 
 
 
 
 
6a14fa9
 
1904012
6a14fa9
1904012
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a14fa9
1904012
 
 
 
 
 
 
6a14fa9
 
1904012
 
 
 
 
6a14fa9
1904012
 
 
 
 
 
6a14fa9
1904012
 
 
 
6a14fa9
1904012
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
from sentence_transformers import SentenceTransformer
import logging
import time
from typing import List, Union
from config.settings import get_settings

logger = logging.getLogger(__name__)

class EmbeddingService:
    _instance = None
    _model = None
    
    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(EmbeddingService, cls).__new__(cls)
        return cls._instance
    
    def __init__(self):
        if EmbeddingService._model is None:
            self._load_model()
    
    def _load_model(self):
        settings = get_settings()
        try:
            start_time = time.perf_counter()
            logger.info(f"[EMBEDDING] Starting to load embedding model: {settings.EMBEDDING_MODEL_NAME}")
            device = "cuda" if torch.cuda.is_available() else "cpu"
            logger.info(f"[EMBEDDING] Using device: {device}")
            
            EmbeddingService._model = SentenceTransformer(
                settings.EMBEDDING_MODEL_NAME,
                device=device
            )
            
            EmbeddingService._model.max_seq_length = 2048
            
            load_time = time.perf_counter() - start_time
            logger.info(f"[EMBEDDING] Embedding model loaded successfully in {load_time:.3f}s")
        except Exception as e:
            logger.error(f"[EMBEDDING] Failed to load embedding model: {str(e)}", exc_info=True)
            raise
    
    def encode(
        self,
        texts: Union[str, List[str]],
        is_query: bool = False,
        batch_size: int = 32,
        max_length: int = 2048
    ) -> Union[List[float], List[List[float]]]:
        if EmbeddingService._model is None:
            raise RuntimeError("Embedding model not loaded")
        
        if isinstance(texts, str):
            texts = [texts]
            single_text = True
        else:
            single_text = False
        
        if not texts:
            raise ValueError("Texts cannot be empty")
        
        try:
            encode_start = time.perf_counter()
            embeddings = EmbeddingService._model.encode(
                texts,
                batch_size=batch_size,
                show_progress_bar=False,
                convert_to_numpy=True,
                normalize_embeddings=False
            )
            encode_time = time.perf_counter() - encode_start
            logger.info(f"[EMBEDDING] Encoded {len(texts)} text(s) in {encode_time:.3f}s (is_query={is_query})")
            
            expected_dim = 1024
            if single_text:
                embedding_list = embeddings[0].tolist()
                if len(embedding_list) != expected_dim:
                    logger.warning(f"[EMBEDDING] Embedding dimension mismatch: expected {expected_dim}, got {len(embedding_list)}")
                return embedding_list
            
            result = []
            for emb in embeddings:
                emb_list = emb.tolist()
                if len(emb_list) != expected_dim:
                    logger.warning(f"[EMBEDDING] Embedding dimension mismatch: expected {expected_dim}, got {len(emb_list)}")
                result.append(emb_list)
            
            return result
        except Exception as e:
            logger.error(f"[EMBEDDING] Error encoding texts: {str(e)}", exc_info=True)
            raise
    
    def get_model_info(self) -> dict:
        settings = get_settings()
        dimension = 1024
        
        if EmbeddingService._model is not None:
            try:
                test_embedding = EmbeddingService._model.encode(["test"], convert_to_numpy=True)
                dimension = len(test_embedding[0])
            except Exception as e:
                logger.warning(f"Could not determine model dimension: {str(e)}")
        
        return {
            "model_name": settings.EMBEDDING_MODEL_NAME,
            "dimension": dimension,
            "device": "cuda" if torch.cuda.is_available() else "cpu",
            "max_seq_length": EmbeddingService._model.max_seq_length if EmbeddingService._model else 2048
        }