Spaces:
Runtime error
Runtime error
Update audio_diffusion_attacks_forhf/src/test_encoder_attack.py
Browse files
audio_diffusion_attacks_forhf/src/test_encoder_attack.py
CHANGED
|
@@ -36,7 +36,6 @@ from audiocraft.losses import (
|
|
| 36 |
'''
|
| 37 |
from audio_diffusion_attacks_forhf.src.music_gen import MusicGenEval
|
| 38 |
from audio_diffusion_attacks_forhf.src.speech_inference import XTTS_Eval
|
| 39 |
-
print("breakpoint 5")
|
| 40 |
|
| 41 |
# From https://pytorch.org/tutorials/beginner/audio_preprocessing_tutorial.html#loading-audio-data-into-tensor
|
| 42 |
def print_stats(waveform, sample_rate=None, src=None):
|
|
@@ -87,12 +86,14 @@ def poison_audio(waveform, sample_rate, encoders, audio_difference_weights=[1],
|
|
| 87 |
audio_folder: string, path to folder of audio files. Protected audio files will be saved in that folder.
|
| 88 |
encoders: encoders to protect against. See initialization at end of file.
|
| 89 |
'''
|
|
|
|
| 90 |
for encoder in encoders:
|
| 91 |
-
encoder.to(device='cuda')
|
| 92 |
encoder.eval()
|
| 93 |
for p in encoder.parameters():
|
| 94 |
p.requires_grad = False
|
| 95 |
|
|
|
|
| 96 |
audio_len=1000000
|
| 97 |
waveform, sample_rate = torchaudio.load(f"test_audio/Texas Sun.mp3")
|
| 98 |
if modality=="music":
|
|
|
|
| 36 |
'''
|
| 37 |
from audio_diffusion_attacks_forhf.src.music_gen import MusicGenEval
|
| 38 |
from audio_diffusion_attacks_forhf.src.speech_inference import XTTS_Eval
|
|
|
|
| 39 |
|
| 40 |
# From https://pytorch.org/tutorials/beginner/audio_preprocessing_tutorial.html#loading-audio-data-into-tensor
|
| 41 |
def print_stats(waveform, sample_rate=None, src=None):
|
|
|
|
| 86 |
audio_folder: string, path to folder of audio files. Protected audio files will be saved in that folder.
|
| 87 |
encoders: encoders to protect against. See initialization at end of file.
|
| 88 |
'''
|
| 89 |
+
print("breakpoint 1")
|
| 90 |
for encoder in encoders:
|
| 91 |
+
#Andy removed: encoder.to(device='cuda')
|
| 92 |
encoder.eval()
|
| 93 |
for p in encoder.parameters():
|
| 94 |
p.requires_grad = False
|
| 95 |
|
| 96 |
+
print("breakpoint 2")
|
| 97 |
audio_len=1000000
|
| 98 |
waveform, sample_rate = torchaudio.load(f"test_audio/Texas Sun.mp3")
|
| 99 |
if modality=="music":
|