LLM-1B-Lab / llm_lab /evaluation /dynamics.py
Vjeong's picture
Fix gradient clipping thresholds in dynamics and checklist modules
a671953
"""Training dynamics analyzer."""
import math
from pathlib import Path
from typing import Any, Dict, List, Optional
try:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
HAS_MATPLOTLIB = True
except ImportError:
HAS_MATPLOTLIB = False
class TrainingDynamicsAnalyzer:
"""Analyzes and visualizes training metrics.
Analysis items:
- Loss curve: Convergence patterns, spike detection
- LR schedule: Warmup + Cosine decay verification
- Gradient Norm: Training stability, explosion/vanishing detection
- Throughput: tokens/sec stability, bottleneck detection
"""
def __init__(self, save_dir: str = "./eval_results"):
self.save_dir = Path(save_dir)
self.save_dir.mkdir(parents=True, exist_ok=True)
def analyze_metrics(self, metrics_history: Dict[str, list]) -> Dict[str, Any]:
"""Analyzes training metrics.
Args:
metrics_history: Trainer.metrics.history dictionary
Returns:
Analysis results
"""
print("\n" + "=" * 70)
print("πŸ”¬ Training Dynamics Analysis")
print("=" * 70)
analysis = {}
# ── Loss analysis ──
if metrics_history.get("train_loss"):
losses = metrics_history["train_loss"]
analysis["loss"] = {
"initial": round(losses[0], 4),
"final": round(losses[-1], 4),
"minimum": round(min(losses), 4),
"total_reduction": round(losses[0] - losses[-1], 4),
}
# Spike detection (sudden increase of 50% or more compared to previous value)
spikes = []
for i in range(1, len(losses)):
if losses[i] > losses[i-1] * 1.5:
step = metrics_history["step"][i] if "step" in metrics_history else i
spikes.append({"step": step, "loss": round(losses[i], 4)})
analysis["loss"]["spikes"] = spikes
print(f"\n πŸ“‰ Loss Analysis:")
print(f" Initial: {analysis['loss']['initial']:.4f}")
print(f" Final: {analysis['loss']['final']:.4f}")
print(f" Minimum: {analysis['loss']['minimum']:.4f}")
print(f" Reduction: {analysis['loss']['total_reduction']:.4f}")
print(f" Spikes: {len(spikes)}")
if spikes:
for s in spikes[:5]:
print(f" Step {s['step']}: Loss = {s['loss']}")
# ── Gradient Norm analysis ──
if metrics_history.get("grad_norm"):
gnorms = metrics_history["grad_norm"]
analysis["grad_norm"] = {
"mean": round(sum(gnorms) / len(gnorms), 4),
"max": round(max(gnorms), 4),
"min": round(min(gnorms), 4),
"clipped_pct": round(sum(1 for g in gnorms if g >= 1.0) / len(gnorms) * 100, 1),
}
print(f"\n πŸ“ Gradient Norm Analysis:")
print(f" Mean: {analysis['grad_norm']['mean']:.4f}")
print(f" Max: {analysis['grad_norm']['max']:.4f}")
print(f" Clipping rate: {analysis['grad_norm']['clipped_pct']:.1f}%")
if analysis["grad_norm"]["clipped_pct"] > 50:
print(f" ⚠️ Clipping is frequent β†’ consider lowering LR or extending warmup")
# ── Throughput analysis ──
if metrics_history.get("tokens_per_sec"):
tps = metrics_history["tokens_per_sec"]
tps_valid = [t for t in tps if t > 0]
if tps_valid:
analysis["throughput"] = {
"mean": round(sum(tps_valid) / len(tps_valid)),
"std": round((sum((t - sum(tps_valid)/len(tps_valid))**2 for t in tps_valid) / len(tps_valid))**0.5),
"min": round(min(tps_valid)),
"max": round(max(tps_valid)),
}
print(f"\n ⚑ Throughput Analysis:")
print(f" Mean: {analysis['throughput']['mean']:,} tokens/sec")
print(f" StdDev: {analysis['throughput']['std']:,}")
print(f" Range: [{analysis['throughput']['min']:,}, {analysis['throughput']['max']:,}]")
return analysis
def plot_training_curves(
self,
metrics_history: Dict[str, list],
save_path: Optional[str] = None,
):
"""Visualizes training curves as a 4-panel chart."""
if not HAS_MATPLOTLIB:
print("⚠️ matplotlib required: pip install matplotlib")
return
fig, axes = plt.subplots(2, 2, figsize=(16, 10))
fig.suptitle("Training Dynamics", fontsize=16, fontweight="bold")
steps = metrics_history.get("step", list(range(len(metrics_history.get("train_loss", [])))))
# ── (1) Loss ──
ax = axes[0, 0]
if metrics_history.get("train_loss"):
ax.plot(steps[:len(metrics_history["train_loss"])],
metrics_history["train_loss"],
color="#2563eb", alpha=0.6, linewidth=0.8, label="Train Loss")
# Moving average (smoothing)
if len(metrics_history["train_loss"]) > 20:
window = min(50, len(metrics_history["train_loss"]) // 5)
smoothed = self._moving_average(metrics_history["train_loss"], window)
ax.plot(steps[window-1:len(smoothed)+window-1],
smoothed, color="#1d4ed8", linewidth=2, label=f"Smoothed (window={window})")
if metrics_history.get("val_loss"):
val_steps = [steps[i] for i in range(0, len(steps),
max(1, len(steps)//len(metrics_history["val_loss"])))][:len(metrics_history["val_loss"])]
ax.plot(val_steps, metrics_history["val_loss"],
"o-", color="#dc2626", linewidth=2, markersize=5, label="Val Loss")
ax.set_xlabel("Step")
ax.set_ylabel("Loss")
ax.set_title("Training & Validation Loss")
ax.legend()
ax.grid(True, alpha=0.3)
# ── (2) Learning Rate ──
ax = axes[0, 1]
if metrics_history.get("learning_rate"):
ax.plot(steps[:len(metrics_history["learning_rate"])],
metrics_history["learning_rate"],
color="#059669", linewidth=2)
ax.set_xlabel("Step")
ax.set_ylabel("Learning Rate")
ax.set_title("Learning Rate Schedule")
ax.ticklabel_format(style="scientific", axis="y", scilimits=(0,0))
ax.grid(True, alpha=0.3)
# ── (3) Gradient Norm ──
ax = axes[1, 0]
if metrics_history.get("grad_norm"):
ax.plot(steps[:len(metrics_history["grad_norm"])],
metrics_history["grad_norm"],
color="#d97706", alpha=0.6, linewidth=0.8)
ax.axhline(y=1.0, color="red", linestyle="--", alpha=0.5, label="Clip threshold")
ax.legend()
ax.set_xlabel("Step")
ax.set_ylabel("Gradient Norm")
ax.set_title("Gradient Norm (clipped at 1.0)")
ax.grid(True, alpha=0.3)
# ── (4) Throughput ──
ax = axes[1, 1]
if metrics_history.get("tokens_per_sec"):
tps = metrics_history["tokens_per_sec"]
ax.plot(steps[:len(tps)], tps, color="#7c3aed", alpha=0.6, linewidth=0.8)
if tps:
avg_tps = sum(tps) / len(tps)
ax.axhline(y=avg_tps, color="#7c3aed", linestyle="--", alpha=0.5,
label=f"Avg: {avg_tps:,.0f}")
ax.legend()
ax.set_xlabel("Step")
ax.set_ylabel("Tokens/sec")
ax.set_title("Training Throughput")
ax.grid(True, alpha=0.3)
plt.tight_layout()
save_path = save_path or str(self.save_dir / "training_curves.png")
fig.savefig(save_path, dpi=150, bbox_inches="tight")
print(f"\n πŸ“Š Training curves saved: {save_path}")
plt.close(fig)
def plot_position_loss(
self,
position_losses: List[float],
save_path: Optional[str] = None,
):
"""Visualizes loss distribution by position."""
if not HAS_MATPLOTLIB:
return
fig, ax = plt.subplots(figsize=(12, 5))
positions = list(range(len(position_losses)))
ax.plot(positions, position_losses, color="#2563eb", linewidth=1.5)
ax.fill_between(positions, position_losses, alpha=0.1, color="#2563eb")
ax.set_xlabel("Position in Sequence", fontsize=12)
ax.set_ylabel("Cross-Entropy Loss", fontsize=12)
ax.set_title("Loss by Position (earlier positions have less context)", fontsize=13, fontweight="bold")
ax.grid(True, alpha=0.3)
# Mark key regions
if len(position_losses) > 100:
early_avg = sum(position_losses[:50]) / 50
late_avg = sum(position_losses[-200:]) / 200
ax.axhline(y=early_avg, color="red", linestyle="--", alpha=0.4,
label=f"Early avg (0-50): {early_avg:.2f}")
ax.axhline(y=late_avg, color="green", linestyle="--", alpha=0.4,
label=f"Late avg (-200): {late_avg:.2f}")
ax.legend()
plt.tight_layout()
save_path = save_path or str(self.save_dir / "position_loss.png")
fig.savefig(save_path, dpi=150, bbox_inches="tight")
print(f" πŸ“Š Position loss saved: {save_path}")
plt.close(fig)
@staticmethod
def _moving_average(data: list, window: int) -> list:
"""Compute moving average."""
result = []
for i in range(window - 1, len(data)):
avg = sum(data[i - window + 1 : i + 1]) / window
result.append(avg)
return result