|
|
"""
|
|
|
StegaStamp Encoder - PyTorch Implementation
|
|
|
Converts the TensorFlow encoder to PyTorch for inference
|
|
|
"""
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
class StegaStampEncoder(nn.Module):
|
|
|
def __init__(self, height=400, width=400):
|
|
|
super(StegaStampEncoder, self).__init__()
|
|
|
self.height = height
|
|
|
self.width = width
|
|
|
|
|
|
|
|
|
self.secret_dense = nn.Linear(100, 7500)
|
|
|
|
|
|
|
|
|
self.conv1 = nn.Conv2d(6, 32, 3, padding=1)
|
|
|
self.conv2 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
|
|
|
self.conv3 = nn.Conv2d(32, 64, 3, stride=2, padding=1)
|
|
|
self.conv4 = nn.Conv2d(64, 128, 3, stride=2, padding=1)
|
|
|
self.conv5 = nn.Conv2d(128, 256, 3, stride=2, padding=1)
|
|
|
|
|
|
|
|
|
|
|
|
self.up6 = nn.Conv2d(256, 128, 2, padding=0)
|
|
|
self.conv6 = nn.Conv2d(256, 128, 3, padding=1)
|
|
|
|
|
|
self.up7 = nn.Conv2d(128, 64, 2, padding=0)
|
|
|
self.conv7 = nn.Conv2d(128, 64, 3, padding=1)
|
|
|
|
|
|
self.up8 = nn.Conv2d(64, 32, 2, padding=0)
|
|
|
self.conv8 = nn.Conv2d(64, 32, 3, padding=1)
|
|
|
|
|
|
self.up9 = nn.Conv2d(32, 32, 2, padding=0)
|
|
|
self.conv9 = nn.Conv2d(70, 32, 3, padding=1)
|
|
|
self.conv10 = nn.Conv2d(32, 32, 3, padding=1)
|
|
|
|
|
|
|
|
|
self.residual = nn.Conv2d(32, 3, 1)
|
|
|
|
|
|
def forward(self, secret, image):
|
|
|
"""
|
|
|
Args:
|
|
|
secret: torch.Tensor (batch, 100) - binary secret to encode
|
|
|
image: torch.Tensor (batch, 3, 400, 400) - image in [0, 1] range, NCHW format
|
|
|
|
|
|
Returns:
|
|
|
residual: torch.Tensor (batch, 3, 400, 400) - residual to add to image
|
|
|
"""
|
|
|
|
|
|
secret = secret - 0.5
|
|
|
image = image - 0.5
|
|
|
|
|
|
|
|
|
secret = F.relu(self.secret_dense(secret))
|
|
|
secret = secret.view(-1, 3, 50, 50)
|
|
|
secret_enlarged = F.interpolate(secret, size=(400, 400), mode='nearest')
|
|
|
|
|
|
|
|
|
inputs = torch.cat([secret_enlarged, image], dim=1)
|
|
|
|
|
|
|
|
|
conv1 = F.relu(self.conv1(inputs))
|
|
|
conv2 = F.relu(self.conv2(conv1))
|
|
|
conv3 = F.relu(self.conv3(conv2))
|
|
|
conv4 = F.relu(self.conv4(conv3))
|
|
|
conv5 = F.relu(self.conv5(conv4))
|
|
|
|
|
|
|
|
|
|
|
|
up6 = F.relu(self.up6(F.interpolate(conv5, scale_factor=2, mode='nearest')))
|
|
|
up6 = F.interpolate(up6, size=conv4.shape[2:], mode='nearest')
|
|
|
merge6 = torch.cat([conv4, up6], dim=1)
|
|
|
conv6 = F.relu(self.conv6(merge6))
|
|
|
|
|
|
up7 = F.relu(self.up7(F.interpolate(conv6, scale_factor=2, mode='nearest')))
|
|
|
up7 = F.interpolate(up7, size=conv3.shape[2:], mode='nearest')
|
|
|
merge7 = torch.cat([conv3, up7], dim=1)
|
|
|
conv7 = F.relu(self.conv7(merge7))
|
|
|
|
|
|
up8 = F.relu(self.up8(F.interpolate(conv7, scale_factor=2, mode='nearest')))
|
|
|
up8 = F.interpolate(up8, size=conv2.shape[2:], mode='nearest')
|
|
|
merge8 = torch.cat([conv2, up8], dim=1)
|
|
|
conv8 = F.relu(self.conv8(merge8))
|
|
|
|
|
|
up9 = F.relu(self.up9(F.interpolate(conv8, scale_factor=2, mode='nearest')))
|
|
|
up9 = F.interpolate(up9, size=conv1.shape[2:], mode='nearest')
|
|
|
merge9 = torch.cat([conv1, up9, inputs], dim=1)
|
|
|
conv9 = F.relu(self.conv9(merge9))
|
|
|
conv10 = F.relu(self.conv10(conv9))
|
|
|
|
|
|
|
|
|
residual = self.residual(conv9)
|
|
|
|
|
|
return residual
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
print("Testing StegaStamp Encoder...")
|
|
|
|
|
|
encoder = StegaStampEncoder()
|
|
|
print(f"Total parameters: {sum(p.numel() for p in encoder.parameters()):,}")
|
|
|
|
|
|
|
|
|
batch_size = 2
|
|
|
secret = torch.randn(batch_size, 100)
|
|
|
image = torch.rand(batch_size, 3, 400, 400)
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
residual = encoder(secret, image)
|
|
|
|
|
|
print(f"Input secret shape: {secret.shape}")
|
|
|
print(f"Input image shape: {image.shape}")
|
|
|
print(f"Output residual shape: {residual.shape}")
|
|
|
print("✓ Encoder test passed!") |