ashishkblink commited on
Commit
fb2f21a
·
verified ·
1 Parent(s): 7866f97

Upload f5_tts/model/trainer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. f5_tts/model/trainer.py +380 -0
f5_tts/model/trainer.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import gc
4
+ import os
5
+
6
+ import torch
7
+ import torchaudio
8
+ import wandb
9
+ from accelerate import Accelerator
10
+ from accelerate.utils import DistributedDataParallelKwargs
11
+ from ema_pytorch import EMA
12
+ from torch.optim import AdamW
13
+ from torch.optim.lr_scheduler import LinearLR, SequentialLR
14
+ from torch.utils.data import DataLoader, Dataset, SequentialSampler
15
+ from tqdm import tqdm
16
+
17
+ from f5_tts.model import CFM
18
+ from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
19
+ from f5_tts.model.utils import default, exists
20
+
21
+ # trainer
22
+
23
+
24
+ class Trainer:
25
+ def __init__(
26
+ self,
27
+ model: CFM,
28
+ epochs,
29
+ learning_rate,
30
+ num_warmup_updates=20000,
31
+ save_per_updates=1000,
32
+ checkpoint_path=None,
33
+ batch_size=32,
34
+ batch_size_type: str = "sample",
35
+ max_samples=32,
36
+ grad_accumulation_steps=1,
37
+ max_grad_norm=1.0,
38
+ noise_scheduler: str | None = None,
39
+ duration_predictor: torch.nn.Module | None = None,
40
+ logger: str | None = "wandb", # "wandb" | "tensorboard" | None
41
+ wandb_project="test_e2-tts",
42
+ wandb_run_name="test_run",
43
+ wandb_resume_id: str = None,
44
+ log_samples: bool = False,
45
+ last_per_steps=None,
46
+ accelerate_kwargs: dict = dict(),
47
+ ema_kwargs: dict = dict(),
48
+ bnb_optimizer: bool = False,
49
+ mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
50
+ is_local_vocoder: bool = False, # use local path vocoder
51
+ local_vocoder_path: str = "", # local vocoder path
52
+ ):
53
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
54
+
55
+ if logger == "wandb" and not wandb.api.api_key:
56
+ logger = None
57
+ print(f"Using logger: {logger}")
58
+ self.log_samples = log_samples
59
+
60
+ self.accelerator = Accelerator(
61
+ log_with=logger if logger == "wandb" else None,
62
+ kwargs_handlers=[ddp_kwargs],
63
+ gradient_accumulation_steps=grad_accumulation_steps,
64
+ **accelerate_kwargs,
65
+ )
66
+
67
+ self.logger = logger
68
+ if self.logger == "wandb":
69
+ if exists(wandb_resume_id):
70
+ init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}}
71
+ else:
72
+ init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
73
+
74
+ self.accelerator.init_trackers(
75
+ project_name=wandb_project,
76
+ init_kwargs=init_kwargs,
77
+ config={
78
+ "epochs": epochs,
79
+ "learning_rate": learning_rate,
80
+ "num_warmup_updates": num_warmup_updates,
81
+ "batch_size": batch_size,
82
+ "batch_size_type": batch_size_type,
83
+ "max_samples": max_samples,
84
+ "grad_accumulation_steps": grad_accumulation_steps,
85
+ "max_grad_norm": max_grad_norm,
86
+ "gpus": self.accelerator.num_processes,
87
+ "noise_scheduler": noise_scheduler,
88
+ },
89
+ )
90
+
91
+ elif self.logger == "tensorboard":
92
+ from torch.utils.tensorboard import SummaryWriter
93
+
94
+ self.writer = SummaryWriter(log_dir=f"runs/{wandb_run_name}")
95
+
96
+ self.model = model
97
+
98
+ if self.is_main:
99
+ self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)
100
+ self.ema_model.to(self.accelerator.device)
101
+
102
+ self.epochs = epochs
103
+ self.num_warmup_updates = num_warmup_updates
104
+ self.save_per_updates = save_per_updates
105
+ self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
106
+ self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
107
+
108
+ self.batch_size = batch_size
109
+ self.batch_size_type = batch_size_type
110
+ self.max_samples = max_samples
111
+ self.grad_accumulation_steps = grad_accumulation_steps
112
+ self.max_grad_norm = max_grad_norm
113
+
114
+ # mel vocoder config
115
+ self.vocoder_name = mel_spec_type
116
+ self.is_local_vocoder = is_local_vocoder
117
+ self.local_vocoder_path = local_vocoder_path
118
+
119
+ self.noise_scheduler = noise_scheduler
120
+
121
+ self.duration_predictor = duration_predictor
122
+
123
+ if bnb_optimizer:
124
+ import bitsandbytes as bnb
125
+
126
+ self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
127
+ else:
128
+ self.optimizer = AdamW(model.parameters(), lr=learning_rate)
129
+ self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
130
+
131
+ @property
132
+ def is_main(self):
133
+ return self.accelerator.is_main_process
134
+
135
+ def save_checkpoint(self, step, last=False):
136
+ self.accelerator.wait_for_everyone()
137
+ if self.is_main:
138
+ checkpoint = dict(
139
+ model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
140
+ optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
141
+ ema_model_state_dict=self.ema_model.state_dict(),
142
+ scheduler_state_dict=self.scheduler.state_dict(),
143
+ step=step,
144
+ )
145
+ if not os.path.exists(self.checkpoint_path):
146
+ os.makedirs(self.checkpoint_path)
147
+ if last:
148
+ self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
149
+ print(f"Saved last checkpoint at step {step}")
150
+ else:
151
+ self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
152
+
153
+ def load_checkpoint(self):
154
+ if (
155
+ not exists(self.checkpoint_path)
156
+ or not os.path.exists(self.checkpoint_path)
157
+ or not any(filename.endswith(".pt") for filename in os.listdir(self.checkpoint_path))
158
+ ):
159
+ return 0
160
+
161
+ self.accelerator.wait_for_everyone()
162
+ if "model_last.pt" in os.listdir(self.checkpoint_path):
163
+ latest_checkpoint = "model_last.pt"
164
+ else:
165
+ latest_checkpoint = sorted(
166
+ [f for f in os.listdir(self.checkpoint_path) if f.endswith(".pt")],
167
+ key=lambda x: int("".join(filter(str.isdigit, x))),
168
+ )[-1]
169
+ # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
170
+ print("Loading checkpoint from: ", f"{self.checkpoint_path}/{latest_checkpoint}")
171
+ checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
172
+
173
+ # patch for backward compatibility, 305e3ea
174
+ for key in ["ema_model.mel_spec.mel_stft.mel_scale.fb", "ema_model.mel_spec.mel_stft.spectrogram.window"]:
175
+ if key in checkpoint["ema_model_state_dict"]:
176
+ del checkpoint["ema_model_state_dict"][key]
177
+
178
+ if self.is_main:
179
+ self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"])
180
+
181
+ if "step" in checkpoint:
182
+ # patch for backward compatibility, 305e3ea
183
+ for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
184
+ if key in checkpoint["model_state_dict"]:
185
+ del checkpoint["model_state_dict"][key]
186
+
187
+ self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
188
+ self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"])
189
+ if self.scheduler:
190
+ self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
191
+ # step = checkpoint["step"]
192
+ # step = 0
193
+ # print("checkpoint step is: ", step, " CHANGE LINE 192 IN /projects/data/ttsteam/repos/f5/src/f5_tts/model/trainer.py TO FIX THIS!!!!")
194
+ else:
195
+ checkpoint["model_state_dict"] = {
196
+ k.replace("ema_model.", ""): v
197
+ for k, v in checkpoint["ema_model_state_dict"].items()
198
+ if k not in ["initted", "step"]
199
+ }
200
+ self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
201
+ step = 0
202
+
203
+ del checkpoint
204
+ gc.collect()
205
+ return step
206
+
207
+ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
208
+ if self.log_samples:
209
+ from f5_tts.infer.utils_infer import cfg_strength, load_vocoder, nfe_step, sway_sampling_coef
210
+
211
+ vocoder = load_vocoder(
212
+ vocoder_name=self.vocoder_name, is_local=self.is_local_vocoder, local_path=self.local_vocoder_path
213
+ )
214
+ target_sample_rate = self.accelerator.unwrap_model(self.model).mel_spec.target_sample_rate
215
+ log_samples_path = f"{self.checkpoint_path}/samples"
216
+ os.makedirs(log_samples_path, exist_ok=True)
217
+
218
+ if exists(resumable_with_seed):
219
+ generator = torch.Generator()
220
+ generator.manual_seed(resumable_with_seed)
221
+ else:
222
+ generator = None
223
+
224
+ if self.batch_size_type == "sample":
225
+ train_dataloader = DataLoader(
226
+ train_dataset,
227
+ collate_fn=collate_fn,
228
+ num_workers=num_workers,
229
+ pin_memory=True,
230
+ persistent_workers=True,
231
+ batch_size=self.batch_size,
232
+ shuffle=True,
233
+ generator=generator,
234
+ )
235
+ elif self.batch_size_type == "frame":
236
+ self.accelerator.even_batches = False
237
+ sampler = SequentialSampler(train_dataset)
238
+ batch_sampler = DynamicBatchSampler(
239
+ sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False
240
+ )
241
+ train_dataloader = DataLoader(
242
+ train_dataset,
243
+ collate_fn=collate_fn,
244
+ num_workers=num_workers,
245
+ pin_memory=True,
246
+ persistent_workers=True,
247
+ batch_sampler=batch_sampler,
248
+ )
249
+ else:
250
+ raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}")
251
+
252
+ # accelerator.prepare() dispatches batches to devices;
253
+ # which means the length of dataloader calculated before, should consider the number of devices
254
+ warmup_steps = (
255
+ self.num_warmup_updates * self.accelerator.num_processes
256
+ ) # consider a fixed warmup steps while using accelerate multi-gpu ddp
257
+ print("Warm Up steps are: ", warmup_steps)
258
+ # otherwise by default with split_batches=False, warmup steps change with num_processes
259
+ total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
260
+ decay_steps = total_steps - warmup_steps
261
+ warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
262
+ decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
263
+ self.scheduler = SequentialLR(
264
+ self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_steps]
265
+ )
266
+ train_dataloader, self.scheduler = self.accelerator.prepare(
267
+ train_dataloader, self.scheduler
268
+ ) # actual steps = 1 gpu steps / gpus
269
+ start_step = self.load_checkpoint()
270
+ global_step = start_step
271
+
272
+ if exists(resumable_with_seed):
273
+ orig_epoch_step = len(train_dataloader)
274
+ skipped_epoch = int(start_step // orig_epoch_step)
275
+ skipped_batch = start_step % orig_epoch_step
276
+ skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
277
+ else:
278
+ skipped_epoch = 0
279
+
280
+ for epoch in range(skipped_epoch, self.epochs):
281
+ self.model.train()
282
+ if exists(resumable_with_seed) and epoch == skipped_epoch:
283
+ progress_bar = tqdm(
284
+ skipped_dataloader,
285
+ desc=f"Epoch {epoch+1}/{self.epochs}",
286
+ unit="step",
287
+ disable=not self.accelerator.is_local_main_process,
288
+ initial=skipped_batch,
289
+ total=orig_epoch_step,
290
+ )
291
+ else:
292
+ progress_bar = tqdm(
293
+ train_dataloader,
294
+ desc=f"Epoch {epoch+1}/{self.epochs}",
295
+ unit="step",
296
+ disable=not self.accelerator.is_local_main_process,
297
+ )
298
+
299
+ for batch in progress_bar:
300
+
301
+ with self.accelerator.accumulate(self.model):
302
+ text_inputs = batch["text"]
303
+ mel_spec = batch["mel"].permute(0, 2, 1)
304
+ mel_lengths = batch["mel_lengths"]
305
+ if mel_spec.shape[0] * mel_spec.shape[1] > 38000: # Hacky Fix for incorrect dynamic batching
306
+ continue
307
+
308
+ # TODO. add duration predictor training
309
+ if self.duration_predictor is not None and self.accelerator.is_local_main_process:
310
+ dur_loss = self.duration_predictor(mel_spec, lens=batch.get("durations"))
311
+ self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step)
312
+
313
+ loss, cond, pred = self.model(
314
+ mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler
315
+ )
316
+ self.accelerator.backward(loss)
317
+
318
+ if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
319
+ self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
320
+
321
+ self.optimizer.step()
322
+ self.scheduler.step()
323
+ self.optimizer.zero_grad()
324
+
325
+ if self.is_main:
326
+ self.ema_model.update()
327
+
328
+ global_step += 1
329
+
330
+ if self.accelerator.is_local_main_process:
331
+ self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
332
+ if self.logger == "tensorboard":
333
+ self.writer.add_scalar("loss", loss.item(), global_step)
334
+ self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_step)
335
+
336
+ progress_bar.set_postfix(step=str(global_step), loss=loss.item())
337
+
338
+ if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
339
+ self.save_checkpoint(global_step)
340
+
341
+ if self.log_samples and self.accelerator.is_local_main_process:
342
+ ref_audio_len = mel_lengths[0]
343
+ infer_text = [
344
+ text_inputs[0] + ([" "] if isinstance(text_inputs[0], list) else " ") + text_inputs[0]
345
+ ]
346
+ with torch.inference_mode():
347
+ generated, _ = self.accelerator.unwrap_model(self.model).sample(
348
+ cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
349
+ text=infer_text,
350
+ duration=ref_audio_len * 2,
351
+ steps=nfe_step,
352
+ cfg_strength=cfg_strength,
353
+ sway_sampling_coef=sway_sampling_coef,
354
+ )
355
+ generated = generated.to(torch.float32)
356
+ gen_mel_spec = generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.accelerator.device)
357
+ ref_mel_spec = batch["mel"][0].unsqueeze(0)
358
+ if self.vocoder_name == "vocos":
359
+ gen_audio = vocoder.decode(gen_mel_spec).cpu()
360
+ ref_audio = vocoder.decode(ref_mel_spec).cpu()
361
+ elif self.vocoder_name == "bigvgan":
362
+ gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu()
363
+ ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu()
364
+
365
+ torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio, target_sample_rate)
366
+ torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate)
367
+
368
+ if global_step % self.last_per_steps == 0:
369
+ self.save_checkpoint(global_step, last=True)
370
+
371
+ # Debugging
372
+
373
+ print(torch.cuda.memory_allocated() / 1e9, "GB allocated")
374
+ print(torch.cuda.memory_reserved() / 1e9, "GB reserved")
375
+ torch.cuda.empty_cache()
376
+ gc.collect()
377
+
378
+ self.save_checkpoint(global_step, last=True)
379
+
380
+ self.accelerator.end_training()