""" StegaStamp Encoder - PyTorch Implementation Converts the TensorFlow encoder to PyTorch for inference """ import torch import torch.nn as nn import torch.nn.functional as F class StegaStampEncoder(nn.Module): def __init__(self, height=400, width=400): super(StegaStampEncoder, self).__init__() self.height = height self.width = width # Secret processing - converts 100-bit secret to 7500 values (50x50x3) self.secret_dense = nn.Linear(100, 7500) # Encoder path - downsampling self.conv1 = nn.Conv2d(6, 32, 3, padding=1) # 6 = 3 (secret) + 3 (image) self.conv2 = nn.Conv2d(32, 32, 3, stride=2, padding=1) self.conv3 = nn.Conv2d(32, 64, 3, stride=2, padding=1) self.conv4 = nn.Conv2d(64, 128, 3, stride=2, padding=1) self.conv5 = nn.Conv2d(128, 256, 3, stride=2, padding=1) # Decoder path - upsampling with skip connections # Using padding=0 for 2x2 kernels, then resize to match skip connections self.up6 = nn.Conv2d(256, 128, 2, padding=0) self.conv6 = nn.Conv2d(256, 128, 3, padding=1) # 256 = 128 (up6) + 128 (conv4) self.up7 = nn.Conv2d(128, 64, 2, padding=0) self.conv7 = nn.Conv2d(128, 64, 3, padding=1) # 128 = 64 (up7) + 64 (conv3) self.up8 = nn.Conv2d(64, 32, 2, padding=0) self.conv8 = nn.Conv2d(64, 32, 3, padding=1) # 64 = 32 (up8) + 32 (conv2) self.up9 = nn.Conv2d(32, 32, 2, padding=0) self.conv9 = nn.Conv2d(70, 32, 3, padding=1) # 70 = 32 (up9) + 32 (conv1) + 6 (inputs) self.conv10 = nn.Conv2d(32, 32, 3, padding=1) # Output layer - produces residual self.residual = nn.Conv2d(32, 3, 1) def forward(self, secret, image): """ Args: secret: torch.Tensor (batch, 100) - binary secret to encode image: torch.Tensor (batch, 3, 400, 400) - image in [0, 1] range, NCHW format Returns: residual: torch.Tensor (batch, 3, 400, 400) - residual to add to image """ # Normalize inputs to [-0.5, 0.5] (matches TF implementation) secret = secret - 0.5 image = image - 0.5 # Process secret: 100 -> 7500 -> reshape to 3x50x50 -> upsample to 3x400x400 secret = F.relu(self.secret_dense(secret)) secret = secret.view(-1, 3, 50, 50) secret_enlarged = F.interpolate(secret, size=(400, 400), mode='nearest') # Concatenate secret and image along channel dimension inputs = torch.cat([secret_enlarged, image], dim=1) # Encoder path (downsampling with skip connections saved) conv1 = F.relu(self.conv1(inputs)) conv2 = F.relu(self.conv2(conv1)) conv3 = F.relu(self.conv3(conv2)) conv4 = F.relu(self.conv4(conv3)) conv5 = F.relu(self.conv5(conv4)) # Decoder path (upsampling with skip connections) # Upsample, apply conv (which reduces size), then resize to match skip connection up6 = F.relu(self.up6(F.interpolate(conv5, scale_factor=2, mode='nearest'))) up6 = F.interpolate(up6, size=conv4.shape[2:], mode='nearest') merge6 = torch.cat([conv4, up6], dim=1) conv6 = F.relu(self.conv6(merge6)) up7 = F.relu(self.up7(F.interpolate(conv6, scale_factor=2, mode='nearest'))) up7 = F.interpolate(up7, size=conv3.shape[2:], mode='nearest') merge7 = torch.cat([conv3, up7], dim=1) conv7 = F.relu(self.conv7(merge7)) up8 = F.relu(self.up8(F.interpolate(conv7, scale_factor=2, mode='nearest'))) up8 = F.interpolate(up8, size=conv2.shape[2:], mode='nearest') merge8 = torch.cat([conv2, up8], dim=1) conv8 = F.relu(self.conv8(merge8)) up9 = F.relu(self.up9(F.interpolate(conv8, scale_factor=2, mode='nearest'))) up9 = F.interpolate(up9, size=conv1.shape[2:], mode='nearest') merge9 = torch.cat([conv1, up9, inputs], dim=1) conv9 = F.relu(self.conv9(merge9)) conv10 = F.relu(self.conv10(conv9)) # Generate residual (no activation on output) residual = self.residual(conv9) return residual if __name__ == "__main__": # Test the encoder print("Testing StegaStamp Encoder...") encoder = StegaStampEncoder() print(f"Total parameters: {sum(p.numel() for p in encoder.parameters()):,}") # Create dummy inputs batch_size = 2 secret = torch.randn(batch_size, 100) image = torch.rand(batch_size, 3, 400, 400) # Forward pass with torch.no_grad(): residual = encoder(secret, image) print(f"Input secret shape: {secret.shape}") print(f"Input image shape: {image.shape}") print(f"Output residual shape: {residual.shape}") print("✓ Encoder test passed!")