NUDR / backend /config.py
magicboris's picture
Upload 83 files
3647b02 verified
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Configuration module for the Universal Deep Research Backend (UDR-B).
This module centralizes all configurable settings and provides
environment variable support for easy deployment customization.
"""
import os
from dataclasses import dataclass, field
from typing import Any, Dict, Optional
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
@dataclass
class ServerConfig:
"""Server configuration settings."""
host: str = field(default_factory=lambda: os.getenv("HOST", "0.0.0.0"))
port: int = field(default_factory=lambda: int(os.getenv("PORT", "8000")))
log_level: str = field(default_factory=lambda: os.getenv("LOG_LEVEL", "info"))
reload: bool = field(
default_factory=lambda: os.getenv("RELOAD", "true").lower() == "true"
)
@dataclass
class CORSConfig:
"""CORS configuration settings."""
frontend_url: str = field(
default_factory=lambda: os.getenv("FRONTEND_URL", "http://localhost:3000")
)
allow_credentials: bool = field(
default_factory=lambda: os.getenv("ALLOW_CREDENTIALS", "true").lower() == "true"
)
allow_methods: list = field(default_factory=lambda: ["*"])
allow_headers: list = field(default_factory=lambda: ["*"])
@dataclass
class ModelConfig:
"""Model configuration settings."""
default_model: str = field(
default_factory=lambda: os.getenv("DEFAULT_MODEL", "llama-3.1-nemotron-253b")
)
base_url: str = field(
default_factory=lambda: os.getenv(
"LLM_BASE_URL", "https://integrate.api.nvidia.com/v1"
)
)
api_key_file: str = field(
default_factory=lambda: os.getenv("LLM_API_KEY_FILE", "nvdev_api.txt")
)
temperature: float = field(
default_factory=lambda: float(os.getenv("LLM_TEMPERATURE", "0.2"))
)
top_p: float = field(default_factory=lambda: float(os.getenv("LLM_TOP_P", "0.7")))
max_tokens: int = field(
default_factory=lambda: int(os.getenv("LLM_MAX_TOKENS", "2048"))
)
@dataclass
class SearchConfig:
"""Search configuration settings."""
tavily_api_key_file: str = field(
default_factory=lambda: os.getenv("TAVILY_API_KEY_FILE", "tavily_api.txt")
)
max_search_results: int = field(
default_factory=lambda: int(os.getenv("MAX_SEARCH_RESULTS", "10"))
)
@dataclass
class ResearchConfig:
"""Research configuration settings."""
max_topics: int = field(default_factory=lambda: int(os.getenv("MAX_TOPICS", "1")))
max_search_phrases: int = field(
default_factory=lambda: int(os.getenv("MAX_SEARCH_PHRASES", "1"))
)
mock_directory: str = field(
default_factory=lambda: os.getenv(
"MOCK_DIRECTORY", "mock_instances/stocks_24th_3_sections"
)
)
random_seed: int = field(
default_factory=lambda: int(os.getenv("RANDOM_SEED", "42"))
)
@dataclass
class LoggingConfig:
"""Logging configuration settings."""
log_dir: str = field(default_factory=lambda: os.getenv("LOG_DIR", "logs"))
trace_enabled: bool = field(
default_factory=lambda: os.getenv("TRACE_ENABLED", "true").lower() == "true"
)
copy_into_stdout: bool = field(
default_factory=lambda: os.getenv("COPY_INTO_STDOUT", "false").lower() == "true"
)
@dataclass
class FrameConfig:
"""FrameV4 configuration settings."""
long_context_cutoff: int = field(
default_factory=lambda: int(os.getenv("LONG_CONTEXT_CUTOFF", "8192"))
)
force_long_context: bool = field(
default_factory=lambda: os.getenv("FORCE_LONG_CONTEXT", "false").lower()
== "true"
)
max_iterations: int = field(
default_factory=lambda: int(os.getenv("MAX_ITERATIONS", "1024"))
)
interaction_level: str = field(
default_factory=lambda: os.getenv("INTERACTION_LEVEL", "none")
)
@dataclass
class AppConfig:
"""Main application configuration."""
server: ServerConfig = field(default_factory=ServerConfig)
cors: CORSConfig = field(default_factory=CORSConfig)
model: ModelConfig = field(default_factory=ModelConfig)
search: SearchConfig = field(default_factory=SearchConfig)
research: ResearchConfig = field(default_factory=ResearchConfig)
logging: LoggingConfig = field(default_factory=LoggingConfig)
frame: FrameConfig = field(default_factory=FrameConfig)
def __post_init__(self):
"""Ensure log directory exists."""
os.makedirs(self.logging.log_dir, exist_ok=True)
def to_dict(self) -> Dict[str, Any]:
"""Convert configuration to dictionary."""
return {
"server": {
"host": self.server.host,
"port": self.server.port,
"log_level": self.server.log_level,
"reload": self.server.reload,
},
"cors": {
"frontend_url": self.cors.frontend_url,
"allow_credentials": self.cors.allow_credentials,
"allow_methods": self.cors.allow_methods,
"allow_headers": self.cors.allow_headers,
},
"model": {
"default_model": self.model.default_model,
"base_url": self.model.base_url,
"api_key_file": self.model.api_key_file,
"temperature": self.model.temperature,
"top_p": self.model.top_p,
"max_tokens": self.model.max_tokens,
},
"search": {
"tavily_api_key_file": self.search.tavily_api_key_file,
"max_search_results": self.search.max_search_results,
},
"research": {
"max_topics": self.research.max_topics,
"max_search_phrases": self.research.max_search_phrases,
"mock_directory": self.research.mock_directory,
"random_seed": self.research.random_seed,
},
"logging": {
"log_dir": self.logging.log_dir,
"trace_enabled": self.logging.trace_enabled,
"copy_into_stdout": self.logging.copy_into_stdout,
},
"frame": {
"long_context_cutoff": self.frame.long_context_cutoff,
"force_long_context": self.frame.force_long_context,
"max_iterations": self.frame.max_iterations,
"interaction_level": self.frame.interaction_level,
},
}
# Global configuration instance
config = AppConfig()
def get_config() -> AppConfig:
"""Get the global configuration instance."""
return config
def update_config(**kwargs) -> None:
"""Update configuration with new values."""
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
else:
raise ValueError(f"Unknown configuration key: {key}")
def get_model_configs() -> Dict[str, Dict[str, Any]]:
"""Get the model configurations from clients.py with configurable values."""
from clients import MODEL_CONFIGS
# Update model configs with environment variables
updated_configs = {}
for model_name, model_config in MODEL_CONFIGS.items():
updated_config = model_config.copy()
# Override with environment variables if available
env_prefix = f"{model_name.upper().replace('-', '_')}_"
if os.getenv(f"{env_prefix}BASE_URL"):
updated_config["base_url"] = os.getenv(f"{env_prefix}BASE_URL")
if os.getenv(f"{env_prefix}MODEL"):
updated_config["completion_config"]["model"] = os.getenv(
f"{env_prefix}MODEL"
)
if os.getenv(f"{env_prefix}TEMPERATURE"):
updated_config["completion_config"]["temperature"] = float(
os.getenv(f"{env_prefix}TEMPERATURE")
)
if os.getenv(f"{env_prefix}TOP_P"):
updated_config["completion_config"]["top_p"] = float(
os.getenv(f"{env_prefix}TOP_P")
)
if os.getenv(f"{env_prefix}MAX_TOKENS"):
updated_config["completion_config"]["max_tokens"] = int(
os.getenv(f"{env_prefix}MAX_TOKENS")
)
updated_configs[model_name] = updated_config
return updated_configs