File size: 4,450 Bytes
957e2dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import argbind
import sounddevice as sd
import numpy as np
import yaml
import torch
import os
from typing import Union

import sys
import warnings

sys.path.append('.')
warnings.filterwarnings('ignore', category=UserWarning)

from src.data.dataproperties import DataProperties
from src.attacks.online import Streamer, VoiceBoxStreamer
from src.constants import MODELS_DIR, CONDITIONING_FILENAME


def get_streams(input_name: str, output_name: str, block_size: int) -> tuple[sd.InputStream, sd.OutputStream]:
    """
    Gets Input and Output stream objects
    """
    try:
        input_name = int(input_name)
    except ValueError:
        pass
    try:
        output_name = int(output_name)
    except ValueError:
        pass
    return (
        sd.InputStream(device=input_name,
                       samplerate=DataProperties.get('sample_rate'),
                       channels=1,
                       blocksize=block_size),
        sd.OutputStream(device=output_name,
                        samplerate=DataProperties.get('sample_rate'),
                        channels=1,
                        blocksize=block_size)
    )


def get_model_streamer(device: str, conditioning_path: str) -> Streamer:
    # TODO: Make a good way to query an attack type. For now, I'm going to hard code this.
    model_dir = os.path.join(MODELS_DIR, 'voicebox')
    checkpoint_path = os.path.join(model_dir, 'voicebox_final.pt')
    config_path = os.path.join(model_dir, 'voicebox_final.yaml')

    with open(config_path) as f:
        config = yaml.safe_load(f)

    state_dict = torch.load(checkpoint_path, map_location=device)
    condition_tensor = torch.load(conditioning_path, map_location=device)
    model = VoiceBoxStreamer(
        **config
    )
    model.load_state_dict(state_dict)
    model.condition_vector = condition_tensor.reshape(1, 1, -1)

    streamer = Streamer(
        model=model,
        device=device,
        lookahead_frames=config['bottleneck_lookahead_frames'],
        recurrent=True
    )
    return streamer


def to_model(x: np.ndarray, device: str) -> torch.Tensor:
    return torch.Tensor(x).view(1, 1, -1).to(device)


def from_model(x: torch.Tensor) -> np.ndarray:
    return x.detach().cpu().view(-1, 1).numpy()


@argbind.bind(without_prefix=True)
def main(
    input: str = None,
    output: str = '',
    device: str = 'cpu',
    num_frames: int = 4,
    pass_through: bool = False,
    conditioning_path: str = CONDITIONING_FILENAME
):
    f"""
    Uses a streaming implementation of an attack to perturb incoming audio

    :param input: Index or name of input audio interface. Defaults to current device
    :type input: str, optional
    :param output: Index of name output audio interface. Defaults to 0
    :type output: str, optional
    :param device: Device to processing attack. Should be either 'cpu' or 'cuda:X'
        Defaults to 'cpu'.
    :type device: str, optional
    :param pass_through: If True, the voicebox perturbation is not applied and the input will be
        identical to the output. This is for demo purposes. The input and output audio will
        remain at 16 kHz.
    :type pass_through: bool, optional
    :type device: str, optional
    :param num_frames: Number of overlapping model frames to process at one iteration.
        Defaults to 1
    :type num_frames: int
    :param conditioning_path: Path to conditioning tensor. Default: {CONDITIONING_FILENAME}
    :type conditioning_path: str
    """
    streamer = get_model_streamer(device, conditioning_path)
    input_stream, output_stream = get_streams(input, output, streamer.hop_length)
    if streamer.win_type in ['hann', 'triangular']:
        input_samples = (num_frames - 1) * streamer.hop_length + streamer.window_length
    else:
        input_samples = streamer.hop_length
    print("Ready to process audio")
    input_stream.start()
    output_stream.start()
    try:
        while True:
            frames, overflow = input_stream.read(input_samples)
            if pass_through:
                output_stream.write(frames)
                continue
            out = streamer.feed(to_model(frames, device))
            out = from_model(out)
            underflow = output_stream.write(out)
    except KeyboardInterrupt:
        print("Stopping")
        input_stream.stop()
        output_stream.stop()


if __name__ == "__main__":
    args = argbind.parse_args()
    with argbind.scope(args):
        main()