jingyi49 commited on
Commit
a375d05
·
verified ·
1 Parent(s): 00deded

Upload 4 files

Browse files
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  cp_8_gm_post/vae_00035000 filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  cp_8_gm_post/vae_00035000 filter=lfs diff=lfs merge=lfs -text
37
+ cp_8_vgg3/vae_00620000 filter=lfs diff=lfs merge=lfs -text
cp_8_vgg3/config_96.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "input_training_wav_list": "/cto_studio/lijingyi/recon/filelist.train",
3
+ "input_validation_wav_list": "/cto_studio/lijingyi/recon/filelist.val",
4
+ "test_input_wavs_dir":"/cto_studio/vistring/zhaozhiyuan/datasets/AudioSet/wavs/test",
5
+ "test_wav_output_dir":"cosmos96_output",
6
+
7
+ "batch_size": 64,
8
+ "learning_rate": 0.0004,
9
+ "adam_b1": 0.9,
10
+ "adam_b2": 0.999,
11
+ "lr_decay": 0.999,
12
+ "seed": 1234,
13
+ "training_epochs": 3100,
14
+ "stdout_interval":5,
15
+ "checkpoint_interval": 5000,
16
+ "summary_interval": 100,
17
+ "validation_interval": 5000,
18
+ "checkpoint_path": "cp_8_vgg3",
19
+ "checkpoint_file_load_Encoder": "cp_8_vgg3/encoder_01000000",
20
+ "checkpoint_file_load_Decoder": "cp_8_vgg3/decoder_01000000",
21
+
22
+ "segment_size": 24320,
23
+ "num_mels": 96,
24
+ "n_fft": 1024,
25
+ "hop_length": 256,
26
+
27
+ "sampling_rate": 44100,
28
+ "num_workers": 4
29
+ }
cp_8_vgg3/logs/events.out.tfevents.1757473938.dgx056.scc.idea.898206.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c62faac52d53906caaea59750455f83f8610fadeeccfb57575925b12e76062f
3
+ size 22458961
cp_8_vgg3/train_96_vgg.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ warnings.simplefilter(action='ignore', category=FutureWarning)
3
+ import itertools
4
+ import os
5
+ import time
6
+ import argparse
7
+ import json
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch.utils.tensorboard import SummaryWriter
11
+ from torch.utils.data import DistributedSampler, DataLoader
12
+ import torch.multiprocessing as mp
13
+ from torch.distributed import init_process_group
14
+ from torch.nn.parallel import DistributedDataParallel
15
+ from models import amplitude_loss
16
+ from dataset import Dataset, mel_spectrogram, get_dataset_filelist
17
+ import cosmos_tokenizer
18
+ from discrete_img import DiscreteImageTokenizer
19
+ from utils import AttrDict, build_env, plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint
20
+ from vocos import Vocos
21
+ import shutil
22
+ from vgg19 import VGG19
23
+
24
+ torch.backends.cudnn.benchmark = True
25
+
26
+ import torch.multiprocessing as mp
27
+ mp.set_start_method("spawn", force=True)
28
+
29
+ from CosmosTokenizer.cosmos_tokenizer.modules import DecoderType, DiscreteQuantizer, EncoderType
30
+
31
+ params = dict(
32
+ attn_resolutions=[6, 12],
33
+ channels=128,
34
+ channels_mult=[2, 4, 4],
35
+ dropout=0.0,
36
+ in_channels=1,
37
+ spatial_compression=8,
38
+ num_res_blocks=2,
39
+ out_channels=1,
40
+ resolution=96,
41
+ patch_size=2,
42
+ patch_method="haar",
43
+ z_channels=256,
44
+ z_factor=2,
45
+ quantizer=DiscreteQuantizer.VQ.name,
46
+ embedding_dim=64,
47
+ num_embeddings=8192,
48
+ num_quantizers=1,
49
+ name="DI",
50
+ encoder=EncoderType.Default.name,
51
+ decoder=DecoderType.Default.name,
52
+ )
53
+
54
+ def train(h):
55
+
56
+ torch.cuda.manual_seed(h.seed)
57
+ device = torch.device('cuda:{:d}'.format(0))
58
+ model = DiscreteImageTokenizer(**params).to(device)
59
+ feature_extractor = VGG19().to(device)
60
+
61
+ print("Model: ")
62
+ print(model)
63
+ os.makedirs(h.checkpoint_path, exist_ok=True)
64
+ print("checkpoints directory : ", h.checkpoint_path)
65
+
66
+ if os.path.isdir(h.checkpoint_path):
67
+ cp_model = scan_checkpoint(h.checkpoint_path, 'vae_')
68
+
69
+
70
+ steps = 0
71
+ if cp_model is None:
72
+ state_dict_vae = None
73
+ last_epoch = -1
74
+ else:
75
+ state_dict_vae = load_checkpoint(cp_model, device)
76
+ model.load_state_dict(state_dict_vae['encoder'])
77
+ steps = 0
78
+ last_epoch = -1
79
+
80
+ optim_g = torch.optim.AdamW(itertools.chain(model.parameters()), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
81
+
82
+ if state_dict_vae is not None:
83
+ model.load_state_dict(state_dict_vae['encoder'])
84
+
85
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
86
+
87
+ training_filelist, validation_filelist = get_dataset_filelist(h.input_training_wav_list, h.input_validation_wav_list)
88
+
89
+ trainset = Dataset(training_filelist, h.segment_size, h.n_fft, h.num_mels,
90
+ h.hop_length, h.sampling_rate, shuffle=True, device=device, train=True)
91
+
92
+ train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False,
93
+ sampler=None,
94
+ batch_size=h.batch_size,
95
+ pin_memory=True,
96
+ drop_last=True)
97
+
98
+ validset = Dataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels,
99
+ h.hop_length, h.sampling_rate, shuffle=False, device=device, train=False)
100
+
101
+ validation_loader = DataLoader(validset, num_workers=1, shuffle=False,
102
+ sampler=None,
103
+ batch_size=1,
104
+ pin_memory=True,
105
+ drop_last=True)
106
+
107
+ sw = SummaryWriter(os.path.join(h.checkpoint_path, 'logs'))
108
+
109
+ #model = model.to(dtype=torch.bfloat16)
110
+ model.train()
111
+
112
+ for epoch in range(max(0, last_epoch), h.training_epochs):
113
+
114
+ start = time.time()
115
+ print("Epoch: {}".format(epoch+1))
116
+
117
+ for i, batch in enumerate(train_loader):
118
+ start_b = time.time()
119
+ y_mel = batch
120
+ y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
121
+ #y_mel = y_mel.to(dtype=torch.bfloat16)
122
+ out_train = model(y_mel)
123
+ y_g_mel = out_train["reconstructions"]
124
+ # Generator
125
+ optim_g.zero_grad()
126
+ # Losses defined on log mel spectra
127
+ L_M = F.l1_loss(y_mel, y_g_mel)*5.0
128
+ Mel_L2_error = amplitude_loss(y_mel, y_g_mel)*25.0
129
+ quant_loss = out_train["quant_loss"].mean()
130
+ feat_in = torch.cat((((y_g_mel+1)/2).repeat(1, 3, 1, 1), ((y_mel+1)/2).repeat(1, 3, 1, 1)), 0)
131
+ feature_loss = feature_extractor(feat_in)
132
+ print(f"feature loss:{feature_loss}")
133
+ L_G = L_M + quant_loss*0.25+ feature_loss*0.5
134
+ L_G.backward()
135
+ optim_g.step()
136
+
137
+ # STDOUT logging
138
+ if steps % h.stdout_interval == 0:
139
+ with torch.no_grad():
140
+ Mel_error = (F.l1_loss(y_mel, y_g_mel)*5.0).item()
141
+ Mel_L2_error = (amplitude_loss(y_mel, y_g_mel)*25.0).item()
142
+ quant_loss = quant_loss.item()
143
+
144
+ print('Steps : {:d}, Gen Loss Total : {:4.3f}, Mel Spectrogram Loss : {:4.3f}, Mel Spectrogram L2 Loss : {:4.3f}, Quant Loss : {:4.3f}, s/b : {:4.3f}'.
145
+ format(steps, L_G, Mel_error, Mel_L2_error, quant_loss, time.time() - start_b))
146
+
147
+ # checkpointing
148
+ if steps % h.checkpoint_interval == 0 and steps != 0:
149
+ checkpoint_path = "{}/vae_{:08d}".format(h.checkpoint_path, steps)
150
+ save_checkpoint(checkpoint_path,
151
+ {'encoder': model.state_dict(),
152
+ 'steps': steps,
153
+ 'epoch': epoch})
154
+
155
+ # Tensorboard summary logging
156
+ if steps % h.summary_interval == 0:
157
+ sw.add_scalar("Training/Generator_Total_Loss", L_G, steps)
158
+ sw.add_scalar("Training/Mel_Spectrogram_Loss", Mel_error, steps)
159
+
160
+ # Validation
161
+ if steps % h.validation_interval == 0: # and steps != 0:
162
+ model.eval()
163
+ torch.cuda.empty_cache()
164
+ val_Mel_err_tot = 0
165
+ val_Mel_L2_err_tot = 0
166
+ with torch.no_grad():
167
+ for j, batch in enumerate(validation_loader):
168
+ y_mel = batch
169
+ #y_mel = y_mel.to(dtype=torch.bfloat16)
170
+ out_eval = model(y_mel.to(device))
171
+ y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
172
+ val_Mel_err_tot += (F.l1_loss(y_mel, out_eval.reconstructions)*5.0).item()
173
+ val_Mel_L2_err_tot += (amplitude_loss(y_mel, out_eval.reconstructions)*25.0).item()
174
+
175
+
176
+ if j <= 4:
177
+ if steps == 0:
178
+ y_plot_tensor = y_mel[0, 0] * 5.0
179
+ y_plot = y_plot_tensor.cpu().float().numpy() # 再转 numpy
180
+ sw.add_figure('gt/y_mel_{}'.format(j), plot_spectrogram(y_plot), steps)
181
+
182
+ y_plot_tensor_g = y_g_mel[0, 0] * 5.0
183
+ y_plot_g = y_plot_tensor_g.cpu().float().numpy()
184
+ sw.add_figure('generated/y_g_mel_{}'.format(j), plot_spectrogram(y_plot_g), steps)
185
+
186
+ val_Mel_err = val_Mel_err_tot / (j+1)
187
+ val_Mel_L2_err = val_Mel_L2_err_tot / (j+1)
188
+ sw.add_scalar("Validation/Mel_Spectrogram_loss", val_Mel_err, steps)
189
+ sw.add_scalar("Validation/Mel_Spectrogram_L2_loss", val_Mel_L2_err, steps)
190
+
191
+ model.train()
192
+
193
+ steps += 1
194
+
195
+ scheduler_g.step()
196
+
197
+ print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start)))
198
+
199
+
200
+ def main():
201
+ print('Initializing Training Process..')
202
+
203
+ config_file = 'config_96.json'
204
+
205
+ with open(config_file) as f:
206
+ data = f.read()
207
+
208
+ json_config = json.loads(data)
209
+ h = AttrDict(json_config)
210
+ build_env(config_file, 'config_96.json', h.checkpoint_path)
211
+
212
+ src = "train_96_vgg.py"
213
+ dst_dir = h.checkpoint_path
214
+ os.makedirs(dst_dir, exist_ok=True)
215
+ dst = os.path.join(dst_dir, "train_96_vgg.py")
216
+ if not os.path.exists(src):
217
+ raise FileNotFoundError(f"{src} 不存在!")
218
+ shutil.copyfile(src, dst)
219
+ print(f"已将 {src} 复制到 {dst}")
220
+
221
+
222
+ torch.manual_seed(h.seed)
223
+ if torch.cuda.is_available():
224
+ torch.cuda.manual_seed(h.seed)
225
+ else:
226
+ pass
227
+
228
+ train(h)
229
+
230
+
231
+ if __name__ == '__main__':
232
+ main()
233
+
cp_8_vgg3/vae_00620000 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5894a97f8027c1fe8bee2c8d72e9eff06f0d0b3c86934ab1b94fce612b0d7c40
3
+ size 357630932