"""Training dynamics analyzer.""" import math from pathlib import Path from typing import Any, Dict, List, Optional try: import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt HAS_MATPLOTLIB = True except ImportError: HAS_MATPLOTLIB = False class TrainingDynamicsAnalyzer: """Analyzes and visualizes training metrics. Analysis items: - Loss curve: Convergence patterns, spike detection - LR schedule: Warmup + Cosine decay verification - Gradient Norm: Training stability, explosion/vanishing detection - Throughput: tokens/sec stability, bottleneck detection """ def __init__(self, save_dir: str = "./eval_results"): self.save_dir = Path(save_dir) self.save_dir.mkdir(parents=True, exist_ok=True) def analyze_metrics(self, metrics_history: Dict[str, list]) -> Dict[str, Any]: """Analyzes training metrics. Args: metrics_history: Trainer.metrics.history dictionary Returns: Analysis results """ print("\n" + "=" * 70) print("🔬 Training Dynamics Analysis") print("=" * 70) analysis = {} # ── Loss analysis ── if metrics_history.get("train_loss"): losses = metrics_history["train_loss"] analysis["loss"] = { "initial": round(losses[0], 4), "final": round(losses[-1], 4), "minimum": round(min(losses), 4), "total_reduction": round(losses[0] - losses[-1], 4), } # Spike detection (sudden increase of 50% or more compared to previous value) spikes = [] for i in range(1, len(losses)): if losses[i] > losses[i-1] * 1.5: step = metrics_history["step"][i] if "step" in metrics_history else i spikes.append({"step": step, "loss": round(losses[i], 4)}) analysis["loss"]["spikes"] = spikes print(f"\n 📉 Loss Analysis:") print(f" Initial: {analysis['loss']['initial']:.4f}") print(f" Final: {analysis['loss']['final']:.4f}") print(f" Minimum: {analysis['loss']['minimum']:.4f}") print(f" Reduction: {analysis['loss']['total_reduction']:.4f}") print(f" Spikes: {len(spikes)}") if spikes: for s in spikes[:5]: print(f" Step {s['step']}: Loss = {s['loss']}") # ── Gradient Norm analysis ── if metrics_history.get("grad_norm"): gnorms = metrics_history["grad_norm"] analysis["grad_norm"] = { "mean": round(sum(gnorms) / len(gnorms), 4), "max": round(max(gnorms), 4), "min": round(min(gnorms), 4), "clipped_pct": round(sum(1 for g in gnorms if g >= 1.0) / len(gnorms) * 100, 1), } print(f"\n 📐 Gradient Norm Analysis:") print(f" Mean: {analysis['grad_norm']['mean']:.4f}") print(f" Max: {analysis['grad_norm']['max']:.4f}") print(f" Clipping rate: {analysis['grad_norm']['clipped_pct']:.1f}%") if analysis["grad_norm"]["clipped_pct"] > 50: print(f" ⚠️ Clipping is frequent → consider lowering LR or extending warmup") # ── Throughput analysis ── if metrics_history.get("tokens_per_sec"): tps = metrics_history["tokens_per_sec"] tps_valid = [t for t in tps if t > 0] if tps_valid: analysis["throughput"] = { "mean": round(sum(tps_valid) / len(tps_valid)), "std": round((sum((t - sum(tps_valid)/len(tps_valid))**2 for t in tps_valid) / len(tps_valid))**0.5), "min": round(min(tps_valid)), "max": round(max(tps_valid)), } print(f"\n ⚡ Throughput Analysis:") print(f" Mean: {analysis['throughput']['mean']:,} tokens/sec") print(f" StdDev: {analysis['throughput']['std']:,}") print(f" Range: [{analysis['throughput']['min']:,}, {analysis['throughput']['max']:,}]") return analysis def plot_training_curves( self, metrics_history: Dict[str, list], save_path: Optional[str] = None, ): """Visualizes training curves as a 4-panel chart.""" if not HAS_MATPLOTLIB: print("⚠️ matplotlib required: pip install matplotlib") return fig, axes = plt.subplots(2, 2, figsize=(16, 10)) fig.suptitle("Training Dynamics", fontsize=16, fontweight="bold") steps = metrics_history.get("step", list(range(len(metrics_history.get("train_loss", []))))) # ── (1) Loss ── ax = axes[0, 0] if metrics_history.get("train_loss"): ax.plot(steps[:len(metrics_history["train_loss"])], metrics_history["train_loss"], color="#2563eb", alpha=0.6, linewidth=0.8, label="Train Loss") # Moving average (smoothing) if len(metrics_history["train_loss"]) > 20: window = min(50, len(metrics_history["train_loss"]) // 5) smoothed = self._moving_average(metrics_history["train_loss"], window) ax.plot(steps[window-1:len(smoothed)+window-1], smoothed, color="#1d4ed8", linewidth=2, label=f"Smoothed (window={window})") if metrics_history.get("val_loss"): val_steps = [steps[i] for i in range(0, len(steps), max(1, len(steps)//len(metrics_history["val_loss"])))][:len(metrics_history["val_loss"])] ax.plot(val_steps, metrics_history["val_loss"], "o-", color="#dc2626", linewidth=2, markersize=5, label="Val Loss") ax.set_xlabel("Step") ax.set_ylabel("Loss") ax.set_title("Training & Validation Loss") ax.legend() ax.grid(True, alpha=0.3) # ── (2) Learning Rate ── ax = axes[0, 1] if metrics_history.get("learning_rate"): ax.plot(steps[:len(metrics_history["learning_rate"])], metrics_history["learning_rate"], color="#059669", linewidth=2) ax.set_xlabel("Step") ax.set_ylabel("Learning Rate") ax.set_title("Learning Rate Schedule") ax.ticklabel_format(style="scientific", axis="y", scilimits=(0,0)) ax.grid(True, alpha=0.3) # ── (3) Gradient Norm ── ax = axes[1, 0] if metrics_history.get("grad_norm"): ax.plot(steps[:len(metrics_history["grad_norm"])], metrics_history["grad_norm"], color="#d97706", alpha=0.6, linewidth=0.8) ax.axhline(y=1.0, color="red", linestyle="--", alpha=0.5, label="Clip threshold") ax.legend() ax.set_xlabel("Step") ax.set_ylabel("Gradient Norm") ax.set_title("Gradient Norm (clipped at 1.0)") ax.grid(True, alpha=0.3) # ── (4) Throughput ── ax = axes[1, 1] if metrics_history.get("tokens_per_sec"): tps = metrics_history["tokens_per_sec"] ax.plot(steps[:len(tps)], tps, color="#7c3aed", alpha=0.6, linewidth=0.8) if tps: avg_tps = sum(tps) / len(tps) ax.axhline(y=avg_tps, color="#7c3aed", linestyle="--", alpha=0.5, label=f"Avg: {avg_tps:,.0f}") ax.legend() ax.set_xlabel("Step") ax.set_ylabel("Tokens/sec") ax.set_title("Training Throughput") ax.grid(True, alpha=0.3) plt.tight_layout() save_path = save_path or str(self.save_dir / "training_curves.png") fig.savefig(save_path, dpi=150, bbox_inches="tight") print(f"\n 📊 Training curves saved: {save_path}") plt.close(fig) def plot_position_loss( self, position_losses: List[float], save_path: Optional[str] = None, ): """Visualizes loss distribution by position.""" if not HAS_MATPLOTLIB: return fig, ax = plt.subplots(figsize=(12, 5)) positions = list(range(len(position_losses))) ax.plot(positions, position_losses, color="#2563eb", linewidth=1.5) ax.fill_between(positions, position_losses, alpha=0.1, color="#2563eb") ax.set_xlabel("Position in Sequence", fontsize=12) ax.set_ylabel("Cross-Entropy Loss", fontsize=12) ax.set_title("Loss by Position (earlier positions have less context)", fontsize=13, fontweight="bold") ax.grid(True, alpha=0.3) # Mark key regions if len(position_losses) > 100: early_avg = sum(position_losses[:50]) / 50 late_avg = sum(position_losses[-200:]) / 200 ax.axhline(y=early_avg, color="red", linestyle="--", alpha=0.4, label=f"Early avg (0-50): {early_avg:.2f}") ax.axhline(y=late_avg, color="green", linestyle="--", alpha=0.4, label=f"Late avg (-200): {late_avg:.2f}") ax.legend() plt.tight_layout() save_path = save_path or str(self.save_dir / "position_loss.png") fig.savefig(save_path, dpi=150, bbox_inches="tight") print(f" 📊 Position loss saved: {save_path}") plt.close(fig) @staticmethod def _moving_average(data: list, window: int) -> list: """Compute moving average.""" result = [] for i in range(window - 1, len(data)): avg = sum(data[i - window + 1 : i + 1]) / window result.append(avg) return result