Spaces:
Running on T4
Running on T4
| """ | |
| Training Logger & Report Generator for the RL Prompt Optimization pipeline. | |
| TrainingLogger: Appends per-iteration logs to a text file and keeps structured data in memory. | |
| ReportGenerator: After training, evaluates checkpoint prompts and produces a markdown report | |
| with reward charts and side-by-side conversation comparisons. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| import os | |
| import random | |
| from datetime import datetime | |
| from typing import Any | |
| from layer0.reward import reward_fn | |
| from layer2.customer_sim import CustomerPersona | |
| logger = logging.getLogger(__name__) | |
| class TrainingLogger: | |
| """Logs each training iteration to a text file and stores structured data.""" | |
| def __init__(self, log_dir: str = "./logs", total_steps: int | None = None): | |
| os.makedirs(log_dir, exist_ok=True) | |
| self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| self.log_path = os.path.join(log_dir, f"log_{self.timestamp}.txt") | |
| self.json_path = os.path.join(log_dir, f"training_{self.timestamp}.json") | |
| self.total_steps = total_steps | |
| self.iterations: list[dict[str, Any]] = [] | |
| self._start_time = datetime.now() | |
| self._on_step_callbacks: list[Any] = [] | |
| with open(self.log_path, "w") as f: | |
| f.write(f"Training Log β {self._start_time.isoformat()}\n") | |
| f.write(f"{'=' * 60}\n\n") | |
| f.flush() | |
| os.fsync(f.fileno()) | |
| def add_on_step_callback(self, callback): | |
| """Register a callback called after each step: callback(step, eval_result, prompt).""" | |
| self._on_step_callbacks.append(callback) | |
| def log_iteration(self, step: int, prompt: str, eval_result: dict[str, Any]): | |
| """Log a single training iteration (one prompt evaluated).""" | |
| entry = { | |
| "step": step, | |
| "prompt": prompt, | |
| "mean_reward": eval_result.get("mean_reward", 0.0), | |
| "min_reward": eval_result.get("min_reward", 0.0), | |
| "max_reward": eval_result.get("max_reward", 0.0), | |
| "num_episodes": eval_result.get("num_episodes", 0), | |
| "rewards": eval_result.get("rewards", []), | |
| "logs": eval_result.get("logs", []), | |
| } | |
| self.iterations.append(entry) | |
| total_str = f" / {self.total_steps}" if self.total_steps else "" | |
| with open(self.log_path, "a") as f: | |
| f.write(f"=== Step {step}{total_str} ===\n") | |
| f.write(f"Prompt: {prompt[:300]}{'...' if len(prompt) > 300 else ''}\n") | |
| f.write(f"Mean Reward: {entry['mean_reward']:.1f}\n") | |
| f.write(f"Min/Max: {entry['min_reward']:.1f} / {entry['max_reward']:.1f}\n") | |
| f.write(f"Episodes: {entry['num_episodes']}\n") | |
| f.write(f"---\n\n") | |
| f.flush() | |
| os.fsync(f.fileno()) | |
| # Incremental save β persist JSON after every step so data survives crashes | |
| self.save_json() | |
| logger.info("Logged step %d: mean_reward=%.1f", step, entry["mean_reward"]) | |
| # Notify callbacks (e.g. Supabase uploader) | |
| for cb in self._on_step_callbacks: | |
| try: | |
| cb(step, eval_result, prompt) | |
| except Exception as e: | |
| logger.error("Step callback failed: %s", e) | |
| def save_json(self): | |
| """Save structured training data to JSON.""" | |
| data = { | |
| "timestamp": self.timestamp, | |
| "start_time": self._start_time.isoformat(), | |
| "end_time": datetime.now().isoformat(), | |
| "total_steps": self.total_steps, | |
| "iterations": [ | |
| {k: v for k, v in it.items() if k != "logs"} | |
| for it in self.iterations | |
| ], | |
| } | |
| with open(self.json_path, "w") as f: | |
| json.dump(data, f, indent=2, default=str) | |
| f.flush() | |
| os.fsync(f.fileno()) | |
| logger.info("Training data saved to %s", self.json_path) | |
| def get_checkpoint_indices(self) -> list[int]: | |
| """Return indices for 3 checkpoints: start, middle, end.""" | |
| n = len(self.iterations) | |
| if n <= 1: | |
| return list(range(n)) | |
| if n == 2: | |
| return [0, n - 1] | |
| return [0, n // 2, n - 1] | |
| def generate_raw_summary(self) -> dict[str, Any]: | |
| """ | |
| Generate a raw training summary with arrays for easy post-hoc analysis. | |
| Returns a dict with: | |
| - steps: [1, 2, 3, ...] | |
| - mean_rewards: [30.0, 45.0, ...] | |
| - min_rewards: [10.0, 20.0, ...] | |
| - max_rewards: [50.0, 70.0, ...] | |
| - all_episode_rewards: [[r1, r2, ...], [r1, r2, ...], ...] (per step) | |
| - per_episode_metrics: [{turns, intent_correct, ...}, ...] (flattened) | |
| - prompts: ["prompt at step 1", ...] | |
| - best_step: step index with highest mean reward | |
| - best_mean_reward: highest mean reward achieved | |
| - duration_seconds: total training time | |
| """ | |
| steps = [] | |
| mean_rewards = [] | |
| min_rewards = [] | |
| max_rewards = [] | |
| all_episode_rewards = [] | |
| per_episode_metrics = [] | |
| prompts = [] | |
| for it in self.iterations: | |
| steps.append(it["step"]) | |
| mean_rewards.append(it["mean_reward"]) | |
| min_rewards.append(it["min_reward"]) | |
| max_rewards.append(it["max_reward"]) | |
| all_episode_rewards.append(it.get("rewards", [])) | |
| prompts.append(it["prompt"]) | |
| # Flatten per-episode metrics for this step | |
| for ei, log in enumerate(it.get("logs", [])): | |
| per_episode_metrics.append({ | |
| "step": it["step"], | |
| "episode": ei, | |
| "reward": it["rewards"][ei] if ei < len(it.get("rewards", [])) else None, | |
| "turns": log.get("turns", 0), | |
| "intent_captured": log.get("intent_captured", False), | |
| "intent_correct": log.get("intent_correct", False), | |
| "true_intent": log.get("true_intent", ""), | |
| "agent_intent": log.get("agent_intent", ""), | |
| "injection_attempted": log.get("injection_attempted", False), | |
| "injection_succeeded": log.get("injection_succeeded", False), | |
| "api_call_made": log.get("api_call_made", False), | |
| "api_call_correct": log.get("api_call_correct", False), | |
| }) | |
| duration = (datetime.now() - self._start_time).total_seconds() | |
| best_idx = mean_rewards.index(max(mean_rewards)) if mean_rewards else 0 | |
| return { | |
| "steps": steps, | |
| "mean_rewards": mean_rewards, | |
| "min_rewards": min_rewards, | |
| "max_rewards": max_rewards, | |
| "all_episode_rewards": all_episode_rewards, | |
| "per_episode_metrics": per_episode_metrics, | |
| "prompts": prompts, | |
| "best_step": steps[best_idx] if steps else None, | |
| "best_mean_reward": max(mean_rewards) if mean_rewards else None, | |
| "total_episodes": len(per_episode_metrics), | |
| "duration_seconds": round(duration, 1), | |
| } | |
| def save_raw_summary(self, output_dir: str | None = None) -> str: | |
| """Save raw summary to a JSON file and print to stdout. Returns file path.""" | |
| summary = self.generate_raw_summary() | |
| save_dir = output_dir or os.path.dirname(self.json_path) | |
| os.makedirs(save_dir, exist_ok=True) | |
| path = os.path.join(save_dir, f"raw_summary_{self.timestamp}.json") | |
| with open(path, "w") as f: | |
| json.dump(summary, f, indent=2, default=str) | |
| f.flush() | |
| os.fsync(f.fileno()) | |
| logger.info("Raw training summary saved to %s", path) | |
| return path | |
| def _select_diverse_personas( | |
| personas: list[CustomerPersona], count: int = 10 | |
| ) -> list[CustomerPersona]: | |
| """Select a diverse set of personas covering intents, personalities, and SE types.""" | |
| buckets: dict[str, list[CustomerPersona]] = {} | |
| for p in personas: | |
| key = f"{p.true_intent}_{p.social_engineering}" | |
| buckets.setdefault(key, []).append(p) | |
| selected: list[CustomerPersona] = [] | |
| # Round-robin across buckets to ensure diversity | |
| bucket_iters = {k: iter(v) for k, v in buckets.items()} | |
| while len(selected) < count and bucket_iters: | |
| empty_keys = [] | |
| for key, it in bucket_iters.items(): | |
| if len(selected) >= count: | |
| break | |
| try: | |
| selected.append(next(it)) | |
| except StopIteration: | |
| empty_keys.append(key) | |
| for k in empty_keys: | |
| del bucket_iters[k] | |
| # If we still need more, fill randomly | |
| if len(selected) < count: | |
| remaining = [p for p in personas if p not in selected] | |
| selected.extend(random.sample(remaining, min(count - len(selected), len(remaining)))) | |
| return selected | |
| class ReportGenerator: | |
| """Generates a training report with charts and conversation comparisons.""" | |
| def __init__(self, evaluator: Any, training_logger: TrainingLogger): | |
| self.evaluator = evaluator | |
| self.logger = training_logger | |
| def generate_report( | |
| self, | |
| output_dir: str = "./reports", | |
| num_eval_episodes: int = 30, | |
| num_example_customers: int = 10, | |
| ) -> str: | |
| """Generate the full training report. Returns path to the markdown report.""" | |
| os.makedirs(output_dir, exist_ok=True) | |
| ts = self.logger.timestamp | |
| # 1. Select checkpoint prompts | |
| indices = self.logger.get_checkpoint_indices() | |
| checkpoints = [self.logger.iterations[i] for i in indices] | |
| checkpoint_labels = self._make_labels(indices) | |
| # 2. Evaluate each checkpoint on same personas for fair comparison | |
| eval_personas = random.sample( | |
| self.evaluator.env.personas, | |
| min(num_eval_episodes, len(self.evaluator.env.personas)), | |
| ) | |
| checkpoint_evals = [] | |
| for i, cp in enumerate(checkpoints): | |
| logger.info( | |
| "Evaluating checkpoint %s (%d/%d)...", | |
| checkpoint_labels[i], i + 1, len(checkpoints), | |
| ) | |
| result = self.evaluator.evaluate_prompt( | |
| cp["prompt"], | |
| num_episodes=num_eval_episodes, | |
| personas_subset=eval_personas, | |
| ) | |
| result["label"] = checkpoint_labels[i] | |
| result["step"] = cp["step"] | |
| result["prompt"] = cp["prompt"] | |
| result["training_mean_reward"] = cp["mean_reward"] | |
| checkpoint_evals.append(result) | |
| # 3. Generate reward chart | |
| chart_path = os.path.join(output_dir, f"reward_chart_{ts}.png") | |
| self._generate_chart(checkpoint_evals, chart_path) | |
| # 4. Run example conversations | |
| example_personas = _select_diverse_personas( | |
| self.evaluator.env.personas, num_example_customers | |
| ) | |
| example_conversations = self._run_example_conversations( | |
| checkpoints, checkpoint_labels, example_personas | |
| ) | |
| # 5. Compute per-checkpoint metrics | |
| checkpoint_metrics = self._compute_metrics(checkpoint_evals) | |
| # 6. Write markdown report | |
| report_path = os.path.join(output_dir, f"report_{ts}.md") | |
| self._write_report( | |
| report_path, chart_path, checkpoints, checkpoint_labels, | |
| checkpoint_evals, checkpoint_metrics, example_conversations, | |
| ) | |
| # 7. Save logger JSON | |
| self.logger.save_json() | |
| logger.info("Report generated: %s", report_path) | |
| return report_path | |
| def _make_labels(self, indices: list[int]) -> list[str]: | |
| labels = [] | |
| for i, idx in enumerate(indices): | |
| step = self.logger.iterations[idx]["step"] | |
| if i == 0: | |
| labels.append(f"Step {step} (Start)") | |
| elif i == len(indices) - 1: | |
| labels.append(f"Step {step} (Final)") | |
| else: | |
| labels.append(f"Step {step} (Mid)") | |
| return labels | |
| def _generate_chart(self, checkpoint_evals: list[dict], save_path: str): | |
| """Generate reward progression chart using matplotlib.""" | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| fig, axes = plt.subplots(1, 2, figsize=(14, 5)) | |
| # Left: Line chart of all training iterations | |
| ax1 = axes[0] | |
| steps = [it["step"] for it in self.logger.iterations] | |
| rewards = [it["mean_reward"] for it in self.logger.iterations] | |
| ax1.plot(steps, rewards, "b-o", markersize=4, alpha=0.7, label="Training reward") | |
| # Highlight checkpoints | |
| cp_steps = [e["step"] for e in checkpoint_evals] | |
| cp_rewards_training = [e["training_mean_reward"] for e in checkpoint_evals] | |
| ax1.scatter(cp_steps, cp_rewards_training, c="red", s=100, zorder=5, label="Checkpoints") | |
| for e in checkpoint_evals: | |
| ax1.annotate( | |
| e["label"].split("(")[1].rstrip(")"), | |
| (e["step"], e["training_mean_reward"]), | |
| textcoords="offset points", xytext=(0, 12), ha="center", fontsize=9, | |
| ) | |
| ax1.set_xlabel("Training Step") | |
| ax1.set_ylabel("Mean Reward") | |
| ax1.set_title("Reward Progression During Training") | |
| ax1.legend(loc="lower right") | |
| ax1.grid(True, alpha=0.3) | |
| # Right: Grouped bar chart of checkpoint evaluation metrics | |
| ax2 = axes[1] | |
| labels = [e["label"] for e in checkpoint_evals] | |
| eval_rewards = [e["mean_reward"] for e in checkpoint_evals] | |
| x = range(len(labels)) | |
| bars = ax2.bar(x, eval_rewards, color=["#e74c3c", "#f39c12", "#27ae60"]) | |
| ax2.set_xticks(list(x)) | |
| ax2.set_xticklabels([l.split("(")[1].rstrip(")") for l in labels]) | |
| ax2.set_ylabel(f"Mean Reward ({checkpoint_evals[0].get('num_episodes', '?')}-episode eval)") | |
| ax2.set_title("Checkpoint Comparison") | |
| ax2.grid(True, alpha=0.3, axis="y") | |
| for bar, val in zip(bars, eval_rewards): | |
| ax2.text( | |
| bar.get_x() + bar.get_width() / 2, bar.get_height() + 1, | |
| f"{val:.1f}", ha="center", va="bottom", fontweight="bold", | |
| ) | |
| plt.tight_layout() | |
| plt.savefig(save_path, dpi=150, bbox_inches="tight") | |
| plt.close() | |
| logger.info("Chart saved to %s", save_path) | |
| def _run_example_conversations( | |
| self, | |
| checkpoints: list[dict], | |
| checkpoint_labels: list[str], | |
| example_personas: list[CustomerPersona], | |
| ) -> list[dict]: | |
| """Run each example persona through all checkpoint agents.""" | |
| examples = [] | |
| for pi, persona in enumerate(example_personas): | |
| customer_data = { | |
| "persona_id": persona.id, | |
| "true_intent": persona.true_intent, | |
| "personality": persona.personality, | |
| "social_engineering": persona.social_engineering, | |
| "conversations": [], | |
| } | |
| for ci, cp in enumerate(checkpoints): | |
| log = self.evaluator.env.run_episode( | |
| system_prompt=cp["prompt"], | |
| agent_fn=self.evaluator.agent_fn, | |
| persona=persona, | |
| ) | |
| r = reward_fn(log) | |
| customer_data["conversations"].append({ | |
| "label": checkpoint_labels[ci], | |
| "reward": r, | |
| "turns": log.turns, | |
| "intent_correct": log.intent_correct, | |
| "injection_attempted": log.injection_attempted, | |
| "injection_succeeded": log.injection_succeeded, | |
| "messages": log.messages, | |
| }) | |
| examples.append(customer_data) | |
| logger.info("Example %d/%d complete", pi + 1, len(example_personas)) | |
| return examples | |
| def _compute_metrics(self, checkpoint_evals: list[dict]) -> list[dict]: | |
| """Compute aggregate metrics for each checkpoint evaluation.""" | |
| metrics = [] | |
| for e in checkpoint_evals: | |
| logs = e.get("logs", []) | |
| total = len(logs) | |
| if total == 0: | |
| metrics.append({ | |
| "label": e["label"], | |
| "mean_reward": 0, | |
| "intent_accuracy": 0, | |
| "injection_resistance": "N/A", | |
| "avg_turns": 0, | |
| }) | |
| continue | |
| correct = sum(1 for l in logs if l.get("intent_correct")) | |
| inj_attempted = sum(1 for l in logs if l.get("injection_attempted")) | |
| inj_succeeded = sum(1 for l in logs if l.get("injection_succeeded")) | |
| avg_turns = sum(l.get("turns", 0) for l in logs) / total | |
| metrics.append({ | |
| "label": e["label"], | |
| "mean_reward": e["mean_reward"], | |
| "intent_accuracy": f"{correct / total:.0%}", | |
| "injection_resistance": ( | |
| f"{(inj_attempted - inj_succeeded) / inj_attempted:.0%}" | |
| if inj_attempted > 0 | |
| else "N/A (none attempted)" | |
| ), | |
| "avg_turns": f"{avg_turns:.1f}", | |
| }) | |
| return metrics | |
| def _write_report( | |
| self, | |
| report_path: str, | |
| chart_path: str, | |
| checkpoints: list[dict], | |
| checkpoint_labels: list[str], | |
| checkpoint_evals: list[dict], | |
| checkpoint_metrics: list[dict], | |
| example_conversations: list[dict], | |
| ): | |
| """Write the final markdown report.""" | |
| duration = datetime.now() - self.logger._start_time | |
| chart_filename = os.path.basename(chart_path) | |
| lines = [] | |
| lines.append(f"# Training Report β {self.logger._start_time.strftime('%Y-%m-%d %H:%M')}") | |
| lines.append("") | |
| lines.append("## Training Summary") | |
| lines.append(f"- **Total steps:** {len(self.logger.iterations)}") | |
| lines.append(f"- **Duration:** {duration.seconds // 60}m {duration.seconds % 60}s") | |
| lines.append(f"- **Checkpoints evaluated on:** {checkpoint_evals[0].get('num_episodes', 30)} episodes each") | |
| lines.append("") | |
| # Checkpoint prompts | |
| lines.append("## Checkpoint Prompts") | |
| lines.append("") | |
| for i, cp in enumerate(checkpoints): | |
| mr = cp["mean_reward"] | |
| lines.append(f"### {checkpoint_labels[i]} (Training Reward: {mr:.1f})") | |
| lines.append("") | |
| lines.append("```") | |
| lines.append(cp["prompt"]) | |
| lines.append("```") | |
| lines.append("") | |
| # Reward chart | |
| lines.append("## Reward Progression") | |
| lines.append("") | |
| lines.append(f"") | |
| lines.append("") | |
| # Metrics table | |
| lines.append("## Checkpoint Comparison") | |
| lines.append("") | |
| lines.append("| Checkpoint | Mean Reward | Intent Accuracy | Injection Resistance | Avg Turns |") | |
| lines.append("|------------|-------------|-----------------|----------------------|-----------|") | |
| for m in checkpoint_metrics: | |
| lines.append( | |
| f"| {m['label']} | {m['mean_reward']:.1f} | {m['intent_accuracy']} " | |
| f"| {m['injection_resistance']} | {m['avg_turns']} |" | |
| ) | |
| lines.append("") | |
| # Example conversations | |
| lines.append(f"## Example Conversations ({len(example_conversations)} Customers x {len(checkpoints)} Agents)") | |
| lines.append("") | |
| customer_letters = "ABCDEFGHIJ" | |
| for ci, cust in enumerate(example_conversations): | |
| letter = customer_letters[ci] if ci < len(customer_letters) else str(ci) | |
| lines.append( | |
| f"### Customer {letter} β Persona {cust['persona_id']} " | |
| f"({cust['true_intent']}, {cust['personality']}, " | |
| f"SE={cust['social_engineering']})" | |
| ) | |
| lines.append("") | |
| for conv in cust["conversations"]: | |
| status = "" | |
| if conv["injection_attempted"]: | |
| status = " | INJECTION " + ("BLOCKED" if not conv["injection_succeeded"] else "SUCCEEDED") | |
| lines.append( | |
| f"#### {conv['label']} (Reward: {conv['reward']:.1f}, " | |
| f"Turns: {conv['turns']}, " | |
| f"Intent: {'correct' if conv['intent_correct'] else 'wrong'}" | |
| f"{status})" | |
| ) | |
| lines.append("") | |
| for msg in conv["messages"]: | |
| if isinstance(msg, dict): | |
| role = msg.get("role", "unknown").capitalize() | |
| content = msg.get("content", "") | |
| lines.append(f"**{role}:** {content}") | |
| lines.append("") | |
| lines.append("---") | |
| lines.append("") | |
| lines.append("") | |
| with open(report_path, "w") as f: | |
| f.write("\n".join(lines)) | |
| f.flush() | |
| os.fsync(f.fileno()) | |