Spaces:
Running
on
T4
Running
on
T4
| import torch | |
| import typing as tp | |
| from .transformer import StreamingTransformer, create_sin_embedding | |
| class UnetTransformer(StreamingTransformer): | |
| """U-net Transformer for processing sequences with optional skip connections. | |
| This transformer architecture incorporates U-net style skip connections | |
| between layers, which can be optionally enabled. It inherits from a | |
| StreamingTransformer. | |
| Args: | |
| d_model (int): Dimension of the model, typically the number of expected features in the input. | |
| num_layers (int): Total number of layers in the transformer. | |
| skip_connections (bool, optional): Flag to determine whether skip connections should be used. | |
| Defaults to False. | |
| layer_dropout_p (float, Optional): if given, defined bernoulli prob. to drop a skip connection (in training). | |
| **kwargs: Additional keyword arguments inherited from `nn.StreamingTransformer`. | |
| """ | |
| def __init__(self, d_model: int, num_layers: int, skip_connections: bool = False, | |
| layer_dropout_p: tp.Optional[float] = None, **kwargs): | |
| super().__init__(d_model=d_model, | |
| num_layers=num_layers, | |
| **kwargs) | |
| self.skip_connect = skip_connections | |
| if self.skip_connect: | |
| self.skip_projections = torch.nn.ModuleList([torch.nn.Linear(d_model * 2, d_model) | |
| for _ in range(num_layers // 2)]) | |
| self.num_layers = num_layers | |
| self.layer_drop_p = max(min(layer_dropout_p, 1.), 0.) if layer_dropout_p is not None else 0.0 | |
| def forward(self, x: torch.Tensor, *args, **kwargs): | |
| B, T, C = x.shape | |
| if 'offsets' in self._streaming_state: | |
| offsets = self._streaming_state['offsets'] | |
| else: | |
| offsets = torch.zeros(B, dtype=torch.long, device=x.device) | |
| if self.positional_embedding in ['sin', 'sin_rope']: | |
| positions = torch.arange(T, device=x.device).view(1, -1, 1) | |
| positions = positions + offsets.view(-1, 1, 1) | |
| pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype) | |
| x = x + self.positional_scale * pos_emb | |
| skip_connections: tp.List[torch.Tensor] = [] | |
| for i, layer in enumerate(self.layers): | |
| if self.skip_connect and i >= self.num_layers // 2: | |
| # in the second half of the layers, add residual connection | |
| # and linearly project the concatenated features back to d_model | |
| x = torch.cat([x, skip_connections.pop()], dim=-1) | |
| x = self.skip_projections[i % len(self.skip_projections)](x) | |
| x = self._apply_layer(layer, x, *args, **kwargs) | |
| if self.skip_connect and i < self.num_layers // 2: | |
| if self.training and torch.rand(1,) < self.layer_drop_p: # drop skip | |
| skip_connections.append(torch.zeros_like(x)) | |
| else: | |
| skip_connections.append(x) | |
| if self._is_streaming: | |
| self._streaming_state['offsets'] = offsets + T | |
| return x | |