| import logging | |
| from pathlib import Path | |
| import torch | |
| logger = logging.getLogger(__name__) | |
| def save_debug_train_data(args, *, rollout_id, rollout_data): | |
| if (path_template := args.save_debug_train_data) is not None: | |
| rank = torch.distributed.get_rank() | |
| path = Path(path_template.format(rollout_id=rollout_id, rank=rank)) | |
| logger.info(f"Save debug train data to {path}") | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| torch.save( | |
| dict( | |
| rollout_id=rollout_id, | |
| rank=rank, | |
| rollout_data=rollout_data, | |
| ), | |
| path, | |
| ) | |