| """Scaling Law analyzer.""" |
|
|
| from pathlib import Path |
| from typing import Any, Dict, List, Optional |
|
|
| try: |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| HAS_MATPLOTLIB = True |
| except ImportError: |
| HAS_MATPLOTLIB = False |
|
|
| try: |
| import numpy as np |
| HAS_NUMPY = True |
| except ImportError: |
| HAS_NUMPY = False |
|
|
|
|
| class ScalingAnalyzer: |
| """Analyzes Scaling Law across 10M β 100M β 1B models. |
| |
| Chinchilla Scaling Law (2022): |
| - Optimal training: tokens β 20 Γ number of parameters |
| - Loss β N^(-Ξ±) Γ D^(-Ξ²) (N=parameters, D=data) |
| - Ξ± β 0.076, Ξ² β 0.095 (per the paper) |
| |
| Purpose of this analysis: |
| - Verify whether our model follows the Scaling Law |
| - Predict the effect of larger models / more data |
| - Understand the optimal allocation of compute resources |
| """ |
|
|
| def __init__(self, save_dir: str = "./eval_results"): |
| self.save_dir = Path(save_dir) |
| self.save_dir.mkdir(parents=True, exist_ok=True) |
|
|
| def analyze( |
| self, |
| model_results: List[Dict[str, Any]], |
| ) -> Dict[str, Any]: |
| """Comparatively analyzes results across multiple model sizes. |
| |
| Args: |
| model_results: [ |
| {"name": "10M", "params": 10e6, "tokens": 1e9, "loss": 4.2, "ppl": 66.7}, |
| {"name": "100M", "params": 100e6, "tokens": 5e9, "loss": 3.5, "ppl": 33.1}, |
| {"name": "1B", "params": 1.1e9, "tokens": 10e9,"loss": 3.0, "ppl": 20.1}, |
| ] |
| |
| Returns: |
| Analysis result dictionary |
| """ |
| if len(model_results) < 2: |
| print("β οΈ Scaling analysis requires results from at least 2 models.") |
| return {} |
|
|
| print("\n" + "=" * 70) |
| print("π Scaling Law Analysis") |
| print("=" * 70) |
|
|
| |
| print(f"\n {'Model':<8} {'Parameters':>12} {'Tokens':>10} {'Loss':>8} {'PPL':>8}") |
| print(f" {'β'*52}") |
| for r in model_results: |
| params_str = f"{r['params']/1e6:.0f}M" if r["params"] < 1e9 else f"{r['params']/1e9:.1f}B" |
| tokens_str = f"{r['tokens']/1e9:.1f}B" |
| print(f" {r['name']:<8} {params_str:>12} {tokens_str:>10} {r['loss']:>8.4f} {r['ppl']:>8.2f}") |
|
|
| |
| analysis = {"models": model_results, "scaling_efficiency": []} |
|
|
| for i in range(1, len(model_results)): |
| prev = model_results[i-1] |
| curr = model_results[i] |
|
|
| param_ratio = curr["params"] / prev["params"] |
| loss_reduction = prev["loss"] - curr["loss"] |
| ppl_reduction = (prev["ppl"] - curr["ppl"]) / prev["ppl"] |
|
|
| efficiency = { |
| "from": prev["name"], |
| "to": curr["name"], |
| "param_multiplier": round(param_ratio, 1), |
| "loss_reduction": round(loss_reduction, 4), |
| "ppl_reduction_pct": round(ppl_reduction * 100, 1), |
| } |
| analysis["scaling_efficiency"].append(efficiency) |
|
|
| print(f"\n {prev['name']} β {curr['name']}:") |
| print(f" Parameters Γ{param_ratio:.1f}") |
| print(f" Loss reduction: {loss_reduction:.4f}") |
| print(f" PPL reduction: {ppl_reduction*100:.1f}%") |
|
|
| |
| print(f"\n Chinchilla optimality check (tokens β 20 Γ parameters):") |
| for r in model_results: |
| actual_ratio = r["tokens"] / r["params"] |
| status = "β
Optimal range" if 15 <= actual_ratio <= 25 else "β οΈ Out of range" |
| print(f" {r['name']}: tokens/parameters = {actual_ratio:.1f}x " |
| f"(optimal: 20x) {status}") |
|
|
| analysis["chinchilla_ratios"] = [ |
| {"name": r["name"], "ratio": round(r["tokens"] / r["params"], 1)} |
| for r in model_results |
| ] |
|
|
| return analysis |
|
|
| def plot_scaling_curves( |
| self, |
| model_results: List[Dict[str, Any]], |
| save_path: Optional[str] = None, |
| ): |
| """Visualizes scaling curves.""" |
| if not HAS_MATPLOTLIB or not HAS_NUMPY: |
| print("β οΈ matplotlib/numpy required: pip install matplotlib numpy") |
| return |
|
|
| fig, axes = plt.subplots(1, 2, figsize=(14, 5)) |
|
|
| params = [r["params"] for r in model_results] |
| losses = [r["loss"] for r in model_results] |
| ppls = [r["ppl"] for r in model_results] |
| names = [r["name"] for r in model_results] |
|
|
| |
| ax = axes[0] |
| ax.loglog(params, losses, "o-", color="#2563eb", linewidth=2, markersize=10) |
| for p, l, n in zip(params, losses, names): |
| ax.annotate(f" {n}\n Loss={l:.2f}", (p, l), fontsize=9) |
| ax.set_xlabel("Parameters", fontsize=12) |
| ax.set_ylabel("Validation Loss", fontsize=12) |
| ax.set_title("Loss vs Model Size (log-log)", fontsize=13, fontweight="bold") |
| ax.grid(True, alpha=0.3) |
|
|
| |
| ax = axes[1] |
| ax.loglog(params, ppls, "s-", color="#dc2626", linewidth=2, markersize=10) |
| for p, pp, n in zip(params, ppls, names): |
| ax.annotate(f" {n}\n PPL={pp:.1f}", (p, pp), fontsize=9) |
| ax.set_xlabel("Parameters", fontsize=12) |
| ax.set_ylabel("Perplexity", fontsize=12) |
| ax.set_title("Perplexity vs Model Size (log-log)", fontsize=13, fontweight="bold") |
| ax.grid(True, alpha=0.3) |
|
|
| plt.tight_layout() |
|
|
| save_path = save_path or str(self.save_dir / "scaling_curves.png") |
| fig.savefig(save_path, dpi=150, bbox_inches="tight") |
| print(f"\n π Scaling curves saved: {save_path}") |
| plt.close(fig) |
|
|