SoulX-Singer / soulxsinger /models /modules /flow_matching.py
Xinsheng-Wang's picture
Upload folder using huggingface_hub
c7f3ffb verified
# https://github.com/open-mmlab/Amphion/blob/main/models/svc/flow_matching_transformer/fmt_model.py
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import numpy as np
import torch.nn as nn
import math
from .llama import DiffLlama
import torch.nn.functional as F
class FlowMatchingTransformer(nn.Module):
def __init__(
self,
mel_dim=100,
hidden_size=1024,
num_layers=12,
num_heads=16,
cfg_drop_prob=0.2,
use_embedding=True,
cond_codebook_size=1024,
cond_scale_factor=1,
sigma=1e-5,
time_scheduler="linear",
cfg=None,
):
super().__init__()
self.cfg = cfg
if cfg is not None:
mel_dim = getattr(cfg, "mel_dim", mel_dim)
hidden_size = getattr(cfg, "hidden_size", hidden_size)
num_layers = getattr(cfg, "num_layers", num_layers)
num_heads = getattr(cfg, "num_heads", num_heads)
cfg_drop_prob = getattr(cfg, "cfg_drop_prob", cfg_drop_prob)
cond_codebook_size = getattr(cfg, "cond_codebook_size", cond_codebook_size)
time_scheduler = getattr(cfg, "time_scheduler", time_scheduler)
sigma = getattr(cfg, "sigma", sigma)
cond_scale_factor = getattr(cfg, "cond_scale_factor", cond_scale_factor)
self.mel_dim = mel_dim
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_heads = num_heads
self.cfg_drop_prob = cfg_drop_prob
self.cond_codebook_size = cond_codebook_size
self.time_scheduler = time_scheduler
self.sigma = sigma
self.cond_scale_factor = cond_scale_factor
if use_embedding:
self.cond_emb = nn.Embedding(cond_codebook_size, self.hidden_size)
else:
self.cond_emb = nn.Linear(cond_codebook_size, self.hidden_size)
if cond_scale_factor != 1:
self.do_resampling = True
assert np.log2(cond_scale_factor).is_integer()
up_layers = []
for _ in range(int(np.log2(cond_scale_factor))):
up_layers.extend(
[
nn.ConvTranspose1d(
hidden_size, hidden_size, kernel_size=4, stride=2, padding=1
),
nn.GELU(),
]
)
self.resampling_layers = nn.Sequential(*up_layers)
else:
self.do_resampling = False
### REPA: Use the Wav2Vec2Bert features to align. ###
self.use_repa = "repa" in cfg
self.repa_layer_index = None
if self.use_repa:
self.repa_layer_index = cfg.repa.layer_index
self.repa_mlp_layer = nn.Sequential(
nn.Linear(hidden_size, hidden_size * 4),
nn.SiLU(),
nn.Linear(hidden_size * 4, cfg.repa.output_dim),
)
### CTC: Use the ASR loss ###
self.use_ctc = "ctc" in cfg
self.ctc_layer_index = None
if self.use_ctc:
self.ctc_layer_index = cfg.ctc.layer_index
self.ctc_mlp_layer = nn.Sequential(
nn.Linear(hidden_size, hidden_size * 4),
nn.SiLU(),
nn.Linear(hidden_size * 4, cfg.ctc.output_dim),
)
self.reset_parameters()
self.diff_estimator = DiffLlama(
mel_dim=mel_dim,
hidden_size=hidden_size,
num_heads=num_heads,
num_layers=num_layers,
)
self.sigma = sigma
@torch.no_grad()
def forward_diffusion(self, x, t, is_prompt=None):
"""
x: (B, T, mel_dim)
t: (B,)
"""
new_t = t
t = t.unsqueeze(-1).unsqueeze(-1)
z = torch.randn(
x.shape, dtype=x.dtype, device=x.device, requires_grad=False
) # (B, T, mel_dim)
# get prompt len
if torch.rand(1) <= self.cfg_drop_prob:
prompt_len = torch.zeros(x.shape[0]).to(x)
is_prompt = torch.zeros_like(x[:, :, 0])
else:
if is_prompt is None:
prompt_len = torch.randint(
min(x.shape[1] // 4, 5), int(x.shape[1] * 0.4), (x.shape[0],)
).to(
x.device
) # (B,)
# get is_prompt
is_prompt = torch.zeros_like(x[:, :, 0]) # (B, T)
col_indices = (
torch.arange(is_prompt.shape[1])
.repeat(is_prompt.shape[0], 1)
.to(prompt_len)
) # (B, T)
is_prompt[col_indices < prompt_len.unsqueeze(1)] = 1 # (B, T) 1 if prompt
else:
prompt_len = is_prompt.sum(dim=1) # (B,)
mask = torch.ones_like(x[:, :, 0]) # mask if 1, not mask if 0
mask[is_prompt.bool()] = 0
mask = mask[:, :, None]
# flow matching: xt = (1 - (1 - sigma) * t) * x0 + t * x; where x0 ~ N(0, 1), x is a sample
# flow gt: x - (1 - sigma) * x0 = x - (1 - sigma) * noise
xt = ((1 - (1 - self.sigma) * t) * z + t * x) * mask + x * (1 - mask)
return xt, z, new_t, prompt_len, mask
def loss_t(
self,
x,
x_mask,
t,
cond=None,
is_prompt=None
):
xt, z, new_t, prompt_len, mask = self.forward_diffusion(x, t, is_prompt)
noise = z
# drop all condition for cfg, so if prompt_len is 0, we also drop cond
if cond is not None:
cond = cond * torch.where(
prompt_len > 0,
torch.ones_like(prompt_len),
torch.zeros_like(prompt_len),
).to(cond.device).unsqueeze(-1).unsqueeze(-1)
dit_output = self.diff_estimator(xt, new_t, cond, x_mask, return_dict=True)
flow_pred = dit_output["output"] # (B, T, mel_dim)
# final mask used for loss calculation
final_mask = mask * x_mask[..., None] # (B, T, 1)
results = {"output": (noise, x, flow_pred, final_mask, prompt_len)}
if self.use_repa:
repa_hidden_states = dit_output["hidden_states"][
self.repa_layer_index
] # (B, T, hidden_size)
repa_pred = self.repa_mlp_layer(repa_hidden_states) # (B, T, repa_dim)
results["repa"] = repa_pred
if self.use_ctc:
ctc_hidden_states = dit_output["hidden_states"][
self.ctc_layer_index
] # (B, T, hidden_size)
ctc_pred = self.ctc_mlp_layer(ctc_hidden_states) # (B, T, ctc_dim)
results["ctc"] = ctc_pred
return results
def compute_loss(self, x, x_mask, cond=None, is_prompt=None):
# x0: (B, T, num_quantizer)
# x_mask: (B, T) mask is 0 for padding
t = torch.rand(x.shape[0], device=x.device, requires_grad=False)
t = torch.clamp(t, 1e-5, 1.0)
# from CosyVoice: considering the generation process at the beginning is harder than follows, we involve a cosine scheduler for the timestep t
if self.time_scheduler == "cos":
t = 1 - torch.cos(t * math.pi * 0.5)
else:
pass
return self.loss_t(x, x_mask, t, cond, is_prompt)
def reset_parameters(self):
def _reset_parameters(m):
if isinstance(m, nn.MultiheadAttention):
if m._qkv_same_embed_dim:
nn.init.normal_(m.in_proj_weight, std=0.02)
else:
nn.init.normal_(m.q_proj_weight, std=0.02)
nn.init.normal_(m.k_proj_weight, std=0.02)
nn.init.normal_(m.v_proj_weight, std=0.02)
if m.in_proj_bias is not None:
nn.init.constant_(m.in_proj_bias, 0.0)
nn.init.constant_(m.out_proj.bias, 0.0)
if m.bias_k is not None:
nn.init.xavier_normal_(m.bias_k)
if m.bias_v is not None:
nn.init.xavier_normal_(m.bias_v)
elif (
isinstance(m, nn.Conv1d)
or isinstance(m, nn.ConvTranspose1d)
or isinstance(m, nn.Conv2d)
or isinstance(m, nn.ConvTranspose2d)
):
m.weight.data.normal_(0.0, 0.02)
elif isinstance(m, nn.Linear):
m.weight.data.normal_(mean=0.0, std=0.02)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Embedding):
m.weight.data.normal_(mean=0.0, std=0.02)
if m.padding_idx is not None:
m.weight.data[m.padding_idx].zero_()
self.apply(_reset_parameters)
@torch.no_grad()
def reverse_diffusion(
self,
cond,
prompt,
x_mask=None,
prompt_mask=None,
n_timesteps=10,
cfg=1.0,
rescale_cfg=0.75,
):
h = 1.0 / n_timesteps
prompt_len = prompt.shape[1]
target_len = cond.shape[1] - prompt_len
if x_mask == None:
x_mask = torch.ones(cond.shape[0], target_len).to(cond.device) # (B, T)
if prompt_mask == None:
prompt_mask = torch.ones(cond.shape[0], prompt_len).to(
cond.device
) # (B, prompt_len)
xt_mask = torch.cat([prompt_mask, x_mask], dim=1)
z = torch.randn(
(cond.shape[0], target_len, self.mel_dim),
dtype=cond.dtype,
device=cond.device,
requires_grad=False,
)
xt = z
# t from 0 to 1: x0 = z ~ N(0, 1)
for i in range(n_timesteps):
xt_input = torch.cat([prompt, xt], dim=1)
t = (0 + (i + 0.5) * h) * torch.ones(
z.shape[0], dtype=z.dtype, device=z.device
)
flow_pred = self.diff_estimator(xt_input, t, cond, xt_mask)
flow_pred = flow_pred[:, prompt_len:, :]
# cfg
if cfg > 0:
uncond_flow_pred = self.diff_estimator(
xt, t, torch.zeros_like(cond)[:, : xt.shape[1], :], x_mask
)
pos_flow_pred_std = flow_pred.std()
flow_pred_cfg = flow_pred + cfg * (flow_pred - uncond_flow_pred)
rescale_flow_pred = (
flow_pred_cfg * pos_flow_pred_std / flow_pred_cfg.std()
)
flow_pred = (
rescale_cfg * rescale_flow_pred + (1 - rescale_cfg) * flow_pred_cfg
)
dxt = flow_pred * h
xt = xt + dxt
return xt
@torch.no_grad()
def reverse_diffusion_v2(
self,
cond,
prompt,
x_mask=None,
prompt_mask=None,
n_timesteps=10,
cfg=1.0,
rescale_cfg=0.75,
):
h = 1.0 / n_timesteps
prompt_len = prompt.shape[1]
target_len = cond.shape[1] - prompt_len * 2
if x_mask == None:
x_mask = torch.ones(cond.shape[0], target_len).to(cond.device) # (B, T)
if prompt_mask == None:
prompt_mask = torch.ones(cond.shape[0], prompt_len).to(
cond.device
) # (B, prompt_len)
xt_mask = torch.cat([prompt_mask, x_mask, prompt_mask], dim=1)
z = torch.randn(
(cond.shape[0], target_len, self.mel_dim),
dtype=cond.dtype,
device=cond.device,
requires_grad=False,
)
xt = z
# t from 0 to 1: x0 = z ~ N(0, 1)
for i in range(n_timesteps):
xt_input = torch.cat([prompt, xt, prompt], dim=1)
t = (0 + (i + 0.5) * h) * torch.ones(
z.shape[0], dtype=z.dtype, device=z.device
)
flow_pred = self.diff_estimator(xt_input, t, cond, xt_mask)
flow_pred = flow_pred[:, prompt_len:-prompt_len, :]
# cfg
if cfg > 0:
uncond_flow_pred = self.diff_estimator(
xt, t, torch.zeros_like(cond)[:, : xt.shape[1], :], x_mask
)
pos_flow_pred_std = flow_pred.std()
flow_pred_cfg = flow_pred + cfg * (flow_pred - uncond_flow_pred)
rescale_flow_pred = (
flow_pred_cfg * pos_flow_pred_std / flow_pred_cfg.std()
)
flow_pred = (
rescale_cfg * rescale_flow_pred + (1 - rescale_cfg) * flow_pred_cfg
)
dxt = flow_pred * h
xt = xt + dxt
return xt
def forward(self, x, x_mask, cond_code, is_prompt=None):
"""
Args:
x: (B, T, mel_dim)
x_mask: (B, T)
cond_code: (B, T), Note that cond_code might be not at 50Hz!
"""
T = x.shape[1]
cond = self.cond_emb(cond_code) # (B, T, hidden_size)
if self.do_resampling:
# Align to the frame rate of Mels
cond = self.resampling_layers(cond.transpose(1, 2)).transpose(1, 2)
# print("cond_code: {}, after resampling: {}".format(cond_code.shape, cond.shape))
if cond.shape[1] >= T: # Check time dimension
cond = cond[:, :T, :]
else:
padding_frames = T - cond.shape[1]
last_frame = cond[:, -1:, :]
padding = last_frame.repeat(1, padding_frames, 1)
cond = torch.cat([cond, padding], dim=1)
return self.compute_loss(x, x_mask, cond, is_prompt)
if __name__ == "__main__":
model_cfg = {
"mel_dim": 128,
"hidden_size": 256,
"num_layers": 8,
"num_heads": 8,
"cfg_drop_prob": 0.2,
"use_embedding": False,
"cond_codebook_size": 256,
"cond_scale_factor": 1,
"sigma": 1e-5,
"time_scheduler": "cos",
}
device = "cuda"
x = torch.randn(2, 100, 128).to(device)
x_mask = torch.ones(2, 100).to(device)
# cond_code = torch.randint(0, 16384, (2, 25)).to(device)
cond_code = torch.randn(2, 100, 256).to(device)
model = FlowMatchingTransformer(cfg=model_cfg, **model_cfg).to(device)
outputs = model(x, x_mask, cond_code)
print(outputs)
noise, x, flow_pred, final_mask, prompt_len = outputs["output"]
final_mask = final_mask.squeeze(-1)
flow_gt = x - (1 - 1e-5) * noise
# [B, n_frames, D]
diff_loss = F.l1_loss(
flow_pred, flow_gt, reduction="none"
).float() * final_mask.unsqueeze(-1)
diff_loss = torch.mean(diff_loss, dim=2).sum() / final_mask.sum()
print("diff_loss:", diff_loss.item())
diffusion_cond = torch.randn(2, 150, 256).to(device)
diffusion_cond_emb = model.cond_emb(diffusion_cond)
diffusion_prompt = torch.randn(2, 50, 128).to(device)
n_timesteps = 32
generated = model.reverse_diffusion(
diffusion_cond_emb,
diffusion_prompt,
n_timesteps=n_timesteps
)
print("generated:", generated.shape)