File size: 13,453 Bytes
1e4fc28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
# app/vit_utils.py
"""
Utilities for Vision Transformer (ViT) model preprocessing and prediction.
"""
import cv2
import numpy as np
from PIL import Image
from typing import Optional, Tuple, Dict, Any
from app.utils import preprocess_face  # Reuse face detection

def preprocess_face_for_vit(
    image_path: str,
    detect_max_dim: int = 800,
    pad_ratio: float = 0.35,  # Increased to 0.35 to include more facial context - helps with happy detection (smile needs more context)
) -> Tuple[Optional[Image.Image], Optional[str]]:
    """
    Preprocess face for Vision Transformer model.
    ViT needs RGB images at 224x224, not grayscale 48x48.
    
    Returns: (PIL Image, filename) or (None, None) if no face detected
    """
    # First detect and crop face (reuse existing detection logic)
    # But we'll keep it in RGB and resize to 224x224
    try:
        img = cv2.imread(image_path)
        if img is None:
            return None, None

        h0, w0 = img.shape[:2]
        # Keep RGB for ViT (not grayscale)
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        gray_full = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

        # Downscale for faster detection if image is huge
        scale = 1.0
        max_side = max(w0, h0)
        if max_side > detect_max_dim:
            scale = detect_max_dim / float(max_side)
            small = cv2.resize(gray_full, (int(w0 * scale), int(h0 * scale)), interpolation=cv2.INTER_LINEAR)
        else:
            small = gray_full.copy()

        # Enhance for detection
        from app.utils import _enhance_for_detection
        small_enh = _enhance_for_detection(small)

        # Optimized face detection: 2 cascades × 2 param sets = 4 attempts (fast)
        # Then fallback to 3rd cascade if needed = +2 attempts (total 6 max)
        # This balances speed (4 attempts) with reliability (6 attempts if needed)
        cascade_paths_primary = [
            "haarcascade_frontalface_default.xml",  # Most reliable
            "haarcascade_frontalface_alt.xml",      # Good fallback
        ]
        
        cascade_paths_fallback = [
            "haarcascade_frontalface_alt2.xml",     # Last resort
        ]
        
        faces = []
        
        # Primary: Try 2 cascades with 2 param sets each (4 attempts, fast path)
        for cascade_name in cascade_paths_primary:
            if len(faces) > 0:
                break
            try:
                face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + cascade_name)
                if face_cascade.empty():
                    continue
                
                # Attempt 1: Most common successful params (catches 90%+ of faces)
                faces = face_cascade.detectMultiScale(
                    small_enh,
                    scaleFactor=1.05,
                    minNeighbors=3,
                    minSize=(20, 20),
                    flags=cv2.CASCADE_SCALE_IMAGE,
                )
                
                # Attempt 2: More permissive (catches challenging cases)
                if len(faces) == 0:
                    faces = face_cascade.detectMultiScale(
                        small_enh,
                        scaleFactor=1.03,
                        minNeighbors=2,
                        minSize=(15, 15),
                        flags=cv2.CASCADE_SCALE_IMAGE,
                    )
                    
            except Exception:
                continue
        
        # Fallback: Only try 3rd cascade if primary failed (adds 2 more attempts)
        if len(faces) == 0:
            for cascade_name in cascade_paths_fallback:
                if len(faces) > 0:
                    break
                try:
                    face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + cascade_name)
                    if face_cascade.empty():
                        continue
                    
                    # Try with permissive params
                    for scale_factor, min_neighbors, min_size in [
                        (1.05, 3, (20, 20)),
                        (1.03, 2, (15, 15)),
                    ]:
                        faces = face_cascade.detectMultiScale(
                            small_enh,
                            scaleFactor=scale_factor,
                            minNeighbors=min_neighbors,
                            minSize=min_size,
                            flags=cv2.CASCADE_SCALE_IMAGE,
                        )
                        if len(faces) > 0:
                            break
                except Exception:
                    continue
        
        # Fallback 1: Try on original (non-enhanced) image if enhanced failed
        # Only try once with best params (don't waste time on multiple attempts)
        if len(faces) == 0:
            try:
                face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_default.xml")
                if not face_cascade.empty():
                    # Single attempt with most successful params (faster than trying multiple)
                    faces = face_cascade.detectMultiScale(
                        small,  # Use original, not enhanced
                        scaleFactor=1.05,
                        minNeighbors=3,
                        minSize=(20, 20),
                        flags=cv2.CASCADE_SCALE_IMAGE,
                    )
            except Exception:
                pass
        
        # Fallback 2: Try on full-size image ONLY if:
        # 1. Still no face found
        # 2. Image was actually downscaled (max_side > 800)
        # 3. Scale is significantly reduced (scale < 0.5, meaning image is 2x+ larger)
        # This prevents slow full-size detection on images that are only slightly over 800px
        if len(faces) == 0 and max_side > detect_max_dim and scale < 0.5:
            try:
                face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_default.xml")
                if not face_cascade.empty():
                    # Single attempt with permissive params (full-size is slow, so only try once)
                    faces = face_cascade.detectMultiScale(
                        gray_full,
                        scaleFactor=1.05,
                        minNeighbors=2,
                        minSize=(30, 30),  # Larger min size for full-res
                        flags=cv2.CASCADE_SCALE_IMAGE,
                    )
            except Exception:
                pass
        
        if len(faces) == 0:
            return None, None

        # Choose largest face
        faces = sorted(faces, key=lambda r: r[2] * r[3], reverse=True)
        (x_s, y_s, w_s, h_s) = faces[0]

        # Map back to original scale (only if we used downscaled detection)
        # If we detected on full-size image, coordinates are already correct
        if max_side > detect_max_dim and scale < 1.0:
            # Detection was on downscaled image
            x = int(x_s / scale)
            y = int(y_s / scale)
            w = int(w_s / scale)
            h = int(h_s / scale)
        else:
            # Detection was on full-size or original scale
            x = x_s
            y = y_s
            w = w_s
            h = h_s

        # Pad bounding box
        pad_w = int(w * pad_ratio)
        pad_h = int(h * pad_ratio)
        x1 = max(0, x - pad_w)
        y1 = max(0, y - pad_h)
        x2 = min(w0, x + w + pad_w)
        y2 = min(h0, y + h + pad_h)

        # Crop face from RGB image (not grayscale)
        face_crop = img_rgb[y1:y2, x1:x2]
        
        # Convert to PIL Image and resize to 224x224 (ViT input size)
        # Use BICUBIC for best quality (emotion recognition needs detail)
        # Note: ViT processor handles normalization, so we don't apply CLAHE here
        # CLAHE can interfere with the model's expected input distribution
        face_pil = Image.fromarray(face_crop)
        face_pil = face_pil.resize((224, 224), Image.Resampling.BICUBIC)

        import os
        used_filename = os.path.basename(image_path) or "upload.jpg"
        return face_pil, used_filename

    except Exception as e:
        import logging
        logger = logging.getLogger(__name__)
        logger.exception(f"Exception in preprocess_face_for_vit for {image_path}: {e}")
        return None, None

def predict_with_vit(
    model_dict: Dict[str, Any],
    image: Image.Image,
    labels: list
) -> Tuple[int, float, Dict[str, float]]:
    """
    Run prediction using Vision Transformer model.
    Enhanced for better accuracy with image preprocessing.
    
    Args:
        model_dict: {'model': model, 'processor': processor, 'type': 'vit'}
        image: PIL Image (224x224 RGB)
        labels: List of emotion labels
    
    Returns:
        (predicted_index, confidence, all_probabilities_dict)
    """
    processor = model_dict['processor']
    model = model_dict['model']
    
    # Ensure image is RGB (some images might be RGBA or grayscale)
    if image.mode != 'RGB':
        image = image.convert('RGB')
    
    # Preprocess image for ViT (processor handles normalization)
    inputs = processor(image, return_tensors="pt")
    
    # Run prediction - optimized for speed
    import torch
    import torch.nn.functional as F
    
    model.eval()
    # Use inference_mode() instead of no_grad() - faster for inference-only
    with torch.inference_mode():  # Faster than no_grad() for pure inference
        outputs = model(**inputs)
        logits = outputs.logits
    
    # Get probabilities (softmax) - optimized conversion
    probs = F.softmax(logits, dim=-1)
    probs_np = probs[0].cpu().numpy()  # Direct indexing, no detach needed in inference_mode
    
    # Get predicted class
    predicted_idx = int(torch.argmax(logits, dim=-1).item())
    confidence = float(probs_np[predicted_idx])
    
    # Create probabilities dict - use model's id2label directly to ensure correct mapping
    all_probs = {}
    model = model_dict['model']
    for i, prob in enumerate(probs_np):
        # Use model's id2label for accurate label mapping
        if hasattr(model, 'config') and hasattr(model.config, 'id2label'):
            raw_label = model.config.id2label.get(i, f"class_{i}")
            # Normalize label name
            label_map = {
                'anger': 'angry',
                'disgust': 'disgust',
                'fear': 'fear',
                'happy': 'happy',
                'neutral': 'neutral',
                'sad': 'sad',
                'surprise': 'surprise',
                'contempt': 'contempt'
            }
            normalized_label = label_map.get(raw_label.lower(), raw_label.lower())
            all_probs[normalized_label] = float(prob)
        elif i < len(labels):
            all_probs[labels[i]] = float(prob)
        else:
            all_probs[f"class_{i}"] = float(prob)
    
    # Post-processing: If happy probability is reasonable (>0.05) but contempt/neutral is high,
    # and happy is in top 3, boost happy probability (model has known happy/contempt confusion)
    happy_prob = all_probs.get('happy', 0.0)
    contempt_prob = all_probs.get('contempt', 0.0)
    neutral_prob = all_probs.get('neutral', 0.0)
    
    # If happy is in top 3 probabilities and contempt/neutral is suspiciously high
    sorted_probs = sorted(all_probs.items(), key=lambda x: x[1], reverse=True)
    top_3_emotions = [e[0] for e in sorted_probs[:3]]
    
    if 'happy' in top_3_emotions and happy_prob > 0.05:
        # If contempt or neutral is highest but happy is close, boost happy
        if (contempt_prob > 0.4 or neutral_prob > 0.4) and happy_prob > 0.05:
            # Boost happy by 30% (helps correct misclassifications)
            boost_factor = 1.3
            boosted_happy = min(1.0, happy_prob * boost_factor)
            
            # Reduce contempt/neutral proportionally to maintain probability sum
            reduction = (boosted_happy - happy_prob) / 2
            new_contempt = max(0.0, contempt_prob - reduction)
            new_neutral = max(0.0, neutral_prob - reduction)
            
            # Update probabilities
            all_probs['happy'] = boosted_happy
            all_probs['contempt'] = new_contempt
            all_probs['neutral'] = new_neutral
            
            # Re-normalize to ensure sum is ~1.0
            total = sum(all_probs.values())
            if total > 0:
                all_probs = {k: v / total for k, v in all_probs.items()}
            
            # Recalculate predicted class after boosting - find emotion with highest prob
            new_top_emotion = max(all_probs.items(), key=lambda x: x[1])[0]
            
            # Find index in labels list
            if new_top_emotion in labels:
                predicted_idx = labels.index(new_top_emotion)
                confidence = all_probs[new_top_emotion]
                print(f"[VIT] Post-processing: Boosted happy from {happy_prob:.3f} to {all_probs.get('happy', 0.0):.3f}, new prediction: {new_top_emotion}")
            else:
                # Fallback to original prediction if label not found
                print(f"[VIT] Post-processing: Boosted happy but couldn't find label {new_top_emotion} in labels list")
    
    print(f"[VIT] Predicted index: {predicted_idx}, Raw label from model: {model.config.id2label.get(predicted_idx, 'unknown')}")
    
    return predicted_idx, confidence, all_probs