File size: 643 Bytes
d7b3a74 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 | import logging
from pathlib import Path
import torch
logger = logging.getLogger(__name__)
def save_debug_train_data(args, *, rollout_id, rollout_data):
if (path_template := args.save_debug_train_data) is not None:
rank = torch.distributed.get_rank()
path = Path(path_template.format(rollout_id=rollout_id, rank=rank))
logger.info(f"Save debug train data to {path}")
path.parent.mkdir(parents=True, exist_ok=True)
torch.save(
dict(
rollout_id=rollout_id,
rank=rank,
rollout_data=rollout_data,
),
path,
)
|