File size: 2,557 Bytes
7694c24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""
Multi-task face model: MobileNetV2 backbone → gender head + age head.

  gender : CrossEntropyLoss  (2-class)
  age    : SmoothL1Loss      (regression, label normalised 0-1)
"""

from __future__ import annotations

from typing import Tuple
import torch
import torch.nn as nn
from torchvision import models
from torchvision.models import MobileNet_V2_Weights


class FaceModel(nn.Module):
    def __init__(self, pretrained: bool = True, dropout: float = 0.3) -> None:
        super().__init__()

        weights = MobileNet_V2_Weights.IMAGENET1K_V1 if pretrained else None
        backbone = models.mobilenet_v2(weights=weights)

        # Feature extractor (all layers except the final classifier)
        self.features = backbone.features

        # Global average pooling + flatten → 1280-dim vector
        self.pool = nn.AdaptiveAvgPool2d(1)

        hidden = 512
        self.shared = nn.Sequential(
            nn.Flatten(),
            nn.Linear(1280, hidden),
            nn.BatchNorm1d(hidden),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
        )

        # Gender head: binary
        self.gender_head = nn.Sequential(
            nn.Linear(hidden, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 2),
        )

        # Age head: scalar regression
        self.age_head = nn.Sequential(
            nn.Linear(hidden, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 1),
            nn.Sigmoid(),   # output in [0, 1] matching normalised labels
        )

    def forward(
        self, x: torch.Tensor
    ) -> "Tuple[torch.Tensor, torch.Tensor]":
        x = self.features(x)
        x = self.pool(x)
        x = self.shared(x)
        gender_logits = self.gender_head(x)
        age_pred      = self.age_head(x).squeeze(1)
        return gender_logits, age_pred

    def freeze_backbone(self) -> None:
        for p in self.features.parameters():
            p.requires_grad = False

    def unfreeze_backbone(self) -> None:
        for p in self.features.parameters():
            p.requires_grad = True


def build_model(cfg, device: torch.device) -> FaceModel:
    model = FaceModel(pretrained=True, dropout=0.3)
    model.freeze_backbone()   # warm-up phase: train heads only
    return model.to(device)


def load_model(path: str, device: torch.device) -> FaceModel:
    model = FaceModel(pretrained=False)
    state = torch.load(path, map_location=device)
    model.load_state_dict(state["model_state_dict"])
    model.to(device)
    model.eval()
    return model