RikkaBotan's picture
Update app.py
77b572e verified
raw
history blame
18.7 kB
# save as llm_training_estimator_precisions.py
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import math
from datetime import timedelta
custom_css = """
body {
background: linear-gradient(135deg, #ffe6f2, #e6f7ff);
font-family: "Nunito", "Rounded Mplus 1c", sans-serif;
}
/* カードやコンポーネントの角丸 */
.gradio-container, .gr-block, .gr-box {
border-radius: 20px !important;
}
/* ボタンのデザイン */
button {
background: #ff99cc !important;
color: white !important;
border-radius: 30px !important;
padding: 10px 18px !important;
font-weight: bold !important;
transition: 0.2s;
}
button:hover {
background: #ff66b3 !important;
transform: scale(1.05);
}
/* 見出し */
h1, h2, h3 {
color: #ff6699 !important;
text-shadow: 0px 1px 2px rgba(255, 100, 150, 0.3);
}
"""
# -----------------------------
# GPU x Precision の理論ピーク FLOPS(FLOPS/sec)
# Note: FP8 = "fp8", "FP8", "Fp8" など全て小文字でマッチ
GPU_PRECISION_PEAKS = {
"H200": {"float32": 9.89e14, "float16": 1.979e15, "bfloat16": 1.979e15, "fp8": 3.958e15},
"H100": {"float32": 9.89e14, "float16": 1.979e15, "bfloat16": 1.979e15, "fp8": 3.958e15},
"A100": {"float32": 19.5e12, "float16": 312e12, "bfloat16": 312e12, "fp8": None},
"V100": {"float32": 7.8e12, "float16": 130e12, "bfloat16": None, "fp8": None},
"RTX 5090": {"float32": 120e12, "float16": 240e12, "bfloat16": 240e12, "fp8": None},
"RTX 5080": {"float32": 90e12, "float16": 180e12, "bfloat16": 180e12, "fp8": None},
"RTX 5070": {"float32": 52e12, "float16": 104e12, "bfloat16": 104e12, "fp8": None},
"RTX 5060": {"float32": 30e12, "float16": 60e12, "bfloat16": 60e12, "fp8": None},
"RTX 4090": {"float32": 82.6e12, "float16": 165.2e12, "bfloat16": 165.2e12, "fp8": None},
"RTX 4080": {"float32": 48.7e12, "float16": 97.4e12, "bfloat16": 97.4e12, "fp8": None},
"RTX 4070": {"float32": 29.2e12, "float16": 58.4e12, "bfloat16": 58.4e12, "fp8": None},
"RTX 4060": {"float32": 15.1e12, "float16": 30.2e12, "bfloat16": 30.2e12, "fp8": None},
"RTX 3090": {"float32": 35.58e12, "float16": 71.16e12, "bfloat16": None, "fp8": None},
"B200": {"float32": 18e15, "float16": 36e15, "bfloat16": 36e15, "fp8": 72e15},
}
# -----------------------------
OPTIMIZER_FACTORS = {"AdamW": 1.927, "AdamAuxMuon": 1.0}
DATASET_FACTORS = {"FineWeb Edu": 1.0, "OpenWebText": 0.14, "Custom": 1.05}
# -----------------------------
# Chinchilla-style scaling law (loss in BITS, log₂)
# Reference: Hoffmann et al. (2022), Table 1 — "OpenWebText-like"
# https://arxiv.org/abs/2203.15556
CHINCHILLA_DEFAULTS = {
"L0_bits": 1.69*1.4426,
"A_bits": 406.4*1.4426,
"B_bits": 410.7*1.4426,
"alpha": 0.34,
"beta": 0.28,
"in_bits": False, # set False to use nats (multiply all by ln(2))
}
def predict_val_loss_chinchilla(
model_params: float, # N: raw count (e.g., 7e9)
total_tokens: float, # D: raw count (e.g., 1.4e12)
dataset: str,
optimizer: str,
L0: float = CHINCHILLA_DEFAULTS["L0_bits"],
A: float = CHINCHILLA_DEFAULTS["A_bits"],
B: float = CHINCHILLA_DEFAULTS["B_bits"],
alpha: float = CHINCHILLA_DEFAULTS["alpha"],
beta: float = CHINCHILLA_DEFAULTS["beta"],
eps: float = 1e-12,
) -> float:
"""Predict validation loss using Chinchilla scaling law (in bits by default)."""
if model_params <= 0 or total_tokens <= 0:
return float("inf")
N = float(model_params)
D = float(total_tokens) * float(DATASET_FACTORS[dataset]) / float(OPTIMIZER_FACTORS[optimizer])
term_N = A / (N ** alpha + eps)
term_D = B / (D ** beta + eps)
return L0 + term_N + term_D
def loss_curve_chinchilla(model_params: float, total_tokens: float, dataset: str, optimizer: str, steps: int = 200, **kwargs):
"""Return (tokens, losses) for plotting."""
xs = np.logspace(
np.log10(1e6),
np.log10(max(total_tokens, 1e6)),
num=steps,
base=10.0
)
ys = [predict_val_loss_chinchilla(model_params, t, dataset, optimizer, **kwargs) for t in xs]
return xs, ys
# -----------------------------
def pretty_time(seconds: float) -> str:
if not np.isfinite(seconds) or seconds > 1e14:
return "推定不能 / 非現実的"
td = timedelta(seconds=float(seconds))
days = td.days
hours, remainder = divmod(td.seconds, 3600)
minutes, seconds = divmod(remainder, 60)
return f"{days} days, {hours} hrs, {minutes} min, {int(seconds)} sec"
def estimate_total_flops(params_count: float, total_tokens: float) -> float:
"""Standard: 6 FLOPs per param per token (forward + backward)."""
return 6.0 * params_count * total_tokens
def get_peak_flops(gpu_name: str, precision: str) -> float:
gpu_info = GPU_PRECISION_PEAKS.get(gpu_name)
if not gpu_info:
return None
# Normalize precision string: case-insensitive, strip spaces, map "FP8" → "fp8"
key = precision.strip().lower().replace("-", "").replace("fp", "fp")
if key.startswith("fp8") or key == "fp8":
key = "fp8"
return gpu_info.get(key)
# -----------------------------
def compute_one_model(
optimizer: str,
dataset: str,
gpu_model: str,
precision: str,
gpu_count: int,
mfu: float,
utilization_overhead: float,
seq_len: float,
batch_size: float,
steps_per_epoch: float,
epochs: float,
model_params_millions: float, # input in millions
total_tokens_override: float = None, # in raw tokens (not billions!)
scaling_kwargs: dict = None
):
scaling_kwargs = scaling_kwargs or {}
model_params = float(model_params_millions) * 1e6 # → raw count
if total_tokens_override is not None and total_tokens_override > 0:
total_tokens = float(total_tokens_override)
else:
# total tokens = seq_len * batch_size * steps_per_epoch * epochs * gpu_count
# Note: steps_per_epoch is usually *global*, but user may input per-GPU.
# We assume steps_per_epoch is *global* (standard in most frameworks)
total_tokens = float(seq_len) * float(batch_size) * float(gpu_count) * float(steps_per_epoch) * float(epochs)
peak = get_peak_flops(gpu_model, precision)
if peak is None:
return {
"error": f"⚠️ GPU '{gpu_model}' does not support precision '{precision}' (normalized to key '{precision.lower()}')."
}
total_flops = estimate_total_flops(model_params, total_tokens)
effective_flops = float(gpu_count) * peak * float(mfu) * float(utilization_overhead)
if effective_flops <= 0 or not np.isfinite(effective_flops):
return {"error": "⚠️ Effective FLOPS ≤ 0 — check MFU/utilization."}
seconds = total_flops / effective_flops
predicted_val_loss = predict_val_loss_chinchilla(
model_params=model_params,
total_tokens=total_tokens,
dataset=dataset,
optimizer=optimizer,
**scaling_kwargs
)
return {
"total_tokens": total_tokens,
"params_count": model_params,
"total_flops": total_flops,
"seconds": seconds,
"time_str": pretty_time(seconds),
"predicted_val_loss": predicted_val_loss
}
# -----------------------------
def compute_precise_estimate(
input_mode,
total_tokens_input_B, # in billions (1e9)
optimizer_a, dataset_a,
gpu_model_a, precision_a, gpu_count_a, mfu_a, utilization_overhead_a,
seq_len_a, batch_size_a, steps_per_epoch_a, epochs_a, model_params_a,
do_compare,
optimizer_b, dataset_b,
gpu_model_b, precision_b, gpu_count_b, mfu_b, utilization_overhead_b,
seq_len_b, batch_size_b, steps_per_epoch_b, epochs_b, model_params_b,
L0_val, A_val, B_val, alpha_val, beta_val, use_bits
):
# Convert user-facing "B" to raw tokens
total_tokens_override = None
if input_mode == "By total tokens":
try:
total_tokens_override = float(total_tokens_input_B) * 1e9
if total_tokens_override < 1e6:
total_tokens_override = 1e6 # min sanity
except Exception:
pass
# Scaling law config
scaling_kwargs = {
"L0": float(L0_val),
"A": float(A_val),
"B": float(B_val),
"alpha": float(alpha_val),
"beta": float(beta_val),
}
if not use_bits: # convert from bits → nats
ln2 = math.log(2)
scaling_kwargs["L0"] *= ln2
scaling_kwargs["A"] *= ln2
scaling_kwargs["B"] *= ln2
a_res = compute_one_model(
optimizer_a, dataset_a, gpu_model_a, precision_a, gpu_count_a, mfu_a, utilization_overhead_a,
seq_len_a, batch_size_a, steps_per_epoch_a, epochs_a, model_params_a,
total_tokens_override=total_tokens_override,
scaling_kwargs=scaling_kwargs
)
if "error" in a_res:
return a_res["error"], "", "", None
b_res = None
if do_compare:
b_res = compute_one_model(
optimizer_b, dataset_b, gpu_model_b, precision_b, gpu_count_b, mfu_b, utilization_overhead_b,
seq_len_b, batch_size_b, steps_per_epoch_b, epochs_b, model_params_b,
total_tokens_override=total_tokens_override,
scaling_kwargs=scaling_kwargs
)
# Output formatting
flops_lines = [f"Model A: {a_res['total_flops']:.3e} FLOPs (params={a_res['params_count']:.1e}, tokens={a_res['total_tokens']:.1e})"]
time_lines = [f"Model A: {a_res['time_str']}"]
loss_lines = [f"Model A: {a_res['predicted_val_loss']:.4f} loss"]
if do_compare:
if "error" in (b_res or {}):
err = b_res.get("error", "Unknown error")
flops_lines.append(f"Model B: error — {err}")
time_lines.append(f"Model B: error — {err}")
loss_lines.append(f"Model B: error — {err}")
else:
flops_lines.append(f"Model B: {b_res['total_flops']:.3e} FLOPs (params={b_res['params_count']:.1e}, tokens={b_res['total_tokens']:.1e})")
time_lines.append(f"Model B: {b_res['time_str']}")
loss_lines.append(f"Model B: {b_res['predicted_val_loss']:.4f} loss")
# Plot
fig, ax = plt.subplots(figsize=(7, 5))
xs_a, ys_a = loss_curve_chinchilla(
model_params=a_res["params_count"],
total_tokens=max(a_res["total_tokens"], 1e6),
dataset=dataset_a,
optimizer=optimizer_a,
**scaling_kwargs
)
ax.plot(xs_a, ys_a, label=f"Model A ({model_params_a:.0f}M)", linewidth=2)
if do_compare and b_res and "params_count" in b_res:
xs_b, ys_b = loss_curve_chinchilla(
model_params=b_res["params_count"],
total_tokens=max(b_res["total_tokens"], 1e6),
dataset=dataset_b,
optimizer=optimizer_b,
**scaling_kwargs
)
ax.plot(xs_b, ys_b, label=f"Model B ({model_params_b:.0f}M)", linestyle='--', linewidth=2)
# Add asymptotic loss line
L0_plot = scaling_kwargs["L0"]
ax.axhline(L0_plot, color='gray', linestyle=':', linewidth=1, label=f"Asymptotic loss $L_0$ = {L0_plot:.3f}")
ax.set_xscale('log')
ax.set_xlabel("Tokens seen (log scale)")
ax.set_ylabel("Predicted validation loss (bits)" if use_bits else "Predicted validation loss (nats)")
ax.set_title("Chinchilla scaling law: Loss vs Tokens")
ax.grid(True, linestyle='-.', alpha=0.5)
ax.legend()
ax.set_ylim(bottom=max(0.0, L0_plot - 0.5), top=None)
plt.tight_layout()
plt.tick_params(axis='both', which='both',
direction='in', # 内向き
labelbottom=True, # x軸のラベル非表示
labelleft=True, # y軸のラベル非表示
bottom=True, top=False,
left=True, right=False)
fig.subplots_adjust(top=0.85)
return "\n".join(flops_lines), "\n".join(time_lines), "\n".join(loss_lines), fig
# -----------------------------
gpu_choices = list(GPU_PRECISION_PEAKS.keys())
precision_choices = ["float32", "float16", "bfloat16", "FP8"]
with gr.Blocks(title="LLM Training Estimator — Chinchilla Scaling Law") as demo:
gr.Markdown(r"""
# 🚀 LLM Training Estimator (Precision-aware)
- **GPU + Precision + Peak FLOPS + Effective throughput** → Training time & Loss
- **Validation loss prediction** via **Chinchilla scaling law** (Hoffmann et al., 2022)
$$ L(N, D) = L_\infty + \frac{A}{N^\alpha} + \frac{B}{D^\beta} $$
- Compare two configurations side-by-side
- Supports **bits (log₂)** or **nats (ln)** loss units
- Based on empirical fits for English web text (FineWeb/OpenWebText)
> ✅ More accurate than GPT-3-style power-law (arXiv:2001.08361)
## 🌸 MFU & System Utilization Examples (8×H100)
|Parameters |MFU | System Utilization |
|------------ |------------- |------------------- |
|182M | ~39% | ~0.5 |
|560M | ~47% | ~0.5 |
|1B~ | ~50% | ~0.5 |
""")
with gr.Row():
with gr.Column(scale=1):
input_mode = gr.Radio(
["By total tokens", "By steps (derived from seq/batch/steps/epochs)"],
value="By total tokens",
label="Token input mode"
)
total_tokens_input_B = gr.Number(value=100.0, label="Total tokens (B) — used if 'By total tokens' selected")
with gr.Accordion("Model A Configuration", open=True):
model_params_a = gr.Number(value=7000.0, label="Model params (millions)")
optimizer_a = gr.Dropdown(list(OPTIMIZER_FACTORS.keys()), label="Optimizer", value="AdamW")
dataset_a = gr.Dropdown(list(DATASET_FACTORS.keys()), label="Dataset", value="OpenWebText")
gpu_model_a = gr.Dropdown(gpu_choices, label="GPU model", value="H100")
precision_a = gr.Dropdown(precision_choices, label="Precision", value="bfloat16")
gpu_count_a = gr.Slider(1, 1024, value=8, step=1, label="GPU count")
mfu_a = gr.Slider(0.01, 1.0, value=0.35, step=0.01, label="MFU (Model FLOPs Utilization)")
utilization_overhead_a = gr.Slider(0.05, 1.0, value=0.5, step=0.01, label="System utilization (incl. comms, IO)")
seq_len_a = gr.Number(value=2048, label="Sequence length (tokens)")
batch_size_a = gr.Number(value=256, label="Global batch size*") # clarified
steps_per_epoch_a = gr.Number(value=1000, label="Steps per epoch")
epochs_a = gr.Number(value=1, label="Epochs")
do_compare = gr.Checkbox(label="Compare with Model B", value=False)
with gr.Accordion("Model B Configuration", open=False):
model_params_b = gr.Number(value=70000.0, label="Model params (millions)")
optimizer_b = gr.Dropdown(list(OPTIMIZER_FACTORS.keys()), label="Optimizer", value="AdamW")
dataset_b = gr.Dropdown(list(DATASET_FACTORS.keys()), label="Dataset", value="OpenWebText")
gpu_model_b = gr.Dropdown(gpu_choices, label="GPU model", value="A100")
precision_b = gr.Dropdown(precision_choices, label="Precision", value="float16")
gpu_count_b = gr.Slider(1, 1024, value=64, step=1, label="GPU count")
mfu_b = gr.Slider(0.01, 1.0, value=0.25, step=0.01, label="MFU")
utilization_overhead_b = gr.Slider(0.05, 1.0, value=0.4, step=0.01, label="System utilization")
seq_len_b = gr.Number(value=2048, label="Sequence length")
batch_size_b = gr.Number(value=2048, label="Global batch size*")
steps_per_epoch_b = gr.Number(value=1000, label="Steps per epoch")
epochs_b = gr.Number(value=1, label="Epochs")
with gr.Accordion("Scaling Law Parameters (Chinchilla)", open=False):
use_bits = gr.Checkbox(value=False, label="Loss in bits(log2)")
L0_val = gr.Number(value=CHINCHILLA_DEFAULTS["L0_bits"], label="L∞ (irreducible loss)")
A_val = gr.Number(value=CHINCHILLA_DEFAULTS["A_bits"], label="A (model-size coefficient)")
B_val = gr.Number(value=CHINCHILLA_DEFAULTS["B_bits"], label="B (data-size coefficient)")
alpha_val = gr.Number(value=CHINCHILLA_DEFAULTS["alpha"], label="α (model exponent)")
beta_val = gr.Number(value=CHINCHILLA_DEFAULTS["beta"], label="β (data exponent)")
gr.Markdown(r"""
# 💡 Tip:
> | | $$ L_\infty $$ | $$ A $$ | $$ B $$ |
> |----------------- |------------- |------------- |--------------- |
> |code/math data |increase |- |- |
> |high-quality data |- |reduce |reduce |
> Default: English web text (FineWeb-like).
""")
with gr.Column(scale=1):
flops_out = gr.Textbox(label="Total Compute (FLOPs)", lines=3)
time_out = gr.Textbox(label="Estimated Training Time", lines=2)
loss_out = gr.Textbox(label="Predicted Validation Loss", lines=2)
plot_out = gr.Plot(label="Loss vs Tokens Curve")
run_btn = gr.Button("🚀 Estimate / Compare", variant="primary")
run_btn.click(
fn=compute_precise_estimate,
inputs=[
input_mode, total_tokens_input_B,
optimizer_a, dataset_a, gpu_model_a, precision_a, gpu_count_a, mfu_a, utilization_overhead_a,
seq_len_a, batch_size_a, steps_per_epoch_a, epochs_a, model_params_a,
do_compare,
optimizer_b, dataset_b, gpu_model_b, precision_b, gpu_count_b, mfu_b, utilization_overhead_b,
seq_len_b, batch_size_b, steps_per_epoch_b, epochs_b, model_params_b,
L0_val, A_val, B_val, alpha_val, beta_val, use_bits
],
outputs=[flops_out, time_out, loss_out, plot_out]
)
gr.Markdown("""
---
📚 References:
- [Chinchilla](https://arxiv.org/abs/2203.15556): *Training Compute-Optimal Large Language Models*
- [GPT-3](https://arxiv.org/abs/2005.14165) (note: arXiv:2001.08361 is earlier version)
- [MFU definition](https://arxiv.org/abs/2205.14451)
- [Muon](https://arxiv.org/abs/2502.16982)
""")
if __name__ == "__main__":
demo.launch(theme=gr.themes.Soft(), css=custom_css)