MahatirTusher commited on
Commit
b06cc6f
·
verified ·
1 Parent(s): 13f3b95

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -8
app.py CHANGED
@@ -3,7 +3,7 @@ import gradio as gr
3
  from PIL import Image
4
  import torch
5
  from transformers import ViTForImageClassification, ViTImageProcessor
6
- from datasets import load_dataset
7
  import matplotlib.pyplot as plt
8
  import numpy as np
9
  import cv2
@@ -14,7 +14,8 @@ processor = ViTImageProcessor.from_pretrained(model_name_or_path)
14
 
15
  # Load dataset (adjust dataset_path accordingly)
16
  dataset_path = "pawlo2013/chest_xray"
17
- train_dataset = load_dataset(dataset_path, split="train")
 
18
  class_names = train_dataset.features["label"].names
19
 
20
  # Load ViT model
@@ -33,7 +34,6 @@ model.eval()
33
  def classify_and_visualize(img, device="cpu", discard_ratio=0.9, head_fusion="mean"):
34
  img = img.convert("RGB")
35
  processed_input = processor(images=img, return_tensors="pt").to(device)
36
-
37
  processed_input = processed_input["pixel_values"].to(device)
38
 
39
  with torch.no_grad():
@@ -77,9 +77,7 @@ def show_final_layer_attention_maps(
77
  ):
78
 
79
  with torch.no_grad():
80
-
81
  image = processed_input.squeeze(0)
82
-
83
  image = image - image.min()
84
  image = image / image.max()
85
 
@@ -105,7 +103,6 @@ def show_final_layer_attention_maps(
105
  I = torch.eye(attention_heads_fused.size(-1)).to(device)
106
  a = (attention_heads_fused + 1.0 * I) / 2
107
  a = a / a.sum(dim=-1)
108
-
109
  result = torch.matmul(a, result)
110
 
111
  mask = result[0, 0, 1:]
@@ -114,7 +111,6 @@ def show_final_layer_attention_maps(
114
  mask = mask / np.max(mask)
115
 
116
  mask = cv2.resize(mask, (224, 224))
117
-
118
  mask = (mask - np.min(mask)) / (np.max(mask) - np.min(mask))
119
  heatmap = plt.cm.jet(mask)[:, :, :3]
120
 
@@ -127,7 +123,6 @@ def show_final_layer_attention_maps(
127
  superimposed_img_pil = Image.fromarray(
128
  (superimposed_img * 255).astype(np.uint8)
129
  )
130
-
131
  return superimposed_img_pil
132
 
133
 
 
3
  from PIL import Image
4
  import torch
5
  from transformers import ViTForImageClassification, ViTImageProcessor
6
+ from datasets import load_dataset, DownloadConfig
7
  import matplotlib.pyplot as plt
8
  import numpy as np
9
  import cv2
 
14
 
15
  # Load dataset (adjust dataset_path accordingly)
16
  dataset_path = "pawlo2013/chest_xray"
17
+ download_config = DownloadConfig(timeout=100, max_retries=10)
18
+ train_dataset = load_dataset(dataset_path, split="train", download_config=download_config)
19
  class_names = train_dataset.features["label"].names
20
 
21
  # Load ViT model
 
34
  def classify_and_visualize(img, device="cpu", discard_ratio=0.9, head_fusion="mean"):
35
  img = img.convert("RGB")
36
  processed_input = processor(images=img, return_tensors="pt").to(device)
 
37
  processed_input = processed_input["pixel_values"].to(device)
38
 
39
  with torch.no_grad():
 
77
  ):
78
 
79
  with torch.no_grad():
 
80
  image = processed_input.squeeze(0)
 
81
  image = image - image.min()
82
  image = image / image.max()
83
 
 
103
  I = torch.eye(attention_heads_fused.size(-1)).to(device)
104
  a = (attention_heads_fused + 1.0 * I) / 2
105
  a = a / a.sum(dim=-1)
 
106
  result = torch.matmul(a, result)
107
 
108
  mask = result[0, 0, 1:]
 
111
  mask = mask / np.max(mask)
112
 
113
  mask = cv2.resize(mask, (224, 224))
 
114
  mask = (mask - np.min(mask)) / (np.max(mask) - np.min(mask))
115
  heatmap = plt.cm.jet(mask)[:, :, :3]
116
 
 
123
  superimposed_img_pil = Image.fromarray(
124
  (superimposed_img * 255).astype(np.uint8)
125
  )
 
126
  return superimposed_img_pil
127
 
128