|
|
""" |
|
|
LangChain Ollama Client for SPARKNET |
|
|
Integrates Ollama with LangChain for multi-model complexity routing |
|
|
Provides unified interface for chat, embeddings, and GPU monitoring |
|
|
""" |
|
|
|
|
|
from typing import Optional, Dict, Any, List, Literal |
|
|
from loguru import logger |
|
|
from langchain_ollama import ChatOllama, OllamaEmbeddings |
|
|
from langchain_core.callbacks import BaseCallbackHandler |
|
|
from langchain_core.messages import BaseMessage |
|
|
from langchain_core.outputs import LLMResult |
|
|
|
|
|
from ..utils.gpu_manager import get_gpu_manager |
|
|
|
|
|
|
|
|
|
|
|
ComplexityLevel = Literal["simple", "standard", "complex", "analysis"] |
|
|
|
|
|
|
|
|
class SparknetCallbackHandler(BaseCallbackHandler): |
|
|
""" |
|
|
Custom callback handler for SPARKNET. |
|
|
Monitors GPU usage, token counts, and latency. |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.gpu_manager = get_gpu_manager() |
|
|
self.token_count = 0 |
|
|
self.llm_calls = 0 |
|
|
|
|
|
def on_llm_start( |
|
|
self, |
|
|
serialized: Dict[str, Any], |
|
|
prompts: List[str], |
|
|
**kwargs: Any |
|
|
) -> None: |
|
|
"""Called when LLM starts processing.""" |
|
|
self.llm_calls += 1 |
|
|
gpu_status = self.gpu_manager.monitor() |
|
|
logger.debug(f"LLM call #{self.llm_calls} started") |
|
|
logger.debug(f"GPU Status: {gpu_status['gpus'][0]['memory_used']:.2f} GB used") |
|
|
|
|
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: |
|
|
"""Called when LLM finishes processing.""" |
|
|
|
|
|
if hasattr(response, 'llm_output') and response.llm_output: |
|
|
token_usage = response.llm_output.get('token_usage', {}) |
|
|
if token_usage: |
|
|
self.token_count += token_usage.get('total_tokens', 0) |
|
|
logger.debug(f"Tokens used: {token_usage.get('total_tokens', 0)}") |
|
|
|
|
|
def on_llm_error(self, error: Exception, **kwargs: Any) -> None: |
|
|
"""Called when LLM encounters an error.""" |
|
|
logger.error(f"LLM error: {error}") |
|
|
|
|
|
def get_stats(self) -> Dict[str, Any]: |
|
|
"""Get accumulated statistics.""" |
|
|
return { |
|
|
'llm_calls': self.llm_calls, |
|
|
'total_tokens': self.token_count, |
|
|
'gpu_status': self.gpu_manager.monitor(), |
|
|
} |
|
|
|
|
|
|
|
|
class LangChainOllamaClient: |
|
|
""" |
|
|
LangChain-powered Ollama client with intelligent model routing. |
|
|
|
|
|
Manages multiple Ollama models for different complexity levels: |
|
|
- simple: Fast, lightweight tasks (gemma2:2b) |
|
|
- standard: General-purpose tasks (llama3.1:8b) |
|
|
- complex: Advanced reasoning and planning (qwen2.5:14b) |
|
|
- analysis: Critical analysis and validation (mistral:latest) |
|
|
|
|
|
Features: |
|
|
- Automatic model selection based on task complexity |
|
|
- GPU monitoring via custom callbacks |
|
|
- Embedding generation for vector search |
|
|
- Streaming and non-streaming support |
|
|
""" |
|
|
|
|
|
|
|
|
MODEL_CONFIG: Dict[ComplexityLevel, Dict[str, Any]] = { |
|
|
"simple": { |
|
|
"model": "gemma2:2b", |
|
|
"temperature": 0.3, |
|
|
"max_tokens": 512, |
|
|
"description": "Fast classification, routing, simple Q&A", |
|
|
"size_gb": 1.6, |
|
|
}, |
|
|
"standard": { |
|
|
"model": "llama3.1:8b", |
|
|
"temperature": 0.7, |
|
|
"max_tokens": 1024, |
|
|
"description": "General tasks, code generation, summarization", |
|
|
"size_gb": 4.9, |
|
|
}, |
|
|
"complex": { |
|
|
"model": "qwen2.5:14b", |
|
|
"temperature": 0.7, |
|
|
"max_tokens": 2048, |
|
|
"description": "Complex reasoning, planning, multi-step tasks", |
|
|
"size_gb": 9.0, |
|
|
}, |
|
|
"analysis": { |
|
|
"model": "mistral:latest", |
|
|
"temperature": 0.6, |
|
|
"max_tokens": 1024, |
|
|
"description": "Critical analysis, validation, quality assessment", |
|
|
"size_gb": 4.4, |
|
|
}, |
|
|
} |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
base_url: str = "http://localhost:11434", |
|
|
default_complexity: ComplexityLevel = "standard", |
|
|
enable_monitoring: bool = True, |
|
|
): |
|
|
""" |
|
|
Initialize LangChain Ollama client. |
|
|
|
|
|
Args: |
|
|
base_url: Ollama server URL |
|
|
default_complexity: Default model complexity level |
|
|
enable_monitoring: Enable GPU monitoring callbacks |
|
|
""" |
|
|
self.base_url = base_url |
|
|
self.default_complexity = default_complexity |
|
|
self.enable_monitoring = enable_monitoring |
|
|
|
|
|
|
|
|
self.callback_handler = SparknetCallbackHandler() if enable_monitoring else None |
|
|
self.callbacks = [self.callback_handler] if self.callback_handler else [] |
|
|
|
|
|
|
|
|
self.llms: Dict[ComplexityLevel, ChatOllama] = {} |
|
|
self._initialize_models() |
|
|
|
|
|
|
|
|
self.embeddings = OllamaEmbeddings( |
|
|
base_url=base_url, |
|
|
model="nomic-embed-text:latest", |
|
|
) |
|
|
|
|
|
logger.info(f"Initialized LangChainOllamaClient with {len(self.llms)} models") |
|
|
logger.info(f"Default complexity: {default_complexity}") |
|
|
|
|
|
def _initialize_models(self) -> None: |
|
|
"""Initialize ChatOllama instances for each complexity level.""" |
|
|
for complexity, config in self.MODEL_CONFIG.items(): |
|
|
try: |
|
|
self.llms[complexity] = ChatOllama( |
|
|
base_url=self.base_url, |
|
|
model=config["model"], |
|
|
temperature=config["temperature"], |
|
|
num_predict=config["max_tokens"], |
|
|
callbacks=self.callbacks, |
|
|
) |
|
|
logger.debug(f"Initialized {complexity} model: {config['model']}") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to initialize {complexity} model: {e}") |
|
|
|
|
|
def get_llm( |
|
|
self, |
|
|
complexity: Optional[ComplexityLevel] = None, |
|
|
temperature: Optional[float] = None, |
|
|
max_tokens: Optional[int] = None, |
|
|
) -> ChatOllama: |
|
|
""" |
|
|
Get LLM for specified complexity level. |
|
|
|
|
|
Args: |
|
|
complexity: Complexity level (simple, standard, complex, analysis) |
|
|
temperature: Override default temperature |
|
|
max_tokens: Override default max tokens |
|
|
|
|
|
Returns: |
|
|
ChatOllama instance |
|
|
""" |
|
|
complexity = complexity or self.default_complexity |
|
|
|
|
|
if complexity not in self.llms: |
|
|
logger.warning(f"Unknown complexity '{complexity}', using default") |
|
|
complexity = self.default_complexity |
|
|
|
|
|
|
|
|
if temperature is None and max_tokens is None: |
|
|
return self.llms[complexity] |
|
|
|
|
|
|
|
|
config = self.MODEL_CONFIG[complexity] |
|
|
return ChatOllama( |
|
|
base_url=self.base_url, |
|
|
model=config["model"], |
|
|
temperature=temperature if temperature is not None else config["temperature"], |
|
|
num_predict=max_tokens if max_tokens is not None else config["max_tokens"], |
|
|
callbacks=self.callbacks, |
|
|
) |
|
|
|
|
|
def get_embeddings(self) -> OllamaEmbeddings: |
|
|
""" |
|
|
Get embedding model for vector operations. |
|
|
|
|
|
Returns: |
|
|
OllamaEmbeddings instance |
|
|
""" |
|
|
return self.embeddings |
|
|
|
|
|
async def ainvoke( |
|
|
self, |
|
|
messages: List[BaseMessage], |
|
|
complexity: Optional[ComplexityLevel] = None, |
|
|
**kwargs: Any, |
|
|
) -> BaseMessage: |
|
|
""" |
|
|
Async invoke LLM with messages. |
|
|
|
|
|
Args: |
|
|
messages: List of messages for the conversation |
|
|
complexity: Model complexity level |
|
|
**kwargs: Additional arguments for the LLM |
|
|
|
|
|
Returns: |
|
|
AI response message |
|
|
""" |
|
|
llm = self.get_llm(complexity) |
|
|
response = await llm.ainvoke(messages, **kwargs) |
|
|
return response |
|
|
|
|
|
def invoke( |
|
|
self, |
|
|
messages: List[BaseMessage], |
|
|
complexity: Optional[ComplexityLevel] = None, |
|
|
**kwargs: Any, |
|
|
) -> BaseMessage: |
|
|
""" |
|
|
Synchronous invoke LLM with messages. |
|
|
|
|
|
Args: |
|
|
messages: List of messages for the conversation |
|
|
complexity: Model complexity level |
|
|
**kwargs: Additional arguments for the LLM |
|
|
|
|
|
Returns: |
|
|
AI response message |
|
|
""" |
|
|
llm = self.get_llm(complexity) |
|
|
response = llm.invoke(messages, **kwargs) |
|
|
return response |
|
|
|
|
|
async def astream( |
|
|
self, |
|
|
messages: List[BaseMessage], |
|
|
complexity: Optional[ComplexityLevel] = None, |
|
|
**kwargs: Any, |
|
|
): |
|
|
""" |
|
|
Async stream LLM responses. |
|
|
|
|
|
Args: |
|
|
messages: List of messages for the conversation |
|
|
complexity: Model complexity level |
|
|
**kwargs: Additional arguments for the LLM |
|
|
|
|
|
Yields: |
|
|
Chunks of AI response |
|
|
""" |
|
|
llm = self.get_llm(complexity) |
|
|
async for chunk in llm.astream(messages, **kwargs): |
|
|
yield chunk |
|
|
|
|
|
async def embed_text(self, text: str) -> List[float]: |
|
|
""" |
|
|
Generate embedding for text. |
|
|
|
|
|
Args: |
|
|
text: Text to embed |
|
|
|
|
|
Returns: |
|
|
Embedding vector |
|
|
""" |
|
|
embedding = await self.embeddings.aembed_query(text) |
|
|
return embedding |
|
|
|
|
|
async def embed_documents(self, documents: List[str]) -> List[List[float]]: |
|
|
""" |
|
|
Generate embeddings for multiple documents. |
|
|
|
|
|
Args: |
|
|
documents: List of documents to embed |
|
|
|
|
|
Returns: |
|
|
List of embedding vectors |
|
|
""" |
|
|
embeddings = await self.embeddings.aembed_documents(documents) |
|
|
return embeddings |
|
|
|
|
|
def get_model_info(self, complexity: Optional[ComplexityLevel] = None) -> Dict[str, Any]: |
|
|
""" |
|
|
Get information about a model. |
|
|
|
|
|
Args: |
|
|
complexity: Complexity level (defaults to current default) |
|
|
|
|
|
Returns: |
|
|
Model configuration dictionary |
|
|
""" |
|
|
complexity = complexity or self.default_complexity |
|
|
return self.MODEL_CONFIG.get(complexity, {}) |
|
|
|
|
|
def list_models(self) -> Dict[ComplexityLevel, Dict[str, Any]]: |
|
|
""" |
|
|
List all available models and their configurations. |
|
|
|
|
|
Returns: |
|
|
Dictionary mapping complexity levels to model configs |
|
|
""" |
|
|
return self.MODEL_CONFIG.copy() |
|
|
|
|
|
def get_stats(self) -> Dict[str, Any]: |
|
|
""" |
|
|
Get client statistics. |
|
|
|
|
|
Returns: |
|
|
Statistics dictionary |
|
|
""" |
|
|
if self.callback_handler: |
|
|
return self.callback_handler.get_stats() |
|
|
return {} |
|
|
|
|
|
def recommend_complexity(self, task_description: str) -> ComplexityLevel: |
|
|
""" |
|
|
Recommend complexity level based on task description. |
|
|
|
|
|
Uses simple heuristics to suggest appropriate model: |
|
|
- Keywords like "plan", "analyze", "complex" → complex |
|
|
- Keywords like "validate", "critique", "assess" → analysis |
|
|
- Keywords like "classify", "route", "simple" → simple |
|
|
- Default → standard |
|
|
|
|
|
Args: |
|
|
task_description: Natural language task description |
|
|
|
|
|
Returns: |
|
|
Recommended complexity level |
|
|
""" |
|
|
task_lower = task_description.lower() |
|
|
|
|
|
|
|
|
if any(kw in task_lower for kw in ["plan", "strategy", "decompose", "workflow", "multi-step"]): |
|
|
return "complex" |
|
|
|
|
|
|
|
|
if any(kw in task_lower for kw in ["validate", "critique", "assess", "review", "quality"]): |
|
|
return "analysis" |
|
|
|
|
|
|
|
|
if any(kw in task_lower for kw in ["classify", "route", "yes/no", "binary", "simple"]): |
|
|
return "simple" |
|
|
|
|
|
|
|
|
return "standard" |
|
|
|
|
|
|
|
|
|
|
|
def get_langchain_client( |
|
|
base_url: str = "http://localhost:11434", |
|
|
default_complexity: ComplexityLevel = "standard", |
|
|
enable_monitoring: bool = True, |
|
|
) -> LangChainOllamaClient: |
|
|
""" |
|
|
Get a LangChain Ollama client instance. |
|
|
|
|
|
Args: |
|
|
base_url: Ollama server URL |
|
|
default_complexity: Default model complexity |
|
|
enable_monitoring: Enable GPU monitoring |
|
|
|
|
|
Returns: |
|
|
LangChainOllamaClient instance |
|
|
""" |
|
|
return LangChainOllamaClient( |
|
|
base_url=base_url, |
|
|
default_complexity=default_complexity, |
|
|
enable_monitoring=enable_monitoring, |
|
|
) |
|
|
|