File size: 3,859 Bytes
364ef4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
version: 1.0.0

model:
  base_learning_rate: 0.00001
  target: mug.diffusion.diffusion.DDPM
  params:
    linear_start: 0.0001
    linear_end: 0.02
    log_every_t: 100
    timesteps: 1000
    z_channels: 16
    z_length: 512
    parameterization: eps
    loss_type: smooth_l1
    monitor: val/loss_simple

    unet_config:
      target: mug.diffusion.unet.UNetModel
      params:
        in_channels: 16
        model_channels: 128
        out_channels: 16
        attention_resolutions: [ 8,4,2 ]
        num_res_blocks: 2
        channel_mult: [ 1,2,3,4 ]
        num_heads: 8
        context_dim: 128
        dropout: 0.0
        lstm_last: false
        lstm_layer: false
        s4_layer: true
        audio_channels: [ 256,512,512,512 ]
        use_checkpoint: false

    first_stage_config:
      target: mug.firststage.autoencoder.AutoencoderKL
      params:
        monitor: "val/loss"
        kl_weight: 0.000001
        ddconfig:
          x_channels: 16 # key_count * 4
          middle_channels: 64
          z_channels: 16
          num_groups: 8
          channel_mult: [ 1,2,4,4 ]  # num_down = len(ch_mult)-1
          num_res_blocks: 1
        lossconfig:
          target: torch.nn.Identity
          # target: mug.firststage.losses.ManiaReconstructLoss
          # params:
          #   weight_start_offset: 0.5
          #   weight_holding: 0.5
          #   weight_end_offset: 0.2
          #   label_smoothing: 0.001


    cond_stage_config:
      target: mug.cond.feature.BeatmapFeatureEmbedder
      params:
        path_to_yaml: "configs/mug/mania_beatmap_features.yaml"
        embed_dim: 128

    wave_stage_config:
      target: mug.cond.wave.MelspectrogramScaleEncoder1D
      params:
        n_freq: 128
        middle_channels: 128
        attention_resolutions: [ 128,256,512 ]
        num_res_blocks: 2
        num_heads: 8
        num_groups: 32
        dropout: 0.0
        use_checkpoint: true
        channel_mult: [ 1,1,1,1,2,2,2,4,4,4 ]


data:
  target: main.DataModuleFromConfig
  params:
    batch_size: 48
    wrap: False
    # num_workers: 0
    num_workers: 7
    common_params:
      txt_file: [  ]
      sr: 22050
      n_fft: 512
      max_audio_frame: 32768
      audio_note_window_ratio: 8
      n_mels: 128
      cache_dir: "data/audio_cache/"
      with_audio: true
      with_feature: true
      feature_yaml: "configs/mug/mania_beatmap_features.yaml"
      # audio_window_frame = n_fft / sr / 4 = 0.00580499 s
      # note_window_frame = audio_note_window_ratio * audio_window_frame = 0.04643990 s
      # max_duration = audio_window_frame * max_audio_frame = 190.2179 s = 3 min 10 s
      # max_note_frame = max_audio_frame / audio_note_window_ratio = 4096

      # old ===========
      # audio_window_frame = n_fft / sr / 4 = 0.02321995 s
      # note_window_frame = audio_note_window_ratio * audio_window_frame = 0.04643990 s
      # max_duration = audio_window_frame * max_audio_frame = 380.4357 s = 6 min 20 s
      # max_note_frame = max_audio_frame / audio_note_window_ratio = 8192
    train:
      target: mug.data.dataset.OsuTrainDataset
      params:
        mirror_p: 0.5
        feature_dropout_p: 0.5
        mirror_at_interval_p: 0
        rate_p: 0.2
        rate: [ 0.75,1.3 ]
        freq_mask_p: 0.0
        freq_mask_num: 15

    validation:
      target: mug.data.dataset.OsuValidDataset
      params: {}
#        test_txt_file: "data\\mug\\local_mania_4k_test.txt"


lightning:
  callbacks:
    beatmap_logger:
      target: mug.data.dataset.BeatmapLogger
      params:
        log_batch_idx: [ 0 ]
        splits: [ 'val' ]
        count: 16

  trainer:
    benchmark: True
    accelerator: dp
    accumulate_grad_batches: 1
    # precision: 16