File size: 3,818 Bytes
ff0e79e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""

Complete Forgery Localization Network

MobileNetV3-Small Encoder + UNet-Lite Decoder

"""

import torch
import torch.nn as nn
from typing import Tuple, List, Optional

from .encoder import MobileNetV3Encoder
from .decoder import UNetLiteDecoder


class ForgeryLocalizationNetwork(nn.Module):
    """

    Complete network for forgery localization

    

    Architecture:

    - Encoder: MobileNetV3-Small (ImageNet pretrained)

    - Decoder: UNet-Lite with skip connections

    - Output: Single-channel forgery probability map

    """
    
    def __init__(self, config):
        """

        Initialize network

        

        Args:

            config: Configuration object

        """
        super().__init__()
        
        self.config = config
        
        # Initialize encoder
        pretrained = config.get('model.encoder.pretrained', True)
        self.encoder = MobileNetV3Encoder(pretrained=pretrained)
        
        # Initialize decoder
        encoder_channels = self.encoder.get_feature_channels()
        output_channels = config.get('model.output_channels', 1)
        self.decoder = UNetLiteDecoder(
            encoder_channels=encoder_channels,
            output_channels=output_channels
        )
        
        print(f"ForgeryLocalizationNetwork initialized")
        print(f"Total parameters: {self.count_parameters():,}")
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        """

        Forward pass

        

        Args:

            x: Input image tensor (B, 3, H, W)

        

        Returns:

            output: Forgery probability map (B, 1, H, W) - logits

            decoder_features: Decoder features for hybrid feature extraction

        """
        # Encode
        encoder_features = self.encoder(x)
        
        # Decode
        output, decoder_features = self.decoder(encoder_features)
        
        return output, decoder_features
    
    def predict(self, x: torch.Tensor, threshold: float = 0.5) -> torch.Tensor:
        """

        Predict binary mask

        

        Args:

            x: Input image tensor (B, 3, H, W)

            threshold: Probability threshold for binarization

        

        Returns:

            Binary mask (B, 1, H, W)

        """
        with torch.no_grad():
            logits, _ = self.forward(x)
            probs = torch.sigmoid(logits)
            mask = (probs > threshold).float()
        
        return mask
    
    def get_probability_map(self, x: torch.Tensor) -> torch.Tensor:
        """

        Get probability map

        

        Args:

            x: Input image tensor (B, 3, H, W)

        

        Returns:

            Probability map (B, 1, H, W)

        """
        with torch.no_grad():
            logits, _ = self.forward(x)
            probs = torch.sigmoid(logits)
        
        return probs
    
    def count_parameters(self) -> int:
        """Count total trainable parameters"""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
    
    def get_decoder_features(self, x: torch.Tensor) -> List[torch.Tensor]:
        """

        Get decoder features for hybrid feature extraction

        

        Args:

            x: Input image tensor (B, 3, H, W)

        

        Returns:

            List of decoder features

        """
        with torch.no_grad():
            _, decoder_features = self.forward(x)
        
        return decoder_features


def get_model(config) -> ForgeryLocalizationNetwork:
    """

    Factory function to create model

    

    Args:

        config: Configuration object

    

    Returns:

        Model instance

    """
    return ForgeryLocalizationNetwork(config)