massabaali commited on
Commit
f55a095
·
verified ·
1 Parent(s): ed8a7df

Upload CoLMbo model weights and code

Browse files
Files changed (36) hide show
  1. config.json +17 -0
  2. encoder/__pycache__/attentive_pooling.cpython-310.pyc +0 -0
  3. encoder/__pycache__/attentive_pooling.cpython-38.pyc +0 -0
  4. encoder/__pycache__/encoder.cpython-310.pyc +0 -0
  5. encoder/__pycache__/encoder.cpython-38.pyc +0 -0
  6. encoder/__pycache__/encoder.cpython-39.pyc +0 -0
  7. encoder/__pycache__/mha.cpython-310.pyc +0 -0
  8. encoder/__pycache__/mha.cpython-38.pyc +0 -0
  9. encoder/__pycache__/self_attn.cpython-310.pyc +0 -0
  10. encoder/__pycache__/self_attn.cpython-38.pyc +0 -0
  11. encoder/attentive_pooling.py +33 -0
  12. encoder/encoder.py +35 -0
  13. encoder/mha.py +62 -0
  14. encoder/self_attn.py +81 -0
  15. load_data/__pycache__/combineddataset.cpython-38.pyc +0 -0
  16. load_data/__pycache__/data_collactor.cpython-310.pyc +0 -0
  17. load_data/__pycache__/data_collactor.cpython-38.pyc +0 -0
  18. load_data/__pycache__/dataset.cpython-38.pyc +0 -0
  19. load_data/__pycache__/extract_fbanks.cpython-310.pyc +0 -0
  20. load_data/__pycache__/extract_fbanks.cpython-38.pyc +0 -0
  21. load_data/__pycache__/prepare_dataloader.cpython-310.pyc +0 -0
  22. load_data/__pycache__/prepare_dataloader.cpython-38.pyc +0 -0
  23. load_data/__pycache__/tears.cpython-38.pyc +0 -0
  24. load_data/__pycache__/timit.cpython-38.pyc +0 -0
  25. load_data/__pycache__/voxceleb.cpython-38.pyc +0 -0
  26. load_data/combineddataset.py +29 -0
  27. load_data/data_collactor.py +74 -0
  28. load_data/dataset.py +109 -0
  29. load_data/extract_fbanks.py +55 -0
  30. load_data/prepare_dataloader.py +22 -0
  31. load_data/tears.py +232 -0
  32. load_data/timit.py +102 -0
  33. load_data/voxceleb.py +63 -0
  34. mapper.py +245 -0
  35. pytorch_model.bin +3 -0
  36. wrapper.py +305 -0
config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "colmbo",
3
+ "architectures": [
4
+ "CoLMboModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "modeling_colmbo.CoLMboConfig",
8
+ "AutoModel": "modeling_colmbo.CoLMboModel"
9
+ },
10
+ "n_mels": 80,
11
+ "embedding_dim": 192,
12
+ "channel": 1024,
13
+ "prefix_length": 10,
14
+ "gpt_model_name": "gpt2",
15
+ "sample_rate": 16000,
16
+ "torch_dtype": "float32"
17
+ }
encoder/__pycache__/attentive_pooling.cpython-310.pyc ADDED
Binary file (1.54 kB). View file
 
encoder/__pycache__/attentive_pooling.cpython-38.pyc ADDED
Binary file (1.53 kB). View file
 
encoder/__pycache__/encoder.cpython-310.pyc ADDED
Binary file (1.64 kB). View file
 
encoder/__pycache__/encoder.cpython-38.pyc ADDED
Binary file (1.66 kB). View file
 
encoder/__pycache__/encoder.cpython-39.pyc ADDED
Binary file (1.63 kB). View file
 
encoder/__pycache__/mha.cpython-310.pyc ADDED
Binary file (2.21 kB). View file
 
encoder/__pycache__/mha.cpython-38.pyc ADDED
Binary file (2.22 kB). View file
 
encoder/__pycache__/self_attn.cpython-310.pyc ADDED
Binary file (3.71 kB). View file
 
encoder/__pycache__/self_attn.cpython-38.pyc ADDED
Binary file (3.74 kB). View file
 
encoder/attentive_pooling.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class SelfAttentionPooling(nn.Module):
5
+ """
6
+ Implementation of SelfAttentionPooling
7
+ Original Paper: Self-Attention Encoding and Pooling for Speaker Recognition
8
+ https://arxiv.org/pdf/2008.01077v1.pdf
9
+ """
10
+ def __init__(self, input_dim):
11
+ super(SelfAttentionPooling, self).__init__()
12
+ self.W = nn.Linear(input_dim, 1)
13
+ def forward(self, batch_rep, att_mask):
14
+ """
15
+ input:
16
+ batch_rep : size (N, T, H), N: batch size, T: sequence length, H: Hidden dimension
17
+
18
+ attention_weight:
19
+ att_w : size (N, T, 1)
20
+
21
+ return:
22
+ utter_rep: size (N, H)
23
+ """
24
+ seq_len = batch_rep.shape[1]
25
+ softmax = nn.functional.softmax
26
+ att_logits = self.W(batch_rep).squeeze(-1)
27
+ att_mask = att_mask[:, :, 0]
28
+ att_logits = att_mask + att_logits
29
+ att_w = softmax(att_logits, dim=-1).unsqueeze(-1)
30
+ utter_rep = torch.sum(batch_rep * att_w, dim=1)
31
+ attn_out_std = torch.sqrt(torch.sum(att_w * (batch_rep - utter_rep.unsqueeze(1))**2, dim=1))
32
+
33
+ return utter_rep, attn_out_std
encoder/encoder.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from speechbrain.lobes.models.ECAPA_TDNN import ECAPA_TDNN
3
+
4
+ class Model(torch.nn.Module):
5
+ def __init__(self, n_mels=80, embedding_dim=192, channel=512):
6
+ super(Model, self).__init__()
7
+ channels = [channel for _ in range(4)]
8
+ channels.append(channel * 3)
9
+ self.model = ECAPA_TDNN(input_size=n_mels, lin_neurons=embedding_dim, channels=channels)
10
+
11
+ def forward(self, x):
12
+ x = x.squeeze(1)
13
+ x = self.model(x)
14
+ x = x.squeeze(1)
15
+ return x
16
+
17
+ if __name__ == '__main__':
18
+ # Fixing the naming issue for 'channel'
19
+ model = Model(n_mels=80, embedding_dim=192, channel=1024)
20
+
21
+ # Load the pretrained model checkpoint
22
+ checkpoint = torch.load("/ocean/projects/cis220031p/abdulhan/AVIS_baseline/ECAPA/pretrained_models/spkrec-ecapa-voxceleb/embedding_model.ckpt")
23
+
24
+ new_state_dict = {f"model.{k}": v for k, v in checkpoint.items()}
25
+
26
+ # Assuming the checkpoint contains the state dict directly
27
+ model.load_state_dict(new_state_dict)
28
+
29
+ # To evaluate or use the model
30
+ model.eval()
31
+
32
+ # Test with dummy input (B, 1, n_mels, T)
33
+ dummy_input = torch.randn(1, 1, 300, 80) # Example input
34
+ output = model(dummy_input)
35
+ print(output.shape)
encoder/mha.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ class MultiHeadAttention(nn.Module):
6
+ def __init__(self, d_model, num_heads):
7
+ super(MultiHeadAttention, self).__init__()
8
+ # Ensure that the model dimension (d_model) is divisible by the number of heads
9
+ assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
10
+
11
+ # Initialize dimensions
12
+ self.d_model = d_model # Model's dimension
13
+ self.num_heads = num_heads # Number of attention heads
14
+ self.d_k = d_model // num_heads # Dimension of each head's key, query, and value
15
+
16
+ # Linear layers for transforming inputs
17
+ self.W_q = nn.Linear(d_model, d_model) # Query transformation
18
+ self.W_k = nn.Linear(d_model, d_model) # Key transformation
19
+ self.W_v = nn.Linear(d_model, d_model) # Value transformation
20
+ self.W_o = nn.Linear(d_model, d_model) # Output transformation
21
+
22
+ def scaled_dot_product_attention(self, Q, K, V, prob_phn=None, mask=None, lambda_val=None):
23
+ attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
24
+ # Add a singleton dimension to prob_phn at index 1
25
+ prob_phn = prob_phn.unsqueeze(1)
26
+ # Expand prob_phn to match the shape of attn_scores
27
+ # This will not increase memory usage as expand returns a new view on the existing tensor
28
+ prob_phn = prob_phn.expand(-1, self.num_heads, -1, -1)
29
+ if lambda_val > 0:
30
+ attn_scores = attn_scores - lambda_val * prob_phn.transpose(-2, -1)
31
+ attn_mask = mask
32
+ if mask is not None:
33
+ # print(mask.shape)
34
+ mask = mask.unsqueeze(1)
35
+ mask = mask.expand(-1, self.num_heads, -1, -1)
36
+ attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
37
+ attn_probs = torch.softmax(attn_scores, dim=-1)
38
+ attn_probs = attn_probs.float()
39
+ output = torch.matmul(attn_probs, V)
40
+ return output, attn_mask
41
+ def split_heads(self, x):
42
+ # Reshape the input to have num_heads for multi-head attention
43
+ batch_size, seq_length, d_model = x.size()
44
+ return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
45
+
46
+ def combine_heads(self, x):
47
+ # Combine the multiple heads back to original shape
48
+ batch_size, _, seq_length, d_k = x.size()
49
+ return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
50
+
51
+ def forward(self, Q, K, V, prob_phn=None, mask=None, lambda_val=None):
52
+ # Apply linear transformations and split heads
53
+ Q = self.split_heads(self.W_q(Q))
54
+ K = self.split_heads(self.W_k(K))
55
+ V = self.split_heads(self.W_v(V))
56
+
57
+ # Perform scaled dot-product attention
58
+ attn_output, attn_mask = self.scaled_dot_product_attention(Q, K, V, prob_phn, mask,lambda_val)
59
+
60
+ # Combine heads and apply output transformation
61
+ output = self.W_o(self.combine_heads(attn_output))
62
+ return output, attn_mask
encoder/self_attn.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from encoder.mha import MultiHeadAttention
4
+ from encoder.attentive_pooling import SelfAttentionPooling
5
+
6
+ class FlippedReLU(nn.Module):
7
+ def __init__(self):
8
+ super(FlippedReLU, self).__init__()
9
+
10
+ def forward(self, x):
11
+ return torch.where(x < 0, x, torch.zeros_like(x))
12
+
13
+ class PositionWiseFeedForward(nn.Module):
14
+ def __init__(self, d_model, d_ff):
15
+ super(PositionWiseFeedForward, self).__init__()
16
+ self.fc1 = nn.Linear(d_model, d_ff)
17
+ self.fc2 = nn.Linear(d_ff, d_model)
18
+ self.relu = nn.ReLU()
19
+
20
+ def forward(self, x):
21
+ return self.fc2(self.relu(self.fc1(x)))
22
+
23
+ class EncoderLayer(nn.Module):
24
+ def __init__(self, d_model, num_heads, d_ff, dropout):
25
+ super(EncoderLayer, self).__init__()
26
+ self.self_attn = MultiHeadAttention(d_model, num_heads)
27
+ self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
28
+ self.norm1 = nn.LayerNorm(d_model)
29
+ self.norm2 = nn.LayerNorm(d_model)
30
+ self.dropout = nn.Dropout(dropout)
31
+
32
+ def forward(self, x, prob_phn=None, mask=None, lambda_val=None):
33
+ attn_output, attn_mask = self.self_attn(x, x, x, prob_phn=prob_phn, mask=mask, lambda_val=lambda_val)
34
+ x = self.norm1(x + self.dropout(attn_output))
35
+ ff_output = self.feed_forward(x)
36
+ x = self.norm2(x + self.dropout(ff_output))
37
+ return x, attn_mask
38
+
39
+
40
+ class TransformerSelfAttention(nn.Module):
41
+ def __init__(self, input_dim, num_heads, dim_feedforward, number_Of_spks, dropout=0.0):
42
+ """EncoderBlock.
43
+
44
+ Args:
45
+ input_dim: Dimensionality of the input
46
+ num_heads: Number of heads to use in the attention block
47
+ dim_feedforward: Dimensionality of the hidden layer in the MLP
48
+ dropout: Dropout probability to use in the dropout layers
49
+ """
50
+ super().__init__()
51
+ # Attention layer
52
+ self.self_mha_attn = EncoderLayer(input_dim, num_heads, dim_feedforward*8,dropout)
53
+ self.attn_pooling = SelfAttentionPooling(input_dim)
54
+ self.emb1 = nn.Linear(input_dim*2, dim_feedforward*8)
55
+ self.emb2 = nn.Linear(input_dim*2, dim_feedforward*8)
56
+ self.emb2.weight.data = self.emb1.weight.data.clone()
57
+ self.emb2.bias.data = self.emb1.bias.data.clone()
58
+ self.bn = nn.BatchNorm1d(dim_feedforward*8)
59
+ self.act = nn.ReLU(inplace=True)
60
+ self.dropout = nn.Dropout(dropout)
61
+ self.classifier = nn.Linear(dim_feedforward*8, number_Of_spks)
62
+ self.flipped_relu = FlippedReLU()
63
+
64
+
65
+ def forward(self, x, prob_phn=None, mask=None, lambda_val=None):
66
+ # Attention part
67
+ attn_out, attn_mask = self.self_mha_attn(x,prob_phn=prob_phn, mask=mask, lambda_val=lambda_val)
68
+ attn_mask= attn_mask.squeeze(1)
69
+ attn_out_mean,attn_out_std = self.attn_pooling(attn_out,attn_mask)
70
+ attn_concat = torch.cat((attn_out_mean, attn_out_std),dim=1).to(dtype=torch.float32)
71
+
72
+ emb1 = self.emb1(attn_concat).to(dtype=torch.float32)
73
+ emb1 = self.act(emb1)
74
+
75
+ emb2 = self.emb2(attn_concat).to(dtype=torch.float32)
76
+ emb2 = self.flipped_relu(emb2)
77
+
78
+ emb = emb1 + emb2
79
+ emb = self.bn(emb)
80
+ x = self.classifier(emb)
81
+ return x,emb
load_data/__pycache__/combineddataset.cpython-38.pyc ADDED
Binary file (1.41 kB). View file
 
load_data/__pycache__/data_collactor.cpython-310.pyc ADDED
Binary file (4.31 kB). View file
 
load_data/__pycache__/data_collactor.cpython-38.pyc ADDED
Binary file (4.32 kB). View file
 
load_data/__pycache__/dataset.cpython-38.pyc ADDED
Binary file (2.85 kB). View file
 
load_data/__pycache__/extract_fbanks.cpython-310.pyc ADDED
Binary file (2.39 kB). View file
 
load_data/__pycache__/extract_fbanks.cpython-38.pyc ADDED
Binary file (2.44 kB). View file
 
load_data/__pycache__/prepare_dataloader.cpython-310.pyc ADDED
Binary file (855 Bytes). View file
 
load_data/__pycache__/prepare_dataloader.cpython-38.pyc ADDED
Binary file (851 Bytes). View file
 
load_data/__pycache__/tears.cpython-38.pyc ADDED
Binary file (6.77 kB). View file
 
load_data/__pycache__/timit.cpython-38.pyc ADDED
Binary file (3.15 kB). View file
 
load_data/__pycache__/voxceleb.cpython-38.pyc ADDED
Binary file (1.82 kB). View file
 
load_data/combineddataset.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ from torch.utils.data import Dataset, DataLoader
4
+
5
+ class CombinedDataset(Dataset):
6
+ """
7
+ A dataset that combines two datasets (TIMIT and EARS), selecting samples based on a probability.
8
+
9
+ Args:
10
+ dataset1 (Dataset): The first dataset (e.g., TIMITDataset).
11
+ dataset2 (Dataset): The second dataset (e.g., EARS).
12
+ switch_prob (float): Probability of picking from dataset1 (default: 0.5).
13
+ """
14
+ def __init__(self, dataset1, dataset2, switch_prob=0.5):
15
+ self.dataset1 = dataset1
16
+ self.dataset2 = dataset2
17
+ self.len1 = len(dataset1)
18
+ self.len2 = len(dataset2)
19
+ self.switch_prob = switch_prob # Probability of picking from dataset1
20
+
21
+ def __len__(self):
22
+ return max(self.len1, self.len2) # Use the longer dataset length
23
+
24
+ def __getitem__(self, idx):
25
+ # Decide whether to sample from dataset1 or dataset2
26
+ if random.random() < self.switch_prob:
27
+ return self.dataset1[idx % self.len1] # Sample from dataset1
28
+ else:
29
+ return self.dataset2[idx % self.len2] # Sample from dataset2
load_data/data_collactor.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoFeatureExtractor
3
+ from dataclasses import dataclass, field
4
+ from typing import Any, Dict, List, Optional, Union
5
+ from preprocessing.ast_processor import ast
6
+ from util_stats.local_stats import local_extract_phn_frame_probs
7
+ from util_stats.global_stats import global_extract_phn_frame_probs
8
+ import numpy as np
9
+ import pickle
10
+ import torch.nn.functional as F
11
+
12
+ from load_data.extract_fbanks import Mel_Spectrogram
13
+
14
+ extractor = Mel_Spectrogram()
15
+
16
+ with open('new_lbl2ind.pkl', 'rb') as f:
17
+ lbl2ind = pickle.load(f)
18
+ with open('new_spk.pkl', 'rb') as f:
19
+ unique_speaker_ids = pickle.load(f)
20
+ # change the labels
21
+ number_Of_spks = len(unique_speaker_ids)
22
+
23
+
24
+ @dataclass
25
+ class DataCollatorWithPadding:
26
+ """
27
+ Data collator that will dynamically pad the inputs received.
28
+ Args:
29
+ processor (:class:`~transformers.Wav2Vec2Processor`)
30
+ The processor used for proccessing the data.
31
+ padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
32
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
33
+ among:
34
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
35
+ sequence if provided).
36
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
37
+ maximum acceptable input length for the model if that argument is not provided.
38
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
39
+ different lengths).
40
+ max_length (:obj:`int`, `optional`):
41
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
42
+ max_length_labels (:obj:`int`, `optional`):
43
+ Maximum length of the ``labels`` returned list and optionally padding length (see above).
44
+ pad_to_multiple_of (:obj:`int`, `optional`):
45
+ If set will pad the sequence to a multiple of the provided value.
46
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
47
+ 7.5 (Volta).
48
+ """
49
+
50
+
51
+
52
+ padding: Union[bool, str] = True
53
+ max_length: Optional[int] = None
54
+ pad_to_multiple_of: Optional[int] = None
55
+ pad_to_multiple_of_labels: Optional[int] = None
56
+ flag_global_local: Optional[str] = None
57
+ dic_train_phn_frequency: Optional [dict] = None
58
+ dic_train_frame_frequency: Optional [dict] = None
59
+ lbl2ind: Optional [dict] = None
60
+
61
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
62
+ # split inputs and labels since they have to be of different lengths and need
63
+ # different padding methods
64
+ batch={}
65
+ batch['input_values']= [features[idx]['audio_tensor'].squeeze(0) for idx in range(len(features))]
66
+ batch["prompt"] = [features[idx]["prompt"] for idx in range(len(features))]
67
+ batch["answer"] = [features[idx]["answer"] for idx in range(len(features))]
68
+ batch["filename"] = [features[idx]["filename"] for idx in range(len(features))]
69
+ # batch["no_hot_encode"] = torch.tensor([lbl2ind[features[idx]['sid']] for idx in range(len(features))])
70
+ batch["no_hot_encode"] = torch.tensor([0 for idx in range(len(features))])
71
+ # if batch["no_hot_encode"].numel():
72
+ batch["labels"]= F.one_hot(batch["no_hot_encode"], number_Of_spks)
73
+ batch['input_values'] = extractor(torch.stack(batch['input_values']))
74
+ return batch
load_data/dataset.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from glob import glob
3
+ import torchaudio
4
+ from torch.utils.data import Dataset
5
+ import pandas as pd
6
+ from PIL import Image
7
+ import pickle
8
+ from copy import deepcopy
9
+ from glob import glob
10
+ import random
11
+ from sklearn.model_selection import train_test_split
12
+ import json
13
+ import os
14
+ import numpy as np
15
+ import librosa
16
+ import torch
17
+ import soundfile as sf
18
+ import pandas as pd
19
+ import random
20
+
21
+ class EARS(Dataset):
22
+ """
23
+ EARS dataset for 10sec or less that 10sec segments.
24
+ Returns:
25
+ audio: torch.Tensor in (1,16000) or (1, <16000), audio waveform
26
+ sid: str (p103), speaker id
27
+ metadict: dict, metadata
28
+ caption: str, caption
29
+ alignment: list
30
+ """
31
+ def __init__(self, root, data_path, meta_path,utterance_path, prompts_path, sample_rate, train_mapper=False, split="train"):
32
+ super().__init__()
33
+ self.root = root
34
+
35
+ with open(f"{data_path}", "r") as f:
36
+ self.data = json.load(f)
37
+
38
+ with open(f"{meta_path}", "r") as f:
39
+ self.meta = json.load(f)
40
+
41
+ with open(f"{utterance_path}", "r") as f:
42
+ self.utterance = json.load(f)
43
+
44
+ with open(f"{prompts_path}", "r") as f:
45
+ self.prompts = json.load(f)
46
+
47
+ self.new_data = []
48
+ if train_mapper:
49
+ for d in self.data:
50
+ file_name = d["filename"]
51
+ sid = file_name.split("/")[0]
52
+ temp = random.sample(self.prompts[sid], 10)
53
+ for qa in temp:
54
+ self.new_data.append({"filename": file_name,
55
+ "start": d["start"],
56
+ "end": d["end"],
57
+ "prompt": qa[0],
58
+ "answer": qa[1]})
59
+ else:
60
+ self.new_data = self.data
61
+ if split == "train":
62
+ random.shuffle(self.new_data)
63
+
64
+ self.sample_rate = sample_rate
65
+
66
+ def __len__(self):
67
+ return len(self.new_data)
68
+
69
+ def __getitem__(self, idx):
70
+ entry = self.new_data[idx]
71
+ filename = entry["filename"]
72
+ sid = filename.split("/")[0]
73
+ audio_path = os.path.join(self.root, filename)
74
+
75
+ # Load audio
76
+ audio, sample_rate = torchaudio.load(audio_path)
77
+ start_sample, end_sample = entry["start"], entry["end"]
78
+
79
+ # Resample if needed
80
+ if sample_rate != self.sample_rate:
81
+ audio = torchaudio.transforms.Resample(sample_rate, self.sample_rate)(audio)
82
+
83
+ # Compute duration in samples
84
+ total_samples = end_sample - start_sample
85
+ num_samples_3s = 3 * self.sample_rate # 3 seconds worth of samples
86
+
87
+ # Select a random 3s window within the available range
88
+ if total_samples >= num_samples_3s:
89
+ start_offset = random.randint(start_sample, end_sample - num_samples_3s)
90
+ end_offset = start_offset + num_samples_3s
91
+ audio = audio[:, start_offset:end_offset]
92
+ else:
93
+ # If less than 3s, take full segment and pad
94
+ pad_size = num_samples_3s - total_samples
95
+ audio = audio[:, start_sample:end_sample]
96
+ audio = torch.nn.functional.pad(audio, (0, pad_size))
97
+
98
+ # Normalize
99
+ mean = torch.mean(audio)
100
+ std = torch.std(audio)
101
+ audio = (audio - mean) / (std + 1e-8)
102
+
103
+ return {
104
+ "audio_tensor": audio,
105
+ "filename": filename,
106
+ "sid": sid,
107
+ "prompt": entry.get("prompt", None),
108
+ "answer": entry.get("answer", None),
109
+ }
load_data/extract_fbanks.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ class PreEmphasis(torch.nn.Module):
7
+ def __init__(self, coef: float = 0.97):
8
+ super(PreEmphasis, self).__init__()
9
+ self.coef = coef
10
+ # make kernel
11
+ # In pytorch, the convolution operation uses cross-correlation. So, filter is flipped.
12
+ self.register_buffer(
13
+ 'flipped_filter', torch.FloatTensor(
14
+ [-self.coef, 1.]).unsqueeze(0).unsqueeze(0)
15
+ )
16
+
17
+ def forward(self, inputs: torch.tensor) -> torch.tensor:
18
+ assert len(
19
+ inputs.size()) == 2, 'The number of dimensions of inputs tensor must be 2!'
20
+ # reflect padding to match lengths of in/out
21
+ inputs = inputs.unsqueeze(1)
22
+ inputs = F.pad(inputs, (1, 0), 'reflect')
23
+ return F.conv1d(inputs, self.flipped_filter).squeeze(1)
24
+
25
+
26
+ class Mel_Spectrogram(nn.Module):
27
+ def __init__(self, sample_rate=16000, n_fft=512, win_length=400, hop=160, n_mels=80, coef=0.97, requires_grad=False):
28
+ super(Mel_Spectrogram, self).__init__()
29
+ self.n_fft = n_fft
30
+ self.n_mels = n_mels
31
+ self.win_length = win_length
32
+ self.hop = hop
33
+
34
+ self.pre_emphasis = PreEmphasis(coef)
35
+ mel_basis = librosa.filters.mel(
36
+ sr=sample_rate, n_fft=n_fft, n_mels=n_mels)
37
+ self.mel_basis = nn.Parameter(
38
+ torch.FloatTensor(mel_basis), requires_grad=requires_grad)
39
+ self.instance_norm = nn.InstanceNorm1d(num_features=n_mels)
40
+ window = torch.hamming_window(self.win_length)
41
+ self.window = nn.Parameter(
42
+ torch.FloatTensor(window), requires_grad=False)
43
+
44
+ def forward(self, x):
45
+ x = self.pre_emphasis(x)
46
+ x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop,
47
+ window=self.window, win_length=self.win_length, return_complex=True)
48
+ x = torch.abs(x)
49
+ x += 1e-9
50
+ x = torch.log(x)
51
+ x = torch.matmul(self.mel_basis, x)
52
+ x = self.instance_norm(x)
53
+ x = x.permute(0, 2, 1)
54
+ x = x.unsqueeze(1)
55
+ return x
load_data/prepare_dataloader.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset, DataLoader
2
+ from torch.utils.data.distributed import DistributedSampler
3
+ from preprocessing.ast_processor import ast
4
+ from load_data.data_collactor import DataCollatorWithPadding
5
+
6
+ def prepare_dataloader(dataset: Dataset, batch_size: int, valid_train_flag: str):
7
+ if valid_train_flag == "train":
8
+ data_collator = DataCollatorWithPadding(padding=True)
9
+ elif valid_train_flag == "valid":
10
+ data_collator = DataCollatorWithPadding(padding=True)
11
+ elif valid_train_flag == "test":
12
+ data_collator = DataCollatorWithPadding(padding=True)
13
+ return DataLoader(
14
+ dataset,
15
+ batch_size=batch_size,
16
+ pin_memory=True,
17
+ shuffle=False,
18
+ sampler=DistributedSampler(dataset),
19
+
20
+ collate_fn=data_collator
21
+ )
22
+
load_data/tears.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ import json
4
+ import torchaudio
5
+ import os
6
+ from typing import Optional, Dict, Any, List, Tuple
7
+ import pandas as pd
8
+ import warnings
9
+ import random
10
+ from pathlib import Path
11
+ from collections import defaultdict
12
+
13
+
14
+
15
+
16
+ class TEARSDataset(Dataset):
17
+ """
18
+ TEARS dataset class that loads audio and associated metadata/responses.
19
+
20
+ Args:
21
+ json_path (str): Path to the JSON file containing TEARS data
22
+ tears_root (str): Root directory containing TEARS audio files
23
+ sample_rate (int, optional): Target sample rate for audio. Defaults to 16000.
24
+ duration (float, optional): Target duration in seconds. Defaults to 3.0.
25
+ normalize_audio (bool, optional): Whether to normalize audio. Defaults to True.
26
+
27
+ Returns:
28
+ Dict containing:
29
+ - audio_tensor: torch.Tensor of shape (1, num_samples)
30
+ - speaker_id: str, speaker identifier
31
+ - metadata: dict containing speaker metadata
32
+ - prompt: str, randomly selected prompt
33
+ - response: str, corresponding response
34
+ - filepath: str, path to audio file
35
+ """
36
+ def __init__(
37
+ self,
38
+ json_path: str,
39
+ tears_root: str,
40
+ sample_rate: int = 16000,
41
+ duration: float = 3.0,
42
+ normalize_audio: bool = True,
43
+ augment: bool = True
44
+ ):
45
+ super().__init__()
46
+
47
+ # Load the JSON data
48
+ with open(json_path, 'r') as f:
49
+ self.data = json.load(f)
50
+
51
+ self.tears_root = Path(tears_root)
52
+ self.sample_rate = sample_rate
53
+ self.duration = duration
54
+ self.normalize_audio = normalize_audio
55
+ self.target_samples = int(duration * sample_rate)
56
+ self.augment = augment
57
+
58
+ def __len__(self) -> int:
59
+ return len(self.data)
60
+
61
+ def augment_audio(self, waveform, sample_rate):
62
+ # Randomly select augmentation methods
63
+ augmentation_choices = ['time_stretch', 'pitch_shift', 'add_noise', 'spec_aug']
64
+ random.shuffle(augmentation_choices)
65
+
66
+ for aug in augmentation_choices[:random.randint(1, len(augmentation_choices))]:
67
+ if aug == 'time_stretch':
68
+ rate = random.uniform(0.8, 1.25)
69
+ effect = [['speed', str(rate)], ['rate', str(16000)]]
70
+ waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
71
+ waveform, 16000, effects=effect
72
+ )
73
+
74
+ elif aug == 'pitch_shift':
75
+ n_steps = random.randint(-4, 4)
76
+ effect = [['pitch', str(n)] for n in [n_steps*100 for n in [random.choice([-2, -1, 1, 2])]]]
77
+ waveform, _ = torchaudio.sox_effects.apply_effects_tensor(waveform, 16000, effect)
78
+
79
+ elif aug == 'add_noise':
80
+ noise = torch.randn_like(waveform) * random.uniform(0.001, 0.015)
81
+ waveform = waveform + noise
82
+
83
+ elif aug == 'frequency_mask':
84
+ freq_mask = T.FrequencyMasking(freq_mask_param=random.randint(15, 30))
85
+ waveform = freq_mask(waveform)
86
+
87
+ elif aug == 'time_mask':
88
+ time_mask = T.TimeMasking(time_mask_param=random.randint(20, 80))
89
+ waveform = time_mask(waveform)
90
+
91
+ elif aug == 'reverb':
92
+ effect = [['reverb', '-w', str(random.randint(10, 50))]]
93
+ waveform, _ = torchaudio.sox_effects.apply_effects_tensor(waveform, 16000, effect)
94
+
95
+ elif aug == 'pitch_shift':
96
+ steps = random.randint(-2, 2)
97
+ effect = [['pitch', str(steps * 100)], ['rate', '16000']]
98
+ waveform, _ = torchaudio.sox_effects.apply_effects_tensor(waveform, 16000, effect)
99
+
100
+ return waveform
101
+
102
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
103
+ # Get sample data
104
+ sample = self.data[idx]
105
+
106
+ # Get file path
107
+ audio_path = str(self.tears_root / sample['audio_path'])
108
+
109
+ # Load and process audio
110
+ try:
111
+ audio, sr = torchaudio.load(audio_path)
112
+
113
+ # Resample if necessary
114
+ if sr != self.sample_rate:
115
+ audio = torchaudio.transforms.Resample(sr, self.sample_rate)(audio)
116
+
117
+ if self.augment:
118
+ audio = self.augment_audio(audio, self.sample_rate)
119
+
120
+ # Normalize if requested
121
+ if self.normalize_audio:
122
+ mean = torch.mean(audio)
123
+ std = torch.std(audio)
124
+ audio = (audio - mean) / (std + 1e-8)
125
+
126
+ # Handle duration
127
+ num_samples = audio.shape[1]
128
+
129
+ if num_samples >= self.target_samples:
130
+ # Randomly crop to target duration
131
+ start_sample = random.randint(0, num_samples - self.target_samples)
132
+ audio = audio[:, start_sample:start_sample + self.target_samples]
133
+ else:
134
+ # Pad if shorter than target duration
135
+ pad_size = self.target_samples - num_samples
136
+ audio = torch.nn.functional.pad(audio, (0, pad_size))
137
+
138
+ except Exception as e:
139
+ warnings.warn(f"Error loading audio file {audio_path}: {str(e)}")
140
+ # Return zero tensor if audio loading fails
141
+ audio = torch.zeros(1, self.target_samples)
142
+
143
+ # Get prompt and response
144
+ prompts = sample.get('prompts', [])
145
+ responses = sample.get('responses', [])
146
+
147
+ if prompts and responses and len(prompts) == len(responses):
148
+ rand_idx = random.randint(0, len(prompts) - 1)
149
+ prompt = prompts[rand_idx]
150
+ response = responses[rand_idx].replace("\n", " ").strip()
151
+ else:
152
+ prompt = None
153
+ response = None
154
+
155
+ return {
156
+ 'audio_tensor': audio,
157
+ 'sid': sample['speaker']['id'],
158
+ 'metadata': sample['speaker'],
159
+ 'prompt': prompt,
160
+ 'answer': response,
161
+ 'filename': str(audio_path)
162
+ }
163
+
164
+
165
+
166
+ @staticmethod
167
+ def redistribute_speakers(
168
+ json_paths: Dict[str, str],
169
+ split_ratios: Dict[str, float],
170
+ seed: int = 42
171
+ ) -> Dict[str, List[Dict]]:
172
+ """
173
+ Redistribute speakers across splits according to given ratios.
174
+
175
+ Args:
176
+ json_paths: Dict mapping split names to json file paths
177
+ split_ratios: Dict mapping split names to desired ratios (should sum to 1)
178
+ seed: Random seed for reproducibility
179
+
180
+ Returns:
181
+ Dict mapping split names to lists of samples
182
+ """
183
+ random.seed(seed)
184
+
185
+ # Collect all samples and group by speaker
186
+ speaker_samples = defaultdict(list)
187
+ for split, path in json_paths.items():
188
+ with open(path, 'r') as f:
189
+ data = json.load(f)
190
+ for sample in data:
191
+ speaker_samples[sample['speaker']['id']].append(sample)
192
+
193
+ # Get list of all speakers
194
+ all_speakers = list(speaker_samples.keys())
195
+ random.shuffle(all_speakers)
196
+
197
+ # Calculate number of speakers for each split
198
+ total_speakers = len(all_speakers)
199
+ split_speakers = {
200
+ split: int(ratio * total_speakers)
201
+ for split, ratio in split_ratios.items()
202
+ }
203
+
204
+ # Adjust for rounding errors
205
+ remainder = total_speakers - sum(split_speakers.values())
206
+ if remainder > 0:
207
+ # Add remaining speakers to first split
208
+ split_speakers[list(split_speakers.keys())[0]] += remainder
209
+
210
+ # Distribute speakers to splits
211
+ new_splits = defaultdict(list)
212
+ current_idx = 0
213
+
214
+ for split, num_speakers in split_speakers.items():
215
+ split_speaker_ids = all_speakers[current_idx:current_idx + num_speakers]
216
+ for speaker_id in split_speaker_ids:
217
+ new_splits[split].extend(speaker_samples[speaker_id])
218
+ current_idx += num_speakers
219
+
220
+ return new_splits
221
+
222
+ @staticmethod
223
+ def save_splits(splits: Dict[str, List[Dict]], output_dir: str):
224
+ """Save redistributed splits to JSON files."""
225
+ output_dir = Path(output_dir)
226
+ output_dir.mkdir(parents=True, exist_ok=True)
227
+
228
+ for split_name, samples in splits.items():
229
+ output_path = output_dir / f"tears_dataset_{split_name}_with_responses.json"
230
+ with open(output_path, 'w') as f:
231
+ json.dump(samples, f, indent=2)
232
+
load_data/timit.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ import json
4
+ import torchaudio
5
+ import os
6
+ from typing import Optional, Dict, Any, List, Tuple
7
+ import pandas as pd
8
+ import warnings
9
+ import random
10
+
11
+ class TIMITDataset(Dataset):
12
+ """
13
+ TIMIT dataset class that loads audio and associated metadata/transcriptions.
14
+
15
+ Args:
16
+ json_path (str): Path to the JSON file containing TIMIT data
17
+ timit_root (str): Root directory containing TIMIT audio files
18
+ sample_rate (int, optional): Target sample rate for audio. Defaults to 16000.
19
+ normalize_audio (bool, optional): Whether to normalize audio. Defaults to True.
20
+
21
+ Returns:
22
+ Dict containing:
23
+ - audio_tensor: torch.Tensor of shape (1, num_samples)
24
+ - speaker_id: str, speaker identifier
25
+ - metadata: dict containing speaker metadata
26
+ - prompts: list of prompts used
27
+ - responses: list of responses generated
28
+ - filepath: str, path to audio file
29
+ - phonemes: DataFrame with columns [start_sample, end_sample, phoneme]
30
+ - words: DataFrame with columns [start_sample, end_sample, word]
31
+ - text: str, complete transcription
32
+ """
33
+ def __init__(
34
+ self,
35
+ json_path: str,
36
+ timit_root: str,
37
+ sample_rate: int = 16000,
38
+ normalize_audio: bool = True
39
+ ):
40
+ super().__init__()
41
+
42
+ # Load the JSON data
43
+ with open(json_path, 'r') as f:
44
+ self.data = json.load(f)
45
+
46
+ self.timit_root = timit_root
47
+ self.sample_rate = sample_rate
48
+ self.normalize_audio = normalize_audio
49
+
50
+ def __len__(self) -> int:
51
+ return len(self.data)
52
+
53
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
54
+ # Get sample data
55
+ sample = self.data[idx]
56
+
57
+ # Get file paths
58
+ audio_path = os.path.join(self.timit_root, sample['audio_path'])
59
+
60
+ # Load audio first
61
+ audio, sr = torchaudio.load(audio_path)
62
+
63
+
64
+ if sr != self.sample_rate:
65
+ audio = torchaudio.transforms.Resample(sr, self.sample_rate)(audio)
66
+
67
+ mean = torch.mean(audio)
68
+ std = torch.std(audio)
69
+ audio = (audio - mean) / (std + 1e-8)
70
+
71
+ # Get total number of samples
72
+ num_samples = audio.shape[1]
73
+ num_samples_3s = 3 * self.sample_rate # Samples for 3 seconds
74
+
75
+ # Ensure the audio is at least 3 seconds long
76
+ if num_samples >= num_samples_3s:
77
+ start_sample = random.randint(0, num_samples - num_samples_3s)
78
+ end_sample = start_sample + num_samples_3s
79
+ audio = audio[:, start_sample:end_sample]
80
+ else:
81
+ # If audio is shorter than 3 seconds, pad it
82
+ pad_size = num_samples_3s - num_samples
83
+ audio = torch.nn.functional.pad(audio, (0, pad_size))
84
+
85
+ prompts = sample.get('prompts', [])
86
+ answers = sample.get('responses', [])
87
+
88
+ if prompts and answers and len(prompts) == len(answers):
89
+ rand_idx = random.randint(0, len(prompts) - 1)
90
+ prompt = prompts[rand_idx]
91
+ answer = answers[rand_idx].replace("\n", " ").strip() # Clean response
92
+ else:
93
+ prompt = None
94
+ answer = None
95
+
96
+ return {
97
+ 'audio_tensor': audio,
98
+ 'sid': sample['speaker']['id'],
99
+ 'prompt': prompt,
100
+ 'answer': answer,
101
+ 'filename': audio_path,
102
+ }
load_data/voxceleb.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ import pandas as pd
4
+ import torchaudio
5
+ import random
6
+ import os
7
+
8
+ class ZeroShotDataset(Dataset):
9
+ def __init__(self, csv_path, transform=None):
10
+ """
11
+ Args:
12
+ csv_path (str): Path to the CSV file.
13
+ transform (callable, optional): Optional transform to be applied to audio.
14
+ """
15
+ self.data = pd.read_csv(csv_path)
16
+ self.transform = transform
17
+ self.sample_rate = 16000
18
+
19
+ def __len__(self):
20
+ return len(self.data)
21
+
22
+ def __getitem__(self, idx):
23
+ row = self.data.iloc[idx]
24
+
25
+ root = "/ocean/projects/cis220031p/psamal/preprocess_TIMIT/"
26
+
27
+ # Load audio file
28
+ audio, sr = torchaudio.load(os.path.join(root, row["File_Path"]))
29
+
30
+ # Apply transformation if provided
31
+ if self.transform:
32
+ audio = self.transform(audio)
33
+
34
+
35
+ if sr != self.sample_rate:
36
+ audio = torchaudio.transforms.Resample(sr, self.sample_rate)(audio)
37
+
38
+ mean = torch.mean(audio)
39
+ std = torch.std(audio)
40
+ audio = (audio - mean) / (std + 1e-8)
41
+
42
+ # Get total number of samples
43
+ num_samples = audio.shape[1]
44
+ num_samples_3s = 3 * self.sample_rate # Samples for 3 seconds
45
+
46
+ # Ensure the audio is at least 3 seconds long
47
+ if num_samples >= num_samples_3s:
48
+ start_sample = random.randint(0, num_samples - num_samples_3s)
49
+ end_sample = start_sample + num_samples_3s
50
+ audio = audio[:, start_sample:end_sample]
51
+ else:
52
+ # If audio is shorter than 3 seconds, pad it
53
+ pad_size = num_samples_3s - num_samples
54
+ audio = torch.nn.functional.pad(audio, (0, pad_size))
55
+
56
+ return {
57
+ "sid": "WBT0",
58
+ "audio_tensor": audio,
59
+ "answer": row["Ground_Truth"],
60
+ "prompt": row["Prompt"],
61
+ # "prompt": random.choice(["What is the dialect of the person?", "Based on the voice of the person, please specify the dialect of the person?", row["Prompt"]]),
62
+ 'filename': row["File_Path"],
63
+ }
mapper.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as nnf
6
+ from typing import Tuple, Optional
7
+
8
+ def get_sid_mapper(map_type: str, emb_size, prefix_size: int, gpt_embedding_size: int, prefix_length: int, clip_length: int, num_layers: int):
9
+
10
+ if map_type == 'mlp':
11
+ mapper = MLP(emb_size, (prefix_size, (gpt_embedding_size * prefix_length) // 2, gpt_embedding_size * prefix_length))
12
+
13
+ elif map_type == 'transformer':
14
+ mapper = TransformerMapper(emb_size, prefix_size, gpt_embedding_size, prefix_length, clip_length, int(num_layers/2))
15
+
16
+ else:
17
+ raise ValueError(f"Unknown mapping type {map_type}")
18
+
19
+ for p in mapper.parameters():
20
+ p.requires_grad = True
21
+
22
+ return mapper
23
+
24
+ def get_text_mapper(map_type: str, emb_size, prefix_size: int, gpt_embedding_size: int, prefix_length: int, clip_length: int, num_layers: int):
25
+
26
+ if map_type == 'mlp':
27
+ mapper = MLP(emb_size, (prefix_size, (gpt_embedding_size * prefix_length) // 2, gpt_embedding_size * prefix_length))
28
+
29
+ elif map_type == 'transformer':
30
+ mapper = TransformerMapperSeq(emb_size, prefix_size, gpt_embedding_size, prefix_length, clip_length, int(num_layers/2))
31
+
32
+ else:
33
+ raise ValueError(f"Unknown mapping type {map_type}")
34
+
35
+ for p in mapper.parameters():
36
+ p.requires_grad = True
37
+
38
+ return mapper
39
+
40
+
41
+ def init_layer(layer):
42
+ """Initialize a Linear or Convolutional layer. """
43
+ nn.init.xavier_uniform_(layer.weight)
44
+
45
+ if hasattr(layer, 'bias'):
46
+ if layer.bias is not None:
47
+ layer.bias.data.fill_(0.)
48
+
49
+ def init_bn(bn):
50
+ """Initialize a Batchnorm layer. """
51
+ bn.bias.data.fill_(0.)
52
+ bn.weight.data.fill_(1.)
53
+
54
+ class Projection(nn.Module):
55
+ def __init__(self, d_in: int, d_out: int, p: float=0.5) -> None:
56
+ super().__init__()
57
+ self.linear1 = nn.Linear(d_in, d_out, bias=False)
58
+ self.linear2 = nn.Linear(d_out, d_out, bias=False)
59
+ self.layer_norm = nn.LayerNorm(d_out)
60
+ self.drop = nn.Dropout(p)
61
+
62
+ self.init_weight()
63
+
64
+ def init_weight(self):
65
+ init_layer(self.linear1)
66
+ init_layer(self.linear2)
67
+ init_bn(self.layer_norm)
68
+
69
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
70
+ embed1 = self.linear1(x)
71
+ embed2 = self.drop(self.linear2(nnf.gelu(embed1)))
72
+ embeds = self.layer_norm(embed1 + embed2)
73
+ return embeds
74
+
75
+
76
+ class MLP(nn.Module):
77
+ def __init__(self, emb_size, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
78
+ super(MLP, self).__init__()
79
+ self.emb_size = emb_size
80
+ # if self.emb_size is not None:
81
+ # self.projector = Projection(emb_size, sizes[0])
82
+ layers = []
83
+ for i in range(len(sizes) - 1):
84
+ layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
85
+ if i < len(sizes) - 2:
86
+ layers.append(act())
87
+ self.model = nn.Sequential(*layers)
88
+
89
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
90
+ # if self.emb_size is not None:
91
+ # x = self.projector(x)
92
+ return self.model(x)
93
+
94
+
95
+ class MlpTransformer(nn.Module):
96
+ def __init__(self, in_dim, h_dim, out_d: Optional[int] = None, act=nnf.relu, dropout=0.):
97
+ super().__init__()
98
+ out_d = out_d if out_d is not None else in_dim
99
+ self.fc1 = nn.Linear(in_dim, h_dim)
100
+ self.act = act
101
+ self.fc2 = nn.Linear(h_dim, out_d)
102
+ self.dropout = nn.Dropout(dropout)
103
+
104
+ def forward(self, x):
105
+ x = self.fc1(x)
106
+ x = self.act(x)
107
+ x = self.dropout(x)
108
+ x = self.fc2(x)
109
+ x = self.dropout(x)
110
+ return x
111
+
112
+ class MultiHeadAttention(nn.Module):
113
+
114
+ def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.):
115
+ super().__init__()
116
+ self.num_heads = num_heads
117
+ head_dim = dim_self // num_heads
118
+ self.scale = head_dim ** -0.5
119
+ self.to_queries = nn.Linear(dim_self, dim_self, bias=bias)
120
+ self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias)
121
+ self.project = nn.Linear(dim_self, dim_self)
122
+ self.dropout = nn.Dropout(dropout)
123
+
124
+ def forward(self, x, y=None, mask=None):
125
+ y = y if y is not None else x
126
+ b, n, c = x.shape
127
+ _, m, d = y.shape
128
+ # b n h dh
129
+ queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads)
130
+ # b m 2 h dh
131
+ keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads)
132
+ keys, values = keys_values[:, :, 0], keys_values[:, :, 1]
133
+ attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale
134
+ if mask is not None:
135
+ if mask.dim() == 2:
136
+ mask = mask.unsqueeze(1)
137
+ attention = attention.masked_fill(mask.unsqueeze(3), float("-inf"))
138
+ attention = attention.softmax(dim=2)
139
+ out = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, c)
140
+ out = self.project(out)
141
+ return out, attention
142
+
143
+
144
+ class TransformerLayer(nn.Module):
145
+
146
+ def forward_with_attention(self, x, y=None, mask=None):
147
+ x_, attention = self.attn(self.norm1(x), y, mask)
148
+ x = x + x_
149
+ x = x + self.mlp(self.norm2(x))
150
+ return x, attention
151
+
152
+ def forward(self, x, y=None, mask=None):
153
+ x = x + self.attn(self.norm1(x), y, mask)[0]
154
+ x = x + self.mlp(self.norm2(x))
155
+ return x
156
+
157
+ def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., act=nnf.relu,
158
+ norm_layer: nn.Module = nn.LayerNorm):
159
+ super().__init__()
160
+ self.norm1 = norm_layer(dim_self)
161
+ self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout)
162
+ self.norm2 = norm_layer(dim_self)
163
+ self.mlp = MlpTransformer(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout)
164
+
165
+
166
+ class Transformer(nn.Module):
167
+ def __init__(self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None,
168
+ mlp_ratio: float = 2., act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm, enc_dec: bool = False):
169
+ super(Transformer, self).__init__()
170
+ dim_ref = dim_ref if dim_ref is not None else dim_self
171
+ self.enc_dec = enc_dec
172
+ if enc_dec:
173
+ num_layers = num_layers * 2
174
+ layers = []
175
+ for i in range(num_layers):
176
+ if i % 2 == 0 and enc_dec: # cross
177
+ layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
178
+ elif enc_dec: # self
179
+ layers.append(TransformerLayer(dim_self, dim_self, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
180
+ else: # self or cross
181
+ layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
182
+ self.layers = nn.ModuleList(layers)
183
+
184
+ def forward_with_attention(self, x, y=None, mask=None):
185
+ attentions = []
186
+ for layer in self.layers:
187
+ x, att = layer.forward_with_attention(x, y, mask)
188
+ attentions.append(att)
189
+ return x, attentions
190
+
191
+ def forward(self, x, y=None, mask=None):
192
+ for i, layer in enumerate(self.layers):
193
+ if i % 2 == 0 and self.enc_dec: # cross
194
+ x = layer(x, y)
195
+ elif self.enc_dec: # self
196
+ x = layer(x, x, mask)
197
+ else: # self or cross
198
+ x = layer(x, y, mask)
199
+ return x
200
+
201
+
202
+ class TransformerMapper(nn.Module):
203
+ def __init__(self, emb_size, dim_clip: int, dim_embedding: int, prefix_length: int, clip_length: int, num_layers: int = 8):
204
+ super(TransformerMapper, self).__init__()
205
+ self.emb_size = emb_size
206
+ # if self.emb_size is not None:
207
+ # self.projector = Projection(emb_size, dim_clip)
208
+ self.clip_length = clip_length
209
+ self.transformer = Transformer(dim_embedding, 8, num_layers)
210
+ self.linear = nn.Linear(dim_clip, clip_length * dim_embedding)
211
+ self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True)
212
+
213
+ def forward(self, x):
214
+ if self.emb_size is not None:
215
+ x = self.projector(x)
216
+ # raise SystemError(x.shape) # torch.Size([100, 1024])
217
+ x = self.linear(x).view(x.shape[0], self.clip_length, -1)
218
+ # raise SystemError(x.shape) # torch.Size([100, 40, 768])
219
+ prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape)
220
+ prefix = torch.cat((x, prefix), dim=1) # shape is batch x seq x dim = b x 40+40 x 768 (clip length is 40)
221
+ out = self.transformer(prefix)[:, self.clip_length:]
222
+ # raise SystemError(out.shape) # torch.Size([100, 40, 768]) sid prefix
223
+ return out
224
+
225
+ class TransformerMapperSeq(nn.Module):
226
+ def __init__(self, emb_size ,dim_clip: int, dim_embedding: int, prefix_length: int, clip_length: int, num_layers: int = 8):
227
+ super(TransformerMapperSeq, self).__init__()
228
+ self.emb_size = emb_size
229
+ # if self.emb_size is not None:
230
+ # self.projector = Projection(emb_size, dim_clip)
231
+ self.clip_length = clip_length
232
+ self.transformer = Transformer(dim_embedding, 8, num_layers)
233
+ self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True)
234
+
235
+ def forward(self, x):
236
+ # if self.emb_size is not None:
237
+ # x = self.projector(x)
238
+ # raise SystemError(x.shape) # torch.Size([32, 80, 768])
239
+ x = x.view(x.shape[0], self.clip_length, -1)
240
+ prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape)
241
+ # raise SystemError(prefix.shape, x.shape) # torch.Size([32, 40, 768]) torch.Size([32, 40, 1536])
242
+ prefix = torch.cat((x, prefix), dim=1)
243
+ out = self.transformer(prefix)[:, self.clip_length:]
244
+ # raise SystemError(out.shape) # torch.Size([100, 80, 768]) text prefix
245
+ return out
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e0d80efbeffb56f4038bf9d320d15b5377d12b1cb85833e908d9f0f6b5c2bbab
3
+ size 2066033810
wrapper.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from transformers import AutoTokenizer
3
+ import os
4
+ import torch
5
+ from collections import OrderedDict
6
+ import librosa
7
+ from importlib_resources import files
8
+ import yaml
9
+ import argparse
10
+ import torchaudio
11
+ import torchaudio.transforms as T
12
+ import collections
13
+ import random
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from torch.nn.parallel import DistributedDataParallel as DDP
19
+ import logging
20
+ from glob import glob
21
+
22
+ from mapper import get_sid_mapper, get_text_mapper
23
+ from transformers import GPT2LMHeadModel
24
+ from transformers import AutoTokenizer
25
+
26
+
27
+ class ExpWrapper():
28
+ def __init__(self, config_wrapper, gpu_id):
29
+ self.tok_len = config_wrapper['tok_len']
30
+ self.text_prefix_length = config_wrapper['text_prefix_length']
31
+ self.sid_prefix_length = config_wrapper['sid_prefix_length']
32
+ self.norm_sid_emb = config_wrapper['norm_sid_emb']
33
+ self.gpu_id = gpu_id
34
+ self.gpt = GPT2LMHeadModel.from_pretrained(config_wrapper['text_decoder'])
35
+ self.gpt = self.gpt.to(self.gpu_id)
36
+ # for param in self.gpt.parameters():
37
+ # param.requires_grad = False
38
+
39
+ self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
40
+
41
+ self.sid_mapper = get_sid_mapper(config_wrapper["map_type"],None,
42
+ config_wrapper["prefix_size"], self.gpt_embedding_size,
43
+ config_wrapper["sid_prefix_length"], config_wrapper["sid_prefix_length_clip"],
44
+ config_wrapper["num_layers"])
45
+
46
+
47
+ # self.text_mapper = get_text_mapper(config_wrapper["map_type"], None,
48
+ # config_wrapper["prefix_size"], self.gpt_embedding_size,
49
+ # config_wrapper["text_prefix_length"], config_wrapper["text_prefix_length_clip"],
50
+ # config_wrapper["num_layers"])
51
+ # # this is temporary
52
+ # if config_wrapper["checkpoint_path"]:
53
+ # checkpoint = torch.load(config_wrapper["checkpoint_path"])
54
+ # state_dict = checkpoint['model']
55
+ # text_project_weights = {k.replace('caption_decoder.text_project.',''): v for k, v in state_dict.items()
56
+ # if 'caption_decoder.text_project' in k}
57
+ # self.text_mapper.load_state_dict(text_project_weights)
58
+
59
+ self.sid_mapper = self.sid_mapper.to(self.gpu_id)
60
+ # self.text_mapper = self.text_mapper.to(self.gpu_id)
61
+ self.tokenizer = AutoTokenizer.from_pretrained(config_wrapper['text_decoder'])
62
+ self.tokenizer.add_special_tokens({'pad_token': '!'})
63
+
64
+ def init_mapper(self):
65
+ self.sid_mapper = DDP(self.sid_mapper, device_ids=[self.gpu_id], find_unused_parameters=True)
66
+
67
+ def freeze_llm(self):
68
+ for param in self.sid_mapper.parameters():
69
+ param.requires_grad = False
70
+ for param in self.gpt.parameters():
71
+ param.requires_grad = False
72
+
73
+ def default_collate(self, batch):
74
+ r"""Puts each data field into a tensor with outer dimension batch size"""
75
+ elem = batch[0]
76
+ elem_type = type(elem)
77
+ if isinstance(elem, torch.Tensor):
78
+ out = None
79
+ if torch.utils.data.get_worker_info() is not None:
80
+ # If we're in a background process, concatenate directly into a
81
+ # shared memory tensor to avoid an extra copy
82
+ numel = sum([x.numel() for x in batch])
83
+ storage = elem.storage()._new_shared(numel)
84
+ out = elem.new(storage)
85
+ return torch.stack(batch, 0, out=out)
86
+ elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
87
+ and elem_type.__name__ != 'string_':
88
+ if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
89
+ # array of string classes and object
90
+ if self.np_str_obj_array_pattern.search(elem.dtype.str) is not None:
91
+ raise TypeError(
92
+ self.default_collate_err_msg_format.format(elem.dtype))
93
+
94
+ return self.default_collate([torch.as_tensor(b) for b in batch])
95
+ elif elem.shape == (): # scalars
96
+ return torch.as_tensor(batch)
97
+ elif isinstance(elem, float):
98
+ return torch.tensor(batch, dtype=torch.float64)
99
+ elif isinstance(elem, int):
100
+ return torch.tensor(batch)
101
+ elif isinstance(elem, collections.abc.Mapping):
102
+ return {key: self.default_collate([d[key] for d in batch]) for key in elem}
103
+ elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
104
+ return elem_type(*(self.default_collate(samples) for samples in zip(*batch)))
105
+ elif isinstance(elem, collections.abc.Sequence):
106
+ # check to make sure that the elements in batch have consistent size
107
+ it = iter(batch)
108
+ elem_size = len(next(it))
109
+ if not all(len(elem) == elem_size for elem in it):
110
+ raise RuntimeError(
111
+ 'each element in list of batch should be of equal size')
112
+ transposed = zip(*batch)
113
+ return [self.default_collate(samples) for samples in transposed]
114
+
115
+ raise TypeError(self.default_collate_err_msg_format.format(elem_type))
116
+
117
+ def load_model(self, st, model):
118
+ try:
119
+ model.load_state_dict(st)
120
+ except:
121
+ for key in list(st.keys()):
122
+ if "module." in key:
123
+ st[key.replace("module.", "")] = st.pop(key)
124
+ model.load_state_dict(st)
125
+ return model
126
+
127
+ def load_model(self, st, model):
128
+ try:
129
+ model.load_state_dict(st)
130
+ except:
131
+ for key in list(st.keys()):
132
+ if "module." in key:
133
+ st[key.replace("module.", "")] = st.pop(key)
134
+ model.load_state_dict(st)
135
+ return model
136
+
137
+ def load_sid_model(self, sid_model, snapshot_path, sid_ck_name):
138
+ loc = f"cuda:{self.gpu_id}"
139
+ # sid_model_path = sorted(glob(f"{snapshot_path}/sid_model_epoch_*.pt"),
140
+ # key=lambda x: float(x.split('_')[-1].replace('.pt', '')))[0]
141
+ sid_model_path = f"{snapshot_path}/{sid_ck_name}"
142
+ snapshot = torch.load(sid_model_path, map_location=loc)
143
+ sid_model = self.load_model(snapshot["sid_model"], sid_model)
144
+ best_val_loss = snapshot["val_loss"]
145
+ epochs_run = snapshot["epochs_run"]
146
+
147
+ def load_mapper(self, snapshot_path, mapper_ck_name):
148
+ loc = f"cuda:{self.gpu_id}"
149
+ mapper_path = sorted(glob(f"{snapshot_path}/mapper_*.pt"))[-1]
150
+ mapper_path = f"{snapshot_path}/{mapper_ck_name}"
151
+ snapshot = torch.load(mapper_path, map_location=loc)
152
+
153
+ self.sid_mapper = self.load_model(snapshot["sid_mapper"],self.sid_mapper)
154
+ # self.text_mapper = self.load_model(snapshot["text_mapper"],self.text_mapper)
155
+
156
+ self.epochs_run = snapshot["epochs_run"]
157
+ logging.info(f"Resuming training from mapper at Epoch {self.epochs_run}")
158
+
159
+ def save_mapper(self, epoch, snapshot_path, val_epoch_ce_llm):
160
+ mapper = {
161
+ # "text_mapper": self.text_mapper.state_dict(),
162
+ "sid_mapper": self.sid_mapper.state_dict(),
163
+ "epochs_run": epoch,
164
+ }
165
+ part = snapshot_path
166
+ torch.save(mapper, f"{part}/unfrozen_mapper_epoch_{str(epoch).zfill(4)}_val_epoch_ce_llm_{val_epoch_ce_llm}.pt")
167
+ logging.info(f"Epoch {epoch} | Training mapper saved at {snapshot_path}")
168
+
169
+ def preprocess_prompt(self, texts): # true false
170
+ r"""Load list of prompts and return tokenized text"""
171
+ tokenized_texts = []
172
+ for ttext in texts:
173
+ tok = self.tokenizer.encode_plus(
174
+ text=ttext, add_special_tokens=True,
175
+ max_length=10,
176
+ pad_to_max_length=True, return_tensors="pt", truncation=True)
177
+ for key in tok.keys():
178
+ tok[key] = tok[key].reshape(-1).to(self.gpu_id)
179
+ tokenized_texts.append(tok)
180
+ return self.default_collate(tokenized_texts)
181
+
182
+ def preprocess_prompt_single(self, texts): # true false
183
+ r"""Load list of prompts and return tokenized text"""
184
+ tokenized_texts = []
185
+ tok = self.tokenizer.encode_plus(
186
+ text=texts, add_special_tokens=True,
187
+ max_length=10,
188
+ pad_to_max_length=True, return_tensors="pt", truncation=True)
189
+ for key in tok.keys():
190
+ tok[key] = tok[key].reshape(-1).to(self.gpu_id)
191
+ tokenized_texts.append(tok)
192
+ return self.default_collate(tokenized_texts)
193
+
194
+
195
+ def preprocess_text(self, texts): # true false
196
+ r"""Load list of prompts and return tokenized text"""
197
+ tokenized_texts = []
198
+ for ttext in texts:
199
+ ttext = ttext + ' <|endoftext|>'
200
+ tok = self.tokenizer.encode_plus(
201
+ text=ttext, add_special_tokens=True,
202
+ max_length=self.tok_len,
203
+ pad_to_max_length=True, return_tensors="pt", truncation=True)
204
+ for key in tok.keys():
205
+ tok[key] = tok[key].reshape(-1).to(self.gpu_id)
206
+ tokenized_texts.append(tok)
207
+ return self.default_collate(tokenized_texts)
208
+
209
+ def _get_text_embeddings(self, preprocessed_texts):
210
+ r"""Load preprocessed prompts and return a prompt embeddings"""
211
+ with torch.no_grad():
212
+ texts_embed = self.gpt.transformer.wte(preprocessed_texts['input_ids'])
213
+ return texts_embed
214
+
215
+ def get_sid_prefix(self, sid_embeddings):
216
+ r"""Produces audio embedding which is fed to LM"""
217
+ if self.norm_sid_emb:
218
+ sid_embeddings = sid_embeddings / sid_embeddings.norm(2, -1).reshape(-1,1)
219
+
220
+ # raise SystemError(sid_embeddings.shape) # torch.Size([2, 1024])
221
+ sids_prefix = self.sid_mapper(sid_embeddings).contiguous().view(-1, self.sid_prefix_length, self.gpt_embedding_size)
222
+ # raise SystemError(sids_prefix.shape) # torch.Size([2, 40, 768]) batch_size, seq_len, embed_size
223
+ return sids_prefix
224
+
225
+ def get_prompt_prefix(self, texts):
226
+ r"""Load list of text prompts and return prompt prefix and prompt embeddings"""
227
+ preprocessed_texts = self.preprocess_prompt(texts)
228
+ print(preprocessed_texts)
229
+ texts_embed = self._get_text_embeddings(preprocessed_texts)
230
+ return texts_embed, preprocessed_texts
231
+ def get_prompt_prefix_single(self, texts):
232
+ r"""Load list of text prompts and return prompt prefix and prompt embeddings"""
233
+ preprocessed_texts = self.preprocess_prompt_single(texts)
234
+ texts_embed = self._get_text_embeddings(preprocessed_texts)
235
+ return texts_embed, preprocessed_texts
236
+
237
+ def get_text_prefix(self, texts):
238
+ r"""Load list of text prompts and return prompt prefix and prompt embeddings"""
239
+ preprocessed_texts = self.preprocess_text(texts)
240
+ texts_embed = self._get_text_embeddings(preprocessed_texts)
241
+ return texts_embed, preprocessed_texts
242
+
243
+ def generate_beam(self, beam_size: int = 1, sids_prefix=None, entry_length=80, temperature=1., stop_token: str = ' <|endoftext|>'):
244
+ stop_token_index = self.tokenizer.encode(stop_token)[0]
245
+ tokens = None
246
+ scores = None
247
+ device = next(self.gpt.parameters()).device
248
+ seq_lengths = torch.ones(beam_size, device=device)
249
+ is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
250
+ with torch.no_grad():
251
+ generated = sids_prefix # sid embedding
252
+ for i in range(entry_length):
253
+ outputs = self.gpt(inputs_embeds=generated)
254
+ logits = outputs.logits
255
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
256
+ logits = logits.softmax(-1).log()
257
+ if scores is None:
258
+ scores, next_tokens = logits.topk(beam_size, -1)
259
+ generated = generated.expand(beam_size, *generated.shape[1:])
260
+ next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
261
+ if tokens is None:
262
+ tokens = next_tokens
263
+ else:
264
+ tokens = tokens.expand(beam_size, *tokens.shape[1:])
265
+ tokens = torch.cat((tokens, next_tokens), dim=1)
266
+ else:
267
+ logits[is_stopped] = -float(np.inf)
268
+ logits[is_stopped, 0] = 0
269
+ scores_sum = scores[:, None] + logits
270
+ seq_lengths[~is_stopped] += 1
271
+ scores_sum_average = scores_sum / seq_lengths[:, None]
272
+ scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
273
+ next_tokens_source = next_tokens // scores_sum.shape[1]
274
+ seq_lengths = seq_lengths[next_tokens_source]
275
+ next_tokens = next_tokens % scores_sum.shape[1]
276
+ next_tokens = next_tokens.unsqueeze(1)
277
+ tokens = tokens[next_tokens_source]
278
+ tokens = torch.cat((tokens, next_tokens), dim=1)
279
+ generated = generated[next_tokens_source]
280
+ scores = scores_sum_average * seq_lengths
281
+ is_stopped = is_stopped[next_tokens_source]
282
+
283
+ next_token_embed = self.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
284
+ generated = torch.cat((generated, next_token_embed), dim=1)
285
+ is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
286
+ if is_stopped.all():
287
+ break
288
+ scores = scores / seq_lengths
289
+ output_list = tokens.cpu().numpy()
290
+ ############ Shuo added for attn plot ###########
291
+ # token_list = []
292
+ # text_list = []
293
+ # for output, length in zip(output_list, seq_lengths):
294
+ # for item in output[:int(length)]:
295
+ # token_list.append(item)
296
+ # text_list.append(self.tokenizer.decode(item))
297
+ ############ Shuo added for attn plot ###########
298
+ output_texts = [self.tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)]
299
+ order = scores.argsort(descending=True)
300
+ #output_texts = [[output_texts[i], scores[i].item()] for i in order]
301
+ output_texts = [output_texts[i] for i in order]
302
+ return output_texts
303
+ # return output_texts, token_list, text_list
304
+
305
+