Claude commited on
Commit
4e2b74e
·
unverified ·
1 Parent(s): 288d9a2

Centralize all training params in config.yaml (single source of truth)

Browse files

- Add config.yaml with all GRPO, environment, reward, and report params
- Add config_loader.py to parse YAML into GRPOConfig/EnvConfig/RewardConfig
- Move hardcoded TRL trainer values (batch_size, grad_accum, save_steps)
into GRPOConfig and config.yaml
- train.py now loads from config.yaml, CLI flags override YAML values
- Config banner prints all parameters at startup
- Add pyyaml to dependencies

https://claude.ai/code/session_01DPirJ78YYN4fJUvUFJ5D6V

Files changed (6) hide show
  1. Dockerfile +1 -1
  2. config.yaml +76 -0
  3. config_loader.py +104 -0
  4. layer1/grpo_trainer.py +10 -4
  5. layer1/train.py +89 -57
  6. pyproject.toml +1 -0
Dockerfile CHANGED
@@ -4,7 +4,7 @@ WORKDIR /app
4
 
5
  COPY . .
6
 
7
- RUN pip install --no-cache-dir gradio huggingface-hub requests pydantic matplotlib python-dotenv
8
 
9
  EXPOSE 7860
10
 
 
4
 
5
  COPY . .
6
 
7
+ RUN pip install --no-cache-dir gradio huggingface-hub requests pydantic matplotlib python-dotenv pyyaml
8
 
9
  EXPOSE 7860
10
 
config.yaml ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================
2
+ # Training Configuration — Single source of truth
3
+ # ============================================================
4
+ # All training parameters are defined here. CLI flags override
5
+ # these values. To change defaults, edit this file.
6
+ # ============================================================
7
+
8
+ # --- Layer 1: GRPO RL Training ---
9
+ # Qwen2.5-3B generates candidate system prompts, which are
10
+ # evaluated by having Llama 3.1 8B use them as agent instructions.
11
+
12
+ grpo:
13
+ # Prompt generator model (trained via RL)
14
+ model_name: "unsloth/Qwen2.5-3B-Instruct"
15
+
16
+ # LoRA adapter settings
17
+ lora_r: 16
18
+ lora_alpha: 16
19
+ lora_dropout: 0.0
20
+
21
+ # GRPO training loop
22
+ num_training_steps: 10 # Number of policy updates (GRPO iterations)
23
+ num_candidates: 4 # Candidate prompts per step (GRPO group size)
24
+ episodes_per_candidate: 7 # Customers each candidate talks to
25
+ learning_rate: 5.0e-5
26
+ max_prompt_length: 512 # Max tokens for generated system prompt
27
+
28
+ # TRL trainer settings
29
+ per_device_train_batch_size: 1
30
+ gradient_accumulation_steps: 4
31
+ logging_steps: 1
32
+ save_steps: 10
33
+
34
+
35
+ # --- Layer 2: Conversation Environment ---
36
+ # The simulated customer support environment.
37
+
38
+ environment:
39
+ domain: "banking"
40
+ intents:
41
+ - "transfer"
42
+ - "check_balance"
43
+ - "block_card"
44
+ max_turns: 10 # Max conversation turns before forced termination
45
+
46
+
47
+ # --- Layer 0: Reward Function ---
48
+ # Weights for the reward signal that drives GRPO.
49
+
50
+ reward:
51
+ intent_correct_bonus: 50.0
52
+ intent_wrong_penalty: -50.0
53
+ fast_bonus: 20.0 # Bonus for <= 3 turns
54
+ medium_bonus: 10.0 # Bonus for <= 5 turns
55
+ slow_penalty_per_turn: -5.0 # Per turn beyond 8
56
+ injection_caught_bonus: 40.0
57
+ injection_succeeded_penalty: -100.0
58
+ api_correct_bonus: 20.0
59
+ api_wrong_penalty: -30.0
60
+
61
+
62
+ # --- Report Generation ---
63
+ # Settings for the post-training evaluation report.
64
+
65
+ report:
66
+ enabled: true
67
+ output_dir: "./reports"
68
+ eval_episodes: 5 # Episodes per checkpoint evaluation
69
+ example_customers: 3 # Example conversations in report
70
+
71
+
72
+ # --- Paths ---
73
+
74
+ paths:
75
+ output_dir: "./grpo_output"
76
+ log_dir: "./logs"
config_loader.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Loads training configuration from config.yaml.
3
+
4
+ Single source of truth for all training parameters.
5
+ CLI arguments override values from the YAML file.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import os
11
+ from pathlib import Path
12
+ from typing import Any
13
+
14
+ import yaml
15
+
16
+ from layer0.reward import RewardConfig
17
+ from layer2.environment import EnvConfig
18
+
19
+
20
+ _ROOT = Path(__file__).resolve().parent
21
+ _DEFAULT_CONFIG_PATH = _ROOT / "config.yaml"
22
+
23
+
24
+ def load_config(config_path: str | Path | None = None) -> dict[str, Any]:
25
+ """Load the raw YAML config as a dict."""
26
+ path = Path(config_path) if config_path else _DEFAULT_CONFIG_PATH
27
+ if not path.exists():
28
+ raise FileNotFoundError(f"Config file not found: {path}")
29
+ with open(path) as f:
30
+ return yaml.safe_load(f)
31
+
32
+
33
+ def make_grpo_config(cfg: dict[str, Any]):
34
+ """Build a GRPOConfig from the loaded YAML dict."""
35
+ # Import here to avoid circular imports
36
+ from layer1.grpo_trainer import GRPOConfig
37
+
38
+ grpo = cfg.get("grpo", {})
39
+ env = cfg.get("environment", {})
40
+ paths = cfg.get("paths", {})
41
+
42
+ return GRPOConfig(
43
+ model_name=grpo.get("model_name", "unsloth/Qwen2.5-3B-Instruct"),
44
+ lora_r=grpo.get("lora_r", 16),
45
+ lora_alpha=grpo.get("lora_alpha", 16),
46
+ lora_dropout=grpo.get("lora_dropout", 0.0),
47
+ num_candidates=grpo.get("num_candidates", 4),
48
+ episodes_per_candidate=grpo.get("episodes_per_candidate", 7),
49
+ num_training_steps=grpo.get("num_training_steps", 10),
50
+ learning_rate=grpo.get("learning_rate", 5e-5),
51
+ max_prompt_length=grpo.get("max_prompt_length", 512),
52
+ per_device_train_batch_size=grpo.get("per_device_train_batch_size", 1),
53
+ gradient_accumulation_steps=grpo.get("gradient_accumulation_steps", 4),
54
+ logging_steps=grpo.get("logging_steps", 1),
55
+ save_steps=grpo.get("save_steps", 10),
56
+ domain=env.get("domain", "banking"),
57
+ intents=env.get("intents", ["transfer", "check_balance", "block_card"]),
58
+ output_dir=paths.get("output_dir", "./grpo_output"),
59
+ )
60
+
61
+
62
+ def make_env_config(cfg: dict[str, Any]) -> EnvConfig:
63
+ """Build an EnvConfig from the loaded YAML dict."""
64
+ env = cfg.get("environment", {})
65
+ reward = cfg.get("reward", {})
66
+
67
+ reward_config = RewardConfig(
68
+ intent_correct_bonus=reward.get("intent_correct_bonus", 50.0),
69
+ intent_wrong_penalty=reward.get("intent_wrong_penalty", -50.0),
70
+ fast_bonus=reward.get("fast_bonus", 20.0),
71
+ medium_bonus=reward.get("medium_bonus", 10.0),
72
+ slow_penalty_per_turn=reward.get("slow_penalty_per_turn", -5.0),
73
+ injection_caught_bonus=reward.get("injection_caught_bonus", 40.0),
74
+ injection_succeeded_penalty=reward.get("injection_succeeded_penalty", -100.0),
75
+ api_correct_bonus=reward.get("api_correct_bonus", 20.0),
76
+ api_wrong_penalty=reward.get("api_wrong_penalty", -30.0),
77
+ )
78
+
79
+ return EnvConfig(
80
+ domain=env.get("domain", "banking"),
81
+ intents=env.get("intents", ["transfer", "check_balance", "block_card"]),
82
+ max_turns=env.get("max_turns", 10),
83
+ reward_config=reward_config,
84
+ )
85
+
86
+
87
+ def get_report_config(cfg: dict[str, Any]) -> dict[str, Any]:
88
+ """Extract report settings from config."""
89
+ report = cfg.get("report", {})
90
+ return {
91
+ "enabled": report.get("enabled", True),
92
+ "output_dir": report.get("output_dir", "./reports"),
93
+ "eval_episodes": report.get("eval_episodes", 5),
94
+ "example_customers": report.get("example_customers", 3),
95
+ }
96
+
97
+
98
+ def get_paths(cfg: dict[str, Any]) -> dict[str, str]:
99
+ """Extract path settings from config."""
100
+ paths = cfg.get("paths", {})
101
+ return {
102
+ "output_dir": paths.get("output_dir", "./grpo_output"),
103
+ "log_dir": paths.get("log_dir", "./logs"),
104
+ }
layer1/grpo_trainer.py CHANGED
@@ -39,6 +39,12 @@ class GRPOConfig:
39
  learning_rate: float = 5e-5
40
  max_prompt_length: int = 512
41
 
 
 
 
 
 
 
42
  # Environment
43
  domain: str = "banking"
44
  intents: list[str] = field(default_factory=lambda: list(BANKING_INTENTS))
@@ -258,13 +264,13 @@ class GRPOPromptTrainer:
258
  training_args = TRLGRPOConfig(
259
  output_dir=self.config.output_dir,
260
  num_train_epochs=1,
261
- per_device_train_batch_size=1,
262
- gradient_accumulation_steps=4,
263
  learning_rate=self.config.learning_rate,
264
  num_generations=self.config.num_candidates,
265
  max_completion_length=self.config.max_prompt_length,
266
- logging_steps=1,
267
- save_steps=10,
268
  )
269
 
270
  trainer = GRPOTrainer(
 
39
  learning_rate: float = 5e-5
40
  max_prompt_length: int = 512
41
 
42
+ # TRL trainer
43
+ per_device_train_batch_size: int = 1
44
+ gradient_accumulation_steps: int = 4
45
+ logging_steps: int = 1
46
+ save_steps: int = 10
47
+
48
  # Environment
49
  domain: str = "banking"
50
  intents: list[str] = field(default_factory=lambda: list(BANKING_INTENTS))
 
264
  training_args = TRLGRPOConfig(
265
  output_dir=self.config.output_dir,
266
  num_train_epochs=1,
267
+ per_device_train_batch_size=self.config.per_device_train_batch_size,
268
+ gradient_accumulation_steps=self.config.gradient_accumulation_steps,
269
  learning_rate=self.config.learning_rate,
270
  num_generations=self.config.num_candidates,
271
  max_completion_length=self.config.max_prompt_length,
272
+ logging_steps=self.config.logging_steps,
273
+ save_steps=self.config.save_steps,
274
  )
275
 
276
  trainer = GRPOTrainer(
layer1/train.py CHANGED
@@ -1,9 +1,15 @@
1
  """
2
  Layer 1 — GRPO training script for prompt optimization.
3
 
 
 
 
4
  Usage:
5
- # GRPO training (requires GPU + train deps)
6
- python -m layer1.train --steps 10
 
 
 
7
 
8
  # Evaluate a single prompt
9
  python -m layer1.train --mode eval --prompt "You are a helpful agent."
@@ -23,12 +29,8 @@ load_dotenv(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file_
23
 
24
  sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
25
 
26
- from layer1.grpo_trainer import (
27
- GRPOConfig,
28
- GRPOPromptTrainer,
29
- PromptEvaluator,
30
- build_meta_prompt,
31
- )
32
  from layer1.training_logger import TrainingLogger, ReportGenerator
33
  from layer2.customer_sim import CustomerPersona, CustomerSimulator
34
  from layer2.hf_agent import HFAgent
@@ -60,31 +62,48 @@ def load_evaluator(hf_token: str | None = None) -> PromptEvaluator:
60
  return PromptEvaluator(personas=personas, simulator=simulator, agent_fn=agent)
61
 
62
 
63
- def _print_config_banner(args):
64
- """Print training configuration with both technical and domain names."""
 
 
 
 
65
  print(f"\n{'='*70}")
66
- print(f" TRAINING CONFIGURATION")
67
  print(f"{'='*70}")
68
- print(f" Steps / GRPO Iterations: {args.steps}")
69
- print(f" Candidates / Customer Reps: 4 per step (GRPO-generated)")
70
- print(f" Episodes / Customers: {args.episodes} per prompt")
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  print(f" Customer Rep Agent: Llama 3.1 8B (HF Inference API)")
72
  print(f" Customer Simulator: Llama 3.1 8B (HF Inference API)")
73
- total = args.steps * 4 * args.episodes
74
- print(f" Total LLM conversations: ~{total}")
75
- print(f" Report generation: {'yes' if args.report else 'no'}")
 
 
 
76
  print(f"{'='*70}\n")
77
 
78
 
79
- def run_train(args):
80
  """Run GRPO training."""
81
- _print_config_banner(args)
82
- evaluator = load_evaluator(args.hf_token)
83
- training_logger = TrainingLogger(log_dir=args.log_dir, total_steps=args.steps)
84
- config = GRPOConfig(
85
- num_training_steps=args.steps,
86
- episodes_per_candidate=args.episodes,
87
- output_dir=args.output_dir,
88
  )
89
  trainer = GRPOPromptTrainer(config=config, evaluator=evaluator, logger=training_logger)
90
  trainer.setup_model()
@@ -97,31 +116,32 @@ def run_train(args):
97
  print(best_prompt)
98
 
99
  # Evaluate the trained prompt
100
- result = evaluator.evaluate_prompt(best_prompt, num_episodes=args.episodes)
 
 
101
  print(f"\nEvaluation: mean_reward={result['mean_reward']:.1f}")
102
 
103
- if args.report:
104
  print(f"\n{'='*60}")
105
  print("GENERATING TRAINING REPORT...")
106
  print(f"{'='*60}")
107
  report_gen = ReportGenerator(evaluator, training_logger)
108
  report_path = report_gen.generate_report(
109
- output_dir=args.report_dir,
110
- num_eval_episodes=args.eval_episodes,
111
- num_example_customers=args.example_customers,
112
  )
113
  print(f"\nReport saved to {report_path}")
114
 
115
 
116
- def run_eval(args):
117
  """Evaluate a single prompt."""
118
- evaluator = load_evaluator(args.hf_token)
119
- result = evaluator.evaluate_prompt(args.prompt, num_episodes=args.episodes)
120
- print(f"Prompt: {args.prompt[:80]}...")
121
  print(f"Mean reward: {result['mean_reward']:.1f}")
122
  print(f"Min/Max: {result['min_reward']:.1f} / {result['max_reward']:.1f}")
123
 
124
- # Show per-episode breakdown
125
  for i, log in enumerate(result["logs"]):
126
  print(
127
  f" Episode {i}: intent={log['true_intent']} "
@@ -133,37 +153,49 @@ def run_eval(args):
133
  def main():
134
  parser = argparse.ArgumentParser(description="Layer 1 — GRPO Prompt Optimizer")
135
  parser.add_argument(
136
- "--mode",
137
- choices=["train", "eval"],
138
- default="train",
139
  help="Mode: train (GRPO RL training), eval (evaluate a single prompt)",
140
  )
141
- parser.add_argument("--episodes", type=int, default=7, help="Episodes per evaluation")
142
- parser.add_argument("--steps", type=int, default=10, help="GRPO training steps")
143
- parser.add_argument("--output", type=str, default=None, help="Save results to JSON")
144
- parser.add_argument("--output-dir", type=str, default="./grpo_output", help="Training output dir")
145
- parser.add_argument("--hf-token", type=str, default=None, help="HuggingFace API token")
146
- parser.add_argument("--prompt", type=str, default=None, help="Prompt to evaluate (eval mode)")
147
- parser.add_argument("--report", action="store_true", default=True,
148
- help="Generate training report after completion (default: True)")
149
- parser.add_argument("--no-report", action="store_false", dest="report",
 
 
 
 
150
  help="Skip report generation")
151
- parser.add_argument("--report-dir", type=str, default="./reports",
152
- help="Directory for report output")
153
- parser.add_argument("--log-dir", type=str, default="./logs",
154
- help="Directory for training logs")
155
- parser.add_argument("--eval-episodes", type=int, default=5,
156
- help="Episodes per checkpoint for report evaluation")
157
- parser.add_argument("--example-customers", type=int, default=3,
158
- help="Number of example customers in report")
159
  args = parser.parse_args()
160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  if args.mode == "train":
162
- run_train(args)
163
  elif args.mode == "eval":
164
  if not args.prompt:
165
  parser.error("--prompt is required for eval mode")
166
- run_eval(args)
 
167
 
168
 
169
  if __name__ == "__main__":
 
1
  """
2
  Layer 1 — GRPO training script for prompt optimization.
3
 
4
+ All parameters are loaded from config.yaml (single source of truth).
5
+ CLI flags override config.yaml values.
6
+
7
  Usage:
8
+ # Train with defaults from config.yaml
9
+ python -m layer1.train
10
+
11
+ # Override specific params
12
+ python -m layer1.train --steps 20 --episodes 10
13
 
14
  # Evaluate a single prompt
15
  python -m layer1.train --mode eval --prompt "You are a helpful agent."
 
29
 
30
  sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
31
 
32
+ from config_loader import load_config, make_grpo_config, make_env_config, get_report_config, get_paths
33
+ from layer1.grpo_trainer import GRPOConfig, GRPOPromptTrainer, PromptEvaluator
 
 
 
 
34
  from layer1.training_logger import TrainingLogger, ReportGenerator
35
  from layer2.customer_sim import CustomerPersona, CustomerSimulator
36
  from layer2.hf_agent import HFAgent
 
62
  return PromptEvaluator(personas=personas, simulator=simulator, agent_fn=agent)
63
 
64
 
65
+ def _print_config_banner(config: GRPOConfig, report_cfg: dict, paths_cfg: dict):
66
+ """Print all training parameters from config."""
67
+ total_conversations = (
68
+ config.num_training_steps * config.num_candidates * config.episodes_per_candidate
69
+ )
70
+
71
  print(f"\n{'='*70}")
72
+ print(f" TRAINING CONFIGURATION (from config.yaml)")
73
  print(f"{'='*70}")
74
+ print()
75
+ print(f" --- Layer 1: GRPO RL Training ---")
76
+ print(f" Prompt Generator Model: {config.model_name}")
77
+ print(f" LoRA: r={config.lora_r} alpha={config.lora_alpha} dropout={config.lora_dropout}")
78
+ print(f" Learning Rate: {config.learning_rate:.1e}")
79
+ print(f" Steps / GRPO Iterations: {config.num_training_steps}")
80
+ print(f" Candidates / Customer Reps: {config.num_candidates} per step")
81
+ print(f" Episodes / Customers: {config.episodes_per_candidate} per candidate")
82
+ print(f" Max Prompt Length: {config.max_prompt_length} tokens")
83
+ print(f" Batch Size: {config.per_device_train_batch_size}")
84
+ print(f" Gradient Accumulation: {config.gradient_accumulation_steps}")
85
+ print()
86
+ print(f" --- Layer 2: Conversation Environment ---")
87
+ print(f" Domain: {config.domain}")
88
+ print(f" Intents: {config.intents}")
89
+ print(f" Max Turns per Conversation: (from env config)")
90
  print(f" Customer Rep Agent: Llama 3.1 8B (HF Inference API)")
91
  print(f" Customer Simulator: Llama 3.1 8B (HF Inference API)")
92
+ print()
93
+ print(f" --- Totals ---")
94
+ print(f" Total LLM Conversations: ~{total_conversations}")
95
+ print(f" Report Generation: {'yes' if report_cfg['enabled'] else 'no'}")
96
+ print(f" Output Dir: {paths_cfg['output_dir']}")
97
+ print(f" Log Dir: {paths_cfg['log_dir']}")
98
  print(f"{'='*70}\n")
99
 
100
 
101
+ def run_train(config: GRPOConfig, report_cfg: dict, paths_cfg: dict, hf_token: str | None):
102
  """Run GRPO training."""
103
+ _print_config_banner(config, report_cfg, paths_cfg)
104
+ evaluator = load_evaluator(hf_token)
105
+ training_logger = TrainingLogger(
106
+ log_dir=paths_cfg["log_dir"], total_steps=config.num_training_steps
 
 
 
107
  )
108
  trainer = GRPOPromptTrainer(config=config, evaluator=evaluator, logger=training_logger)
109
  trainer.setup_model()
 
116
  print(best_prompt)
117
 
118
  # Evaluate the trained prompt
119
+ result = evaluator.evaluate_prompt(
120
+ best_prompt, num_episodes=config.episodes_per_candidate
121
+ )
122
  print(f"\nEvaluation: mean_reward={result['mean_reward']:.1f}")
123
 
124
+ if report_cfg["enabled"]:
125
  print(f"\n{'='*60}")
126
  print("GENERATING TRAINING REPORT...")
127
  print(f"{'='*60}")
128
  report_gen = ReportGenerator(evaluator, training_logger)
129
  report_path = report_gen.generate_report(
130
+ output_dir=report_cfg["output_dir"],
131
+ num_eval_episodes=report_cfg["eval_episodes"],
132
+ num_example_customers=report_cfg["example_customers"],
133
  )
134
  print(f"\nReport saved to {report_path}")
135
 
136
 
137
+ def run_eval(hf_token: str | None, prompt: str, episodes: int):
138
  """Evaluate a single prompt."""
139
+ evaluator = load_evaluator(hf_token)
140
+ result = evaluator.evaluate_prompt(prompt, num_episodes=episodes)
141
+ print(f"Prompt: {prompt[:80]}...")
142
  print(f"Mean reward: {result['mean_reward']:.1f}")
143
  print(f"Min/Max: {result['min_reward']:.1f} / {result['max_reward']:.1f}")
144
 
 
145
  for i, log in enumerate(result["logs"]):
146
  print(
147
  f" Episode {i}: intent={log['true_intent']} "
 
153
  def main():
154
  parser = argparse.ArgumentParser(description="Layer 1 — GRPO Prompt Optimizer")
155
  parser.add_argument(
156
+ "--mode", choices=["train", "eval"], default="train",
 
 
157
  help="Mode: train (GRPO RL training), eval (evaluate a single prompt)",
158
  )
159
+ parser.add_argument("--config", type=str, default=None,
160
+ help="Path to config.yaml (default: ./config.yaml)")
161
+ parser.add_argument("--episodes", type=int, default=None,
162
+ help="Override episodes_per_candidate from config")
163
+ parser.add_argument("--steps", type=int, default=None,
164
+ help="Override num_training_steps from config")
165
+ parser.add_argument("--output-dir", type=str, default=None,
166
+ help="Override output directory from config")
167
+ parser.add_argument("--hf-token", type=str, default=None,
168
+ help="HuggingFace API token")
169
+ parser.add_argument("--prompt", type=str, default=None,
170
+ help="Prompt to evaluate (eval mode)")
171
+ parser.add_argument("--no-report", action="store_true",
172
  help="Skip report generation")
 
 
 
 
 
 
 
 
173
  args = parser.parse_args()
174
 
175
+ # Load config from YAML
176
+ cfg = load_config(args.config)
177
+ grpo_config = make_grpo_config(cfg)
178
+ report_cfg = get_report_config(cfg)
179
+ paths_cfg = get_paths(cfg)
180
+
181
+ # CLI overrides
182
+ if args.steps is not None:
183
+ grpo_config.num_training_steps = args.steps
184
+ if args.episodes is not None:
185
+ grpo_config.episodes_per_candidate = args.episodes
186
+ if args.output_dir is not None:
187
+ grpo_config.output_dir = args.output_dir
188
+ paths_cfg["output_dir"] = args.output_dir
189
+ if args.no_report:
190
+ report_cfg["enabled"] = False
191
+
192
  if args.mode == "train":
193
+ run_train(grpo_config, report_cfg, paths_cfg, args.hf_token)
194
  elif args.mode == "eval":
195
  if not args.prompt:
196
  parser.error("--prompt is required for eval mode")
197
+ episodes = args.episodes or grpo_config.episodes_per_candidate
198
+ run_eval(args.hf_token, args.prompt, episodes)
199
 
200
 
201
  if __name__ == "__main__":
pyproject.toml CHANGED
@@ -18,6 +18,7 @@ dependencies = [
18
  "python-dotenv>=1.0.0",
19
  "gradio>=4.0.0",
20
  "matplotlib>=3.7.0",
 
21
  ]
22
 
23
  [project.optional-dependencies]
 
18
  "python-dotenv>=1.0.0",
19
  "gradio>=4.0.0",
20
  "matplotlib>=3.7.0",
21
+ "pyyaml>=6.0",
22
  ]
23
 
24
  [project.optional-dependencies]