| | import functools
|
| | import numpy as np
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| |
|
| | from sgmse.util.registry import Registry
|
| |
|
| |
|
| | BackboneRegistry = Registry("Backbone")
|
| |
|
| |
|
| | class GaussianFourierProjection(nn.Module):
|
| | """Gaussian random features for encoding time steps."""
|
| |
|
| | def __init__(self, embed_dim, scale=16, complex_valued=False):
|
| | super().__init__()
|
| | self.complex_valued = complex_valued
|
| | if not complex_valued:
|
| |
|
| |
|
| |
|
| |
|
| | embed_dim = embed_dim // 2
|
| |
|
| |
|
| | self.W = nn.Parameter(torch.randn(embed_dim) * scale, requires_grad=False)
|
| |
|
| | def forward(self, t):
|
| | t_proj = t[:, None] * self.W[None, :] * 2*np.pi
|
| | if self.complex_valued:
|
| | return torch.exp(1j * t_proj)
|
| | else:
|
| | return torch.cat([torch.sin(t_proj), torch.cos(t_proj)], dim=-1)
|
| |
|
| |
|
| | class DiffusionStepEmbedding(nn.Module):
|
| | """Diffusion-Step embedding as in DiffWave / Vaswani et al. 2017."""
|
| |
|
| | def __init__(self, embed_dim, complex_valued=False):
|
| | super().__init__()
|
| | self.complex_valued = complex_valued
|
| | if not complex_valued:
|
| |
|
| |
|
| |
|
| |
|
| | embed_dim = embed_dim // 2
|
| | self.embed_dim = embed_dim
|
| |
|
| | def forward(self, t):
|
| | fac = 10**(4*torch.arange(self.embed_dim, device=t.device) / (self.embed_dim-1))
|
| | inner = t[:, None] * fac[None, :]
|
| | if self.complex_valued:
|
| | return torch.exp(1j * inner)
|
| | else:
|
| | return torch.cat([torch.sin(inner), torch.cos(inner)], dim=-1)
|
| |
|
| |
|
| | class ComplexLinear(nn.Module):
|
| | """A potentially complex-valued linear layer. Reduces to a regular linear layer if `complex_valued=False`."""
|
| | def __init__(self, input_dim, output_dim, complex_valued):
|
| | super().__init__()
|
| | self.complex_valued = complex_valued
|
| | if self.complex_valued:
|
| | self.re = nn.Linear(input_dim, output_dim)
|
| | self.im = nn.Linear(input_dim, output_dim)
|
| | else:
|
| | self.lin = nn.Linear(input_dim, output_dim)
|
| |
|
| | def forward(self, x):
|
| | if self.complex_valued:
|
| | return (self.re(x.real) - self.im(x.imag)) + 1j*(self.re(x.imag) + self.im(x.real))
|
| | else:
|
| | return self.lin(x)
|
| |
|
| |
|
| | class FeatureMapDense(nn.Module):
|
| | """A fully connected layer that reshapes outputs to feature maps."""
|
| |
|
| | def __init__(self, input_dim, output_dim, complex_valued=False):
|
| | super().__init__()
|
| | self.complex_valued = complex_valued
|
| | self.dense = ComplexLinear(input_dim, output_dim, complex_valued=complex_valued)
|
| |
|
| | def forward(self, x):
|
| | return self.dense(x)[..., None, None]
|
| |
|
| |
|
| | def torch_complex_from_reim(re, im):
|
| | return torch.view_as_complex(torch.stack([re, im], dim=-1))
|
| |
|
| |
|
| | class ArgsComplexMultiplicationWrapper(nn.Module):
|
| | """Adapted from `asteroid`'s `complex_nn.py`, allowing args/kwargs to be passed through forward().
|
| |
|
| | Make a complex-valued module `F` from a real-valued module `f` by applying
|
| | complex multiplication rules:
|
| |
|
| | F(a + i b) = f1(a) - f1(b) + i (f2(b) + f2(a))
|
| |
|
| | where `f1`, `f2` are instances of `f` that do *not* share weights.
|
| |
|
| | Args:
|
| | module_cls (callable): A class or function that returns a Torch module/functional.
|
| | Constructor of `f` in the formula above. Called 2x with `*args`, `**kwargs`,
|
| | to construct the real and imaginary component modules.
|
| | """
|
| |
|
| | def __init__(self, module_cls, *args, **kwargs):
|
| | super().__init__()
|
| | self.re_module = module_cls(*args, **kwargs)
|
| | self.im_module = module_cls(*args, **kwargs)
|
| |
|
| | def forward(self, x, *args, **kwargs):
|
| | return torch_complex_from_reim(
|
| | self.re_module(x.real, *args, **kwargs) - self.im_module(x.imag, *args, **kwargs),
|
| | self.re_module(x.imag, *args, **kwargs) + self.im_module(x.real, *args, **kwargs),
|
| | )
|
| |
|
| |
|
| | ComplexConv2d = functools.partial(ArgsComplexMultiplicationWrapper, nn.Conv2d)
|
| | ComplexConvTranspose2d = functools.partial(ArgsComplexMultiplicationWrapper, nn.ConvTranspose2d)
|
| |
|