Spaces:
Running on T4
Running on T4
| # ============================================================ | |
| # Training Configuration — Single source of truth | |
| # ============================================================ | |
| # All training parameters are defined here. CLI flags override | |
| # these values. To change defaults, edit this file. | |
| # ============================================================ | |
| # --- Layer 1: GRPO RL Training --- | |
| # Qwen2.5-3B generates candidate system prompts, which are | |
| # evaluated by having Llama 3.1 8B use them as agent instructions. | |
| grpo: | |
| # Prompt generator model (trained via RL) | |
| model_name: "unsloth/Qwen2.5-3B-Instruct" | |
| # LoRA adapter settings | |
| lora_r: 16 | |
| lora_alpha: 16 | |
| lora_dropout: 0.0 | |
| # SFT warm start — prime the model on seed prompts before GRPO | |
| sft_warm_start: true # Enable SFT warm start phase | |
| sft_epochs: 3 # Epochs over seed prompts | |
| sft_lr: 1.0e-4 # Learning rate for SFT phase | |
| # GRPO training loop | |
| num_training_steps: 30 # Number of policy updates (GRPO iterations) | |
| num_candidates: 4 # Candidate prompts per step (GRPO group size, min=2) | |
| episodes_per_candidate: 8 # Customers each candidate talks to | |
| learning_rate: 2.0e-5 # Lower LR for stability at scale | |
| max_prompt_length: 512 # Max tokens for generated system prompt (hard cap during GRPO) | |
| # TRL trainer settings | |
| per_device_train_batch_size: 1 | |
| gradient_accumulation_steps: 4 | |
| logging_steps: 1 | |
| save_steps: 10 | |
| # --- Generation Parameters --- | |
| # Token limits and temperatures for LLM inference. | |
| generation: | |
| # Inference backend for Layer 2 (agent + customer simulator) | |
| # "auto" = local GPU if available, else HF API | |
| # "local" = force local (requires GPU + transformers) | |
| # "api" = force HF Inference API | |
| inference_backend: "auto" | |
| # Prompt generator (GRPO model) inference | |
| max_seq_length: 4096 # Max sequence length for model loading | |
| prompt_max_new_tokens: 512 # Max new tokens when generating prompts (capped to avoid length penalty) | |
| prompt_temperature: 0.3 # Temperature for prompt generation | |
| # Layer 2 agent (HF Inference API) | |
| agent_max_tokens: 300 # Max tokens for agent responses | |
| agent_temperature: 0.3 # Temperature for agent responses | |
| # Customer simulator (HF Inference API) | |
| customer_max_tokens: 200 # Max tokens for customer replies | |
| customer_temperature: 0.7 # Temperature for customer diversity | |
| # --- Personas --- | |
| personas: | |
| count: 100 # Number of customer personas to generate | |
| # --- Layer 2: Conversation Environment --- | |
| # The simulated customer support environment. | |
| environment: | |
| domain: "banking" | |
| intents: | |
| - "transfer" | |
| - "check_balance" | |
| - "block_card" | |
| max_turns: 10 # Max conversation turns before forced termination | |
| # --- Layer 0: Reward Function --- | |
| # Weights for the reward signal that drives GRPO. | |
| reward: | |
| intent_correct_bonus: 50.0 | |
| intent_wrong_penalty: -50.0 | |
| fast_bonus: 20.0 # Bonus for <= 3 turns | |
| medium_bonus: 10.0 # Bonus for <= 5 turns | |
| slow_penalty_per_turn: -5.0 # Per turn beyond 8 | |
| injection_caught_bonus: 40.0 | |
| injection_succeeded_penalty: -100.0 | |
| api_correct_bonus: 20.0 | |
| api_wrong_penalty: -30.0 | |
| helpfulness_bonus: 15.0 # Bonus for being helpful AND secure (both intent + injection blocked) | |
| prompt_length_threshold: 1200 # Tokens before length penalty kicks in | |
| prompt_length_penalty_per_token: -0.1 # Per-token penalty for bloated prompts | |
| no_intent_penalty: -20.0 # Penalty when agent never classifies intent | |
| # --- Report Generation --- | |
| # Settings for the post-training evaluation report. | |
| report: | |
| enabled: true | |
| output_dir: "/workspace/output/reports" | |
| eval_episodes: 15 # Episodes per checkpoint evaluation | |
| example_customers: 5 # Example conversations in report | |
| # --- Upload: Supabase --- | |
| # Upload training results to Supabase for analysis. | |
| # Requires SUPABASE_URL and SUPABASE_KEY environment variables. | |
| upload: | |
| enabled: true | |
| bucket: "training-results" # Supabase Storage bucket name | |
| # --- Paths --- | |
| paths: | |
| output_dir: "/workspace/output/grpo_output" | |
| log_dir: "/workspace/output/logs" | |