File size: 2,482 Bytes
cd9e010
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
import gradio as gr
import numpy as np
from keras.models import load_model
from keras.utils import load_img, img_to_array
from tensorflow.image import resize

from PIL import Image

#gradio interface
def classify_image(image):
    input_arr = img_to_array(image)/255 #convert PIL object to numpy array and normalize
    input_arr_resh = resize(input_arr, (48, 48)).numpy()
    if model.channelno == 1:
        # Model expects inputs of shape (48,48,1)
        input_arr_resh_gray = input_arr_resh.mean(axis=2).reshape(48,48,1)
        predictions = model.predictor.predict(np.array([input_arr_resh_gray]))
    elif model.channelno == 3:
        # Model expects inputs of shape (48,48,3)
        input_arr_resh_4dims = np.expand_dims(input_arr_resh, axis=0)
        predictions = model.predictor.predict(input_arr_resh_4dims)
    pr_emotion = model.labeldict[predictions.argmax()]
    prob = predictions.max()*100
    returnstr = f'Prediction: {pr_emotion}, probability: {prob:4.1f}%'
    predictions_f = ['%s:%5.2f'%(model.labeldict[i],p*100) for i,p in enumerate(predictions[0])]
    print(predictions_f)
    return returnstr

class ModelClass:
    def __init__(self,name='EDA_CNN.h5'):
        self.name = name
        self.predictor = load_model(os.path.join("models",modeltouse))
        if name == "model_mobilenet_oncleandata_valacc078.h5":
            self.labeldict = {0: 'fear', 1: 'Angry', 2: 'Neutral', 3: 'Happy'}
        else:
            self.labeldict = {0: 'Angry', 1: 'Disgust', 2: 'Fear', 3: 'Happy', 4: 'Sad', 5: 'Surprise', 6: 'Neutral'} 
        if name == "EDA_CNN.h5":
            self.channelno = 1
        else:
            self.channelno = 3

#modeltouse = "EDA_CNN.h5"
modeltouse = "MobileNet12blocks_wdgenaug_onrawdata_valacc063.h5"
#modeltouse = "model_mobilenet_oncleandata_valacc078.h5"

model = ModelClass(modeltouse)

image = gr.inputs.Image(shape=(48,48))
label = gr.outputs.Label()
examples = ['Happy_48_48_%d.png'%model.channelno,
            'Neutral_48_48_%d.png'%model.channelno,
            'Fear_48_48_%d.png'%model.channelno,
            'Angry_48_48_%d.png'%model.channelno,
            'Sad_48_48_%d.png'%model.channelno, 
            #'Disgust_48_48_%d.png'%model.channelno, 
            'Surprise_48_48_%d.png'%model.channelno]

# image = Image.open('./Happy_48_48_%d.png'%model.channelno)
# classify_image(image)
intf = gr.Interface(fn=classify_image, inputs=image, outputs=label, examples=examples)
intf.launch(inline=False)