Spaces:
Running on Zero
Running on Zero
File size: 1,837 Bytes
0f5513d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 | import re
import pandas as pd
from typing import Callable, Optional, Tuple
def extract_epoch_step_from_checkpoint_str(s: str) -> Tuple[float, float]:
"""Return (epoch, step) parsed from a checkpoint name string; NaNs if absent."""
if pd.isna(s):
return float('nan'), float('nan')
me = re.search(r'epoch=([0-9]+(?:\.[0-9]+)?)', s)
ms = re.search(r'step=([0-9]+(?:\.[0-9]+)?)', s)
return (float(me.group(1)) if me else float('nan'),
float(ms.group(1)) if ms else float('nan'))
def resolve_last_checkpoint_positions(
df: pd.DataFrame,
load_last_ckpt_fn: Optional[Callable[[], dict]] = None,
) -> pd.DataFrame:
"""For rows where checkpoint == 'last', set epoch/step from the checkpoint payload.
Falls back to placing 'last' after the current maximum numeric epoch/step if
the loader is unavailable or the checkpoint lacks the expected keys.
"""
last_mask = df['checkpoint'].astype(str).str.contains(r'\blast\b', na=False)
if not last_mask.any():
return df
max_epoch = df['epoch'].dropna().max()
max_step = df['step'].dropna().max()
max_epoch = 0.0 if pd.isna(max_epoch) else max_epoch
max_step = 0.0 if pd.isna(max_step) else max_step
try:
ckpt = load_last_ckpt_fn()
ckpt_epoch = ckpt.get('epoch', ckpt.get('epoch_idx', None))
ckpt_step = ckpt.get('global_step', ckpt.get('step', None))
if ckpt_epoch is None or ckpt_step is None:
raise KeyError("checkpoint missing epoch/global_step")
df.loc[last_mask, 'epoch'] = float(ckpt_epoch)
df.loc[last_mask, 'step'] = float(ckpt_step)
except Exception:
df.loc[last_mask, 'epoch'] = max_epoch + 1.0
df.loc[last_mask, 'step'] = max_step + 1.0
return df
|