Synthatic_Audio_detetcion / robust_watermark.py
Ubuntu
Add application file
cda8304
import torch
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, Optional
import os
import soundfile as sf
# Phase 1: Global Parameters
SAMPLE_RATE = 16000
N_FFT = 1024
HOP_LENGTH = N_FFT // 4 # 256
WIN_LENGTH = N_FFT
WATERMARK_KEY = 42 # Default key
def load_audio(path: str) -> Tuple[torch.Tensor, int]:
"""Robust audio loading using soundfile."""
try:
data, sr = sf.read(path)
# Soundfile returns (Frames, Channels) or (Frames,)
if data.ndim == 1:
waveform = torch.from_numpy(data).unsqueeze(0) # (1, T)
else:
waveform = torch.from_numpy(data.T) # (C, T)
return waveform.float(), sr
except Exception as e:
print(f"Error loading {path} with soundfile: {e}")
# Fallback or re-raise
raise e
def save_audio(path: str, waveform: torch.Tensor, sample_rate: int):
"""Robust audio saving using soundfile."""
# Waveform is (C, T)
data = waveform.detach().cpu().numpy().T
sf.write(path, data, sample_rate)
class WatermarkEmbedder:
def __init__(self, sample_rate: int = SAMPLE_RATE, n_fft: int = N_FFT,
hop_length: int = HOP_LENGTH, key: int = WATERMARK_KEY,
alpha: float = 3.0): # Increased alpha for better detection
self.sample_rate = sample_rate
self.n_fft = n_fft
self.hop_length = hop_length
self.key = key
self.alpha = alpha
self.window = torch.hann_window(self.n_fft)
def _get_masking_threshold(self, magnitude: torch.Tensor) -> torch.Tensor:
"""
Calculates a simplified psychoacoustic masking threshold.
"""
freqs = torch.linspace(0, self.sample_rate / 2, magnitude.shape[-2])
# 1. Absolute Threshold of Hearing (ATH)
f_khz = freqs / 1000.0
f_khz = torch.clamp(f_khz, min=0.02) # Clamp to 20Hz to prevent overflow
ath = 3.64 * (f_khz ** -0.8) - 6.5 * torch.exp(-0.6 * (f_khz - 3.3) ** 2) + 1e-3 * (f_khz ** 4)
ath = 10 ** (ath / 20) # Convert dB to amplitude
ath = ath.view(1, -1, 1).to(magnitude.device)
# 2. Tonal Masking
# Spread energy
mag_unsqueezed = magnitude.unsqueeze(1)
kernel = torch.tensor([0.1, 0.3, 1.0, 0.3, 0.1], device=magnitude.device).view(1, 1, -1, 1)
spread_energy = torch.nn.functional.conv2d(
mag_unsqueezed,
kernel,
padding=(2, 0)
).squeeze(1)
# Masking threshold
masking_threshold = torch.max(ath, spread_energy * 0.1)
return masking_threshold
def embed(self, audio_path: str, output_path: str, visualize: bool = False):
# 1. Input Normalization
waveform, sr = load_audio(audio_path)
if sr != self.sample_rate:
resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
waveform = resampler(waveform)
max_val = torch.abs(waveform).max()
if max_val > 0:
waveform = waveform / max_val
original_channels = waveform.shape[0]
if original_channels == 2:
mid = (waveform[0] + waveform[1]) / 2
side = (waveform[0] - waveform[1]) / 2
target_signal = mid.unsqueeze(0)
else:
target_signal = waveform
# 2. STFT
self.window = self.window.to(waveform.device)
stft = torch.stft(
target_signal,
n_fft=self.n_fft,
hop_length=self.hop_length,
window=self.window,
return_complex=True,
center=True
)
magnitude = torch.abs(stft)
phase = torch.angle(stft)
# 3. Psychoacoustic Masking
masking_threshold = self._get_masking_threshold(magnitude)
# 4. Watermark Generation
# Block size 48 frames (~0.75 sec)
frames_per_block = 48
freq_bins = magnitude.shape[1]
g = torch.Generator(device=waveform.device)
g.manual_seed(self.key)
watermark_block = (torch.rand((1, freq_bins, frames_per_block), generator=g, device=waveform.device) * 2) - 1
total_frames = magnitude.shape[2]
num_repeats = (total_frames // frames_per_block) + 1
watermark_full = watermark_block.repeat(1, 1, num_repeats)
watermark_full = watermark_full[:, :, :total_frames]
# 5. Injection
injection_signal = (self.alpha * watermark_full * masking_threshold)
magnitude_mod = magnitude + injection_signal
# Visualization
if visualize:
self._plot_embedding_stats(magnitude, masking_threshold, injection_signal, magnitude_mod, output_path)
# 6. Reconstruction
stft_mod = torch.polar(magnitude_mod, phase)
reconstructed = torch.istft(
stft_mod,
n_fft=self.n_fft,
hop_length=self.hop_length,
window=self.window,
center=True,
length=target_signal.shape[-1]
)
if original_channels == 2:
rec_mid = reconstructed.squeeze(0)
rec_l = rec_mid + side
rec_r = rec_mid - side
final_audio = torch.stack([rec_l, rec_r])
else:
final_audio = reconstructed
save_audio(output_path, final_audio, self.sample_rate)
return final_audio
def _plot_embedding_stats(self, magnitude, masking_threshold, injection, magnitude_mod, output_path):
"""Generates plots for the embedding process."""
mag_np = 20 * torch.log10(magnitude[0] + 1e-6).cpu().numpy()
mask_np = 20 * torch.log10(masking_threshold[0] + 1e-6).cpu().numpy()
inj_np = 20 * torch.log10(torch.abs(injection[0]) + 1e-6).cpu().numpy()
mod_np = 20 * torch.log10(magnitude_mod[0] + 1e-6).cpu().numpy()
plt.figure(figsize=(15, 10))
plt.subplot(2, 2, 1)
plt.imshow(mag_np, aspect='auto', origin='lower', cmap='inferno')
plt.title("Original Spectrogram (dB)")
plt.colorbar(format='%+2.0f dB')
plt.subplot(2, 2, 2)
plt.imshow(mask_np, aspect='auto', origin='lower', cmap='viridis')
plt.title("Masking Threshold (dB)")
plt.colorbar(format='%+2.0f dB')
plt.subplot(2, 2, 3)
plt.imshow(inj_np, aspect='auto', origin='lower', cmap='magma')
plt.title("Injected Watermark Signal (dB)")
plt.colorbar(format='%+2.0f dB')
plt.subplot(2, 2, 4)
plt.imshow(mod_np, aspect='auto', origin='lower', cmap='inferno')
plt.title("Watermarked Spectrogram (dB)")
plt.colorbar(format='%+2.0f dB')
plt.tight_layout()
plot_path = os.path.splitext(output_path)[0] + "_embedding_analysis_aa.png"
plt.savefig(plot_path)
plt.close()
print(f"Saved embedding analysis to {plot_path}")
class WatermarkDetector:
def __init__(self, sample_rate: int = SAMPLE_RATE, n_fft: int = N_FFT,
hop_length: int = HOP_LENGTH, key: int = WATERMARK_KEY):
self.sample_rate = sample_rate
self.n_fft = n_fft
self.hop_length = hop_length
self.key = key
self.window = torch.hann_window(self.n_fft)
# Generate the reference watermark block
self.frames_per_block = 48
freq_bins = n_fft // 2 + 1
g = torch.Generator()
g.manual_seed(self.key)
self.watermark_block = (torch.rand((1, freq_bins, self.frames_per_block), generator=g) * 2) - 1
def detect(self, audio_path: str, threshold: float = 0.05, visualize: bool = False) -> bool:
# 1. Preprocessing
waveform, sr = load_audio(audio_path)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
if sr != self.sample_rate:
resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
waveform = resampler(waveform)
# 2. Synchronization
self.window = self.window.to(waveform.device)
stft = torch.stft(
waveform,
n_fft=self.n_fft,
hop_length=self.hop_length,
window=self.window,
return_complex=True,
center=True
)
magnitude = torch.abs(stft)
# 3. Whitening
# Subtract moving average along time to remove speech formants
# Kernel size for smoothing: (1, 15)
mag_unsqueezed = magnitude.unsqueeze(1)
smoothed = torch.nn.functional.avg_pool2d(
mag_unsqueezed,
kernel_size=(1, 15),
stride=1,
padding=(0, 7)
)
# Handle size mismatch due to padding/pooling if any (avg_pool2d with stride 1 and padding should preserve size)
# But avg_pool2d might truncate edges if not careful.
# Let's use 'same' padding logic.
# With kernel 15, padding 7, output size is Input + 2*7 - 15 + 1 = Input. Correct.
whitened = mag_unsqueezed - smoothed
whitened = whitened.squeeze(1)
# Normalize variance
whitened = whitened / (torch.std(whitened) + 1e-6)
# 4. Correlation Check
input_signal = whitened.unsqueeze(1)
# Kernel
kernel = self.watermark_block.to(waveform.device).unsqueeze(1)
kernel = kernel - torch.mean(kernel)
kernel = kernel / (torch.norm(kernel) + 1e-6)
# Check if input is smaller than kernel
if input_signal.shape[-1] < kernel.shape[-1]:
print("Warning: Input audio too short for detection.")
return False
correlation_map = torch.nn.functional.conv2d(input_signal, kernel)
scores = correlation_map.squeeze()
if scores.numel() == 0:
return False
max_score = torch.max(scores).item()
print(f"Max Correlation Score: {max_score}")
if visualize:
self._plot_detection_stats(whitened, scores, max_score, threshold, audio_path)
if max_score > threshold:
return True
return False
def _plot_detection_stats(self, whitened, scores, max_score, threshold, audio_path):
"""Generates plots for the detection process."""
whitened_np = whitened[0].cpu().numpy()
scores_np = scores.cpu().numpy()
plt.figure(figsize=(15, 8))
plt.subplot(2, 1, 1)
plt.imshow(whitened_np, aspect='auto', origin='lower', cmap='coolwarm', vmin=-3, vmax=3)
plt.title("Whitened Spectrogram (Signal - Smoothed Background)")
plt.colorbar()
plt.subplot(2, 1, 2)
plt.plot(scores_np)
plt.axhline(y=threshold, color='r', linestyle='--', label=f'Threshold ({threshold})')
plt.axhline(y=max_score, color='g', linestyle=':', label=f'Max Score ({max_score:.2f})')
plt.title("Correlation Score (Sliding Window)")
plt.xlabel("Time Frame Index")
plt.ylabel("Pearson Correlation")
plt.legend()
plt.tight_layout()
plot_path = os.path.splitext(audio_path)[0] + "_detection_analysis_aa.png"
plt.savefig(plot_path)
plt.close()
print(f"Saved detection analysis to {plot_path}")
if __name__ == "__main__":
# Simple test
import os
embedder = WatermarkEmbedder(alpha=5.0) # Stronger watermark for test
detector = WatermarkDetector()
# Create dummy audio
sr = 16000
duration = 5
t = torch.linspace(0, duration, sr * duration)
audio = 0.5 * torch.sin(2 * torch.pi * 440 * t) + 0.1 * torch.randn_like(t)
audio = audio.unsqueeze(0)
test_file = "test_original.wav"
watermarked_file = "test_watermarked.wav"
save_audio(test_file, audio, sr)
print("Embedding watermark...")
embedder.embed(test_file, watermarked_file, visualize=True)
print("Detecting watermark...")
detected = detector.detect(watermarked_file, threshold=0.02, visualize=True)
print(f"Detected: {detected}")
print("Testing with crop...")
wm_audio, _ = load_audio(watermarked_file)
crop_start = sr * 1
crop_end = sr * 3 # 2 seconds
cropped_audio = wm_audio[:, crop_start:crop_end]
cropped_file = "test_cropped.wav"
save_audio(cropped_file, cropped_audio, sr)
detected_crop = detector.detect(cropped_file, threshold=0.02)
print(f"Detected in crop: {detected_crop}")