| | |
| | |
| |
|
| | import contextlib |
| | import logging |
| | import os |
| | from dataclasses import dataclass |
| | from pathlib import Path |
| |
|
| | import torch.distributed |
| | import wandb |
| | import xformers.profiler |
| | from torch.profiler.profiler import profile |
| | from xformers.profiler import MemSnapshotsProfiler, PyTorchProfiler |
| |
|
| | from core.distributed import get_is_master |
| |
|
| |
|
| | @dataclass |
| | class ProfilerArgs: |
| | run: bool = False |
| | trace_folder: str = "profiling" |
| | mem_warmup: int = 100 |
| | mem_steps: int = 2 |
| | profile_warmup: int = 102 |
| | profile_steps: int = 2 |
| |
|
| |
|
| | logger = logging.getLogger() |
| |
|
| |
|
| | def perfetto_to_html(json_file, html_file): |
| | import gzip |
| | import string |
| |
|
| | import viztracer |
| |
|
| | root = os.path.dirname(viztracer.__file__) |
| | sub = {} |
| | json_file = gzip.open(json_file) if ".gz" in str(json_file) else open(json_file) |
| | with open( |
| | os.path.join(root, "html/trace_viewer_embedder.html"), encoding="utf-8" |
| | ) as f: |
| | tmpl = f.read() |
| | with open(os.path.join(root, "html/trace_viewer_full.html"), encoding="utf-8") as f: |
| | sub["trace_viewer_full"] = f.read() |
| | with json_file as j: |
| | content = j.read() |
| | if isinstance(content, bytes): |
| | content = content.decode("utf-8") |
| | sub["json_data"] = content.replace("</script>", "<\\/script>") |
| | with open(html_file, "w+", encoding="utf-8") as output_file: |
| | output_file.write(string.Template(tmpl).substitute(sub)) |
| |
|
| |
|
| | class PyTorchProfilerWandb(PyTorchProfiler): |
| | def __init__(self, main_profiler) -> None: |
| | self.main_profiler = main_profiler |
| | self.num_steps = 0 |
| | self.pytorch_profiler = torch.profiler.profile( |
| | on_trace_ready=self._on_trace, |
| | profile_memory=True, |
| | record_shapes=True, |
| | |
| | |
| | |
| | with_stack=False, |
| | with_flops=True, |
| | activities=self.ACTIVITIES, |
| | ) |
| |
|
| | def _analyze_trace(self, prof: profile): |
| | logger.info("Begin analyze trace") |
| | super()._analyze_trace(prof) |
| | logger.info("End analyze trace") |
| |
|
| | def _on_trace(self, prof: torch.profiler.profiler.profile) -> None: |
| | super()._on_trace(prof) |
| | if get_is_master() and wandb.run is not None: |
| | filename = list( |
| | Path(self.main_profiler.output_dir).glob( |
| | "profile_CPU_CUDA*/*.pt.trace.json*" |
| | ) |
| | )[0] |
| | html_path = str(filename).replace(".json", ".html") |
| | perfetto_to_html(filename, html_path) |
| | wandb.log({"profile_trace": wandb.Html(html_path)}) |
| |
|
| |
|
| | class MemSnapshotsProfilerWandb(MemSnapshotsProfiler): |
| | def __exit__(self, exc_type, exc_val, exc_tb): |
| | super().__exit__(exc_type, exc_val, exc_tb) |
| | if get_is_master() and wandb.run is not None: |
| | filename = list( |
| | Path(self.main_profiler.output_dir).glob("memory_trace_plot/*.html") |
| | )[0] |
| | wandb.log({"memory_trace": wandb.Html(open(filename), inject=False)}) |
| |
|
| |
|
| | @contextlib.contextmanager |
| | def maybe_run_profiler(dump_dir, module, config: ProfilerArgs): |
| | |
| |
|
| | if config.run: |
| | trace_dir = os.path.join(dump_dir, config.trace_folder) |
| |
|
| | logger.info(f"Profiling active. Traces will be saved at {trace_dir}") |
| |
|
| | if get_is_master() and not os.path.exists(trace_dir): |
| | os.makedirs(trace_dir) |
| | if torch.distributed.is_initialized(): |
| | torch.distributed.barrier() |
| |
|
| | with xformers.profiler.profile( |
| | output_dir=trace_dir, |
| | module=module, |
| | schedule=[ |
| | ( |
| | MemSnapshotsProfilerWandb, |
| | config.mem_warmup, |
| | config.mem_warmup + config.mem_steps, |
| | ), |
| | ( |
| | PyTorchProfilerWandb, |
| | config.profile_warmup, |
| | config.profile_warmup + config.profile_steps, |
| | ), |
| | ], |
| | ) as profiler: |
| | yield profiler |
| |
|
| | else: |
| | torch_profiler = contextlib.nullcontext() |
| | yield None |
| |
|