import logging import time import traceback from pathlib import Path import torch from slime.utils.memory_utils import print_memory logger = logging.getLogger(__name__) class TrainProfiler: def __init__(self, args): self.args = args self._torch_profiler_overall = None self._memory_profiler_overall = None if args.use_pytorch_profiler and ("train_overall" in args.profile_target): self._torch_profiler_overall = _create_torch_profiler(args, name="train_overall") if args.record_memory_history and ("train_overall" in args.profile_target): self._memory_profiler_overall = _BaseMemoryProfiler.create(args) self._memory_profiler_overall.start() def on_init_end(self): if self._torch_profiler_overall is not None: self._torch_profiler_overall.start() def step(self, rollout_id: int): if self._torch_profiler_overall is not None: self._torch_profiler_overall.step() if ( self._memory_profiler_overall is not None and ((s := self.args.memory_snapshot_num_steps) is not None) and (rollout_id == s - 1) ): self._memory_profiler_overall.stop() def iterate_train_actor(self, iterator): return _profile_simple_loop(iterator, self.args, name="train_actor") def iterate_train_log_probs(self, iterator): return _profile_simple_loop(iterator, self.args, name="train_log_probs") def _profile_simple_loop(iterator, args, name): if not (args.use_pytorch_profiler and (name in args.profile_target)): yield from iterator return torch_profiler = _create_torch_profiler(args, name=name) torch_profiler.start() for item in iterator: yield item torch_profiler.step() def _create_torch_profiler(args, name): return torch.profiler.profile( schedule=torch.profiler.schedule( # TODO the train_actor and train_log_probs ones may need to have different args to control step wait=max(args.profile_step_start - 1, 0), warmup=1 if args.profile_step_start > 0 else 0, active=args.profile_step_end - args.profile_step_start, repeat=1, ), on_trace_ready=torch.profiler.tensorboard_trace_handler( args.tensorboard_dir, worker_name=f"{name}_rank_{torch.distributed.get_rank()}", use_gzip=True, ), record_shapes=True, with_stack=True, profile_memory=True, with_flops=True, ) class _BaseMemoryProfiler: @staticmethod def create(args): c = { "torch": _TorchMemoryProfiler, "memray": _MemrayMemoryProfiler, }[args.memory_recorder] return c(args) def __init__(self, args): self._path_dump = ( Path(args.memory_snapshot_dir) / f"memory_snapshot_time{time.time()}_rank{torch.distributed.get_rank()}_{args.memory_snapshot_path}" ) def start(self): raise NotImplementedError def stop(self): raise NotImplementedError class _TorchMemoryProfiler(_BaseMemoryProfiler): def start(self): logger.info("Attach OOM dump memory history.") torch.cuda.memory._record_memory_history( max_entries=1000000, # record stack information for the trace events # trace_alloc_record_context=True, stacks="all", ) def oom_observer(device, alloc, device_alloc, device_free): logger.info( f"Observe OOM, will dump snapshot to {self._path_dump}. ({device=} {alloc=} {device_alloc=} {device_free=}; stacktrace is as follows)" ) traceback.print_stack() torch.cuda.memory._dump_snapshot(self._path_dump) print_memory("when oom") torch._C._cuda_attach_out_of_memory_observer(oom_observer) def stop(self): logger.info(f"Dump memory snapshot to: {self._path_dump}") torch.cuda.memory._dump_snapshot(self._path_dump) torch.cuda.memory._record_memory_history(enabled=None) class _MemrayMemoryProfiler(_BaseMemoryProfiler): def __init__(self, args): super().__init__(args) assert args.memory_snapshot_num_steps is not None, "In memray, must provide --memory-snapshot-num-steps" def start(self): logger.info("Memray tracker started.") import memray self._tracker = memray.Tracker( file_name=self._path_dump, native_traces=True, ) self._tracker.__enter__() def stop(self): logger.info(f"Memray tracker stopped and dump snapshot to: {self._path_dump}") self._tracker.__exit__(None, None, None)