| """Model manager that uses NVIDIA API for inference.""" |
| import asyncio |
| import json |
| from datetime import datetime |
| from typing import Any, AsyncGenerator, Dict, List, Optional |
|
|
| from openai import OpenAI, AsyncOpenAI |
|
|
| from config import settings |
| from system_prompt import DEFAULT_SYSTEM_PROMPT |
| from tool_client import tool_client |
|
|
|
|
| SPECIAL_TOKENS = [ |
| "<|im_end|>", |
| "<|im_start|>", |
| "<|endoftext|>", |
| "<|startoftext|>", |
| ] |
|
|
|
|
| class ModelManager: |
| """Singleton manager that uses NVIDIA API for inference.""" |
|
|
| _instance = None |
| _initialized = False |
|
|
| def __new__(cls): |
| if cls._instance is None: |
| cls._instance = super(ModelManager, cls).__new__(cls) |
| return cls._instance |
|
|
| def __init__(self): |
| if self._initialized: |
| return |
| self._initialized = True |
|
|
| self.nvidia_api_key = settings.NVIDIA_API_KEY |
| self.nvidia_base_url = settings.NVIDIA_BASE_URL |
| self.nvidia_model = settings.NVIDIA_MODEL |
| self.n_ctx = settings.N_CTX |
| self.temperature = settings.TEMPERATURE |
| self.max_tokens = settings.MAX_TOKENS |
| self.top_p = settings.TOP_P |
| |
| self._client = None |
| self._async_client = None |
| self._is_available = False |
| self._last_error = None |
| self._last_prompt_meta = {} |
| self._context_safety_buffer = 0 |
| self._min_response_tokens = 64 |
| self._tool_client = tool_client |
| |
| |
| self.MAX_TOOL_ROUNDS = 3 |
|
|
| |
| |
| |
|
|
| @property |
| def is_loaded(self) -> bool: |
| return self._is_available |
|
|
| @property |
| def is_available(self) -> bool: |
| return bool(self.nvidia_api_key) |
|
|
| @property |
| def last_error(self) -> Optional[str]: |
| return self._last_error |
|
|
| @property |
| def last_prompt_meta(self) -> Dict[str, Any]: |
| return self._last_prompt_meta |
|
|
| def get_max_generation_tokens_limit(self) -> int: |
| """Get the maximum generation tokens limit.""" |
| return self.max_tokens |
|
|
| def get_model_info(self) -> Dict[str, Any]: |
| """Get comprehensive model information for API responses.""" |
| return { |
| "nvidia_api_key": "***" + self.nvidia_api_key[-8:] if self.nvidia_api_key else None, |
| "nvidia_base_url": self.nvidia_base_url, |
| "model_name": self.nvidia_model, |
| "is_loaded": self.is_loaded, |
| "is_available": self.is_available, |
| "last_error": self.last_error, |
| "tools_available": self._tool_client.is_available, |
| "tools": self._tool_client.get_tool_names() if self._tool_client.is_available else [], |
| "context_window": self.n_ctx, |
| "max_generation_tokens_limit": self.max_tokens, |
| "default_temperature": self.temperature, |
| "default_max_tokens": self.max_tokens, |
| "default_top_p": self.top_p, |
| } |
|
|
| |
| |
| |
|
|
| def _get_client(self) -> OpenAI: |
| """Get or create synchronous OpenAI client.""" |
| if self._client is None: |
| self._client = OpenAI( |
| base_url=self.nvidia_base_url, |
| api_key=self.nvidia_api_key |
| ) |
| return self._client |
|
|
| def _get_async_client(self) -> AsyncOpenAI: |
| """Get or create asynchronous OpenAI client.""" |
| if self._async_client is None: |
| self._async_client = AsyncOpenAI( |
| base_url=self.nvidia_base_url, |
| api_key=self.nvidia_api_key |
| ) |
| return self._async_client |
|
|
| |
| |
| |
|
|
| def load_model(self) -> bool: |
| """Verify NVIDIA API is available.""" |
| if not self.nvidia_api_key: |
| self._last_error = "NVIDIA API key not configured" |
| self._is_available = False |
| return False |
| |
| try: |
| |
| client = self._get_client() |
| self._is_available = True |
| self._last_error = None |
| print(f"NVIDIA API initialized: model={self.nvidia_model}") |
| return True |
| except Exception as exc: |
| self._last_error = f"NVIDIA API initialization failed: {exc}" |
| self._is_available = False |
| return False |
|
|
| def unload_model(self): |
| """Close API clients.""" |
| self._is_available = False |
| self._client = None |
| self._async_client = None |
| print("NVIDIA API connection closed") |
|
|
| |
| |
| |
|
|
| @staticmethod |
| def estimate_tokens(text: str) -> int: |
| """Rough token estimation (3 chars ≈ 1 token).""" |
| return max(1, len(text) // 3) |
|
|
| def count_tokens(self, text: str) -> int: |
| """Count tokens in text (alias for estimate_tokens for compatibility).""" |
| return self.estimate_tokens(text) |
|
|
| def resolve_max_tokens(self, prompt: str, requested: Optional[int]) -> int: |
| """Calculate safe max_tokens given prompt length.""" |
| prompt_tokens = self.estimate_tokens(prompt) |
| available = self.n_ctx - prompt_tokens - self._context_safety_buffer |
| available = max(available, self._min_response_tokens) |
| |
| if requested is None: |
| return min(self.max_tokens, available) |
| return min(requested, available) |
|
|
| |
| |
| |
|
|
| def build_prompt( |
| self, |
| query: str, |
| history: List[Dict[str, Any]] = None, |
| system_prompt: str = None, |
| file_content: str = None, |
| custom_instructions: str = None, |
| max_history_messages: int = 50, |
| ) -> str: |
| """Build a complete prompt with dynamic truncation.""" |
| history = history or [] |
| system = system_prompt or DEFAULT_SYSTEM_PROMPT |
| |
| |
| sections = [] |
| |
| |
| if system: |
| sections.append(f"SYSTEM: {system}") |
| |
| |
| if custom_instructions: |
| sections.append(f"INSTRUCTIONS: {custom_instructions}") |
| |
| |
| if file_content: |
| sections.append(f"FILE CONTENT:\n{file_content}") |
| |
| |
| if history: |
| history_text = "--- Conversation History ---\n" |
| for msg in history[-max_history_messages:]: |
| role = msg.get("role", "user").upper() |
| content = msg.get("content", "") |
| history_text += f"{role}: {content}\n" |
| sections.append(history_text) |
| |
| |
| sections.append(f"USER: {query}") |
| sections.append("ASSISTANT:") |
| |
| prompt = "\n\n".join(sections) |
| |
| |
| self._last_prompt_meta = { |
| "prompt_length": len(prompt), |
| "estimated_tokens": self.estimate_tokens(prompt), |
| "history_messages": len(history), |
| "history_messages_used": min(len(history), max_history_messages), |
| "history_messages_total": len(history), |
| "timestamp": datetime.now().isoformat(), |
| } |
| |
| return prompt |
|
|
| |
| |
| |
|
|
| @staticmethod |
| def _strip_special_tokens(text: str) -> str: |
| """Remove special tokens from generated text.""" |
| for token in SPECIAL_TOKENS: |
| text = text.replace(token, "") |
| return text |
|
|
| @staticmethod |
| def _apply_stop_sequences(text: str, stop_markers: List[str]) -> str: |
| """Truncate text at first occurrence of any stop marker.""" |
| for marker in stop_markers: |
| if marker in text: |
| text = text.split(marker)[0] |
| return text |
|
|
| @staticmethod |
| def _chunk_text(text: str, chunk_size: int = 10) -> List[str]: |
| """Split text into chunks for streaming.""" |
| return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)] |
|
|
| |
| |
| |
|
|
| def _extract_tool_calls(self, text: str) -> List[Dict[str, Any]]: |
| """Extract tool calls from model output.""" |
| if not self._tool_client.is_available: |
| return [] |
| |
| |
| tool_calls = [] |
| try: |
| |
| start_idx = text.find("{") |
| end_idx = text.rfind("}") |
| if start_idx != -1 and end_idx != -1: |
| json_str = text[start_idx:end_idx + 1] |
| data = json.loads(json_str) |
| |
| |
| if isinstance(data.get("tool_calls"), list): |
| for call in data["tool_calls"]: |
| if isinstance(call, dict) and "tool" in call: |
| tool_calls.append(call) |
| except (json.JSONDecodeError, ValueError): |
| pass |
| |
| return tool_calls |
|
|
| |
| |
| |
|
|
| def generate( |
| self, |
| prompt: str, |
| temperature: float = None, |
| max_tokens: int = None, |
| top_p: float = None, |
| stop: List[str] = None, |
| ) -> str: |
| """Generate a non-streaming response.""" |
| if not self._is_available: |
| if not self.load_model(): |
| return "Error: NVIDIA API is not available." |
| |
| resolved_max_tokens = self.resolve_max_tokens(prompt, max_tokens) |
| temp = self.temperature if temperature is None else float(temperature) |
| top_p_val = self.top_p if top_p is None else float(top_p) |
| |
| try: |
| client = self._get_client() |
| response = client.chat.completions.create( |
| model=self.nvidia_model, |
| messages=[{"role": "user", "content": prompt}], |
| temperature=temp, |
| top_p=top_p_val, |
| max_tokens=resolved_max_tokens, |
| stream=False |
| ) |
| |
| text = response.choices[0].message.content or "" |
| text = self._strip_special_tokens(text) |
| |
| if stop: |
| text = self._apply_stop_sequences(text, stop) |
| |
| self._last_error = None |
| return text.strip() |
| |
| except Exception as exc: |
| self._last_error = f"Generation failed: {exc}" |
| print(f"[NVIDIA] Error: {exc}") |
| return f"Error: {exc}" |
|
|
| async def generate_stream( |
| self, |
| prompt: str, |
| temperature: float = None, |
| max_tokens: int = None, |
| top_p: float = None, |
| top_k: int = None, |
| stop: List[str] = None, |
| stop_event: Optional[Any] = None, |
| ) -> AsyncGenerator[str, None]: |
| """Generate a streaming response via NVIDIA API with tool support.""" |
| if not self._is_available: |
| if not self.load_model(): |
| yield json.dumps({ |
| "error": "NVIDIA API not available", |
| "content": "Error: NVIDIA API is not available.", |
| }) |
| return |
|
|
| resolved_max_tokens = self.resolve_max_tokens(prompt, max_tokens) |
| temp = self.temperature if temperature is None else float(temperature) |
| top_p_val = self.top_p if top_p is None else float(top_p) |
| stop_markers = stop or ["USER:", "SYSTEM:"] |
|
|
| try: |
| |
| current_prompt = prompt |
| tool_round = 0 |
| |
| while tool_round < self.MAX_TOOL_ROUNDS: |
| |
| client = self._get_async_client() |
| stream = await client.chat.completions.create( |
| model=self.nvidia_model, |
| messages=[{"role": "user", "content": current_prompt}], |
| temperature=temp, |
| top_p=top_p_val, |
| max_tokens=resolved_max_tokens, |
| stream=True |
| ) |
|
|
| accumulated_text = "" |
| streamed_to_user = False |
| |
| async for chunk in stream: |
| if stop_event and getattr(stop_event, "is_set", lambda: False)(): |
| yield json.dumps({"stopped": True, "done": True}) |
| return |
|
|
| if not chunk.choices: |
| continue |
|
|
| delta = chunk.choices[0].delta |
| if delta.content: |
| content = delta.content |
| accumulated_text += content |
| |
| |
| should_stop = False |
| for marker in stop_markers: |
| if marker in accumulated_text: |
| content = accumulated_text.split(marker)[0] |
| accumulated_text = content |
| should_stop = True |
| break |
| |
| if should_stop: |
| break |
|
|
| if chunk.choices[0].finish_reason: |
| break |
|
|
| |
| tool_calls = self._extract_tool_calls(accumulated_text) |
| |
| if not tool_calls or not self._tool_client.is_available: |
| |
| if not streamed_to_user and accumulated_text: |
| |
| for char in accumulated_text: |
| yield json.dumps({"token": char, "finish_reason": None}) |
| await asyncio.sleep(0) |
| break |
| |
| |
| tool_round += 1 |
| print(f"[TOOL] Executing {len(tool_calls)} tool call(s) in round {tool_round}") |
| |
| tool_results = [] |
| for call in tool_calls: |
| tool_name = call.get("tool", "") |
| arguments = call.get("arguments", {}) |
| |
| try: |
| result_str = await self._tool_client.call_tool(tool_name, arguments) |
| |
| |
| if tool_name == "web_search" and '"status": "error"' in result_str: |
| original_query = arguments.get("query", "") |
| print(f"[TOOL] Search failed for '{original_query}', trying simpler query...") |
| |
| |
| alternative_queries = [] |
| |
| |
| simplified = original_query.replace("latest", "").replace("today", "").replace("news", "").strip() |
| if simplified and simplified != original_query: |
| alternative_queries.append(simplified) |
| |
| |
| words = original_query.split() |
| if len(words) > 2: |
| main_topic = " ".join(words[:2]) |
| if main_topic not in alternative_queries: |
| alternative_queries.append(main_topic) |
| |
| |
| for alt_query in alternative_queries[:2]: |
| print(f"[TOOL] Retrying with: '{alt_query}'") |
| alt_args = arguments.copy() |
| alt_args["query"] = alt_query |
| result_str = await self._tool_client.call_tool(tool_name, alt_args) |
| if '"status": "error"' not in result_str: |
| print(f"[TOOL] Alternative query succeeded!") |
| break |
| |
| tool_results.append({ |
| "tool": tool_name, |
| "result": result_str |
| }) |
| print(f"[TOOL] {tool_name} executed successfully, result length: {len(result_str)}") |
| except Exception as tool_exc: |
| error_msg = f"Tool {tool_name} failed: {tool_exc}" |
| tool_results.append({ |
| "tool": tool_name, |
| "error": error_msg |
| }) |
| print(f"[TOOL] {tool_name} failed: {tool_exc}") |
| |
| |
| tool_results_text = "\n\n=== TOOL EXECUTION RESULTS ===\n" |
| for tr in tool_results: |
| tool_results_text += f"\nTool: {tr['tool']}\n" |
| if "result" in tr: |
| |
| result = tr['result'] |
| if len(result) > 50000: |
| result = result[:50000] + "\n... (truncated)" |
| tool_results_text += f"Result:\n{result}\n" |
| if "error" in tr: |
| tool_results_text += f"Error: {tr['error']}\n" |
| |
| tool_results_text += "\n=== END TOOL RESULTS ===\n\nNow provide a helpful answer to the user based on these search results. Cite sources and be specific. Do NOT output more tool_calls JSON.\n" |
| |
| |
| current_prompt = prompt + tool_results_text |
| print(f"[TOOL] Continuing generation with tool results, prompt length: {len(current_prompt)}") |
| |
| |
|
|
| self._last_error = None |
| yield json.dumps({"finish_reason": "stop", "done": True}) |
|
|
| except Exception as exc: |
| self._last_error = f"Streaming generation failed: {exc}" |
| print(f"[NVIDIA] Error: {exc}") |
| yield json.dumps({ |
| "error": str(exc), |
| "content": f"Error: {exc}", |
| "done": True |
| }) |
|
|
|
|
| |
| model_manager = ModelManager() |
|
|