| |
| |
| |
| |
| |
|
|
| import contextlib |
| import os |
| import pickle |
| import time |
|
|
| import torch |
|
|
| from torchtitan.config_manager import JobConfig |
| from torchtitan.tools.logging import logger |
|
|
| |
| WARMUP = 3 |
|
|
| |
| MEMORY_SNAPSHOT_MAX_ENTRIES = 100000 |
|
|
|
|
| @contextlib.contextmanager |
| def maybe_enable_profiling(config: JobConfig, *, global_step: int = 0): |
| |
| enable_profiling = config.profiling.enable_profiling |
|
|
| if enable_profiling: |
| dump_dir = config.job.dump_folder |
| save_trace_dir = config.profiling.save_traces_folder |
| trace_dir = os.path.join(dump_dir, save_trace_dir) |
| profile_freq = config.profiling.profile_freq |
|
|
| rank = torch.distributed.get_rank() |
|
|
| def trace_handler(prof): |
| curr_trace_dir_name = "iteration_" + str(prof.step_num) |
| curr_trace_dir = os.path.join(trace_dir, curr_trace_dir_name) |
| if not os.path.exists(curr_trace_dir): |
| os.makedirs(curr_trace_dir, exist_ok=True) |
|
|
| logger.info(f"Dumping profiler traces at step {prof.step_num}") |
| begin = time.monotonic() |
| prof.export_chrome_trace(f"{curr_trace_dir}/rank{rank}_trace.json") |
| logger.info( |
| f"Finished dumping profiler traces in {time.monotonic() - begin:.2f} seconds" |
| ) |
|
|
| logger.info(f"Profiling active. Traces will be saved at {trace_dir}") |
|
|
| if not os.path.exists(trace_dir): |
| os.makedirs(trace_dir, exist_ok=True) |
|
|
| warmup, active = WARMUP, 1 |
| wait = profile_freq - (active + warmup) |
| assert ( |
| wait >= 0 |
| ), "profile_freq must be greater than or equal to warmup + active" |
| gpu_device_profiled = None |
| if torch.cuda.is_available(): |
| gpu_device_profiled = torch.profiler.ProfilerActivity.CUDA |
| elif torch.xpu.is_available(): |
| gpu_device_profiled = torch.profiler.ProfilerActivity.XPU |
| with torch.profiler.profile( |
| activities=[ |
| torch.profiler.ProfilerActivity.CPU, |
| gpu_device_profiled, |
| ], |
| schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active), |
| on_trace_ready=trace_handler, |
| record_shapes=True, |
| ) as torch_profiler: |
| torch_profiler.step_num = global_step |
| yield torch_profiler |
| else: |
| torch_profiler = contextlib.nullcontext() |
| yield None |
|
|
|
|
| @contextlib.contextmanager |
| def maybe_enable_memory_snapshot(config: JobConfig, *, global_step: int = 0): |
| enable_snapshot = config.profiling.enable_memory_snapshot |
| if enable_snapshot: |
| snapshot_folder = config.profiling.save_memory_snapshot_folder |
| snapshot_dir = os.path.join(config.job.dump_folder, snapshot_folder) |
| if not os.path.exists(snapshot_dir): |
| os.makedirs(snapshot_dir, exist_ok=True) |
| rank = torch.distributed.get_rank() |
|
|
| class MemoryProfiler: |
| def __init__(self, step_num: int, freq: int): |
| torch.cuda.memory._record_memory_history( |
| max_entries=MEMORY_SNAPSHOT_MAX_ENTRIES |
| ) |
| |
| self.step_num = step_num |
| self.freq = freq |
|
|
| def step(self, exit_ctx: bool = False): |
| self.step_num += 1 |
| if not exit_ctx and self.step_num % self.freq != 0: |
| return |
| if not exit_ctx: |
| curr_step = self.step_num |
| dir_name = f"iteration_{curr_step}" |
| else: |
| |
| curr_step = self.step_num - 1 |
| dir_name = f"iteration_{curr_step}_exit" |
| curr_snapshot_dir = os.path.join(snapshot_dir, dir_name) |
| if not os.path.exists(curr_snapshot_dir): |
| os.makedirs(curr_snapshot_dir, exist_ok=True) |
| logger.info(f"Dumping memory snapshot at step {curr_step}") |
| begin = time.monotonic() |
| with open( |
| f"{curr_snapshot_dir}/rank{rank}_memory_snapshot.pickle", "wb" |
| ) as output: |
| pickle.dump(torch.cuda.memory._snapshot(), output) |
| logger.info( |
| f"Finished dumping memory snapshot in {time.monotonic() - begin:.2f} seconds" |
| ) |
|
|
| logger.info(f"Memory profiler active. Snapshot will be saved at {snapshot_dir}") |
| profiler = MemoryProfiler(global_step, config.profiling.profile_freq) |
| try: |
| yield profiler |
| except torch.OutOfMemoryError as e: |
| profiler.step(exit_ctx=True) |
| else: |
| yield None |
|
|