import gradio as gr import cv2 import numpy as np import clip import torch from PIL import Image import torch.nn.functional as F from facenet_pytorch import MTCNN import matplotlib.pyplot as plt # Global variables input_labels_X = "Happy Face, Sad Face, Angry Face, Fear Face, Disgust Face, Contempt Face, Nervous Face, Curious Face, Flirtatious Face, Ashamed Face, Bored Face, Confused Face, Calm Face, Proud Face, Guilty Face, Annoyed Face, Desperate Face, Jealous Face, Embarrassed Face, Uncomfortable Face, Helpless Face, Shy Face, Infatuated Face, Apathetic Face, Neutral Face" device = "cuda" if torch.cuda.is_available() else "cpu" model, preprocess = clip.load("ViT-B/16", device=device) current_model_name = "ViT-B/16" # Initialize MTCNN for face detection mtcnn = MTCNN(keep_all=False, device=device, thresholds=[0.9, 0.9, 0.9], min_face_size=50) def process_frame(frame, selected_model): global model, preprocess, current_model_name try: # Load the selected model if it's different from the current one if selected_model != current_model_name: model, preprocess = clip.load(selected_model, device=device) current_model_name = selected_model # Convert frame to RGB if it's not already if len(frame.shape) == 3 and frame.shape[2] == 3: frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) else: frame_rgb = frame # Detect faces boxes, _ = mtcnn.detect(Image.fromarray(frame_rgb)) # Find the largest face detected largest_face = None if boxes is not None and len(boxes) > 0: largest_face = max(boxes, key=lambda box: (box[2] - box[0]) * (box[3] - box[1])) # Process the largest face if largest_face is not None: x, y, w, h = map(int, largest_face) cv2.rectangle(frame_rgb, (x, y), (w, h), (0, 255, 0), 2) cropped_face = frame_rgb[y:h, x:w] # Convert the cropped face to a tensor frame_tensor = preprocess(Image.fromarray(cropped_face)).unsqueeze(0).to(device) # Tokenize input labels and prepare for model input_labels = input_labels_X.split(", ") input_labels_no_face = [label.replace(" Face", "") for label in input_labels] text = clip.tokenize(input_labels).to(device) with torch.no_grad(): # Encode the frame and text image_features = model.encode_image(frame_tensor) text_features = model.encode_text(text) # Calculate logit logit_per_image, logit_per_text = model(frame_tensor, text) # Apply softmax to convert logits to probabilities probabilities = F.softmax(logit_per_image[0], dim=0) # Combine labels with probabilities and sort combined_labels_probs = list(zip(input_labels_no_face, probabilities.tolist())) combined_labels_probs.sort(key=lambda x: x[1], reverse=True) top_five_labels_probs = combined_labels_probs[:5] # Create a bar graph fig, ax = plt.subplots(figsize=(10, 5), dpi=300) plt.subplots_adjust(left=0.3) labels, probs = zip(*top_five_labels_probs) bars = ax.barh(labels, probs, color=plt.cm.tab20.colors) ax.set_xlabel('Probability') ax.set_title('Top 5 Emotions') ax.invert_yaxis() ax.set_xticks(ax.get_xticks()) ax.set_xticklabels([f'{x:.3f}' for x in ax.get_xticks()], rotation=0, ha='center') plt.tight_layout() fig.canvas.draw() plot_img = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8) plot_img = plot_img.reshape(fig.canvas.get_width_height()[::-1] + (4,)) plt.close(fig) return frame_rgb, plot_img except Exception as e: print(f"An error occurred: {str(e)}") return frame_rgb, None def process_video(input_video, selected_model, frame_index): try: cap = cv2.VideoCapture(input_video) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index) ret, frame = cap.read() if not ret: return None, None processed_frame, graph = process_frame(frame, selected_model) cap.release() return processed_frame, graph except Exception as e: print(f"An error occurred: {str(e)}") return None, None def process_image(image, model): # Convert input image to RGB if it's not already if image.shape[2] == 3: image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) else: image_rgb = image frame_rgb, graph = process_frame(image_rgb, model) # Ensure the output is in RGB format if frame_rgb.dtype != np.uint8: frame_rgb = (frame_rgb * 255).astype(np.uint8) # Convert to PIL Image to ensure correct color display in Gradio frame_pil = Image.fromarray(frame_rgb) return frame_pil, graph # Create the Gradio app using Blocks with gr.Blocks() as app: gr.Markdown("# EmotionTrack (Zero-Shot)") with gr.Row(): with gr.Column(): with gr.Tab("Video"): model_dropdown_video = gr.Dropdown(choices=["ViT-B/32", "ViT-B/16", "ViT-L/14"], label="Model", value="ViT-B/16") gr.Markdown("Upload a video to detect faces and recognize emotions.") video_input = gr.Video() output_frame = gr.Image(label="Processed Frame") frame_slider = gr.Slider(minimum=0, maximum=100, step=1, label="Frame Index", value=0) output_graph = gr.Image(label="Results Graph") def update_slider_and_process(video): cap = cv2.VideoCapture(video) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) cap.release() processed_frame, graph = process_video(video, "ViT-B/16", 0) return gr.update(maximum=total_frames-1), processed_frame, graph def update_frame(video, model, frame_idx): processed_frame, graph = process_video(video, model, frame_idx) return processed_frame, graph video_input.change(update_slider_and_process, inputs=[video_input], outputs=[frame_slider, output_frame, output_graph]) frame_slider.release(update_frame, inputs=[video_input, model_dropdown_video, frame_slider], outputs=[output_frame, output_graph]) process_button_video = gr.Button("Process Frame") process_button_video.click(update_frame, inputs=[video_input, model_dropdown_video, frame_slider], outputs=[output_frame, output_graph]) with gr.Tab("Image"): gr.Markdown("Upload an image to detect faces and recognize emotions.") image_input = gr.Image(type="numpy") model_dropdown_image = gr.Dropdown(choices=["ViT-B/32", "ViT-B/16", "ViT-L/14"], label="Model", value="ViT-L/14") process_button_image = gr.Button("Process Image") output_image = gr.Image(type="pil", label="Processed Image") output_image_graph = gr.Image(label="Results Graph") process_button_image.click(process_image, inputs=[image_input, model_dropdown_image], outputs=[output_image, output_image_graph]) # Launch the app with public link enabled app.launch(share=True)