Spaces:
Build error
Build error
SakibRumu
commited on
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,79 +1,420 @@
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
-
import
|
| 4 |
-
|
|
|
|
| 5 |
from PIL import Image
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
-
# Define
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
self.cnn = models.resnet50(pretrained=True)
|
| 13 |
-
self.cnn = nn.Sequential(*list(self.cnn.children())[:-2])
|
| 14 |
-
self.channel_reduction = nn.Conv2d(2048, 64, kernel_size=1)
|
| 15 |
-
self.to_rgb = nn.Conv2d(64, 3, kernel_size=1)
|
| 16 |
-
self.transformer = ViTModel.from_pretrained("google/vit-base-patch16-224")
|
| 17 |
-
self.fc = nn.Sequential(
|
| 18 |
-
nn.Linear(768, 512),
|
| 19 |
-
nn.ReLU(),
|
| 20 |
-
nn.Dropout(0.3),
|
| 21 |
-
nn.Linear(512, num_classes)
|
| 22 |
-
)
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
# Transform
|
| 38 |
transform = transforms.Compose([
|
| 39 |
transforms.Resize((224, 224)),
|
| 40 |
transforms.ToTensor(),
|
| 41 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 42 |
])
|
| 43 |
|
| 44 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
def predict_emotion(image):
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
|
|
|
|
|
|
| 71 |
|
| 72 |
-
|
|
|
|
| 73 |
fn=predict_emotion,
|
| 74 |
-
inputs=gr.Image(type="pil"),
|
| 75 |
-
outputs=[
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import torchvision.transforms as transforms
|
| 6 |
+
import numpy as np
|
| 7 |
from PIL import Image
|
| 8 |
+
import cv2
|
| 9 |
+
import dlib
|
| 10 |
+
import os
|
| 11 |
+
import requests
|
| 12 |
+
import bz2
|
| 13 |
+
import shutil
|
| 14 |
+
from efficientnet_pytorch import EfficientNet
|
| 15 |
|
| 16 |
+
# Define paths
|
| 17 |
+
SHAPE_PREDICTOR_URL = "http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2"
|
| 18 |
+
SHAPE_PREDICTOR_PATH = "shape_predictor_68_face_landmarks.dat"
|
| 19 |
+
MODEL_WEIGHTS_PATH = "quad_stream_model_rafdb.pth" # Update if weights are in a different path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
+
# Download and extract shape predictor if not present
|
| 22 |
+
def download_shape_predictor():
|
| 23 |
+
if not os.path.exists(SHAPE_PREDICTOR_PATH):
|
| 24 |
+
print("Downloading shape predictor...")
|
| 25 |
+
response = requests.get(SHAPE_PREDICTOR_URL, stream=True)
|
| 26 |
+
with open("shape_predictor_68_face_landmarks.dat.bz2", "wb") as f:
|
| 27 |
+
f.write(response.content)
|
| 28 |
+
print("Extracting shape predictor...")
|
| 29 |
+
with bz2.BZ2File("shape_predictor_68_face_landmarks.dat.bz2", "rb") as f_in:
|
| 30 |
+
with open(SHAPE_PREDICTOR_PATH, "wb") as f_out:
|
| 31 |
+
shutil.copyfileobj(f_in, f_out)
|
| 32 |
+
os.remove("shape_predictor_68_face_landmarks.dat.bz2")
|
| 33 |
+
print("Shape predictor ready.")
|
| 34 |
+
else:
|
| 35 |
+
print("Shape predictor already exists.")
|
| 36 |
|
| 37 |
+
download_shape_predictor()
|
| 38 |
+
|
| 39 |
+
# Initialize Dlib detector and predictor
|
| 40 |
+
detector = dlib.get_frontal_face_detector()
|
| 41 |
+
predictor = dlib.shape_predictor(SHAPE_PREDICTOR_PATH)
|
| 42 |
+
|
| 43 |
+
# Class mapping for RAF-DB
|
| 44 |
+
class_mapping = {
|
| 45 |
+
0: "Surprise",
|
| 46 |
+
1: "Fear",
|
| 47 |
+
2: "Disgust",
|
| 48 |
+
3: "Happiness",
|
| 49 |
+
4: "Sadness",
|
| 50 |
+
5: "Anger",
|
| 51 |
+
6: "Neutral"
|
| 52 |
+
}
|
| 53 |
|
| 54 |
+
# Transform for input images
|
| 55 |
transform = transforms.Compose([
|
| 56 |
transforms.Resize((224, 224)),
|
| 57 |
transforms.ToTensor(),
|
| 58 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 59 |
])
|
| 60 |
|
| 61 |
+
# Function to extract landmark features
|
| 62 |
+
def extract_landmark_features(image):
|
| 63 |
+
image_np = np.array(image)
|
| 64 |
+
gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
|
| 65 |
+
h, w = image_np.shape[:2]
|
| 66 |
+
|
| 67 |
+
faces = detector(gray)
|
| 68 |
+
if len(faces) == 0:
|
| 69 |
+
return np.zeros(14, dtype=np.float32)
|
| 70 |
+
|
| 71 |
+
face = faces[0]
|
| 72 |
+
shape = predictor(gray, face)
|
| 73 |
+
landmarks = [(shape.part(i).x, shape.part(i).y) for i in range(68)]
|
| 74 |
+
|
| 75 |
+
key_points = {
|
| 76 |
+
'left_eye': landmarks[36],
|
| 77 |
+
'right_eye': landmarks[45],
|
| 78 |
+
'nose_tip': landmarks[30],
|
| 79 |
+
'mouth_left': landmarks[48],
|
| 80 |
+
'mouth_right': landmarks[54],
|
| 81 |
+
'left_eyebrow': landmarks[19],
|
| 82 |
+
'right_eyebrow': landmarks[24],
|
| 83 |
+
'jaw_left': landmarks[5],
|
| 84 |
+
'jaw_right': landmarks[11],
|
| 85 |
+
'chin': landmarks[8],
|
| 86 |
+
'left_lower_eyelid': landmarks[41],
|
| 87 |
+
'right_lower_eyelid': landmarks[46],
|
| 88 |
+
'left_cheek': landmarks[2],
|
| 89 |
+
'right_cheek': landmarks[14]
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
features = []
|
| 93 |
+
eye_dist = np.sqrt((key_points['left_eye'][0] - key_points['right_eye'][0])**2 +
|
| 94 |
+
(key_points['left_eye'][1] - key_points['right_eye'][1])**2)
|
| 95 |
+
features.append(eye_dist)
|
| 96 |
+
|
| 97 |
+
mouth_width = np.sqrt((key_points['mouth_left'][0] - key_points['mouth_right'][0])**2 +
|
| 98 |
+
(key_points['mouth_left'][1] - key_points['mouth_right'][1])**2)
|
| 99 |
+
features.append(mouth_width)
|
| 100 |
+
|
| 101 |
+
nose_to_mouth_left = np.sqrt((key_points['nose_tip'][0] - key_points['mouth_left'][0])**2 +
|
| 102 |
+
(key_points['nose_tip'][1] - key_points['mouth_left'][1])**2)
|
| 103 |
+
nose_to_mouth_right = np.sqrt((key_points['nose_tip'][0] - key_points['mouth_right'][0])**2 +
|
| 104 |
+
(key_points['nose_tip'][1] - key_points['mouth_right'][1])**2)
|
| 105 |
+
features.extend([nose_to_mouth_left, nose_to_mouth_right])
|
| 106 |
+
|
| 107 |
+
left_eye_to_nose = np.sqrt((key_points['left_eye'][0] - key_points['nose_tip'][0])**2 +
|
| 108 |
+
(key_points['left_eye'][1] - key_points['nose_tip'][1])**2)
|
| 109 |
+
right_eye_to_nose = np.sqrt((key_points['right_eye'][0] - key_points['nose_tip'][0])**2 +
|
| 110 |
+
(key_points['right_eye'][1] - key_points['nose_tip'][1])**2)
|
| 111 |
+
features.extend([left_eye_to_nose, right_eye_to_nose])
|
| 112 |
+
|
| 113 |
+
vec1 = np.array([key_points['left_eye'][0] - key_points['nose_tip'][0],
|
| 114 |
+
key_points['left_eye'][1] - key_points['nose_tip'][1]])
|
| 115 |
+
vec2 = np.array([key_points['right_eye'][0] - key_points['nose_tip'][0],
|
| 116 |
+
key_points['right_eye'][1] - key_points['nose_tip'][1]])
|
| 117 |
+
cos_angle = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2) + 1e-8)
|
| 118 |
+
angle = np.arccos(np.clip(cos_angle, -1.0, 1.0))
|
| 119 |
+
features.append(angle)
|
| 120 |
+
|
| 121 |
+
mouth_center = ((key_points['mouth_left'][0] + key_points['mouth_right'][0]) / 2,
|
| 122 |
+
(key_points['mouth_left'][1] + key_points['mouth_right'][1]) / 2)
|
| 123 |
+
mouth_to_left_eye = np.sqrt((mouth_center[0] - key_points['left_eye'][0])**2 +
|
| 124 |
+
(mouth_center[1] - key_points['left_eye'][1])**2)
|
| 125 |
+
mouth_to_right_eye = np.sqrt((mouth_center[0] - key_points['right_eye'][0])**2 +
|
| 126 |
+
(mouth_center[1] - key_points['right_eye'][1])**2)
|
| 127 |
+
features.extend([mouth_to_left_eye, mouth_to_right_eye])
|
| 128 |
+
|
| 129 |
+
mouth_aspect_ratio = mouth_width / (nose_to_mouth_left + nose_to_mouth_right + 1e-8)
|
| 130 |
+
features.append(mouth_aspect_ratio)
|
| 131 |
+
|
| 132 |
+
left_eyebrow_to_eye = np.sqrt((key_points['left_eyebrow'][0] - key_points['left_eye'][0])**2 +
|
| 133 |
+
(key_points['left_eyebrow'][1] - key_points['left_eye'][1])**2)
|
| 134 |
+
right_eyebrow_to_eye = np.sqrt((key_points['right_eyebrow'][0] - key_points['right_eye'][0])**2 +
|
| 135 |
+
(key_points['right_eyebrow'][1] - key_points['right_eye'][1])**2)
|
| 136 |
+
features.extend([left_eyebrow_to_eye, right_eyebrow_to_eye])
|
| 137 |
+
|
| 138 |
+
left_au6 = np.sqrt((key_points['left_lower_eyelid'][0] - key_points['left_cheek'][0])**2 +
|
| 139 |
+
(key_points['left_lower_eyelid'][1] - key_points['left_cheek'][1])**2)
|
| 140 |
+
right_au6 = np.sqrt((key_points['right_lower_eyelid'][0] - key_points['right_cheek'][0])**2 +
|
| 141 |
+
(key_points['right_lower_eyelid'][1] - key_points['right_cheek'][1])**2)
|
| 142 |
+
avg_au6 = (left_au6 + right_au6) / 2
|
| 143 |
+
features.append(avg_au6)
|
| 144 |
+
|
| 145 |
+
mouth_left_to_chin = np.sqrt((key_points['mouth_left'][0] - key_points['chin'][0])**2 +
|
| 146 |
+
(key_points['mouth_left'][1] - key_points['chin'][1])**2)
|
| 147 |
+
mouth_right_to_chin = np.sqrt((key_points['mouth_right'][0] - key_points['chin'][0])**2 +
|
| 148 |
+
(key_points['mouth_right'][1] - key_points['chin'][1])**2)
|
| 149 |
+
avg_au12 = (mouth_left_to_chin + mouth_right_to_chin) / (2 * (mouth_width + 1e-8))
|
| 150 |
+
features.append(avg_au12)
|
| 151 |
+
|
| 152 |
+
return np.array(features, dtype=np.float32)
|
| 153 |
+
|
| 154 |
+
# Function to get landmark mask
|
| 155 |
+
def get_landmark_mask(image, target_size=(7, 7)):
|
| 156 |
+
image_np = np.array(image)
|
| 157 |
+
gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
|
| 158 |
+
h, w = image_np.shape[:2]
|
| 159 |
+
|
| 160 |
+
faces = detector(gray)
|
| 161 |
+
if len(faces) == 0:
|
| 162 |
+
return np.ones(target_size, dtype=np.float32)
|
| 163 |
+
|
| 164 |
+
face = faces[0]
|
| 165 |
+
shape = predictor(gray, face)
|
| 166 |
+
landmarks = [(shape.part(i).x, shape.part(i).y) for i in range(68)]
|
| 167 |
+
|
| 168 |
+
mask = np.zeros((h, w), dtype=np.float32)
|
| 169 |
+
|
| 170 |
+
eye_indices = [36, 39, 42, 45]
|
| 171 |
+
mouth_indices = [48, 54, 51, 57]
|
| 172 |
+
eyebrow_indices = [19, 24]
|
| 173 |
+
jaw_indices = [5, 11, 8]
|
| 174 |
+
cheek_indices = [2, 14]
|
| 175 |
+
key_points = [landmarks[i] for i in eye_indices + mouth_indices + eyebrow_indices + jaw_indices + cheek_indices]
|
| 176 |
+
|
| 177 |
+
for i, (x, y) in enumerate(key_points):
|
| 178 |
+
radius = 30 if i in [4, 5, 6, 7, 12, 13] else 20
|
| 179 |
+
cv2.circle(mask, (x, y), radius, 1.0, -1)
|
| 180 |
+
|
| 181 |
+
mask = cv2.resize(mask, target_size, interpolation=cv2.INTER_LINEAR)
|
| 182 |
+
mask = np.clip(mask, 0, 1)
|
| 183 |
+
return mask
|
| 184 |
+
|
| 185 |
+
# Model definitions
|
| 186 |
+
class EfficientNetBackbone(nn.Module):
|
| 187 |
+
def __init__(self):
|
| 188 |
+
super(EfficientNetBackbone, self).__init__()
|
| 189 |
+
self.efficientnet = EfficientNet.from_pretrained('efficientnet-b4')
|
| 190 |
+
self.efficientnet._conv_stem = nn.Conv2d(3, 48, kernel_size=3, stride=2, padding=1, bias=False)
|
| 191 |
+
self.channel_reducer = nn.Conv2d(1792, 256, kernel_size=1, stride=1, padding=0, bias=False)
|
| 192 |
+
self.bn = nn.BatchNorm2d(256)
|
| 193 |
+
nn.init.xavier_uniform_(self.channel_reducer.weight)
|
| 194 |
+
|
| 195 |
+
def forward(self, x):
|
| 196 |
+
x = self.efficientnet.extract_features(x)
|
| 197 |
+
x = self.channel_reducer(x)
|
| 198 |
+
x = self.bn(x)
|
| 199 |
+
return x
|
| 200 |
+
|
| 201 |
+
class HLA(nn.Module):
|
| 202 |
+
def __init__(self, in_channels=256, reduction=4):
|
| 203 |
+
super(HLA, self).__init__()
|
| 204 |
+
reduced_channels = in_channels // reduction
|
| 205 |
+
self.spatial_branch1 = nn.Conv2d(in_channels, reduced_channels, 1)
|
| 206 |
+
self.spatial_branch2 = nn.Conv2d(in_channels, reduced_channels, 1)
|
| 207 |
+
self.sigmoid = nn.Sigmoid()
|
| 208 |
+
self.channel_restore = nn.Conv2d(reduced_channels, in_channels, 1)
|
| 209 |
+
self.channel_attention = nn.Sequential(
|
| 210 |
+
nn.AdaptiveAvgPool2d(1),
|
| 211 |
+
nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False),
|
| 212 |
+
nn.ReLU(),
|
| 213 |
+
nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False),
|
| 214 |
+
nn.Sigmoid()
|
| 215 |
+
)
|
| 216 |
+
self.bn = nn.BatchNorm2d(in_channels, eps=1e-5)
|
| 217 |
+
self.dropout = nn.Dropout2d(0.2)
|
| 218 |
+
|
| 219 |
+
def forward(self, x, landmark_mask=None):
|
| 220 |
+
b1 = self.spatial_branch1(x)
|
| 221 |
+
b2 = self.spatial_branch2(x)
|
| 222 |
+
spatial_attn = self.sigmoid(torch.max(b1, b2))
|
| 223 |
+
spatial_attn = self.channel_restore(spatial_attn)
|
| 224 |
+
|
| 225 |
+
if landmark_mask is not None:
|
| 226 |
+
landmark_mask = torch.tensor(landmark_mask, dtype=x.dtype)
|
| 227 |
+
landmark_mask = landmark_mask.view(-1, 1, 7, 7)
|
| 228 |
+
spatial_attn = spatial_attn * landmark_mask
|
| 229 |
+
|
| 230 |
+
spatial_attn = self.dropout(spatial_attn)
|
| 231 |
+
spatial_out = x * spatial_attn
|
| 232 |
+
channel_attn = self.channel_attention(spatial_out)
|
| 233 |
+
channel_attn = self.dropout(channel_attn)
|
| 234 |
+
out = spatial_out * channel_attn
|
| 235 |
+
out = self.bn(out)
|
| 236 |
+
return out
|
| 237 |
+
|
| 238 |
+
class ViT(nn.Module):
|
| 239 |
+
def __init__(self, in_channels=256, patch_size=1, embed_dim=768, num_layers=8, num_heads=12):
|
| 240 |
+
super(ViT, self).__init__()
|
| 241 |
+
self.patch_embed = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 242 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 243 |
+
num_patches = (7 // patch_size) * (7 // patch_size)
|
| 244 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
| 245 |
+
self.transformer = nn.ModuleList([
|
| 246 |
+
nn.TransformerEncoderLayer(embed_dim, num_heads, dim_feedforward=1536, activation="gelu")
|
| 247 |
+
for _ in range(num_layers)
|
| 248 |
+
])
|
| 249 |
+
self.ln = nn.LayerNorm(embed_dim)
|
| 250 |
+
self.bn = nn.BatchNorm1d(embed_dim, eps=1e-5)
|
| 251 |
+
nn.init.xavier_uniform_(self.patch_embed.weight)
|
| 252 |
+
nn.init.zeros_(self.patch_embed.bias)
|
| 253 |
+
nn.init.normal_(self.cls_token, std=0.02)
|
| 254 |
+
nn.init.normal_(self.pos_embed, std=0.02)
|
| 255 |
+
|
| 256 |
+
def forward(self, x):
|
| 257 |
+
x = self.patch_embed(x)
|
| 258 |
+
x = x.flatten(2).transpose(1, 2)
|
| 259 |
+
cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
|
| 260 |
+
x = torch.cat([cls_tokens, x], dim=1)
|
| 261 |
+
x = x + self.pos_embed
|
| 262 |
+
for layer in self.transformer:
|
| 263 |
+
x = layer(x)
|
| 264 |
+
x = x[:, 0]
|
| 265 |
+
x = self.ln(x)
|
| 266 |
+
x = self.bn(x)
|
| 267 |
+
return x
|
| 268 |
+
|
| 269 |
+
class IntensityStream(nn.Module):
|
| 270 |
+
def __init__(self, in_channels=256):
|
| 271 |
+
super(IntensityStream, self).__init__()
|
| 272 |
+
sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32)
|
| 273 |
+
sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32)
|
| 274 |
+
self.sobel_x = nn.Conv2d(in_channels, in_channels, 3, padding=1, bias=False, groups=in_channels)
|
| 275 |
+
self.sobel_y = nn.Conv2d(in_channels, in_channels, 3, padding=1, bias=False, groups=in_channels)
|
| 276 |
+
self.sobel_x.weight.data = sobel_x.repeat(in_channels, 1, 1, 1)
|
| 277 |
+
self.sobel_y.weight.data = sobel_y.repeat(in_channels, 1, 1, 1)
|
| 278 |
+
self.conv = nn.Conv2d(in_channels, 128, 3, padding=1)
|
| 279 |
+
self.bn = nn.BatchNorm2d(128, eps=1e-5)
|
| 280 |
+
self.pool = nn.AdaptiveAvgPool2d(1)
|
| 281 |
+
self.attention = nn.MultiheadAttention(embed_dim=128, num_heads=1)
|
| 282 |
+
nn.init.xavier_uniform_(self.conv.weight)
|
| 283 |
+
nn.init.zeros_(self.conv.bias)
|
| 284 |
+
|
| 285 |
+
def forward(self, x):
|
| 286 |
+
gx = self.sobel_x(x)
|
| 287 |
+
gy = self.sobel_y(x)
|
| 288 |
+
grad_magnitude = torch.sqrt(gx**2 + gy**2 + 1e-8)
|
| 289 |
+
variance = ((x - x.mean(dim=1, keepdim=True))**2).mean(dim=1).flatten(1)
|
| 290 |
+
cnn_out = F.relu(self.conv(grad_magnitude))
|
| 291 |
+
cnn_out = self.bn(cnn_out)
|
| 292 |
+
texture_out = self.pool(cnn_out).squeeze(-1).squeeze(-1)
|
| 293 |
+
attn_in = cnn_out.flatten(2).permute(2, 0, 1)
|
| 294 |
+
attn_in = attn_in / (attn_in.norm(dim=-1, keepdim=True) + 1e-8)
|
| 295 |
+
attn_out, _ = self.attention(attn_in, attn_in, attn_in)
|
| 296 |
+
context_out = attn_out.mean(dim=0)
|
| 297 |
+
out = torch.cat([texture_out, context_out], dim=1)
|
| 298 |
+
return out, grad_magnitude, variance
|
| 299 |
+
|
| 300 |
+
class LandmarkStream(nn.Module):
|
| 301 |
+
def __init__(self, input_dim=14, embed_dim=768):
|
| 302 |
+
super(LandmarkStream, self).__init__()
|
| 303 |
+
self.fc1 = nn.Linear(input_dim, 128)
|
| 304 |
+
self.fc2 = nn.Linear(128, 256)
|
| 305 |
+
self.fc3 = nn.Linear(256, embed_dim)
|
| 306 |
+
self.bn1 = nn.BatchNorm1d(128)
|
| 307 |
+
self.bn2 = nn.BatchNorm1d(256)
|
| 308 |
+
self.bn3 = nn.BatchNorm1d(embed_dim)
|
| 309 |
+
self.dropout = nn.Dropout(0.4)
|
| 310 |
+
nn.init.xavier_uniform_(self.fc1.weight)
|
| 311 |
+
nn.init.zeros_(self.fc1.bias)
|
| 312 |
+
nn.init.xavier_uniform_(self.fc2.weight)
|
| 313 |
+
nn.init.zeros_(self.fc2.bias)
|
| 314 |
+
nn.init.xavier_uniform_(self.fc3.weight)
|
| 315 |
+
nn.init.zeros_(self.fc3.bias)
|
| 316 |
+
|
| 317 |
+
def forward(self, x):
|
| 318 |
+
x = F.relu(self.bn1(self.fc1(x)))
|
| 319 |
+
x = self.dropout(x)
|
| 320 |
+
x = F.relu(self.bn2(self.fc2(x)))
|
| 321 |
+
x = self.dropout(x)
|
| 322 |
+
x = self.bn3(self.fc3(x))
|
| 323 |
+
return x
|
| 324 |
+
|
| 325 |
+
class QuadStreamHLAViT(nn.Module):
|
| 326 |
+
def __init__(self, num_classes=7):
|
| 327 |
+
super(QuadStreamHLAViT, self).__init__()
|
| 328 |
+
self.backbone = EfficientNetBackbone()
|
| 329 |
+
self.hla = HLA()
|
| 330 |
+
self.vit = ViT()
|
| 331 |
+
self.intensity = IntensityStream()
|
| 332 |
+
self.landmark = LandmarkStream(input_dim=14, embed_dim=768)
|
| 333 |
+
self.fc_hla = nn.Linear(256*7*7, 768)
|
| 334 |
+
self.fc_intensity = nn.Linear(256, 768)
|
| 335 |
+
self.fusion_fc = nn.Linear(768*4, 512)
|
| 336 |
+
self.bn_fusion = nn.BatchNorm1d(512, eps=1e-5)
|
| 337 |
+
self.dropout = nn.Dropout(0.6)
|
| 338 |
+
self.classifier = nn.Linear(512, num_classes)
|
| 339 |
+
nn.init.xavier_uniform_(self.fc_hla.weight)
|
| 340 |
+
nn.init.zeros_(self.fc_hla.bias)
|
| 341 |
+
nn.init.xavier_uniform_(self.fc_intensity.weight)
|
| 342 |
+
nn.init.zeros_(self.fc_intensity.bias)
|
| 343 |
+
nn.init.xavier_uniform_(self.fusion_fc.weight)
|
| 344 |
+
nn.init.zeros_(self.fusion_fc.bias)
|
| 345 |
+
nn.init.xavier_uniform_(self.classifier.weight)
|
| 346 |
+
nn.init.zeros_(self.classifier.bias)
|
| 347 |
+
|
| 348 |
+
def forward(self, x, landmark_features, landmark_mask=None):
|
| 349 |
+
features = self.backbone(x)
|
| 350 |
+
hla_out = self.hla(features, landmark_mask)
|
| 351 |
+
vit_out = self.vit(features)
|
| 352 |
+
intensity_out, grad_magnitude, variance = self.intensity(features)
|
| 353 |
+
landmark_out = self.landmark(landmark_features)
|
| 354 |
+
hla_flat = self.fc_hla(hla_out.view(-1, 256*7*7))
|
| 355 |
+
intensity_flat = self.fc_intensity(intensity_out)
|
| 356 |
+
fused = torch.cat([hla_flat, vit_out, intensity_flat, landmark_out], dim=1)
|
| 357 |
+
fused = F.relu(self.fusion_fc(fused))
|
| 358 |
+
fused = self.bn_fusion(fused)
|
| 359 |
+
fused = self.dropout(fused)
|
| 360 |
+
logits = self.classifier(fused)
|
| 361 |
+
return logits, hla_out, vit_out, grad_magnitude, variance
|
| 362 |
+
|
| 363 |
+
# Load model
|
| 364 |
+
model = QuadStreamHLAViT(num_classes=7)
|
| 365 |
+
if os.path.exists(MODEL_WEIGHTS_PATH):
|
| 366 |
+
try:
|
| 367 |
+
model.load_state_dict(torch.load(MODEL_WEIGHTS_PATH, map_location=torch.device('cpu'), weights_only=True))
|
| 368 |
+
print("Model weights loaded successfully.")
|
| 369 |
+
except Exception as e:
|
| 370 |
+
print(f"Error loading model weights: {e}")
|
| 371 |
+
else:
|
| 372 |
+
print(f"Model weights not found at {MODEL_WEIGHTS_PATH}. Please upload the weights.")
|
| 373 |
+
model.eval()
|
| 374 |
+
|
| 375 |
+
# Inference function
|
| 376 |
def predict_emotion(image):
|
| 377 |
+
try:
|
| 378 |
+
# Convert image to RGB
|
| 379 |
+
if isinstance(image, np.ndarray):
|
| 380 |
+
image = Image.fromarray(image)
|
| 381 |
+
image = image.convert("RGB")
|
| 382 |
+
|
| 383 |
+
# Extract landmarks and mask
|
| 384 |
+
lm_features = extract_landmark_features(image)
|
| 385 |
+
lm_mask = get_landmark_mask(image)
|
| 386 |
+
|
| 387 |
+
# Transform image
|
| 388 |
+
img_tensor = transform(image).unsqueeze(0)
|
| 389 |
+
lm_features_tensor = torch.tensor(lm_features, dtype=torch.float32).unsqueeze(0)
|
| 390 |
+
|
| 391 |
+
# Run inference
|
| 392 |
+
with torch.no_grad():
|
| 393 |
+
outputs, _, _, _, _ = model(img_tensor, lm_features_tensor, lm_mask)
|
| 394 |
+
probs = F.softmax(outputs, dim=1)[0]
|
| 395 |
+
pred_label = torch.argmax(probs).item()
|
| 396 |
+
pred_emotion = class_mapping[pred_label]
|
| 397 |
+
|
| 398 |
+
# Format probabilities
|
| 399 |
+
prob_dict = {class_mapping[i]: f"{probs[i].item():.4f}" for i in range(len(class_mapping))}
|
| 400 |
+
|
| 401 |
+
return pred_emotion, prob_dict
|
| 402 |
+
except Exception as e:
|
| 403 |
+
return "Error", {"Message": f"Failed to process image: {str(e)}"}
|
| 404 |
|
| 405 |
+
# Gradio interface
|
| 406 |
+
iface = gr.Interface(
|
| 407 |
fn=predict_emotion,
|
| 408 |
+
inputs=gr.Image(type="pil", label="Upload an Image"),
|
| 409 |
+
outputs=[
|
| 410 |
+
gr.Textbox(label="Predicted Emotion"),
|
| 411 |
+
gr.JSON(label="Emotion Probabilities")
|
| 412 |
+
],
|
| 413 |
+
title="Facial Emotion Recognition with QuadStreamHLAViT",
|
| 414 |
+
description="Upload an image to predict facial emotions (Surprise, Fear, Disgust, Happiness, Sadness, Anger, Neutral) using a QuadStreamHLAViT model trained on RAF-DB. Model accuracy: 82.31%.",
|
| 415 |
+
allow_flagging="never"
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
# Launch the app
|
| 419 |
+
if __name__ == "__main__":
|
| 420 |
+
iface.launch()
|