|
|
from typing import List, Optional, Dict |
|
|
|
|
|
from sentence_transformers import SentenceTransformer |
|
|
from llama_index.core.embeddings import BaseEmbedding |
|
|
|
|
|
from evoagentx.core.logging import logger |
|
|
from .base import BaseEmbeddingWrapper, EmbeddingProvider, SUPPORTED_MODELS |
|
|
|
|
|
|
|
|
class HuggingFaceEmbedding(BaseEmbedding): |
|
|
"""HuggingFace embedding model compatible with LlamaIndex BaseEmbedding.""" |
|
|
|
|
|
model: SentenceTransformer = None |
|
|
_dimension: int = None |
|
|
model_name: str = "sentence-transformers/all-MiniLM-L6-v2" |
|
|
embed_batch_size: int = 10 |
|
|
device: Optional[str] = None |
|
|
normalize: bool = False |
|
|
model_kwargs: Dict = {} |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_name: str = "sentence-transformers/all-MiniLM-L6-v2", |
|
|
device: Optional[str] = None, |
|
|
normalize: bool = False, |
|
|
**model_kwargs |
|
|
): |
|
|
super().__init__(model_name=model_name, embed_batch_size=10) |
|
|
self.device = device |
|
|
self.normalize = normalize |
|
|
self.model_kwargs = model_kwargs or {} |
|
|
|
|
|
if not EmbeddingProvider.validate_model(EmbeddingProvider.HUGGINGFACE, model_name): |
|
|
raise ValueError(f"Unsupported HuggingFace model: {model_name}. Supported models: {SUPPORTED_MODELS['huggingface']}") |
|
|
try: |
|
|
self.model = SentenceTransformer(model_name, device=device, **model_kwargs) |
|
|
logger.debug(f"Initialized HuggingFace embedding model: {model_name}") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to initialize HuggingFace embedding: {str(e)}") |
|
|
raise |
|
|
|
|
|
self._dimension = self.model.get_sentence_embedding_dimension() |
|
|
|
|
|
def _get_query_embedding(self, query: str) -> List[float]: |
|
|
"""Get embedding for a query string.""" |
|
|
try: |
|
|
embedding = self.model.encode( |
|
|
query, |
|
|
normalize_embeddings=self.normalize, |
|
|
convert_to_numpy=True |
|
|
).tolist() |
|
|
return embedding |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to encode query: {str(e)}") |
|
|
raise |
|
|
|
|
|
def _get_text_embedding(self, text: str) -> List[float]: |
|
|
"""Get embedding for a text string.""" |
|
|
try: |
|
|
embedding = self.model.encode( |
|
|
text, |
|
|
normalize_embeddings=self.normalize, |
|
|
convert_to_numpy=True |
|
|
).tolist() |
|
|
return embedding |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to encode text: {str(e)}") |
|
|
raise |
|
|
|
|
|
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: |
|
|
"""Get embeddings for a list of texts synchronously.""" |
|
|
try: |
|
|
embeddings = self.model.encode( |
|
|
texts, |
|
|
normalize_embeddings=self.normalize, |
|
|
convert_to_numpy=True, |
|
|
batch_size=self.embed_batch_size |
|
|
).tolist() |
|
|
return embeddings |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to encode texts: {str(e)}") |
|
|
raise |
|
|
|
|
|
async def _aget_query_embedding(self, query: str) -> List[float]: |
|
|
"""Asynchronous query embedding (falls back to sync).""" |
|
|
return self._get_query_embedding(query) |
|
|
|
|
|
@property |
|
|
def dimension(self) -> int: |
|
|
"""Return the embedding dimension.""" |
|
|
return self._dimension |
|
|
|
|
|
|
|
|
class HuggingFaceEmbeddingWrapper(BaseEmbeddingWrapper): |
|
|
"""Wrapper for HuggingFace embedding models.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_name: str = "sentence-transformers/all-MiniLM-L6-v2", |
|
|
device: Optional[str] = None, |
|
|
normalize: bool = True, |
|
|
**model_kwargs |
|
|
): |
|
|
self.model_name = model_name |
|
|
self.device = device |
|
|
self.normalize = normalize |
|
|
self.model_kwargs = model_kwargs |
|
|
self._embedding_model = None |
|
|
self._embedding_model = self.get_embedding_model() |
|
|
|
|
|
def get_embedding_model(self) -> BaseEmbedding: |
|
|
"""Return the LlamaIndex-compatible embedding model.""" |
|
|
if self._embedding_model is None: |
|
|
try: |
|
|
self._embedding_model = HuggingFaceEmbedding( |
|
|
model_name=self.model_name, |
|
|
device=self.device, |
|
|
normalize=self.normalize, |
|
|
**self.model_kwargs |
|
|
) |
|
|
logger.debug(f"Initialized HuggingFace embedding wrapper for model: {self.model_name}") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to initialize HuggingFace embedding wrapper: {str(e)}") |
|
|
raise |
|
|
return self._embedding_model |
|
|
|
|
|
@property |
|
|
def dimensions(self) -> int: |
|
|
"""Return the embedding dimensions.""" |
|
|
return self._embedding_model.dimension |