import argbind import sounddevice as sd import numpy as np import yaml import torch import os from typing import Union import sys import warnings sys.path.append('.') warnings.filterwarnings('ignore', category=UserWarning) from src.data.dataproperties import DataProperties from src.attacks.online import Streamer, VoiceBoxStreamer from src.constants import MODELS_DIR, CONDITIONING_FILENAME def get_streams(input_name: str, output_name: str, block_size: int) -> tuple[sd.InputStream, sd.OutputStream]: """ Gets Input and Output stream objects """ try: input_name = int(input_name) except ValueError: pass try: output_name = int(output_name) except ValueError: pass return ( sd.InputStream(device=input_name, samplerate=DataProperties.get('sample_rate'), channels=1, blocksize=block_size), sd.OutputStream(device=output_name, samplerate=DataProperties.get('sample_rate'), channels=1, blocksize=block_size) ) def get_model_streamer(device: str, conditioning_path: str) -> Streamer: # TODO: Make a good way to query an attack type. For now, I'm going to hard code this. model_dir = os.path.join(MODELS_DIR, 'voicebox') checkpoint_path = os.path.join(model_dir, 'voicebox_final.pt') config_path = os.path.join(model_dir, 'voicebox_final.yaml') with open(config_path) as f: config = yaml.safe_load(f) state_dict = torch.load(checkpoint_path, map_location=device) condition_tensor = torch.load(conditioning_path, map_location=device) model = VoiceBoxStreamer( **config ) model.load_state_dict(state_dict) model.condition_vector = condition_tensor.reshape(1, 1, -1) streamer = Streamer( model=model, device=device, lookahead_frames=config['bottleneck_lookahead_frames'], recurrent=True ) return streamer def to_model(x: np.ndarray, device: str) -> torch.Tensor: return torch.Tensor(x).view(1, 1, -1).to(device) def from_model(x: torch.Tensor) -> np.ndarray: return x.detach().cpu().view(-1, 1).numpy() @argbind.bind(without_prefix=True) def main( input: str = None, output: str = '', device: str = 'cpu', num_frames: int = 4, pass_through: bool = False, conditioning_path: str = CONDITIONING_FILENAME ): f""" Uses a streaming implementation of an attack to perturb incoming audio :param input: Index or name of input audio interface. Defaults to current device :type input: str, optional :param output: Index of name output audio interface. Defaults to 0 :type output: str, optional :param device: Device to processing attack. Should be either 'cpu' or 'cuda:X' Defaults to 'cpu'. :type device: str, optional :param pass_through: If True, the voicebox perturbation is not applied and the input will be identical to the output. This is for demo purposes. The input and output audio will remain at 16 kHz. :type pass_through: bool, optional :type device: str, optional :param num_frames: Number of overlapping model frames to process at one iteration. Defaults to 1 :type num_frames: int :param conditioning_path: Path to conditioning tensor. Default: {CONDITIONING_FILENAME} :type conditioning_path: str """ streamer = get_model_streamer(device, conditioning_path) input_stream, output_stream = get_streams(input, output, streamer.hop_length) if streamer.win_type in ['hann', 'triangular']: input_samples = (num_frames - 1) * streamer.hop_length + streamer.window_length else: input_samples = streamer.hop_length print("Ready to process audio") input_stream.start() output_stream.start() try: while True: frames, overflow = input_stream.read(input_samples) if pass_through: output_stream.write(frames) continue out = streamer.feed(to_model(frames, device)) out = from_model(out) underflow = output_stream.write(out) except KeyboardInterrupt: print("Stopping") input_stream.stop() output_stream.stop() if __name__ == "__main__": args = argbind.parse_args() with argbind.scope(args): main()