Spaces:
Sleeping
Sleeping
File size: 2,116 Bytes
ff0e79e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 | """
MobileNetV3-Small Encoder for forgery localization
ImageNet pretrained, feature extraction mode
"""
import torch
import torch.nn as nn
import timm
from typing import List
class MobileNetV3Encoder(nn.Module):
"""
MobileNetV3-Small encoder for document forgery detection
Chosen for:
- Stroke-level and texture preservation
- Robustness to compression and blur
- Edge and CPU deployment efficiency
"""
def __init__(self, pretrained: bool = True):
"""
Initialize encoder
Args:
pretrained: Whether to use ImageNet pretrained weights
"""
super().__init__()
# Load MobileNetV3-Small with feature extraction
self.backbone = timm.create_model(
'mobilenetv3_small_100',
pretrained=pretrained,
features_only=True,
out_indices=(0, 1, 2, 3, 4) # All feature stages
)
# Get feature channel dimensions
# MobileNetV3-Small: [16, 16, 24, 48, 576]
self.feature_channels = self.backbone.feature_info.channels()
print(f"MobileNetV3-Small encoder initialized")
print(f"Feature channels: {self.feature_channels}")
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
"""
Extract multi-scale features
Args:
x: Input tensor (B, 3, H, W)
Returns:
List of feature tensors at different scales
"""
features = self.backbone(x)
return features
def get_feature_channels(self) -> List[int]:
"""Get feature channel dimensions for each stage"""
return self.feature_channels
def get_encoder(config) -> MobileNetV3Encoder:
"""
Factory function to create encoder
Args:
config: Configuration object
Returns:
Encoder instance
"""
pretrained = config.get('model.encoder.pretrained', True)
return MobileNetV3Encoder(pretrained=pretrained)
|