Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -3,7 +3,7 @@ import gradio as gr
|
|
| 3 |
from PIL import Image
|
| 4 |
import torch
|
| 5 |
from transformers import ViTForImageClassification, ViTImageProcessor
|
| 6 |
-
from datasets import load_dataset
|
| 7 |
import matplotlib.pyplot as plt
|
| 8 |
import numpy as np
|
| 9 |
import cv2
|
|
@@ -14,7 +14,8 @@ processor = ViTImageProcessor.from_pretrained(model_name_or_path)
|
|
| 14 |
|
| 15 |
# Load dataset (adjust dataset_path accordingly)
|
| 16 |
dataset_path = "pawlo2013/chest_xray"
|
| 17 |
-
|
|
|
|
| 18 |
class_names = train_dataset.features["label"].names
|
| 19 |
|
| 20 |
# Load ViT model
|
|
@@ -33,7 +34,6 @@ model.eval()
|
|
| 33 |
def classify_and_visualize(img, device="cpu", discard_ratio=0.9, head_fusion="mean"):
|
| 34 |
img = img.convert("RGB")
|
| 35 |
processed_input = processor(images=img, return_tensors="pt").to(device)
|
| 36 |
-
|
| 37 |
processed_input = processed_input["pixel_values"].to(device)
|
| 38 |
|
| 39 |
with torch.no_grad():
|
|
@@ -77,9 +77,7 @@ def show_final_layer_attention_maps(
|
|
| 77 |
):
|
| 78 |
|
| 79 |
with torch.no_grad():
|
| 80 |
-
|
| 81 |
image = processed_input.squeeze(0)
|
| 82 |
-
|
| 83 |
image = image - image.min()
|
| 84 |
image = image / image.max()
|
| 85 |
|
|
@@ -105,7 +103,6 @@ def show_final_layer_attention_maps(
|
|
| 105 |
I = torch.eye(attention_heads_fused.size(-1)).to(device)
|
| 106 |
a = (attention_heads_fused + 1.0 * I) / 2
|
| 107 |
a = a / a.sum(dim=-1)
|
| 108 |
-
|
| 109 |
result = torch.matmul(a, result)
|
| 110 |
|
| 111 |
mask = result[0, 0, 1:]
|
|
@@ -114,7 +111,6 @@ def show_final_layer_attention_maps(
|
|
| 114 |
mask = mask / np.max(mask)
|
| 115 |
|
| 116 |
mask = cv2.resize(mask, (224, 224))
|
| 117 |
-
|
| 118 |
mask = (mask - np.min(mask)) / (np.max(mask) - np.min(mask))
|
| 119 |
heatmap = plt.cm.jet(mask)[:, :, :3]
|
| 120 |
|
|
@@ -127,7 +123,6 @@ def show_final_layer_attention_maps(
|
|
| 127 |
superimposed_img_pil = Image.fromarray(
|
| 128 |
(superimposed_img * 255).astype(np.uint8)
|
| 129 |
)
|
| 130 |
-
|
| 131 |
return superimposed_img_pil
|
| 132 |
|
| 133 |
|
|
|
|
| 3 |
from PIL import Image
|
| 4 |
import torch
|
| 5 |
from transformers import ViTForImageClassification, ViTImageProcessor
|
| 6 |
+
from datasets import load_dataset, DownloadConfig
|
| 7 |
import matplotlib.pyplot as plt
|
| 8 |
import numpy as np
|
| 9 |
import cv2
|
|
|
|
| 14 |
|
| 15 |
# Load dataset (adjust dataset_path accordingly)
|
| 16 |
dataset_path = "pawlo2013/chest_xray"
|
| 17 |
+
download_config = DownloadConfig(timeout=100, max_retries=10)
|
| 18 |
+
train_dataset = load_dataset(dataset_path, split="train", download_config=download_config)
|
| 19 |
class_names = train_dataset.features["label"].names
|
| 20 |
|
| 21 |
# Load ViT model
|
|
|
|
| 34 |
def classify_and_visualize(img, device="cpu", discard_ratio=0.9, head_fusion="mean"):
|
| 35 |
img = img.convert("RGB")
|
| 36 |
processed_input = processor(images=img, return_tensors="pt").to(device)
|
|
|
|
| 37 |
processed_input = processed_input["pixel_values"].to(device)
|
| 38 |
|
| 39 |
with torch.no_grad():
|
|
|
|
| 77 |
):
|
| 78 |
|
| 79 |
with torch.no_grad():
|
|
|
|
| 80 |
image = processed_input.squeeze(0)
|
|
|
|
| 81 |
image = image - image.min()
|
| 82 |
image = image / image.max()
|
| 83 |
|
|
|
|
| 103 |
I = torch.eye(attention_heads_fused.size(-1)).to(device)
|
| 104 |
a = (attention_heads_fused + 1.0 * I) / 2
|
| 105 |
a = a / a.sum(dim=-1)
|
|
|
|
| 106 |
result = torch.matmul(a, result)
|
| 107 |
|
| 108 |
mask = result[0, 0, 1:]
|
|
|
|
| 111 |
mask = mask / np.max(mask)
|
| 112 |
|
| 113 |
mask = cv2.resize(mask, (224, 224))
|
|
|
|
| 114 |
mask = (mask - np.min(mask)) / (np.max(mask) - np.min(mask))
|
| 115 |
heatmap = plt.cm.jet(mask)[:, :, :3]
|
| 116 |
|
|
|
|
| 123 |
superimposed_img_pil = Image.fromarray(
|
| 124 |
(superimposed_img * 255).astype(np.uint8)
|
| 125 |
)
|
|
|
|
| 126 |
return superimposed_img_pil
|
| 127 |
|
| 128 |
|