File size: 5,636 Bytes
5374a2d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | import sys
import subprocess
from typing import List, Dict, Optional
from llama_index.core.embeddings import BaseEmbedding
from evoagentx.core.logging import logger
from .base import BaseEmbeddingWrapper, EmbeddingProvider, SUPPORTED_MODELS
try:
from ollama import Client
except ImportError:
logger.warning("The 'ollama' library is not installed. Attempting to install it.")
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", "ollama"])
from ollama import Client
except subprocess.CalledProcessError:
logger.error("Failed to install 'ollama'. Please install it manually using 'pip install ollama'.")
raise ImportError("The 'ollama' library is required.")
MODEL_DIMENSIONS = {
"nomic-embed-text": 384,
# "mxbai-embed-large": ,
# "bge-m3": ,
# "all-minilm": ,
# "snowflake-arctic-embed",
}
class OllamaEmbedding(BaseEmbedding):
"""Ollama embedding model compatible with LlamaIndex BaseEmbedding."""
base_url: str = None
client: Client = None
model_name: str = "nomic-embed-text"
embed_batch_size: int = 10
embedding_dims: int = None
kwargs: Optional[Dict] = {}
def __init__(
self,
model_name: str = "nomic-embed-text",
base_url: str = None,
embedding_dims: int = None,
**kwargs
):
super().__init__(model_name=model_name, embed_batch_size=10)
self.base_url = base_url or "http://localhost:11434"
self.embedding_dims = embedding_dims or 512
self.kwargs = kwargs
if not EmbeddingProvider.validate_model(EmbeddingProvider.OLLAMA, model_name):
raise ValueError(f"Unsupported Ollama model: {model_name}. Supported models: {SUPPORTED_MODELS['ollama']}")
try:
self.client = Client(host=self.base_url)
self._ensure_model_exists()
logger.debug(f"Initialized Ollama embedding model: {model_name}")
except Exception as e:
logger.error(f"Failed to initialize Ollama client: {str(e)}")
raise
def _ensure_model_exists(self):
"""Ensure the specified model exists locally, pulling it if necessary."""
try:
local_models = self.client.list()["models"]
if not any(model.get("name") == self.model_name for model in local_models):
logger.info(f"Pulling Ollama model: {self.model_name}")
self.client.pull(self.model_name)
except Exception as e:
logger.error(f"Failed to ensure Ollama model exists: {str(e)}")
raise
def _get_query_embedding(self, query: str) -> List[float]:
"""Get embedding for a query string."""
try:
response = self.client.embeddings(model=self.model_name, prompt=query, **self.kwargs)
return response["embedding"]
except Exception as e:
logger.error(f"Failed to encode query: {str(e)}")
raise
def _get_text_embedding(self, text: str) -> List[float]:
"""Get embedding for a text string."""
try:
response = self.client.embeddings(model=self.model_name, prompt=text, **self.kwargs)
return response["embedding"]
except Exception as e:
logger.error(f"Failed to encode text: {str(e)}")
raise
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get embeddings for a list of texts synchronously."""
try:
embeddings = []
for i in range(0, len(texts), self.embed_batch_size):
batch = texts[i:i + self.embed_batch_size]
batch_embeddings = [self._get_text_embedding(text) for text in batch]
embeddings.extend(batch_embeddings)
return embeddings
except Exception as e:
logger.error(f"Failed to encode texts: {str(e)}")
raise
async def _aget_query_embedding(self, query: str) -> List[float]:
"""Asynchronous query embedding (falls back to sync)."""
return self._get_query_embedding(query)
@property
def dimension(self) -> int:
"""Return the embedding dimension."""
return self.embedding_dims
class OllamaEmbeddingWrapper(BaseEmbeddingWrapper):
"""Wrapper for Ollama embedding models."""
def __init__(
self,
model_name: str = "nomic-embed-text",
base_url: str = None,
dimensions: int = None,
**kwargs
):
self.model_name = model_name
self.base_url = base_url
self._dimensions = MODEL_DIMENSIONS.get(model_name, None) or dimensions
self.kwargs = kwargs
self._embedding_model = None
self._embedding_model = self.get_embedding_model()
def get_embedding_model(self) -> BaseEmbedding:
"""Return the LlamaIndex-compatible embedding model."""
if self._embedding_model is None:
try:
self._embedding_model = OllamaEmbedding(
model_name=self.model_name,
base_url=self.base_url,
embedding_dims=self._dimensions,
**self.kwargs
)
logger.debug(f"Initialized Ollama embedding wrapper for model: {self.model_name}")
except Exception as e:
logger.error(f"Failed to initialize Ollama embedding wrapper: {str(e)}")
raise
return self._embedding_model
@property
def dimensions(self) -> int:
"""Return the embedding dimensions."""
return self._dimensions |