Commit ·
761dcf3
1
Parent(s): d6d0f4e
Add patience clock chart to dashboard (inferred from val loss)
Browse files- pawn/dashboard/charts.py +54 -0
- pawn/dashboard/sol.py +10 -0
pawn/dashboard/charts.py
CHANGED
|
@@ -298,3 +298,57 @@ def val_accuracy_chart(records: list[dict], x_key: str, run_type: str):
|
|
| 298 |
("val_top5", "Val Top-5", COLORS["green"]),
|
| 299 |
]
|
| 300 |
return make_chart(records, x_key, specs, title="Validation Accuracy", y_title="Rate")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
("val_top5", "Val Top-5", COLORS["green"]),
|
| 299 |
]
|
| 300 |
return make_chart(records, x_key, specs, title="Validation Accuracy", y_title="Rate")
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def patience_chart(val_records: list[dict], x_key: str = "step",
|
| 304 |
+
patience_limit: int = 10) -> "go.Figure":
|
| 305 |
+
"""Infer patience counter from val loss records and plot it.
|
| 306 |
+
|
| 307 |
+
Patience resets to 0 when val loss improves, increments by 1 otherwise.
|
| 308 |
+
Shows a horizontal line at the patience limit.
|
| 309 |
+
"""
|
| 310 |
+
import plotly.graph_objects as go
|
| 311 |
+
|
| 312 |
+
if not val_records:
|
| 313 |
+
fig = go.Figure()
|
| 314 |
+
fig.update_layout(**LAYOUT, title="Patience (early stopping)")
|
| 315 |
+
return fig
|
| 316 |
+
|
| 317 |
+
best_loss = float("inf")
|
| 318 |
+
steps = []
|
| 319 |
+
counters = []
|
| 320 |
+
|
| 321 |
+
counter = 0
|
| 322 |
+
for r in val_records:
|
| 323 |
+
vl = r.get("val/loss")
|
| 324 |
+
s = r.get(x_key)
|
| 325 |
+
if vl is None or s is None:
|
| 326 |
+
continue
|
| 327 |
+
if vl < best_loss:
|
| 328 |
+
best_loss = vl
|
| 329 |
+
counter = 0
|
| 330 |
+
else:
|
| 331 |
+
counter += 1
|
| 332 |
+
steps.append(s)
|
| 333 |
+
counters.append(counter)
|
| 334 |
+
|
| 335 |
+
fig = go.Figure()
|
| 336 |
+
fig.add_trace(go.Scatter(
|
| 337 |
+
x=steps, y=counters, mode="lines+markers",
|
| 338 |
+
name="Patience counter",
|
| 339 |
+
line=dict(color=COLORS["orange"], width=2),
|
| 340 |
+
marker=dict(size=4),
|
| 341 |
+
))
|
| 342 |
+
fig.add_hline(
|
| 343 |
+
y=patience_limit, line_dash="dash", line_color=COLORS["red"],
|
| 344 |
+
annotation_text=f"limit ({patience_limit})",
|
| 345 |
+
annotation_position="top left",
|
| 346 |
+
)
|
| 347 |
+
fig.update_layout(
|
| 348 |
+
**LAYOUT,
|
| 349 |
+
title="Patience (early stopping)",
|
| 350 |
+
xaxis_title=x_key.capitalize(),
|
| 351 |
+
yaxis_title="Evals without improvement",
|
| 352 |
+
yaxis=dict(range=[0, patience_limit + 2]),
|
| 353 |
+
)
|
| 354 |
+
return fig
|
pawn/dashboard/sol.py
CHANGED
|
@@ -327,6 +327,16 @@ def MetricsCharts():
|
|
| 327 |
with solara.Columns([1, 1]):
|
| 328 |
ChartWithInfo(charts.gpu_chart(train, x_key), desc("gpu"))
|
| 329 |
ChartWithInfo(charts.time_chart(train, x_key, run_type), desc("time"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
|
| 331 |
|
| 332 |
@solara.component
|
|
|
|
| 327 |
with solara.Columns([1, 1]):
|
| 328 |
ChartWithInfo(charts.gpu_chart(train, x_key), desc("gpu"))
|
| 329 |
ChartWithInfo(charts.time_chart(train, x_key, run_type), desc("time"))
|
| 330 |
+
if val:
|
| 331 |
+
patience = config.get("training", {}).get("patience", 10)
|
| 332 |
+
if isinstance(patience, str):
|
| 333 |
+
patience = int(patience)
|
| 334 |
+
with solara.Columns([1]):
|
| 335 |
+
ChartWithInfo(
|
| 336 |
+
charts.patience_chart(val, x_key, patience_limit=patience),
|
| 337 |
+
"Consecutive evals without val loss improvement. "
|
| 338 |
+
"Training stops when this reaches the patience limit."
|
| 339 |
+
)
|
| 340 |
|
| 341 |
|
| 342 |
@solara.component
|