p-alonso commited on
Commit
1217850
·
verified ·
1 Parent(s): ab48227

Copy train-time config to config_train.gin

Browse files
Files changed (1) hide show
  1. config_train.gin +98 -0
config_train.gin ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Parameters for AudioDataModule:
2
+ # ==============================================================================
3
+ AudioDataModule.num_workers = 20
4
+
5
+ # Parameters for AudioDataset:
6
+ # ==============================================================================
7
+ AudioDataset.half_precision = True
8
+ AudioDataset.mono = True
9
+ AudioDataset.new_freq = 16000
10
+ AudioDataset.num_frames = 480000
11
+ AudioDataset.orig_freq = 16000
12
+
13
+ # Parameters for build_dev_datamodule:
14
+ # ==============================================================================
15
+ build_dev_datamodule.datamodule = @discotube
16
+
17
+ # Parameters for build_module:
18
+ # ==============================================================================
19
+ build_module.ckpt_path = 'model.ckpt'
20
+ build_module.module = @modules.maskingmodel.MaskingModel
21
+ build_module.net = @nets.conformer.Conformer
22
+ build_module.representation = @nets.melspectrogram.MelSpectrogram
23
+
24
+ # Parameters for Conformer:
25
+ # ==============================================================================
26
+ Conformer.alpha_deepnorm = 2.6321480259049848
27
+ Conformer.beta_deepnorm = 0.022386873579657126
28
+ Conformer.conv_kernel_size = 5
29
+ Conformer.depth = 24
30
+ Conformer.dropout = 0.2
31
+ Conformer.embed_dim = 1024
32
+ Conformer.input_dropout = 0.0
33
+ Conformer.mlp_ratio = 4.0
34
+ Conformer.mlp_residual_factor = 4.0
35
+ Conformer.num_heads = 8
36
+ Conformer.num_patches = 460
37
+ Conformer.use_deepnorm = True
38
+ Conformer.use_rope = True
39
+
40
+ # Parameters for CosineAnnealingCallback:
41
+ # ==============================================================================
42
+ CosineAnnealingCallback.eta_min = 1e-07
43
+ CosineAnnealingCallback.warmup_steps = 30000
44
+
45
+ # Parameters for DiscotubeAudioDataModule:
46
+ # ==============================================================================
47
+ DiscotubeAudioDataModule.batch_size = 32
48
+ DiscotubeAudioDataModule.data_dir = ''
49
+ DiscotubeAudioDataModule.filelist_train = ''
50
+ DiscotubeAudioDataModule.filelist_val = ''
51
+
52
+ # Parameters for MaskingModel:
53
+ # ==============================================================================
54
+ MaskingModel.codebook_dim = 16
55
+ MaskingModel.codebook_size = 8192
56
+ MaskingModel.diff_input = False
57
+ MaskingModel.lr = 0.0001
58
+ MaskingModel.mask_prob = 0.6
59
+ MaskingModel.mask_seconds = 0.4
60
+ MaskingModel.num_codebooks = 4
61
+ MaskingModel.plot_tokens = False
62
+ MaskingModel.seed = 0
63
+ MaskingModel.weight_decay = 0.01
64
+
65
+ # Parameters for MelSpectrogram:
66
+ # ==============================================================================
67
+ MelSpectrogram.freq_mask_param = 0
68
+ MelSpectrogram.hop_len = 256
69
+ MelSpectrogram.mel_scale = 'slaney'
70
+ MelSpectrogram.n_mel = 96
71
+ MelSpectrogram.norm = 'slaney'
72
+ MelSpectrogram.norm_mean = 2.06755686098554
73
+ MelSpectrogram.norm_std = 1.268292820667291
74
+ MelSpectrogram.power = 2
75
+ MelSpectrogram.sr = 16000
76
+ MelSpectrogram.stretch_factor = 1
77
+ MelSpectrogram.time_mask_param = 0
78
+ MelSpectrogram.win_len = 512
79
+ MelSpectrogram.patch_size = (96, 4)
80
+
81
+ # Parameters for train:
82
+ # ==============================================================================
83
+ train.params = \
84
+ {'accelerator': 'gpu',
85
+ 'devices': 4,
86
+ 'log_every_n_steps': 50,
87
+ 'max_steps': 400000,
88
+ 'num_nodes': 1,
89
+ 'num_sanity_val_steps': 0,
90
+ 'precision': 'bf16-mixed',
91
+ 'strategy': 'ddp_find_unused_parameters_true'}
92
+ train.wandb_params = \
93
+ {'entity': 'mtg-upf',
94
+ 'group': 'masking_conformer',
95
+ 'name': 'mask_conformer_rope_multi4_large',
96
+ 'offline': True,
97
+ 'project': 'mtg-ssl',
98
+ 'save_dir': '/gpfs/projects/upf97/logs/'}