Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import cv2 | |
| import numpy as np | |
| import clip | |
| import torch | |
| from PIL import Image | |
| import torch.nn.functional as F | |
| from facenet_pytorch import MTCNN | |
| import matplotlib.pyplot as plt | |
| # Global variables | |
| input_labels_X = "Happy Face, Sad Face, Angry Face, Fear Face, Disgust Face, Contempt Face, Nervous Face, Curious Face, Flirtatious Face, Ashamed Face, Bored Face, Confused Face, Calm Face, Proud Face, Guilty Face, Annoyed Face, Desperate Face, Jealous Face, Embarrassed Face, Uncomfortable Face, Helpless Face, Shy Face, Infatuated Face, Apathetic Face, Neutral Face" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model, preprocess = clip.load("ViT-B/16", device=device) | |
| current_model_name = "ViT-B/16" | |
| # Initialize MTCNN for face detection | |
| mtcnn = MTCNN(keep_all=False, device=device, thresholds=[0.9, 0.9, 0.9], min_face_size=50) | |
| def process_frame(frame, selected_model): | |
| global model, preprocess, current_model_name | |
| try: | |
| # Load the selected model if it's different from the current one | |
| if selected_model != current_model_name: | |
| model, preprocess = clip.load(selected_model, device=device) | |
| current_model_name = selected_model | |
| # Convert frame to RGB if it's not already | |
| if len(frame.shape) == 3 and frame.shape[2] == 3: | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| else: | |
| frame_rgb = frame | |
| # Detect faces | |
| boxes, _ = mtcnn.detect(Image.fromarray(frame_rgb)) | |
| # Find the largest face detected | |
| largest_face = None | |
| if boxes is not None and len(boxes) > 0: | |
| largest_face = max(boxes, key=lambda box: (box[2] - box[0]) * (box[3] - box[1])) | |
| # Process the largest face | |
| if largest_face is not None: | |
| x, y, w, h = map(int, largest_face) | |
| cv2.rectangle(frame_rgb, (x, y), (w, h), (0, 255, 0), 2) | |
| cropped_face = frame_rgb[y:h, x:w] | |
| # Convert the cropped face to a tensor | |
| frame_tensor = preprocess(Image.fromarray(cropped_face)).unsqueeze(0).to(device) | |
| # Tokenize input labels and prepare for model | |
| input_labels = input_labels_X.split(", ") | |
| input_labels_no_face = [label.replace(" Face", "") for label in input_labels] | |
| text = clip.tokenize(input_labels).to(device) | |
| with torch.no_grad(): | |
| # Encode the frame and text | |
| image_features = model.encode_image(frame_tensor) | |
| text_features = model.encode_text(text) | |
| # Calculate logit | |
| logit_per_image, logit_per_text = model(frame_tensor, text) | |
| # Apply softmax to convert logits to probabilities | |
| probabilities = F.softmax(logit_per_image[0], dim=0) | |
| # Combine labels with probabilities and sort | |
| combined_labels_probs = list(zip(input_labels_no_face, probabilities.tolist())) | |
| combined_labels_probs.sort(key=lambda x: x[1], reverse=True) | |
| top_five_labels_probs = combined_labels_probs[:5] | |
| # Create a bar graph | |
| fig, ax = plt.subplots(figsize=(10, 5), dpi=300) | |
| plt.subplots_adjust(left=0.3) | |
| labels, probs = zip(*top_five_labels_probs) | |
| bars = ax.barh(labels, probs, color=plt.cm.tab20.colors) | |
| ax.set_xlabel('Probability') | |
| ax.set_title('Top 5 Emotions') | |
| ax.invert_yaxis() | |
| ax.set_xticks(ax.get_xticks()) | |
| ax.set_xticklabels([f'{x:.3f}' for x in ax.get_xticks()], rotation=0, ha='center') | |
| plt.tight_layout() | |
| fig.canvas.draw() | |
| plot_img = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8) | |
| plot_img = plot_img.reshape(fig.canvas.get_width_height()[::-1] + (4,)) | |
| plt.close(fig) | |
| return frame_rgb, plot_img | |
| except Exception as e: | |
| print(f"An error occurred: {str(e)}") | |
| return frame_rgb, None | |
| def process_video(input_video, selected_model, frame_index): | |
| try: | |
| cap = cv2.VideoCapture(input_video) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index) | |
| ret, frame = cap.read() | |
| if not ret: | |
| return None, None | |
| processed_frame, graph = process_frame(frame, selected_model) | |
| cap.release() | |
| return processed_frame, graph | |
| except Exception as e: | |
| print(f"An error occurred: {str(e)}") | |
| return None, None | |
| def process_image(image, model): | |
| # Convert input image to RGB if it's not already | |
| if image.shape[2] == 3: | |
| image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| else: | |
| image_rgb = image | |
| frame_rgb, graph = process_frame(image_rgb, model) | |
| # Ensure the output is in RGB format | |
| if frame_rgb.dtype != np.uint8: | |
| frame_rgb = (frame_rgb * 255).astype(np.uint8) | |
| # Convert to PIL Image to ensure correct color display in Gradio | |
| frame_pil = Image.fromarray(frame_rgb) | |
| return frame_pil, graph | |
| # Create the Gradio app using Blocks | |
| with gr.Blocks() as app: | |
| gr.Markdown("# EmotionTrack (Zero-Shot)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Tab("Video"): | |
| model_dropdown_video = gr.Dropdown(choices=["ViT-B/32", "ViT-B/16", "ViT-L/14"], label="Model", value="ViT-B/16") | |
| gr.Markdown("Upload a video to detect faces and recognize emotions.") | |
| video_input = gr.Video() | |
| output_frame = gr.Image(label="Processed Frame") | |
| frame_slider = gr.Slider(minimum=0, maximum=100, step=1, label="Frame Index", value=0) | |
| output_graph = gr.Image(label="Results Graph") | |
| def update_slider_and_process(video): | |
| cap = cv2.VideoCapture(video) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| cap.release() | |
| processed_frame, graph = process_video(video, "ViT-B/16", 0) | |
| return gr.update(maximum=total_frames-1), processed_frame, graph | |
| def update_frame(video, model, frame_idx): | |
| processed_frame, graph = process_video(video, model, frame_idx) | |
| return processed_frame, graph | |
| video_input.change(update_slider_and_process, inputs=[video_input], outputs=[frame_slider, output_frame, output_graph]) | |
| frame_slider.release(update_frame, inputs=[video_input, model_dropdown_video, frame_slider], outputs=[output_frame, output_graph]) | |
| process_button_video = gr.Button("Process Frame") | |
| process_button_video.click(update_frame, inputs=[video_input, model_dropdown_video, frame_slider], outputs=[output_frame, output_graph]) | |
| with gr.Tab("Image"): | |
| gr.Markdown("Upload an image to detect faces and recognize emotions.") | |
| image_input = gr.Image(type="numpy") | |
| model_dropdown_image = gr.Dropdown(choices=["ViT-B/32", "ViT-B/16", "ViT-L/14"], label="Model", value="ViT-L/14") | |
| process_button_image = gr.Button("Process Image") | |
| output_image = gr.Image(type="pil", label="Processed Image") | |
| output_image_graph = gr.Image(label="Results Graph") | |
| process_button_image.click(process_image, inputs=[image_input, model_dropdown_image], outputs=[output_image, output_image_graph]) | |
| # Launch the app with public link enabled | |
| app.launch(share=True) |