Upload CoLMbo model weights and code
Browse files- config.json +17 -0
- encoder/__pycache__/attentive_pooling.cpython-310.pyc +0 -0
- encoder/__pycache__/attentive_pooling.cpython-38.pyc +0 -0
- encoder/__pycache__/encoder.cpython-310.pyc +0 -0
- encoder/__pycache__/encoder.cpython-38.pyc +0 -0
- encoder/__pycache__/encoder.cpython-39.pyc +0 -0
- encoder/__pycache__/mha.cpython-310.pyc +0 -0
- encoder/__pycache__/mha.cpython-38.pyc +0 -0
- encoder/__pycache__/self_attn.cpython-310.pyc +0 -0
- encoder/__pycache__/self_attn.cpython-38.pyc +0 -0
- encoder/attentive_pooling.py +33 -0
- encoder/encoder.py +35 -0
- encoder/mha.py +62 -0
- encoder/self_attn.py +81 -0
- load_data/__pycache__/combineddataset.cpython-38.pyc +0 -0
- load_data/__pycache__/data_collactor.cpython-310.pyc +0 -0
- load_data/__pycache__/data_collactor.cpython-38.pyc +0 -0
- load_data/__pycache__/dataset.cpython-38.pyc +0 -0
- load_data/__pycache__/extract_fbanks.cpython-310.pyc +0 -0
- load_data/__pycache__/extract_fbanks.cpython-38.pyc +0 -0
- load_data/__pycache__/prepare_dataloader.cpython-310.pyc +0 -0
- load_data/__pycache__/prepare_dataloader.cpython-38.pyc +0 -0
- load_data/__pycache__/tears.cpython-38.pyc +0 -0
- load_data/__pycache__/timit.cpython-38.pyc +0 -0
- load_data/__pycache__/voxceleb.cpython-38.pyc +0 -0
- load_data/combineddataset.py +29 -0
- load_data/data_collactor.py +74 -0
- load_data/dataset.py +109 -0
- load_data/extract_fbanks.py +55 -0
- load_data/prepare_dataloader.py +22 -0
- load_data/tears.py +232 -0
- load_data/timit.py +102 -0
- load_data/voxceleb.py +63 -0
- mapper.py +245 -0
- pytorch_model.bin +3 -0
- wrapper.py +305 -0
config.json
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "colmbo",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"CoLMboModel"
|
| 5 |
+
],
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoConfig": "modeling_colmbo.CoLMboConfig",
|
| 8 |
+
"AutoModel": "modeling_colmbo.CoLMboModel"
|
| 9 |
+
},
|
| 10 |
+
"n_mels": 80,
|
| 11 |
+
"embedding_dim": 192,
|
| 12 |
+
"channel": 1024,
|
| 13 |
+
"prefix_length": 10,
|
| 14 |
+
"gpt_model_name": "gpt2",
|
| 15 |
+
"sample_rate": 16000,
|
| 16 |
+
"torch_dtype": "float32"
|
| 17 |
+
}
|
encoder/__pycache__/attentive_pooling.cpython-310.pyc
ADDED
|
Binary file (1.54 kB). View file
|
|
|
encoder/__pycache__/attentive_pooling.cpython-38.pyc
ADDED
|
Binary file (1.53 kB). View file
|
|
|
encoder/__pycache__/encoder.cpython-310.pyc
ADDED
|
Binary file (1.64 kB). View file
|
|
|
encoder/__pycache__/encoder.cpython-38.pyc
ADDED
|
Binary file (1.66 kB). View file
|
|
|
encoder/__pycache__/encoder.cpython-39.pyc
ADDED
|
Binary file (1.63 kB). View file
|
|
|
encoder/__pycache__/mha.cpython-310.pyc
ADDED
|
Binary file (2.21 kB). View file
|
|
|
encoder/__pycache__/mha.cpython-38.pyc
ADDED
|
Binary file (2.22 kB). View file
|
|
|
encoder/__pycache__/self_attn.cpython-310.pyc
ADDED
|
Binary file (3.71 kB). View file
|
|
|
encoder/__pycache__/self_attn.cpython-38.pyc
ADDED
|
Binary file (3.74 kB). View file
|
|
|
encoder/attentive_pooling.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
class SelfAttentionPooling(nn.Module):
|
| 5 |
+
"""
|
| 6 |
+
Implementation of SelfAttentionPooling
|
| 7 |
+
Original Paper: Self-Attention Encoding and Pooling for Speaker Recognition
|
| 8 |
+
https://arxiv.org/pdf/2008.01077v1.pdf
|
| 9 |
+
"""
|
| 10 |
+
def __init__(self, input_dim):
|
| 11 |
+
super(SelfAttentionPooling, self).__init__()
|
| 12 |
+
self.W = nn.Linear(input_dim, 1)
|
| 13 |
+
def forward(self, batch_rep, att_mask):
|
| 14 |
+
"""
|
| 15 |
+
input:
|
| 16 |
+
batch_rep : size (N, T, H), N: batch size, T: sequence length, H: Hidden dimension
|
| 17 |
+
|
| 18 |
+
attention_weight:
|
| 19 |
+
att_w : size (N, T, 1)
|
| 20 |
+
|
| 21 |
+
return:
|
| 22 |
+
utter_rep: size (N, H)
|
| 23 |
+
"""
|
| 24 |
+
seq_len = batch_rep.shape[1]
|
| 25 |
+
softmax = nn.functional.softmax
|
| 26 |
+
att_logits = self.W(batch_rep).squeeze(-1)
|
| 27 |
+
att_mask = att_mask[:, :, 0]
|
| 28 |
+
att_logits = att_mask + att_logits
|
| 29 |
+
att_w = softmax(att_logits, dim=-1).unsqueeze(-1)
|
| 30 |
+
utter_rep = torch.sum(batch_rep * att_w, dim=1)
|
| 31 |
+
attn_out_std = torch.sqrt(torch.sum(att_w * (batch_rep - utter_rep.unsqueeze(1))**2, dim=1))
|
| 32 |
+
|
| 33 |
+
return utter_rep, attn_out_std
|
encoder/encoder.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from speechbrain.lobes.models.ECAPA_TDNN import ECAPA_TDNN
|
| 3 |
+
|
| 4 |
+
class Model(torch.nn.Module):
|
| 5 |
+
def __init__(self, n_mels=80, embedding_dim=192, channel=512):
|
| 6 |
+
super(Model, self).__init__()
|
| 7 |
+
channels = [channel for _ in range(4)]
|
| 8 |
+
channels.append(channel * 3)
|
| 9 |
+
self.model = ECAPA_TDNN(input_size=n_mels, lin_neurons=embedding_dim, channels=channels)
|
| 10 |
+
|
| 11 |
+
def forward(self, x):
|
| 12 |
+
x = x.squeeze(1)
|
| 13 |
+
x = self.model(x)
|
| 14 |
+
x = x.squeeze(1)
|
| 15 |
+
return x
|
| 16 |
+
|
| 17 |
+
if __name__ == '__main__':
|
| 18 |
+
# Fixing the naming issue for 'channel'
|
| 19 |
+
model = Model(n_mels=80, embedding_dim=192, channel=1024)
|
| 20 |
+
|
| 21 |
+
# Load the pretrained model checkpoint
|
| 22 |
+
checkpoint = torch.load("/ocean/projects/cis220031p/abdulhan/AVIS_baseline/ECAPA/pretrained_models/spkrec-ecapa-voxceleb/embedding_model.ckpt")
|
| 23 |
+
|
| 24 |
+
new_state_dict = {f"model.{k}": v for k, v in checkpoint.items()}
|
| 25 |
+
|
| 26 |
+
# Assuming the checkpoint contains the state dict directly
|
| 27 |
+
model.load_state_dict(new_state_dict)
|
| 28 |
+
|
| 29 |
+
# To evaluate or use the model
|
| 30 |
+
model.eval()
|
| 31 |
+
|
| 32 |
+
# Test with dummy input (B, 1, n_mels, T)
|
| 33 |
+
dummy_input = torch.randn(1, 1, 300, 80) # Example input
|
| 34 |
+
output = model(dummy_input)
|
| 35 |
+
print(output.shape)
|
encoder/mha.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
class MultiHeadAttention(nn.Module):
|
| 6 |
+
def __init__(self, d_model, num_heads):
|
| 7 |
+
super(MultiHeadAttention, self).__init__()
|
| 8 |
+
# Ensure that the model dimension (d_model) is divisible by the number of heads
|
| 9 |
+
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
|
| 10 |
+
|
| 11 |
+
# Initialize dimensions
|
| 12 |
+
self.d_model = d_model # Model's dimension
|
| 13 |
+
self.num_heads = num_heads # Number of attention heads
|
| 14 |
+
self.d_k = d_model // num_heads # Dimension of each head's key, query, and value
|
| 15 |
+
|
| 16 |
+
# Linear layers for transforming inputs
|
| 17 |
+
self.W_q = nn.Linear(d_model, d_model) # Query transformation
|
| 18 |
+
self.W_k = nn.Linear(d_model, d_model) # Key transformation
|
| 19 |
+
self.W_v = nn.Linear(d_model, d_model) # Value transformation
|
| 20 |
+
self.W_o = nn.Linear(d_model, d_model) # Output transformation
|
| 21 |
+
|
| 22 |
+
def scaled_dot_product_attention(self, Q, K, V, prob_phn=None, mask=None, lambda_val=None):
|
| 23 |
+
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
|
| 24 |
+
# Add a singleton dimension to prob_phn at index 1
|
| 25 |
+
prob_phn = prob_phn.unsqueeze(1)
|
| 26 |
+
# Expand prob_phn to match the shape of attn_scores
|
| 27 |
+
# This will not increase memory usage as expand returns a new view on the existing tensor
|
| 28 |
+
prob_phn = prob_phn.expand(-1, self.num_heads, -1, -1)
|
| 29 |
+
if lambda_val > 0:
|
| 30 |
+
attn_scores = attn_scores - lambda_val * prob_phn.transpose(-2, -1)
|
| 31 |
+
attn_mask = mask
|
| 32 |
+
if mask is not None:
|
| 33 |
+
# print(mask.shape)
|
| 34 |
+
mask = mask.unsqueeze(1)
|
| 35 |
+
mask = mask.expand(-1, self.num_heads, -1, -1)
|
| 36 |
+
attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
|
| 37 |
+
attn_probs = torch.softmax(attn_scores, dim=-1)
|
| 38 |
+
attn_probs = attn_probs.float()
|
| 39 |
+
output = torch.matmul(attn_probs, V)
|
| 40 |
+
return output, attn_mask
|
| 41 |
+
def split_heads(self, x):
|
| 42 |
+
# Reshape the input to have num_heads for multi-head attention
|
| 43 |
+
batch_size, seq_length, d_model = x.size()
|
| 44 |
+
return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
|
| 45 |
+
|
| 46 |
+
def combine_heads(self, x):
|
| 47 |
+
# Combine the multiple heads back to original shape
|
| 48 |
+
batch_size, _, seq_length, d_k = x.size()
|
| 49 |
+
return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
|
| 50 |
+
|
| 51 |
+
def forward(self, Q, K, V, prob_phn=None, mask=None, lambda_val=None):
|
| 52 |
+
# Apply linear transformations and split heads
|
| 53 |
+
Q = self.split_heads(self.W_q(Q))
|
| 54 |
+
K = self.split_heads(self.W_k(K))
|
| 55 |
+
V = self.split_heads(self.W_v(V))
|
| 56 |
+
|
| 57 |
+
# Perform scaled dot-product attention
|
| 58 |
+
attn_output, attn_mask = self.scaled_dot_product_attention(Q, K, V, prob_phn, mask,lambda_val)
|
| 59 |
+
|
| 60 |
+
# Combine heads and apply output transformation
|
| 61 |
+
output = self.W_o(self.combine_heads(attn_output))
|
| 62 |
+
return output, attn_mask
|
encoder/self_attn.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from encoder.mha import MultiHeadAttention
|
| 4 |
+
from encoder.attentive_pooling import SelfAttentionPooling
|
| 5 |
+
|
| 6 |
+
class FlippedReLU(nn.Module):
|
| 7 |
+
def __init__(self):
|
| 8 |
+
super(FlippedReLU, self).__init__()
|
| 9 |
+
|
| 10 |
+
def forward(self, x):
|
| 11 |
+
return torch.where(x < 0, x, torch.zeros_like(x))
|
| 12 |
+
|
| 13 |
+
class PositionWiseFeedForward(nn.Module):
|
| 14 |
+
def __init__(self, d_model, d_ff):
|
| 15 |
+
super(PositionWiseFeedForward, self).__init__()
|
| 16 |
+
self.fc1 = nn.Linear(d_model, d_ff)
|
| 17 |
+
self.fc2 = nn.Linear(d_ff, d_model)
|
| 18 |
+
self.relu = nn.ReLU()
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
return self.fc2(self.relu(self.fc1(x)))
|
| 22 |
+
|
| 23 |
+
class EncoderLayer(nn.Module):
|
| 24 |
+
def __init__(self, d_model, num_heads, d_ff, dropout):
|
| 25 |
+
super(EncoderLayer, self).__init__()
|
| 26 |
+
self.self_attn = MultiHeadAttention(d_model, num_heads)
|
| 27 |
+
self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
|
| 28 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 29 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 30 |
+
self.dropout = nn.Dropout(dropout)
|
| 31 |
+
|
| 32 |
+
def forward(self, x, prob_phn=None, mask=None, lambda_val=None):
|
| 33 |
+
attn_output, attn_mask = self.self_attn(x, x, x, prob_phn=prob_phn, mask=mask, lambda_val=lambda_val)
|
| 34 |
+
x = self.norm1(x + self.dropout(attn_output))
|
| 35 |
+
ff_output = self.feed_forward(x)
|
| 36 |
+
x = self.norm2(x + self.dropout(ff_output))
|
| 37 |
+
return x, attn_mask
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class TransformerSelfAttention(nn.Module):
|
| 41 |
+
def __init__(self, input_dim, num_heads, dim_feedforward, number_Of_spks, dropout=0.0):
|
| 42 |
+
"""EncoderBlock.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
input_dim: Dimensionality of the input
|
| 46 |
+
num_heads: Number of heads to use in the attention block
|
| 47 |
+
dim_feedforward: Dimensionality of the hidden layer in the MLP
|
| 48 |
+
dropout: Dropout probability to use in the dropout layers
|
| 49 |
+
"""
|
| 50 |
+
super().__init__()
|
| 51 |
+
# Attention layer
|
| 52 |
+
self.self_mha_attn = EncoderLayer(input_dim, num_heads, dim_feedforward*8,dropout)
|
| 53 |
+
self.attn_pooling = SelfAttentionPooling(input_dim)
|
| 54 |
+
self.emb1 = nn.Linear(input_dim*2, dim_feedforward*8)
|
| 55 |
+
self.emb2 = nn.Linear(input_dim*2, dim_feedforward*8)
|
| 56 |
+
self.emb2.weight.data = self.emb1.weight.data.clone()
|
| 57 |
+
self.emb2.bias.data = self.emb1.bias.data.clone()
|
| 58 |
+
self.bn = nn.BatchNorm1d(dim_feedforward*8)
|
| 59 |
+
self.act = nn.ReLU(inplace=True)
|
| 60 |
+
self.dropout = nn.Dropout(dropout)
|
| 61 |
+
self.classifier = nn.Linear(dim_feedforward*8, number_Of_spks)
|
| 62 |
+
self.flipped_relu = FlippedReLU()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def forward(self, x, prob_phn=None, mask=None, lambda_val=None):
|
| 66 |
+
# Attention part
|
| 67 |
+
attn_out, attn_mask = self.self_mha_attn(x,prob_phn=prob_phn, mask=mask, lambda_val=lambda_val)
|
| 68 |
+
attn_mask= attn_mask.squeeze(1)
|
| 69 |
+
attn_out_mean,attn_out_std = self.attn_pooling(attn_out,attn_mask)
|
| 70 |
+
attn_concat = torch.cat((attn_out_mean, attn_out_std),dim=1).to(dtype=torch.float32)
|
| 71 |
+
|
| 72 |
+
emb1 = self.emb1(attn_concat).to(dtype=torch.float32)
|
| 73 |
+
emb1 = self.act(emb1)
|
| 74 |
+
|
| 75 |
+
emb2 = self.emb2(attn_concat).to(dtype=torch.float32)
|
| 76 |
+
emb2 = self.flipped_relu(emb2)
|
| 77 |
+
|
| 78 |
+
emb = emb1 + emb2
|
| 79 |
+
emb = self.bn(emb)
|
| 80 |
+
x = self.classifier(emb)
|
| 81 |
+
return x,emb
|
load_data/__pycache__/combineddataset.cpython-38.pyc
ADDED
|
Binary file (1.41 kB). View file
|
|
|
load_data/__pycache__/data_collactor.cpython-310.pyc
ADDED
|
Binary file (4.31 kB). View file
|
|
|
load_data/__pycache__/data_collactor.cpython-38.pyc
ADDED
|
Binary file (4.32 kB). View file
|
|
|
load_data/__pycache__/dataset.cpython-38.pyc
ADDED
|
Binary file (2.85 kB). View file
|
|
|
load_data/__pycache__/extract_fbanks.cpython-310.pyc
ADDED
|
Binary file (2.39 kB). View file
|
|
|
load_data/__pycache__/extract_fbanks.cpython-38.pyc
ADDED
|
Binary file (2.44 kB). View file
|
|
|
load_data/__pycache__/prepare_dataloader.cpython-310.pyc
ADDED
|
Binary file (855 Bytes). View file
|
|
|
load_data/__pycache__/prepare_dataloader.cpython-38.pyc
ADDED
|
Binary file (851 Bytes). View file
|
|
|
load_data/__pycache__/tears.cpython-38.pyc
ADDED
|
Binary file (6.77 kB). View file
|
|
|
load_data/__pycache__/timit.cpython-38.pyc
ADDED
|
Binary file (3.15 kB). View file
|
|
|
load_data/__pycache__/voxceleb.cpython-38.pyc
ADDED
|
Binary file (1.82 kB). View file
|
|
|
load_data/combineddataset.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import random
|
| 3 |
+
from torch.utils.data import Dataset, DataLoader
|
| 4 |
+
|
| 5 |
+
class CombinedDataset(Dataset):
|
| 6 |
+
"""
|
| 7 |
+
A dataset that combines two datasets (TIMIT and EARS), selecting samples based on a probability.
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
dataset1 (Dataset): The first dataset (e.g., TIMITDataset).
|
| 11 |
+
dataset2 (Dataset): The second dataset (e.g., EARS).
|
| 12 |
+
switch_prob (float): Probability of picking from dataset1 (default: 0.5).
|
| 13 |
+
"""
|
| 14 |
+
def __init__(self, dataset1, dataset2, switch_prob=0.5):
|
| 15 |
+
self.dataset1 = dataset1
|
| 16 |
+
self.dataset2 = dataset2
|
| 17 |
+
self.len1 = len(dataset1)
|
| 18 |
+
self.len2 = len(dataset2)
|
| 19 |
+
self.switch_prob = switch_prob # Probability of picking from dataset1
|
| 20 |
+
|
| 21 |
+
def __len__(self):
|
| 22 |
+
return max(self.len1, self.len2) # Use the longer dataset length
|
| 23 |
+
|
| 24 |
+
def __getitem__(self, idx):
|
| 25 |
+
# Decide whether to sample from dataset1 or dataset2
|
| 26 |
+
if random.random() < self.switch_prob:
|
| 27 |
+
return self.dataset1[idx % self.len1] # Sample from dataset1
|
| 28 |
+
else:
|
| 29 |
+
return self.dataset2[idx % self.len2] # Sample from dataset2
|
load_data/data_collactor.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import AutoFeatureExtractor
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from typing import Any, Dict, List, Optional, Union
|
| 5 |
+
from preprocessing.ast_processor import ast
|
| 6 |
+
from util_stats.local_stats import local_extract_phn_frame_probs
|
| 7 |
+
from util_stats.global_stats import global_extract_phn_frame_probs
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pickle
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
from load_data.extract_fbanks import Mel_Spectrogram
|
| 13 |
+
|
| 14 |
+
extractor = Mel_Spectrogram()
|
| 15 |
+
|
| 16 |
+
with open('new_lbl2ind.pkl', 'rb') as f:
|
| 17 |
+
lbl2ind = pickle.load(f)
|
| 18 |
+
with open('new_spk.pkl', 'rb') as f:
|
| 19 |
+
unique_speaker_ids = pickle.load(f)
|
| 20 |
+
# change the labels
|
| 21 |
+
number_Of_spks = len(unique_speaker_ids)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class DataCollatorWithPadding:
|
| 26 |
+
"""
|
| 27 |
+
Data collator that will dynamically pad the inputs received.
|
| 28 |
+
Args:
|
| 29 |
+
processor (:class:`~transformers.Wav2Vec2Processor`)
|
| 30 |
+
The processor used for proccessing the data.
|
| 31 |
+
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
| 32 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
| 33 |
+
among:
|
| 34 |
+
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
| 35 |
+
sequence if provided).
|
| 36 |
+
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
| 37 |
+
maximum acceptable input length for the model if that argument is not provided.
|
| 38 |
+
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
| 39 |
+
different lengths).
|
| 40 |
+
max_length (:obj:`int`, `optional`):
|
| 41 |
+
Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
|
| 42 |
+
max_length_labels (:obj:`int`, `optional`):
|
| 43 |
+
Maximum length of the ``labels`` returned list and optionally padding length (see above).
|
| 44 |
+
pad_to_multiple_of (:obj:`int`, `optional`):
|
| 45 |
+
If set will pad the sequence to a multiple of the provided value.
|
| 46 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
| 47 |
+
7.5 (Volta).
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
padding: Union[bool, str] = True
|
| 53 |
+
max_length: Optional[int] = None
|
| 54 |
+
pad_to_multiple_of: Optional[int] = None
|
| 55 |
+
pad_to_multiple_of_labels: Optional[int] = None
|
| 56 |
+
flag_global_local: Optional[str] = None
|
| 57 |
+
dic_train_phn_frequency: Optional [dict] = None
|
| 58 |
+
dic_train_frame_frequency: Optional [dict] = None
|
| 59 |
+
lbl2ind: Optional [dict] = None
|
| 60 |
+
|
| 61 |
+
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
| 62 |
+
# split inputs and labels since they have to be of different lengths and need
|
| 63 |
+
# different padding methods
|
| 64 |
+
batch={}
|
| 65 |
+
batch['input_values']= [features[idx]['audio_tensor'].squeeze(0) for idx in range(len(features))]
|
| 66 |
+
batch["prompt"] = [features[idx]["prompt"] for idx in range(len(features))]
|
| 67 |
+
batch["answer"] = [features[idx]["answer"] for idx in range(len(features))]
|
| 68 |
+
batch["filename"] = [features[idx]["filename"] for idx in range(len(features))]
|
| 69 |
+
# batch["no_hot_encode"] = torch.tensor([lbl2ind[features[idx]['sid']] for idx in range(len(features))])
|
| 70 |
+
batch["no_hot_encode"] = torch.tensor([0 for idx in range(len(features))])
|
| 71 |
+
# if batch["no_hot_encode"].numel():
|
| 72 |
+
batch["labels"]= F.one_hot(batch["no_hot_encode"], number_Of_spks)
|
| 73 |
+
batch['input_values'] = extractor(torch.stack(batch['input_values']))
|
| 74 |
+
return batch
|
load_data/dataset.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from glob import glob
|
| 3 |
+
import torchaudio
|
| 4 |
+
from torch.utils.data import Dataset
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import pickle
|
| 8 |
+
from copy import deepcopy
|
| 9 |
+
from glob import glob
|
| 10 |
+
import random
|
| 11 |
+
from sklearn.model_selection import train_test_split
|
| 12 |
+
import json
|
| 13 |
+
import os
|
| 14 |
+
import numpy as np
|
| 15 |
+
import librosa
|
| 16 |
+
import torch
|
| 17 |
+
import soundfile as sf
|
| 18 |
+
import pandas as pd
|
| 19 |
+
import random
|
| 20 |
+
|
| 21 |
+
class EARS(Dataset):
|
| 22 |
+
"""
|
| 23 |
+
EARS dataset for 10sec or less that 10sec segments.
|
| 24 |
+
Returns:
|
| 25 |
+
audio: torch.Tensor in (1,16000) or (1, <16000), audio waveform
|
| 26 |
+
sid: str (p103), speaker id
|
| 27 |
+
metadict: dict, metadata
|
| 28 |
+
caption: str, caption
|
| 29 |
+
alignment: list
|
| 30 |
+
"""
|
| 31 |
+
def __init__(self, root, data_path, meta_path,utterance_path, prompts_path, sample_rate, train_mapper=False, split="train"):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.root = root
|
| 34 |
+
|
| 35 |
+
with open(f"{data_path}", "r") as f:
|
| 36 |
+
self.data = json.load(f)
|
| 37 |
+
|
| 38 |
+
with open(f"{meta_path}", "r") as f:
|
| 39 |
+
self.meta = json.load(f)
|
| 40 |
+
|
| 41 |
+
with open(f"{utterance_path}", "r") as f:
|
| 42 |
+
self.utterance = json.load(f)
|
| 43 |
+
|
| 44 |
+
with open(f"{prompts_path}", "r") as f:
|
| 45 |
+
self.prompts = json.load(f)
|
| 46 |
+
|
| 47 |
+
self.new_data = []
|
| 48 |
+
if train_mapper:
|
| 49 |
+
for d in self.data:
|
| 50 |
+
file_name = d["filename"]
|
| 51 |
+
sid = file_name.split("/")[0]
|
| 52 |
+
temp = random.sample(self.prompts[sid], 10)
|
| 53 |
+
for qa in temp:
|
| 54 |
+
self.new_data.append({"filename": file_name,
|
| 55 |
+
"start": d["start"],
|
| 56 |
+
"end": d["end"],
|
| 57 |
+
"prompt": qa[0],
|
| 58 |
+
"answer": qa[1]})
|
| 59 |
+
else:
|
| 60 |
+
self.new_data = self.data
|
| 61 |
+
if split == "train":
|
| 62 |
+
random.shuffle(self.new_data)
|
| 63 |
+
|
| 64 |
+
self.sample_rate = sample_rate
|
| 65 |
+
|
| 66 |
+
def __len__(self):
|
| 67 |
+
return len(self.new_data)
|
| 68 |
+
|
| 69 |
+
def __getitem__(self, idx):
|
| 70 |
+
entry = self.new_data[idx]
|
| 71 |
+
filename = entry["filename"]
|
| 72 |
+
sid = filename.split("/")[0]
|
| 73 |
+
audio_path = os.path.join(self.root, filename)
|
| 74 |
+
|
| 75 |
+
# Load audio
|
| 76 |
+
audio, sample_rate = torchaudio.load(audio_path)
|
| 77 |
+
start_sample, end_sample = entry["start"], entry["end"]
|
| 78 |
+
|
| 79 |
+
# Resample if needed
|
| 80 |
+
if sample_rate != self.sample_rate:
|
| 81 |
+
audio = torchaudio.transforms.Resample(sample_rate, self.sample_rate)(audio)
|
| 82 |
+
|
| 83 |
+
# Compute duration in samples
|
| 84 |
+
total_samples = end_sample - start_sample
|
| 85 |
+
num_samples_3s = 3 * self.sample_rate # 3 seconds worth of samples
|
| 86 |
+
|
| 87 |
+
# Select a random 3s window within the available range
|
| 88 |
+
if total_samples >= num_samples_3s:
|
| 89 |
+
start_offset = random.randint(start_sample, end_sample - num_samples_3s)
|
| 90 |
+
end_offset = start_offset + num_samples_3s
|
| 91 |
+
audio = audio[:, start_offset:end_offset]
|
| 92 |
+
else:
|
| 93 |
+
# If less than 3s, take full segment and pad
|
| 94 |
+
pad_size = num_samples_3s - total_samples
|
| 95 |
+
audio = audio[:, start_sample:end_sample]
|
| 96 |
+
audio = torch.nn.functional.pad(audio, (0, pad_size))
|
| 97 |
+
|
| 98 |
+
# Normalize
|
| 99 |
+
mean = torch.mean(audio)
|
| 100 |
+
std = torch.std(audio)
|
| 101 |
+
audio = (audio - mean) / (std + 1e-8)
|
| 102 |
+
|
| 103 |
+
return {
|
| 104 |
+
"audio_tensor": audio,
|
| 105 |
+
"filename": filename,
|
| 106 |
+
"sid": sid,
|
| 107 |
+
"prompt": entry.get("prompt", None),
|
| 108 |
+
"answer": entry.get("answer", None),
|
| 109 |
+
}
|
load_data/extract_fbanks.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import librosa
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
class PreEmphasis(torch.nn.Module):
|
| 7 |
+
def __init__(self, coef: float = 0.97):
|
| 8 |
+
super(PreEmphasis, self).__init__()
|
| 9 |
+
self.coef = coef
|
| 10 |
+
# make kernel
|
| 11 |
+
# In pytorch, the convolution operation uses cross-correlation. So, filter is flipped.
|
| 12 |
+
self.register_buffer(
|
| 13 |
+
'flipped_filter', torch.FloatTensor(
|
| 14 |
+
[-self.coef, 1.]).unsqueeze(0).unsqueeze(0)
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
def forward(self, inputs: torch.tensor) -> torch.tensor:
|
| 18 |
+
assert len(
|
| 19 |
+
inputs.size()) == 2, 'The number of dimensions of inputs tensor must be 2!'
|
| 20 |
+
# reflect padding to match lengths of in/out
|
| 21 |
+
inputs = inputs.unsqueeze(1)
|
| 22 |
+
inputs = F.pad(inputs, (1, 0), 'reflect')
|
| 23 |
+
return F.conv1d(inputs, self.flipped_filter).squeeze(1)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class Mel_Spectrogram(nn.Module):
|
| 27 |
+
def __init__(self, sample_rate=16000, n_fft=512, win_length=400, hop=160, n_mels=80, coef=0.97, requires_grad=False):
|
| 28 |
+
super(Mel_Spectrogram, self).__init__()
|
| 29 |
+
self.n_fft = n_fft
|
| 30 |
+
self.n_mels = n_mels
|
| 31 |
+
self.win_length = win_length
|
| 32 |
+
self.hop = hop
|
| 33 |
+
|
| 34 |
+
self.pre_emphasis = PreEmphasis(coef)
|
| 35 |
+
mel_basis = librosa.filters.mel(
|
| 36 |
+
sr=sample_rate, n_fft=n_fft, n_mels=n_mels)
|
| 37 |
+
self.mel_basis = nn.Parameter(
|
| 38 |
+
torch.FloatTensor(mel_basis), requires_grad=requires_grad)
|
| 39 |
+
self.instance_norm = nn.InstanceNorm1d(num_features=n_mels)
|
| 40 |
+
window = torch.hamming_window(self.win_length)
|
| 41 |
+
self.window = nn.Parameter(
|
| 42 |
+
torch.FloatTensor(window), requires_grad=False)
|
| 43 |
+
|
| 44 |
+
def forward(self, x):
|
| 45 |
+
x = self.pre_emphasis(x)
|
| 46 |
+
x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop,
|
| 47 |
+
window=self.window, win_length=self.win_length, return_complex=True)
|
| 48 |
+
x = torch.abs(x)
|
| 49 |
+
x += 1e-9
|
| 50 |
+
x = torch.log(x)
|
| 51 |
+
x = torch.matmul(self.mel_basis, x)
|
| 52 |
+
x = self.instance_norm(x)
|
| 53 |
+
x = x.permute(0, 2, 1)
|
| 54 |
+
x = x.unsqueeze(1)
|
| 55 |
+
return x
|
load_data/prepare_dataloader.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data import Dataset, DataLoader
|
| 2 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 3 |
+
from preprocessing.ast_processor import ast
|
| 4 |
+
from load_data.data_collactor import DataCollatorWithPadding
|
| 5 |
+
|
| 6 |
+
def prepare_dataloader(dataset: Dataset, batch_size: int, valid_train_flag: str):
|
| 7 |
+
if valid_train_flag == "train":
|
| 8 |
+
data_collator = DataCollatorWithPadding(padding=True)
|
| 9 |
+
elif valid_train_flag == "valid":
|
| 10 |
+
data_collator = DataCollatorWithPadding(padding=True)
|
| 11 |
+
elif valid_train_flag == "test":
|
| 12 |
+
data_collator = DataCollatorWithPadding(padding=True)
|
| 13 |
+
return DataLoader(
|
| 14 |
+
dataset,
|
| 15 |
+
batch_size=batch_size,
|
| 16 |
+
pin_memory=True,
|
| 17 |
+
shuffle=False,
|
| 18 |
+
sampler=DistributedSampler(dataset),
|
| 19 |
+
|
| 20 |
+
collate_fn=data_collator
|
| 21 |
+
)
|
| 22 |
+
|
load_data/tears.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import Dataset
|
| 3 |
+
import json
|
| 4 |
+
import torchaudio
|
| 5 |
+
import os
|
| 6 |
+
from typing import Optional, Dict, Any, List, Tuple
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import warnings
|
| 9 |
+
import random
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from collections import defaultdict
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class TEARSDataset(Dataset):
|
| 17 |
+
"""
|
| 18 |
+
TEARS dataset class that loads audio and associated metadata/responses.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
json_path (str): Path to the JSON file containing TEARS data
|
| 22 |
+
tears_root (str): Root directory containing TEARS audio files
|
| 23 |
+
sample_rate (int, optional): Target sample rate for audio. Defaults to 16000.
|
| 24 |
+
duration (float, optional): Target duration in seconds. Defaults to 3.0.
|
| 25 |
+
normalize_audio (bool, optional): Whether to normalize audio. Defaults to True.
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
Dict containing:
|
| 29 |
+
- audio_tensor: torch.Tensor of shape (1, num_samples)
|
| 30 |
+
- speaker_id: str, speaker identifier
|
| 31 |
+
- metadata: dict containing speaker metadata
|
| 32 |
+
- prompt: str, randomly selected prompt
|
| 33 |
+
- response: str, corresponding response
|
| 34 |
+
- filepath: str, path to audio file
|
| 35 |
+
"""
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
json_path: str,
|
| 39 |
+
tears_root: str,
|
| 40 |
+
sample_rate: int = 16000,
|
| 41 |
+
duration: float = 3.0,
|
| 42 |
+
normalize_audio: bool = True,
|
| 43 |
+
augment: bool = True
|
| 44 |
+
):
|
| 45 |
+
super().__init__()
|
| 46 |
+
|
| 47 |
+
# Load the JSON data
|
| 48 |
+
with open(json_path, 'r') as f:
|
| 49 |
+
self.data = json.load(f)
|
| 50 |
+
|
| 51 |
+
self.tears_root = Path(tears_root)
|
| 52 |
+
self.sample_rate = sample_rate
|
| 53 |
+
self.duration = duration
|
| 54 |
+
self.normalize_audio = normalize_audio
|
| 55 |
+
self.target_samples = int(duration * sample_rate)
|
| 56 |
+
self.augment = augment
|
| 57 |
+
|
| 58 |
+
def __len__(self) -> int:
|
| 59 |
+
return len(self.data)
|
| 60 |
+
|
| 61 |
+
def augment_audio(self, waveform, sample_rate):
|
| 62 |
+
# Randomly select augmentation methods
|
| 63 |
+
augmentation_choices = ['time_stretch', 'pitch_shift', 'add_noise', 'spec_aug']
|
| 64 |
+
random.shuffle(augmentation_choices)
|
| 65 |
+
|
| 66 |
+
for aug in augmentation_choices[:random.randint(1, len(augmentation_choices))]:
|
| 67 |
+
if aug == 'time_stretch':
|
| 68 |
+
rate = random.uniform(0.8, 1.25)
|
| 69 |
+
effect = [['speed', str(rate)], ['rate', str(16000)]]
|
| 70 |
+
waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
|
| 71 |
+
waveform, 16000, effects=effect
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
elif aug == 'pitch_shift':
|
| 75 |
+
n_steps = random.randint(-4, 4)
|
| 76 |
+
effect = [['pitch', str(n)] for n in [n_steps*100 for n in [random.choice([-2, -1, 1, 2])]]]
|
| 77 |
+
waveform, _ = torchaudio.sox_effects.apply_effects_tensor(waveform, 16000, effect)
|
| 78 |
+
|
| 79 |
+
elif aug == 'add_noise':
|
| 80 |
+
noise = torch.randn_like(waveform) * random.uniform(0.001, 0.015)
|
| 81 |
+
waveform = waveform + noise
|
| 82 |
+
|
| 83 |
+
elif aug == 'frequency_mask':
|
| 84 |
+
freq_mask = T.FrequencyMasking(freq_mask_param=random.randint(15, 30))
|
| 85 |
+
waveform = freq_mask(waveform)
|
| 86 |
+
|
| 87 |
+
elif aug == 'time_mask':
|
| 88 |
+
time_mask = T.TimeMasking(time_mask_param=random.randint(20, 80))
|
| 89 |
+
waveform = time_mask(waveform)
|
| 90 |
+
|
| 91 |
+
elif aug == 'reverb':
|
| 92 |
+
effect = [['reverb', '-w', str(random.randint(10, 50))]]
|
| 93 |
+
waveform, _ = torchaudio.sox_effects.apply_effects_tensor(waveform, 16000, effect)
|
| 94 |
+
|
| 95 |
+
elif aug == 'pitch_shift':
|
| 96 |
+
steps = random.randint(-2, 2)
|
| 97 |
+
effect = [['pitch', str(steps * 100)], ['rate', '16000']]
|
| 98 |
+
waveform, _ = torchaudio.sox_effects.apply_effects_tensor(waveform, 16000, effect)
|
| 99 |
+
|
| 100 |
+
return waveform
|
| 101 |
+
|
| 102 |
+
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
| 103 |
+
# Get sample data
|
| 104 |
+
sample = self.data[idx]
|
| 105 |
+
|
| 106 |
+
# Get file path
|
| 107 |
+
audio_path = str(self.tears_root / sample['audio_path'])
|
| 108 |
+
|
| 109 |
+
# Load and process audio
|
| 110 |
+
try:
|
| 111 |
+
audio, sr = torchaudio.load(audio_path)
|
| 112 |
+
|
| 113 |
+
# Resample if necessary
|
| 114 |
+
if sr != self.sample_rate:
|
| 115 |
+
audio = torchaudio.transforms.Resample(sr, self.sample_rate)(audio)
|
| 116 |
+
|
| 117 |
+
if self.augment:
|
| 118 |
+
audio = self.augment_audio(audio, self.sample_rate)
|
| 119 |
+
|
| 120 |
+
# Normalize if requested
|
| 121 |
+
if self.normalize_audio:
|
| 122 |
+
mean = torch.mean(audio)
|
| 123 |
+
std = torch.std(audio)
|
| 124 |
+
audio = (audio - mean) / (std + 1e-8)
|
| 125 |
+
|
| 126 |
+
# Handle duration
|
| 127 |
+
num_samples = audio.shape[1]
|
| 128 |
+
|
| 129 |
+
if num_samples >= self.target_samples:
|
| 130 |
+
# Randomly crop to target duration
|
| 131 |
+
start_sample = random.randint(0, num_samples - self.target_samples)
|
| 132 |
+
audio = audio[:, start_sample:start_sample + self.target_samples]
|
| 133 |
+
else:
|
| 134 |
+
# Pad if shorter than target duration
|
| 135 |
+
pad_size = self.target_samples - num_samples
|
| 136 |
+
audio = torch.nn.functional.pad(audio, (0, pad_size))
|
| 137 |
+
|
| 138 |
+
except Exception as e:
|
| 139 |
+
warnings.warn(f"Error loading audio file {audio_path}: {str(e)}")
|
| 140 |
+
# Return zero tensor if audio loading fails
|
| 141 |
+
audio = torch.zeros(1, self.target_samples)
|
| 142 |
+
|
| 143 |
+
# Get prompt and response
|
| 144 |
+
prompts = sample.get('prompts', [])
|
| 145 |
+
responses = sample.get('responses', [])
|
| 146 |
+
|
| 147 |
+
if prompts and responses and len(prompts) == len(responses):
|
| 148 |
+
rand_idx = random.randint(0, len(prompts) - 1)
|
| 149 |
+
prompt = prompts[rand_idx]
|
| 150 |
+
response = responses[rand_idx].replace("\n", " ").strip()
|
| 151 |
+
else:
|
| 152 |
+
prompt = None
|
| 153 |
+
response = None
|
| 154 |
+
|
| 155 |
+
return {
|
| 156 |
+
'audio_tensor': audio,
|
| 157 |
+
'sid': sample['speaker']['id'],
|
| 158 |
+
'metadata': sample['speaker'],
|
| 159 |
+
'prompt': prompt,
|
| 160 |
+
'answer': response,
|
| 161 |
+
'filename': str(audio_path)
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
@staticmethod
|
| 167 |
+
def redistribute_speakers(
|
| 168 |
+
json_paths: Dict[str, str],
|
| 169 |
+
split_ratios: Dict[str, float],
|
| 170 |
+
seed: int = 42
|
| 171 |
+
) -> Dict[str, List[Dict]]:
|
| 172 |
+
"""
|
| 173 |
+
Redistribute speakers across splits according to given ratios.
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
json_paths: Dict mapping split names to json file paths
|
| 177 |
+
split_ratios: Dict mapping split names to desired ratios (should sum to 1)
|
| 178 |
+
seed: Random seed for reproducibility
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
Dict mapping split names to lists of samples
|
| 182 |
+
"""
|
| 183 |
+
random.seed(seed)
|
| 184 |
+
|
| 185 |
+
# Collect all samples and group by speaker
|
| 186 |
+
speaker_samples = defaultdict(list)
|
| 187 |
+
for split, path in json_paths.items():
|
| 188 |
+
with open(path, 'r') as f:
|
| 189 |
+
data = json.load(f)
|
| 190 |
+
for sample in data:
|
| 191 |
+
speaker_samples[sample['speaker']['id']].append(sample)
|
| 192 |
+
|
| 193 |
+
# Get list of all speakers
|
| 194 |
+
all_speakers = list(speaker_samples.keys())
|
| 195 |
+
random.shuffle(all_speakers)
|
| 196 |
+
|
| 197 |
+
# Calculate number of speakers for each split
|
| 198 |
+
total_speakers = len(all_speakers)
|
| 199 |
+
split_speakers = {
|
| 200 |
+
split: int(ratio * total_speakers)
|
| 201 |
+
for split, ratio in split_ratios.items()
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
# Adjust for rounding errors
|
| 205 |
+
remainder = total_speakers - sum(split_speakers.values())
|
| 206 |
+
if remainder > 0:
|
| 207 |
+
# Add remaining speakers to first split
|
| 208 |
+
split_speakers[list(split_speakers.keys())[0]] += remainder
|
| 209 |
+
|
| 210 |
+
# Distribute speakers to splits
|
| 211 |
+
new_splits = defaultdict(list)
|
| 212 |
+
current_idx = 0
|
| 213 |
+
|
| 214 |
+
for split, num_speakers in split_speakers.items():
|
| 215 |
+
split_speaker_ids = all_speakers[current_idx:current_idx + num_speakers]
|
| 216 |
+
for speaker_id in split_speaker_ids:
|
| 217 |
+
new_splits[split].extend(speaker_samples[speaker_id])
|
| 218 |
+
current_idx += num_speakers
|
| 219 |
+
|
| 220 |
+
return new_splits
|
| 221 |
+
|
| 222 |
+
@staticmethod
|
| 223 |
+
def save_splits(splits: Dict[str, List[Dict]], output_dir: str):
|
| 224 |
+
"""Save redistributed splits to JSON files."""
|
| 225 |
+
output_dir = Path(output_dir)
|
| 226 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 227 |
+
|
| 228 |
+
for split_name, samples in splits.items():
|
| 229 |
+
output_path = output_dir / f"tears_dataset_{split_name}_with_responses.json"
|
| 230 |
+
with open(output_path, 'w') as f:
|
| 231 |
+
json.dump(samples, f, indent=2)
|
| 232 |
+
|
load_data/timit.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import Dataset
|
| 3 |
+
import json
|
| 4 |
+
import torchaudio
|
| 5 |
+
import os
|
| 6 |
+
from typing import Optional, Dict, Any, List, Tuple
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import warnings
|
| 9 |
+
import random
|
| 10 |
+
|
| 11 |
+
class TIMITDataset(Dataset):
|
| 12 |
+
"""
|
| 13 |
+
TIMIT dataset class that loads audio and associated metadata/transcriptions.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
json_path (str): Path to the JSON file containing TIMIT data
|
| 17 |
+
timit_root (str): Root directory containing TIMIT audio files
|
| 18 |
+
sample_rate (int, optional): Target sample rate for audio. Defaults to 16000.
|
| 19 |
+
normalize_audio (bool, optional): Whether to normalize audio. Defaults to True.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
Dict containing:
|
| 23 |
+
- audio_tensor: torch.Tensor of shape (1, num_samples)
|
| 24 |
+
- speaker_id: str, speaker identifier
|
| 25 |
+
- metadata: dict containing speaker metadata
|
| 26 |
+
- prompts: list of prompts used
|
| 27 |
+
- responses: list of responses generated
|
| 28 |
+
- filepath: str, path to audio file
|
| 29 |
+
- phonemes: DataFrame with columns [start_sample, end_sample, phoneme]
|
| 30 |
+
- words: DataFrame with columns [start_sample, end_sample, word]
|
| 31 |
+
- text: str, complete transcription
|
| 32 |
+
"""
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
json_path: str,
|
| 36 |
+
timit_root: str,
|
| 37 |
+
sample_rate: int = 16000,
|
| 38 |
+
normalize_audio: bool = True
|
| 39 |
+
):
|
| 40 |
+
super().__init__()
|
| 41 |
+
|
| 42 |
+
# Load the JSON data
|
| 43 |
+
with open(json_path, 'r') as f:
|
| 44 |
+
self.data = json.load(f)
|
| 45 |
+
|
| 46 |
+
self.timit_root = timit_root
|
| 47 |
+
self.sample_rate = sample_rate
|
| 48 |
+
self.normalize_audio = normalize_audio
|
| 49 |
+
|
| 50 |
+
def __len__(self) -> int:
|
| 51 |
+
return len(self.data)
|
| 52 |
+
|
| 53 |
+
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
| 54 |
+
# Get sample data
|
| 55 |
+
sample = self.data[idx]
|
| 56 |
+
|
| 57 |
+
# Get file paths
|
| 58 |
+
audio_path = os.path.join(self.timit_root, sample['audio_path'])
|
| 59 |
+
|
| 60 |
+
# Load audio first
|
| 61 |
+
audio, sr = torchaudio.load(audio_path)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
if sr != self.sample_rate:
|
| 65 |
+
audio = torchaudio.transforms.Resample(sr, self.sample_rate)(audio)
|
| 66 |
+
|
| 67 |
+
mean = torch.mean(audio)
|
| 68 |
+
std = torch.std(audio)
|
| 69 |
+
audio = (audio - mean) / (std + 1e-8)
|
| 70 |
+
|
| 71 |
+
# Get total number of samples
|
| 72 |
+
num_samples = audio.shape[1]
|
| 73 |
+
num_samples_3s = 3 * self.sample_rate # Samples for 3 seconds
|
| 74 |
+
|
| 75 |
+
# Ensure the audio is at least 3 seconds long
|
| 76 |
+
if num_samples >= num_samples_3s:
|
| 77 |
+
start_sample = random.randint(0, num_samples - num_samples_3s)
|
| 78 |
+
end_sample = start_sample + num_samples_3s
|
| 79 |
+
audio = audio[:, start_sample:end_sample]
|
| 80 |
+
else:
|
| 81 |
+
# If audio is shorter than 3 seconds, pad it
|
| 82 |
+
pad_size = num_samples_3s - num_samples
|
| 83 |
+
audio = torch.nn.functional.pad(audio, (0, pad_size))
|
| 84 |
+
|
| 85 |
+
prompts = sample.get('prompts', [])
|
| 86 |
+
answers = sample.get('responses', [])
|
| 87 |
+
|
| 88 |
+
if prompts and answers and len(prompts) == len(answers):
|
| 89 |
+
rand_idx = random.randint(0, len(prompts) - 1)
|
| 90 |
+
prompt = prompts[rand_idx]
|
| 91 |
+
answer = answers[rand_idx].replace("\n", " ").strip() # Clean response
|
| 92 |
+
else:
|
| 93 |
+
prompt = None
|
| 94 |
+
answer = None
|
| 95 |
+
|
| 96 |
+
return {
|
| 97 |
+
'audio_tensor': audio,
|
| 98 |
+
'sid': sample['speaker']['id'],
|
| 99 |
+
'prompt': prompt,
|
| 100 |
+
'answer': answer,
|
| 101 |
+
'filename': audio_path,
|
| 102 |
+
}
|
load_data/voxceleb.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import Dataset
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import torchaudio
|
| 5 |
+
import random
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
class ZeroShotDataset(Dataset):
|
| 9 |
+
def __init__(self, csv_path, transform=None):
|
| 10 |
+
"""
|
| 11 |
+
Args:
|
| 12 |
+
csv_path (str): Path to the CSV file.
|
| 13 |
+
transform (callable, optional): Optional transform to be applied to audio.
|
| 14 |
+
"""
|
| 15 |
+
self.data = pd.read_csv(csv_path)
|
| 16 |
+
self.transform = transform
|
| 17 |
+
self.sample_rate = 16000
|
| 18 |
+
|
| 19 |
+
def __len__(self):
|
| 20 |
+
return len(self.data)
|
| 21 |
+
|
| 22 |
+
def __getitem__(self, idx):
|
| 23 |
+
row = self.data.iloc[idx]
|
| 24 |
+
|
| 25 |
+
root = "/ocean/projects/cis220031p/psamal/preprocess_TIMIT/"
|
| 26 |
+
|
| 27 |
+
# Load audio file
|
| 28 |
+
audio, sr = torchaudio.load(os.path.join(root, row["File_Path"]))
|
| 29 |
+
|
| 30 |
+
# Apply transformation if provided
|
| 31 |
+
if self.transform:
|
| 32 |
+
audio = self.transform(audio)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
if sr != self.sample_rate:
|
| 36 |
+
audio = torchaudio.transforms.Resample(sr, self.sample_rate)(audio)
|
| 37 |
+
|
| 38 |
+
mean = torch.mean(audio)
|
| 39 |
+
std = torch.std(audio)
|
| 40 |
+
audio = (audio - mean) / (std + 1e-8)
|
| 41 |
+
|
| 42 |
+
# Get total number of samples
|
| 43 |
+
num_samples = audio.shape[1]
|
| 44 |
+
num_samples_3s = 3 * self.sample_rate # Samples for 3 seconds
|
| 45 |
+
|
| 46 |
+
# Ensure the audio is at least 3 seconds long
|
| 47 |
+
if num_samples >= num_samples_3s:
|
| 48 |
+
start_sample = random.randint(0, num_samples - num_samples_3s)
|
| 49 |
+
end_sample = start_sample + num_samples_3s
|
| 50 |
+
audio = audio[:, start_sample:end_sample]
|
| 51 |
+
else:
|
| 52 |
+
# If audio is shorter than 3 seconds, pad it
|
| 53 |
+
pad_size = num_samples_3s - num_samples
|
| 54 |
+
audio = torch.nn.functional.pad(audio, (0, pad_size))
|
| 55 |
+
|
| 56 |
+
return {
|
| 57 |
+
"sid": "WBT0",
|
| 58 |
+
"audio_tensor": audio,
|
| 59 |
+
"answer": row["Ground_Truth"],
|
| 60 |
+
"prompt": row["Prompt"],
|
| 61 |
+
# "prompt": random.choice(["What is the dialect of the person?", "Based on the voice of the person, please specify the dialect of the person?", row["Prompt"]]),
|
| 62 |
+
'filename': row["File_Path"],
|
| 63 |
+
}
|
mapper.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.nn import functional as nnf
|
| 6 |
+
from typing import Tuple, Optional
|
| 7 |
+
|
| 8 |
+
def get_sid_mapper(map_type: str, emb_size, prefix_size: int, gpt_embedding_size: int, prefix_length: int, clip_length: int, num_layers: int):
|
| 9 |
+
|
| 10 |
+
if map_type == 'mlp':
|
| 11 |
+
mapper = MLP(emb_size, (prefix_size, (gpt_embedding_size * prefix_length) // 2, gpt_embedding_size * prefix_length))
|
| 12 |
+
|
| 13 |
+
elif map_type == 'transformer':
|
| 14 |
+
mapper = TransformerMapper(emb_size, prefix_size, gpt_embedding_size, prefix_length, clip_length, int(num_layers/2))
|
| 15 |
+
|
| 16 |
+
else:
|
| 17 |
+
raise ValueError(f"Unknown mapping type {map_type}")
|
| 18 |
+
|
| 19 |
+
for p in mapper.parameters():
|
| 20 |
+
p.requires_grad = True
|
| 21 |
+
|
| 22 |
+
return mapper
|
| 23 |
+
|
| 24 |
+
def get_text_mapper(map_type: str, emb_size, prefix_size: int, gpt_embedding_size: int, prefix_length: int, clip_length: int, num_layers: int):
|
| 25 |
+
|
| 26 |
+
if map_type == 'mlp':
|
| 27 |
+
mapper = MLP(emb_size, (prefix_size, (gpt_embedding_size * prefix_length) // 2, gpt_embedding_size * prefix_length))
|
| 28 |
+
|
| 29 |
+
elif map_type == 'transformer':
|
| 30 |
+
mapper = TransformerMapperSeq(emb_size, prefix_size, gpt_embedding_size, prefix_length, clip_length, int(num_layers/2))
|
| 31 |
+
|
| 32 |
+
else:
|
| 33 |
+
raise ValueError(f"Unknown mapping type {map_type}")
|
| 34 |
+
|
| 35 |
+
for p in mapper.parameters():
|
| 36 |
+
p.requires_grad = True
|
| 37 |
+
|
| 38 |
+
return mapper
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def init_layer(layer):
|
| 42 |
+
"""Initialize a Linear or Convolutional layer. """
|
| 43 |
+
nn.init.xavier_uniform_(layer.weight)
|
| 44 |
+
|
| 45 |
+
if hasattr(layer, 'bias'):
|
| 46 |
+
if layer.bias is not None:
|
| 47 |
+
layer.bias.data.fill_(0.)
|
| 48 |
+
|
| 49 |
+
def init_bn(bn):
|
| 50 |
+
"""Initialize a Batchnorm layer. """
|
| 51 |
+
bn.bias.data.fill_(0.)
|
| 52 |
+
bn.weight.data.fill_(1.)
|
| 53 |
+
|
| 54 |
+
class Projection(nn.Module):
|
| 55 |
+
def __init__(self, d_in: int, d_out: int, p: float=0.5) -> None:
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.linear1 = nn.Linear(d_in, d_out, bias=False)
|
| 58 |
+
self.linear2 = nn.Linear(d_out, d_out, bias=False)
|
| 59 |
+
self.layer_norm = nn.LayerNorm(d_out)
|
| 60 |
+
self.drop = nn.Dropout(p)
|
| 61 |
+
|
| 62 |
+
self.init_weight()
|
| 63 |
+
|
| 64 |
+
def init_weight(self):
|
| 65 |
+
init_layer(self.linear1)
|
| 66 |
+
init_layer(self.linear2)
|
| 67 |
+
init_bn(self.layer_norm)
|
| 68 |
+
|
| 69 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 70 |
+
embed1 = self.linear1(x)
|
| 71 |
+
embed2 = self.drop(self.linear2(nnf.gelu(embed1)))
|
| 72 |
+
embeds = self.layer_norm(embed1 + embed2)
|
| 73 |
+
return embeds
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class MLP(nn.Module):
|
| 77 |
+
def __init__(self, emb_size, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
|
| 78 |
+
super(MLP, self).__init__()
|
| 79 |
+
self.emb_size = emb_size
|
| 80 |
+
# if self.emb_size is not None:
|
| 81 |
+
# self.projector = Projection(emb_size, sizes[0])
|
| 82 |
+
layers = []
|
| 83 |
+
for i in range(len(sizes) - 1):
|
| 84 |
+
layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
|
| 85 |
+
if i < len(sizes) - 2:
|
| 86 |
+
layers.append(act())
|
| 87 |
+
self.model = nn.Sequential(*layers)
|
| 88 |
+
|
| 89 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 90 |
+
# if self.emb_size is not None:
|
| 91 |
+
# x = self.projector(x)
|
| 92 |
+
return self.model(x)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class MlpTransformer(nn.Module):
|
| 96 |
+
def __init__(self, in_dim, h_dim, out_d: Optional[int] = None, act=nnf.relu, dropout=0.):
|
| 97 |
+
super().__init__()
|
| 98 |
+
out_d = out_d if out_d is not None else in_dim
|
| 99 |
+
self.fc1 = nn.Linear(in_dim, h_dim)
|
| 100 |
+
self.act = act
|
| 101 |
+
self.fc2 = nn.Linear(h_dim, out_d)
|
| 102 |
+
self.dropout = nn.Dropout(dropout)
|
| 103 |
+
|
| 104 |
+
def forward(self, x):
|
| 105 |
+
x = self.fc1(x)
|
| 106 |
+
x = self.act(x)
|
| 107 |
+
x = self.dropout(x)
|
| 108 |
+
x = self.fc2(x)
|
| 109 |
+
x = self.dropout(x)
|
| 110 |
+
return x
|
| 111 |
+
|
| 112 |
+
class MultiHeadAttention(nn.Module):
|
| 113 |
+
|
| 114 |
+
def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.):
|
| 115 |
+
super().__init__()
|
| 116 |
+
self.num_heads = num_heads
|
| 117 |
+
head_dim = dim_self // num_heads
|
| 118 |
+
self.scale = head_dim ** -0.5
|
| 119 |
+
self.to_queries = nn.Linear(dim_self, dim_self, bias=bias)
|
| 120 |
+
self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias)
|
| 121 |
+
self.project = nn.Linear(dim_self, dim_self)
|
| 122 |
+
self.dropout = nn.Dropout(dropout)
|
| 123 |
+
|
| 124 |
+
def forward(self, x, y=None, mask=None):
|
| 125 |
+
y = y if y is not None else x
|
| 126 |
+
b, n, c = x.shape
|
| 127 |
+
_, m, d = y.shape
|
| 128 |
+
# b n h dh
|
| 129 |
+
queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads)
|
| 130 |
+
# b m 2 h dh
|
| 131 |
+
keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads)
|
| 132 |
+
keys, values = keys_values[:, :, 0], keys_values[:, :, 1]
|
| 133 |
+
attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale
|
| 134 |
+
if mask is not None:
|
| 135 |
+
if mask.dim() == 2:
|
| 136 |
+
mask = mask.unsqueeze(1)
|
| 137 |
+
attention = attention.masked_fill(mask.unsqueeze(3), float("-inf"))
|
| 138 |
+
attention = attention.softmax(dim=2)
|
| 139 |
+
out = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, c)
|
| 140 |
+
out = self.project(out)
|
| 141 |
+
return out, attention
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class TransformerLayer(nn.Module):
|
| 145 |
+
|
| 146 |
+
def forward_with_attention(self, x, y=None, mask=None):
|
| 147 |
+
x_, attention = self.attn(self.norm1(x), y, mask)
|
| 148 |
+
x = x + x_
|
| 149 |
+
x = x + self.mlp(self.norm2(x))
|
| 150 |
+
return x, attention
|
| 151 |
+
|
| 152 |
+
def forward(self, x, y=None, mask=None):
|
| 153 |
+
x = x + self.attn(self.norm1(x), y, mask)[0]
|
| 154 |
+
x = x + self.mlp(self.norm2(x))
|
| 155 |
+
return x
|
| 156 |
+
|
| 157 |
+
def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., act=nnf.relu,
|
| 158 |
+
norm_layer: nn.Module = nn.LayerNorm):
|
| 159 |
+
super().__init__()
|
| 160 |
+
self.norm1 = norm_layer(dim_self)
|
| 161 |
+
self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout)
|
| 162 |
+
self.norm2 = norm_layer(dim_self)
|
| 163 |
+
self.mlp = MlpTransformer(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class Transformer(nn.Module):
|
| 167 |
+
def __init__(self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None,
|
| 168 |
+
mlp_ratio: float = 2., act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm, enc_dec: bool = False):
|
| 169 |
+
super(Transformer, self).__init__()
|
| 170 |
+
dim_ref = dim_ref if dim_ref is not None else dim_self
|
| 171 |
+
self.enc_dec = enc_dec
|
| 172 |
+
if enc_dec:
|
| 173 |
+
num_layers = num_layers * 2
|
| 174 |
+
layers = []
|
| 175 |
+
for i in range(num_layers):
|
| 176 |
+
if i % 2 == 0 and enc_dec: # cross
|
| 177 |
+
layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
|
| 178 |
+
elif enc_dec: # self
|
| 179 |
+
layers.append(TransformerLayer(dim_self, dim_self, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
|
| 180 |
+
else: # self or cross
|
| 181 |
+
layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
|
| 182 |
+
self.layers = nn.ModuleList(layers)
|
| 183 |
+
|
| 184 |
+
def forward_with_attention(self, x, y=None, mask=None):
|
| 185 |
+
attentions = []
|
| 186 |
+
for layer in self.layers:
|
| 187 |
+
x, att = layer.forward_with_attention(x, y, mask)
|
| 188 |
+
attentions.append(att)
|
| 189 |
+
return x, attentions
|
| 190 |
+
|
| 191 |
+
def forward(self, x, y=None, mask=None):
|
| 192 |
+
for i, layer in enumerate(self.layers):
|
| 193 |
+
if i % 2 == 0 and self.enc_dec: # cross
|
| 194 |
+
x = layer(x, y)
|
| 195 |
+
elif self.enc_dec: # self
|
| 196 |
+
x = layer(x, x, mask)
|
| 197 |
+
else: # self or cross
|
| 198 |
+
x = layer(x, y, mask)
|
| 199 |
+
return x
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class TransformerMapper(nn.Module):
|
| 203 |
+
def __init__(self, emb_size, dim_clip: int, dim_embedding: int, prefix_length: int, clip_length: int, num_layers: int = 8):
|
| 204 |
+
super(TransformerMapper, self).__init__()
|
| 205 |
+
self.emb_size = emb_size
|
| 206 |
+
# if self.emb_size is not None:
|
| 207 |
+
# self.projector = Projection(emb_size, dim_clip)
|
| 208 |
+
self.clip_length = clip_length
|
| 209 |
+
self.transformer = Transformer(dim_embedding, 8, num_layers)
|
| 210 |
+
self.linear = nn.Linear(dim_clip, clip_length * dim_embedding)
|
| 211 |
+
self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True)
|
| 212 |
+
|
| 213 |
+
def forward(self, x):
|
| 214 |
+
if self.emb_size is not None:
|
| 215 |
+
x = self.projector(x)
|
| 216 |
+
# raise SystemError(x.shape) # torch.Size([100, 1024])
|
| 217 |
+
x = self.linear(x).view(x.shape[0], self.clip_length, -1)
|
| 218 |
+
# raise SystemError(x.shape) # torch.Size([100, 40, 768])
|
| 219 |
+
prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape)
|
| 220 |
+
prefix = torch.cat((x, prefix), dim=1) # shape is batch x seq x dim = b x 40+40 x 768 (clip length is 40)
|
| 221 |
+
out = self.transformer(prefix)[:, self.clip_length:]
|
| 222 |
+
# raise SystemError(out.shape) # torch.Size([100, 40, 768]) sid prefix
|
| 223 |
+
return out
|
| 224 |
+
|
| 225 |
+
class TransformerMapperSeq(nn.Module):
|
| 226 |
+
def __init__(self, emb_size ,dim_clip: int, dim_embedding: int, prefix_length: int, clip_length: int, num_layers: int = 8):
|
| 227 |
+
super(TransformerMapperSeq, self).__init__()
|
| 228 |
+
self.emb_size = emb_size
|
| 229 |
+
# if self.emb_size is not None:
|
| 230 |
+
# self.projector = Projection(emb_size, dim_clip)
|
| 231 |
+
self.clip_length = clip_length
|
| 232 |
+
self.transformer = Transformer(dim_embedding, 8, num_layers)
|
| 233 |
+
self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True)
|
| 234 |
+
|
| 235 |
+
def forward(self, x):
|
| 236 |
+
# if self.emb_size is not None:
|
| 237 |
+
# x = self.projector(x)
|
| 238 |
+
# raise SystemError(x.shape) # torch.Size([32, 80, 768])
|
| 239 |
+
x = x.view(x.shape[0], self.clip_length, -1)
|
| 240 |
+
prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape)
|
| 241 |
+
# raise SystemError(prefix.shape, x.shape) # torch.Size([32, 40, 768]) torch.Size([32, 40, 1536])
|
| 242 |
+
prefix = torch.cat((x, prefix), dim=1)
|
| 243 |
+
out = self.transformer(prefix)[:, self.clip_length:]
|
| 244 |
+
# raise SystemError(out.shape) # torch.Size([100, 80, 768]) text prefix
|
| 245 |
+
return out
|
pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e0d80efbeffb56f4038bf9d320d15b5377d12b1cb85833e908d9f0f6b5c2bbab
|
| 3 |
+
size 2066033810
|
wrapper.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from transformers import AutoTokenizer
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
from collections import OrderedDict
|
| 6 |
+
import librosa
|
| 7 |
+
from importlib_resources import files
|
| 8 |
+
import yaml
|
| 9 |
+
import argparse
|
| 10 |
+
import torchaudio
|
| 11 |
+
import torchaudio.transforms as T
|
| 12 |
+
import collections
|
| 13 |
+
import random
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 19 |
+
import logging
|
| 20 |
+
from glob import glob
|
| 21 |
+
|
| 22 |
+
from mapper import get_sid_mapper, get_text_mapper
|
| 23 |
+
from transformers import GPT2LMHeadModel
|
| 24 |
+
from transformers import AutoTokenizer
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class ExpWrapper():
|
| 28 |
+
def __init__(self, config_wrapper, gpu_id):
|
| 29 |
+
self.tok_len = config_wrapper['tok_len']
|
| 30 |
+
self.text_prefix_length = config_wrapper['text_prefix_length']
|
| 31 |
+
self.sid_prefix_length = config_wrapper['sid_prefix_length']
|
| 32 |
+
self.norm_sid_emb = config_wrapper['norm_sid_emb']
|
| 33 |
+
self.gpu_id = gpu_id
|
| 34 |
+
self.gpt = GPT2LMHeadModel.from_pretrained(config_wrapper['text_decoder'])
|
| 35 |
+
self.gpt = self.gpt.to(self.gpu_id)
|
| 36 |
+
# for param in self.gpt.parameters():
|
| 37 |
+
# param.requires_grad = False
|
| 38 |
+
|
| 39 |
+
self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
|
| 40 |
+
|
| 41 |
+
self.sid_mapper = get_sid_mapper(config_wrapper["map_type"],None,
|
| 42 |
+
config_wrapper["prefix_size"], self.gpt_embedding_size,
|
| 43 |
+
config_wrapper["sid_prefix_length"], config_wrapper["sid_prefix_length_clip"],
|
| 44 |
+
config_wrapper["num_layers"])
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# self.text_mapper = get_text_mapper(config_wrapper["map_type"], None,
|
| 48 |
+
# config_wrapper["prefix_size"], self.gpt_embedding_size,
|
| 49 |
+
# config_wrapper["text_prefix_length"], config_wrapper["text_prefix_length_clip"],
|
| 50 |
+
# config_wrapper["num_layers"])
|
| 51 |
+
# # this is temporary
|
| 52 |
+
# if config_wrapper["checkpoint_path"]:
|
| 53 |
+
# checkpoint = torch.load(config_wrapper["checkpoint_path"])
|
| 54 |
+
# state_dict = checkpoint['model']
|
| 55 |
+
# text_project_weights = {k.replace('caption_decoder.text_project.',''): v for k, v in state_dict.items()
|
| 56 |
+
# if 'caption_decoder.text_project' in k}
|
| 57 |
+
# self.text_mapper.load_state_dict(text_project_weights)
|
| 58 |
+
|
| 59 |
+
self.sid_mapper = self.sid_mapper.to(self.gpu_id)
|
| 60 |
+
# self.text_mapper = self.text_mapper.to(self.gpu_id)
|
| 61 |
+
self.tokenizer = AutoTokenizer.from_pretrained(config_wrapper['text_decoder'])
|
| 62 |
+
self.tokenizer.add_special_tokens({'pad_token': '!'})
|
| 63 |
+
|
| 64 |
+
def init_mapper(self):
|
| 65 |
+
self.sid_mapper = DDP(self.sid_mapper, device_ids=[self.gpu_id], find_unused_parameters=True)
|
| 66 |
+
|
| 67 |
+
def freeze_llm(self):
|
| 68 |
+
for param in self.sid_mapper.parameters():
|
| 69 |
+
param.requires_grad = False
|
| 70 |
+
for param in self.gpt.parameters():
|
| 71 |
+
param.requires_grad = False
|
| 72 |
+
|
| 73 |
+
def default_collate(self, batch):
|
| 74 |
+
r"""Puts each data field into a tensor with outer dimension batch size"""
|
| 75 |
+
elem = batch[0]
|
| 76 |
+
elem_type = type(elem)
|
| 77 |
+
if isinstance(elem, torch.Tensor):
|
| 78 |
+
out = None
|
| 79 |
+
if torch.utils.data.get_worker_info() is not None:
|
| 80 |
+
# If we're in a background process, concatenate directly into a
|
| 81 |
+
# shared memory tensor to avoid an extra copy
|
| 82 |
+
numel = sum([x.numel() for x in batch])
|
| 83 |
+
storage = elem.storage()._new_shared(numel)
|
| 84 |
+
out = elem.new(storage)
|
| 85 |
+
return torch.stack(batch, 0, out=out)
|
| 86 |
+
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
| 87 |
+
and elem_type.__name__ != 'string_':
|
| 88 |
+
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
|
| 89 |
+
# array of string classes and object
|
| 90 |
+
if self.np_str_obj_array_pattern.search(elem.dtype.str) is not None:
|
| 91 |
+
raise TypeError(
|
| 92 |
+
self.default_collate_err_msg_format.format(elem.dtype))
|
| 93 |
+
|
| 94 |
+
return self.default_collate([torch.as_tensor(b) for b in batch])
|
| 95 |
+
elif elem.shape == (): # scalars
|
| 96 |
+
return torch.as_tensor(batch)
|
| 97 |
+
elif isinstance(elem, float):
|
| 98 |
+
return torch.tensor(batch, dtype=torch.float64)
|
| 99 |
+
elif isinstance(elem, int):
|
| 100 |
+
return torch.tensor(batch)
|
| 101 |
+
elif isinstance(elem, collections.abc.Mapping):
|
| 102 |
+
return {key: self.default_collate([d[key] for d in batch]) for key in elem}
|
| 103 |
+
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
|
| 104 |
+
return elem_type(*(self.default_collate(samples) for samples in zip(*batch)))
|
| 105 |
+
elif isinstance(elem, collections.abc.Sequence):
|
| 106 |
+
# check to make sure that the elements in batch have consistent size
|
| 107 |
+
it = iter(batch)
|
| 108 |
+
elem_size = len(next(it))
|
| 109 |
+
if not all(len(elem) == elem_size for elem in it):
|
| 110 |
+
raise RuntimeError(
|
| 111 |
+
'each element in list of batch should be of equal size')
|
| 112 |
+
transposed = zip(*batch)
|
| 113 |
+
return [self.default_collate(samples) for samples in transposed]
|
| 114 |
+
|
| 115 |
+
raise TypeError(self.default_collate_err_msg_format.format(elem_type))
|
| 116 |
+
|
| 117 |
+
def load_model(self, st, model):
|
| 118 |
+
try:
|
| 119 |
+
model.load_state_dict(st)
|
| 120 |
+
except:
|
| 121 |
+
for key in list(st.keys()):
|
| 122 |
+
if "module." in key:
|
| 123 |
+
st[key.replace("module.", "")] = st.pop(key)
|
| 124 |
+
model.load_state_dict(st)
|
| 125 |
+
return model
|
| 126 |
+
|
| 127 |
+
def load_model(self, st, model):
|
| 128 |
+
try:
|
| 129 |
+
model.load_state_dict(st)
|
| 130 |
+
except:
|
| 131 |
+
for key in list(st.keys()):
|
| 132 |
+
if "module." in key:
|
| 133 |
+
st[key.replace("module.", "")] = st.pop(key)
|
| 134 |
+
model.load_state_dict(st)
|
| 135 |
+
return model
|
| 136 |
+
|
| 137 |
+
def load_sid_model(self, sid_model, snapshot_path, sid_ck_name):
|
| 138 |
+
loc = f"cuda:{self.gpu_id}"
|
| 139 |
+
# sid_model_path = sorted(glob(f"{snapshot_path}/sid_model_epoch_*.pt"),
|
| 140 |
+
# key=lambda x: float(x.split('_')[-1].replace('.pt', '')))[0]
|
| 141 |
+
sid_model_path = f"{snapshot_path}/{sid_ck_name}"
|
| 142 |
+
snapshot = torch.load(sid_model_path, map_location=loc)
|
| 143 |
+
sid_model = self.load_model(snapshot["sid_model"], sid_model)
|
| 144 |
+
best_val_loss = snapshot["val_loss"]
|
| 145 |
+
epochs_run = snapshot["epochs_run"]
|
| 146 |
+
|
| 147 |
+
def load_mapper(self, snapshot_path, mapper_ck_name):
|
| 148 |
+
loc = f"cuda:{self.gpu_id}"
|
| 149 |
+
mapper_path = sorted(glob(f"{snapshot_path}/mapper_*.pt"))[-1]
|
| 150 |
+
mapper_path = f"{snapshot_path}/{mapper_ck_name}"
|
| 151 |
+
snapshot = torch.load(mapper_path, map_location=loc)
|
| 152 |
+
|
| 153 |
+
self.sid_mapper = self.load_model(snapshot["sid_mapper"],self.sid_mapper)
|
| 154 |
+
# self.text_mapper = self.load_model(snapshot["text_mapper"],self.text_mapper)
|
| 155 |
+
|
| 156 |
+
self.epochs_run = snapshot["epochs_run"]
|
| 157 |
+
logging.info(f"Resuming training from mapper at Epoch {self.epochs_run}")
|
| 158 |
+
|
| 159 |
+
def save_mapper(self, epoch, snapshot_path, val_epoch_ce_llm):
|
| 160 |
+
mapper = {
|
| 161 |
+
# "text_mapper": self.text_mapper.state_dict(),
|
| 162 |
+
"sid_mapper": self.sid_mapper.state_dict(),
|
| 163 |
+
"epochs_run": epoch,
|
| 164 |
+
}
|
| 165 |
+
part = snapshot_path
|
| 166 |
+
torch.save(mapper, f"{part}/unfrozen_mapper_epoch_{str(epoch).zfill(4)}_val_epoch_ce_llm_{val_epoch_ce_llm}.pt")
|
| 167 |
+
logging.info(f"Epoch {epoch} | Training mapper saved at {snapshot_path}")
|
| 168 |
+
|
| 169 |
+
def preprocess_prompt(self, texts): # true false
|
| 170 |
+
r"""Load list of prompts and return tokenized text"""
|
| 171 |
+
tokenized_texts = []
|
| 172 |
+
for ttext in texts:
|
| 173 |
+
tok = self.tokenizer.encode_plus(
|
| 174 |
+
text=ttext, add_special_tokens=True,
|
| 175 |
+
max_length=10,
|
| 176 |
+
pad_to_max_length=True, return_tensors="pt", truncation=True)
|
| 177 |
+
for key in tok.keys():
|
| 178 |
+
tok[key] = tok[key].reshape(-1).to(self.gpu_id)
|
| 179 |
+
tokenized_texts.append(tok)
|
| 180 |
+
return self.default_collate(tokenized_texts)
|
| 181 |
+
|
| 182 |
+
def preprocess_prompt_single(self, texts): # true false
|
| 183 |
+
r"""Load list of prompts and return tokenized text"""
|
| 184 |
+
tokenized_texts = []
|
| 185 |
+
tok = self.tokenizer.encode_plus(
|
| 186 |
+
text=texts, add_special_tokens=True,
|
| 187 |
+
max_length=10,
|
| 188 |
+
pad_to_max_length=True, return_tensors="pt", truncation=True)
|
| 189 |
+
for key in tok.keys():
|
| 190 |
+
tok[key] = tok[key].reshape(-1).to(self.gpu_id)
|
| 191 |
+
tokenized_texts.append(tok)
|
| 192 |
+
return self.default_collate(tokenized_texts)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def preprocess_text(self, texts): # true false
|
| 196 |
+
r"""Load list of prompts and return tokenized text"""
|
| 197 |
+
tokenized_texts = []
|
| 198 |
+
for ttext in texts:
|
| 199 |
+
ttext = ttext + ' <|endoftext|>'
|
| 200 |
+
tok = self.tokenizer.encode_plus(
|
| 201 |
+
text=ttext, add_special_tokens=True,
|
| 202 |
+
max_length=self.tok_len,
|
| 203 |
+
pad_to_max_length=True, return_tensors="pt", truncation=True)
|
| 204 |
+
for key in tok.keys():
|
| 205 |
+
tok[key] = tok[key].reshape(-1).to(self.gpu_id)
|
| 206 |
+
tokenized_texts.append(tok)
|
| 207 |
+
return self.default_collate(tokenized_texts)
|
| 208 |
+
|
| 209 |
+
def _get_text_embeddings(self, preprocessed_texts):
|
| 210 |
+
r"""Load preprocessed prompts and return a prompt embeddings"""
|
| 211 |
+
with torch.no_grad():
|
| 212 |
+
texts_embed = self.gpt.transformer.wte(preprocessed_texts['input_ids'])
|
| 213 |
+
return texts_embed
|
| 214 |
+
|
| 215 |
+
def get_sid_prefix(self, sid_embeddings):
|
| 216 |
+
r"""Produces audio embedding which is fed to LM"""
|
| 217 |
+
if self.norm_sid_emb:
|
| 218 |
+
sid_embeddings = sid_embeddings / sid_embeddings.norm(2, -1).reshape(-1,1)
|
| 219 |
+
|
| 220 |
+
# raise SystemError(sid_embeddings.shape) # torch.Size([2, 1024])
|
| 221 |
+
sids_prefix = self.sid_mapper(sid_embeddings).contiguous().view(-1, self.sid_prefix_length, self.gpt_embedding_size)
|
| 222 |
+
# raise SystemError(sids_prefix.shape) # torch.Size([2, 40, 768]) batch_size, seq_len, embed_size
|
| 223 |
+
return sids_prefix
|
| 224 |
+
|
| 225 |
+
def get_prompt_prefix(self, texts):
|
| 226 |
+
r"""Load list of text prompts and return prompt prefix and prompt embeddings"""
|
| 227 |
+
preprocessed_texts = self.preprocess_prompt(texts)
|
| 228 |
+
print(preprocessed_texts)
|
| 229 |
+
texts_embed = self._get_text_embeddings(preprocessed_texts)
|
| 230 |
+
return texts_embed, preprocessed_texts
|
| 231 |
+
def get_prompt_prefix_single(self, texts):
|
| 232 |
+
r"""Load list of text prompts and return prompt prefix and prompt embeddings"""
|
| 233 |
+
preprocessed_texts = self.preprocess_prompt_single(texts)
|
| 234 |
+
texts_embed = self._get_text_embeddings(preprocessed_texts)
|
| 235 |
+
return texts_embed, preprocessed_texts
|
| 236 |
+
|
| 237 |
+
def get_text_prefix(self, texts):
|
| 238 |
+
r"""Load list of text prompts and return prompt prefix and prompt embeddings"""
|
| 239 |
+
preprocessed_texts = self.preprocess_text(texts)
|
| 240 |
+
texts_embed = self._get_text_embeddings(preprocessed_texts)
|
| 241 |
+
return texts_embed, preprocessed_texts
|
| 242 |
+
|
| 243 |
+
def generate_beam(self, beam_size: int = 1, sids_prefix=None, entry_length=80, temperature=1., stop_token: str = ' <|endoftext|>'):
|
| 244 |
+
stop_token_index = self.tokenizer.encode(stop_token)[0]
|
| 245 |
+
tokens = None
|
| 246 |
+
scores = None
|
| 247 |
+
device = next(self.gpt.parameters()).device
|
| 248 |
+
seq_lengths = torch.ones(beam_size, device=device)
|
| 249 |
+
is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
|
| 250 |
+
with torch.no_grad():
|
| 251 |
+
generated = sids_prefix # sid embedding
|
| 252 |
+
for i in range(entry_length):
|
| 253 |
+
outputs = self.gpt(inputs_embeds=generated)
|
| 254 |
+
logits = outputs.logits
|
| 255 |
+
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
|
| 256 |
+
logits = logits.softmax(-1).log()
|
| 257 |
+
if scores is None:
|
| 258 |
+
scores, next_tokens = logits.topk(beam_size, -1)
|
| 259 |
+
generated = generated.expand(beam_size, *generated.shape[1:])
|
| 260 |
+
next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
|
| 261 |
+
if tokens is None:
|
| 262 |
+
tokens = next_tokens
|
| 263 |
+
else:
|
| 264 |
+
tokens = tokens.expand(beam_size, *tokens.shape[1:])
|
| 265 |
+
tokens = torch.cat((tokens, next_tokens), dim=1)
|
| 266 |
+
else:
|
| 267 |
+
logits[is_stopped] = -float(np.inf)
|
| 268 |
+
logits[is_stopped, 0] = 0
|
| 269 |
+
scores_sum = scores[:, None] + logits
|
| 270 |
+
seq_lengths[~is_stopped] += 1
|
| 271 |
+
scores_sum_average = scores_sum / seq_lengths[:, None]
|
| 272 |
+
scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
|
| 273 |
+
next_tokens_source = next_tokens // scores_sum.shape[1]
|
| 274 |
+
seq_lengths = seq_lengths[next_tokens_source]
|
| 275 |
+
next_tokens = next_tokens % scores_sum.shape[1]
|
| 276 |
+
next_tokens = next_tokens.unsqueeze(1)
|
| 277 |
+
tokens = tokens[next_tokens_source]
|
| 278 |
+
tokens = torch.cat((tokens, next_tokens), dim=1)
|
| 279 |
+
generated = generated[next_tokens_source]
|
| 280 |
+
scores = scores_sum_average * seq_lengths
|
| 281 |
+
is_stopped = is_stopped[next_tokens_source]
|
| 282 |
+
|
| 283 |
+
next_token_embed = self.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
|
| 284 |
+
generated = torch.cat((generated, next_token_embed), dim=1)
|
| 285 |
+
is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
|
| 286 |
+
if is_stopped.all():
|
| 287 |
+
break
|
| 288 |
+
scores = scores / seq_lengths
|
| 289 |
+
output_list = tokens.cpu().numpy()
|
| 290 |
+
############ Shuo added for attn plot ###########
|
| 291 |
+
# token_list = []
|
| 292 |
+
# text_list = []
|
| 293 |
+
# for output, length in zip(output_list, seq_lengths):
|
| 294 |
+
# for item in output[:int(length)]:
|
| 295 |
+
# token_list.append(item)
|
| 296 |
+
# text_list.append(self.tokenizer.decode(item))
|
| 297 |
+
############ Shuo added for attn plot ###########
|
| 298 |
+
output_texts = [self.tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)]
|
| 299 |
+
order = scores.argsort(descending=True)
|
| 300 |
+
#output_texts = [[output_texts[i], scores[i].item()] for i in order]
|
| 301 |
+
output_texts = [output_texts[i] for i in order]
|
| 302 |
+
return output_texts
|
| 303 |
+
# return output_texts, token_list, text_list
|
| 304 |
+
|
| 305 |
+
|