|
|
import os,torch |
|
|
import sys |
|
|
|
|
|
dir_path = os.path.abspath(os.path.dirname(__file__)) |
|
|
|
|
|
class config: |
|
|
|
|
|
data_path='./' |
|
|
model_name="OSS Shogun" |
|
|
user_name = "xxx" |
|
|
|
|
|
desc = ''' |
|
|
OSS |
|
|
''' |
|
|
|
|
|
train_file = "Semantic_tokens/SEMANTICS_/train.txt" |
|
|
val_file = "Semantic_tokens/SEMANTICS_/val.txt" |
|
|
|
|
|
norms = torch.load(f"{dir_path}/mel_norms.pt") |
|
|
mu = norms["mean_val"] |
|
|
std = norms["std"] |
|
|
scale = True |
|
|
semantic_model_centroids = 10000 + 1 |
|
|
seed_value = 3407 |
|
|
|
|
|
t2s_checkpoint = "/omega/Models/FT_ENGLISH/T2S/1_latest.pt" |
|
|
ts_finetuning = True |
|
|
ts_wandb_logs = False |
|
|
text_loss_weight = 0.01 |
|
|
t2s_position = 8192 |
|
|
ts_batch_size = 1 |
|
|
ts_epochs = 10 |
|
|
ts_lr = 1e-5 |
|
|
ts_weight_decay = 1e-4 |
|
|
ts_eval_epoch = 1 |
|
|
ts_num_workers = 8 |
|
|
ts_gradient_accumulation_steps = 1 |
|
|
ts_eval_step = 10000 |
|
|
|
|
|
langs = [ |
|
|
"odia", |
|
|
"assamese", |
|
|
"thai", |
|
|
"gujrati", |
|
|
"russian", |
|
|
"japanese", |
|
|
"punjabi", |
|
|
"hindi", |
|
|
"manipuri", |
|
|
"korean", |
|
|
"bhojpuri", |
|
|
"sanskrit", |
|
|
"english", |
|
|
"french", |
|
|
"bodo", |
|
|
"malayalam", |
|
|
"telugu", |
|
|
"kannada", |
|
|
"dogri", |
|
|
"marathi", |
|
|
"german", |
|
|
"italian", |
|
|
"rajasthani", |
|
|
"spanish", |
|
|
"arabic", |
|
|
"urdu", |
|
|
"gujarati", |
|
|
"tamil", |
|
|
"bengali", |
|
|
] |
|
|
|
|
|
lang_index = {i: j for j, i in enumerate(langs)} |
|
|
|
|
|
|
|
|
sa_wandb_logs = False |
|
|
joint_training = (False,) |
|
|
checkpoint = "/omega/Models/" + "FT_ENGLISH/" |
|
|
sa_timesteps_max = 1000 |
|
|
sa_batch_size = 32 |
|
|
sa_epochs = 5000000 |
|
|
gradient_accumulation_steps = 4 |
|
|
sa_lr = 1e-4 |
|
|
sa_weight_decay = 1e-2 |
|
|
sa_eval_step = 10000 |
|
|
sa_infer = True |
|
|
sa_infer_epoch = 1 |
|
|
sa_num_workers = 24 |
|
|
|
|
|
|
|
|
dvae_wandb_logs = True |
|
|
dvae_batch_size = 128 |
|
|
dvae_epochs = 5000 |
|
|
dvae_lr = 3e-4 |
|
|
dvae_weight_decay = 1e-2 |
|
|
dvae_eval_epoch = 1 |
|
|
dvae_infer = True |
|
|
dvae_infer_epoch = 1 |
|
|
dvae_num_workers = 16 |
|
|
|
|
|
|
|
|
CLIP_LENGTH = 500 |
|
|
MAX_WAV_VALUE = 32768.0 - 1 |
|
|
filter_length = 1024 |
|
|
hop_length = 256 |
|
|
window = "hann" |
|
|
win_length = 1024 |
|
|
n_mel_channels = 100 |
|
|
sampling_rate = 24000 |
|
|
mel_fmin = 0.0 |
|
|
mel_fmax = None |
|
|
normalize = True |
|
|
|