File size: 381 Bytes
9dd3461 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
import torch
from typing import TypeVar
from contextlib import contextmanager
T = TypeVar('T')
# returns if all are the same mode
def all_same_mode(modes):
return all(tuple(mode == modes[0] for mode in modes))
@contextmanager
def no_dispatch():
guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
try:
yield
finally:
del guard
|