import torch import librosa import soundfile as sf from tqdm import tqdm from src.attacks.offline.perturbation.voicebox import projection from src.attacks.online import Streamer, VoiceBoxStreamer from src.models import ResNetSE34V2, SpeakerVerificationModel from src.constants import MODELS_DIR, TEST_DIR, PPG_PRETRAINED_PATH import warnings warnings.filterwarnings("ignore") torch.set_num_threads(1) device = 'cpu' lookahead = 5 signal_length = 64_000 chunk_size = 640 test_audio = torch.Tensor( librosa.load(TEST_DIR / 'data' / 'test.wav', sr=16_000, mono=True)[0] ).unsqueeze(0).unsqueeze(0) tests = [ (512, 512, 512) ] resnet_model = SpeakerVerificationModel(model=ResNetSE34V2()) condition_vector = resnet_model(test_audio) for (bottleneck_hidden_size, bottleneck_feedforward_size, spec_encoder_hidden_size) in tests: print( f""" ==================================== bottleneck_hidden_size: {bottleneck_hidden_size} bottleneck_feedforward_size: {bottleneck_feedforward_size} spec_encoder_hidden_size: {spec_encoder_hidden_size} """ ) streamer = Streamer( VoiceBoxStreamer( win_length=256, bottleneck_type='lstm', bottleneck_skip=True, bottleneck_depth=2, bottleneck_lookahead_frames=5, bottleneck_hidden_size=bottleneck_hidden_size, bottleneck_feedforward_size=bottleneck_feedforward_size, conditioning_dim=512, spec_encoder_mlp_depth=2, spec_encoder_hidden_size=spec_encoder_hidden_size, spec_encoder_lookahead_frames=0, ppg_encoder_path=PPG_PRETRAINED_PATH, ppg_encoder_depth=2, ppg_encoder_hidden_size=256, projection_norm='inf', control_eps=0.5, n_bands=128 ), device, hop_length=128, window_length=256, win_type='hann', lookahead_frames=lookahead, recurrent=True ) streamer.model.load_state_dict(torch.load(MODELS_DIR / 'voicebox' / 'voicebox_final.pt')) streamer.condition_vector = condition_vector output_chunks = [] for i in tqdm(range(0, signal_length, chunk_size)): signal_chunk = test_audio[..., i:i+chunk_size] out = streamer.feed(signal_chunk) output_chunks.append(out) output_chunks.append(streamer.flush()) output_audio = torch.cat(output_chunks, dim=-1) output_embedding = resnet_model(output_audio) print( f""" RTF: {streamer.real_time_factor} Embedding Distance: {resnet_model.distance_fn(output_embedding, condition_vector)} ==================================== """ ) sf.write( 'output.wav', output_audio.numpy().squeeze(), 16_000, )