File size: 5,636 Bytes
5374a2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import sys
import subprocess
from typing import List, Dict, Optional

from llama_index.core.embeddings import BaseEmbedding

from evoagentx.core.logging import logger
from .base import BaseEmbeddingWrapper, EmbeddingProvider, SUPPORTED_MODELS

try:
    from ollama import Client
except ImportError:
    logger.warning("The 'ollama' library is not installed. Attempting to install it.")
    try:
        subprocess.check_call([sys.executable, "-m", "pip", "install", "ollama"])
        from ollama import Client
    except subprocess.CalledProcessError:
        logger.error("Failed to install 'ollama'. Please install it manually using 'pip install ollama'.")
        raise ImportError("The 'ollama' library is required.")


MODEL_DIMENSIONS = {
    "nomic-embed-text": 384,
    # "mxbai-embed-large": ,
    # "bge-m3": ,
    # "all-minilm": ,
    # "snowflake-arctic-embed",
}

class OllamaEmbedding(BaseEmbedding):
    """Ollama embedding model compatible with LlamaIndex BaseEmbedding."""
    
    base_url: str = None
    client: Client = None
    model_name: str = "nomic-embed-text"
    embed_batch_size: int = 10
    embedding_dims: int = None
    kwargs: Optional[Dict] = {}
    
    def __init__(
        self,
        model_name: str = "nomic-embed-text",
        base_url: str = None,
        embedding_dims: int = None,
        **kwargs
    ):
        super().__init__(model_name=model_name, embed_batch_size=10)
        self.base_url = base_url or "http://localhost:11434"
        self.embedding_dims = embedding_dims or 512
        self.kwargs = kwargs

        if not EmbeddingProvider.validate_model(EmbeddingProvider.OLLAMA, model_name):
            raise ValueError(f"Unsupported Ollama model: {model_name}. Supported models: {SUPPORTED_MODELS['ollama']}")

        try:
            self.client = Client(host=self.base_url)
            self._ensure_model_exists()
            logger.debug(f"Initialized Ollama embedding model: {model_name}")
        except Exception as e:
            logger.error(f"Failed to initialize Ollama client: {str(e)}")
            raise

    def _ensure_model_exists(self):
        """Ensure the specified model exists locally, pulling it if necessary."""
        try:
            local_models = self.client.list()["models"]
            if not any(model.get("name") == self.model_name for model in local_models):
                logger.info(f"Pulling Ollama model: {self.model_name}")
                self.client.pull(self.model_name)
        except Exception as e:
            logger.error(f"Failed to ensure Ollama model exists: {str(e)}")
            raise

    def _get_query_embedding(self, query: str) -> List[float]:
        """Get embedding for a query string."""
        try:
            response = self.client.embeddings(model=self.model_name, prompt=query, **self.kwargs)
            return response["embedding"]
        except Exception as e:
            logger.error(f"Failed to encode query: {str(e)}")
            raise

    def _get_text_embedding(self, text: str) -> List[float]:
        """Get embedding for a text string."""
        try:
            response = self.client.embeddings(model=self.model_name, prompt=text, **self.kwargs)
            return response["embedding"]
        except Exception as e:
            logger.error(f"Failed to encode text: {str(e)}")
            raise

    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        """Get embeddings for a list of texts synchronously."""
        try:
            embeddings = []
            for i in range(0, len(texts), self.embed_batch_size):
                batch = texts[i:i + self.embed_batch_size]
                batch_embeddings = [self._get_text_embedding(text) for text in batch]
                embeddings.extend(batch_embeddings)
            return embeddings
        except Exception as e:
            logger.error(f"Failed to encode texts: {str(e)}")
            raise

    async def _aget_query_embedding(self, query: str) -> List[float]:
        """Asynchronous query embedding (falls back to sync)."""
        return self._get_query_embedding(query)

    @property
    def dimension(self) -> int:
        """Return the embedding dimension."""
        return self.embedding_dims


class OllamaEmbeddingWrapper(BaseEmbeddingWrapper):
    """Wrapper for Ollama embedding models."""
    
    def __init__(
        self,
        model_name: str = "nomic-embed-text",
        base_url: str = None,
        dimensions: int = None,
        **kwargs
    ):
        self.model_name = model_name
        self.base_url = base_url
        self._dimensions = MODEL_DIMENSIONS.get(model_name, None) or dimensions 
        self.kwargs = kwargs
        self._embedding_model = None
        self._embedding_model = self.get_embedding_model()

    def get_embedding_model(self) -> BaseEmbedding:
        """Return the LlamaIndex-compatible embedding model."""
        if self._embedding_model is None:
            try:
                self._embedding_model = OllamaEmbedding(
                    model_name=self.model_name,
                    base_url=self.base_url,
                    embedding_dims=self._dimensions,
                    **self.kwargs
                )
                logger.debug(f"Initialized Ollama embedding wrapper for model: {self.model_name}")
            except Exception as e:
                logger.error(f"Failed to initialize Ollama embedding wrapper: {str(e)}")
                raise
        return self._embedding_model
    
    @property
    def dimensions(self) -> int:
        """Return the embedding dimensions."""
        return self._dimensions