SPARKNET / src /llm /ollama_client.py
MHamdan's picture
Initial commit: SPARKNET framework
a9dc537
"""
Ollama Client for SPARKNET
Handles communication with local Ollama LLM models
"""
import ollama
from typing import List, Dict, Optional, Generator, Any
from loguru import logger
import json
class OllamaClient:
"""Client for interacting with Ollama LLM models."""
def __init__(
self,
host: str = "localhost",
port: int = 11434,
default_model: str = "llama3.2:latest",
timeout: int = 300,
):
"""
Initialize Ollama client.
Args:
host: Ollama server host
port: Ollama server port
default_model: Default model to use
timeout: Request timeout in seconds
"""
self.host = host
self.port = port
self.base_url = f"http://{host}:{port}"
self.default_model = default_model
self.timeout = timeout
self.client = ollama.Client(host=self.base_url)
logger.info(f"Initialized Ollama client: {self.base_url}")
def list_models(self) -> List[Dict[str, Any]]:
"""
List available models.
Returns:
List of model information dictionaries
"""
try:
response = self.client.list()
models = response.get("models", [])
logger.info(f"Found {len(models)} available models")
return models
except Exception as e:
logger.error(f"Error listing models: {e}")
return []
def pull_model(self, model_name: str) -> bool:
"""
Pull/download a model.
Args:
model_name: Name of the model to pull
Returns:
True if successful, False otherwise
"""
try:
logger.info(f"Pulling model: {model_name}")
self.client.pull(model_name)
logger.info(f"Successfully pulled model: {model_name}")
return True
except Exception as e:
logger.error(f"Error pulling model {model_name}: {e}")
return False
def generate(
self,
prompt: str,
model: Optional[str] = None,
system: Optional[str] = None,
temperature: float = 0.7,
max_tokens: Optional[int] = None,
stream: bool = False,
**kwargs,
) -> str | Generator[str, None, None]:
"""
Generate completion from a prompt.
Args:
prompt: Input prompt
model: Model to use (default: self.default_model)
system: System prompt
temperature: Sampling temperature
max_tokens: Maximum tokens to generate
stream: Whether to stream the response
**kwargs: Additional generation parameters
Returns:
Generated text or generator if streaming
"""
model = model or self.default_model
options = {
"temperature": temperature,
}
if max_tokens:
options["num_predict"] = max_tokens
options.update(kwargs)
try:
logger.debug(f"Generating with model {model}, prompt length: {len(prompt)}")
if stream:
return self._generate_stream(prompt, model, system, options)
else:
response = self.client.generate(
model=model,
prompt=prompt,
system=system,
options=options,
)
generated_text = response.get("response", "")
logger.debug(f"Generated {len(generated_text)} characters")
return generated_text
except Exception as e:
logger.error(f"Error generating completion: {e}")
return ""
def _generate_stream(
self,
prompt: str,
model: str,
system: Optional[str],
options: Dict,
) -> Generator[str, None, None]:
"""
Generate streaming completion.
Args:
prompt: Input prompt
model: Model to use
system: System prompt
options: Generation options
Yields:
Generated text chunks
"""
try:
stream = self.client.generate(
model=model,
prompt=prompt,
system=system,
options=options,
stream=True,
)
for chunk in stream:
if "response" in chunk:
yield chunk["response"]
except Exception as e:
logger.error(f"Error in streaming generation: {e}")
yield ""
def chat(
self,
messages: List[Dict[str, str]],
model: Optional[str] = None,
temperature: float = 0.7,
stream: bool = False,
**kwargs,
) -> str | Generator[str, None, None]:
"""
Chat completion with conversation history.
Args:
messages: List of message dicts with 'role' and 'content'
model: Model to use (default: self.default_model)
temperature: Sampling temperature
stream: Whether to stream the response
**kwargs: Additional chat parameters
Returns:
Response text or generator if streaming
"""
model = model or self.default_model
options = {
"temperature": temperature,
}
options.update(kwargs)
try:
logger.debug(f"Chat with model {model}, {len(messages)} messages")
if stream:
return self._chat_stream(messages, model, options)
else:
response = self.client.chat(
model=model,
messages=messages,
options=options,
)
message = response.get("message", {})
content = message.get("content", "")
logger.debug(f"Chat response: {len(content)} characters")
return content
except Exception as e:
logger.error(f"Error in chat completion: {e}")
return ""
def _chat_stream(
self,
messages: List[Dict[str, str]],
model: str,
options: Dict,
) -> Generator[str, None, None]:
"""
Streaming chat completion.
Args:
messages: List of message dicts
model: Model to use
options: Chat options
Yields:
Response text chunks
"""
try:
stream = self.client.chat(
model=model,
messages=messages,
options=options,
stream=True,
)
for chunk in stream:
if "message" in chunk:
message = chunk["message"]
if "content" in message:
yield message["content"]
except Exception as e:
logger.error(f"Error in streaming chat: {e}")
yield ""
def embed(
self,
text: str | List[str],
model: str = "nomic-embed-text:latest",
) -> List[List[float]]:
"""
Generate embeddings for text.
Args:
text: Text or list of texts to embed
model: Embedding model to use
Returns:
List of embedding vectors
"""
try:
if isinstance(text, str):
text = [text]
logger.debug(f"Generating embeddings for {len(text)} texts")
embeddings = []
for t in text:
response = self.client.embeddings(model=model, prompt=t)
embedding = response.get("embedding", [])
embeddings.append(embedding)
logger.debug(f"Generated {len(embeddings)} embeddings")
return embeddings
except Exception as e:
logger.error(f"Error generating embeddings: {e}")
return []
def count_tokens(self, text: str) -> int:
"""
Estimate token count for text.
Simple estimation: ~4 characters per token for English text.
Args:
text: Text to count tokens for
Returns:
Estimated token count
"""
# Simple estimation - this can be improved with proper tokenization
return len(text) // 4
def is_available(self) -> bool:
"""
Check if Ollama server is available.
Returns:
True if server is responding, False otherwise
"""
try:
self.list_models()
return True
except Exception:
return False
# Global Ollama client instance
_ollama_client: Optional[OllamaClient] = None
def get_ollama_client(
host: str = "localhost",
port: int = 11434,
default_model: str = "llama3.2:latest",
) -> OllamaClient:
"""Get or create the global Ollama client instance."""
global _ollama_client
if _ollama_client is None:
_ollama_client = OllamaClient(host=host, port=port, default_model=default_model)
return _ollama_client