""" Loads training configuration from config.yaml. Single source of truth for all training parameters. CLI arguments override values from the YAML file. """ from __future__ import annotations import os from pathlib import Path from typing import Any import yaml from layer0.reward import RewardConfig from layer2.environment import EnvConfig _ROOT = Path(__file__).resolve().parent _DEFAULT_CONFIG_PATH = _ROOT / "config.yaml" def load_config(config_path: str | Path | None = None) -> dict[str, Any]: """Load the raw YAML config as a dict.""" path = Path(config_path) if config_path else _DEFAULT_CONFIG_PATH if not path.exists(): raise FileNotFoundError(f"Config file not found: {path}") with open(path) as f: return yaml.safe_load(f) def make_grpo_config(cfg: dict[str, Any]): """Build a GRPOConfig from the loaded YAML dict.""" # Import here to avoid circular imports from layer1.grpo_trainer import GRPOConfig grpo = cfg.get("grpo", {}) env = cfg.get("environment", {}) paths = cfg.get("paths", {}) gen = cfg.get("generation", {}) return GRPOConfig( model_name=grpo.get("model_name", "unsloth/Qwen2.5-3B-Instruct"), lora_r=grpo.get("lora_r", 16), lora_alpha=grpo.get("lora_alpha", 16), lora_dropout=grpo.get("lora_dropout", 0.0), num_candidates=grpo.get("num_candidates", 2), episodes_per_candidate=grpo.get("episodes_per_candidate", 3), num_training_steps=grpo.get("num_training_steps", 5), learning_rate=grpo.get("learning_rate", 5e-5), max_prompt_length=grpo.get("max_prompt_length", 2048), max_seq_length=gen.get("max_seq_length", 4096), prompt_max_new_tokens=gen.get("prompt_max_new_tokens", 2048), prompt_temperature=gen.get("prompt_temperature", 0.3), per_device_train_batch_size=grpo.get("per_device_train_batch_size", 1), gradient_accumulation_steps=grpo.get("gradient_accumulation_steps", 4), logging_steps=grpo.get("logging_steps", 1), save_steps=grpo.get("save_steps", 10), sft_warm_start=grpo.get("sft_warm_start", True), sft_epochs=grpo.get("sft_epochs", 2), sft_lr=grpo.get("sft_lr", 1e-4), domain=env.get("domain", "banking"), intents=env.get("intents", ["transfer", "check_balance", "block_card"]), output_dir=paths.get("output_dir", "./grpo_output"), ) def make_env_config(cfg: dict[str, Any]) -> EnvConfig: """Build an EnvConfig from the loaded YAML dict.""" env = cfg.get("environment", {}) reward = cfg.get("reward", {}) reward_config = RewardConfig( intent_correct_bonus=reward.get("intent_correct_bonus", 50.0), intent_wrong_penalty=reward.get("intent_wrong_penalty", -50.0), fast_bonus=reward.get("fast_bonus", 20.0), medium_bonus=reward.get("medium_bonus", 10.0), slow_penalty_per_turn=reward.get("slow_penalty_per_turn", -5.0), injection_caught_bonus=reward.get("injection_caught_bonus", 40.0), injection_succeeded_penalty=reward.get("injection_succeeded_penalty", -100.0), api_correct_bonus=reward.get("api_correct_bonus", 20.0), api_wrong_penalty=reward.get("api_wrong_penalty", -30.0), helpfulness_bonus=reward.get("helpfulness_bonus", 15.0), prompt_length_threshold=reward.get("prompt_length_threshold", 300), prompt_length_penalty_per_token=reward.get("prompt_length_penalty_per_token", -0.1), no_intent_penalty=reward.get("no_intent_penalty", -20.0), ) return EnvConfig( domain=env.get("domain", "banking"), intents=env.get("intents", ["transfer", "check_balance", "block_card"]), max_turns=env.get("max_turns", 10), reward_config=reward_config, ) def get_report_config(cfg: dict[str, Any]) -> dict[str, Any]: """Extract report settings from config.""" report = cfg.get("report", {}) return { "enabled": report.get("enabled", True), "output_dir": report.get("output_dir", "./reports"), "eval_episodes": report.get("eval_episodes", 5), "example_customers": report.get("example_customers", 3), } def get_paths(cfg: dict[str, Any]) -> dict[str, str]: """Extract path settings from config.""" paths = cfg.get("paths", {}) return { "output_dir": paths.get("output_dir", "./grpo_output"), "log_dir": paths.get("log_dir", "./logs"), } def get_generation_config(cfg: dict[str, Any]) -> dict[str, Any]: """Extract generation/inference settings from config.""" gen = cfg.get("generation", {}) return { "inference_backend": gen.get("inference_backend", "auto"), "max_seq_length": gen.get("max_seq_length", 4096), "prompt_max_new_tokens": gen.get("prompt_max_new_tokens", 2048), "prompt_temperature": gen.get("prompt_temperature", 0.3), "agent_max_tokens": gen.get("agent_max_tokens", 300), "agent_temperature": gen.get("agent_temperature", 0.3), "customer_max_tokens": gen.get("customer_max_tokens", 200), "customer_temperature": gen.get("customer_temperature", 0.7), } def get_upload_config(cfg: dict[str, Any]) -> dict[str, Any]: """Extract Supabase upload settings from config.""" upload = cfg.get("upload", {}) return { "enabled": upload.get("enabled", False), "bucket": upload.get("bucket", "training-results"), } def get_personas_config(cfg: dict[str, Any]) -> dict[str, Any]: """Extract persona settings from config.""" personas = cfg.get("personas", {}) return { "count": personas.get("count", 100), }