sandygmaharaj commited on
Commit
03ff362
·
verified ·
1 Parent(s): 1872c94

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -50
app.py CHANGED
@@ -1,70 +1,90 @@
1
  import gradio as gr
2
  from fastai.vision.all import *
3
- import skimage
4
  from facenet_pytorch import MTCNN
5
  import torch
6
- import pandas as pd
7
  from PIL import Image, ImageDraw, ImageFont, ImageFilter
8
 
 
9
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
10
  print('Running on device: {}'.format(device))
11
 
 
12
  mtcnn = MTCNN(margin=40, keep_all=True, post_process=False, device=device)
13
 
 
14
  learn = load_learner('export.pkl')
15
 
16
- labels = learn.dls.vocab
17
- e_colors = {'confused':'orange', 'attentive':'green','bored':'red','interested':'blue','frustrated':'purple','thoughtful':'pink'}
18
- emotions = {"x": [], "y": [], "State": []}
 
 
 
 
 
 
19
 
 
20
  def predict(img):
21
  img = PILImage.create(img)
22
  boxes, _ = mtcnn.detect(img)
23
- o_img = Image.new("RGBA", img.size, color = "white")
 
 
 
 
24
  draw = ImageDraw.Draw(o_img)
 
 
 
 
 
 
25
  for box in boxes:
26
- coords = tuple(box.tolist())
27
- pred,pred_idx,probs = learn.predict(img.crop(coords))
28
- draw.rectangle(coords, fill=e_colors[pred], outline=(0, 0, 0), width=1)
29
- draw.text((coords[0]+10, coords[1]+10), pred[0].upper(), font=ImageFont.truetype("arial.ttf"), fill="#FFF")
 
 
30
  return o_img.filter(ImageFilter.SMOOTH_MORE)
31
-
32
- title = "Students emotion classifer"
33
- description = """<!DOCTYPE html>
34
- <html>
35
- <body>
36
- <div style="display: flex; flex-wrap: wrap;">
37
- <div style="width: 200px; margin: 10px;">
38
- <div style="background-color: orange; width: 50px; height: 50px;"></div>
39
- <p>Confused</p>
40
- </div>
41
- <div style="width: 200px; margin: 10px;">
42
- <div style="background-color: green; width: 50px; height: 50px;"></div>
43
- <p>Attentive</p>
44
- </div>
45
- <div style="width: 200px; margin: 10px;">
46
- <div style="background-color: red; width: 50px; height: 50px;"></div>
47
- <p>Bored</p>
48
- </div>
49
- <div style="width: 200px; margin: 10px;">
50
- <div style="background-color: blue; width: 50px; height: 50px;"></div>
51
- <p>Interested</p>
52
- </div>
53
- <div style="width: 200px; margin: 10px;">
54
- <div style="background-color: purple; width: 50px; height: 50px;"></div>
55
- <p>Frustrated</p>
56
- </div>
57
- <div style="width: 200px; margin: 10px;">
58
- <div style="background-color: pink; width: 50px; height: 50px;"></div>
59
- <p>Thoughtful</p>
60
- </div>
61
- </div>
62
- </body>
63
- </html>"""
64
-
65
- enable_queue=True
66
-
67
- #gr.Interface(fn=predict,inputs=gr.Image(source="webcam",shape=(512, 512)),outputs=gr.outputs.Label(num_top_classes=3),title=title,description=description,interpretation=interpretation,enable_queue=enable_queue).launch()
68
- #gr.Interface(fn=predict,inputs=gr.Image(source="webcam",shape=(512, 512)),outputs=gr.ScatterPlot(),title=title,description=description,interpretation=interpretation,enable_queue=enable_queue, share=True).launch()
69
- gr.Interface(fn=predict,inputs=gr.Image(sources=["webcam"],height=512, width=512),outputs=gr.Image(),title=title,description=description).launch()
70
- #gr.Interface(fn=predict,inputs=gr.Image(shape=(512, 512)),outputs=gr.Image(),title=title,description=description,enable_queue=enable_queue).launch()
 
1
  import gradio as gr
2
  from fastai.vision.all import *
 
3
  from facenet_pytorch import MTCNN
4
  import torch
 
5
  from PIL import Image, ImageDraw, ImageFont, ImageFilter
6
 
7
+ # Device setup
8
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
9
  print('Running on device: {}'.format(device))
10
 
11
+ # Face detector
12
  mtcnn = MTCNN(margin=40, keep_all=True, post_process=False, device=device)
13
 
14
+ # Load model
15
  learn = load_learner('export.pkl')
16
 
17
+ # Emotion colors
18
+ e_colors = {
19
+ 'confused': 'orange',
20
+ 'attentive': 'green',
21
+ 'bored': 'red',
22
+ 'interested': 'blue',
23
+ 'frustrated': 'purple',
24
+ 'thoughtful': 'pink'
25
+ }
26
 
27
+ # Prediction function
28
  def predict(img):
29
  img = PILImage.create(img)
30
  boxes, _ = mtcnn.detect(img)
31
+
32
+ if boxes is None: # no faces detected
33
+ return img
34
+
35
+ o_img = Image.new("RGBA", img.size, color="white")
36
  draw = ImageDraw.Draw(o_img)
37
+
38
+ try:
39
+ font = ImageFont.truetype("arial.ttf", 20)
40
+ except:
41
+ font = ImageFont.load_default()
42
+
43
  for box in boxes:
44
+ coords = tuple(map(int, box.tolist()))
45
+ pred, pred_idx, probs = learn.predict(img.crop(coords))
46
+
47
+ draw.rectangle(coords, fill=e_colors[pred], outline=(0, 0, 0), width=2)
48
+ draw.text((coords[0] + 10, coords[1] + 10), pred.upper(), font=font, fill="white")
49
+
50
  return o_img.filter(ImageFilter.SMOOTH_MORE)
51
+
52
+
53
+ # Title + description
54
+ title = "Students Emotion Classifier"
55
+ description = """
56
+ <div style="display: flex; flex-wrap: wrap;">
57
+ <div style="width: 200px; margin: 10px;">
58
+ <div style="background-color: orange; width: 50px; height: 50px;"></div><p>Confused</p>
59
+ </div>
60
+ <div style="width: 200px; margin: 10px;">
61
+ <div style="background-color: green; width: 50px; height: 50px;"></div><p>Attentive</p>
62
+ </div>
63
+ <div style="width: 200px; margin: 10px;">
64
+ <div style="background-color: red; width: 50px; height: 50px;"></div><p>Bored</p>
65
+ </div>
66
+ <div style="width: 200px; margin: 10px;">
67
+ <div style="background-color: blue; width: 50px; height: 50px;"></div><p>Interested</p>
68
+ </div>
69
+ <div style="width: 200px; margin: 10px;">
70
+ <div style="background-color: purple; width: 50px; height: 50px;"></div><p>Frustrated</p>
71
+ </div>
72
+ <div style="width: 200px; margin: 10px;">
73
+ <div style="background-color: pink; width: 50px; height: 50px;"></div><p>Thoughtful</p>
74
+ </div>
75
+ </div>
76
+ """
77
+
78
+ # Gradio modern UI
79
+ with gr.Blocks() as demo:
80
+ gr.Markdown(f"# {title}")
81
+ gr.HTML(description)
82
+
83
+ with gr.Row():
84
+ webcam = gr.Image(sources=["webcam"], type="pil", height=512, width=512, label="Webcam Input")
85
+ output = gr.Image(type="pil", label="Predicted Output")
86
+
87
+ run_btn = gr.Button("Classify Emotions")
88
+ run_btn.click(fn=predict, inputs=webcam, outputs=output)
89
+
90
+ demo.launch()