JKrishnanandhaa's picture
Upload 54 files
ff0e79e verified
raw
history blame
6.44 kB
"""
UNet-Lite Decoder for forgery localization
Lightweight decoder with skip connections, depthwise separable convolutions
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List
class DepthwiseSeparableConv(nn.Module):
"""Depthwise separable convolution for efficiency"""
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3):
super().__init__()
self.depthwise = nn.Conv2d(
in_channels, in_channels,
kernel_size=kernel_size,
padding=kernel_size // 2,
groups=in_channels,
bias=False
)
self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.depthwise(x)
x = self.pointwise(x)
x = self.bn(x)
x = self.relu(x)
return x
class DecoderBlock(nn.Module):
"""Single decoder block with skip connection"""
def __init__(self, in_channels: int, skip_channels: int, out_channels: int):
"""
Initialize decoder block
Args:
in_channels: Input channels from previous decoder stage
skip_channels: Channels from encoder skip connection
out_channels: Output channels
"""
super().__init__()
# Combine upsampled features with skip connection
combined_channels = in_channels + skip_channels
self.conv1 = DepthwiseSeparableConv(combined_channels, out_channels)
self.conv2 = DepthwiseSeparableConv(out_channels, out_channels)
def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
"""
Forward pass
Args:
x: Input from previous decoder stage
skip: Skip connection from encoder
Returns:
Decoded features
"""
# Bilinear upsampling
x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
# Concatenate with skip connection
x = torch.cat([x, skip], dim=1)
# Convolutions
x = self.conv1(x)
x = self.conv2(x)
return x
class UNetLiteDecoder(nn.Module):
"""
UNet-Lite decoder for forgery localization
Features:
- Skip connections from encoder stages
- Bilinear upsampling
- Depthwise separable convolutions for efficiency
"""
def __init__(self,
encoder_channels: List[int],
decoder_channels: List[int] = None,
output_channels: int = 1):
"""
Initialize decoder
Args:
encoder_channels: List of encoder feature channels [stage0, ..., stageN]
decoder_channels: List of decoder output channels
output_channels: Number of output channels (1 for binary mask)
"""
super().__init__()
# Default decoder channels if not provided
if decoder_channels is None:
decoder_channels = [256, 128, 64, 32, 16]
# Reverse encoder channels for decoder (bottom to top)
encoder_channels = encoder_channels[::-1]
# Initial convolution from deepest encoder features
self.initial_conv = DepthwiseSeparableConv(encoder_channels[0], decoder_channels[0])
# Decoder blocks
self.decoder_blocks = nn.ModuleList()
for i in range(len(encoder_channels) - 1):
in_ch = decoder_channels[i]
skip_ch = encoder_channels[i + 1]
out_ch = decoder_channels[i + 1] if i + 1 < len(decoder_channels) else decoder_channels[-1]
self.decoder_blocks.append(
DecoderBlock(in_ch, skip_ch, out_ch)
)
# Final upsampling to original resolution
self.final_upsample = nn.Sequential(
DepthwiseSeparableConv(decoder_channels[-1], decoder_channels[-1]),
nn.Conv2d(decoder_channels[-1], output_channels, kernel_size=1)
)
# Store decoder feature channels for feature extraction
self.decoder_channels = decoder_channels
print(f"UNet-Lite decoder initialized")
print(f"Encoder channels: {encoder_channels[::-1]}")
print(f"Decoder channels: {decoder_channels}")
def forward(self, encoder_features: List[torch.Tensor]) -> tuple:
"""
Forward pass
Args:
encoder_features: List of encoder features [stage0, ..., stageN]
Returns:
output: Forgery probability map (B, 1, H, W)
decoder_features: List of decoder features for hybrid extraction
"""
# Reverse for bottom-up decoding
features = encoder_features[::-1]
# Initial convolution
x = self.initial_conv(features[0])
# Store decoder features for hybrid feature extraction
decoder_features = [x]
# Decoder blocks with skip connections
for i, block in enumerate(self.decoder_blocks):
x = block(x, features[i + 1])
decoder_features.append(x)
# Final upsampling to original resolution
# Assume input was 384x384, final feature map should match
target_size = encoder_features[0].shape[2] * 2 # First encoder feature is at 1/2 scale
x = F.interpolate(x, size=(target_size, target_size), mode='bilinear', align_corners=False)
output = self.final_upsample[1](self.final_upsample[0](x))
return output, decoder_features
def get_decoder(encoder_channels: List[int], config) -> UNetLiteDecoder:
"""
Factory function to create decoder
Args:
encoder_channels: Encoder feature channels
config: Configuration object
Returns:
Decoder instance
"""
output_channels = config.get('model.output_channels', 1)
return UNetLiteDecoder(encoder_channels, output_channels=output_channels)