| """ |
| Configuration handling for SkyDiscover |
| """ |
|
|
| from __future__ import annotations |
|
|
| import logging |
| import os |
| import re |
| from dataclasses import dataclass, field, fields |
| from pathlib import Path |
| from typing import Any, Callable, Dict, List, Optional, Union |
|
|
| import yaml |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| |
| |
|
|
| _PROVIDERS: Dict[str, tuple] = { |
| "openai": ("https://api.openai.com/v1", ["OPENAI_API_KEY"]), |
| "azure": ("https://api.openai.com/v1", ["AZURE_API_KEY", "OPENAI_API_KEY"]), |
| "gemini": ( |
| "https://generativelanguage.googleapis.com/v1beta/openai/", |
| ["GEMINI_API_KEY", "GOOGLE_API_KEY"], |
| ), |
| "anthropic": ("https://api.anthropic.com/v1/", ["ANTHROPIC_API_KEY"]), |
| "deepseek": ("https://api.deepseek.com/v1", ["DEEPSEEK_API_KEY"]), |
| "mistral": ("https://api.mistral.ai/v1", ["MISTRAL_API_KEY"]), |
| "cohere": ("https://api.cohere.com/v1", ["CO_API_KEY", "COHERE_API_KEY"]), |
| "huggingface": (None, ["HF_TOKEN", "HUGGINGFACE_API_KEY"]), |
| "ollama": (None, []), |
| "vllm": (None, []), |
| } |
|
|
| |
| _BARE_PREFIX_MAP: Dict[str, str] = { |
| "gpt-": "openai", |
| "o1": "openai", |
| "o3": "openai", |
| "o4": "openai", |
| "gemini-": "gemini", |
| "claude-": "anthropic", |
| "deepseek-": "deepseek", |
| "mistral-": "mistral", |
| "command-": "cohere", |
| } |
|
|
|
|
| def _parse_model_spec(model_str: str) -> tuple: |
| """Parse a model string into ``(provider, model_name, default_api_base, env_vars)``. |
| |
| Supports: |
| - ``provider/model`` (e.g. ``gemini/gemini-3-pro``) |
| - bare names with known prefix (e.g. ``gemini-3-pro`` β gemini) |
| - unknown bare names default to ``openai`` |
| """ |
| if "/" in model_str: |
| provider, _, model_name = model_str.partition("/") |
| provider_lower = provider.lower() |
| if provider_lower in _PROVIDERS: |
| api_base, env_vars = _PROVIDERS[provider_lower] |
| return provider_lower, model_name, api_base, env_vars |
|
|
| for prefix, provider in _BARE_PREFIX_MAP.items(): |
| if model_str.startswith(prefix): |
| api_base, env_vars = _PROVIDERS[provider] |
| return provider, model_str, api_base, env_vars |
|
|
| api_base, env_vars = _PROVIDERS["openai"] |
| return "openai", model_str, api_base, env_vars |
|
|
|
|
| def _resolve_api_key_from_env(env_vars: Optional[List[str]] = None) -> Optional[str]: |
| """Return the first API key found in *env_vars*, falling back to ``OPENAI_API_KEY``. |
| |
| *env_vars* typically comes from ``_parse_model_spec()``. |
| """ |
| for var in env_vars or []: |
| key = os.environ.get(var) |
| if key: |
| return key |
| return os.environ.get("OPENAI_API_KEY") |
|
|
|
|
| def _expand_env_vars(text: str) -> str: |
| """Expand ${VAR} patterns in text with environment variable values.""" |
|
|
| def _replacer(match): |
| return os.environ.get(match.group(1), match.group(0)) |
|
|
| return re.sub(r"\$\{(\w+)\}", _replacer, text) |
|
|
|
|
| |
| |
| |
|
|
|
|
| @dataclass |
| class ContextBuilderConfig: |
| """Configuration for prompt generation""" |
|
|
| template: str = "default" |
| template_dir: Optional[str] = None |
| system_message: str = "system_message" |
| evaluator_system_message: str = "evaluator_system_message" |
|
|
| suggest_simplification_after_chars: Optional[int] = 500 |
|
|
|
|
| |
| |
| |
|
|
|
|
| @dataclass |
| class LLMModelConfig: |
| """Configuration for a single LLM model""" |
|
|
| |
| api_base: Optional[str] = None |
| api_key: Optional[str] = None |
| name: Optional[str] = None |
|
|
| |
| init_client: Optional[Callable] = None |
|
|
| |
| weight: float = 1.0 |
|
|
| |
| system_message: Optional[str] = None |
| temperature: Optional[float] = None |
| top_p: Optional[float] = None |
| max_tokens: Optional[int] = None |
|
|
| |
| timeout: Optional[int] = None |
| retries: Optional[int] = None |
| retry_delay: Optional[int] = None |
|
|
| |
| reasoning_effort: Optional[str] = None |
|
|
|
|
| @dataclass |
| class LLMConfig(LLMModelConfig): |
| """Configuration for LLM models""" |
|
|
| |
| api_base: str = _PROVIDERS["openai"][0] |
|
|
| |
| system_message: Optional[str] = "system_message" |
| temperature: Optional[float] = 0.7 |
| top_p: Optional[float] = None |
| max_tokens: int = 32000 |
|
|
| |
| timeout: int = 600 |
| retries: int = 3 |
| retry_delay: int = 5 |
|
|
| |
| models: List[LLMModelConfig] = field(default_factory=list) |
|
|
| |
| evaluator_models: List[LLMModelConfig] = field(default_factory=lambda: []) |
|
|
| |
| |
| guide_models: List[LLMModelConfig] = field(default_factory=lambda: []) |
|
|
| |
| reasoning_effort: Optional[str] = None |
|
|
| def __post_init__(self): |
| """Post-initialization to set up model configurations""" |
| |
| if not self.evaluator_models: |
| self.evaluator_models = self.models.copy() |
|
|
| |
| if not self.guide_models: |
| self.guide_models = self.models.copy() |
|
|
| |
| |
| |
| |
| |
| user_set_api_base = self.api_base.rstrip("/") != _PROVIDERS["openai"][0].rstrip("/") |
| for model in self.models + self.evaluator_models + self.guide_models: |
| if model.name and model.api_base is None: |
| provider, bare_name, provider_base, env_vars = _parse_model_spec(model.name) |
| |
| |
| |
| is_fallback = provider == "openai" and not ( |
| model.name.startswith("openai/") |
| or any(model.name.startswith(p) for p in _BARE_PREFIX_MAP) |
| ) |
| if provider_base and not (user_set_api_base and is_fallback): |
| model.api_base = provider_base |
| if model.api_key is None: |
| model.api_key = _resolve_api_key_from_env(env_vars) |
| |
| if "/" in model.name and provider != "openai": |
| model.name = bare_name |
|
|
| |
| shared_config = { |
| "api_base": self.api_base, |
| "api_key": self.api_key, |
| "temperature": self.temperature, |
| "top_p": self.top_p, |
| "max_tokens": self.max_tokens, |
| "timeout": self.timeout, |
| "retries": self.retries, |
| "retry_delay": self.retry_delay, |
| "reasoning_effort": self.reasoning_effort, |
| } |
| self.update_model_params(shared_config) |
|
|
| def update_model_params(self, args: Dict[str, Any], overwrite: bool = False) -> None: |
| """Update model parameters for all models (including guide_models).""" |
| all_models = self.models + self.evaluator_models + self.guide_models |
| for model in all_models: |
| for key, value in args.items(): |
| if overwrite or getattr(model, key, None) is None: |
| setattr(model, key, value) |
|
|
|
|
| @dataclass |
| class AgenticConfig: |
| """Configuration for agentic solution generation. |
| |
| When enabled, replaces the single-shot LLM call with a multi-turn |
| tool-calling agent loop that can read files and search the codebase |
| before outputting the discovered solution. |
| """ |
|
|
| enabled: bool = False |
| codebase_root: Optional[str] = None |
|
|
| |
| max_steps: int = 5 |
|
|
| |
| per_step_timeout: float = 60.0 |
| overall_timeout: float = 300.0 |
|
|
| |
| max_context_chars: int = 400_000 |
| max_file_chars: int = 50_000 |
| max_search_results: int = 50 |
| max_files_read: int = 20 |
|
|
| |
| regex_timeout: float = 2.0 |
| max_regex_length: int = 200 |
|
|
| |
| |
| repo_map_max_depth: int = 4 |
|
|
| |
| allowed_extensions: tuple = ( |
| ".py", |
| ".txt", |
| ".md", |
| ".json", |
| ".yaml", |
| ".yml", |
| ".toml", |
| ".cfg", |
| ".ini", |
| ".js", |
| ".ts", |
| ".java", |
| ".cpp", |
| ".c", |
| ".h", |
| ".rs", |
| ".go", |
| ) |
| excluded_dirs: tuple = ( |
| ".git", |
| "__pycache__", |
| "node_modules", |
| ".venv", |
| "venv", |
| ".mypy_cache", |
| ".pytest_cache", |
| "dist", |
| "build", |
| ) |
|
|
|
|
| |
| |
| |
|
|
|
|
| @dataclass |
| class EvaluatorConfig: |
| """Configuration for program evaluation""" |
|
|
| evaluation_file: Optional[str] = None |
| file_suffix: str = ".py" |
| is_image_mode: bool = False |
|
|
| timeout: int = 360 |
| max_retries: int = 3 |
|
|
| |
| cascade_evaluation: bool = True |
| cascade_thresholds: List[float] = field(default_factory=lambda: [0.3, 0.6]) |
|
|
| |
| |
| |
| |
| inject_evaluator_context: bool = False |
|
|
| |
| |
| |
| llm_as_judge: bool = False |
|
|
|
|
| |
| |
| |
|
|
|
|
| @dataclass |
| class DatabaseConfig: |
| """Base configuration shared by all database types.""" |
|
|
| db_path: Optional[str] = None |
| log_prompts: bool = True |
|
|
|
|
| @dataclass |
| class EvolveDatabaseConfig(DatabaseConfig): |
| """Read database from a file.""" |
|
|
| database_file_path: Optional[str] = None |
|
|
|
|
| @dataclass |
| class EvoxDatabaseConfig(EvolveDatabaseConfig): |
| """Evox (co-evolution) database config with built-in defaults.""" |
|
|
| evaluation_file: Optional[str] = None |
| config_path: Optional[str] = None |
| auto_generate_variation_operators: bool = True |
|
|
| _evox_config_dir = Path(__file__).parent / "search" / "evox" / "config" |
| _evox_database_dir = Path(__file__).parent / "search" / "evox" / "database" |
|
|
| def __post_init__(self): |
| if self.database_file_path is None: |
| |
| self.database_file_path = str(self._evox_database_dir / "initial_search_strategy.py") |
| if self.evaluation_file is None: |
| |
| self.evaluation_file = str(self._evox_database_dir / "search_strategy_evaluator.py") |
| if self.config_path is None: |
| |
| self.config_path = str(self._evox_config_dir / "search.yaml") |
|
|
|
|
| @dataclass |
| class BeamSearchDatabaseConfig(DatabaseConfig): |
| """Beam search database config.""" |
|
|
| beam_width: int = 5 |
| beam_selection_strategy: str = "diversity_weighted" |
| beam_diversity_weight: float = 0.3 |
| beam_temperature: float = 1.0 |
| beam_depth_penalty: float = 0.0 |
|
|
|
|
| @dataclass |
| class BestOfNDatabaseConfig(DatabaseConfig): |
| """Best-of-N database config.""" |
|
|
| best_of_n: int = 5 |
|
|
|
|
| @dataclass |
| class AdaEvolveDatabaseConfig(DatabaseConfig): |
| """AdaEvolve adaptive multi-island database config.""" |
|
|
| population_size: int = 20 |
| num_islands: int = 2 |
| decay: float = 0.9 |
| intensity_min: float = 0.15 |
| intensity_max: float = 0.5 |
| use_adaptive_search: bool = True |
| use_ucb_selection: bool = True |
| use_migration: bool = True |
| use_unified_archive: bool = True |
| fixed_intensity: float = 0.4 |
| migration_interval: int = 15 |
| migration_count: int = 5 |
| local_context_program_ratio: float = 0.6 |
| archive_elite_ratio: float = 0.2 |
| pareto_weight: float = 0.4 |
| fitness_weight: float = 1.0 |
| novelty_weight: float = 0.0 |
| k_neighbors: int = 5 |
| diversity_strategy: str = "code" |
| use_dynamic_islands: bool = True |
| max_islands: int = 5 |
| spawn_productivity_threshold: float = 0.015 |
| spawn_cooldown_iterations: int = 30 |
| use_paradigm_breakthrough: bool = True |
| paradigm_window_size: int = 10 |
| paradigm_improvement_threshold: float = 0.12 |
| paradigm_max_uses: int = 2 |
| paradigm_num_to_generate: int = 3 |
| paradigm_max_tried: int = 10 |
|
|
| |
| stagnation_threshold: int = 10 |
| stagnation_multi_child_count: int = 3 |
|
|
| |
| sibling_context_limit: int = 5 |
|
|
| |
| enable_error_retry: bool = True |
| max_error_retries: int = 2 |
|
|
| |
| archive_size: int = 100 |
|
|
| |
| higher_is_better: Dict[str, bool] = field(default_factory=dict) |
| fitness_key: Optional[str] = None |
| pareto_objectives: List[str] = field(default_factory=list) |
| pareto_objectives_weight: float = 0.0 |
|
|
|
|
| @dataclass |
| class OpenEvolveNativeDatabaseConfig(DatabaseConfig): |
| """OpenEvolve Native: MAP-Elites + island-based search config.""" |
|
|
| num_islands: int = 5 |
| population_size: int = 40 |
| archive_size: int = 100 |
| exploration_ratio: float = 0.2 |
| exploitation_ratio: float = 0.7 |
| elite_selection_ratio: float = 0.1 |
| feature_dimensions: List[str] = field(default_factory=lambda: ["complexity", "diversity"]) |
| feature_bins: int = 10 |
| diversity_reference_size: int = 20 |
| migration_interval: int = 10 |
| migration_rate: float = 0.1 |
| random_seed: Optional[int] = 42 |
|
|
|
|
| @dataclass |
| class ClaudeCodeConfig(DatabaseConfig): |
| """Configuration for the Claude Code baseline. |
| |
| Claude Code runs autonomously inside a Docker container, iterating on |
| the solution using the evaluator directly. max_turns maps to the |
| --max-turns flag passed to the claude CLI. |
| """ |
|
|
| max_turns: int = 50 |
| docker_image: str = "skydiscover-claude-code:latest" |
|
|
|
|
| @dataclass |
| class GEPANativeDatabaseConfig(DatabaseConfig): |
| """Configuration for GEPA Native search database. |
| |
| GEPA (Guided Evolution for Program Adaptation) uses an elite pool with |
| epsilon-greedy selection, acceptance gating, and LLM-mediated merge. |
| """ |
|
|
| population_size: int = 40 |
| candidate_selection_strategy: str = "epsilon_greedy" |
| epsilon: float = 0.1 |
| max_rejection_history: int = 20 |
|
|
| |
| acceptance_gating: bool = True |
| use_merge: bool = True |
| merge_after_stagnation: int = 15 |
| max_merge_attempts: int = 10 |
| max_recent_failures: int = 5 |
| random_seed: Optional[int] = 42 |
|
|
|
|
| _DB_CONFIG_BY_TYPE: Dict[str, type] = { |
| "evox": EvoxDatabaseConfig, |
| "beam_search": BeamSearchDatabaseConfig, |
| "best_of_n": BestOfNDatabaseConfig, |
| "topk": DatabaseConfig, |
| "adaevolve": AdaEvolveDatabaseConfig, |
| "openevolve_native": OpenEvolveNativeDatabaseConfig, |
| "gepa_native": GEPANativeDatabaseConfig, |
| "claude_code": ClaudeCodeConfig, |
| } |
|
|
|
|
| @dataclass |
| class SearchConfig: |
| """General Configuration for All Search Algorithms""" |
|
|
| type: str = "topk" |
| database: DatabaseConfig = field(default_factory=DatabaseConfig) |
| num_context_programs: int = 4 |
| output_dir: Optional[str] = None |
| switch_interval: Optional[int] = ( |
| None |
| ) |
| share_llm: bool = ( |
| False |
| ) |
|
|
|
|
| |
| |
| |
|
|
|
|
| @dataclass |
| class MonitorConfig: |
| """Configuration for the live run monitor dashboard""" |
|
|
| enabled: bool = False |
| port: int = 8765 |
| host: str = "127.0.0.1" |
| max_solution_length: int = 10000 |
|
|
| |
| summary_model: str = "gpt-5-mini" |
| summary_api_key: Optional[str] = None |
| summary_api_base: str = _PROVIDERS["openai"][0] |
| summary_top_k: int = 3 |
| summary_interval: int = 0 |
|
|
|
|
| |
| |
| |
|
|
|
|
| @dataclass |
| class BenchmarkConfig: |
| """Configuration for loading problems from external benchmark datasets. |
| |
| When enabled, allows SkyDiscover to fetch problems from external |
| benchmark datasets (e.g., KernelBench, Frontier-CS) without requiring |
| explicit initial_program paths. |
| |
| Benchmark specification and evaluation parameters (e.g., target problem) |
| are stored in a `params` dictionary. |
| """ |
|
|
| enabled: bool = False |
| name: Optional[str] = None |
| resolver: Optional[str] = ( |
| None |
| ) |
| params: Dict[str, Any] = field(default_factory=dict) |
|
|
|
|
| |
| |
| |
|
|
|
|
| @dataclass |
| class Config: |
| """Master configuration for SkyDiscover""" |
|
|
| |
| max_iterations: int = 100 |
| checkpoint_interval: int = 10 |
| log_level: str = "INFO" |
| log_dir: Optional[str] = None |
| language: Optional[str] = None |
| file_suffix: str = ".py" |
|
|
| |
| llm: LLMConfig = field(default_factory=LLMConfig) |
| context_builder: ContextBuilderConfig = field(default_factory=ContextBuilderConfig) |
| search: SearchConfig = field(default_factory=SearchConfig) |
| evaluator: EvaluatorConfig = field(default_factory=EvaluatorConfig) |
| agentic: AgenticConfig = field(default_factory=AgenticConfig) |
| benchmark: BenchmarkConfig = field(default_factory=BenchmarkConfig) |
|
|
| |
| monitor: MonitorConfig = field(default_factory=MonitorConfig) |
|
|
| |
| human_feedback_enabled: bool = False |
| human_feedback_file: Optional[str] = None |
| human_feedback_mode: str = "append" |
|
|
| |
| diff_based_generation: bool = True |
| max_solution_length: int = 60000 |
|
|
| |
| |
| |
| |
| max_parallel_iterations: int = 1 |
|
|
| |
| system_prompt_override: Optional[str] = None |
|
|
| @classmethod |
| def from_yaml(cls, path: Union[str, Path]) -> Config: |
| """Load configuration from a YAML file""" |
| config_path = Path(path) |
| config_dir = config_path.parent |
|
|
| with open(path, "r") as f: |
| raw = f.read() |
| config_dict = yaml.safe_load(_expand_env_vars(raw)) |
|
|
| |
| if "prompt" in config_dict and "system_message" in config_dict["prompt"]: |
| system_message = config_dict["prompt"]["system_message"] |
| if ( |
| isinstance(system_message, str) |
| and "\n" not in system_message.strip() |
| and len(system_message.strip()) < 256 |
| ): |
| file_path = config_dir / system_message |
| try: |
| if file_path.exists() and file_path.is_file(): |
| with open(file_path, "r") as f: |
| config_dict["prompt"]["system_message"] = f.read() |
| except OSError: |
| logger.debug("Could not read system_message from %s", file_path, exc_info=True) |
|
|
| return cls.from_dict(config_dict) |
|
|
| def to_yaml(self, path: Union[str, Path]) -> None: |
| """Save configuration to a YAML file""" |
| with open(path, "w") as f: |
| yaml.dump(self.to_dict(), f, default_flow_style=False) |
|
|
| @classmethod |
| def from_dict(cls, config_dict: Dict[str, Any]) -> Config: |
| """Create configuration from a dictionary""" |
| |
| config = Config() |
|
|
| |
| for key, value in config_dict.items(): |
| if key not in [ |
| "llm", |
| "prompt", |
| "database", |
| "search", |
| "evaluator", |
| "agentic", |
| "benchmark", |
| "monitor", |
| ] and hasattr(config, key): |
| setattr(config, key, value) |
|
|
| |
| if "llm" in config_dict: |
| llm_dict = config_dict["llm"] |
| if "models" in llm_dict: |
| llm_dict["models"] = [LLMModelConfig(**m) for m in llm_dict["models"]] |
| if "evaluator_models" in llm_dict: |
| llm_dict["evaluator_models"] = [ |
| LLMModelConfig(**m) for m in llm_dict["evaluator_models"] |
| ] |
| if "guide_models" in llm_dict: |
| llm_dict["guide_models"] = [LLMModelConfig(**m) for m in llm_dict["guide_models"]] |
| config.llm = LLMConfig(**llm_dict) |
| if "prompt" in config_dict: |
| config.context_builder = ContextBuilderConfig(**config_dict["prompt"]) |
|
|
| if "search" in config_dict: |
| search_dict = config_dict["search"] |
| search_type = search_dict.get("type", "topk") |
| db_config_cls = _DB_CONFIG_BY_TYPE.get(search_type, DatabaseConfig) |
| if "database" in search_dict: |
| db_dict = search_dict["database"] |
| |
| |
| known_fields = {f.name for f in fields(db_config_cls)} |
| db_known = {k: v for k, v in db_dict.items() if k in known_fields} |
| db_extras = {k: v for k, v in db_dict.items() if k not in known_fields} |
| db_config = db_config_cls(**db_known) |
| for k, v in db_extras.items(): |
| setattr(db_config, k, v) |
| search_dict["database"] = db_config |
| else: |
| search_dict["database"] = db_config_cls() |
| config.search = SearchConfig(**search_dict) |
|
|
| if "evaluator" in config_dict: |
| config.evaluator = EvaluatorConfig(**config_dict["evaluator"]) |
| if "agentic" in config_dict: |
| agentic_dict = dict(config_dict["agentic"]) |
| |
| for tuple_field in ("allowed_extensions", "excluded_dirs"): |
| if tuple_field in agentic_dict and isinstance(agentic_dict[tuple_field], list): |
| agentic_dict[tuple_field] = tuple(agentic_dict[tuple_field]) |
| config.agentic = AgenticConfig(**agentic_dict) |
| if "benchmark" in config_dict: |
| benchmark_dict = config_dict["benchmark"] |
| |
| known_fields = {f.name for f in fields(BenchmarkConfig) if f.name != "params"} |
| benchmark_known = {k: v for k, v in benchmark_dict.items() if k in known_fields} |
| benchmark_params = {k: v for k, v in benchmark_dict.items() if k not in known_fields} |
| config.benchmark = BenchmarkConfig(**benchmark_known, params=benchmark_params) |
| if "monitor" in config_dict: |
| config.monitor = MonitorConfig(**config_dict["monitor"]) |
|
|
| return config |
|
|
| def to_dict(self) -> Dict[str, Any]: |
| """Convert configuration to a dictionary""" |
| return { |
| |
| "max_iterations": self.max_iterations, |
| "checkpoint_interval": self.checkpoint_interval, |
| "log_level": self.log_level, |
| "log_dir": self.log_dir, |
| |
| "llm": { |
| "models": self.llm.models, |
| "evaluator_models": self.llm.evaluator_models, |
| "api_base": self.llm.api_base, |
| "temperature": self.llm.temperature, |
| "top_p": self.llm.top_p, |
| "max_tokens": self.llm.max_tokens, |
| "timeout": self.llm.timeout, |
| "retries": self.llm.retries, |
| "retry_delay": self.llm.retry_delay, |
| }, |
| "prompt": { |
| "template": self.context_builder.template, |
| "template_dir": self.context_builder.template_dir, |
| "system_message": self.context_builder.system_message, |
| "evaluator_system_message": self.context_builder.evaluator_system_message, |
| }, |
| "search": { |
| "type": self.search.type, |
| "num_context_programs": self.search.num_context_programs, |
| "database": { |
| f.name: getattr(self.search.database, f.name) |
| for f in fields(self.search.database) |
| }, |
| }, |
| "evaluator": { |
| "evaluation_file": self.evaluator.evaluation_file, |
| "file_suffix": self.evaluator.file_suffix, |
| "is_image_mode": self.evaluator.is_image_mode, |
| "timeout": self.evaluator.timeout, |
| "max_retries": self.evaluator.max_retries, |
| "cascade_evaluation": self.evaluator.cascade_evaluation, |
| "cascade_thresholds": self.evaluator.cascade_thresholds, |
| "inject_evaluator_context": self.evaluator.inject_evaluator_context, |
| "llm_as_judge": self.evaluator.llm_as_judge, |
| }, |
| |
| "agentic": { |
| "enabled": self.agentic.enabled, |
| "codebase_root": self.agentic.codebase_root, |
| "max_steps": self.agentic.max_steps, |
| "per_step_timeout": self.agentic.per_step_timeout, |
| "overall_timeout": self.agentic.overall_timeout, |
| "max_context_chars": self.agentic.max_context_chars, |
| "max_file_chars": self.agentic.max_file_chars, |
| "max_search_results": self.agentic.max_search_results, |
| "max_files_read": self.agentic.max_files_read, |
| "regex_timeout": self.agentic.regex_timeout, |
| "max_regex_length": self.agentic.max_regex_length, |
| "repo_map_max_depth": self.agentic.repo_map_max_depth, |
| "allowed_extensions": list(self.agentic.allowed_extensions), |
| "excluded_dirs": list(self.agentic.excluded_dirs), |
| }, |
| |
| "monitor": { |
| "enabled": self.monitor.enabled, |
| "port": self.monitor.port, |
| "host": self.monitor.host, |
| "max_solution_length": self.monitor.max_solution_length, |
| "summary_model": self.monitor.summary_model, |
| "summary_top_k": self.monitor.summary_top_k, |
| "summary_interval": self.monitor.summary_interval, |
| }, |
| |
| "human_feedback_enabled": self.human_feedback_enabled, |
| "human_feedback_file": self.human_feedback_file, |
| |
| "diff_based_generation": self.diff_based_generation, |
| "max_solution_length": self.max_solution_length, |
| |
| "max_parallel_iterations": self.max_parallel_iterations, |
| } |
|
|
|
|
| def load_config(config_path: Optional[Union[str, Path]] = None) -> Config: |
| """Load configuration from a YAML file or use defaults""" |
| if config_path: |
| if not os.path.exists(config_path): |
| raise FileNotFoundError(f"Config file not found: {config_path}") |
| config = Config.from_yaml(config_path) |
| else: |
| config = Config() |
|
|
| |
| |
| api_base = os.environ.get("OPENAI_API_BASE") or os.environ.get("OPENAI_BASE_URL") |
| if api_base: |
| config.llm.api_base = api_base |
| config.llm.update_model_params({"api_base": api_base}, overwrite=True) |
|
|
| |
| if not config.llm.api_key: |
| env_vars = None |
| if config.llm.models: |
| first_model_name = config.llm.models[0].name |
| if first_model_name: |
| _, _, _, env_vars = _parse_model_spec(first_model_name) |
| api_key = _resolve_api_key_from_env(env_vars) |
| if api_key: |
| config.llm.api_key = api_key |
| config.llm.update_model_params({"api_key": api_key}) |
|
|
| |
| config.llm.update_model_params({"system_message": config.context_builder.system_message}) |
|
|
| |
| |
| bridge_provider_env(config) |
|
|
| return config |
|
|
|
|
| def bridge_provider_env(config: Config) -> None: |
| """ |
| Set provider-specific env vars from resolved config. |
| |
| External backends read credentials from environment variables directly. |
| """ |
| if not config.llm.models: |
| return |
| model = config.llm.models[0] |
| if not model.api_key: |
| return |
|
|
| |
| _, _, _, env_vars = _parse_model_spec(model.name or "") |
| for var in env_vars: |
| os.environ.setdefault(var, model.api_key) |
|
|
| |
| os.environ.setdefault("OPENAI_API_KEY", model.api_key) |
|
|
| |
| if model.api_base: |
| os.environ.setdefault("OPENAI_API_BASE", model.api_base) |
|
|
|
|
| def build_output_dir(search_type: str, initial_program_path: str, base_dir: str = "outputs") -> str: |
| """Build a standardized output directory: outputs/<search_type>/<problem_name>_<MMDD_HHMM>/""" |
| from datetime import datetime |
|
|
| problem_name = ( |
| os.path.basename(os.path.dirname(os.path.abspath(initial_program_path))) or "unknown" |
| ) |
| timestamp = datetime.now().strftime("%m%d_%H%M") |
| return os.path.join(base_dir, search_type, f"{problem_name}_{timestamp}") |
|
|
|
|
| |
| |
| |
|
|
|
|
| def apply_overrides( |
| config: Config, |
| *, |
| model: Optional[str] = None, |
| api_base: Optional[str] = None, |
| agentic: bool = False, |
| search: Optional[str] = None, |
| system_prompt: Optional[str] = None, |
| ) -> None: |
| """Apply runtime overrides (model, api_base, etc.) to a loaded Config in place.""" |
| if model: |
| |
| specs = [s.strip() for s in model.split(",")] |
| models: List[LLMModelConfig] = [] |
| for spec in specs: |
| provider, model_name, default_api_base, env_vars = _parse_model_spec(spec) |
| effective_base = api_base or default_api_base |
| if effective_base is None: |
| raise ValueError( |
| f"Provider '{provider}' requires an explicit api_base.\n" |
| f"Example: model='{spec}', api_base='http://localhost:8000/v1'" |
| ) |
| resolved_key = _resolve_api_key_from_env(env_vars) |
| models.append( |
| LLMModelConfig( |
| name=model_name, |
| api_base=effective_base, |
| api_key=resolved_key, |
| ) |
| ) |
|
|
| config.llm.api_base = models[0].api_base |
| if models[0].api_key: |
| config.llm.api_key = models[0].api_key |
| config.llm.models = models |
| config.llm.evaluator_models = [ |
| LLMModelConfig(name=m.name, api_base=m.api_base, api_key=m.api_key) for m in models |
| ] |
| config.llm.guide_models = [ |
| LLMModelConfig(name=m.name, api_base=m.api_base, api_key=m.api_key) for m in models |
| ] |
| elif api_base: |
| config.llm.api_base = api_base |
| config.llm.update_model_params({"api_base": api_base}, overwrite=True) |
|
|
| |
| if not model and api_base: |
| parsed_env_vars: Optional[List[str]] = None |
| for _prefix, (base_url, env_list) in _PROVIDERS.items(): |
| if base_url and config.llm.api_base.startswith(base_url.rstrip("/")): |
| parsed_env_vars = env_list |
| break |
| resolved_key = _resolve_api_key_from_env(parsed_env_vars) |
| if resolved_key: |
| config.llm.api_key = resolved_key |
| config.llm.update_model_params({"api_key": resolved_key}, overwrite=True) |
|
|
| |
| if model or api_base: |
| config.llm.update_model_params( |
| { |
| "temperature": config.llm.temperature, |
| "top_p": config.llm.top_p, |
| "max_tokens": config.llm.max_tokens, |
| "timeout": config.llm.timeout, |
| "retries": config.llm.retries, |
| "retry_delay": config.llm.retry_delay, |
| "reasoning_effort": config.llm.reasoning_effort, |
| }, |
| overwrite=True, |
| ) |
| |
| config.llm.update_model_params( |
| {"api_base": config.llm.api_base, "api_key": config.llm.api_key}, |
| overwrite=False, |
| ) |
|
|
| if agentic: |
| config.agentic.enabled = True |
|
|
| if search: |
| if not hasattr(config, "search"): |
| config.search = SearchConfig() |
| config.search.type = search |
| new_db_cls = _DB_CONFIG_BY_TYPE.get(search) |
| if new_db_cls and not isinstance(config.search.database, new_db_cls): |
| config.search.database = new_db_cls() |
|
|
| if system_prompt: |
| config.context_builder.system_message = system_prompt |
| config.system_prompt_override = system_prompt |
|
|