File size: 4,112 Bytes
8a682b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""
Embedding Manager - Centralized Embedding Management
Fixes the critical embedding consistency issue where different components
were using different embedding methods (random vs real embeddings).
"""

import os
import logging
from typing import Optional, List
import numpy as np

logger = logging.getLogger(__name__)

class EmbeddingManager:
    """Centralized embedding management to ensure consistency across all components"""
    
    _instance: Optional['EmbeddingManager'] = None
    
    def __new__(cls):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
            cls._instance._initialize()
        return cls._instance
    
    def _initialize(self):
        """Initialize embedding model once"""
        self._client = None
        self._model = None
        
        if os.getenv("OPENAI_API_KEY"):
            try:
                from openai import OpenAI
                self._client = OpenAI()
                self.method = "openai"
                self.dimension = 1536
                logger.info("Using OpenAI embeddings")
            except ImportError:
                logger.warning("OpenAI not available, falling back to local embeddings")
                self._setup_local_embeddings()
        else:
            self._setup_local_embeddings()
    
    def _setup_local_embeddings(self):
        """Setup local sentence transformer embeddings"""
        try:
            from sentence_transformers import SentenceTransformer
            self._model = SentenceTransformer('all-MiniLM-L6-v2')
            self.method = "local"
            self.dimension = 384
            logger.info("Using local sentence transformer embeddings")
        except ImportError:
            logger.error("No embedding models available!")
            self.method = "none"
            self.dimension = 0
    
    def embed(self, text: str) -> List[float]:
        """Get embedding for text"""
        if not text or self.method == "none":
            # Return zero vector as fallback
            return [0.0] * max(self.dimension, 384)
        
        if self.method == "openai":
            try:
                response = self._client.embeddings.create(
                    model="text-embedding-3-small",
                    input=text
                )
                return response.data[0].embedding
            except Exception as e:
                logger.error(f"OpenAI embedding failed: {e}")
                return [0.0] * self.dimension
        else:
            try:
                embedding = self._model.encode(text)
                return embedding.tolist()
            except Exception as e:
                logger.error(f"Local embedding failed: {e}")
                return [0.0] * self.dimension
    
    def embed_batch(self, texts: List[str]) -> List[List[float]]:
        """Get embeddings for multiple texts (more efficient)"""
        if not texts:
            return []
        
        if self.method == "openai":
            try:
                response = self._client.embeddings.create(
                    model="text-embedding-3-small",
                    input=texts
                )
                return [data.embedding for data in response.data]
            except Exception as e:
                logger.error(f"OpenAI batch embedding failed: {e}")
                return [[0.0] * self.dimension for _ in texts]
        else:
            try:
                embeddings = self._model.encode(texts)
                return embeddings.tolist()
            except Exception as e:
                logger.error(f"Local batch embedding failed: {e}")
                return [[0.0] * self.dimension for _ in texts]
    
    def get_dimension(self) -> int:
        """Get embedding dimension"""
        return self.dimension
    
    def get_method(self) -> str:
        """Get embedding method being used"""
        return self.method

# Global embedding manager instance
embedding_manager = EmbeddingManager()

def get_embedding_manager() -> EmbeddingManager:
    """Get the global embedding manager instance"""
    return embedding_manager