File size: 7,585 Bytes
1fa3c6c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 | # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import time
from collections.abc import Callable
from transformers import Trainer
from transformers.integrations import is_mlflow_available, is_wandb_available
if is_wandb_available():
import wandb
if is_mlflow_available():
import mlflow
class ProfilingContext:
"""
Context manager for profiling code blocks with configurable logging.
This class handles timing of code execution and logging metrics to various backends (Weights & Biases, MLflow)
without being coupled to the Trainer class.
Args:
name (`str`):
Name of the profiling context. Used in the metric name.
report_to (`list` of `str`):
List of integrations to report metrics to (e.g., ["wandb", "mlflow"]).
is_main_process (`bool`, *optional*, defaults to `True`):
Whether this is the main process in distributed training. Metrics are only logged from the main process.
step (`int` or `None`, *optional*):
Training step to associate with the logged metrics.
metric_prefix (`str`, *optional*, defaults to `"profiling/Time taken"`):
Prefix for the metric name in logs.
Example:
```python
# Direct usage
from trl.extras.profiling import ProfilingContext
with ProfilingContext(
name="MyClass.expensive_operation",
report_to=["wandb"],
is_main_process=True,
step=100,
):
# Code to profile
result = expensive_computation()
# With Trainer (backwards compatible via profiling_context function)
from transformers import Trainer
from trl.extras.profiling import profiling_context
class MyTrainer(Trainer):
def some_method(self):
with profiling_context(self, "matrix_multiplication"):
result = matrix_multiply()
```
"""
def __init__(
self,
name: str,
report_to: list[str],
is_main_process: bool = True,
step: int | None = None,
metric_prefix: str = "profiling/Time taken",
):
self.name = name
self.report_to = report_to
self.is_main_process = is_main_process
self.step = step
self.metric_prefix = metric_prefix
self._start_time = None
def __enter__(self):
"""Start timing when entering the context."""
self._start_time = time.perf_counter()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Stop timing and log metrics when exiting the context."""
if self._start_time is not None:
duration = time.perf_counter() - self._start_time
self._log_metrics(duration)
return False
def _log_metrics(self, duration: float) -> None:
"""
Log profiling metrics to configured backends.
Args:
duration (`float`):
Execution time in seconds.
"""
if not self.is_main_process:
return
metric_name = f"{self.metric_prefix}: {self.name}"
metrics = {metric_name: duration}
# Log to Weights & Biases if configured
if "wandb" in self.report_to and is_wandb_available() and wandb.run is not None:
wandb.log(metrics)
# Log to MLflow if configured
if "mlflow" in self.report_to and is_mlflow_available() and mlflow.active_run() is not None:
mlflow.log_metrics(metrics, step=self.step)
def profiling_context(trainer: Trainer, name: str) -> ProfilingContext:
"""
Factory function to create a ProfilingContext from a Trainer instance.
This function maintains backwards compatibility with existing code while using the decoupled ProfilingContext class
internally.
Args:
trainer (`~transformers.Trainer`):
Trainer object containing configuration for logging.
name (`str`):
Name of the block to be profiled. Will be prefixed with the trainer class name.
Returns:
`ProfilingContext`: A configured profiling context manager.
Example:
```python
from transformers import Trainer
from trl.extras.profiling import profiling_context
class MyTrainer(Trainer):
def some_method(self):
A = np.random.rand(1000, 1000)
B = np.random.rand(1000, 1000)
with profiling_context(self, "matrix_multiplication"):
# Code to profile: simulate a computationally expensive operation
result = A @ B # Matrix multiplication
```
"""
context_name = f"{trainer.__class__.__name__}.{name}"
step = trainer.state.global_step
return ProfilingContext(
name=context_name,
report_to=trainer.args.report_to,
is_main_process=trainer.accelerator.is_main_process,
step=step,
)
def profiling_decorator(func: Callable) -> Callable:
"""
Decorator to profile a function and log execution time using [`extras.profiling.profiling_context`].
This decorator works with methods that have access to a trainer instance (typically as `self`). For non-Trainer
objects that have an `accelerator` attribute, it will use that for logging configuration.
Args:
func (`Callable`):
Function to be profiled.
Returns:
`Callable`: Wrapped function that profiles execution time.
Example:
```python
from transformers import Trainer
from trl.extras.profiling import profiling_decorator
class MyTrainer(Trainer):
@profiling_decorator
def some_method(self):
A = np.random.rand(1000, 1000)
B = np.random.rand(1000, 1000)
# Code to profile: simulate a computationally expensive operation
result = A @ B
```
"""
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
# Check if self is a Trainer-like object with required attributes
if hasattr(self, "state") and hasattr(self, "args"):
with profiling_context(self, func.__name__):
return func(self, *args, **kwargs)
# For non-Trainer objects (e.g., VLLMGeneration), use ProfilingContext directly
elif hasattr(self, "accelerator"):
context_name = f"{self.__class__.__name__}.{func.__name__}"
with ProfilingContext(
name=context_name,
report_to=[], # No reporting for non-Trainer objects without args
is_main_process=self.accelerator.is_main_process,
step=None,
):
return func(self, *args, **kwargs)
else:
# No profiling available, just run the function
return func(self, *args, **kwargs)
return wrapper
|