bakhil-aissa commited on
Commit
cd87932
·
verified ·
1 Parent(s): 09af1eb

Upload 2 files

Browse files
Files changed (2) hide show
  1. app_.py +56 -0
  2. requirements.txt +9 -0
app_.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import onnxruntime as ort
4
+ from PIL import Image
5
+
6
+ session = ort.InferenceSession(
7
+ "dog_emotion_classifier.onnx",
8
+ providers=["CPUExecutionProvider"]
9
+ )
10
+
11
+
12
+ def normalize_image(image):
13
+
14
+ mean = np.array([0.485, 0.456, 0.406])
15
+ std = np.array([0.229, 0.224, 0.225])
16
+ image[0] = (image[0] - mean[0]) / std[0]
17
+ image[1] = (image[1] - mean[1]) / std[1]
18
+ image[2] = (image[2] - mean[2]) / std[2]
19
+ return image
20
+
21
+ def preprocess_image(image, is_link = False):
22
+
23
+ image = image.resize((288, 288),Image.BICUBIC)
24
+ image = np.array(image).astype(np.float32)
25
+ image = normalize_image(image)
26
+ return image
27
+ def predict(image_path , is_link = False):
28
+ image = preprocess_image(image_path , is_link)
29
+
30
+ image = np.reshape(image , (1, 3, 288, 288))
31
+ input_name = session.get_inputs()[0].name
32
+ outputs = session.run([session.get_outputs()[0].name], {input_name: image})
33
+ predictions = outputs[0]
34
+ probs = np.exp(predictions) / np.sum(np.exp(predictions))
35
+ probs = probs.flatten()
36
+ probs = probs.tolist()
37
+ to_plot = {'angry':probs[0], 'happy':probs[1], 'relaxed':probs[2], 'sad':probs[3]}
38
+
39
+ return to_plot
40
+
41
+
42
+ with gr.Blocks() as demo :
43
+ gr.Markdown("**Dog Emotion Classifier**")
44
+ with gr.Row():
45
+ with gr.Column():
46
+ gr.Markdown("Upload an image of a dog to predict its emotion.")
47
+ dog_image = gr.Image(type="pil")
48
+ button = gr.Button("Predict")
49
+
50
+ with gr.Column():
51
+ gr.Markdown("Results")
52
+ text_out = gr.Label()
53
+ value = button.click(predict, inputs=[dog_image], outputs=text_out)
54
+
55
+
56
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio==5.42.0
2
+ numpy==1.21.5
3
+ onnxruntime==1.22.0
4
+ pandas==2.3.1
5
+ Pillow==9.0.1
6
+ Pillow==11.3.0
7
+ Requests==2.32.4
8
+ streamlit==1.44.1
9
+ torch==2.0.1