|
|
import sys |
|
|
import subprocess |
|
|
from typing import List, Dict, Optional |
|
|
|
|
|
from llama_index.core.embeddings import BaseEmbedding |
|
|
|
|
|
from evoagentx.core.logging import logger |
|
|
from .base import BaseEmbeddingWrapper, EmbeddingProvider, SUPPORTED_MODELS |
|
|
|
|
|
try: |
|
|
from ollama import Client |
|
|
except ImportError: |
|
|
logger.warning("The 'ollama' library is not installed. Attempting to install it.") |
|
|
try: |
|
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "ollama"]) |
|
|
from ollama import Client |
|
|
except subprocess.CalledProcessError: |
|
|
logger.error("Failed to install 'ollama'. Please install it manually using 'pip install ollama'.") |
|
|
raise ImportError("The 'ollama' library is required.") |
|
|
|
|
|
|
|
|
MODEL_DIMENSIONS = { |
|
|
"nomic-embed-text": 384, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
class OllamaEmbedding(BaseEmbedding): |
|
|
"""Ollama embedding model compatible with LlamaIndex BaseEmbedding.""" |
|
|
|
|
|
base_url: str = None |
|
|
client: Client = None |
|
|
model_name: str = "nomic-embed-text" |
|
|
embed_batch_size: int = 10 |
|
|
embedding_dims: int = None |
|
|
kwargs: Optional[Dict] = {} |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_name: str = "nomic-embed-text", |
|
|
base_url: str = None, |
|
|
embedding_dims: int = None, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__(model_name=model_name, embed_batch_size=10) |
|
|
self.base_url = base_url or "http://localhost:11434" |
|
|
self.embedding_dims = embedding_dims or 512 |
|
|
self.kwargs = kwargs |
|
|
|
|
|
if not EmbeddingProvider.validate_model(EmbeddingProvider.OLLAMA, model_name): |
|
|
raise ValueError(f"Unsupported Ollama model: {model_name}. Supported models: {SUPPORTED_MODELS['ollama']}") |
|
|
|
|
|
try: |
|
|
self.client = Client(host=self.base_url) |
|
|
self._ensure_model_exists() |
|
|
logger.debug(f"Initialized Ollama embedding model: {model_name}") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to initialize Ollama client: {str(e)}") |
|
|
raise |
|
|
|
|
|
def _ensure_model_exists(self): |
|
|
"""Ensure the specified model exists locally, pulling it if necessary.""" |
|
|
try: |
|
|
local_models = self.client.list()["models"] |
|
|
if not any(model.get("name") == self.model_name for model in local_models): |
|
|
logger.info(f"Pulling Ollama model: {self.model_name}") |
|
|
self.client.pull(self.model_name) |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to ensure Ollama model exists: {str(e)}") |
|
|
raise |
|
|
|
|
|
def _get_query_embedding(self, query: str) -> List[float]: |
|
|
"""Get embedding for a query string.""" |
|
|
try: |
|
|
response = self.client.embeddings(model=self.model_name, prompt=query, **self.kwargs) |
|
|
return response["embedding"] |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to encode query: {str(e)}") |
|
|
raise |
|
|
|
|
|
def _get_text_embedding(self, text: str) -> List[float]: |
|
|
"""Get embedding for a text string.""" |
|
|
try: |
|
|
response = self.client.embeddings(model=self.model_name, prompt=text, **self.kwargs) |
|
|
return response["embedding"] |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to encode text: {str(e)}") |
|
|
raise |
|
|
|
|
|
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: |
|
|
"""Get embeddings for a list of texts synchronously.""" |
|
|
try: |
|
|
embeddings = [] |
|
|
for i in range(0, len(texts), self.embed_batch_size): |
|
|
batch = texts[i:i + self.embed_batch_size] |
|
|
batch_embeddings = [self._get_text_embedding(text) for text in batch] |
|
|
embeddings.extend(batch_embeddings) |
|
|
return embeddings |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to encode texts: {str(e)}") |
|
|
raise |
|
|
|
|
|
async def _aget_query_embedding(self, query: str) -> List[float]: |
|
|
"""Asynchronous query embedding (falls back to sync).""" |
|
|
return self._get_query_embedding(query) |
|
|
|
|
|
@property |
|
|
def dimension(self) -> int: |
|
|
"""Return the embedding dimension.""" |
|
|
return self.embedding_dims |
|
|
|
|
|
|
|
|
class OllamaEmbeddingWrapper(BaseEmbeddingWrapper): |
|
|
"""Wrapper for Ollama embedding models.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_name: str = "nomic-embed-text", |
|
|
base_url: str = None, |
|
|
dimensions: int = None, |
|
|
**kwargs |
|
|
): |
|
|
self.model_name = model_name |
|
|
self.base_url = base_url |
|
|
self._dimensions = MODEL_DIMENSIONS.get(model_name, None) or dimensions |
|
|
self.kwargs = kwargs |
|
|
self._embedding_model = None |
|
|
self._embedding_model = self.get_embedding_model() |
|
|
|
|
|
def get_embedding_model(self) -> BaseEmbedding: |
|
|
"""Return the LlamaIndex-compatible embedding model.""" |
|
|
if self._embedding_model is None: |
|
|
try: |
|
|
self._embedding_model = OllamaEmbedding( |
|
|
model_name=self.model_name, |
|
|
base_url=self.base_url, |
|
|
embedding_dims=self._dimensions, |
|
|
**self.kwargs |
|
|
) |
|
|
logger.debug(f"Initialized Ollama embedding wrapper for model: {self.model_name}") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to initialize Ollama embedding wrapper: {str(e)}") |
|
|
raise |
|
|
return self._embedding_model |
|
|
|
|
|
@property |
|
|
def dimensions(self) -> int: |
|
|
"""Return the embedding dimensions.""" |
|
|
return self._dimensions |