| """ |
| SE-AlexNet: Model Architectures for Emotion Recognition with Squeeze-and-Excitation. |
| |
| This file defines the exact model architectures matching the trained weight files. |
| Each model class mirrors the state_dict key structure of the original .pth files |
| to ensure seamless weight loading. |
| |
| Architecture Summary: |
| - AlexNet : Standard AlexNet, num_classes=11 (pretrained on AffectNet 11-way) |
| - AlexNetWithSE_L1 : SE block after all conv layers (256ch), reduction r ∈ {2,4,8,16,32} |
| - AlexNetWithSE_L2 : SE block between fc6→fc7 (4096dim), reduction=16 (fixed) |
| - AlexNetWithSE_L3 : SELayer inside classifier[0] (256ch), reduction r ∈ {2,4,8,16,32} |
| - VGG16 : Standard VGG16, num_classes=2 (binary Happy/Sad) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| |
| |
| |
| class AlexNet(nn.Module): |
| """ |
| Standard AlexNet adapted for facial emotion recognition. |
| |
| Differences from torchvision AlexNet: |
| - No Dropout before fc6 (only one Dropout between fc6→fc7) |
| - num_classes=11 for AffectNet pretraining |
| |
| State dict key structure: |
| features.0.weight, features.0.bias, ... |
| classifier.0.weight → Linear(9216, 4096) [fc6] |
| classifier.3.weight → Linear(4096, 4096) [fc7] |
| classifier.5.weight → Linear(4096, 11) [fc8] |
| """ |
| def __init__(self, num_classes=11): |
| super(AlexNet, self).__init__() |
| self.features = nn.Sequential( |
| nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), |
| nn.ReLU(inplace=True), |
| nn.MaxPool2d(kernel_size=3, stride=2), |
| nn.Conv2d(64, 192, kernel_size=5, padding=2), |
| nn.ReLU(inplace=True), |
| nn.MaxPool2d(kernel_size=3, stride=2), |
| nn.Conv2d(192, 384, kernel_size=3, padding=1), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(384, 256, kernel_size=3, padding=1), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(256, 256, kernel_size=3, padding=1), |
| nn.ReLU(inplace=True), |
| nn.MaxPool2d(kernel_size=3, stride=2), |
| ) |
| self.classifier = nn.Sequential( |
| nn.Linear(256 * 6 * 6, 4096), |
| nn.ReLU(inplace=True), |
| nn.Dropout(p=0.5), |
| nn.Linear(4096, 4096), |
| nn.ReLU(inplace=True), |
| nn.Linear(4096, num_classes), |
| ) |
|
|
| def forward(self, x): |
| x = self.features(x) |
| x = torch.flatten(x, start_dim=1) |
| x = self.classifier(x) |
| return x |
|
|
|
|
| |
| |
| |
| class SEBlock_L1(nn.Module): |
| """ |
| SE block operating on 2D feature maps (256 channels from last conv). |
| |
| State dict keys (as self.se_block): |
| se_block.global_pool (no params) |
| se_block.fc1.weight → Linear(256, 256//reduction) |
| se_block.fc2.weight → Linear(256//reduction, 256) |
| se_block.sigmoid (no params) |
| """ |
| def __init__(self, in_channels, reduction=32): |
| super(SEBlock_L1, self).__init__() |
| self.global_pool = nn.AdaptiveAvgPool2d(1) |
| self.fc1 = nn.Linear(in_channels, in_channels // reduction, bias=False) |
| self.fc2 = nn.Linear(in_channels // reduction, in_channels, bias=False) |
| self.sigmoid = nn.Sigmoid() |
|
|
| def forward(self, x): |
| b, c, _, _ = x.size() |
| y = self.global_pool(x).view(b, c) |
| y = F.relu(self.fc1(y)) |
| y = self.fc2(y) |
| y = self.sigmoid(y).view(b, c, 1, 1) |
| return x * y.expand_as(x) |
|
|
|
|
| class AlexNetWithSE_L1(nn.Module): |
| """ |
| SE-AlexNet with SE block between features and classifier. |
| |
| State dict key structure: |
| features.* (AlexNet conv layers) |
| se_block.fc1.weight, se_block.fc2.weight |
| classifier.1.weight → Linear(9216, 4096) [fc6] |
| classifier.4.weight → Linear(4096, 4096) [fc7] |
| classifier.6.weight → Linear(4096, 8) [fc8] |
| """ |
| def __init__(self, num_classes=8, reduction=32): |
| super(AlexNetWithSE_L1, self).__init__() |
| self.features = nn.Sequential( |
| nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), |
| nn.ReLU(inplace=True), |
| nn.MaxPool2d(kernel_size=3, stride=2), |
| nn.Conv2d(64, 192, kernel_size=5, padding=2), |
| nn.ReLU(inplace=True), |
| nn.MaxPool2d(kernel_size=3, stride=2), |
| nn.Conv2d(192, 384, kernel_size=3, padding=1), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(384, 256, kernel_size=3, padding=1), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(256, 256, kernel_size=3, padding=1), |
| nn.ReLU(inplace=True), |
| nn.MaxPool2d(kernel_size=3, stride=2), |
| ) |
| self.se_block = SEBlock_L1(256, reduction=reduction) |
| self.classifier = nn.Sequential( |
| nn.Dropout(), |
| nn.Linear(9216, 4096), |
| nn.ReLU(inplace=True), |
| nn.Dropout(), |
| nn.Linear(4096, 4096), |
| nn.ReLU(inplace=True), |
| nn.Linear(4096, num_classes), |
| ) |
|
|
| def forward(self, x): |
| x = self.features(x) |
| x = self.se_block(x) |
| x = x.view(x.size(0), -1) |
| x = self.classifier(x) |
| return x |
|
|
|
|
| |
| |
| |
| class SEBlock_L2(nn.Module): |
| """ |
| SE block operating on 1D feature vectors (4096-dim from fc6). |
| NOTE: This SE block has FIXED reduction=16 for ALL squeeze-labeled variants. |
| The squeeze label (2/4/8/16/32) refers to a training hyperparameter, |
| not the architectural reduction ratio. |
| |
| State dict keys (as classifier.3): |
| classifier.3.fc1.weight, classifier.3.fc1.bias → Linear(4096, 256) |
| classifier.3.fc2.weight, classifier.3.fc2.bias → Linear(256, 4096) |
| """ |
| def __init__(self, channel, reduction=16): |
| super(SEBlock_L2, self).__init__() |
| self.fc1 = nn.Linear(channel, channel // reduction, bias=True) |
| self.relu = nn.ReLU(inplace=True) |
| self.fc2 = nn.Linear(channel // reduction, channel, bias=True) |
| self.sigmoid = nn.Sigmoid() |
|
|
| def forward(self, x): |
| y = self.fc1(x) |
| y = self.relu(y) |
| y = self.fc2(y) |
| y = self.sigmoid(y) |
| return x * y |
|
|
|
|
| class AlexNetWithSE_L2(nn.Module): |
| """ |
| SE-AlexNet with SE block between fc6→fc7. |
| Uses torchvision AlexNet feature extractor + AdaptiveAvgPool2d. |
| |
| State dict key structure: |
| features.* (AlexNet conv from torchvision) |
| avgpool (AdaptiveAvgPool2d, no params) |
| classifier.1.weight → Linear(9216, 4096) [fc6] |
| classifier.3.fc1/fc2 → SEBlock(4096, red=16) [SE] |
| classifier.5.weight → Linear(4096, 4096) [fc7] |
| classifier.7.weight → Linear(4096, 8) [fc8] |
| """ |
| def __init__(self, num_classes=8, reduction=16): |
| super(AlexNetWithSE_L2, self).__init__() |
| |
| |
| |
| from torchvision.models import alexnet |
| pretrained = alexnet(pretrained=False) |
| self.features = pretrained.features |
| self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) |
| self.classifier = nn.Sequential( |
| nn.Dropout(), |
| nn.Linear(256 * 6 * 6, 4096), |
| nn.ReLU(inplace=True), |
| SEBlock_L2(4096, reduction=16), |
| nn.Dropout(), |
| nn.Linear(4096, 4096), |
| nn.ReLU(inplace=True), |
| nn.Linear(4096, num_classes), |
| ) |
|
|
| def forward(self, x): |
| x = self.features(x) |
| x = self.avgpool(x) |
| x = torch.flatten(x, 1) |
| x = self.classifier(x) |
| return x |
|
|
|
|
| |
| |
| |
| class SELayer_L3(nn.Module): |
| """ |
| SE layer operating on 256-channel feature maps. |
| Packaged as nn.Sequential internally to create sub-indices (fc.0, fc.1, fc.2, fc.3). |
| |
| State dict keys (as classifier.0): |
| classifier.0.fc.0.weight → Linear(256, 256//reduction) |
| classifier.0.fc.2.weight → Linear(256//reduction, 256) |
| """ |
| def __init__(self, channel, reduction=32): |
| super(SELayer_L3, self).__init__() |
| self.avg_pool = nn.AdaptiveAvgPool2d(1) |
| self.fc = nn.Sequential( |
| nn.Linear(channel, channel // reduction, bias=False), |
| nn.ReLU(inplace=True), |
| nn.Linear(channel // reduction, channel, bias=False), |
| nn.Sigmoid() |
| ) |
|
|
| def forward(self, x): |
| b, c, _, _ = x.size() |
| y = self.avg_pool(x).view(b, c) |
| y = self.fc(y).view(b, c, 1, 1) |
| return x * y.expand_as(x) |
|
|
|
|
| class AlexNetWithSE_L3(nn.Module): |
| """ |
| SE-AlexNet with SELayer as classifier[0] (after features, before fc6). |
| |
| State dict key structure: |
| features.* (AlexNet conv layers) |
| classifier.0.fc.0/fc.2 (SELayer on 256ch) |
| classifier.1.weight → Linear(9216, 4096) [fc6] |
| classifier.4.weight → Linear(4096, 4096) [fc7] |
| classifier.6.weight → Linear(4096, 11) [fc8] |
| """ |
| def __init__(self, num_classes=11, reduction=32): |
| super(AlexNetWithSE_L3, self).__init__() |
| self.features = nn.Sequential( |
| nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), |
| nn.ReLU(inplace=True), |
| nn.MaxPool2d(kernel_size=3, stride=2), |
| nn.Conv2d(64, 192, kernel_size=5, padding=2), |
| nn.ReLU(inplace=True), |
| nn.MaxPool2d(kernel_size=3, stride=2), |
| nn.Conv2d(192, 384, kernel_size=3, padding=1), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(384, 256, kernel_size=3, padding=1), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(256, 256, kernel_size=3, padding=1), |
| nn.ReLU(inplace=True), |
| nn.MaxPool2d(kernel_size=3, stride=2), |
| ) |
| self.classifier = nn.Sequential( |
| SELayer_L3(256, reduction=reduction), |
| nn.Linear(9216, 4096), |
| nn.ReLU(inplace=True), |
| nn.Dropout(p=0.5), |
| nn.Linear(4096, 4096), |
| nn.ReLU(inplace=True), |
| nn.Linear(4096, num_classes), |
| ) |
|
|
| def forward(self, x): |
| x = self.features(x) |
| |
| |
| x = self.classifier[0](x) |
| x = torch.flatten(x, start_dim=1) |
| for layer in self.classifier[1:]: |
| x = layer(x) |
| return x |
|
|
|
|
| |
| |
| |
| class VGG16(nn.Module): |
| """ |
| Standard VGG16 adapted for binary emotion classification (Happy vs Sad). |
| |
| State dict key structure: |
| features.* (13 conv + 5 pool layers) |
| classifier.0.weight → Linear(25088, 4096) [fc6] |
| classifier.3.weight → Linear(4096, 4096) [fc7] |
| classifier.6.weight → Linear(4096, 2) [fc8] |
| """ |
| def __init__(self, num_classes=2): |
| super(VGG16, self).__init__() |
| self.features = self._make_features([ |
| 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', |
| 512, 512, 512, 'M', 512, 512, 512, 'M' |
| ]) |
| self.classifier = nn.Sequential( |
| nn.Linear(512 * 7 * 7, 4096), |
| nn.ReLU(True), |
| nn.Dropout(p=0.5), |
| nn.Linear(4096, 4096), |
| nn.ReLU(True), |
| nn.Dropout(p=0.5), |
| nn.Linear(4096, num_classes), |
| ) |
|
|
| @staticmethod |
| def _make_features(cfg): |
| layers = [] |
| in_channels = 3 |
| for v in cfg: |
| if v == 'M': |
| layers += [nn.MaxPool2d(kernel_size=2, stride=2)] |
| else: |
| layers += [nn.Conv2d(in_channels, v, kernel_size=3, padding=1), |
| nn.ReLU(True)] |
| in_channels = v |
| return nn.Sequential(*layers) |
|
|
| def forward(self, x): |
| x = self.features(x) |
| x = torch.flatten(x, start_dim=1) |
| x = self.classifier(x) |
| return x |
|
|
|
|
| |
| |
| |
| MODEL_REGISTRY = { |
| 'alexnet': AlexNet, |
| 'se_alexnet_l1': AlexNetWithSE_L1, |
| 'se_alexnet_l2': AlexNetWithSE_L2, |
| 'se_alexnet_l3': AlexNetWithSE_L3, |
| 'vgg16': VGG16, |
| } |
|
|
|
|
| def load_model_from_config(config, weights_path=None, device='cpu'): |
| """ |
| Build a model from a config dict and optionally load weights. |
| |
| Args: |
| config: dict with keys 'model_type', 'num_classes', and optionally |
| 'reduction' (for SE variants) |
| weights_path: path to .pth or .safetensors file |
| device: torch device |
| |
| Returns: |
| model: nn.Module |
| """ |
| model_type = config['model_type'] |
| model_cls = MODEL_REGISTRY[model_type] |
|
|
| |
| kwargs = {'num_classes': config['num_classes']} |
| if 'reduction' in config: |
| kwargs['reduction'] = config['reduction'] |
|
|
| model = model_cls(**kwargs) |
|
|
| if weights_path is not None: |
| if weights_path.endswith('.safetensors'): |
| from safetensors.torch import load_file |
| state_dict = load_file(weights_path) |
| else: |
| state_dict = torch.load(weights_path, map_location=device, weights_only=True) |
|
|
| |
| if 'model' in state_dict: |
| state_dict = state_dict['model'] |
|
|
| model.load_state_dict(state_dict, strict=True) |
|
|
| model.to(device) |
| model.eval() |
| return model |
|
|