Spaces:
Sleeping
Sleeping
| import argbind | |
| import sounddevice as sd | |
| import numpy as np | |
| import yaml | |
| import torch | |
| import os | |
| from typing import Union | |
| import sys | |
| import warnings | |
| sys.path.append('.') | |
| warnings.filterwarnings('ignore', category=UserWarning) | |
| from src.data.dataproperties import DataProperties | |
| from src.attacks.online import Streamer, VoiceBoxStreamer | |
| from src.constants import MODELS_DIR, CONDITIONING_FILENAME | |
| def get_streams(input_name: str, output_name: str, block_size: int) -> tuple[sd.InputStream, sd.OutputStream]: | |
| """ | |
| Gets Input and Output stream objects | |
| """ | |
| try: | |
| input_name = int(input_name) | |
| except ValueError: | |
| pass | |
| try: | |
| output_name = int(output_name) | |
| except ValueError: | |
| pass | |
| return ( | |
| sd.InputStream(device=input_name, | |
| samplerate=DataProperties.get('sample_rate'), | |
| channels=1, | |
| blocksize=block_size), | |
| sd.OutputStream(device=output_name, | |
| samplerate=DataProperties.get('sample_rate'), | |
| channels=1, | |
| blocksize=block_size) | |
| ) | |
| def get_model_streamer(device: str, conditioning_path: str) -> Streamer: | |
| # TODO: Make a good way to query an attack type. For now, I'm going to hard code this. | |
| model_dir = os.path.join(MODELS_DIR, 'voicebox') | |
| checkpoint_path = os.path.join(model_dir, 'voicebox_final.pt') | |
| config_path = os.path.join(model_dir, 'voicebox_final.yaml') | |
| with open(config_path) as f: | |
| config = yaml.safe_load(f) | |
| state_dict = torch.load(checkpoint_path, map_location=device) | |
| condition_tensor = torch.load(conditioning_path, map_location=device) | |
| model = VoiceBoxStreamer( | |
| **config | |
| ) | |
| model.load_state_dict(state_dict) | |
| model.condition_vector = condition_tensor.reshape(1, 1, -1) | |
| streamer = Streamer( | |
| model=model, | |
| device=device, | |
| lookahead_frames=config['bottleneck_lookahead_frames'], | |
| recurrent=True | |
| ) | |
| return streamer | |
| def to_model(x: np.ndarray, device: str) -> torch.Tensor: | |
| return torch.Tensor(x).view(1, 1, -1).to(device) | |
| def from_model(x: torch.Tensor) -> np.ndarray: | |
| return x.detach().cpu().view(-1, 1).numpy() | |
| def main( | |
| input: str = None, | |
| output: str = '', | |
| device: str = 'cpu', | |
| num_frames: int = 4, | |
| pass_through: bool = False, | |
| conditioning_path: str = CONDITIONING_FILENAME | |
| ): | |
| f""" | |
| Uses a streaming implementation of an attack to perturb incoming audio | |
| :param input: Index or name of input audio interface. Defaults to current device | |
| :type input: str, optional | |
| :param output: Index of name output audio interface. Defaults to 0 | |
| :type output: str, optional | |
| :param device: Device to processing attack. Should be either 'cpu' or 'cuda:X' | |
| Defaults to 'cpu'. | |
| :type device: str, optional | |
| :param pass_through: If True, the voicebox perturbation is not applied and the input will be | |
| identical to the output. This is for demo purposes. The input and output audio will | |
| remain at 16 kHz. | |
| :type pass_through: bool, optional | |
| :type device: str, optional | |
| :param num_frames: Number of overlapping model frames to process at one iteration. | |
| Defaults to 1 | |
| :type num_frames: int | |
| :param conditioning_path: Path to conditioning tensor. Default: {CONDITIONING_FILENAME} | |
| :type conditioning_path: str | |
| """ | |
| streamer = get_model_streamer(device, conditioning_path) | |
| input_stream, output_stream = get_streams(input, output, streamer.hop_length) | |
| if streamer.win_type in ['hann', 'triangular']: | |
| input_samples = (num_frames - 1) * streamer.hop_length + streamer.window_length | |
| else: | |
| input_samples = streamer.hop_length | |
| print("Ready to process audio") | |
| input_stream.start() | |
| output_stream.start() | |
| try: | |
| while True: | |
| frames, overflow = input_stream.read(input_samples) | |
| if pass_through: | |
| output_stream.write(frames) | |
| continue | |
| out = streamer.feed(to_model(frames, device)) | |
| out = from_model(out) | |
| underflow = output_stream.write(out) | |
| except KeyboardInterrupt: | |
| print("Stopping") | |
| input_stream.stop() | |
| output_stream.stop() | |
| if __name__ == "__main__": | |
| args = argbind.parse_args() | |
| with argbind.scope(args): | |
| main() | |