Claude commited on
Commit
506d641
·
unverified ·
1 Parent(s): 4ceb26f

Add training report & logging system with reward charts and conversation comparisons

Browse files

Adds TrainingLogger (per-iteration text log + JSON) and ReportGenerator
(matplotlib reward charts, 3-checkpoint prompt comparison evaluated on
30 episodes each, 10 diverse customer conversation examples showing
agent improvement). Integrates into both MockPromptOptimizer and
GRPOPromptTrainer via optional logger parameter.

https://claude.ai/code/session_01DPirJ78YYN4fJUvUFJ5D6V

Dockerfile CHANGED
@@ -4,7 +4,7 @@ WORKDIR /app
4
 
5
  COPY . .
6
 
7
- RUN pip install --no-cache-dir gradio huggingface-hub requests pydantic
8
 
9
  EXPOSE 7860
10
 
 
4
 
5
  COPY . .
6
 
7
+ RUN pip install --no-cache-dir gradio huggingface-hub requests pydantic matplotlib python-dotenv
8
 
9
  EXPOSE 7860
10
 
layer1/grpo_trainer.py CHANGED
@@ -143,11 +143,13 @@ class GRPOPromptTrainer:
143
  Requires GPU and train dependencies: pip install -e ".[train]"
144
  """
145
 
146
- def __init__(self, config: GRPOConfig, evaluator: PromptEvaluator):
147
  self.config = config
148
  self.evaluator = evaluator
149
  self._model = None
150
  self._tokenizer = None
 
 
151
 
152
  def setup_model(self):
153
  """Load model with Unsloth LoRA quantization."""
@@ -197,6 +199,14 @@ class GRPOPromptTrainer:
197
  rewards.append(result["mean_reward"])
198
  logger.info("Prompt reward: %.1f", result["mean_reward"])
199
 
 
 
 
 
 
 
 
 
200
  return rewards
201
 
202
  def train(self):
@@ -315,9 +325,10 @@ class MockPromptOptimizer:
315
  ),
316
  ]
317
 
318
- def __init__(self, evaluator: PromptEvaluator):
319
  self.evaluator = evaluator
320
  self.results: list[dict[str, Any]] = []
 
321
 
322
  def optimize(self, num_episodes_per_prompt: int = 10) -> dict[str, Any]:
323
  """Evaluate all candidate prompts and return the best one."""
@@ -333,6 +344,9 @@ class MockPromptOptimizer:
333
  self.results.append(result)
334
  print(f"Prompt {i}: mean_reward={result['mean_reward']:.1f}")
335
 
 
 
 
336
  self.results.sort(key=lambda r: r["mean_reward"], reverse=True)
337
  best = self.results[0]
338
 
 
143
  Requires GPU and train dependencies: pip install -e ".[train]"
144
  """
145
 
146
+ def __init__(self, config: GRPOConfig, evaluator: PromptEvaluator, logger=None):
147
  self.config = config
148
  self.evaluator = evaluator
149
  self._model = None
150
  self._tokenizer = None
151
+ self._logger = logger
152
+ self._current_step = 0
153
 
154
  def setup_model(self):
155
  """Load model with Unsloth LoRA quantization."""
 
199
  rewards.append(result["mean_reward"])
200
  logger.info("Prompt reward: %.1f", result["mean_reward"])
201
 
202
+ if self._logger:
203
+ self._logger.log_iteration(
204
+ step=self._current_step,
205
+ prompt=system_prompt,
206
+ eval_result=result,
207
+ )
208
+
209
+ self._current_step += 1
210
  return rewards
211
 
212
  def train(self):
 
325
  ),
326
  ]
327
 
328
+ def __init__(self, evaluator: PromptEvaluator, logger=None):
329
  self.evaluator = evaluator
330
  self.results: list[dict[str, Any]] = []
331
+ self._logger = logger
332
 
333
  def optimize(self, num_episodes_per_prompt: int = 10) -> dict[str, Any]:
334
  """Evaluate all candidate prompts and return the best one."""
 
344
  self.results.append(result)
345
  print(f"Prompt {i}: mean_reward={result['mean_reward']:.1f}")
346
 
347
+ if self._logger:
348
+ self._logger.log_iteration(step=i, prompt=prompt, eval_result=result)
349
+
350
  self.results.sort(key=lambda r: r["mean_reward"], reverse=True)
351
  best = self.results[0]
352
 
layer1/train.py CHANGED
@@ -33,6 +33,7 @@ from layer1.grpo_trainer import (
33
  PromptEvaluator,
34
  build_meta_prompt,
35
  )
 
36
  from layer2.customer_sim import CustomerPersona, CustomerSimulator
37
  from layer2.hf_agent import HFAgent
38
  from personas.generate_personas import generate_personas
@@ -63,7 +64,11 @@ def load_evaluator(hf_token: str | None = None, use_llm_agent: bool = False) ->
63
  def run_mock(args):
64
  """Run mock optimization with hand-written prompts."""
65
  evaluator = load_evaluator(args.hf_token, use_llm_agent=args.llm_agent)
66
- optimizer = MockPromptOptimizer(evaluator)
 
 
 
 
67
  result = optimizer.optimize(num_episodes_per_prompt=args.episodes)
68
 
69
  print(f"\n{'='*60}")
@@ -79,16 +84,29 @@ def run_mock(args):
79
  json.dump(result, f, indent=2, default=str)
80
  print(f"\nResults saved to {args.output}")
81
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  def run_train(args):
84
  """Run full GRPO training (requires GPU)."""
85
  evaluator = load_evaluator(args.hf_token, use_llm_agent=args.llm_agent)
 
86
  config = GRPOConfig(
87
  num_training_steps=args.steps,
88
  episodes_per_candidate=args.episodes,
89
  output_dir=args.output_dir,
90
  )
91
- trainer = GRPOPromptTrainer(config=config, evaluator=evaluator)
92
  trainer.setup_model()
93
  trainer.train()
94
 
@@ -102,6 +120,18 @@ def run_train(args):
102
  result = evaluator.evaluate_prompt(best_prompt, num_episodes=50)
103
  print(f"\nEvaluation: mean_reward={result['mean_reward']:.1f}")
104
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  def run_eval(args):
107
  """Evaluate a single prompt."""
@@ -136,6 +166,18 @@ def main():
136
  parser.add_argument("--prompt", type=str, default=None, help="Prompt to evaluate (eval mode)")
137
  parser.add_argument("--llm-agent", action="store_true",
138
  help="Use LLM (Llama 3.1) as the agent instead of rule-based")
 
 
 
 
 
 
 
 
 
 
 
 
139
  args = parser.parse_args()
140
 
141
  if args.mode == "train":
 
33
  PromptEvaluator,
34
  build_meta_prompt,
35
  )
36
+ from layer1.training_logger import TrainingLogger, ReportGenerator
37
  from layer2.customer_sim import CustomerPersona, CustomerSimulator
38
  from layer2.hf_agent import HFAgent
39
  from personas.generate_personas import generate_personas
 
64
  def run_mock(args):
65
  """Run mock optimization with hand-written prompts."""
66
  evaluator = load_evaluator(args.hf_token, use_llm_agent=args.llm_agent)
67
+ training_logger = TrainingLogger(
68
+ log_dir=args.log_dir,
69
+ total_steps=len(MockPromptOptimizer.CANDIDATE_PROMPTS),
70
+ )
71
+ optimizer = MockPromptOptimizer(evaluator, logger=training_logger)
72
  result = optimizer.optimize(num_episodes_per_prompt=args.episodes)
73
 
74
  print(f"\n{'='*60}")
 
84
  json.dump(result, f, indent=2, default=str)
85
  print(f"\nResults saved to {args.output}")
86
 
87
+ if args.report:
88
+ print(f"\n{'='*60}")
89
+ print("GENERATING TRAINING REPORT...")
90
+ print(f"{'='*60}")
91
+ report_gen = ReportGenerator(evaluator, training_logger)
92
+ report_path = report_gen.generate_report(
93
+ output_dir=args.report_dir,
94
+ num_eval_episodes=args.eval_episodes,
95
+ num_example_customers=args.example_customers,
96
+ )
97
+ print(f"\nReport saved to {report_path}")
98
+
99
 
100
  def run_train(args):
101
  """Run full GRPO training (requires GPU)."""
102
  evaluator = load_evaluator(args.hf_token, use_llm_agent=args.llm_agent)
103
+ training_logger = TrainingLogger(log_dir=args.log_dir, total_steps=args.steps)
104
  config = GRPOConfig(
105
  num_training_steps=args.steps,
106
  episodes_per_candidate=args.episodes,
107
  output_dir=args.output_dir,
108
  )
109
+ trainer = GRPOPromptTrainer(config=config, evaluator=evaluator, logger=training_logger)
110
  trainer.setup_model()
111
  trainer.train()
112
 
 
120
  result = evaluator.evaluate_prompt(best_prompt, num_episodes=50)
121
  print(f"\nEvaluation: mean_reward={result['mean_reward']:.1f}")
122
 
123
+ if args.report:
124
+ print(f"\n{'='*60}")
125
+ print("GENERATING TRAINING REPORT...")
126
+ print(f"{'='*60}")
127
+ report_gen = ReportGenerator(evaluator, training_logger)
128
+ report_path = report_gen.generate_report(
129
+ output_dir=args.report_dir,
130
+ num_eval_episodes=args.eval_episodes,
131
+ num_example_customers=args.example_customers,
132
+ )
133
+ print(f"\nReport saved to {report_path}")
134
+
135
 
136
  def run_eval(args):
137
  """Evaluate a single prompt."""
 
166
  parser.add_argument("--prompt", type=str, default=None, help="Prompt to evaluate (eval mode)")
167
  parser.add_argument("--llm-agent", action="store_true",
168
  help="Use LLM (Llama 3.1) as the agent instead of rule-based")
169
+ parser.add_argument("--report", action="store_true", default=True,
170
+ help="Generate training report after completion (default: True)")
171
+ parser.add_argument("--no-report", action="store_false", dest="report",
172
+ help="Skip report generation")
173
+ parser.add_argument("--report-dir", type=str, default="./reports",
174
+ help="Directory for report output")
175
+ parser.add_argument("--log-dir", type=str, default="./logs",
176
+ help="Directory for training logs")
177
+ parser.add_argument("--eval-episodes", type=int, default=30,
178
+ help="Episodes per checkpoint for report evaluation")
179
+ parser.add_argument("--example-customers", type=int, default=10,
180
+ help="Number of example customers in report")
181
  args = parser.parse_args()
182
 
183
  if args.mode == "train":
layer1/training_logger.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training Logger & Report Generator for the RL Prompt Optimization pipeline.
3
+
4
+ TrainingLogger: Appends per-iteration logs to a text file and keeps structured data in memory.
5
+ ReportGenerator: After training, evaluates checkpoint prompts and produces a markdown report
6
+ with reward charts and side-by-side conversation comparisons.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import json
12
+ import logging
13
+ import os
14
+ import random
15
+ from datetime import datetime
16
+ from dataclasses import dataclass, field
17
+ from typing import Any, Callable
18
+
19
+ from layer0.reward import reward_fn
20
+ from layer2.customer_sim import CustomerPersona
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class TrainingLogger:
26
+ """Logs each training iteration to a text file and stores structured data."""
27
+
28
+ def __init__(self, log_dir: str = "./logs", total_steps: int | None = None):
29
+ os.makedirs(log_dir, exist_ok=True)
30
+ self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
31
+ self.log_path = os.path.join(log_dir, f"log_{self.timestamp}.txt")
32
+ self.json_path = os.path.join(log_dir, f"training_{self.timestamp}.json")
33
+ self.total_steps = total_steps
34
+ self.iterations: list[dict[str, Any]] = []
35
+ self._start_time = datetime.now()
36
+
37
+ with open(self.log_path, "w") as f:
38
+ f.write(f"Training Log — {self._start_time.isoformat()}\n")
39
+ f.write(f"{'=' * 60}\n\n")
40
+
41
+ def log_iteration(self, step: int, prompt: str, eval_result: dict[str, Any]):
42
+ """Log a single training iteration (one prompt evaluated)."""
43
+ entry = {
44
+ "step": step,
45
+ "prompt": prompt,
46
+ "mean_reward": eval_result.get("mean_reward", 0.0),
47
+ "min_reward": eval_result.get("min_reward", 0.0),
48
+ "max_reward": eval_result.get("max_reward", 0.0),
49
+ "num_episodes": eval_result.get("num_episodes", 0),
50
+ "rewards": eval_result.get("rewards", []),
51
+ "logs": eval_result.get("logs", []),
52
+ }
53
+ self.iterations.append(entry)
54
+
55
+ total_str = f" / {self.total_steps}" if self.total_steps else ""
56
+ with open(self.log_path, "a") as f:
57
+ f.write(f"=== Step {step}{total_str} ===\n")
58
+ f.write(f"Prompt: {prompt[:300]}{'...' if len(prompt) > 300 else ''}\n")
59
+ f.write(f"Mean Reward: {entry['mean_reward']:.1f}\n")
60
+ f.write(f"Min/Max: {entry['min_reward']:.1f} / {entry['max_reward']:.1f}\n")
61
+ f.write(f"Episodes: {entry['num_episodes']}\n")
62
+ f.write(f"---\n\n")
63
+
64
+ logger.info("Logged step %d: mean_reward=%.1f", step, entry["mean_reward"])
65
+
66
+ def save_json(self):
67
+ """Save structured training data to JSON."""
68
+ data = {
69
+ "timestamp": self.timestamp,
70
+ "start_time": self._start_time.isoformat(),
71
+ "end_time": datetime.now().isoformat(),
72
+ "total_steps": self.total_steps,
73
+ "iterations": [
74
+ {k: v for k, v in it.items() if k != "logs"}
75
+ for it in self.iterations
76
+ ],
77
+ }
78
+ with open(self.json_path, "w") as f:
79
+ json.dump(data, f, indent=2, default=str)
80
+ logger.info("Training data saved to %s", self.json_path)
81
+
82
+ def get_checkpoint_indices(self) -> list[int]:
83
+ """Return indices for 3 checkpoints: start, middle, end."""
84
+ n = len(self.iterations)
85
+ if n <= 1:
86
+ return list(range(n))
87
+ if n == 2:
88
+ return [0, n - 1]
89
+ return [0, n // 2, n - 1]
90
+
91
+
92
+ def _select_diverse_personas(
93
+ personas: list[CustomerPersona], count: int = 10
94
+ ) -> list[CustomerPersona]:
95
+ """Select a diverse set of personas covering intents, personalities, and SE types."""
96
+ buckets: dict[str, list[CustomerPersona]] = {}
97
+ for p in personas:
98
+ key = f"{p.true_intent}_{p.social_engineering}"
99
+ buckets.setdefault(key, []).append(p)
100
+
101
+ selected: list[CustomerPersona] = []
102
+ # Round-robin across buckets to ensure diversity
103
+ bucket_iters = {k: iter(v) for k, v in buckets.items()}
104
+ while len(selected) < count and bucket_iters:
105
+ empty_keys = []
106
+ for key, it in bucket_iters.items():
107
+ if len(selected) >= count:
108
+ break
109
+ try:
110
+ selected.append(next(it))
111
+ except StopIteration:
112
+ empty_keys.append(key)
113
+ for k in empty_keys:
114
+ del bucket_iters[k]
115
+
116
+ # If we still need more, fill randomly
117
+ if len(selected) < count:
118
+ remaining = [p for p in personas if p not in selected]
119
+ selected.extend(random.sample(remaining, min(count - len(selected), len(remaining))))
120
+
121
+ return selected
122
+
123
+
124
+ class ReportGenerator:
125
+ """Generates a training report with charts and conversation comparisons."""
126
+
127
+ def __init__(self, evaluator: Any, training_logger: TrainingLogger):
128
+ self.evaluator = evaluator
129
+ self.logger = training_logger
130
+
131
+ def generate_report(
132
+ self,
133
+ output_dir: str = "./reports",
134
+ num_eval_episodes: int = 30,
135
+ num_example_customers: int = 10,
136
+ ) -> str:
137
+ """Generate the full training report. Returns path to the markdown report."""
138
+ os.makedirs(output_dir, exist_ok=True)
139
+ ts = self.logger.timestamp
140
+
141
+ # 1. Select checkpoint prompts
142
+ indices = self.logger.get_checkpoint_indices()
143
+ checkpoints = [self.logger.iterations[i] for i in indices]
144
+ checkpoint_labels = self._make_labels(indices)
145
+
146
+ # 2. Evaluate each checkpoint on same personas for fair comparison
147
+ eval_personas = random.sample(
148
+ self.evaluator.env.personas,
149
+ min(num_eval_episodes, len(self.evaluator.env.personas)),
150
+ )
151
+ checkpoint_evals = []
152
+ for i, cp in enumerate(checkpoints):
153
+ logger.info(
154
+ "Evaluating checkpoint %s (%d/%d)...",
155
+ checkpoint_labels[i], i + 1, len(checkpoints),
156
+ )
157
+ result = self.evaluator.evaluate_prompt(
158
+ cp["prompt"],
159
+ num_episodes=num_eval_episodes,
160
+ personas_subset=eval_personas,
161
+ )
162
+ result["label"] = checkpoint_labels[i]
163
+ result["step"] = cp["step"]
164
+ result["prompt"] = cp["prompt"]
165
+ result["training_mean_reward"] = cp["mean_reward"]
166
+ checkpoint_evals.append(result)
167
+
168
+ # 3. Generate reward chart
169
+ chart_path = os.path.join(output_dir, f"reward_chart_{ts}.png")
170
+ self._generate_chart(checkpoint_evals, chart_path)
171
+
172
+ # 4. Run example conversations
173
+ example_personas = _select_diverse_personas(
174
+ self.evaluator.env.personas, num_example_customers
175
+ )
176
+ example_conversations = self._run_example_conversations(
177
+ checkpoints, checkpoint_labels, example_personas
178
+ )
179
+
180
+ # 5. Compute per-checkpoint metrics
181
+ checkpoint_metrics = self._compute_metrics(checkpoint_evals)
182
+
183
+ # 6. Write markdown report
184
+ report_path = os.path.join(output_dir, f"report_{ts}.md")
185
+ self._write_report(
186
+ report_path, chart_path, checkpoints, checkpoint_labels,
187
+ checkpoint_evals, checkpoint_metrics, example_conversations,
188
+ )
189
+
190
+ # 7. Save logger JSON
191
+ self.logger.save_json()
192
+
193
+ logger.info("Report generated: %s", report_path)
194
+ return report_path
195
+
196
+ def _make_labels(self, indices: list[int]) -> list[str]:
197
+ labels = []
198
+ for i, idx in enumerate(indices):
199
+ step = self.logger.iterations[idx]["step"]
200
+ if i == 0:
201
+ labels.append(f"Step {step} (Start)")
202
+ elif i == len(indices) - 1:
203
+ labels.append(f"Step {step} (Final)")
204
+ else:
205
+ labels.append(f"Step {step} (Mid)")
206
+ return labels
207
+
208
+ def _generate_chart(self, checkpoint_evals: list[dict], save_path: str):
209
+ """Generate reward progression chart using matplotlib."""
210
+ import matplotlib
211
+ matplotlib.use("Agg")
212
+ import matplotlib.pyplot as plt
213
+
214
+ fig, axes = plt.subplots(1, 2, figsize=(14, 5))
215
+
216
+ # Left: Line chart of all training iterations
217
+ ax1 = axes[0]
218
+ steps = [it["step"] for it in self.logger.iterations]
219
+ rewards = [it["mean_reward"] for it in self.logger.iterations]
220
+ ax1.plot(steps, rewards, "b-o", markersize=4, alpha=0.7, label="Training reward")
221
+
222
+ # Highlight checkpoints
223
+ cp_steps = [e["step"] for e in checkpoint_evals]
224
+ cp_rewards_training = [e["training_mean_reward"] for e in checkpoint_evals]
225
+ ax1.scatter(cp_steps, cp_rewards_training, c="red", s=100, zorder=5, label="Checkpoints")
226
+ for e in checkpoint_evals:
227
+ ax1.annotate(
228
+ e["label"].split("(")[1].rstrip(")"),
229
+ (e["step"], e["training_mean_reward"]),
230
+ textcoords="offset points", xytext=(0, 12), ha="center", fontsize=9,
231
+ )
232
+
233
+ ax1.set_xlabel("Training Step")
234
+ ax1.set_ylabel("Mean Reward")
235
+ ax1.set_title("Reward Progression During Training")
236
+ ax1.legend(loc="lower right")
237
+ ax1.grid(True, alpha=0.3)
238
+
239
+ # Right: Grouped bar chart of checkpoint evaluation metrics
240
+ ax2 = axes[1]
241
+ labels = [e["label"] for e in checkpoint_evals]
242
+ eval_rewards = [e["mean_reward"] for e in checkpoint_evals]
243
+ x = range(len(labels))
244
+ bars = ax2.bar(x, eval_rewards, color=["#e74c3c", "#f39c12", "#27ae60"])
245
+ ax2.set_xticks(list(x))
246
+ ax2.set_xticklabels([l.split("(")[1].rstrip(")") for l in labels])
247
+ ax2.set_ylabel("Mean Reward (30-episode eval)")
248
+ ax2.set_title("Checkpoint Comparison")
249
+ ax2.grid(True, alpha=0.3, axis="y")
250
+
251
+ for bar, val in zip(bars, eval_rewards):
252
+ ax2.text(
253
+ bar.get_x() + bar.get_width() / 2, bar.get_height() + 1,
254
+ f"{val:.1f}", ha="center", va="bottom", fontweight="bold",
255
+ )
256
+
257
+ plt.tight_layout()
258
+ plt.savefig(save_path, dpi=150, bbox_inches="tight")
259
+ plt.close()
260
+ logger.info("Chart saved to %s", save_path)
261
+
262
+ def _run_example_conversations(
263
+ self,
264
+ checkpoints: list[dict],
265
+ checkpoint_labels: list[str],
266
+ example_personas: list[CustomerPersona],
267
+ ) -> list[dict]:
268
+ """Run each example persona through all checkpoint agents."""
269
+ examples = []
270
+ for pi, persona in enumerate(example_personas):
271
+ customer_data = {
272
+ "persona_id": persona.id,
273
+ "true_intent": persona.true_intent,
274
+ "personality": persona.personality,
275
+ "social_engineering": persona.social_engineering,
276
+ "conversations": [],
277
+ }
278
+ for ci, cp in enumerate(checkpoints):
279
+ log = self.evaluator.env.run_episode(
280
+ system_prompt=cp["prompt"],
281
+ agent_fn=self.evaluator.agent_fn,
282
+ persona=persona,
283
+ )
284
+ r = reward_fn(log)
285
+ customer_data["conversations"].append({
286
+ "label": checkpoint_labels[ci],
287
+ "reward": r,
288
+ "turns": log.turns,
289
+ "intent_correct": log.intent_correct,
290
+ "injection_attempted": log.injection_attempted,
291
+ "injection_succeeded": log.injection_succeeded,
292
+ "messages": log.messages,
293
+ })
294
+ examples.append(customer_data)
295
+ logger.info("Example %d/%d complete", pi + 1, len(example_personas))
296
+ return examples
297
+
298
+ def _compute_metrics(self, checkpoint_evals: list[dict]) -> list[dict]:
299
+ """Compute aggregate metrics for each checkpoint evaluation."""
300
+ metrics = []
301
+ for e in checkpoint_evals:
302
+ logs = e.get("logs", [])
303
+ total = len(logs)
304
+ if total == 0:
305
+ metrics.append({
306
+ "label": e["label"],
307
+ "mean_reward": 0,
308
+ "intent_accuracy": 0,
309
+ "injection_resistance": "N/A",
310
+ "avg_turns": 0,
311
+ })
312
+ continue
313
+
314
+ correct = sum(1 for l in logs if l.get("intent_correct"))
315
+ inj_attempted = sum(1 for l in logs if l.get("injection_attempted"))
316
+ inj_succeeded = sum(1 for l in logs if l.get("injection_succeeded"))
317
+ avg_turns = sum(l.get("turns", 0) for l in logs) / total
318
+
319
+ metrics.append({
320
+ "label": e["label"],
321
+ "mean_reward": e["mean_reward"],
322
+ "intent_accuracy": f"{correct / total:.0%}",
323
+ "injection_resistance": (
324
+ f"{(inj_attempted - inj_succeeded) / inj_attempted:.0%}"
325
+ if inj_attempted > 0
326
+ else "N/A (none attempted)"
327
+ ),
328
+ "avg_turns": f"{avg_turns:.1f}",
329
+ })
330
+ return metrics
331
+
332
+ def _write_report(
333
+ self,
334
+ report_path: str,
335
+ chart_path: str,
336
+ checkpoints: list[dict],
337
+ checkpoint_labels: list[str],
338
+ checkpoint_evals: list[dict],
339
+ checkpoint_metrics: list[dict],
340
+ example_conversations: list[dict],
341
+ ):
342
+ """Write the final markdown report."""
343
+ duration = datetime.now() - self.logger._start_time
344
+ chart_filename = os.path.basename(chart_path)
345
+
346
+ lines = []
347
+ lines.append(f"# Training Report — {self.logger._start_time.strftime('%Y-%m-%d %H:%M')}")
348
+ lines.append("")
349
+ lines.append("## Training Summary")
350
+ lines.append(f"- **Total steps:** {len(self.logger.iterations)}")
351
+ lines.append(f"- **Duration:** {duration.seconds // 60}m {duration.seconds % 60}s")
352
+ lines.append(f"- **Checkpoints evaluated on:** {checkpoint_evals[0].get('num_episodes', 30)} episodes each")
353
+ lines.append("")
354
+
355
+ # Checkpoint prompts
356
+ lines.append("## Checkpoint Prompts")
357
+ lines.append("")
358
+ for i, cp in enumerate(checkpoints):
359
+ mr = cp["mean_reward"]
360
+ lines.append(f"### {checkpoint_labels[i]} (Training Reward: {mr:.1f})")
361
+ lines.append("")
362
+ lines.append("```")
363
+ lines.append(cp["prompt"])
364
+ lines.append("```")
365
+ lines.append("")
366
+
367
+ # Reward chart
368
+ lines.append("## Reward Progression")
369
+ lines.append("")
370
+ lines.append(f"![Reward Chart]({chart_filename})")
371
+ lines.append("")
372
+
373
+ # Metrics table
374
+ lines.append("## Checkpoint Comparison")
375
+ lines.append("")
376
+ lines.append("| Checkpoint | Mean Reward | Intent Accuracy | Injection Resistance | Avg Turns |")
377
+ lines.append("|------------|-------------|-----------------|----------------------|-----------|")
378
+ for m in checkpoint_metrics:
379
+ lines.append(
380
+ f"| {m['label']} | {m['mean_reward']:.1f} | {m['intent_accuracy']} "
381
+ f"| {m['injection_resistance']} | {m['avg_turns']} |"
382
+ )
383
+ lines.append("")
384
+
385
+ # Example conversations
386
+ lines.append("## Example Conversations (10 Customers x 3 Agents)")
387
+ lines.append("")
388
+ customer_letters = "ABCDEFGHIJ"
389
+ for ci, cust in enumerate(example_conversations):
390
+ letter = customer_letters[ci] if ci < len(customer_letters) else str(ci)
391
+ lines.append(
392
+ f"### Customer {letter} — Persona {cust['persona_id']} "
393
+ f"({cust['true_intent']}, {cust['personality']}, "
394
+ f"SE={cust['social_engineering']})"
395
+ )
396
+ lines.append("")
397
+
398
+ for conv in cust["conversations"]:
399
+ status = ""
400
+ if conv["injection_attempted"]:
401
+ status = " | INJECTION " + ("BLOCKED" if not conv["injection_succeeded"] else "SUCCEEDED")
402
+ lines.append(
403
+ f"#### {conv['label']} (Reward: {conv['reward']:.1f}, "
404
+ f"Turns: {conv['turns']}, "
405
+ f"Intent: {'correct' if conv['intent_correct'] else 'wrong'}"
406
+ f"{status})"
407
+ )
408
+ lines.append("")
409
+ for msg in conv["messages"]:
410
+ if isinstance(msg, dict):
411
+ role = msg.get("role", "unknown").capitalize()
412
+ content = msg.get("content", "")
413
+ lines.append(f"**{role}:** {content}")
414
+ lines.append("")
415
+ lines.append("---")
416
+ lines.append("")
417
+
418
+ lines.append("")
419
+
420
+ with open(report_path, "w") as f:
421
+ f.write("\n".join(lines))
pyproject.toml CHANGED
@@ -17,6 +17,7 @@ dependencies = [
17
  "pydantic>=2.0",
18
  "python-dotenv>=1.0.0",
19
  "gradio>=4.0.0",
 
20
  ]
21
 
22
  [project.optional-dependencies]
 
17
  "pydantic>=2.0",
18
  "python-dotenv>=1.0.0",
19
  "gradio>=4.0.0",
20
+ "matplotlib>=3.7.0",
21
  ]
22
 
23
  [project.optional-dependencies]