Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| # Copyright 2019 Shigeki Karita | |
| # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) | |
| """Repeat the same layer definition.""" | |
| import torch | |
| class MultiSequential(torch.nn.Sequential): | |
| """Multi-input multi-output torch.nn.Sequential.""" | |
| def __init__(self, *args, layer_drop_rate=0.0): | |
| """Initialize MultiSequential with layer_drop. | |
| Args: | |
| layer_drop_rate (float): Probability of dropping out each fn (layer). | |
| """ | |
| super(MultiSequential, self).__init__(*args) | |
| self.layer_drop_rate = layer_drop_rate | |
| def forward(self, *args): | |
| """Repeat.""" | |
| _probs = torch.empty(len(self)).uniform_() | |
| for idx, m in enumerate(self): | |
| if not self.training or (_probs[idx] >= self.layer_drop_rate): | |
| args = m(*args) | |
| return args | |
| def repeat(N, fn, layer_drop_rate=0.0): | |
| """Repeat module N times. | |
| Args: | |
| N (int): Number of repeat time. | |
| fn (Callable): Function to generate module. | |
| layer_drop_rate (float): Probability of dropping out each fn (layer). | |
| Returns: | |
| MultiSequential: Repeated model instance. | |
| """ | |
| return MultiSequential(*[fn(n) for n in range(N)], layer_drop_rate=layer_drop_rate) | |