Spaces:
Running
Running
| """ | |
| Multi-task face model: MobileNetV2 backbone → gender head + age head. | |
| gender : CrossEntropyLoss (2-class) | |
| age : SmoothL1Loss (regression, label normalised 0-1) | |
| """ | |
| from __future__ import annotations | |
| from typing import Tuple | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import models | |
| from torchvision.models import MobileNet_V2_Weights | |
| class FaceModel(nn.Module): | |
| def __init__(self, pretrained: bool = True, dropout: float = 0.3) -> None: | |
| super().__init__() | |
| weights = MobileNet_V2_Weights.IMAGENET1K_V1 if pretrained else None | |
| backbone = models.mobilenet_v2(weights=weights) | |
| # Feature extractor (all layers except the final classifier) | |
| self.features = backbone.features | |
| # Global average pooling + flatten → 1280-dim vector | |
| self.pool = nn.AdaptiveAvgPool2d(1) | |
| hidden = 512 | |
| self.shared = nn.Sequential( | |
| nn.Flatten(), | |
| nn.Linear(1280, hidden), | |
| nn.BatchNorm1d(hidden), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(dropout), | |
| ) | |
| # Gender head: binary | |
| self.gender_head = nn.Sequential( | |
| nn.Linear(hidden, 128), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(128, 2), | |
| ) | |
| # Age head: scalar regression | |
| self.age_head = nn.Sequential( | |
| nn.Linear(hidden, 128), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(128, 1), | |
| nn.Sigmoid(), # output in [0, 1] matching normalised labels | |
| ) | |
| def forward( | |
| self, x: torch.Tensor | |
| ) -> "Tuple[torch.Tensor, torch.Tensor]": | |
| x = self.features(x) | |
| x = self.pool(x) | |
| x = self.shared(x) | |
| gender_logits = self.gender_head(x) | |
| age_pred = self.age_head(x).squeeze(1) | |
| return gender_logits, age_pred | |
| def freeze_backbone(self) -> None: | |
| for p in self.features.parameters(): | |
| p.requires_grad = False | |
| def unfreeze_backbone(self) -> None: | |
| for p in self.features.parameters(): | |
| p.requires_grad = True | |
| def build_model(cfg, device: torch.device) -> FaceModel: | |
| model = FaceModel(pretrained=True, dropout=0.3) | |
| model.freeze_backbone() # warm-up phase: train heads only | |
| return model.to(device) | |
| def load_model(path: str, device: torch.device) -> FaceModel: | |
| model = FaceModel(pretrained=False) | |
| state = torch.load(path, map_location=device) | |
| model.load_state_dict(state["model_state_dict"]) | |
| model.to(device) | |
| model.eval() | |
| return model | |