File size: 12,593 Bytes
a9dc537
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
"""
LangChain Ollama Client for SPARKNET
Integrates Ollama with LangChain for multi-model complexity routing
Provides unified interface for chat, embeddings, and GPU monitoring
"""

from typing import Optional, Dict, Any, List, Literal
from loguru import logger
from langchain_ollama import ChatOllama, OllamaEmbeddings
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import BaseMessage
from langchain_core.outputs import LLMResult

from ..utils.gpu_manager import get_gpu_manager


# Type alias for complexity levels
ComplexityLevel = Literal["simple", "standard", "complex", "analysis"]


class SparknetCallbackHandler(BaseCallbackHandler):
    """
    Custom callback handler for SPARKNET.
    Monitors GPU usage, token counts, and latency.
    """

    def __init__(self):
        super().__init__()
        self.gpu_manager = get_gpu_manager()
        self.token_count = 0
        self.llm_calls = 0

    def on_llm_start(
        self,
        serialized: Dict[str, Any],
        prompts: List[str],
        **kwargs: Any
    ) -> None:
        """Called when LLM starts processing."""
        self.llm_calls += 1
        gpu_status = self.gpu_manager.monitor()
        logger.debug(f"LLM call #{self.llm_calls} started")
        logger.debug(f"GPU Status: {gpu_status['gpus'][0]['memory_used']:.2f} GB used")

    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
        """Called when LLM finishes processing."""
        # Count tokens if available
        if hasattr(response, 'llm_output') and response.llm_output:
            token_usage = response.llm_output.get('token_usage', {})
            if token_usage:
                self.token_count += token_usage.get('total_tokens', 0)
                logger.debug(f"Tokens used: {token_usage.get('total_tokens', 0)}")

    def on_llm_error(self, error: Exception, **kwargs: Any) -> None:
        """Called when LLM encounters an error."""
        logger.error(f"LLM error: {error}")

    def get_stats(self) -> Dict[str, Any]:
        """Get accumulated statistics."""
        return {
            'llm_calls': self.llm_calls,
            'total_tokens': self.token_count,
            'gpu_status': self.gpu_manager.monitor(),
        }


class LangChainOllamaClient:
    """
    LangChain-powered Ollama client with intelligent model routing.

    Manages multiple Ollama models for different complexity levels:
    - simple: Fast, lightweight tasks (gemma2:2b)
    - standard: General-purpose tasks (llama3.1:8b)
    - complex: Advanced reasoning and planning (qwen2.5:14b)
    - analysis: Critical analysis and validation (mistral:latest)

    Features:
    - Automatic model selection based on task complexity
    - GPU monitoring via custom callbacks
    - Embedding generation for vector search
    - Streaming and non-streaming support
    """

    # Model configuration for each complexity level
    MODEL_CONFIG: Dict[ComplexityLevel, Dict[str, Any]] = {
        "simple": {
            "model": "gemma2:2b",
            "temperature": 0.3,
            "max_tokens": 512,
            "description": "Fast classification, routing, simple Q&A",
            "size_gb": 1.6,
        },
        "standard": {
            "model": "llama3.1:8b",
            "temperature": 0.7,
            "max_tokens": 1024,
            "description": "General tasks, code generation, summarization",
            "size_gb": 4.9,
        },
        "complex": {
            "model": "qwen2.5:14b",
            "temperature": 0.7,
            "max_tokens": 2048,
            "description": "Complex reasoning, planning, multi-step tasks",
            "size_gb": 9.0,
        },
        "analysis": {
            "model": "mistral:latest",
            "temperature": 0.6,
            "max_tokens": 1024,
            "description": "Critical analysis, validation, quality assessment",
            "size_gb": 4.4,
        },
    }

    def __init__(
        self,
        base_url: str = "http://localhost:11434",
        default_complexity: ComplexityLevel = "standard",
        enable_monitoring: bool = True,
    ):
        """
        Initialize LangChain Ollama client.

        Args:
            base_url: Ollama server URL
            default_complexity: Default model complexity level
            enable_monitoring: Enable GPU monitoring callbacks
        """
        self.base_url = base_url
        self.default_complexity = default_complexity
        self.enable_monitoring = enable_monitoring

        # Initialize callback handler
        self.callback_handler = SparknetCallbackHandler() if enable_monitoring else None
        self.callbacks = [self.callback_handler] if self.callback_handler else []

        # Initialize LLMs for each complexity level
        self.llms: Dict[ComplexityLevel, ChatOllama] = {}
        self._initialize_models()

        # Initialize embedding model
        self.embeddings = OllamaEmbeddings(
            base_url=base_url,
            model="nomic-embed-text:latest",
        )

        logger.info(f"Initialized LangChainOllamaClient with {len(self.llms)} models")
        logger.info(f"Default complexity: {default_complexity}")

    def _initialize_models(self) -> None:
        """Initialize ChatOllama instances for each complexity level."""
        for complexity, config in self.MODEL_CONFIG.items():
            try:
                self.llms[complexity] = ChatOllama(
                    base_url=self.base_url,
                    model=config["model"],
                    temperature=config["temperature"],
                    num_predict=config["max_tokens"],
                    callbacks=self.callbacks,
                )
                logger.debug(f"Initialized {complexity} model: {config['model']}")
            except Exception as e:
                logger.error(f"Failed to initialize {complexity} model: {e}")

    def get_llm(
        self,
        complexity: Optional[ComplexityLevel] = None,
        temperature: Optional[float] = None,
        max_tokens: Optional[int] = None,
    ) -> ChatOllama:
        """
        Get LLM for specified complexity level.

        Args:
            complexity: Complexity level (simple, standard, complex, analysis)
            temperature: Override default temperature
            max_tokens: Override default max tokens

        Returns:
            ChatOllama instance
        """
        complexity = complexity or self.default_complexity

        if complexity not in self.llms:
            logger.warning(f"Unknown complexity '{complexity}', using default")
            complexity = self.default_complexity

        # If no overrides, return cached instance
        if temperature is None and max_tokens is None:
            return self.llms[complexity]

        # Create new instance with overridden parameters
        config = self.MODEL_CONFIG[complexity]
        return ChatOllama(
            base_url=self.base_url,
            model=config["model"],
            temperature=temperature if temperature is not None else config["temperature"],
            num_predict=max_tokens if max_tokens is not None else config["max_tokens"],
            callbacks=self.callbacks,
        )

    def get_embeddings(self) -> OllamaEmbeddings:
        """
        Get embedding model for vector operations.

        Returns:
            OllamaEmbeddings instance
        """
        return self.embeddings

    async def ainvoke(
        self,
        messages: List[BaseMessage],
        complexity: Optional[ComplexityLevel] = None,
        **kwargs: Any,
    ) -> BaseMessage:
        """
        Async invoke LLM with messages.

        Args:
            messages: List of messages for the conversation
            complexity: Model complexity level
            **kwargs: Additional arguments for the LLM

        Returns:
            AI response message
        """
        llm = self.get_llm(complexity)
        response = await llm.ainvoke(messages, **kwargs)
        return response

    def invoke(
        self,
        messages: List[BaseMessage],
        complexity: Optional[ComplexityLevel] = None,
        **kwargs: Any,
    ) -> BaseMessage:
        """
        Synchronous invoke LLM with messages.

        Args:
            messages: List of messages for the conversation
            complexity: Model complexity level
            **kwargs: Additional arguments for the LLM

        Returns:
            AI response message
        """
        llm = self.get_llm(complexity)
        response = llm.invoke(messages, **kwargs)
        return response

    async def astream(
        self,
        messages: List[BaseMessage],
        complexity: Optional[ComplexityLevel] = None,
        **kwargs: Any,
    ):
        """
        Async stream LLM responses.

        Args:
            messages: List of messages for the conversation
            complexity: Model complexity level
            **kwargs: Additional arguments for the LLM

        Yields:
            Chunks of AI response
        """
        llm = self.get_llm(complexity)
        async for chunk in llm.astream(messages, **kwargs):
            yield chunk

    async def embed_text(self, text: str) -> List[float]:
        """
        Generate embedding for text.

        Args:
            text: Text to embed

        Returns:
            Embedding vector
        """
        embedding = await self.embeddings.aembed_query(text)
        return embedding

    async def embed_documents(self, documents: List[str]) -> List[List[float]]:
        """
        Generate embeddings for multiple documents.

        Args:
            documents: List of documents to embed

        Returns:
            List of embedding vectors
        """
        embeddings = await self.embeddings.aembed_documents(documents)
        return embeddings

    def get_model_info(self, complexity: Optional[ComplexityLevel] = None) -> Dict[str, Any]:
        """
        Get information about a model.

        Args:
            complexity: Complexity level (defaults to current default)

        Returns:
            Model configuration dictionary
        """
        complexity = complexity or self.default_complexity
        return self.MODEL_CONFIG.get(complexity, {})

    def list_models(self) -> Dict[ComplexityLevel, Dict[str, Any]]:
        """
        List all available models and their configurations.

        Returns:
            Dictionary mapping complexity levels to model configs
        """
        return self.MODEL_CONFIG.copy()

    def get_stats(self) -> Dict[str, Any]:
        """
        Get client statistics.

        Returns:
            Statistics dictionary
        """
        if self.callback_handler:
            return self.callback_handler.get_stats()
        return {}

    def recommend_complexity(self, task_description: str) -> ComplexityLevel:
        """
        Recommend complexity level based on task description.

        Uses simple heuristics to suggest appropriate model:
        - Keywords like "plan", "analyze", "complex" → complex
        - Keywords like "validate", "critique", "assess" → analysis
        - Keywords like "classify", "route", "simple" → simple
        - Default → standard

        Args:
            task_description: Natural language task description

        Returns:
            Recommended complexity level
        """
        task_lower = task_description.lower()

        # Complex tasks
        if any(kw in task_lower for kw in ["plan", "strategy", "decompose", "workflow", "multi-step"]):
            return "complex"

        # Analysis tasks
        if any(kw in task_lower for kw in ["validate", "critique", "assess", "review", "quality"]):
            return "analysis"

        # Simple tasks
        if any(kw in task_lower for kw in ["classify", "route", "yes/no", "binary", "simple"]):
            return "simple"

        # Default to standard
        return "standard"


# Convenience function for quick initialization
def get_langchain_client(
    base_url: str = "http://localhost:11434",
    default_complexity: ComplexityLevel = "standard",
    enable_monitoring: bool = True,
) -> LangChainOllamaClient:
    """
    Get a LangChain Ollama client instance.

    Args:
        base_url: Ollama server URL
        default_complexity: Default model complexity
        enable_monitoring: Enable GPU monitoring

    Returns:
        LangChainOllamaClient instance
    """
    return LangChainOllamaClient(
        base_url=base_url,
        default_complexity=default_complexity,
        enable_monitoring=enable_monitoring,
    )