File size: 1,837 Bytes
0f5513d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import re
import pandas as pd

from typing import Callable, Optional, Tuple


def extract_epoch_step_from_checkpoint_str(s: str) -> Tuple[float, float]:
    """Return (epoch, step) parsed from a checkpoint name string; NaNs if absent."""
    if pd.isna(s):
        return float('nan'), float('nan')
    me = re.search(r'epoch=([0-9]+(?:\.[0-9]+)?)', s)
    ms = re.search(r'step=([0-9]+(?:\.[0-9]+)?)', s)
    return (float(me.group(1)) if me else float('nan'),
            float(ms.group(1)) if ms else float('nan'))


def resolve_last_checkpoint_positions(

    df: pd.DataFrame,

    load_last_ckpt_fn: Optional[Callable[[], dict]] = None,

) -> pd.DataFrame:
    """For rows where checkpoint == 'last', set epoch/step from the checkpoint payload.



    Falls back to placing 'last' after the current maximum numeric epoch/step if

    the loader is unavailable or the checkpoint lacks the expected keys.

    """
    last_mask = df['checkpoint'].astype(str).str.contains(r'\blast\b', na=False)
    if not last_mask.any():
        return df

    max_epoch = df['epoch'].dropna().max()
    max_step  = df['step'].dropna().max()
    max_epoch = 0.0 if pd.isna(max_epoch) else max_epoch
    max_step  = 0.0 if pd.isna(max_step)  else max_step

    try:
        ckpt = load_last_ckpt_fn()
        ckpt_epoch = ckpt.get('epoch', ckpt.get('epoch_idx', None))
        ckpt_step  = ckpt.get('global_step', ckpt.get('step', None))
        if ckpt_epoch is None or ckpt_step is None:
            raise KeyError("checkpoint missing epoch/global_step")
        df.loc[last_mask, 'epoch'] = float(ckpt_epoch)
        df.loc[last_mask, 'step']  = float(ckpt_step)
    except Exception:
        df.loc[last_mask, 'epoch'] = max_epoch + 1.0
        df.loc[last_mask, 'step']  = max_step  + 1.0

    return df