thomas-schweich commited on
Commit
761dcf3
·
1 Parent(s): d6d0f4e

Add patience clock chart to dashboard (inferred from val loss)

Browse files
Files changed (2) hide show
  1. pawn/dashboard/charts.py +54 -0
  2. pawn/dashboard/sol.py +10 -0
pawn/dashboard/charts.py CHANGED
@@ -298,3 +298,57 @@ def val_accuracy_chart(records: list[dict], x_key: str, run_type: str):
298
  ("val_top5", "Val Top-5", COLORS["green"]),
299
  ]
300
  return make_chart(records, x_key, specs, title="Validation Accuracy", y_title="Rate")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  ("val_top5", "Val Top-5", COLORS["green"]),
299
  ]
300
  return make_chart(records, x_key, specs, title="Validation Accuracy", y_title="Rate")
301
+
302
+
303
+ def patience_chart(val_records: list[dict], x_key: str = "step",
304
+ patience_limit: int = 10) -> "go.Figure":
305
+ """Infer patience counter from val loss records and plot it.
306
+
307
+ Patience resets to 0 when val loss improves, increments by 1 otherwise.
308
+ Shows a horizontal line at the patience limit.
309
+ """
310
+ import plotly.graph_objects as go
311
+
312
+ if not val_records:
313
+ fig = go.Figure()
314
+ fig.update_layout(**LAYOUT, title="Patience (early stopping)")
315
+ return fig
316
+
317
+ best_loss = float("inf")
318
+ steps = []
319
+ counters = []
320
+
321
+ counter = 0
322
+ for r in val_records:
323
+ vl = r.get("val/loss")
324
+ s = r.get(x_key)
325
+ if vl is None or s is None:
326
+ continue
327
+ if vl < best_loss:
328
+ best_loss = vl
329
+ counter = 0
330
+ else:
331
+ counter += 1
332
+ steps.append(s)
333
+ counters.append(counter)
334
+
335
+ fig = go.Figure()
336
+ fig.add_trace(go.Scatter(
337
+ x=steps, y=counters, mode="lines+markers",
338
+ name="Patience counter",
339
+ line=dict(color=COLORS["orange"], width=2),
340
+ marker=dict(size=4),
341
+ ))
342
+ fig.add_hline(
343
+ y=patience_limit, line_dash="dash", line_color=COLORS["red"],
344
+ annotation_text=f"limit ({patience_limit})",
345
+ annotation_position="top left",
346
+ )
347
+ fig.update_layout(
348
+ **LAYOUT,
349
+ title="Patience (early stopping)",
350
+ xaxis_title=x_key.capitalize(),
351
+ yaxis_title="Evals without improvement",
352
+ yaxis=dict(range=[0, patience_limit + 2]),
353
+ )
354
+ return fig
pawn/dashboard/sol.py CHANGED
@@ -327,6 +327,16 @@ def MetricsCharts():
327
  with solara.Columns([1, 1]):
328
  ChartWithInfo(charts.gpu_chart(train, x_key), desc("gpu"))
329
  ChartWithInfo(charts.time_chart(train, x_key, run_type), desc("time"))
 
 
 
 
 
 
 
 
 
 
330
 
331
 
332
  @solara.component
 
327
  with solara.Columns([1, 1]):
328
  ChartWithInfo(charts.gpu_chart(train, x_key), desc("gpu"))
329
  ChartWithInfo(charts.time_chart(train, x_key, run_type), desc("time"))
330
+ if val:
331
+ patience = config.get("training", {}).get("patience", 10)
332
+ if isinstance(patience, str):
333
+ patience = int(patience)
334
+ with solara.Columns([1]):
335
+ ChartWithInfo(
336
+ charts.patience_chart(val, x_key, patience_limit=patience),
337
+ "Consecutive evals without val loss improvement. "
338
+ "Training stops when this reaches the patience limit."
339
+ )
340
 
341
 
342
  @solara.component