Spaces:
Paused
Paused
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| Configuration module for the Universal Deep Research Backend (UDR-B). | |
| This module centralizes all configurable settings and provides | |
| environment variable support for easy deployment customization. | |
| """ | |
| import os | |
| from dataclasses import dataclass, field | |
| from typing import Any, Dict, Optional | |
| from dotenv import load_dotenv | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| class ServerConfig: | |
| """Server configuration settings.""" | |
| host: str = field(default_factory=lambda: os.getenv("HOST", "0.0.0.0")) | |
| port: int = field(default_factory=lambda: int(os.getenv("PORT", "8000"))) | |
| log_level: str = field(default_factory=lambda: os.getenv("LOG_LEVEL", "info")) | |
| reload: bool = field( | |
| default_factory=lambda: os.getenv("RELOAD", "true").lower() == "true" | |
| ) | |
| class CORSConfig: | |
| """CORS configuration settings.""" | |
| frontend_url: str = field( | |
| default_factory=lambda: os.getenv("FRONTEND_URL", "http://localhost:3000") | |
| ) | |
| allow_credentials: bool = field( | |
| default_factory=lambda: os.getenv("ALLOW_CREDENTIALS", "true").lower() == "true" | |
| ) | |
| allow_methods: list = field(default_factory=lambda: ["*"]) | |
| allow_headers: list = field(default_factory=lambda: ["*"]) | |
| class ModelConfig: | |
| """Model configuration settings.""" | |
| default_model: str = field( | |
| default_factory=lambda: os.getenv("DEFAULT_MODEL", "llama-3.1-nemotron-253b") | |
| ) | |
| base_url: str = field( | |
| default_factory=lambda: os.getenv( | |
| "LLM_BASE_URL", "https://integrate.api.nvidia.com/v1" | |
| ) | |
| ) | |
| api_key_file: str = field( | |
| default_factory=lambda: os.getenv("LLM_API_KEY_FILE", "nvdev_api.txt") | |
| ) | |
| temperature: float = field( | |
| default_factory=lambda: float(os.getenv("LLM_TEMPERATURE", "0.2")) | |
| ) | |
| top_p: float = field(default_factory=lambda: float(os.getenv("LLM_TOP_P", "0.7"))) | |
| max_tokens: int = field( | |
| default_factory=lambda: int(os.getenv("LLM_MAX_TOKENS", "2048")) | |
| ) | |
| class SearchConfig: | |
| """Search configuration settings.""" | |
| tavily_api_key_file: str = field( | |
| default_factory=lambda: os.getenv("TAVILY_API_KEY_FILE", "tavily_api.txt") | |
| ) | |
| max_search_results: int = field( | |
| default_factory=lambda: int(os.getenv("MAX_SEARCH_RESULTS", "10")) | |
| ) | |
| class ResearchConfig: | |
| """Research configuration settings.""" | |
| max_topics: int = field(default_factory=lambda: int(os.getenv("MAX_TOPICS", "1"))) | |
| max_search_phrases: int = field( | |
| default_factory=lambda: int(os.getenv("MAX_SEARCH_PHRASES", "1")) | |
| ) | |
| mock_directory: str = field( | |
| default_factory=lambda: os.getenv( | |
| "MOCK_DIRECTORY", "mock_instances/stocks_24th_3_sections" | |
| ) | |
| ) | |
| random_seed: int = field( | |
| default_factory=lambda: int(os.getenv("RANDOM_SEED", "42")) | |
| ) | |
| class LoggingConfig: | |
| """Logging configuration settings.""" | |
| log_dir: str = field(default_factory=lambda: os.getenv("LOG_DIR", "logs")) | |
| trace_enabled: bool = field( | |
| default_factory=lambda: os.getenv("TRACE_ENABLED", "true").lower() == "true" | |
| ) | |
| copy_into_stdout: bool = field( | |
| default_factory=lambda: os.getenv("COPY_INTO_STDOUT", "false").lower() == "true" | |
| ) | |
| class FrameConfig: | |
| """FrameV4 configuration settings.""" | |
| long_context_cutoff: int = field( | |
| default_factory=lambda: int(os.getenv("LONG_CONTEXT_CUTOFF", "8192")) | |
| ) | |
| force_long_context: bool = field( | |
| default_factory=lambda: os.getenv("FORCE_LONG_CONTEXT", "false").lower() | |
| == "true" | |
| ) | |
| max_iterations: int = field( | |
| default_factory=lambda: int(os.getenv("MAX_ITERATIONS", "1024")) | |
| ) | |
| interaction_level: str = field( | |
| default_factory=lambda: os.getenv("INTERACTION_LEVEL", "none") | |
| ) | |
| class AppConfig: | |
| """Main application configuration.""" | |
| server: ServerConfig = field(default_factory=ServerConfig) | |
| cors: CORSConfig = field(default_factory=CORSConfig) | |
| model: ModelConfig = field(default_factory=ModelConfig) | |
| search: SearchConfig = field(default_factory=SearchConfig) | |
| research: ResearchConfig = field(default_factory=ResearchConfig) | |
| logging: LoggingConfig = field(default_factory=LoggingConfig) | |
| frame: FrameConfig = field(default_factory=FrameConfig) | |
| def __post_init__(self): | |
| """Ensure log directory exists.""" | |
| os.makedirs(self.logging.log_dir, exist_ok=True) | |
| def to_dict(self) -> Dict[str, Any]: | |
| """Convert configuration to dictionary.""" | |
| return { | |
| "server": { | |
| "host": self.server.host, | |
| "port": self.server.port, | |
| "log_level": self.server.log_level, | |
| "reload": self.server.reload, | |
| }, | |
| "cors": { | |
| "frontend_url": self.cors.frontend_url, | |
| "allow_credentials": self.cors.allow_credentials, | |
| "allow_methods": self.cors.allow_methods, | |
| "allow_headers": self.cors.allow_headers, | |
| }, | |
| "model": { | |
| "default_model": self.model.default_model, | |
| "base_url": self.model.base_url, | |
| "api_key_file": self.model.api_key_file, | |
| "temperature": self.model.temperature, | |
| "top_p": self.model.top_p, | |
| "max_tokens": self.model.max_tokens, | |
| }, | |
| "search": { | |
| "tavily_api_key_file": self.search.tavily_api_key_file, | |
| "max_search_results": self.search.max_search_results, | |
| }, | |
| "research": { | |
| "max_topics": self.research.max_topics, | |
| "max_search_phrases": self.research.max_search_phrases, | |
| "mock_directory": self.research.mock_directory, | |
| "random_seed": self.research.random_seed, | |
| }, | |
| "logging": { | |
| "log_dir": self.logging.log_dir, | |
| "trace_enabled": self.logging.trace_enabled, | |
| "copy_into_stdout": self.logging.copy_into_stdout, | |
| }, | |
| "frame": { | |
| "long_context_cutoff": self.frame.long_context_cutoff, | |
| "force_long_context": self.frame.force_long_context, | |
| "max_iterations": self.frame.max_iterations, | |
| "interaction_level": self.frame.interaction_level, | |
| }, | |
| } | |
| # Global configuration instance | |
| config = AppConfig() | |
| def get_config() -> AppConfig: | |
| """Get the global configuration instance.""" | |
| return config | |
| def update_config(**kwargs) -> None: | |
| """Update configuration with new values.""" | |
| for key, value in kwargs.items(): | |
| if hasattr(config, key): | |
| setattr(config, key, value) | |
| else: | |
| raise ValueError(f"Unknown configuration key: {key}") | |
| def get_model_configs() -> Dict[str, Dict[str, Any]]: | |
| """Get the model configurations from clients.py with configurable values.""" | |
| from clients import MODEL_CONFIGS | |
| # Update model configs with environment variables | |
| updated_configs = {} | |
| for model_name, model_config in MODEL_CONFIGS.items(): | |
| updated_config = model_config.copy() | |
| # Override with environment variables if available | |
| env_prefix = f"{model_name.upper().replace('-', '_')}_" | |
| if os.getenv(f"{env_prefix}BASE_URL"): | |
| updated_config["base_url"] = os.getenv(f"{env_prefix}BASE_URL") | |
| if os.getenv(f"{env_prefix}MODEL"): | |
| updated_config["completion_config"]["model"] = os.getenv( | |
| f"{env_prefix}MODEL" | |
| ) | |
| if os.getenv(f"{env_prefix}TEMPERATURE"): | |
| updated_config["completion_config"]["temperature"] = float( | |
| os.getenv(f"{env_prefix}TEMPERATURE") | |
| ) | |
| if os.getenv(f"{env_prefix}TOP_P"): | |
| updated_config["completion_config"]["top_p"] = float( | |
| os.getenv(f"{env_prefix}TOP_P") | |
| ) | |
| if os.getenv(f"{env_prefix}MAX_TOKENS"): | |
| updated_config["completion_config"]["max_tokens"] = int( | |
| os.getenv(f"{env_prefix}MAX_TOKENS") | |
| ) | |
| updated_configs[model_name] = updated_config | |
| return updated_configs | |