Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,70 +1,90 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
from fastai.vision.all import *
|
| 3 |
-
import skimage
|
| 4 |
from facenet_pytorch import MTCNN
|
| 5 |
import torch
|
| 6 |
-
import pandas as pd
|
| 7 |
from PIL import Image, ImageDraw, ImageFont, ImageFilter
|
| 8 |
|
|
|
|
| 9 |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
| 10 |
print('Running on device: {}'.format(device))
|
| 11 |
|
|
|
|
| 12 |
mtcnn = MTCNN(margin=40, keep_all=True, post_process=False, device=device)
|
| 13 |
|
|
|
|
| 14 |
learn = load_learner('export.pkl')
|
| 15 |
|
| 16 |
-
|
| 17 |
-
e_colors = {
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
|
|
|
| 20 |
def predict(img):
|
| 21 |
img = PILImage.create(img)
|
| 22 |
boxes, _ = mtcnn.detect(img)
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
draw = ImageDraw.Draw(o_img)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
for box in boxes:
|
| 26 |
-
coords = tuple(box.tolist())
|
| 27 |
-
pred,pred_idx,probs = learn.predict(img.crop(coords))
|
| 28 |
-
|
| 29 |
-
draw.
|
|
|
|
|
|
|
| 30 |
return o_img.filter(ImageFilter.SMOOTH_MORE)
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
from fastai.vision.all import *
|
|
|
|
| 3 |
from facenet_pytorch import MTCNN
|
| 4 |
import torch
|
|
|
|
| 5 |
from PIL import Image, ImageDraw, ImageFont, ImageFilter
|
| 6 |
|
| 7 |
+
# Device setup
|
| 8 |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
| 9 |
print('Running on device: {}'.format(device))
|
| 10 |
|
| 11 |
+
# Face detector
|
| 12 |
mtcnn = MTCNN(margin=40, keep_all=True, post_process=False, device=device)
|
| 13 |
|
| 14 |
+
# Load model
|
| 15 |
learn = load_learner('export.pkl')
|
| 16 |
|
| 17 |
+
# Emotion colors
|
| 18 |
+
e_colors = {
|
| 19 |
+
'confused': 'orange',
|
| 20 |
+
'attentive': 'green',
|
| 21 |
+
'bored': 'red',
|
| 22 |
+
'interested': 'blue',
|
| 23 |
+
'frustrated': 'purple',
|
| 24 |
+
'thoughtful': 'pink'
|
| 25 |
+
}
|
| 26 |
|
| 27 |
+
# Prediction function
|
| 28 |
def predict(img):
|
| 29 |
img = PILImage.create(img)
|
| 30 |
boxes, _ = mtcnn.detect(img)
|
| 31 |
+
|
| 32 |
+
if boxes is None: # no faces detected
|
| 33 |
+
return img
|
| 34 |
+
|
| 35 |
+
o_img = Image.new("RGBA", img.size, color="white")
|
| 36 |
draw = ImageDraw.Draw(o_img)
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
font = ImageFont.truetype("arial.ttf", 20)
|
| 40 |
+
except:
|
| 41 |
+
font = ImageFont.load_default()
|
| 42 |
+
|
| 43 |
for box in boxes:
|
| 44 |
+
coords = tuple(map(int, box.tolist()))
|
| 45 |
+
pred, pred_idx, probs = learn.predict(img.crop(coords))
|
| 46 |
+
|
| 47 |
+
draw.rectangle(coords, fill=e_colors[pred], outline=(0, 0, 0), width=2)
|
| 48 |
+
draw.text((coords[0] + 10, coords[1] + 10), pred.upper(), font=font, fill="white")
|
| 49 |
+
|
| 50 |
return o_img.filter(ImageFilter.SMOOTH_MORE)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# Title + description
|
| 54 |
+
title = "Students Emotion Classifier"
|
| 55 |
+
description = """
|
| 56 |
+
<div style="display: flex; flex-wrap: wrap;">
|
| 57 |
+
<div style="width: 200px; margin: 10px;">
|
| 58 |
+
<div style="background-color: orange; width: 50px; height: 50px;"></div><p>Confused</p>
|
| 59 |
+
</div>
|
| 60 |
+
<div style="width: 200px; margin: 10px;">
|
| 61 |
+
<div style="background-color: green; width: 50px; height: 50px;"></div><p>Attentive</p>
|
| 62 |
+
</div>
|
| 63 |
+
<div style="width: 200px; margin: 10px;">
|
| 64 |
+
<div style="background-color: red; width: 50px; height: 50px;"></div><p>Bored</p>
|
| 65 |
+
</div>
|
| 66 |
+
<div style="width: 200px; margin: 10px;">
|
| 67 |
+
<div style="background-color: blue; width: 50px; height: 50px;"></div><p>Interested</p>
|
| 68 |
+
</div>
|
| 69 |
+
<div style="width: 200px; margin: 10px;">
|
| 70 |
+
<div style="background-color: purple; width: 50px; height: 50px;"></div><p>Frustrated</p>
|
| 71 |
+
</div>
|
| 72 |
+
<div style="width: 200px; margin: 10px;">
|
| 73 |
+
<div style="background-color: pink; width: 50px; height: 50px;"></div><p>Thoughtful</p>
|
| 74 |
+
</div>
|
| 75 |
+
</div>
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
# Gradio modern UI
|
| 79 |
+
with gr.Blocks() as demo:
|
| 80 |
+
gr.Markdown(f"# {title}")
|
| 81 |
+
gr.HTML(description)
|
| 82 |
+
|
| 83 |
+
with gr.Row():
|
| 84 |
+
webcam = gr.Image(sources=["webcam"], type="pil", height=512, width=512, label="Webcam Input")
|
| 85 |
+
output = gr.Image(type="pil", label="Predicted Output")
|
| 86 |
+
|
| 87 |
+
run_btn = gr.Button("Classify Emotions")
|
| 88 |
+
run_btn.click(fn=predict, inputs=webcam, outputs=output)
|
| 89 |
+
|
| 90 |
+
demo.launch()
|