File size: 5,686 Bytes
4e2b74e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3dc48b7
4e2b74e
 
 
 
 
 
3dc48b7
 
 
4e2b74e
552e492
 
 
3dc48b7
4e2b74e
 
 
 
c2dc160
 
 
4e2b74e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd8220a
 
 
 
4e2b74e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3dc48b7
 
 
 
 
 
10418d0
552e492
 
3dc48b7
 
 
 
 
 
 
 
28bcb40
 
 
 
 
 
 
 
 
3dc48b7
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""
Loads training configuration from config.yaml.

Single source of truth for all training parameters.
CLI arguments override values from the YAML file.
"""

from __future__ import annotations

import os
from pathlib import Path
from typing import Any

import yaml

from layer0.reward import RewardConfig
from layer2.environment import EnvConfig


_ROOT = Path(__file__).resolve().parent
_DEFAULT_CONFIG_PATH = _ROOT / "config.yaml"


def load_config(config_path: str | Path | None = None) -> dict[str, Any]:
    """Load the raw YAML config as a dict."""
    path = Path(config_path) if config_path else _DEFAULT_CONFIG_PATH
    if not path.exists():
        raise FileNotFoundError(f"Config file not found: {path}")
    with open(path) as f:
        return yaml.safe_load(f)


def make_grpo_config(cfg: dict[str, Any]):
    """Build a GRPOConfig from the loaded YAML dict."""
    # Import here to avoid circular imports
    from layer1.grpo_trainer import GRPOConfig

    grpo = cfg.get("grpo", {})
    env = cfg.get("environment", {})
    paths = cfg.get("paths", {})
    gen = cfg.get("generation", {})

    return GRPOConfig(
        model_name=grpo.get("model_name", "unsloth/Qwen2.5-3B-Instruct"),
        lora_r=grpo.get("lora_r", 16),
        lora_alpha=grpo.get("lora_alpha", 16),
        lora_dropout=grpo.get("lora_dropout", 0.0),
        num_candidates=grpo.get("num_candidates", 2),
        episodes_per_candidate=grpo.get("episodes_per_candidate", 3),
        num_training_steps=grpo.get("num_training_steps", 5),
        learning_rate=grpo.get("learning_rate", 5e-5),
        max_prompt_length=grpo.get("max_prompt_length", 2048),
        max_seq_length=gen.get("max_seq_length", 4096),
        prompt_max_new_tokens=gen.get("prompt_max_new_tokens", 2048),
        prompt_temperature=gen.get("prompt_temperature", 0.3),
        per_device_train_batch_size=grpo.get("per_device_train_batch_size", 1),
        gradient_accumulation_steps=grpo.get("gradient_accumulation_steps", 4),
        logging_steps=grpo.get("logging_steps", 1),
        save_steps=grpo.get("save_steps", 10),
        sft_warm_start=grpo.get("sft_warm_start", True),
        sft_epochs=grpo.get("sft_epochs", 2),
        sft_lr=grpo.get("sft_lr", 1e-4),
        domain=env.get("domain", "banking"),
        intents=env.get("intents", ["transfer", "check_balance", "block_card"]),
        output_dir=paths.get("output_dir", "./grpo_output"),
    )


def make_env_config(cfg: dict[str, Any]) -> EnvConfig:
    """Build an EnvConfig from the loaded YAML dict."""
    env = cfg.get("environment", {})
    reward = cfg.get("reward", {})

    reward_config = RewardConfig(
        intent_correct_bonus=reward.get("intent_correct_bonus", 50.0),
        intent_wrong_penalty=reward.get("intent_wrong_penalty", -50.0),
        fast_bonus=reward.get("fast_bonus", 20.0),
        medium_bonus=reward.get("medium_bonus", 10.0),
        slow_penalty_per_turn=reward.get("slow_penalty_per_turn", -5.0),
        injection_caught_bonus=reward.get("injection_caught_bonus", 40.0),
        injection_succeeded_penalty=reward.get("injection_succeeded_penalty", -100.0),
        api_correct_bonus=reward.get("api_correct_bonus", 20.0),
        api_wrong_penalty=reward.get("api_wrong_penalty", -30.0),
        helpfulness_bonus=reward.get("helpfulness_bonus", 15.0),
        prompt_length_threshold=reward.get("prompt_length_threshold", 300),
        prompt_length_penalty_per_token=reward.get("prompt_length_penalty_per_token", -0.1),
        no_intent_penalty=reward.get("no_intent_penalty", -20.0),
    )

    return EnvConfig(
        domain=env.get("domain", "banking"),
        intents=env.get("intents", ["transfer", "check_balance", "block_card"]),
        max_turns=env.get("max_turns", 10),
        reward_config=reward_config,
    )


def get_report_config(cfg: dict[str, Any]) -> dict[str, Any]:
    """Extract report settings from config."""
    report = cfg.get("report", {})
    return {
        "enabled": report.get("enabled", True),
        "output_dir": report.get("output_dir", "./reports"),
        "eval_episodes": report.get("eval_episodes", 5),
        "example_customers": report.get("example_customers", 3),
    }


def get_paths(cfg: dict[str, Any]) -> dict[str, str]:
    """Extract path settings from config."""
    paths = cfg.get("paths", {})
    return {
        "output_dir": paths.get("output_dir", "./grpo_output"),
        "log_dir": paths.get("log_dir", "./logs"),
    }


def get_generation_config(cfg: dict[str, Any]) -> dict[str, Any]:
    """Extract generation/inference settings from config."""
    gen = cfg.get("generation", {})
    return {
        "inference_backend": gen.get("inference_backend", "auto"),
        "max_seq_length": gen.get("max_seq_length", 4096),
        "prompt_max_new_tokens": gen.get("prompt_max_new_tokens", 2048),
        "prompt_temperature": gen.get("prompt_temperature", 0.3),
        "agent_max_tokens": gen.get("agent_max_tokens", 300),
        "agent_temperature": gen.get("agent_temperature", 0.3),
        "customer_max_tokens": gen.get("customer_max_tokens", 200),
        "customer_temperature": gen.get("customer_temperature", 0.7),
    }


def get_upload_config(cfg: dict[str, Any]) -> dict[str, Any]:
    """Extract Supabase upload settings from config."""
    upload = cfg.get("upload", {})
    return {
        "enabled": upload.get("enabled", False),
        "bucket": upload.get("bucket", "training-results"),
    }


def get_personas_config(cfg: dict[str, Any]) -> dict[str, Any]:
    """Extract persona settings from config."""
    personas = cfg.get("personas", {})
    return {
        "count": personas.get("count", 100),
    }