Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchaudio | |
| import sys | |
| import os | |
| # Import sibling modules | |
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
| from deep_watermark import WatermarkDiscriminator | |
| from robust_watermark import WatermarkDetector | |
| import numpy as np | |
| import math | |
| try: | |
| # Try importing audioseal. If not installed, we might mock it or fail. | |
| # The user requested integration, so we assume it's available or we install it. | |
| # For this code, we'll use torch.hub if package not found, or just assume it's there. | |
| import audioseal | |
| except ImportError: | |
| print("Warning: 'audioseal' not found. AudioSeal branch will fail if used.") | |
| class WatermarkExpertBranch(nn.Module): | |
| """ | |
| Differentiable implementation of the Robust Watermark Detector. | |
| Extracts correlation features. | |
| """ | |
| def __init__(self, sample_rate=16000, n_fft=1024, hop_length=256): | |
| super().__init__() | |
| self.detector = WatermarkDetector(sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length) | |
| # We need to register the watermark block as a buffer so it moves with the model | |
| self.register_buffer('watermark_kernel', self.detector.watermark_block.unsqueeze(1)) | |
| self.window = torch.hann_window(n_fft) | |
| self.n_fft = n_fft | |
| self.hop_length = hop_length | |
| def forward(self, waveform): | |
| # waveform: (B, 1, T) | |
| # 1. STFT | |
| window = self.window.to(waveform.device) | |
| stft = torch.stft(waveform.squeeze(1), n_fft=self.n_fft, hop_length=self.hop_length, | |
| window=window, return_complex=True, center=True) | |
| magnitude = torch.abs(stft) # (B, F, T) | |
| # 2. Whitening (Spectral Smoothing) | |
| mag_unsqueezed = magnitude.unsqueeze(1) # (B, 1, F, T) | |
| smoothed = torch.nn.functional.avg_pool2d( | |
| mag_unsqueezed, | |
| kernel_size=(1, 15), | |
| stride=1, | |
| padding=(0, 7) | |
| ) | |
| whitened = mag_unsqueezed - smoothed | |
| whitened = whitened / (torch.std(whitened, dim=(2,3), keepdim=True) + 1e-6) | |
| # 3. Correlation | |
| # Kernel: (1, 1, F, T_block) | |
| kernel = self.watermark_kernel | |
| kernel = kernel - torch.mean(kernel) | |
| kernel = kernel / (torch.norm(kernel) + 1e-6) | |
| # Conv2d | |
| # Input: (B, 1, F, T) | |
| # Weight: (1, 1, F, T_block) | |
| correlation_map = torch.nn.functional.conv2d(whitened, kernel) # (B, 1, 1, T_out) | |
| scores = correlation_map.squeeze(1).squeeze(1) # (B, T_out) | |
| # 4. Feature Extraction | |
| # Max score, Mean score, Std score | |
| max_score = torch.max(scores, dim=1, keepdim=True)[0] | |
| mean_score = torch.mean(scores, dim=1, keepdim=True) | |
| std_score = torch.std(scores, dim=1, keepdim=True) | |
| return torch.cat([max_score, mean_score, std_score], dim=1) # (B, 3) | |
| class SynthArtifactBranch(nn.Module): | |
| """ | |
| ResNet-like CNN to detect synthetic artifacts from Mel-Spectrograms. | |
| """ | |
| def __init__(self): | |
| super().__init__() | |
| self.mel_transform = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=64) | |
| self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1) | |
| self.bn1 = nn.BatchNorm2d(16) | |
| self.relu = nn.ReLU() | |
| self.pool = nn.MaxPool2d(2, 2) | |
| def __init__(self, sample_rate=16000): | |
| super().__init__() | |
| self.mel_transform = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate, n_mels=64) | |
| self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB() | |
| self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1) | |
| self.bn1 = nn.BatchNorm2d(16) | |
| self.relu = nn.ReLU(inplace=True) | |
| self.layer1 = nn.Sequential( | |
| nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), | |
| nn.BatchNorm2d(32), | |
| nn.ReLU(inplace=True) | |
| ) | |
| self.layer2 = nn.Sequential( | |
| nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(inplace=True) | |
| ) | |
| self.pool = nn.AdaptiveAvgPool2d((1, 1)) | |
| self.fc = nn.Linear(64, 32) # Embedding size 32 | |
| def forward(self, x): | |
| # x: (B, 1, T) | |
| mel = self.mel_transform(x) # (B, 1, n_mels, time) | |
| mel = self.amplitude_to_db(mel) | |
| x = self.conv1(mel) | |
| x = self.bn1(x) | |
| x = self.relu(x) | |
| x = self.layer1(x) | |
| x = self.layer2(x) | |
| x = self.pool(x) | |
| x = x.flatten(1) | |
| x = self.fc(x) | |
| return x # (B, 32) | |
| class LFCCBranch(nn.Module): | |
| def __init__(self, sample_rate=16000): | |
| super().__init__() | |
| # LFCC: Linear Frequency Cepstral Coefficients | |
| # Good for detecting high-frequency artifacts in synthetic speech | |
| self.lfcc_transform = torchaudio.transforms.LFCC( | |
| sample_rate=sample_rate, | |
| n_lfcc=40, | |
| speckwargs={"n_fft": 1024, "win_length": 400, "hop_length": 160} | |
| ) | |
| # ResNet-style Encoder (Reuse similar architecture) | |
| self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1) | |
| self.bn1 = nn.BatchNorm2d(16) | |
| self.relu = nn.ReLU(inplace=True) | |
| self.layer1 = nn.Sequential( | |
| nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), | |
| nn.BatchNorm2d(32), | |
| nn.ReLU(inplace=True) | |
| ) | |
| self.layer2 = nn.Sequential( | |
| nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(inplace=True) | |
| ) | |
| self.pool = nn.AdaptiveAvgPool2d((1, 1)) | |
| self.fc = nn.Linear(64, 32) # Embedding size 32 | |
| def forward(self, x): | |
| # x: (B, 1, T) | |
| lfcc = self.lfcc_transform(x) # (B, 1, n_lfcc, time) | |
| # LFCC is already coefficients, no need for dB conversion | |
| x = self.conv1(lfcc) | |
| x = self.bn1(x) | |
| x = self.relu(x) | |
| x = self.layer1(x) | |
| x = self.layer2(x) | |
| x = self.pool(x) | |
| x = x.flatten(1) | |
| x = self.fc(x) | |
| return x # (B, 32) | |
| class AudioSealBranch(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| # Load pre-trained detector from installed package | |
| try: | |
| # We use the 'detector' part only. | |
| self.detector = audioseal.AudioSeal.load_detector("audioseal_detector_16bits") | |
| self.detector.eval() # Freeze | |
| for param in self.detector.parameters(): | |
| param.requires_grad = False | |
| except Exception as e: | |
| print(f"Error loading AudioSeal: {e}") | |
| self.detector = None | |
| # Feature Extraction Head | |
| # AudioSeal outputs a probability map (B, 2, T) | |
| self.pool = nn.AdaptiveAvgPool1d(1) | |
| self.fc = nn.Linear(2, 32) # Map to 32-dim embedding | |
| def forward(self, waveform): | |
| # waveform: (B, 1, T) | |
| if self.detector is None: | |
| return torch.zeros(waveform.size(0), 32).to(waveform.device) | |
| with torch.no_grad(): | |
| # AudioSeal expects (B, 1, T) | |
| # Returns: (B, 2, T) -> [Prob_Watermark, Prob_Message] | |
| # Note: AudioSeal might expect 16kHz. | |
| result = self.detector(waveform,16000) | |
| # AudioSeal returns a tuple (watermark_prob, message_prob) or similar. | |
| # We want the watermark probability map. | |
| if isinstance(result, tuple): | |
| probs = result[0] # Assume first element is the probability map | |
| else: | |
| probs = result | |
| # Ensure it's a tensor | |
| if not isinstance(probs, torch.Tensor): | |
| # Fallback if something is wrong | |
| return torch.zeros(waveform.size(0), 32).to(waveform.device) | |
| probs = probs[:, :, :] # (B, 2, T) | |
| # 1. Global Statistics (Mean probability across time) | |
| global_stats = self.pool(probs).squeeze(2) # (B, 2) | |
| # 2. Map to Embedding | |
| embedding = self.fc(global_stats) # (B, 32) | |
| return embedding | |
| class SincConv(nn.Module): | |
| """ | |
| Sinc-based convolution layer. | |
| Initializes filters as band-pass filters (Sinc functions). | |
| """ | |
| def __init__(self, out_channels, kernel_size, sample_rate=16000, min_low_hz=50, min_band_hz=50): | |
| super().__init__() | |
| if kernel_size % 2 == 0: | |
| kernel_size = kernel_size + 1 | |
| self.out_channels = out_channels | |
| self.kernel_size = kernel_size | |
| self.sample_rate = sample_rate | |
| self.min_low_hz = min_low_hz | |
| self.min_band_hz = min_band_hz | |
| # Initialize filters | |
| # We learn low_hz and band_hz | |
| low_hz = 30 | |
| high_hz = sample_rate / 2 - (min_low_hz + min_band_hz) | |
| # Mel-scale initialization | |
| mel = np.linspace(self.to_mel(low_hz), self.to_mel(high_hz), self.out_channels + 1) | |
| hz = self.to_hz(mel) | |
| # Filter parameters (Learnable) | |
| self.low_hz_ = nn.Parameter(torch.Tensor(hz[:-1]).view(-1, 1)) | |
| self.band_hz_ = nn.Parameter(torch.Tensor(np.diff(hz)).view(-1, 1)) | |
| # Hamming window | |
| self.window_ = torch.hamming_window(self.kernel_size, periodic=False) | |
| def to_mel(self, hz): | |
| return 2595 * np.log10(1 + hz / 700) | |
| def to_hz(self, mel): | |
| return 700 * (10 ** (mel / 2595) - 1) | |
| def forward(self, x): | |
| # Calculate actual frequencies | |
| low = self.min_low_hz + torch.abs(self.low_hz_) | |
| high = torch.clamp(low + self.min_band_hz + torch.abs(self.band_hz_), self.min_low_hz, self.sample_rate/2) | |
| band = (high - low)[:, 0] | |
| # Create filters in time domain | |
| f_times_t_low = torch.matmul(low, (2 * math.pi * torch.arange(-(self.kernel_size-1)/2, (self.kernel_size-1)/2 + 1).to(x.device).view(1, -1) / self.sample_rate)) | |
| f_times_t_high = torch.matmul(high, (2 * math.pi * torch.arange(-(self.kernel_size-1)/2, (self.kernel_size-1)/2 + 1).to(x.device).view(1, -1) / self.sample_rate)) | |
| # Sinc function: sin(x)/x | |
| # Bandpass = 2 * (f_high * sinc(2*pi*f_high*t) - f_low * sinc(2*pi*f_low*t)) | |
| # Note: We use a simplified implementation for stability | |
| # Ideally we use full sinc formula. Here we approximate or use standard conv if too complex for this snippet. | |
| # But let's try to be correct. | |
| # Standard Sinc: | |
| # band_pass = 2 * f2 * sinc(2*pi*f2*t) - 2 * f1 * sinc(2*pi*f1*t) | |
| # We can implement sinc manually | |
| t = torch.arange(-(self.kernel_size-1)/2, (self.kernel_size-1)/2 + 1).to(x.device).view(1, -1) / self.sample_rate | |
| t = t.repeat(self.out_channels, 1) | |
| low_f = low | |
| high_f = high | |
| # Sinc(x) = sin(x)/x | |
| # 2*f*sinc(2*pi*f*t) = 2*f * sin(2*pi*f*t) / (2*pi*f*t) = sin(2*pi*f*t) / (pi*t) | |
| band_pass_left = torch.sin(2 * math.pi * high_f * t) / (math.pi * t + 1e-6) | |
| band_pass_right = torch.sin(2 * math.pi * low_f * t) / (math.pi * t + 1e-6) | |
| # Handle t=0 | |
| # at t=0, limit is 2*f | |
| center_idx = int((self.kernel_size-1)/2) | |
| band_pass_left[:, center_idx] = 2 * high_f[:, 0] | |
| band_pass_right[:, center_idx] = 2 * low_f[:, 0] | |
| filters = band_pass_left - band_pass_right | |
| # Apply window | |
| filters = filters * self.window_.to(filters.device) | |
| return F.conv1d(x, filters.view(self.out_channels, 1, self.kernel_size)) | |
| class RawWaveBranch(nn.Module): | |
| """ | |
| Branch 6: Raw Waveform Analysis using SincNet-style layers. | |
| Detects fine-grained temporal artifacts. | |
| """ | |
| def __init__(self, sample_rate=16000): | |
| super().__init__() | |
| # Sinc Conv Layer | |
| self.sinc_conv = SincConv(out_channels=32, kernel_size=129, sample_rate=sample_rate) | |
| # Standard CNN layers following SincConv | |
| self.layer1 = nn.Sequential( | |
| nn.Conv1d(32, 64, kernel_size=3, stride=2, padding=1), | |
| nn.BatchNorm1d(64), | |
| nn.LeakyReLU(0.2), | |
| nn.MaxPool1d(2) | |
| ) | |
| self.layer2 = nn.Sequential( | |
| nn.Conv1d(64, 128, kernel_size=3, stride=2, padding=1), | |
| nn.BatchNorm1d(128), | |
| nn.LeakyReLU(0.2), | |
| nn.MaxPool1d(2) | |
| ) | |
| self.pool = nn.AdaptiveAvgPool1d(1) | |
| self.fc = nn.Linear(128, 32) | |
| def forward(self, x): | |
| # x: (B, 1, T) | |
| x = self.sinc_conv(x) # (B, 32, T') | |
| x = self.layer1(x) | |
| x = self.layer2(x) | |
| x = self.pool(x).flatten(1) | |
| x = self.fc(x) | |
| return x | |
| class UniversalDetector(nn.Module): | |
| def __init__(self, sample_rate=16000): | |
| super().__init__() | |
| # Branch 1: Watermark Expert | |
| self.watermark_expert = WatermarkExpertBranch(sample_rate) | |
| # Branch 2: Synth Artifacts | |
| self.synth_artifact = SynthArtifactBranch(sample_rate) | |
| # Branch 3: LFCC Features | |
| self.lfcc_branch = LFCCBranch(sample_rate) | |
| # Branch 4: Deep Watermark (Pre-trained Discriminator) | |
| self.deep_watermark = WatermarkDiscriminator() | |
| # Branch 5: AudioSeal (SOTA) | |
| self.audioseal_branch = AudioSealBranch() | |
| # Branch 6: RawWave Expert (SincNet) | |
| self.raw_wave_branch = RawWaveBranch(sample_rate) | |
| # Fusion Head | |
| # Inputs: | |
| # - WM Expert: 3 | |
| # - Synth Artifact: 32 | |
| # - Deep WM: 1 | |
| # - LFCC: 32 | |
| # - AudioSeal: 32 | |
| # - RawWave: 32 | |
| # Total: 3 + 32 + 1 + 32 + 32 + 32 = 132 | |
| self.fusion_head = nn.Sequential( | |
| nn.Linear(132, 128), | |
| nn.BatchNorm1d(128), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(128, 64), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(64, 2) # 2 Outputs: [Logit_Watermarked, Logit_Synth] | |
| ) | |
| def forward(self, waveform): | |
| # waveform: (B, 1, T) | |
| # 1. Expert Features | |
| wm_feats = self.watermark_expert(waveform) # (B, 3) | |
| # 2. Synth Features | |
| synth_emb = self.synth_artifact(waveform) # (B, 32) | |
| # 3. LFCC Features | |
| lfcc_emb = self.lfcc_branch(waveform) # (B, 32) | |
| # 4. Deep Features | |
| deep_prob = self.deep_watermark(waveform) # (B, 1) | |
| # 5. AudioSeal Features | |
| audioseal_emb = self.audioseal_branch(waveform) # (B, 32) | |
| # 6. RawWave Features | |
| raw_emb = self.raw_wave_branch(waveform) # (B, 32) | |
| # Fusion | |
| features = torch.cat([wm_feats, synth_emb, deep_prob, lfcc_emb, audioseal_emb, raw_emb], dim=1) # (B, 132) | |
| logits = self.fusion_head(features) # (B, 2) | |
| return logits | |