| |
| |
|
|
| import copy |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| def MultiwayWrapper(args, module, dim=1): |
| if args.multiway: |
| return MultiwayNetwork(module, dim=dim) |
| return module |
|
|
|
|
| def set_split_position(position): |
| def apply_fn(module): |
| if hasattr(module, "split_position"): |
| module.split_position = position |
|
|
| return apply_fn |
|
|
|
|
| class MultiwayNetwork(nn.Module): |
| def __init__(self, module, dim=1): |
| super().__init__() |
| self.dim = dim |
| self.A = module |
| self.B = copy.deepcopy(module) |
| self.B.reset_parameters() |
| self.split_position = -1 |
|
|
| def forward(self, x, **kwargs): |
| if self.split_position == -1: |
| return self.A(x, **kwargs) |
| if self.split_position == 0: |
| return self.B(x, **kwargs) |
| x1, x2 = torch.split( |
| x, |
| [self.split_position, x.size(self.dim) - self.split_position], |
| dim=self.dim, |
| ) |
| |
| y1, y2 = self.A(x1, **kwargs), self.B(x2, **kwargs) |
| return torch.cat([y1, y2], dim=self.dim) |
|
|
|
|
| class MutliwayEmbedding(MultiwayNetwork): |
| def __init__(self, modules, dim=1): |
| super(MultiwayNetwork, self).__init__() |
| self.dim = dim |
| assert len(modules) == 2 |
| self.A = modules[0] |
| self.B = modules[1] |
| self.split_position = -1 |
|
|