Spaces:
Sleeping
Sleeping
File size: 6,443 Bytes
770b89a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 | """
UNet-Lite Decoder for forgery localization
Lightweight decoder with skip connections, depthwise separable convolutions
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List
class DepthwiseSeparableConv(nn.Module):
"""Depthwise separable convolution for efficiency"""
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3):
super().__init__()
self.depthwise = nn.Conv2d(
in_channels, in_channels,
kernel_size=kernel_size,
padding=kernel_size // 2,
groups=in_channels,
bias=False
)
self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.depthwise(x)
x = self.pointwise(x)
x = self.bn(x)
x = self.relu(x)
return x
class DecoderBlock(nn.Module):
"""Single decoder block with skip connection"""
def __init__(self, in_channels: int, skip_channels: int, out_channels: int):
"""
Initialize decoder block
Args:
in_channels: Input channels from previous decoder stage
skip_channels: Channels from encoder skip connection
out_channels: Output channels
"""
super().__init__()
# Combine upsampled features with skip connection
combined_channels = in_channels + skip_channels
self.conv1 = DepthwiseSeparableConv(combined_channels, out_channels)
self.conv2 = DepthwiseSeparableConv(out_channels, out_channels)
def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
"""
Forward pass
Args:
x: Input from previous decoder stage
skip: Skip connection from encoder
Returns:
Decoded features
"""
# Bilinear upsampling
x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
# Concatenate with skip connection
x = torch.cat([x, skip], dim=1)
# Convolutions
x = self.conv1(x)
x = self.conv2(x)
return x
class UNetLiteDecoder(nn.Module):
"""
UNet-Lite decoder for forgery localization
Features:
- Skip connections from encoder stages
- Bilinear upsampling
- Depthwise separable convolutions for efficiency
"""
def __init__(self,
encoder_channels: List[int],
decoder_channels: List[int] = None,
output_channels: int = 1):
"""
Initialize decoder
Args:
encoder_channels: List of encoder feature channels [stage0, ..., stageN]
decoder_channels: List of decoder output channels
output_channels: Number of output channels (1 for binary mask)
"""
super().__init__()
# Default decoder channels if not provided
if decoder_channels is None:
decoder_channels = [256, 128, 64, 32, 16]
# Reverse encoder channels for decoder (bottom to top)
encoder_channels = encoder_channels[::-1]
# Initial convolution from deepest encoder features
self.initial_conv = DepthwiseSeparableConv(encoder_channels[0], decoder_channels[0])
# Decoder blocks
self.decoder_blocks = nn.ModuleList()
for i in range(len(encoder_channels) - 1):
in_ch = decoder_channels[i]
skip_ch = encoder_channels[i + 1]
out_ch = decoder_channels[i + 1] if i + 1 < len(decoder_channels) else decoder_channels[-1]
self.decoder_blocks.append(
DecoderBlock(in_ch, skip_ch, out_ch)
)
# Final upsampling to original resolution
self.final_upsample = nn.Sequential(
DepthwiseSeparableConv(decoder_channels[-1], decoder_channels[-1]),
nn.Conv2d(decoder_channels[-1], output_channels, kernel_size=1)
)
# Store decoder feature channels for feature extraction
self.decoder_channels = decoder_channels
print(f"UNet-Lite decoder initialized")
print(f"Encoder channels: {encoder_channels[::-1]}")
print(f"Decoder channels: {decoder_channels}")
def forward(self, encoder_features: List[torch.Tensor]) -> tuple:
"""
Forward pass
Args:
encoder_features: List of encoder features [stage0, ..., stageN]
Returns:
output: Forgery probability map (B, 1, H, W)
decoder_features: List of decoder features for hybrid extraction
"""
# Reverse for bottom-up decoding
features = encoder_features[::-1]
# Initial convolution
x = self.initial_conv(features[0])
# Store decoder features for hybrid feature extraction
decoder_features = [x]
# Decoder blocks with skip connections
for i, block in enumerate(self.decoder_blocks):
x = block(x, features[i + 1])
decoder_features.append(x)
# Final upsampling to original resolution
# Assume input was 384x384, final feature map should match
target_size = encoder_features[0].shape[2] * 2 # First encoder feature is at 1/2 scale
x = F.interpolate(x, size=(target_size, target_size), mode='bilinear', align_corners=False)
output = self.final_upsample[1](self.final_upsample[0](x))
return output, decoder_features
def get_decoder(encoder_channels: List[int], config) -> UNetLiteDecoder:
"""
Factory function to create decoder
Args:
encoder_channels: Encoder feature channels
config: Configuration object
Returns:
Decoder instance
"""
output_channels = config.get('model.output_channels', 1)
return UNetLiteDecoder(encoder_channels, output_channels=output_channels)
|