ESmike commited on
Commit
36ccd8b
·
verified ·
1 Parent(s): 7c22697

chore: add message extractor classes.

Browse files
Files changed (1) hide show
  1. modeling_message_extractor.py +64 -0
modeling_message_extractor.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from huggingface_hub import PyTorchModelHubMixin
4
+
5
+ class ConvBNRelu(nn.Module):
6
+ """
7
+ Building block used in HiDDeN network. Is a sequence of Convolution, Batch Normalization, and ReLU activation
8
+ """
9
+
10
+ def __init__(self, channels_in, channels_out):
11
+ super(ConvBNRelu, self).__init__()
12
+
13
+ self.layers = nn.Sequential(
14
+ nn.Conv2d(channels_in, channels_out, 3, stride=1, padding=1),
15
+ nn.BatchNorm2d(channels_out, eps=1e-3),
16
+ nn.GELU()
17
+ )
18
+
19
+ def forward(self, x):
20
+ return self.layers(x)
21
+
22
+
23
+ class HiddenDecoder(nn.Module):
24
+ """
25
+ Decoder module. Receives a watermarked image and extracts the watermark.
26
+ """
27
+
28
+ def __init__(self, num_blocks, num_bits, channels, redundancy=1):
29
+ super(HiddenDecoder, self).__init__()
30
+
31
+ layers = [ConvBNRelu(3, channels)]
32
+ for _ in range(num_blocks - 1):
33
+ layers.append(ConvBNRelu(channels, channels))
34
+
35
+ layers.append(ConvBNRelu(channels, num_bits * redundancy))
36
+ layers.append(nn.AdaptiveAvgPool2d(output_size=(1, 1)))
37
+ self.layers = nn.Sequential(*layers)
38
+
39
+ self.linear = nn.Linear(num_bits * redundancy, num_bits * redundancy)
40
+
41
+ self.num_bits = num_bits
42
+ self.redundancy = redundancy
43
+
44
+ def forward(self, img_w):
45
+ x = self.layers(img_w) # b d 1 1
46
+ x = x.squeeze(-1).squeeze(-1) # b d
47
+ x = self.linear(x)
48
+
49
+ x = x.view(-1, self.num_bits, self.redundancy) # b k*r -> b k r
50
+ x = torch.sum(x, dim=-1) # b k r -> b k
51
+
52
+ return x
53
+
54
+
55
+ class MsgExtractor(nn.Module, PyTorchModelHubMixin):
56
+ def __init__(self, hidden_decoder: nn.Module, in_features: int, out_features: int):
57
+ super().__init__()
58
+ self.hidden_decoder = hidden_decoder
59
+ self.head = nn.Linear(in_features, out_features)
60
+
61
+ def forward(self, x):
62
+ x = self.hidden_decoder(x)
63
+ x = self.head(x)
64
+ return x