SeaSky1027 commited on
Commit
8e60cc8
·
1 Parent(s): db20897

Add CLAP & HiFiGAN

Browse files
CLAP/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ dir_path = os.path.dirname(os.path.abspath(__file__))
4
+ sys.path.append(dir_path)
5
+ from .clap_module import CLAP_Module
CLAP/clap_model.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from pathlib import Path
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+
9
+ from transformers import RobertaModel
10
+ from .htsat import create_htsat_model
11
+
12
+ BASE_DIR = Path(__file__).resolve().parent
13
+
14
+ class MLPLayers(nn.Module):
15
+ def __init__(self, units=[512, 512, 512], nonlin=nn.ReLU(), dropout=0.1):
16
+ super(MLPLayers, self).__init__()
17
+ self.nonlin = nonlin
18
+ self.dropout = dropout
19
+
20
+ sequence = []
21
+ for u0, u1 in zip(units[:-1], units[1:]):
22
+ sequence.append(nn.Linear(u0, u1))
23
+ sequence.append(self.nonlin)
24
+ sequence.append(nn.Dropout(self.dropout))
25
+ sequence = sequence[:-2]
26
+
27
+ self.sequential = nn.Sequential(*sequence)
28
+
29
+ def forward(self, X):
30
+ X = self.sequential(X)
31
+ return X
32
+
33
+
34
+ # Audio Config Class
35
+ @dataclass
36
+ class CLAPAudioCfp:
37
+ model_type: str = "PANN"
38
+ model_name: str = "Cnn14"
39
+ sample_rate: int = 48000
40
+ # Param
41
+ audio_length: int = 1024
42
+ window_size: int = 1024
43
+ hop_size: int = 1024
44
+ fmin: int = 50
45
+ fmax: int = 14000
46
+ class_num: int = 527
47
+ mel_bins: int = 64
48
+ clip_samples: int = 480000
49
+
50
+
51
+ @dataclass
52
+ class CLAPTextCfg:
53
+ context_length: int
54
+ vocab_size: int
55
+ width: int
56
+ heads: int
57
+ layers: int
58
+ model_type: str
59
+
60
+
61
+ class CLAP(nn.Module):
62
+ def __init__(
63
+ self,
64
+ embed_dim: int,
65
+ audio_cfg: CLAPAudioCfp,
66
+ text_cfg: CLAPTextCfg,
67
+ ):
68
+ super().__init__()
69
+
70
+ audio_cfg = CLAPAudioCfp(**audio_cfg)
71
+ text_cfg = CLAPTextCfg(**text_cfg)
72
+ self.context_length = text_cfg.context_length
73
+
74
+ mlp_act_layer = nn.ReLU()
75
+
76
+ # audio branch
77
+ self.audio_branch = create_htsat_model(audio_cfg)
78
+
79
+ # audio branch parameters
80
+ self.audio_transform = MLPLayers(units=[512,
81
+ 512,
82
+ 512], dropout=0.1)
83
+
84
+ self.audio_projection = nn.Sequential(
85
+ nn.Linear(embed_dim, 512),
86
+ mlp_act_layer,
87
+ nn.Linear(512, 512)
88
+ )
89
+
90
+ # text branch
91
+ self.text_branch = RobertaModel.from_pretrained('roberta-base')
92
+
93
+ # text branch parameters
94
+ self.text_transform = MLPLayers(units=[512,
95
+ 512,
96
+ 512], dropout=0.1)
97
+ self.text_projection = nn.Sequential(
98
+ nn.Linear(768, 512),
99
+ mlp_act_layer,
100
+ nn.Linear(512, 512)
101
+ )
102
+
103
+ self.logit_scale_a = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
104
+ self.logit_scale_t = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
105
+ self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False)
106
+
107
+ self.init_text_branch_parameters()
108
+
109
+ def init_text_branch_parameters(self):
110
+ nn.init.constant_(self.logit_scale_a, np.log(1 / 0.07))
111
+ nn.init.constant_(self.logit_scale_t, np.log(1 / 0.07))
112
+
113
+ def build_attention_mask(self):
114
+ # lazily create causal attention mask, with full attention between the vision tokens
115
+ # pytorch uses additive attention mask; fill with -inf
116
+ mask = torch.empty(self.context_length, self.context_length)
117
+ mask.fill_(float("-inf"))
118
+ mask.triu_(1) # zero out the lower diagonal
119
+ return mask
120
+
121
+ def get_word_embedding(self, data):
122
+ device = next(self.parameters()).device
123
+ for k in data:
124
+ data[k] = data[k].to(device)
125
+
126
+ word_embeds = self.text_branch.embeddings.word_embeddings(
127
+ data['input_ids'].to(device=device, non_blocking=True)
128
+ )
129
+
130
+ return word_embeds
131
+
132
+ def get_text_embedding(self, data, normalize=False):
133
+
134
+ device = next(self.parameters()).device
135
+ for k in data:
136
+ data[k] = data[k].to(device)
137
+
138
+ x = self.text_branch(
139
+ input_ids=data["input_ids"].to(device=device, non_blocking=True),
140
+ attention_mask=data["attention_mask"].to(
141
+ device=device, non_blocking=True
142
+ ),
143
+ )["pooler_output"]
144
+ text_embeds = self.text_projection(x)
145
+
146
+ if normalize:
147
+ text_embeds = F.normalize(text_embeds, dim=-1)
148
+
149
+ return text_embeds
150
+
151
+ def get_audio_embedding(self, data, normalize=False):
152
+
153
+ device = next(self.parameters()).device
154
+ input_dict = {}
155
+ keys = data[0].keys()
156
+ for k in keys:
157
+ input_dict[k] = torch.cat([d[k].unsqueeze(0) for d in data], dim=0).to(device)
158
+
159
+ audio_embeds = self.audio_branch(input_dict, mixup_lambda=None, device=device)["embedding"]
160
+ audio_embeds = self.audio_projection(audio_embeds)
161
+
162
+ if normalize:
163
+ audio_embeds = F.normalize(audio_embeds, dim=-1)
164
+ return audio_embeds
165
+
CLAP/clap_module.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Contrastive Language-Audio Pretraining Model from LAION
3
+ --------------------------------------------------------
4
+ Paper: https://arxiv.org/abs/2211.06687
5
+ Authors (equal contributions): Ke Chen, Yusong Wu, Tianyu Zhang, Yuchen Hui
6
+ Support: LAION
7
+ """
8
+ import os
9
+ import json
10
+ import torch
11
+ import librosa
12
+ import torchaudio
13
+ import transformers
14
+ import numpy as np
15
+ from pathlib import Path
16
+ from packaging import version
17
+
18
+ from .data import get_audio_features
19
+ from .data import int16_to_float32, float32_to_int16
20
+ from .clap_model import CLAP
21
+
22
+ from transformers import RobertaTokenizer
23
+ import wget
24
+
25
+ BASE_DIR = Path(__file__).resolve().parent
26
+
27
+ class CLAP_Module(torch.nn.Module):
28
+ def __init__(self, amodel='HTSAT-tiny', tmodel='roberta') -> None:
29
+ super(CLAP_Module, self).__init__()
30
+
31
+ config_path = os.path.join(BASE_DIR, 'model_configs', f'{amodel}.json')
32
+ with open(config_path, "r") as f:
33
+ model_cfg = json.load(f)
34
+
35
+ self.tokenize = RobertaTokenizer.from_pretrained("roberta-base")
36
+
37
+ model_cfg["text_cfg"]["model_type"] = tmodel
38
+ model = CLAP(**model_cfg)
39
+
40
+ self.model = model
41
+ self.model_cfg = model_cfg
42
+
43
+ def tokenizer(self, text):
44
+ result = self.tokenize(
45
+ text,
46
+ padding="max_length",
47
+ truncation=True,
48
+ max_length=77,
49
+ return_tensors="pt",
50
+ )
51
+ return result
52
+
53
+ def load_ckpt(self, ckpt_folder_path, ckpt_name):
54
+ ckpt_path = os.path.join(ckpt_folder_path, ckpt_name)
55
+
56
+ if os.path.exists(ckpt_path):
57
+ print(f'Load checkpoint from {ckpt_path}')
58
+ else:
59
+ download_link = 'https://huggingface.co/lukewys/laion_clap/resolve/main/'
60
+ print(f'Download checkpoint from {download_link + ckpt_name}.')
61
+ ckpt_path = wget.download(download_link + ckpt_name, ckpt_folder_path)
62
+ print('Download completed!')
63
+ print()
64
+
65
+ checkpoint = torch.load(ckpt_path, map_location='cpu', weights_only=False)
66
+
67
+ if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
68
+ state_dict = checkpoint["state_dict"]
69
+ else:
70
+ state_dict = checkpoint
71
+
72
+ if next(iter(state_dict.items()))[0].startswith("module"):
73
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
74
+
75
+ if version.parse(transformers.__version__) >= version.parse("4.31.0"):
76
+ del state_dict["text_branch.embeddings.position_ids"]
77
+
78
+ self.model.load_state_dict(state_dict)
79
+
80
+ def get_audio_embedding(self, x, sr=16000, normalize=False, use_tensor=True):
81
+ self.model.eval()
82
+ if isinstance(x, str):
83
+ x = [x]
84
+
85
+ audio_input = []
86
+ for audio_waveform in x:
87
+
88
+ if isinstance(audio_waveform, str):
89
+ # load the waveform of the shape (T,), should resample to 48000
90
+ audio_waveform, _ = librosa.load(audio_waveform, sr=48000)
91
+ elif sr != 48000:
92
+ audio_waveform = torchaudio.functional.resample(audio_waveform, orig_freq=sr, new_freq=48000)
93
+
94
+ if isinstance(audio_waveform, torch.Tensor):
95
+ audio_waveform = audio_waveform.numpy()
96
+
97
+ # quantize
98
+ audio_waveform = int16_to_float32(float32_to_int16(audio_waveform))
99
+ audio_waveform = torch.from_numpy(audio_waveform).float()
100
+
101
+ temp_dict = {}
102
+ temp_dict = get_audio_features(
103
+ temp_dict, audio_waveform, 480000,
104
+ data_truncating='rand_trunc',
105
+ data_filling='repeatpad',
106
+ audio_cfg=self.model_cfg['audio_cfg'],
107
+ require_grad=audio_waveform.requires_grad
108
+ )
109
+
110
+ audio_input.append(temp_dict)
111
+
112
+ audio_embed = self.model.get_audio_embedding(audio_input, normalize)
113
+
114
+ if not use_tensor:
115
+ audio_embed = audio_embed.detach().cpu().numpy()
116
+
117
+ return audio_embed
118
+
119
+ def get_text_embedding(self, x, normalize=False, use_tensor=True):
120
+ self.model.eval()
121
+ if isinstance(x, str):
122
+ x = [x]
123
+
124
+ token_data = self.tokenizer(x)
125
+ sequence_lengths = (torch.ne(token_data['attention_mask'], 0).sum(-1) - 1)
126
+ setence_embeds = self.model.get_text_embedding(token_data, normalize)
127
+ word_embeds = self.model.get_word_embedding(token_data)
128
+
129
+ if not use_tensor:
130
+ setence_embeds = setence_embeds.detach().cpu().numpy()
131
+ word_embeds = word_embeds.detach().cpu().numpy()
132
+
133
+ return setence_embeds, word_embeds, sequence_lengths
134
+
135
+ def get_clap_score(self, text, audio, sr=16000):
136
+ setence_embeds, word_embeds, sequence_lengths = self.get_text_embedding(text, normalize=True)
137
+ audio_embeds = self.get_audio_embedding(audio, sr=16000, normalize=True)
138
+
139
+ clap_score = torch.nn.functional.cosine_similarity(setence_embeds, audio_embeds, dim=-1)
140
+
141
+ return clap_score
CLAP/data.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import torchvision
4
+ import numpy as np
5
+ from contextlib import suppress
6
+ import torch.nn.functional as F
7
+
8
+ def int16_to_float32(x):
9
+ return (x / 32767.0).astype(np.float32)
10
+
11
+ def float32_to_int16(x):
12
+ x = np.clip(x, a_min=-1., a_max=1.)
13
+ return (x * 32767.).astype(np.int16)
14
+
15
+ def get_mel(audio_data, audio_cfg):
16
+ # mel shape: (n_mels, T)
17
+ mel_tf = torchaudio.transforms.MelSpectrogram(
18
+ sample_rate=audio_cfg['sample_rate'],
19
+ n_fft=audio_cfg['window_size'],
20
+ win_length=audio_cfg['window_size'],
21
+ hop_length=audio_cfg['hop_size'],
22
+ center=True,
23
+ pad_mode="reflect",
24
+ power=2.0,
25
+ norm=None,
26
+ onesided=True,
27
+ n_mels=audio_cfg['mel_bins'],
28
+ f_min=audio_cfg['fmin'],
29
+ f_max=audio_cfg['fmax']
30
+ ).to(audio_data.device)
31
+
32
+ mel = mel_tf(audio_data)
33
+
34
+ # we use log mel spectrogram as input
35
+ mel = torchaudio.transforms.AmplitudeToDB(top_db=None)(mel)
36
+ return mel.T # (T, n_mels)
37
+
38
+ def get_audio_features(sample, audio_data, max_len, data_truncating, data_filling, audio_cfg, require_grad=False):
39
+ """
40
+ Calculate and add audio features to sample.
41
+ Sample: a dict containing all the data of current sample.
42
+ audio_data: a tensor of shape (T) containing audio data.
43
+ max_len: the maximum length of audio data.
44
+ data_truncating: the method of truncating data.
45
+ data_filling: the method of filling data.
46
+ audio_cfg: a dict containing audio configuration. Comes from model_cfg['audio_cfg'].
47
+ require_grad: whether to require gradient for audio data.
48
+ This is useful when we want to apply gradient-based classifier-guidance.
49
+ """
50
+ grad_fn = suppress if require_grad else torch.no_grad
51
+ with grad_fn():
52
+ if len(audio_data) > max_len:
53
+ if data_truncating == "rand_trunc":
54
+ longer = torch.tensor([True])
55
+ elif data_truncating == "fusion":
56
+ # fusion
57
+ mel = get_mel(audio_data, audio_cfg)
58
+ # split to three parts
59
+ chunk_frames = max_len // audio_cfg['hop_size'] + 1 # the +1 related to how the spectrogram is computed
60
+ total_frames = mel.shape[0]
61
+ if chunk_frames == total_frames:
62
+ # there is a corner case where the audio length is
63
+ # larger than max_len but smaller than max_len+hop_size.
64
+ # In this case, we just use the whole audio.
65
+ mel_fusion = torch.stack([mel, mel, mel, mel], dim=0)
66
+ sample["mel_fusion"] = mel_fusion
67
+ longer = torch.tensor([False])
68
+ else:
69
+ ranges = np.array_split(list(range(0, total_frames - chunk_frames + 1)), 3)
70
+
71
+ if len(ranges[1]) == 0:
72
+ # if the audio is too short, we just use the first chunk
73
+ ranges[1] = [0]
74
+ if len(ranges[2]) == 0:
75
+ # if the audio is too short, we just use the first chunk
76
+ ranges[2] = [0]
77
+ # randomly choose index for each part
78
+ idx_front = np.random.choice(ranges[0])
79
+ idx_middle = np.random.choice(ranges[1])
80
+ idx_back = np.random.choice(ranges[2])
81
+ # select mel
82
+ mel_chunk_front = mel[idx_front:idx_front + chunk_frames, :]
83
+ mel_chunk_middle = mel[idx_middle:idx_middle + chunk_frames, :]
84
+ mel_chunk_back = mel[idx_back:idx_back + chunk_frames, :]
85
+
86
+ # shrink the mel
87
+ mel_shrink = torchvision.transforms.Resize(size=[chunk_frames, audio_cfg['mel_bins']])(mel[None])[0]
88
+
89
+ # stack
90
+ mel_fusion = torch.stack([mel_shrink, mel_chunk_front, mel_chunk_middle, mel_chunk_back], dim=0)
91
+ sample["mel_fusion"] = mel_fusion
92
+ longer = torch.tensor([True])
93
+ else:
94
+ raise NotImplementedError(
95
+ f"data_truncating {data_truncating} not implemented"
96
+ )
97
+ # random crop to max_len (for compatibility)
98
+ overflow = len(audio_data) - max_len
99
+ idx = np.random.randint(0, overflow + 1)
100
+ audio_data = audio_data[idx: idx + max_len]
101
+
102
+ else: # padding if too short
103
+ if len(audio_data) < max_len: # do nothing if equal
104
+ if data_filling == "repeatpad":
105
+ n_repeat = int(max_len / len(audio_data))
106
+ audio_data = audio_data.repeat(n_repeat)
107
+
108
+ audio_data = F.pad(
109
+ audio_data,
110
+ (0, max_len - len(audio_data)),
111
+ mode="constant",
112
+ value=0,
113
+ )
114
+ elif data_filling == "pad":
115
+ audio_data = F.pad(
116
+ audio_data,
117
+ (0, max_len - len(audio_data)),
118
+ mode="constant",
119
+ value=0,
120
+ )
121
+ elif data_filling == "repeat":
122
+ n_repeat = int(max_len / len(audio_data))
123
+ audio_data = audio_data.repeat(n_repeat + 1)[:max_len]
124
+ else:
125
+ raise NotImplementedError(
126
+ f"data_filling {data_filling} not implemented"
127
+ )
128
+ if data_truncating == 'fusion':
129
+ mel = get_mel(audio_data, audio_cfg)
130
+ mel_fusion = torch.stack([mel, mel, mel, mel], dim=0)
131
+ sample["mel_fusion"] = mel_fusion
132
+ longer = torch.tensor([False])
133
+
134
+ sample["longer"] = longer
135
+ sample["waveform"] = audio_data
136
+
137
+ return sample
CLAP/htsat.py ADDED
@@ -0,0 +1,894 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ke Chen
2
+ # knutchen@ucsd.edu
3
+ # HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
4
+ # Some layers designed on the model
5
+ # below codes are based and referred from https://github.com/microsoft/Swin-Transformer
6
+ # Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from itertools import repeat
12
+ import collections.abc
13
+ import math
14
+ import warnings
15
+
16
+ from torch.nn.init import _calculate_fan_in_and_fan_out
17
+ import torch.utils.checkpoint as checkpoint
18
+
19
+ import random
20
+
21
+ from torchlibrosa.stft import Spectrogram, LogmelFilterBank
22
+ from torchlibrosa.augmentation import SpecAugmentation
23
+
24
+ from itertools import repeat
25
+
26
+ def interpolate(x, ratio):
27
+ """Interpolate data in time domain. This is used to compensate the
28
+ resolution reduction in downsampling of a CNN.
29
+
30
+ Args:
31
+ x: (batch_size, time_steps, classes_num)
32
+ ratio: int, ratio to interpolate
33
+ Returns:
34
+ upsampled: (batch_size, time_steps * ratio, classes_num)
35
+ """
36
+ (batch_size, time_steps, classes_num) = x.shape
37
+ upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
38
+ upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
39
+ return upsampled
40
+
41
+ def do_mixup(x, mixup_lambda):
42
+ """
43
+ Args:
44
+ x: (batch_size , ...)
45
+ mixup_lambda: (batch_size,)
46
+ Returns:
47
+ out: (batch_size, ...)
48
+ """
49
+ out = (
50
+ x.transpose(0, -1) * mixup_lambda
51
+ + torch.flip(x, dims=[0]).transpose(0, -1) * (1 - mixup_lambda)
52
+ ).transpose(0, -1)
53
+ return out
54
+
55
+ # from PyTorch internals
56
+ def _ntuple(n):
57
+ def parse(x):
58
+ if isinstance(x, collections.abc.Iterable):
59
+ return x
60
+ return tuple(repeat(x, n))
61
+ return parse
62
+
63
+ to_1tuple = _ntuple(1)
64
+ to_2tuple = _ntuple(2)
65
+ to_3tuple = _ntuple(3)
66
+ to_4tuple = _ntuple(4)
67
+ to_ntuple = _ntuple
68
+
69
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
70
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
71
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
72
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
73
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
74
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
75
+ 'survival rate' as the argument.
76
+ """
77
+ if drop_prob == 0. or not training:
78
+ return x
79
+ keep_prob = 1 - drop_prob
80
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
81
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
82
+ random_tensor.floor_() # binarize
83
+ output = x.div(keep_prob) * random_tensor
84
+ return output
85
+
86
+
87
+ class DropPath(nn.Module):
88
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
89
+ """
90
+ def __init__(self, drop_prob=None):
91
+ super(DropPath, self).__init__()
92
+ self.drop_prob = drop_prob
93
+
94
+ def forward(self, x):
95
+ return drop_path(x, self.drop_prob, self.training)
96
+
97
+ class PatchEmbed(nn.Module):
98
+ """ 2D Image to Patch Embedding
99
+ """
100
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, patch_stride = 16):
101
+ super().__init__()
102
+ img_size = to_2tuple(img_size)
103
+ patch_size = to_2tuple(patch_size)
104
+ patch_stride = to_2tuple(patch_stride)
105
+ self.img_size = img_size
106
+ self.patch_size = patch_size
107
+ self.patch_stride = patch_stride
108
+ self.grid_size = (img_size[0] // patch_stride[0], img_size[1] // patch_stride[1])
109
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
110
+ self.flatten = flatten
111
+ self.in_chans = in_chans
112
+ self.embed_dim = embed_dim
113
+
114
+ padding = ((patch_size[0] - patch_stride[0]) // 2, (patch_size[1] - patch_stride[1]) // 2)
115
+
116
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride, padding=padding)
117
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
118
+
119
+ def forward(self, x):
120
+ B, C, H, W = x.shape
121
+ assert H == self.img_size[0] and W == self.img_size[1], \
122
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
123
+ x = self.proj(x)
124
+
125
+ if self.flatten:
126
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
127
+ x = self.norm(x)
128
+ return x
129
+
130
+ class Mlp(nn.Module):
131
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
132
+ """
133
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
134
+ super().__init__()
135
+ out_features = out_features or in_features
136
+ hidden_features = hidden_features or in_features
137
+ self.fc1 = nn.Linear(in_features, hidden_features)
138
+ self.act = act_layer()
139
+ self.fc2 = nn.Linear(hidden_features, out_features)
140
+ self.drop = nn.Dropout(drop)
141
+
142
+ def forward(self, x):
143
+ x = self.fc1(x)
144
+ x = self.act(x)
145
+ x = self.drop(x)
146
+ x = self.fc2(x)
147
+ x = self.drop(x)
148
+ return x
149
+
150
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
151
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
152
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
153
+ def norm_cdf(x):
154
+ # Computes standard normal cumulative distribution function
155
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
156
+
157
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
158
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
159
+ "The distribution of values may be incorrect.",
160
+ stacklevel=2)
161
+
162
+ with torch.no_grad():
163
+ # Values are generated by using a truncated uniform distribution and
164
+ # then using the inverse CDF for the normal distribution.
165
+ # Get upper and lower cdf values
166
+ l = norm_cdf((a - mean) / std)
167
+ u = norm_cdf((b - mean) / std)
168
+
169
+ # Uniformly fill tensor with values from [l, u], then translate to
170
+ # [2l-1, 2u-1].
171
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
172
+
173
+ # Use inverse cdf transform for normal distribution to get truncated
174
+ # standard normal
175
+ tensor.erfinv_()
176
+
177
+ # Transform to proper mean, std
178
+ tensor.mul_(std * math.sqrt(2.))
179
+ tensor.add_(mean)
180
+
181
+ # Clamp to ensure it's in the proper range
182
+ tensor.clamp_(min=a, max=b)
183
+ return tensor
184
+
185
+
186
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
187
+ # type: (Tensor, float, float, float, float) -> Tensor
188
+ r"""Fills the input Tensor with values drawn from a truncated
189
+ normal distribution. The values are effectively drawn from the
190
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
191
+ with values outside :math:`[a, b]` redrawn until they are within
192
+ the bounds. The method used for generating the random values works
193
+ best when :math:`a \leq \text{mean} \leq b`.
194
+ Args:
195
+ tensor: an n-dimensional `torch.Tensor`
196
+ mean: the mean of the normal distribution
197
+ std: the standard deviation of the normal distribution
198
+ a: the minimum cutoff value
199
+ b: the maximum cutoff value
200
+ Examples:
201
+ >>> w = torch.empty(3, 5)
202
+ >>> nn.init.trunc_normal_(w)
203
+ """
204
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
205
+
206
+
207
+ def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
208
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
209
+ if mode == 'fan_in':
210
+ denom = fan_in
211
+ elif mode == 'fan_out':
212
+ denom = fan_out
213
+ elif mode == 'fan_avg':
214
+ denom = (fan_in + fan_out) / 2
215
+
216
+ variance = scale / denom
217
+
218
+ if distribution == "truncated_normal":
219
+ # constant is stddev of standard normal truncated to (-2, 2)
220
+ trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
221
+ elif distribution == "normal":
222
+ tensor.normal_(std=math.sqrt(variance))
223
+ elif distribution == "uniform":
224
+ bound = math.sqrt(3 * variance)
225
+ tensor.uniform_(-bound, bound)
226
+ else:
227
+ raise ValueError(f"invalid distribution {distribution}")
228
+
229
+
230
+ def lecun_normal_(tensor):
231
+ variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
232
+
233
+ def window_partition(x, window_size):
234
+ """
235
+ Args:
236
+ x: (B, H, W, C)
237
+ window_size (int): window size
238
+ Returns:
239
+ windows: (num_windows*B, window_size, window_size, C)
240
+ """
241
+ B, H, W, C = x.shape
242
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
243
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
244
+ return windows
245
+
246
+
247
+ def window_reverse(windows, window_size, H, W):
248
+ """
249
+ Args:
250
+ windows: (num_windows*B, window_size, window_size, C)
251
+ window_size (int): Window size
252
+ H (int): Height of image
253
+ W (int): Width of image
254
+ Returns:
255
+ x: (B, H, W, C)
256
+ """
257
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
258
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
259
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
260
+ return x
261
+
262
+
263
+ class WindowAttention(nn.Module):
264
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
265
+ It supports both of shifted and non-shifted window.
266
+ Args:
267
+ dim (int): Number of input channels.
268
+ window_size (tuple[int]): The height and width of the window.
269
+ num_heads (int): Number of attention heads.
270
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
271
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
272
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
273
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
274
+ """
275
+
276
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
277
+
278
+ super().__init__()
279
+ self.dim = dim
280
+ self.window_size = window_size # Wh, Ww
281
+ self.num_heads = num_heads
282
+ head_dim = dim // num_heads
283
+ self.scale = qk_scale or head_dim ** -0.5
284
+
285
+ # define a parameter table of relative position bias
286
+ self.relative_position_bias_table = nn.Parameter(
287
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
288
+
289
+ # get pair-wise relative position index for each token inside the window
290
+ coords_h = torch.arange(self.window_size[0])
291
+ coords_w = torch.arange(self.window_size[1])
292
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
293
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
294
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
295
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
296
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
297
+ relative_coords[:, :, 1] += self.window_size[1] - 1
298
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
299
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
300
+ self.register_buffer("relative_position_index", relative_position_index)
301
+
302
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
303
+ self.attn_drop = nn.Dropout(attn_drop)
304
+ self.proj = nn.Linear(dim, dim)
305
+ self.proj_drop = nn.Dropout(proj_drop)
306
+
307
+ trunc_normal_(self.relative_position_bias_table, std=.02)
308
+ self.softmax = nn.Softmax(dim=-1)
309
+
310
+ def forward(self, x, mask=None):
311
+ """
312
+ Args:
313
+ x: input features with shape of (num_windows*B, N, C)
314
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
315
+ """
316
+ B_, N, C = x.shape
317
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
318
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
319
+
320
+ q = q * self.scale
321
+ attn = (q @ k.transpose(-2, -1))
322
+
323
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
324
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
325
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
326
+ attn = attn + relative_position_bias.unsqueeze(0)
327
+
328
+ if mask is not None:
329
+ nW = mask.shape[0]
330
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
331
+ attn = attn.view(-1, self.num_heads, N, N)
332
+ attn = self.softmax(attn)
333
+ else:
334
+ attn = self.softmax(attn)
335
+
336
+ attn = self.attn_drop(attn)
337
+
338
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
339
+ x = self.proj(x)
340
+ x = self.proj_drop(x)
341
+ return x, attn
342
+
343
+ def extra_repr(self):
344
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
345
+
346
+
347
+ # We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model
348
+ class SwinTransformerBlock(nn.Module):
349
+ r""" Swin Transformer Block.
350
+ Args:
351
+ dim (int): Number of input channels.
352
+ input_resolution (tuple[int]): Input resulotion.
353
+ num_heads (int): Number of attention heads.
354
+ window_size (int): Window size.
355
+ shift_size (int): Shift size for SW-MSA.
356
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
357
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
358
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
359
+ drop (float, optional): Dropout rate. Default: 0.0
360
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
361
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
362
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
363
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
364
+ """
365
+
366
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
367
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
368
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_before_mlp='ln'):
369
+ super().__init__()
370
+ self.dim = dim
371
+ self.input_resolution = input_resolution
372
+ self.num_heads = num_heads
373
+ self.window_size = window_size
374
+ self.shift_size = shift_size
375
+ self.mlp_ratio = mlp_ratio
376
+ self.norm_before_mlp = norm_before_mlp
377
+ if min(self.input_resolution) <= self.window_size:
378
+ # if window size is larger than input resolution, we don't partition windows
379
+ self.shift_size = 0
380
+ self.window_size = min(self.input_resolution)
381
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
382
+
383
+ self.norm1 = norm_layer(dim)
384
+ self.attn = WindowAttention(
385
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
386
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
387
+
388
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
389
+ if self.norm_before_mlp == 'ln':
390
+ self.norm2 = nn.LayerNorm(dim)
391
+ elif self.norm_before_mlp == 'bn':
392
+ self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(1, 2)
393
+ else:
394
+ raise NotImplementedError
395
+ mlp_hidden_dim = int(dim * mlp_ratio)
396
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
397
+
398
+ if self.shift_size > 0:
399
+ # calculate attention mask for SW-MSA
400
+ H, W = self.input_resolution
401
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
402
+ h_slices = (slice(0, -self.window_size),
403
+ slice(-self.window_size, -self.shift_size),
404
+ slice(-self.shift_size, None))
405
+ w_slices = (slice(0, -self.window_size),
406
+ slice(-self.window_size, -self.shift_size),
407
+ slice(-self.shift_size, None))
408
+ cnt = 0
409
+ for h in h_slices:
410
+ for w in w_slices:
411
+ img_mask[:, h, w, :] = cnt
412
+ cnt += 1
413
+
414
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
415
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
416
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
417
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
418
+ else:
419
+ attn_mask = None
420
+
421
+ self.register_buffer("attn_mask", attn_mask)
422
+
423
+ def forward(self, x):
424
+ # pdb.set_trace()
425
+ H, W = self.input_resolution
426
+ # print("H: ", H)
427
+ # print("W: ", W)
428
+ # pdb.set_trace()
429
+ B, L, C = x.shape
430
+ # assert L == H * W, "input feature has wrong size"
431
+
432
+ shortcut = x
433
+ x = self.norm1(x)
434
+ x = x.view(B, H, W, C)
435
+
436
+ # cyclic shift
437
+ if self.shift_size > 0:
438
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
439
+ else:
440
+ shifted_x = x
441
+
442
+ # partition windows
443
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
444
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
445
+
446
+ # W-MSA/SW-MSA
447
+ attn_windows, attn = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
448
+
449
+ # merge windows
450
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
451
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
452
+
453
+ # reverse cyclic shift
454
+ if self.shift_size > 0:
455
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
456
+ else:
457
+ x = shifted_x
458
+ x = x.view(B, H * W, C)
459
+
460
+ # FFN
461
+ x = shortcut + self.drop_path(x)
462
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
463
+
464
+ return x, attn
465
+
466
+ def extra_repr(self):
467
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
468
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
469
+
470
+
471
+
472
+ class PatchMerging(nn.Module):
473
+ r""" Patch Merging Layer.
474
+ Args:
475
+ input_resolution (tuple[int]): Resolution of input feature.
476
+ dim (int): Number of input channels.
477
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
478
+ """
479
+
480
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
481
+ super().__init__()
482
+ self.input_resolution = input_resolution
483
+ self.dim = dim
484
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
485
+ self.norm = norm_layer(4 * dim)
486
+
487
+ def forward(self, x):
488
+ """
489
+ x: B, H*W, C
490
+ """
491
+ H, W = self.input_resolution
492
+ B, L, C = x.shape
493
+ assert L == H * W, "input feature has wrong size"
494
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
495
+
496
+ x = x.view(B, H, W, C)
497
+
498
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
499
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
500
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
501
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
502
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
503
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
504
+
505
+ x = self.norm(x)
506
+ x = self.reduction(x)
507
+
508
+ return x
509
+
510
+ def extra_repr(self):
511
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
512
+
513
+
514
+ class BasicLayer(nn.Module):
515
+ """ A basic Swin Transformer layer for one stage.
516
+ Args:
517
+ dim (int): Number of input channels.
518
+ input_resolution (tuple[int]): Input resolution.
519
+ depth (int): Number of blocks.
520
+ num_heads (int): Number of attention heads.
521
+ window_size (int): Local window size.
522
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
523
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
524
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
525
+ drop (float, optional): Dropout rate. Default: 0.0
526
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
527
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
528
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
529
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
530
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
531
+ """
532
+
533
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
534
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
535
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
536
+ norm_before_mlp='ln'):
537
+
538
+ super().__init__()
539
+ self.dim = dim
540
+ self.input_resolution = input_resolution
541
+ self.depth = depth
542
+ self.use_checkpoint = use_checkpoint
543
+
544
+ # build blocks
545
+ self.blocks = nn.ModuleList([
546
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
547
+ num_heads=num_heads, window_size=window_size,
548
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
549
+ mlp_ratio=mlp_ratio,
550
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
551
+ drop=drop, attn_drop=attn_drop,
552
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
553
+ norm_layer=norm_layer, norm_before_mlp=norm_before_mlp)
554
+ for i in range(depth)])
555
+
556
+ # patch merging layer
557
+ if downsample is not None:
558
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
559
+ else:
560
+ self.downsample = None
561
+
562
+ def forward(self, x):
563
+ attns = []
564
+ for blk in self.blocks:
565
+ if self.use_checkpoint:
566
+ x = checkpoint.checkpoint(blk, x)
567
+ else:
568
+ x, attn = blk(x)
569
+ if not self.training:
570
+ attns.append(attn.unsqueeze(0))
571
+ if self.downsample is not None:
572
+ x = self.downsample(x)
573
+ if not self.training:
574
+ attn = torch.cat(attns, dim = 0)
575
+ attn = torch.mean(attn, dim = 0)
576
+ return x, attn
577
+
578
+ def extra_repr(self):
579
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
580
+
581
+
582
+ # The Core of HTSAT
583
+ class HTSAT_Swin_Transformer(nn.Module):
584
+ r"""HTSAT based on the Swin Transformer
585
+ Args:
586
+ spec_size (int | tuple(int)): Input Spectrogram size. Default 256
587
+ patch_size (int | tuple(int)): Patch size. Default: 4
588
+ path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4
589
+ in_chans (int): Number of input image channels. Default: 1 (mono)
590
+ num_classes (int): Number of classes for classification head. Default: 527
591
+ embed_dim (int): Patch embedding dimension. Default: 96
592
+ depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer.
593
+ num_heads (tuple(int)): Number of attention heads in different layers.
594
+ window_size (int): Window size. Default: 8
595
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
596
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
597
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
598
+ drop_rate (float): Dropout rate. Default: 0
599
+ attn_drop_rate (float): Attention dropout rate. Default: 0
600
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
601
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
602
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
603
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
604
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
605
+ config (module): The configuration Module from config.py
606
+ """
607
+
608
+ def __init__(self, spec_size=256, patch_size=4, patch_stride=(4,4),
609
+ in_chans=1, num_classes=527,
610
+ embed_dim=96, depths=[2, 2, 6, 2], num_heads=[4, 8, 16, 32],
611
+ window_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None,
612
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
613
+ norm_layer=nn.LayerNorm,
614
+ ape=False, patch_norm=True,
615
+ use_checkpoint=False, norm_before_mlp='ln', config = None, **kwargs):
616
+ super(HTSAT_Swin_Transformer, self).__init__()
617
+
618
+ self.config = config
619
+ self.spec_size = spec_size
620
+ self.patch_stride = patch_stride
621
+ self.patch_size = patch_size
622
+ self.window_size = window_size
623
+ self.embed_dim = embed_dim
624
+ self.depths = depths
625
+ self.ape = ape
626
+ self.in_chans = in_chans
627
+ self.num_classes = num_classes
628
+ self.num_heads = num_heads
629
+ self.num_layers = len(self.depths)
630
+ self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1))
631
+
632
+ self.drop_rate = drop_rate
633
+ self.attn_drop_rate = attn_drop_rate
634
+ self.drop_path_rate = drop_path_rate
635
+
636
+ self.qkv_bias = qkv_bias
637
+ self.qk_scale = None
638
+
639
+ self.patch_norm = patch_norm
640
+ self.norm_layer = norm_layer if self.patch_norm else None
641
+ self.norm_before_mlp = norm_before_mlp
642
+ self.mlp_ratio = mlp_ratio
643
+
644
+ self.use_checkpoint = use_checkpoint
645
+
646
+ # process mel-spec ; used only once
647
+ self.freq_ratio = self.spec_size // self.config.mel_bins
648
+ window = 'hann'
649
+ center = True
650
+ pad_mode = 'reflect'
651
+ ref = 1.0
652
+ amin = 1e-10
653
+ top_db = None
654
+ self.interpolate_ratio = 32 # Downsampled ratio
655
+ # Spectrogram extractor
656
+ self.spectrogram_extractor = Spectrogram(n_fft=config.window_size, hop_length=config.hop_size,
657
+ win_length=config.window_size, window=window, center=center, pad_mode=pad_mode,
658
+ freeze_parameters=True)
659
+ # Logmel feature extractor
660
+ self.logmel_extractor = LogmelFilterBank(sr=config.sample_rate, n_fft=config.window_size,
661
+ n_mels=config.mel_bins, fmin=config.fmin, fmax=config.fmax, ref=ref, amin=amin, top_db=top_db,
662
+ freeze_parameters=True)
663
+ # Spec augmenter
664
+ self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2,
665
+ freq_drop_width=8, freq_stripes_num=2) # 2 2
666
+ self.bn0 = nn.BatchNorm2d(self.config.mel_bins)
667
+
668
+
669
+ # split spctrogram into non-overlapping patches
670
+ self.patch_embed = PatchEmbed(
671
+ img_size=self.spec_size, patch_size=self.patch_size, in_chans=self.in_chans,
672
+ embed_dim=self.embed_dim, norm_layer=self.norm_layer, patch_stride = patch_stride,
673
+ )
674
+
675
+ num_patches = self.patch_embed.num_patches
676
+ patches_resolution = self.patch_embed.grid_size
677
+ self.patches_resolution = patches_resolution
678
+
679
+ # absolute position embedding
680
+ if self.ape:
681
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dim))
682
+ trunc_normal_(self.absolute_pos_embed, std=.02)
683
+
684
+ self.pos_drop = nn.Dropout(p=self.drop_rate)
685
+
686
+ # stochastic depth
687
+ dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))] # stochastic depth decay rule
688
+
689
+ # build layers
690
+ self.layers = nn.ModuleList()
691
+ for i_layer in range(self.num_layers):
692
+ layer = BasicLayer(dim=int(self.embed_dim * 2 ** i_layer),
693
+ input_resolution=(patches_resolution[0] // (2 ** i_layer),
694
+ patches_resolution[1] // (2 ** i_layer)),
695
+ depth=self.depths[i_layer],
696
+ num_heads=self.num_heads[i_layer],
697
+ window_size=self.window_size,
698
+ mlp_ratio=self.mlp_ratio,
699
+ qkv_bias=self.qkv_bias, qk_scale=self.qk_scale,
700
+ drop=self.drop_rate, attn_drop=self.attn_drop_rate,
701
+ drop_path=dpr[sum(self.depths[:i_layer]):sum(self.depths[:i_layer + 1])],
702
+ norm_layer=self.norm_layer,
703
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
704
+ use_checkpoint=use_checkpoint,
705
+ norm_before_mlp=self.norm_before_mlp)
706
+ self.layers.append(layer)
707
+
708
+ self.norm = self.norm_layer(self.num_features)
709
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
710
+ self.maxpool = nn.AdaptiveMaxPool1d(1)
711
+
712
+ SF = self.spec_size // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] // self.freq_ratio
713
+ self.tscam_conv = nn.Conv2d(
714
+ in_channels = self.num_features,
715
+ out_channels = self.num_classes,
716
+ kernel_size = (SF,3),
717
+ padding = (0,1)
718
+ )
719
+ self.head = nn.Linear(num_classes, num_classes)
720
+
721
+ self.apply(self._init_weights)
722
+
723
+ def _init_weights(self, m):
724
+ if isinstance(m, nn.Linear):
725
+ trunc_normal_(m.weight, std=.02)
726
+ if isinstance(m, nn.Linear) and m.bias is not None:
727
+ nn.init.constant_(m.bias, 0)
728
+ elif isinstance(m, nn.LayerNorm):
729
+ nn.init.constant_(m.bias, 0)
730
+ nn.init.constant_(m.weight, 1.0)
731
+
732
+ @torch.jit.ignore
733
+ def no_weight_decay(self):
734
+ return {'absolute_pos_embed'}
735
+
736
+ @torch.jit.ignore
737
+ def no_weight_decay_keywords(self):
738
+ return {'relative_position_bias_table'}
739
+
740
+
741
+ def forward_features(self, x):
742
+ # A deprecated optimization for using a hierarchical output from different blocks
743
+
744
+ frames_num = x.shape[2]
745
+ x = self.patch_embed(x)
746
+ if self.ape:
747
+ x = x + self.absolute_pos_embed
748
+ x = self.pos_drop(x)
749
+ for i, layer in enumerate(self.layers):
750
+ x, attn = layer(x)
751
+ # for x
752
+ x = self.norm(x)
753
+ B, N, C = x.shape
754
+ SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0]
755
+ ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1]
756
+ x = x.permute(0,2,1).contiguous().reshape(B, C, SF, ST)
757
+ B, C, F, T = x.shape
758
+ # group 2D CNN
759
+ c_freq_bin = F // self.freq_ratio
760
+ x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
761
+ x = x.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1)
762
+ # get latent_output
763
+ fine_grained_latent_output = torch.mean(x, dim = 2)
764
+ fine_grained_latent_output = interpolate(fine_grained_latent_output.permute(0,2,1).contiguous(), 8 * self.patch_stride[1])
765
+
766
+ latent_output = self.avgpool(torch.flatten(x,2))
767
+ latent_output = torch.flatten(latent_output, 1)
768
+
769
+ # display the attention map, if needed
770
+
771
+ x = self.tscam_conv(x)
772
+ x = torch.flatten(x, 2) # B, C, T
773
+
774
+ fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous(), 8 * self.patch_stride[1])
775
+
776
+ x = self.avgpool(x)
777
+ x = torch.flatten(x, 1)
778
+
779
+ output_dict = {
780
+ 'framewise_output': fpx, # already sigmoided
781
+ 'clipwise_output': torch.sigmoid(x),
782
+ 'fine_grained_embedding': fine_grained_latent_output,
783
+ 'embedding': latent_output
784
+ }
785
+
786
+ return output_dict
787
+
788
+ def crop_wav(self, x, crop_size, spe_pos = None):
789
+ time_steps = x.shape[2]
790
+ tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device)
791
+ for i in range(len(x)):
792
+ if spe_pos is None:
793
+ crop_pos = random.randint(0, time_steps - crop_size - 1)
794
+ else:
795
+ crop_pos = spe_pos
796
+ tx[i][0] = x[i, 0, crop_pos:crop_pos + crop_size,:]
797
+ return tx
798
+
799
+ # Reshape the wavform to a img size, if you want to use the pretrained swin transformer model
800
+ def reshape_wav2img(self, x):
801
+ B, C, T, F = x.shape
802
+ target_T = int(self.spec_size * self.freq_ratio)
803
+ target_F = self.spec_size // self.freq_ratio
804
+ assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size"
805
+ # to avoid bicubic zero error
806
+ if T < target_T:
807
+ x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True)
808
+ if F < target_F:
809
+ x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True)
810
+ x = x.permute(0,1,3,2).contiguous()
811
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2], self.freq_ratio, x.shape[3] // self.freq_ratio)
812
+ # print(x.shape)
813
+ x = x.permute(0,1,3,2,4).contiguous()
814
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4])
815
+ return x
816
+
817
+ # Repeat the wavform to a img size, if you want to use the pretrained swin transformer model
818
+ def repeat_wat2img(self, x, cur_pos):
819
+ B, C, T, F = x.shape
820
+ target_T = int(self.spec_size * self.freq_ratio)
821
+ target_F = self.spec_size // self.freq_ratio
822
+ assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size"
823
+ # to avoid bicubic zero error
824
+ if T < target_T:
825
+ x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True)
826
+ if F < target_F:
827
+ x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True)
828
+ x = x.permute(0,1,3,2).contiguous() # B C F T
829
+ x = x[:,:,:,cur_pos:cur_pos + self.spec_size]
830
+ x = x.repeat(repeats = (1,1,4,1))
831
+ return x
832
+
833
+ def forward(self, x: torch.Tensor, mixup_lambda = None, infer_mode = False, device=None):# out_feat_keys: List[str] = None):
834
+
835
+ x = x["waveform"].to(device=device, non_blocking=True)
836
+ x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins)
837
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
838
+ x = x.transpose(1, 3)
839
+ x = self.bn0(x)
840
+ x = x.transpose(1, 3)
841
+ if self.training:
842
+ x = self.spec_augmenter(x)
843
+
844
+ if self.training and mixup_lambda is not None:
845
+ x = do_mixup(x, mixup_lambda)
846
+
847
+ x = self.reshape_wav2img(x)
848
+ output_dict = self.forward_features(x)
849
+
850
+ return output_dict
851
+
852
+ def create_htsat_model(audio_cfg):
853
+ try:
854
+ assert audio_cfg.model_name in ["tiny", "base", "large"], "model name for HTS-AT is wrong!"
855
+ if audio_cfg.model_name == "tiny":
856
+ model = HTSAT_Swin_Transformer(
857
+ spec_size=256,
858
+ patch_size=4,
859
+ patch_stride=(4,4),
860
+ num_classes=audio_cfg.class_num,
861
+ embed_dim=96,
862
+ depths=[2,2,6,2],
863
+ num_heads=[4,8,16,32],
864
+ window_size=8,
865
+ config = audio_cfg
866
+ )
867
+ elif audio_cfg.model_name == "base":
868
+ model = HTSAT_Swin_Transformer(
869
+ spec_size=256,
870
+ patch_size=4,
871
+ patch_stride=(4,4),
872
+ num_classes=audio_cfg.class_num,
873
+ embed_dim=128,
874
+ depths=[2,2,12,2],
875
+ num_heads=[4,8,16,32],
876
+ window_size=8,
877
+ config = audio_cfg
878
+ )
879
+ elif audio_cfg.model_name == "large":
880
+ model = HTSAT_Swin_Transformer(
881
+ spec_size=256,
882
+ patch_size=4,
883
+ patch_stride=(4,4),
884
+ num_classes=audio_cfg.class_num,
885
+ embed_dim=256,
886
+ depths=[2,2,12,2],
887
+ num_heads=[4,8,16,32],
888
+ window_size=8,
889
+ config = audio_cfg
890
+ )
891
+ return model
892
+ except:
893
+ raise RuntimeError(f'Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough.')
894
+
CLAP/model_configs/HTSAT-base.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "HTSAT",
14
+ "model_name": "base"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
CLAP/model_configs/HTSAT-tiny.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "HTSAT",
14
+ "model_name": "tiny"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
HiFiGAN/hifigan_model.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn import Conv1d, ConvTranspose1d
5
+ from torch.nn.utils import weight_norm, remove_weight_norm
6
+
7
+ LRELU_SLOPE = 0.1
8
+
9
+
10
+ def init_weights(m, mean=0.0, std=0.01):
11
+ classname = m.__class__.__name__
12
+ if classname.find("Conv") != -1:
13
+ m.weight.data.normal_(mean, std)
14
+
15
+
16
+ def get_padding(kernel_size, dilation=1):
17
+ return int((kernel_size * dilation - dilation) / 2)
18
+
19
+ class AttrDict(dict):
20
+ def __init__(self, *args, **kwargs):
21
+ super(AttrDict, self).__init__(*args, **kwargs)
22
+ self.__dict__ = self
23
+
24
+ class ResBlock(torch.nn.Module):
25
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
26
+ super(ResBlock, self).__init__()
27
+ self.h = h
28
+ self.convs1 = nn.ModuleList(
29
+ [
30
+ weight_norm(
31
+ Conv1d(
32
+ channels,
33
+ channels,
34
+ kernel_size,
35
+ 1,
36
+ dilation=dilation[0],
37
+ padding=get_padding(kernel_size, dilation[0]),
38
+ )
39
+ ),
40
+ weight_norm(
41
+ Conv1d(
42
+ channels,
43
+ channels,
44
+ kernel_size,
45
+ 1,
46
+ dilation=dilation[1],
47
+ padding=get_padding(kernel_size, dilation[1]),
48
+ )
49
+ ),
50
+ weight_norm(
51
+ Conv1d(
52
+ channels,
53
+ channels,
54
+ kernel_size,
55
+ 1,
56
+ dilation=dilation[2],
57
+ padding=get_padding(kernel_size, dilation[2]),
58
+ )
59
+ ),
60
+ ]
61
+ )
62
+ self.convs1.apply(init_weights)
63
+
64
+ self.convs2 = nn.ModuleList(
65
+ [
66
+ weight_norm(
67
+ Conv1d(
68
+ channels,
69
+ channels,
70
+ kernel_size,
71
+ 1,
72
+ dilation=1,
73
+ padding=get_padding(kernel_size, 1),
74
+ )
75
+ ),
76
+ weight_norm(
77
+ Conv1d(
78
+ channels,
79
+ channels,
80
+ kernel_size,
81
+ 1,
82
+ dilation=1,
83
+ padding=get_padding(kernel_size, 1),
84
+ )
85
+ ),
86
+ weight_norm(
87
+ Conv1d(
88
+ channels,
89
+ channels,
90
+ kernel_size,
91
+ 1,
92
+ dilation=1,
93
+ padding=get_padding(kernel_size, 1),
94
+ )
95
+ ),
96
+ ]
97
+ )
98
+ self.convs2.apply(init_weights)
99
+
100
+ def forward(self, x):
101
+ for c1, c2 in zip(self.convs1, self.convs2):
102
+ xt = F.leaky_relu(x, LRELU_SLOPE)
103
+ xt = c1(xt)
104
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
105
+ xt = c2(xt)
106
+ x = xt + x
107
+ return x
108
+
109
+ def remove_weight_norm(self):
110
+ for l in self.convs1:
111
+ remove_weight_norm(l)
112
+ for l in self.convs2:
113
+ remove_weight_norm(l)
114
+
115
+
116
+ class Generator(torch.nn.Module):
117
+ def __init__(self, h):
118
+ super(Generator, self).__init__()
119
+ self.h = h
120
+ self.num_kernels = len(h.resblock_kernel_sizes)
121
+ self.num_upsamples = len(h.upsample_rates)
122
+ self.conv_pre = weight_norm(
123
+ Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
124
+ )
125
+ resblock = ResBlock
126
+
127
+ self.ups = nn.ModuleList()
128
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
129
+ self.ups.append(
130
+ weight_norm(
131
+ ConvTranspose1d(
132
+ h.upsample_initial_channel // (2**i),
133
+ h.upsample_initial_channel // (2 ** (i + 1)),
134
+ k,
135
+ u,
136
+ padding=(k - u) // 2,
137
+ )
138
+ )
139
+ )
140
+
141
+ self.resblocks = nn.ModuleList()
142
+ for i in range(len(self.ups)):
143
+ ch = h.upsample_initial_channel // (2 ** (i + 1))
144
+ for j, (k, d) in enumerate(
145
+ zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
146
+ ):
147
+ self.resblocks.append(resblock(h, ch, k, d))
148
+
149
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
150
+ self.ups.apply(init_weights)
151
+ self.conv_post.apply(init_weights)
152
+
153
+ def forward(self, x):
154
+ x = self.conv_pre(x)
155
+ for i in range(self.num_upsamples):
156
+ x = F.leaky_relu(x, LRELU_SLOPE)
157
+ x = self.ups[i](x)
158
+ xs = None
159
+ for j in range(self.num_kernels):
160
+ if xs is None:
161
+ xs = self.resblocks[i * self.num_kernels + j](x)
162
+ else:
163
+ xs += self.resblocks[i * self.num_kernels + j](x)
164
+ x = xs / self.num_kernels
165
+ x = F.leaky_relu(x)
166
+ x = self.conv_post(x)
167
+ x = torch.tanh(x)
168
+
169
+ return x
170
+
171
+ def remove_weight_norm(self):
172
+ for l in self.ups:
173
+ remove_weight_norm(l)
174
+ for l in self.resblocks:
175
+ l.remove_weight_norm()
176
+ remove_weight_norm(self.conv_pre)
177
+ remove_weight_norm(self.conv_post)
HiFiGAN/inference.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+
5
+ from .hifigan_model import AttrDict, Generator
6
+
7
+ def torch_version_orig_mod_remove(state_dict):
8
+ new_state_dict = {}
9
+ new_state_dict["generator"] = {}
10
+ for key in state_dict["generator"].keys():
11
+ if "_orig_mod." in key:
12
+ new_state_dict["generator"][key.replace("_orig_mod.", "")] = state_dict[
13
+ "generator"
14
+ ][key]
15
+ else:
16
+ new_state_dict["generator"][key] = state_dict["generator"][key]
17
+ return new_state_dict
18
+
19
+ def get_vocoder(sr, ckpt_path):
20
+
21
+ with open(os.path.join(ckpt_path, "hifigan_16k_64bins.json"), "r") as f:
22
+ config = json.load(f)
23
+ config = AttrDict(config)
24
+ vocoder = Generator(config)
25
+
26
+ ckpt = torch.load(os.path.join(ckpt_path, "hifigan_16k_64bins.ckpt"), map_location='cpu')
27
+ ckpt = torch_version_orig_mod_remove(ckpt)
28
+ vocoder.load_state_dict(ckpt["generator"])
29
+ vocoder.eval()
30
+ vocoder.remove_weight_norm()
31
+
32
+ return vocoder
generator.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import init
4
+ from torch.nn import functional as F
5
+
6
+ class SpectralNorm:
7
+ def __init__(self, name):
8
+ self.name = name
9
+
10
+ def compute_weight(self, module):
11
+ weight = getattr(module, self.name + '_orig')
12
+ u = getattr(module, self.name + '_u')
13
+ size = weight.size()
14
+ weight_mat = weight.contiguous().view(size[0], -1)
15
+ with torch.no_grad():
16
+ v = weight_mat.t() @ u
17
+ v = v / v.norm()
18
+ u = weight_mat @ v
19
+ u = u / u.norm()
20
+ sigma = u @ weight_mat @ v
21
+ weight_sn = weight / sigma
22
+
23
+ return weight_sn, u, sigma
24
+
25
+ @staticmethod
26
+ def apply(module, name):
27
+ fn = SpectralNorm(name)
28
+
29
+ weight = getattr(module, name)
30
+ del module._parameters[name]
31
+ module.register_parameter(name + '_orig', weight)
32
+ input_size = weight.size(0)
33
+ u = weight.new_empty(input_size).normal_()
34
+ module.register_buffer(name, weight)
35
+ module.register_buffer(name + '_u', u)
36
+ module.register_buffer(name + '_sv', torch.ones(1).squeeze())
37
+
38
+ module.register_forward_pre_hook(fn)
39
+
40
+ return fn
41
+
42
+ def __call__(self, module, input):
43
+ weight_sn, u, sigma = self.compute_weight(module)
44
+ setattr(module, self.name, weight_sn)
45
+ setattr(module, self.name + '_u', u)
46
+ setattr(module, self.name + '_sv', sigma)
47
+
48
+ def spectral_norm(module, name='weight'):
49
+ SpectralNorm.apply(module, name)
50
+ return module
51
+
52
+ def spectral_init(module, gain=1):
53
+ init.xavier_uniform_(module.weight, gain)
54
+ if module.bias is not None:
55
+ module.bias.data.zero_()
56
+ return spectral_norm(module)
57
+
58
+ class ConditionalNorm(nn.Module):
59
+ def __init__(self, in_channel, condition_dim):
60
+ super().__init__()
61
+
62
+ self.bn = nn.BatchNorm2d(in_channel, affine=False)
63
+ self.linear1 = nn.Linear(condition_dim, in_channel)
64
+ self.linear2 = nn.Linear(condition_dim, in_channel)
65
+
66
+ def forward(self, input, condition):
67
+ out = self.bn(input)
68
+ gamma, beta = self.linear1(condition), self.linear2(condition)
69
+ gamma = gamma.unsqueeze(2).unsqueeze(3)
70
+ beta = beta.unsqueeze(2).unsqueeze(3)
71
+ out = gamma * out + beta
72
+
73
+ return out
74
+
75
+ class ConvBlock(nn.Module):
76
+ def __init__(self, in_channel, out_channel, kernel_size=[3, 3],
77
+ padding=1, stride=1, condition_dim=None, bn=True,
78
+ activation=F.relu, upsample=True, downsample=False):
79
+ super().__init__()
80
+
81
+ gain = 2 ** 0.5
82
+
83
+ self.conv1 = spectral_init(nn.Conv2d(in_channel, out_channel,
84
+ kernel_size, stride, padding,
85
+ bias=False if bn else True),
86
+ gain=gain)
87
+ self.conv2 = spectral_init(nn.Conv2d(out_channel, out_channel,
88
+ kernel_size, stride, padding,
89
+ bias=False if bn else True),
90
+ gain=gain)
91
+
92
+ self.skip_proj = False
93
+ if in_channel != out_channel or upsample or downsample:
94
+ self.conv_skip = spectral_init(nn.Conv2d(in_channel, out_channel,
95
+ 1, 1, 0))
96
+ self.skip_proj = True
97
+
98
+ self.upsample = upsample
99
+ self.downsample = downsample
100
+ self.activation = activation
101
+ self.bn = bn
102
+
103
+ if bn:
104
+ self.norm1 = ConditionalNorm(in_channel, condition_dim)
105
+ self.norm2 = ConditionalNorm(out_channel, condition_dim)
106
+
107
+ def forward(self, input, condition=None, condition1=None):
108
+ out = input
109
+
110
+ if self.bn:
111
+ out = self.norm1(out, condition)
112
+ out = self.activation(out)
113
+ if self.upsample:
114
+ out = F.interpolate(out, scale_factor=2, mode='nearest')
115
+ out = self.conv1(out)
116
+ if self.bn:
117
+ out = self.norm2(out, condition)
118
+ out = self.activation(out)
119
+ out = self.conv2(out)
120
+
121
+ if self.downsample:
122
+ out = F.avg_pool2d(out, 2)
123
+
124
+ if self.skip_proj:
125
+ skip = input
126
+ if self.upsample:
127
+ skip = F.interpolate(skip, scale_factor=2, mode='nearest')
128
+ skip = self.conv_skip(skip)
129
+ if self.downsample:
130
+ skip = F.avg_pool2d(skip, 2)
131
+ else:
132
+ skip = input
133
+
134
+ return out + skip
135
+
136
+ class SelfAttention(nn.Module):
137
+ def __init__(self, in_channel, embed_dim, gain=2 ** 0.5):
138
+ super().__init__()
139
+
140
+ self.query = spectral_init(nn.Conv1d(in_channel, embed_dim, 1),
141
+ gain=gain)
142
+ self.key = spectral_init(nn.Conv1d(in_channel, embed_dim, 1),
143
+ gain=gain)
144
+ self.value = spectral_init(nn.Conv1d(in_channel, in_channel, 1),
145
+ gain=gain)
146
+
147
+ self.gamma = nn.Parameter(torch.tensor(0.0))
148
+
149
+ def forward(self, input): # [bsz, channel, freq, time]
150
+ shape = input.shape
151
+ flatten = input.view(shape[0], shape[1], -1) # [bsz, channel, freq*time]
152
+ query = self.query(flatten).permute(0, 2, 1)
153
+ key = self.key(flatten)
154
+ value = self.value(flatten)
155
+ query_key = torch.bmm(query, key) # [bsz, freq*time, freq*time]
156
+ attention_map = F.softmax(query_key, 1)
157
+ out = torch.bmm(value, attention_map)
158
+ out = out.view(*shape)
159
+ out = self.gamma * out + input
160
+
161
+ return (out, attention_map)
162
+
163
+ class CrossAttention(nn.Module):
164
+ def __init__(self, in_channel, cond_channel, embed_dim, gain=2 ** 0.5):
165
+ super().__init__()
166
+
167
+ self.key = spectral_init(nn.Conv1d(cond_channel, embed_dim, 1),
168
+ gain=gain)
169
+ self.value = spectral_init(nn.Conv1d(cond_channel, in_channel, 1),
170
+ gain=gain)
171
+ self.query = spectral_init(nn.Conv1d(in_channel, embed_dim, 1),
172
+ gain=gain)
173
+
174
+ self.gamma = nn.Parameter(torch.tensor(0.0))
175
+
176
+ def forward(self, input, condition, sequence_lengths=None):
177
+ # input : mel [bsz, channel, freq, time] or sentence [bsz, channel]
178
+ # condition : sentence [bsz, channel] or word [bsz, word_num, channel]
179
+ input_shape = input.shape
180
+ if len(input.shape) == 4: # mel [bsz, channel, freq, time]
181
+ batch_size, c, w, h = input.shape
182
+ num = w * h
183
+ x = input.reshape([batch_size, c, num]) #[bsz, channel, input_num]
184
+ elif len(input.shape) == 2: # sentence [bsz, channel]
185
+ batch_size, c = input.shape
186
+ num = 1
187
+ x = input.unsqueeze(2) # [bsz, channel, input_num]
188
+
189
+ if len(condition.shape) == 2: # sentence [bsz, channel]
190
+ condition = condition.unsqueeze(2) # [bsz, channel, cond_num]
191
+ else: # word [bsz, word_num, channel]
192
+ condition = condition.permute(0, 2, 1) # [bsz, channel, cond_num]
193
+
194
+ query = self.query(x).permute(0, 2, 1) # [bsz, input_num, channel]
195
+ key = self.key(condition) # [bsz, channel, cond_num]
196
+ value = self.value(condition).permute(0, 2, 1) # [bsz, cond_num, channel]
197
+ attention_map = torch.bmm(query, key) # [bsz, input_num, cond_num]
198
+
199
+ if sequence_lengths is not None: # condition is word embedding
200
+ total_len = condition.shape[2]
201
+
202
+ mask = torch.tile(torch.arange(total_len), [batch_size, num, 1]).to(condition.device)
203
+ for i in range(batch_size):
204
+ sequence_lengths_i = sequence_lengths[i]
205
+ mask[i,:,:] = mask[i,:,:] >= sequence_lengths_i.item()
206
+ attention_map = attention_map + mask * (-1e9)
207
+
208
+ attention_map = F.softmax(attention_map, dim=-1) # [bsz, input_num, cond_num]
209
+ out = torch.bmm(attention_map, value).permute(0, 2, 1) # [bsz, input_num, channel]
210
+ out = out.permute(0, 2, 1).reshape(input_shape).squeeze()
211
+ out = self.gamma * out + input
212
+
213
+ return out, attention_map
214
+
215
+ class Spec_Attention(nn.Module):
216
+ def __init__(self, in_channel, cond_channel=None, embed_dim=64, gain=2 ** 0.5):
217
+ super().__init__()
218
+ if cond_channel is None:
219
+ cond_channel = in_channel
220
+
221
+ self.f_query = spectral_init(nn.Conv1d(in_channel, embed_dim, 1),
222
+ gain=gain)
223
+ self.t_key = spectral_init(nn.Conv1d(cond_channel, embed_dim, 1),
224
+ gain=gain)
225
+
226
+ self.t_query = spectral_init(nn.Conv1d(in_channel, embed_dim, 1),
227
+ gain=gain)
228
+ self.f_key = spectral_init(nn.Conv1d(cond_channel, embed_dim, 1),
229
+ gain=gain)
230
+
231
+ self.value = spectral_init(nn.Conv1d(cond_channel, in_channel, 1),
232
+ gain=gain)
233
+
234
+ self.gamma = nn.Parameter(torch.tensor(0.0))
235
+
236
+ def forward(self, input, condition=None, sequence_lengths=None):
237
+ # input : mel [bsz, channel, freq, time]
238
+ # condition : sentence [bsz, channel] or word [bsz, word_num, channel]
239
+
240
+ batch_size, c, f, t = input.shape
241
+
242
+ freq_embedding = input.mean(dim=3) # [bsz, channel, freq]
243
+ time_embedding = input.mean(dim=2) # [bsz, channel, time]
244
+
245
+ if condition is not None:
246
+ if len(condition.shape) == 2: # sentence [bsz, channel]
247
+ condition = condition.unsqueeze(2) # [bsz, channel, 1]
248
+ else: # word [bsz, word_num, channel]
249
+ condition = condition.permute(0, 2, 1) # [bsz, channel, cond_num]
250
+ t_condition = condition
251
+ f_condition = condition
252
+ else:
253
+ t_condition = time_embedding
254
+ f_condition = freq_embedding
255
+
256
+ f_query = self.f_query(freq_embedding).permute(0, 2, 1) # [bsz, freq, channel]
257
+ t_key = self.t_key(t_condition) # [bsz, channel, time] or [bsz, channel, cond_num]
258
+ freq_cond_map = torch.bmm(f_query, t_key) # [bsz, freq, time] or [bsz, freq, cond_num]
259
+
260
+ t_query = self.t_query(time_embedding).permute(0, 2, 1) # [bsz, time, channel]
261
+ f_key = self.f_key(f_condition) # [bsz, channel, freq] or [bsz, channel, cond_num]
262
+ time_cond_map = torch.bmm(t_query, f_key) # [bsz, time, freq] or [bsz, time, cond_num]
263
+
264
+ if sequence_lengths is not None: # condition is word embedding
265
+ total_len = condition.shape[2]
266
+
267
+ mask = torch.arange(total_len, device=condition.device)[None, None, :]
268
+ mask = mask >= sequence_lengths[:, None, None]
269
+
270
+ freq_cond_map = freq_cond_map + mask * (-1e9)
271
+ time_cond_map = time_cond_map + mask * (-1e9)
272
+
273
+ freq_cond_map = F.softmax(freq_cond_map, dim=-1) # [bsz, freq, time] or [bsz, freq, cond_num]
274
+ time_cond_map = F.softmax(time_cond_map, dim=-1) # [bsz, time, freq] or [bsz, time, cond_num]
275
+
276
+ if condition is None:
277
+ freq_time_embedding = input.reshape([batch_size, c, f*t]) # [bsz, channel, freq*time]
278
+ weight_map = torch.add(freq_cond_map, time_cond_map.permute(0, 2, 1)).reshape([batch_size, f*t]).unsqueeze(-1) # [bsz, freq*time, 1]
279
+ value = self.value(freq_time_embedding).permute(0, 2, 1) # [bsz, freq*time, channel]
280
+ out = torch.mul(value, weight_map).permute(0, 2, 1).reshape(batch_size, c, f, t) # [bsz, channel, freq, time]
281
+ else:
282
+ freq_cond_map = torch.tile(freq_cond_map.unsqueeze(2), [1, 1, t, 1]) # [bsz, freq, time, cond_num]
283
+ time_cond_map = torch.tile(time_cond_map.unsqueeze(1), [1, f, 1, 1]) # [bsz, freq, time, cond_num]
284
+ weight_map = torch.add(freq_cond_map, time_cond_map).reshape([batch_size, f*t, -1]) # [bsz, freq*time, cond_num]
285
+ value = self.value(condition).permute(0, 2, 1) # [bsz, cond_num, channel]
286
+ out = torch.bmm(weight_map, value).permute(0, 2, 1).reshape(batch_size, c, f, t) # [bsz, channel, freq, time]
287
+
288
+ out = self.gamma * out + input
289
+
290
+ return out, weight_map
291
+
292
+ class Multi_Triple_Attention(nn.Module):
293
+ def __init__(self, in_channel, sentence_embed_dim=768, word_embed_dim=768, embed_dim=64, reverse=False, gain=2 ** 0.5, n_heads=2, attention_list="self,word,sentence", spec_attention=False):
294
+ super().__init__()
295
+ self.reverse = reverse
296
+ self.n_heads = n_heads
297
+ self.attention_list = attention_list.split(",")
298
+
299
+ if "self" in self.attention_list:
300
+ if spec_attention:
301
+ self.self_attention_modules = nn.ModuleList([Spec_Attention(in_channel, embed_dim=embed_dim) for _ in range(self.n_heads)])
302
+ else:
303
+ self.self_attention_modules = nn.ModuleList([SelfAttention(in_channel, embed_dim=embed_dim) for _ in range(self.n_heads)])
304
+
305
+ if "word" in self.attention_list:
306
+ if spec_attention:
307
+ self.cross_attention_for_word_modules = nn.ModuleList([Spec_Attention(in_channel, cond_channel=word_embed_dim, embed_dim=embed_dim) for _ in range(self.n_heads)])
308
+ else:
309
+ self.cross_attention_for_word_modules = nn.ModuleList([CrossAttention(in_channel, cond_channel=word_embed_dim, embed_dim=embed_dim) for _ in range(self.n_heads)])
310
+
311
+ if "sentence" in self.attention_list:
312
+ if spec_attention:
313
+ self.cross_attention_for_sent_modules = nn.ModuleList([Spec_Attention(in_channel, cond_channel=sentence_embed_dim, embed_dim=embed_dim) for _ in range(self.n_heads)])
314
+ else:
315
+ self.cross_attention_for_sent_modules = nn.ModuleList([CrossAttention(in_channel, cond_channel=sentence_embed_dim, embed_dim=embed_dim) for _ in range(self.n_heads)])
316
+
317
+ self.gamma = [nn.Parameter(torch.tensor(0.0)) for _ in range(self.n_heads)]
318
+
319
+ self.conv_for_attention = spectral_init(nn.Conv1d(in_channel * len(self.attention_list), in_channel, 1), gain=gain)
320
+
321
+ self.out = spectral_init(nn.Conv1d(in_channel * self.n_heads, in_channel, 1), gain=gain)
322
+
323
+ def forward(self, input, sentence_embedding, word_embedding, sequence_lengths):
324
+ batch_size, c, f, t = input.shape
325
+ x = input
326
+
327
+ result = []
328
+ for head in range(self.n_heads):
329
+ out_list = []
330
+ if "self" in self.attention_list:
331
+ x_self, attention_map = self.self_attention_modules[head](x)
332
+ out_list.append(x_self)
333
+ if "word" in self.attention_list:
334
+ x_word, attention_map = self.cross_attention_for_word_modules[head](x, word_embedding, sequence_lengths)
335
+ out_list.append(x_word)
336
+ if "sentence" in self.attention_list:
337
+ x_sent, attention_map = self.cross_attention_for_sent_modules[head](x, sentence_embedding)
338
+ out_list.append(x_sent)
339
+ out = torch.cat(out_list, dim=1)
340
+ out = self.conv_for_attention(out.reshape([batch_size, c*len(out_list), f*t])).reshape([batch_size, c, f, t])
341
+ out = self.gamma[head] * out + x
342
+ result.append(out)
343
+ x = torch.cat(result, dim=1)
344
+ x = self.out(x.reshape([batch_size, c * self.n_heads, f*t])).reshape([batch_size, c, f, t])
345
+
346
+ x = input + x
347
+
348
+ return x
349
+
350
+ class Generator(nn.Module):
351
+ def __init__(self, model_config=None):
352
+ super().__init__()
353
+
354
+ if model_config is None:
355
+ model_config = {
356
+ "noise_dim":128,
357
+ "g_chaneel":128,
358
+ "n_heads":10,
359
+ "sentence_embed_dim":512,
360
+ "word_embed_dim":768,
361
+ "attention_list":["self,word,sentence", "word,sentence", "sentence"],
362
+ "spec_attention":True,
363
+ }
364
+
365
+ self.noise_dim = model_config['noise_dim']
366
+ self.channel = model_config['g_chaneel']
367
+ self.n_heads = model_config['n_heads']
368
+
369
+ self.sentence_embed_dim = model_config['sentence_embed_dim']
370
+ self.word_embed_dim = model_config['word_embed_dim']
371
+
372
+ self.attention_list = model_config['attention_list']
373
+ self.spec_attention = model_config['spec_attention']
374
+
375
+ channel_list = [self.channel, self.channel, self.channel//2, self.channel//2, self.channel//4, self.channel//4, self.channel//4, self.channel//8, self.channel//8]
376
+
377
+ self.lin_code = spectral_init(nn.Linear(self.noise_dim, channel_list[0] * 2 * 32))
378
+
379
+ self.conv1 = ConvBlock(channel_list[0], channel_list[1], condition_dim=self.sentence_embed_dim)
380
+ self.conv2 = ConvBlock(channel_list[1], channel_list[2], condition_dim=self.sentence_embed_dim)
381
+ self.multi_triple_attention_1 = Multi_Triple_Attention(channel_list[2],
382
+ sentence_embed_dim=self.sentence_embed_dim,
383
+ word_embed_dim=self.word_embed_dim,
384
+ embed_dim=channel_list[2],
385
+ reverse=False,
386
+ n_heads=self.n_heads,
387
+ attention_list=self.attention_list[0],
388
+ spec_attention=self.spec_attention)
389
+
390
+ self.conv3 = ConvBlock(channel_list[2], channel_list[3], condition_dim=self.sentence_embed_dim)
391
+ self.conv4 = ConvBlock(channel_list[3], channel_list[4], condition_dim=self.sentence_embed_dim, upsample=False)
392
+ self.multi_triple_attention_2 = Multi_Triple_Attention(channel_list[4],
393
+ sentence_embed_dim=self.sentence_embed_dim,
394
+ word_embed_dim=self.word_embed_dim,
395
+ embed_dim=channel_list[4],
396
+ reverse=False,
397
+ n_heads=self.n_heads,
398
+ attention_list=self.attention_list[1],
399
+ spec_attention=self.spec_attention)
400
+
401
+ self.conv5 = ConvBlock(channel_list[4], channel_list[5], condition_dim=self.sentence_embed_dim)
402
+ self.conv6 = ConvBlock(channel_list[5], channel_list[6], condition_dim=self.sentence_embed_dim, upsample=False)
403
+ self.multi_triple_attention_3 = Multi_Triple_Attention(channel_list[6],
404
+ sentence_embed_dim=self.sentence_embed_dim,
405
+ word_embed_dim=self.word_embed_dim,
406
+ embed_dim=channel_list[6],
407
+ reverse=False,
408
+ n_heads=self.n_heads,
409
+ attention_list=self.attention_list[2],
410
+ spec_attention=self.spec_attention)
411
+
412
+ self.conv7 = ConvBlock(channel_list[6], channel_list[7], condition_dim=self.sentence_embed_dim)
413
+ self.bn = nn.BatchNorm2d(channel_list[8])
414
+ self.colorize = spectral_init(nn.Conv1d(channel_list[8], 1, 1))
415
+
416
+ def forward(self, z, sentence_embedding, word_embedding, sequence_lengths):
417
+ batch_size = z.shape[0]
418
+
419
+ x = self.lin_code(z)
420
+ x = x.view(-1, self.channel, 2, 32) # [bsz, c, 2, 32]
421
+
422
+ x = self.conv1(x, sentence_embedding) # [bsz, c, 4, 64]
423
+ x = self.conv2(x, sentence_embedding) # [bsz, c, 8, 128]
424
+ x = self.multi_triple_attention_1(x, sentence_embedding, word_embedding, sequence_lengths) # [bsz, c, 8, 128]
425
+
426
+ x = self.conv3(x, sentence_embedding) # [bsz, c, 16, 256]
427
+ x = self.conv4(x, sentence_embedding) # [bsz, c, 16, 256]
428
+ x = self.multi_triple_attention_2(x, sentence_embedding, word_embedding, sequence_lengths) # [bsz, c, 16, 256]
429
+
430
+ x = self.conv5(x, sentence_embedding) # [bsz, c, 32, 512]
431
+ x = self.conv6(x, sentence_embedding) # [bsz, c, 32, 512]
432
+ x = self.multi_triple_attention_3(x, sentence_embedding, word_embedding, sequence_lengths) # [bsz, c, 32, 512]
433
+
434
+ x = self.conv7(x, sentence_embedding) # [bsz, c, 64, 1024]
435
+ x = self.bn(x) # [bsz, c // 8, 64, 1024]
436
+ x = F.relu(x)
437
+ x = self.colorize(x.reshape([batch_size, -1, 64*1024])).reshape([batch_size, 1, 64, 1024]) # [bsz, 1, 64, 1024]
438
+
439
+ return x