File size: 7,585 Bytes
1fa3c6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import time
from collections.abc import Callable

from transformers import Trainer
from transformers.integrations import is_mlflow_available, is_wandb_available


if is_wandb_available():
    import wandb

if is_mlflow_available():
    import mlflow


class ProfilingContext:
    """

    Context manager for profiling code blocks with configurable logging.



    This class handles timing of code execution and logging metrics to various backends (Weights & Biases, MLflow)

    without being coupled to the Trainer class.



    Args:

        name (`str`):

            Name of the profiling context. Used in the metric name.

        report_to (`list` of `str`):

            List of integrations to report metrics to (e.g., ["wandb", "mlflow"]).

        is_main_process (`bool`, *optional*, defaults to `True`):

            Whether this is the main process in distributed training. Metrics are only logged from the main process.

        step (`int` or `None`, *optional*):

            Training step to associate with the logged metrics.

        metric_prefix (`str`, *optional*, defaults to `"profiling/Time taken"`):

            Prefix for the metric name in logs.



    Example:

    ```python

    # Direct usage

    from trl.extras.profiling import ProfilingContext



    with ProfilingContext(

        name="MyClass.expensive_operation",

        report_to=["wandb"],

        is_main_process=True,

        step=100,

    ):

        # Code to profile

        result = expensive_computation()



    # With Trainer (backwards compatible via profiling_context function)

    from transformers import Trainer

    from trl.extras.profiling import profiling_context





    class MyTrainer(Trainer):

        def some_method(self):

            with profiling_context(self, "matrix_multiplication"):

                result = matrix_multiply()

    ```

    """

    def __init__(

        self,

        name: str,

        report_to: list[str],

        is_main_process: bool = True,

        step: int | None = None,

        metric_prefix: str = "profiling/Time taken",

    ):
        self.name = name
        self.report_to = report_to
        self.is_main_process = is_main_process
        self.step = step
        self.metric_prefix = metric_prefix
        self._start_time = None

    def __enter__(self):
        """Start timing when entering the context."""
        self._start_time = time.perf_counter()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Stop timing and log metrics when exiting the context."""
        if self._start_time is not None:
            duration = time.perf_counter() - self._start_time
            self._log_metrics(duration)
        return False

    def _log_metrics(self, duration: float) -> None:
        """

        Log profiling metrics to configured backends.



        Args:

            duration (`float`):

                Execution time in seconds.

        """
        if not self.is_main_process:
            return

        metric_name = f"{self.metric_prefix}: {self.name}"
        metrics = {metric_name: duration}

        # Log to Weights & Biases if configured
        if "wandb" in self.report_to and is_wandb_available() and wandb.run is not None:
            wandb.log(metrics)

        # Log to MLflow if configured
        if "mlflow" in self.report_to and is_mlflow_available() and mlflow.active_run() is not None:
            mlflow.log_metrics(metrics, step=self.step)


def profiling_context(trainer: Trainer, name: str) -> ProfilingContext:
    """

    Factory function to create a ProfilingContext from a Trainer instance.



    This function maintains backwards compatibility with existing code while using the decoupled ProfilingContext class

    internally.



    Args:

        trainer (`~transformers.Trainer`):

            Trainer object containing configuration for logging.

        name (`str`):

            Name of the block to be profiled. Will be prefixed with the trainer class name.



    Returns:

        `ProfilingContext`: A configured profiling context manager.



    Example:

    ```python

    from transformers import Trainer

    from trl.extras.profiling import profiling_context





    class MyTrainer(Trainer):

        def some_method(self):

            A = np.random.rand(1000, 1000)

            B = np.random.rand(1000, 1000)

            with profiling_context(self, "matrix_multiplication"):

                # Code to profile: simulate a computationally expensive operation

                result = A @ B  # Matrix multiplication

    ```

    """
    context_name = f"{trainer.__class__.__name__}.{name}"
    step = trainer.state.global_step

    return ProfilingContext(
        name=context_name,
        report_to=trainer.args.report_to,
        is_main_process=trainer.accelerator.is_main_process,
        step=step,
    )


def profiling_decorator(func: Callable) -> Callable:
    """

    Decorator to profile a function and log execution time using [`extras.profiling.profiling_context`].



    This decorator works with methods that have access to a trainer instance (typically as `self`). For non-Trainer

    objects that have an `accelerator` attribute, it will use that for logging configuration.



    Args:

        func (`Callable`):

            Function to be profiled.



    Returns:

        `Callable`: Wrapped function that profiles execution time.



    Example:

    ```python

    from transformers import Trainer

    from trl.extras.profiling import profiling_decorator





    class MyTrainer(Trainer):

        @profiling_decorator

        def some_method(self):

            A = np.random.rand(1000, 1000)

            B = np.random.rand(1000, 1000)

            # Code to profile: simulate a computationally expensive operation

            result = A @ B

    ```

    """

    @functools.wraps(func)
    def wrapper(self, *args, **kwargs):
        # Check if self is a Trainer-like object with required attributes
        if hasattr(self, "state") and hasattr(self, "args"):
            with profiling_context(self, func.__name__):
                return func(self, *args, **kwargs)
        # For non-Trainer objects (e.g., VLLMGeneration), use ProfilingContext directly
        elif hasattr(self, "accelerator"):
            context_name = f"{self.__class__.__name__}.{func.__name__}"
            with ProfilingContext(
                name=context_name,
                report_to=[],  # No reporting for non-Trainer objects without args
                is_main_process=self.accelerator.is_main_process,
                step=None,
            ):
                return func(self, *args, **kwargs)
        else:
            # No profiling available, just run the function
            return func(self, *args, **kwargs)

    return wrapper