Commit ·
8bdee23
1
Parent(s): 82293e4
fix: Display warning when trainer_state.json is missing for model metrics
Browse files- app.py +3 -2
- gradio_wrapper.py +11 -6
app.py
CHANGED
|
@@ -78,6 +78,7 @@ with gr.Blocks(theme=gr.themes.Monochrome(), title="Multi-Class Classification (
|
|
| 78 |
choices=hf_models,
|
| 79 |
value="darklorddad/Model-Swin-Transformer-88"
|
| 80 |
)
|
|
|
|
| 81 |
with gr.Column(visible=False) as inf_plots_container:
|
| 82 |
with gr.Row():
|
| 83 |
inf_plot_loss = gr.Plot(label="Loss")
|
|
@@ -105,12 +106,12 @@ with gr.Blocks(theme=gr.themes.Monochrome(), title="Multi-Class Classification (
|
|
| 105 |
metrics_model_path.change(
|
| 106 |
fn=show_model_charts,
|
| 107 |
inputs=[metrics_model_path],
|
| 108 |
-
outputs=inf_plots + [inf_plots_container]
|
| 109 |
)
|
| 110 |
demo.load(
|
| 111 |
fn=show_model_charts,
|
| 112 |
inputs=[metrics_model_path],
|
| 113 |
-
outputs=inf_plots + [inf_plots_container]
|
| 114 |
)
|
| 115 |
with gr.Tab("Data Preparation"):
|
| 116 |
with gr.Accordion("Organise Raw Dataset", open=False):
|
|
|
|
| 78 |
choices=hf_models,
|
| 79 |
value="darklorddad/Model-Swin-Transformer-88"
|
| 80 |
)
|
| 81 |
+
metrics_status = gr.Markdown(visible=False)
|
| 82 |
with gr.Column(visible=False) as inf_plots_container:
|
| 83 |
with gr.Row():
|
| 84 |
inf_plot_loss = gr.Plot(label="Loss")
|
|
|
|
| 106 |
metrics_model_path.change(
|
| 107 |
fn=show_model_charts,
|
| 108 |
inputs=[metrics_model_path],
|
| 109 |
+
outputs=inf_plots + [inf_plots_container, metrics_status]
|
| 110 |
)
|
| 111 |
demo.load(
|
| 112 |
fn=show_model_charts,
|
| 113 |
inputs=[metrics_model_path],
|
| 114 |
+
outputs=inf_plots + [inf_plots_container, metrics_status]
|
| 115 |
)
|
| 116 |
with gr.Tab("Data Preparation"):
|
| 117 |
with gr.Accordion("Organise Raw Dataset", open=False):
|
gradio_wrapper.py
CHANGED
|
@@ -268,7 +268,8 @@ def show_model_charts(model_path):
|
|
| 268 |
and returns metric plots.
|
| 269 |
"""
|
| 270 |
if not model_path:
|
| 271 |
-
|
|
|
|
| 272 |
|
| 273 |
search_path = model_path
|
| 274 |
if not os.path.isdir(model_path):
|
|
@@ -277,11 +278,13 @@ def show_model_charts(model_path):
|
|
| 277 |
search_path = snapshot_download(repo_id=model_path)
|
| 278 |
print(f"Model '{model_path}' downloaded to: {search_path}")
|
| 279 |
except RepositoryNotFoundError:
|
|
|
|
| 280 |
print(f"Error: Hugging Face model repository not found: {model_path}")
|
| 281 |
-
return (None,) * 11 + (gr.update(visible=False),)
|
| 282 |
except Exception as e:
|
|
|
|
| 283 |
print(f"An error occurred while downloading the model '{model_path}': {e}")
|
| 284 |
-
return (None,) * 11 + (gr.update(visible=False),)
|
| 285 |
|
| 286 |
json_path = None
|
| 287 |
for root, _, files in os.walk(search_path):
|
|
@@ -290,8 +293,9 @@ def show_model_charts(model_path):
|
|
| 290 |
break
|
| 291 |
|
| 292 |
if not json_path:
|
|
|
|
| 293 |
print(f"trainer_state.json not found in '{search_path}'")
|
| 294 |
-
return (None,) * 11 + (gr.update(visible=False),)
|
| 295 |
|
| 296 |
print(f"Found trainer_state.json at: {json_path}")
|
| 297 |
try:
|
|
@@ -301,11 +305,12 @@ def show_model_charts(model_path):
|
|
| 301 |
figures.get('Gradient Norm'), figures.get('F1 Scores'), figures.get('Precision'),
|
| 302 |
figures.get('Recall'), figures.get('Epoch'), figures.get('Eval Runtime'),
|
| 303 |
figures.get('Eval Samples/sec'), figures.get('Eval Steps/sec'),
|
| 304 |
-
gr.update(visible=True),
|
| 305 |
)
|
| 306 |
except Exception as e:
|
|
|
|
| 307 |
print(f"Error generating plots for {json_path}: {e}")
|
| 308 |
-
return (None,) * 11 + (gr.update(visible=False),)
|
| 309 |
|
| 310 |
|
| 311 |
def run_plot_metrics(json_path):
|
|
|
|
| 268 |
and returns metric plots.
|
| 269 |
"""
|
| 270 |
if not model_path:
|
| 271 |
+
no_model_msg = "### Please select a model."
|
| 272 |
+
return (None,) * 11 + (gr.update(visible=False), gr.update(value=no_model_msg, visible=True))
|
| 273 |
|
| 274 |
search_path = model_path
|
| 275 |
if not os.path.isdir(model_path):
|
|
|
|
| 278 |
search_path = snapshot_download(repo_id=model_path)
|
| 279 |
print(f"Model '{model_path}' downloaded to: {search_path}")
|
| 280 |
except RepositoryNotFoundError:
|
| 281 |
+
msg = f"### ⚠️ Error\n\nHugging Face model repository not found: `{model_path}`"
|
| 282 |
print(f"Error: Hugging Face model repository not found: {model_path}")
|
| 283 |
+
return (None,) * 11 + (gr.update(visible=False), gr.update(value=msg, visible=True))
|
| 284 |
except Exception as e:
|
| 285 |
+
msg = f"### ⚠️ Error\n\nAn error occurred while downloading the model `{model_path}`: {e}"
|
| 286 |
print(f"An error occurred while downloading the model '{model_path}': {e}")
|
| 287 |
+
return (None,) * 11 + (gr.update(visible=False), gr.update(value=msg, visible=True))
|
| 288 |
|
| 289 |
json_path = None
|
| 290 |
for root, _, files in os.walk(search_path):
|
|
|
|
| 293 |
break
|
| 294 |
|
| 295 |
if not json_path:
|
| 296 |
+
msg = f"### ⚠️ Warning\n\n`trainer_state.json` not found for model `{model_path}`. Cannot display training metrics."
|
| 297 |
print(f"trainer_state.json not found in '{search_path}'")
|
| 298 |
+
return (None,) * 11 + (gr.update(visible=False), gr.update(value=msg, visible=True))
|
| 299 |
|
| 300 |
print(f"Found trainer_state.json at: {json_path}")
|
| 301 |
try:
|
|
|
|
| 305 |
figures.get('Gradient Norm'), figures.get('F1 Scores'), figures.get('Precision'),
|
| 306 |
figures.get('Recall'), figures.get('Epoch'), figures.get('Eval Runtime'),
|
| 307 |
figures.get('Eval Samples/sec'), figures.get('Eval Steps/sec'),
|
| 308 |
+
gr.update(visible=True), gr.update(visible=False)
|
| 309 |
)
|
| 310 |
except Exception as e:
|
| 311 |
+
msg = f"### ⚠️ Error\n\nAn error occurred while generating plots for `{model_path}`: {e}"
|
| 312 |
print(f"Error generating plots for {json_path}: {e}")
|
| 313 |
+
return (None,) * 11 + (gr.update(visible=False), gr.update(value=msg, visible=True))
|
| 314 |
|
| 315 |
|
| 316 |
def run_plot_metrics(json_path):
|