#!/usr/bin/env python # -*- coding: UTF-8 -*- ''' @Project :Waveformer-main @File :CLAPSep.py @IDE :PyCharm @Author :Aisaka/Hao Ma @SDU @Date :2024/2/28 下午1:12 ''' import os import torch import laion_clap from torchmetrics.audio.snr import( scale_invariant_signal_noise_ratio as si_snr, signal_noise_ratio as snr) from torchmetrics.audio.sdr import( signal_distortion_ratio as sdr, scale_invariant_signal_distortion_ratio as si_sdr) import copy import loralib as lora from torchlibrosa import ISTFT, STFT, SpecAugmentation from torchlibrosa.stft import magphase import librosa import pytorch_lightning as pl import soundfile as sf def loss_fn(pred, tgt): return -0.9 * snr(pred, tgt).mean() - 0.1 * si_snr(pred, tgt).mean() def set_module(model, submodule_key, module): tokens = submodule_key.split('.') sub_tokens = tokens[:-1] cur_mod = model for s in sub_tokens: cur_mod = getattr(cur_mod, s) setattr(cur_mod, tokens[-1], module) def process_model(model, rank): for n, module in model.named_modules(): if 'WindowAttention' in str(type(module)): for n_, layer in module.named_modules(): if isinstance(layer, torch.nn.Linear): lora_layer = lora.Linear(layer.in_features, layer.out_features, r=rank, bias=hasattr(layer, 'bias'), merge_weights=False) lora_layer.weight = layer.weight if hasattr(layer, 'bias'): lora_layer.bias = layer.bias set_module(model, n+'.'+n_, lora_layer) return model class LightningModule(pl.LightningModule): def __init__(self, clap_model, decoder_model, lr, use_lora=False, rank=8, nfft=1024): super().__init__() self.phase = decoder_model.phase self.lr = lr self.clap_model = clap_model for p in self.clap_model.parameters(): p.requires_grad = False self.audio_branch = copy.deepcopy(self.clap_model.model.audio_branch) if use_lora: process_model(self.audio_branch, rank) lora.mark_only_lora_as_trainable(self.audio_branch, bias='lora_only') self.decoder_model = decoder_model self.stft = STFT(n_fft=nfft, hop_length=320, win_length=nfft, window='hann', center=True, pad_mode='reflect', freeze_parameters=True) self.istft = ISTFT(n_fft=nfft, hop_length=320, win_length=nfft, window='hann', center=True, pad_mode='reflect', freeze_parameters=True) self.features = self.install_forward_hooks() def training_step(self, batch, batch_idx): self.clap_model.eval() self.audio_branch.eval() # print([len(x) for x in batch]) mixed, mixed_resample, pos_cap, neg_cap, gt, pos_sample, neg_sample = batch real, imag = self.stft(mixed) mag, cos, sin = magphase(real, imag) with torch.no_grad(): a = torch.rand((1,)).type_as(gt) embed_pos_a, embed_neg_a = torch.chunk( self.clap_model.get_audio_embedding_from_data(torch.concat([pos_sample, neg_sample], dim=0), use_tensor=True), dim=0, chunks=2) embed_pos_t, embed_neg_t = torch.chunk( self.clap_model.get_text_embedding(pos_cap + neg_cap, use_tensor=True), dim=0, chunks=2) embed_pos = a * embed_pos_a + (1 - a) * embed_pos_t embed_neg = a * embed_neg_a + (1 - a) * embed_neg_t del self.features[:] self.features.append(mag) self.audio_branch({"waveform": mixed_resample}) a = torch.rand((1,)) if a < 0.25: loss = self.cal_loss(embed_pos, torch.zeros_like(embed_pos), mag, cos, sin, length=mixed.size(-1), gt=gt) elif a < 0.5: loss = self.cal_loss(torch.zeros_like(embed_neg), embed_neg, mag, cos, sin, length=mixed.size(-1), gt=gt) else: loss = self.cal_loss(embed_pos, embed_neg, mag, cos, sin, length=mixed.size(-1), gt=gt) self.log("train_loss", loss.item(), on_epoch=True, prog_bar=True, sync_dist=True, batch_size=len(mixed)) del self.features[:] return loss def cal_loss(self, embed_p, embed_n, mag, cos, sin, length, gt): embed = torch.nn.functional.normalize(torch.concat([embed_p, embed_n], dim=-1), dim=-1) mask = self.decoder_model(hidden_state=self.features[-1], skip_features=self.features[:-1], embed=embed) pred = self.wav_reconstruct(mask, mag, cos, sin, length=length) return loss_fn(pred, gt) def wav_reconstruct(self, mask, mag_x, cos_x, sin_x, length): # ref: https://github.com/Audio-AGI/AudioSep/blob/main/models/resunet.py # Y = |Y|cos∠Y + j|Y|sin∠Y # = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M) # = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M) if self.phase: mag_y = torch.nn.functional.relu_(mag_x * mask[0]) _, mask_cos, mask_sin = magphase(mask[1], mask[2]) cos_y = cos_x * mask_cos - sin_x * mask_sin sin_y = sin_x * mask_cos + cos_x * mask_sin else: mag_y = torch.nn.functional.relu_(mag_x * mask) cos_y = cos_x sin_y = sin_x pred = self.istft(mag_y * cos_y, mag_y * sin_y, length=length) return pred def validation_step(self, batch, batch_idx): mixed, mixed_resample, label, neg_label, gt, _, _ = batch real, imag = self.stft(mixed) mag, cos, sin = magphase(real, imag) self.features.append(mag) with torch.no_grad(): embed_pos = self.clap_model.get_text_embedding(label, use_tensor=True) embed_neg = self.clap_model.get_text_embedding(neg_label, use_tensor=True) embed = torch.concat([embed_pos, embed_neg], dim=-1) self.audio_branch({"waveform": mixed_resample}) mask = self.decoder_model(hidden_state=self.features[-1], skip_features=self.features[:-1], embed=embed) pred = self.wav_reconstruct(mask, mag, cos, sin, length=mixed.size(-1)) loss = si_snr(pred, gt).mean() - si_snr(mixed, gt).mean() del self.features[:] self.log("val_loss", loss, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=len(mixed)) return {"val_loss": loss} def test_step(self, batch, batch_idx): mixed, mixed_resample, labels, src_path = batch assert len(labels) == 1 src_path = src_path[0] save_path = src_path.replace('/data/sound/audioset/audios_32k', 'data_engine_infer/audioset_separation_child_label')[:-4] os.makedirs(save_path, exist_ok=True) real, imag = self.stft(mixed) mag, cos, sin = magphase(real, imag) with torch.no_grad(): labels = labels[0].split('|') for pos_label in labels: embed_pos = self.clap_model.get_text_embedding(pos_label, use_tensor=True) neg_labels = copy.deepcopy(labels) neg_labels.remove(pos_label) if neg_labels: neg_label = ', '.join(neg_labels) embed_neg = self.clap_model.get_text_embedding(neg_label, use_tensor=True) else: embed_neg = torch.zeros_like(embed_pos) # only positive # embed = torch.concat([embed_pos, torch.zeros_like(embed_pos)], dim=1) # positive and negative embed = torch.concat([embed_pos, embed_neg], dim=1) del self.features[:] self.features.append(mag) self.audio_branch({"waveform": mixed_resample}) mask = self.decoder_model(hidden_state=self.features[-1], skip_features=self.features[:-1], embed=embed) pred = self.wav_reconstruct(mask, mag, cos, sin, length=mixed.size(-1)) del self.features[:] sf.write(os.path.join(save_path, pos_label + '.wav'), pred.squeeze().cpu().numpy(), samplerate=32000) def configure_optimizers(self): optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) schedular = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.3, patience=5, verbose=True, min_lr=5e-6) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": schedular, "interval": "epoch", "monitor": "val_loss" }, } def install_forward_hooks(self): features = [] spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, freq_drop_width=8, freq_stripes_num=2) def get_features_list(_, __, output): features.append(output) def get_features_list_basic_layer(_, __, output): features.append(output[0]) def spec_augmentation_hook(_, __, out): out = out.transpose(1, 3) out = spec_augmenter(out) return out.transpose(1, 3) def spectrogram_padding(_, __, out): return torch.nn.functional.pad(out, (0, 0, 0, 1024 - out.size(2))) self.clap_model.model.audio_branch.bn0.register_forward_hook(spec_augmentation_hook) self.audio_branch.spectrogram_extractor.register_forward_hook(spectrogram_padding) self.audio_branch.patch_embed.register_forward_hook(get_features_list) for module in self.audio_branch.layers: module.register_forward_hook(get_features_list_basic_layer) return features # # this will only save tuned parameters during training # def on_save_checkpoint(self, checkpoint): # weights = checkpoint['state_dict'] # new_dict = {} # for k, v in weights.items(): # if any(e in k for e in ['lora', 'attn.qkv.bias', 'attn.proj.bias', 'decoder_model']): # new_dict[k] = v # checkpoint['state_dict'] = new_dict