totes-emosh / app /app_utils.py
drdeception
feat: six-emotion replication challenge — totes-emosh EmotionMap build
0d27c43
Raw
History Blame Contribute Delete
2.21 kB
"""
File: app_utils.py
Author: Dr. Gordon Wright
Description: Utility functions for the static facial-expression recogniser.
The dynamic / video sweep pipeline lives in a separate app.
License: MIT License
"""
import torch
import numpy as np
from PIL import Image
import cv2
from pytorch_grad_cam.utils.image import show_cam_on_image
from app.model import pth_model_static, cam, pth_processing
from app.face_landmarker import detect as detect_face, bbox_from_landmarks
from app.config import DICT_EMO
def preprocess_image_and_predict(inp):
"""Detect a face in the input, classify the expression, and return
everything the tile renderer needs.
Returns a 6-tuple `(face_crop, heatmap, confidences, blendshapes,
landmarks, bbox)` where `bbox` is the face crop's pixel rectangle in
the *original* image coordinates — needed downstream so the
landmark wireframe can be redrawn at the same crop as the face.
Returns `(None,) * 6` if no face is detected.
"""
inp = np.array(inp)
if inp is None:
return None, None, None, None, None, None
try:
h, w = inp.shape[:2]
except Exception:
return None, None, None, None, None, None
landmarks, blendshapes = detect_face(Image.fromarray(inp))
if landmarks is None:
return None, None, None, None, None, None
startX, startY, endX, endY = bbox_from_landmarks(landmarks, w, h)
cur_face = inp[startY:endY, startX:endX]
cur_face_n = pth_processing(Image.fromarray(cur_face))
with torch.no_grad():
prediction = (
torch.nn.functional.softmax(pth_model_static(cur_face_n), dim=1)
.detach()
.numpy()[0]
)
confidences = {DICT_EMO[i]: float(prediction[i]) for i in range(7)}
grayscale_cam = cam(input_tensor=cur_face_n)
grayscale_cam = grayscale_cam[0, :]
cur_face_hm = cv2.resize(cur_face, (224, 224))
cur_face_hm = np.float32(cur_face_hm) / 255
heatmap = show_cam_on_image(cur_face_hm, grayscale_cam, use_rgb=True)
bbox = (startX, startY, endX, endY)
landmark_pts = [(lm.x, lm.y) for lm in landmarks]
return cur_face, heatmap, confidences, blendshapes, landmark_pts, bbox