Spaces:
Running
Running
| """ | |
| PV Defect Classification — Flask Demo | |
| Loads the best ONNX model and serves a web interface for | |
| real-time photovoltaic panel defect classification. | |
| """ | |
| import os | |
| import time | |
| import numpy as np | |
| from PIL import Image | |
| from flask import Flask, render_template, request, jsonify, send_file | |
| import onnxruntime as ort | |
| # Config | |
| MODEL_DIR = os.path.join(os.path.dirname(__file__), "models") | |
| CLASS_NAMES = ["DEFECTIVE", "NORMAL"] | |
| IMG_SIZE = 224 | |
| MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32) | |
| STD = np.array([0.229, 0.224, 0.225], dtype=np.float32) | |
| app = Flask(__name__) | |
| # Load ONNX model | |
| def find_onnx_model(): | |
| """Auto-detect the first .onnx file in /models.""" | |
| for f in os.listdir(MODEL_DIR): | |
| if f.endswith(".onnx"): | |
| return os.path.join(MODEL_DIR, f) | |
| return None | |
| model_path = find_onnx_model() | |
| if model_path: | |
| session = ort.InferenceSession(model_path) | |
| input_name = session.get_inputs()[0].name | |
| print(f"[INFO] Model loaded: {os.path.basename(model_path)}") | |
| else: | |
| session = None | |
| print("[ERROR] No ONNX model file found in MODEL_DIR. Inference endpoint will be unavailable.") | |
| # Preprocessing | |
| def preprocess(image: Image.Image) -> np.ndarray: | |
| """Resize, normalise, and convert PIL image to ONNX input tensor.""" | |
| img = image.convert("RGB").resize((IMG_SIZE, IMG_SIZE)) | |
| arr = np.array(img, dtype=np.float32) / 255.0 # [H, W, 3] | |
| arr = (arr - MEAN) / STD # normalise | |
| arr = arr.transpose(2, 0, 1) # [3, H, W] | |
| return arr[np.newaxis, ...] # [1, 3, H, W] | |
| def softmax(x): | |
| e = np.exp(x - np.max(x)) | |
| return e / e.sum() | |
| # Routes | |
| def index(): | |
| model_name = os.path.basename(model_path) if model_path else "No model loaded" | |
| return render_template("index.html", model_name=model_name) | |
| def predict(): | |
| if session is None: | |
| return jsonify({"error": "Inference service unavailable: model not loaded."}), 500 | |
| if "file" not in request.files: | |
| return jsonify({"error": "No file uploaded."}), 400 | |
| file = request.files["file"] | |
| if file.filename == "": | |
| return jsonify({"error": "Empty filename."}), 400 | |
| try: | |
| image = Image.open(file.stream) | |
| tensor = preprocess(image) | |
| # Inference with timing | |
| t0 = time.time() | |
| outputs = session.run(None, {input_name: tensor}) | |
| latency_ms = (time.time() - t0) * 1000 | |
| logits = outputs[0][0] | |
| probs = softmax(logits) | |
| pred_idx = int(np.argmax(probs)) | |
| confidence = float(probs[pred_idx]) * 100 | |
| return jsonify({ | |
| "prediction": CLASS_NAMES[pred_idx], | |
| "confidence": round(confidence, 1), | |
| "latency_ms": round(latency_ms, 1), | |
| "probabilities": { | |
| CLASS_NAMES[i]: round(float(probs[i]) * 100, 1) | |
| for i in range(len(CLASS_NAMES)) | |
| } | |
| }) | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| SAMPLES_DIR = os.path.join(os.path.dirname(__file__), "test_images") | |
| def sample(label): | |
| """Serve a pre-stored sample image for demo purposes.""" | |
| if label not in ("defective", "normal"): | |
| return jsonify({"error": "Unknown label"}), 404 | |
| for ext, mime in [("jpg", "image/jpeg"), ("png", "image/png")]: | |
| path = os.path.join(SAMPLES_DIR, f"{label}.{ext}") | |
| if os.path.exists(path): | |
| return send_file(path, mimetype=mime) | |
| return jsonify({"error": f"No sample found for '{label}'"}), 404 | |
| def health(): | |
| """Health check endpoint — useful for cloud deployment.""" | |
| return jsonify({ | |
| "status": "ok", | |
| "model_loaded": session is not None, | |
| "model_file": os.path.basename(model_path) if model_path else None | |
| }) | |
| if __name__ == "__main__": | |
| app.run(debug=True, host="0.0.0.0", port=5000) | |