ashishkblink commited on
Commit
5fcbaac
·
verified ·
1 Parent(s): c067e56

Upload f5_tts/train/train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. f5_tts/train/train.py +75 -0
f5_tts/train/train.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # training script.
2
+
3
+ import os
4
+ from importlib.resources import files
5
+
6
+ import hydra
7
+
8
+ from f5_tts.model import CFM, DiT, Trainer, UNetT
9
+ from f5_tts.model.dataset import load_dataset
10
+ from f5_tts.model.utils import get_tokenizer
11
+
12
+ os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to root of project (local editable)
13
+
14
+
15
+ @hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None)
16
+ def main(cfg):
17
+ tokenizer = cfg.model.tokenizer
18
+ mel_spec_type = cfg.model.mel_spec.mel_spec_type
19
+ exp_name = f"{cfg.model.name}_{mel_spec_type}_{cfg.model.tokenizer}_{cfg.datasets.name}"
20
+
21
+ # set text tokenizer
22
+ if tokenizer != "custom":
23
+ tokenizer_path = cfg.datasets.name
24
+ else:
25
+ tokenizer_path = cfg.model.tokenizer_path
26
+ vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
27
+
28
+ # set model
29
+ if "F5TTS" in cfg.model.name:
30
+ model_cls = DiT
31
+ elif "E2TTS" in cfg.model.name:
32
+ model_cls = UNetT
33
+ wandb_resume_id = None
34
+
35
+ model = CFM(
36
+ transformer=model_cls(**cfg.model.arch, text_num_embeds=vocab_size, mel_dim=cfg.model.mel_spec.n_mel_channels),
37
+ mel_spec_kwargs=cfg.model.mel_spec,
38
+ vocab_char_map=vocab_char_map,
39
+ )
40
+
41
+ # init trainer
42
+ trainer = Trainer(
43
+ model,
44
+ epochs=cfg.optim.epochs,
45
+ learning_rate=cfg.optim.learning_rate,
46
+ num_warmup_updates=cfg.optim.num_warmup_updates,
47
+ save_per_updates=cfg.ckpts.save_per_updates,
48
+ checkpoint_path=str(files("f5_tts").joinpath(f"../../{cfg.ckpts.save_dir}")),
49
+ batch_size=cfg.datasets.batch_size_per_gpu,
50
+ batch_size_type=cfg.datasets.batch_size_type,
51
+ max_samples=cfg.datasets.max_samples,
52
+ grad_accumulation_steps=cfg.optim.grad_accumulation_steps,
53
+ max_grad_norm=cfg.optim.max_grad_norm,
54
+ logger=cfg.ckpts.logger,
55
+ wandb_project="CFM-TTS",
56
+ wandb_run_name=exp_name,
57
+ wandb_resume_id=wandb_resume_id,
58
+ last_per_steps=cfg.ckpts.last_per_steps,
59
+ log_samples=True,
60
+ bnb_optimizer=cfg.optim.bnb_optimizer,
61
+ mel_spec_type=mel_spec_type,
62
+ is_local_vocoder=cfg.model.vocoder.is_local,
63
+ local_vocoder_path=cfg.model.vocoder.local_path,
64
+ )
65
+
66
+ train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)
67
+ trainer.train(
68
+ train_dataset,
69
+ num_workers=cfg.datasets.num_workers,
70
+ resumable_with_seed=666, # seed for shuffling dataset
71
+ )
72
+
73
+
74
+ if __name__ == "__main__":
75
+ main()