Spaces:
Running
Running
| # app/vit_utils.py | |
| """ | |
| Utilities for Vision Transformer (ViT) model preprocessing and prediction. | |
| """ | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| from typing import Optional, Tuple, Dict, Any | |
| from app.utils import preprocess_face # Reuse face detection | |
| def preprocess_face_for_vit( | |
| image_path: str, | |
| detect_max_dim: int = 800, | |
| pad_ratio: float = 0.35, # Increased to 0.35 to include more facial context - helps with happy detection (smile needs more context) | |
| ) -> Tuple[Optional[Image.Image], Optional[str]]: | |
| """ | |
| Preprocess face for Vision Transformer model. | |
| ViT needs RGB images at 224x224, not grayscale 48x48. | |
| Returns: (PIL Image, filename) or (None, None) if no face detected | |
| """ | |
| # First detect and crop face (reuse existing detection logic) | |
| # But we'll keep it in RGB and resize to 224x224 | |
| try: | |
| img = cv2.imread(image_path) | |
| if img is None: | |
| return None, None | |
| h0, w0 = img.shape[:2] | |
| # Keep RGB for ViT (not grayscale) | |
| img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| gray_full = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
| # Downscale for faster detection if image is huge | |
| scale = 1.0 | |
| max_side = max(w0, h0) | |
| if max_side > detect_max_dim: | |
| scale = detect_max_dim / float(max_side) | |
| small = cv2.resize(gray_full, (int(w0 * scale), int(h0 * scale)), interpolation=cv2.INTER_LINEAR) | |
| else: | |
| small = gray_full.copy() | |
| # Enhance for detection | |
| from app.utils import _enhance_for_detection | |
| small_enh = _enhance_for_detection(small) | |
| # Optimized face detection: 2 cascades × 2 param sets = 4 attempts (fast) | |
| # Then fallback to 3rd cascade if needed = +2 attempts (total 6 max) | |
| # This balances speed (4 attempts) with reliability (6 attempts if needed) | |
| cascade_paths_primary = [ | |
| "haarcascade_frontalface_default.xml", # Most reliable | |
| "haarcascade_frontalface_alt.xml", # Good fallback | |
| ] | |
| cascade_paths_fallback = [ | |
| "haarcascade_frontalface_alt2.xml", # Last resort | |
| ] | |
| faces = [] | |
| # Primary: Try 2 cascades with 2 param sets each (4 attempts, fast path) | |
| for cascade_name in cascade_paths_primary: | |
| if len(faces) > 0: | |
| break | |
| try: | |
| face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + cascade_name) | |
| if face_cascade.empty(): | |
| continue | |
| # Attempt 1: Most common successful params (catches 90%+ of faces) | |
| faces = face_cascade.detectMultiScale( | |
| small_enh, | |
| scaleFactor=1.05, | |
| minNeighbors=3, | |
| minSize=(20, 20), | |
| flags=cv2.CASCADE_SCALE_IMAGE, | |
| ) | |
| # Attempt 2: More permissive (catches challenging cases) | |
| if len(faces) == 0: | |
| faces = face_cascade.detectMultiScale( | |
| small_enh, | |
| scaleFactor=1.03, | |
| minNeighbors=2, | |
| minSize=(15, 15), | |
| flags=cv2.CASCADE_SCALE_IMAGE, | |
| ) | |
| except Exception: | |
| continue | |
| # Fallback: Only try 3rd cascade if primary failed (adds 2 more attempts) | |
| if len(faces) == 0: | |
| for cascade_name in cascade_paths_fallback: | |
| if len(faces) > 0: | |
| break | |
| try: | |
| face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + cascade_name) | |
| if face_cascade.empty(): | |
| continue | |
| # Try with permissive params | |
| for scale_factor, min_neighbors, min_size in [ | |
| (1.05, 3, (20, 20)), | |
| (1.03, 2, (15, 15)), | |
| ]: | |
| faces = face_cascade.detectMultiScale( | |
| small_enh, | |
| scaleFactor=scale_factor, | |
| minNeighbors=min_neighbors, | |
| minSize=min_size, | |
| flags=cv2.CASCADE_SCALE_IMAGE, | |
| ) | |
| if len(faces) > 0: | |
| break | |
| except Exception: | |
| continue | |
| # Fallback 1: Try on original (non-enhanced) image if enhanced failed | |
| # Only try once with best params (don't waste time on multiple attempts) | |
| if len(faces) == 0: | |
| try: | |
| face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_default.xml") | |
| if not face_cascade.empty(): | |
| # Single attempt with most successful params (faster than trying multiple) | |
| faces = face_cascade.detectMultiScale( | |
| small, # Use original, not enhanced | |
| scaleFactor=1.05, | |
| minNeighbors=3, | |
| minSize=(20, 20), | |
| flags=cv2.CASCADE_SCALE_IMAGE, | |
| ) | |
| except Exception: | |
| pass | |
| # Fallback 2: Try on full-size image ONLY if: | |
| # 1. Still no face found | |
| # 2. Image was actually downscaled (max_side > 800) | |
| # 3. Scale is significantly reduced (scale < 0.5, meaning image is 2x+ larger) | |
| # This prevents slow full-size detection on images that are only slightly over 800px | |
| if len(faces) == 0 and max_side > detect_max_dim and scale < 0.5: | |
| try: | |
| face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_default.xml") | |
| if not face_cascade.empty(): | |
| # Single attempt with permissive params (full-size is slow, so only try once) | |
| faces = face_cascade.detectMultiScale( | |
| gray_full, | |
| scaleFactor=1.05, | |
| minNeighbors=2, | |
| minSize=(30, 30), # Larger min size for full-res | |
| flags=cv2.CASCADE_SCALE_IMAGE, | |
| ) | |
| except Exception: | |
| pass | |
| if len(faces) == 0: | |
| return None, None | |
| # Choose largest face | |
| faces = sorted(faces, key=lambda r: r[2] * r[3], reverse=True) | |
| (x_s, y_s, w_s, h_s) = faces[0] | |
| # Map back to original scale (only if we used downscaled detection) | |
| # If we detected on full-size image, coordinates are already correct | |
| if max_side > detect_max_dim and scale < 1.0: | |
| # Detection was on downscaled image | |
| x = int(x_s / scale) | |
| y = int(y_s / scale) | |
| w = int(w_s / scale) | |
| h = int(h_s / scale) | |
| else: | |
| # Detection was on full-size or original scale | |
| x = x_s | |
| y = y_s | |
| w = w_s | |
| h = h_s | |
| # Pad bounding box | |
| pad_w = int(w * pad_ratio) | |
| pad_h = int(h * pad_ratio) | |
| x1 = max(0, x - pad_w) | |
| y1 = max(0, y - pad_h) | |
| x2 = min(w0, x + w + pad_w) | |
| y2 = min(h0, y + h + pad_h) | |
| # Crop face from RGB image (not grayscale) | |
| face_crop = img_rgb[y1:y2, x1:x2] | |
| # Convert to PIL Image and resize to 224x224 (ViT input size) | |
| # Use BICUBIC for best quality (emotion recognition needs detail) | |
| # Note: ViT processor handles normalization, so we don't apply CLAHE here | |
| # CLAHE can interfere with the model's expected input distribution | |
| face_pil = Image.fromarray(face_crop) | |
| face_pil = face_pil.resize((224, 224), Image.Resampling.BICUBIC) | |
| import os | |
| used_filename = os.path.basename(image_path) or "upload.jpg" | |
| return face_pil, used_filename | |
| except Exception as e: | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| logger.exception(f"Exception in preprocess_face_for_vit for {image_path}: {e}") | |
| return None, None | |
| def predict_with_vit( | |
| model_dict: Dict[str, Any], | |
| image: Image.Image, | |
| labels: list | |
| ) -> Tuple[int, float, Dict[str, float]]: | |
| """ | |
| Run prediction using Vision Transformer model. | |
| Enhanced for better accuracy with image preprocessing. | |
| Args: | |
| model_dict: {'model': model, 'processor': processor, 'type': 'vit'} | |
| image: PIL Image (224x224 RGB) | |
| labels: List of emotion labels | |
| Returns: | |
| (predicted_index, confidence, all_probabilities_dict) | |
| """ | |
| processor = model_dict['processor'] | |
| model = model_dict['model'] | |
| # Ensure image is RGB (some images might be RGBA or grayscale) | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Preprocess image for ViT (processor handles normalization) | |
| inputs = processor(image, return_tensors="pt") | |
| # Run prediction - optimized for speed | |
| import torch | |
| import torch.nn.functional as F | |
| model.eval() | |
| # Use inference_mode() instead of no_grad() - faster for inference-only | |
| with torch.inference_mode(): # Faster than no_grad() for pure inference | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| # Get probabilities (softmax) - optimized conversion | |
| probs = F.softmax(logits, dim=-1) | |
| probs_np = probs[0].cpu().numpy() # Direct indexing, no detach needed in inference_mode | |
| # Get predicted class | |
| predicted_idx = int(torch.argmax(logits, dim=-1).item()) | |
| confidence = float(probs_np[predicted_idx]) | |
| # Create probabilities dict - use model's id2label directly to ensure correct mapping | |
| all_probs = {} | |
| model = model_dict['model'] | |
| for i, prob in enumerate(probs_np): | |
| # Use model's id2label for accurate label mapping | |
| if hasattr(model, 'config') and hasattr(model.config, 'id2label'): | |
| raw_label = model.config.id2label.get(i, f"class_{i}") | |
| # Normalize label name | |
| label_map = { | |
| 'anger': 'angry', | |
| 'disgust': 'disgust', | |
| 'fear': 'fear', | |
| 'happy': 'happy', | |
| 'neutral': 'neutral', | |
| 'sad': 'sad', | |
| 'surprise': 'surprise', | |
| 'contempt': 'contempt' | |
| } | |
| normalized_label = label_map.get(raw_label.lower(), raw_label.lower()) | |
| all_probs[normalized_label] = float(prob) | |
| elif i < len(labels): | |
| all_probs[labels[i]] = float(prob) | |
| else: | |
| all_probs[f"class_{i}"] = float(prob) | |
| # Post-processing: If happy probability is reasonable (>0.05) but contempt/neutral is high, | |
| # and happy is in top 3, boost happy probability (model has known happy/contempt confusion) | |
| happy_prob = all_probs.get('happy', 0.0) | |
| contempt_prob = all_probs.get('contempt', 0.0) | |
| neutral_prob = all_probs.get('neutral', 0.0) | |
| # If happy is in top 3 probabilities and contempt/neutral is suspiciously high | |
| sorted_probs = sorted(all_probs.items(), key=lambda x: x[1], reverse=True) | |
| top_3_emotions = [e[0] for e in sorted_probs[:3]] | |
| if 'happy' in top_3_emotions and happy_prob > 0.05: | |
| # If contempt or neutral is highest but happy is close, boost happy | |
| if (contempt_prob > 0.4 or neutral_prob > 0.4) and happy_prob > 0.05: | |
| # Boost happy by 30% (helps correct misclassifications) | |
| boost_factor = 1.3 | |
| boosted_happy = min(1.0, happy_prob * boost_factor) | |
| # Reduce contempt/neutral proportionally to maintain probability sum | |
| reduction = (boosted_happy - happy_prob) / 2 | |
| new_contempt = max(0.0, contempt_prob - reduction) | |
| new_neutral = max(0.0, neutral_prob - reduction) | |
| # Update probabilities | |
| all_probs['happy'] = boosted_happy | |
| all_probs['contempt'] = new_contempt | |
| all_probs['neutral'] = new_neutral | |
| # Re-normalize to ensure sum is ~1.0 | |
| total = sum(all_probs.values()) | |
| if total > 0: | |
| all_probs = {k: v / total for k, v in all_probs.items()} | |
| # Recalculate predicted class after boosting - find emotion with highest prob | |
| new_top_emotion = max(all_probs.items(), key=lambda x: x[1])[0] | |
| # Find index in labels list | |
| if new_top_emotion in labels: | |
| predicted_idx = labels.index(new_top_emotion) | |
| confidence = all_probs[new_top_emotion] | |
| print(f"[VIT] Post-processing: Boosted happy from {happy_prob:.3f} to {all_probs.get('happy', 0.0):.3f}, new prediction: {new_top_emotion}") | |
| else: | |
| # Fallback to original prediction if label not found | |
| print(f"[VIT] Post-processing: Boosted happy but couldn't find label {new_top_emotion} in labels list") | |
| print(f"[VIT] Predicted index: {predicted_idx}, Raw label from model: {model.config.id2label.get(predicted_idx, 'unknown')}") | |
| return predicted_idx, confidence, all_probs | |