selfevolveagent / evoagentx /rag /embeddings /huggingface_embedding.py
iLOVE2D's picture
Upload 2846 files
5374a2d verified
from typing import List, Optional, Dict
from sentence_transformers import SentenceTransformer
from llama_index.core.embeddings import BaseEmbedding
from evoagentx.core.logging import logger
from .base import BaseEmbeddingWrapper, EmbeddingProvider, SUPPORTED_MODELS
class HuggingFaceEmbedding(BaseEmbedding):
"""HuggingFace embedding model compatible with LlamaIndex BaseEmbedding."""
model: SentenceTransformer = None
_dimension: int = None
model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
embed_batch_size: int = 10
device: Optional[str] = None
normalize: bool = False
model_kwargs: Dict = {}
def __init__(
self,
model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
device: Optional[str] = None,
normalize: bool = False,
**model_kwargs
):
super().__init__(model_name=model_name, embed_batch_size=10)
self.device = device
self.normalize = normalize
self.model_kwargs = model_kwargs or {}
if not EmbeddingProvider.validate_model(EmbeddingProvider.HUGGINGFACE, model_name):
raise ValueError(f"Unsupported HuggingFace model: {model_name}. Supported models: {SUPPORTED_MODELS['huggingface']}")
try:
self.model = SentenceTransformer(model_name, device=device, **model_kwargs)
logger.debug(f"Initialized HuggingFace embedding model: {model_name}")
except Exception as e:
logger.error(f"Failed to initialize HuggingFace embedding: {str(e)}")
raise
self._dimension = self.model.get_sentence_embedding_dimension()
def _get_query_embedding(self, query: str) -> List[float]:
"""Get embedding for a query string."""
try:
embedding = self.model.encode(
query,
normalize_embeddings=self.normalize,
convert_to_numpy=True
).tolist()
return embedding
except Exception as e:
logger.error(f"Failed to encode query: {str(e)}")
raise
def _get_text_embedding(self, text: str) -> List[float]:
"""Get embedding for a text string."""
try:
embedding = self.model.encode(
text,
normalize_embeddings=self.normalize,
convert_to_numpy=True
).tolist()
return embedding
except Exception as e:
logger.error(f"Failed to encode text: {str(e)}")
raise
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get embeddings for a list of texts synchronously."""
try:
embeddings = self.model.encode(
texts,
normalize_embeddings=self.normalize,
convert_to_numpy=True,
batch_size=self.embed_batch_size
).tolist()
return embeddings
except Exception as e:
logger.error(f"Failed to encode texts: {str(e)}")
raise
async def _aget_query_embedding(self, query: str) -> List[float]:
"""Asynchronous query embedding (falls back to sync)."""
return self._get_query_embedding(query)
@property
def dimension(self) -> int:
"""Return the embedding dimension."""
return self._dimension
class HuggingFaceEmbeddingWrapper(BaseEmbeddingWrapper):
"""Wrapper for HuggingFace embedding models."""
def __init__(
self,
model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
device: Optional[str] = None,
normalize: bool = True,
**model_kwargs
):
self.model_name = model_name
self.device = device
self.normalize = normalize
self.model_kwargs = model_kwargs
self._embedding_model = None
self._embedding_model = self.get_embedding_model()
def get_embedding_model(self) -> BaseEmbedding:
"""Return the LlamaIndex-compatible embedding model."""
if self._embedding_model is None:
try:
self._embedding_model = HuggingFaceEmbedding(
model_name=self.model_name,
device=self.device,
normalize=self.normalize,
**self.model_kwargs
)
logger.debug(f"Initialized HuggingFace embedding wrapper for model: {self.model_name}")
except Exception as e:
logger.error(f"Failed to initialize HuggingFace embedding wrapper: {str(e)}")
raise
return self._embedding_model
@property
def dimensions(self) -> int:
"""Return the embedding dimensions."""
return self._embedding_model.dimension