Classifymoods / app.py
OnurKerimoglu's picture
first working version of the gradio app
cd9e010
import os
import gradio as gr
import numpy as np
from keras.models import load_model
from keras.utils import load_img, img_to_array
from tensorflow.image import resize
from PIL import Image
#gradio interface
def classify_image(image):
input_arr = img_to_array(image)/255 #convert PIL object to numpy array and normalize
input_arr_resh = resize(input_arr, (48, 48)).numpy()
if model.channelno == 1:
# Model expects inputs of shape (48,48,1)
input_arr_resh_gray = input_arr_resh.mean(axis=2).reshape(48,48,1)
predictions = model.predictor.predict(np.array([input_arr_resh_gray]))
elif model.channelno == 3:
# Model expects inputs of shape (48,48,3)
input_arr_resh_4dims = np.expand_dims(input_arr_resh, axis=0)
predictions = model.predictor.predict(input_arr_resh_4dims)
pr_emotion = model.labeldict[predictions.argmax()]
prob = predictions.max()*100
returnstr = f'Prediction: {pr_emotion}, probability: {prob:4.1f}%'
predictions_f = ['%s:%5.2f'%(model.labeldict[i],p*100) for i,p in enumerate(predictions[0])]
print(predictions_f)
return returnstr
class ModelClass:
def __init__(self,name='EDA_CNN.h5'):
self.name = name
self.predictor = load_model(os.path.join("models",modeltouse))
if name == "model_mobilenet_oncleandata_valacc078.h5":
self.labeldict = {0: 'fear', 1: 'Angry', 2: 'Neutral', 3: 'Happy'}
else:
self.labeldict = {0: 'Angry', 1: 'Disgust', 2: 'Fear', 3: 'Happy', 4: 'Sad', 5: 'Surprise', 6: 'Neutral'}
if name == "EDA_CNN.h5":
self.channelno = 1
else:
self.channelno = 3
#modeltouse = "EDA_CNN.h5"
modeltouse = "MobileNet12blocks_wdgenaug_onrawdata_valacc063.h5"
#modeltouse = "model_mobilenet_oncleandata_valacc078.h5"
model = ModelClass(modeltouse)
image = gr.inputs.Image(shape=(48,48))
label = gr.outputs.Label()
examples = ['Happy_48_48_%d.png'%model.channelno,
'Neutral_48_48_%d.png'%model.channelno,
'Fear_48_48_%d.png'%model.channelno,
'Angry_48_48_%d.png'%model.channelno,
'Sad_48_48_%d.png'%model.channelno,
#'Disgust_48_48_%d.png'%model.channelno,
'Surprise_48_48_%d.png'%model.channelno]
# image = Image.open('./Happy_48_48_%d.png'%model.channelno)
# classify_image(image)
intf = gr.Interface(fn=classify_image, inputs=image, outputs=label, examples=examples)
intf.launch(inline=False)