mramjad commited on
Commit
ff588f1
·
verified ·
1 Parent(s): efafa2f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -38
app.py CHANGED
@@ -1,39 +1,39 @@
1
- import gradio as gr
2
- import numpy as np
3
- from PIL import Image
4
- import onnxruntime as ort
5
-
6
- # Load the ONNX model
7
- model_path = "cifar10_model.onnx"
8
- ort_session = ort.InferenceSession(model_path)
9
-
10
- # CIFAR-10 class labels
11
- labels = [
12
- "airplane", "automobile", "bird", "cat", "deer",
13
- "dog", "frog", "horse", "ship", "truck"
14
- ]
15
-
16
- def preprocess_image(image):
17
- # Resize to 32x32 and normalize
18
- image = image.resize((32, 32))
19
- image = np.array(image).astype(np.float32) / 255.0
20
- # Reshape to (1, 3, 32, 32) [batch, channels, height, width]
21
- return np.expand_dims(image.transpose(2, 0, 1), axis=0)
22
-
23
- def predict(image):
24
- # Preprocess the image
25
- input_data = preprocess_image(image)
26
- # Run inference
27
- outputs = ort_session.run(None, {"serving_default_keras_tensor:0": input_data})[0]
28
- predicted_class_idx = np.argmax(outputs)
29
- return labels[predicted_class_idx]
30
-
31
- # Gradio Interface
32
- gr.Interface(
33
- fn=predict,
34
- inputs=gr.Image(type="pil"),
35
- outputs=gr.Label(num_top_classes=3),
36
- title="CIFAR-10 Classifier",
37
- description="Upload an image to classify it into CIFAR-10 categories.",
38
- allow_flagging="never" # Disable flagging
39
  ).launch(server_name="0.0.0.0", server_port=7860)
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image
4
+ import onnxruntime as ort
5
+
6
+ # Load the ONNX model
7
+ model_path = "cifar10_model.onnx"
8
+ ort_session = ort.InferenceSession(model_path)
9
+
10
+ # CIFAR-10 class labels
11
+ labels = [
12
+ "airplane", "automobile", "bird", "cat", "deer",
13
+ "dog", "frog", "horse", "ship", "truck"
14
+ ]
15
+
16
+ def preprocess_image(image):
17
+ # Resize to 32x32 and normalize
18
+ image = image.resize((32, 32))
19
+ image = np.array(image).astype(np.float32) / 255.0
20
+ # Reshape to (1, 3, 32, 32) [batch, channels, height, width]
21
+ return np.expand_dims(image.transpose(2, 0, 1), axis=0)
22
+
23
+ def predict(image):
24
+ # Preprocess the image
25
+ input_data = preprocess_image(image)
26
+ # Run inference
27
+ outputs = ort_session.run(None, {"serving_default_keras_tensor:0": input_data})[0]
28
+ predicted_class_idx = np.argmax(outputs)
29
+ return labels[predicted_class_idx]
30
+
31
+ # Gradio Interface
32
+ gr.Interface(
33
+ fn=predict,
34
+ inputs=gr.Image(type="pil"),
35
+ outputs=gr.Label(num_top_classes=3),
36
+ title="CIFAR-10 Classifier",
37
+ description="Upload an image to classify it into CIFAR-10 categories.",
38
+ allow_flagging="never" # Disable flagging
39
  ).launch(server_name="0.0.0.0", server_port=7860)