| """ | |
| Logging util | |
| @Author: penhe@microsoft.com | |
| """ | |
| """ Utils for torch jit tracing customer operators/functions | |
| """ | |
| import os | |
| def traceable(cls): | |
| """ Decorator over customer functions | |
| There is an issue for tracing customer python torch Function, using this decorator to work around it. | |
| e.g. | |
| @traceable | |
| class MyOp(torch.autograd.Function): | |
| xxx | |
| """ | |
| class _Function(object): | |
| def apply(*args): | |
| jit_trace = (os.getenv('JIT_TRACE', 'False').lower() == 'true') | |
| if jit_trace: | |
| return cls.forward(_Function, *args) | |
| else: | |
| return cls.apply(*args) | |
| def save_for_backward(*args): | |
| pass | |
| _Function.__name__ = cls.__name__ | |
| _Function.__doc__ = cls.__doc__ | |
| return _Function | |
| class TraceMode(): | |
| """ Trace context used when tracing modules contains customer operators/Functions | |
| """ | |
| def __enter__(self): | |
| os.environ['JIT_TRACE'] = 'True' | |
| return self | |
| def __exit__(self, exp_value, exp_type, trace): | |
| del os.environ['JIT_TRACE'] | |