|
|
|
|
|
import importlib |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
lib = torch.library.Library("export", "FRAGMENT") |
|
|
|
|
|
lib.define( |
|
|
"access_subclass_inner_tensor(Tensor src_subclass_tensor, str attr) -> Tensor" |
|
|
) |
|
|
|
|
|
|
|
|
@torch.library.impl(lib, "access_subclass_inner_tensor", "Autograd") |
|
|
|
|
|
|
|
|
@torch.library.impl(lib, "access_subclass_inner_tensor", "Python") |
|
|
def _access_subclass_inner_tensor( |
|
|
src_subclass_tensor: torch.Tensor, attr: str |
|
|
) -> torch.Tensor: |
|
|
from torch.utils._python_dispatch import is_traceable_wrapper_subclass |
|
|
|
|
|
assert is_traceable_wrapper_subclass(src_subclass_tensor) |
|
|
val = getattr(src_subclass_tensor, attr, None) |
|
|
if val is None or not isinstance(val, torch.Tensor): |
|
|
raise RuntimeError( |
|
|
f"Attribute {attr} is not a tensor or doesn't exist in {src_subclass_tensor}" |
|
|
) |
|
|
return val |
|
|
|
|
|
|
|
|
def _call_custom_autograd_function_in_pre_dispatch(function_cls_name, *args, **kwargs): |
|
|
""" |
|
|
Import a custom autograd function by string name and call it. This is pretty bad |
|
|
because: |
|
|
1) There is no schema |
|
|
|
|
|
Ideally we should automatically wrap custom autograd functions with a custom op, but |
|
|
that is too much work because we need to schematize custom autograd functions. For now, |
|
|
we just hackily put it in the IR. |
|
|
""" |
|
|
|
|
|
module_name, class_name = function_cls_name.rsplit(".", 1) |
|
|
|
|
|
|
|
|
module = importlib.import_module(module_name) |
|
|
function_cls = getattr(module, class_name) |
|
|
assert hasattr(function_cls, "apply") |
|
|
return function_cls.apply(*args, **kwargs) |
|
|
|