File size: 2,846 Bytes
3d1c0e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# Copyright (c) 2025 FoundationVision
# SPDX-License-Identifier: MIT
import logging
import os
from contextlib import contextmanager, nullcontext
from datetime import datetime

import torch
import torch.distributed as dist
from torch.profiler import record_function as torch_record_function


class _TraceHandler:
    def __init__(self, save_path="/tmp/trace.json", logger=None, rank=None):
        self.logger = logger
        if logger is None:
            self.logger = logging.getLogger(__name__)

        self.logger.info(f"trace dump path: {save_path}")
        self.save_path = save_path + ".json.gz"
        self.rank = rank

    def __call__(self, prof):
        if self.logger is not None:
            self.logger.info(f"dump trace to {self.save_path}")
        prof.export_chrome_trace(self.save_path)

class torch_profiler:
    """
    usage:

    ```python
    import pnp

    pnp.torch_profiler.setup(output_folder="./", wait_steps=30)

    for step in range(100):
        pnp.torch_profiler.step()
        ...
    
        with pnp.troch_profiler.mark("fwd"):
            model.forward()

        ...

        with pnp.torch_profiler.mark("bwd"):
            loss.backward()

    ```

    """
    _TP = None
    mark = nullcontext

    @staticmethod
    def step():
        if torch_profiler._TP is None:
            return

        torch_profiler._TP.step()

    @staticmethod
    @property
    def mark():
        return torch_profiler.mark

    @staticmethod
    def setup(enabled=True, output_folder="./", file_prefix="", wait_steps=30):
        """
        enabled: if False, profiler will do nothing
        output_folder: the folder to dump trace
        wait_steps: start profiling after wait_steps(in your training loop)
        file_prefix: the prefix of the trace file for your custom
        """
        if enabled:
           if not os.path.exists(output_folder):
               os.makedirs(output_folder, exist_ok=True)
  
           torch_profiler._TP = torch.profiler.profile(
               activities=[
                   torch.profiler.ProfilerActivity.CPU,
                   torch.profiler.ProfilerActivity.CUDA,
               ],
               schedule=torch.profiler.schedule(
                   wait=wait_steps,
                   warmup=3,
                   active=5,
                   repeat=0,
               ),
               with_stack=True,
               record_shapes=True,
               profile_memory=False,
               on_trace_ready=_TraceHandler(
                   f"{output_folder}/{file_prefix}world_size-{dist.get_world_size()}-rank{dist.get_rank()}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}",
                   None,
                   dist.get_rank(),
               ),
           )
           torch_profiler._TP.start()
           torch_profiler.mark = torch_record_function