| """
|
| UNet-Lite Decoder for forgery localization
|
| Lightweight decoder with skip connections, depthwise separable convolutions
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from typing import List
|
|
|
|
|
| class DepthwiseSeparableConv(nn.Module):
|
| """Depthwise separable convolution for efficiency"""
|
|
|
| def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3):
|
| super().__init__()
|
|
|
| self.depthwise = nn.Conv2d(
|
| in_channels, in_channels,
|
| kernel_size=kernel_size,
|
| padding=kernel_size // 2,
|
| groups=in_channels,
|
| bias=False
|
| )
|
| self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
|
| self.bn = nn.BatchNorm2d(out_channels)
|
| self.relu = nn.ReLU(inplace=True)
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| x = self.depthwise(x)
|
| x = self.pointwise(x)
|
| x = self.bn(x)
|
| x = self.relu(x)
|
| return x
|
|
|
|
|
| class DecoderBlock(nn.Module):
|
| """Single decoder block with skip connection"""
|
|
|
| def __init__(self, in_channels: int, skip_channels: int, out_channels: int):
|
| """
|
| Initialize decoder block
|
|
|
| Args:
|
| in_channels: Input channels from previous decoder stage
|
| skip_channels: Channels from encoder skip connection
|
| out_channels: Output channels
|
| """
|
| super().__init__()
|
|
|
|
|
| combined_channels = in_channels + skip_channels
|
|
|
| self.conv1 = DepthwiseSeparableConv(combined_channels, out_channels)
|
| self.conv2 = DepthwiseSeparableConv(out_channels, out_channels)
|
|
|
| def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
|
| """
|
| Forward pass
|
|
|
| Args:
|
| x: Input from previous decoder stage
|
| skip: Skip connection from encoder
|
|
|
| Returns:
|
| Decoded features
|
| """
|
|
|
| x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
|
|
|
|
|
| x = torch.cat([x, skip], dim=1)
|
|
|
|
|
| x = self.conv1(x)
|
| x = self.conv2(x)
|
|
|
| return x
|
|
|
|
|
| class UNetLiteDecoder(nn.Module):
|
| """
|
| UNet-Lite decoder for forgery localization
|
|
|
| Features:
|
| - Skip connections from encoder stages
|
| - Bilinear upsampling
|
| - Depthwise separable convolutions for efficiency
|
| """
|
|
|
| def __init__(self,
|
| encoder_channels: List[int],
|
| decoder_channels: List[int] = None,
|
| output_channels: int = 1):
|
| """
|
| Initialize decoder
|
|
|
| Args:
|
| encoder_channels: List of encoder feature channels [stage0, ..., stageN]
|
| decoder_channels: List of decoder output channels
|
| output_channels: Number of output channels (1 for binary mask)
|
| """
|
| super().__init__()
|
|
|
|
|
| if decoder_channels is None:
|
| decoder_channels = [256, 128, 64, 32, 16]
|
|
|
|
|
| encoder_channels = encoder_channels[::-1]
|
|
|
|
|
| self.initial_conv = DepthwiseSeparableConv(encoder_channels[0], decoder_channels[0])
|
|
|
|
|
| self.decoder_blocks = nn.ModuleList()
|
|
|
| for i in range(len(encoder_channels) - 1):
|
| in_ch = decoder_channels[i]
|
| skip_ch = encoder_channels[i + 1]
|
| out_ch = decoder_channels[i + 1] if i + 1 < len(decoder_channels) else decoder_channels[-1]
|
|
|
| self.decoder_blocks.append(
|
| DecoderBlock(in_ch, skip_ch, out_ch)
|
| )
|
|
|
|
|
| self.final_upsample = nn.Sequential(
|
| DepthwiseSeparableConv(decoder_channels[-1], decoder_channels[-1]),
|
| nn.Conv2d(decoder_channels[-1], output_channels, kernel_size=1)
|
| )
|
|
|
|
|
| self.decoder_channels = decoder_channels
|
|
|
| print(f"UNet-Lite decoder initialized")
|
| print(f"Encoder channels: {encoder_channels[::-1]}")
|
| print(f"Decoder channels: {decoder_channels}")
|
|
|
| def forward(self, encoder_features: List[torch.Tensor]) -> tuple:
|
| """
|
| Forward pass
|
|
|
| Args:
|
| encoder_features: List of encoder features [stage0, ..., stageN]
|
|
|
| Returns:
|
| output: Forgery probability map (B, 1, H, W)
|
| decoder_features: List of decoder features for hybrid extraction
|
| """
|
|
|
| features = encoder_features[::-1]
|
|
|
|
|
| x = self.initial_conv(features[0])
|
|
|
|
|
| decoder_features = [x]
|
|
|
|
|
| for i, block in enumerate(self.decoder_blocks):
|
| x = block(x, features[i + 1])
|
| decoder_features.append(x)
|
|
|
|
|
|
|
| target_size = encoder_features[0].shape[2] * 2
|
| x = F.interpolate(x, size=(target_size, target_size), mode='bilinear', align_corners=False)
|
| output = self.final_upsample[1](self.final_upsample[0](x))
|
|
|
| return output, decoder_features
|
|
|
|
|
| def get_decoder(encoder_channels: List[int], config) -> UNetLiteDecoder:
|
| """
|
| Factory function to create decoder
|
|
|
| Args:
|
| encoder_channels: Encoder feature channels
|
| config: Configuration object
|
|
|
| Returns:
|
| Decoder instance
|
| """
|
| output_channels = config.get('model.output_channels', 1)
|
| return UNetLiteDecoder(encoder_channels, output_channels=output_channels)
|
|
|