File size: 4,782 Bytes
5374a2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from typing import List, Optional, Dict

from sentence_transformers import SentenceTransformer
from llama_index.core.embeddings import BaseEmbedding

from evoagentx.core.logging import logger
from .base import BaseEmbeddingWrapper, EmbeddingProvider, SUPPORTED_MODELS


class HuggingFaceEmbedding(BaseEmbedding):
    """HuggingFace embedding model compatible with LlamaIndex BaseEmbedding."""
    
    model: SentenceTransformer = None
    _dimension: int = None
    model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
    embed_batch_size: int = 10
    device: Optional[str] = None
    normalize: bool = False
    model_kwargs: Dict = {}
    
    def __init__(
        self,
        model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
        device: Optional[str] = None,
        normalize: bool = False,
        **model_kwargs
    ):
        super().__init__(model_name=model_name, embed_batch_size=10)
        self.device = device
        self.normalize = normalize
        self.model_kwargs = model_kwargs or {}

        if not EmbeddingProvider.validate_model(EmbeddingProvider.HUGGINGFACE, model_name):
            raise ValueError(f"Unsupported HuggingFace model: {model_name}. Supported models: {SUPPORTED_MODELS['huggingface']}")
        try:
            self.model = SentenceTransformer(model_name, device=device, **model_kwargs)
            logger.debug(f"Initialized HuggingFace embedding model: {model_name}")
        except Exception as e:
            logger.error(f"Failed to initialize HuggingFace embedding: {str(e)}")
            raise

        self._dimension = self.model.get_sentence_embedding_dimension()

    def _get_query_embedding(self, query: str) -> List[float]:
        """Get embedding for a query string."""
        try:
            embedding = self.model.encode(
                query,
                normalize_embeddings=self.normalize,
                convert_to_numpy=True
            ).tolist()
            return embedding
        except Exception as e:
            logger.error(f"Failed to encode query: {str(e)}")
            raise

    def _get_text_embedding(self, text: str) -> List[float]:
        """Get embedding for a text string."""
        try:
            embedding = self.model.encode(
                text,
                normalize_embeddings=self.normalize,
                convert_to_numpy=True
            ).tolist()
            return embedding
        except Exception as e:
            logger.error(f"Failed to encode text: {str(e)}")
            raise

    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        """Get embeddings for a list of texts synchronously."""
        try:
            embeddings = self.model.encode(
                texts,
                normalize_embeddings=self.normalize,
                convert_to_numpy=True,
                batch_size=self.embed_batch_size
            ).tolist()
            return embeddings
        except Exception as e:
            logger.error(f"Failed to encode texts: {str(e)}")
            raise

    async def _aget_query_embedding(self, query: str) -> List[float]:
        """Asynchronous query embedding (falls back to sync)."""
        return self._get_query_embedding(query)

    @property
    def dimension(self) -> int:
        """Return the embedding dimension."""
        return self._dimension


class HuggingFaceEmbeddingWrapper(BaseEmbeddingWrapper):
    """Wrapper for HuggingFace embedding models."""
    
    def __init__(
        self,
        model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
        device: Optional[str] = None,
        normalize: bool = True,
        **model_kwargs
    ):
        self.model_name = model_name
        self.device = device
        self.normalize = normalize
        self.model_kwargs = model_kwargs
        self._embedding_model = None
        self._embedding_model = self.get_embedding_model()

    def get_embedding_model(self) -> BaseEmbedding:
        """Return the LlamaIndex-compatible embedding model."""
        if self._embedding_model is None:
            try:
                self._embedding_model = HuggingFaceEmbedding(
                    model_name=self.model_name,
                    device=self.device,
                    normalize=self.normalize,
                    **self.model_kwargs
                )
                logger.debug(f"Initialized HuggingFace embedding wrapper for model: {self.model_name}")
            except Exception as e:
                logger.error(f"Failed to initialize HuggingFace embedding wrapper: {str(e)}")
                raise
        return self._embedding_model
    
    @property
    def dimensions(self) -> int:
        """Return the embedding dimensions."""
        return self._embedding_model.dimension