File size: 5,066 Bytes
6b430c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""

StegaStamp Encoder - PyTorch Implementation

Converts the TensorFlow encoder to PyTorch for inference

"""
import torch
import torch.nn as nn
import torch.nn.functional as F


class StegaStampEncoder(nn.Module):
    def __init__(self, height=400, width=400):
        super(StegaStampEncoder, self).__init__()
        self.height = height
        self.width = width
        
        # Secret processing - converts 100-bit secret to 7500 values (50x50x3)
        self.secret_dense = nn.Linear(100, 7500)
        
        # Encoder path - downsampling
        self.conv1 = nn.Conv2d(6, 32, 3, padding=1)  # 6 = 3 (secret) + 3 (image)
        self.conv2 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(32, 64, 3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(64, 128, 3, stride=2, padding=1)
        self.conv5 = nn.Conv2d(128, 256, 3, stride=2, padding=1)
        
        # Decoder path - upsampling with skip connections
        # Using padding=0 for 2x2 kernels, then resize to match skip connections
        self.up6 = nn.Conv2d(256, 128, 2, padding=0)
        self.conv6 = nn.Conv2d(256, 128, 3, padding=1)  # 256 = 128 (up6) + 128 (conv4)
        
        self.up7 = nn.Conv2d(128, 64, 2, padding=0)
        self.conv7 = nn.Conv2d(128, 64, 3, padding=1)  # 128 = 64 (up7) + 64 (conv3)
        
        self.up8 = nn.Conv2d(64, 32, 2, padding=0)
        self.conv8 = nn.Conv2d(64, 32, 3, padding=1)  # 64 = 32 (up8) + 32 (conv2)
        
        self.up9 = nn.Conv2d(32, 32, 2, padding=0)
        self.conv9 = nn.Conv2d(70, 32, 3, padding=1)  # 70 = 32 (up9) + 32 (conv1) + 6 (inputs)
        self.conv10 = nn.Conv2d(32, 32, 3, padding=1)
        
        # Output layer - produces residual
        self.residual = nn.Conv2d(32, 3, 1)
        
    def forward(self, secret, image):
        """

        Args:

            secret: torch.Tensor (batch, 100) - binary secret to encode

            image: torch.Tensor (batch, 3, 400, 400) - image in [0, 1] range, NCHW format

            

        Returns:

            residual: torch.Tensor (batch, 3, 400, 400) - residual to add to image

        """
        # Normalize inputs to [-0.5, 0.5] (matches TF implementation)
        secret = secret - 0.5
        image = image - 0.5
        
        # Process secret: 100 -> 7500 -> reshape to 3x50x50 -> upsample to 3x400x400
        secret = F.relu(self.secret_dense(secret))
        secret = secret.view(-1, 3, 50, 50)
        secret_enlarged = F.interpolate(secret, size=(400, 400), mode='nearest')
        
        # Concatenate secret and image along channel dimension
        inputs = torch.cat([secret_enlarged, image], dim=1)
        
        # Encoder path (downsampling with skip connections saved)
        conv1 = F.relu(self.conv1(inputs))
        conv2 = F.relu(self.conv2(conv1))
        conv3 = F.relu(self.conv3(conv2))
        conv4 = F.relu(self.conv4(conv3))
        conv5 = F.relu(self.conv5(conv4))
        
        # Decoder path (upsampling with skip connections)
        # Upsample, apply conv (which reduces size), then resize to match skip connection
        up6 = F.relu(self.up6(F.interpolate(conv5, scale_factor=2, mode='nearest')))
        up6 = F.interpolate(up6, size=conv4.shape[2:], mode='nearest')
        merge6 = torch.cat([conv4, up6], dim=1)
        conv6 = F.relu(self.conv6(merge6))
        
        up7 = F.relu(self.up7(F.interpolate(conv6, scale_factor=2, mode='nearest')))
        up7 = F.interpolate(up7, size=conv3.shape[2:], mode='nearest')
        merge7 = torch.cat([conv3, up7], dim=1)
        conv7 = F.relu(self.conv7(merge7))
        
        up8 = F.relu(self.up8(F.interpolate(conv7, scale_factor=2, mode='nearest')))
        up8 = F.interpolate(up8, size=conv2.shape[2:], mode='nearest')
        merge8 = torch.cat([conv2, up8], dim=1)
        conv8 = F.relu(self.conv8(merge8))
        
        up9 = F.relu(self.up9(F.interpolate(conv8, scale_factor=2, mode='nearest')))
        up9 = F.interpolate(up9, size=conv1.shape[2:], mode='nearest')
        merge9 = torch.cat([conv1, up9, inputs], dim=1)
        conv9 = F.relu(self.conv9(merge9))
        conv10 = F.relu(self.conv10(conv9))
        
        # Generate residual (no activation on output)
        residual = self.residual(conv9)
        
        return residual


if __name__ == "__main__":
    # Test the encoder
    print("Testing StegaStamp Encoder...")
    
    encoder = StegaStampEncoder()
    print(f"Total parameters: {sum(p.numel() for p in encoder.parameters()):,}")
    
    # Create dummy inputs
    batch_size = 2
    secret = torch.randn(batch_size, 100)
    image = torch.rand(batch_size, 3, 400, 400)
    
    # Forward pass
    with torch.no_grad():
        residual = encoder(secret, image)
    
    print(f"Input secret shape: {secret.shape}")
    print(f"Input image shape: {image.shape}")
    print(f"Output residual shape: {residual.shape}")
    print("✓ Encoder test passed!")