"""Model manager that uses NVIDIA API for inference.""" import asyncio import json from datetime import datetime from typing import Any, AsyncGenerator, Dict, List, Optional from openai import OpenAI, AsyncOpenAI from config import settings from system_prompt import DEFAULT_SYSTEM_PROMPT from tool_client import tool_client SPECIAL_TOKENS = [ "<|im_end|>", "<|im_start|>", "<|endoftext|>", "<|startoftext|>", ] class ModelManager: """Singleton manager that uses NVIDIA API for inference.""" _instance = None _initialized = False def __new__(cls): if cls._instance is None: cls._instance = super(ModelManager, cls).__new__(cls) return cls._instance def __init__(self): if self._initialized: return self._initialized = True self.nvidia_api_key = settings.NVIDIA_API_KEY self.nvidia_base_url = settings.NVIDIA_BASE_URL self.nvidia_model = settings.NVIDIA_MODEL self.n_ctx = settings.N_CTX self.temperature = settings.TEMPERATURE self.max_tokens = settings.MAX_TOKENS self.top_p = settings.TOP_P self._client = None self._async_client = None self._is_available = False self._last_error = None self._last_prompt_meta = {} self._context_safety_buffer = 0 self._min_response_tokens = 64 self._tool_client = tool_client # Tool execution settings self.MAX_TOOL_ROUNDS = 3 # ------------------------------------------------------------------ # # Properties # # ------------------------------------------------------------------ # @property def is_loaded(self) -> bool: return self._is_available @property def is_available(self) -> bool: return bool(self.nvidia_api_key) @property def last_error(self) -> Optional[str]: return self._last_error @property def last_prompt_meta(self) -> Dict[str, Any]: return self._last_prompt_meta def get_max_generation_tokens_limit(self) -> int: """Get the maximum generation tokens limit.""" return self.max_tokens def get_model_info(self) -> Dict[str, Any]: """Get comprehensive model information for API responses.""" return { "nvidia_api_key": "***" + self.nvidia_api_key[-8:] if self.nvidia_api_key else None, "nvidia_base_url": self.nvidia_base_url, "model_name": self.nvidia_model, "is_loaded": self.is_loaded, "is_available": self.is_available, "last_error": self.last_error, "tools_available": self._tool_client.is_available, "tools": self._tool_client.get_tool_names() if self._tool_client.is_available else [], "context_window": self.n_ctx, "max_generation_tokens_limit": self.max_tokens, "default_temperature": self.temperature, "default_max_tokens": self.max_tokens, "default_top_p": self.top_p, } # ------------------------------------------------------------------ # # Client initialization # # ------------------------------------------------------------------ # def _get_client(self) -> OpenAI: """Get or create synchronous OpenAI client.""" if self._client is None: self._client = OpenAI( base_url=self.nvidia_base_url, api_key=self.nvidia_api_key ) return self._client def _get_async_client(self) -> AsyncOpenAI: """Get or create asynchronous OpenAI client.""" if self._async_client is None: self._async_client = AsyncOpenAI( base_url=self.nvidia_base_url, api_key=self.nvidia_api_key ) return self._async_client # ------------------------------------------------------------------ # # Model loading/unloading # # ------------------------------------------------------------------ # def load_model(self) -> bool: """Verify NVIDIA API is available.""" if not self.nvidia_api_key: self._last_error = "NVIDIA API key not configured" self._is_available = False return False try: # Simple test to verify API is accessible client = self._get_client() self._is_available = True self._last_error = None print(f"NVIDIA API initialized: model={self.nvidia_model}") return True except Exception as exc: self._last_error = f"NVIDIA API initialization failed: {exc}" self._is_available = False return False def unload_model(self): """Close API clients.""" self._is_available = False self._client = None self._async_client = None print("NVIDIA API connection closed") # ------------------------------------------------------------------ # # Token estimation # # ------------------------------------------------------------------ # @staticmethod def estimate_tokens(text: str) -> int: """Rough token estimation (3 chars ≈ 1 token).""" return max(1, len(text) // 3) def count_tokens(self, text: str) -> int: """Count tokens in text (alias for estimate_tokens for compatibility).""" return self.estimate_tokens(text) def resolve_max_tokens(self, prompt: str, requested: Optional[int]) -> int: """Calculate safe max_tokens given prompt length.""" prompt_tokens = self.estimate_tokens(prompt) available = self.n_ctx - prompt_tokens - self._context_safety_buffer available = max(available, self._min_response_tokens) if requested is None: return min(self.max_tokens, available) return min(requested, available) # ------------------------------------------------------------------ # # Prompt building # # ------------------------------------------------------------------ # def build_prompt( self, query: str, history: List[Dict[str, Any]] = None, system_prompt: str = None, file_content: str = None, custom_instructions: str = None, max_history_messages: int = 50, ) -> str: """Build a complete prompt with dynamic truncation.""" history = history or [] system = system_prompt or DEFAULT_SYSTEM_PROMPT # Build sections sections = [] # System prompt if system: sections.append(f"SYSTEM: {system}") # Custom instructions if custom_instructions: sections.append(f"INSTRUCTIONS: {custom_instructions}") # File content if file_content: sections.append(f"FILE CONTENT:\n{file_content}") # History if history: history_text = "--- Conversation History ---\n" for msg in history[-max_history_messages:]: role = msg.get("role", "user").upper() content = msg.get("content", "") history_text += f"{role}: {content}\n" sections.append(history_text) # Current query sections.append(f"USER: {query}") sections.append("ASSISTANT:") prompt = "\n\n".join(sections) # Store metadata self._last_prompt_meta = { "prompt_length": len(prompt), "estimated_tokens": self.estimate_tokens(prompt), "history_messages": len(history), "history_messages_used": min(len(history), max_history_messages), "history_messages_total": len(history), "timestamp": datetime.now().isoformat(), } return prompt # ------------------------------------------------------------------ # # Text processing utilities # # ------------------------------------------------------------------ # @staticmethod def _strip_special_tokens(text: str) -> str: """Remove special tokens from generated text.""" for token in SPECIAL_TOKENS: text = text.replace(token, "") return text @staticmethod def _apply_stop_sequences(text: str, stop_markers: List[str]) -> str: """Truncate text at first occurrence of any stop marker.""" for marker in stop_markers: if marker in text: text = text.split(marker)[0] return text @staticmethod def _chunk_text(text: str, chunk_size: int = 10) -> List[str]: """Split text into chunks for streaming.""" return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)] # ------------------------------------------------------------------ # # Tool call extraction # # ------------------------------------------------------------------ # def _extract_tool_calls(self, text: str) -> List[Dict[str, Any]]: """Extract tool calls from model output.""" if not self._tool_client.is_available: return [] # Look for JSON blocks with tool calls tool_calls = [] try: # Try to find JSON in the text start_idx = text.find("{") end_idx = text.rfind("}") if start_idx != -1 and end_idx != -1: json_str = text[start_idx:end_idx + 1] data = json.loads(json_str) # Check for tool_calls array if isinstance(data.get("tool_calls"), list): for call in data["tool_calls"]: if isinstance(call, dict) and "tool" in call: tool_calls.append(call) except (json.JSONDecodeError, ValueError): pass return tool_calls # ------------------------------------------------------------------ # # Generation methods # # ------------------------------------------------------------------ # def generate( self, prompt: str, temperature: float = None, max_tokens: int = None, top_p: float = None, stop: List[str] = None, ) -> str: """Generate a non-streaming response.""" if not self._is_available: if not self.load_model(): return "Error: NVIDIA API is not available." resolved_max_tokens = self.resolve_max_tokens(prompt, max_tokens) temp = self.temperature if temperature is None else float(temperature) top_p_val = self.top_p if top_p is None else float(top_p) try: client = self._get_client() response = client.chat.completions.create( model=self.nvidia_model, messages=[{"role": "user", "content": prompt}], temperature=temp, top_p=top_p_val, max_tokens=resolved_max_tokens, stream=False ) text = response.choices[0].message.content or "" text = self._strip_special_tokens(text) if stop: text = self._apply_stop_sequences(text, stop) self._last_error = None return text.strip() except Exception as exc: self._last_error = f"Generation failed: {exc}" print(f"[NVIDIA] Error: {exc}") return f"Error: {exc}" async def generate_stream( self, prompt: str, temperature: float = None, max_tokens: int = None, top_p: float = None, top_k: int = None, stop: List[str] = None, stop_event: Optional[Any] = None, ) -> AsyncGenerator[str, None]: """Generate a streaming response via NVIDIA API with tool support.""" if not self._is_available: if not self.load_model(): yield json.dumps({ "error": "NVIDIA API not available", "content": "Error: NVIDIA API is not available.", }) return resolved_max_tokens = self.resolve_max_tokens(prompt, max_tokens) temp = self.temperature if temperature is None else float(temperature) top_p_val = self.top_p if top_p is None else float(top_p) stop_markers = stop or ["USER:", "SYSTEM:"] try: # Tool execution loop current_prompt = prompt tool_round = 0 while tool_round < self.MAX_TOOL_ROUNDS: # Stream response from model client = self._get_async_client() stream = await client.chat.completions.create( model=self.nvidia_model, messages=[{"role": "user", "content": current_prompt}], temperature=temp, top_p=top_p_val, max_tokens=resolved_max_tokens, stream=True ) accumulated_text = "" streamed_to_user = False async for chunk in stream: if stop_event and getattr(stop_event, "is_set", lambda: False)(): yield json.dumps({"stopped": True, "done": True}) return if not chunk.choices: continue delta = chunk.choices[0].delta if delta.content: content = delta.content accumulated_text += content # Check for stop sequences should_stop = False for marker in stop_markers: if marker in accumulated_text: content = accumulated_text.split(marker)[0] accumulated_text = content should_stop = True break if should_stop: break if chunk.choices[0].finish_reason: break # Check if response contains tool calls tool_calls = self._extract_tool_calls(accumulated_text) if not tool_calls or not self._tool_client.is_available: # No tools to execute - this is the final response, stream it to user if not streamed_to_user and accumulated_text: # Stream the accumulated text token by token for char in accumulated_text: yield json.dumps({"token": char, "finish_reason": None}) await asyncio.sleep(0) break # Tools detected - execute them without showing the JSON to user tool_round += 1 print(f"[TOOL] Executing {len(tool_calls)} tool call(s) in round {tool_round}") tool_results = [] for call in tool_calls: tool_name = call.get("tool", "") arguments = call.get("arguments", {}) try: result_str = await self._tool_client.call_tool(tool_name, arguments) # Check if search returned empty results and retry with simpler query if tool_name == "web_search" and '"status": "error"' in result_str: original_query = arguments.get("query", "") print(f"[TOOL] Search failed for '{original_query}', trying simpler query...") # Try up to 2 alternative queries alternative_queries = [] # Remove common words that might cause issues simplified = original_query.replace("latest", "").replace("today", "").replace("news", "").strip() if simplified and simplified != original_query: alternative_queries.append(simplified) # Try just the main topic words = original_query.split() if len(words) > 2: main_topic = " ".join(words[:2]) if main_topic not in alternative_queries: alternative_queries.append(main_topic) # Try alternatives for alt_query in alternative_queries[:2]: print(f"[TOOL] Retrying with: '{alt_query}'") alt_args = arguments.copy() alt_args["query"] = alt_query result_str = await self._tool_client.call_tool(tool_name, alt_args) if '"status": "error"' not in result_str: print(f"[TOOL] Alternative query succeeded!") break tool_results.append({ "tool": tool_name, "result": result_str }) print(f"[TOOL] {tool_name} executed successfully, result length: {len(result_str)}") except Exception as tool_exc: error_msg = f"Tool {tool_name} failed: {tool_exc}" tool_results.append({ "tool": tool_name, "error": error_msg }) print(f"[TOOL] {tool_name} failed: {tool_exc}") # Build next prompt with tool results tool_results_text = "\n\n=== TOOL EXECUTION RESULTS ===\n" for tr in tool_results: tool_results_text += f"\nTool: {tr['tool']}\n" if "result" in tr: # Truncate very long results result = tr['result'] if len(result) > 50000: result = result[:50000] + "\n... (truncated)" tool_results_text += f"Result:\n{result}\n" if "error" in tr: tool_results_text += f"Error: {tr['error']}\n" tool_results_text += "\n=== END TOOL RESULTS ===\n\nNow provide a helpful answer to the user based on these search results. Cite sources and be specific. Do NOT output more tool_calls JSON.\n" # Update prompt for next round current_prompt = prompt + tool_results_text print(f"[TOOL] Continuing generation with tool results, prompt length: {len(current_prompt)}") # Continue loop to get final answer with tool results self._last_error = None yield json.dumps({"finish_reason": "stop", "done": True}) except Exception as exc: self._last_error = f"Streaming generation failed: {exc}" print(f"[NVIDIA] Error: {exc}") yield json.dumps({ "error": str(exc), "content": f"Error: {exc}", "done": True }) # Global singleton instance model_manager = ModelManager()