|
|
import torch |
|
|
import torch.nn as nn |
|
|
import src.utils as utils |
|
|
|
|
|
|
|
|
|
|
|
class FilmLayer(nn.Module): |
|
|
def __init__(self, D, C, nF, groups = 1): |
|
|
super().__init__() |
|
|
self.D = D |
|
|
self.C = C |
|
|
self.nF = nF |
|
|
self.weight = nn.Conv1d(self.D, self.C * nF, 1, groups = groups) |
|
|
self.bias = nn.Conv1d(self.D, self.C * nF, 1, groups = groups) |
|
|
|
|
|
def forward(self, x: torch.Tensor, embedding: torch.Tensor): |
|
|
""" |
|
|
x: (B, D, F, T) |
|
|
embedding: (B, D, F) |
|
|
""" |
|
|
B, D, _F, T = x.shape |
|
|
|
|
|
w = self.weight(embedding).reshape(B, self.C, _F, 1) |
|
|
b = self.bias(embedding).reshape(B, self.C, _F, 1) |
|
|
|
|
|
return x * w + b |
|
|
|
|
|
|
|
|
class LayerNormPermuted(nn.LayerNorm): |
|
|
def __init__(self, *args, **kwargs): |
|
|
super(LayerNormPermuted, self).__init__(*args, **kwargs) |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
Args: |
|
|
x: [B, C, T, F] |
|
|
""" |
|
|
x = x.permute(0, 2, 3, 1) |
|
|
x = super().forward(x) |
|
|
x = x.permute(0, 3, 1, 2) |
|
|
return x |
|
|
|
|
|
|
|
|
class Conv_Emb_Generator(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
block_model_name, |
|
|
block_model_params, |
|
|
spk_dim=256, |
|
|
n_srcs=1, |
|
|
n_fft=128, |
|
|
latent_dim=16, |
|
|
num_inputs=1, |
|
|
n_layers=6, |
|
|
use_first_ln=True, |
|
|
n_imics=1, |
|
|
lstm_fold_chunk=400, |
|
|
E=2, |
|
|
use_speaker_emb=True, |
|
|
one_emb=True, |
|
|
local_context_len=-1 |
|
|
|
|
|
): |
|
|
super().__init__() |
|
|
self.n_srcs = n_srcs |
|
|
self.n_layers = n_layers |
|
|
self.num_inputs = num_inputs |
|
|
assert n_fft % 2 == 0 |
|
|
n_freqs = n_fft // 2 + 1 |
|
|
self.n_freqs = n_freqs |
|
|
self.latent_dim = latent_dim |
|
|
|
|
|
self.use_speaker_emb=use_speaker_emb |
|
|
self.one_emb=one_emb |
|
|
|
|
|
attn_approx_qk_dim=E*n_freqs |
|
|
|
|
|
self.n_fft = n_fft |
|
|
|
|
|
self.eps=1.0e-5 |
|
|
|
|
|
t_ksize = 3 |
|
|
self.t_ksize = t_ksize |
|
|
ks, padding = (t_ksize, t_ksize), (0, 1) |
|
|
|
|
|
self.n_imics=n_imics |
|
|
if not use_speaker_emb: |
|
|
self.n_imics=self.n_imics+1 |
|
|
|
|
|
module_list = [nn.Conv2d(2*self.n_imics, latent_dim, ks, padding=padding)] |
|
|
|
|
|
if use_first_ln: |
|
|
module_list.append(LayerNormPermuted(latent_dim)) |
|
|
|
|
|
self.conv = nn.Sequential( |
|
|
*module_list |
|
|
) |
|
|
|
|
|
|
|
|
self.embeds = nn.ModuleList([]) |
|
|
|
|
|
self.local_context_len=local_context_len |
|
|
|
|
|
self.blocks = nn.ModuleList([]) |
|
|
for _i in range(n_layers-1): |
|
|
self.blocks.append(utils.import_attr(block_model_name)(emb_dim=latent_dim, n_freqs=n_freqs, approx_qk_dim=attn_approx_qk_dim, lstm_fold_chunk=lstm_fold_chunk, last=False, local_context_len=local_context_len, **block_model_params)) |
|
|
self.blocks.append(utils.import_attr(block_model_name)(emb_dim=latent_dim, n_freqs=n_freqs, approx_qk_dim=attn_approx_qk_dim, lstm_fold_chunk=lstm_fold_chunk, local_context_len=local_context_len, last=True, **block_model_params)) |
|
|
|
|
|
if self.use_speaker_emb and not self.one_emb: |
|
|
for _i in range(n_layers-1): |
|
|
self.embeds.append(FilmLayer(spk_dim, latent_dim, n_freqs, 1)) |
|
|
elif self.use_speaker_emb and self.one_emb: |
|
|
self.embeds.append(FilmLayer(spk_dim, latent_dim, n_freqs, 1)) |
|
|
|
|
|
def init_buffers(self, batch_size, device): |
|
|
conv_buf = torch.zeros(batch_size, 2*self.n_imics, self.t_ksize - 1, self.n_freqs, |
|
|
device=device) |
|
|
|
|
|
deconv_buf = torch.zeros(batch_size, self.latent_dim, self.t_ksize - 1, self.n_freqs, |
|
|
device=device) |
|
|
|
|
|
block_buffers = {} |
|
|
for i in range(len(self.blocks)): |
|
|
block_buffers[f'buf{i}'] = None |
|
|
|
|
|
return dict(conv_buf=conv_buf, deconv_buf=deconv_buf, |
|
|
block_bufs=block_buffers) |
|
|
|
|
|
def forward(self, current_input: torch.Tensor, embedding: torch.Tensor, input_state, quantized=False) -> torch.Tensor: |
|
|
""" |
|
|
B: batch, M: mic, F: freq bin, C: real/imag, T: time frame |
|
|
D: dimension of the embedding vector |
|
|
current_input: (B, CM, T, F) |
|
|
embedding: (B, D) |
|
|
output: (B, S, T, C*F) |
|
|
""" |
|
|
|
|
|
n_batch, _, n_frames, n_freqs = current_input.shape |
|
|
batch = current_input |
|
|
|
|
|
if input_state is None: |
|
|
input_state = self.init_buffers(current_input.shape[0], current_input.device) |
|
|
|
|
|
conv_buf = input_state['conv_buf'] |
|
|
gridnet_buf = input_state['block_bufs'] |
|
|
|
|
|
if quantized: |
|
|
batch = nn.functional.pad(batch, (0, 0, self.t_ksize - 1, 0)) |
|
|
else: |
|
|
batch = torch.cat((conv_buf, batch), dim=2) |
|
|
|
|
|
conv_buf = batch[:, :, -(self.t_ksize - 1):, :] |
|
|
batch = self.conv(batch) |
|
|
|
|
|
if self.use_speaker_emb: |
|
|
if not self.one_emb: |
|
|
assert len(self.blocks)==self.n_layers |
|
|
assert len(self.embeds)==self.n_layers-1 |
|
|
for ii in range(self.n_layers-1): |
|
|
batch = batch.transpose(2, 3) |
|
|
if ii > 0: |
|
|
batch = self.embeds[ii - 1](batch, embedding) |
|
|
batch = batch.transpose(2, 3) |
|
|
batch, gridnet_buf[f'buf{ii}'] = self.blocks[ii](batch, gridnet_buf[f'buf{ii}']) |
|
|
|
|
|
batch = batch.transpose(2, 3) |
|
|
batch = self.embeds[-1](batch, embedding) |
|
|
batch = batch.transpose(2, 3) |
|
|
batch, gridnet_buf[f'buf{self.n_layers-1}'] = self.blocks[self.n_layers-1](batch, gridnet_buf[f'buf{self.n_layers-1}']) |
|
|
|
|
|
else: |
|
|
assert len(self.blocks)==self.n_layers |
|
|
assert len(self.embeds)==1 |
|
|
for ii in range(self.n_layers): |
|
|
batch = batch.transpose(2, 3) |
|
|
if ii == 1: |
|
|
batch = self.embeds[ii - 1](batch, embedding) |
|
|
batch = batch.transpose(2, 3) |
|
|
batch, gridnet_buf[f'buf{ii}'] = self.blocks[ii](batch, gridnet_buf[f'buf{ii}']) |
|
|
|
|
|
else: |
|
|
assert len(self.blocks)==self.n_layers |
|
|
for ii in range(self.n_layers): |
|
|
batch, gridnet_buf[f'buf{ii}'] = self.blocks[ii](batch, gridnet_buf[f'buf{ii}']) |
|
|
|
|
|
conversation_emb=batch |
|
|
|
|
|
return conversation_emb, input_state |
|
|
|
|
|
|
|
|
def edge_mode(self): |
|
|
for i in range(len(self.blocks)): |
|
|
self.blocks[i].edge_mode() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
pass |