iLOVE2D's picture
Upload 2846 files
5374a2d verified
import sys
import subprocess
from typing import List, Dict, Optional
from llama_index.core.embeddings import BaseEmbedding
from evoagentx.core.logging import logger
from .base import BaseEmbeddingWrapper, EmbeddingProvider, SUPPORTED_MODELS
try:
from ollama import Client
except ImportError:
logger.warning("The 'ollama' library is not installed. Attempting to install it.")
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", "ollama"])
from ollama import Client
except subprocess.CalledProcessError:
logger.error("Failed to install 'ollama'. Please install it manually using 'pip install ollama'.")
raise ImportError("The 'ollama' library is required.")
MODEL_DIMENSIONS = {
"nomic-embed-text": 384,
# "mxbai-embed-large": ,
# "bge-m3": ,
# "all-minilm": ,
# "snowflake-arctic-embed",
}
class OllamaEmbedding(BaseEmbedding):
"""Ollama embedding model compatible with LlamaIndex BaseEmbedding."""
base_url: str = None
client: Client = None
model_name: str = "nomic-embed-text"
embed_batch_size: int = 10
embedding_dims: int = None
kwargs: Optional[Dict] = {}
def __init__(
self,
model_name: str = "nomic-embed-text",
base_url: str = None,
embedding_dims: int = None,
**kwargs
):
super().__init__(model_name=model_name, embed_batch_size=10)
self.base_url = base_url or "http://localhost:11434"
self.embedding_dims = embedding_dims or 512
self.kwargs = kwargs
if not EmbeddingProvider.validate_model(EmbeddingProvider.OLLAMA, model_name):
raise ValueError(f"Unsupported Ollama model: {model_name}. Supported models: {SUPPORTED_MODELS['ollama']}")
try:
self.client = Client(host=self.base_url)
self._ensure_model_exists()
logger.debug(f"Initialized Ollama embedding model: {model_name}")
except Exception as e:
logger.error(f"Failed to initialize Ollama client: {str(e)}")
raise
def _ensure_model_exists(self):
"""Ensure the specified model exists locally, pulling it if necessary."""
try:
local_models = self.client.list()["models"]
if not any(model.get("name") == self.model_name for model in local_models):
logger.info(f"Pulling Ollama model: {self.model_name}")
self.client.pull(self.model_name)
except Exception as e:
logger.error(f"Failed to ensure Ollama model exists: {str(e)}")
raise
def _get_query_embedding(self, query: str) -> List[float]:
"""Get embedding for a query string."""
try:
response = self.client.embeddings(model=self.model_name, prompt=query, **self.kwargs)
return response["embedding"]
except Exception as e:
logger.error(f"Failed to encode query: {str(e)}")
raise
def _get_text_embedding(self, text: str) -> List[float]:
"""Get embedding for a text string."""
try:
response = self.client.embeddings(model=self.model_name, prompt=text, **self.kwargs)
return response["embedding"]
except Exception as e:
logger.error(f"Failed to encode text: {str(e)}")
raise
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get embeddings for a list of texts synchronously."""
try:
embeddings = []
for i in range(0, len(texts), self.embed_batch_size):
batch = texts[i:i + self.embed_batch_size]
batch_embeddings = [self._get_text_embedding(text) for text in batch]
embeddings.extend(batch_embeddings)
return embeddings
except Exception as e:
logger.error(f"Failed to encode texts: {str(e)}")
raise
async def _aget_query_embedding(self, query: str) -> List[float]:
"""Asynchronous query embedding (falls back to sync)."""
return self._get_query_embedding(query)
@property
def dimension(self) -> int:
"""Return the embedding dimension."""
return self.embedding_dims
class OllamaEmbeddingWrapper(BaseEmbeddingWrapper):
"""Wrapper for Ollama embedding models."""
def __init__(
self,
model_name: str = "nomic-embed-text",
base_url: str = None,
dimensions: int = None,
**kwargs
):
self.model_name = model_name
self.base_url = base_url
self._dimensions = MODEL_DIMENSIONS.get(model_name, None) or dimensions
self.kwargs = kwargs
self._embedding_model = None
self._embedding_model = self.get_embedding_model()
def get_embedding_model(self) -> BaseEmbedding:
"""Return the LlamaIndex-compatible embedding model."""
if self._embedding_model is None:
try:
self._embedding_model = OllamaEmbedding(
model_name=self.model_name,
base_url=self.base_url,
embedding_dims=self._dimensions,
**self.kwargs
)
logger.debug(f"Initialized Ollama embedding wrapper for model: {self.model_name}")
except Exception as e:
logger.error(f"Failed to initialize Ollama embedding wrapper: {str(e)}")
raise
return self._embedding_model
@property
def dimensions(self) -> int:
"""Return the embedding dimensions."""
return self._dimensions