Spaces:
Sleeping
Sleeping
| """ | |
| UNet-Lite Decoder for forgery localization | |
| Lightweight decoder with skip connections, depthwise separable convolutions | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import List | |
| class DepthwiseSeparableConv(nn.Module): | |
| """Depthwise separable convolution for efficiency""" | |
| def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3): | |
| super().__init__() | |
| self.depthwise = nn.Conv2d( | |
| in_channels, in_channels, | |
| kernel_size=kernel_size, | |
| padding=kernel_size // 2, | |
| groups=in_channels, | |
| bias=False | |
| ) | |
| self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) | |
| self.bn = nn.BatchNorm2d(out_channels) | |
| self.relu = nn.ReLU(inplace=True) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.depthwise(x) | |
| x = self.pointwise(x) | |
| x = self.bn(x) | |
| x = self.relu(x) | |
| return x | |
| class DecoderBlock(nn.Module): | |
| """Single decoder block with skip connection""" | |
| def __init__(self, in_channels: int, skip_channels: int, out_channels: int): | |
| """ | |
| Initialize decoder block | |
| Args: | |
| in_channels: Input channels from previous decoder stage | |
| skip_channels: Channels from encoder skip connection | |
| out_channels: Output channels | |
| """ | |
| super().__init__() | |
| # Combine upsampled features with skip connection | |
| combined_channels = in_channels + skip_channels | |
| self.conv1 = DepthwiseSeparableConv(combined_channels, out_channels) | |
| self.conv2 = DepthwiseSeparableConv(out_channels, out_channels) | |
| def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Forward pass | |
| Args: | |
| x: Input from previous decoder stage | |
| skip: Skip connection from encoder | |
| Returns: | |
| Decoded features | |
| """ | |
| # Bilinear upsampling | |
| x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False) | |
| # Concatenate with skip connection | |
| x = torch.cat([x, skip], dim=1) | |
| # Convolutions | |
| x = self.conv1(x) | |
| x = self.conv2(x) | |
| return x | |
| class UNetLiteDecoder(nn.Module): | |
| """ | |
| UNet-Lite decoder for forgery localization | |
| Features: | |
| - Skip connections from encoder stages | |
| - Bilinear upsampling | |
| - Depthwise separable convolutions for efficiency | |
| """ | |
| def __init__(self, | |
| encoder_channels: List[int], | |
| decoder_channels: List[int] = None, | |
| output_channels: int = 1): | |
| """ | |
| Initialize decoder | |
| Args: | |
| encoder_channels: List of encoder feature channels [stage0, ..., stageN] | |
| decoder_channels: List of decoder output channels | |
| output_channels: Number of output channels (1 for binary mask) | |
| """ | |
| super().__init__() | |
| # Default decoder channels if not provided | |
| if decoder_channels is None: | |
| decoder_channels = [256, 128, 64, 32, 16] | |
| # Reverse encoder channels for decoder (bottom to top) | |
| encoder_channels = encoder_channels[::-1] | |
| # Initial convolution from deepest encoder features | |
| self.initial_conv = DepthwiseSeparableConv(encoder_channels[0], decoder_channels[0]) | |
| # Decoder blocks | |
| self.decoder_blocks = nn.ModuleList() | |
| for i in range(len(encoder_channels) - 1): | |
| in_ch = decoder_channels[i] | |
| skip_ch = encoder_channels[i + 1] | |
| out_ch = decoder_channels[i + 1] if i + 1 < len(decoder_channels) else decoder_channels[-1] | |
| self.decoder_blocks.append( | |
| DecoderBlock(in_ch, skip_ch, out_ch) | |
| ) | |
| # Final upsampling to original resolution | |
| self.final_upsample = nn.Sequential( | |
| DepthwiseSeparableConv(decoder_channels[-1], decoder_channels[-1]), | |
| nn.Conv2d(decoder_channels[-1], output_channels, kernel_size=1) | |
| ) | |
| # Store decoder feature channels for feature extraction | |
| self.decoder_channels = decoder_channels | |
| print(f"UNet-Lite decoder initialized") | |
| print(f"Encoder channels: {encoder_channels[::-1]}") | |
| print(f"Decoder channels: {decoder_channels}") | |
| def forward(self, encoder_features: List[torch.Tensor]) -> tuple: | |
| """ | |
| Forward pass | |
| Args: | |
| encoder_features: List of encoder features [stage0, ..., stageN] | |
| Returns: | |
| output: Forgery probability map (B, 1, H, W) | |
| decoder_features: List of decoder features for hybrid extraction | |
| """ | |
| # Reverse for bottom-up decoding | |
| features = encoder_features[::-1] | |
| # Initial convolution | |
| x = self.initial_conv(features[0]) | |
| # Store decoder features for hybrid feature extraction | |
| decoder_features = [x] | |
| # Decoder blocks with skip connections | |
| for i, block in enumerate(self.decoder_blocks): | |
| x = block(x, features[i + 1]) | |
| decoder_features.append(x) | |
| # Final upsampling to original resolution | |
| # Assume input was 384x384, final feature map should match | |
| target_size = encoder_features[0].shape[2] * 2 # First encoder feature is at 1/2 scale | |
| x = F.interpolate(x, size=(target_size, target_size), mode='bilinear', align_corners=False) | |
| output = self.final_upsample[1](self.final_upsample[0](x)) | |
| return output, decoder_features | |
| def get_decoder(encoder_channels: List[int], config) -> UNetLiteDecoder: | |
| """ | |
| Factory function to create decoder | |
| Args: | |
| encoder_channels: Encoder feature channels | |
| config: Configuration object | |
| Returns: | |
| Decoder instance | |
| """ | |
| output_channels = config.get('model.output_channels', 1) | |
| return UNetLiteDecoder(encoder_channels, output_channels=output_channels) | |