| # mypy: allow-untyped-defs | |
| from torch.fx.proxy import Proxy | |
| from ._compatibility import compatibility | |
| def annotate(val, type): | |
| """ | |
| Annotates a Proxy object with a given type. | |
| This function annotates a val with a given type if a type of the val is a torch.fx.Proxy object | |
| Args: | |
| val (object): An object to be annotated if its type is torch.fx.Proxy. | |
| type (object): A type to be assigned to a given proxy object as val. | |
| Returns: | |
| The given val. | |
| Raises: | |
| RuntimeError: If a val already has a type in its node. | |
| """ | |
| if isinstance(val, Proxy): | |
| if val.node.type: | |
| raise RuntimeError( | |
| f"Tried to annotate a value that already had a type on it!" | |
| f" Existing type is {val.node.type} " | |
| f"and new type is {type}. " | |
| f"This could happen if you tried to annotate a function parameter " | |
| f"value (in which case you should use the type slot " | |
| f"on the function signature) or you called " | |
| f"annotate on the same value twice" | |
| ) | |
| else: | |
| val.node.type = type | |
| return val | |
| else: | |
| return val | |