JKrishnanandhaa's picture
Upload 54 files
ff0e79e verified
raw
history blame
3.82 kB
"""
Complete Forgery Localization Network
MobileNetV3-Small Encoder + UNet-Lite Decoder
"""
import torch
import torch.nn as nn
from typing import Tuple, List, Optional
from .encoder import MobileNetV3Encoder
from .decoder import UNetLiteDecoder
class ForgeryLocalizationNetwork(nn.Module):
"""
Complete network for forgery localization
Architecture:
- Encoder: MobileNetV3-Small (ImageNet pretrained)
- Decoder: UNet-Lite with skip connections
- Output: Single-channel forgery probability map
"""
def __init__(self, config):
"""
Initialize network
Args:
config: Configuration object
"""
super().__init__()
self.config = config
# Initialize encoder
pretrained = config.get('model.encoder.pretrained', True)
self.encoder = MobileNetV3Encoder(pretrained=pretrained)
# Initialize decoder
encoder_channels = self.encoder.get_feature_channels()
output_channels = config.get('model.output_channels', 1)
self.decoder = UNetLiteDecoder(
encoder_channels=encoder_channels,
output_channels=output_channels
)
print(f"ForgeryLocalizationNetwork initialized")
print(f"Total parameters: {self.count_parameters():,}")
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""
Forward pass
Args:
x: Input image tensor (B, 3, H, W)
Returns:
output: Forgery probability map (B, 1, H, W) - logits
decoder_features: Decoder features for hybrid feature extraction
"""
# Encode
encoder_features = self.encoder(x)
# Decode
output, decoder_features = self.decoder(encoder_features)
return output, decoder_features
def predict(self, x: torch.Tensor, threshold: float = 0.5) -> torch.Tensor:
"""
Predict binary mask
Args:
x: Input image tensor (B, 3, H, W)
threshold: Probability threshold for binarization
Returns:
Binary mask (B, 1, H, W)
"""
with torch.no_grad():
logits, _ = self.forward(x)
probs = torch.sigmoid(logits)
mask = (probs > threshold).float()
return mask
def get_probability_map(self, x: torch.Tensor) -> torch.Tensor:
"""
Get probability map
Args:
x: Input image tensor (B, 3, H, W)
Returns:
Probability map (B, 1, H, W)
"""
with torch.no_grad():
logits, _ = self.forward(x)
probs = torch.sigmoid(logits)
return probs
def count_parameters(self) -> int:
"""Count total trainable parameters"""
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def get_decoder_features(self, x: torch.Tensor) -> List[torch.Tensor]:
"""
Get decoder features for hybrid feature extraction
Args:
x: Input image tensor (B, 3, H, W)
Returns:
List of decoder features
"""
with torch.no_grad():
_, decoder_features = self.forward(x)
return decoder_features
def get_model(config) -> ForgeryLocalizationNetwork:
"""
Factory function to create model
Args:
config: Configuration object
Returns:
Model instance
"""
return ForgeryLocalizationNetwork(config)