File size: 4,782 Bytes
5374a2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
from typing import List, Optional, Dict
from sentence_transformers import SentenceTransformer
from llama_index.core.embeddings import BaseEmbedding
from evoagentx.core.logging import logger
from .base import BaseEmbeddingWrapper, EmbeddingProvider, SUPPORTED_MODELS
class HuggingFaceEmbedding(BaseEmbedding):
"""HuggingFace embedding model compatible with LlamaIndex BaseEmbedding."""
model: SentenceTransformer = None
_dimension: int = None
model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
embed_batch_size: int = 10
device: Optional[str] = None
normalize: bool = False
model_kwargs: Dict = {}
def __init__(
self,
model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
device: Optional[str] = None,
normalize: bool = False,
**model_kwargs
):
super().__init__(model_name=model_name, embed_batch_size=10)
self.device = device
self.normalize = normalize
self.model_kwargs = model_kwargs or {}
if not EmbeddingProvider.validate_model(EmbeddingProvider.HUGGINGFACE, model_name):
raise ValueError(f"Unsupported HuggingFace model: {model_name}. Supported models: {SUPPORTED_MODELS['huggingface']}")
try:
self.model = SentenceTransformer(model_name, device=device, **model_kwargs)
logger.debug(f"Initialized HuggingFace embedding model: {model_name}")
except Exception as e:
logger.error(f"Failed to initialize HuggingFace embedding: {str(e)}")
raise
self._dimension = self.model.get_sentence_embedding_dimension()
def _get_query_embedding(self, query: str) -> List[float]:
"""Get embedding for a query string."""
try:
embedding = self.model.encode(
query,
normalize_embeddings=self.normalize,
convert_to_numpy=True
).tolist()
return embedding
except Exception as e:
logger.error(f"Failed to encode query: {str(e)}")
raise
def _get_text_embedding(self, text: str) -> List[float]:
"""Get embedding for a text string."""
try:
embedding = self.model.encode(
text,
normalize_embeddings=self.normalize,
convert_to_numpy=True
).tolist()
return embedding
except Exception as e:
logger.error(f"Failed to encode text: {str(e)}")
raise
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get embeddings for a list of texts synchronously."""
try:
embeddings = self.model.encode(
texts,
normalize_embeddings=self.normalize,
convert_to_numpy=True,
batch_size=self.embed_batch_size
).tolist()
return embeddings
except Exception as e:
logger.error(f"Failed to encode texts: {str(e)}")
raise
async def _aget_query_embedding(self, query: str) -> List[float]:
"""Asynchronous query embedding (falls back to sync)."""
return self._get_query_embedding(query)
@property
def dimension(self) -> int:
"""Return the embedding dimension."""
return self._dimension
class HuggingFaceEmbeddingWrapper(BaseEmbeddingWrapper):
"""Wrapper for HuggingFace embedding models."""
def __init__(
self,
model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
device: Optional[str] = None,
normalize: bool = True,
**model_kwargs
):
self.model_name = model_name
self.device = device
self.normalize = normalize
self.model_kwargs = model_kwargs
self._embedding_model = None
self._embedding_model = self.get_embedding_model()
def get_embedding_model(self) -> BaseEmbedding:
"""Return the LlamaIndex-compatible embedding model."""
if self._embedding_model is None:
try:
self._embedding_model = HuggingFaceEmbedding(
model_name=self.model_name,
device=self.device,
normalize=self.normalize,
**self.model_kwargs
)
logger.debug(f"Initialized HuggingFace embedding wrapper for model: {self.model_name}")
except Exception as e:
logger.error(f"Failed to initialize HuggingFace embedding wrapper: {str(e)}")
raise
return self._embedding_model
@property
def dimensions(self) -> int:
"""Return the embedding dimensions."""
return self._embedding_model.dimension |