""" Training Logger & Report Generator for the RL Prompt Optimization pipeline. TrainingLogger: Appends per-iteration logs to a text file and keeps structured data in memory. ReportGenerator: After training, evaluates checkpoint prompts and produces a markdown report with reward charts and side-by-side conversation comparisons. """ from __future__ import annotations import json import logging import os import random from datetime import datetime from typing import Any from layer0.reward import reward_fn from layer2.customer_sim import CustomerPersona logger = logging.getLogger(__name__) class TrainingLogger: """Logs each training iteration to a text file and stores structured data.""" def __init__(self, log_dir: str = "./logs", total_steps: int | None = None): os.makedirs(log_dir, exist_ok=True) self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") self.log_path = os.path.join(log_dir, f"log_{self.timestamp}.txt") self.json_path = os.path.join(log_dir, f"training_{self.timestamp}.json") self.total_steps = total_steps self.iterations: list[dict[str, Any]] = [] self._start_time = datetime.now() self._on_step_callbacks: list[Any] = [] with open(self.log_path, "w") as f: f.write(f"Training Log — {self._start_time.isoformat()}\n") f.write(f"{'=' * 60}\n\n") f.flush() os.fsync(f.fileno()) def add_on_step_callback(self, callback): """Register a callback called after each step: callback(step, eval_result, prompt).""" self._on_step_callbacks.append(callback) def log_iteration(self, step: int, prompt: str, eval_result: dict[str, Any]): """Log a single training iteration (one prompt evaluated).""" entry = { "step": step, "prompt": prompt, "mean_reward": eval_result.get("mean_reward", 0.0), "min_reward": eval_result.get("min_reward", 0.0), "max_reward": eval_result.get("max_reward", 0.0), "num_episodes": eval_result.get("num_episodes", 0), "rewards": eval_result.get("rewards", []), "logs": eval_result.get("logs", []), } self.iterations.append(entry) total_str = f" / {self.total_steps}" if self.total_steps else "" with open(self.log_path, "a") as f: f.write(f"=== Step {step}{total_str} ===\n") f.write(f"Prompt: {prompt[:300]}{'...' if len(prompt) > 300 else ''}\n") f.write(f"Mean Reward: {entry['mean_reward']:.1f}\n") f.write(f"Min/Max: {entry['min_reward']:.1f} / {entry['max_reward']:.1f}\n") f.write(f"Episodes: {entry['num_episodes']}\n") f.write(f"---\n\n") f.flush() os.fsync(f.fileno()) # Incremental save — persist JSON after every step so data survives crashes self.save_json() logger.info("Logged step %d: mean_reward=%.1f", step, entry["mean_reward"]) # Notify callbacks (e.g. Supabase uploader) for cb in self._on_step_callbacks: try: cb(step, eval_result, prompt) except Exception as e: logger.error("Step callback failed: %s", e) def save_json(self): """Save structured training data to JSON.""" data = { "timestamp": self.timestamp, "start_time": self._start_time.isoformat(), "end_time": datetime.now().isoformat(), "total_steps": self.total_steps, "iterations": [ {k: v for k, v in it.items() if k != "logs"} for it in self.iterations ], } with open(self.json_path, "w") as f: json.dump(data, f, indent=2, default=str) f.flush() os.fsync(f.fileno()) logger.info("Training data saved to %s", self.json_path) def get_checkpoint_indices(self) -> list[int]: """Return indices for 3 checkpoints: start, middle, end.""" n = len(self.iterations) if n <= 1: return list(range(n)) if n == 2: return [0, n - 1] return [0, n // 2, n - 1] def generate_raw_summary(self) -> dict[str, Any]: """ Generate a raw training summary with arrays for easy post-hoc analysis. Returns a dict with: - steps: [1, 2, 3, ...] - mean_rewards: [30.0, 45.0, ...] - min_rewards: [10.0, 20.0, ...] - max_rewards: [50.0, 70.0, ...] - all_episode_rewards: [[r1, r2, ...], [r1, r2, ...], ...] (per step) - per_episode_metrics: [{turns, intent_correct, ...}, ...] (flattened) - prompts: ["prompt at step 1", ...] - best_step: step index with highest mean reward - best_mean_reward: highest mean reward achieved - duration_seconds: total training time """ steps = [] mean_rewards = [] min_rewards = [] max_rewards = [] all_episode_rewards = [] per_episode_metrics = [] prompts = [] for it in self.iterations: steps.append(it["step"]) mean_rewards.append(it["mean_reward"]) min_rewards.append(it["min_reward"]) max_rewards.append(it["max_reward"]) all_episode_rewards.append(it.get("rewards", [])) prompts.append(it["prompt"]) # Flatten per-episode metrics for this step for ei, log in enumerate(it.get("logs", [])): per_episode_metrics.append({ "step": it["step"], "episode": ei, "reward": it["rewards"][ei] if ei < len(it.get("rewards", [])) else None, "turns": log.get("turns", 0), "intent_captured": log.get("intent_captured", False), "intent_correct": log.get("intent_correct", False), "true_intent": log.get("true_intent", ""), "agent_intent": log.get("agent_intent", ""), "injection_attempted": log.get("injection_attempted", False), "injection_succeeded": log.get("injection_succeeded", False), "api_call_made": log.get("api_call_made", False), "api_call_correct": log.get("api_call_correct", False), }) duration = (datetime.now() - self._start_time).total_seconds() best_idx = mean_rewards.index(max(mean_rewards)) if mean_rewards else 0 return { "steps": steps, "mean_rewards": mean_rewards, "min_rewards": min_rewards, "max_rewards": max_rewards, "all_episode_rewards": all_episode_rewards, "per_episode_metrics": per_episode_metrics, "prompts": prompts, "best_step": steps[best_idx] if steps else None, "best_mean_reward": max(mean_rewards) if mean_rewards else None, "total_episodes": len(per_episode_metrics), "duration_seconds": round(duration, 1), } def save_raw_summary(self, output_dir: str | None = None) -> str: """Save raw summary to a JSON file and print to stdout. Returns file path.""" summary = self.generate_raw_summary() save_dir = output_dir or os.path.dirname(self.json_path) os.makedirs(save_dir, exist_ok=True) path = os.path.join(save_dir, f"raw_summary_{self.timestamp}.json") with open(path, "w") as f: json.dump(summary, f, indent=2, default=str) f.flush() os.fsync(f.fileno()) logger.info("Raw training summary saved to %s", path) return path def _select_diverse_personas( personas: list[CustomerPersona], count: int = 10 ) -> list[CustomerPersona]: """Select a diverse set of personas covering intents, personalities, and SE types.""" buckets: dict[str, list[CustomerPersona]] = {} for p in personas: key = f"{p.true_intent}_{p.social_engineering}" buckets.setdefault(key, []).append(p) selected: list[CustomerPersona] = [] # Round-robin across buckets to ensure diversity bucket_iters = {k: iter(v) for k, v in buckets.items()} while len(selected) < count and bucket_iters: empty_keys = [] for key, it in bucket_iters.items(): if len(selected) >= count: break try: selected.append(next(it)) except StopIteration: empty_keys.append(key) for k in empty_keys: del bucket_iters[k] # If we still need more, fill randomly if len(selected) < count: remaining = [p for p in personas if p not in selected] selected.extend(random.sample(remaining, min(count - len(selected), len(remaining)))) return selected class ReportGenerator: """Generates a training report with charts and conversation comparisons.""" def __init__(self, evaluator: Any, training_logger: TrainingLogger): self.evaluator = evaluator self.logger = training_logger def generate_report( self, output_dir: str = "./reports", num_eval_episodes: int = 30, num_example_customers: int = 10, ) -> str: """Generate the full training report. Returns path to the markdown report.""" os.makedirs(output_dir, exist_ok=True) ts = self.logger.timestamp # 1. Select checkpoint prompts indices = self.logger.get_checkpoint_indices() checkpoints = [self.logger.iterations[i] for i in indices] checkpoint_labels = self._make_labels(indices) # 2. Evaluate each checkpoint on same personas for fair comparison eval_personas = random.sample( self.evaluator.env.personas, min(num_eval_episodes, len(self.evaluator.env.personas)), ) checkpoint_evals = [] for i, cp in enumerate(checkpoints): logger.info( "Evaluating checkpoint %s (%d/%d)...", checkpoint_labels[i], i + 1, len(checkpoints), ) result = self.evaluator.evaluate_prompt( cp["prompt"], num_episodes=num_eval_episodes, personas_subset=eval_personas, ) result["label"] = checkpoint_labels[i] result["step"] = cp["step"] result["prompt"] = cp["prompt"] result["training_mean_reward"] = cp["mean_reward"] checkpoint_evals.append(result) # 3. Generate reward chart chart_path = os.path.join(output_dir, f"reward_chart_{ts}.png") self._generate_chart(checkpoint_evals, chart_path) # 4. Run example conversations example_personas = _select_diverse_personas( self.evaluator.env.personas, num_example_customers ) example_conversations = self._run_example_conversations( checkpoints, checkpoint_labels, example_personas ) # 5. Compute per-checkpoint metrics checkpoint_metrics = self._compute_metrics(checkpoint_evals) # 6. Write markdown report report_path = os.path.join(output_dir, f"report_{ts}.md") self._write_report( report_path, chart_path, checkpoints, checkpoint_labels, checkpoint_evals, checkpoint_metrics, example_conversations, ) # 7. Save logger JSON self.logger.save_json() logger.info("Report generated: %s", report_path) return report_path def _make_labels(self, indices: list[int]) -> list[str]: labels = [] for i, idx in enumerate(indices): step = self.logger.iterations[idx]["step"] if i == 0: labels.append(f"Step {step} (Start)") elif i == len(indices) - 1: labels.append(f"Step {step} (Final)") else: labels.append(f"Step {step} (Mid)") return labels def _generate_chart(self, checkpoint_evals: list[dict], save_path: str): """Generate reward progression chart using matplotlib.""" import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt fig, axes = plt.subplots(1, 2, figsize=(14, 5)) # Left: Line chart of all training iterations ax1 = axes[0] steps = [it["step"] for it in self.logger.iterations] rewards = [it["mean_reward"] for it in self.logger.iterations] ax1.plot(steps, rewards, "b-o", markersize=4, alpha=0.7, label="Training reward") # Highlight checkpoints cp_steps = [e["step"] for e in checkpoint_evals] cp_rewards_training = [e["training_mean_reward"] for e in checkpoint_evals] ax1.scatter(cp_steps, cp_rewards_training, c="red", s=100, zorder=5, label="Checkpoints") for e in checkpoint_evals: ax1.annotate( e["label"].split("(")[1].rstrip(")"), (e["step"], e["training_mean_reward"]), textcoords="offset points", xytext=(0, 12), ha="center", fontsize=9, ) ax1.set_xlabel("Training Step") ax1.set_ylabel("Mean Reward") ax1.set_title("Reward Progression During Training") ax1.legend(loc="lower right") ax1.grid(True, alpha=0.3) # Right: Grouped bar chart of checkpoint evaluation metrics ax2 = axes[1] labels = [e["label"] for e in checkpoint_evals] eval_rewards = [e["mean_reward"] for e in checkpoint_evals] x = range(len(labels)) bars = ax2.bar(x, eval_rewards, color=["#e74c3c", "#f39c12", "#27ae60"]) ax2.set_xticks(list(x)) ax2.set_xticklabels([l.split("(")[1].rstrip(")") for l in labels]) ax2.set_ylabel(f"Mean Reward ({checkpoint_evals[0].get('num_episodes', '?')}-episode eval)") ax2.set_title("Checkpoint Comparison") ax2.grid(True, alpha=0.3, axis="y") for bar, val in zip(bars, eval_rewards): ax2.text( bar.get_x() + bar.get_width() / 2, bar.get_height() + 1, f"{val:.1f}", ha="center", va="bottom", fontweight="bold", ) plt.tight_layout() plt.savefig(save_path, dpi=150, bbox_inches="tight") plt.close() logger.info("Chart saved to %s", save_path) def _run_example_conversations( self, checkpoints: list[dict], checkpoint_labels: list[str], example_personas: list[CustomerPersona], ) -> list[dict]: """Run each example persona through all checkpoint agents.""" examples = [] for pi, persona in enumerate(example_personas): customer_data = { "persona_id": persona.id, "true_intent": persona.true_intent, "personality": persona.personality, "social_engineering": persona.social_engineering, "conversations": [], } for ci, cp in enumerate(checkpoints): log = self.evaluator.env.run_episode( system_prompt=cp["prompt"], agent_fn=self.evaluator.agent_fn, persona=persona, ) r = reward_fn(log) customer_data["conversations"].append({ "label": checkpoint_labels[ci], "reward": r, "turns": log.turns, "intent_correct": log.intent_correct, "injection_attempted": log.injection_attempted, "injection_succeeded": log.injection_succeeded, "messages": log.messages, }) examples.append(customer_data) logger.info("Example %d/%d complete", pi + 1, len(example_personas)) return examples def _compute_metrics(self, checkpoint_evals: list[dict]) -> list[dict]: """Compute aggregate metrics for each checkpoint evaluation.""" metrics = [] for e in checkpoint_evals: logs = e.get("logs", []) total = len(logs) if total == 0: metrics.append({ "label": e["label"], "mean_reward": 0, "intent_accuracy": 0, "injection_resistance": "N/A", "avg_turns": 0, }) continue correct = sum(1 for l in logs if l.get("intent_correct")) inj_attempted = sum(1 for l in logs if l.get("injection_attempted")) inj_succeeded = sum(1 for l in logs if l.get("injection_succeeded")) avg_turns = sum(l.get("turns", 0) for l in logs) / total metrics.append({ "label": e["label"], "mean_reward": e["mean_reward"], "intent_accuracy": f"{correct / total:.0%}", "injection_resistance": ( f"{(inj_attempted - inj_succeeded) / inj_attempted:.0%}" if inj_attempted > 0 else "N/A (none attempted)" ), "avg_turns": f"{avg_turns:.1f}", }) return metrics def _write_report( self, report_path: str, chart_path: str, checkpoints: list[dict], checkpoint_labels: list[str], checkpoint_evals: list[dict], checkpoint_metrics: list[dict], example_conversations: list[dict], ): """Write the final markdown report.""" duration = datetime.now() - self.logger._start_time chart_filename = os.path.basename(chart_path) lines = [] lines.append(f"# Training Report — {self.logger._start_time.strftime('%Y-%m-%d %H:%M')}") lines.append("") lines.append("## Training Summary") lines.append(f"- **Total steps:** {len(self.logger.iterations)}") lines.append(f"- **Duration:** {duration.seconds // 60}m {duration.seconds % 60}s") lines.append(f"- **Checkpoints evaluated on:** {checkpoint_evals[0].get('num_episodes', 30)} episodes each") lines.append("") # Checkpoint prompts lines.append("## Checkpoint Prompts") lines.append("") for i, cp in enumerate(checkpoints): mr = cp["mean_reward"] lines.append(f"### {checkpoint_labels[i]} (Training Reward: {mr:.1f})") lines.append("") lines.append("```") lines.append(cp["prompt"]) lines.append("```") lines.append("") # Reward chart lines.append("## Reward Progression") lines.append("") lines.append(f"![Reward Chart]({chart_filename})") lines.append("") # Metrics table lines.append("## Checkpoint Comparison") lines.append("") lines.append("| Checkpoint | Mean Reward | Intent Accuracy | Injection Resistance | Avg Turns |") lines.append("|------------|-------------|-----------------|----------------------|-----------|") for m in checkpoint_metrics: lines.append( f"| {m['label']} | {m['mean_reward']:.1f} | {m['intent_accuracy']} " f"| {m['injection_resistance']} | {m['avg_turns']} |" ) lines.append("") # Example conversations lines.append(f"## Example Conversations ({len(example_conversations)} Customers x {len(checkpoints)} Agents)") lines.append("") customer_letters = "ABCDEFGHIJ" for ci, cust in enumerate(example_conversations): letter = customer_letters[ci] if ci < len(customer_letters) else str(ci) lines.append( f"### Customer {letter} — Persona {cust['persona_id']} " f"({cust['true_intent']}, {cust['personality']}, " f"SE={cust['social_engineering']})" ) lines.append("") for conv in cust["conversations"]: status = "" if conv["injection_attempted"]: status = " | INJECTION " + ("BLOCKED" if not conv["injection_succeeded"] else "SUCCEEDED") lines.append( f"#### {conv['label']} (Reward: {conv['reward']:.1f}, " f"Turns: {conv['turns']}, " f"Intent: {'correct' if conv['intent_correct'] else 'wrong'}" f"{status})" ) lines.append("") for msg in conv["messages"]: if isinstance(msg, dict): role = msg.get("role", "unknown").capitalize() content = msg.get("content", "") lines.append(f"**{role}:** {content}") lines.append("") lines.append("---") lines.append("") lines.append("") with open(report_path, "w") as f: f.write("\n".join(lines)) f.flush() os.fsync(f.fileno())