| """Training dynamics analyzer.""" |
|
|
| import math |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional |
|
|
| try: |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| HAS_MATPLOTLIB = True |
| except ImportError: |
| HAS_MATPLOTLIB = False |
|
|
|
|
| class TrainingDynamicsAnalyzer: |
| """Analyzes and visualizes training metrics. |
| |
| Analysis items: |
| - Loss curve: Convergence patterns, spike detection |
| - LR schedule: Warmup + Cosine decay verification |
| - Gradient Norm: Training stability, explosion/vanishing detection |
| - Throughput: tokens/sec stability, bottleneck detection |
| """ |
|
|
| def __init__(self, save_dir: str = "./eval_results"): |
| self.save_dir = Path(save_dir) |
| self.save_dir.mkdir(parents=True, exist_ok=True) |
|
|
| def analyze_metrics(self, metrics_history: Dict[str, list]) -> Dict[str, Any]: |
| """Analyzes training metrics. |
| |
| Args: |
| metrics_history: Trainer.metrics.history dictionary |
| |
| Returns: |
| Analysis results |
| """ |
| print("\n" + "=" * 70) |
| print("π¬ Training Dynamics Analysis") |
| print("=" * 70) |
|
|
| analysis = {} |
|
|
| |
| if metrics_history.get("train_loss"): |
| losses = metrics_history["train_loss"] |
| analysis["loss"] = { |
| "initial": round(losses[0], 4), |
| "final": round(losses[-1], 4), |
| "minimum": round(min(losses), 4), |
| "total_reduction": round(losses[0] - losses[-1], 4), |
| } |
|
|
| |
| spikes = [] |
| for i in range(1, len(losses)): |
| if losses[i] > losses[i-1] * 1.5: |
| step = metrics_history["step"][i] if "step" in metrics_history else i |
| spikes.append({"step": step, "loss": round(losses[i], 4)}) |
|
|
| analysis["loss"]["spikes"] = spikes |
|
|
| print(f"\n π Loss Analysis:") |
| print(f" Initial: {analysis['loss']['initial']:.4f}") |
| print(f" Final: {analysis['loss']['final']:.4f}") |
| print(f" Minimum: {analysis['loss']['minimum']:.4f}") |
| print(f" Reduction: {analysis['loss']['total_reduction']:.4f}") |
| print(f" Spikes: {len(spikes)}") |
| if spikes: |
| for s in spikes[:5]: |
| print(f" Step {s['step']}: Loss = {s['loss']}") |
|
|
| |
| if metrics_history.get("grad_norm"): |
| gnorms = metrics_history["grad_norm"] |
| analysis["grad_norm"] = { |
| "mean": round(sum(gnorms) / len(gnorms), 4), |
| "max": round(max(gnorms), 4), |
| "min": round(min(gnorms), 4), |
| "clipped_pct": round(sum(1 for g in gnorms if g >= 1.0) / len(gnorms) * 100, 1), |
| } |
|
|
| print(f"\n π Gradient Norm Analysis:") |
| print(f" Mean: {analysis['grad_norm']['mean']:.4f}") |
| print(f" Max: {analysis['grad_norm']['max']:.4f}") |
| print(f" Clipping rate: {analysis['grad_norm']['clipped_pct']:.1f}%") |
| if analysis["grad_norm"]["clipped_pct"] > 50: |
| print(f" β οΈ Clipping is frequent β consider lowering LR or extending warmup") |
|
|
| |
| if metrics_history.get("tokens_per_sec"): |
| tps = metrics_history["tokens_per_sec"] |
| tps_valid = [t for t in tps if t > 0] |
| if tps_valid: |
| analysis["throughput"] = { |
| "mean": round(sum(tps_valid) / len(tps_valid)), |
| "std": round((sum((t - sum(tps_valid)/len(tps_valid))**2 for t in tps_valid) / len(tps_valid))**0.5), |
| "min": round(min(tps_valid)), |
| "max": round(max(tps_valid)), |
| } |
|
|
| print(f"\n β‘ Throughput Analysis:") |
| print(f" Mean: {analysis['throughput']['mean']:,} tokens/sec") |
| print(f" StdDev: {analysis['throughput']['std']:,}") |
| print(f" Range: [{analysis['throughput']['min']:,}, {analysis['throughput']['max']:,}]") |
|
|
| return analysis |
|
|
| def plot_training_curves( |
| self, |
| metrics_history: Dict[str, list], |
| save_path: Optional[str] = None, |
| ): |
| """Visualizes training curves as a 4-panel chart.""" |
| if not HAS_MATPLOTLIB: |
| print("β οΈ matplotlib required: pip install matplotlib") |
| return |
|
|
| fig, axes = plt.subplots(2, 2, figsize=(16, 10)) |
| fig.suptitle("Training Dynamics", fontsize=16, fontweight="bold") |
|
|
| steps = metrics_history.get("step", list(range(len(metrics_history.get("train_loss", []))))) |
|
|
| |
| ax = axes[0, 0] |
| if metrics_history.get("train_loss"): |
| ax.plot(steps[:len(metrics_history["train_loss"])], |
| metrics_history["train_loss"], |
| color="#2563eb", alpha=0.6, linewidth=0.8, label="Train Loss") |
|
|
| |
| if len(metrics_history["train_loss"]) > 20: |
| window = min(50, len(metrics_history["train_loss"]) // 5) |
| smoothed = self._moving_average(metrics_history["train_loss"], window) |
| ax.plot(steps[window-1:len(smoothed)+window-1], |
| smoothed, color="#1d4ed8", linewidth=2, label=f"Smoothed (window={window})") |
|
|
| if metrics_history.get("val_loss"): |
| val_steps = [steps[i] for i in range(0, len(steps), |
| max(1, len(steps)//len(metrics_history["val_loss"])))][:len(metrics_history["val_loss"])] |
| ax.plot(val_steps, metrics_history["val_loss"], |
| "o-", color="#dc2626", linewidth=2, markersize=5, label="Val Loss") |
|
|
| ax.set_xlabel("Step") |
| ax.set_ylabel("Loss") |
| ax.set_title("Training & Validation Loss") |
| ax.legend() |
| ax.grid(True, alpha=0.3) |
|
|
| |
| ax = axes[0, 1] |
| if metrics_history.get("learning_rate"): |
| ax.plot(steps[:len(metrics_history["learning_rate"])], |
| metrics_history["learning_rate"], |
| color="#059669", linewidth=2) |
| ax.set_xlabel("Step") |
| ax.set_ylabel("Learning Rate") |
| ax.set_title("Learning Rate Schedule") |
| ax.ticklabel_format(style="scientific", axis="y", scilimits=(0,0)) |
| ax.grid(True, alpha=0.3) |
|
|
| |
| ax = axes[1, 0] |
| if metrics_history.get("grad_norm"): |
| ax.plot(steps[:len(metrics_history["grad_norm"])], |
| metrics_history["grad_norm"], |
| color="#d97706", alpha=0.6, linewidth=0.8) |
| ax.axhline(y=1.0, color="red", linestyle="--", alpha=0.5, label="Clip threshold") |
| ax.legend() |
| ax.set_xlabel("Step") |
| ax.set_ylabel("Gradient Norm") |
| ax.set_title("Gradient Norm (clipped at 1.0)") |
| ax.grid(True, alpha=0.3) |
|
|
| |
| ax = axes[1, 1] |
| if metrics_history.get("tokens_per_sec"): |
| tps = metrics_history["tokens_per_sec"] |
| ax.plot(steps[:len(tps)], tps, color="#7c3aed", alpha=0.6, linewidth=0.8) |
| if tps: |
| avg_tps = sum(tps) / len(tps) |
| ax.axhline(y=avg_tps, color="#7c3aed", linestyle="--", alpha=0.5, |
| label=f"Avg: {avg_tps:,.0f}") |
| ax.legend() |
| ax.set_xlabel("Step") |
| ax.set_ylabel("Tokens/sec") |
| ax.set_title("Training Throughput") |
| ax.grid(True, alpha=0.3) |
|
|
| plt.tight_layout() |
|
|
| save_path = save_path or str(self.save_dir / "training_curves.png") |
| fig.savefig(save_path, dpi=150, bbox_inches="tight") |
| print(f"\n π Training curves saved: {save_path}") |
| plt.close(fig) |
|
|
| def plot_position_loss( |
| self, |
| position_losses: List[float], |
| save_path: Optional[str] = None, |
| ): |
| """Visualizes loss distribution by position.""" |
| if not HAS_MATPLOTLIB: |
| return |
|
|
| fig, ax = plt.subplots(figsize=(12, 5)) |
|
|
| positions = list(range(len(position_losses))) |
| ax.plot(positions, position_losses, color="#2563eb", linewidth=1.5) |
| ax.fill_between(positions, position_losses, alpha=0.1, color="#2563eb") |
|
|
| ax.set_xlabel("Position in Sequence", fontsize=12) |
| ax.set_ylabel("Cross-Entropy Loss", fontsize=12) |
| ax.set_title("Loss by Position (earlier positions have less context)", fontsize=13, fontweight="bold") |
| ax.grid(True, alpha=0.3) |
|
|
| |
| if len(position_losses) > 100: |
| early_avg = sum(position_losses[:50]) / 50 |
| late_avg = sum(position_losses[-200:]) / 200 |
| ax.axhline(y=early_avg, color="red", linestyle="--", alpha=0.4, |
| label=f"Early avg (0-50): {early_avg:.2f}") |
| ax.axhline(y=late_avg, color="green", linestyle="--", alpha=0.4, |
| label=f"Late avg (-200): {late_avg:.2f}") |
| ax.legend() |
|
|
| plt.tight_layout() |
|
|
| save_path = save_path or str(self.save_dir / "position_loss.png") |
| fig.savefig(save_path, dpi=150, bbox_inches="tight") |
| print(f" π Position loss saved: {save_path}") |
| plt.close(fig) |
|
|
| @staticmethod |
| def _moving_average(data: list, window: int) -> list: |
| """Compute moving average.""" |
| result = [] |
| for i in range(window - 1, len(data)): |
| avg = sum(data[i - window + 1 : i + 1]) / window |
| result.append(avg) |
| return result |
|
|