Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from fastai.vision.all import * | |
| from facenet_pytorch import MTCNN | |
| import torch | |
| from PIL import Image, ImageDraw, ImageFont, ImageFilter | |
| # Device setup | |
| device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
| print('Running on device: {}'.format(device)) | |
| # Face detector | |
| mtcnn = MTCNN(margin=40, keep_all=True, post_process=False, device=device) | |
| # Load model | |
| learn = load_learner('export.pkl') | |
| # Emotion colors | |
| e_colors = { | |
| 'confused': 'orange', | |
| 'attentive': 'green', | |
| 'bored': 'red', | |
| 'interested': 'blue', | |
| 'frustrated': 'purple', | |
| 'thoughtful': 'pink' | |
| } | |
| # Prediction function | |
| def predict(img): | |
| img = PILImage.create(img) | |
| boxes, _ = mtcnn.detect(img) | |
| if boxes is None: # no faces detected | |
| return img | |
| o_img = Image.new("RGBA", img.size, color="white") | |
| draw = ImageDraw.Draw(o_img) | |
| try: | |
| font = ImageFont.truetype("arial.ttf", 20) | |
| except: | |
| font = ImageFont.load_default() | |
| for box in boxes: | |
| coords = tuple(map(int, box.tolist())) | |
| pred, pred_idx, probs = learn.predict(img.crop(coords)) | |
| draw.rectangle(coords, fill=e_colors[pred], outline=(0, 0, 0), width=2) | |
| draw.text((coords[0] + 10, coords[1] + 10), pred.upper(), font=font, fill="white") | |
| return o_img.filter(ImageFilter.SMOOTH_MORE) | |
| # Title + description | |
| title = "Students Emotion Classifier" | |
| description = """ | |
| <div style="display: flex; flex-wrap: wrap;"> | |
| <div style="width: 200px; margin: 10px;"> | |
| <div style="background-color: orange; width: 50px; height: 50px;"></div><p>Confused</p> | |
| </div> | |
| <div style="width: 200px; margin: 10px;"> | |
| <div style="background-color: green; width: 50px; height: 50px;"></div><p>Attentive</p> | |
| </div> | |
| <div style="width: 200px; margin: 10px;"> | |
| <div style="background-color: red; width: 50px; height: 50px;"></div><p>Bored</p> | |
| </div> | |
| <div style="width: 200px; margin: 10px;"> | |
| <div style="background-color: blue; width: 50px; height: 50px;"></div><p>Interested</p> | |
| </div> | |
| <div style="width: 200px; margin: 10px;"> | |
| <div style="background-color: purple; width: 50px; height: 50px;"></div><p>Frustrated</p> | |
| </div> | |
| <div style="width: 200px; margin: 10px;"> | |
| <div style="background-color: pink; width: 50px; height: 50px;"></div><p>Thoughtful</p> | |
| </div> | |
| </div> | |
| """ | |
| # Gradio modern UI | |
| with gr.Blocks() as demo: | |
| gr.Markdown(f"# {title}") | |
| gr.HTML(description) | |
| with gr.Row(): | |
| webcam = gr.Image(sources=["webcam"], type="pil", height=512, width=512, label="Webcam Input") | |
| output = gr.Image(type="pil", label="Predicted Output") | |
| run_btn = gr.Button("Classify Emotions") | |
| run_btn.click(fn=predict, inputs=webcam, outputs=output) | |
| demo.launch() | |