Spaces:
Build error
Build error
| import torch | |
| import torchaudio | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from typing import Tuple, Optional | |
| import os | |
| import soundfile as sf | |
| # Phase 1: Global Parameters | |
| SAMPLE_RATE = 16000 | |
| N_FFT = 1024 | |
| HOP_LENGTH = N_FFT // 4 # 256 | |
| WIN_LENGTH = N_FFT | |
| WATERMARK_KEY = 42 # Default key | |
| def load_audio(path: str) -> Tuple[torch.Tensor, int]: | |
| """Robust audio loading using soundfile.""" | |
| try: | |
| data, sr = sf.read(path) | |
| # Soundfile returns (Frames, Channels) or (Frames,) | |
| if data.ndim == 1: | |
| waveform = torch.from_numpy(data).unsqueeze(0) # (1, T) | |
| else: | |
| waveform = torch.from_numpy(data.T) # (C, T) | |
| return waveform.float(), sr | |
| except Exception as e: | |
| print(f"Error loading {path} with soundfile: {e}") | |
| # Fallback or re-raise | |
| raise e | |
| def save_audio(path: str, waveform: torch.Tensor, sample_rate: int): | |
| """Robust audio saving using soundfile.""" | |
| # Waveform is (C, T) | |
| data = waveform.detach().cpu().numpy().T | |
| sf.write(path, data, sample_rate) | |
| class WatermarkEmbedder: | |
| def __init__(self, sample_rate: int = SAMPLE_RATE, n_fft: int = N_FFT, | |
| hop_length: int = HOP_LENGTH, key: int = WATERMARK_KEY, | |
| alpha: float = 3.0): # Increased alpha for better detection | |
| self.sample_rate = sample_rate | |
| self.n_fft = n_fft | |
| self.hop_length = hop_length | |
| self.key = key | |
| self.alpha = alpha | |
| self.window = torch.hann_window(self.n_fft) | |
| def _get_masking_threshold(self, magnitude: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Calculates a simplified psychoacoustic masking threshold. | |
| """ | |
| freqs = torch.linspace(0, self.sample_rate / 2, magnitude.shape[-2]) | |
| # 1. Absolute Threshold of Hearing (ATH) | |
| f_khz = freqs / 1000.0 | |
| f_khz = torch.clamp(f_khz, min=0.02) # Clamp to 20Hz to prevent overflow | |
| ath = 3.64 * (f_khz ** -0.8) - 6.5 * torch.exp(-0.6 * (f_khz - 3.3) ** 2) + 1e-3 * (f_khz ** 4) | |
| ath = 10 ** (ath / 20) # Convert dB to amplitude | |
| ath = ath.view(1, -1, 1).to(magnitude.device) | |
| # 2. Tonal Masking | |
| # Spread energy | |
| mag_unsqueezed = magnitude.unsqueeze(1) | |
| kernel = torch.tensor([0.1, 0.3, 1.0, 0.3, 0.1], device=magnitude.device).view(1, 1, -1, 1) | |
| spread_energy = torch.nn.functional.conv2d( | |
| mag_unsqueezed, | |
| kernel, | |
| padding=(2, 0) | |
| ).squeeze(1) | |
| # Masking threshold | |
| masking_threshold = torch.max(ath, spread_energy * 0.1) | |
| return masking_threshold | |
| def embed(self, audio_path: str, output_path: str, visualize: bool = False): | |
| # 1. Input Normalization | |
| waveform, sr = load_audio(audio_path) | |
| if sr != self.sample_rate: | |
| resampler = torchaudio.transforms.Resample(sr, self.sample_rate) | |
| waveform = resampler(waveform) | |
| max_val = torch.abs(waveform).max() | |
| if max_val > 0: | |
| waveform = waveform / max_val | |
| original_channels = waveform.shape[0] | |
| if original_channels == 2: | |
| mid = (waveform[0] + waveform[1]) / 2 | |
| side = (waveform[0] - waveform[1]) / 2 | |
| target_signal = mid.unsqueeze(0) | |
| else: | |
| target_signal = waveform | |
| # 2. STFT | |
| self.window = self.window.to(waveform.device) | |
| stft = torch.stft( | |
| target_signal, | |
| n_fft=self.n_fft, | |
| hop_length=self.hop_length, | |
| window=self.window, | |
| return_complex=True, | |
| center=True | |
| ) | |
| magnitude = torch.abs(stft) | |
| phase = torch.angle(stft) | |
| # 3. Psychoacoustic Masking | |
| masking_threshold = self._get_masking_threshold(magnitude) | |
| # 4. Watermark Generation | |
| # Block size 48 frames (~0.75 sec) | |
| frames_per_block = 48 | |
| freq_bins = magnitude.shape[1] | |
| g = torch.Generator(device=waveform.device) | |
| g.manual_seed(self.key) | |
| watermark_block = (torch.rand((1, freq_bins, frames_per_block), generator=g, device=waveform.device) * 2) - 1 | |
| total_frames = magnitude.shape[2] | |
| num_repeats = (total_frames // frames_per_block) + 1 | |
| watermark_full = watermark_block.repeat(1, 1, num_repeats) | |
| watermark_full = watermark_full[:, :, :total_frames] | |
| # 5. Injection | |
| injection_signal = (self.alpha * watermark_full * masking_threshold) | |
| magnitude_mod = magnitude + injection_signal | |
| # Visualization | |
| if visualize: | |
| self._plot_embedding_stats(magnitude, masking_threshold, injection_signal, magnitude_mod, output_path) | |
| # 6. Reconstruction | |
| stft_mod = torch.polar(magnitude_mod, phase) | |
| reconstructed = torch.istft( | |
| stft_mod, | |
| n_fft=self.n_fft, | |
| hop_length=self.hop_length, | |
| window=self.window, | |
| center=True, | |
| length=target_signal.shape[-1] | |
| ) | |
| if original_channels == 2: | |
| rec_mid = reconstructed.squeeze(0) | |
| rec_l = rec_mid + side | |
| rec_r = rec_mid - side | |
| final_audio = torch.stack([rec_l, rec_r]) | |
| else: | |
| final_audio = reconstructed | |
| save_audio(output_path, final_audio, self.sample_rate) | |
| return final_audio | |
| def _plot_embedding_stats(self, magnitude, masking_threshold, injection, magnitude_mod, output_path): | |
| """Generates plots for the embedding process.""" | |
| mag_np = 20 * torch.log10(magnitude[0] + 1e-6).cpu().numpy() | |
| mask_np = 20 * torch.log10(masking_threshold[0] + 1e-6).cpu().numpy() | |
| inj_np = 20 * torch.log10(torch.abs(injection[0]) + 1e-6).cpu().numpy() | |
| mod_np = 20 * torch.log10(magnitude_mod[0] + 1e-6).cpu().numpy() | |
| plt.figure(figsize=(15, 10)) | |
| plt.subplot(2, 2, 1) | |
| plt.imshow(mag_np, aspect='auto', origin='lower', cmap='inferno') | |
| plt.title("Original Spectrogram (dB)") | |
| plt.colorbar(format='%+2.0f dB') | |
| plt.subplot(2, 2, 2) | |
| plt.imshow(mask_np, aspect='auto', origin='lower', cmap='viridis') | |
| plt.title("Masking Threshold (dB)") | |
| plt.colorbar(format='%+2.0f dB') | |
| plt.subplot(2, 2, 3) | |
| plt.imshow(inj_np, aspect='auto', origin='lower', cmap='magma') | |
| plt.title("Injected Watermark Signal (dB)") | |
| plt.colorbar(format='%+2.0f dB') | |
| plt.subplot(2, 2, 4) | |
| plt.imshow(mod_np, aspect='auto', origin='lower', cmap='inferno') | |
| plt.title("Watermarked Spectrogram (dB)") | |
| plt.colorbar(format='%+2.0f dB') | |
| plt.tight_layout() | |
| plot_path = os.path.splitext(output_path)[0] + "_embedding_analysis_aa.png" | |
| plt.savefig(plot_path) | |
| plt.close() | |
| print(f"Saved embedding analysis to {plot_path}") | |
| class WatermarkDetector: | |
| def __init__(self, sample_rate: int = SAMPLE_RATE, n_fft: int = N_FFT, | |
| hop_length: int = HOP_LENGTH, key: int = WATERMARK_KEY): | |
| self.sample_rate = sample_rate | |
| self.n_fft = n_fft | |
| self.hop_length = hop_length | |
| self.key = key | |
| self.window = torch.hann_window(self.n_fft) | |
| # Generate the reference watermark block | |
| self.frames_per_block = 48 | |
| freq_bins = n_fft // 2 + 1 | |
| g = torch.Generator() | |
| g.manual_seed(self.key) | |
| self.watermark_block = (torch.rand((1, freq_bins, self.frames_per_block), generator=g) * 2) - 1 | |
| def detect(self, audio_path: str, threshold: float = 0.05, visualize: bool = False) -> bool: | |
| # 1. Preprocessing | |
| waveform, sr = load_audio(audio_path) | |
| if waveform.shape[0] > 1: | |
| waveform = torch.mean(waveform, dim=0, keepdim=True) | |
| if sr != self.sample_rate: | |
| resampler = torchaudio.transforms.Resample(sr, self.sample_rate) | |
| waveform = resampler(waveform) | |
| # 2. Synchronization | |
| self.window = self.window.to(waveform.device) | |
| stft = torch.stft( | |
| waveform, | |
| n_fft=self.n_fft, | |
| hop_length=self.hop_length, | |
| window=self.window, | |
| return_complex=True, | |
| center=True | |
| ) | |
| magnitude = torch.abs(stft) | |
| # 3. Whitening | |
| # Subtract moving average along time to remove speech formants | |
| # Kernel size for smoothing: (1, 15) | |
| mag_unsqueezed = magnitude.unsqueeze(1) | |
| smoothed = torch.nn.functional.avg_pool2d( | |
| mag_unsqueezed, | |
| kernel_size=(1, 15), | |
| stride=1, | |
| padding=(0, 7) | |
| ) | |
| # Handle size mismatch due to padding/pooling if any (avg_pool2d with stride 1 and padding should preserve size) | |
| # But avg_pool2d might truncate edges if not careful. | |
| # Let's use 'same' padding logic. | |
| # With kernel 15, padding 7, output size is Input + 2*7 - 15 + 1 = Input. Correct. | |
| whitened = mag_unsqueezed - smoothed | |
| whitened = whitened.squeeze(1) | |
| # Normalize variance | |
| whitened = whitened / (torch.std(whitened) + 1e-6) | |
| # 4. Correlation Check | |
| input_signal = whitened.unsqueeze(1) | |
| # Kernel | |
| kernel = self.watermark_block.to(waveform.device).unsqueeze(1) | |
| kernel = kernel - torch.mean(kernel) | |
| kernel = kernel / (torch.norm(kernel) + 1e-6) | |
| # Check if input is smaller than kernel | |
| if input_signal.shape[-1] < kernel.shape[-1]: | |
| print("Warning: Input audio too short for detection.") | |
| return False | |
| correlation_map = torch.nn.functional.conv2d(input_signal, kernel) | |
| scores = correlation_map.squeeze() | |
| if scores.numel() == 0: | |
| return False | |
| max_score = torch.max(scores).item() | |
| print(f"Max Correlation Score: {max_score}") | |
| if visualize: | |
| self._plot_detection_stats(whitened, scores, max_score, threshold, audio_path) | |
| if max_score > threshold: | |
| return True | |
| return False | |
| def _plot_detection_stats(self, whitened, scores, max_score, threshold, audio_path): | |
| """Generates plots for the detection process.""" | |
| whitened_np = whitened[0].cpu().numpy() | |
| scores_np = scores.cpu().numpy() | |
| plt.figure(figsize=(15, 8)) | |
| plt.subplot(2, 1, 1) | |
| plt.imshow(whitened_np, aspect='auto', origin='lower', cmap='coolwarm', vmin=-3, vmax=3) | |
| plt.title("Whitened Spectrogram (Signal - Smoothed Background)") | |
| plt.colorbar() | |
| plt.subplot(2, 1, 2) | |
| plt.plot(scores_np) | |
| plt.axhline(y=threshold, color='r', linestyle='--', label=f'Threshold ({threshold})') | |
| plt.axhline(y=max_score, color='g', linestyle=':', label=f'Max Score ({max_score:.2f})') | |
| plt.title("Correlation Score (Sliding Window)") | |
| plt.xlabel("Time Frame Index") | |
| plt.ylabel("Pearson Correlation") | |
| plt.legend() | |
| plt.tight_layout() | |
| plot_path = os.path.splitext(audio_path)[0] + "_detection_analysis_aa.png" | |
| plt.savefig(plot_path) | |
| plt.close() | |
| print(f"Saved detection analysis to {plot_path}") | |
| if __name__ == "__main__": | |
| # Simple test | |
| import os | |
| embedder = WatermarkEmbedder(alpha=5.0) # Stronger watermark for test | |
| detector = WatermarkDetector() | |
| # Create dummy audio | |
| sr = 16000 | |
| duration = 5 | |
| t = torch.linspace(0, duration, sr * duration) | |
| audio = 0.5 * torch.sin(2 * torch.pi * 440 * t) + 0.1 * torch.randn_like(t) | |
| audio = audio.unsqueeze(0) | |
| test_file = "test_original.wav" | |
| watermarked_file = "test_watermarked.wav" | |
| save_audio(test_file, audio, sr) | |
| print("Embedding watermark...") | |
| embedder.embed(test_file, watermarked_file, visualize=True) | |
| print("Detecting watermark...") | |
| detected = detector.detect(watermarked_file, threshold=0.02, visualize=True) | |
| print(f"Detected: {detected}") | |
| print("Testing with crop...") | |
| wm_audio, _ = load_audio(watermarked_file) | |
| crop_start = sr * 1 | |
| crop_end = sr * 3 # 2 seconds | |
| cropped_audio = wm_audio[:, crop_start:crop_end] | |
| cropped_file = "test_cropped.wav" | |
| save_audio(cropped_file, cropped_audio, sr) | |
| detected_crop = detector.detect(cropped_file, threshold=0.02) | |
| print(f"Detected in crop: {detected_crop}") | |