SE-AlexNet / modeling.py
JiayuMBao's picture
Upload folder using huggingface_hub
6e88f15 verified
Raw
History Blame Contribute Delete
15.6 kB
"""
SE-AlexNet: Model Architectures for Emotion Recognition with Squeeze-and-Excitation.
This file defines the exact model architectures matching the trained weight files.
Each model class mirrors the state_dict key structure of the original .pth files
to ensure seamless weight loading.
Architecture Summary:
- AlexNet : Standard AlexNet, num_classes=11 (pretrained on AffectNet 11-way)
- AlexNetWithSE_L1 : SE block after all conv layers (256ch), reduction r ∈ {2,4,8,16,32}
- AlexNetWithSE_L2 : SE block between fc6→fc7 (4096dim), reduction=16 (fixed)
- AlexNetWithSE_L3 : SELayer inside classifier[0] (256ch), reduction r ∈ {2,4,8,16,32}
- VGG16 : Standard VGG16, num_classes=2 (binary Happy/Sad)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
# =============================================================================
# 1. Standard AlexNet (used by RawAlexNet)
# =============================================================================
class AlexNet(nn.Module):
"""
Standard AlexNet adapted for facial emotion recognition.
Differences from torchvision AlexNet:
- No Dropout before fc6 (only one Dropout between fc6→fc7)
- num_classes=11 for AffectNet pretraining
State dict key structure:
features.0.weight, features.0.bias, ...
classifier.0.weight → Linear(9216, 4096) [fc6]
classifier.3.weight → Linear(4096, 4096) [fc7]
classifier.5.weight → Linear(4096, 11) [fc8]
"""
def __init__(self, num_classes=11):
super(AlexNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(64, 192, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(192, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
)
self.classifier = nn.Sequential(
nn.Linear(256 * 6 * 6, 4096), # 0: fc6
nn.ReLU(inplace=True), # 1
nn.Dropout(p=0.5), # 2
nn.Linear(4096, 4096), # 3: fc7
nn.ReLU(inplace=True), # 4
nn.Linear(4096, num_classes), # 5: fc8
)
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, start_dim=1)
x = self.classifier(x)
return x
# =============================================================================
# 2. SE-AlexNet Location 1: SE block after last conv, before FC
# =============================================================================
class SEBlock_L1(nn.Module):
"""
SE block operating on 2D feature maps (256 channels from last conv).
State dict keys (as self.se_block):
se_block.global_pool (no params)
se_block.fc1.weight → Linear(256, 256//reduction)
se_block.fc2.weight → Linear(256//reduction, 256)
se_block.sigmoid (no params)
"""
def __init__(self, in_channels, reduction=32):
super(SEBlock_L1, self).__init__()
self.global_pool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Linear(in_channels, in_channels // reduction, bias=False)
self.fc2 = nn.Linear(in_channels // reduction, in_channels, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
b, c, _, _ = x.size()
y = self.global_pool(x).view(b, c)
y = F.relu(self.fc1(y))
y = self.fc2(y)
y = self.sigmoid(y).view(b, c, 1, 1)
return x * y.expand_as(x)
class AlexNetWithSE_L1(nn.Module):
"""
SE-AlexNet with SE block between features and classifier.
State dict key structure:
features.* (AlexNet conv layers)
se_block.fc1.weight, se_block.fc2.weight
classifier.1.weight → Linear(9216, 4096) [fc6]
classifier.4.weight → Linear(4096, 4096) [fc7]
classifier.6.weight → Linear(4096, 8) [fc8]
"""
def __init__(self, num_classes=8, reduction=32):
super(AlexNetWithSE_L1, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(64, 192, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(192, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
)
self.se_block = SEBlock_L1(256, reduction=reduction)
self.classifier = nn.Sequential(
nn.Dropout(), # 0
nn.Linear(9216, 4096), # 1: fc6
nn.ReLU(inplace=True), # 2
nn.Dropout(), # 3
nn.Linear(4096, 4096), # 4: fc7
nn.ReLU(inplace=True), # 5
nn.Linear(4096, num_classes), # 6: fc8
)
def forward(self, x):
x = self.features(x)
x = self.se_block(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
# =============================================================================
# 3. SE-AlexNet Location 2: SE block between fc6 and fc7 (in classifier)
# =============================================================================
class SEBlock_L2(nn.Module):
"""
SE block operating on 1D feature vectors (4096-dim from fc6).
NOTE: This SE block has FIXED reduction=16 for ALL squeeze-labeled variants.
The squeeze label (2/4/8/16/32) refers to a training hyperparameter,
not the architectural reduction ratio.
State dict keys (as classifier.3):
classifier.3.fc1.weight, classifier.3.fc1.bias → Linear(4096, 256)
classifier.3.fc2.weight, classifier.3.fc2.bias → Linear(256, 4096)
"""
def __init__(self, channel, reduction=16):
super(SEBlock_L2, self).__init__()
self.fc1 = nn.Linear(channel, channel // reduction, bias=True)
self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Linear(channel // reduction, channel, bias=True)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
y = self.fc1(x)
y = self.relu(y)
y = self.fc2(y)
y = self.sigmoid(y)
return x * y
class AlexNetWithSE_L2(nn.Module):
"""
SE-AlexNet with SE block between fc6→fc7.
Uses torchvision AlexNet feature extractor + AdaptiveAvgPool2d.
State dict key structure:
features.* (AlexNet conv from torchvision)
avgpool (AdaptiveAvgPool2d, no params)
classifier.1.weight → Linear(9216, 4096) [fc6]
classifier.3.fc1/fc2 → SEBlock(4096, red=16) [SE]
classifier.5.weight → Linear(4096, 4096) [fc7]
classifier.7.weight → Linear(4096, 8) [fc8]
"""
def __init__(self, num_classes=8, reduction=16):
super(AlexNetWithSE_L2, self).__init__()
# reduction is accepted for API compatibility but FIXED at 16
# (all L2 variants share identical architecture regardless of squeeze label)
# Match torchvision AlexNet feature extractor
from torchvision.models import alexnet
pretrained = alexnet(pretrained=False)
self.features = pretrained.features
self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
self.classifier = nn.Sequential(
nn.Dropout(), # 0
nn.Linear(256 * 6 * 6, 4096), # 1: fc6
nn.ReLU(inplace=True), # 2
SEBlock_L2(4096, reduction=16), # 3: SE (fixed reduction=16)
nn.Dropout(), # 4
nn.Linear(4096, 4096), # 5: fc7
nn.ReLU(inplace=True), # 6
nn.Linear(4096, num_classes), # 7: fc8
)
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
# =============================================================================
# 4. SE-AlexNet Location 3: SELayer inside classifier[0]
# =============================================================================
class SELayer_L3(nn.Module):
"""
SE layer operating on 256-channel feature maps.
Packaged as nn.Sequential internally to create sub-indices (fc.0, fc.1, fc.2, fc.3).
State dict keys (as classifier.0):
classifier.0.fc.0.weight → Linear(256, 256//reduction)
classifier.0.fc.2.weight → Linear(256//reduction, 256)
"""
def __init__(self, channel, reduction=32):
super(SELayer_L3, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False), # fc.0
nn.ReLU(inplace=True), # fc.1
nn.Linear(channel // reduction, channel, bias=False), # fc.2
nn.Sigmoid() # fc.3
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
class AlexNetWithSE_L3(nn.Module):
"""
SE-AlexNet with SELayer as classifier[0] (after features, before fc6).
State dict key structure:
features.* (AlexNet conv layers)
classifier.0.fc.0/fc.2 (SELayer on 256ch)
classifier.1.weight → Linear(9216, 4096) [fc6]
classifier.4.weight → Linear(4096, 4096) [fc7]
classifier.6.weight → Linear(4096, 11) [fc8]
"""
def __init__(self, num_classes=11, reduction=32):
super(AlexNetWithSE_L3, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(64, 192, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(192, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
)
self.classifier = nn.Sequential(
SELayer_L3(256, reduction=reduction), # 0: SE
nn.Linear(9216, 4096), # 1: fc6 (no Dropout before)
nn.ReLU(inplace=True), # 2
nn.Dropout(p=0.5), # 3
nn.Linear(4096, 4096), # 4: fc7
nn.ReLU(inplace=True), # 5
nn.Linear(4096, num_classes), # 6: fc8
)
def forward(self, x):
x = self.features(x)
# classifier[0] = SELayer_L3 (operates on 4D feature maps)
# classifier[1] = Linear (needs 2D input)
x = self.classifier[0](x) # SE: (B,256,6,6) → (B,256,6,6)
x = torch.flatten(x, start_dim=1) # → (B, 9216)
for layer in self.classifier[1:]: # remaining FC layers
x = layer(x)
return x
# =============================================================================
# 5. VGG16
# =============================================================================
class VGG16(nn.Module):
"""
Standard VGG16 adapted for binary emotion classification (Happy vs Sad).
State dict key structure:
features.* (13 conv + 5 pool layers)
classifier.0.weight → Linear(25088, 4096) [fc6]
classifier.3.weight → Linear(4096, 4096) [fc7]
classifier.6.weight → Linear(4096, 2) [fc8]
"""
def __init__(self, num_classes=2):
super(VGG16, self).__init__()
self.features = self._make_features([
64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M',
512, 512, 512, 'M', 512, 512, 512, 'M'
])
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096), # 0: fc6
nn.ReLU(True), # 1
nn.Dropout(p=0.5), # 2
nn.Linear(4096, 4096), # 3: fc7
nn.ReLU(True), # 4
nn.Dropout(p=0.5), # 5
nn.Linear(4096, num_classes), # 6: fc8
)
@staticmethod
def _make_features(cfg):
layers = []
in_channels = 3
for v in cfg:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
layers += [nn.Conv2d(in_channels, v, kernel_size=3, padding=1),
nn.ReLU(True)]
in_channels = v
return nn.Sequential(*layers)
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, start_dim=1)
x = self.classifier(x)
return x
# =============================================================================
# Model Registry — maps config.model_type to class
# =============================================================================
MODEL_REGISTRY = {
'alexnet': AlexNet,
'se_alexnet_l1': AlexNetWithSE_L1,
'se_alexnet_l2': AlexNetWithSE_L2,
'se_alexnet_l3': AlexNetWithSE_L3,
'vgg16': VGG16,
}
def load_model_from_config(config, weights_path=None, device='cpu'):
"""
Build a model from a config dict and optionally load weights.
Args:
config: dict with keys 'model_type', 'num_classes', and optionally
'reduction' (for SE variants)
weights_path: path to .pth or .safetensors file
device: torch device
Returns:
model: nn.Module
"""
model_type = config['model_type']
model_cls = MODEL_REGISTRY[model_type]
# Build constructor kwargs
kwargs = {'num_classes': config['num_classes']}
if 'reduction' in config:
kwargs['reduction'] = config['reduction']
model = model_cls(**kwargs)
if weights_path is not None:
if weights_path.endswith('.safetensors'):
from safetensors.torch import load_file
state_dict = load_file(weights_path)
else:
state_dict = torch.load(weights_path, map_location=device, weights_only=True)
# Handle checkpoint wrappers
if 'model' in state_dict:
state_dict = state_dict['model']
model.load_state_dict(state_dict, strict=True)
model.to(device)
model.eval()
return model