|
|
""" |
|
|
Ollama Client for SPARKNET |
|
|
Handles communication with local Ollama LLM models |
|
|
""" |
|
|
|
|
|
import ollama |
|
|
from typing import List, Dict, Optional, Generator, Any |
|
|
from loguru import logger |
|
|
import json |
|
|
|
|
|
|
|
|
class OllamaClient: |
|
|
"""Client for interacting with Ollama LLM models.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
host: str = "localhost", |
|
|
port: int = 11434, |
|
|
default_model: str = "llama3.2:latest", |
|
|
timeout: int = 300, |
|
|
): |
|
|
""" |
|
|
Initialize Ollama client. |
|
|
|
|
|
Args: |
|
|
host: Ollama server host |
|
|
port: Ollama server port |
|
|
default_model: Default model to use |
|
|
timeout: Request timeout in seconds |
|
|
""" |
|
|
self.host = host |
|
|
self.port = port |
|
|
self.base_url = f"http://{host}:{port}" |
|
|
self.default_model = default_model |
|
|
self.timeout = timeout |
|
|
self.client = ollama.Client(host=self.base_url) |
|
|
|
|
|
logger.info(f"Initialized Ollama client: {self.base_url}") |
|
|
|
|
|
def list_models(self) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
List available models. |
|
|
|
|
|
Returns: |
|
|
List of model information dictionaries |
|
|
""" |
|
|
try: |
|
|
response = self.client.list() |
|
|
models = response.get("models", []) |
|
|
logger.info(f"Found {len(models)} available models") |
|
|
return models |
|
|
except Exception as e: |
|
|
logger.error(f"Error listing models: {e}") |
|
|
return [] |
|
|
|
|
|
def pull_model(self, model_name: str) -> bool: |
|
|
""" |
|
|
Pull/download a model. |
|
|
|
|
|
Args: |
|
|
model_name: Name of the model to pull |
|
|
|
|
|
Returns: |
|
|
True if successful, False otherwise |
|
|
""" |
|
|
try: |
|
|
logger.info(f"Pulling model: {model_name}") |
|
|
self.client.pull(model_name) |
|
|
logger.info(f"Successfully pulled model: {model_name}") |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.error(f"Error pulling model {model_name}: {e}") |
|
|
return False |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
prompt: str, |
|
|
model: Optional[str] = None, |
|
|
system: Optional[str] = None, |
|
|
temperature: float = 0.7, |
|
|
max_tokens: Optional[int] = None, |
|
|
stream: bool = False, |
|
|
**kwargs, |
|
|
) -> str | Generator[str, None, None]: |
|
|
""" |
|
|
Generate completion from a prompt. |
|
|
|
|
|
Args: |
|
|
prompt: Input prompt |
|
|
model: Model to use (default: self.default_model) |
|
|
system: System prompt |
|
|
temperature: Sampling temperature |
|
|
max_tokens: Maximum tokens to generate |
|
|
stream: Whether to stream the response |
|
|
**kwargs: Additional generation parameters |
|
|
|
|
|
Returns: |
|
|
Generated text or generator if streaming |
|
|
""" |
|
|
model = model or self.default_model |
|
|
|
|
|
options = { |
|
|
"temperature": temperature, |
|
|
} |
|
|
if max_tokens: |
|
|
options["num_predict"] = max_tokens |
|
|
|
|
|
options.update(kwargs) |
|
|
|
|
|
try: |
|
|
logger.debug(f"Generating with model {model}, prompt length: {len(prompt)}") |
|
|
|
|
|
if stream: |
|
|
return self._generate_stream(prompt, model, system, options) |
|
|
else: |
|
|
response = self.client.generate( |
|
|
model=model, |
|
|
prompt=prompt, |
|
|
system=system, |
|
|
options=options, |
|
|
) |
|
|
generated_text = response.get("response", "") |
|
|
logger.debug(f"Generated {len(generated_text)} characters") |
|
|
return generated_text |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error generating completion: {e}") |
|
|
return "" |
|
|
|
|
|
def _generate_stream( |
|
|
self, |
|
|
prompt: str, |
|
|
model: str, |
|
|
system: Optional[str], |
|
|
options: Dict, |
|
|
) -> Generator[str, None, None]: |
|
|
""" |
|
|
Generate streaming completion. |
|
|
|
|
|
Args: |
|
|
prompt: Input prompt |
|
|
model: Model to use |
|
|
system: System prompt |
|
|
options: Generation options |
|
|
|
|
|
Yields: |
|
|
Generated text chunks |
|
|
""" |
|
|
try: |
|
|
stream = self.client.generate( |
|
|
model=model, |
|
|
prompt=prompt, |
|
|
system=system, |
|
|
options=options, |
|
|
stream=True, |
|
|
) |
|
|
|
|
|
for chunk in stream: |
|
|
if "response" in chunk: |
|
|
yield chunk["response"] |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in streaming generation: {e}") |
|
|
yield "" |
|
|
|
|
|
def chat( |
|
|
self, |
|
|
messages: List[Dict[str, str]], |
|
|
model: Optional[str] = None, |
|
|
temperature: float = 0.7, |
|
|
stream: bool = False, |
|
|
**kwargs, |
|
|
) -> str | Generator[str, None, None]: |
|
|
""" |
|
|
Chat completion with conversation history. |
|
|
|
|
|
Args: |
|
|
messages: List of message dicts with 'role' and 'content' |
|
|
model: Model to use (default: self.default_model) |
|
|
temperature: Sampling temperature |
|
|
stream: Whether to stream the response |
|
|
**kwargs: Additional chat parameters |
|
|
|
|
|
Returns: |
|
|
Response text or generator if streaming |
|
|
""" |
|
|
model = model or self.default_model |
|
|
|
|
|
options = { |
|
|
"temperature": temperature, |
|
|
} |
|
|
options.update(kwargs) |
|
|
|
|
|
try: |
|
|
logger.debug(f"Chat with model {model}, {len(messages)} messages") |
|
|
|
|
|
if stream: |
|
|
return self._chat_stream(messages, model, options) |
|
|
else: |
|
|
response = self.client.chat( |
|
|
model=model, |
|
|
messages=messages, |
|
|
options=options, |
|
|
) |
|
|
message = response.get("message", {}) |
|
|
content = message.get("content", "") |
|
|
logger.debug(f"Chat response: {len(content)} characters") |
|
|
return content |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in chat completion: {e}") |
|
|
return "" |
|
|
|
|
|
def _chat_stream( |
|
|
self, |
|
|
messages: List[Dict[str, str]], |
|
|
model: str, |
|
|
options: Dict, |
|
|
) -> Generator[str, None, None]: |
|
|
""" |
|
|
Streaming chat completion. |
|
|
|
|
|
Args: |
|
|
messages: List of message dicts |
|
|
model: Model to use |
|
|
options: Chat options |
|
|
|
|
|
Yields: |
|
|
Response text chunks |
|
|
""" |
|
|
try: |
|
|
stream = self.client.chat( |
|
|
model=model, |
|
|
messages=messages, |
|
|
options=options, |
|
|
stream=True, |
|
|
) |
|
|
|
|
|
for chunk in stream: |
|
|
if "message" in chunk: |
|
|
message = chunk["message"] |
|
|
if "content" in message: |
|
|
yield message["content"] |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in streaming chat: {e}") |
|
|
yield "" |
|
|
|
|
|
def embed( |
|
|
self, |
|
|
text: str | List[str], |
|
|
model: str = "nomic-embed-text:latest", |
|
|
) -> List[List[float]]: |
|
|
""" |
|
|
Generate embeddings for text. |
|
|
|
|
|
Args: |
|
|
text: Text or list of texts to embed |
|
|
model: Embedding model to use |
|
|
|
|
|
Returns: |
|
|
List of embedding vectors |
|
|
""" |
|
|
try: |
|
|
if isinstance(text, str): |
|
|
text = [text] |
|
|
|
|
|
logger.debug(f"Generating embeddings for {len(text)} texts") |
|
|
|
|
|
embeddings = [] |
|
|
for t in text: |
|
|
response = self.client.embeddings(model=model, prompt=t) |
|
|
embedding = response.get("embedding", []) |
|
|
embeddings.append(embedding) |
|
|
|
|
|
logger.debug(f"Generated {len(embeddings)} embeddings") |
|
|
return embeddings |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error generating embeddings: {e}") |
|
|
return [] |
|
|
|
|
|
def count_tokens(self, text: str) -> int: |
|
|
""" |
|
|
Estimate token count for text. |
|
|
Simple estimation: ~4 characters per token for English text. |
|
|
|
|
|
Args: |
|
|
text: Text to count tokens for |
|
|
|
|
|
Returns: |
|
|
Estimated token count |
|
|
""" |
|
|
|
|
|
return len(text) // 4 |
|
|
|
|
|
def is_available(self) -> bool: |
|
|
""" |
|
|
Check if Ollama server is available. |
|
|
|
|
|
Returns: |
|
|
True if server is responding, False otherwise |
|
|
""" |
|
|
try: |
|
|
self.list_models() |
|
|
return True |
|
|
except Exception: |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
_ollama_client: Optional[OllamaClient] = None |
|
|
|
|
|
|
|
|
def get_ollama_client( |
|
|
host: str = "localhost", |
|
|
port: int = 11434, |
|
|
default_model: str = "llama3.2:latest", |
|
|
) -> OllamaClient: |
|
|
"""Get or create the global Ollama client instance.""" |
|
|
global _ollama_client |
|
|
if _ollama_client is None: |
|
|
_ollama_client = OllamaClient(host=host, port=port, default_model=default_model) |
|
|
return _ollama_client |
|
|
|