nouamanetazi HF Staff commited on
Commit
9553b06
·
1 Parent(s): 4921bbf

add oom predictor

Browse files
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -71,9 +71,10 @@ def format_config_display(config):
71
 
72
  def process_yaml_and_plot(config):
73
  if not config:
74
- return None, None, "No configuration loaded"
75
- fig1, fig2 = plot_memory_breakdown(**config)
76
- return fig1, fig2, format_config_display(config)
 
77
 
78
  with gr.Blocks() as demo:
79
  with gr.Row():
@@ -109,9 +110,9 @@ with gr.Blocks() as demo:
109
  zero_stage = gr.Radio([0, 1], value=0, label="ZeRO Stage")
110
 
111
  manual_submit = gr.Button("Calculate Memory (Manual Input)")
112
-
113
  with gr.Column(scale=2):
114
  config_display = gr.Markdown(label="Configuration Values")
 
115
  plot1 = gr.Plot(label="Memory Component Breakdown")
116
  plot2 = gr.Plot(label="Aggregate Memory Metrics")
117
 
@@ -119,14 +120,14 @@ with gr.Blocks() as demo:
119
  yaml_file.change(
120
  lambda x: process_yaml_and_plot(load_config_from_yaml_file(x) if x else None),
121
  inputs=[yaml_file],
122
- outputs=[plot1, plot2, config_display]
123
  )
124
 
125
  # Handle YAML text input
126
  yaml_submit.click(
127
  lambda x: process_yaml_and_plot(load_config_from_yaml_content(x) if x else None),
128
  inputs=[yaml_text],
129
- outputs=[plot1, plot2, config_display]
130
  )
131
 
132
  # Handle manual input
@@ -156,7 +157,7 @@ with gr.Blocks() as demo:
156
  seq_len, mbs, batch_accum, tp, pp, dp, zero_stage,
157
  tie_word_embeddings
158
  ],
159
- outputs=[plot1, plot2, config_display]
160
  )
161
 
162
  if __name__ == "__main__":
 
71
 
72
  def process_yaml_and_plot(config):
73
  if not config:
74
+ return None, None, "No configuration loaded", None
75
+ fig1, fig2, memory_usage_peak_tbi = plot_memory_breakdown(**config)
76
+ oom_prediction = "OOM" if memory_usage_peak_tbi > 75000 else "No OOM"
77
+ return fig1, fig2, format_config_display(config), oom_prediction
78
 
79
  with gr.Blocks() as demo:
80
  with gr.Row():
 
110
  zero_stage = gr.Radio([0, 1], value=0, label="ZeRO Stage")
111
 
112
  manual_submit = gr.Button("Calculate Memory (Manual Input)")
 
113
  with gr.Column(scale=2):
114
  config_display = gr.Markdown(label="Configuration Values")
115
+ oom_display = gr.Text(label="OOM Prediction")
116
  plot1 = gr.Plot(label="Memory Component Breakdown")
117
  plot2 = gr.Plot(label="Aggregate Memory Metrics")
118
 
 
120
  yaml_file.change(
121
  lambda x: process_yaml_and_plot(load_config_from_yaml_file(x) if x else None),
122
  inputs=[yaml_file],
123
+ outputs=[plot1, plot2, config_display, oom_display]
124
  )
125
 
126
  # Handle YAML text input
127
  yaml_submit.click(
128
  lambda x: process_yaml_and_plot(load_config_from_yaml_content(x) if x else None),
129
  inputs=[yaml_text],
130
+ outputs=[plot1, plot2, config_display, oom_display]
131
  )
132
 
133
  # Handle manual input
 
157
  seq_len, mbs, batch_accum, tp, pp, dp, zero_stage,
158
  tie_word_embeddings
159
  ],
160
+ outputs=[plot1, plot2, config_display, oom_display]
161
  )
162
 
163
  if __name__ == "__main__":