LLM-1B-Lab / llm_lab /evaluation /scaling.py
Vjeong's picture
docs: translate all Korean comments and docstrings to English
858e8b2
"""Scaling Law analyzer."""
from pathlib import Path
from typing import Any, Dict, List, Optional
try:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
HAS_MATPLOTLIB = True
except ImportError:
HAS_MATPLOTLIB = False
try:
import numpy as np
HAS_NUMPY = True
except ImportError:
HAS_NUMPY = False
class ScalingAnalyzer:
"""Analyzes Scaling Law across 10M β†’ 100M β†’ 1B models.
Chinchilla Scaling Law (2022):
- Optimal training: tokens β‰ˆ 20 Γ— number of parameters
- Loss ∝ N^(-Ξ±) Γ— D^(-Ξ²) (N=parameters, D=data)
- Ξ± β‰ˆ 0.076, Ξ² β‰ˆ 0.095 (per the paper)
Purpose of this analysis:
- Verify whether our model follows the Scaling Law
- Predict the effect of larger models / more data
- Understand the optimal allocation of compute resources
"""
def __init__(self, save_dir: str = "./eval_results"):
self.save_dir = Path(save_dir)
self.save_dir.mkdir(parents=True, exist_ok=True)
def analyze(
self,
model_results: List[Dict[str, Any]],
) -> Dict[str, Any]:
"""Comparatively analyzes results across multiple model sizes.
Args:
model_results: [
{"name": "10M", "params": 10e6, "tokens": 1e9, "loss": 4.2, "ppl": 66.7},
{"name": "100M", "params": 100e6, "tokens": 5e9, "loss": 3.5, "ppl": 33.1},
{"name": "1B", "params": 1.1e9, "tokens": 10e9,"loss": 3.0, "ppl": 20.1},
]
Returns:
Analysis result dictionary
"""
if len(model_results) < 2:
print("⚠️ Scaling analysis requires results from at least 2 models.")
return {}
print("\n" + "=" * 70)
print("πŸ“ˆ Scaling Law Analysis")
print("=" * 70)
# ── Results table ──
print(f"\n {'Model':<8} {'Parameters':>12} {'Tokens':>10} {'Loss':>8} {'PPL':>8}")
print(f" {'─'*52}")
for r in model_results:
params_str = f"{r['params']/1e6:.0f}M" if r["params"] < 1e9 else f"{r['params']/1e9:.1f}B"
tokens_str = f"{r['tokens']/1e9:.1f}B"
print(f" {r['name']:<8} {params_str:>12} {tokens_str:>10} {r['loss']:>8.4f} {r['ppl']:>8.2f}")
# ── Scaling efficiency calculation ──
analysis = {"models": model_results, "scaling_efficiency": []}
for i in range(1, len(model_results)):
prev = model_results[i-1]
curr = model_results[i]
param_ratio = curr["params"] / prev["params"]
loss_reduction = prev["loss"] - curr["loss"]
ppl_reduction = (prev["ppl"] - curr["ppl"]) / prev["ppl"]
efficiency = {
"from": prev["name"],
"to": curr["name"],
"param_multiplier": round(param_ratio, 1),
"loss_reduction": round(loss_reduction, 4),
"ppl_reduction_pct": round(ppl_reduction * 100, 1),
}
analysis["scaling_efficiency"].append(efficiency)
print(f"\n {prev['name']} β†’ {curr['name']}:")
print(f" Parameters Γ—{param_ratio:.1f}")
print(f" Loss reduction: {loss_reduction:.4f}")
print(f" PPL reduction: {ppl_reduction*100:.1f}%")
# ── Chinchilla optimality check ──
print(f"\n Chinchilla optimality check (tokens β‰ˆ 20 Γ— parameters):")
for r in model_results:
actual_ratio = r["tokens"] / r["params"]
status = "βœ… Optimal range" if 15 <= actual_ratio <= 25 else "⚠️ Out of range"
print(f" {r['name']}: tokens/parameters = {actual_ratio:.1f}x "
f"(optimal: 20x) {status}")
analysis["chinchilla_ratios"] = [
{"name": r["name"], "ratio": round(r["tokens"] / r["params"], 1)}
for r in model_results
]
return analysis
def plot_scaling_curves(
self,
model_results: List[Dict[str, Any]],
save_path: Optional[str] = None,
):
"""Visualizes scaling curves."""
if not HAS_MATPLOTLIB or not HAS_NUMPY:
print("⚠️ matplotlib/numpy required: pip install matplotlib numpy")
return
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
params = [r["params"] for r in model_results]
losses = [r["loss"] for r in model_results]
ppls = [r["ppl"] for r in model_results]
names = [r["name"] for r in model_results]
# ── Loss vs Parameters (log-log) ──
ax = axes[0]
ax.loglog(params, losses, "o-", color="#2563eb", linewidth=2, markersize=10)
for p, l, n in zip(params, losses, names):
ax.annotate(f" {n}\n Loss={l:.2f}", (p, l), fontsize=9)
ax.set_xlabel("Parameters", fontsize=12)
ax.set_ylabel("Validation Loss", fontsize=12)
ax.set_title("Loss vs Model Size (log-log)", fontsize=13, fontweight="bold")
ax.grid(True, alpha=0.3)
# ── PPL vs Parameters (log-log) ──
ax = axes[1]
ax.loglog(params, ppls, "s-", color="#dc2626", linewidth=2, markersize=10)
for p, pp, n in zip(params, ppls, names):
ax.annotate(f" {n}\n PPL={pp:.1f}", (p, pp), fontsize=9)
ax.set_xlabel("Parameters", fontsize=12)
ax.set_ylabel("Perplexity", fontsize=12)
ax.set_title("Perplexity vs Model Size (log-log)", fontsize=13, fontweight="bold")
ax.grid(True, alpha=0.3)
plt.tight_layout()
save_path = save_path or str(self.save_dir / "scaling_curves.png")
fig.savefig(save_path, dpi=150, bbox_inches="tight")
print(f"\n πŸ“Š Scaling curves saved: {save_path}")
plt.close(fig)