KMayanja commited on
Commit
d841d2a
·
verified ·
1 Parent(s): 8ce3da0

Upload 2 files

Browse files

### Onnx model
Uploaded `onnx` version of coffee classification model along with python file`onnx_server.py` that runs it in flask

Files changed (2) hide show
  1. coffee_model.onnx +3 -0
  2. onnx_server.py +55 -0
coffee_model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:692fcb1f895602df35597f2e935c4c5cdf05bb2c15a62c08e69b0e1c881cfb4b
3
+ size 29999661
onnx_server.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify, render_template
2
+ from flask_cors import CORS # Import CORS
3
+ import onnxruntime as rt
4
+ import numpy as np
5
+ import cv2
6
+ import io
7
+ from PIL import Image
8
+
9
+ app = Flask(__name__)
10
+
11
+ # Enable CORS for all routes
12
+ CORS(app)
13
+
14
+ # Load ONNX model
15
+ MODEL_PATH = "coffee_model.onnx" # Ensure the correct path
16
+ session = rt.InferenceSession(MODEL_PATH)
17
+
18
+ @app.route("/", methods=["GET"])
19
+ def home():
20
+ context = jsonify({"message": "ONNX Model API is running!"})
21
+ return render_template("home.html")
22
+
23
+ @app.route("/predict", methods=["POST"])
24
+ def predict():
25
+ try:
26
+ if 'file' not in request.files:
27
+ return jsonify({"error": "No file uploaded"}), 400
28
+
29
+ file = request.files['file']
30
+ image = Image.open(io.BytesIO(file.read()))
31
+ image = np.array(image)
32
+ image = cv2.resize(image, (224, 224)) # Resize to model input size
33
+ image = image.astype(np.float32) / 255.0 # Normalize
34
+ image = np.expand_dims(image, axis=0) # Add batch dimension
35
+
36
+ # Get input name from the model
37
+ input_name = session.get_inputs()[0].name
38
+
39
+ # Run model inference
40
+ result = session.run(None, {input_name: image})
41
+ prediction = np.argmax(result[0])
42
+ confidence = np.max(result[0])
43
+
44
+ # Map predictions to class names
45
+ classes = ['Health leaves', 'leaf rust', 'phoma']
46
+ predicted_class = classes[prediction]
47
+
48
+ return jsonify({"class": predicted_class, "confidence": float(confidence)})
49
+
50
+ except Exception as e:
51
+ print(f"Error: {e}") # Log error for debugging
52
+ return jsonify({"error": str(e)}), 500
53
+
54
+ if __name__ == "__main__":
55
+ app.run(host="0.0.0.0", port=5000, debug=True)