gajavegs commited on
Commit
259a0f2
·
1 Parent(s): 52af76f

Added Grad-CAM

Browse files
Files changed (3) hide show
  1. app.py +97 -2
  2. requirements.txt +2 -1
  3. static/index.html +51 -27
app.py CHANGED
@@ -8,6 +8,11 @@ from dotenv import load_dotenv
8
  from model_loader import load_alexnet_model, preprocess_image
9
  from flask_cors import CORS
10
 
 
 
 
 
 
11
  load_dotenv(override=True)
12
 
13
  # HF sets PORT dynamically. Fall back to 7860 locally.
@@ -55,6 +60,96 @@ def load_image(file_stream_or_path):
55
  return Image.open(file_stream_or_path).convert("RGB")
56
  return Image.open(file_stream_or_path).convert("RGB")
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  def run_inference(img: Image.Image) -> Dict[str, Any]:
59
  input_tensor = preprocess_image(img).to(DEVICE)
60
  with torch.no_grad():
@@ -78,7 +173,7 @@ def predict_alexnet() -> Any:
78
  return jsonify({"error": "Empty file."}), 400
79
  try:
80
  img = load_image(file.stream)
81
- result = run_inference(img)
82
  return jsonify(result)
83
  except Exception as e:
84
  return jsonify({"error": f"Failed to process image: {e}"}), 400
@@ -103,7 +198,7 @@ def predict_preset() -> Any:
103
 
104
  try:
105
  img = load_image(path)
106
- result = run_inference(img)
107
  result.update({"preset": key, "path": path})
108
  return jsonify(result)
109
  except Exception as e:
 
8
  from model_loader import load_alexnet_model, preprocess_image
9
  from flask_cors import CORS
10
 
11
+ from io import BytesIO
12
+ import base64
13
+ import numpy as np
14
+
15
+
16
  load_dotenv(override=True)
17
 
18
  # HF sets PORT dynamically. Fall back to 7860 locally.
 
60
  return Image.open(file_stream_or_path).convert("RGB")
61
  return Image.open(file_stream_or_path).convert("RGB")
62
 
63
+ def generate_gradcam(img_pil: Image.Image, target_idx: int) -> str:
64
+ """
65
+ Returns a data URL (PNG) of the Grad-CAM overlay for the target class.
66
+ """
67
+ model.eval()
68
+ orig_w, orig_h = img_pil.size
69
+
70
+ # Last conv of standard AlexNet
71
+ target_layer = model.features[10]
72
+
73
+ activations = []
74
+ gradients = []
75
+
76
+ def fwd_hook(_, __, out):
77
+ # Save activations (detached) and attach a tensor hook to capture gradients
78
+ activations.append(out.detach())
79
+ out.register_hook(lambda g: gradients.append(g.detach().clone()))
80
+
81
+ handle = target_layer.register_forward_hook(fwd_hook)
82
+ try:
83
+ # Forward
84
+ input_tensor = preprocess_image(img_pil).to(DEVICE)
85
+ output = model(input_tensor) # [1, C]
86
+
87
+ # Backward on the selected class
88
+ if target_idx < 0 or target_idx >= output.shape[1]:
89
+ raise ValueError(f"target_idx {target_idx} out of range for output dim {output.shape[1]}")
90
+ score = output[0, target_idx]
91
+ model.zero_grad(set_to_none=True)
92
+ score.backward()
93
+
94
+ # Ensure hooks fired
95
+ if not activations or not gradients:
96
+ raise RuntimeError("Grad-CAM hooks did not capture activations/gradients")
97
+
98
+ A = activations[-1] # [1, C, H, W]
99
+ dA = gradients[-1] # [1, C, H, W]
100
+
101
+ # Weights: global-average-pool the gradients
102
+ weights = dA.mean(dim=(2, 3), keepdim=True) # [1, C, 1, 1]
103
+ cam = (weights * A).sum(dim=1, keepdim=False) # [1, H, W]
104
+ cam = torch.relu(cam)[0] # [H, W]
105
+
106
+ # Normalize to [0,1]
107
+ cam -= cam.min()
108
+ if cam.max() > 0:
109
+ cam /= cam.max()
110
+
111
+ # Resize CAM to original image size
112
+ cam_np = cam.detach().cpu().numpy()
113
+ cam_img = Image.fromarray((cam_np * 255).astype(np.uint8), mode="L")
114
+ cam_img = cam_img.resize((orig_w, orig_h), resample=Image.BILINEAR)
115
+
116
+ # Red alpha overlay
117
+ heat_rgba = Image.new("RGBA", (orig_w, orig_h), (255, 0, 0, 0))
118
+ heat_rgba.putalpha(cam_img)
119
+ base = img_pil.convert("RGBA")
120
+ overlayed = Image.alpha_composite(base, heat_rgba)
121
+
122
+ # Encode to data URL
123
+ buff = BytesIO()
124
+ overlayed.save(buff, format="PNG")
125
+ b64 = base64.b64encode(buff.getvalue()).decode("utf-8")
126
+ return f"data:image/png;base64,{b64}"
127
+ finally:
128
+ handle.remove() # <-- remove the actual handle you registered
129
+
130
+
131
+ def run_inference_with_gradcam(img: Image.Image) -> Dict[str, Any]:
132
+ """Run softmax inference and also compute Grad-CAM for the predicted class."""
133
+ # Regular inference (no grad) for probabilities
134
+ input_tensor = preprocess_image(img).to(DEVICE)
135
+ with torch.no_grad():
136
+ output = model(input_tensor)
137
+ probabilities = F.softmax(output[0], dim=0).detach().cpu()
138
+
139
+ pred_prob, pred_idx = torch.max(probabilities, dim=0)
140
+ predicted_class = classes[int(pred_idx)]
141
+
142
+ # Grad-CAM for predicted index
143
+ gradcam_data_url = generate_gradcam(img, int(pred_idx))
144
+
145
+ return {
146
+ "class": predicted_class,
147
+ "confidence": float(pred_prob),
148
+ "probabilities": {cls: float(prob) for cls, prob in zip(classes, probabilities.tolist())},
149
+ "gradcam": gradcam_data_url,
150
+ }
151
+
152
+
153
  def run_inference(img: Image.Image) -> Dict[str, Any]:
154
  input_tensor = preprocess_image(img).to(DEVICE)
155
  with torch.no_grad():
 
173
  return jsonify({"error": "Empty file."}), 400
174
  try:
175
  img = load_image(file.stream)
176
+ result = run_inference_with_gradcam(img) # << changed
177
  return jsonify(result)
178
  except Exception as e:
179
  return jsonify({"error": f"Failed to process image: {e}"}), 400
 
198
 
199
  try:
200
  img = load_image(path)
201
+ result = run_inference_with_gradcam(img) # << changed
202
  result.update({"preset": key, "path": path})
203
  return jsonify(result)
204
  except Exception as e:
requirements.txt CHANGED
@@ -2,4 +2,5 @@ flask>=3.0.0
2
  pillow>=10.0.0
3
  gunicorn>=21.2.0
4
  python-dotenv>=1.0.0
5
- Flask-Cors>=4.0.0
 
 
2
  pillow>=10.0.0
3
  gunicorn>=21.2.0
4
  python-dotenv>=1.0.0
5
+ Flask-Cors>=4.0.0
6
+ numpy>=1.24.0
static/index.html CHANGED
@@ -155,6 +155,12 @@
155
  <div class="probabilities-title">All Class Probabilities</div>
156
  <div id="probabilitiesList"></div>
157
  </div>
 
 
 
 
 
 
158
  </div>
159
 
160
  <div class="error-message" id="errorMessage"></div>
@@ -180,6 +186,10 @@
180
  const loadingSpinner = document.getElementById('loadingSpinner');
181
  const presetGrid = document.getElementById('presetGrid');
182
 
 
 
 
 
183
  let currentFile = null;
184
  let currentPreset = null; // 'TP' | 'TN' | 'FN' | 'FP' | null
185
 
@@ -344,35 +354,49 @@
344
  }
345
  });
346
 
347
- function displayResults(result) {
348
- predictedClass.textContent = result.class;
349
- confidenceScore.textContent = `${(result.confidence * 100).toFixed(2)}% Confidence`;
350
- const sortedProbs = Object.entries(result.probabilities).sort(([,a],[,b])=>b-a).slice(0,10);
351
- probabilitiesList.innerHTML = '';
352
- sortedProbs.forEach(([className, prob], index) => {
353
- const probPercent = (prob * 100).toFixed(2);
354
- const isTop = index === 0;
355
- const div = document.createElement('div');
356
- div.className = 'probability-item';
357
- div.innerHTML = `
358
- <div class="probability-label">
359
- <span class="class-name" style="${isTop?'font-weight:700;color:#667eea;':''}">${className}</span>
360
- <span class="class-prob" style="${isTop?'font-weight:700;color:#667eea;':''}">${probPercent}%</span>
361
- </div>
362
- <div class="probability-bar-bg">
363
- <div class="probability-bar" style="width:0%;" data-width="${probPercent}"></div>
364
- </div>
365
- `;
366
- probabilitiesList.appendChild(div);
367
- });
368
- resultsSection.classList.add('active');
369
- setTimeout(() => {
370
- probabilitiesList.querySelectorAll('.probability-bar').forEach(bar => {
371
- bar.style.width = bar.getAttribute('data-width') + '%';
372
- });
373
- }, 100);
374
  }
375
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  function showLoading() { loadingSpinner.classList.add('active'); classifyBtn.disabled = true; }
377
  function hideLoading() { loadingSpinner.classList.remove('active'); classifyBtn.disabled = false; }
378
  function hideResults() { resultsSection.classList.remove('active'); }
 
155
  <div class="probabilities-title">All Class Probabilities</div>
156
  <div id="probabilitiesList"></div>
157
  </div>
158
+
159
+ <div class="gradcam-container" id="gradcamContainer" style="display:none; margin:16px 0 20px;">
160
+ <div class="probabilities-title" style="margin-bottom:10px;">Grad-CAM (Predicted Class)</div>
161
+ <img id="gradcamImage" class="preview-image" alt="Grad-CAM visualization" style="max-width:480px; width:100%; border-radius:10px; box-shadow:0 4px 20px rgba(0,0,0,0.08);" />
162
+ </div>
163
+
164
  </div>
165
 
166
  <div class="error-message" id="errorMessage"></div>
 
186
  const loadingSpinner = document.getElementById('loadingSpinner');
187
  const presetGrid = document.getElementById('presetGrid');
188
 
189
+ const gradcamContainer = document.getElementById('gradcamContainer');
190
+ const gradcamImage = document.getElementById('gradcamImage');
191
+
192
+
193
  let currentFile = null;
194
  let currentPreset = null; // 'TP' | 'TN' | 'FN' | 'FP' | null
195
 
 
354
  }
355
  });
356
 
357
+ function displayResults(result) {
358
+ predictedClass.textContent = result.class;
359
+ confidenceScore.textContent = `${(result.confidence * 100).toFixed(2)}% Confidence`;
360
+
361
+ // --- NEW: Grad-CAM rendering ---
362
+ if (result.gradcam) {
363
+ gradcamImage.src = result.gradcam;
364
+ gradcamContainer.style.display = 'block';
365
+ } else {
366
+ gradcamContainer.style.display = 'none';
367
+ gradcamImage.removeAttribute('src');
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
  }
369
 
370
+ const sortedProbs = Object.entries(result.probabilities)
371
+ .sort(([, a], [, b]) => b - a)
372
+ .slice(0, 10);
373
+
374
+ probabilitiesList.innerHTML = '';
375
+ sortedProbs.forEach(([className, prob], index) => {
376
+ const probPercent = (prob * 100).toFixed(2);
377
+ const isTop = index === 0;
378
+ const div = document.createElement('div');
379
+ div.className = 'probability-item';
380
+ div.innerHTML = `
381
+ <div class="probability-label">
382
+ <span class="class-name" style="${isTop ? 'font-weight:700;color:#667eea;' : ''}">${className}</span>
383
+ <span class="class-prob" style="${isTop ? 'font-weight:700;color:#667eea;' : ''}">${probPercent}%</span>
384
+ </div>
385
+ <div class="probability-bar-bg">
386
+ <div class="probability-bar" style="width:0%;" data-width="${probPercent}"></div>
387
+ </div>
388
+ `;
389
+ probabilitiesList.appendChild(div);
390
+ });
391
+
392
+ resultsSection.classList.add('active');
393
+ setTimeout(() => {
394
+ probabilitiesList.querySelectorAll('.probability-bar').forEach(bar => {
395
+ bar.style.width = bar.getAttribute('data-width') + '%';
396
+ });
397
+ }, 100);
398
+ }
399
+
400
  function showLoading() { loadingSpinner.classList.add('active'); classifyBtn.disabled = true; }
401
  function hideLoading() { loadingSpinner.classList.remove('active'); classifyBtn.disabled = false; }
402
  function hideResults() { resultsSection.classList.remove('active'); }