fff / X-Codec-2.0 /lightning_module.py
novateur's picture
Add files using upload-large-folder tool
1cba6e2 verified
import os
import random
import hydra
import numpy as np
import librosa
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pytorch_lightning as pl
from vq import CodecEncoder, CodecDecoderVocos
from module import HiFiGANMultiPeriodDiscriminator, SpecDiscriminator
from criterions import GANLoss, MultiResolutionMelSpectrogramLoss, MultiResolutionSTFTLoss
from common.schedulers import WarmupLR
from transformers import AutoModel
from vq.module import SemanticDecoder,SemanticEncoder
from transformers import AutoFeatureExtractor, Wav2Vec2BertModel
import sys
# 说话人验证模型(可选,依赖 fairseq;不可用时验证阶段跳过 sim 指标)
try:
sys.path.append('/apdcephfs/private_jishengpeng2/work/shengpeng/research/X-Codec-2.0/UniSpeech/downstreams/speaker_verification')
from verification import init_model
model_spk = init_model('wavlm_large', '/apdcephfs/private_jishengpeng2/work/shengpeng/research/X-Codec-2.0/ckpt/wavlm_large_finetune.pth')
_SPK_MODEL_AVAILABLE = True
except Exception as e:
print(f"[警告] 说话人验证模型加载失败,验证阶段将跳过 sim 指标: {e}")
model_spk = None
_SPK_MODEL_AVAILABLE = False
class CodecLightningModule(pl.LightningModule):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.ocwd = hydra.utils.get_original_cwd()
self.construct_model()
self.construct_criteria()
self.save_hyperparameters()
self.automatic_optimization = False
def construct_model(self):
# 初始化 Codec Encoder
enccfg = self.cfg.model.codec_encoder
self.CodecEnc = CodecEncoder(
ngf=enccfg.ngf,
up_ratios=enccfg.up_ratios,
dilations=enccfg.dilations,
hidden_dim=enccfg['hidden_dim'],
depth=enccfg['depth'],
heads=enccfg['heads'],
pos_meb_dim=enccfg['pos_meb_dim'],
)
# 初始化 Codec Decoder
deccfg = self.cfg.model.codec_decoder
self.generator = CodecDecoderVocos(
hidden_dim=deccfg.hidden_dim,
depth=deccfg.depth,
heads=deccfg.heads,
pos_meb_dim=deccfg.pos_meb_dim,
hop_length=320,
vq_num_quantizers=deccfg.vq_num_quantizers, # VQ 量化器数量
vq_dim=deccfg.vq_dim, # VQ 维度
vq_commit_weight=deccfg.vq_commit_weight, # VQ 提交权重
vq_weight_init=deccfg.vq_weight_init, # VQ 权重初始化
vq_full_commit_loss=deccfg.vq_full_commit_loss, # 是否使用完整的提交损失
codebook_size=deccfg.codebook_size, # 码本大小
codebook_dim=deccfg.codebook_dim , # 码本维度
# 隐藏层维度
)
# 初始化 MultiPeriod Discriminator
mpdcfg = self.cfg.model.mpd
self.discriminator = HiFiGANMultiPeriodDiscriminator(
periods=mpdcfg.periods,
max_downsample_channels=mpdcfg.max_downsample_channels,
channels=mpdcfg.channels,
channel_increasing_factor=mpdcfg.channel_increasing_factor,
)
# 初始化 Spectral Discriminator
mstftcfg = self.cfg.model.mstft
self.spec_discriminator = SpecDiscriminator(
stft_params=mstftcfg.stft_params,
in_channels=mstftcfg.in_channels,
out_channels=mstftcfg.out_channels,
kernel_sizes=mstftcfg.kernel_sizes,
channels=mstftcfg.channels,
max_downsample_channels=mstftcfg.max_downsample_channels,
downsample_scales=mstftcfg.downsample_scales,
use_weight_norm=mstftcfg.use_weight_norm,
)
# 单独编译需要优化的子模块
# self.CodecEnc = torch.compile(self.CodecEnc)
# self.generator.backbone = torch.compile(self.generator )
# self.mel_conv = torch.compile(self.mel_conv)
if _SPK_MODEL_AVAILABLE:
self.model_spk = model_spk.eval()
# self.semantic_model = AutoModel.from_pretrained("microsoft/wavlm-large")
# self.semantic_model.eval()
# self.semantic_model.requires_grad_(False)
self.fc_prior = nn.Linear(1024 + 1024, deccfg.vq_dim, )
self.fc_post_a = nn.Linear(deccfg.vq_dim, deccfg.hidden_dim )
self.fc_post_s = nn.Linear(deccfg.vq_dim, 1024)
self.SemanticDecoder_module = SemanticDecoder(1024, 1024, 1024)
self.SemanticEncoder_module = SemanticEncoder(1024, 1024, 1024)
self.semantic_model = Wav2Vec2BertModel.from_pretrained("/apdcephfs/private_jishengpeng2/work/shengpeng/research/X-Codec-2.0/ckpt/w2v-bert-2.0", output_hidden_states=True)
self.semantic_model.eval()
self.semantic_model.requires_grad_(False)
# self.register_buffer('mel_basis', mel_basis)
# self.perception_model = AutoModel.from_pretrained("facebook/wav2vec2-large-xlsr-53")
# self.perception_model.eval()
# self.perception_model.requires_grad_(False)
def construct_criteria(self):
cfg = self.cfg.train
self.criteria = nn.ModuleDict()
if cfg.use_mel_loss:
self.criteria['mel_loss'] = MultiResolutionMelSpectrogramLoss(sample_rate=self.cfg.dataset.sr)
if cfg.use_stft_loss:
self.criteria['stft_loss'] = MultiResolutionSTFTLoss(
fft_sizes=cfg.stft_loss_params.fft_sizes,
hop_sizes=cfg.stft_loss_params.hop_sizes,
win_sizes=cfg.stft_loss_params.win_lengths
)
if cfg.use_feat_match_loss:
self.criteria['fm_loss'] = nn.L1Loss()
self.criteria['gan_loss'] = GANLoss()
self.criteria['l1_loss'] = nn.L1Loss()
self.criteria['l2_loss'] = nn.MSELoss()
print(self.criteria)
def forward(self, batch):
wav = batch['wav']
feats= batch['feats']
# if self.trainer.local_rank == 0:
# breakpoint()
# else:
# import time
# time.sleep(10000000)
vq_emb = self.CodecEnc(wav.unsqueeze(1))
vq_emb = vq_emb.transpose(1, 2)
with torch.no_grad():
semantic_target = self.semantic_model(feats[:,0,:,:])
semantic_target = semantic_target.hidden_states[16]
semantic_target = semantic_target.detach()
semantic_target = semantic_target.transpose(1, 2)
semantic_target_processed = self.SemanticEncoder_module(semantic_target)
# 拼接语义嵌入和编码器输出
vq_emb = torch.cat([semantic_target_processed, vq_emb], dim=1)
vq_emb = self.fc_prior(vq_emb.transpose(1, 2)).transpose(1, 2) #[8,1024,300]
# if self.trainer.local_rank == 0:
# breakpoint()
# else:
# import time
# time.sleep(10000000)
vq_post_emb = vq_emb
# vq_post_emb, vq_code, vq_loss = self.generator(vq_emb, vq=True)
semantic_recon = self.fc_post_s(vq_post_emb.transpose(1, 2)).transpose(1, 2)
semantic_recon = self.SemanticDecoder_module(semantic_recon)
y_ ,_ = self.generator(
self.fc_post_a(vq_post_emb.transpose(1, 2)) ,
vq=False
)
y = wav.unsqueeze(1)
# gt_perceptual = self.perception_model(wav.squeeze(1), output_hidden_states=True) .hidden_states
# gen_perceptual = self.perception_model(y_.squeeze(1), output_hidden_states=True) .hidden_states
# gt_perceptual_se = gt_perceptual[10:22]
# gen_perceptual_se = gen_perceptual[10:22]
# perceptual_se_loss = [tensor1 - tensor2 for tensor1, tensor2 in zip(gt_perceptual_se, gen_perceptual_se)]
# # 使用列表推导式逐元素相减
# perceptual_se_loss_l2 = [F.mse_loss(tensor1.detach(), tensor2) for tensor1, tensor2 in zip(gt_perceptual_se, gen_perceptual_se)]
# perceptual_se_loss_l2 =torch.stack(perceptual_se_loss_l2).mean()
output = {
'gt_wav': y,
'gen_wav': y_,
# 'vq_loss': vq_loss,
# 'vq_code': vq_code,
'semantic_recon_loss': F.mse_loss(semantic_recon, semantic_target),
# 'perceptual_se_loss_l2': perceptual_se_loss_l2,
}
return output
@torch.inference_mode()
def inference(self, wav):
vq_emb = self.CodecEnc(wav.unsqueeze(1))
vq_post_emb = vq_emb
# vq_post_emb, vq_code, vq_loss = self.generator(vq_emb, vq=True)
y_ = self.generator(vq_post_emb, vq=False).squeeze(1) # [B, T]
return y_
def compute_disc_loss(self, batch, output):
y, y_ = output['gt_wav'], output['gen_wav']
y_ = y_.detach()
p = self.discriminator(y)
p_ = self.discriminator(y_)
real_loss_list, fake_loss_list = [], []
for i in range(len(p)):
real_loss, fake_loss = self.criteria['gan_loss'].disc_loss(p[i][-1], p_[i][-1])
real_loss_list.append(real_loss)
fake_loss_list.append(fake_loss)
if hasattr(self, 'spec_discriminator'):
sd_p = self.spec_discriminator(y)
sd_p_ = self.spec_discriminator(y_)
for i in range(len(sd_p)):
real_loss, fake_loss = self.criteria['gan_loss'].disc_loss(sd_p[i][-1], sd_p_[i][-1])
real_loss_list.append(real_loss)
fake_loss_list.append(fake_loss)
real_loss = sum(real_loss_list)
fake_loss = sum(fake_loss_list)
disc_loss = real_loss + fake_loss
disc_loss = self.cfg.train.lambdas.lambda_disc * disc_loss
output = {
'real_loss': real_loss,
'fake_loss': fake_loss,
'disc_loss': disc_loss,
}
return output
def compute_gen_loss(self, batch, output):
y, y_ = output['gt_wav'], output['gen_wav']
# vq_loss, vq_code = output['vq_loss'], output['vq_code']
semantic_recon_loss = output['semantic_recon_loss']
# perceptual_se_loss_l2 = output['perceptual_se_loss_l2']
# x_feat_recon_loss = output['x_feat_recon_loss']
gen_loss = 0.0
self.set_discriminator_gradients(False)
output_dict = {}
cfg = self.cfg.train
# Mel spectrogram loss
if cfg.use_mel_loss:
mel_loss = self.criteria['mel_loss'](y_.squeeze(1), y.squeeze(1))
gen_loss += mel_loss * cfg.lambdas.lambda_mel_loss
output_dict['mel_loss'] = mel_loss
# GAN loss
p_ = self.discriminator(y_)
adv_loss_list = []
for i in range(len(p_)):
adv_loss_list.append(self.criteria['gan_loss'].gen_loss(p_[i][-1]))
if hasattr(self, 'spec_discriminator'):
sd_p_ = self.spec_discriminator(y_)
for i in range(len(sd_p_)):
adv_loss_list.append(self.criteria['gan_loss'].gen_loss(sd_p_[i][-1]))
adv_loss = sum(adv_loss_list)
gen_loss += adv_loss * cfg.lambdas.lambda_adv
output_dict['adv_loss'] = adv_loss
# Feature Matching loss
if cfg.use_feat_match_loss:
fm_loss = 0.0
with torch.no_grad():
p = self.discriminator(y)
for i in range(len(p_)):
for j in range(len(p_[i]) - 1):
fm_loss += self.criteria['fm_loss'](p_[i][j], p[i][j].detach())
gen_loss += fm_loss * cfg.lambdas.lambda_feat_match_loss
output_dict['fm_loss'] = fm_loss
if hasattr(self, 'spec_discriminator'):
spec_fm_loss = 0.0
with torch.no_grad():
sd_p = self.spec_discriminator(y)
for i in range(len(sd_p_)):
for j in range(len(sd_p_[i]) - 1):
spec_fm_loss += self.criteria['fm_loss'](sd_p_[i][j], sd_p[i][j].detach())
gen_loss += spec_fm_loss * cfg.lambdas.lambda_feat_match_loss
output_dict['spec_fm_loss'] = spec_fm_loss
# # VQ loss
# if vq_loss is not None:
# vq_loss = sum(vq_loss)
# gen_loss += vq_loss
# output_dict['vq_loss'] = vq_loss
# Semantic reconstruction loss
output_dict['semantic_recon_loss'] = semantic_recon_loss
gen_loss += output_dict['semantic_recon_loss'] * cfg.lambdas.lambda_semantic_loss
# Perceptual loss
# output_dict['perceptual_se_loss_l2'] = perceptual_se_loss_l2
# gen_loss += output_dict['perceptual_se_loss_l2'] * cfg.lambdas.lambda_perceptual_loss
self.set_discriminator_gradients(True)
output_dict['gen_loss'] = gen_loss
return output_dict
def training_step(self, batch, batch_idx):
output = self(batch)
gen_opt, disc_opt = self.optimizers()
gen_sche, disc_sche = self.lr_schedulers()
# 训练判别器
disc_losses = self.compute_disc_loss(batch, output)
disc_loss = disc_losses['disc_loss']
disc_opt.zero_grad()
self.manual_backward(disc_loss)
self.clip_gradients(
disc_opt,
gradient_clip_val=self.cfg.train.disc_grad_clip,
gradient_clip_algorithm='norm'
)
disc_opt.step()
disc_sche.step()
# 训练生成器
gen_losses = self.compute_gen_loss(batch, output)
gen_loss = gen_losses['gen_loss']
gen_opt.zero_grad()
self.manual_backward(gen_loss)
self.clip_gradients(
gen_opt,
gradient_clip_val=self.cfg.train.gen_grad_clip,
gradient_clip_algorithm='norm'
)
gen_opt.step()
gen_sche.step()
# 记录损失
self.log_dict(
disc_losses,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
batch_size=self.cfg.dataset.train.batch_size,
sync_dist=True
)
self.log_dict(
gen_losses,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
batch_size=self.cfg.dataset.train.batch_size,
sync_dist=True
)
def validation_step(self, batch, batch_idx):
output = self(batch)
if not _SPK_MODEL_AVAILABLE:
# fairseq 未安装,跳过说话人相似度计算
mel_loss = self.criteria['mel_loss'](
output['gen_wav'].squeeze(1), output['gt_wav'].squeeze(1)
)
self.log('val/mel_loss', mel_loss, on_step=False, on_epoch=True,
prog_bar=True, logger=True)
return {'mel_loss': mel_loss}
y = output['gt_wav']
y_ = output['gen_wav']
embeddings1 = self.model_spk(y.squeeze(1))
embeddings2 = self.model_spk(y_.squeeze(1))
sim = F.cosine_similarity(embeddings1, embeddings2).mean()
self.log('val/sim', sim, on_step=False, on_epoch=True, prog_bar=True, logger=True)
return {'sim': sim}
def test_step(self, batch, batch_idx):
# 您可以在此处实现测试逻辑
pass
def configure_optimizers(self):
from itertools import chain
# 判别器参数
disc_params = self.discriminator.parameters()
# if hasattr(self, 'spec_discriminator'):
disc_params = chain(disc_params, self.spec_discriminator.parameters())
# 生成器参数
gen_params = chain(
self.CodecEnc.parameters(),
self.generator.parameters(),
# self.mel_conv.parameters(),
self.fc_prior.parameters(),
self.fc_post_a.parameters(),
self.fc_post_s.parameters(),
self.SemanticDecoder_module.parameters(),
self.SemanticEncoder_module.parameters()
)
# 优化器
gen_opt = optim.AdamW(gen_params, **self.cfg.train.gen_optim_params)
disc_opt = optim.AdamW(disc_params, **self.cfg.train.disc_optim_params)
# 学习率调度器
gen_sche = WarmupLR(gen_opt, **self.cfg.train.gen_schedule_params)
disc_sche = WarmupLR(disc_opt, **self.cfg.train.disc_schedule_params)
print(f'Generator optim: {gen_opt}')
print(f'Discriminator optim: {disc_opt}')
return [gen_opt, disc_opt], [gen_sche, disc_sche]
def set_discriminator_gradients(self, flag=True):
for p in self.discriminator.parameters():
p.requires_grad = flag
if hasattr(self, 'spec_discriminator'):
for p in self.spec_discriminator.parameters():
p.requires_grad = flag