|
|
|
|
|
|
|
|
''' |
|
|
@Project :Waveformer-main |
|
|
@File :CLAPSep.py |
|
|
@IDE :PyCharm |
|
|
@Author :Aisaka/Hao Ma @SDU |
|
|
@Date :2024/2/28 下午1:12 |
|
|
''' |
|
|
|
|
|
import torch |
|
|
import laion_clap |
|
|
from torchmetrics.audio.snr import( |
|
|
scale_invariant_signal_noise_ratio as si_snr, |
|
|
signal_noise_ratio as snr) |
|
|
from torchmetrics.audio.sdr import( |
|
|
signal_distortion_ratio as sdr, |
|
|
scale_invariant_signal_distortion_ratio as si_sdr) |
|
|
import copy |
|
|
import loralib as lora |
|
|
from torchlibrosa import ISTFT, STFT, SpecAugmentation |
|
|
from torchlibrosa.stft import magphase |
|
|
import librosa |
|
|
import pytorch_lightning as pl |
|
|
|
|
|
|
|
|
def loss_fn(pred, tgt): |
|
|
return -0.9 * snr(pred, tgt).mean() - 0.1 * si_snr(pred, tgt).mean() |
|
|
|
|
|
|
|
|
def set_module(model, submodule_key, module): |
|
|
tokens = submodule_key.split('.') |
|
|
sub_tokens = tokens[:-1] |
|
|
cur_mod = model |
|
|
for s in sub_tokens: |
|
|
cur_mod = getattr(cur_mod, s) |
|
|
setattr(cur_mod, tokens[-1], module) |
|
|
|
|
|
|
|
|
def process_model(model, rank): |
|
|
for n, module in model.named_modules(): |
|
|
if 'WindowAttention' in str(type(module)): |
|
|
for n_, layer in module.named_modules(): |
|
|
if isinstance(layer, torch.nn.Linear): |
|
|
lora_layer = lora.Linear(layer.in_features, layer.out_features, r=rank, |
|
|
bias=hasattr(layer, 'bias'), merge_weights=False) |
|
|
lora_layer.weight = layer.weight |
|
|
if hasattr(layer, 'bias'): |
|
|
lora_layer.bias = layer.bias |
|
|
set_module(model, n+'.'+n_, lora_layer) |
|
|
return model |
|
|
|
|
|
|
|
|
class LightningModule(pl.LightningModule): |
|
|
def __init__(self, clap_model, decoder_model, lr, use_lora=False, rank=8, nfft=1024): |
|
|
super().__init__() |
|
|
self.phase = decoder_model.phase |
|
|
self.lr = lr |
|
|
self.clap_model = clap_model |
|
|
for p in self.clap_model.parameters(): |
|
|
p.requires_grad = False |
|
|
self.audio_branch = copy.deepcopy(self.clap_model.model.audio_branch) |
|
|
if use_lora: |
|
|
process_model(self.audio_branch, rank) |
|
|
lora.mark_only_lora_as_trainable(self.audio_branch, bias='lora_only') |
|
|
|
|
|
self.decoder_model = decoder_model |
|
|
self.stft = STFT(n_fft=nfft, hop_length=320, |
|
|
win_length=nfft, window='hann', center=True, pad_mode='reflect', |
|
|
freeze_parameters=True) |
|
|
self.istft = ISTFT(n_fft=nfft, hop_length=320, |
|
|
win_length=nfft, window='hann', center=True, pad_mode='reflect', |
|
|
freeze_parameters=True) |
|
|
self.features = self.install_forward_hooks() |
|
|
|
|
|
def training_step(self, batch, batch_idx): |
|
|
self.clap_model.eval() |
|
|
self.audio_branch.eval() |
|
|
|
|
|
mixed, mixed_resample, pos_cap, neg_cap, gt, pos_sample, neg_sample = batch |
|
|
real, imag = self.stft(mixed) |
|
|
mag, cos, sin = magphase(real, imag) |
|
|
with torch.no_grad(): |
|
|
a = torch.rand((1,)).type_as(gt) |
|
|
embed_pos_a, embed_neg_a = torch.chunk( |
|
|
self.clap_model.get_audio_embedding_from_data(torch.concat([pos_sample, neg_sample], dim=0), |
|
|
use_tensor=True), dim=0, chunks=2) |
|
|
embed_pos_t, embed_neg_t = torch.chunk( |
|
|
self.clap_model.get_text_embedding(pos_cap + neg_cap, use_tensor=True), dim=0, chunks=2) |
|
|
embed_pos = a * embed_pos_a + (1 - a) * embed_pos_t |
|
|
embed_neg = a * embed_neg_a + (1 - a) * embed_neg_t |
|
|
del self.features[:] |
|
|
self.features.append(mag) |
|
|
self.audio_branch({"waveform": mixed_resample}) |
|
|
a = torch.rand((1,)) |
|
|
if a < 0.25: |
|
|
loss = self.cal_loss(embed_pos, torch.zeros_like(embed_pos), mag, cos, sin, length=mixed.size(-1), gt=gt) |
|
|
elif a < 0.5: |
|
|
loss = self.cal_loss(torch.zeros_like(embed_neg), embed_neg, mag, cos, sin, length=mixed.size(-1), gt=gt) |
|
|
else: |
|
|
loss = self.cal_loss(embed_pos, embed_neg, mag, cos, sin, length=mixed.size(-1), gt=gt) |
|
|
self.log("train_loss", loss.item(), on_epoch=True, prog_bar=True, sync_dist=True, batch_size=len(mixed)) |
|
|
del self.features[:] |
|
|
return loss |
|
|
|
|
|
def cal_loss(self, embed_p, embed_n, mag, cos, sin, length, gt): |
|
|
embed = torch.nn.functional.normalize(torch.concat([embed_p, embed_n], dim=-1), dim=-1) |
|
|
mask = self.decoder_model(hidden_state=self.features[-1], skip_features=self.features[:-1], embed=embed) |
|
|
pred = self.wav_reconstruct(mask, mag, cos, sin, length=length) |
|
|
return loss_fn(pred, gt) |
|
|
|
|
|
def wav_reconstruct(self, mask, mag_x, cos_x, sin_x, length): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.phase: |
|
|
mag_y = torch.nn.functional.relu_(mag_x * mask[0]) |
|
|
_, mask_cos, mask_sin = magphase(mask[1], mask[2]) |
|
|
cos_y = cos_x * mask_cos - sin_x * mask_sin |
|
|
sin_y = sin_x * mask_cos + cos_x * mask_sin |
|
|
else: |
|
|
mag_y = torch.nn.functional.relu_(mag_x * mask) |
|
|
cos_y = cos_x |
|
|
sin_y = sin_x |
|
|
pred = self.istft(mag_y * cos_y, mag_y * sin_y, length=length) |
|
|
return pred |
|
|
|
|
|
def validation_step(self, batch, batch_idx): |
|
|
mixed, mixed_resample, label, neg_label, gt, _, _ = batch |
|
|
real, imag = self.stft(mixed) |
|
|
mag, cos, sin = magphase(real, imag) |
|
|
self.features.append(mag) |
|
|
with torch.no_grad(): |
|
|
embed_pos = self.clap_model.get_text_embedding(label, use_tensor=True) |
|
|
embed_neg = self.clap_model.get_text_embedding(neg_label, use_tensor=True) |
|
|
embed = torch.concat([embed_pos, embed_neg], dim=-1) |
|
|
self.audio_branch({"waveform": mixed_resample}) |
|
|
mask = self.decoder_model(hidden_state=self.features[-1], skip_features=self.features[:-1], embed=embed) |
|
|
pred = self.wav_reconstruct(mask, mag, cos, sin, length=mixed.size(-1)) |
|
|
loss = si_snr(pred, gt).mean() - si_snr(mixed, gt).mean() |
|
|
del self.features[:] |
|
|
self.log("val_loss", loss, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=len(mixed)) |
|
|
return {"val_loss": loss} |
|
|
|
|
|
def on_test_start(self) -> None: |
|
|
self.sdr_vals = torch.tensor([]) |
|
|
self.sdri_vals = torch.tensor([]) |
|
|
self.sisdr_vals = torch.tensor([]) |
|
|
self.sisdri_vals = torch.tensor([]) |
|
|
|
|
|
def test_step(self, batch, batch_idx): |
|
|
mixed, mixed_resample, label, neg_label, gt = batch |
|
|
real, imag = self.stft(mixed) |
|
|
mag, cos, sin = magphase(real, imag) |
|
|
with torch.no_grad(): |
|
|
embed_pos_bached, embed_neg_bached = torch.chunk(self.clap_model.get_text_embedding(label + neg_label, use_tensor=True), chunks=2, dim=0) |
|
|
del self.features[:] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
embed = torch.concat([embed_pos_bached, embed_neg_bached], dim=1) |
|
|
self.features.append(mag) |
|
|
self.audio_branch({"waveform": mixed_resample}) |
|
|
mask = self.decoder_model(hidden_state=self.features[-1], skip_features=self.features[:-1], embed=embed) |
|
|
pred = self.wav_reconstruct(mask, mag, cos, sin, length=mixed.size(-1)) |
|
|
sisdr = si_sdr(pred, gt).cpu() |
|
|
self.sisdr_vals = torch.concat([self.sisdr_vals, sisdr]) |
|
|
self.sisdri_vals = torch.concat([self.sisdri_vals, sisdr - si_sdr(mixed, gt).cpu()]) |
|
|
sdr_ = sdr(pred, gt).cpu() |
|
|
self.sdr_vals = torch.concat([self.sdr_vals, sdr_]) |
|
|
self.sdri_vals = torch.concat([self.sdri_vals, sdr_ - sdr(mixed, gt).cpu()]) |
|
|
del self.features[:] |
|
|
|
|
|
def on_test_end(self) -> None: |
|
|
print(f"SDR-mean: {torch.mean(self.sdr_vals).cpu().numpy():.4f}, SDR-std: {torch.std(self.sdr_vals).cpu().numpy():.4f}") |
|
|
print(f"SDRi-mean: {torch.mean(self.sdri_vals).cpu().numpy():.4f}, SDRi-std: {torch.std(self.sdri_vals).cpu().numpy():.4f}") |
|
|
print(f"SISDR-mean: {torch.mean(self.sisdr_vals).cpu().numpy():.4f}, SISDR-std: {torch.std(self.sisdr_vals).cpu().numpy():.4f}") |
|
|
print(f"SISDRi-mean: {torch.mean(self.sisdri_vals).cpu().numpy():.4f}, SISDRi-std: {torch.std(self.sisdri_vals).cpu().numpy():.4f}") |
|
|
|
|
|
def configure_optimizers(self): |
|
|
optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) |
|
|
schedular = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.3, patience=5, |
|
|
verbose=True, min_lr=5e-6) |
|
|
return { |
|
|
"optimizer": optimizer, |
|
|
"lr_scheduler": { |
|
|
"scheduler": schedular, |
|
|
"interval": "epoch", |
|
|
"monitor": "val_loss" |
|
|
}, |
|
|
} |
|
|
|
|
|
def install_forward_hooks(self): |
|
|
features = [] |
|
|
spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, |
|
|
freq_drop_width=8, freq_stripes_num=2) |
|
|
|
|
|
def get_features_list(_, __, output): |
|
|
features.append(output) |
|
|
|
|
|
def get_features_list_basic_layer(_, __, output): |
|
|
features.append(output[0]) |
|
|
|
|
|
def spec_augmentation_hook(_, __, out): |
|
|
out = out.transpose(1, 3) |
|
|
out = spec_augmenter(out) |
|
|
return out.transpose(1, 3) |
|
|
|
|
|
def spectrogram_padding(_, __, out): |
|
|
return torch.nn.functional.pad(out, (0, 0, 0, 1024 - out.size(2))) |
|
|
|
|
|
self.clap_model.model.audio_branch.bn0.register_forward_hook(spec_augmentation_hook) |
|
|
self.audio_branch.spectrogram_extractor.register_forward_hook(spectrogram_padding) |
|
|
self.audio_branch.patch_embed.register_forward_hook(get_features_list) |
|
|
for module in self.audio_branch.layers: |
|
|
module.register_forward_hook(get_features_list_basic_layer) |
|
|
return features |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|