Zorrojurro commited on
Commit
464772c
Β·
verified Β·
1 Parent(s): 1c746d7

Upload web_app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. web_app.py +254 -0
web_app.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Flask Web Application β€” Thermal Pattern Analysis Interface.
4
+
5
+ Usage:
6
+ python web_app.py
7
+ β†’ Open http://localhost:5000
8
+ """
9
+
10
+ import os
11
+ import io
12
+ import base64
13
+ import torch
14
+ import cv2
15
+ import numpy as np
16
+ import matplotlib
17
+ matplotlib.use('Agg')
18
+ import matplotlib.pyplot as plt
19
+ import torch.nn as nn
20
+ from pathlib import Path
21
+ from flask import Flask, render_template, request, jsonify
22
+ from flask_cors import CORS
23
+
24
+ from src.utils.config import load_config, setup_device
25
+ from src.preprocessing.image_processor import ThermalImageProcessor
26
+ from src.models.anomaly_detector import ThermalPatternPipeline
27
+
28
+ app = Flask(__name__)
29
+ CORS(app)
30
+
31
+ # ── Global model state ───────────────────────────────────────────────
32
+ MODEL = None
33
+ CLASSIFIER = None
34
+ PROCESSOR = None
35
+ DEVICE = None
36
+
37
+
38
+ def load_model():
39
+ """Load model, classifier, and processor at startup."""
40
+ global MODEL, CLASSIFIER, PROCESSOR, DEVICE
41
+
42
+ config = load_config("configs/config.yaml")
43
+ DEVICE = setup_device(config)
44
+
45
+ MODEL = ThermalPatternPipeline.from_config(config).to(DEVICE)
46
+ CLASSIFIER = nn.Linear(config.model.feature_extractor.embedding_dim, 2).to(DEVICE)
47
+
48
+ ckpt_path = Path("checkpoints/best_model.pt")
49
+ if ckpt_path.exists():
50
+ ckpt = torch.load(ckpt_path, map_location=DEVICE, weights_only=False)
51
+ MODEL.load_state_dict(ckpt["model_state_dict"])
52
+ CLASSIFIER.load_state_dict(ckpt["classifier_state_dict"])
53
+ print(f" βœ“ Model loaded from {ckpt_path}")
54
+ else:
55
+ print(f" βœ— No checkpoint at {ckpt_path}")
56
+
57
+ MODEL.eval()
58
+ CLASSIFIER.eval()
59
+ PROCESSOR = ThermalImageProcessor.from_config(config)
60
+
61
+
62
+ def img_to_base64(img, cmap=None):
63
+ """Convert numpy image to base64-encoded PNG for HTML display."""
64
+ # Normalize to 0-255 uint8 if needed
65
+ if img.dtype == np.float32 or img.dtype == np.float64:
66
+ img_u8 = (np.clip(img, 0, 1) * 255).astype(np.uint8) if img.max() <= 1.0 else np.clip(img, 0, 255).astype(np.uint8)
67
+ else:
68
+ img_u8 = img.astype(np.uint8)
69
+
70
+ if cmap == 'jet':
71
+ # Grad-CAM heatmap
72
+ colored = cv2.applyColorMap(img_u8, cv2.COLORMAP_JET)
73
+ elif len(img_u8.shape) == 2:
74
+ # Grayscale β†’ apply thermal inferno colormap
75
+ colored = cv2.applyColorMap(img_u8, cv2.COLORMAP_INFERNO)
76
+ else:
77
+ # Already colored (like overlay)
78
+ colored = cv2.cvtColor(img_u8, cv2.COLOR_RGB2BGR) if img_u8.shape[2] == 3 else img_u8
79
+
80
+ _, buf = cv2.imencode('.png', colored)
81
+ return base64.b64encode(buf.tobytes()).decode('utf-8')
82
+
83
+
84
+ def compute_gradcam(input_tensor):
85
+ """Compute Grad-CAM heatmap."""
86
+ target_layer = MODEL.feature_extractor.layer4[-1].conv2
87
+ activations, gradients = {}, {}
88
+
89
+ def fwd_hook(m, i, o): activations["v"] = o.detach()
90
+ def bwd_hook(m, gi, go): gradients["v"] = go[0].detach()
91
+
92
+ fh = target_layer.register_forward_hook(fwd_hook)
93
+ bh = target_layer.register_full_backward_hook(bwd_hook)
94
+
95
+ try:
96
+ img = input_tensor.unsqueeze(0).to(DEVICE)
97
+ features = MODEL.feature_extractor(img)
98
+ MODEL.zero_grad()
99
+ features.max().backward()
100
+
101
+ acts = activations["v"].squeeze(0)
102
+ grads = gradients["v"].squeeze(0)
103
+ weights = grads.mean(dim=(1, 2))
104
+ cam = torch.relu((weights[:, None, None] * acts).sum(0))
105
+ cam = cam / (cam.max() + 1e-8)
106
+ cam = cam.cpu().numpy()
107
+ return cv2.resize(cam, (224, 224))
108
+ finally:
109
+ fh.remove()
110
+ bh.remove()
111
+
112
+
113
+ # ── Routes ────────────────────────────────────────────────────────────
114
+
115
+ @app.route("/")
116
+ def index():
117
+ return render_template("index.html")
118
+
119
+
120
+ @app.route("/analyze", methods=["POST"])
121
+ def analyze():
122
+ if "file" not in request.files:
123
+ return jsonify({"error": "No file uploaded"}), 400
124
+
125
+ file = request.files["file"]
126
+ file_bytes = np.frombuffer(file.read(), np.uint8)
127
+ img = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
128
+
129
+ if img is None:
130
+ return jsonify({"error": "Cannot read image"}), 400
131
+
132
+ # Grayscale
133
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if len(img.shape) == 3 else img.copy()
134
+ original = gray.copy()
135
+
136
+ # Preprocessing steps
137
+ resized = PROCESSOR.resize(gray)
138
+ denoised = PROCESSOR.denoise(resized)
139
+ enhanced = PROCESSOR.enhance_contrast(denoised)
140
+ normalized = enhanced.astype(np.float32) / 255.0
141
+
142
+ # Inference
143
+ with torch.no_grad():
144
+ img_tensor = torch.from_numpy(normalized).unsqueeze(0) # [1, H, W]
145
+ sequence = img_tensor.unsqueeze(0).repeat(1, 5, 1, 1).unsqueeze(2) # [1, 5, 1, H, W]
146
+ sequence = sequence.to(DEVICE)
147
+
148
+ results = MODEL(sequence)
149
+ logits = CLASSIFIER(results["encoding"])
150
+ probs = torch.softmax(logits, dim=1)
151
+ anomaly_score = probs[0, 1].item()
152
+ prediction = "ABNORMAL" if anomaly_score > 0.5 else "NORMAL"
153
+ confidence = max(anomaly_score, 1 - anomaly_score) * 100
154
+
155
+ # Grad-CAM
156
+ gradcam = compute_gradcam(img_tensor)
157
+
158
+ # Create overlay
159
+ heatmap_colored = cv2.applyColorMap((gradcam * 255).astype(np.uint8), cv2.COLORMAP_JET)
160
+ base_bgr = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2BGR)
161
+ overlay = cv2.addWeighted(base_bgr, 0.6, heatmap_colored, 0.4, 0)
162
+ overlay_rgb = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
163
+
164
+ # Encode images
165
+ response = {
166
+ "prediction": prediction,
167
+ "anomaly_score": round(anomaly_score, 4),
168
+ "confidence": round(confidence, 1),
169
+ "images": {
170
+ "original": img_to_base64(original),
171
+ "resized": img_to_base64(resized),
172
+ "denoised": img_to_base64(denoised),
173
+ "enhanced": img_to_base64(enhanced),
174
+ "normalized": img_to_base64(normalized),
175
+ "gradcam": img_to_base64(gradcam, cmap='jet'),
176
+ "overlay": img_to_base64(overlay_rgb),
177
+ }
178
+ }
179
+
180
+ return jsonify(response)
181
+
182
+
183
+ @app.route("/sample_images")
184
+ def sample_images():
185
+ """Return list of sample images from the dataset."""
186
+ import glob
187
+ samples = glob.glob("data/raw/Power Transformers/*.jpg")[:12]
188
+ names = [Path(s).name for s in samples]
189
+ return jsonify(names)
190
+
191
+
192
+ @app.route("/analyze_sample/<filename>")
193
+ def analyze_sample(filename):
194
+ """Analyze a sample image from the dataset."""
195
+ path = Path("data/raw/Power Transformers") / filename
196
+ if not path.exists():
197
+ return jsonify({"error": "Sample not found"}), 404
198
+
199
+ with open(path, "rb") as f:
200
+ from werkzeug.datastructures import FileStorage
201
+ file = FileStorage(f, filename=filename)
202
+ # Read the file manually
203
+ file_bytes = np.frombuffer(f.read(), np.uint8)
204
+
205
+ img = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
206
+ if img is None:
207
+ return jsonify({"error": "Cannot read image"}), 400
208
+
209
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if len(img.shape) == 3 else img.copy()
210
+ original = gray.copy()
211
+ resized = PROCESSOR.resize(gray)
212
+ denoised = PROCESSOR.denoise(resized)
213
+ enhanced = PROCESSOR.enhance_contrast(denoised)
214
+ normalized = enhanced.astype(np.float32) / 255.0
215
+
216
+ with torch.no_grad():
217
+ img_tensor = torch.from_numpy(normalized).unsqueeze(0)
218
+ sequence = img_tensor.unsqueeze(0).repeat(1, 5, 1, 1).unsqueeze(2)
219
+ sequence = sequence.to(DEVICE)
220
+ results = MODEL(sequence)
221
+ logits = CLASSIFIER(results["encoding"])
222
+ probs = torch.softmax(logits, dim=1)
223
+ anomaly_score = probs[0, 1].item()
224
+ prediction = "ABNORMAL" if anomaly_score > 0.5 else "NORMAL"
225
+ confidence = max(anomaly_score, 1 - anomaly_score) * 100
226
+
227
+ gradcam = compute_gradcam(img_tensor)
228
+ heatmap_colored = cv2.applyColorMap((gradcam * 255).astype(np.uint8), cv2.COLORMAP_JET)
229
+ base_bgr = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2BGR)
230
+ overlay = cv2.addWeighted(base_bgr, 0.6, heatmap_colored, 0.4, 0)
231
+ overlay_rgb = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
232
+
233
+ return jsonify({
234
+ "prediction": prediction,
235
+ "anomaly_score": round(anomaly_score, 4),
236
+ "confidence": round(confidence, 1),
237
+ "images": {
238
+ "original": img_to_base64(original),
239
+ "resized": img_to_base64(resized),
240
+ "denoised": img_to_base64(denoised),
241
+ "enhanced": img_to_base64(enhanced),
242
+ "normalized": img_to_base64(normalized),
243
+ "gradcam": img_to_base64(gradcam, cmap='jet'),
244
+ "overlay": img_to_base64(overlay_rgb),
245
+ }
246
+ })
247
+
248
+
249
+ if __name__ == "__main__":
250
+ print("Loading model...")
251
+ load_model()
252
+ port = int(os.environ.get("PORT", 5000))
253
+ print(f"Starting server on port {port}")
254
+ app.run(debug=False, host="0.0.0.0", port=port)