| | import functools |
| | from inspect import signature |
| | from .common_op_utils import _basic_validation |
| |
|
| | """ |
| | Common utilities to register ops on ShardedTensor, ReplicatedTensor |
| | and PartialTensor. |
| | """ |
| |
|
| | def _register_op(op, func, op_table): |
| | """ |
| | Performs basic validation and registers the provided op in the given |
| | op_table. |
| | """ |
| | if len(signature(func).parameters) != 4: |
| | raise TypeError( |
| | f'Custom sharded op function expects signature: ' |
| | f'(types, args, kwargs, process_group), but received ' |
| | f'signature: {signature(func)}') |
| |
|
| | op_table[op] = func |
| |
|
| | def _decorator_func(wrapped_func, op, op_table): |
| | """ |
| | Decorator function to register the given ``op`` in the provided |
| | ``op_table`` |
| | """ |
| |
|
| | @functools.wraps(wrapped_func) |
| | def wrapper(types, args, kwargs, process_group): |
| | _basic_validation(op, args, kwargs) |
| | return wrapped_func(types, args, kwargs, process_group) |
| |
|
| | _register_op(op, wrapper, op_table) |
| | return wrapper |
| |
|