winzerprince commited on
Commit
f9a8949
·
verified ·
1 Parent(s): 5013dec

Add modeling_vit_emotion.py

Browse files
Files changed (1) hide show
  1. modeling_vit_emotion.py +135 -0
modeling_vit_emotion.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Vision Transformer (ViT) Model Definition for Emotion Regression
3
+
4
+ This file defines the ViT model architecture used for valence-arousal prediction.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from transformers import ViTModel, ViTConfig
10
+
11
+
12
+ class ViTForEmotionRegression(nn.Module):
13
+ """
14
+ Vision Transformer for emotion regression (valence and arousal prediction).
15
+
16
+ Architecture:
17
+ - Pre-trained ViT backbone (google/vit-base-patch16-224-in21k)
18
+ - Custom regression head for 2D emotion prediction
19
+ - Dropout for regularization
20
+ """
21
+
22
+ def __init__(self, model_name='google/vit-base-patch16-224-in21k',
23
+ num_emotions=2, freeze_backbone=False, dropout=0.1):
24
+ super().__init__()
25
+
26
+ # Load pre-trained ViT model
27
+ try:
28
+ self.vit = ViTModel.from_pretrained(model_name)
29
+ print(f"✅ Loaded pre-trained ViT from {model_name}")
30
+ except Exception as e:
31
+ print(f"⚠️ Could not load pre-trained model: {e}")
32
+ print(" Initializing with random weights...")
33
+ config = ViTConfig()
34
+ self.vit = ViTModel(config)
35
+
36
+ # Freeze backbone if specified
37
+ if freeze_backbone:
38
+ for param in self.vit.parameters():
39
+ param.requires_grad = False
40
+ print(f"❄️ Frozen ViT backbone")
41
+
42
+ # Get hidden size from ViT config
43
+ hidden_size = self.vit.config.hidden_size
44
+
45
+ # Regression head for emotion prediction (named 'head' to match saved checkpoint)
46
+ # Architecture: 768 -> 512 -> 128 -> 2
47
+ self.head = nn.Sequential(
48
+ nn.LayerNorm(hidden_size), # [0] weight: [768], bias: [768]
49
+ nn.Dropout(dropout),
50
+ nn.Linear(hidden_size, 512), # [2] weight: [512, 768], bias: [512]
51
+ nn.GELU(),
52
+ nn.Dropout(dropout),
53
+ nn.Linear(512, 128), # [5] weight: [128, 512], bias: [128]
54
+ nn.GELU(),
55
+ nn.Dropout(dropout),
56
+ nn.Linear(128, num_emotions), # [8] weight: [2, 128], bias: [2]
57
+ nn.Tanh() # Output in range [-1, 1]
58
+ )
59
+
60
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
61
+ """
62
+ Forward pass through the model.
63
+
64
+ Args:
65
+ pixel_values: Input images tensor of shape (batch_size, 3, 224, 224)
66
+
67
+ Returns:
68
+ Emotion predictions tensor of shape (batch_size, 2) [valence, arousal]
69
+ """
70
+ # Get ViT outputs
71
+ outputs = self.vit(pixel_values)
72
+ cls_output = outputs.last_hidden_state[:, 0]
73
+
74
+ # Predict emotions
75
+ emotion_predictions = self.head(cls_output)
76
+ return emotion_predictions
77
+
78
+
79
+ class MobileViTStudent(nn.Module):
80
+ """
81
+ Lightweight MobileViT student model for emotion regression.
82
+ Used in distilled version for faster inference.
83
+ """
84
+
85
+ def __init__(self, num_emotions=2, dropout=0.1):
86
+ super().__init__()
87
+
88
+ # Lightweight CNN backbone
89
+ self.conv_stem = nn.Sequential(
90
+ nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),
91
+ nn.BatchNorm2d(32),
92
+ nn.ReLU(inplace=True),
93
+ )
94
+
95
+ # Mobile inverted bottleneck blocks
96
+ self.blocks = nn.Sequential(
97
+ self._make_mb_block(32, 64, stride=2),
98
+ self._make_mb_block(64, 128, stride=2),
99
+ self._make_mb_block(128, 256, stride=2),
100
+ )
101
+
102
+ # Global pooling
103
+ self.global_pool = nn.AdaptiveAvgPool2d(1)
104
+
105
+ # Regression head (named 'head' to match saved checkpoint)
106
+ self.head = nn.Sequential(
107
+ nn.Flatten(),
108
+ nn.Linear(256, 128),
109
+ nn.ReLU(inplace=True),
110
+ nn.Dropout(dropout),
111
+ nn.Linear(128, num_emotions),
112
+ nn.Tanh()
113
+ )
114
+
115
+ def _make_mb_block(self, in_channels, out_channels, stride=1):
116
+ """Create Mobile Inverted Bottleneck block"""
117
+ return nn.Sequential(
118
+ # Depthwise
119
+ nn.Conv2d(in_channels, in_channels, kernel_size=3,
120
+ stride=stride, padding=1, groups=in_channels),
121
+ nn.BatchNorm2d(in_channels),
122
+ nn.ReLU(inplace=True),
123
+ # Pointwise
124
+ nn.Conv2d(in_channels, out_channels, kernel_size=1),
125
+ nn.BatchNorm2d(out_channels),
126
+ nn.ReLU(inplace=True),
127
+ )
128
+
129
+ def forward(self, x):
130
+ """Forward pass"""
131
+ x = self.conv_stem(x)
132
+ x = self.blocks(x)
133
+ x = self.global_pool(x)
134
+ emotions = self.head(x)
135
+ return emotions