|
|
|
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
import math |
|
|
from datetime import timedelta |
|
|
|
|
|
custom_css = """ |
|
|
body { |
|
|
background: linear-gradient(135deg, #ffe6f2, #e6f7ff); |
|
|
font-family: "Nunito", "Rounded Mplus 1c", sans-serif; |
|
|
} |
|
|
|
|
|
/* カードやコンポーネントの角丸 */ |
|
|
.gradio-container, .gr-block, .gr-box { |
|
|
border-radius: 20px !important; |
|
|
} |
|
|
|
|
|
/* ボタンのデザイン */ |
|
|
button { |
|
|
background: #ff99cc !important; |
|
|
color: white !important; |
|
|
border-radius: 30px !important; |
|
|
padding: 10px 18px !important; |
|
|
font-weight: bold !important; |
|
|
transition: 0.2s; |
|
|
} |
|
|
|
|
|
button:hover { |
|
|
background: #ff66b3 !important; |
|
|
transform: scale(1.05); |
|
|
} |
|
|
|
|
|
/* 見出し */ |
|
|
h1, h2, h3 { |
|
|
color: #ff6699 !important; |
|
|
text-shadow: 0px 1px 2px rgba(255, 100, 150, 0.3); |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
GPU_PRECISION_PEAKS = { |
|
|
"H200": {"float32": 9.89e14, "float16": 1.979e15, "bfloat16": 1.979e15, "fp8": 3.958e15}, |
|
|
"H100": {"float32": 9.89e14, "float16": 1.979e15, "bfloat16": 1.979e15, "fp8": 3.958e15}, |
|
|
"A100": {"float32": 19.5e12, "float16": 312e12, "bfloat16": 312e12, "fp8": None}, |
|
|
"V100": {"float32": 7.8e12, "float16": 130e12, "bfloat16": None, "fp8": None}, |
|
|
"RTX 5090": {"float32": 120e12, "float16": 240e12, "bfloat16": 240e12, "fp8": None}, |
|
|
"RTX 5080": {"float32": 90e12, "float16": 180e12, "bfloat16": 180e12, "fp8": None}, |
|
|
"RTX 5070": {"float32": 52e12, "float16": 104e12, "bfloat16": 104e12, "fp8": None}, |
|
|
"RTX 5060": {"float32": 30e12, "float16": 60e12, "bfloat16": 60e12, "fp8": None}, |
|
|
"RTX 4090": {"float32": 82.6e12, "float16": 165.2e12, "bfloat16": 165.2e12, "fp8": None}, |
|
|
"RTX 4080": {"float32": 48.7e12, "float16": 97.4e12, "bfloat16": 97.4e12, "fp8": None}, |
|
|
"RTX 4070": {"float32": 29.2e12, "float16": 58.4e12, "bfloat16": 58.4e12, "fp8": None}, |
|
|
"RTX 4060": {"float32": 15.1e12, "float16": 30.2e12, "bfloat16": 30.2e12, "fp8": None}, |
|
|
"RTX 3090": {"float32": 35.58e12, "float16": 71.16e12, "bfloat16": None, "fp8": None}, |
|
|
"B200": {"float32": 18e15, "float16": 36e15, "bfloat16": 36e15, "fp8": 72e15}, |
|
|
} |
|
|
|
|
|
|
|
|
OPTIMIZER_FACTORS = {"AdamW": 1.927, "AdamAuxMuon": 1.0} |
|
|
DATASET_FACTORS = {"FineWeb Edu": 1.0, "OpenWebText": 0.14, "Custom": 1.05} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CHINCHILLA_DEFAULTS = { |
|
|
"L0_bits": 1.69*1.4426, |
|
|
"A_bits": 406.4*1.4426, |
|
|
"B_bits": 410.7*1.4426, |
|
|
"alpha": 0.34, |
|
|
"beta": 0.28, |
|
|
"in_bits": False, |
|
|
} |
|
|
|
|
|
def predict_val_loss_chinchilla( |
|
|
model_params: float, |
|
|
total_tokens: float, |
|
|
dataset: str, |
|
|
optimizer: str, |
|
|
L0: float = CHINCHILLA_DEFAULTS["L0_bits"], |
|
|
A: float = CHINCHILLA_DEFAULTS["A_bits"], |
|
|
B: float = CHINCHILLA_DEFAULTS["B_bits"], |
|
|
alpha: float = CHINCHILLA_DEFAULTS["alpha"], |
|
|
beta: float = CHINCHILLA_DEFAULTS["beta"], |
|
|
eps: float = 1e-12, |
|
|
) -> float: |
|
|
"""Predict validation loss using Chinchilla scaling law (in bits by default).""" |
|
|
if model_params <= 0 or total_tokens <= 0: |
|
|
return float("inf") |
|
|
N = float(model_params) |
|
|
D = float(total_tokens) * float(DATASET_FACTORS[dataset]) / float(OPTIMIZER_FACTORS[optimizer]) |
|
|
term_N = A / (N ** alpha + eps) |
|
|
term_D = B / (D ** beta + eps) |
|
|
return L0 + term_N + term_D |
|
|
|
|
|
def loss_curve_chinchilla(model_params: float, total_tokens: float, dataset: str, optimizer: str, steps: int = 200, **kwargs): |
|
|
"""Return (tokens, losses) for plotting.""" |
|
|
xs = np.logspace( |
|
|
np.log10(1e6), |
|
|
np.log10(max(total_tokens, 1e6)), |
|
|
num=steps, |
|
|
base=10.0 |
|
|
) |
|
|
ys = [predict_val_loss_chinchilla(model_params, t, dataset, optimizer, **kwargs) for t in xs] |
|
|
return xs, ys |
|
|
|
|
|
|
|
|
def pretty_time(seconds: float) -> str: |
|
|
if not np.isfinite(seconds) or seconds > 1e14: |
|
|
return "推定不能 / 非現実的" |
|
|
td = timedelta(seconds=float(seconds)) |
|
|
days = td.days |
|
|
hours, remainder = divmod(td.seconds, 3600) |
|
|
minutes, seconds = divmod(remainder, 60) |
|
|
return f"{days} days, {hours} hrs, {minutes} min, {int(seconds)} sec" |
|
|
|
|
|
def estimate_total_flops(params_count: float, total_tokens: float) -> float: |
|
|
"""Standard: 6 FLOPs per param per token (forward + backward).""" |
|
|
return 6.0 * params_count * total_tokens |
|
|
|
|
|
def get_peak_flops(gpu_name: str, precision: str) -> float: |
|
|
gpu_info = GPU_PRECISION_PEAKS.get(gpu_name) |
|
|
if not gpu_info: |
|
|
return None |
|
|
|
|
|
key = precision.strip().lower().replace("-", "").replace("fp", "fp") |
|
|
if key.startswith("fp8") or key == "fp8": |
|
|
key = "fp8" |
|
|
return gpu_info.get(key) |
|
|
|
|
|
|
|
|
def compute_one_model( |
|
|
optimizer: str, |
|
|
dataset: str, |
|
|
gpu_model: str, |
|
|
precision: str, |
|
|
gpu_count: int, |
|
|
mfu: float, |
|
|
utilization_overhead: float, |
|
|
seq_len: float, |
|
|
batch_size: float, |
|
|
steps_per_epoch: float, |
|
|
epochs: float, |
|
|
model_params_millions: float, |
|
|
total_tokens_override: float = None, |
|
|
scaling_kwargs: dict = None |
|
|
): |
|
|
scaling_kwargs = scaling_kwargs or {} |
|
|
|
|
|
model_params = float(model_params_millions) * 1e6 |
|
|
|
|
|
if total_tokens_override is not None and total_tokens_override > 0: |
|
|
total_tokens = float(total_tokens_override) |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
total_tokens = float(seq_len) * float(batch_size) * float(gpu_count) * float(steps_per_epoch) * float(epochs) |
|
|
|
|
|
peak = get_peak_flops(gpu_model, precision) |
|
|
if peak is None: |
|
|
return { |
|
|
"error": f"⚠️ GPU '{gpu_model}' does not support precision '{precision}' (normalized to key '{precision.lower()}')." |
|
|
} |
|
|
|
|
|
total_flops = estimate_total_flops(model_params, total_tokens) |
|
|
effective_flops = float(gpu_count) * peak * float(mfu) * float(utilization_overhead) |
|
|
|
|
|
if effective_flops <= 0 or not np.isfinite(effective_flops): |
|
|
return {"error": "⚠️ Effective FLOPS ≤ 0 — check MFU/utilization."} |
|
|
|
|
|
seconds = total_flops / effective_flops |
|
|
|
|
|
predicted_val_loss = predict_val_loss_chinchilla( |
|
|
model_params=model_params, |
|
|
total_tokens=total_tokens, |
|
|
dataset=dataset, |
|
|
optimizer=optimizer, |
|
|
**scaling_kwargs |
|
|
) |
|
|
|
|
|
return { |
|
|
"total_tokens": total_tokens, |
|
|
"params_count": model_params, |
|
|
"total_flops": total_flops, |
|
|
"seconds": seconds, |
|
|
"time_str": pretty_time(seconds), |
|
|
"predicted_val_loss": predicted_val_loss |
|
|
} |
|
|
|
|
|
|
|
|
def compute_precise_estimate( |
|
|
input_mode, |
|
|
total_tokens_input_B, |
|
|
optimizer_a, dataset_a, |
|
|
gpu_model_a, precision_a, gpu_count_a, mfu_a, utilization_overhead_a, |
|
|
seq_len_a, batch_size_a, steps_per_epoch_a, epochs_a, model_params_a, |
|
|
do_compare, |
|
|
optimizer_b, dataset_b, |
|
|
gpu_model_b, precision_b, gpu_count_b, mfu_b, utilization_overhead_b, |
|
|
seq_len_b, batch_size_b, steps_per_epoch_b, epochs_b, model_params_b, |
|
|
L0_val, A_val, B_val, alpha_val, beta_val, use_bits |
|
|
): |
|
|
|
|
|
total_tokens_override = None |
|
|
if input_mode == "By total tokens": |
|
|
try: |
|
|
total_tokens_override = float(total_tokens_input_B) * 1e9 |
|
|
if total_tokens_override < 1e6: |
|
|
total_tokens_override = 1e6 |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
scaling_kwargs = { |
|
|
"L0": float(L0_val), |
|
|
"A": float(A_val), |
|
|
"B": float(B_val), |
|
|
"alpha": float(alpha_val), |
|
|
"beta": float(beta_val), |
|
|
} |
|
|
if not use_bits: |
|
|
ln2 = math.log(2) |
|
|
scaling_kwargs["L0"] *= ln2 |
|
|
scaling_kwargs["A"] *= ln2 |
|
|
scaling_kwargs["B"] *= ln2 |
|
|
|
|
|
a_res = compute_one_model( |
|
|
optimizer_a, dataset_a, gpu_model_a, precision_a, gpu_count_a, mfu_a, utilization_overhead_a, |
|
|
seq_len_a, batch_size_a, steps_per_epoch_a, epochs_a, model_params_a, |
|
|
total_tokens_override=total_tokens_override, |
|
|
scaling_kwargs=scaling_kwargs |
|
|
) |
|
|
|
|
|
if "error" in a_res: |
|
|
return a_res["error"], "", "", None |
|
|
|
|
|
b_res = None |
|
|
if do_compare: |
|
|
b_res = compute_one_model( |
|
|
optimizer_b, dataset_b, gpu_model_b, precision_b, gpu_count_b, mfu_b, utilization_overhead_b, |
|
|
seq_len_b, batch_size_b, steps_per_epoch_b, epochs_b, model_params_b, |
|
|
total_tokens_override=total_tokens_override, |
|
|
scaling_kwargs=scaling_kwargs |
|
|
) |
|
|
|
|
|
|
|
|
flops_lines = [f"Model A: {a_res['total_flops']:.3e} FLOPs (params={a_res['params_count']:.1e}, tokens={a_res['total_tokens']:.1e})"] |
|
|
time_lines = [f"Model A: {a_res['time_str']}"] |
|
|
loss_lines = [f"Model A: {a_res['predicted_val_loss']:.4f} loss"] |
|
|
|
|
|
if do_compare: |
|
|
if "error" in (b_res or {}): |
|
|
err = b_res.get("error", "Unknown error") |
|
|
flops_lines.append(f"Model B: error — {err}") |
|
|
time_lines.append(f"Model B: error — {err}") |
|
|
loss_lines.append(f"Model B: error — {err}") |
|
|
else: |
|
|
flops_lines.append(f"Model B: {b_res['total_flops']:.3e} FLOPs (params={b_res['params_count']:.1e}, tokens={b_res['total_tokens']:.1e})") |
|
|
time_lines.append(f"Model B: {b_res['time_str']}") |
|
|
loss_lines.append(f"Model B: {b_res['predicted_val_loss']:.4f} loss") |
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(figsize=(7, 5)) |
|
|
xs_a, ys_a = loss_curve_chinchilla( |
|
|
model_params=a_res["params_count"], |
|
|
total_tokens=max(a_res["total_tokens"], 1e6), |
|
|
dataset=dataset_a, |
|
|
optimizer=optimizer_a, |
|
|
**scaling_kwargs |
|
|
) |
|
|
ax.plot(xs_a, ys_a, label=f"Model A ({model_params_a:.0f}M)", linewidth=2) |
|
|
|
|
|
if do_compare and b_res and "params_count" in b_res: |
|
|
xs_b, ys_b = loss_curve_chinchilla( |
|
|
model_params=b_res["params_count"], |
|
|
total_tokens=max(b_res["total_tokens"], 1e6), |
|
|
dataset=dataset_b, |
|
|
optimizer=optimizer_b, |
|
|
**scaling_kwargs |
|
|
) |
|
|
ax.plot(xs_b, ys_b, label=f"Model B ({model_params_b:.0f}M)", linestyle='--', linewidth=2) |
|
|
|
|
|
|
|
|
L0_plot = scaling_kwargs["L0"] |
|
|
ax.axhline(L0_plot, color='gray', linestyle=':', linewidth=1, label=f"Asymptotic loss $L_0$ = {L0_plot:.3f}") |
|
|
|
|
|
ax.set_xscale('log') |
|
|
ax.set_xlabel("Tokens seen (log scale)") |
|
|
ax.set_ylabel("Predicted validation loss (bits)" if use_bits else "Predicted validation loss (nats)") |
|
|
ax.set_title("Chinchilla scaling law: Loss vs Tokens") |
|
|
ax.grid(True, linestyle='-.', alpha=0.5) |
|
|
ax.legend() |
|
|
ax.set_ylim(bottom=max(0.0, L0_plot - 0.5), top=None) |
|
|
plt.tight_layout() |
|
|
plt.tick_params(axis='both', which='both', |
|
|
direction='in', |
|
|
labelbottom=True, |
|
|
labelleft=True, |
|
|
bottom=True, top=False, |
|
|
left=True, right=False) |
|
|
fig.subplots_adjust(top=0.85) |
|
|
|
|
|
return "\n".join(flops_lines), "\n".join(time_lines), "\n".join(loss_lines), fig |
|
|
|
|
|
|
|
|
gpu_choices = list(GPU_PRECISION_PEAKS.keys()) |
|
|
precision_choices = ["float32", "float16", "bfloat16", "FP8"] |
|
|
|
|
|
with gr.Blocks(title="LLM Training Estimator — Chinchilla Scaling Law") as demo: |
|
|
gr.Markdown(r""" |
|
|
# 🚀 LLM Training Estimator (Precision-aware) |
|
|
|
|
|
- **GPU + Precision + Peak FLOPS + Effective throughput** → Training time & Loss |
|
|
- **Validation loss prediction** via **Chinchilla scaling law** (Hoffmann et al., 2022) |
|
|
$$ L(N, D) = L_\infty + \frac{A}{N^\alpha} + \frac{B}{D^\beta} $$ |
|
|
- Compare two configurations side-by-side |
|
|
- Supports **bits (log₂)** or **nats (ln)** loss units |
|
|
- Based on empirical fits for English web text (FineWeb/OpenWebText) |
|
|
|
|
|
> ✅ More accurate than GPT-3-style power-law (arXiv:2001.08361) |
|
|
|
|
|
## 🌸 MFU & System Utilization Examples (8×H100) |
|
|
|Parameters |MFU | System Utilization | |
|
|
|------------ |------------- |------------------- | |
|
|
|182M | ~39% | ~0.5 | |
|
|
|560M | ~47% | ~0.5 | |
|
|
|1B~ | ~50% | ~0.5 | |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
input_mode = gr.Radio( |
|
|
["By total tokens", "By steps (derived from seq/batch/steps/epochs)"], |
|
|
value="By total tokens", |
|
|
label="Token input mode" |
|
|
) |
|
|
total_tokens_input_B = gr.Number(value=100.0, label="Total tokens (B) — used if 'By total tokens' selected") |
|
|
|
|
|
with gr.Accordion("Model A Configuration", open=True): |
|
|
model_params_a = gr.Number(value=7000.0, label="Model params (millions)") |
|
|
optimizer_a = gr.Dropdown(list(OPTIMIZER_FACTORS.keys()), label="Optimizer", value="AdamW") |
|
|
dataset_a = gr.Dropdown(list(DATASET_FACTORS.keys()), label="Dataset", value="OpenWebText") |
|
|
gpu_model_a = gr.Dropdown(gpu_choices, label="GPU model", value="H100") |
|
|
precision_a = gr.Dropdown(precision_choices, label="Precision", value="bfloat16") |
|
|
gpu_count_a = gr.Slider(1, 1024, value=8, step=1, label="GPU count") |
|
|
mfu_a = gr.Slider(0.01, 1.0, value=0.35, step=0.01, label="MFU (Model FLOPs Utilization)") |
|
|
utilization_overhead_a = gr.Slider(0.05, 1.0, value=0.5, step=0.01, label="System utilization (incl. comms, IO)") |
|
|
seq_len_a = gr.Number(value=2048, label="Sequence length (tokens)") |
|
|
batch_size_a = gr.Number(value=256, label="Global batch size*") |
|
|
steps_per_epoch_a = gr.Number(value=1000, label="Steps per epoch") |
|
|
epochs_a = gr.Number(value=1, label="Epochs") |
|
|
|
|
|
|
|
|
do_compare = gr.Checkbox(label="Compare with Model B", value=False) |
|
|
|
|
|
with gr.Accordion("Model B Configuration", open=False): |
|
|
model_params_b = gr.Number(value=70000.0, label="Model params (millions)") |
|
|
optimizer_b = gr.Dropdown(list(OPTIMIZER_FACTORS.keys()), label="Optimizer", value="AdamW") |
|
|
dataset_b = gr.Dropdown(list(DATASET_FACTORS.keys()), label="Dataset", value="OpenWebText") |
|
|
gpu_model_b = gr.Dropdown(gpu_choices, label="GPU model", value="A100") |
|
|
precision_b = gr.Dropdown(precision_choices, label="Precision", value="float16") |
|
|
gpu_count_b = gr.Slider(1, 1024, value=64, step=1, label="GPU count") |
|
|
mfu_b = gr.Slider(0.01, 1.0, value=0.25, step=0.01, label="MFU") |
|
|
utilization_overhead_b = gr.Slider(0.05, 1.0, value=0.4, step=0.01, label="System utilization") |
|
|
seq_len_b = gr.Number(value=2048, label="Sequence length") |
|
|
batch_size_b = gr.Number(value=2048, label="Global batch size*") |
|
|
steps_per_epoch_b = gr.Number(value=1000, label="Steps per epoch") |
|
|
epochs_b = gr.Number(value=1, label="Epochs") |
|
|
|
|
|
|
|
|
with gr.Accordion("Scaling Law Parameters (Chinchilla)", open=False): |
|
|
use_bits = gr.Checkbox(value=False, label="Loss in bits(log2)") |
|
|
L0_val = gr.Number(value=CHINCHILLA_DEFAULTS["L0_bits"], label="L∞ (irreducible loss)") |
|
|
A_val = gr.Number(value=CHINCHILLA_DEFAULTS["A_bits"], label="A (model-size coefficient)") |
|
|
B_val = gr.Number(value=CHINCHILLA_DEFAULTS["B_bits"], label="B (data-size coefficient)") |
|
|
alpha_val = gr.Number(value=CHINCHILLA_DEFAULTS["alpha"], label="α (model exponent)") |
|
|
beta_val = gr.Number(value=CHINCHILLA_DEFAULTS["beta"], label="β (data exponent)") |
|
|
|
|
|
gr.Markdown(r""" |
|
|
# 💡 Tip: |
|
|
> | | $$ L_\infty $$ | $$ A $$ | $$ B $$ | |
|
|
> |----------------- |------------- |------------- |--------------- | |
|
|
> |code/math data |increase |- |- | |
|
|
> |high-quality data |- |reduce |reduce | |
|
|
|
|
|
> Default: English web text (FineWeb-like). |
|
|
""") |
|
|
|
|
|
|
|
|
|
|
|
with gr.Column(scale=1): |
|
|
flops_out = gr.Textbox(label="Total Compute (FLOPs)", lines=3) |
|
|
time_out = gr.Textbox(label="Estimated Training Time", lines=2) |
|
|
loss_out = gr.Textbox(label="Predicted Validation Loss", lines=2) |
|
|
plot_out = gr.Plot(label="Loss vs Tokens Curve") |
|
|
run_btn = gr.Button("🚀 Estimate / Compare", variant="primary") |
|
|
|
|
|
run_btn.click( |
|
|
fn=compute_precise_estimate, |
|
|
inputs=[ |
|
|
input_mode, total_tokens_input_B, |
|
|
optimizer_a, dataset_a, gpu_model_a, precision_a, gpu_count_a, mfu_a, utilization_overhead_a, |
|
|
seq_len_a, batch_size_a, steps_per_epoch_a, epochs_a, model_params_a, |
|
|
do_compare, |
|
|
optimizer_b, dataset_b, gpu_model_b, precision_b, gpu_count_b, mfu_b, utilization_overhead_b, |
|
|
seq_len_b, batch_size_b, steps_per_epoch_b, epochs_b, model_params_b, |
|
|
L0_val, A_val, B_val, alpha_val, beta_val, use_bits |
|
|
], |
|
|
outputs=[flops_out, time_out, loss_out, plot_out] |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
--- |
|
|
📚 References: |
|
|
- [Chinchilla](https://arxiv.org/abs/2203.15556): *Training Compute-Optimal Large Language Models* |
|
|
- [GPT-3](https://arxiv.org/abs/2005.14165) (note: arXiv:2001.08361 is earlier version) |
|
|
- [MFU definition](https://arxiv.org/abs/2205.14451) |
|
|
- [Muon](https://arxiv.org/abs/2502.16982) |
|
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(theme=gr.themes.Soft(), css=custom_css) |
|
|
|