GAIA / config.py
hapda12's picture
Upload 12 files
358eb7e verified
"""
配置管理模块 - GAIA Agent 配置
使用 .env 文件加载配置
"""
import os
from pathlib import Path
from dotenv import load_dotenv
# 加载 .env 文件(支持从父目录加载)
env_path = Path(__file__).parent / ".env"
if not env_path.exists():
env_path = Path(__file__).parent.parent / ".env"
load_dotenv(env_path)
# ========================================
# LLM 配置
# ========================================
OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
MODEL = os.getenv("MODEL", "gpt-4o-mini")
TEMPERATURE = float(os.getenv("TEMPERATURE", "0"))
# ========================================
# API 配置
# ========================================
SCORING_API_URL = os.getenv("SCORING_API_URL", "https://agents-course-unit4-scoring.hf.space")
# ========================================
# Agent 配置
# ========================================
MAX_ITERATIONS = int(os.getenv("MAX_ITERATIONS", "10"))
# ========================================
# 超时配置(秒)
# ========================================
TOOL_TIMEOUT = int(os.getenv("TOOL_TIMEOUT", "30"))
TOTAL_TIMEOUT = int(os.getenv("TOTAL_TIMEOUT", "300"))
LLM_TIMEOUT = int(os.getenv("LLM_TIMEOUT", "120")) # LLM 调用超时
# ========================================
# 搜索配置
# ========================================
SEARCH_MAX_RESULTS = int(os.getenv("SEARCH_MAX_RESULTS", "5"))
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY", "")
WIKIPEDIA_MAX_RESULTS = int(os.getenv("WIKIPEDIA_MAX_RESULTS", "2"))
ARXIV_MAX_RESULTS = int(os.getenv("ARXIV_MAX_RESULTS", "3"))
TAVILY_MAX_RESULTS = int(os.getenv("TAVILY_MAX_RESULTS", "3"))
# ========================================
# 文件处理配置
# ========================================
MAX_FILE_SIZE = int(os.getenv("MAX_FILE_SIZE", "10000"))
# ========================================
# RAG 配置
# ========================================
RAG_PERSIST_DIR = os.getenv("RAG_PERSIST_DIR", "./rag_index")
RAG_CSV_PATH = os.getenv("RAG_CSV_PATH", "data_clean.csv")
RAG_EMBEDDING_MODEL = os.getenv("RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
RAG_TOP_K = int(os.getenv("RAG_TOP_K", "3"))
# ========================================
# 速率限制配置
# ========================================
RATE_LIMIT_RETRY_MAX = int(os.getenv("RATE_LIMIT_RETRY_MAX", "5")) # 429错误最大重试次数
RATE_LIMIT_RETRY_BASE_DELAY = float(os.getenv("RATE_LIMIT_RETRY_BASE_DELAY", "10")) # 基础延迟秒数
BATCH_QUESTION_DELAY = float(os.getenv("BATCH_QUESTION_DELAY", "5")) # 批量测试问题间延迟秒数
# ========================================
# 调试配置
# ========================================
DEBUG = os.getenv("DEBUG", "false").lower() == "true"
# ========================================
# 路径配置
# ========================================
BASE_DIR = Path(__file__).parent
TEMP_DIR = BASE_DIR / "temp"
TEMP_DIR.mkdir(exist_ok=True)