darklorddad commited on
Commit
8bdee23
·
1 Parent(s): 82293e4

fix: Display warning when trainer_state.json is missing for model metrics

Browse files
Files changed (2) hide show
  1. app.py +3 -2
  2. gradio_wrapper.py +11 -6
app.py CHANGED
@@ -78,6 +78,7 @@ with gr.Blocks(theme=gr.themes.Monochrome(), title="Multi-Class Classification (
78
  choices=hf_models,
79
  value="darklorddad/Model-Swin-Transformer-88"
80
  )
 
81
  with gr.Column(visible=False) as inf_plots_container:
82
  with gr.Row():
83
  inf_plot_loss = gr.Plot(label="Loss")
@@ -105,12 +106,12 @@ with gr.Blocks(theme=gr.themes.Monochrome(), title="Multi-Class Classification (
105
  metrics_model_path.change(
106
  fn=show_model_charts,
107
  inputs=[metrics_model_path],
108
- outputs=inf_plots + [inf_plots_container]
109
  )
110
  demo.load(
111
  fn=show_model_charts,
112
  inputs=[metrics_model_path],
113
- outputs=inf_plots + [inf_plots_container]
114
  )
115
  with gr.Tab("Data Preparation"):
116
  with gr.Accordion("Organise Raw Dataset", open=False):
 
78
  choices=hf_models,
79
  value="darklorddad/Model-Swin-Transformer-88"
80
  )
81
+ metrics_status = gr.Markdown(visible=False)
82
  with gr.Column(visible=False) as inf_plots_container:
83
  with gr.Row():
84
  inf_plot_loss = gr.Plot(label="Loss")
 
106
  metrics_model_path.change(
107
  fn=show_model_charts,
108
  inputs=[metrics_model_path],
109
+ outputs=inf_plots + [inf_plots_container, metrics_status]
110
  )
111
  demo.load(
112
  fn=show_model_charts,
113
  inputs=[metrics_model_path],
114
+ outputs=inf_plots + [inf_plots_container, metrics_status]
115
  )
116
  with gr.Tab("Data Preparation"):
117
  with gr.Accordion("Organise Raw Dataset", open=False):
gradio_wrapper.py CHANGED
@@ -268,7 +268,8 @@ def show_model_charts(model_path):
268
  and returns metric plots.
269
  """
270
  if not model_path:
271
- return (None,) * 11 + (gr.update(visible=False),)
 
272
 
273
  search_path = model_path
274
  if not os.path.isdir(model_path):
@@ -277,11 +278,13 @@ def show_model_charts(model_path):
277
  search_path = snapshot_download(repo_id=model_path)
278
  print(f"Model '{model_path}' downloaded to: {search_path}")
279
  except RepositoryNotFoundError:
 
280
  print(f"Error: Hugging Face model repository not found: {model_path}")
281
- return (None,) * 11 + (gr.update(visible=False),)
282
  except Exception as e:
 
283
  print(f"An error occurred while downloading the model '{model_path}': {e}")
284
- return (None,) * 11 + (gr.update(visible=False),)
285
 
286
  json_path = None
287
  for root, _, files in os.walk(search_path):
@@ -290,8 +293,9 @@ def show_model_charts(model_path):
290
  break
291
 
292
  if not json_path:
 
293
  print(f"trainer_state.json not found in '{search_path}'")
294
- return (None,) * 11 + (gr.update(visible=False),)
295
 
296
  print(f"Found trainer_state.json at: {json_path}")
297
  try:
@@ -301,11 +305,12 @@ def show_model_charts(model_path):
301
  figures.get('Gradient Norm'), figures.get('F1 Scores'), figures.get('Precision'),
302
  figures.get('Recall'), figures.get('Epoch'), figures.get('Eval Runtime'),
303
  figures.get('Eval Samples/sec'), figures.get('Eval Steps/sec'),
304
- gr.update(visible=True),
305
  )
306
  except Exception as e:
 
307
  print(f"Error generating plots for {json_path}: {e}")
308
- return (None,) * 11 + (gr.update(visible=False),)
309
 
310
 
311
  def run_plot_metrics(json_path):
 
268
  and returns metric plots.
269
  """
270
  if not model_path:
271
+ no_model_msg = "### Please select a model."
272
+ return (None,) * 11 + (gr.update(visible=False), gr.update(value=no_model_msg, visible=True))
273
 
274
  search_path = model_path
275
  if not os.path.isdir(model_path):
 
278
  search_path = snapshot_download(repo_id=model_path)
279
  print(f"Model '{model_path}' downloaded to: {search_path}")
280
  except RepositoryNotFoundError:
281
+ msg = f"### ⚠️ Error\n\nHugging Face model repository not found: `{model_path}`"
282
  print(f"Error: Hugging Face model repository not found: {model_path}")
283
+ return (None,) * 11 + (gr.update(visible=False), gr.update(value=msg, visible=True))
284
  except Exception as e:
285
+ msg = f"### ⚠️ Error\n\nAn error occurred while downloading the model `{model_path}`: {e}"
286
  print(f"An error occurred while downloading the model '{model_path}': {e}")
287
+ return (None,) * 11 + (gr.update(visible=False), gr.update(value=msg, visible=True))
288
 
289
  json_path = None
290
  for root, _, files in os.walk(search_path):
 
293
  break
294
 
295
  if not json_path:
296
+ msg = f"### ⚠️ Warning\n\n`trainer_state.json` not found for model `{model_path}`. Cannot display training metrics."
297
  print(f"trainer_state.json not found in '{search_path}'")
298
+ return (None,) * 11 + (gr.update(visible=False), gr.update(value=msg, visible=True))
299
 
300
  print(f"Found trainer_state.json at: {json_path}")
301
  try:
 
305
  figures.get('Gradient Norm'), figures.get('F1 Scores'), figures.get('Precision'),
306
  figures.get('Recall'), figures.get('Epoch'), figures.get('Eval Runtime'),
307
  figures.get('Eval Samples/sec'), figures.get('Eval Steps/sec'),
308
+ gr.update(visible=True), gr.update(visible=False)
309
  )
310
  except Exception as e:
311
+ msg = f"### ⚠️ Error\n\nAn error occurred while generating plots for `{model_path}`: {e}"
312
  print(f"Error generating plots for {json_path}: {e}")
313
+ return (None,) * 11 + (gr.update(visible=False), gr.update(value=msg, visible=True))
314
 
315
 
316
  def run_plot_metrics(json_path):