Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import math | |
| # ------------------------ | |
| # GPU presets: TFLOPs (units: TFLOPs) | |
| # ------------------------ | |
| GPUS = { | |
| # Turing / consumer | |
| "RTX 2060": {"FP32": 6.50, "FP16": 13.00, "INT4": 0.0}, | |
| "RTX 2060 12GB": {"FP32": 7.20, "FP16": 14.40, "INT4": 0.0}, | |
| "RTX 2060 SUPER": {"FP32": 8.90, "FP16": 17.80, "INT4": 0.0}, | |
| "RTX 2070": {"FP32": 8.90, "FP16": 16.00, "INT4": 0.0}, | |
| "RTX 2070 SUPER": {"FP32": 9.10, "FP16": 18.20, "INT4": 0.0}, | |
| "RTX 2080": {"FP32": 10.10, "FP16": 20.20, "INT4": 0.0}, | |
| "RTX 2080 SUPER": {"FP32": 11.15, "FP16": 22.30, "INT4": 0.0}, | |
| "RTX 2080 Ti": {"FP32": 13.45, "FP16": 26.90, "INT4": 544.0}, | |
| # Ampere / consumer | |
| "RTX 3050": {"FP32": 9.10, "FP16": 18.20, "INT4": 0.0}, | |
| "RTX 3060": {"FP32": 12.70, "FP16": 25.40, "INT4": 0.0}, | |
| "RTX 3060 Ti": {"FP32": 16.20, "FP16": 32.40, "INT4": 0.0}, | |
| "RTX 3070": {"FP32": 20.30, "FP16": 40.60, "INT4": 0.0}, | |
| "RTX 3070 Ti": {"FP32": 22.30, "FP16": 44.60, "INT4": 0.0}, | |
| "RTX 3080": {"FP32": 29.80, "FP16": 59.60, "INT4": 1248.0}, | |
| "RTX 3080 Ti": {"FP32": 34.10, "FP16": 68.20, "INT4": 1248.0}, | |
| "RTX 3090": {"FP32": 35.58, "FP16": 71.16, "INT4": 1248.0}, | |
| "RTX 3090 Ti": {"FP32": 40.00, "FP16": 80.00, "INT4": 1248.0}, | |
| # Ada / Lovelace consumer | |
| "RTX 4050": {"FP32": 16.90, "FP16": 33.80, "INT4": 0.0}, | |
| "RTX 4060": {"FP32": 31.10, "FP16": 62.20, "INT4": 0.0}, | |
| "RTX 4060 Ti": {"FP32": 45.60, "FP16": 91.20, "INT4": 0.0}, | |
| "RTX 4070": {"FP32": 75.00, "FP16": 150.00, "INT4": 0.0}, | |
| "RTX 4070 Ti": {"FP32": 92.20, "FP16": 184.40, "INT4": 0.0}, | |
| "RTX 4080": {"FP32":144.00, "FP16": 288.00, "INT4": 0.0}, | |
| "RTX 4080 SUPER": {"FP32":167.60, "FP16": 335.20, "INT4": 0.0}, | |
| "RTX 4090": {"FP32":201.00, "FP16": 402.00, "INT4":1676.0}, | |
| # Blackwell consumer (RTX 50xx series) | |
| "RTX 5050": {"FP32": 16.90, "FP16": 33.80, "INT4": 0.0}, | |
| "RTX 5060": {"FP32": 31.10, "FP16": 62.20, "INT4": 0.0}, | |
| "RTX 5060 Ti": {"FP32": 45.60, "FP16": 91.20, "INT4": 0.0}, | |
| "RTX 5070": {"FP32": 75.00, "FP16": 150.00, "INT4": 0.0}, | |
| "RTX 5070 Ti": {"FP32": 92.20, "FP16": 184.40, "INT4": 0.0}, | |
| "RTX 5080": {"FP32":144.00, "FP16": 288.00, "INT4": 0.0}, | |
| "RTX 5090": {"FP32":201.00, "FP16": 402.00, "INT4":1676.0}, | |
| # Data center / Tesla / A-series | |
| "Tesla T4": {"FP32": 8.10, "FP16": 65.13, "INT4": 0.0}, | |
| "Tesla V100": {"FP32": 15.70, "FP16": 31.40, "INT4": 0.0}, | |
| "NVIDIA A10": {"FP32": 31.20, "FP16": 62.40, "INT4": 0.0}, | |
| "A100": {"FP32": 19.50, "FP16": 39.00, "INT4": 624.0}, | |
| "A100 80GB": {"FP32": 19.50, "FP16": 39.00, "INT4": 624.0}, | |
| # Hopper / Blackwell datacenter estimates | |
| "H100": {"FP32":300.0, "FP16": 600.0, "INT4":3000.0}, | |
| "B100": {"FP32":400.0, "FP16": 800.0, "INT4":4000.0}, | |
| "B200": {"FP32":500.0, "FP16":1000.0, "INT4":5000.0}, | |
| # AMD (kept for completeness) | |
| "RX 5500 XT": {"FP32": 5.20, "FP16": 10.40, "INT4": 0.0}, | |
| "RX 5600 XT": {"FP32": 10.80, "FP16": 21.60, "INT4": 0.0}, | |
| "RX 5700": {"FP32": 14.40, "FP16": 28.80, "INT4": 0.0}, | |
| "RX 5700 XT": {"FP32": 16.20, "FP16": 32.40, "INT4": 0.0}, | |
| "RX 6600": {"FP32": 17.90, "FP16": 35.80, "INT4": 0.0}, | |
| "RX 6600 XT": {"FP32": 20.00, "FP16": 40.00, "INT4": 0.0}, | |
| "RX 6700 XT": {"FP32": 23.00, "FP16": 46.00, "INT4": 0.0}, | |
| "RX 6800": {"FP32": 30.00, "FP16": 60.00, "INT4": 0.0}, | |
| "RX 6800 XT": {"FP32": 34.00, "FP16": 68.00, "INT4": 0.0}, | |
| "RX 6900 XT": {"FP32": 40.00, "FP16": 80.00, "INT4": 0.0}, | |
| "RX 7600": {"FP32": 25.00, "FP16": 50.00, "INT4": 0.0}, | |
| "RX 7700 XT": {"FP32": 35.00, "FP16": 70.00, "INT4": 0.0}, | |
| "RX 7900 XT": {"FP32": 40.00, "FP16": 80.00, "INT4": 0.0}, | |
| "RX 7900 XTX": {"FP32": 61.10, "FP16": 122.20, "INT4": 0.0}, | |
| # AMD MI / CDNA datacenter | |
| "MI50": {"FP32": 13.70, "FP16": 27.40, "INT4": 0.0}, | |
| "MI100": {"FP32": 23.10, "FP16": 46.20, "INT4": 0.0}, | |
| "MI200": {"FP32": 300.0, "FP16": 600.0, "INT4":3000.0}, | |
| "MI300": {"FP32": 400.0, "FP16": 800.0, "INT4":4000.0}, | |
| "MI355X": {"FP32": 157, "FP16": 2500, "INT4": 10000}, | |
| # Hopper / Grace superchips | |
| "H200": {"FP32": 350.0, "FP16": 700.0, "INT4": 3500.0}, | |
| "GH200": {"FP32": 300.0, "FP16": 600.0, "INT4": 3000.0}, # H100-class GPU + Grace CPU | |
| "GB10": {"FP32": 400.0, "FP16": 800.0, "INT4": 4000.0}, # dev module, Blackwell-class | |
| # Ada Lovelace datacenter | |
| "L20": {"FP32": 44.0, "FP16": 88.0, "INT4": 700.0}, | |
| "A40": {"FP32": 37.4, "FP16": 74.8, "INT4": 600.0}, | |
| "A2": {"FP32": 4.5, "FP16": 9.0, "INT4": 160.0}, | |
| # RTX Ada workstation GPUs | |
| "RTX A2000": {"FP32": 8.0, "FP16": 16.0, "INT4": 0.0}, | |
| "RTX A4000": {"FP32": 19.2, "FP16": 38.4, "INT4": 0.0}, | |
| "RTX A4500": {"FP32": 23.7, "FP16": 47.4, "INT4": 0.0}, | |
| "RTX A5000": {"FP32": 27.8, "FP16": 55.6, "INT4": 0.0}, | |
| "RTX A6000 Ada": {"FP32": 91.1, "FP16": 182.2, "INT4": 1450.0}, | |
| } | |
| # ------------------------ | |
| # CSS / Theme variables | |
| # ------------------------ | |
| CSS = r""" | |
| :root { --bg:#071233; --card:#07112a; --accent:#2563eb; --text:#e8f0ff; --muted:#9fb6e8; } | |
| body { background: var(--bg); color:var(--text); font-family: Inter, system-ui, -apple-system, "Segoe UI", Roboto, "Helvetica Neue", Arial; } | |
| .gradio-container { max-width: 920px; margin: 14px auto; padding: 12px; } | |
| /* card */ | |
| .card { background: var(--card); border-radius:12px; padding:14px; box-shadow: 0 8px 26px rgba(2,6,23,0.5); border:1px solid rgba(255,255,255,0.03); } | |
| /* accent and buttons */ | |
| .btn-theme { background:transparent; color:var(--accent); border:1px solid var(--accent); padding:8px 12px; border-radius:10px; cursor:pointer; } | |
| .btn-theme:hover { background: rgba(255,255,255,0.02); } | |
| /* result */ | |
| .result-box { background: linear-gradient(180deg, rgba(255,255,255,0.01), rgba(255,255,255,0.02)); border-radius:8px; padding:10px; border:1px solid rgba(255,255,255,0.03); color:var(--text); font-weight:600; } | |
| /* small text */ | |
| .small-muted { color: var(--muted); font-size:0.92em; } | |
| /* themes */ | |
| .theme-blue { --bg:#071233; --card:#07112a; --accent:#2563eb; --text:#e8f0ff; --muted:#9fb6e8; } | |
| .theme-green{ --bg:#07120a; --card:#07120a; --accent:#16a34a; --text:#e8fff0; --muted:#9fe8b0; } | |
| .theme-purple{ --bg:#120521; --card:#15061a; --accent:#8b5cf6; --text:#f2e8ff; --muted:#c9b8f6; } | |
| /* minor Gradio element tweaks */ | |
| input[type="number"], .gradio-number { background: transparent; color: var(--text); border-radius:6px; } | |
| /* theme button row */ | |
| .theme-btn-row { display:flex; gap:8px; align-items:center; } | |
| """ | |
| # ------------------------ | |
| # Core logic | |
| # ------------------------ | |
| def estimate_time(params_m: float, | |
| tokens_b: float, | |
| selected_gpu: str, | |
| dtype: str, | |
| tf_override: float, | |
| utilization_pct: float, | |
| gpu_count: float): | |
| if params_m <= 0 or tokens_b <= 0: | |
| return "Enter positive values for parameters and tokens." | |
| if gpu_count is None or gpu_count <= 0: | |
| return "Enter a positive number of GPUs." | |
| params = params_m * 1e6 | |
| tokens = tokens_b * 1e9 | |
| # choose TFLOPs per-GPU | |
| if tf_override is not None and tf_override > 0: | |
| chosen_tf_per_gpu = float(tf_override) | |
| source = "manual override" | |
| else: | |
| try: | |
| chosen_tf_per_gpu = float(GPUS[selected_gpu].get(dtype, 0.0)) | |
| source = f"preset ({selected_gpu} / {dtype})" | |
| except Exception: | |
| return "Couldn't determine GPU TFLOPs. Pick a GPU or enter TFLOPs manually." | |
| if chosen_tf_per_gpu <= 0: | |
| return "Couldn't determine GPU TFLOPs. Pick a GPU or enter TFLOPs manually." | |
| # multiply by count and utilization -> FLOPs/sec | |
| total_tf = chosen_tf_per_gpu * float(gpu_count) | |
| gpu_flops_per_sec = total_tf * 1e12 * (max(0.001, utilization_pct / 100.0)) | |
| flops_total = 6 * params * tokens | |
| seconds = flops_total / gpu_flops_per_sec | |
| hours = seconds / 3600.0 | |
| days = hours / 24.0 | |
| seq_len = 2048.0 | |
| steps = max(1.0, tokens / seq_len) | |
| flops_per_step = flops_total / steps if steps > 0 else 0.0 | |
| # warnings for absurd counts | |
| warnings = [] | |
| if gpu_count >= 10000: | |
| warnings.append("⚠️ Wow that's a lot of GPUs — are you sure? Check units (e.g., 8 not 800k).") | |
| if total_tf > 1e6: | |
| warnings.append("⚠️ Total TFLOPs exceed 1e6 TFLOPs (exaFLOPs scale) — results are rough estimates.") | |
| out = [ | |
| f"🔥 Roman's Training Time Estimator", | |
| "", | |
| f"Model params: {params_m:,.1f} M", | |
| f"Training tokens: {tokens_b:,.3f} B", | |
| f"Total training FLOPs (approx): {flops_total:.3e}", | |
| "", | |
| f"Hardware source: {source}", | |
| f"Per-GPU TFLOPs: {chosen_tf_per_gpu:.3f} TFLOPs", | |
| f"GPU count: {int(gpu_count):,}", | |
| f"Total effective TFLOPs (before utilization): {total_tf:,.3f} TFLOPs", | |
| f"Utilization: {utilization_pct:.0f}%", | |
| "", | |
| f"⏱️ Wall-clock estimate: {hours:,.2f} hours (~{days:,.2f} days)", | |
| f"Steps (rough, seq_len=2048): {steps:,.0f} steps", | |
| f"FLOPs / step (avg): {flops_per_step:.3e}", | |
| ] | |
| if warnings: | |
| out.append("") | |
| out.extend(warnings) | |
| if tf_override and tf_override > 0 and selected_gpu != "Custom": | |
| out.append("") | |
| out.append("⚠️ Note: you overrode the preset TFLOPs. Ensure the value is in TFLOPs (e.g., 150 for A100 FP16-like).") | |
| return "\n".join(out) | |
| def preset_tf_for_ui(selected_gpu: str, dtype: str): | |
| if selected_gpu in GPUS: | |
| return GPUS[selected_gpu].get(dtype, 0.0) | |
| return 0.0 | |
| # ------------------------ | |
| # Build UI | |
| # ------------------------ | |
| # Inline HTML for theme buttons with client-side onclick handlers | |
| THEME_BUTTONS_HTML = """ | |
| <div class="theme-btn-row"> | |
| <button class="btn-theme" onclick="document.documentElement.className='theme-blue'">Blue</button> | |
| <button class="btn-theme" onclick="document.documentElement.className='theme-green'">Green</button> | |
| <button class="btn-theme" onclick="document.documentElement.className='theme-purple'">Purple</button> | |
| </div> | |
| """ | |
| with gr.Blocks() as demo: | |
| # initial theme set (runs immediately on load) | |
| gr.HTML("<script>document.documentElement.className='theme-blue';</script>") | |
| with gr.Column(elem_classes="card"): | |
| with gr.Row(): | |
| gr.Markdown("## 🧠 Roman’s Training Time Estimator") | |
| # render the theme buttons as raw HTML so onclick works client-side instantly | |
| gr.HTML(THEME_BUTTONS_HTML) | |
| with gr.Column(elem_classes="card"): | |
| gr.Markdown("### Model & Hardware") | |
| with gr.Row(): | |
| params = gr.Slider(minimum=1, maximum=20000, value=100, step=0.1, label="Model Parameters (Millions)") | |
| tokens = gr.Number(value=1.0, label="Training Tokens (Billions)") | |
| with gr.Row(): | |
| gpu_dropdown = gr.Dropdown(choices=list(GPUS.keys()), value="A100 80GB", label="GPU Preset (changes TFLOPs below)") | |
| dtype_dropdown = gr.Dropdown(choices=["FP32", "FP16", "INT4"], value="FP16", label="Training Precision / DType") | |
| with gr.Row(): | |
| tf_override = gr.Number(value=preset_tf_for_ui("A100 80GB", "FP16"), label="GPU TFLOPs (teraFLOPs) — editable", precision=3) | |
| utilization = gr.Slider(minimum=1, maximum=100, value=80, step=1, label="Hardware Utilization (%) — realistic throughput") | |
| with gr.Row(): | |
| gpu_count = gr.Number(value=1, label="GPU Count (how many of the chosen preset you have)", precision=0) | |
| with gr.Column(elem_classes="card"): | |
| gr.Markdown("### Estimate") | |
| result = gr.Textbox(lines=14, interactive=False, elem_classes="result-box", label="Result") | |
| run_btn = gr.Button("Estimate Training Time", elem_classes="btn-theme") | |
| # update TF override when gpu/dtype change | |
| def _update_tf(selected_gpu, dtype): | |
| return gr.update(value=preset_tf_for_ui(selected_gpu, dtype)) | |
| gpu_dropdown.change(_update_tf, inputs=[gpu_dropdown, dtype_dropdown], outputs=[tf_override]) | |
| dtype_dropdown.change(_update_tf, inputs=[gpu_dropdown, dtype_dropdown], outputs=[tf_override]) | |
| # Run button computes estimate | |
| run_btn.click(estimate_time, | |
| inputs=[params, tokens, gpu_dropdown, dtype_dropdown, tf_override, utilization, gpu_count], | |
| outputs=[result]) | |
| gr.HTML("<div class='small-muted'>Tip: GPU presets are TFLOPs per dtype. You can edit the TFLOPs number to override. Utilization reduces theoretical peak to realistic throughput.</div>") | |
| gr.HTML("<div class='small-muted'>Thanks to the contributions from Reality123b</div>") | |
| # pass CSS to launch | |
| if __name__ == "__main__": | |
| demo.launch(css=CSS) | |