Spaces:
Sleeping
Sleeping
File size: 6,410 Bytes
f872c8a 33158ae f872c8a 17f3204 171115d 12a037e 33158ae 12a037e f872c8a 12a037e 8c0aa35 12a037e 5f82d8a 12a037e 8c0aa35 12a037e 8c0aa35 f872c8a 8c0aa35 f872c8a 12a037e 33158ae f872c8a 8c0aa35 f872c8a 8c0aa35 33158ae f872c8a 12a037e 5f82d8a 8c0aa35 dff04ad 33158ae 12a037e 33158ae 8c0aa35 f872c8a 33158ae 95af785 33158ae 12a037e 33158ae 12a037e 33158ae 12a037e 8c0aa35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import os
import sys
import uuid
from pathlib import Path
from contextlib import contextmanager
import ruptures as rpt
import numpy as np
import torch
import gradio as gr
import librosa
from pyharp.core import ModelCard, build_endpoint
from pyharp.media.audio import save_audio
from pyharp import LabelList, AudioLabel, OutputLabel
from audiotools import AudioSignal
from audioseal import AudioSeal
LOUDNESS_DB = -16.
SAMPLE_RATE = 48_000
ENCODEC_SAMPLE_RATE = 16_000
AUDIOSEAL_SAMPLE_RATE = 16_000
model_card = ModelCard(
name="Meta AudioSeal Watermarking",
description=("Meta AudioSeal watermarking generation and detection model\n"
"The watermark is applied under 16kHz."),
author="Robin San Roman, Pierre Fernandez, Alexandre Défossez, Teddy Furon, Tuan Tran, Hady Elsahar",
tags=["watermarking"]
)
print("Initializing AudioSeal model...")
generator = AudioSeal.load_generator("audioseal_wm_16bits")
detector = AudioSeal.load_detector("audioseal_detector_16bits")
generator.eval()
detector.eval()
def load_audio(audio_path):
try:
wav, sr = librosa.load(audio_path, mono=True)
return wav, sr
except Exception as e:
print(f"Audio preprocessing failed: {e}")
raise ValueError(f"Failed to load audio: {str(e)}")
@torch.no_grad()
def split_bands(signal: AudioSignal, sample_rate: float = ENCODEC_SAMPLE_RATE):
nyq = sample_rate // 2
high = signal.clone().high_pass(cutoffs=int(nyq * 0.95), zeros=51)
low = signal.clone().low_pass(cutoffs=int(nyq * 1.05), zeros=51)
loud_db = low.loudness()
low = low.resample(sample_rate)
return low, high, loud_db
@torch.no_grad()
def merge_bands(low, high, loud_db):
low = low.clone().to(high.device).resample(high.sample_rate)
low.audio_data = low.audio_data[..., :high.signal_length]
low.audio_data = torch.nn.functional.pad(
low.audio_data, (0, max(0, high.signal_length - low.signal_length))
)
return low.normalize(loud_db) + high
@torch.no_grad()
def embed(signal: AudioSignal, embedder: torch.nn.Module):
orig_ch, orig_sr = signal.num_channels, signal.sample_rate
sig = signal.clone().resample(SAMPLE_RATE)
if orig_ch > 1:
b, c, n = sig.audio_data.shape
sig.audio_data = sig.audio_data.reshape(b * c, 1, n)
low, high, loud = split_bands(sig.clone(), AUDIOSEAL_SAMPLE_RATE)
wm = embedder.get_watermark(low.audio_data, AUDIOSEAL_SAMPLE_RATE)
low.audio_data = low.audio_data + wm
merged = merge_bands(low, high, loud)
if orig_ch > 1:
b2, c2, n2 = merged.audio_data.shape
merged.audio_data = merged.audio_data.reshape(-1, orig_ch * c2, n2)
return merged.resample(orig_sr)
@torch.no_grad()
def detect(signal: AudioSignal, detector: torch.nn.Module):
sig = signal.clone().to_mono().resample(AUDIOSEAL_SAMPLE_RATE)
result, _ = detector.forward(sig.audio_data, sample_rate=AUDIOSEAL_SAMPLE_RATE)
return result[0, 1, :].detach().cpu().numpy()
def process_fn(inp_audio, option_text):
audio_np, sr = load_audio(inp_audio)
print(f"sr: {sr}, audio shape: {audio_np.shape}")
if audio_np.ndim == 1:
audio_np = audio_np[None, None, :]
else:
audio_np = np.transpose(audio_np, (1, 0))[None, ...]
print(f"formatted audio: {audio_np.shape}")
ori_sig = AudioSignal(torch.from_numpy(audio_np).float(), sample_rate=sr)
orig_loud = ori_sig.loudness()
sig = ori_sig.to_mono().resample(SAMPLE_RATE).normalize(LOUDNESS_DB).ensure_max_of_audio()
output_labels = LabelList()
if option_text == "Generate Watermark":
with torch.no_grad():
wm_sig = embed(sig.clone(), generator).normalize(orig_loud).ensure_max_of_audio()
output_labels.labels.append(
AudioLabel(
t = 0,
label = "watermark: 1.0",
duration = wm_sig.duration,
description = f"watermark confidence: 1.0, start: 0.0s, end: {wm_sig.duration:.2f}s",
color = OutputLabel.rgb_color_to_int(255, 0, 0),
amplitude = 1.0
)
)
return save_audio(wm_sig), output_labels
else:
with torch.no_grad():
scores = detect(sig, detector) # AUDIOSEAL_SAMPLE_RATE
N = len(scores)
hop = int(0.01 * AUDIOSEAL_SAMPLE_RATE)
avg_curve = []
for i in range(0, N, hop):
seg = scores[i:i+hop]
value = np.mean(seg)
avg_curve.append(value)
avg_curve = np.array(avg_curve)
print(avg_curve.shape)
min_size = max(2, int(0.25 * AUDIOSEAL_SAMPLE_RATE))
bkps = rpt.Pelt(model="l2", min_size=1).fit_predict(avg_curve, 1.0)
t0 = 0
for t1 in bkps:
print(t0, t1)
seg = avg_curve[t0:t1]
value = seg.mean()
output_labels.labels.append(
AudioLabel(
t = (t0 / 100),
label = f"watermark: {value:.2f}",
duration = (t1 - t0) / 100,
description = f"watermark confidence: {value:.2f}, start: {(t0 / 100):.2f}s, end: {(t1 / 100):.2f}s",
color = OutputLabel.rgb_color_to_int(int(value * 255), int((1 - value) * 255), 0),
amplitude = value * 2 - 1
)
)
t0 = t1
return inp_audio, output_labels
with gr.Blocks() as app:
gr.Markdown("## Meta AudioSeal Watermarking")
# Inputs
input_audio = gr.Audio(
label="Input Audio",
type="filepath",
sources=["upload", "microphone"]
)
option_dropdown = gr.Dropdown(
["Generate Watermark", "Detect Watermark"],
value='Generate Watermark',
label='Option',
info='Model Options'
)
# Outputs
output_wav = gr.Audio(
type="filepath",
label="Watermarked Speech"
)
output_label = gr.JSON(label="Watermark Confidence")
_ = build_endpoint(
model_card=model_card,
input_components=[
input_audio,
option_dropdown
],
output_components=[
output_wav,
output_label
],
process_fn=process_fn
)
if __name__ == '__main__':
app.launch(share=True, show_error=True, debug=True) |