|
|
| """
|
| Flask Web Application — Thermal Pattern Analysis Interface.
|
|
|
| Usage:
|
| python web_app.py
|
| → Open http://localhost:5000
|
| """
|
|
|
| import os
|
| import io
|
| import base64
|
| import torch
|
| import cv2
|
| import numpy as np
|
| import matplotlib
|
| matplotlib.use('Agg')
|
| import matplotlib.pyplot as plt
|
| import torch.nn as nn
|
| from pathlib import Path
|
| from flask import Flask, render_template, request, jsonify
|
| from flask_cors import CORS
|
|
|
| from src.utils.config import load_config, setup_device
|
| from src.preprocessing.image_processor import ThermalImageProcessor
|
| from src.models.anomaly_detector import ThermalPatternPipeline
|
|
|
| app = Flask(__name__)
|
| CORS(app)
|
|
|
|
|
| MODEL = None
|
| CLASSIFIER = None
|
| PROCESSOR = None
|
| DEVICE = None
|
|
|
|
|
| def load_model():
|
| """Load model, classifier, and processor at startup."""
|
| global MODEL, CLASSIFIER, PROCESSOR, DEVICE
|
|
|
| config = load_config("configs/config.yaml")
|
| DEVICE = setup_device(config)
|
|
|
| MODEL = ThermalPatternPipeline.from_config(config).to(DEVICE)
|
| CLASSIFIER = nn.Linear(config.model.feature_extractor.embedding_dim, 2).to(DEVICE)
|
|
|
| ckpt_path = Path("checkpoints/best_model.pt")
|
| if ckpt_path.exists():
|
| ckpt = torch.load(ckpt_path, map_location=DEVICE, weights_only=False)
|
| MODEL.load_state_dict(ckpt["model_state_dict"])
|
| CLASSIFIER.load_state_dict(ckpt["classifier_state_dict"])
|
| print(f" ✓ Model loaded from {ckpt_path}")
|
| else:
|
| print(f" ✗ No checkpoint at {ckpt_path}")
|
|
|
| MODEL.eval()
|
| CLASSIFIER.eval()
|
| PROCESSOR = ThermalImageProcessor.from_config(config)
|
|
|
|
|
| def img_to_base64(img, cmap=None):
|
| """Convert numpy image to base64-encoded PNG for HTML display."""
|
|
|
| if img.dtype == np.float32 or img.dtype == np.float64:
|
| img_u8 = (np.clip(img, 0, 1) * 255).astype(np.uint8) if img.max() <= 1.0 else np.clip(img, 0, 255).astype(np.uint8)
|
| else:
|
| img_u8 = img.astype(np.uint8)
|
|
|
| if cmap == 'jet':
|
|
|
| colored = cv2.applyColorMap(img_u8, cv2.COLORMAP_JET)
|
| elif len(img_u8.shape) == 2:
|
|
|
| colored = cv2.applyColorMap(img_u8, cv2.COLORMAP_INFERNO)
|
| else:
|
|
|
| colored = cv2.cvtColor(img_u8, cv2.COLOR_RGB2BGR) if img_u8.shape[2] == 3 else img_u8
|
|
|
| _, buf = cv2.imencode('.png', colored)
|
| return base64.b64encode(buf.tobytes()).decode('utf-8')
|
|
|
|
|
| def compute_gradcam(input_tensor):
|
| """Compute Grad-CAM heatmap."""
|
| target_layer = MODEL.feature_extractor.layer4[-1].conv2
|
| activations, gradients = {}, {}
|
|
|
| def fwd_hook(m, i, o): activations["v"] = o.detach()
|
| def bwd_hook(m, gi, go): gradients["v"] = go[0].detach()
|
|
|
| fh = target_layer.register_forward_hook(fwd_hook)
|
| bh = target_layer.register_full_backward_hook(bwd_hook)
|
|
|
| try:
|
| img = input_tensor.unsqueeze(0).to(DEVICE)
|
| features = MODEL.feature_extractor(img)
|
| MODEL.zero_grad()
|
| features.max().backward()
|
|
|
| acts = activations["v"].squeeze(0)
|
| grads = gradients["v"].squeeze(0)
|
| weights = grads.mean(dim=(1, 2))
|
| cam = torch.relu((weights[:, None, None] * acts).sum(0))
|
| cam = cam / (cam.max() + 1e-8)
|
| cam = cam.cpu().numpy()
|
| return cv2.resize(cam, (224, 224))
|
| finally:
|
| fh.remove()
|
| bh.remove()
|
|
|
|
|
|
|
|
|
| @app.route("/")
|
| def index():
|
| return render_template("index.html")
|
|
|
|
|
| @app.route("/analyze", methods=["POST"])
|
| def analyze():
|
| if "file" not in request.files:
|
| return jsonify({"error": "No file uploaded"}), 400
|
|
|
| file = request.files["file"]
|
| file_bytes = np.frombuffer(file.read(), np.uint8)
|
| img = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
|
|
|
| if img is None:
|
| return jsonify({"error": "Cannot read image"}), 400
|
|
|
|
|
| gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if len(img.shape) == 3 else img.copy()
|
| original = gray.copy()
|
|
|
|
|
| resized = PROCESSOR.resize(gray)
|
| denoised = PROCESSOR.denoise(resized)
|
| enhanced = PROCESSOR.enhance_contrast(denoised)
|
| normalized = enhanced.astype(np.float32) / 255.0
|
|
|
|
|
| with torch.no_grad():
|
| img_tensor = torch.from_numpy(normalized).unsqueeze(0)
|
| sequence = img_tensor.unsqueeze(0).repeat(1, 5, 1, 1).unsqueeze(2)
|
| sequence = sequence.to(DEVICE)
|
|
|
| results = MODEL(sequence)
|
| logits = CLASSIFIER(results["encoding"])
|
| probs = torch.softmax(logits, dim=1)
|
| anomaly_score = probs[0, 1].item()
|
| prediction = "ABNORMAL" if anomaly_score > 0.5 else "NORMAL"
|
| confidence = max(anomaly_score, 1 - anomaly_score) * 100
|
|
|
|
|
| gradcam = compute_gradcam(img_tensor)
|
|
|
|
|
| heatmap_colored = cv2.applyColorMap((gradcam * 255).astype(np.uint8), cv2.COLORMAP_JET)
|
| base_bgr = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2BGR)
|
| overlay = cv2.addWeighted(base_bgr, 0.6, heatmap_colored, 0.4, 0)
|
| overlay_rgb = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
|
|
|
|
|
| response = {
|
| "prediction": prediction,
|
| "anomaly_score": round(anomaly_score, 4),
|
| "confidence": round(confidence, 1),
|
| "images": {
|
| "original": img_to_base64(original),
|
| "resized": img_to_base64(resized),
|
| "denoised": img_to_base64(denoised),
|
| "enhanced": img_to_base64(enhanced),
|
| "normalized": img_to_base64(normalized),
|
| "gradcam": img_to_base64(gradcam, cmap='jet'),
|
| "overlay": img_to_base64(overlay_rgb),
|
| }
|
| }
|
|
|
| return jsonify(response)
|
|
|
|
|
| @app.route("/sample_images")
|
| def sample_images():
|
| """Return list of sample images from the dataset."""
|
| import glob
|
| samples = glob.glob("data/raw/Power Transformers/*.jpg")[:12]
|
| names = [Path(s).name for s in samples]
|
| return jsonify(names)
|
|
|
|
|
| @app.route("/analyze_sample/<filename>")
|
| def analyze_sample(filename):
|
| """Analyze a sample image from the dataset."""
|
| path = Path("data/raw/Power Transformers") / filename
|
| if not path.exists():
|
| return jsonify({"error": "Sample not found"}), 404
|
|
|
| with open(path, "rb") as f:
|
| from werkzeug.datastructures import FileStorage
|
| file = FileStorage(f, filename=filename)
|
|
|
| file_bytes = np.frombuffer(f.read(), np.uint8)
|
|
|
| img = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
|
| if img is None:
|
| return jsonify({"error": "Cannot read image"}), 400
|
|
|
| gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if len(img.shape) == 3 else img.copy()
|
| original = gray.copy()
|
| resized = PROCESSOR.resize(gray)
|
| denoised = PROCESSOR.denoise(resized)
|
| enhanced = PROCESSOR.enhance_contrast(denoised)
|
| normalized = enhanced.astype(np.float32) / 255.0
|
|
|
| with torch.no_grad():
|
| img_tensor = torch.from_numpy(normalized).unsqueeze(0)
|
| sequence = img_tensor.unsqueeze(0).repeat(1, 5, 1, 1).unsqueeze(2)
|
| sequence = sequence.to(DEVICE)
|
| results = MODEL(sequence)
|
| logits = CLASSIFIER(results["encoding"])
|
| probs = torch.softmax(logits, dim=1)
|
| anomaly_score = probs[0, 1].item()
|
| prediction = "ABNORMAL" if anomaly_score > 0.5 else "NORMAL"
|
| confidence = max(anomaly_score, 1 - anomaly_score) * 100
|
|
|
| gradcam = compute_gradcam(img_tensor)
|
| heatmap_colored = cv2.applyColorMap((gradcam * 255).astype(np.uint8), cv2.COLORMAP_JET)
|
| base_bgr = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2BGR)
|
| overlay = cv2.addWeighted(base_bgr, 0.6, heatmap_colored, 0.4, 0)
|
| overlay_rgb = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
|
|
|
| return jsonify({
|
| "prediction": prediction,
|
| "anomaly_score": round(anomaly_score, 4),
|
| "confidence": round(confidence, 1),
|
| "images": {
|
| "original": img_to_base64(original),
|
| "resized": img_to_base64(resized),
|
| "denoised": img_to_base64(denoised),
|
| "enhanced": img_to_base64(enhanced),
|
| "normalized": img_to_base64(normalized),
|
| "gradcam": img_to_base64(gradcam, cmap='jet'),
|
| "overlay": img_to_base64(overlay_rgb),
|
| }
|
| })
|
|
|
|
|
| if __name__ == "__main__":
|
| print("Loading model...")
|
| load_model()
|
| port = int(os.environ.get("PORT", 5000))
|
| print(f"Starting server on port {port}")
|
| app.run(debug=False, host="0.0.0.0", port=port)
|
|
|