""" Multi-task face model: MobileNetV2 backbone → gender head + age head. gender : CrossEntropyLoss (2-class) age : SmoothL1Loss (regression, label normalised 0-1) """ from __future__ import annotations from typing import Tuple import torch import torch.nn as nn from torchvision import models from torchvision.models import MobileNet_V2_Weights class FaceModel(nn.Module): def __init__(self, pretrained: bool = True, dropout: float = 0.3) -> None: super().__init__() weights = MobileNet_V2_Weights.IMAGENET1K_V1 if pretrained else None backbone = models.mobilenet_v2(weights=weights) # Feature extractor (all layers except the final classifier) self.features = backbone.features # Global average pooling + flatten → 1280-dim vector self.pool = nn.AdaptiveAvgPool2d(1) hidden = 512 self.shared = nn.Sequential( nn.Flatten(), nn.Linear(1280, hidden), nn.BatchNorm1d(hidden), nn.ReLU(inplace=True), nn.Dropout(dropout), ) # Gender head: binary self.gender_head = nn.Sequential( nn.Linear(hidden, 128), nn.ReLU(inplace=True), nn.Linear(128, 2), ) # Age head: scalar regression self.age_head = nn.Sequential( nn.Linear(hidden, 128), nn.ReLU(inplace=True), nn.Linear(128, 1), nn.Sigmoid(), # output in [0, 1] matching normalised labels ) def forward( self, x: torch.Tensor ) -> "Tuple[torch.Tensor, torch.Tensor]": x = self.features(x) x = self.pool(x) x = self.shared(x) gender_logits = self.gender_head(x) age_pred = self.age_head(x).squeeze(1) return gender_logits, age_pred def freeze_backbone(self) -> None: for p in self.features.parameters(): p.requires_grad = False def unfreeze_backbone(self) -> None: for p in self.features.parameters(): p.requires_grad = True def build_model(cfg, device: torch.device) -> FaceModel: model = FaceModel(pretrained=True, dropout=0.3) model.freeze_backbone() # warm-up phase: train heads only return model.to(device) def load_model(path: str, device: torch.device) -> FaceModel: model = FaceModel(pretrained=False) state = torch.load(path, map_location=device) model.load_state_dict(state["model_state_dict"]) model.to(device) model.eval() return model