Ubuntu
Add application file
cda8304
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import sys
import os
# Import sibling modules
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from deep_watermark import WatermarkDiscriminator
from robust_watermark import WatermarkDetector
import numpy as np
import math
try:
# Try importing audioseal. If not installed, we might mock it or fail.
# The user requested integration, so we assume it's available or we install it.
# For this code, we'll use torch.hub if package not found, or just assume it's there.
import audioseal
except ImportError:
print("Warning: 'audioseal' not found. AudioSeal branch will fail if used.")
class WatermarkExpertBranch(nn.Module):
"""
Differentiable implementation of the Robust Watermark Detector.
Extracts correlation features.
"""
def __init__(self, sample_rate=16000, n_fft=1024, hop_length=256):
super().__init__()
self.detector = WatermarkDetector(sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length)
# We need to register the watermark block as a buffer so it moves with the model
self.register_buffer('watermark_kernel', self.detector.watermark_block.unsqueeze(1))
self.window = torch.hann_window(n_fft)
self.n_fft = n_fft
self.hop_length = hop_length
def forward(self, waveform):
# waveform: (B, 1, T)
# 1. STFT
window = self.window.to(waveform.device)
stft = torch.stft(waveform.squeeze(1), n_fft=self.n_fft, hop_length=self.hop_length,
window=window, return_complex=True, center=True)
magnitude = torch.abs(stft) # (B, F, T)
# 2. Whitening (Spectral Smoothing)
mag_unsqueezed = magnitude.unsqueeze(1) # (B, 1, F, T)
smoothed = torch.nn.functional.avg_pool2d(
mag_unsqueezed,
kernel_size=(1, 15),
stride=1,
padding=(0, 7)
)
whitened = mag_unsqueezed - smoothed
whitened = whitened / (torch.std(whitened, dim=(2,3), keepdim=True) + 1e-6)
# 3. Correlation
# Kernel: (1, 1, F, T_block)
kernel = self.watermark_kernel
kernel = kernel - torch.mean(kernel)
kernel = kernel / (torch.norm(kernel) + 1e-6)
# Conv2d
# Input: (B, 1, F, T)
# Weight: (1, 1, F, T_block)
correlation_map = torch.nn.functional.conv2d(whitened, kernel) # (B, 1, 1, T_out)
scores = correlation_map.squeeze(1).squeeze(1) # (B, T_out)
# 4. Feature Extraction
# Max score, Mean score, Std score
max_score = torch.max(scores, dim=1, keepdim=True)[0]
mean_score = torch.mean(scores, dim=1, keepdim=True)
std_score = torch.std(scores, dim=1, keepdim=True)
return torch.cat([max_score, mean_score, std_score], dim=1) # (B, 3)
class SynthArtifactBranch(nn.Module):
"""
ResNet-like CNN to detect synthetic artifacts from Mel-Spectrograms.
"""
def __init__(self):
super().__init__()
self.mel_transform = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=64)
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(16)
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(2, 2)
def __init__(self, sample_rate=16000):
super().__init__()
self.mel_transform = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate, n_mels=64)
self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(16)
self.relu = nn.ReLU(inplace=True)
self.layer1 = nn.Sequential(
nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True)
)
self.layer2 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
)
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(64, 32) # Embedding size 32
def forward(self, x):
# x: (B, 1, T)
mel = self.mel_transform(x) # (B, 1, n_mels, time)
mel = self.amplitude_to_db(mel)
x = self.conv1(mel)
x = self.bn1(x)
x = self.relu(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.pool(x)
x = x.flatten(1)
x = self.fc(x)
return x # (B, 32)
class LFCCBranch(nn.Module):
def __init__(self, sample_rate=16000):
super().__init__()
# LFCC: Linear Frequency Cepstral Coefficients
# Good for detecting high-frequency artifacts in synthetic speech
self.lfcc_transform = torchaudio.transforms.LFCC(
sample_rate=sample_rate,
n_lfcc=40,
speckwargs={"n_fft": 1024, "win_length": 400, "hop_length": 160}
)
# ResNet-style Encoder (Reuse similar architecture)
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(16)
self.relu = nn.ReLU(inplace=True)
self.layer1 = nn.Sequential(
nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True)
)
self.layer2 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
)
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(64, 32) # Embedding size 32
def forward(self, x):
# x: (B, 1, T)
lfcc = self.lfcc_transform(x) # (B, 1, n_lfcc, time)
# LFCC is already coefficients, no need for dB conversion
x = self.conv1(lfcc)
x = self.bn1(x)
x = self.relu(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.pool(x)
x = x.flatten(1)
x = self.fc(x)
return x # (B, 32)
class AudioSealBranch(nn.Module):
def __init__(self):
super().__init__()
# Load pre-trained detector from installed package
try:
# We use the 'detector' part only.
self.detector = audioseal.AudioSeal.load_detector("audioseal_detector_16bits")
self.detector.eval() # Freeze
for param in self.detector.parameters():
param.requires_grad = False
except Exception as e:
print(f"Error loading AudioSeal: {e}")
self.detector = None
# Feature Extraction Head
# AudioSeal outputs a probability map (B, 2, T)
self.pool = nn.AdaptiveAvgPool1d(1)
self.fc = nn.Linear(2, 32) # Map to 32-dim embedding
def forward(self, waveform):
# waveform: (B, 1, T)
if self.detector is None:
return torch.zeros(waveform.size(0), 32).to(waveform.device)
with torch.no_grad():
# AudioSeal expects (B, 1, T)
# Returns: (B, 2, T) -> [Prob_Watermark, Prob_Message]
# Note: AudioSeal might expect 16kHz.
result = self.detector(waveform,16000)
# AudioSeal returns a tuple (watermark_prob, message_prob) or similar.
# We want the watermark probability map.
if isinstance(result, tuple):
probs = result[0] # Assume first element is the probability map
else:
probs = result
# Ensure it's a tensor
if not isinstance(probs, torch.Tensor):
# Fallback if something is wrong
return torch.zeros(waveform.size(0), 32).to(waveform.device)
probs = probs[:, :, :] # (B, 2, T)
# 1. Global Statistics (Mean probability across time)
global_stats = self.pool(probs).squeeze(2) # (B, 2)
# 2. Map to Embedding
embedding = self.fc(global_stats) # (B, 32)
return embedding
class SincConv(nn.Module):
"""
Sinc-based convolution layer.
Initializes filters as band-pass filters (Sinc functions).
"""
def __init__(self, out_channels, kernel_size, sample_rate=16000, min_low_hz=50, min_band_hz=50):
super().__init__()
if kernel_size % 2 == 0:
kernel_size = kernel_size + 1
self.out_channels = out_channels
self.kernel_size = kernel_size
self.sample_rate = sample_rate
self.min_low_hz = min_low_hz
self.min_band_hz = min_band_hz
# Initialize filters
# We learn low_hz and band_hz
low_hz = 30
high_hz = sample_rate / 2 - (min_low_hz + min_band_hz)
# Mel-scale initialization
mel = np.linspace(self.to_mel(low_hz), self.to_mel(high_hz), self.out_channels + 1)
hz = self.to_hz(mel)
# Filter parameters (Learnable)
self.low_hz_ = nn.Parameter(torch.Tensor(hz[:-1]).view(-1, 1))
self.band_hz_ = nn.Parameter(torch.Tensor(np.diff(hz)).view(-1, 1))
# Hamming window
self.window_ = torch.hamming_window(self.kernel_size, periodic=False)
def to_mel(self, hz):
return 2595 * np.log10(1 + hz / 700)
def to_hz(self, mel):
return 700 * (10 ** (mel / 2595) - 1)
def forward(self, x):
# Calculate actual frequencies
low = self.min_low_hz + torch.abs(self.low_hz_)
high = torch.clamp(low + self.min_band_hz + torch.abs(self.band_hz_), self.min_low_hz, self.sample_rate/2)
band = (high - low)[:, 0]
# Create filters in time domain
f_times_t_low = torch.matmul(low, (2 * math.pi * torch.arange(-(self.kernel_size-1)/2, (self.kernel_size-1)/2 + 1).to(x.device).view(1, -1) / self.sample_rate))
f_times_t_high = torch.matmul(high, (2 * math.pi * torch.arange(-(self.kernel_size-1)/2, (self.kernel_size-1)/2 + 1).to(x.device).view(1, -1) / self.sample_rate))
# Sinc function: sin(x)/x
# Bandpass = 2 * (f_high * sinc(2*pi*f_high*t) - f_low * sinc(2*pi*f_low*t))
# Note: We use a simplified implementation for stability
# Ideally we use full sinc formula. Here we approximate or use standard conv if too complex for this snippet.
# But let's try to be correct.
# Standard Sinc:
# band_pass = 2 * f2 * sinc(2*pi*f2*t) - 2 * f1 * sinc(2*pi*f1*t)
# We can implement sinc manually
t = torch.arange(-(self.kernel_size-1)/2, (self.kernel_size-1)/2 + 1).to(x.device).view(1, -1) / self.sample_rate
t = t.repeat(self.out_channels, 1)
low_f = low
high_f = high
# Sinc(x) = sin(x)/x
# 2*f*sinc(2*pi*f*t) = 2*f * sin(2*pi*f*t) / (2*pi*f*t) = sin(2*pi*f*t) / (pi*t)
band_pass_left = torch.sin(2 * math.pi * high_f * t) / (math.pi * t + 1e-6)
band_pass_right = torch.sin(2 * math.pi * low_f * t) / (math.pi * t + 1e-6)
# Handle t=0
# at t=0, limit is 2*f
center_idx = int((self.kernel_size-1)/2)
band_pass_left[:, center_idx] = 2 * high_f[:, 0]
band_pass_right[:, center_idx] = 2 * low_f[:, 0]
filters = band_pass_left - band_pass_right
# Apply window
filters = filters * self.window_.to(filters.device)
return F.conv1d(x, filters.view(self.out_channels, 1, self.kernel_size))
class RawWaveBranch(nn.Module):
"""
Branch 6: Raw Waveform Analysis using SincNet-style layers.
Detects fine-grained temporal artifacts.
"""
def __init__(self, sample_rate=16000):
super().__init__()
# Sinc Conv Layer
self.sinc_conv = SincConv(out_channels=32, kernel_size=129, sample_rate=sample_rate)
# Standard CNN layers following SincConv
self.layer1 = nn.Sequential(
nn.Conv1d(32, 64, kernel_size=3, stride=2, padding=1),
nn.BatchNorm1d(64),
nn.LeakyReLU(0.2),
nn.MaxPool1d(2)
)
self.layer2 = nn.Sequential(
nn.Conv1d(64, 128, kernel_size=3, stride=2, padding=1),
nn.BatchNorm1d(128),
nn.LeakyReLU(0.2),
nn.MaxPool1d(2)
)
self.pool = nn.AdaptiveAvgPool1d(1)
self.fc = nn.Linear(128, 32)
def forward(self, x):
# x: (B, 1, T)
x = self.sinc_conv(x) # (B, 32, T')
x = self.layer1(x)
x = self.layer2(x)
x = self.pool(x).flatten(1)
x = self.fc(x)
return x
class UniversalDetector(nn.Module):
def __init__(self, sample_rate=16000):
super().__init__()
# Branch 1: Watermark Expert
self.watermark_expert = WatermarkExpertBranch(sample_rate)
# Branch 2: Synth Artifacts
self.synth_artifact = SynthArtifactBranch(sample_rate)
# Branch 3: LFCC Features
self.lfcc_branch = LFCCBranch(sample_rate)
# Branch 4: Deep Watermark (Pre-trained Discriminator)
self.deep_watermark = WatermarkDiscriminator()
# Branch 5: AudioSeal (SOTA)
self.audioseal_branch = AudioSealBranch()
# Branch 6: RawWave Expert (SincNet)
self.raw_wave_branch = RawWaveBranch(sample_rate)
# Fusion Head
# Inputs:
# - WM Expert: 3
# - Synth Artifact: 32
# - Deep WM: 1
# - LFCC: 32
# - AudioSeal: 32
# - RawWave: 32
# Total: 3 + 32 + 1 + 32 + 32 + 32 = 132
self.fusion_head = nn.Sequential(
nn.Linear(132, 128),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(128, 64),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(64, 2) # 2 Outputs: [Logit_Watermarked, Logit_Synth]
)
def forward(self, waveform):
# waveform: (B, 1, T)
# 1. Expert Features
wm_feats = self.watermark_expert(waveform) # (B, 3)
# 2. Synth Features
synth_emb = self.synth_artifact(waveform) # (B, 32)
# 3. LFCC Features
lfcc_emb = self.lfcc_branch(waveform) # (B, 32)
# 4. Deep Features
deep_prob = self.deep_watermark(waveform) # (B, 1)
# 5. AudioSeal Features
audioseal_emb = self.audioseal_branch(waveform) # (B, 32)
# 6. RawWave Features
raw_emb = self.raw_wave_branch(waveform) # (B, 32)
# Fusion
features = torch.cat([wm_feats, synth_emb, deep_prob, lfcc_emb, audioseal_emb, raw_emb], dim=1) # (B, 132)
logits = self.fusion_head(features) # (B, 2)
return logits