File size: 3,992 Bytes
7fcff4a
 
 
 
 
 
 
 
 
 
 
5764d7a
7fcff4a
 
1082fc8
7fcff4a
 
 
 
 
 
 
 
1082fc8
7fcff4a
 
 
 
 
 
 
 
 
 
 
a07581b
7fcff4a
 
a07581b
7fcff4a
 
1082fc8
7fcff4a
 
 
 
 
 
 
 
 
 
 
 
 
 
1082fc8
7fcff4a
 
 
 
 
 
 
 
 
a07581b
7fcff4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5764d7a
 
 
 
 
 
 
 
 
 
 
 
 
 
7fcff4a
 
 
 
 
 
 
 
 
 
 
1082fc8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""
PV Defect Classification — Flask Demo
Loads the best ONNX model and serves a web interface for
real-time photovoltaic panel defect classification.

"""

import os
import time
import numpy as np
from PIL import Image
from flask import Flask, render_template, request, jsonify, send_file
import onnxruntime as ort

#  Config 
MODEL_DIR = os.path.join(os.path.dirname(__file__), "models")
CLASS_NAMES = ["DEFECTIVE", "NORMAL"]
IMG_SIZE = 224
MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
STD  = np.array([0.229, 0.224, 0.225], dtype=np.float32)

app = Flask(__name__)

# Load ONNX model 
def find_onnx_model():
    """Auto-detect the first .onnx file in /models."""
    for f in os.listdir(MODEL_DIR):
        if f.endswith(".onnx"):
            return os.path.join(MODEL_DIR, f)
    return None

model_path = find_onnx_model()
if model_path:
    session = ort.InferenceSession(model_path)
    input_name = session.get_inputs()[0].name
    print(f"[INFO] Model loaded: {os.path.basename(model_path)}")
else:
    session = None
    print("[ERROR] No ONNX model file found in MODEL_DIR. Inference endpoint will be unavailable.")


#  Preprocessing
def preprocess(image: Image.Image) -> np.ndarray:
    """Resize, normalise, and convert PIL image to ONNX input tensor."""
    img = image.convert("RGB").resize((IMG_SIZE, IMG_SIZE))
    arr = np.array(img, dtype=np.float32) / 255.0   # [H, W, 3]
    arr = (arr - MEAN) / STD                          # normalise
    arr = arr.transpose(2, 0, 1)                      # [3, H, W]
    return arr[np.newaxis, ...]                        # [1, 3, H, W]


def softmax(x):
    e = np.exp(x - np.max(x))
    return e / e.sum()


# Routes 
@app.route("/")
def index():
    model_name = os.path.basename(model_path) if model_path else "No model loaded"
    return render_template("index.html", model_name=model_name)


@app.route("/predict", methods=["POST"])
def predict():
    if session is None:
        return jsonify({"error": "Inference service unavailable: model not loaded."}), 500

    if "file" not in request.files:
        return jsonify({"error": "No file uploaded."}), 400

    file = request.files["file"]
    if file.filename == "":
        return jsonify({"error": "Empty filename."}), 400

    try:
        image = Image.open(file.stream)
        tensor = preprocess(image)

        # Inference with timing
        t0 = time.time()
        outputs = session.run(None, {input_name: tensor})
        latency_ms = (time.time() - t0) * 1000

        logits = outputs[0][0]
        probs = softmax(logits)
        pred_idx = int(np.argmax(probs))
        confidence = float(probs[pred_idx]) * 100

        return jsonify({
            "prediction": CLASS_NAMES[pred_idx],
            "confidence": round(confidence, 1),
            "latency_ms": round(latency_ms, 1),
            "probabilities": {
                CLASS_NAMES[i]: round(float(probs[i]) * 100, 1)
                for i in range(len(CLASS_NAMES))
            }
        })
    except Exception as e:
        return jsonify({"error": str(e)}), 500


SAMPLES_DIR = os.path.join(os.path.dirname(__file__), "test_images")

@app.route("/sample/<label>")
def sample(label):
    """Serve a pre-stored sample image for demo purposes."""
    if label not in ("defective", "normal"):
        return jsonify({"error": "Unknown label"}), 404
    for ext, mime in [("jpg", "image/jpeg"), ("png", "image/png")]:
        path = os.path.join(SAMPLES_DIR, f"{label}.{ext}")
        if os.path.exists(path):
            return send_file(path, mimetype=mime)
    return jsonify({"error": f"No sample found for '{label}'"}), 404


@app.route("/health")
def health():
    """Health check endpoint — useful for cloud deployment."""
    return jsonify({
        "status": "ok",
        "model_loaded": session is not None,
        "model_file": os.path.basename(model_path) if model_path else None
    })


if __name__ == "__main__":
    app.run(debug=True, host="0.0.0.0", port=5000)