| import logging |
| import time |
| import traceback |
| from pathlib import Path |
|
|
| import torch |
|
|
| from slime.utils.memory_utils import print_memory |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class TrainProfiler: |
| def __init__(self, args): |
| self.args = args |
| self._torch_profiler_overall = None |
| self._memory_profiler_overall = None |
|
|
| if args.use_pytorch_profiler and ("train_overall" in args.profile_target): |
| self._torch_profiler_overall = _create_torch_profiler(args, name="train_overall") |
|
|
| if args.record_memory_history and ("train_overall" in args.profile_target): |
| self._memory_profiler_overall = _BaseMemoryProfiler.create(args) |
| self._memory_profiler_overall.start() |
|
|
| def on_init_end(self): |
| if self._torch_profiler_overall is not None: |
| self._torch_profiler_overall.start() |
|
|
| def step(self, rollout_id: int): |
| if self._torch_profiler_overall is not None: |
| self._torch_profiler_overall.step() |
|
|
| if ( |
| self._memory_profiler_overall is not None |
| and ((s := self.args.memory_snapshot_num_steps) is not None) |
| and (rollout_id == s - 1) |
| ): |
| self._memory_profiler_overall.stop() |
|
|
| def iterate_train_actor(self, iterator): |
| return _profile_simple_loop(iterator, self.args, name="train_actor") |
|
|
| def iterate_train_log_probs(self, iterator): |
| return _profile_simple_loop(iterator, self.args, name="train_log_probs") |
|
|
|
|
| def _profile_simple_loop(iterator, args, name): |
| if not (args.use_pytorch_profiler and (name in args.profile_target)): |
| yield from iterator |
| return |
|
|
| torch_profiler = _create_torch_profiler(args, name=name) |
| torch_profiler.start() |
| for item in iterator: |
| yield item |
| torch_profiler.step() |
|
|
|
|
| def _create_torch_profiler(args, name): |
| return torch.profiler.profile( |
| schedule=torch.profiler.schedule( |
| |
| wait=max(args.profile_step_start - 1, 0), |
| warmup=1 if args.profile_step_start > 0 else 0, |
| active=args.profile_step_end - args.profile_step_start, |
| repeat=1, |
| ), |
| on_trace_ready=torch.profiler.tensorboard_trace_handler( |
| args.tensorboard_dir, |
| worker_name=f"{name}_rank_{torch.distributed.get_rank()}", |
| use_gzip=True, |
| ), |
| record_shapes=True, |
| with_stack=True, |
| profile_memory=True, |
| with_flops=True, |
| ) |
|
|
|
|
| class _BaseMemoryProfiler: |
| @staticmethod |
| def create(args): |
| c = { |
| "torch": _TorchMemoryProfiler, |
| "memray": _MemrayMemoryProfiler, |
| }[args.memory_recorder] |
| return c(args) |
|
|
| def __init__(self, args): |
| self._path_dump = ( |
| Path(args.memory_snapshot_dir) |
| / f"memory_snapshot_time{time.time()}_rank{torch.distributed.get_rank()}_{args.memory_snapshot_path}" |
| ) |
|
|
| def start(self): |
| raise NotImplementedError |
|
|
| def stop(self): |
| raise NotImplementedError |
|
|
|
|
| class _TorchMemoryProfiler(_BaseMemoryProfiler): |
| def start(self): |
| logger.info("Attach OOM dump memory history.") |
|
|
| torch.cuda.memory._record_memory_history( |
| max_entries=1000000, |
| |
| |
| stacks="all", |
| ) |
|
|
| def oom_observer(device, alloc, device_alloc, device_free): |
| logger.info( |
| f"Observe OOM, will dump snapshot to {self._path_dump}. ({device=} {alloc=} {device_alloc=} {device_free=}; stacktrace is as follows)" |
| ) |
| traceback.print_stack() |
| torch.cuda.memory._dump_snapshot(self._path_dump) |
| print_memory("when oom") |
|
|
| torch._C._cuda_attach_out_of_memory_observer(oom_observer) |
|
|
| def stop(self): |
| logger.info(f"Dump memory snapshot to: {self._path_dump}") |
| torch.cuda.memory._dump_snapshot(self._path_dump) |
| torch.cuda.memory._record_memory_history(enabled=None) |
|
|
|
|
| class _MemrayMemoryProfiler(_BaseMemoryProfiler): |
| def __init__(self, args): |
| super().__init__(args) |
| assert args.memory_snapshot_num_steps is not None, "In memray, must provide --memory-snapshot-num-steps" |
|
|
| def start(self): |
| logger.info("Memray tracker started.") |
| import memray |
|
|
| self._tracker = memray.Tracker( |
| file_name=self._path_dump, |
| native_traces=True, |
| ) |
| self._tracker.__enter__() |
|
|
| def stop(self): |
| logger.info(f"Memray tracker stopped and dump snapshot to: {self._path_dump}") |
| self._tracker.__exit__(None, None, None) |
|
|