Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| import inspect | |
| from typing import Any | |
| import torch.nn as nn | |
| class InputIdentity(nn.Module): | |
| def __init__(self, input_name: str) -> None: | |
| r""" | |
| The identity operation | |
| Args: | |
| input_name (str) | |
| The name of the input this layer is associated to. For debugging | |
| purposes. | |
| """ | |
| super().__init__() | |
| self.input_name = input_name | |
| def forward(self, x): | |
| return x | |
| class ModelInputWrapper(nn.Module): | |
| def __init__(self, module_to_wrap: nn.Module) -> None: | |
| r""" | |
| This is a convenience class. This wraps a model via first feeding the | |
| model's inputs to separate layers (one for each input) and then feeding | |
| the (unmodified) inputs to the underlying model (`module_to_wrap`). Each | |
| input is fed through an `InputIdentity` layer/module. This class does | |
| not change how you feed inputs to your model, so feel free to use your | |
| model as you normally would. | |
| To access a wrapped input layer, simply access it via the `input_maps` | |
| ModuleDict, e.g. to get the corresponding module for input "x", simply | |
| provide/write `my_wrapped_module.input_maps["x"]` | |
| This is done such that one can use layer attribution methods on inputs. | |
| Which should allow you to use mix layers with inputs with these | |
| attribution methods. This is especially useful multimodal models which | |
| input discrete features (mapped to embeddings, such as text) and regular | |
| continuous feature vectors. | |
| Notes: | |
| - Since inputs are mapped with the identity, attributing to the | |
| input/feature can be done with either the input or output of the | |
| layer, e.g. attributing to an input/feature doesn't depend on whether | |
| attribute_to_layer_input is True or False for | |
| LayerIntegratedGradients. | |
| - Please refer to the multimodal tutorial or unit tests | |
| (test/attr/test_layer_wrapper.py) for an example. | |
| Args: | |
| module_to_wrap (nn.Module): | |
| The model/module you want to wrap | |
| """ | |
| super().__init__() | |
| self.module = module_to_wrap | |
| # ignore self | |
| self.arg_name_list = inspect.getfullargspec(module_to_wrap.forward).args[1:] | |
| self.input_maps = nn.ModuleDict( | |
| {arg_name: InputIdentity(arg_name) for arg_name in self.arg_name_list} | |
| ) | |
| def forward(self, *args, **kwargs) -> Any: | |
| args = list(args) | |
| for idx, (arg_name, arg) in enumerate(zip(self.arg_name_list, args)): | |
| args[idx] = self.input_maps[arg_name](arg) | |
| for arg_name in kwargs.keys(): | |
| kwargs[arg_name] = self.input_maps[arg_name](kwargs[arg_name]) | |
| return self.module(*tuple(args), **kwargs) | |