Spaces:
Runtime error
Runtime error
Commit
·
5174b1f
1
Parent(s):
b6c245f
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import gradio as gr
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from torchvision import models, transforms
|
| 9 |
+
from torchvision.models.feature_extraction import create_feature_extractor
|
| 10 |
+
from transformers import ViTForImageClassification
|
| 11 |
+
|
| 12 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 13 |
+
|
| 14 |
+
labels = json.loads(Path("labels.json").read_text())
|
| 15 |
+
|
| 16 |
+
# Load ResNet-50
|
| 17 |
+
resnet50 = models.resnet50(pretrained=True).to(device)
|
| 18 |
+
resnet50.eval()
|
| 19 |
+
|
| 20 |
+
# Create ResNet feature extractor
|
| 21 |
+
feature_extractor = create_feature_extractor(resnet50, return_nodes=["layer4", "fc"])
|
| 22 |
+
fc_layer_weights = resnet50.fc.weight
|
| 23 |
+
|
| 24 |
+
# Load ViT
|
| 25 |
+
vit = ViTForImageClassification.from_pretrained("raedinkhaled/vit-base-mri").to(
|
| 26 |
+
device
|
| 27 |
+
)
|
| 28 |
+
vit.eval()
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 32 |
+
|
| 33 |
+
preprocess = transforms.Compose(
|
| 34 |
+
[transforms.Resize((224, 224)), transforms.ToTensor(), normalize]
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
examples = sorted([f.as_posix() for f in Path("examples").glob("*")])
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_cam(img_tensor):
|
| 41 |
+
output = feature_extractor(img_tensor)
|
| 42 |
+
cnn_features = output["layer4"].squeeze()
|
| 43 |
+
class_id = output["fc"].argmax()
|
| 44 |
+
|
| 45 |
+
cam = fc_layer_weights[class_id].matmul(cnn_features.flatten(1))
|
| 46 |
+
cam = cam.reshape(cnn_features.shape[1], cnn_features.shape[2])
|
| 47 |
+
|
| 48 |
+
return cam.cpu().numpy(), labels[class_id]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_attention_mask(img_tensor):
|
| 52 |
+
result = vit(img_tensor, output_attentions=True)
|
| 53 |
+
class_id = result[0].argmax()
|
| 54 |
+
attention_probs = torch.stack(result[1]).squeeze(1)
|
| 55 |
+
|
| 56 |
+
# Average the attention at each layer over all heads
|
| 57 |
+
attention_probs = torch.mean(attention_probs, dim=1)
|
| 58 |
+
residual = torch.eye(attention_probs.size(-1)).to(device)
|
| 59 |
+
attention_probs = 0.5 * attention_probs + 0.5 * residual
|
| 60 |
+
|
| 61 |
+
# normalize by layer
|
| 62 |
+
attention_probs = attention_probs / attention_probs.sum(dim=-1).unsqueeze(-1)
|
| 63 |
+
|
| 64 |
+
attention_rollout = attention_probs[0]
|
| 65 |
+
|
| 66 |
+
for i in range(1, attention_probs.size(0)):
|
| 67 |
+
attention_rollout = torch.matmul(attention_probs[i], attention_rollout)
|
| 68 |
+
|
| 69 |
+
# Attention between cls token and patches
|
| 70 |
+
mask = attention_rollout[0, 1:]
|
| 71 |
+
mask_size = np.sqrt(mask.size(0)).astype(int)
|
| 72 |
+
mask = mask.reshape(mask_size, mask_size)
|
| 73 |
+
|
| 74 |
+
return mask.cpu().numpy(), labels[class_id]
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def plot_mask_on_image(image, mask):
|
| 78 |
+
# min-max normalization
|
| 79 |
+
mask = (mask - mask.min()) / mask.max()
|
| 80 |
+
mask = (255 * mask).astype(np.uint8)
|
| 81 |
+
mask = cv2.resize(mask, image.size)
|
| 82 |
+
|
| 83 |
+
heatmap = cv2.applyColorMap(mask, cv2.COLORMAP_JET)
|
| 84 |
+
result = heatmap * 0.3 + np.array(image) * 0.5
|
| 85 |
+
return result.astype(np.uint8)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def inference(img):
|
| 89 |
+
img_tensor = preprocess(img).unsqueeze(0).to(device)
|
| 90 |
+
|
| 91 |
+
with torch.no_grad():
|
| 92 |
+
cam, resnet_label = get_cam(img_tensor)
|
| 93 |
+
attention_mask, vit_label = get_attention_mask(img_tensor)
|
| 94 |
+
|
| 95 |
+
cam_result = plot_mask_on_image(img, cam)
|
| 96 |
+
rollout_result = plot_mask_on_image(img, attention_mask)
|
| 97 |
+
|
| 98 |
+
return resnet_label, cam_result, vit_label, rollout_result
|
| 99 |
+
|
| 100 |
+
if __name__ == "__main__":
|
| 101 |
+
interface = gr.Interface(
|
| 102 |
+
fn=inference,
|
| 103 |
+
inputs=gr.inputs.Image(type="pil", label="Input Image"),
|
| 104 |
+
outputs=[
|
| 105 |
+
gr.outputs.Label(num_top_classes=1, type="auto", label="ResNet Label"),
|
| 106 |
+
gr.outputs.Image(type="auto", label="ResNet CAM"),
|
| 107 |
+
gr.outputs.Label(num_top_classes=1, type="auto", label="ViT Label"),
|
| 108 |
+
gr.outputs.Image(type="auto", label="raedinkhaled/vit-base-mri CAM"),
|
| 109 |
+
],
|
| 110 |
+
examples=examples,
|
| 111 |
+
title="Transformer Explainability On Our Pre Trained Model",
|
| 112 |
+
live=True,
|
| 113 |
+
)
|
| 114 |
+
interface.launch()
|