ArchCoder commited on
Commit
d57f983
·
verified ·
1 Parent(s): c530021

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -5
app.py CHANGED
@@ -41,6 +41,8 @@ class DoubleConv(nn.Module):
41
  def forward(self, x):
42
  return self.conv(x)
43
 
 
 
44
  class AttentionBlock(nn.Module):
45
  def __init__(self, F_g, F_l, F_int):
46
  super(AttentionBlock, self).__init__()
@@ -67,7 +69,7 @@ class AttentionBlock(nn.Module):
67
  x1 = self.W_x(x)
68
  psi = self.relu(g1 + x1)
69
  psi = self.psi(psi)
70
- return x * psi, psi
71
 
72
  class AttentionUNET(nn.Module):
73
  def __init__(self, in_channels=1, out_channels=1, features=[32, 64, 128, 256]):
@@ -261,17 +263,34 @@ def apply_tta(model, input_tensor):
261
  return torch.mean(torch.stack(predictions), dim=0)
262
 
263
  def generate_attention_heatmap(attention_maps):
264
- """Generate attention heatmap"""
265
  if not attention_maps:
266
  return np.zeros((256, 256, 3))
267
 
268
- # Combine attention maps
269
- combined_att = torch.mean(torch.stack(attention_maps), dim=0).squeeze().cpu().numpy()
270
- combined_att = cv2.resize(combined_att, (256, 256))
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  combined_att = (combined_att - combined_att.min()) / (combined_att.max() - combined_att.min() + 1e-8)
 
 
272
  heatmap = cv2.applyColorMap((combined_att * 255).astype(np.uint8), cv2.COLORMAP_JET)
 
273
  return heatmap
274
 
 
275
  def analyze_image(image, ground_truth, filename):
276
  """Main analysis function"""
277
  if model is None:
 
41
  def forward(self, x):
42
  return self.conv(x)
43
 
44
+
45
+ # Also, make sure your AttentionBlock.forward() returns the attention map:
46
  class AttentionBlock(nn.Module):
47
  def __init__(self, F_g, F_l, F_int):
48
  super(AttentionBlock, self).__init__()
 
69
  x1 = self.W_x(x)
70
  psi = self.relu(g1 + x1)
71
  psi = self.psi(psi)
72
+ return x * psi, psi # Return both attended features AND attention map
73
 
74
  class AttentionUNET(nn.Module):
75
  def __init__(self, in_channels=1, out_channels=1, features=[32, 64, 128, 256]):
 
263
  return torch.mean(torch.stack(predictions), dim=0)
264
 
265
  def generate_attention_heatmap(attention_maps):
266
+ """Generate attention heatmap - Fixed version"""
267
  if not attention_maps:
268
  return np.zeros((256, 256, 3))
269
 
270
+ # Resize all attention maps to the same size (256x256) before combining
271
+ resized_maps = []
272
+ target_size = (256, 256)
273
+
274
+ for att_map in attention_maps:
275
+ # Convert to numpy and squeeze
276
+ att_np = att_map.squeeze().cpu().numpy()
277
+
278
+ # Resize to target size
279
+ att_resized = cv2.resize(att_np, target_size)
280
+ resized_maps.append(att_resized)
281
+
282
+ # Now we can safely average the maps since they're all the same size
283
+ combined_att = np.mean(resized_maps, axis=0)
284
+
285
+ # Normalize to [0, 1]
286
  combined_att = (combined_att - combined_att.min()) / (combined_att.max() - combined_att.min() + 1e-8)
287
+
288
+ # Apply colormap
289
  heatmap = cv2.applyColorMap((combined_att * 255).astype(np.uint8), cv2.COLORMAP_JET)
290
+
291
  return heatmap
292
 
293
+
294
  def analyze_image(image, ground_truth, filename):
295
  """Main analysis function"""
296
  if model is None: