Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # -*- coding: UTF-8 -*- | |
| ''' | |
| @Project :Waveformer-main | |
| @File :CLAPSep.py | |
| @IDE :PyCharm | |
| @Author :Aisaka/Hao Ma @SDU | |
| @Date :2024/2/28 下午1:12 | |
| ''' | |
| import torch | |
| from torch import nn | |
| import torchaudio | |
| import laion_clap | |
| from .CLAPSep_decoder import HTSAT_Decoder | |
| import copy | |
| import loralib as lora | |
| from torchlibrosa import ISTFT, STFT | |
| from torchlibrosa.stft import magphase | |
| import librosa | |
| def set_module(model, submodule_key, module): | |
| tokens = submodule_key.split('.') | |
| sub_tokens = tokens[:-1] | |
| cur_mod = model | |
| for s in sub_tokens: | |
| cur_mod = getattr(cur_mod, s) | |
| setattr(cur_mod, tokens[-1], module) | |
| def process_model(model, rank): | |
| for n, module in model.named_modules(): | |
| if 'WindowAttention' in str(type(module)): | |
| for n_, layer in module.named_modules(): | |
| if isinstance(layer, torch.nn.Linear): | |
| lora_layer = lora.Linear(layer.in_features, layer.out_features, r=rank, | |
| bias=hasattr(layer, 'bias'), merge_weights=True) | |
| lora_layer.weight = layer.weight | |
| if hasattr(layer, 'bias'): | |
| lora_layer.bias = layer.bias | |
| set_module(model, n+'.'+n_, lora_layer) | |
| return model | |
| class CLAPSep(nn.Module): | |
| def __init__(self, model_config, CLAP_path, use_lora=True, rank=16, nfft=1024): | |
| super().__init__() | |
| self.resampler = torchaudio.transforms.Resample(32000, 48000) | |
| self.clap_model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cpu') | |
| self.clap_model.load_ckpt(CLAP_path) | |
| for p in self.clap_model.parameters(): | |
| p.requires_grad = False | |
| self.audio_branch = copy.deepcopy(self.clap_model.model.audio_branch) | |
| if use_lora: | |
| process_model(self.audio_branch, rank) | |
| self.decoder_model = HTSAT_Decoder(**model_config) | |
| self.stft = STFT(n_fft=nfft, hop_length=320, | |
| win_length=nfft, window='hann', center=True, pad_mode='reflect', | |
| freeze_parameters=True) | |
| self.istft = ISTFT(n_fft=nfft, hop_length=320, | |
| win_length=nfft, window='hann', center=True, pad_mode='reflect', | |
| freeze_parameters=True) | |
| self.features = self.install_forward_hooks() | |
| def wav_reconstruct(self, mask, mag_x, cos_x, sin_x, length): | |
| mag_y = torch.nn.functional.relu_(mag_x * mask) | |
| cos_y = cos_x | |
| sin_y = sin_x | |
| pred = self.istft(mag_y * cos_y, mag_y * sin_y, length=length) | |
| return pred | |
| def inference_from_data(self, mixed, embed_pos, embed_neg): | |
| self.eval() | |
| real, imag = self.stft(mixed) | |
| mag, cos, sin = magphase(real, imag) | |
| self.features.append(mag) | |
| with torch.no_grad(): | |
| embed = torch.nn.functional.normalize(torch.concat([embed_pos, embed_neg], dim=-1), dim=-1) | |
| self.audio_branch({"waveform": self.resampler(mixed)}) | |
| mask = self.decoder_model(hidden_state=self.features[-1], skip_features=self.features[:-1], embed=embed) | |
| pred = self.wav_reconstruct(mask, mag, cos, sin, length=mixed.size(-1)) | |
| del self.features[:] | |
| return pred | |
| def install_forward_hooks(self): | |
| features = [] | |
| def get_features_list(_, __, output): | |
| features.append(output) | |
| def get_features_list_basic_layer(_, __, output): | |
| features.append(output[0]) | |
| def spectrogram_padding(_, __, out): | |
| return torch.nn.functional.pad(out, (0, 0, 0, 1024 - out.size(2))) | |
| self.audio_branch.spectrogram_extractor.register_forward_hook(spectrogram_padding) | |
| self.audio_branch.patch_embed.register_forward_hook(get_features_list) | |
| for module in self.audio_branch.layers: | |
| module.register_forward_hook(get_features_list_basic_layer) | |
| return features | |
| if __name__ == '__main__': | |
| model_config = {"lan_embed_dim": 1024, | |
| "depths": [1, 1, 1, 1], | |
| "embed_dim": 128, | |
| "encoder_embed_dim": 128, | |
| "phase": False, | |
| "spec_factor": 8, | |
| "d_attn": 640, | |
| "n_masker_layer": 3, | |
| "conv": False} | |
| CLAP_path = "./music_audioset_epoch_15_esc_90.14.pt" | |
| model = CLAPSep(model_config, CLAP_path) | |
| ckpt = torch.load('best_model.ckpt', map_location='cpu') | |
| model.load_state_dict(ckpt, strict=False) | |
| model.eval() | |
| audio, fs = librosa.load("./510_25.221254348754883_mixture.wav", sr=32000) | |
| pred = model.inference_from_data(torch.tensor(audio).unsqueeze(0), pos_prompt=[''], neg_prompt=['A vehicle engine revving then powering down.']) | |
| import soundfile as sf | |
| sf.write('./pred.wav', pred.squeeze().numpy(), 32000) |