File size: 6,443 Bytes
ff0e79e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
"""

UNet-Lite Decoder for forgery localization

Lightweight decoder with skip connections, depthwise separable convolutions

"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List


class DepthwiseSeparableConv(nn.Module):
    """Depthwise separable convolution for efficiency"""
    
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3):
        super().__init__()
        
        self.depthwise = nn.Conv2d(
            in_channels, in_channels, 
            kernel_size=kernel_size, 
            padding=kernel_size // 2,
            groups=in_channels,
            bias=False
        )
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.depthwise(x)
        x = self.pointwise(x)
        x = self.bn(x)
        x = self.relu(x)
        return x


class DecoderBlock(nn.Module):
    """Single decoder block with skip connection"""
    
    def __init__(self, in_channels: int, skip_channels: int, out_channels: int):
        """

        Initialize decoder block

        

        Args:

            in_channels: Input channels from previous decoder stage

            skip_channels: Channels from encoder skip connection

            out_channels: Output channels

        """
        super().__init__()
        
        # Combine upsampled features with skip connection
        combined_channels = in_channels + skip_channels
        
        self.conv1 = DepthwiseSeparableConv(combined_channels, out_channels)
        self.conv2 = DepthwiseSeparableConv(out_channels, out_channels)
    
    def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
        """

        Forward pass

        

        Args:

            x: Input from previous decoder stage

            skip: Skip connection from encoder

        

        Returns:

            Decoded features

        """
        # Bilinear upsampling
        x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
        
        # Concatenate with skip connection
        x = torch.cat([x, skip], dim=1)
        
        # Convolutions
        x = self.conv1(x)
        x = self.conv2(x)
        
        return x


class UNetLiteDecoder(nn.Module):
    """

    UNet-Lite decoder for forgery localization

    

    Features:

    - Skip connections from encoder stages

    - Bilinear upsampling

    - Depthwise separable convolutions for efficiency

    """
    
    def __init__(self, 

                 encoder_channels: List[int],

                 decoder_channels: List[int] = None,

                 output_channels: int = 1):
        """

        Initialize decoder

        

        Args:

            encoder_channels: List of encoder feature channels [stage0, ..., stageN]

            decoder_channels: List of decoder output channels

            output_channels: Number of output channels (1 for binary mask)

        """
        super().__init__()
        
        # Default decoder channels if not provided
        if decoder_channels is None:
            decoder_channels = [256, 128, 64, 32, 16]
        
        # Reverse encoder channels for decoder (bottom to top)
        encoder_channels = encoder_channels[::-1]
        
        # Initial convolution from deepest encoder features
        self.initial_conv = DepthwiseSeparableConv(encoder_channels[0], decoder_channels[0])
        
        # Decoder blocks
        self.decoder_blocks = nn.ModuleList()
        
        for i in range(len(encoder_channels) - 1):
            in_ch = decoder_channels[i]
            skip_ch = encoder_channels[i + 1]
            out_ch = decoder_channels[i + 1] if i + 1 < len(decoder_channels) else decoder_channels[-1]
            
            self.decoder_blocks.append(
                DecoderBlock(in_ch, skip_ch, out_ch)
            )
        
        # Final upsampling to original resolution
        self.final_upsample = nn.Sequential(
            DepthwiseSeparableConv(decoder_channels[-1], decoder_channels[-1]),
            nn.Conv2d(decoder_channels[-1], output_channels, kernel_size=1)
        )
        
        # Store decoder feature channels for feature extraction
        self.decoder_channels = decoder_channels
        
        print(f"UNet-Lite decoder initialized")
        print(f"Encoder channels: {encoder_channels[::-1]}")
        print(f"Decoder channels: {decoder_channels}")
    
    def forward(self, encoder_features: List[torch.Tensor]) -> tuple:
        """

        Forward pass

        

        Args:

            encoder_features: List of encoder features [stage0, ..., stageN]

        

        Returns:

            output: Forgery probability map (B, 1, H, W)

            decoder_features: List of decoder features for hybrid extraction

        """
        # Reverse for bottom-up decoding
        features = encoder_features[::-1]
        
        # Initial convolution
        x = self.initial_conv(features[0])
        
        # Store decoder features for hybrid feature extraction
        decoder_features = [x]
        
        # Decoder blocks with skip connections
        for i, block in enumerate(self.decoder_blocks):
            x = block(x, features[i + 1])
            decoder_features.append(x)
        
        # Final upsampling to original resolution
        # Assume input was 384x384, final feature map should match
        target_size = encoder_features[0].shape[2] * 2  # First encoder feature is at 1/2 scale
        x = F.interpolate(x, size=(target_size, target_size), mode='bilinear', align_corners=False)
        output = self.final_upsample[1](self.final_upsample[0](x))
        
        return output, decoder_features


def get_decoder(encoder_channels: List[int], config) -> UNetLiteDecoder:
    """

    Factory function to create decoder

    

    Args:

        encoder_channels: Encoder feature channels

        config: Configuration object

    

    Returns:

        Decoder instance

    """
    output_channels = config.get('model.output_channels', 1)
    return UNetLiteDecoder(encoder_channels, output_channels=output_channels)