fer-inference / utils.py
kshitizgajurel's picture
Deploy FER inference app
eec43fb
Raw
History Blame Contribute Delete
4.88 kB
from __future__ import annotations
import numpy as np
import cv2
import matplotlib
matplotlib.use('Agg') # non-interactive backend — safe for CLI/server use
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from PIL import Image
# Emotion colour palette (matplotlib colour names / hex)
EMOTION_COLORS = {
'happy': '#2ecc71', # green
'angry': '#e74c3c', # red
'sad': '#3498db', # blue
'fear': '#e67e22', # orange
'surprise': '#f1c40f', # yellow
'disgust': '#9b59b6', # purple
'neutral': '#95a5a6', # gray
}
# OpenCV BGR equivalents for draw_face_predictions
EMOTION_BGR = {
'happy': (0, 200, 80),
'angry': (0, 60, 220),
'sad': (200, 80, 0),
'fear': (0, 130, 230),
'surprise': (0, 210, 240),
'disgust': (150, 50, 130),
'neutral': (140, 140, 140),
}
def _emotion_color_mpl(emotion: str) -> str:
return EMOTION_COLORS.get(emotion, '#555555')
def visualize_prediction(image, result: dict, save_path: str | None = None):
"""
Side-by-side: left = image, right = horizontal bar chart of probabilities.
Displays interactively if save_path is None, else saves to file.
"""
if isinstance(image, np.ndarray):
pil_img = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
elif isinstance(image, Image.Image):
pil_img = image.convert('RGB')
else:
pil_img = image
emotions = list(result['probabilities'].keys())
probs = [result['probabilities'][e] for e in emotions]
colors = [_emotion_color_mpl(e) for e in emotions]
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
fig.suptitle(
f"Prediction: {result['emotion'].upper()} ({result['confidence']*100:.1f}%)",
fontsize=14, fontweight='bold'
)
# Left: image
axes[0].imshow(pil_img)
axes[0].axis('off')
axes[0].set_title('Input Image')
# Right: horizontal bars
y_pos = range(len(emotions))
bars = axes[1].barh(list(y_pos), probs, color=colors, height=0.6)
axes[1].set_yticks(list(y_pos))
axes[1].set_yticklabels(emotions, fontsize=11)
axes[1].set_xlim(0, 1.0)
axes[1].set_xlabel('Probability')
axes[1].set_title('Emotion Probabilities')
axes[1].invert_yaxis()
for bar, prob in zip(bars, probs):
axes[1].text(
bar.get_width() + 0.01, bar.get_y() + bar.get_height() / 2,
f'{prob*100:.1f}%', va='center', fontsize=9
)
plt.tight_layout()
if save_path:
plt.savefig(save_path, bbox_inches='tight', dpi=150)
print(f"[INFO] Visualization saved to: {save_path}")
plt.close(fig)
else:
matplotlib.use('TkAgg') # try switching for interactive display
plt.show()
def draw_face_predictions(image: np.ndarray, face_results: list[dict]) -> np.ndarray:
"""
Draw bounding boxes and emotion labels on image (BGR numpy array).
Returns annotated copy.
"""
out = image.copy()
for res in face_results:
bbox = res.get('bbox')
if bbox is None:
continue
x, y, w, h = (int(v) for v in bbox)
emotion = res['emotion']
conf = res['confidence']
color = EMOTION_BGR.get(emotion, (200, 200, 200))
cv2.rectangle(out, (x, y), (x + w, y + h), color, 2)
label = f"{emotion} {conf*100:.0f}%"
font = cv2.FONT_HERSHEY_SIMPLEX
scale, thickness = 0.7, 2
(tw, th), baseline = cv2.getTextSize(label, font, scale, thickness)
# Background pill for text
ty = max(y - 8, th + 4)
cv2.rectangle(out, (x, ty - th - 4), (x + tw + 4, ty + baseline), color, cv2.FILLED)
cv2.putText(out, label, (x + 2, ty - 2), font, scale, (255, 255, 255), thickness, cv2.LINE_AA)
return out
def plot_emotion_bars(probabilities: dict, title: str = '', save_path: str | None = None):
"""Standalone horizontal bar chart of all emotion probabilities."""
emotions = list(probabilities.keys())
probs = [probabilities[e] for e in emotions]
colors = [_emotion_color_mpl(e) for e in emotions]
fig, ax = plt.subplots(figsize=(7, 4))
if title:
ax.set_title(title, fontsize=13)
y_pos = range(len(emotions))
bars = ax.barh(list(y_pos), probs, color=colors, height=0.6)
ax.set_yticks(list(y_pos))
ax.set_yticklabels(emotions, fontsize=11)
ax.set_xlim(0, 1.0)
ax.set_xlabel('Probability')
ax.invert_yaxis()
for bar, prob in zip(bars, probs):
ax.text(
bar.get_width() + 0.01, bar.get_y() + bar.get_height() / 2,
f'{prob*100:.1f}%', va='center', fontsize=9
)
plt.tight_layout()
if save_path:
plt.savefig(save_path, bbox_inches='tight', dpi=150)
print(f"[INFO] Plot saved to: {save_path}")
plt.close(fig)
else:
plt.show()