test-local-nested-envs / config.yaml
Claude
Increase training scale: more steps, episodes, and SFT epochs
b1685a6 unverified
# ============================================================
# Training Configuration — Single source of truth
# ============================================================
# All training parameters are defined here. CLI flags override
# these values. To change defaults, edit this file.
# ============================================================
# --- Layer 1: GRPO RL Training ---
# Qwen2.5-3B generates candidate system prompts, which are
# evaluated by having Llama 3.1 8B use them as agent instructions.
grpo:
# Prompt generator model (trained via RL)
model_name: "unsloth/Qwen2.5-3B-Instruct"
# LoRA adapter settings
lora_r: 16
lora_alpha: 16
lora_dropout: 0.0
# SFT warm start — prime the model on seed prompts before GRPO
sft_warm_start: true # Enable SFT warm start phase
sft_epochs: 3 # Epochs over seed prompts
sft_lr: 1.0e-4 # Learning rate for SFT phase
# GRPO training loop
num_training_steps: 30 # Number of policy updates (GRPO iterations)
num_candidates: 4 # Candidate prompts per step (GRPO group size, min=2)
episodes_per_candidate: 8 # Customers each candidate talks to
learning_rate: 2.0e-5 # Lower LR for stability at scale
max_prompt_length: 512 # Max tokens for generated system prompt (hard cap during GRPO)
# TRL trainer settings
per_device_train_batch_size: 1
gradient_accumulation_steps: 4
logging_steps: 1
save_steps: 10
# --- Generation Parameters ---
# Token limits and temperatures for LLM inference.
generation:
# Inference backend for Layer 2 (agent + customer simulator)
# "auto" = local GPU if available, else HF API
# "local" = force local (requires GPU + transformers)
# "api" = force HF Inference API
inference_backend: "auto"
# Prompt generator (GRPO model) inference
max_seq_length: 4096 # Max sequence length for model loading
prompt_max_new_tokens: 512 # Max new tokens when generating prompts (capped to avoid length penalty)
prompt_temperature: 0.3 # Temperature for prompt generation
# Layer 2 agent (HF Inference API)
agent_max_tokens: 300 # Max tokens for agent responses
agent_temperature: 0.3 # Temperature for agent responses
# Customer simulator (HF Inference API)
customer_max_tokens: 200 # Max tokens for customer replies
customer_temperature: 0.7 # Temperature for customer diversity
# --- Personas ---
personas:
count: 100 # Number of customer personas to generate
# --- Layer 2: Conversation Environment ---
# The simulated customer support environment.
environment:
domain: "banking"
intents:
- "transfer"
- "check_balance"
- "block_card"
max_turns: 10 # Max conversation turns before forced termination
# --- Layer 0: Reward Function ---
# Weights for the reward signal that drives GRPO.
reward:
intent_correct_bonus: 50.0
intent_wrong_penalty: -50.0
fast_bonus: 20.0 # Bonus for <= 3 turns
medium_bonus: 10.0 # Bonus for <= 5 turns
slow_penalty_per_turn: -5.0 # Per turn beyond 8
injection_caught_bonus: 40.0
injection_succeeded_penalty: -100.0
api_correct_bonus: 20.0
api_wrong_penalty: -30.0
helpfulness_bonus: 15.0 # Bonus for being helpful AND secure (both intent + injection blocked)
prompt_length_threshold: 1200 # Tokens before length penalty kicks in
prompt_length_penalty_per_token: -0.1 # Per-token penalty for bloated prompts
no_intent_penalty: -20.0 # Penalty when agent never classifies intent
# --- Report Generation ---
# Settings for the post-training evaluation report.
report:
enabled: true
output_dir: "/workspace/output/reports"
eval_episodes: 15 # Episodes per checkpoint evaluation
example_customers: 5 # Example conversations in report
# --- Upload: Supabase ---
# Upload training results to Supabase for analysis.
# Requires SUPABASE_URL and SUPABASE_KEY environment variables.
upload:
enabled: true
bucket: "training-results" # Supabase Storage bucket name
# --- Paths ---
paths:
output_dir: "/workspace/output/grpo_output"
log_dir: "/workspace/output/logs"