File size: 2,514 Bytes
41bc8a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os,torch
import sys

dir_path = os.path.abspath(os.path.dirname(__file__))

class config:

    data_path='./'
    model_name="OSS Shogun"
    user_name = "xxx" #for wandb
    
    desc = '''
        OSS
    '''

    train_file = "Semantic_tokens/SEMANTICS_/train.txt" 
    val_file = "Semantic_tokens/SEMANTICS_/val.txt"

    norms = torch.load(f"{dir_path}/mel_norms.pt")
    mu = norms["mean_val"]
    std = norms["std"]
    scale = True
    semantic_model_centroids = 10000 + 1
    seed_value = 3407

    t2s_checkpoint = "/omega/Models/FT_ENGLISH/T2S/1_latest.pt"
    ts_finetuning = True
    ts_wandb_logs = False
    text_loss_weight = 0.01
    t2s_position = 8192
    ts_batch_size = 1
    ts_epochs = 10
    ts_lr = 1e-5
    ts_weight_decay = 1e-4
    ts_eval_epoch = 1
    ts_num_workers = 8
    ts_gradient_accumulation_steps = 1  # EfBS of 128 for finetuning, 256 for pretraining, around 9k steps for sft for 2 epochs
    ts_eval_step = 10000

    langs = [
        "odia",
        "assamese",
        "thai",
        "gujrati",
        "russian",
        "japanese",
        "punjabi",
        "hindi",
        "manipuri",
        "korean",
        "bhojpuri",
        "sanskrit",
        "english",
        "french",
        "bodo",
        "malayalam",
        "telugu",
        "kannada",
        "dogri",
        "marathi",
        "german",
        "italian",
        "rajasthani",
        "spanish",
        "arabic",
        "urdu",
        "gujarati",
        "tamil",
        "bengali",
    ]

    lang_index = {i: j for j, i in enumerate(langs)}

    # Train s2a
    sa_wandb_logs = False
    joint_training = (False,)  # doesn't work
    checkpoint = "/omega/Models/" + "FT_ENGLISH/"
    sa_timesteps_max = 1000
    sa_batch_size = 32
    sa_epochs = 5000000
    gradient_accumulation_steps = 4
    sa_lr = 1e-4
    sa_weight_decay = 1e-2
    sa_eval_step = 10000
    sa_infer = True
    sa_infer_epoch = 1
    sa_num_workers = 24

    # Train Dvae (not using)
    dvae_wandb_logs = True
    dvae_batch_size = 128
    dvae_epochs = 5000
    dvae_lr = 3e-4
    dvae_weight_decay = 1e-2
    dvae_eval_epoch = 1
    dvae_infer = True
    dvae_infer_epoch = 1
    dvae_num_workers = 16

    # Acoustic Properties, Do not change
    CLIP_LENGTH = 500
    MAX_WAV_VALUE = 32768.0 - 1
    filter_length = 1024
    hop_length = 256  # 256
    window = "hann"
    win_length = 1024
    n_mel_channels = 100
    sampling_rate = 24000
    mel_fmin = 0.0
    mel_fmax = None
    normalize = True