Spaces:
Paused
Paused
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| Client management module for LLM and search API interactions. | |
| This module provides client creation and management for: | |
| - Large Language Models (OpenAI, NVIDIA, local vLLM) | |
| - Web search (Tavily API) | |
| - Configuration-based client setup | |
| """ | |
| from typing import Any, Dict, List, Literal, TypedDict | |
| from openai import OpenAI | |
| from tavily import TavilyClient | |
| from config import get_config | |
| # Get configuration | |
| config = get_config() | |
| # Configuration system | |
| ApiType = Literal["nvdev", "openai", "tavily"] | |
| class ModelConfig(TypedDict): | |
| base_url: str | |
| api_type: ApiType | |
| completion_config: Dict[str, Any] | |
| # Available model configurations | |
| MODEL_CONFIGS: Dict[str, ModelConfig] = { | |
| "llama-3.1-8b": { | |
| "base_url": "https://integrate.api.nvidia.com/v1", | |
| "api_type": "nvdev", | |
| "completion_config": { | |
| "model": "nvdev/meta/llama-3.1-8b-instruct", | |
| "temperature": 0.2, | |
| "top_p": 0.7, | |
| "max_tokens": 2048, | |
| "stream": True, | |
| }, | |
| }, | |
| "llama-3.1-nemotron-8b": { | |
| "base_url": "https://integrate.api.nvidia.com/v1", | |
| "api_type": "nvdev", | |
| "completion_config": { | |
| "model": "nvdev/nvidia/llama-3.1-nemotron-nano-8b-v1", | |
| "temperature": 0.2, | |
| "top_p": 0.7, | |
| "max_tokens": 2048, | |
| "stream": True, | |
| }, | |
| }, | |
| "llama-3.1-nemotron-253b": { | |
| "base_url": "https://integrate.api.nvidia.com/v1", | |
| "api_type": "nvdev", | |
| "completion_config": { | |
| "model": "nvdev/nvidia/llama-3.1-nemotron-ultra-253b-v1", | |
| "temperature": 0.2, | |
| "top_p": 0.7, | |
| "max_tokens": 2048, | |
| "stream": True, | |
| }, | |
| }, | |
| } | |
| # Default model to use (from configuration) | |
| DEFAULT_MODEL = config.model.default_model | |
| def get_api_key(api_type: ApiType) -> str: | |
| """ | |
| Get the API key for the specified API type. | |
| This function reads API keys from configuration-specified files. | |
| The file paths can be customized via environment variables. | |
| Args: | |
| api_type: The type of API to get the key for ("nvdev", "openai", "tavily") | |
| Returns: | |
| str: The API key from the configured file | |
| Raises: | |
| FileNotFoundError: If the API key file doesn't exist | |
| ValueError: If the API type is unknown | |
| Example: | |
| >>> get_api_key("tavily") | |
| "your-tavily-api-key" | |
| """ | |
| api_key_files = { | |
| "nvdev": config.model.api_key_file, | |
| "openai": "openai_api.txt", | |
| "tavily": config.search.tavily_api_key_file, | |
| } | |
| key_file = api_key_files.get(api_type) | |
| if not key_file: | |
| raise ValueError(f"Unknown API type: {api_type}") | |
| try: | |
| with open(key_file, "r") as file: | |
| return file.read().strip() | |
| except FileNotFoundError: | |
| raise FileNotFoundError( | |
| f"API key file not found for {api_type}. " | |
| f"Please create {key_file} with your API key. " | |
| f"See README.md for configuration instructions." | |
| ) | |
| def create_lm_client(model_config: ModelConfig | None = None) -> OpenAI: | |
| """ | |
| Create an OpenAI client instance with the specified configuration. | |
| This function creates a client for the configured LLM provider. | |
| The client can be customized with specific model configurations | |
| or will use the default model from configuration. | |
| Args: | |
| model_config: Optional model configuration to override defaults. | |
| If None, uses the default model from configuration. | |
| Returns: | |
| OpenAI: Configured OpenAI client instance | |
| Example: | |
| >>> client = create_lm_client() | |
| >>> response = client.chat.completions.create(...) | |
| """ | |
| model_config = model_config or MODEL_CONFIGS[DEFAULT_MODEL] | |
| api_key = get_api_key(model_config["api_type"]) | |
| return OpenAI(base_url=model_config["base_url"], api_key=api_key) | |
| def create_tavily_client() -> TavilyClient: | |
| """ | |
| Create a Tavily client instance for web search functionality. | |
| This function creates a client for the Tavily search API using | |
| the API key from the configured file path. | |
| Returns: | |
| TavilyClient: Configured Tavily client instance | |
| Raises: | |
| FileNotFoundError: If the Tavily API key file is not found | |
| Example: | |
| >>> client = create_tavily_client() | |
| >>> results = client.search("quantum computing") | |
| """ | |
| api_key = get_api_key("tavily") | |
| return TavilyClient(api_key=api_key) | |
| def get_completion( | |
| client: OpenAI, | |
| messages: List[Dict[str, Any]], | |
| model_config: ModelConfig | None = None, | |
| ) -> str: | |
| """ | |
| Get completion from the OpenAI client using the specified model configuration. | |
| This function handles both streaming and non-streaming completions, | |
| with special handling for certain model configurations that require | |
| specific message formatting. | |
| Args: | |
| client: OpenAI client instance | |
| messages: List of messages for the completion | |
| model_config: Optional model configuration to override defaults. | |
| If None, uses the default model configuration. | |
| Returns: | |
| str: The completion text | |
| Example: | |
| >>> client = create_lm_client() | |
| >>> messages = [{"role": "user", "content": "Hello"}] | |
| >>> response = get_completion(client, messages) | |
| >>> print(response) | |
| "Hello! How can I help you today?" | |
| """ | |
| model_config = model_config or MODEL_CONFIGS[DEFAULT_MODEL] | |
| # Handle special model configurations | |
| if "retarded" in model_config and model_config["retarded"]: | |
| if messages[0]["role"] == "system": | |
| first_message = messages[0] | |
| messages = [msg for msg in messages if msg["role"] != "system"] | |
| messages[0]["content"] = ( | |
| first_message["content"] + "\n\n" + messages[0]["content"] | |
| ) | |
| messages.insert(0, {"role": "system", "content": "detailed thinking off"}) | |
| completion = client.chat.completions.create( | |
| messages=messages, **model_config["completion_config"] | |
| ) | |
| # Handle streaming vs non-streaming responses | |
| if model_config["completion_config"]["stream"]: | |
| ret = "" | |
| for chunk in completion: | |
| if chunk.choices[0].delta.content: | |
| ret += chunk.choices[0].delta.content | |
| else: | |
| ret = completion.choices[0].message.content | |
| return ret | |
| def is_output_positive(output: str) -> bool: | |
| """ | |
| Check if the output contains positive indicators. | |
| This function checks if the given output string contains | |
| positive words like "yes" or "true" (case-insensitive). | |
| Args: | |
| output: The string to check for positive indicators | |
| Returns: | |
| bool: True if positive indicators are found, False otherwise | |
| Example: | |
| >>> is_output_positive("Yes, that's correct") | |
| True | |
| >>> is_output_positive("No, that's not right") | |
| False | |
| """ | |
| positive_words = ["yes", "true"] | |
| return any(word in output.lower() for word in positive_words) | |