AsamiYukiko commited on
Commit
7fcff4a
·
1 Parent(s): d0f7d2d

Add Flask PV defect classifier app with EfficientNet-B0 ONNX model

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.onnx.data filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ── PV Defect Classifier — Docker Image ──
2
+ # Lightweight Python image, ~350MB final size
3
+ FROM python:3.11-slim
4
+
5
+ WORKDIR /app
6
+
7
+ # Install dependencies first (cache layer)
8
+ COPY requirements.txt .
9
+ RUN pip install --no-cache-dir -r requirements.txt
10
+
11
+ # Copy app code
12
+ COPY app.py .
13
+ COPY templates/ templates/
14
+ COPY static/ static/
15
+ COPY models/ models/
16
+
17
+ # Expose port (HF Spaces requires 7860)
18
+ EXPOSE 7860
19
+
20
+ # Production server (gunicorn instead of Flask dev server)
21
+ RUN pip install --no-cache-dir gunicorn
22
+
23
+ # Run with gunicorn: 2 workers, bind to all interfaces
24
+ CMD ["gunicorn", "--bind", "0.0.0.0:7860", "--workers", "2", "--timeout", "30", "app:app"]
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PV Defect Classification — Flask Demo
3
+ ======================================
4
+ Loads the best ONNX model and serves a web interface for
5
+ real-time photovoltaic panel defect classification.
6
+
7
+ Usage:
8
+ 1. Put your .onnx model file in the /models folder
9
+ 2. pip install flask onnxruntime pillow numpy
10
+ 3. python app.py
11
+ 4. Open http://localhost:7860 in your browser
12
+ """
13
+
14
+ import os
15
+ import time
16
+ import numpy as np
17
+ from PIL import Image
18
+ from flask import Flask, render_template, request, jsonify
19
+ import onnxruntime as ort
20
+
21
+ # ── Config ────────────────────────────────────────────────────
22
+ MODEL_DIR = os.path.join(os.path.dirname(__file__), "models")
23
+ CLASS_NAMES = ["DEFECTIVE", "NORMAL"]
24
+ IMG_SIZE = 224
25
+ MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
26
+ STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
27
+
28
+ app = Flask(__name__)
29
+
30
+ # ── Load ONNX model ──────────────────────────────────────────
31
+ def find_onnx_model():
32
+ """Auto-detect the first .onnx file in /models."""
33
+ for f in os.listdir(MODEL_DIR):
34
+ if f.endswith(".onnx"):
35
+ return os.path.join(MODEL_DIR, f)
36
+ return None
37
+
38
+ model_path = find_onnx_model()
39
+ if model_path:
40
+ session = ort.InferenceSession(model_path)
41
+ input_name = session.get_inputs()[0].name
42
+ print(f"✅ Loaded model: {os.path.basename(model_path)}")
43
+ else:
44
+ session = None
45
+ print("⚠️ No .onnx file found in /models — place your model there and restart.")
46
+
47
+
48
+ # ── Preprocessing (same as val_tf in your notebook) ──────────
49
+ def preprocess(image: Image.Image) -> np.ndarray:
50
+ """Resize, normalise, and convert PIL image to ONNX input tensor."""
51
+ img = image.convert("RGB").resize((IMG_SIZE, IMG_SIZE))
52
+ arr = np.array(img, dtype=np.float32) / 255.0 # [H, W, 3]
53
+ arr = (arr - MEAN) / STD # normalise
54
+ arr = arr.transpose(2, 0, 1) # [3, H, W]
55
+ return arr[np.newaxis, ...] # [1, 3, H, W]
56
+
57
+
58
+ def softmax(x):
59
+ e = np.exp(x - np.max(x))
60
+ return e / e.sum()
61
+
62
+
63
+ # ── Routes ────────────────────────────────────────────────────
64
+ @app.route("/")
65
+ def index():
66
+ model_name = os.path.basename(model_path) if model_path else "No model loaded"
67
+ return render_template("index.html", model_name=model_name)
68
+
69
+
70
+ @app.route("/predict", methods=["POST"])
71
+ def predict():
72
+ if session is None:
73
+ return jsonify({"error": "No model loaded. Put a .onnx file in /models."}), 500
74
+
75
+ if "file" not in request.files:
76
+ return jsonify({"error": "No file uploaded."}), 400
77
+
78
+ file = request.files["file"]
79
+ if file.filename == "":
80
+ return jsonify({"error": "Empty filename."}), 400
81
+
82
+ try:
83
+ image = Image.open(file.stream)
84
+ tensor = preprocess(image)
85
+
86
+ # Inference with timing
87
+ t0 = time.time()
88
+ outputs = session.run(None, {input_name: tensor})
89
+ latency_ms = (time.time() - t0) * 1000
90
+
91
+ logits = outputs[0][0]
92
+ probs = softmax(logits)
93
+ pred_idx = int(np.argmax(probs))
94
+ confidence = float(probs[pred_idx]) * 100
95
+
96
+ return jsonify({
97
+ "prediction": CLASS_NAMES[pred_idx],
98
+ "confidence": round(confidence, 1),
99
+ "latency_ms": round(latency_ms, 1),
100
+ "probabilities": {
101
+ CLASS_NAMES[i]: round(float(probs[i]) * 100, 1)
102
+ for i in range(len(CLASS_NAMES))
103
+ }
104
+ })
105
+ except Exception as e:
106
+ return jsonify({"error": str(e)}), 500
107
+
108
+
109
+ @app.route("/health")
110
+ def health():
111
+ """Health check endpoint — useful for cloud deployment."""
112
+ return jsonify({
113
+ "status": "ok",
114
+ "model_loaded": session is not None,
115
+ "model_file": os.path.basename(model_path) if model_path else None
116
+ })
117
+
118
+
119
+ if __name__ == "__main__":
120
+ app.run(debug=False, host="0.0.0.0", port=7860)
models/efficientnet_b0.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8082715f227b63acbf3a4c36cc5cec25df19eca44ed46e45b0058a6563addfa0
3
+ size 602059
models/efficientnet_b0.onnx.data ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eb3333d9ed900411fcb851cadfc58f0665f7b1db6ca39cdf80492e4d76163cc0
3
+ size 15990784
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ flask==3.0.0
2
+ onnxruntime==1.17.0
3
+ Pillow==10.2.0
4
+ numpy==1.26.3
templates/index.html ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>PV Defect Classifier</title>
7
+ <style>
8
+ * { box-sizing: border-box; margin: 0; padding: 0; }
9
+ body {
10
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
11
+ background: #f5f7fa;
12
+ color: #1a1a2e;
13
+ min-height: 100vh;
14
+ display: flex;
15
+ flex-direction: column;
16
+ align-items: center;
17
+ }
18
+ .header {
19
+ width: 100%;
20
+ background: linear-gradient(135deg, #1a3c6e 0%, #2e5e8e 100%);
21
+ color: white;
22
+ padding: 2rem;
23
+ text-align: center;
24
+ }
25
+ .header h1 { font-size: 1.6rem; font-weight: 600; }
26
+ .header p { font-size: 0.85rem; opacity: 0.8; margin-top: 0.4rem; }
27
+ .container {
28
+ max-width: 560px;
29
+ width: 100%;
30
+ padding: 2rem 1.5rem;
31
+ }
32
+
33
+ /* Upload area */
34
+ .upload-area {
35
+ border: 2px dashed #c0cfe0;
36
+ border-radius: 12px;
37
+ padding: 2.5rem 1.5rem;
38
+ text-align: center;
39
+ background: white;
40
+ cursor: pointer;
41
+ transition: border-color 0.2s, background 0.2s;
42
+ }
43
+ .upload-area:hover, .upload-area.dragover {
44
+ border-color: #2e5e8e;
45
+ background: #eef3fa;
46
+ }
47
+ .upload-area svg { width: 48px; height: 48px; stroke: #7a8fa8; }
48
+ .upload-area p { color: #5a6d82; margin-top: 0.8rem; font-size: 0.9rem; }
49
+ .upload-area .hint { font-size: 0.75rem; color: #9aa8b8; margin-top: 0.3rem; }
50
+ #file-input { display: none; }
51
+
52
+ /* Preview */
53
+ .preview-section { margin-top: 1.5rem; display: none; }
54
+ .preview-section.show { display: block; }
55
+ .preview-img {
56
+ width: 100%;
57
+ max-height: 300px;
58
+ object-fit: contain;
59
+ border-radius: 8px;
60
+ background: #e8ecf2;
61
+ }
62
+ .btn {
63
+ display: block;
64
+ width: 100%;
65
+ margin-top: 1rem;
66
+ padding: 0.75rem;
67
+ font-size: 1rem;
68
+ font-weight: 500;
69
+ border: none;
70
+ border-radius: 8px;
71
+ cursor: pointer;
72
+ background: #1a3c6e;
73
+ color: white;
74
+ transition: background 0.2s;
75
+ }
76
+ .btn:hover { background: #2e5e8e; }
77
+ .btn:disabled { background: #9aa8b8; cursor: not-allowed; }
78
+
79
+ /* Result card */
80
+ .result-card {
81
+ margin-top: 1.5rem;
82
+ background: white;
83
+ border-radius: 12px;
84
+ padding: 1.5rem;
85
+ display: none;
86
+ box-shadow: 0 2px 8px rgba(0,0,0,0.06);
87
+ }
88
+ .result-card.show { display: block; }
89
+ .result-label {
90
+ font-size: 1.4rem;
91
+ font-weight: 600;
92
+ text-align: center;
93
+ padding: 0.5rem 0;
94
+ }
95
+ .result-label.defective { color: #c0392b; }
96
+ .result-label.normal { color: #0f6e56; }
97
+
98
+ .metrics {
99
+ display: grid;
100
+ grid-template-columns: 1fr 1fr 1fr;
101
+ gap: 12px;
102
+ margin-top: 1rem;
103
+ }
104
+ .metric {
105
+ text-align: center;
106
+ background: #f5f7fa;
107
+ padding: 0.8rem 0.5rem;
108
+ border-radius: 8px;
109
+ }
110
+ .metric .value { font-size: 1.2rem; font-weight: 600; color: #1a3c6e; }
111
+ .metric .label { font-size: 0.7rem; color: #7a8fa8; margin-top: 2px; }
112
+
113
+ .prob-bar-container { margin-top: 1.2rem; }
114
+ .prob-row {
115
+ display: flex;
116
+ align-items: center;
117
+ gap: 8px;
118
+ margin-bottom: 6px;
119
+ font-size: 0.8rem;
120
+ }
121
+ .prob-name { width: 80px; text-align: right; color: #5a6d82; }
122
+ .prob-bar-bg {
123
+ flex: 1;
124
+ height: 10px;
125
+ background: #e8ecf2;
126
+ border-radius: 5px;
127
+ overflow: hidden;
128
+ }
129
+ .prob-bar {
130
+ height: 100%;
131
+ border-radius: 5px;
132
+ transition: width 0.5s ease;
133
+ }
134
+ .prob-bar.defective { background: #e74c3c; }
135
+ .prob-bar.normal { background: #1d9e75; }
136
+ .prob-pct { width: 44px; font-weight: 500; color: #1a1a2e; }
137
+
138
+ .model-badge {
139
+ text-align: center;
140
+ margin-top: 1rem;
141
+ font-size: 0.7rem;
142
+ color: #9aa8b8;
143
+ }
144
+
145
+ /* Loading spinner */
146
+ .spinner { display: none; text-align: center; margin-top: 1rem; }
147
+ .spinner.show { display: block; }
148
+ .spinner::after {
149
+ content: '';
150
+ display: inline-block;
151
+ width: 28px; height: 28px;
152
+ border: 3px solid #c0cfe0;
153
+ border-top-color: #1a3c6e;
154
+ border-radius: 50%;
155
+ animation: spin 0.7s linear infinite;
156
+ }
157
+ @keyframes spin { to { transform: rotate(360deg); } }
158
+ </style>
159
+ </head>
160
+ <body>
161
+
162
+ <div class="header">
163
+ <h1>PV Defect Classifier</h1>
164
+ <p>Upload a photovoltaic cell image for real-time defect classification</p>
165
+ </div>
166
+
167
+ <div class="container">
168
+
169
+ <!-- Upload -->
170
+ <div class="upload-area" id="drop-zone" onclick="document.getElementById('file-input').click()">
171
+ <svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor">
172
+ <path stroke-linecap="round" stroke-linejoin="round" d="M3 16.5v2.25A2.25 2.25 0 005.25 21h13.5A2.25 2.25 0 0021 18.75V16.5m-13.5-9L12 3m0 0l4.5 4.5M12 3v13.5"/>
173
+ </svg>
174
+ <p>Click or drag an image here</p>
175
+ <span class="hint">Supports JPG, PNG — PV electroluminescence images</span>
176
+ </div>
177
+ <input type="file" id="file-input" accept="image/*">
178
+
179
+ <!-- Preview -->
180
+ <div class="preview-section" id="preview-section">
181
+ <img id="preview-img" class="preview-img" alt="Preview">
182
+ <button class="btn" id="classify-btn" onclick="classify()">Classify</button>
183
+ </div>
184
+
185
+ <!-- Loading -->
186
+ <div class="spinner" id="spinner"></div>
187
+
188
+ <!-- Result -->
189
+ <div class="result-card" id="result-card">
190
+ <div class="result-label" id="result-label"></div>
191
+ <div class="metrics">
192
+ <div class="metric">
193
+ <div class="value" id="confidence">—</div>
194
+ <div class="label">Confidence</div>
195
+ </div>
196
+ <div class="metric">
197
+ <div class="value" id="latency">—</div>
198
+ <div class="label">Latency</div>
199
+ </div>
200
+ <div class="metric">
201
+ <div class="value" id="model-size">—</div>
202
+ <div class="label">Model</div>
203
+ </div>
204
+ </div>
205
+ <div class="prob-bar-container" id="prob-bars"></div>
206
+ <div class="model-badge">Model: <span id="model-name">{{ model_name }}</span></div>
207
+ </div>
208
+
209
+ </div>
210
+
211
+ <script>
212
+ const dropZone = document.getElementById('drop-zone');
213
+ const fileInput = document.getElementById('file-input');
214
+ const previewSection = document.getElementById('preview-section');
215
+ const previewImg = document.getElementById('preview-img');
216
+ let selectedFile = null;
217
+
218
+ // Drag & drop
219
+ dropZone.addEventListener('dragover', e => { e.preventDefault(); dropZone.classList.add('dragover'); });
220
+ dropZone.addEventListener('dragleave', () => dropZone.classList.remove('dragover'));
221
+ dropZone.addEventListener('drop', e => {
222
+ e.preventDefault();
223
+ dropZone.classList.remove('dragover');
224
+ if (e.dataTransfer.files.length) handleFile(e.dataTransfer.files[0]);
225
+ });
226
+
227
+ fileInput.addEventListener('change', () => {
228
+ if (fileInput.files.length) handleFile(fileInput.files[0]);
229
+ });
230
+
231
+ function handleFile(file) {
232
+ selectedFile = file;
233
+ const reader = new FileReader();
234
+ reader.onload = e => {
235
+ previewImg.src = e.target.result;
236
+ previewSection.classList.add('show');
237
+ document.getElementById('result-card').classList.remove('show');
238
+ };
239
+ reader.readAsDataURL(file);
240
+ }
241
+
242
+ async function classify() {
243
+ if (!selectedFile) return;
244
+ const btn = document.getElementById('classify-btn');
245
+ const spinner = document.getElementById('spinner');
246
+ const resultCard = document.getElementById('result-card');
247
+
248
+ btn.disabled = true;
249
+ spinner.classList.add('show');
250
+ resultCard.classList.remove('show');
251
+
252
+ const formData = new FormData();
253
+ formData.append('file', selectedFile);
254
+
255
+ try {
256
+ const res = await fetch('/predict', { method: 'POST', body: formData });
257
+ const data = await res.json();
258
+
259
+ if (data.error) {
260
+ alert('Error: ' + data.error);
261
+ return;
262
+ }
263
+
264
+ // Update result
265
+ const label = document.getElementById('result-label');
266
+ label.textContent = data.prediction;
267
+ label.className = 'result-label ' + data.prediction.toLowerCase();
268
+
269
+ document.getElementById('confidence').textContent = data.confidence + '%';
270
+ document.getElementById('latency').textContent = data.latency_ms + 'ms';
271
+ document.getElementById('model-size').textContent = document.getElementById('model-name').textContent.replace('.onnx','');
272
+
273
+ // Probability bars
274
+ const barsDiv = document.getElementById('prob-bars');
275
+ barsDiv.innerHTML = '';
276
+ for (const [cls, pct] of Object.entries(data.probabilities)) {
277
+ barsDiv.innerHTML += `
278
+ <div class="prob-row">
279
+ <span class="prob-name">${cls}</span>
280
+ <div class="prob-bar-bg"><div class="prob-bar ${cls.toLowerCase()}" style="width:${pct}%"></div></div>
281
+ <span class="prob-pct">${pct}%</span>
282
+ </div>`;
283
+ }
284
+
285
+ resultCard.classList.add('show');
286
+ } catch (e) {
287
+ alert('Request failed: ' + e.message);
288
+ } finally {
289
+ btn.disabled = false;
290
+ spinner.classList.remove('show');
291
+ }
292
+ }
293
+ </script>
294
+
295
+ </body>
296
+ </html>