File size: 6,334 Bytes
5374a2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import os
import warnings
from typing import List, Optional, Dict

from openai import OpenAI
from llama_index.core.embeddings import BaseEmbedding

from evoagentx.core.logging import logger
from .base import BaseEmbeddingWrapper, EmbeddingProvider, SUPPORTED_MODELS

# Mapping of default embedding dimensions for OpenAI models
MODEL_DIMENSIONS = {
    "text-embedding-ada-002": 1536,
    "text-embedding-3-small": 1536,
    "text-embedding-3-large": 3072
}

SUPPORTED_DIMENSIONS = ["text-embedding-3-small", "text-embedding-3-large",]

class OpenAIEmbedding(BaseEmbedding):
    """OpenAI embedding model compatible with LlamaIndex BaseEmbedding."""
    
    api_key: str
    client: OpenAI = None
    base_url: str = "https://api.openai.com/v1"
    model_name: str = "text-embedding-3-small"
    embed_batch_size: int = 10
    dimensions: Optional[int] = None
    kwargs: Optional[Dict] = {}
    
    def __init__(
        self,
        model_name: str = "text-embedding-3-small",
        api_key: str = None,
        dimensions: int = None,
        base_url: str = None,
        **kwargs
    ):
        api_key = api_key or os.getenv("OPENAI_API_KEY") or ""
        super().__init__(api_key=api_key, model_name=model_name, embed_batch_size=10)
        base_url = (
            base_url
            or os.getenv("OPENAI_API_BASE")
            or os.getenv("OPENAI_BASE_URL")
            or "https://api.openai.com/v1"
        )
        if os.environ.get("OPENAI_API_BASE"):
            warnings.warn(
                "The environment variable 'OPENAI_API_BASE' is deprecated and will be removed in the 0.1.80. "
                "Please use 'OPENAI_BASE_URL' instead.",
                DeprecationWarning,
            )
        self.base_url = base_url
        self.dimensions = dimensions
        self.kwargs = kwargs

        if not EmbeddingProvider.validate_model(EmbeddingProvider.OPENAI, model_name):
            raise ValueError(f"Unsupported OpenAI model: {model_name}. Supported models: {SUPPORTED_MODELS['openai']}")
        # Check for the dimensions support
        if dimensions is not None and model_name not in SUPPORTED_DIMENSIONS:
            logger.warning(
                f"Dimensions parameter is not supported for model {model_name}. "
                f"Only '{SUPPORTED_DIMENSIONS}' support custom dimensions. Ignoring dimensions parameter."
            )
            self.dimensions = None
        elif dimensions is None and model_name in SUPPORTED_DIMENSIONS:
            self.dimensions = dimensions or MODEL_DIMENSIONS.get(model_name)

        try:
            self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
            logger.debug(f"Initialized OpenAI embedding model: {model_name}")
        except Exception as e:
            logger.error(f"Failed to initialize OpenAI client: {str(e)}")
            raise

    def _get_query_embedding(self, query: str) -> List[float]:
        """Get embedding for a query string."""
        try:
            query = query.replace("\n", " ")
            response = self.client.embeddings.create(
                input=[query],
                model=self.model_name,
                dimensions=self.dimensions,
                **self.kwargs
            )
            return response.data[0].embedding
        except Exception as e:
            logger.error(f"Failed to encode query: {str(e)}")
            raise

    def _get_text_embedding(self, text: str) -> List[float]:
        """Get embedding for a text string."""
        try:
            text = text.replace("\n", " ")
            response = self.client.embeddings.create(
                input=[text],
                model=self.model_name,
                dimensions=self.dimensions,
                **self.kwargs
            )
            return response.data[0].embedding
        except Exception as e:
            logger.error(f"Failed to encode text: {str(e)}")
            raise

    async def _aget_query_embedding(self, query: str) -> List[float]:
        """Asynchronous query embedding."""
        return self._get_query_embedding(query)

    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        """Get embeddings for a list of texts synchronously."""
        try:
            texts = [text.replace("\n", " ") for text in texts]
            response = self.client.embeddings.create(
                input=texts,
                model=self.model_name,
                dimensions=self.dimensions,
                **self.kwargs
            )
            return [item.embedding for item in response.data]
        except Exception as e:
            logger.error(f"Failed to encode texts: {str(e)}")
            raise


class OpenAIEmbeddingWrapper(BaseEmbeddingWrapper):
    """Wrapper for OpenAI embedding models."""
    
    def __init__(
        self,
        model_name: str = "text-embedding-3-small",
        api_key: str = None,
        dimensions: int = None,
        base_url: str = None,
        **kwargs
    ):
        self.model_name = model_name
        self.api_key = api_key
        self._dimensions = MODEL_DIMENSIONS.get(self.model_name, None) or dimensions
        self.base_url = base_url
        self.kwargs = kwargs
        self._embedding_model = None
        self._embedding_model = self.get_embedding_model()

    def get_embedding_model(self) -> BaseEmbedding:
        """Return the LlamaIndex-compatible embedding model."""
        # if self._embedding_model is None:
        if getattr(self, "_embedding_model", None) is None:
            try:
                self._embedding_model = OpenAIEmbedding(
                    model_name=self.model_name,
                    api_key=self.api_key,
                    dimensions=self._dimensions,
                    base_url=self.base_url,
                    **self.kwargs
                )
                logger.debug(f"Initialized OpenAI embedding wrapper for model: {self.model_name}")
            except Exception as e:
                logger.error(f"Failed to initialize OpenAI embedding wrapper: {str(e)}")
                raise
        return self._embedding_model
    
    @property
    def dimensions(self) -> int:
        """Return the embedding dimensions."""
        # return self._embedding_model or MODEL_DIMENSIONS.get(self.model_name, None)
        return self._dimensions