EmotionTrack / app.py
reab5555's picture
Update app.py
c665983 verified
import gradio as gr
import cv2
import numpy as np
import clip
import torch
from PIL import Image
import torch.nn.functional as F
from facenet_pytorch import MTCNN
import matplotlib.pyplot as plt
# Global variables
input_labels_X = "Happy Face, Sad Face, Angry Face, Fear Face, Disgust Face, Contempt Face, Nervous Face, Curious Face, Flirtatious Face, Ashamed Face, Bored Face, Confused Face, Calm Face, Proud Face, Guilty Face, Annoyed Face, Desperate Face, Jealous Face, Embarrassed Face, Uncomfortable Face, Helpless Face, Shy Face, Infatuated Face, Apathetic Face, Neutral Face"
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/16", device=device)
current_model_name = "ViT-B/16"
# Initialize MTCNN for face detection
mtcnn = MTCNN(keep_all=False, device=device, thresholds=[0.9, 0.9, 0.9], min_face_size=50)
def process_frame(frame, selected_model):
global model, preprocess, current_model_name
try:
# Load the selected model if it's different from the current one
if selected_model != current_model_name:
model, preprocess = clip.load(selected_model, device=device)
current_model_name = selected_model
# Convert frame to RGB if it's not already
if len(frame.shape) == 3 and frame.shape[2] == 3:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
else:
frame_rgb = frame
# Detect faces
boxes, _ = mtcnn.detect(Image.fromarray(frame_rgb))
# Find the largest face detected
largest_face = None
if boxes is not None and len(boxes) > 0:
largest_face = max(boxes, key=lambda box: (box[2] - box[0]) * (box[3] - box[1]))
# Process the largest face
if largest_face is not None:
x, y, w, h = map(int, largest_face)
cv2.rectangle(frame_rgb, (x, y), (w, h), (0, 255, 0), 2)
cropped_face = frame_rgb[y:h, x:w]
# Convert the cropped face to a tensor
frame_tensor = preprocess(Image.fromarray(cropped_face)).unsqueeze(0).to(device)
# Tokenize input labels and prepare for model
input_labels = input_labels_X.split(", ")
input_labels_no_face = [label.replace(" Face", "") for label in input_labels]
text = clip.tokenize(input_labels).to(device)
with torch.no_grad():
# Encode the frame and text
image_features = model.encode_image(frame_tensor)
text_features = model.encode_text(text)
# Calculate logit
logit_per_image, logit_per_text = model(frame_tensor, text)
# Apply softmax to convert logits to probabilities
probabilities = F.softmax(logit_per_image[0], dim=0)
# Combine labels with probabilities and sort
combined_labels_probs = list(zip(input_labels_no_face, probabilities.tolist()))
combined_labels_probs.sort(key=lambda x: x[1], reverse=True)
top_five_labels_probs = combined_labels_probs[:5]
# Create a bar graph
fig, ax = plt.subplots(figsize=(10, 5), dpi=300)
plt.subplots_adjust(left=0.3)
labels, probs = zip(*top_five_labels_probs)
bars = ax.barh(labels, probs, color=plt.cm.tab20.colors)
ax.set_xlabel('Probability')
ax.set_title('Top 5 Emotions')
ax.invert_yaxis()
ax.set_xticks(ax.get_xticks())
ax.set_xticklabels([f'{x:.3f}' for x in ax.get_xticks()], rotation=0, ha='center')
plt.tight_layout()
fig.canvas.draw()
plot_img = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
plot_img = plot_img.reshape(fig.canvas.get_width_height()[::-1] + (4,))
plt.close(fig)
return frame_rgb, plot_img
except Exception as e:
print(f"An error occurred: {str(e)}")
return frame_rgb, None
def process_video(input_video, selected_model, frame_index):
try:
cap = cv2.VideoCapture(input_video)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
ret, frame = cap.read()
if not ret:
return None, None
processed_frame, graph = process_frame(frame, selected_model)
cap.release()
return processed_frame, graph
except Exception as e:
print(f"An error occurred: {str(e)}")
return None, None
def process_image(image, model):
# Convert input image to RGB if it's not already
if image.shape[2] == 3:
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
else:
image_rgb = image
frame_rgb, graph = process_frame(image_rgb, model)
# Ensure the output is in RGB format
if frame_rgb.dtype != np.uint8:
frame_rgb = (frame_rgb * 255).astype(np.uint8)
# Convert to PIL Image to ensure correct color display in Gradio
frame_pil = Image.fromarray(frame_rgb)
return frame_pil, graph
# Create the Gradio app using Blocks
with gr.Blocks() as app:
gr.Markdown("# EmotionTrack (Zero-Shot)")
with gr.Row():
with gr.Column():
with gr.Tab("Video"):
model_dropdown_video = gr.Dropdown(choices=["ViT-B/32", "ViT-B/16", "ViT-L/14"], label="Model", value="ViT-B/16")
gr.Markdown("Upload a video to detect faces and recognize emotions.")
video_input = gr.Video()
output_frame = gr.Image(label="Processed Frame")
frame_slider = gr.Slider(minimum=0, maximum=100, step=1, label="Frame Index", value=0)
output_graph = gr.Image(label="Results Graph")
def update_slider_and_process(video):
cap = cv2.VideoCapture(video)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cap.release()
processed_frame, graph = process_video(video, "ViT-B/16", 0)
return gr.update(maximum=total_frames-1), processed_frame, graph
def update_frame(video, model, frame_idx):
processed_frame, graph = process_video(video, model, frame_idx)
return processed_frame, graph
video_input.change(update_slider_and_process, inputs=[video_input], outputs=[frame_slider, output_frame, output_graph])
frame_slider.release(update_frame, inputs=[video_input, model_dropdown_video, frame_slider], outputs=[output_frame, output_graph])
process_button_video = gr.Button("Process Frame")
process_button_video.click(update_frame, inputs=[video_input, model_dropdown_video, frame_slider], outputs=[output_frame, output_graph])
with gr.Tab("Image"):
gr.Markdown("Upload an image to detect faces and recognize emotions.")
image_input = gr.Image(type="numpy")
model_dropdown_image = gr.Dropdown(choices=["ViT-B/32", "ViT-B/16", "ViT-L/14"], label="Model", value="ViT-L/14")
process_button_image = gr.Button("Process Image")
output_image = gr.Image(type="pil", label="Processed Image")
output_image_graph = gr.Image(label="Results Graph")
process_button_image.click(process_image, inputs=[image_input, model_dropdown_image], outputs=[output_image, output_image_graph])
# Launch the app with public link enabled
app.launch(share=True)