Spaces:
Sleeping
Sleeping
File size: 3,818 Bytes
ff0e79e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | """
Complete Forgery Localization Network
MobileNetV3-Small Encoder + UNet-Lite Decoder
"""
import torch
import torch.nn as nn
from typing import Tuple, List, Optional
from .encoder import MobileNetV3Encoder
from .decoder import UNetLiteDecoder
class ForgeryLocalizationNetwork(nn.Module):
"""
Complete network for forgery localization
Architecture:
- Encoder: MobileNetV3-Small (ImageNet pretrained)
- Decoder: UNet-Lite with skip connections
- Output: Single-channel forgery probability map
"""
def __init__(self, config):
"""
Initialize network
Args:
config: Configuration object
"""
super().__init__()
self.config = config
# Initialize encoder
pretrained = config.get('model.encoder.pretrained', True)
self.encoder = MobileNetV3Encoder(pretrained=pretrained)
# Initialize decoder
encoder_channels = self.encoder.get_feature_channels()
output_channels = config.get('model.output_channels', 1)
self.decoder = UNetLiteDecoder(
encoder_channels=encoder_channels,
output_channels=output_channels
)
print(f"ForgeryLocalizationNetwork initialized")
print(f"Total parameters: {self.count_parameters():,}")
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""
Forward pass
Args:
x: Input image tensor (B, 3, H, W)
Returns:
output: Forgery probability map (B, 1, H, W) - logits
decoder_features: Decoder features for hybrid feature extraction
"""
# Encode
encoder_features = self.encoder(x)
# Decode
output, decoder_features = self.decoder(encoder_features)
return output, decoder_features
def predict(self, x: torch.Tensor, threshold: float = 0.5) -> torch.Tensor:
"""
Predict binary mask
Args:
x: Input image tensor (B, 3, H, W)
threshold: Probability threshold for binarization
Returns:
Binary mask (B, 1, H, W)
"""
with torch.no_grad():
logits, _ = self.forward(x)
probs = torch.sigmoid(logits)
mask = (probs > threshold).float()
return mask
def get_probability_map(self, x: torch.Tensor) -> torch.Tensor:
"""
Get probability map
Args:
x: Input image tensor (B, 3, H, W)
Returns:
Probability map (B, 1, H, W)
"""
with torch.no_grad():
logits, _ = self.forward(x)
probs = torch.sigmoid(logits)
return probs
def count_parameters(self) -> int:
"""Count total trainable parameters"""
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def get_decoder_features(self, x: torch.Tensor) -> List[torch.Tensor]:
"""
Get decoder features for hybrid feature extraction
Args:
x: Input image tensor (B, 3, H, W)
Returns:
List of decoder features
"""
with torch.no_grad():
_, decoder_features = self.forward(x)
return decoder_features
def get_model(config) -> ForgeryLocalizationNetwork:
"""
Factory function to create model
Args:
config: Configuration object
Returns:
Model instance
"""
return ForgeryLocalizationNetwork(config)
|