File size: 3,048 Bytes
25d0747
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
import torch.nn as nn
from transformers import ViTModel, ViTConfig
from torchvision import transforms
import cv2
import numpy as np

class VisionEmotionModel(nn.Module):
    """
    Vision Transformer for facial emotion recognition.
    Fine-tuned on FER-2013/AffectNet datasets.
    """
    def __init__(self, num_emotions=7, pretrained=True):
        super().__init__()
        self.num_emotions = num_emotions

        # Load pre-trained ViT
        if pretrained:
            self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224')
        else:
            config = ViTConfig()
            self.vit = ViTModel(config)

        # Freeze base layers if fine-tuning
        for param in self.vit.parameters():
            param.requires_grad = False

        # Emotion classification head
        self.emotion_classifier = nn.Sequential(
            nn.Linear(self.vit.config.hidden_size, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_emotions)
        )

        # Confidence estimation
        self.confidence_head = nn.Sequential(
            nn.Linear(self.vit.config.hidden_size, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid()  # 0-1 confidence
        )

        # Image preprocessing
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def forward(self, x):
        """
        x: batch of images (B, C, H, W) or list of face crops
        Returns: emotion_logits, confidence
        """
        if isinstance(x, list):
            # Handle list of face images
            batch = torch.stack([self.transform(img) for img in x])
        else:
            batch = x

        outputs = self.vit(pixel_values=batch)
        cls_token = outputs.last_hidden_state[:, 0, :]  # [CLS] token

        emotion_logits = self.emotion_classifier(cls_token)
        confidence = self.confidence_head(cls_token)

        return emotion_logits, confidence.squeeze()

    def detect_faces(self, frame):
        """
        Detect faces in a video frame using OpenCV.
        Returns list of face crops.
        """
        face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        faces = face_cascade.detectMultiScale(gray, 1.1, 4)

        face_crops = []
        for (x, y, w, h) in faces:
            face = frame[y:y+h, x:x+w]
            if face.size > 0:
                face_crops.append(face)

        return face_crops

    def extract_features(self, faces):
        """
        Extract emotion features from detected faces.
        """
        if not faces:
            return None, None

        with torch.no_grad():
            emotion_logits, confidence = self.forward(faces)

        return emotion_logits, confidence