| |
| """ |
| CLI for SynthID watermark inference on single images. |
| |
| Example: |
| python infer.py --checkpoint model.pt --image path/to/image.png --size 512 |
| """ |
|
|
| import argparse |
|
|
| import cv2 |
| import numpy as np |
| import pywt |
| import torch |
| import torch.nn as nn |
| from torchvision import models |
|
|
|
|
| |
| |
| |
|
|
| def wavelet_denoise(channel, wavelet="db4", level=3): |
| channel = np.nan_to_num(channel, nan=0.0, posinf=1.0, neginf=0.0) |
| coeffs = pywt.wavedec2(channel, wavelet, level=level) |
| detail = coeffs[-1][0] |
| sigma = np.median(np.abs(detail)) / 0.6745 |
| threshold = sigma * np.sqrt(2 * np.log(channel.size)) |
|
|
| new_coeffs = [coeffs[0]] |
| for details in coeffs[1:]: |
| with np.errstate(invalid="ignore", divide="ignore"): |
| new_details = tuple(pywt.threshold(d, threshold, mode="soft") for d in details) |
| new_coeffs.append(new_details) |
|
|
| denoised = pywt.waverec2(new_coeffs, wavelet) |
| return denoised[: channel.shape[0], : channel.shape[1]] |
|
|
|
|
| def build_carrier_mask(size: int) -> np.ndarray: |
| """Binary mask with known SynthID carrier frequencies marked (symmetric).""" |
| carriers = [(14, 14), (-14, -14), (126, 14), (-126, -14), (98, -14), (-98, 14), (128, 128), (-128, -128)] |
| mask = np.zeros((size, size), dtype=np.float32) |
| c = size // 2 |
| for fy, fx in carriers: |
| y, x = c + fy, c + fx |
| if 0 <= y < size and 0 <= x < size: |
| mask[y, x] = 1.0 |
| |
| y2, x2 = c - fy, c - fx |
| if 0 <= y2 < size and 0 <= x2 < size: |
| mask[y2, x2] = 1.0 |
| return mask |
|
|
|
|
| def fft_log_magnitude(gray: np.ndarray) -> np.ndarray: |
| """Compute log-magnitude FFT channel normalized to [0,1].""" |
| f = np.fft.fft2(gray) |
| fshift = np.fft.fftshift(f) |
| mag = np.abs(fshift) |
| log_mag = np.log1p(mag) |
| log_mag = (log_mag - log_mag.min()) / (log_mag.max() - log_mag.min() + 1e-8) |
| return log_mag.astype(np.float32) |
|
|
|
|
| |
| |
| |
|
|
| class DualStreamWatermarkNet(nn.Module): |
| """RGB+residual branch + frequency branch fused for watermark detection.""" |
|
|
| def __init__( |
| self, |
| spatial_in: int = 4, |
| freq_in: int = 1, |
| hidden_dim: int = 256, |
| pretrained: bool = False, |
| backbone: str = "resnet18", |
| ): |
| super().__init__() |
| |
| if backbone == "resnet34": |
| weights = models.ResNet34_Weights.IMAGENET1K_V1 if pretrained else None |
| self.spatial = models.resnet34(weights=weights) |
| else: |
| weights = models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None |
| self.spatial = models.resnet18(weights=weights) |
|
|
| self.spatial.conv1 = nn.Conv2d(spatial_in, 64, kernel_size=7, stride=2, padding=3, bias=False) |
| self.spatial.fc = nn.Identity() |
|
|
| |
| freq_layers = [ |
| nn.Conv2d(freq_in, 32, kernel_size=5, stride=2, padding=2), |
| nn.BatchNorm2d(32), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), |
| nn.BatchNorm2d(64), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), |
| nn.BatchNorm2d(128), |
| nn.ReLU(inplace=True), |
| nn.AdaptiveAvgPool2d((1, 1)), |
| ] |
| self.freq = nn.Sequential(*freq_layers) |
|
|
| |
| fusion_dim = 512 + 128 |
| self.classifier = nn.Sequential( |
| nn.Linear(fusion_dim, hidden_dim), |
| nn.ReLU(inplace=True), |
| nn.Dropout(0.2), |
| nn.Linear(hidden_dim, 1), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| spatial = x[:, :4, :, :] |
| freq = x[:, 4:, :, :] |
|
|
| s_feat = self.spatial(spatial) |
| f_feat = self.freq(freq).flatten(1) |
|
|
| fused = torch.cat([s_feat, f_feat], dim=1) |
| logit = self.classifier(fused).squeeze(1) |
| return logit |
|
|
|
|
| def build_model(total_channels: int, hidden_dim: int = 256, pretrained: bool = False, backbone: str = "resnet34") -> DualStreamWatermarkNet: |
| """Helper to build the model with inferred branch channels.""" |
| spatial_in = 4 |
| freq_in = max(1, total_channels - spatial_in) |
| return DualStreamWatermarkNet( |
| spatial_in=spatial_in, |
| freq_in=freq_in, |
| hidden_dim=hidden_dim, |
| pretrained=pretrained, |
| backbone=backbone, |
| ) |
|
|
|
|
| def preprocess(path: str, size: int = 512, use_fft: bool = True, use_carrier: bool = True) -> torch.Tensor: |
| img = cv2.imread(path) |
| if img is None: |
| raise FileNotFoundError(path) |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| img = cv2.resize(img, (size, size)) |
| img_f = img.astype(np.float32) / 255.0 |
| residual = np.zeros((size, size, 3), dtype=np.float32) |
| for c in range(3): |
| residual[:, :, c] = img_f[:, :, c] - wavelet_denoise(img_f[:, :, c]) |
| residual_gray = residual.mean(axis=2) |
| channels = [img_f.transpose(2, 0, 1), residual_gray[None, :, :]] |
| if use_fft: |
| gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY).astype(np.float32) / 255.0 |
| channels.append(fft_log_magnitude(gray)[None, :, :]) |
| if use_carrier: |
| channels.append(build_carrier_mask(size)[None, :, :]) |
| x = np.concatenate(channels, axis=0) |
| x = np.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0) |
| return torch.from_numpy(x).unsqueeze(0).float() |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="SynthID watermark inference.") |
| parser.add_argument("--checkpoint", type=str, required=True, help="Path to trained checkpoint (.pt).") |
| parser.add_argument("--image", type=str, required=True, help="Path to image.") |
| parser.add_argument("--size", type=int, default=512, help="Resize target.") |
| parser.add_argument("--no-fft", action="store_true", help="Disable FFT channel.") |
| parser.add_argument("--no-carrier-mask", action="store_true", help="Disable carrier mask channel.") |
| args = parser.parse_args() |
|
|
| ckpt = torch.load(args.checkpoint, map_location="cpu") |
| total_channels = ckpt.get("channels", 6) |
| model = build_model(total_channels=total_channels, pretrained=False, backbone=ckpt.get("args", {}).get("backbone", "resnet34")) |
| model.load_state_dict(ckpt["model_state"], strict=False) |
| model.eval() |
|
|
| x = preprocess(args.image, size=args.size, use_fft=not args.no_fft, use_carrier=not args.no_carrier_mask) |
| with torch.no_grad(): |
| logit = model(x) |
| prob = torch.sigmoid(logit).item() |
| print(f"Watermark probability: {prob:.4f}") |
| print(f"Decision (threshold 0.5): {'WATERMARKED' if prob >= 0.5 else 'CLEAN'}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|