FaceInsight_AI / src /models /face_model.py
vaisagan's picture
Upload src/models/face_model.py with huggingface_hub
7694c24 verified
"""
Multi-task face model: MobileNetV2 backbone → gender head + age head.
gender : CrossEntropyLoss (2-class)
age : SmoothL1Loss (regression, label normalised 0-1)
"""
from __future__ import annotations
from typing import Tuple
import torch
import torch.nn as nn
from torchvision import models
from torchvision.models import MobileNet_V2_Weights
class FaceModel(nn.Module):
def __init__(self, pretrained: bool = True, dropout: float = 0.3) -> None:
super().__init__()
weights = MobileNet_V2_Weights.IMAGENET1K_V1 if pretrained else None
backbone = models.mobilenet_v2(weights=weights)
# Feature extractor (all layers except the final classifier)
self.features = backbone.features
# Global average pooling + flatten → 1280-dim vector
self.pool = nn.AdaptiveAvgPool2d(1)
hidden = 512
self.shared = nn.Sequential(
nn.Flatten(),
nn.Linear(1280, hidden),
nn.BatchNorm1d(hidden),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
)
# Gender head: binary
self.gender_head = nn.Sequential(
nn.Linear(hidden, 128),
nn.ReLU(inplace=True),
nn.Linear(128, 2),
)
# Age head: scalar regression
self.age_head = nn.Sequential(
nn.Linear(hidden, 128),
nn.ReLU(inplace=True),
nn.Linear(128, 1),
nn.Sigmoid(), # output in [0, 1] matching normalised labels
)
def forward(
self, x: torch.Tensor
) -> "Tuple[torch.Tensor, torch.Tensor]":
x = self.features(x)
x = self.pool(x)
x = self.shared(x)
gender_logits = self.gender_head(x)
age_pred = self.age_head(x).squeeze(1)
return gender_logits, age_pred
def freeze_backbone(self) -> None:
for p in self.features.parameters():
p.requires_grad = False
def unfreeze_backbone(self) -> None:
for p in self.features.parameters():
p.requires_grad = True
def build_model(cfg, device: torch.device) -> FaceModel:
model = FaceModel(pretrained=True, dropout=0.3)
model.freeze_backbone() # warm-up phase: train heads only
return model.to(device)
def load_model(path: str, device: torch.device) -> FaceModel:
model = FaceModel(pretrained=False)
state = torch.load(path, map_location=device)
model.load_state_dict(state["model_state_dict"])
model.to(device)
model.eval()
return model