Roman190928 commited on
Commit
9ed2149
·
verified ·
1 Parent(s): 9cfa7c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -48
app.py CHANGED
@@ -1,20 +1,17 @@
1
- # roman_estimator_with_presets.py
2
  import gradio as gr
3
  import math
4
 
5
  # ------------------------
6
  # GPU presets: TFLOPs (units: TFLOPs)
7
- # Each GPU: dict of dtype -> TFLOPs (tera-FLOPs)
8
- # These are presets / approximations — editable by user via the TFLOPs field.
9
  # ------------------------
10
  GPUS = {
11
  "T4": {"FP32": 8.1, "FP16": 65.0, "INT4": 260.0},
12
  "RTX 3060": {"FP32": 13.0, "FP16": 26.0, "INT4": 52.0},
13
  "RTX 3090": {"FP32": 35.0, "FP16": 70.0, "INT4": 140.0},
14
  "RTX 4090": {"FP32": 83.0, "FP16": 166.0, "INT4": 332.0},
15
- "A100 80GB": {"FP32": 19.5, "FP16": 150.0, "INT4": 600.0}, # FP32 architecture throughput is lower than FP16
16
  "H100 SXM": {"FP32": 30.0, "FP16": 300.0, "INT4": 1200.0},
17
- "Custom": {"FP32": 1.0, "FP16": 1.0, "INT4": 1.0}, # placeholder for manual GPUs
18
  }
19
 
20
  # ------------------------
@@ -56,48 +53,32 @@ def estimate_time(params_m: float,
56
  dtype: str,
57
  tf_override: float,
58
  utilization_pct: float):
59
- """
60
- params_m: model params in millions (e.g., 100 -> 100M)
61
- tokens_b: training tokens in billions (e.g., 1.5 -> 1.5B)
62
- selected_gpu: key in GPUS
63
- dtype: "FP32" / "FP16" / "INT4"
64
- tf_override: numeric TFLOPs (if >0 will override preset)
65
- utilization_pct: fraction 0..100 representing real-world utilization
66
- """
67
- # validation
68
  if params_m <= 0 or tokens_b <= 0:
69
  return "Enter positive values for parameters and tokens."
70
 
71
  params = params_m * 1e6
72
  tokens = tokens_b * 1e9
73
 
74
- # choose TFLOPs: override if user entered > 0, else use preset
75
  if tf_override is not None and tf_override > 0:
76
  chosen_tf = float(tf_override)
77
  source = "manual override"
78
  else:
79
- # safe fallback
80
  try:
81
  chosen_tf = float(GPUS[selected_gpu].get(dtype, 0.0))
82
  source = f"preset ({selected_gpu} / {dtype})"
83
  except Exception:
84
- chosen_tf = 0.0
85
- source = "preset-missing"
86
 
87
  if chosen_tf <= 0:
88
  return "Couldn't determine GPU TFLOPs. Pick a GPU or enter TFLOPs manually."
89
 
90
- # convert to FLOPs/sec
91
- dtype_mul = 1.0 # GPUS already store per-dtype TFLOPs, so no extra multiplier
92
  gpu_flops_per_sec = chosen_tf * 1e12 * (max(0.001, utilization_pct / 100.0))
93
 
94
- # FLOPs estimate (industry rule of thumb)
95
- flops_total = 6 * params * tokens # total training FLOPs (approx)
96
  seconds = flops_total / gpu_flops_per_sec
97
  hours = seconds / 3600.0
98
  days = hours / 24.0
99
 
100
- # extras
101
  seq_len = 2048.0
102
  steps = max(1.0, tokens / seq_len)
103
  flops_per_step = flops_total / steps if steps > 0 else 0.0
@@ -117,39 +98,45 @@ def estimate_time(params_m: float,
117
  f"FLOPs / step (avg): {flops_per_step:.3e}",
118
  ]
119
 
120
- # warning for suspicious override values
121
  if tf_override and tf_override > 0 and selected_gpu != "Custom":
122
  out.append("")
123
- out.append("⚠️ Note: you overrode the preset TFLOPs. Make sure the value is in TFLOPs (e.g., 150 for A100 FP16-like).")
124
 
125
  return "\n".join(out)
126
 
127
- # ------------------------
128
- # Helper to return preset TFLOPs for UI update
129
- # ------------------------
130
  def preset_tf_for_ui(selected_gpu: str, dtype: str):
131
- """Return the preset TFLOPs number for the selected GPU+dtype (or 0 if missing)."""
132
- val = 0.0
133
  if selected_gpu in GPUS:
134
- val = GPUS[selected_gpu].get(dtype, 0.0)
135
- return val
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  # ------------------------
138
  # Build UI
139
  # ------------------------
140
- with gr.Blocks(css=CSS, title="Roman's Training Time Estimator") as demo:
141
- # set a default theme class
142
  gr.HTML("<script>document.documentElement.className='theme-blue';</script>")
143
 
144
  with gr.Column(elem_classes="card"):
145
- with gr.Row(elem_classes="header-row"):
146
  gr.Markdown("## 🧠 Roman’s Training Time Estimator")
147
- with gr.Row(elem_classes="theme-buttons"):
148
  btn_blue = gr.Button("Blue", elem_classes="btn-theme")
149
  btn_green = gr.Button("Green", elem_classes="btn-theme")
150
  btn_purple = gr.Button("Purple", elem_classes="btn-theme")
151
 
152
- # Model & hardware card
153
  with gr.Column(elem_classes="card"):
154
  gr.Markdown("### Model & Hardware")
155
  with gr.Row():
@@ -162,31 +149,32 @@ with gr.Blocks(css=CSS, title="Roman's Training Time Estimator") as demo:
162
  tf_override = gr.Number(value=preset_tf_for_ui("A100 80GB", "FP16"), label="GPU TFLOPs (teraFLOPs) — editable", precision=3)
163
  utilization = gr.Slider(minimum=1, maximum=100, value=80, step=1, label="Hardware Utilization (%) — realistic throughput")
164
 
165
- # Result card
166
  with gr.Column(elem_classes="card"):
167
  gr.Markdown("### Estimate")
168
  result = gr.Textbox(lines=12, interactive=False, elem_classes="result-box", label="Result")
169
  run_btn = gr.Button("Estimate Training Time", elem_classes="btn-theme")
170
 
171
- # Wire interactions
172
- # When GPU or dtype changes, update tf_override value to the preset for that combo
 
 
173
  def _update_tf(selected_gpu, dtype):
174
  return gr.update(value=preset_tf_for_ui(selected_gpu, dtype))
175
-
176
  gpu_dropdown.change(_update_tf, inputs=[gpu_dropdown, dtype_dropdown], outputs=[tf_override])
177
  dtype_dropdown.change(_update_tf, inputs=[gpu_dropdown, dtype_dropdown], outputs=[tf_override])
178
 
179
- # Run button computes estimate
180
  run_btn.click(estimate_time,
181
  inputs=[params, tokens, gpu_dropdown, dtype_dropdown, tf_override, utilization],
182
  outputs=[result])
183
 
184
- # Theme buttons (JS simply toggles class on documentElement)
185
- btn_blue.click(None, None, None, _js="() => { document.documentElement.className='theme-blue'; return []; }")
186
- btn_green.click(None, None, None, _js="() => { document.documentElement.className='theme-green'; return []; }")
187
- btn_purple.click(None, None, None, _js="() => { document.documentElement.className='theme-purple'; return []; }")
188
 
189
- gr.HTML("<div class='small-muted'>Tip: GPU preset values are TFLOPs per dtype. You can edit the TFLOPs number to override. Utilization reduces theoretical peak to realistic throughput.</div>")
190
 
 
191
  if __name__ == "__main__":
192
- demo.launch()
 
 
1
  import gradio as gr
2
  import math
3
 
4
  # ------------------------
5
  # GPU presets: TFLOPs (units: TFLOPs)
 
 
6
  # ------------------------
7
  GPUS = {
8
  "T4": {"FP32": 8.1, "FP16": 65.0, "INT4": 260.0},
9
  "RTX 3060": {"FP32": 13.0, "FP16": 26.0, "INT4": 52.0},
10
  "RTX 3090": {"FP32": 35.0, "FP16": 70.0, "INT4": 140.0},
11
  "RTX 4090": {"FP32": 83.0, "FP16": 166.0, "INT4": 332.0},
12
+ "A100 80GB": {"FP32": 19.5, "FP16": 150.0, "INT4": 600.0},
13
  "H100 SXM": {"FP32": 30.0, "FP16": 300.0, "INT4": 1200.0},
14
+ "Custom": {"FP32": 1.0, "FP16": 1.0, "INT4": 1.0},
15
  }
16
 
17
  # ------------------------
 
53
  dtype: str,
54
  tf_override: float,
55
  utilization_pct: float):
 
 
 
 
 
 
 
 
 
56
  if params_m <= 0 or tokens_b <= 0:
57
  return "Enter positive values for parameters and tokens."
58
 
59
  params = params_m * 1e6
60
  tokens = tokens_b * 1e9
61
 
 
62
  if tf_override is not None and tf_override > 0:
63
  chosen_tf = float(tf_override)
64
  source = "manual override"
65
  else:
 
66
  try:
67
  chosen_tf = float(GPUS[selected_gpu].get(dtype, 0.0))
68
  source = f"preset ({selected_gpu} / {dtype})"
69
  except Exception:
70
+ return "Couldn't determine GPU TFLOPs. Pick a GPU or enter TFLOPs manually."
 
71
 
72
  if chosen_tf <= 0:
73
  return "Couldn't determine GPU TFLOPs. Pick a GPU or enter TFLOPs manually."
74
 
 
 
75
  gpu_flops_per_sec = chosen_tf * 1e12 * (max(0.001, utilization_pct / 100.0))
76
 
77
+ flops_total = 6 * params * tokens
 
78
  seconds = flops_total / gpu_flops_per_sec
79
  hours = seconds / 3600.0
80
  days = hours / 24.0
81
 
 
82
  seq_len = 2048.0
83
  steps = max(1.0, tokens / seq_len)
84
  flops_per_step = flops_total / steps if steps > 0 else 0.0
 
98
  f"FLOPs / step (avg): {flops_per_step:.3e}",
99
  ]
100
 
 
101
  if tf_override and tf_override > 0 and selected_gpu != "Custom":
102
  out.append("")
103
+ out.append("⚠️ Note: you overrode the preset TFLOPs. Ensure the value is in TFLOPs (e.g., 150 for A100 FP16-like).")
104
 
105
  return "\n".join(out)
106
 
 
 
 
107
  def preset_tf_for_ui(selected_gpu: str, dtype: str):
 
 
108
  if selected_gpu in GPUS:
109
+ return GPUS[selected_gpu].get(dtype, 0.0)
110
+ return 0.0
111
+
112
+ # ------------------------
113
+ # Theme setter (returns HTML snippet to run client-side JS)
114
+ # ------------------------
115
+ def set_theme(theme_name: str):
116
+ # map button label -> class name used in CSS
117
+ cls = {
118
+ "Blue": "theme-blue",
119
+ "Green": "theme-green",
120
+ "Purple": "theme-purple",
121
+ }.get(theme_name, "theme-blue")
122
+ # return script that sets the root class
123
+ return f"<script>document.documentElement.className='{cls}';</script>"
124
 
125
  # ------------------------
126
  # Build UI
127
  # ------------------------
128
+ with gr.Blocks() as demo:
129
+ # initial theme set (runs immediately on load)
130
  gr.HTML("<script>document.documentElement.className='theme-blue';</script>")
131
 
132
  with gr.Column(elem_classes="card"):
133
+ with gr.Row():
134
  gr.Markdown("## 🧠 Roman’s Training Time Estimator")
135
+ with gr.Row():
136
  btn_blue = gr.Button("Blue", elem_classes="btn-theme")
137
  btn_green = gr.Button("Green", elem_classes="btn-theme")
138
  btn_purple = gr.Button("Purple", elem_classes="btn-theme")
139
 
 
140
  with gr.Column(elem_classes="card"):
141
  gr.Markdown("### Model & Hardware")
142
  with gr.Row():
 
149
  tf_override = gr.Number(value=preset_tf_for_ui("A100 80GB", "FP16"), label="GPU TFLOPs (teraFLOPs) — editable", precision=3)
150
  utilization = gr.Slider(minimum=1, maximum=100, value=80, step=1, label="Hardware Utilization (%) — realistic throughput")
151
 
 
152
  with gr.Column(elem_classes="card"):
153
  gr.Markdown("### Estimate")
154
  result = gr.Textbox(lines=12, interactive=False, elem_classes="result-box", label="Result")
155
  run_btn = gr.Button("Estimate Training Time", elem_classes="btn-theme")
156
 
157
+ # invisible HTML target used to inject theme-changing script
158
+ theme_script = gr.HTML(value="")
159
+
160
+ # update TF override when gpu/dtype change
161
  def _update_tf(selected_gpu, dtype):
162
  return gr.update(value=preset_tf_for_ui(selected_gpu, dtype))
 
163
  gpu_dropdown.change(_update_tf, inputs=[gpu_dropdown, dtype_dropdown], outputs=[tf_override])
164
  dtype_dropdown.change(_update_tf, inputs=[gpu_dropdown, dtype_dropdown], outputs=[tf_override])
165
 
166
+ # button clicks
167
  run_btn.click(estimate_time,
168
  inputs=[params, tokens, gpu_dropdown, dtype_dropdown, tf_override, utilization],
169
  outputs=[result])
170
 
171
+ # theme buttons now call the Python set_theme and return HTML that runs client-side
172
+ btn_blue.click(set_theme, inputs=["Blue"], outputs=[theme_script])
173
+ btn_green.click(set_theme, inputs=["Green"], outputs=[theme_script])
174
+ btn_purple.click(set_theme, inputs=["Purple"], outputs=[theme_script])
175
 
176
+ gr.HTML("<div class='small-muted'>Tip: GPU presets are TFLOPs per dtype. You can edit the TFLOPs number to override. Utilization reduces theoretical peak to realistic throughput.</div>")
177
 
178
+ # pass CSS to launch (Gradio 6.0+ API)
179
  if __name__ == "__main__":
180
+ demo.launch(css=CSS, title="Roman's Training Time Estimator")