File size: 7,740 Bytes
2c1fa32
 
 
 
 
 
 
 
6d5f107
2c1fa32
 
cd680cf
2c1fa32
 
981e736
 
2c1fa32
 
c666484
2c1fa32
e5bacbc
caba546
e5bacbc
 
 
 
 
 
 
dde67a3
 
 
 
 
e5bacbc
 
 
 
 
 
 
 
 
 
 
 
dde67a3
f099e78
e5bacbc
6d5f107
e5bacbc
 
 
 
995585d
e5bacbc
 
 
 
 
 
 
 
 
 
 
 
 
 
995585d
e5bacbc
 
 
6d5f107
dde67a3
 
981e736
6d5f107
4749330
6d5f107
 
dde67a3
6d5f107
d94a411
 
981e736
 
 
6d5f107
f099e78
 
6d5f107
 
e5bacbc
dde67a3
e5bacbc
 
dde67a3
e5bacbc
6d5f107
e5bacbc
 
 
6d5f107
 
 
 
 
2662a5b
e5bacbc
dde67a3
e5bacbc
 
2662a5b
e5bacbc
 
2662a5b
2c1fa32
c665983
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4deb5c
 
71973e5
d4deb5c
995585d
 
6d06054
981e736
995585d
 
0132173
d94a411
2662a5b
0132173
2662a5b
995585d
 
 
d94a411
2662a5b
 
995585d
2662a5b
 
 
 
 
995585d
6d06054
2662a5b
0132173
6d06054
995585d
 
 
 
 
c665983
2662a5b
995585d
2f3bfb9
d53087e
4749330
2662a5b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import gradio as gr
import cv2
import numpy as np
import clip
import torch
from PIL import Image
import torch.nn.functional as F
from facenet_pytorch import MTCNN
import matplotlib.pyplot as plt

# Global variables
input_labels_X = "Happy Face, Sad Face, Angry Face, Fear Face, Disgust Face, Contempt Face, Nervous Face, Curious Face, Flirtatious Face, Ashamed Face, Bored Face, Confused Face, Calm Face, Proud Face, Guilty Face, Annoyed Face, Desperate Face, Jealous Face, Embarrassed Face, Uncomfortable Face, Helpless Face, Shy Face, Infatuated Face, Apathetic Face, Neutral Face"

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/16", device=device)
current_model_name = "ViT-B/16"

# Initialize MTCNN for face detection
mtcnn = MTCNN(keep_all=False, device=device, thresholds=[0.9, 0.9, 0.9], min_face_size=50)

def process_frame(frame, selected_model):
    global model, preprocess, current_model_name
    
    try:
        # Load the selected model if it's different from the current one
        if selected_model != current_model_name:
            model, preprocess = clip.load(selected_model, device=device)
            current_model_name = selected_model
        
        # Convert frame to RGB if it's not already
        if len(frame.shape) == 3 and frame.shape[2] == 3:
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        else:
            frame_rgb = frame

        # Detect faces
        boxes, _ = mtcnn.detect(Image.fromarray(frame_rgb))

        # Find the largest face detected
        largest_face = None
        if boxes is not None and len(boxes) > 0:
            largest_face = max(boxes, key=lambda box: (box[2] - box[0]) * (box[3] - box[1]))

        # Process the largest face
        if largest_face is not None:
            x, y, w, h = map(int, largest_face)
            cv2.rectangle(frame_rgb, (x, y), (w, h), (0, 255, 0), 2)
            cropped_face = frame_rgb[y:h, x:w]

            # Convert the cropped face to a tensor
            frame_tensor = preprocess(Image.fromarray(cropped_face)).unsqueeze(0).to(device)

            # Tokenize input labels and prepare for model
            input_labels = input_labels_X.split(", ")
            input_labels_no_face = [label.replace(" Face", "") for label in input_labels]
            text = clip.tokenize(input_labels).to(device)

            with torch.no_grad():
                # Encode the frame and text
                image_features = model.encode_image(frame_tensor)
                text_features = model.encode_text(text)

                # Calculate logit
                logit_per_image, logit_per_text = model(frame_tensor, text)

                # Apply softmax to convert logits to probabilities
                probabilities = F.softmax(logit_per_image[0], dim=0)

            # Combine labels with probabilities and sort
            combined_labels_probs = list(zip(input_labels_no_face, probabilities.tolist()))
            combined_labels_probs.sort(key=lambda x: x[1], reverse=True)
            top_five_labels_probs = combined_labels_probs[:5]

            # Create a bar graph
            fig, ax = plt.subplots(figsize=(10, 5), dpi=300)
            plt.subplots_adjust(left=0.3)
            
            labels, probs = zip(*top_five_labels_probs)
            bars = ax.barh(labels, probs, color=plt.cm.tab20.colors)
            ax.set_xlabel('Probability')
            ax.set_title('Top 5 Emotions')
            ax.invert_yaxis()
            
            ax.set_xticks(ax.get_xticks())
            ax.set_xticklabels([f'{x:.3f}' for x in ax.get_xticks()], rotation=0, ha='center')
            
            plt.tight_layout()
            
            fig.canvas.draw()
            plot_img = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
            plot_img = plot_img.reshape(fig.canvas.get_width_height()[::-1] + (4,))
            
            plt.close(fig)
        
        return frame_rgb, plot_img
    except Exception as e:
        print(f"An error occurred: {str(e)}")
        return frame_rgb, None

def process_video(input_video, selected_model, frame_index):
    try:
        cap = cv2.VideoCapture(input_video)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
        
        ret, frame = cap.read()
        if not ret:
            return None, None

        processed_frame, graph = process_frame(frame, selected_model)
        cap.release()

        return processed_frame, graph
    except Exception as e:
        print(f"An error occurred: {str(e)}")
        return None, None

def process_image(image, model):
    # Convert input image to RGB if it's not already
    if image.shape[2] == 3:
        image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    else:
        image_rgb = image
    
    frame_rgb, graph = process_frame(image_rgb, model)
    
    # Ensure the output is in RGB format
    if frame_rgb.dtype != np.uint8:
        frame_rgb = (frame_rgb * 255).astype(np.uint8)
    
    # Convert to PIL Image to ensure correct color display in Gradio
    frame_pil = Image.fromarray(frame_rgb)
    
    return frame_pil, graph

# Create the Gradio app using Blocks
with gr.Blocks() as app:
    gr.Markdown("# EmotionTrack (Zero-Shot)")
    
    with gr.Row():
        with gr.Column():
            with gr.Tab("Video"):
                model_dropdown_video = gr.Dropdown(choices=["ViT-B/32", "ViT-B/16", "ViT-L/14"], label="Model", value="ViT-B/16")
                gr.Markdown("Upload a video to detect faces and recognize emotions.")
                video_input = gr.Video()
                output_frame = gr.Image(label="Processed Frame")
                frame_slider = gr.Slider(minimum=0, maximum=100, step=1, label="Frame Index", value=0)
                output_graph = gr.Image(label="Results Graph")

                def update_slider_and_process(video):
                    cap = cv2.VideoCapture(video)
                    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
                    cap.release()
                    processed_frame, graph = process_video(video, "ViT-B/16", 0)
                    return gr.update(maximum=total_frames-1), processed_frame, graph

                def update_frame(video, model, frame_idx):
                    processed_frame, graph = process_video(video, model, frame_idx)
                    return processed_frame, graph

                video_input.change(update_slider_and_process, inputs=[video_input], outputs=[frame_slider, output_frame, output_graph])
                frame_slider.release(update_frame, inputs=[video_input, model_dropdown_video, frame_slider], outputs=[output_frame, output_graph])
                
                process_button_video = gr.Button("Process Frame")
                process_button_video.click(update_frame, inputs=[video_input, model_dropdown_video, frame_slider], outputs=[output_frame, output_graph])

            with gr.Tab("Image"):
                gr.Markdown("Upload an image to detect faces and recognize emotions.")
                image_input = gr.Image(type="numpy")
                model_dropdown_image = gr.Dropdown(choices=["ViT-B/32", "ViT-B/16", "ViT-L/14"], label="Model", value="ViT-L/14")
                process_button_image = gr.Button("Process Image")
                
                output_image = gr.Image(type="pil", label="Processed Image")
                output_image_graph = gr.Image(label="Results Graph")
                
                process_button_image.click(process_image, inputs=[image_input, model_dropdown_image], outputs=[output_image, output_image_graph])

# Launch the app with public link enabled
app.launch(share=True)