ashishkblink commited on
Commit
ea2fcef
·
verified ·
1 Parent(s): 5ee4d87

Upload f5_tts/train/finetune_cli.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. f5_tts/train/finetune_cli.py +192 -0
f5_tts/train/finetune_cli.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import shutil
4
+ import torch
5
+
6
+ from cached_path import cached_path
7
+ from f5_tts.model import CFM, UNetT, DiT, Trainer
8
+ from f5_tts.model.utils import get_tokenizer
9
+ from f5_tts.model.dataset import load_dataset
10
+ from importlib.resources import files
11
+
12
+ from accelerate import Accelerator
13
+
14
+ accelerator = Accelerator()
15
+ print(f"Using mixed precision: {accelerator.mixed_precision}")
16
+
17
+ # -------------------------- Dataset Settings --------------------------- #
18
+ target_sample_rate = 24000
19
+ n_mel_channels = 100
20
+ hop_length = 256
21
+ win_length = 1024
22
+ n_fft = 1024
23
+ mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
24
+
25
+
26
+ # -------------------------- Argument Parsing --------------------------- #
27
+ def parse_args():
28
+ # batch_size_per_gpu = 1000 settting for gpu 8GB
29
+ # batch_size_per_gpu = 1600 settting for gpu 12GB
30
+ # batch_size_per_gpu = 2000 settting for gpu 16GB
31
+ # batch_size_per_gpu = 3200 settting for gpu 24GB
32
+
33
+ # num_warmup_updates = 300 for 5000 sample about 10 hours
34
+
35
+ # change save_per_updates , last_per_steps change this value what you need ,
36
+
37
+ parser = argparse.ArgumentParser(description="Train CFM Model")
38
+
39
+ parser.add_argument(
40
+ "--exp_name", type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"], help="Experiment name"
41
+ )
42
+ parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use")
43
+ parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for training")
44
+ parser.add_argument("--batch_size_per_gpu", type=int, default=3200, help="Batch size per GPU")
45
+ parser.add_argument(
46
+ "--batch_size_type", type=str, default="frame", choices=["frame", "sample"], help="Batch size type"
47
+ )
48
+ parser.add_argument("--max_samples", type=int, default=64, help="Max sequences per batch")
49
+ parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
50
+ parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
51
+ parser.add_argument("--epochs", type=int, default=700, help="Number of training epochs")
52
+ parser.add_argument("--num_warmup_updates", type=int, default=1500, help="Warmup steps")
53
+ parser.add_argument("--save_per_updates", type=int, default=4000, help="Save checkpoint every X steps")
54
+ parser.add_argument("--last_per_steps", type=int, default=40000, help="Save last checkpoint every X steps")
55
+ parser.add_argument("--finetune", type=bool, default=True, help="Use Finetune")
56
+ parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint")
57
+ parser.add_argument(
58
+ "--tokenizer", type=str, default="pinyin", choices=["pinyin", "char", "custom"], help="Tokenizer type"
59
+ )
60
+ parser.add_argument(
61
+ "--tokenizer_path",
62
+ type=str,
63
+ default=None,
64
+ help="Path to custom tokenizer vocab file (only used if tokenizer = 'custom')",
65
+ )
66
+ parser.add_argument(
67
+ "--log_samples",
68
+ type=bool,
69
+ default=False,
70
+ help="Log inferenced samples per ckpt save steps",
71
+ )
72
+ parser.add_argument("--logger", type=str, default=None, choices=["wandb", "tensorboard"], help="logger")
73
+ parser.add_argument(
74
+ "--bnb_optimizer",
75
+ type=bool,
76
+ default=False,
77
+ help="Use 8-bit Adam optimizer from bitsandbytes",
78
+ )
79
+ parser.add_argument("--ckpt_dir", required=True, type=str)
80
+ parser.add_argument("--data_dir", required=True, type=str)
81
+ parser.add_argument("--wandb_resume_id", type=str, default=None)
82
+ parser.add_argument("--resume", type=bool, default=False, help="Resume Finetune")
83
+
84
+ return parser.parse_args()
85
+
86
+
87
+ # -------------------------- Training Settings -------------------------- #
88
+
89
+
90
+ def main():
91
+ args = parse_args()
92
+
93
+ # checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}"))
94
+ checkpoint_path = args.ckpt_dir
95
+
96
+ # Model parameters based on experiment name
97
+ if args.exp_name == "F5TTS_Base":
98
+ wandb_resume_id = args.wandb_resume_id
99
+ print("wandb resume id is: ", wandb_resume_id)
100
+ model_cls = DiT
101
+ model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
102
+ # ckpt_path = "/home/tts/ttsteam/repos/F5-TTS/runs/indic_langs_11_1hr/ckpt/model_1200000.pt"
103
+ # if args.finetune:
104
+ # if args.pretrain is None:
105
+ # ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
106
+ # else:
107
+ # ckpt_path = args.pretrain
108
+ # elif args.exp_name == "E2TTS_Base":
109
+ # wandb_resume_id = None
110
+ # model_cls = UNetT
111
+ # model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
112
+ # if args.finetune:
113
+ # if args.pretrain is None:
114
+ # ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
115
+ # else:
116
+ # ckpt_path = args.pretrain
117
+
118
+ if args.finetune and not args.resume:
119
+ if not os.path.isdir(checkpoint_path):
120
+ os.makedirs(checkpoint_path, exist_ok=True)
121
+
122
+ file_checkpoint = os.path.join(checkpoint_path, 'model_last.pt')
123
+
124
+ # if not os.path.isfile(file_checkpoint): ## UNRELIABLE, if too slow on Multinode, can lead to some nodes training from scratch
125
+ # # shutil.copy2(load_from, file_checkpoint)
126
+ # ckpt = torch.load(args.load_from, weights_only=True, map_location="cpu")
127
+ # del ckpt['step']
128
+ # torch.save(ckpt, file_checkpoint)
129
+ # del ckpt
130
+ # print("copy checkpoint for finetune", load_from, file_checkpoint)
131
+
132
+ # Use the tokenizer and tokenizer_path provided in the command line arguments
133
+ tokenizer = args.tokenizer
134
+ if tokenizer == "custom":
135
+ if not args.tokenizer_path:
136
+ raise ValueError("Custom tokenizer selected, but no tokenizer_path provided.")
137
+ tokenizer_path = args.tokenizer_path
138
+ else:
139
+ tokenizer_path = args.dataset_name
140
+
141
+ vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
142
+
143
+ print("\nvocab : ", vocab_size)
144
+ print("\nvocoder : ", mel_spec_type)
145
+
146
+ mel_spec_kwargs = dict(
147
+ n_fft=n_fft,
148
+ hop_length=hop_length,
149
+ win_length=win_length,
150
+ n_mel_channels=n_mel_channels,
151
+ target_sample_rate=target_sample_rate,
152
+ mel_spec_type=mel_spec_type,
153
+ )
154
+
155
+ model = CFM(
156
+ transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
157
+ mel_spec_kwargs=mel_spec_kwargs,
158
+ vocab_char_map=vocab_char_map,
159
+ )
160
+
161
+ trainer = Trainer(
162
+ model,
163
+ args.epochs,
164
+ args.learning_rate,
165
+ num_warmup_updates=args.num_warmup_updates,
166
+ save_per_updates=args.save_per_updates,
167
+ checkpoint_path=checkpoint_path,
168
+ batch_size=args.batch_size_per_gpu,
169
+ batch_size_type=args.batch_size_type,
170
+ max_samples=args.max_samples,
171
+ grad_accumulation_steps=args.grad_accumulation_steps,
172
+ max_grad_norm=args.max_grad_norm,
173
+ logger=args.logger,
174
+ wandb_project=args.dataset_name,
175
+ wandb_run_name=args.exp_name,
176
+ wandb_resume_id=wandb_resume_id,
177
+ log_samples=args.log_samples,
178
+ last_per_steps=args.last_per_steps,
179
+ bnb_optimizer=args.bnb_optimizer,
180
+ )
181
+
182
+ train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs, data_dir=args.data_dir)
183
+
184
+ trainer.train(
185
+ train_dataset,
186
+ resumable_with_seed=666, # seed for shuffling dataset
187
+ num_workers=16
188
+ )
189
+
190
+
191
+ if __name__ == "__main__":
192
+ main()