# save as llm_training_estimator_precisions.py import gradio as gr import numpy as np import matplotlib.pyplot as plt import math from datetime import timedelta custom_css = """ body { background: linear-gradient(135deg, #ffe6f2, #e6f7ff); font-family: "Nunito", "Rounded Mplus 1c", sans-serif; } /* カードやコンポーネントの角丸 */ .gradio-container, .gr-block, .gr-box { border-radius: 20px !important; } /* ボタンのデザイン */ button { background: #ff99cc !important; color: white !important; border-radius: 30px !important; padding: 10px 18px !important; font-weight: bold !important; transition: 0.2s; } button:hover { background: #ff66b3 !important; transform: scale(1.05); } /* 見出し */ h1, h2, h3 { color: #ff6699 !important; text-shadow: 0px 1px 2px rgba(255, 100, 150, 0.3); } """ # ----------------------------- # GPU x Precision の理論ピーク FLOPS(FLOPS/sec) # Note: FP8 = "fp8", "FP8", "Fp8" など全て小文字でマッチ GPU_PRECISION_PEAKS = { "H200": {"float32": 9.89e14, "float16": 1.979e15, "bfloat16": 1.979e15, "fp8": 3.958e15}, "H100": {"float32": 9.89e14, "float16": 1.979e15, "bfloat16": 1.979e15, "fp8": 3.958e15}, "A100": {"float32": 19.5e12, "float16": 312e12, "bfloat16": 312e12, "fp8": None}, "V100": {"float32": 7.8e12, "float16": 130e12, "bfloat16": None, "fp8": None}, "RTX 5090": {"float32": 120e12, "float16": 240e12, "bfloat16": 240e12, "fp8": None}, "RTX 5080": {"float32": 90e12, "float16": 180e12, "bfloat16": 180e12, "fp8": None}, "RTX 5070": {"float32": 52e12, "float16": 104e12, "bfloat16": 104e12, "fp8": None}, "RTX 5060": {"float32": 30e12, "float16": 60e12, "bfloat16": 60e12, "fp8": None}, "RTX 4090": {"float32": 82.6e12, "float16": 165.2e12, "bfloat16": 165.2e12, "fp8": None}, "RTX 4080": {"float32": 48.7e12, "float16": 97.4e12, "bfloat16": 97.4e12, "fp8": None}, "RTX 4070": {"float32": 29.2e12, "float16": 58.4e12, "bfloat16": 58.4e12, "fp8": None}, "RTX 4060": {"float32": 15.1e12, "float16": 30.2e12, "bfloat16": 30.2e12, "fp8": None}, "RTX 3090": {"float32": 35.58e12, "float16": 71.16e12, "bfloat16": None, "fp8": None}, "B200": {"float32": 18e15, "float16": 36e15, "bfloat16": 36e15, "fp8": 72e15}, } # ----------------------------- OPTIMIZER_FACTORS = {"AdamW": 1.927, "AdamAuxMuon": 1.0} DATASET_FACTORS = {"FineWeb Edu": 1.0, "OpenWebText": 0.14, "Custom": 1.05} # ----------------------------- # Chinchilla-style scaling law (loss in BITS, log₂) # Reference: Hoffmann et al. (2022), Table 1 — "OpenWebText-like" # https://arxiv.org/abs/2203.15556 CHINCHILLA_DEFAULTS = { "L0_bits": 1.69*1.4426, "A_bits": 406.4*1.4426, "B_bits": 410.7*1.4426, "alpha": 0.34, "beta": 0.28, "in_bits": False, # set False to use nats (multiply all by ln(2)) } def predict_val_loss_chinchilla( model_params: float, # N: raw count (e.g., 7e9) total_tokens: float, # D: raw count (e.g., 1.4e12) dataset: str, optimizer: str, L0: float = CHINCHILLA_DEFAULTS["L0_bits"], A: float = CHINCHILLA_DEFAULTS["A_bits"], B: float = CHINCHILLA_DEFAULTS["B_bits"], alpha: float = CHINCHILLA_DEFAULTS["alpha"], beta: float = CHINCHILLA_DEFAULTS["beta"], eps: float = 1e-12, ) -> float: """Predict validation loss using Chinchilla scaling law (in bits by default).""" if model_params <= 0 or total_tokens <= 0: return float("inf") N = float(model_params) D = float(total_tokens) * float(DATASET_FACTORS[dataset]) / float(OPTIMIZER_FACTORS[optimizer]) term_N = A / (N ** alpha + eps) term_D = B / (D ** beta + eps) return L0 + term_N + term_D def loss_curve_chinchilla(model_params: float, total_tokens: float, dataset: str, optimizer: str, steps: int = 200, **kwargs): """Return (tokens, losses) for plotting.""" xs = np.logspace( np.log10(1e6), np.log10(max(total_tokens, 1e6)), num=steps, base=10.0 ) ys = [predict_val_loss_chinchilla(model_params, t, dataset, optimizer, **kwargs) for t in xs] return xs, ys # ----------------------------- def pretty_time(seconds: float) -> str: if not np.isfinite(seconds) or seconds > 1e14: return "推定不能 / 非現実的" td = timedelta(seconds=float(seconds)) days = td.days hours, remainder = divmod(td.seconds, 3600) minutes, seconds = divmod(remainder, 60) return f"{days} days, {hours} hrs, {minutes} min, {int(seconds)} sec" def estimate_total_flops(params_count: float, total_tokens: float) -> float: """Standard: 6 FLOPs per param per token (forward + backward).""" return 6.0 * params_count * total_tokens def get_peak_flops(gpu_name: str, precision: str) -> float: gpu_info = GPU_PRECISION_PEAKS.get(gpu_name) if not gpu_info: return None # Normalize precision string: case-insensitive, strip spaces, map "FP8" → "fp8" key = precision.strip().lower().replace("-", "").replace("fp", "fp") if key.startswith("fp8") or key == "fp8": key = "fp8" return gpu_info.get(key) # ----------------------------- def compute_one_model( optimizer: str, dataset: str, gpu_model: str, precision: str, gpu_count: int, mfu: float, utilization_overhead: float, seq_len: float, batch_size: float, steps_per_epoch: float, epochs: float, model_params_millions: float, # input in millions total_tokens_override: float = None, # in raw tokens (not billions!) scaling_kwargs: dict = None ): scaling_kwargs = scaling_kwargs or {} model_params = float(model_params_millions) * 1e6 # → raw count if total_tokens_override is not None and total_tokens_override > 0: total_tokens = float(total_tokens_override) else: # total tokens = seq_len * batch_size * steps_per_epoch * epochs * gpu_count # Note: steps_per_epoch is usually *global*, but user may input per-GPU. # We assume steps_per_epoch is *global* (standard in most frameworks) total_tokens = float(seq_len) * float(batch_size) * float(gpu_count) * float(steps_per_epoch) * float(epochs) peak = get_peak_flops(gpu_model, precision) if peak is None: return { "error": f"⚠️ GPU '{gpu_model}' does not support precision '{precision}' (normalized to key '{precision.lower()}')." } total_flops = estimate_total_flops(model_params, total_tokens) effective_flops = float(gpu_count) * peak * float(mfu) * float(utilization_overhead) if effective_flops <= 0 or not np.isfinite(effective_flops): return {"error": "⚠️ Effective FLOPS ≤ 0 — check MFU/utilization."} seconds = total_flops / effective_flops predicted_val_loss = predict_val_loss_chinchilla( model_params=model_params, total_tokens=total_tokens, dataset=dataset, optimizer=optimizer, **scaling_kwargs ) return { "total_tokens": total_tokens, "params_count": model_params, "total_flops": total_flops, "seconds": seconds, "time_str": pretty_time(seconds), "predicted_val_loss": predicted_val_loss } # ----------------------------- def compute_precise_estimate( input_mode, total_tokens_input_B, # in billions (1e9) optimizer_a, dataset_a, gpu_model_a, precision_a, gpu_count_a, mfu_a, utilization_overhead_a, seq_len_a, batch_size_a, steps_per_epoch_a, epochs_a, model_params_a, do_compare, optimizer_b, dataset_b, gpu_model_b, precision_b, gpu_count_b, mfu_b, utilization_overhead_b, seq_len_b, batch_size_b, steps_per_epoch_b, epochs_b, model_params_b, L0_val, A_val, B_val, alpha_val, beta_val, use_bits ): # Convert user-facing "B" to raw tokens total_tokens_override = None if input_mode == "By total tokens": try: total_tokens_override = float(total_tokens_input_B) * 1e9 if total_tokens_override < 1e6: total_tokens_override = 1e6 # min sanity except Exception: pass # Scaling law config scaling_kwargs = { "L0": float(L0_val), "A": float(A_val), "B": float(B_val), "alpha": float(alpha_val), "beta": float(beta_val), } if not use_bits: # convert from bits → nats ln2 = math.log(2) scaling_kwargs["L0"] *= ln2 scaling_kwargs["A"] *= ln2 scaling_kwargs["B"] *= ln2 a_res = compute_one_model( optimizer_a, dataset_a, gpu_model_a, precision_a, gpu_count_a, mfu_a, utilization_overhead_a, seq_len_a, batch_size_a, steps_per_epoch_a, epochs_a, model_params_a, total_tokens_override=total_tokens_override, scaling_kwargs=scaling_kwargs ) if "error" in a_res: return a_res["error"], "", "", None b_res = None if do_compare: b_res = compute_one_model( optimizer_b, dataset_b, gpu_model_b, precision_b, gpu_count_b, mfu_b, utilization_overhead_b, seq_len_b, batch_size_b, steps_per_epoch_b, epochs_b, model_params_b, total_tokens_override=total_tokens_override, scaling_kwargs=scaling_kwargs ) # Output formatting flops_lines = [f"Model A: {a_res['total_flops']:.3e} FLOPs (params={a_res['params_count']:.1e}, tokens={a_res['total_tokens']:.1e})"] time_lines = [f"Model A: {a_res['time_str']}"] loss_lines = [f"Model A: {a_res['predicted_val_loss']:.4f} loss"] if do_compare: if "error" in (b_res or {}): err = b_res.get("error", "Unknown error") flops_lines.append(f"Model B: error — {err}") time_lines.append(f"Model B: error — {err}") loss_lines.append(f"Model B: error — {err}") else: flops_lines.append(f"Model B: {b_res['total_flops']:.3e} FLOPs (params={b_res['params_count']:.1e}, tokens={b_res['total_tokens']:.1e})") time_lines.append(f"Model B: {b_res['time_str']}") loss_lines.append(f"Model B: {b_res['predicted_val_loss']:.4f} loss") # Plot fig, ax = plt.subplots(figsize=(7, 5)) xs_a, ys_a = loss_curve_chinchilla( model_params=a_res["params_count"], total_tokens=max(a_res["total_tokens"], 1e6), dataset=dataset_a, optimizer=optimizer_a, **scaling_kwargs ) ax.plot(xs_a, ys_a, label=f"Model A ({model_params_a:.0f}M)", linewidth=2) if do_compare and b_res and "params_count" in b_res: xs_b, ys_b = loss_curve_chinchilla( model_params=b_res["params_count"], total_tokens=max(b_res["total_tokens"], 1e6), dataset=dataset_b, optimizer=optimizer_b, **scaling_kwargs ) ax.plot(xs_b, ys_b, label=f"Model B ({model_params_b:.0f}M)", linestyle='--', linewidth=2) # Add asymptotic loss line L0_plot = scaling_kwargs["L0"] ax.axhline(L0_plot, color='gray', linestyle=':', linewidth=1, label=f"Asymptotic loss $L_0$ = {L0_plot:.3f}") ax.set_xscale('log') ax.set_xlabel("Tokens seen (log scale)") ax.set_ylabel("Predicted validation loss (bits)" if use_bits else "Predicted validation loss (nats)") ax.set_title("Chinchilla scaling law: Loss vs Tokens") ax.grid(True, linestyle='-.', alpha=0.5) ax.legend() ax.set_ylim(bottom=max(0.0, L0_plot - 0.5), top=None) plt.tight_layout() plt.tick_params(axis='both', which='both', direction='in', # 内向き labelbottom=True, # x軸のラベル非表示 labelleft=True, # y軸のラベル非表示 bottom=True, top=False, left=True, right=False) fig.subplots_adjust(top=0.85) return "\n".join(flops_lines), "\n".join(time_lines), "\n".join(loss_lines), fig # ----------------------------- gpu_choices = list(GPU_PRECISION_PEAKS.keys()) precision_choices = ["float32", "float16", "bfloat16", "FP8"] with gr.Blocks(title="LLM Training Estimator — Chinchilla Scaling Law") as demo: gr.Markdown(r""" # 🚀 LLM Training Estimator (Precision-aware) - **GPU + Precision + Peak FLOPS + Effective throughput** → Training time & Loss - **Validation loss prediction** via **Chinchilla scaling law** (Hoffmann et al., 2022) $$ L(N, D) = L_\infty + \frac{A}{N^\alpha} + \frac{B}{D^\beta} $$ - Compare two configurations side-by-side - Supports **bits (log₂)** or **nats (ln)** loss units - Based on empirical fits for English web text (FineWeb/OpenWebText) > ✅ More accurate than GPT-3-style power-law (arXiv:2001.08361) ## 🌸 MFU & System Utilization Examples (8×H100) |Parameters |MFU | System Utilization | |------------ |------------- |------------------- | |182M | ~39% | ~0.5 | |560M | ~47% | ~0.5 | |1B~ | ~50% | ~0.5 | """) with gr.Row(): with gr.Column(scale=1): input_mode = gr.Radio( ["By total tokens", "By steps (derived from seq/batch/steps/epochs)"], value="By total tokens", label="Token input mode" ) total_tokens_input_B = gr.Number(value=100.0, label="Total tokens (B) — used if 'By total tokens' selected") with gr.Accordion("Model A Configuration", open=True): model_params_a = gr.Number(value=7000.0, label="Model params (millions)") optimizer_a = gr.Dropdown(list(OPTIMIZER_FACTORS.keys()), label="Optimizer", value="AdamW") dataset_a = gr.Dropdown(list(DATASET_FACTORS.keys()), label="Dataset", value="OpenWebText") gpu_model_a = gr.Dropdown(gpu_choices, label="GPU model", value="H100") precision_a = gr.Dropdown(precision_choices, label="Precision", value="bfloat16") gpu_count_a = gr.Slider(1, 1024, value=8, step=1, label="GPU count") mfu_a = gr.Slider(0.01, 1.0, value=0.35, step=0.01, label="MFU (Model FLOPs Utilization)") utilization_overhead_a = gr.Slider(0.05, 1.0, value=0.5, step=0.01, label="System utilization (incl. comms, IO)") seq_len_a = gr.Number(value=2048, label="Sequence length (tokens)") batch_size_a = gr.Number(value=256, label="Global batch size*") # clarified steps_per_epoch_a = gr.Number(value=1000, label="Steps per epoch") epochs_a = gr.Number(value=1, label="Epochs") do_compare = gr.Checkbox(label="Compare with Model B", value=False) with gr.Accordion("Model B Configuration", open=False): model_params_b = gr.Number(value=70000.0, label="Model params (millions)") optimizer_b = gr.Dropdown(list(OPTIMIZER_FACTORS.keys()), label="Optimizer", value="AdamW") dataset_b = gr.Dropdown(list(DATASET_FACTORS.keys()), label="Dataset", value="OpenWebText") gpu_model_b = gr.Dropdown(gpu_choices, label="GPU model", value="A100") precision_b = gr.Dropdown(precision_choices, label="Precision", value="float16") gpu_count_b = gr.Slider(1, 1024, value=64, step=1, label="GPU count") mfu_b = gr.Slider(0.01, 1.0, value=0.25, step=0.01, label="MFU") utilization_overhead_b = gr.Slider(0.05, 1.0, value=0.4, step=0.01, label="System utilization") seq_len_b = gr.Number(value=2048, label="Sequence length") batch_size_b = gr.Number(value=2048, label="Global batch size*") steps_per_epoch_b = gr.Number(value=1000, label="Steps per epoch") epochs_b = gr.Number(value=1, label="Epochs") with gr.Accordion("Scaling Law Parameters (Chinchilla)", open=False): use_bits = gr.Checkbox(value=False, label="Loss in bits(log2)") L0_val = gr.Number(value=CHINCHILLA_DEFAULTS["L0_bits"], label="L∞ (irreducible loss)") A_val = gr.Number(value=CHINCHILLA_DEFAULTS["A_bits"], label="A (model-size coefficient)") B_val = gr.Number(value=CHINCHILLA_DEFAULTS["B_bits"], label="B (data-size coefficient)") alpha_val = gr.Number(value=CHINCHILLA_DEFAULTS["alpha"], label="α (model exponent)") beta_val = gr.Number(value=CHINCHILLA_DEFAULTS["beta"], label="β (data exponent)") gr.Markdown(r""" # 💡 Tip: > | | $$ L_\infty $$ | $$ A $$ | $$ B $$ | > |----------------- |------------- |------------- |--------------- | > |code/math data |increase |- |- | > |high-quality data |- |reduce |reduce | > Default: English web text (FineWeb-like). """) with gr.Column(scale=1): flops_out = gr.Textbox(label="Total Compute (FLOPs)", lines=3) time_out = gr.Textbox(label="Estimated Training Time", lines=2) loss_out = gr.Textbox(label="Predicted Validation Loss", lines=2) plot_out = gr.Plot(label="Loss vs Tokens Curve") run_btn = gr.Button("🚀 Estimate / Compare", variant="primary") run_btn.click( fn=compute_precise_estimate, inputs=[ input_mode, total_tokens_input_B, optimizer_a, dataset_a, gpu_model_a, precision_a, gpu_count_a, mfu_a, utilization_overhead_a, seq_len_a, batch_size_a, steps_per_epoch_a, epochs_a, model_params_a, do_compare, optimizer_b, dataset_b, gpu_model_b, precision_b, gpu_count_b, mfu_b, utilization_overhead_b, seq_len_b, batch_size_b, steps_per_epoch_b, epochs_b, model_params_b, L0_val, A_val, B_val, alpha_val, beta_val, use_bits ], outputs=[flops_out, time_out, loss_out, plot_out] ) gr.Markdown(""" --- 📚 References: - [Chinchilla](https://arxiv.org/abs/2203.15556): *Training Compute-Optimal Large Language Models* - [GPT-3](https://arxiv.org/abs/2005.14165) (note: arXiv:2001.08361 is earlier version) - [MFU definition](https://arxiv.org/abs/2104.04473) - [Muon](https://arxiv.org/abs/2502.16982) """) if __name__ == "__main__": demo.launch(theme=gr.themes.Soft(), css=custom_css)