emotion-detection-api / app /vit_utils.py
HimAJ's picture
upload 32 files for the ml
1e4fc28 verified
# app/vit_utils.py
"""
Utilities for Vision Transformer (ViT) model preprocessing and prediction.
"""
import cv2
import numpy as np
from PIL import Image
from typing import Optional, Tuple, Dict, Any
from app.utils import preprocess_face # Reuse face detection
def preprocess_face_for_vit(
image_path: str,
detect_max_dim: int = 800,
pad_ratio: float = 0.35, # Increased to 0.35 to include more facial context - helps with happy detection (smile needs more context)
) -> Tuple[Optional[Image.Image], Optional[str]]:
"""
Preprocess face for Vision Transformer model.
ViT needs RGB images at 224x224, not grayscale 48x48.
Returns: (PIL Image, filename) or (None, None) if no face detected
"""
# First detect and crop face (reuse existing detection logic)
# But we'll keep it in RGB and resize to 224x224
try:
img = cv2.imread(image_path)
if img is None:
return None, None
h0, w0 = img.shape[:2]
# Keep RGB for ViT (not grayscale)
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
gray_full = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# Downscale for faster detection if image is huge
scale = 1.0
max_side = max(w0, h0)
if max_side > detect_max_dim:
scale = detect_max_dim / float(max_side)
small = cv2.resize(gray_full, (int(w0 * scale), int(h0 * scale)), interpolation=cv2.INTER_LINEAR)
else:
small = gray_full.copy()
# Enhance for detection
from app.utils import _enhance_for_detection
small_enh = _enhance_for_detection(small)
# Optimized face detection: 2 cascades × 2 param sets = 4 attempts (fast)
# Then fallback to 3rd cascade if needed = +2 attempts (total 6 max)
# This balances speed (4 attempts) with reliability (6 attempts if needed)
cascade_paths_primary = [
"haarcascade_frontalface_default.xml", # Most reliable
"haarcascade_frontalface_alt.xml", # Good fallback
]
cascade_paths_fallback = [
"haarcascade_frontalface_alt2.xml", # Last resort
]
faces = []
# Primary: Try 2 cascades with 2 param sets each (4 attempts, fast path)
for cascade_name in cascade_paths_primary:
if len(faces) > 0:
break
try:
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + cascade_name)
if face_cascade.empty():
continue
# Attempt 1: Most common successful params (catches 90%+ of faces)
faces = face_cascade.detectMultiScale(
small_enh,
scaleFactor=1.05,
minNeighbors=3,
minSize=(20, 20),
flags=cv2.CASCADE_SCALE_IMAGE,
)
# Attempt 2: More permissive (catches challenging cases)
if len(faces) == 0:
faces = face_cascade.detectMultiScale(
small_enh,
scaleFactor=1.03,
minNeighbors=2,
minSize=(15, 15),
flags=cv2.CASCADE_SCALE_IMAGE,
)
except Exception:
continue
# Fallback: Only try 3rd cascade if primary failed (adds 2 more attempts)
if len(faces) == 0:
for cascade_name in cascade_paths_fallback:
if len(faces) > 0:
break
try:
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + cascade_name)
if face_cascade.empty():
continue
# Try with permissive params
for scale_factor, min_neighbors, min_size in [
(1.05, 3, (20, 20)),
(1.03, 2, (15, 15)),
]:
faces = face_cascade.detectMultiScale(
small_enh,
scaleFactor=scale_factor,
minNeighbors=min_neighbors,
minSize=min_size,
flags=cv2.CASCADE_SCALE_IMAGE,
)
if len(faces) > 0:
break
except Exception:
continue
# Fallback 1: Try on original (non-enhanced) image if enhanced failed
# Only try once with best params (don't waste time on multiple attempts)
if len(faces) == 0:
try:
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_default.xml")
if not face_cascade.empty():
# Single attempt with most successful params (faster than trying multiple)
faces = face_cascade.detectMultiScale(
small, # Use original, not enhanced
scaleFactor=1.05,
minNeighbors=3,
minSize=(20, 20),
flags=cv2.CASCADE_SCALE_IMAGE,
)
except Exception:
pass
# Fallback 2: Try on full-size image ONLY if:
# 1. Still no face found
# 2. Image was actually downscaled (max_side > 800)
# 3. Scale is significantly reduced (scale < 0.5, meaning image is 2x+ larger)
# This prevents slow full-size detection on images that are only slightly over 800px
if len(faces) == 0 and max_side > detect_max_dim and scale < 0.5:
try:
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_default.xml")
if not face_cascade.empty():
# Single attempt with permissive params (full-size is slow, so only try once)
faces = face_cascade.detectMultiScale(
gray_full,
scaleFactor=1.05,
minNeighbors=2,
minSize=(30, 30), # Larger min size for full-res
flags=cv2.CASCADE_SCALE_IMAGE,
)
except Exception:
pass
if len(faces) == 0:
return None, None
# Choose largest face
faces = sorted(faces, key=lambda r: r[2] * r[3], reverse=True)
(x_s, y_s, w_s, h_s) = faces[0]
# Map back to original scale (only if we used downscaled detection)
# If we detected on full-size image, coordinates are already correct
if max_side > detect_max_dim and scale < 1.0:
# Detection was on downscaled image
x = int(x_s / scale)
y = int(y_s / scale)
w = int(w_s / scale)
h = int(h_s / scale)
else:
# Detection was on full-size or original scale
x = x_s
y = y_s
w = w_s
h = h_s
# Pad bounding box
pad_w = int(w * pad_ratio)
pad_h = int(h * pad_ratio)
x1 = max(0, x - pad_w)
y1 = max(0, y - pad_h)
x2 = min(w0, x + w + pad_w)
y2 = min(h0, y + h + pad_h)
# Crop face from RGB image (not grayscale)
face_crop = img_rgb[y1:y2, x1:x2]
# Convert to PIL Image and resize to 224x224 (ViT input size)
# Use BICUBIC for best quality (emotion recognition needs detail)
# Note: ViT processor handles normalization, so we don't apply CLAHE here
# CLAHE can interfere with the model's expected input distribution
face_pil = Image.fromarray(face_crop)
face_pil = face_pil.resize((224, 224), Image.Resampling.BICUBIC)
import os
used_filename = os.path.basename(image_path) or "upload.jpg"
return face_pil, used_filename
except Exception as e:
import logging
logger = logging.getLogger(__name__)
logger.exception(f"Exception in preprocess_face_for_vit for {image_path}: {e}")
return None, None
def predict_with_vit(
model_dict: Dict[str, Any],
image: Image.Image,
labels: list
) -> Tuple[int, float, Dict[str, float]]:
"""
Run prediction using Vision Transformer model.
Enhanced for better accuracy with image preprocessing.
Args:
model_dict: {'model': model, 'processor': processor, 'type': 'vit'}
image: PIL Image (224x224 RGB)
labels: List of emotion labels
Returns:
(predicted_index, confidence, all_probabilities_dict)
"""
processor = model_dict['processor']
model = model_dict['model']
# Ensure image is RGB (some images might be RGBA or grayscale)
if image.mode != 'RGB':
image = image.convert('RGB')
# Preprocess image for ViT (processor handles normalization)
inputs = processor(image, return_tensors="pt")
# Run prediction - optimized for speed
import torch
import torch.nn.functional as F
model.eval()
# Use inference_mode() instead of no_grad() - faster for inference-only
with torch.inference_mode(): # Faster than no_grad() for pure inference
outputs = model(**inputs)
logits = outputs.logits
# Get probabilities (softmax) - optimized conversion
probs = F.softmax(logits, dim=-1)
probs_np = probs[0].cpu().numpy() # Direct indexing, no detach needed in inference_mode
# Get predicted class
predicted_idx = int(torch.argmax(logits, dim=-1).item())
confidence = float(probs_np[predicted_idx])
# Create probabilities dict - use model's id2label directly to ensure correct mapping
all_probs = {}
model = model_dict['model']
for i, prob in enumerate(probs_np):
# Use model's id2label for accurate label mapping
if hasattr(model, 'config') and hasattr(model.config, 'id2label'):
raw_label = model.config.id2label.get(i, f"class_{i}")
# Normalize label name
label_map = {
'anger': 'angry',
'disgust': 'disgust',
'fear': 'fear',
'happy': 'happy',
'neutral': 'neutral',
'sad': 'sad',
'surprise': 'surprise',
'contempt': 'contempt'
}
normalized_label = label_map.get(raw_label.lower(), raw_label.lower())
all_probs[normalized_label] = float(prob)
elif i < len(labels):
all_probs[labels[i]] = float(prob)
else:
all_probs[f"class_{i}"] = float(prob)
# Post-processing: If happy probability is reasonable (>0.05) but contempt/neutral is high,
# and happy is in top 3, boost happy probability (model has known happy/contempt confusion)
happy_prob = all_probs.get('happy', 0.0)
contempt_prob = all_probs.get('contempt', 0.0)
neutral_prob = all_probs.get('neutral', 0.0)
# If happy is in top 3 probabilities and contempt/neutral is suspiciously high
sorted_probs = sorted(all_probs.items(), key=lambda x: x[1], reverse=True)
top_3_emotions = [e[0] for e in sorted_probs[:3]]
if 'happy' in top_3_emotions and happy_prob > 0.05:
# If contempt or neutral is highest but happy is close, boost happy
if (contempt_prob > 0.4 or neutral_prob > 0.4) and happy_prob > 0.05:
# Boost happy by 30% (helps correct misclassifications)
boost_factor = 1.3
boosted_happy = min(1.0, happy_prob * boost_factor)
# Reduce contempt/neutral proportionally to maintain probability sum
reduction = (boosted_happy - happy_prob) / 2
new_contempt = max(0.0, contempt_prob - reduction)
new_neutral = max(0.0, neutral_prob - reduction)
# Update probabilities
all_probs['happy'] = boosted_happy
all_probs['contempt'] = new_contempt
all_probs['neutral'] = new_neutral
# Re-normalize to ensure sum is ~1.0
total = sum(all_probs.values())
if total > 0:
all_probs = {k: v / total for k, v in all_probs.items()}
# Recalculate predicted class after boosting - find emotion with highest prob
new_top_emotion = max(all_probs.items(), key=lambda x: x[1])[0]
# Find index in labels list
if new_top_emotion in labels:
predicted_idx = labels.index(new_top_emotion)
confidence = all_probs[new_top_emotion]
print(f"[VIT] Post-processing: Boosted happy from {happy_prob:.3f} to {all_probs.get('happy', 0.0):.3f}, new prediction: {new_top_emotion}")
else:
# Fallback to original prediction if label not found
print(f"[VIT] Post-processing: Boosted happy but couldn't find label {new_top_emotion} in labels list")
print(f"[VIT] Predicted index: {predicted_idx}, Raw label from model: {model.config.id2label.get(predicted_idx, 'unknown')}")
return predicted_idx, confidence, all_probs