Shivdutta commited on
Commit
90aceef
·
verified ·
1 Parent(s): a6456c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -7
app.py CHANGED
@@ -42,15 +42,14 @@ def inference(input_img, num_gradcam_images=1, target_layer_number=-1, transpare
42
  rgb_img = np.transpose(org_img, (1, 2, 0))
43
  visualization.append(show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency))
44
 
 
45
  num_rows = 2
46
- num_cols = (num_gradcam_images - 1) // num_rows + 1
47
- fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, 6))
48
- for i, visualization_img in enumerate(visualization):
49
- row_index = i // num_cols
50
- col_index = i % num_cols
51
- ax = axes[row_index, col_index] if num_rows > 1 else axes[col_index]
52
- ax.imshow(visualization_img)
53
  ax.axis('off')
 
54
  plt.tight_layout()
55
  buffer = BytesIO()
56
  plt.savefig(buffer, format='png')
 
42
  rgb_img = np.transpose(org_img, (1, 2, 0))
43
  visualization.append(show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency))
44
 
45
+ fig = plt.figure(figsize=(12, 5))
46
  num_rows = 2
47
+ num_cols = 5
48
+ for i in range(len(visualizations)):
49
+ ax = fig.add_subplot(num_rows, num_cols, i + 1)
50
+ ax.imshow(visualizations[i])
 
 
 
51
  ax.axis('off')
52
+
53
  plt.tight_layout()
54
  buffer = BytesIO()
55
  plt.savefig(buffer, format='png')