stegastamp / encoder.py
KingTechnician's picture
Upload folder using huggingface_hub
6b430c5 verified
"""
StegaStamp Encoder - PyTorch Implementation
Converts the TensorFlow encoder to PyTorch for inference
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class StegaStampEncoder(nn.Module):
def __init__(self, height=400, width=400):
super(StegaStampEncoder, self).__init__()
self.height = height
self.width = width
# Secret processing - converts 100-bit secret to 7500 values (50x50x3)
self.secret_dense = nn.Linear(100, 7500)
# Encoder path - downsampling
self.conv1 = nn.Conv2d(6, 32, 3, padding=1) # 6 = 3 (secret) + 3 (image)
self.conv2 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
self.conv3 = nn.Conv2d(32, 64, 3, stride=2, padding=1)
self.conv4 = nn.Conv2d(64, 128, 3, stride=2, padding=1)
self.conv5 = nn.Conv2d(128, 256, 3, stride=2, padding=1)
# Decoder path - upsampling with skip connections
# Using padding=0 for 2x2 kernels, then resize to match skip connections
self.up6 = nn.Conv2d(256, 128, 2, padding=0)
self.conv6 = nn.Conv2d(256, 128, 3, padding=1) # 256 = 128 (up6) + 128 (conv4)
self.up7 = nn.Conv2d(128, 64, 2, padding=0)
self.conv7 = nn.Conv2d(128, 64, 3, padding=1) # 128 = 64 (up7) + 64 (conv3)
self.up8 = nn.Conv2d(64, 32, 2, padding=0)
self.conv8 = nn.Conv2d(64, 32, 3, padding=1) # 64 = 32 (up8) + 32 (conv2)
self.up9 = nn.Conv2d(32, 32, 2, padding=0)
self.conv9 = nn.Conv2d(70, 32, 3, padding=1) # 70 = 32 (up9) + 32 (conv1) + 6 (inputs)
self.conv10 = nn.Conv2d(32, 32, 3, padding=1)
# Output layer - produces residual
self.residual = nn.Conv2d(32, 3, 1)
def forward(self, secret, image):
"""
Args:
secret: torch.Tensor (batch, 100) - binary secret to encode
image: torch.Tensor (batch, 3, 400, 400) - image in [0, 1] range, NCHW format
Returns:
residual: torch.Tensor (batch, 3, 400, 400) - residual to add to image
"""
# Normalize inputs to [-0.5, 0.5] (matches TF implementation)
secret = secret - 0.5
image = image - 0.5
# Process secret: 100 -> 7500 -> reshape to 3x50x50 -> upsample to 3x400x400
secret = F.relu(self.secret_dense(secret))
secret = secret.view(-1, 3, 50, 50)
secret_enlarged = F.interpolate(secret, size=(400, 400), mode='nearest')
# Concatenate secret and image along channel dimension
inputs = torch.cat([secret_enlarged, image], dim=1)
# Encoder path (downsampling with skip connections saved)
conv1 = F.relu(self.conv1(inputs))
conv2 = F.relu(self.conv2(conv1))
conv3 = F.relu(self.conv3(conv2))
conv4 = F.relu(self.conv4(conv3))
conv5 = F.relu(self.conv5(conv4))
# Decoder path (upsampling with skip connections)
# Upsample, apply conv (which reduces size), then resize to match skip connection
up6 = F.relu(self.up6(F.interpolate(conv5, scale_factor=2, mode='nearest')))
up6 = F.interpolate(up6, size=conv4.shape[2:], mode='nearest')
merge6 = torch.cat([conv4, up6], dim=1)
conv6 = F.relu(self.conv6(merge6))
up7 = F.relu(self.up7(F.interpolate(conv6, scale_factor=2, mode='nearest')))
up7 = F.interpolate(up7, size=conv3.shape[2:], mode='nearest')
merge7 = torch.cat([conv3, up7], dim=1)
conv7 = F.relu(self.conv7(merge7))
up8 = F.relu(self.up8(F.interpolate(conv7, scale_factor=2, mode='nearest')))
up8 = F.interpolate(up8, size=conv2.shape[2:], mode='nearest')
merge8 = torch.cat([conv2, up8], dim=1)
conv8 = F.relu(self.conv8(merge8))
up9 = F.relu(self.up9(F.interpolate(conv8, scale_factor=2, mode='nearest')))
up9 = F.interpolate(up9, size=conv1.shape[2:], mode='nearest')
merge9 = torch.cat([conv1, up9, inputs], dim=1)
conv9 = F.relu(self.conv9(merge9))
conv10 = F.relu(self.conv10(conv9))
# Generate residual (no activation on output)
residual = self.residual(conv9)
return residual
if __name__ == "__main__":
# Test the encoder
print("Testing StegaStamp Encoder...")
encoder = StegaStampEncoder()
print(f"Total parameters: {sum(p.numel() for p in encoder.parameters()):,}")
# Create dummy inputs
batch_size = 2
secret = torch.randn(batch_size, 100)
image = torch.rand(batch_size, 3, 400, 400)
# Forward pass
with torch.no_grad():
residual = encoder(secret, image)
print(f"Input secret shape: {secret.shape}")
print(f"Input image shape: {image.shape}")
print(f"Output residual shape: {residual.shape}")
print("✓ Encoder test passed!")