jingyi49 commited on
Commit
00deded
·
verified ·
1 Parent(s): bffeaa4

Upload 6 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ cp_8_gm_post/vae_00035000 filter=lfs diff=lfs merge=lfs -text
cp_8_gm_post/config_post.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "input_training_wav_list": "/cto_studio/lijingyi/recon/filelist.train",
3
+ "input_validation_wav_list": "/cto_studio/lijingyi/recon/filelist.val",
4
+ "test_input_wavs_dir":"/cto_studio/vistring/zhaozhiyuan/datasets/AudioSet/wavs/test",
5
+ "test_wav_output_dir":"cosmos96_output",
6
+
7
+ "batch_size": 64,
8
+ "learning_rate": 0.0002,
9
+ "adam_b1": 0.9,
10
+ "adam_b2": 0.999,
11
+ "lr_decay": 0.999,
12
+ "seed": 1234,
13
+ "training_epochs": 3100,
14
+ "stdout_interval":5,
15
+ "checkpoint_interval": 5000,
16
+ "summary_interval": 100,
17
+ "validation_interval": 5000,
18
+ "checkpoint_path": "cp_8_gm_post",
19
+ "checkpoint_file_load_Encoder": "cp_8_gm_post/encoder_01000000",
20
+ "checkpoint_file_load_Decoder": "cp_8_gm_post/decoder_01000000",
21
+
22
+ "segment_size": 24320,
23
+ "num_mels": 96,
24
+ "n_fft": 1024,
25
+ "hop_length": 256,
26
+
27
+ "sampling_rate": 44100,
28
+ "num_workers": 4
29
+ }
cp_8_gm_post/logs/events.out.tfevents.1758006916.dgx056.scc.idea.3848759.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a3f40770fcb594317f1037d3a147c2b62ca84da160a40cec5f6a0b47514b3de
3
+ size 348211
cp_8_gm_post/logs/events.out.tfevents.1758016570.dgx056.scc.idea.896021.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a3df1daa3157ca86f0678f14dba7d847ab5713dd13cc73b89f2282708f6d5d7
3
+ size 88
cp_8_gm_post/logs/events.out.tfevents.1758016704.dgx056.scc.idea.917530.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33d1f8ba85ec481e3bc1556803d4e9f307f0026ac514fc6ffb56e737ea627932
3
+ size 1692944
cp_8_gm_post/train_96_post_gm.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ warnings.simplefilter(action='ignore', category=FutureWarning)
3
+ import itertools
4
+ import os
5
+ import time
6
+ import argparse
7
+ import json
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch.utils.tensorboard import SummaryWriter
11
+ from torch.utils.data import DistributedSampler, DataLoader
12
+ import torch.multiprocessing as mp
13
+ from torch.distributed import init_process_group
14
+ from torch.nn.parallel import DistributedDataParallel
15
+ from models import amplitude_loss
16
+ from dataset import Dataset, mel_spectrogram, get_dataset_filelist
17
+ import cosmos_tokenizer
18
+ from discrete_img import DiscreteImageTokenizer
19
+ from utils import AttrDict, build_env, plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint
20
+ from vocos import Vocos
21
+ import shutil
22
+ from vgg19 import VGG19
23
+
24
+ torch.backends.cudnn.benchmark = True
25
+
26
+ import torch.multiprocessing as mp
27
+ mp.set_start_method("spawn", force=True)
28
+
29
+ from CosmosTokenizer.cosmos_tokenizer.modules import DecoderType, DiscreteQuantizer, EncoderType
30
+ from gm_loss import GM
31
+
32
+ params = dict(
33
+ attn_resolutions=[6, 12],
34
+ channels=128,
35
+ channels_mult=[2, 4, 4],
36
+ dropout=0.0,
37
+ in_channels=1,
38
+ spatial_compression=8,
39
+ num_res_blocks=2,
40
+ out_channels=1,
41
+ resolution=96,
42
+ patch_size=2,
43
+ patch_method="haar",
44
+ z_channels=256,
45
+ z_factor=2,
46
+ quantizer=DiscreteQuantizer.VQ.name,
47
+ embedding_dim=64,
48
+ num_embeddings=8192,
49
+ num_quantizers=1,
50
+ name="DI",
51
+ encoder=EncoderType.Default.name,
52
+ decoder=DecoderType.Default.name,
53
+ )
54
+
55
+ def train(h):
56
+
57
+ torch.cuda.manual_seed(h.seed)
58
+ device = torch.device('cuda:{:d}'.format(0))
59
+ model = DiscreteImageTokenizer(**params).to(device)
60
+ feature_extractor = GM().to(device)
61
+
62
+ print("Model: ")
63
+ print(model)
64
+ os.makedirs(h.checkpoint_path, exist_ok=True)
65
+ print("checkpoints directory : ", h.checkpoint_path)
66
+
67
+ if os.path.isdir(h.checkpoint_path):
68
+ cp_model = scan_checkpoint(h.checkpoint_path, 'vae_')
69
+
70
+
71
+ steps = 0
72
+ if cp_model is None:
73
+ state_dict_vae = None
74
+ last_epoch = -1
75
+ else:
76
+ state_dict_vae = load_checkpoint(cp_model, device)
77
+ model.load_state_dict(state_dict_vae['encoder'])
78
+ steps = 0
79
+ last_epoch = -1
80
+
81
+ optim_g = torch.optim.AdamW(itertools.chain(model.parameters()), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
82
+
83
+ if state_dict_vae is not None:
84
+ model.load_state_dict(state_dict_vae['encoder'])
85
+
86
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
87
+
88
+ training_filelist, validation_filelist = get_dataset_filelist(h.input_training_wav_list, h.input_validation_wav_list)
89
+
90
+ trainset = Dataset(training_filelist, h.segment_size, h.n_fft, h.num_mels,
91
+ h.hop_length, h.sampling_rate, shuffle=True, device=device, train=True)
92
+
93
+ train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False,
94
+ sampler=None,
95
+ batch_size=h.batch_size,
96
+ pin_memory=True,
97
+ drop_last=True)
98
+
99
+ validset = Dataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels,
100
+ h.hop_length, h.sampling_rate, shuffle=False, device=device, train=False)
101
+
102
+ validation_loader = DataLoader(validset, num_workers=1, shuffle=False,
103
+ sampler=None,
104
+ batch_size=1,
105
+ pin_memory=True,
106
+ drop_last=True)
107
+
108
+ sw = SummaryWriter(os.path.join(h.checkpoint_path, 'logs'))
109
+
110
+ #model = model.to(dtype=torch.bfloat16)
111
+ model.train()
112
+
113
+ for epoch in range(max(0, last_epoch), h.training_epochs):
114
+
115
+ start = time.time()
116
+ print("Epoch: {}".format(epoch+1))
117
+
118
+ for i, batch in enumerate(train_loader):
119
+ start_b = time.time()
120
+ y_mel = batch
121
+ y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
122
+ #y_mel = y_mel.to(dtype=torch.bfloat16)
123
+ out_train = model(y_mel)
124
+ y_g_mel = out_train["reconstructions"]
125
+ # Generator
126
+ optim_g.zero_grad()
127
+ # Losses defined on log mel spectra
128
+ L_M = F.l1_loss(y_mel, y_g_mel)*5.0
129
+ Mel_L2_error = amplitude_loss(y_mel, y_g_mel)*25.0
130
+ quant_loss = out_train["quant_loss"].mean()
131
+ feat_in = torch.cat((((y_g_mel+1)/2).repeat(1, 3, 1, 1), ((y_mel+1)/2).repeat(1, 3, 1, 1)), 0)
132
+ feature_loss = feature_extractor(feat_in)
133
+ print(f"feature loss:{feature_loss}")
134
+ L_G = L_M + quant_loss*0.25+ feature_loss*1e5
135
+ L_G.backward()
136
+ optim_g.step()
137
+
138
+ # STDOUT logging
139
+ if steps % h.stdout_interval == 0:
140
+ with torch.no_grad():
141
+ Mel_error = (F.l1_loss(y_mel, y_g_mel)*5.0).item()
142
+ Mel_L2_error = (amplitude_loss(y_mel, y_g_mel)*25.0).item()
143
+ quant_loss = quant_loss.item()
144
+
145
+ print('Steps : {:d}, Gen Loss Total : {:4.3f}, Mel Spectrogram Loss : {:4.3f}, Mel Spectrogram L2 Loss : {:4.3f}, Quant Loss : {:4.3f}, s/b : {:4.3f}'.
146
+ format(steps, L_G, Mel_error, Mel_L2_error, quant_loss, time.time() - start_b))
147
+
148
+ # checkpointing
149
+ if steps % h.checkpoint_interval == 0 and steps != 0:
150
+ checkpoint_path = "{}/vae_{:08d}".format(h.checkpoint_path, steps)
151
+ save_checkpoint(checkpoint_path,
152
+ {'encoder': model.state_dict(),
153
+ 'steps': steps,
154
+ 'epoch': epoch})
155
+
156
+ # Tensorboard summary logging
157
+ if steps % h.summary_interval == 0:
158
+ sw.add_scalar("Training/Generator_Total_Loss", L_G, steps)
159
+ sw.add_scalar("Training/Mel_Spectrogram_Loss", Mel_error, steps)
160
+
161
+ # Validation
162
+ if steps % h.validation_interval == 0: # and steps != 0:
163
+ model.eval()
164
+ torch.cuda.empty_cache()
165
+ val_Mel_err_tot = 0
166
+ val_Mel_L2_err_tot = 0
167
+ with torch.no_grad():
168
+ for j, batch in enumerate(validation_loader):
169
+ y_mel = batch
170
+ #y_mel = y_mel.to(dtype=torch.bfloat16)
171
+ out_eval = model(y_mel.to(device))
172
+ y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
173
+ val_Mel_err_tot += (F.l1_loss(y_mel, out_eval.reconstructions)*5.0).item()
174
+ val_Mel_L2_err_tot += (amplitude_loss(y_mel, out_eval.reconstructions)*25.0).item()
175
+
176
+
177
+ if j <= 4:
178
+ if steps == 0:
179
+ y_plot_tensor = y_mel[0, 0] * 5.0
180
+ y_plot = y_plot_tensor.cpu().float().numpy() # 再转 numpy
181
+ sw.add_figure('gt/y_mel_{}'.format(j), plot_spectrogram(y_plot), steps)
182
+
183
+ y_plot_tensor_g = y_g_mel[0, 0] * 5.0
184
+ y_plot_g = y_plot_tensor_g.cpu().float().numpy()
185
+ sw.add_figure('generated/y_g_mel_{}'.format(j), plot_spectrogram(y_plot_g), steps)
186
+
187
+ val_Mel_err = val_Mel_err_tot / (j+1)
188
+ val_Mel_L2_err = val_Mel_L2_err_tot / (j+1)
189
+ sw.add_scalar("Validation/Mel_Spectrogram_loss", val_Mel_err, steps)
190
+ sw.add_scalar("Validation/Mel_Spectrogram_L2_loss", val_Mel_L2_err, steps)
191
+
192
+ model.train()
193
+
194
+ steps += 1
195
+
196
+ scheduler_g.step()
197
+
198
+ print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start)))
199
+
200
+
201
+ def main():
202
+ print('Initializing Training Process..')
203
+
204
+ config_file = 'config_post.json'
205
+
206
+ with open(config_file) as f:
207
+ data = f.read()
208
+
209
+ json_config = json.loads(data)
210
+ h = AttrDict(json_config)
211
+ build_env(config_file, 'config_post.json', h.checkpoint_path)
212
+
213
+ src = "train_96_post_gm.py"
214
+ dst_dir = h.checkpoint_path
215
+ os.makedirs(dst_dir, exist_ok=True)
216
+ dst = os.path.join(dst_dir, "train_96_post_gm.py")
217
+ if not os.path.exists(src):
218
+ raise FileNotFoundError(f"{src} 不存在!")
219
+ shutil.copyfile(src, dst)
220
+ print(f"已将 {src} 复制到 {dst}")
221
+
222
+
223
+ torch.manual_seed(h.seed)
224
+ if torch.cuda.is_available():
225
+ torch.cuda.manual_seed(h.seed)
226
+ else:
227
+ pass
228
+
229
+ train(h)
230
+
231
+
232
+ if __name__ == '__main__':
233
+ main()
234
+
cp_8_gm_post/vae_00035000 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d3ee591485f64e5087595c71548d6e450ae1d582968ffb758ca7da8099a09e18
3
+ size 357630932