SPARKNET / src /llm /langchain_ollama_client.py
MHamdan's picture
Initial commit: SPARKNET framework
a9dc537
"""
LangChain Ollama Client for SPARKNET
Integrates Ollama with LangChain for multi-model complexity routing
Provides unified interface for chat, embeddings, and GPU monitoring
"""
from typing import Optional, Dict, Any, List, Literal
from loguru import logger
from langchain_ollama import ChatOllama, OllamaEmbeddings
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import BaseMessage
from langchain_core.outputs import LLMResult
from ..utils.gpu_manager import get_gpu_manager
# Type alias for complexity levels
ComplexityLevel = Literal["simple", "standard", "complex", "analysis"]
class SparknetCallbackHandler(BaseCallbackHandler):
"""
Custom callback handler for SPARKNET.
Monitors GPU usage, token counts, and latency.
"""
def __init__(self):
super().__init__()
self.gpu_manager = get_gpu_manager()
self.token_count = 0
self.llm_calls = 0
def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
**kwargs: Any
) -> None:
"""Called when LLM starts processing."""
self.llm_calls += 1
gpu_status = self.gpu_manager.monitor()
logger.debug(f"LLM call #{self.llm_calls} started")
logger.debug(f"GPU Status: {gpu_status['gpus'][0]['memory_used']:.2f} GB used")
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Called when LLM finishes processing."""
# Count tokens if available
if hasattr(response, 'llm_output') and response.llm_output:
token_usage = response.llm_output.get('token_usage', {})
if token_usage:
self.token_count += token_usage.get('total_tokens', 0)
logger.debug(f"Tokens used: {token_usage.get('total_tokens', 0)}")
def on_llm_error(self, error: Exception, **kwargs: Any) -> None:
"""Called when LLM encounters an error."""
logger.error(f"LLM error: {error}")
def get_stats(self) -> Dict[str, Any]:
"""Get accumulated statistics."""
return {
'llm_calls': self.llm_calls,
'total_tokens': self.token_count,
'gpu_status': self.gpu_manager.monitor(),
}
class LangChainOllamaClient:
"""
LangChain-powered Ollama client with intelligent model routing.
Manages multiple Ollama models for different complexity levels:
- simple: Fast, lightweight tasks (gemma2:2b)
- standard: General-purpose tasks (llama3.1:8b)
- complex: Advanced reasoning and planning (qwen2.5:14b)
- analysis: Critical analysis and validation (mistral:latest)
Features:
- Automatic model selection based on task complexity
- GPU monitoring via custom callbacks
- Embedding generation for vector search
- Streaming and non-streaming support
"""
# Model configuration for each complexity level
MODEL_CONFIG: Dict[ComplexityLevel, Dict[str, Any]] = {
"simple": {
"model": "gemma2:2b",
"temperature": 0.3,
"max_tokens": 512,
"description": "Fast classification, routing, simple Q&A",
"size_gb": 1.6,
},
"standard": {
"model": "llama3.1:8b",
"temperature": 0.7,
"max_tokens": 1024,
"description": "General tasks, code generation, summarization",
"size_gb": 4.9,
},
"complex": {
"model": "qwen2.5:14b",
"temperature": 0.7,
"max_tokens": 2048,
"description": "Complex reasoning, planning, multi-step tasks",
"size_gb": 9.0,
},
"analysis": {
"model": "mistral:latest",
"temperature": 0.6,
"max_tokens": 1024,
"description": "Critical analysis, validation, quality assessment",
"size_gb": 4.4,
},
}
def __init__(
self,
base_url: str = "http://localhost:11434",
default_complexity: ComplexityLevel = "standard",
enable_monitoring: bool = True,
):
"""
Initialize LangChain Ollama client.
Args:
base_url: Ollama server URL
default_complexity: Default model complexity level
enable_monitoring: Enable GPU monitoring callbacks
"""
self.base_url = base_url
self.default_complexity = default_complexity
self.enable_monitoring = enable_monitoring
# Initialize callback handler
self.callback_handler = SparknetCallbackHandler() if enable_monitoring else None
self.callbacks = [self.callback_handler] if self.callback_handler else []
# Initialize LLMs for each complexity level
self.llms: Dict[ComplexityLevel, ChatOllama] = {}
self._initialize_models()
# Initialize embedding model
self.embeddings = OllamaEmbeddings(
base_url=base_url,
model="nomic-embed-text:latest",
)
logger.info(f"Initialized LangChainOllamaClient with {len(self.llms)} models")
logger.info(f"Default complexity: {default_complexity}")
def _initialize_models(self) -> None:
"""Initialize ChatOllama instances for each complexity level."""
for complexity, config in self.MODEL_CONFIG.items():
try:
self.llms[complexity] = ChatOllama(
base_url=self.base_url,
model=config["model"],
temperature=config["temperature"],
num_predict=config["max_tokens"],
callbacks=self.callbacks,
)
logger.debug(f"Initialized {complexity} model: {config['model']}")
except Exception as e:
logger.error(f"Failed to initialize {complexity} model: {e}")
def get_llm(
self,
complexity: Optional[ComplexityLevel] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> ChatOllama:
"""
Get LLM for specified complexity level.
Args:
complexity: Complexity level (simple, standard, complex, analysis)
temperature: Override default temperature
max_tokens: Override default max tokens
Returns:
ChatOllama instance
"""
complexity = complexity or self.default_complexity
if complexity not in self.llms:
logger.warning(f"Unknown complexity '{complexity}', using default")
complexity = self.default_complexity
# If no overrides, return cached instance
if temperature is None and max_tokens is None:
return self.llms[complexity]
# Create new instance with overridden parameters
config = self.MODEL_CONFIG[complexity]
return ChatOllama(
base_url=self.base_url,
model=config["model"],
temperature=temperature if temperature is not None else config["temperature"],
num_predict=max_tokens if max_tokens is not None else config["max_tokens"],
callbacks=self.callbacks,
)
def get_embeddings(self) -> OllamaEmbeddings:
"""
Get embedding model for vector operations.
Returns:
OllamaEmbeddings instance
"""
return self.embeddings
async def ainvoke(
self,
messages: List[BaseMessage],
complexity: Optional[ComplexityLevel] = None,
**kwargs: Any,
) -> BaseMessage:
"""
Async invoke LLM with messages.
Args:
messages: List of messages for the conversation
complexity: Model complexity level
**kwargs: Additional arguments for the LLM
Returns:
AI response message
"""
llm = self.get_llm(complexity)
response = await llm.ainvoke(messages, **kwargs)
return response
def invoke(
self,
messages: List[BaseMessage],
complexity: Optional[ComplexityLevel] = None,
**kwargs: Any,
) -> BaseMessage:
"""
Synchronous invoke LLM with messages.
Args:
messages: List of messages for the conversation
complexity: Model complexity level
**kwargs: Additional arguments for the LLM
Returns:
AI response message
"""
llm = self.get_llm(complexity)
response = llm.invoke(messages, **kwargs)
return response
async def astream(
self,
messages: List[BaseMessage],
complexity: Optional[ComplexityLevel] = None,
**kwargs: Any,
):
"""
Async stream LLM responses.
Args:
messages: List of messages for the conversation
complexity: Model complexity level
**kwargs: Additional arguments for the LLM
Yields:
Chunks of AI response
"""
llm = self.get_llm(complexity)
async for chunk in llm.astream(messages, **kwargs):
yield chunk
async def embed_text(self, text: str) -> List[float]:
"""
Generate embedding for text.
Args:
text: Text to embed
Returns:
Embedding vector
"""
embedding = await self.embeddings.aembed_query(text)
return embedding
async def embed_documents(self, documents: List[str]) -> List[List[float]]:
"""
Generate embeddings for multiple documents.
Args:
documents: List of documents to embed
Returns:
List of embedding vectors
"""
embeddings = await self.embeddings.aembed_documents(documents)
return embeddings
def get_model_info(self, complexity: Optional[ComplexityLevel] = None) -> Dict[str, Any]:
"""
Get information about a model.
Args:
complexity: Complexity level (defaults to current default)
Returns:
Model configuration dictionary
"""
complexity = complexity or self.default_complexity
return self.MODEL_CONFIG.get(complexity, {})
def list_models(self) -> Dict[ComplexityLevel, Dict[str, Any]]:
"""
List all available models and their configurations.
Returns:
Dictionary mapping complexity levels to model configs
"""
return self.MODEL_CONFIG.copy()
def get_stats(self) -> Dict[str, Any]:
"""
Get client statistics.
Returns:
Statistics dictionary
"""
if self.callback_handler:
return self.callback_handler.get_stats()
return {}
def recommend_complexity(self, task_description: str) -> ComplexityLevel:
"""
Recommend complexity level based on task description.
Uses simple heuristics to suggest appropriate model:
- Keywords like "plan", "analyze", "complex" → complex
- Keywords like "validate", "critique", "assess" → analysis
- Keywords like "classify", "route", "simple" → simple
- Default → standard
Args:
task_description: Natural language task description
Returns:
Recommended complexity level
"""
task_lower = task_description.lower()
# Complex tasks
if any(kw in task_lower for kw in ["plan", "strategy", "decompose", "workflow", "multi-step"]):
return "complex"
# Analysis tasks
if any(kw in task_lower for kw in ["validate", "critique", "assess", "review", "quality"]):
return "analysis"
# Simple tasks
if any(kw in task_lower for kw in ["classify", "route", "yes/no", "binary", "simple"]):
return "simple"
# Default to standard
return "standard"
# Convenience function for quick initialization
def get_langchain_client(
base_url: str = "http://localhost:11434",
default_complexity: ComplexityLevel = "standard",
enable_monitoring: bool = True,
) -> LangChainOllamaClient:
"""
Get a LangChain Ollama client instance.
Args:
base_url: Ollama server URL
default_complexity: Default model complexity
enable_monitoring: Enable GPU monitoring
Returns:
LangChainOllamaClient instance
"""
return LangChainOllamaClient(
base_url=base_url,
default_complexity=default_complexity,
enable_monitoring=enable_monitoring,
)