File size: 20,371 Bytes
4ac5fb6
2c31cbe
 
 
 
 
4ac5fb6
2c31cbe
3c6434b
 
2c31cbe
 
 
 
 
 
 
 
 
 
 
 
4ac5fb6
2c31cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
4ac5fb6
 
 
2c31cbe
 
 
 
4ac5fb6
 
 
 
2c31cbe
 
 
 
 
4ac5fb6
 
 
2c31cbe
 
 
 
 
 
 
4ac5fb6
2c31cbe
 
 
4ac5fb6
2c31cbe
 
 
 
 
 
 
 
 
305a939
 
 
 
4ac5fb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3243bf
 
4ac5fb6
 
 
b3243bf
4ac5fb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3243bf
4ac5fb6
 
 
b3243bf
4ac5fb6
 
 
 
 
 
 
2c31cbe
4ac5fb6
 
 
 
 
 
2c31cbe
4ac5fb6
 
 
2c31cbe
 
4ac5fb6
 
 
 
 
2c31cbe
4ac5fb6
 
 
2c31cbe
 
4ac5fb6
 
 
 
01fe245
 
 
 
4ac5fb6
 
 
 
 
 
 
 
 
2c31cbe
 
4ac5fb6
2c31cbe
 
4ac5fb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01fe245
 
4ac5fb6
 
 
 
2c31cbe
 
4ac5fb6
2c31cbe
 
4ac5fb6
 
 
 
 
 
2c31cbe
4ac5fb6
 
 
 
 
 
 
2c31cbe
 
4ac5fb6
 
 
2c31cbe
4ac5fb6
 
 
2c31cbe
4ac5fb6
 
 
2c31cbe
4ac5fb6
 
 
2c31cbe
4ac5fb6
 
 
 
 
 
 
 
 
 
 
 
2c31cbe
4ac5fb6
 
 
2c31cbe
 
4ac5fb6
2c31cbe
 
 
 
 
 
 
 
 
4ac5fb6
 
 
2c31cbe
4ac5fb6
 
 
 
 
 
2c31cbe
4ac5fb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c31cbe
4ac5fb6
 
2c31cbe
4ac5fb6
 
 
2c31cbe
 
 
 
 
 
 
 
 
 
 
3c6434b
4ac5fb6
2c31cbe
 
4ac5fb6
 
2c31cbe
 
 
 
 
4ac5fb6
2c31cbe
 
 
3c6434b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ce4ba6
 
3c6434b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ac5fb6
3c6434b
 
4ac5fb6
 
3c6434b
 
 
 
3ce4ba6
 
 
 
 
 
2c31cbe
3c6434b
3ce4ba6
3c6434b
3ce4ba6
3c6434b
 
 
 
 
 
 
3ce4ba6
13285c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c6434b
 
3ce4ba6
3c6434b
3ce4ba6
3c6434b
 
 
 
 
 
3ce4ba6
3c6434b
 
3ce4ba6
3c6434b
 
 
3ce4ba6
 
 
 
 
3c6434b
 
 
3ce4ba6
 
 
3c6434b
3ce4ba6
3c6434b
3ce4ba6
2c31cbe
 
 
 
 
4ac5fb6
 
 
 
 
 
 
2c31cbe
 
4ac5fb6
2c31cbe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
"""Model manager that uses NVIDIA API for inference."""
import asyncio
import json
from datetime import datetime
from typing import Any, AsyncGenerator, Dict, List, Optional

from openai import OpenAI, AsyncOpenAI

from config import settings
from system_prompt import DEFAULT_SYSTEM_PROMPT
from tool_client import tool_client


SPECIAL_TOKENS = [
    "<|im_end|>",
    "<|im_start|>",
    "<|endoftext|>",
    "<|startoftext|>",
]


class ModelManager:
    """Singleton manager that uses NVIDIA API for inference."""

    _instance = None
    _initialized = False

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(ModelManager, cls).__new__(cls)
        return cls._instance

    def __init__(self):
        if self._initialized:
            return
        self._initialized = True

        self.nvidia_api_key = settings.NVIDIA_API_KEY
        self.nvidia_base_url = settings.NVIDIA_BASE_URL
        self.nvidia_model = settings.NVIDIA_MODEL
        self.n_ctx = settings.N_CTX
        self.temperature = settings.TEMPERATURE
        self.max_tokens = settings.MAX_TOKENS
        self.top_p = settings.TOP_P
        
        self._client = None
        self._async_client = None
        self._is_available = False
        self._last_error = None
        self._last_prompt_meta = {}
        self._context_safety_buffer = 0
        self._min_response_tokens = 64
        self._tool_client = tool_client
        
        # Tool execution settings
        self.MAX_TOOL_ROUNDS = 3

    # ------------------------------------------------------------------ #
    #  Properties                                                         #
    # ------------------------------------------------------------------ #

    @property
    def is_loaded(self) -> bool:
        return self._is_available

    @property
    def is_available(self) -> bool:
        return bool(self.nvidia_api_key)

    @property
    def last_error(self) -> Optional[str]:
        return self._last_error

    @property
    def last_prompt_meta(self) -> Dict[str, Any]:
        return self._last_prompt_meta

    def get_max_generation_tokens_limit(self) -> int:
        """Get the maximum generation tokens limit."""
        return self.max_tokens

    def get_model_info(self) -> Dict[str, Any]:
        """Get comprehensive model information for API responses."""
        return {
            "nvidia_api_key": "***" + self.nvidia_api_key[-8:] if self.nvidia_api_key else None,
            "nvidia_base_url": self.nvidia_base_url,
            "model_name": self.nvidia_model,
            "is_loaded": self.is_loaded,
            "is_available": self.is_available,
            "last_error": self.last_error,
            "tools_available": self._tool_client.is_available,
            "tools": self._tool_client.get_tool_names() if self._tool_client.is_available else [],
            "context_window": self.n_ctx,
            "max_generation_tokens_limit": self.max_tokens,
            "default_temperature": self.temperature,
            "default_max_tokens": self.max_tokens,
            "default_top_p": self.top_p,
        }

    # ------------------------------------------------------------------ #
    #  Client initialization                                              #
    # ------------------------------------------------------------------ #

    def _get_client(self) -> OpenAI:
        """Get or create synchronous OpenAI client."""
        if self._client is None:
            self._client = OpenAI(
                base_url=self.nvidia_base_url,
                api_key=self.nvidia_api_key
            )
        return self._client

    def _get_async_client(self) -> AsyncOpenAI:
        """Get or create asynchronous OpenAI client."""
        if self._async_client is None:
            self._async_client = AsyncOpenAI(
                base_url=self.nvidia_base_url,
                api_key=self.nvidia_api_key
            )
        return self._async_client

    # ------------------------------------------------------------------ #
    #  Model loading/unloading                                            #
    # ------------------------------------------------------------------ #

    def load_model(self) -> bool:
        """Verify NVIDIA API is available."""
        if not self.nvidia_api_key:
            self._last_error = "NVIDIA API key not configured"
            self._is_available = False
            return False
        
        try:
            # Simple test to verify API is accessible
            client = self._get_client()
            self._is_available = True
            self._last_error = None
            print(f"NVIDIA API initialized: model={self.nvidia_model}")
            return True
        except Exception as exc:
            self._last_error = f"NVIDIA API initialization failed: {exc}"
            self._is_available = False
            return False

    def unload_model(self):
        """Close API clients."""
        self._is_available = False
        self._client = None
        self._async_client = None
        print("NVIDIA API connection closed")

    # ------------------------------------------------------------------ #
    #  Token estimation                                                   #
    # ------------------------------------------------------------------ #

    @staticmethod
    def estimate_tokens(text: str) -> int:
        """Rough token estimation (3 chars ≈ 1 token)."""
        return max(1, len(text) // 3)

    def count_tokens(self, text: str) -> int:
        """Count tokens in text (alias for estimate_tokens for compatibility)."""
        return self.estimate_tokens(text)

    def resolve_max_tokens(self, prompt: str, requested: Optional[int]) -> int:
        """Calculate safe max_tokens given prompt length."""
        prompt_tokens = self.estimate_tokens(prompt)
        available = self.n_ctx - prompt_tokens - self._context_safety_buffer
        available = max(available, self._min_response_tokens)
        
        if requested is None:
            return min(self.max_tokens, available)
        return min(requested, available)

    # ------------------------------------------------------------------ #
    #  Prompt building                                                    #
    # ------------------------------------------------------------------ #

    def build_prompt(
        self,
        query: str,
        history: List[Dict[str, Any]] = None,
        system_prompt: str = None,
        file_content: str = None,
        custom_instructions: str = None,
        max_history_messages: int = 50,
    ) -> str:
        """Build a complete prompt with dynamic truncation."""
        history = history or []
        system = system_prompt or DEFAULT_SYSTEM_PROMPT
        
        # Build sections
        sections = []
        
        # System prompt
        if system:
            sections.append(f"SYSTEM: {system}")
        
        # Custom instructions
        if custom_instructions:
            sections.append(f"INSTRUCTIONS: {custom_instructions}")
        
        # File content
        if file_content:
            sections.append(f"FILE CONTENT:\n{file_content}")
        
        # History
        if history:
            history_text = "--- Conversation History ---\n"
            for msg in history[-max_history_messages:]:
                role = msg.get("role", "user").upper()
                content = msg.get("content", "")
                history_text += f"{role}: {content}\n"
            sections.append(history_text)
        
        # Current query
        sections.append(f"USER: {query}")
        sections.append("ASSISTANT:")
        
        prompt = "\n\n".join(sections)
        
        # Store metadata
        self._last_prompt_meta = {
            "prompt_length": len(prompt),
            "estimated_tokens": self.estimate_tokens(prompt),
            "history_messages": len(history),
            "history_messages_used": min(len(history), max_history_messages),
            "history_messages_total": len(history),
            "timestamp": datetime.now().isoformat(),
        }
        
        return prompt

    # ------------------------------------------------------------------ #
    #  Text processing utilities                                          #
    # ------------------------------------------------------------------ #

    @staticmethod
    def _strip_special_tokens(text: str) -> str:
        """Remove special tokens from generated text."""
        for token in SPECIAL_TOKENS:
            text = text.replace(token, "")
        return text

    @staticmethod
    def _apply_stop_sequences(text: str, stop_markers: List[str]) -> str:
        """Truncate text at first occurrence of any stop marker."""
        for marker in stop_markers:
            if marker in text:
                text = text.split(marker)[0]
        return text

    @staticmethod
    def _chunk_text(text: str, chunk_size: int = 10) -> List[str]:
        """Split text into chunks for streaming."""
        return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]

    # ------------------------------------------------------------------ #
    #  Tool call extraction                                               #
    # ------------------------------------------------------------------ #

    def _extract_tool_calls(self, text: str) -> List[Dict[str, Any]]:
        """Extract tool calls from model output."""
        if not self._tool_client.is_available:
            return []
        
        # Look for JSON blocks with tool calls
        tool_calls = []
        try:
            # Try to find JSON in the text
            start_idx = text.find("{")
            end_idx = text.rfind("}")
            if start_idx != -1 and end_idx != -1:
                json_str = text[start_idx:end_idx + 1]
                data = json.loads(json_str)
                
                # Check for tool_calls array
                if isinstance(data.get("tool_calls"), list):
                    for call in data["tool_calls"]:
                        if isinstance(call, dict) and "tool" in call:
                            tool_calls.append(call)
        except (json.JSONDecodeError, ValueError):
            pass
        
        return tool_calls

    # ------------------------------------------------------------------ #
    #  Generation methods                                                 #
    # ------------------------------------------------------------------ #

    def generate(
        self,
        prompt: str,
        temperature: float = None,
        max_tokens: int = None,
        top_p: float = None,
        stop: List[str] = None,
    ) -> str:
        """Generate a non-streaming response."""
        if not self._is_available:
            if not self.load_model():
                return "Error: NVIDIA API is not available."
        
        resolved_max_tokens = self.resolve_max_tokens(prompt, max_tokens)
        temp = self.temperature if temperature is None else float(temperature)
        top_p_val = self.top_p if top_p is None else float(top_p)
        
        try:
            client = self._get_client()
            response = client.chat.completions.create(
                model=self.nvidia_model,
                messages=[{"role": "user", "content": prompt}],
                temperature=temp,
                top_p=top_p_val,
                max_tokens=resolved_max_tokens,
                stream=False
            )
            
            text = response.choices[0].message.content or ""
            text = self._strip_special_tokens(text)
            
            if stop:
                text = self._apply_stop_sequences(text, stop)
            
            self._last_error = None
            return text.strip()
            
        except Exception as exc:
            self._last_error = f"Generation failed: {exc}"
            print(f"[NVIDIA] Error: {exc}")
            return f"Error: {exc}"

    async def generate_stream(
        self,
        prompt: str,
        temperature: float = None,
        max_tokens: int = None,
        top_p: float = None,
        top_k: int = None,
        stop: List[str] = None,
        stop_event: Optional[Any] = None,
    ) -> AsyncGenerator[str, None]:
        """Generate a streaming response via NVIDIA API with tool support."""
        if not self._is_available:
            if not self.load_model():
                yield json.dumps({
                    "error": "NVIDIA API not available",
                    "content": "Error: NVIDIA API is not available.",
                })
                return

        resolved_max_tokens = self.resolve_max_tokens(prompt, max_tokens)
        temp = self.temperature if temperature is None else float(temperature)
        top_p_val = self.top_p if top_p is None else float(top_p)
        stop_markers = stop or ["USER:", "SYSTEM:"]

        try:
            # Tool execution loop
            current_prompt = prompt
            tool_round = 0
            
            while tool_round < self.MAX_TOOL_ROUNDS:
                # Stream response from model
                client = self._get_async_client()
                stream = await client.chat.completions.create(
                    model=self.nvidia_model,
                    messages=[{"role": "user", "content": current_prompt}],
                    temperature=temp,
                    top_p=top_p_val,
                    max_tokens=resolved_max_tokens,
                    stream=True
                )

                accumulated_text = ""
                streamed_to_user = False
                
                async for chunk in stream:
                    if stop_event and getattr(stop_event, "is_set", lambda: False)():
                        yield json.dumps({"stopped": True, "done": True})
                        return

                    if not chunk.choices:
                        continue

                    delta = chunk.choices[0].delta
                    if delta.content:
                        content = delta.content
                        accumulated_text += content
                        
                        # Check for stop sequences
                        should_stop = False
                        for marker in stop_markers:
                            if marker in accumulated_text:
                                content = accumulated_text.split(marker)[0]
                                accumulated_text = content
                                should_stop = True
                                break
                        
                        if should_stop:
                            break

                    if chunk.choices[0].finish_reason:
                        break

                # Check if response contains tool calls
                tool_calls = self._extract_tool_calls(accumulated_text)
                
                if not tool_calls or not self._tool_client.is_available:
                    # No tools to execute - this is the final response, stream it to user
                    if not streamed_to_user and accumulated_text:
                        # Stream the accumulated text token by token
                        for char in accumulated_text:
                            yield json.dumps({"token": char, "finish_reason": None})
                            await asyncio.sleep(0)
                    break
                
                # Tools detected - execute them without showing the JSON to user
                tool_round += 1
                print(f"[TOOL] Executing {len(tool_calls)} tool call(s) in round {tool_round}")
                
                tool_results = []
                for call in tool_calls:
                    tool_name = call.get("tool", "")
                    arguments = call.get("arguments", {})
                    
                    try:
                        result_str = await self._tool_client.call_tool(tool_name, arguments)
                        
                        # Check if search returned empty results and retry with simpler query
                        if tool_name == "web_search" and '"status": "error"' in result_str:
                            original_query = arguments.get("query", "")
                            print(f"[TOOL] Search failed for '{original_query}', trying simpler query...")
                            
                            # Try up to 2 alternative queries
                            alternative_queries = []
                            
                            # Remove common words that might cause issues
                            simplified = original_query.replace("latest", "").replace("today", "").replace("news", "").strip()
                            if simplified and simplified != original_query:
                                alternative_queries.append(simplified)
                            
                            # Try just the main topic
                            words = original_query.split()
                            if len(words) > 2:
                                main_topic = " ".join(words[:2])
                                if main_topic not in alternative_queries:
                                    alternative_queries.append(main_topic)
                            
                            # Try alternatives
                            for alt_query in alternative_queries[:2]:
                                print(f"[TOOL] Retrying with: '{alt_query}'")
                                alt_args = arguments.copy()
                                alt_args["query"] = alt_query
                                result_str = await self._tool_client.call_tool(tool_name, alt_args)
                                if '"status": "error"' not in result_str:
                                    print(f"[TOOL] Alternative query succeeded!")
                                    break
                        
                        tool_results.append({
                            "tool": tool_name,
                            "result": result_str
                        })
                        print(f"[TOOL] {tool_name} executed successfully, result length: {len(result_str)}")
                    except Exception as tool_exc:
                        error_msg = f"Tool {tool_name} failed: {tool_exc}"
                        tool_results.append({
                            "tool": tool_name,
                            "error": error_msg
                        })
                        print(f"[TOOL] {tool_name} failed: {tool_exc}")
                
                # Build next prompt with tool results
                tool_results_text = "\n\n=== TOOL EXECUTION RESULTS ===\n"
                for tr in tool_results:
                    tool_results_text += f"\nTool: {tr['tool']}\n"
                    if "result" in tr:
                        # Truncate very long results
                        result = tr['result']
                        if len(result) > 50000:
                            result = result[:50000] + "\n... (truncated)"
                        tool_results_text += f"Result:\n{result}\n"
                    if "error" in tr:
                        tool_results_text += f"Error: {tr['error']}\n"
                
                tool_results_text += "\n=== END TOOL RESULTS ===\n\nNow provide a helpful answer to the user based on these search results. Cite sources and be specific. Do NOT output more tool_calls JSON.\n"
                
                # Update prompt for next round
                current_prompt = prompt + tool_results_text
                print(f"[TOOL] Continuing generation with tool results, prompt length: {len(current_prompt)}")
                
                # Continue loop to get final answer with tool results

            self._last_error = None
            yield json.dumps({"finish_reason": "stop", "done": True})

        except Exception as exc:
            self._last_error = f"Streaming generation failed: {exc}"
            print(f"[NVIDIA] Error: {exc}")
            yield json.dumps({
                "error": str(exc),
                "content": f"Error: {exc}",
                "done": True
            })


# Global singleton instance
model_manager = ModelManager()