SakibRumu commited on
Commit
d72a17d
·
verified ·
1 Parent(s): 183e2bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +404 -63
app.py CHANGED
@@ -1,79 +1,420 @@
 
1
  import torch
2
  import torch.nn as nn
3
- import gradio as gr
4
- from torchvision import models, transforms
 
5
  from PIL import Image
6
- from transformers import ViTModel
 
 
 
 
 
 
7
 
8
- # Define Hybrid CNN + Transformer
9
- class HybridCNNTransformer(nn.Module):
10
- def __init__(self, num_classes=7):
11
- super(HybridCNNTransformer, self).__init__()
12
- self.cnn = models.resnet50(pretrained=True)
13
- self.cnn = nn.Sequential(*list(self.cnn.children())[:-2])
14
- self.channel_reduction = nn.Conv2d(2048, 64, kernel_size=1)
15
- self.to_rgb = nn.Conv2d(64, 3, kernel_size=1)
16
- self.transformer = ViTModel.from_pretrained("google/vit-base-patch16-224")
17
- self.fc = nn.Sequential(
18
- nn.Linear(768, 512),
19
- nn.ReLU(),
20
- nn.Dropout(0.3),
21
- nn.Linear(512, num_classes)
22
- )
23
 
24
- def forward(self, x):
25
- x = self.cnn(x)
26
- x = self.channel_reduction(x)
27
- x = self.to_rgb(x)
28
- x = nn.functional.interpolate(x, size=(224, 224), mode="bilinear")
29
- x = self.transformer(pixel_values=x).last_hidden_state[:, 0, :]
30
- return self.fc(x)
 
 
 
 
 
 
 
 
31
 
32
- # Load model
33
- model = HybridCNNTransformer(num_classes=7)
34
- model.load_state_dict(torch.load("transformerHybrid_emotation_model.pth", map_location=torch.device('cpu')), strict=False)
35
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- # Transform
38
  transform = transforms.Compose([
39
  transforms.Resize((224, 224)),
40
  transforms.ToTensor(),
41
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
42
  ])
43
 
44
- # Prediction function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def predict_emotion(image):
46
- image = transform(image).unsqueeze(0)
47
- with torch.no_grad():
48
- output = model(image)
49
- probs = torch.nn.functional.softmax(output, dim=1)
50
- conf, pred = torch.max(probs, 1)
51
-
52
- labels = ["Angry", "Disgust", "Fear", "Happy", "Neutral", "Sad", "Surprise"]
53
- return labels[pred.item()], f"{conf.item() * 100:.2f}%"
54
-
55
- # Interface
56
- css = """
57
- body {
58
- background-color: #1e1e1e;
59
- color: white;
60
- }
61
- #component-1 {
62
- background-color: rgba(255, 255, 255, 0.7);
63
- padding: 20px;
64
- border-radius: 10px;
65
- }
66
- #component-2 {
67
- color: black;
68
- font-weight: bold;
69
- }
70
- """
 
 
71
 
72
- gr.Interface(
 
73
  fn=predict_emotion,
74
- inputs=gr.Image(type="pil"),
75
- outputs=[gr.Textbox(label="Predicted Emotion"), gr.Textbox(label="Confidence")],
76
- title="Emotion Classification",
77
- description="Upload an image to predict the emotion expressed using a Hybrid CNN + ViT model.",
78
- css=css
79
- ).launch()
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torchvision.transforms as transforms
6
+ import numpy as np
7
  from PIL import Image
8
+ import cv2
9
+ import dlib
10
+ import os
11
+ import requests
12
+ import bz2
13
+ import shutil
14
+ from efficientnet_pytorch import EfficientNet
15
 
16
+ # Define paths
17
+ SHAPE_PREDICTOR_URL = "http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2"
18
+ SHAPE_PREDICTOR_PATH = "shape_predictor_68_face_landmarks.dat"
19
+ MODEL_WEIGHTS_PATH = "quad_stream_model_rafdb.pth" # Update if weights are in a different path
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ # Download and extract shape predictor if not present
22
+ def download_shape_predictor():
23
+ if not os.path.exists(SHAPE_PREDICTOR_PATH):
24
+ print("Downloading shape predictor...")
25
+ response = requests.get(SHAPE_PREDICTOR_URL, stream=True)
26
+ with open("shape_predictor_68_face_landmarks.dat.bz2", "wb") as f:
27
+ f.write(response.content)
28
+ print("Extracting shape predictor...")
29
+ with bz2.BZ2File("shape_predictor_68_face_landmarks.dat.bz2", "rb") as f_in:
30
+ with open(SHAPE_PREDICTOR_PATH, "wb") as f_out:
31
+ shutil.copyfileobj(f_in, f_out)
32
+ os.remove("shape_predictor_68_face_landmarks.dat.bz2")
33
+ print("Shape predictor ready.")
34
+ else:
35
+ print("Shape predictor already exists.")
36
 
37
+ download_shape_predictor()
38
+
39
+ # Initialize Dlib detector and predictor
40
+ detector = dlib.get_frontal_face_detector()
41
+ predictor = dlib.shape_predictor(SHAPE_PREDICTOR_PATH)
42
+
43
+ # Class mapping for RAF-DB
44
+ class_mapping = {
45
+ 0: "Surprise",
46
+ 1: "Fear",
47
+ 2: "Disgust",
48
+ 3: "Happiness",
49
+ 4: "Sadness",
50
+ 5: "Anger",
51
+ 6: "Neutral"
52
+ }
53
 
54
+ # Transform for input images
55
  transform = transforms.Compose([
56
  transforms.Resize((224, 224)),
57
  transforms.ToTensor(),
58
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
59
  ])
60
 
61
+ # Function to extract landmark features
62
+ def extract_landmark_features(image):
63
+ image_np = np.array(image)
64
+ gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
65
+ h, w = image_np.shape[:2]
66
+
67
+ faces = detector(gray)
68
+ if len(faces) == 0:
69
+ return np.zeros(14, dtype=np.float32)
70
+
71
+ face = faces[0]
72
+ shape = predictor(gray, face)
73
+ landmarks = [(shape.part(i).x, shape.part(i).y) for i in range(68)]
74
+
75
+ key_points = {
76
+ 'left_eye': landmarks[36],
77
+ 'right_eye': landmarks[45],
78
+ 'nose_tip': landmarks[30],
79
+ 'mouth_left': landmarks[48],
80
+ 'mouth_right': landmarks[54],
81
+ 'left_eyebrow': landmarks[19],
82
+ 'right_eyebrow': landmarks[24],
83
+ 'jaw_left': landmarks[5],
84
+ 'jaw_right': landmarks[11],
85
+ 'chin': landmarks[8],
86
+ 'left_lower_eyelid': landmarks[41],
87
+ 'right_lower_eyelid': landmarks[46],
88
+ 'left_cheek': landmarks[2],
89
+ 'right_cheek': landmarks[14]
90
+ }
91
+
92
+ features = []
93
+ eye_dist = np.sqrt((key_points['left_eye'][0] - key_points['right_eye'][0])**2 +
94
+ (key_points['left_eye'][1] - key_points['right_eye'][1])**2)
95
+ features.append(eye_dist)
96
+
97
+ mouth_width = np.sqrt((key_points['mouth_left'][0] - key_points['mouth_right'][0])**2 +
98
+ (key_points['mouth_left'][1] - key_points['mouth_right'][1])**2)
99
+ features.append(mouth_width)
100
+
101
+ nose_to_mouth_left = np.sqrt((key_points['nose_tip'][0] - key_points['mouth_left'][0])**2 +
102
+ (key_points['nose_tip'][1] - key_points['mouth_left'][1])**2)
103
+ nose_to_mouth_right = np.sqrt((key_points['nose_tip'][0] - key_points['mouth_right'][0])**2 +
104
+ (key_points['nose_tip'][1] - key_points['mouth_right'][1])**2)
105
+ features.extend([nose_to_mouth_left, nose_to_mouth_right])
106
+
107
+ left_eye_to_nose = np.sqrt((key_points['left_eye'][0] - key_points['nose_tip'][0])**2 +
108
+ (key_points['left_eye'][1] - key_points['nose_tip'][1])**2)
109
+ right_eye_to_nose = np.sqrt((key_points['right_eye'][0] - key_points['nose_tip'][0])**2 +
110
+ (key_points['right_eye'][1] - key_points['nose_tip'][1])**2)
111
+ features.extend([left_eye_to_nose, right_eye_to_nose])
112
+
113
+ vec1 = np.array([key_points['left_eye'][0] - key_points['nose_tip'][0],
114
+ key_points['left_eye'][1] - key_points['nose_tip'][1]])
115
+ vec2 = np.array([key_points['right_eye'][0] - key_points['nose_tip'][0],
116
+ key_points['right_eye'][1] - key_points['nose_tip'][1]])
117
+ cos_angle = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2) + 1e-8)
118
+ angle = np.arccos(np.clip(cos_angle, -1.0, 1.0))
119
+ features.append(angle)
120
+
121
+ mouth_center = ((key_points['mouth_left'][0] + key_points['mouth_right'][0]) / 2,
122
+ (key_points['mouth_left'][1] + key_points['mouth_right'][1]) / 2)
123
+ mouth_to_left_eye = np.sqrt((mouth_center[0] - key_points['left_eye'][0])**2 +
124
+ (mouth_center[1] - key_points['left_eye'][1])**2)
125
+ mouth_to_right_eye = np.sqrt((mouth_center[0] - key_points['right_eye'][0])**2 +
126
+ (mouth_center[1] - key_points['right_eye'][1])**2)
127
+ features.extend([mouth_to_left_eye, mouth_to_right_eye])
128
+
129
+ mouth_aspect_ratio = mouth_width / (nose_to_mouth_left + nose_to_mouth_right + 1e-8)
130
+ features.append(mouth_aspect_ratio)
131
+
132
+ left_eyebrow_to_eye = np.sqrt((key_points['left_eyebrow'][0] - key_points['left_eye'][0])**2 +
133
+ (key_points['left_eyebrow'][1] - key_points['left_eye'][1])**2)
134
+ right_eyebrow_to_eye = np.sqrt((key_points['right_eyebrow'][0] - key_points['right_eye'][0])**2 +
135
+ (key_points['right_eyebrow'][1] - key_points['right_eye'][1])**2)
136
+ features.extend([left_eyebrow_to_eye, right_eyebrow_to_eye])
137
+
138
+ left_au6 = np.sqrt((key_points['left_lower_eyelid'][0] - key_points['left_cheek'][0])**2 +
139
+ (key_points['left_lower_eyelid'][1] - key_points['left_cheek'][1])**2)
140
+ right_au6 = np.sqrt((key_points['right_lower_eyelid'][0] - key_points['right_cheek'][0])**2 +
141
+ (key_points['right_lower_eyelid'][1] - key_points['right_cheek'][1])**2)
142
+ avg_au6 = (left_au6 + right_au6) / 2
143
+ features.append(avg_au6)
144
+
145
+ mouth_left_to_chin = np.sqrt((key_points['mouth_left'][0] - key_points['chin'][0])**2 +
146
+ (key_points['mouth_left'][1] - key_points['chin'][1])**2)
147
+ mouth_right_to_chin = np.sqrt((key_points['mouth_right'][0] - key_points['chin'][0])**2 +
148
+ (key_points['mouth_right'][1] - key_points['chin'][1])**2)
149
+ avg_au12 = (mouth_left_to_chin + mouth_right_to_chin) / (2 * (mouth_width + 1e-8))
150
+ features.append(avg_au12)
151
+
152
+ return np.array(features, dtype=np.float32)
153
+
154
+ # Function to get landmark mask
155
+ def get_landmark_mask(image, target_size=(7, 7)):
156
+ image_np = np.array(image)
157
+ gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
158
+ h, w = image_np.shape[:2]
159
+
160
+ faces = detector(gray)
161
+ if len(faces) == 0:
162
+ return np.ones(target_size, dtype=np.float32)
163
+
164
+ face = faces[0]
165
+ shape = predictor(gray, face)
166
+ landmarks = [(shape.part(i).x, shape.part(i).y) for i in range(68)]
167
+
168
+ mask = np.zeros((h, w), dtype=np.float32)
169
+
170
+ eye_indices = [36, 39, 42, 45]
171
+ mouth_indices = [48, 54, 51, 57]
172
+ eyebrow_indices = [19, 24]
173
+ jaw_indices = [5, 11, 8]
174
+ cheek_indices = [2, 14]
175
+ key_points = [landmarks[i] for i in eye_indices + mouth_indices + eyebrow_indices + jaw_indices + cheek_indices]
176
+
177
+ for i, (x, y) in enumerate(key_points):
178
+ radius = 30 if i in [4, 5, 6, 7, 12, 13] else 20
179
+ cv2.circle(mask, (x, y), radius, 1.0, -1)
180
+
181
+ mask = cv2.resize(mask, target_size, interpolation=cv2.INTER_LINEAR)
182
+ mask = np.clip(mask, 0, 1)
183
+ return mask
184
+
185
+ # Model definitions
186
+ class EfficientNetBackbone(nn.Module):
187
+ def __init__(self):
188
+ super(EfficientNetBackbone, self).__init__()
189
+ self.efficientnet = EfficientNet.from_pretrained('efficientnet-b4')
190
+ self.efficientnet._conv_stem = nn.Conv2d(3, 48, kernel_size=3, stride=2, padding=1, bias=False)
191
+ self.channel_reducer = nn.Conv2d(1792, 256, kernel_size=1, stride=1, padding=0, bias=False)
192
+ self.bn = nn.BatchNorm2d(256)
193
+ nn.init.xavier_uniform_(self.channel_reducer.weight)
194
+
195
+ def forward(self, x):
196
+ x = self.efficientnet.extract_features(x)
197
+ x = self.channel_reducer(x)
198
+ x = self.bn(x)
199
+ return x
200
+
201
+ class HLA(nn.Module):
202
+ def __init__(self, in_channels=256, reduction=4):
203
+ super(HLA, self).__init__()
204
+ reduced_channels = in_channels // reduction
205
+ self.spatial_branch1 = nn.Conv2d(in_channels, reduced_channels, 1)
206
+ self.spatial_branch2 = nn.Conv2d(in_channels, reduced_channels, 1)
207
+ self.sigmoid = nn.Sigmoid()
208
+ self.channel_restore = nn.Conv2d(reduced_channels, in_channels, 1)
209
+ self.channel_attention = nn.Sequential(
210
+ nn.AdaptiveAvgPool2d(1),
211
+ nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False),
212
+ nn.ReLU(),
213
+ nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False),
214
+ nn.Sigmoid()
215
+ )
216
+ self.bn = nn.BatchNorm2d(in_channels, eps=1e-5)
217
+ self.dropout = nn.Dropout2d(0.2)
218
+
219
+ def forward(self, x, landmark_mask=None):
220
+ b1 = self.spatial_branch1(x)
221
+ b2 = self.spatial_branch2(x)
222
+ spatial_attn = self.sigmoid(torch.max(b1, b2))
223
+ spatial_attn = self.channel_restore(spatial_attn)
224
+
225
+ if landmark_mask is not None:
226
+ landmark_mask = torch.tensor(landmark_mask, dtype=x.dtype)
227
+ landmark_mask = landmark_mask.view(-1, 1, 7, 7)
228
+ spatial_attn = spatial_attn * landmark_mask
229
+
230
+ spatial_attn = self.dropout(spatial_attn)
231
+ spatial_out = x * spatial_attn
232
+ channel_attn = self.channel_attention(spatial_out)
233
+ channel_attn = self.dropout(channel_attn)
234
+ out = spatial_out * channel_attn
235
+ out = self.bn(out)
236
+ return out
237
+
238
+ class ViT(nn.Module):
239
+ def __init__(self, in_channels=256, patch_size=1, embed_dim=768, num_layers=8, num_heads=12):
240
+ super(ViT, self).__init__()
241
+ self.patch_embed = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
242
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
243
+ num_patches = (7 // patch_size) * (7 // patch_size)
244
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
245
+ self.transformer = nn.ModuleList([
246
+ nn.TransformerEncoderLayer(embed_dim, num_heads, dim_feedforward=1536, activation="gelu")
247
+ for _ in range(num_layers)
248
+ ])
249
+ self.ln = nn.LayerNorm(embed_dim)
250
+ self.bn = nn.BatchNorm1d(embed_dim, eps=1e-5)
251
+ nn.init.xavier_uniform_(self.patch_embed.weight)
252
+ nn.init.zeros_(self.patch_embed.bias)
253
+ nn.init.normal_(self.cls_token, std=0.02)
254
+ nn.init.normal_(self.pos_embed, std=0.02)
255
+
256
+ def forward(self, x):
257
+ x = self.patch_embed(x)
258
+ x = x.flatten(2).transpose(1, 2)
259
+ cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
260
+ x = torch.cat([cls_tokens, x], dim=1)
261
+ x = x + self.pos_embed
262
+ for layer in self.transformer:
263
+ x = layer(x)
264
+ x = x[:, 0]
265
+ x = self.ln(x)
266
+ x = self.bn(x)
267
+ return x
268
+
269
+ class IntensityStream(nn.Module):
270
+ def __init__(self, in_channels=256):
271
+ super(IntensityStream, self).__init__()
272
+ sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32)
273
+ sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32)
274
+ self.sobel_x = nn.Conv2d(in_channels, in_channels, 3, padding=1, bias=False, groups=in_channels)
275
+ self.sobel_y = nn.Conv2d(in_channels, in_channels, 3, padding=1, bias=False, groups=in_channels)
276
+ self.sobel_x.weight.data = sobel_x.repeat(in_channels, 1, 1, 1)
277
+ self.sobel_y.weight.data = sobel_y.repeat(in_channels, 1, 1, 1)
278
+ self.conv = nn.Conv2d(in_channels, 128, 3, padding=1)
279
+ self.bn = nn.BatchNorm2d(128, eps=1e-5)
280
+ self.pool = nn.AdaptiveAvgPool2d(1)
281
+ self.attention = nn.MultiheadAttention(embed_dim=128, num_heads=1)
282
+ nn.init.xavier_uniform_(self.conv.weight)
283
+ nn.init.zeros_(self.conv.bias)
284
+
285
+ def forward(self, x):
286
+ gx = self.sobel_x(x)
287
+ gy = self.sobel_y(x)
288
+ grad_magnitude = torch.sqrt(gx**2 + gy**2 + 1e-8)
289
+ variance = ((x - x.mean(dim=1, keepdim=True))**2).mean(dim=1).flatten(1)
290
+ cnn_out = F.relu(self.conv(grad_magnitude))
291
+ cnn_out = self.bn(cnn_out)
292
+ texture_out = self.pool(cnn_out).squeeze(-1).squeeze(-1)
293
+ attn_in = cnn_out.flatten(2).permute(2, 0, 1)
294
+ attn_in = attn_in / (attn_in.norm(dim=-1, keepdim=True) + 1e-8)
295
+ attn_out, _ = self.attention(attn_in, attn_in, attn_in)
296
+ context_out = attn_out.mean(dim=0)
297
+ out = torch.cat([texture_out, context_out], dim=1)
298
+ return out, grad_magnitude, variance
299
+
300
+ class LandmarkStream(nn.Module):
301
+ def __init__(self, input_dim=14, embed_dim=768):
302
+ super(LandmarkStream, self).__init__()
303
+ self.fc1 = nn.Linear(input_dim, 128)
304
+ self.fc2 = nn.Linear(128, 256)
305
+ self.fc3 = nn.Linear(256, embed_dim)
306
+ self.bn1 = nn.BatchNorm1d(128)
307
+ self.bn2 = nn.BatchNorm1d(256)
308
+ self.bn3 = nn.BatchNorm1d(embed_dim)
309
+ self.dropout = nn.Dropout(0.4)
310
+ nn.init.xavier_uniform_(self.fc1.weight)
311
+ nn.init.zeros_(self.fc1.bias)
312
+ nn.init.xavier_uniform_(self.fc2.weight)
313
+ nn.init.zeros_(self.fc2.bias)
314
+ nn.init.xavier_uniform_(self.fc3.weight)
315
+ nn.init.zeros_(self.fc3.bias)
316
+
317
+ def forward(self, x):
318
+ x = F.relu(self.bn1(self.fc1(x)))
319
+ x = self.dropout(x)
320
+ x = F.relu(self.bn2(self.fc2(x)))
321
+ x = self.dropout(x)
322
+ x = self.bn3(self.fc3(x))
323
+ return x
324
+
325
+ class QuadStreamHLAViT(nn.Module):
326
+ def __init__(self, num_classes=7):
327
+ super(QuadStreamHLAViT, self).__init__()
328
+ self.backbone = EfficientNetBackbone()
329
+ self.hla = HLA()
330
+ self.vit = ViT()
331
+ self.intensity = IntensityStream()
332
+ self.landmark = LandmarkStream(input_dim=14, embed_dim=768)
333
+ self.fc_hla = nn.Linear(256*7*7, 768)
334
+ self.fc_intensity = nn.Linear(256, 768)
335
+ self.fusion_fc = nn.Linear(768*4, 512)
336
+ self.bn_fusion = nn.BatchNorm1d(512, eps=1e-5)
337
+ self.dropout = nn.Dropout(0.6)
338
+ self.classifier = nn.Linear(512, num_classes)
339
+ nn.init.xavier_uniform_(self.fc_hla.weight)
340
+ nn.init.zeros_(self.fc_hla.bias)
341
+ nn.init.xavier_uniform_(self.fc_intensity.weight)
342
+ nn.init.zeros_(self.fc_intensity.bias)
343
+ nn.init.xavier_uniform_(self.fusion_fc.weight)
344
+ nn.init.zeros_(self.fusion_fc.bias)
345
+ nn.init.xavier_uniform_(self.classifier.weight)
346
+ nn.init.zeros_(self.classifier.bias)
347
+
348
+ def forward(self, x, landmark_features, landmark_mask=None):
349
+ features = self.backbone(x)
350
+ hla_out = self.hla(features, landmark_mask)
351
+ vit_out = self.vit(features)
352
+ intensity_out, grad_magnitude, variance = self.intensity(features)
353
+ landmark_out = self.landmark(landmark_features)
354
+ hla_flat = self.fc_hla(hla_out.view(-1, 256*7*7))
355
+ intensity_flat = self.fc_intensity(intensity_out)
356
+ fused = torch.cat([hla_flat, vit_out, intensity_flat, landmark_out], dim=1)
357
+ fused = F.relu(self.fusion_fc(fused))
358
+ fused = self.bn_fusion(fused)
359
+ fused = self.dropout(fused)
360
+ logits = self.classifier(fused)
361
+ return logits, hla_out, vit_out, grad_magnitude, variance
362
+
363
+ # Load model
364
+ model = QuadStreamHLAViT(num_classes=7)
365
+ if os.path.exists(MODEL_WEIGHTS_PATH):
366
+ try:
367
+ model.load_state_dict(torch.load(MODEL_WEIGHTS_PATH, map_location=torch.device('cpu'), weights_only=True))
368
+ print("Model weights loaded successfully.")
369
+ except Exception as e:
370
+ print(f"Error loading model weights: {e}")
371
+ else:
372
+ print(f"Model weights not found at {MODEL_WEIGHTS_PATH}. Please upload the weights.")
373
+ model.eval()
374
+
375
+ # Inference function
376
  def predict_emotion(image):
377
+ try:
378
+ # Convert image to RGB
379
+ if isinstance(image, np.ndarray):
380
+ image = Image.fromarray(image)
381
+ image = image.convert("RGB")
382
+
383
+ # Extract landmarks and mask
384
+ lm_features = extract_landmark_features(image)
385
+ lm_mask = get_landmark_mask(image)
386
+
387
+ # Transform image
388
+ img_tensor = transform(image).unsqueeze(0)
389
+ lm_features_tensor = torch.tensor(lm_features, dtype=torch.float32).unsqueeze(0)
390
+
391
+ # Run inference
392
+ with torch.no_grad():
393
+ outputs, _, _, _, _ = model(img_tensor, lm_features_tensor, lm_mask)
394
+ probs = F.softmax(outputs, dim=1)[0]
395
+ pred_label = torch.argmax(probs).item()
396
+ pred_emotion = class_mapping[pred_label]
397
+
398
+ # Format probabilities
399
+ prob_dict = {class_mapping[i]: f"{probs[i].item():.4f}" for i in range(len(class_mapping))}
400
+
401
+ return pred_emotion, prob_dict
402
+ except Exception as e:
403
+ return "Error", {"Message": f"Failed to process image: {str(e)}"}
404
 
405
+ # Gradio interface
406
+ iface = gr.Interface(
407
  fn=predict_emotion,
408
+ inputs=gr.Image(type="pil", label="Upload an Image"),
409
+ outputs=[
410
+ gr.Textbox(label="Predicted Emotion"),
411
+ gr.JSON(label="Emotion Probabilities")
412
+ ],
413
+ title="Facial Emotion Recognition with QuadStreamHLAViT",
414
+ description="Upload an image to predict facial emotions (Surprise, Fear, Disgust, Happiness, Sadness, Anger, Neutral) using a QuadStreamHLAViT model trained on RAF-DB. Model accuracy: 82.31%.",
415
+ allow_flagging="never"
416
+ )
417
+
418
+ # Launch the app
419
+ if __name__ == "__main__":
420
+ iface.launch()