| """
|
| Complete Forgery Localization Network
|
| MobileNetV3-Small Encoder + UNet-Lite Decoder
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| from typing import Tuple, List, Optional
|
|
|
| from .encoder import MobileNetV3Encoder
|
| from .decoder import UNetLiteDecoder
|
|
|
|
|
| class ForgeryLocalizationNetwork(nn.Module):
|
| """
|
| Complete network for forgery localization
|
|
|
| Architecture:
|
| - Encoder: MobileNetV3-Small (ImageNet pretrained)
|
| - Decoder: UNet-Lite with skip connections
|
| - Output: Single-channel forgery probability map
|
| """
|
|
|
| def __init__(self, config):
|
| """
|
| Initialize network
|
|
|
| Args:
|
| config: Configuration object
|
| """
|
| super().__init__()
|
|
|
| self.config = config
|
|
|
|
|
| pretrained = config.get('model.encoder.pretrained', True)
|
| self.encoder = MobileNetV3Encoder(pretrained=pretrained)
|
|
|
|
|
| encoder_channels = self.encoder.get_feature_channels()
|
| output_channels = config.get('model.output_channels', 1)
|
| self.decoder = UNetLiteDecoder(
|
| encoder_channels=encoder_channels,
|
| output_channels=output_channels
|
| )
|
|
|
| print(f"ForgeryLocalizationNetwork initialized")
|
| print(f"Total parameters: {self.count_parameters():,}")
|
|
|
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
| """
|
| Forward pass
|
|
|
| Args:
|
| x: Input image tensor (B, 3, H, W)
|
|
|
| Returns:
|
| output: Forgery probability map (B, 1, H, W) - logits
|
| decoder_features: Decoder features for hybrid feature extraction
|
| """
|
|
|
| encoder_features = self.encoder(x)
|
|
|
|
|
| output, decoder_features = self.decoder(encoder_features)
|
|
|
| return output, decoder_features
|
|
|
| def predict(self, x: torch.Tensor, threshold: float = 0.5) -> torch.Tensor:
|
| """
|
| Predict binary mask
|
|
|
| Args:
|
| x: Input image tensor (B, 3, H, W)
|
| threshold: Probability threshold for binarization
|
|
|
| Returns:
|
| Binary mask (B, 1, H, W)
|
| """
|
| with torch.no_grad():
|
| logits, _ = self.forward(x)
|
| probs = torch.sigmoid(logits)
|
| mask = (probs > threshold).float()
|
|
|
| return mask
|
|
|
| def get_probability_map(self, x: torch.Tensor) -> torch.Tensor:
|
| """
|
| Get probability map
|
|
|
| Args:
|
| x: Input image tensor (B, 3, H, W)
|
|
|
| Returns:
|
| Probability map (B, 1, H, W)
|
| """
|
| with torch.no_grad():
|
| logits, _ = self.forward(x)
|
| probs = torch.sigmoid(logits)
|
|
|
| return probs
|
|
|
| def count_parameters(self) -> int:
|
| """Count total trainable parameters"""
|
| return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
|
|
| def get_decoder_features(self, x: torch.Tensor) -> List[torch.Tensor]:
|
| """
|
| Get decoder features for hybrid feature extraction
|
|
|
| Args:
|
| x: Input image tensor (B, 3, H, W)
|
|
|
| Returns:
|
| List of decoder features
|
| """
|
| with torch.no_grad():
|
| _, decoder_features = self.forward(x)
|
|
|
| return decoder_features
|
|
|
|
|
| def get_model(config) -> ForgeryLocalizationNetwork:
|
| """
|
| Factory function to create model
|
|
|
| Args:
|
| config: Configuration object
|
|
|
| Returns:
|
| Model instance
|
| """
|
| return ForgeryLocalizationNetwork(config)
|
|
|