Spaces:
Running
on
Zero
Running
on
Zero
| from __future__ import annotations | |
| import contextlib | |
| import copy | |
| import os | |
| from pathlib import Path | |
| from random import random | |
| from typing import Callable | |
| import torch | |
| from torch import nn | |
| from f5_tts.model import DiT, UNetT | |
| from f5_tts.model.utils import (default, exists, lens_to_mask, list_str_to_idx, | |
| list_str_to_tensor, mask_from_frac_lengths, | |
| sample_consecutive_steps, sample_from_list) | |
| from guidance_model import Guidance | |
| class UniModel(nn.Module): | |
| def __init__( | |
| self, | |
| model: DiT, # teacher model (dit model) | |
| checkpoint_path: str = "", | |
| second_time: bool = True, | |
| use_fp16: bool = True, | |
| real_guidance_scale: float = 2.0, | |
| fake_guidance_scale: float = 0.0, | |
| gen_cls_loss: bool = False, | |
| sway_coeff: float = -1.0, | |
| vocab_char_map: dict[str, int] | None = None, | |
| frac_lengths_mask: tuple[float, float] = (0.7, 1.0), | |
| ): | |
| super().__init__() | |
| if checkpoint_path != "": | |
| if "model_last.pt" in os.listdir(checkpoint_path): | |
| latest_checkpoint = "model_last.pt" | |
| else: | |
| latest_checkpoint = sorted( | |
| [f for f in os.listdir(checkpoint_path) if f.endswith(".pt")], | |
| key=lambda x: int("".join(filter(str.isdigit, x))), | |
| )[-1] | |
| checkpoint = torch.load( | |
| f"{checkpoint_path}/{latest_checkpoint}", | |
| weights_only=True, | |
| map_location="cpu", | |
| ) | |
| if "scale" in checkpoint: | |
| self.scale = checkpoint["scale"] | |
| else: | |
| self.scale = 1.0 | |
| print(f"Loaded teacher model with scale: {self.scale}") | |
| if "step" in checkpoint: | |
| state = checkpoint["model_state_dict"] | |
| else: | |
| checkpoint["model_state_dict"] = { | |
| k.replace("ema_model.", ""): v | |
| for k, v in checkpoint["ema_model_state_dict"].items() | |
| if k not in ["initted", "step"] | |
| } | |
| state = checkpoint["model_state_dict"] | |
| # only load the DiT module | |
| filtered_state_dict = { | |
| k.replace("transformer.", ""): v | |
| for k, v in state.items() | |
| if k.startswith("transformer.") | |
| } | |
| model.load_state_dict(filtered_state_dict, strict=False) | |
| else: | |
| self.scale = 1.0 | |
| real_unet = copy.deepcopy(model) | |
| real_unet.time_embed2 = None | |
| fake_unet = copy.deepcopy(model) | |
| # Instantiate Guidance, which internally uses real_unet and fake_unet initialized from the teacher | |
| self.guidance_model = Guidance( | |
| real_unet=real_unet, | |
| fake_unet=fake_unet, | |
| use_fp16=use_fp16, | |
| real_guidance_scale=real_guidance_scale, | |
| fake_guidance_scale=fake_guidance_scale, | |
| gen_cls_loss=gen_cls_loss, | |
| sway_coeff=sway_coeff, | |
| ) | |
| self.feedforward_model = copy.deepcopy(model) # initialize the student model | |
| self.feedforward_model.requires_grad_(True) | |
| self.feedforward_model.time_embed2 = None | |
| self.vocab_char_map = vocab_char_map | |
| self.frac_lengths_mask = frac_lengths_mask | |
| self.second_time = second_time # fake_unet.time_embed2 is not None | |
| def forward( | |
| self, | |
| inp: float["b n d"], # mel | |
| text: int["b nt"] | list[str], | |
| *, | |
| lens: int["b"] | None = None, | |
| student_steps: list[int] = [0, 0.25, 0.5, 0.75], | |
| update_generator: bool = False, | |
| ): | |
| """ | |
| Forward pass that routes to either generator_forward or guidance_forward | |
| in the Guidance class, depending on the arguments. | |
| Parameters: | |
| ----------- | |
| generator_turn: bool | |
| If True, run the generator forward pass (distribution matching loss, etc.) | |
| guidance_turn: bool | |
| If True, run the guidance forward pass (fake loss, cls loss, etc.) | |
| data_dict: dict | |
| Input dictionary containing the necessary keys for the forward passes. | |
| Expected keys may include: | |
| "inp": Tensor (B, N, D) - input mel or latent | |
| "text": Tensor or list[str] - text input | |
| "rand_span_mask": Tensor (B, N) - boolean mask | |
| "real_data": dict with keys like: | |
| "inp", "text", "rand_span_mask" | |
| Returns: | |
| -------- | |
| loss_dict: dict[str, Tensor] | |
| Dictionary of losses. | |
| log_dict: dict[str, Tensor or float] | |
| Dictionary of logging tensors or values. | |
| """ | |
| batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device | |
| # handle text as string | |
| if isinstance(text, list): | |
| if exists(self.vocab_char_map): | |
| text = list_str_to_idx(text, self.vocab_char_map).to(device) | |
| else: | |
| text = list_str_to_tensor(text).to(device) | |
| assert text.shape[0] == batch | |
| # lens and mask | |
| if not exists(lens): | |
| lens = torch.full((batch,), seq_len, device=device) | |
| mask = lens_to_mask( | |
| lens, length=seq_len | |
| ) # useless here, as collate_fn will pad to max length in batch | |
| # sample from the list of student steps | |
| time = sample_from_list(student_steps, batch).to(device) | |
| c_time, p_time = sample_consecutive_steps(student_steps) | |
| time = torch.ones_like(time) * c_time | |
| p_time = torch.ones_like(time) * p_time | |
| frac_lengths = ( | |
| torch.zeros((batch,), device=device) | |
| .float() | |
| .uniform_(*self.frac_lengths_mask) | |
| ) | |
| rand_span_mask = mask_from_frac_lengths(lens, frac_lengths) | |
| if exists(mask): | |
| rand_span_mask &= mask | |
| # # use generated output from previous step as input | |
| with torch.no_grad(): | |
| x1 = inp | |
| x0 = torch.randn_like(x1) | |
| t = p_time.unsqueeze(-1).unsqueeze(-1) | |
| phi = (1 - t) * x0 + t * x1 | |
| cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1) | |
| pred = self.feedforward_model( | |
| x=phi, | |
| cond=cond, | |
| text=text, | |
| time=p_time, | |
| drop_audio_cond=False, | |
| drop_text=False, # make sure the cfg=1 | |
| ) # flow prediction | |
| # predicted mel spectrogram | |
| output = phi + (1 - t) * pred | |
| output[~rand_span_mask] = inp[~rand_span_mask] | |
| # forward diffusion | |
| x1 = output | |
| x0 = torch.randn_like(x1) | |
| t = time.unsqueeze(-1).unsqueeze(-1) | |
| phi = (1 - t) * x0 + t * x1 | |
| cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1) | |
| with torch.no_grad() if not update_generator else contextlib.nullcontext(): | |
| pred = self.feedforward_model( | |
| x=phi, | |
| cond=cond, | |
| text=text, | |
| time=time, | |
| drop_audio_cond=False, | |
| drop_text=False, # make sure no cfg is used | |
| ) | |
| # predicted mel spectrogram | |
| output = phi + (1 - t) * pred | |
| output[~rand_span_mask] = inp[~rand_span_mask] | |
| if update_generator: | |
| generator_data_dict = { | |
| "inp": output, | |
| "text": text, | |
| "rand_span_mask": rand_span_mask, | |
| "second_time": time if self.second_time else None, | |
| "mse_loss": time.mean() == student_steps[-1].mean(), | |
| "real_data": { | |
| "inp": inp, | |
| "text": text, | |
| "rand_span_mask": rand_span_mask, | |
| }, | |
| } | |
| # avoid any side effects of gradient accumulation | |
| # self.guidance_model.requires_grad_(False) | |
| # self.feedforward_model.requires_grad_(True) | |
| generator_loss_dict, generator_log_dict = self.guidance_model( | |
| generator_turn=True, | |
| guidance_turn=False, | |
| generator_data_dict=generator_data_dict, | |
| guidance_data_dict=None, | |
| ) | |
| generator_log_dict["ground_truth"] = x1 | |
| generator_log_dict["generator_input"] = phi | |
| generator_log_dict["generator_output"] = output | |
| generator_log_dict["generator_cond"] = cond | |
| generator_log_dict["time"] = time | |
| return generator_loss_dict, generator_log_dict | |
| else: | |
| guidance_data_dict = { | |
| "inp": output.detach(), | |
| "text": text, | |
| "rand_span_mask": rand_span_mask, | |
| "second_time": time if self.second_time else None, | |
| "real_data": { | |
| "inp": inp, | |
| "text": text, | |
| "rand_span_mask": rand_span_mask, | |
| }, | |
| } | |
| # avoid any side effects of gradient accumulation | |
| # self.feedforward_model.requires_grad_(False) | |
| # self.guidance_model.requires_grad_(True) | |
| guidance_loss_dict, guidance_log_dict = self.guidance_model( | |
| generator_turn=False, | |
| guidance_turn=True, | |
| generator_data_dict=None, | |
| guidance_data_dict=guidance_data_dict, | |
| ) | |
| # self.feedforward_model.requires_grad_(True) | |
| return guidance_loss_dict, guidance_log_dict | |
| # return guidance_loss_dict, guidance_log_dict, generator_loss_dict, generator_log_dict | |
| if __name__ == "__main__": | |
| from torch.utils.data import DataLoader, Dataset, SequentialSampler | |
| from f5_tts.model.dataset import (DynamicBatchSampler, collate_fn, | |
| load_dataset) | |
| from f5_tts.model.utils import get_tokenizer | |
| bsz = 16 | |
| tokenizer = "pinyin" # 'pinyin', 'char', or 'custom' | |
| tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) | |
| dataset_name = "Emilia_ZH_EN" | |
| if tokenizer == "custom": | |
| tokenizer_path = tokenizer_path | |
| else: | |
| tokenizer_path = dataset_name | |
| vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer) | |
| dit = DiT( | |
| dim=1024, | |
| depth=22, | |
| heads=16, | |
| ff_mult=2, | |
| text_dim=512, | |
| conv_layers=4, | |
| text_num_embeds=vocab_size, | |
| mel_dim=100, | |
| ) | |
| model = UniModel( | |
| dit, | |
| checkpoint_path="/data4/F5TTS/ckpts/F5TTS_Base_norm_flow_8GPU_vocos_pinyin_Emilia_ZH_EN", | |
| gen_cls_loss=True, | |
| vocab_char_map=vocab_char_map, | |
| frac_lengths_mask=(0.7, 1.0), | |
| ).cuda() | |
| # batch = next(iter(train_dataloader)) | |
| # torch.save(batch, "batch.pt") | |
| batch = torch.load("batch.pt") | |
| inp, text, lens = ( | |
| batch["mel"].permute(0, 2, 1).cuda(), | |
| batch["text"], | |
| batch["mel_lengths"].cuda(), | |
| ) | |
| # text = ["hello world"] * bsz | |
| # lens = torch.randint(1, 1000, (bsz,)).cuda() | |
| # inp = torch.randn(bsz, lens.max(), 100).cuda() | |
| with torch.autocast(device_type="cuda", dtype=torch.float16): | |
| num_student_step = 4 | |
| guidance_loss_dict, guidance_log_dict = model( | |
| inp, | |
| text, | |
| lens=lens, | |
| update_generator=False, | |
| student_steps=(torch.linspace(0.0, 1.0, num_student_step + 1)[:-1]), | |
| ) | |
| generator_loss_dict, generator_log_dict = model( | |
| inp, | |
| text, | |
| lens=lens, | |
| update_generator=True, | |
| student_steps=(torch.linspace(0.0, 1.0, num_student_step + 1)[:-1]), | |
| ) | |
| print(guidance_loss_dict) | |
| print(generator_loss_dict) | |
| guidance_loss = 0 | |
| guidance_loss += guidance_loss_dict["loss_fake_mean"] | |
| guidance_loss += guidance_loss_dict["guidance_cls_loss"] | |
| generator_loss = 0 | |
| generator_loss += generator_loss_dict["loss_dm"] | |
| generator_loss += generator_loss_dict["loss_ctc"] | |
| generator_loss += generator_loss_dict["loss_sim"] | |
| generator_loss += generator_loss_dict["gen_cls_loss"] | |
| generator_loss += generator_loss_dict["loss_mse"] | |
| guidance_loss.backward() | |
| generator_loss.backward() | |