guilinhu's picture
Upload folder using huggingface_hub
df9f13e verified
import torch
import torch.nn as nn
import src.utils as utils
# from src.models.common.film import FiLM
class FilmLayer(nn.Module):
def __init__(self, D, C, nF, groups = 1):
super().__init__()
self.D = D # speaker dim 256
self.C = C # latent dim 16
self.nF = nF
self.weight = nn.Conv1d(self.D, self.C * nF, 1, groups = groups)
self.bias = nn.Conv1d(self.D, self.C * nF, 1, groups = groups)
def forward(self, x: torch.Tensor, embedding: torch.Tensor):
"""
x: (B, D, F, T)
embedding: (B, D, F)
"""
B, D, _F, T = x.shape
w = self.weight(embedding).reshape(B, self.C, _F, 1) # (B, C, F, 1)
b = self.bias(embedding).reshape(B, self.C, _F, 1) # (B, C, F, 1)
return x * w + b
class LayerNormPermuted(nn.LayerNorm):
def __init__(self, *args, **kwargs):
super(LayerNormPermuted, self).__init__(*args, **kwargs)
def forward(self, x):
"""
Args:
x: [B, C, T, F]
"""
x = x.permute(0, 2, 3, 1) # [B, T, F, C]
x = super().forward(x)
x = x.permute(0, 3, 1, 2) # [B, C, T, F]
return x
class Conv_Emb_Generator(nn.Module):
def __init__(
self,
block_model_name,
block_model_params,
spk_dim=256,
n_srcs=1,
n_fft=128,
latent_dim=16,
num_inputs=1,
n_layers=6,
use_first_ln=True,
n_imics=1,
lstm_fold_chunk=400,
E=2,
use_speaker_emb=True,
one_emb=True,
local_context_len=-1
# 6
):
super().__init__()
self.n_srcs = n_srcs
self.n_layers = n_layers
self.num_inputs = num_inputs
assert n_fft % 2 == 0
n_freqs = n_fft // 2 + 1
self.n_freqs = n_freqs
self.latent_dim = latent_dim
self.use_speaker_emb=use_speaker_emb
self.one_emb=one_emb
attn_approx_qk_dim=E*n_freqs
self.n_fft = n_fft
self.eps=1.0e-5
t_ksize = 3
self.t_ksize = t_ksize
ks, padding = (t_ksize, t_ksize), (0, 1)
self.n_imics=n_imics
if not use_speaker_emb:
self.n_imics=self.n_imics+1
module_list = [nn.Conv2d(2*self.n_imics, latent_dim, ks, padding=padding)]
if use_first_ln:
module_list.append(LayerNormPermuted(latent_dim))
self.conv = nn.Sequential(
*module_list
)
# FiLM layer
self.embeds = nn.ModuleList([])
self.local_context_len=local_context_len
self.blocks = nn.ModuleList([])
for _i in range(n_layers-1):
self.blocks.append(utils.import_attr(block_model_name)(emb_dim=latent_dim, n_freqs=n_freqs, approx_qk_dim=attn_approx_qk_dim, lstm_fold_chunk=lstm_fold_chunk, last=False, local_context_len=local_context_len, **block_model_params))
self.blocks.append(utils.import_attr(block_model_name)(emb_dim=latent_dim, n_freqs=n_freqs, approx_qk_dim=attn_approx_qk_dim, lstm_fold_chunk=lstm_fold_chunk, local_context_len=local_context_len, last=True, **block_model_params))
if self.use_speaker_emb and not self.one_emb:
for _i in range(n_layers-1):
self.embeds.append(FilmLayer(spk_dim, latent_dim, n_freqs, 1))
elif self.use_speaker_emb and self.one_emb:
self.embeds.append(FilmLayer(spk_dim, latent_dim, n_freqs, 1))
def init_buffers(self, batch_size, device):
conv_buf = torch.zeros(batch_size, 2*self.n_imics, self.t_ksize - 1, self.n_freqs,
device=device)
deconv_buf = torch.zeros(batch_size, self.latent_dim, self.t_ksize - 1, self.n_freqs,
device=device)
block_buffers = {}
for i in range(len(self.blocks)):
block_buffers[f'buf{i}'] = None
return dict(conv_buf=conv_buf, deconv_buf=deconv_buf,
block_bufs=block_buffers)
def forward(self, current_input: torch.Tensor, embedding: torch.Tensor, input_state, quantized=False) -> torch.Tensor:
"""
B: batch, M: mic, F: freq bin, C: real/imag, T: time frame
D: dimension of the embedding vector
current_input: (B, CM, T, F)
embedding: (B, D)
output: (B, S, T, C*F)
"""
# [B, C, T, F]
n_batch, _, n_frames, n_freqs = current_input.shape
batch = current_input
if input_state is None:
input_state = self.init_buffers(current_input.shape[0], current_input.device)
conv_buf = input_state['conv_buf']
gridnet_buf = input_state['block_bufs']
if quantized:
batch = nn.functional.pad(batch, (0, 0, self.t_ksize - 1, 0))
else:
batch = torch.cat((conv_buf, batch), dim=2)
conv_buf = batch[:, :, -(self.t_ksize - 1):, :]
batch = self.conv(batch) # [B, D, T, F]
if self.use_speaker_emb:
if not self.one_emb:
assert len(self.blocks)==self.n_layers
assert len(self.embeds)==self.n_layers-1
for ii in range(self.n_layers-1):
batch = batch.transpose(2, 3)
if ii > 0:
batch = self.embeds[ii - 1](batch, embedding)
batch = batch.transpose(2, 3)
batch, gridnet_buf[f'buf{ii}'] = self.blocks[ii](batch, gridnet_buf[f'buf{ii}'])
batch = batch.transpose(2, 3)
batch = self.embeds[-1](batch, embedding)
batch = batch.transpose(2, 3)
batch, gridnet_buf[f'buf{self.n_layers-1}'] = self.blocks[self.n_layers-1](batch, gridnet_buf[f'buf{self.n_layers-1}'])
else:
assert len(self.blocks)==self.n_layers
assert len(self.embeds)==1
for ii in range(self.n_layers):
batch = batch.transpose(2, 3)
if ii == 1:
batch = self.embeds[ii - 1](batch, embedding)
batch = batch.transpose(2, 3)
batch, gridnet_buf[f'buf{ii}'] = self.blocks[ii](batch, gridnet_buf[f'buf{ii}'])
else:
assert len(self.blocks)==self.n_layers
for ii in range(self.n_layers):
batch, gridnet_buf[f'buf{ii}'] = self.blocks[ii](batch, gridnet_buf[f'buf{ii}'])
conversation_emb=batch
return conversation_emb, input_state
def edge_mode(self):
for i in range(len(self.blocks)):
self.blocks[i].edge_mode()
if __name__ == "__main__":
pass