Buckets:
| # Copyright (c) 2025 SandAI. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from functools import wraps | |
| from typing import Any, Callable, TypeVar, cast | |
| import torch | |
| # fixed the mypy type check missing bug | |
| # when a func is wrapped | |
| # issue: https://stackoverflow.com/questions/65621789/mypy-untyped-decorator-makes-function-my-method-untyped | |
| F = TypeVar("F", bound=Callable[..., Any]) | |
| # global var for torch.autograd.profiler.emit_nvtx | |
| _EMIT_NVTX_CTX: None | torch.autograd.profiler.emit_nvtx = None | |
| def nvtx_range_push(event_name: str) -> None: | |
| """torch.ops.magi_compile.nvtx_range_push""" | |
| torch.cuda.nvtx.range_push(event_name) | |
| def _(event_name: str) -> None: | |
| pass | |
| def nvtx_range_pop() -> None: | |
| """torch.ops.magi_compile.nvtx_range_pop""" | |
| torch.cuda.nvtx.range_pop() | |
| def _() -> None: | |
| pass | |
| # NOTE: since torch.compile does not support @contextlib.contextmanager, | |
| # we use the class-based context manager | |
| class add_nvtx_event: | |
| """ | |
| Context manager to add an NVTX event around a code block. | |
| Args: | |
| event_name (str): The name of the event to be recorded. | |
| """ | |
| def __init__(self, event_name: str): | |
| self.enter_name = event_name | |
| def __enter__(self): | |
| if torch.compiler.is_compiling(): | |
| # NOTE: torch.compile supports neither retrieving the attributes from "self" | |
| # nor modifying a variable not in the current scope | |
| # so we have no choice but assign a constant event name when compiling | |
| nvtx_range_push("torch compile region") | |
| else: | |
| torch.cuda.nvtx.range_push(self.enter_name) | |
| return self | |
| def __exit__(self, *excinfo): | |
| if torch.compiler.is_compiling(): | |
| nvtx_range_pop() | |
| else: | |
| torch.cuda.nvtx.range_pop() | |
| def instrument_nvtx(func: F) -> F: | |
| """ | |
| Decorator that records an NVTX range for the duration of the function call. | |
| Args: | |
| func (Callable): The function to be decorated. | |
| Returns: | |
| Callable: The wrapped function that is now being profiled. | |
| """ | |
| def wrapped_fn(*args, **kwargs): | |
| if torch.compiler.is_compiling(): | |
| # NOTE: we can not access func.__qualname__ when compiling | |
| # thus use func.__name__ instead | |
| func_name = func.__name__ | |
| else: | |
| func_name = func.__qualname__ | |
| with add_nvtx_event(func_name): | |
| ret_val = func(*args, **kwargs) | |
| return ret_val | |
| return cast(F, wrapped_fn) | |
Xet Storage Details
- Size:
- 3.26 kB
- Xet hash:
- 73fe249f8705b03091fba316a13acf7d2b6c9eae99b3cf21a129b3b473219b0f
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.