| from typing import List, Tuple |
|
|
| import torch |
| from torch import nn |
| from torch.utils.checkpoint import checkpoint_sequential |
|
|
| from .utils import ( |
| band_widths_from_specs, |
| check_no_gap, |
| check_no_overlap, |
| check_nonzero_bandwidth, |
| ) |
|
|
|
|
| class NormFC(nn.Module): |
| def __init__( |
| self, |
| emb_dim: int, |
| bandwidth: int, |
| in_channels: int, |
| normalize_channel_independently: bool = False, |
| treat_channel_as_feature: bool = True, |
| ) -> None: |
| super().__init__() |
|
|
| if not treat_channel_as_feature: |
| raise NotImplementedError |
|
|
| self.treat_channel_as_feature = treat_channel_as_feature |
|
|
| if normalize_channel_independently: |
| raise NotImplementedError |
|
|
| reim = 2 |
|
|
| norm = nn.LayerNorm(in_channels * bandwidth * reim) |
|
|
| fc_in = bandwidth * reim |
|
|
| if treat_channel_as_feature: |
| fc_in *= in_channels |
| else: |
| assert emb_dim % in_channels == 0 |
| emb_dim = emb_dim // in_channels |
|
|
| fc = nn.Linear(fc_in, emb_dim) |
|
|
| self.combined = nn.Sequential(norm, fc) |
|
|
| def forward(self, xb): |
| return checkpoint_sequential(self.combined, 1, xb, use_reentrant=False) |
|
|
|
|
| class BandSplitModule(nn.Module): |
| def __init__( |
| self, |
| band_specs: List[Tuple[float, float]], |
| emb_dim: int, |
| in_channels: int, |
| require_no_overlap: bool = False, |
| require_no_gap: bool = True, |
| normalize_channel_independently: bool = False, |
| treat_channel_as_feature: bool = True, |
| ) -> None: |
| super().__init__() |
|
|
| check_nonzero_bandwidth(band_specs) |
|
|
| if require_no_gap: |
| check_no_gap(band_specs) |
|
|
| if require_no_overlap: |
| check_no_overlap(band_specs) |
|
|
| self.band_specs = band_specs |
| |
| |
| self.band_widths = band_widths_from_specs(band_specs) |
| self.n_bands = len(band_specs) |
| self.emb_dim = emb_dim |
|
|
| try: |
| self.norm_fc_modules = nn.ModuleList( |
| [ |
| torch.compile( |
| NormFC( |
| emb_dim=emb_dim, |
| bandwidth=bw, |
| in_channels=in_channels, |
| normalize_channel_independently=normalize_channel_independently, |
| treat_channel_as_feature=treat_channel_as_feature, |
| ), |
| disable=True, |
| ) |
| for bw in self.band_widths |
| ] |
| ) |
| except Exception as e: |
| self.norm_fc_modules = nn.ModuleList( |
| [ |
| NormFC( |
| emb_dim=emb_dim, |
| bandwidth=bw, |
| in_channels=in_channels, |
| normalize_channel_independently=normalize_channel_independently, |
| treat_channel_as_feature=treat_channel_as_feature, |
| ) |
| for bw in self.band_widths |
| ] |
| ) |
|
|
| def forward(self, x: torch.Tensor): |
| |
|
|
| batch, in_chan, band_width, n_time = x.shape |
|
|
| z = torch.zeros( |
| size=(batch, self.n_bands, n_time, self.emb_dim), device=x.device |
| ) |
|
|
| x = torch.permute(x, (0, 3, 1, 2)).contiguous() |
|
|
| for i, nfm in enumerate(self.norm_fc_modules): |
| fstart, fend = self.band_specs[i] |
| xb = x[:, :, :, fstart:fend] |
| xb = torch.view_as_real(xb) |
| xb = torch.reshape(xb, (batch, n_time, -1)) |
| z[:, i, :, :] = nfm(xb) |
|
|
| return z |
|
|