pv-classifier / app.py
AsamiYukiko
refactor: replace tutorial-style messages with production-grade log output
a07581b
"""
PV Defect Classification — Flask Demo
Loads the best ONNX model and serves a web interface for
real-time photovoltaic panel defect classification.
"""
import os
import time
import numpy as np
from PIL import Image
from flask import Flask, render_template, request, jsonify, send_file
import onnxruntime as ort
# Config
MODEL_DIR = os.path.join(os.path.dirname(__file__), "models")
CLASS_NAMES = ["DEFECTIVE", "NORMAL"]
IMG_SIZE = 224
MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
app = Flask(__name__)
# Load ONNX model
def find_onnx_model():
"""Auto-detect the first .onnx file in /models."""
for f in os.listdir(MODEL_DIR):
if f.endswith(".onnx"):
return os.path.join(MODEL_DIR, f)
return None
model_path = find_onnx_model()
if model_path:
session = ort.InferenceSession(model_path)
input_name = session.get_inputs()[0].name
print(f"[INFO] Model loaded: {os.path.basename(model_path)}")
else:
session = None
print("[ERROR] No ONNX model file found in MODEL_DIR. Inference endpoint will be unavailable.")
# Preprocessing
def preprocess(image: Image.Image) -> np.ndarray:
"""Resize, normalise, and convert PIL image to ONNX input tensor."""
img = image.convert("RGB").resize((IMG_SIZE, IMG_SIZE))
arr = np.array(img, dtype=np.float32) / 255.0 # [H, W, 3]
arr = (arr - MEAN) / STD # normalise
arr = arr.transpose(2, 0, 1) # [3, H, W]
return arr[np.newaxis, ...] # [1, 3, H, W]
def softmax(x):
e = np.exp(x - np.max(x))
return e / e.sum()
# Routes
@app.route("/")
def index():
model_name = os.path.basename(model_path) if model_path else "No model loaded"
return render_template("index.html", model_name=model_name)
@app.route("/predict", methods=["POST"])
def predict():
if session is None:
return jsonify({"error": "Inference service unavailable: model not loaded."}), 500
if "file" not in request.files:
return jsonify({"error": "No file uploaded."}), 400
file = request.files["file"]
if file.filename == "":
return jsonify({"error": "Empty filename."}), 400
try:
image = Image.open(file.stream)
tensor = preprocess(image)
# Inference with timing
t0 = time.time()
outputs = session.run(None, {input_name: tensor})
latency_ms = (time.time() - t0) * 1000
logits = outputs[0][0]
probs = softmax(logits)
pred_idx = int(np.argmax(probs))
confidence = float(probs[pred_idx]) * 100
return jsonify({
"prediction": CLASS_NAMES[pred_idx],
"confidence": round(confidence, 1),
"latency_ms": round(latency_ms, 1),
"probabilities": {
CLASS_NAMES[i]: round(float(probs[i]) * 100, 1)
for i in range(len(CLASS_NAMES))
}
})
except Exception as e:
return jsonify({"error": str(e)}), 500
SAMPLES_DIR = os.path.join(os.path.dirname(__file__), "test_images")
@app.route("/sample/<label>")
def sample(label):
"""Serve a pre-stored sample image for demo purposes."""
if label not in ("defective", "normal"):
return jsonify({"error": "Unknown label"}), 404
for ext, mime in [("jpg", "image/jpeg"), ("png", "image/png")]:
path = os.path.join(SAMPLES_DIR, f"{label}.{ext}")
if os.path.exists(path):
return send_file(path, mimetype=mime)
return jsonify({"error": f"No sample found for '{label}'"}), 404
@app.route("/health")
def health():
"""Health check endpoint — useful for cloud deployment."""
return jsonify({
"status": "ok",
"model_loaded": session is not None,
"model_file": os.path.basename(model_path) if model_path else None
})
if __name__ == "__main__":
app.run(debug=True, host="0.0.0.0", port=5000)