Spaces:
Running
Running
Commit ·
9553b06
1
Parent(s): 4921bbf
add oom predictor
Browse files
app.py
CHANGED
|
@@ -71,9 +71,10 @@ def format_config_display(config):
|
|
| 71 |
|
| 72 |
def process_yaml_and_plot(config):
|
| 73 |
if not config:
|
| 74 |
-
return None, None, "No configuration loaded"
|
| 75 |
-
fig1, fig2 = plot_memory_breakdown(**config)
|
| 76 |
-
|
|
|
|
| 77 |
|
| 78 |
with gr.Blocks() as demo:
|
| 79 |
with gr.Row():
|
|
@@ -109,9 +110,9 @@ with gr.Blocks() as demo:
|
|
| 109 |
zero_stage = gr.Radio([0, 1], value=0, label="ZeRO Stage")
|
| 110 |
|
| 111 |
manual_submit = gr.Button("Calculate Memory (Manual Input)")
|
| 112 |
-
|
| 113 |
with gr.Column(scale=2):
|
| 114 |
config_display = gr.Markdown(label="Configuration Values")
|
|
|
|
| 115 |
plot1 = gr.Plot(label="Memory Component Breakdown")
|
| 116 |
plot2 = gr.Plot(label="Aggregate Memory Metrics")
|
| 117 |
|
|
@@ -119,14 +120,14 @@ with gr.Blocks() as demo:
|
|
| 119 |
yaml_file.change(
|
| 120 |
lambda x: process_yaml_and_plot(load_config_from_yaml_file(x) if x else None),
|
| 121 |
inputs=[yaml_file],
|
| 122 |
-
outputs=[plot1, plot2, config_display]
|
| 123 |
)
|
| 124 |
|
| 125 |
# Handle YAML text input
|
| 126 |
yaml_submit.click(
|
| 127 |
lambda x: process_yaml_and_plot(load_config_from_yaml_content(x) if x else None),
|
| 128 |
inputs=[yaml_text],
|
| 129 |
-
outputs=[plot1, plot2, config_display]
|
| 130 |
)
|
| 131 |
|
| 132 |
# Handle manual input
|
|
@@ -156,7 +157,7 @@ with gr.Blocks() as demo:
|
|
| 156 |
seq_len, mbs, batch_accum, tp, pp, dp, zero_stage,
|
| 157 |
tie_word_embeddings
|
| 158 |
],
|
| 159 |
-
outputs=[plot1, plot2, config_display]
|
| 160 |
)
|
| 161 |
|
| 162 |
if __name__ == "__main__":
|
|
|
|
| 71 |
|
| 72 |
def process_yaml_and_plot(config):
|
| 73 |
if not config:
|
| 74 |
+
return None, None, "No configuration loaded", None
|
| 75 |
+
fig1, fig2, memory_usage_peak_tbi = plot_memory_breakdown(**config)
|
| 76 |
+
oom_prediction = "OOM" if memory_usage_peak_tbi > 75000 else "No OOM"
|
| 77 |
+
return fig1, fig2, format_config_display(config), oom_prediction
|
| 78 |
|
| 79 |
with gr.Blocks() as demo:
|
| 80 |
with gr.Row():
|
|
|
|
| 110 |
zero_stage = gr.Radio([0, 1], value=0, label="ZeRO Stage")
|
| 111 |
|
| 112 |
manual_submit = gr.Button("Calculate Memory (Manual Input)")
|
|
|
|
| 113 |
with gr.Column(scale=2):
|
| 114 |
config_display = gr.Markdown(label="Configuration Values")
|
| 115 |
+
oom_display = gr.Text(label="OOM Prediction")
|
| 116 |
plot1 = gr.Plot(label="Memory Component Breakdown")
|
| 117 |
plot2 = gr.Plot(label="Aggregate Memory Metrics")
|
| 118 |
|
|
|
|
| 120 |
yaml_file.change(
|
| 121 |
lambda x: process_yaml_and_plot(load_config_from_yaml_file(x) if x else None),
|
| 122 |
inputs=[yaml_file],
|
| 123 |
+
outputs=[plot1, plot2, config_display, oom_display]
|
| 124 |
)
|
| 125 |
|
| 126 |
# Handle YAML text input
|
| 127 |
yaml_submit.click(
|
| 128 |
lambda x: process_yaml_and_plot(load_config_from_yaml_content(x) if x else None),
|
| 129 |
inputs=[yaml_text],
|
| 130 |
+
outputs=[plot1, plot2, config_display, oom_display]
|
| 131 |
)
|
| 132 |
|
| 133 |
# Handle manual input
|
|
|
|
| 157 |
seq_len, mbs, batch_accum, tp, pp, dp, zero_stage,
|
| 158 |
tie_word_embeddings
|
| 159 |
],
|
| 160 |
+
outputs=[plot1, plot2, config_display, oom_display]
|
| 161 |
)
|
| 162 |
|
| 163 |
if __name__ == "__main__":
|