import os import gradio as gr import numpy as np from keras.models import load_model from keras.utils import load_img, img_to_array from tensorflow.image import resize from PIL import Image #gradio interface def classify_image(image): input_arr = img_to_array(image)/255 #convert PIL object to numpy array and normalize input_arr_resh = resize(input_arr, (48, 48)).numpy() if model.channelno == 1: # Model expects inputs of shape (48,48,1) input_arr_resh_gray = input_arr_resh.mean(axis=2).reshape(48,48,1) predictions = model.predictor.predict(np.array([input_arr_resh_gray])) elif model.channelno == 3: # Model expects inputs of shape (48,48,3) input_arr_resh_4dims = np.expand_dims(input_arr_resh, axis=0) predictions = model.predictor.predict(input_arr_resh_4dims) pr_emotion = model.labeldict[predictions.argmax()] prob = predictions.max()*100 returnstr = f'Prediction: {pr_emotion}, probability: {prob:4.1f}%' predictions_f = ['%s:%5.2f'%(model.labeldict[i],p*100) for i,p in enumerate(predictions[0])] print(predictions_f) return returnstr class ModelClass: def __init__(self,name='EDA_CNN.h5'): self.name = name self.predictor = load_model(os.path.join("models",modeltouse)) if name == "model_mobilenet_oncleandata_valacc078.h5": self.labeldict = {0: 'fear', 1: 'Angry', 2: 'Neutral', 3: 'Happy'} else: self.labeldict = {0: 'Angry', 1: 'Disgust', 2: 'Fear', 3: 'Happy', 4: 'Sad', 5: 'Surprise', 6: 'Neutral'} if name == "EDA_CNN.h5": self.channelno = 1 else: self.channelno = 3 #modeltouse = "EDA_CNN.h5" modeltouse = "MobileNet12blocks_wdgenaug_onrawdata_valacc063.h5" #modeltouse = "model_mobilenet_oncleandata_valacc078.h5" model = ModelClass(modeltouse) image = gr.inputs.Image(shape=(48,48)) label = gr.outputs.Label() examples = ['Happy_48_48_%d.png'%model.channelno, 'Neutral_48_48_%d.png'%model.channelno, 'Fear_48_48_%d.png'%model.channelno, 'Angry_48_48_%d.png'%model.channelno, 'Sad_48_48_%d.png'%model.channelno, #'Disgust_48_48_%d.png'%model.channelno, 'Surprise_48_48_%d.png'%model.channelno] # image = Image.open('./Happy_48_48_%d.png'%model.channelno) # classify_image(image) intf = gr.Interface(fn=classify_image, inputs=image, outputs=label, examples=examples) intf.launch(inline=False)