Spaces:
Sleeping
Sleeping
Added Grad-CAM
Browse files- app.py +97 -2
- requirements.txt +2 -1
- static/index.html +51 -27
app.py
CHANGED
|
@@ -8,6 +8,11 @@ from dotenv import load_dotenv
|
|
| 8 |
from model_loader import load_alexnet_model, preprocess_image
|
| 9 |
from flask_cors import CORS
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
load_dotenv(override=True)
|
| 12 |
|
| 13 |
# HF sets PORT dynamically. Fall back to 7860 locally.
|
|
@@ -55,6 +60,96 @@ def load_image(file_stream_or_path):
|
|
| 55 |
return Image.open(file_stream_or_path).convert("RGB")
|
| 56 |
return Image.open(file_stream_or_path).convert("RGB")
|
| 57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
def run_inference(img: Image.Image) -> Dict[str, Any]:
|
| 59 |
input_tensor = preprocess_image(img).to(DEVICE)
|
| 60 |
with torch.no_grad():
|
|
@@ -78,7 +173,7 @@ def predict_alexnet() -> Any:
|
|
| 78 |
return jsonify({"error": "Empty file."}), 400
|
| 79 |
try:
|
| 80 |
img = load_image(file.stream)
|
| 81 |
-
result =
|
| 82 |
return jsonify(result)
|
| 83 |
except Exception as e:
|
| 84 |
return jsonify({"error": f"Failed to process image: {e}"}), 400
|
|
@@ -103,7 +198,7 @@ def predict_preset() -> Any:
|
|
| 103 |
|
| 104 |
try:
|
| 105 |
img = load_image(path)
|
| 106 |
-
result =
|
| 107 |
result.update({"preset": key, "path": path})
|
| 108 |
return jsonify(result)
|
| 109 |
except Exception as e:
|
|
|
|
| 8 |
from model_loader import load_alexnet_model, preprocess_image
|
| 9 |
from flask_cors import CORS
|
| 10 |
|
| 11 |
+
from io import BytesIO
|
| 12 |
+
import base64
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
|
| 16 |
load_dotenv(override=True)
|
| 17 |
|
| 18 |
# HF sets PORT dynamically. Fall back to 7860 locally.
|
|
|
|
| 60 |
return Image.open(file_stream_or_path).convert("RGB")
|
| 61 |
return Image.open(file_stream_or_path).convert("RGB")
|
| 62 |
|
| 63 |
+
def generate_gradcam(img_pil: Image.Image, target_idx: int) -> str:
|
| 64 |
+
"""
|
| 65 |
+
Returns a data URL (PNG) of the Grad-CAM overlay for the target class.
|
| 66 |
+
"""
|
| 67 |
+
model.eval()
|
| 68 |
+
orig_w, orig_h = img_pil.size
|
| 69 |
+
|
| 70 |
+
# Last conv of standard AlexNet
|
| 71 |
+
target_layer = model.features[10]
|
| 72 |
+
|
| 73 |
+
activations = []
|
| 74 |
+
gradients = []
|
| 75 |
+
|
| 76 |
+
def fwd_hook(_, __, out):
|
| 77 |
+
# Save activations (detached) and attach a tensor hook to capture gradients
|
| 78 |
+
activations.append(out.detach())
|
| 79 |
+
out.register_hook(lambda g: gradients.append(g.detach().clone()))
|
| 80 |
+
|
| 81 |
+
handle = target_layer.register_forward_hook(fwd_hook)
|
| 82 |
+
try:
|
| 83 |
+
# Forward
|
| 84 |
+
input_tensor = preprocess_image(img_pil).to(DEVICE)
|
| 85 |
+
output = model(input_tensor) # [1, C]
|
| 86 |
+
|
| 87 |
+
# Backward on the selected class
|
| 88 |
+
if target_idx < 0 or target_idx >= output.shape[1]:
|
| 89 |
+
raise ValueError(f"target_idx {target_idx} out of range for output dim {output.shape[1]}")
|
| 90 |
+
score = output[0, target_idx]
|
| 91 |
+
model.zero_grad(set_to_none=True)
|
| 92 |
+
score.backward()
|
| 93 |
+
|
| 94 |
+
# Ensure hooks fired
|
| 95 |
+
if not activations or not gradients:
|
| 96 |
+
raise RuntimeError("Grad-CAM hooks did not capture activations/gradients")
|
| 97 |
+
|
| 98 |
+
A = activations[-1] # [1, C, H, W]
|
| 99 |
+
dA = gradients[-1] # [1, C, H, W]
|
| 100 |
+
|
| 101 |
+
# Weights: global-average-pool the gradients
|
| 102 |
+
weights = dA.mean(dim=(2, 3), keepdim=True) # [1, C, 1, 1]
|
| 103 |
+
cam = (weights * A).sum(dim=1, keepdim=False) # [1, H, W]
|
| 104 |
+
cam = torch.relu(cam)[0] # [H, W]
|
| 105 |
+
|
| 106 |
+
# Normalize to [0,1]
|
| 107 |
+
cam -= cam.min()
|
| 108 |
+
if cam.max() > 0:
|
| 109 |
+
cam /= cam.max()
|
| 110 |
+
|
| 111 |
+
# Resize CAM to original image size
|
| 112 |
+
cam_np = cam.detach().cpu().numpy()
|
| 113 |
+
cam_img = Image.fromarray((cam_np * 255).astype(np.uint8), mode="L")
|
| 114 |
+
cam_img = cam_img.resize((orig_w, orig_h), resample=Image.BILINEAR)
|
| 115 |
+
|
| 116 |
+
# Red alpha overlay
|
| 117 |
+
heat_rgba = Image.new("RGBA", (orig_w, orig_h), (255, 0, 0, 0))
|
| 118 |
+
heat_rgba.putalpha(cam_img)
|
| 119 |
+
base = img_pil.convert("RGBA")
|
| 120 |
+
overlayed = Image.alpha_composite(base, heat_rgba)
|
| 121 |
+
|
| 122 |
+
# Encode to data URL
|
| 123 |
+
buff = BytesIO()
|
| 124 |
+
overlayed.save(buff, format="PNG")
|
| 125 |
+
b64 = base64.b64encode(buff.getvalue()).decode("utf-8")
|
| 126 |
+
return f"data:image/png;base64,{b64}"
|
| 127 |
+
finally:
|
| 128 |
+
handle.remove() # <-- remove the actual handle you registered
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def run_inference_with_gradcam(img: Image.Image) -> Dict[str, Any]:
|
| 132 |
+
"""Run softmax inference and also compute Grad-CAM for the predicted class."""
|
| 133 |
+
# Regular inference (no grad) for probabilities
|
| 134 |
+
input_tensor = preprocess_image(img).to(DEVICE)
|
| 135 |
+
with torch.no_grad():
|
| 136 |
+
output = model(input_tensor)
|
| 137 |
+
probabilities = F.softmax(output[0], dim=0).detach().cpu()
|
| 138 |
+
|
| 139 |
+
pred_prob, pred_idx = torch.max(probabilities, dim=0)
|
| 140 |
+
predicted_class = classes[int(pred_idx)]
|
| 141 |
+
|
| 142 |
+
# Grad-CAM for predicted index
|
| 143 |
+
gradcam_data_url = generate_gradcam(img, int(pred_idx))
|
| 144 |
+
|
| 145 |
+
return {
|
| 146 |
+
"class": predicted_class,
|
| 147 |
+
"confidence": float(pred_prob),
|
| 148 |
+
"probabilities": {cls: float(prob) for cls, prob in zip(classes, probabilities.tolist())},
|
| 149 |
+
"gradcam": gradcam_data_url,
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
|
| 153 |
def run_inference(img: Image.Image) -> Dict[str, Any]:
|
| 154 |
input_tensor = preprocess_image(img).to(DEVICE)
|
| 155 |
with torch.no_grad():
|
|
|
|
| 173 |
return jsonify({"error": "Empty file."}), 400
|
| 174 |
try:
|
| 175 |
img = load_image(file.stream)
|
| 176 |
+
result = run_inference_with_gradcam(img) # << changed
|
| 177 |
return jsonify(result)
|
| 178 |
except Exception as e:
|
| 179 |
return jsonify({"error": f"Failed to process image: {e}"}), 400
|
|
|
|
| 198 |
|
| 199 |
try:
|
| 200 |
img = load_image(path)
|
| 201 |
+
result = run_inference_with_gradcam(img) # << changed
|
| 202 |
result.update({"preset": key, "path": path})
|
| 203 |
return jsonify(result)
|
| 204 |
except Exception as e:
|
requirements.txt
CHANGED
|
@@ -2,4 +2,5 @@ flask>=3.0.0
|
|
| 2 |
pillow>=10.0.0
|
| 3 |
gunicorn>=21.2.0
|
| 4 |
python-dotenv>=1.0.0
|
| 5 |
-
Flask-Cors>=4.0.0
|
|
|
|
|
|
| 2 |
pillow>=10.0.0
|
| 3 |
gunicorn>=21.2.0
|
| 4 |
python-dotenv>=1.0.0
|
| 5 |
+
Flask-Cors>=4.0.0
|
| 6 |
+
numpy>=1.24.0
|
static/index.html
CHANGED
|
@@ -155,6 +155,12 @@
|
|
| 155 |
<div class="probabilities-title">All Class Probabilities</div>
|
| 156 |
<div id="probabilitiesList"></div>
|
| 157 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
</div>
|
| 159 |
|
| 160 |
<div class="error-message" id="errorMessage"></div>
|
|
@@ -180,6 +186,10 @@
|
|
| 180 |
const loadingSpinner = document.getElementById('loadingSpinner');
|
| 181 |
const presetGrid = document.getElementById('presetGrid');
|
| 182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
let currentFile = null;
|
| 184 |
let currentPreset = null; // 'TP' | 'TN' | 'FN' | 'FP' | null
|
| 185 |
|
|
@@ -344,35 +354,49 @@
|
|
| 344 |
}
|
| 345 |
});
|
| 346 |
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
<div class="probability-label">
|
| 359 |
-
<span class="class-name" style="${isTop?'font-weight:700;color:#667eea;':''}">${className}</span>
|
| 360 |
-
<span class="class-prob" style="${isTop?'font-weight:700;color:#667eea;':''}">${probPercent}%</span>
|
| 361 |
-
</div>
|
| 362 |
-
<div class="probability-bar-bg">
|
| 363 |
-
<div class="probability-bar" style="width:0%;" data-width="${probPercent}"></div>
|
| 364 |
-
</div>
|
| 365 |
-
`;
|
| 366 |
-
probabilitiesList.appendChild(div);
|
| 367 |
-
});
|
| 368 |
-
resultsSection.classList.add('active');
|
| 369 |
-
setTimeout(() => {
|
| 370 |
-
probabilitiesList.querySelectorAll('.probability-bar').forEach(bar => {
|
| 371 |
-
bar.style.width = bar.getAttribute('data-width') + '%';
|
| 372 |
-
});
|
| 373 |
-
}, 100);
|
| 374 |
}
|
| 375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
function showLoading() { loadingSpinner.classList.add('active'); classifyBtn.disabled = true; }
|
| 377 |
function hideLoading() { loadingSpinner.classList.remove('active'); classifyBtn.disabled = false; }
|
| 378 |
function hideResults() { resultsSection.classList.remove('active'); }
|
|
|
|
| 155 |
<div class="probabilities-title">All Class Probabilities</div>
|
| 156 |
<div id="probabilitiesList"></div>
|
| 157 |
</div>
|
| 158 |
+
|
| 159 |
+
<div class="gradcam-container" id="gradcamContainer" style="display:none; margin:16px 0 20px;">
|
| 160 |
+
<div class="probabilities-title" style="margin-bottom:10px;">Grad-CAM (Predicted Class)</div>
|
| 161 |
+
<img id="gradcamImage" class="preview-image" alt="Grad-CAM visualization" style="max-width:480px; width:100%; border-radius:10px; box-shadow:0 4px 20px rgba(0,0,0,0.08);" />
|
| 162 |
+
</div>
|
| 163 |
+
|
| 164 |
</div>
|
| 165 |
|
| 166 |
<div class="error-message" id="errorMessage"></div>
|
|
|
|
| 186 |
const loadingSpinner = document.getElementById('loadingSpinner');
|
| 187 |
const presetGrid = document.getElementById('presetGrid');
|
| 188 |
|
| 189 |
+
const gradcamContainer = document.getElementById('gradcamContainer');
|
| 190 |
+
const gradcamImage = document.getElementById('gradcamImage');
|
| 191 |
+
|
| 192 |
+
|
| 193 |
let currentFile = null;
|
| 194 |
let currentPreset = null; // 'TP' | 'TN' | 'FN' | 'FP' | null
|
| 195 |
|
|
|
|
| 354 |
}
|
| 355 |
});
|
| 356 |
|
| 357 |
+
function displayResults(result) {
|
| 358 |
+
predictedClass.textContent = result.class;
|
| 359 |
+
confidenceScore.textContent = `${(result.confidence * 100).toFixed(2)}% Confidence`;
|
| 360 |
+
|
| 361 |
+
// --- NEW: Grad-CAM rendering ---
|
| 362 |
+
if (result.gradcam) {
|
| 363 |
+
gradcamImage.src = result.gradcam;
|
| 364 |
+
gradcamContainer.style.display = 'block';
|
| 365 |
+
} else {
|
| 366 |
+
gradcamContainer.style.display = 'none';
|
| 367 |
+
gradcamImage.removeAttribute('src');
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
}
|
| 369 |
|
| 370 |
+
const sortedProbs = Object.entries(result.probabilities)
|
| 371 |
+
.sort(([, a], [, b]) => b - a)
|
| 372 |
+
.slice(0, 10);
|
| 373 |
+
|
| 374 |
+
probabilitiesList.innerHTML = '';
|
| 375 |
+
sortedProbs.forEach(([className, prob], index) => {
|
| 376 |
+
const probPercent = (prob * 100).toFixed(2);
|
| 377 |
+
const isTop = index === 0;
|
| 378 |
+
const div = document.createElement('div');
|
| 379 |
+
div.className = 'probability-item';
|
| 380 |
+
div.innerHTML = `
|
| 381 |
+
<div class="probability-label">
|
| 382 |
+
<span class="class-name" style="${isTop ? 'font-weight:700;color:#667eea;' : ''}">${className}</span>
|
| 383 |
+
<span class="class-prob" style="${isTop ? 'font-weight:700;color:#667eea;' : ''}">${probPercent}%</span>
|
| 384 |
+
</div>
|
| 385 |
+
<div class="probability-bar-bg">
|
| 386 |
+
<div class="probability-bar" style="width:0%;" data-width="${probPercent}"></div>
|
| 387 |
+
</div>
|
| 388 |
+
`;
|
| 389 |
+
probabilitiesList.appendChild(div);
|
| 390 |
+
});
|
| 391 |
+
|
| 392 |
+
resultsSection.classList.add('active');
|
| 393 |
+
setTimeout(() => {
|
| 394 |
+
probabilitiesList.querySelectorAll('.probability-bar').forEach(bar => {
|
| 395 |
+
bar.style.width = bar.getAttribute('data-width') + '%';
|
| 396 |
+
});
|
| 397 |
+
}, 100);
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
function showLoading() { loadingSpinner.classList.add('active'); classifyBtn.disabled = true; }
|
| 401 |
function hideLoading() { loadingSpinner.classList.remove('active'); classifyBtn.disabled = false; }
|
| 402 |
function hideResults() { resultsSection.classList.remove('active'); }
|