""" LangChain Ollama Client for SPARKNET Integrates Ollama with LangChain for multi-model complexity routing Provides unified interface for chat, embeddings, and GPU monitoring """ from typing import Optional, Dict, Any, List, Literal from loguru import logger from langchain_ollama import ChatOllama, OllamaEmbeddings from langchain_core.callbacks import BaseCallbackHandler from langchain_core.messages import BaseMessage from langchain_core.outputs import LLMResult from ..utils.gpu_manager import get_gpu_manager # Type alias for complexity levels ComplexityLevel = Literal["simple", "standard", "complex", "analysis"] class SparknetCallbackHandler(BaseCallbackHandler): """ Custom callback handler for SPARKNET. Monitors GPU usage, token counts, and latency. """ def __init__(self): super().__init__() self.gpu_manager = get_gpu_manager() self.token_count = 0 self.llm_calls = 0 def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: """Called when LLM starts processing.""" self.llm_calls += 1 gpu_status = self.gpu_manager.monitor() logger.debug(f"LLM call #{self.llm_calls} started") logger.debug(f"GPU Status: {gpu_status['gpus'][0]['memory_used']:.2f} GB used") def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: """Called when LLM finishes processing.""" # Count tokens if available if hasattr(response, 'llm_output') and response.llm_output: token_usage = response.llm_output.get('token_usage', {}) if token_usage: self.token_count += token_usage.get('total_tokens', 0) logger.debug(f"Tokens used: {token_usage.get('total_tokens', 0)}") def on_llm_error(self, error: Exception, **kwargs: Any) -> None: """Called when LLM encounters an error.""" logger.error(f"LLM error: {error}") def get_stats(self) -> Dict[str, Any]: """Get accumulated statistics.""" return { 'llm_calls': self.llm_calls, 'total_tokens': self.token_count, 'gpu_status': self.gpu_manager.monitor(), } class LangChainOllamaClient: """ LangChain-powered Ollama client with intelligent model routing. Manages multiple Ollama models for different complexity levels: - simple: Fast, lightweight tasks (gemma2:2b) - standard: General-purpose tasks (llama3.1:8b) - complex: Advanced reasoning and planning (qwen2.5:14b) - analysis: Critical analysis and validation (mistral:latest) Features: - Automatic model selection based on task complexity - GPU monitoring via custom callbacks - Embedding generation for vector search - Streaming and non-streaming support """ # Model configuration for each complexity level MODEL_CONFIG: Dict[ComplexityLevel, Dict[str, Any]] = { "simple": { "model": "gemma2:2b", "temperature": 0.3, "max_tokens": 512, "description": "Fast classification, routing, simple Q&A", "size_gb": 1.6, }, "standard": { "model": "llama3.1:8b", "temperature": 0.7, "max_tokens": 1024, "description": "General tasks, code generation, summarization", "size_gb": 4.9, }, "complex": { "model": "qwen2.5:14b", "temperature": 0.7, "max_tokens": 2048, "description": "Complex reasoning, planning, multi-step tasks", "size_gb": 9.0, }, "analysis": { "model": "mistral:latest", "temperature": 0.6, "max_tokens": 1024, "description": "Critical analysis, validation, quality assessment", "size_gb": 4.4, }, } def __init__( self, base_url: str = "http://localhost:11434", default_complexity: ComplexityLevel = "standard", enable_monitoring: bool = True, ): """ Initialize LangChain Ollama client. Args: base_url: Ollama server URL default_complexity: Default model complexity level enable_monitoring: Enable GPU monitoring callbacks """ self.base_url = base_url self.default_complexity = default_complexity self.enable_monitoring = enable_monitoring # Initialize callback handler self.callback_handler = SparknetCallbackHandler() if enable_monitoring else None self.callbacks = [self.callback_handler] if self.callback_handler else [] # Initialize LLMs for each complexity level self.llms: Dict[ComplexityLevel, ChatOllama] = {} self._initialize_models() # Initialize embedding model self.embeddings = OllamaEmbeddings( base_url=base_url, model="nomic-embed-text:latest", ) logger.info(f"Initialized LangChainOllamaClient with {len(self.llms)} models") logger.info(f"Default complexity: {default_complexity}") def _initialize_models(self) -> None: """Initialize ChatOllama instances for each complexity level.""" for complexity, config in self.MODEL_CONFIG.items(): try: self.llms[complexity] = ChatOllama( base_url=self.base_url, model=config["model"], temperature=config["temperature"], num_predict=config["max_tokens"], callbacks=self.callbacks, ) logger.debug(f"Initialized {complexity} model: {config['model']}") except Exception as e: logger.error(f"Failed to initialize {complexity} model: {e}") def get_llm( self, complexity: Optional[ComplexityLevel] = None, temperature: Optional[float] = None, max_tokens: Optional[int] = None, ) -> ChatOllama: """ Get LLM for specified complexity level. Args: complexity: Complexity level (simple, standard, complex, analysis) temperature: Override default temperature max_tokens: Override default max tokens Returns: ChatOllama instance """ complexity = complexity or self.default_complexity if complexity not in self.llms: logger.warning(f"Unknown complexity '{complexity}', using default") complexity = self.default_complexity # If no overrides, return cached instance if temperature is None and max_tokens is None: return self.llms[complexity] # Create new instance with overridden parameters config = self.MODEL_CONFIG[complexity] return ChatOllama( base_url=self.base_url, model=config["model"], temperature=temperature if temperature is not None else config["temperature"], num_predict=max_tokens if max_tokens is not None else config["max_tokens"], callbacks=self.callbacks, ) def get_embeddings(self) -> OllamaEmbeddings: """ Get embedding model for vector operations. Returns: OllamaEmbeddings instance """ return self.embeddings async def ainvoke( self, messages: List[BaseMessage], complexity: Optional[ComplexityLevel] = None, **kwargs: Any, ) -> BaseMessage: """ Async invoke LLM with messages. Args: messages: List of messages for the conversation complexity: Model complexity level **kwargs: Additional arguments for the LLM Returns: AI response message """ llm = self.get_llm(complexity) response = await llm.ainvoke(messages, **kwargs) return response def invoke( self, messages: List[BaseMessage], complexity: Optional[ComplexityLevel] = None, **kwargs: Any, ) -> BaseMessage: """ Synchronous invoke LLM with messages. Args: messages: List of messages for the conversation complexity: Model complexity level **kwargs: Additional arguments for the LLM Returns: AI response message """ llm = self.get_llm(complexity) response = llm.invoke(messages, **kwargs) return response async def astream( self, messages: List[BaseMessage], complexity: Optional[ComplexityLevel] = None, **kwargs: Any, ): """ Async stream LLM responses. Args: messages: List of messages for the conversation complexity: Model complexity level **kwargs: Additional arguments for the LLM Yields: Chunks of AI response """ llm = self.get_llm(complexity) async for chunk in llm.astream(messages, **kwargs): yield chunk async def embed_text(self, text: str) -> List[float]: """ Generate embedding for text. Args: text: Text to embed Returns: Embedding vector """ embedding = await self.embeddings.aembed_query(text) return embedding async def embed_documents(self, documents: List[str]) -> List[List[float]]: """ Generate embeddings for multiple documents. Args: documents: List of documents to embed Returns: List of embedding vectors """ embeddings = await self.embeddings.aembed_documents(documents) return embeddings def get_model_info(self, complexity: Optional[ComplexityLevel] = None) -> Dict[str, Any]: """ Get information about a model. Args: complexity: Complexity level (defaults to current default) Returns: Model configuration dictionary """ complexity = complexity or self.default_complexity return self.MODEL_CONFIG.get(complexity, {}) def list_models(self) -> Dict[ComplexityLevel, Dict[str, Any]]: """ List all available models and their configurations. Returns: Dictionary mapping complexity levels to model configs """ return self.MODEL_CONFIG.copy() def get_stats(self) -> Dict[str, Any]: """ Get client statistics. Returns: Statistics dictionary """ if self.callback_handler: return self.callback_handler.get_stats() return {} def recommend_complexity(self, task_description: str) -> ComplexityLevel: """ Recommend complexity level based on task description. Uses simple heuristics to suggest appropriate model: - Keywords like "plan", "analyze", "complex" → complex - Keywords like "validate", "critique", "assess" → analysis - Keywords like "classify", "route", "simple" → simple - Default → standard Args: task_description: Natural language task description Returns: Recommended complexity level """ task_lower = task_description.lower() # Complex tasks if any(kw in task_lower for kw in ["plan", "strategy", "decompose", "workflow", "multi-step"]): return "complex" # Analysis tasks if any(kw in task_lower for kw in ["validate", "critique", "assess", "review", "quality"]): return "analysis" # Simple tasks if any(kw in task_lower for kw in ["classify", "route", "yes/no", "binary", "simple"]): return "simple" # Default to standard return "standard" # Convenience function for quick initialization def get_langchain_client( base_url: str = "http://localhost:11434", default_complexity: ComplexityLevel = "standard", enable_monitoring: bool = True, ) -> LangChainOllamaClient: """ Get a LangChain Ollama client instance. Args: base_url: Ollama server URL default_complexity: Default model complexity enable_monitoring: Enable GPU monitoring Returns: LangChainOllamaClient instance """ return LangChainOllamaClient( base_url=base_url, default_complexity=default_complexity, enable_monitoring=enable_monitoring, )