jingyi49 commited on
Commit
3273452
·
verified ·
1 Parent(s): a375d05

Upload 4 files

Browse files
.gitattributes CHANGED
@@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  cp_8_gm_post/vae_00035000 filter=lfs diff=lfs merge=lfs -text
37
  cp_8_vgg3/vae_00620000 filter=lfs diff=lfs merge=lfs -text
 
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  cp_8_gm_post/vae_00035000 filter=lfs diff=lfs merge=lfs -text
37
  cp_8_vgg3/vae_00620000 filter=lfs diff=lfs merge=lfs -text
38
+ cp_8_vq3/vae_00430000 filter=lfs diff=lfs merge=lfs -text
cp_8_vq3/config_96.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "input_training_wav_list": "/cto_studio/lijingyi/recon/filelist.train",
3
+ "input_validation_wav_list": "/cto_studio/lijingyi/recon/filelist.val",
4
+ "test_input_wavs_dir":"/cto_studio/vistring/zhaozhiyuan/datasets/AudioSet/wavs/test",
5
+ "test_wav_output_dir":"cosmos96_output",
6
+
7
+ "batch_size": 64,
8
+ "learning_rate": 0.0004,
9
+ "adam_b1": 0.9,
10
+ "adam_b2": 0.999,
11
+ "lr_decay": 0.999,
12
+ "seed": 1234,
13
+ "training_epochs": 3100,
14
+ "stdout_interval":5,
15
+ "checkpoint_interval": 5000,
16
+ "summary_interval": 100,
17
+ "validation_interval": 5000,
18
+ "checkpoint_path": "cp_cosmos96_vq3",
19
+ "checkpoint_file_load_Encoder": "cp_cosmos96_vq3/encoder_01000000",
20
+ "checkpoint_file_load_Decoder": "cp_cosmos96_vq3/decoder_01000000",
21
+
22
+ "segment_size": 24320,
23
+ "num_mels": 96,
24
+ "n_fft": 1024,
25
+ "hop_length": 256,
26
+
27
+ "sampling_rate": 44100,
28
+ "num_workers": 4
29
+ }
cp_8_vq3/logs/events.out.tfevents.1757242860.dgx056.scc.idea.1476461.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b65909a44d61d7f3fb99fbf30b1955be1f94f5ccbc32de1077b9c91276f4ab2
3
+ size 11966985
cp_8_vq3/train_96.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ warnings.simplefilter(action='ignore', category=FutureWarning)
3
+ import itertools
4
+ import os
5
+ import time
6
+ import argparse
7
+ import json
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch.utils.tensorboard import SummaryWriter
11
+ from torch.utils.data import DistributedSampler, DataLoader
12
+ import torch.multiprocessing as mp
13
+ from torch.distributed import init_process_group
14
+ from torch.nn.parallel import DistributedDataParallel
15
+ from models import amplitude_loss
16
+ from dataset import Dataset, mel_spectrogram, get_dataset_filelist
17
+ import cosmos_tokenizer
18
+ from discrete_img import DiscreteImageTokenizer
19
+ from utils import AttrDict, build_env, plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint
20
+ from vocos import Vocos
21
+ import shutil
22
+ torch.backends.cudnn.benchmark = True
23
+
24
+ import torch.multiprocessing as mp
25
+ mp.set_start_method("spawn", force=True)
26
+
27
+ from CosmosTokenizer.cosmos_tokenizer.modules import DecoderType, DiscreteQuantizer, EncoderType
28
+
29
+ params = dict(
30
+ attn_resolutions=[6, 12],
31
+ channels=128,
32
+ channels_mult=[2, 4, 4],
33
+ dropout=0.0,
34
+ in_channels=1,
35
+ spatial_compression=8,
36
+ num_res_blocks=2,
37
+ out_channels=1,
38
+ resolution=96,
39
+ patch_size=2,
40
+ patch_method="haar",
41
+ z_channels=256,
42
+ z_factor=2,
43
+ quantizer=DiscreteQuantizer.VQ.name,
44
+ embedding_dim=64,
45
+ num_embeddings=8192,
46
+ num_quantizers=1,
47
+ name="DI",
48
+ encoder=EncoderType.Default.name,
49
+ decoder=DecoderType.Default.name,
50
+ )
51
+
52
+ def train(h):
53
+
54
+ torch.cuda.manual_seed(h.seed)
55
+ device = torch.device('cuda:{:d}'.format(0))
56
+ model = DiscreteImageTokenizer(**params).to(device)
57
+
58
+ print("Model: ")
59
+ print(model)
60
+ os.makedirs(h.checkpoint_path, exist_ok=True)
61
+ print("checkpoints directory : ", h.checkpoint_path)
62
+
63
+ if os.path.isdir(h.checkpoint_path):
64
+ cp_model = scan_checkpoint(h.checkpoint_path, 'vae_')
65
+
66
+
67
+ steps = 0
68
+ if cp_model is None:
69
+ state_dict_vae = None
70
+ last_epoch = -1
71
+ else:
72
+ state_dict_vae = load_checkpoint(cp_model, device)
73
+ model.load_state_dict(state_dict_vae['encoder'])
74
+ steps = 0
75
+ last_epoch = -1
76
+
77
+ optim_g = torch.optim.AdamW(itertools.chain(model.parameters()), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
78
+
79
+ if state_dict_vae is not None:
80
+ model.load_state_dict(state_dict_vae['encoder'])
81
+
82
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
83
+
84
+ training_filelist, validation_filelist = get_dataset_filelist(h.input_training_wav_list, h.input_validation_wav_list)
85
+
86
+ trainset = Dataset(training_filelist, h.segment_size, h.n_fft, h.num_mels,
87
+ h.hop_length, h.sampling_rate, shuffle=True, device=device, train=True)
88
+
89
+ train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False,
90
+ sampler=None,
91
+ batch_size=h.batch_size,
92
+ pin_memory=True,
93
+ drop_last=True)
94
+
95
+ validset = Dataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels,
96
+ h.hop_length, h.sampling_rate, shuffle=False, device=device, train=False)
97
+
98
+ validation_loader = DataLoader(validset, num_workers=1, shuffle=False,
99
+ sampler=None,
100
+ batch_size=1,
101
+ pin_memory=True,
102
+ drop_last=True)
103
+
104
+ sw = SummaryWriter(os.path.join(h.checkpoint_path, 'logs'))
105
+
106
+ #model = model.to(dtype=torch.bfloat16)
107
+ model.train()
108
+
109
+ for epoch in range(max(0, last_epoch), h.training_epochs):
110
+
111
+ start = time.time()
112
+ print("Epoch: {}".format(epoch+1))
113
+
114
+ for i, batch in enumerate(train_loader):
115
+ start_b = time.time()
116
+ y_mel = batch
117
+ y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
118
+ #y_mel = y_mel.to(dtype=torch.bfloat16)
119
+ out_train = model(y_mel)
120
+ y_g_mel = out_train["reconstructions"]
121
+ # Generator
122
+ optim_g.zero_grad()
123
+ # Losses defined on log mel spectra
124
+ L_M = F.l1_loss(y_mel, y_g_mel)*5.0
125
+ Mel_L2_error = amplitude_loss(y_mel, y_g_mel)*25.0
126
+ quant_loss = out_train["quant_loss"].mean()
127
+ L_G = L_M + quant_loss*0.25
128
+ L_G.backward()
129
+ optim_g.step()
130
+
131
+ # STDOUT logging
132
+ if steps % h.stdout_interval == 0:
133
+ with torch.no_grad():
134
+ Mel_error = (F.l1_loss(y_mel, y_g_mel)*5.0).item()
135
+ Mel_L2_error = (amplitude_loss(y_mel, y_g_mel)*25.0).item()
136
+ quant_loss = quant_loss.item()
137
+
138
+ print('Steps : {:d}, Gen Loss Total : {:4.3f}, Mel Spectrogram Loss : {:4.3f}, Mel Spectrogram L2 Loss : {:4.3f}, Quant Loss : {:4.3f}, s/b : {:4.3f}'.
139
+ format(steps, L_G, Mel_error, Mel_L2_error, quant_loss, time.time() - start_b))
140
+
141
+ # checkpointing
142
+ if steps % h.checkpoint_interval == 0 and steps != 0:
143
+ checkpoint_path = "{}/vae_{:08d}".format(h.checkpoint_path, steps)
144
+ save_checkpoint(checkpoint_path,
145
+ {'encoder': model.state_dict(),
146
+ 'steps': steps,
147
+ 'epoch': epoch})
148
+
149
+ # Tensorboard summary logging
150
+ if steps % h.summary_interval == 0:
151
+ sw.add_scalar("Training/Generator_Total_Loss", L_G, steps)
152
+ sw.add_scalar("Training/Mel_Spectrogram_Loss", Mel_error, steps)
153
+
154
+ # Validation
155
+ if steps % h.validation_interval == 0: # and steps != 0:
156
+ model.eval()
157
+ torch.cuda.empty_cache()
158
+ val_Mel_err_tot = 0
159
+ val_Mel_L2_err_tot = 0
160
+ with torch.no_grad():
161
+ for j, batch in enumerate(validation_loader):
162
+ y_mel = batch
163
+ #y_mel = y_mel.to(dtype=torch.bfloat16)
164
+ out_eval = model(y_mel.to(device))
165
+ y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
166
+ val_Mel_err_tot += (F.l1_loss(y_mel, out_eval.reconstructions)*5.0).item()
167
+ val_Mel_L2_err_tot += (amplitude_loss(y_mel, out_eval.reconstructions)*25.0).item()
168
+
169
+
170
+ if j <= 4:
171
+ if steps == 0:
172
+ y_plot_tensor = y_mel[0, 0] * 5.0
173
+ y_plot = y_plot_tensor.cpu().float().numpy() # 再转 numpy
174
+ sw.add_figure('gt/y_mel_{}'.format(j), plot_spectrogram(y_plot), steps)
175
+
176
+ y_plot_tensor_g = y_g_mel[0, 0] * 5.0
177
+ y_plot_g = y_plot_tensor_g.cpu().float().numpy()
178
+ sw.add_figure('generated/y_g_mel_{}'.format(j), plot_spectrogram(y_plot_g), steps)
179
+
180
+ val_Mel_err = val_Mel_err_tot / (j+1)
181
+ val_Mel_L2_err = val_Mel_L2_err_tot / (j+1)
182
+ sw.add_scalar("Validation/Mel_Spectrogram_loss", val_Mel_err, steps)
183
+ sw.add_scalar("Validation/Mel_Spectrogram_L2_loss", val_Mel_L2_err, steps)
184
+
185
+ model.train()
186
+
187
+ steps += 1
188
+
189
+ scheduler_g.step()
190
+
191
+ print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start)))
192
+
193
+
194
+ def main():
195
+ print('Initializing Training Process..')
196
+
197
+ config_file = 'config_96.json'
198
+
199
+ with open(config_file) as f:
200
+ data = f.read()
201
+
202
+ json_config = json.loads(data)
203
+ h = AttrDict(json_config)
204
+ build_env(config_file, 'config_96.json', h.checkpoint_path)
205
+
206
+ src = "train_96.py"
207
+ dst_dir = h.checkpoint_path
208
+ os.makedirs(dst_dir, exist_ok=True)
209
+ dst = os.path.join(dst_dir, "train_96.py")
210
+ if not os.path.exists(src):
211
+ raise FileNotFoundError(f"{src} 不存在!")
212
+ shutil.copyfile(src, dst)
213
+ print(f"已将 {src} 复制到 {dst}")
214
+
215
+
216
+ torch.manual_seed(h.seed)
217
+ if torch.cuda.is_available():
218
+ torch.cuda.manual_seed(h.seed)
219
+ else:
220
+ pass
221
+
222
+ train(h)
223
+
224
+
225
+ if __name__ == '__main__':
226
+ main()
227
+
cp_8_vq3/vae_00430000 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50e8b4e9da2208f25197a4a303afa4c340d45fe147f3c466b6490175d67c758f
3
+ size 357630932