| | import inspect |
| | import torch |
| |
|
| | from detectron2.utils.env import TORCH_VERSION |
| |
|
| | try: |
| | from torch.fx._symbolic_trace import is_fx_tracing as is_fx_tracing_current |
| |
|
| | tracing_current_exists = True |
| | except ImportError: |
| | tracing_current_exists = False |
| |
|
| | try: |
| | from torch.fx._symbolic_trace import _orig_module_call |
| |
|
| | tracing_legacy_exists = True |
| | except ImportError: |
| | tracing_legacy_exists = False |
| |
|
| |
|
| | @torch.jit.ignore |
| | def is_fx_tracing_legacy() -> bool: |
| | """ |
| | Returns a bool indicating whether torch.fx is currently symbolically tracing a module. |
| | Can be useful for gating module logic that is incompatible with symbolic tracing. |
| | """ |
| | return torch.nn.Module.__call__ is not _orig_module_call |
| |
|
| |
|
| | def is_fx_tracing() -> bool: |
| | """Returns whether execution is currently in |
| | Torch FX tracing mode""" |
| | if torch.jit.is_scripting(): |
| | return False |
| | if TORCH_VERSION >= (1, 10) and tracing_current_exists: |
| | return is_fx_tracing_current() |
| | elif tracing_legacy_exists: |
| | return is_fx_tracing_legacy() |
| | else: |
| | |
| | |
| | return False |
| |
|
| |
|
| | def assert_fx_safe(condition: bool, message: str) -> torch.Tensor: |
| | """An FX-tracing safe version of assert. |
| | Avoids erroneous type assertion triggering when types are masked inside |
| | an fx.proxy.Proxy object during tracing. |
| | Args: condition - either a boolean expression or a string representing |
| | the condition to test. If this assert triggers an exception when tracing |
| | due to dynamic control flow, try encasing the expression in quotation |
| | marks and supplying it as a string.""" |
| | |
| | |
| | if torch.jit.is_scripting() or is_fx_tracing(): |
| | return torch.zeros(1) |
| | return _do_assert_fx_safe(condition, message) |
| |
|
| |
|
| | def _do_assert_fx_safe(condition: bool, message: str) -> torch.Tensor: |
| | try: |
| | if isinstance(condition, str): |
| | caller_frame = inspect.currentframe().f_back |
| | torch._assert(eval(condition, caller_frame.f_globals, caller_frame.f_locals), message) |
| | return torch.ones(1) |
| | else: |
| | torch._assert(condition, message) |
| | return torch.ones(1) |
| | except torch.fx.proxy.TraceError as e: |
| | print( |
| | "Found a non-FX compatible assertion. Skipping the check. Failure is shown below" |
| | + str(e) |
| | ) |
| |
|