Spaces:
Sleeping
Sleeping
File size: 4,450 Bytes
957e2dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import argbind
import sounddevice as sd
import numpy as np
import yaml
import torch
import os
from typing import Union
import sys
import warnings
sys.path.append('.')
warnings.filterwarnings('ignore', category=UserWarning)
from src.data.dataproperties import DataProperties
from src.attacks.online import Streamer, VoiceBoxStreamer
from src.constants import MODELS_DIR, CONDITIONING_FILENAME
def get_streams(input_name: str, output_name: str, block_size: int) -> tuple[sd.InputStream, sd.OutputStream]:
"""
Gets Input and Output stream objects
"""
try:
input_name = int(input_name)
except ValueError:
pass
try:
output_name = int(output_name)
except ValueError:
pass
return (
sd.InputStream(device=input_name,
samplerate=DataProperties.get('sample_rate'),
channels=1,
blocksize=block_size),
sd.OutputStream(device=output_name,
samplerate=DataProperties.get('sample_rate'),
channels=1,
blocksize=block_size)
)
def get_model_streamer(device: str, conditioning_path: str) -> Streamer:
# TODO: Make a good way to query an attack type. For now, I'm going to hard code this.
model_dir = os.path.join(MODELS_DIR, 'voicebox')
checkpoint_path = os.path.join(model_dir, 'voicebox_final.pt')
config_path = os.path.join(model_dir, 'voicebox_final.yaml')
with open(config_path) as f:
config = yaml.safe_load(f)
state_dict = torch.load(checkpoint_path, map_location=device)
condition_tensor = torch.load(conditioning_path, map_location=device)
model = VoiceBoxStreamer(
**config
)
model.load_state_dict(state_dict)
model.condition_vector = condition_tensor.reshape(1, 1, -1)
streamer = Streamer(
model=model,
device=device,
lookahead_frames=config['bottleneck_lookahead_frames'],
recurrent=True
)
return streamer
def to_model(x: np.ndarray, device: str) -> torch.Tensor:
return torch.Tensor(x).view(1, 1, -1).to(device)
def from_model(x: torch.Tensor) -> np.ndarray:
return x.detach().cpu().view(-1, 1).numpy()
@argbind.bind(without_prefix=True)
def main(
input: str = None,
output: str = '',
device: str = 'cpu',
num_frames: int = 4,
pass_through: bool = False,
conditioning_path: str = CONDITIONING_FILENAME
):
f"""
Uses a streaming implementation of an attack to perturb incoming audio
:param input: Index or name of input audio interface. Defaults to current device
:type input: str, optional
:param output: Index of name output audio interface. Defaults to 0
:type output: str, optional
:param device: Device to processing attack. Should be either 'cpu' or 'cuda:X'
Defaults to 'cpu'.
:type device: str, optional
:param pass_through: If True, the voicebox perturbation is not applied and the input will be
identical to the output. This is for demo purposes. The input and output audio will
remain at 16 kHz.
:type pass_through: bool, optional
:type device: str, optional
:param num_frames: Number of overlapping model frames to process at one iteration.
Defaults to 1
:type num_frames: int
:param conditioning_path: Path to conditioning tensor. Default: {CONDITIONING_FILENAME}
:type conditioning_path: str
"""
streamer = get_model_streamer(device, conditioning_path)
input_stream, output_stream = get_streams(input, output, streamer.hop_length)
if streamer.win_type in ['hann', 'triangular']:
input_samples = (num_frames - 1) * streamer.hop_length + streamer.window_length
else:
input_samples = streamer.hop_length
print("Ready to process audio")
input_stream.start()
output_stream.start()
try:
while True:
frames, overflow = input_stream.read(input_samples)
if pass_through:
output_stream.write(frames)
continue
out = streamer.feed(to_model(frames, device))
out = from_model(out)
underflow = output_stream.write(out)
except KeyboardInterrupt:
print("Stopping")
input_stream.stop()
output_stream.stop()
if __name__ == "__main__":
args = argbind.parse_args()
with argbind.scope(args):
main()
|