Spaces:
Running on T4
Running on T4
Claude commited on
Add training report & logging system with reward charts and conversation comparisons
Browse filesAdds TrainingLogger (per-iteration text log + JSON) and ReportGenerator
(matplotlib reward charts, 3-checkpoint prompt comparison evaluated on
30 episodes each, 10 diverse customer conversation examples showing
agent improvement). Integrates into both MockPromptOptimizer and
GRPOPromptTrainer via optional logger parameter.
https://claude.ai/code/session_01DPirJ78YYN4fJUvUFJ5D6V
- Dockerfile +1 -1
- layer1/grpo_trainer.py +16 -2
- layer1/train.py +44 -2
- layer1/training_logger.py +421 -0
- pyproject.toml +1 -0
Dockerfile
CHANGED
|
@@ -4,7 +4,7 @@ WORKDIR /app
|
|
| 4 |
|
| 5 |
COPY . .
|
| 6 |
|
| 7 |
-
RUN pip install --no-cache-dir gradio huggingface-hub requests pydantic
|
| 8 |
|
| 9 |
EXPOSE 7860
|
| 10 |
|
|
|
|
| 4 |
|
| 5 |
COPY . .
|
| 6 |
|
| 7 |
+
RUN pip install --no-cache-dir gradio huggingface-hub requests pydantic matplotlib python-dotenv
|
| 8 |
|
| 9 |
EXPOSE 7860
|
| 10 |
|
layer1/grpo_trainer.py
CHANGED
|
@@ -143,11 +143,13 @@ class GRPOPromptTrainer:
|
|
| 143 |
Requires GPU and train dependencies: pip install -e ".[train]"
|
| 144 |
"""
|
| 145 |
|
| 146 |
-
def __init__(self, config: GRPOConfig, evaluator: PromptEvaluator):
|
| 147 |
self.config = config
|
| 148 |
self.evaluator = evaluator
|
| 149 |
self._model = None
|
| 150 |
self._tokenizer = None
|
|
|
|
|
|
|
| 151 |
|
| 152 |
def setup_model(self):
|
| 153 |
"""Load model with Unsloth LoRA quantization."""
|
|
@@ -197,6 +199,14 @@ class GRPOPromptTrainer:
|
|
| 197 |
rewards.append(result["mean_reward"])
|
| 198 |
logger.info("Prompt reward: %.1f", result["mean_reward"])
|
| 199 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
return rewards
|
| 201 |
|
| 202 |
def train(self):
|
|
@@ -315,9 +325,10 @@ class MockPromptOptimizer:
|
|
| 315 |
),
|
| 316 |
]
|
| 317 |
|
| 318 |
-
def __init__(self, evaluator: PromptEvaluator):
|
| 319 |
self.evaluator = evaluator
|
| 320 |
self.results: list[dict[str, Any]] = []
|
|
|
|
| 321 |
|
| 322 |
def optimize(self, num_episodes_per_prompt: int = 10) -> dict[str, Any]:
|
| 323 |
"""Evaluate all candidate prompts and return the best one."""
|
|
@@ -333,6 +344,9 @@ class MockPromptOptimizer:
|
|
| 333 |
self.results.append(result)
|
| 334 |
print(f"Prompt {i}: mean_reward={result['mean_reward']:.1f}")
|
| 335 |
|
|
|
|
|
|
|
|
|
|
| 336 |
self.results.sort(key=lambda r: r["mean_reward"], reverse=True)
|
| 337 |
best = self.results[0]
|
| 338 |
|
|
|
|
| 143 |
Requires GPU and train dependencies: pip install -e ".[train]"
|
| 144 |
"""
|
| 145 |
|
| 146 |
+
def __init__(self, config: GRPOConfig, evaluator: PromptEvaluator, logger=None):
|
| 147 |
self.config = config
|
| 148 |
self.evaluator = evaluator
|
| 149 |
self._model = None
|
| 150 |
self._tokenizer = None
|
| 151 |
+
self._logger = logger
|
| 152 |
+
self._current_step = 0
|
| 153 |
|
| 154 |
def setup_model(self):
|
| 155 |
"""Load model with Unsloth LoRA quantization."""
|
|
|
|
| 199 |
rewards.append(result["mean_reward"])
|
| 200 |
logger.info("Prompt reward: %.1f", result["mean_reward"])
|
| 201 |
|
| 202 |
+
if self._logger:
|
| 203 |
+
self._logger.log_iteration(
|
| 204 |
+
step=self._current_step,
|
| 205 |
+
prompt=system_prompt,
|
| 206 |
+
eval_result=result,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
self._current_step += 1
|
| 210 |
return rewards
|
| 211 |
|
| 212 |
def train(self):
|
|
|
|
| 325 |
),
|
| 326 |
]
|
| 327 |
|
| 328 |
+
def __init__(self, evaluator: PromptEvaluator, logger=None):
|
| 329 |
self.evaluator = evaluator
|
| 330 |
self.results: list[dict[str, Any]] = []
|
| 331 |
+
self._logger = logger
|
| 332 |
|
| 333 |
def optimize(self, num_episodes_per_prompt: int = 10) -> dict[str, Any]:
|
| 334 |
"""Evaluate all candidate prompts and return the best one."""
|
|
|
|
| 344 |
self.results.append(result)
|
| 345 |
print(f"Prompt {i}: mean_reward={result['mean_reward']:.1f}")
|
| 346 |
|
| 347 |
+
if self._logger:
|
| 348 |
+
self._logger.log_iteration(step=i, prompt=prompt, eval_result=result)
|
| 349 |
+
|
| 350 |
self.results.sort(key=lambda r: r["mean_reward"], reverse=True)
|
| 351 |
best = self.results[0]
|
| 352 |
|
layer1/train.py
CHANGED
|
@@ -33,6 +33,7 @@ from layer1.grpo_trainer import (
|
|
| 33 |
PromptEvaluator,
|
| 34 |
build_meta_prompt,
|
| 35 |
)
|
|
|
|
| 36 |
from layer2.customer_sim import CustomerPersona, CustomerSimulator
|
| 37 |
from layer2.hf_agent import HFAgent
|
| 38 |
from personas.generate_personas import generate_personas
|
|
@@ -63,7 +64,11 @@ def load_evaluator(hf_token: str | None = None, use_llm_agent: bool = False) ->
|
|
| 63 |
def run_mock(args):
|
| 64 |
"""Run mock optimization with hand-written prompts."""
|
| 65 |
evaluator = load_evaluator(args.hf_token, use_llm_agent=args.llm_agent)
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
result = optimizer.optimize(num_episodes_per_prompt=args.episodes)
|
| 68 |
|
| 69 |
print(f"\n{'='*60}")
|
|
@@ -79,16 +84,29 @@ def run_mock(args):
|
|
| 79 |
json.dump(result, f, indent=2, default=str)
|
| 80 |
print(f"\nResults saved to {args.output}")
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
def run_train(args):
|
| 84 |
"""Run full GRPO training (requires GPU)."""
|
| 85 |
evaluator = load_evaluator(args.hf_token, use_llm_agent=args.llm_agent)
|
|
|
|
| 86 |
config = GRPOConfig(
|
| 87 |
num_training_steps=args.steps,
|
| 88 |
episodes_per_candidate=args.episodes,
|
| 89 |
output_dir=args.output_dir,
|
| 90 |
)
|
| 91 |
-
trainer = GRPOPromptTrainer(config=config, evaluator=evaluator)
|
| 92 |
trainer.setup_model()
|
| 93 |
trainer.train()
|
| 94 |
|
|
@@ -102,6 +120,18 @@ def run_train(args):
|
|
| 102 |
result = evaluator.evaluate_prompt(best_prompt, num_episodes=50)
|
| 103 |
print(f"\nEvaluation: mean_reward={result['mean_reward']:.1f}")
|
| 104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
def run_eval(args):
|
| 107 |
"""Evaluate a single prompt."""
|
|
@@ -136,6 +166,18 @@ def main():
|
|
| 136 |
parser.add_argument("--prompt", type=str, default=None, help="Prompt to evaluate (eval mode)")
|
| 137 |
parser.add_argument("--llm-agent", action="store_true",
|
| 138 |
help="Use LLM (Llama 3.1) as the agent instead of rule-based")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
args = parser.parse_args()
|
| 140 |
|
| 141 |
if args.mode == "train":
|
|
|
|
| 33 |
PromptEvaluator,
|
| 34 |
build_meta_prompt,
|
| 35 |
)
|
| 36 |
+
from layer1.training_logger import TrainingLogger, ReportGenerator
|
| 37 |
from layer2.customer_sim import CustomerPersona, CustomerSimulator
|
| 38 |
from layer2.hf_agent import HFAgent
|
| 39 |
from personas.generate_personas import generate_personas
|
|
|
|
| 64 |
def run_mock(args):
|
| 65 |
"""Run mock optimization with hand-written prompts."""
|
| 66 |
evaluator = load_evaluator(args.hf_token, use_llm_agent=args.llm_agent)
|
| 67 |
+
training_logger = TrainingLogger(
|
| 68 |
+
log_dir=args.log_dir,
|
| 69 |
+
total_steps=len(MockPromptOptimizer.CANDIDATE_PROMPTS),
|
| 70 |
+
)
|
| 71 |
+
optimizer = MockPromptOptimizer(evaluator, logger=training_logger)
|
| 72 |
result = optimizer.optimize(num_episodes_per_prompt=args.episodes)
|
| 73 |
|
| 74 |
print(f"\n{'='*60}")
|
|
|
|
| 84 |
json.dump(result, f, indent=2, default=str)
|
| 85 |
print(f"\nResults saved to {args.output}")
|
| 86 |
|
| 87 |
+
if args.report:
|
| 88 |
+
print(f"\n{'='*60}")
|
| 89 |
+
print("GENERATING TRAINING REPORT...")
|
| 90 |
+
print(f"{'='*60}")
|
| 91 |
+
report_gen = ReportGenerator(evaluator, training_logger)
|
| 92 |
+
report_path = report_gen.generate_report(
|
| 93 |
+
output_dir=args.report_dir,
|
| 94 |
+
num_eval_episodes=args.eval_episodes,
|
| 95 |
+
num_example_customers=args.example_customers,
|
| 96 |
+
)
|
| 97 |
+
print(f"\nReport saved to {report_path}")
|
| 98 |
+
|
| 99 |
|
| 100 |
def run_train(args):
|
| 101 |
"""Run full GRPO training (requires GPU)."""
|
| 102 |
evaluator = load_evaluator(args.hf_token, use_llm_agent=args.llm_agent)
|
| 103 |
+
training_logger = TrainingLogger(log_dir=args.log_dir, total_steps=args.steps)
|
| 104 |
config = GRPOConfig(
|
| 105 |
num_training_steps=args.steps,
|
| 106 |
episodes_per_candidate=args.episodes,
|
| 107 |
output_dir=args.output_dir,
|
| 108 |
)
|
| 109 |
+
trainer = GRPOPromptTrainer(config=config, evaluator=evaluator, logger=training_logger)
|
| 110 |
trainer.setup_model()
|
| 111 |
trainer.train()
|
| 112 |
|
|
|
|
| 120 |
result = evaluator.evaluate_prompt(best_prompt, num_episodes=50)
|
| 121 |
print(f"\nEvaluation: mean_reward={result['mean_reward']:.1f}")
|
| 122 |
|
| 123 |
+
if args.report:
|
| 124 |
+
print(f"\n{'='*60}")
|
| 125 |
+
print("GENERATING TRAINING REPORT...")
|
| 126 |
+
print(f"{'='*60}")
|
| 127 |
+
report_gen = ReportGenerator(evaluator, training_logger)
|
| 128 |
+
report_path = report_gen.generate_report(
|
| 129 |
+
output_dir=args.report_dir,
|
| 130 |
+
num_eval_episodes=args.eval_episodes,
|
| 131 |
+
num_example_customers=args.example_customers,
|
| 132 |
+
)
|
| 133 |
+
print(f"\nReport saved to {report_path}")
|
| 134 |
+
|
| 135 |
|
| 136 |
def run_eval(args):
|
| 137 |
"""Evaluate a single prompt."""
|
|
|
|
| 166 |
parser.add_argument("--prompt", type=str, default=None, help="Prompt to evaluate (eval mode)")
|
| 167 |
parser.add_argument("--llm-agent", action="store_true",
|
| 168 |
help="Use LLM (Llama 3.1) as the agent instead of rule-based")
|
| 169 |
+
parser.add_argument("--report", action="store_true", default=True,
|
| 170 |
+
help="Generate training report after completion (default: True)")
|
| 171 |
+
parser.add_argument("--no-report", action="store_false", dest="report",
|
| 172 |
+
help="Skip report generation")
|
| 173 |
+
parser.add_argument("--report-dir", type=str, default="./reports",
|
| 174 |
+
help="Directory for report output")
|
| 175 |
+
parser.add_argument("--log-dir", type=str, default="./logs",
|
| 176 |
+
help="Directory for training logs")
|
| 177 |
+
parser.add_argument("--eval-episodes", type=int, default=30,
|
| 178 |
+
help="Episodes per checkpoint for report evaluation")
|
| 179 |
+
parser.add_argument("--example-customers", type=int, default=10,
|
| 180 |
+
help="Number of example customers in report")
|
| 181 |
args = parser.parse_args()
|
| 182 |
|
| 183 |
if args.mode == "train":
|
layer1/training_logger.py
ADDED
|
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training Logger & Report Generator for the RL Prompt Optimization pipeline.
|
| 3 |
+
|
| 4 |
+
TrainingLogger: Appends per-iteration logs to a text file and keeps structured data in memory.
|
| 5 |
+
ReportGenerator: After training, evaluates checkpoint prompts and produces a markdown report
|
| 6 |
+
with reward charts and side-by-side conversation comparisons.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import json
|
| 12 |
+
import logging
|
| 13 |
+
import os
|
| 14 |
+
import random
|
| 15 |
+
from datetime import datetime
|
| 16 |
+
from dataclasses import dataclass, field
|
| 17 |
+
from typing import Any, Callable
|
| 18 |
+
|
| 19 |
+
from layer0.reward import reward_fn
|
| 20 |
+
from layer2.customer_sim import CustomerPersona
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class TrainingLogger:
|
| 26 |
+
"""Logs each training iteration to a text file and stores structured data."""
|
| 27 |
+
|
| 28 |
+
def __init__(self, log_dir: str = "./logs", total_steps: int | None = None):
|
| 29 |
+
os.makedirs(log_dir, exist_ok=True)
|
| 30 |
+
self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 31 |
+
self.log_path = os.path.join(log_dir, f"log_{self.timestamp}.txt")
|
| 32 |
+
self.json_path = os.path.join(log_dir, f"training_{self.timestamp}.json")
|
| 33 |
+
self.total_steps = total_steps
|
| 34 |
+
self.iterations: list[dict[str, Any]] = []
|
| 35 |
+
self._start_time = datetime.now()
|
| 36 |
+
|
| 37 |
+
with open(self.log_path, "w") as f:
|
| 38 |
+
f.write(f"Training Log — {self._start_time.isoformat()}\n")
|
| 39 |
+
f.write(f"{'=' * 60}\n\n")
|
| 40 |
+
|
| 41 |
+
def log_iteration(self, step: int, prompt: str, eval_result: dict[str, Any]):
|
| 42 |
+
"""Log a single training iteration (one prompt evaluated)."""
|
| 43 |
+
entry = {
|
| 44 |
+
"step": step,
|
| 45 |
+
"prompt": prompt,
|
| 46 |
+
"mean_reward": eval_result.get("mean_reward", 0.0),
|
| 47 |
+
"min_reward": eval_result.get("min_reward", 0.0),
|
| 48 |
+
"max_reward": eval_result.get("max_reward", 0.0),
|
| 49 |
+
"num_episodes": eval_result.get("num_episodes", 0),
|
| 50 |
+
"rewards": eval_result.get("rewards", []),
|
| 51 |
+
"logs": eval_result.get("logs", []),
|
| 52 |
+
}
|
| 53 |
+
self.iterations.append(entry)
|
| 54 |
+
|
| 55 |
+
total_str = f" / {self.total_steps}" if self.total_steps else ""
|
| 56 |
+
with open(self.log_path, "a") as f:
|
| 57 |
+
f.write(f"=== Step {step}{total_str} ===\n")
|
| 58 |
+
f.write(f"Prompt: {prompt[:300]}{'...' if len(prompt) > 300 else ''}\n")
|
| 59 |
+
f.write(f"Mean Reward: {entry['mean_reward']:.1f}\n")
|
| 60 |
+
f.write(f"Min/Max: {entry['min_reward']:.1f} / {entry['max_reward']:.1f}\n")
|
| 61 |
+
f.write(f"Episodes: {entry['num_episodes']}\n")
|
| 62 |
+
f.write(f"---\n\n")
|
| 63 |
+
|
| 64 |
+
logger.info("Logged step %d: mean_reward=%.1f", step, entry["mean_reward"])
|
| 65 |
+
|
| 66 |
+
def save_json(self):
|
| 67 |
+
"""Save structured training data to JSON."""
|
| 68 |
+
data = {
|
| 69 |
+
"timestamp": self.timestamp,
|
| 70 |
+
"start_time": self._start_time.isoformat(),
|
| 71 |
+
"end_time": datetime.now().isoformat(),
|
| 72 |
+
"total_steps": self.total_steps,
|
| 73 |
+
"iterations": [
|
| 74 |
+
{k: v for k, v in it.items() if k != "logs"}
|
| 75 |
+
for it in self.iterations
|
| 76 |
+
],
|
| 77 |
+
}
|
| 78 |
+
with open(self.json_path, "w") as f:
|
| 79 |
+
json.dump(data, f, indent=2, default=str)
|
| 80 |
+
logger.info("Training data saved to %s", self.json_path)
|
| 81 |
+
|
| 82 |
+
def get_checkpoint_indices(self) -> list[int]:
|
| 83 |
+
"""Return indices for 3 checkpoints: start, middle, end."""
|
| 84 |
+
n = len(self.iterations)
|
| 85 |
+
if n <= 1:
|
| 86 |
+
return list(range(n))
|
| 87 |
+
if n == 2:
|
| 88 |
+
return [0, n - 1]
|
| 89 |
+
return [0, n // 2, n - 1]
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _select_diverse_personas(
|
| 93 |
+
personas: list[CustomerPersona], count: int = 10
|
| 94 |
+
) -> list[CustomerPersona]:
|
| 95 |
+
"""Select a diverse set of personas covering intents, personalities, and SE types."""
|
| 96 |
+
buckets: dict[str, list[CustomerPersona]] = {}
|
| 97 |
+
for p in personas:
|
| 98 |
+
key = f"{p.true_intent}_{p.social_engineering}"
|
| 99 |
+
buckets.setdefault(key, []).append(p)
|
| 100 |
+
|
| 101 |
+
selected: list[CustomerPersona] = []
|
| 102 |
+
# Round-robin across buckets to ensure diversity
|
| 103 |
+
bucket_iters = {k: iter(v) for k, v in buckets.items()}
|
| 104 |
+
while len(selected) < count and bucket_iters:
|
| 105 |
+
empty_keys = []
|
| 106 |
+
for key, it in bucket_iters.items():
|
| 107 |
+
if len(selected) >= count:
|
| 108 |
+
break
|
| 109 |
+
try:
|
| 110 |
+
selected.append(next(it))
|
| 111 |
+
except StopIteration:
|
| 112 |
+
empty_keys.append(key)
|
| 113 |
+
for k in empty_keys:
|
| 114 |
+
del bucket_iters[k]
|
| 115 |
+
|
| 116 |
+
# If we still need more, fill randomly
|
| 117 |
+
if len(selected) < count:
|
| 118 |
+
remaining = [p for p in personas if p not in selected]
|
| 119 |
+
selected.extend(random.sample(remaining, min(count - len(selected), len(remaining))))
|
| 120 |
+
|
| 121 |
+
return selected
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class ReportGenerator:
|
| 125 |
+
"""Generates a training report with charts and conversation comparisons."""
|
| 126 |
+
|
| 127 |
+
def __init__(self, evaluator: Any, training_logger: TrainingLogger):
|
| 128 |
+
self.evaluator = evaluator
|
| 129 |
+
self.logger = training_logger
|
| 130 |
+
|
| 131 |
+
def generate_report(
|
| 132 |
+
self,
|
| 133 |
+
output_dir: str = "./reports",
|
| 134 |
+
num_eval_episodes: int = 30,
|
| 135 |
+
num_example_customers: int = 10,
|
| 136 |
+
) -> str:
|
| 137 |
+
"""Generate the full training report. Returns path to the markdown report."""
|
| 138 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 139 |
+
ts = self.logger.timestamp
|
| 140 |
+
|
| 141 |
+
# 1. Select checkpoint prompts
|
| 142 |
+
indices = self.logger.get_checkpoint_indices()
|
| 143 |
+
checkpoints = [self.logger.iterations[i] for i in indices]
|
| 144 |
+
checkpoint_labels = self._make_labels(indices)
|
| 145 |
+
|
| 146 |
+
# 2. Evaluate each checkpoint on same personas for fair comparison
|
| 147 |
+
eval_personas = random.sample(
|
| 148 |
+
self.evaluator.env.personas,
|
| 149 |
+
min(num_eval_episodes, len(self.evaluator.env.personas)),
|
| 150 |
+
)
|
| 151 |
+
checkpoint_evals = []
|
| 152 |
+
for i, cp in enumerate(checkpoints):
|
| 153 |
+
logger.info(
|
| 154 |
+
"Evaluating checkpoint %s (%d/%d)...",
|
| 155 |
+
checkpoint_labels[i], i + 1, len(checkpoints),
|
| 156 |
+
)
|
| 157 |
+
result = self.evaluator.evaluate_prompt(
|
| 158 |
+
cp["prompt"],
|
| 159 |
+
num_episodes=num_eval_episodes,
|
| 160 |
+
personas_subset=eval_personas,
|
| 161 |
+
)
|
| 162 |
+
result["label"] = checkpoint_labels[i]
|
| 163 |
+
result["step"] = cp["step"]
|
| 164 |
+
result["prompt"] = cp["prompt"]
|
| 165 |
+
result["training_mean_reward"] = cp["mean_reward"]
|
| 166 |
+
checkpoint_evals.append(result)
|
| 167 |
+
|
| 168 |
+
# 3. Generate reward chart
|
| 169 |
+
chart_path = os.path.join(output_dir, f"reward_chart_{ts}.png")
|
| 170 |
+
self._generate_chart(checkpoint_evals, chart_path)
|
| 171 |
+
|
| 172 |
+
# 4. Run example conversations
|
| 173 |
+
example_personas = _select_diverse_personas(
|
| 174 |
+
self.evaluator.env.personas, num_example_customers
|
| 175 |
+
)
|
| 176 |
+
example_conversations = self._run_example_conversations(
|
| 177 |
+
checkpoints, checkpoint_labels, example_personas
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# 5. Compute per-checkpoint metrics
|
| 181 |
+
checkpoint_metrics = self._compute_metrics(checkpoint_evals)
|
| 182 |
+
|
| 183 |
+
# 6. Write markdown report
|
| 184 |
+
report_path = os.path.join(output_dir, f"report_{ts}.md")
|
| 185 |
+
self._write_report(
|
| 186 |
+
report_path, chart_path, checkpoints, checkpoint_labels,
|
| 187 |
+
checkpoint_evals, checkpoint_metrics, example_conversations,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
# 7. Save logger JSON
|
| 191 |
+
self.logger.save_json()
|
| 192 |
+
|
| 193 |
+
logger.info("Report generated: %s", report_path)
|
| 194 |
+
return report_path
|
| 195 |
+
|
| 196 |
+
def _make_labels(self, indices: list[int]) -> list[str]:
|
| 197 |
+
labels = []
|
| 198 |
+
for i, idx in enumerate(indices):
|
| 199 |
+
step = self.logger.iterations[idx]["step"]
|
| 200 |
+
if i == 0:
|
| 201 |
+
labels.append(f"Step {step} (Start)")
|
| 202 |
+
elif i == len(indices) - 1:
|
| 203 |
+
labels.append(f"Step {step} (Final)")
|
| 204 |
+
else:
|
| 205 |
+
labels.append(f"Step {step} (Mid)")
|
| 206 |
+
return labels
|
| 207 |
+
|
| 208 |
+
def _generate_chart(self, checkpoint_evals: list[dict], save_path: str):
|
| 209 |
+
"""Generate reward progression chart using matplotlib."""
|
| 210 |
+
import matplotlib
|
| 211 |
+
matplotlib.use("Agg")
|
| 212 |
+
import matplotlib.pyplot as plt
|
| 213 |
+
|
| 214 |
+
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
| 215 |
+
|
| 216 |
+
# Left: Line chart of all training iterations
|
| 217 |
+
ax1 = axes[0]
|
| 218 |
+
steps = [it["step"] for it in self.logger.iterations]
|
| 219 |
+
rewards = [it["mean_reward"] for it in self.logger.iterations]
|
| 220 |
+
ax1.plot(steps, rewards, "b-o", markersize=4, alpha=0.7, label="Training reward")
|
| 221 |
+
|
| 222 |
+
# Highlight checkpoints
|
| 223 |
+
cp_steps = [e["step"] for e in checkpoint_evals]
|
| 224 |
+
cp_rewards_training = [e["training_mean_reward"] for e in checkpoint_evals]
|
| 225 |
+
ax1.scatter(cp_steps, cp_rewards_training, c="red", s=100, zorder=5, label="Checkpoints")
|
| 226 |
+
for e in checkpoint_evals:
|
| 227 |
+
ax1.annotate(
|
| 228 |
+
e["label"].split("(")[1].rstrip(")"),
|
| 229 |
+
(e["step"], e["training_mean_reward"]),
|
| 230 |
+
textcoords="offset points", xytext=(0, 12), ha="center", fontsize=9,
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
ax1.set_xlabel("Training Step")
|
| 234 |
+
ax1.set_ylabel("Mean Reward")
|
| 235 |
+
ax1.set_title("Reward Progression During Training")
|
| 236 |
+
ax1.legend(loc="lower right")
|
| 237 |
+
ax1.grid(True, alpha=0.3)
|
| 238 |
+
|
| 239 |
+
# Right: Grouped bar chart of checkpoint evaluation metrics
|
| 240 |
+
ax2 = axes[1]
|
| 241 |
+
labels = [e["label"] for e in checkpoint_evals]
|
| 242 |
+
eval_rewards = [e["mean_reward"] for e in checkpoint_evals]
|
| 243 |
+
x = range(len(labels))
|
| 244 |
+
bars = ax2.bar(x, eval_rewards, color=["#e74c3c", "#f39c12", "#27ae60"])
|
| 245 |
+
ax2.set_xticks(list(x))
|
| 246 |
+
ax2.set_xticklabels([l.split("(")[1].rstrip(")") for l in labels])
|
| 247 |
+
ax2.set_ylabel("Mean Reward (30-episode eval)")
|
| 248 |
+
ax2.set_title("Checkpoint Comparison")
|
| 249 |
+
ax2.grid(True, alpha=0.3, axis="y")
|
| 250 |
+
|
| 251 |
+
for bar, val in zip(bars, eval_rewards):
|
| 252 |
+
ax2.text(
|
| 253 |
+
bar.get_x() + bar.get_width() / 2, bar.get_height() + 1,
|
| 254 |
+
f"{val:.1f}", ha="center", va="bottom", fontweight="bold",
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
plt.tight_layout()
|
| 258 |
+
plt.savefig(save_path, dpi=150, bbox_inches="tight")
|
| 259 |
+
plt.close()
|
| 260 |
+
logger.info("Chart saved to %s", save_path)
|
| 261 |
+
|
| 262 |
+
def _run_example_conversations(
|
| 263 |
+
self,
|
| 264 |
+
checkpoints: list[dict],
|
| 265 |
+
checkpoint_labels: list[str],
|
| 266 |
+
example_personas: list[CustomerPersona],
|
| 267 |
+
) -> list[dict]:
|
| 268 |
+
"""Run each example persona through all checkpoint agents."""
|
| 269 |
+
examples = []
|
| 270 |
+
for pi, persona in enumerate(example_personas):
|
| 271 |
+
customer_data = {
|
| 272 |
+
"persona_id": persona.id,
|
| 273 |
+
"true_intent": persona.true_intent,
|
| 274 |
+
"personality": persona.personality,
|
| 275 |
+
"social_engineering": persona.social_engineering,
|
| 276 |
+
"conversations": [],
|
| 277 |
+
}
|
| 278 |
+
for ci, cp in enumerate(checkpoints):
|
| 279 |
+
log = self.evaluator.env.run_episode(
|
| 280 |
+
system_prompt=cp["prompt"],
|
| 281 |
+
agent_fn=self.evaluator.agent_fn,
|
| 282 |
+
persona=persona,
|
| 283 |
+
)
|
| 284 |
+
r = reward_fn(log)
|
| 285 |
+
customer_data["conversations"].append({
|
| 286 |
+
"label": checkpoint_labels[ci],
|
| 287 |
+
"reward": r,
|
| 288 |
+
"turns": log.turns,
|
| 289 |
+
"intent_correct": log.intent_correct,
|
| 290 |
+
"injection_attempted": log.injection_attempted,
|
| 291 |
+
"injection_succeeded": log.injection_succeeded,
|
| 292 |
+
"messages": log.messages,
|
| 293 |
+
})
|
| 294 |
+
examples.append(customer_data)
|
| 295 |
+
logger.info("Example %d/%d complete", pi + 1, len(example_personas))
|
| 296 |
+
return examples
|
| 297 |
+
|
| 298 |
+
def _compute_metrics(self, checkpoint_evals: list[dict]) -> list[dict]:
|
| 299 |
+
"""Compute aggregate metrics for each checkpoint evaluation."""
|
| 300 |
+
metrics = []
|
| 301 |
+
for e in checkpoint_evals:
|
| 302 |
+
logs = e.get("logs", [])
|
| 303 |
+
total = len(logs)
|
| 304 |
+
if total == 0:
|
| 305 |
+
metrics.append({
|
| 306 |
+
"label": e["label"],
|
| 307 |
+
"mean_reward": 0,
|
| 308 |
+
"intent_accuracy": 0,
|
| 309 |
+
"injection_resistance": "N/A",
|
| 310 |
+
"avg_turns": 0,
|
| 311 |
+
})
|
| 312 |
+
continue
|
| 313 |
+
|
| 314 |
+
correct = sum(1 for l in logs if l.get("intent_correct"))
|
| 315 |
+
inj_attempted = sum(1 for l in logs if l.get("injection_attempted"))
|
| 316 |
+
inj_succeeded = sum(1 for l in logs if l.get("injection_succeeded"))
|
| 317 |
+
avg_turns = sum(l.get("turns", 0) for l in logs) / total
|
| 318 |
+
|
| 319 |
+
metrics.append({
|
| 320 |
+
"label": e["label"],
|
| 321 |
+
"mean_reward": e["mean_reward"],
|
| 322 |
+
"intent_accuracy": f"{correct / total:.0%}",
|
| 323 |
+
"injection_resistance": (
|
| 324 |
+
f"{(inj_attempted - inj_succeeded) / inj_attempted:.0%}"
|
| 325 |
+
if inj_attempted > 0
|
| 326 |
+
else "N/A (none attempted)"
|
| 327 |
+
),
|
| 328 |
+
"avg_turns": f"{avg_turns:.1f}",
|
| 329 |
+
})
|
| 330 |
+
return metrics
|
| 331 |
+
|
| 332 |
+
def _write_report(
|
| 333 |
+
self,
|
| 334 |
+
report_path: str,
|
| 335 |
+
chart_path: str,
|
| 336 |
+
checkpoints: list[dict],
|
| 337 |
+
checkpoint_labels: list[str],
|
| 338 |
+
checkpoint_evals: list[dict],
|
| 339 |
+
checkpoint_metrics: list[dict],
|
| 340 |
+
example_conversations: list[dict],
|
| 341 |
+
):
|
| 342 |
+
"""Write the final markdown report."""
|
| 343 |
+
duration = datetime.now() - self.logger._start_time
|
| 344 |
+
chart_filename = os.path.basename(chart_path)
|
| 345 |
+
|
| 346 |
+
lines = []
|
| 347 |
+
lines.append(f"# Training Report — {self.logger._start_time.strftime('%Y-%m-%d %H:%M')}")
|
| 348 |
+
lines.append("")
|
| 349 |
+
lines.append("## Training Summary")
|
| 350 |
+
lines.append(f"- **Total steps:** {len(self.logger.iterations)}")
|
| 351 |
+
lines.append(f"- **Duration:** {duration.seconds // 60}m {duration.seconds % 60}s")
|
| 352 |
+
lines.append(f"- **Checkpoints evaluated on:** {checkpoint_evals[0].get('num_episodes', 30)} episodes each")
|
| 353 |
+
lines.append("")
|
| 354 |
+
|
| 355 |
+
# Checkpoint prompts
|
| 356 |
+
lines.append("## Checkpoint Prompts")
|
| 357 |
+
lines.append("")
|
| 358 |
+
for i, cp in enumerate(checkpoints):
|
| 359 |
+
mr = cp["mean_reward"]
|
| 360 |
+
lines.append(f"### {checkpoint_labels[i]} (Training Reward: {mr:.1f})")
|
| 361 |
+
lines.append("")
|
| 362 |
+
lines.append("```")
|
| 363 |
+
lines.append(cp["prompt"])
|
| 364 |
+
lines.append("```")
|
| 365 |
+
lines.append("")
|
| 366 |
+
|
| 367 |
+
# Reward chart
|
| 368 |
+
lines.append("## Reward Progression")
|
| 369 |
+
lines.append("")
|
| 370 |
+
lines.append(f"")
|
| 371 |
+
lines.append("")
|
| 372 |
+
|
| 373 |
+
# Metrics table
|
| 374 |
+
lines.append("## Checkpoint Comparison")
|
| 375 |
+
lines.append("")
|
| 376 |
+
lines.append("| Checkpoint | Mean Reward | Intent Accuracy | Injection Resistance | Avg Turns |")
|
| 377 |
+
lines.append("|------------|-------------|-----------------|----------------------|-----------|")
|
| 378 |
+
for m in checkpoint_metrics:
|
| 379 |
+
lines.append(
|
| 380 |
+
f"| {m['label']} | {m['mean_reward']:.1f} | {m['intent_accuracy']} "
|
| 381 |
+
f"| {m['injection_resistance']} | {m['avg_turns']} |"
|
| 382 |
+
)
|
| 383 |
+
lines.append("")
|
| 384 |
+
|
| 385 |
+
# Example conversations
|
| 386 |
+
lines.append("## Example Conversations (10 Customers x 3 Agents)")
|
| 387 |
+
lines.append("")
|
| 388 |
+
customer_letters = "ABCDEFGHIJ"
|
| 389 |
+
for ci, cust in enumerate(example_conversations):
|
| 390 |
+
letter = customer_letters[ci] if ci < len(customer_letters) else str(ci)
|
| 391 |
+
lines.append(
|
| 392 |
+
f"### Customer {letter} — Persona {cust['persona_id']} "
|
| 393 |
+
f"({cust['true_intent']}, {cust['personality']}, "
|
| 394 |
+
f"SE={cust['social_engineering']})"
|
| 395 |
+
)
|
| 396 |
+
lines.append("")
|
| 397 |
+
|
| 398 |
+
for conv in cust["conversations"]:
|
| 399 |
+
status = ""
|
| 400 |
+
if conv["injection_attempted"]:
|
| 401 |
+
status = " | INJECTION " + ("BLOCKED" if not conv["injection_succeeded"] else "SUCCEEDED")
|
| 402 |
+
lines.append(
|
| 403 |
+
f"#### {conv['label']} (Reward: {conv['reward']:.1f}, "
|
| 404 |
+
f"Turns: {conv['turns']}, "
|
| 405 |
+
f"Intent: {'correct' if conv['intent_correct'] else 'wrong'}"
|
| 406 |
+
f"{status})"
|
| 407 |
+
)
|
| 408 |
+
lines.append("")
|
| 409 |
+
for msg in conv["messages"]:
|
| 410 |
+
if isinstance(msg, dict):
|
| 411 |
+
role = msg.get("role", "unknown").capitalize()
|
| 412 |
+
content = msg.get("content", "")
|
| 413 |
+
lines.append(f"**{role}:** {content}")
|
| 414 |
+
lines.append("")
|
| 415 |
+
lines.append("---")
|
| 416 |
+
lines.append("")
|
| 417 |
+
|
| 418 |
+
lines.append("")
|
| 419 |
+
|
| 420 |
+
with open(report_path, "w") as f:
|
| 421 |
+
f.write("\n".join(lines))
|
pyproject.toml
CHANGED
|
@@ -17,6 +17,7 @@ dependencies = [
|
|
| 17 |
"pydantic>=2.0",
|
| 18 |
"python-dotenv>=1.0.0",
|
| 19 |
"gradio>=4.0.0",
|
|
|
|
| 20 |
]
|
| 21 |
|
| 22 |
[project.optional-dependencies]
|
|
|
|
| 17 |
"pydantic>=2.0",
|
| 18 |
"python-dotenv>=1.0.0",
|
| 19 |
"gradio>=4.0.0",
|
| 20 |
+
"matplotlib>=3.7.0",
|
| 21 |
]
|
| 22 |
|
| 23 |
[project.optional-dependencies]
|