NUDR / backend /clients.py
magicboris's picture
Upload 83 files
3647b02 verified
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Client management module for LLM and search API interactions.
This module provides client creation and management for:
- Large Language Models (OpenAI, NVIDIA, local vLLM)
- Web search (Tavily API)
- Configuration-based client setup
"""
from typing import Any, Dict, List, Literal, TypedDict
from openai import OpenAI
from tavily import TavilyClient
from config import get_config
# Get configuration
config = get_config()
# Configuration system
ApiType = Literal["nvdev", "openai", "tavily"]
class ModelConfig(TypedDict):
base_url: str
api_type: ApiType
completion_config: Dict[str, Any]
# Available model configurations
MODEL_CONFIGS: Dict[str, ModelConfig] = {
"llama-3.1-8b": {
"base_url": "https://integrate.api.nvidia.com/v1",
"api_type": "nvdev",
"completion_config": {
"model": "nvdev/meta/llama-3.1-8b-instruct",
"temperature": 0.2,
"top_p": 0.7,
"max_tokens": 2048,
"stream": True,
},
},
"llama-3.1-nemotron-8b": {
"base_url": "https://integrate.api.nvidia.com/v1",
"api_type": "nvdev",
"completion_config": {
"model": "nvdev/nvidia/llama-3.1-nemotron-nano-8b-v1",
"temperature": 0.2,
"top_p": 0.7,
"max_tokens": 2048,
"stream": True,
},
},
"llama-3.1-nemotron-253b": {
"base_url": "https://integrate.api.nvidia.com/v1",
"api_type": "nvdev",
"completion_config": {
"model": "nvdev/nvidia/llama-3.1-nemotron-ultra-253b-v1",
"temperature": 0.2,
"top_p": 0.7,
"max_tokens": 2048,
"stream": True,
},
},
}
# Default model to use (from configuration)
DEFAULT_MODEL = config.model.default_model
def get_api_key(api_type: ApiType) -> str:
"""
Get the API key for the specified API type.
This function reads API keys from configuration-specified files.
The file paths can be customized via environment variables.
Args:
api_type: The type of API to get the key for ("nvdev", "openai", "tavily")
Returns:
str: The API key from the configured file
Raises:
FileNotFoundError: If the API key file doesn't exist
ValueError: If the API type is unknown
Example:
>>> get_api_key("tavily")
"your-tavily-api-key"
"""
api_key_files = {
"nvdev": config.model.api_key_file,
"openai": "openai_api.txt",
"tavily": config.search.tavily_api_key_file,
}
key_file = api_key_files.get(api_type)
if not key_file:
raise ValueError(f"Unknown API type: {api_type}")
try:
with open(key_file, "r") as file:
return file.read().strip()
except FileNotFoundError:
raise FileNotFoundError(
f"API key file not found for {api_type}. "
f"Please create {key_file} with your API key. "
f"See README.md for configuration instructions."
)
def create_lm_client(model_config: ModelConfig | None = None) -> OpenAI:
"""
Create an OpenAI client instance with the specified configuration.
This function creates a client for the configured LLM provider.
The client can be customized with specific model configurations
or will use the default model from configuration.
Args:
model_config: Optional model configuration to override defaults.
If None, uses the default model from configuration.
Returns:
OpenAI: Configured OpenAI client instance
Example:
>>> client = create_lm_client()
>>> response = client.chat.completions.create(...)
"""
model_config = model_config or MODEL_CONFIGS[DEFAULT_MODEL]
api_key = get_api_key(model_config["api_type"])
return OpenAI(base_url=model_config["base_url"], api_key=api_key)
def create_tavily_client() -> TavilyClient:
"""
Create a Tavily client instance for web search functionality.
This function creates a client for the Tavily search API using
the API key from the configured file path.
Returns:
TavilyClient: Configured Tavily client instance
Raises:
FileNotFoundError: If the Tavily API key file is not found
Example:
>>> client = create_tavily_client()
>>> results = client.search("quantum computing")
"""
api_key = get_api_key("tavily")
return TavilyClient(api_key=api_key)
def get_completion(
client: OpenAI,
messages: List[Dict[str, Any]],
model_config: ModelConfig | None = None,
) -> str:
"""
Get completion from the OpenAI client using the specified model configuration.
This function handles both streaming and non-streaming completions,
with special handling for certain model configurations that require
specific message formatting.
Args:
client: OpenAI client instance
messages: List of messages for the completion
model_config: Optional model configuration to override defaults.
If None, uses the default model configuration.
Returns:
str: The completion text
Example:
>>> client = create_lm_client()
>>> messages = [{"role": "user", "content": "Hello"}]
>>> response = get_completion(client, messages)
>>> print(response)
"Hello! How can I help you today?"
"""
model_config = model_config or MODEL_CONFIGS[DEFAULT_MODEL]
# Handle special model configurations
if "retarded" in model_config and model_config["retarded"]:
if messages[0]["role"] == "system":
first_message = messages[0]
messages = [msg for msg in messages if msg["role"] != "system"]
messages[0]["content"] = (
first_message["content"] + "\n\n" + messages[0]["content"]
)
messages.insert(0, {"role": "system", "content": "detailed thinking off"})
completion = client.chat.completions.create(
messages=messages, **model_config["completion_config"]
)
# Handle streaming vs non-streaming responses
if model_config["completion_config"]["stream"]:
ret = ""
for chunk in completion:
if chunk.choices[0].delta.content:
ret += chunk.choices[0].delta.content
else:
ret = completion.choices[0].message.content
return ret
def is_output_positive(output: str) -> bool:
"""
Check if the output contains positive indicators.
This function checks if the given output string contains
positive words like "yes" or "true" (case-insensitive).
Args:
output: The string to check for positive indicators
Returns:
bool: True if positive indicators are found, False otherwise
Example:
>>> is_output_positive("Yes, that's correct")
True
>>> is_output_positive("No, that's not right")
False
"""
positive_words = ["yes", "true"]
return any(word in output.lower() for word in positive_words)