Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import torch
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
from torchvision import transforms
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
from PathDino import get_pathDino_model
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
|
| 13 |
+
|
| 14 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 15 |
+
|
| 16 |
+
# Load PathDino model and image transforms
|
| 17 |
+
model, image_transforms = get_pathDino_model("PathDino512.pth")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
st.sidebar.markdown("### PathDino")
|
| 21 |
+
st.sidebar.markdown(
|
| 22 |
+
"PathDino is a lightweight histology transformer consisting of just five small vision transformer blocks. "
|
| 23 |
+
"PathDino is a customized ViT architecture, finely tuned to the nuances of histological images. It not only exhibits "
|
| 24 |
+
"superior performance but also effectively reduces susceptibility to overfitting, a common challenge in histology "
|
| 25 |
+
"image analysis.\n\n"
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
default_image_url_compare = "images/HistRotate.png"
|
| 29 |
+
st.sidebar.image(default_image_url_compare, caption='A 360 rotation augmentation for training models on histopathology images. Unlike training on natural images where the rotation may change the context of the visual data, rotating a histopathology patch does not change the context and it improves the learning process for better reliable embedding learning.', width=500)
|
| 30 |
+
|
| 31 |
+
default_image_url_compare = "images/FigPathDino_parameters_FLOPs_compare.png"
|
| 32 |
+
st.sidebar.image(default_image_url_compare, caption='PathDino Vs its counterparts. Number of Parameters (Millions) vs the patch-level retrieval with macro avg F-score of majority vote (MV@5) on CAMELYON16 dataset. The bubble size represents the FLOPs.', width=500)
|
| 33 |
+
|
| 34 |
+
default_image_url_compare = "images/ActivationMap.png"
|
| 35 |
+
st.sidebar.image(default_image_url_compare, caption='Attention Visualization. When visualizing attention patterns, our PathDino transformer outperforms HIPT-small and DinoSSLPath, despite being trained on a smaller dataset of 6 million TCGA patches. In contrast, DinoSSLPath and HIPT were trained on much larger datasets, with 19 million and 104 million TCGA patches, respectively.', width=500)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def visualize_attention_ViT(model, img, patch_size=16):
|
| 40 |
+
attention_list = []
|
| 41 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 42 |
+
w_featmap = img.shape[-2] // patch_size
|
| 43 |
+
h_featmap = img.shape[-1] // patch_size
|
| 44 |
+
attentions = model.get_last_selfattention(img.to(device))
|
| 45 |
+
nh = attentions.shape[1] # number of head
|
| 46 |
+
# we keep only the output patch attention
|
| 47 |
+
attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
|
| 48 |
+
attentions = attentions.reshape(nh, w_featmap, h_featmap)
|
| 49 |
+
attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=patch_size, mode="nearest")[0].detach().numpy()
|
| 50 |
+
for j in range(nh):
|
| 51 |
+
attention_list.append(attentions[j])
|
| 52 |
+
return attention_list
|
| 53 |
+
|
| 54 |
+
# Define the function to generate activation maps
|
| 55 |
+
def generate_activation_maps(image):
|
| 56 |
+
preprocess = transforms.Compose([
|
| 57 |
+
transforms.Resize((512, 512)),
|
| 58 |
+
transforms.CenterCrop(512),
|
| 59 |
+
transforms.ToTensor(),
|
| 60 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize the tensors
|
| 61 |
+
])
|
| 62 |
+
image_tensor = preprocess(image)
|
| 63 |
+
img = image_tensor.unsqueeze(0).to(device)
|
| 64 |
+
# Generate activation maps
|
| 65 |
+
with torch.no_grad():
|
| 66 |
+
attention_list = visualize_attention_ViT(model=model, img=img, patch_size=16)
|
| 67 |
+
return attention_list
|
| 68 |
+
|
| 69 |
+
# Streamlit UI
|
| 70 |
+
st.title("PathDino - Compact ViT for Histolopathology Image Analysis")
|
| 71 |
+
st.write("Upload a histology image to view the activation maps.")
|
| 72 |
+
|
| 73 |
+
# uploaded_image = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
|
| 74 |
+
uploaded_image = "images/HistRotate.png"
|
| 75 |
+
uploaded_image = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
|
| 76 |
+
|
| 77 |
+
if uploaded_image is not None:
|
| 78 |
+
columns = st.columns(3)
|
| 79 |
+
columns[1].image(uploaded_image, caption="Uploaded Image", width=300)
|
| 80 |
+
|
| 81 |
+
# Load the image and apply preprocessing
|
| 82 |
+
uploaded_image = Image.open(uploaded_image).convert('RGB')
|
| 83 |
+
attention_list = generate_activation_maps(uploaded_image)
|
| 84 |
+
print(len(attention_list))
|
| 85 |
+
st.subheader(f"Attention Maps of the input image")
|
| 86 |
+
columns = st.columns(len(attention_list)//2)
|
| 87 |
+
columns2 = st.columns(len(attention_list)//2)
|
| 88 |
+
for index, col in enumerate(columns):
|
| 89 |
+
# Create a plot
|
| 90 |
+
plt.plot(512, 512)
|
| 91 |
+
|
| 92 |
+
# Remove x and y axis labels
|
| 93 |
+
plt.xticks([]) # Hide x-axis ticks and labels
|
| 94 |
+
plt.yticks([]) # Hide y-axis ticks and labels
|
| 95 |
+
|
| 96 |
+
# Alternatively, if you only want to hide the labels and keep the ticks:
|
| 97 |
+
plt.gca().axes.get_xaxis().set_visible(False)
|
| 98 |
+
plt.gca().axes.get_yaxis().set_visible(False)
|
| 99 |
+
|
| 100 |
+
plt.imshow(attention_list[index])
|
| 101 |
+
col.pyplot(plt)
|
| 102 |
+
plt.close()
|
| 103 |
+
|
| 104 |
+
for index, col in enumerate(columns2):
|
| 105 |
+
|
| 106 |
+
index = index + len(attention_list)//2
|
| 107 |
+
# Create a plot
|
| 108 |
+
plt.plot(512, 512)
|
| 109 |
+
|
| 110 |
+
# Remove x and y axis labels
|
| 111 |
+
plt.xticks([]) # Hide x-axis ticks and labels
|
| 112 |
+
plt.yticks([]) # Hide y-axis ticks and labels
|
| 113 |
+
|
| 114 |
+
# Alternatively, if you only want to hide the labels and keep the ticks:
|
| 115 |
+
plt.gca().axes.get_xaxis().set_visible(False)
|
| 116 |
+
plt.gca().axes.get_yaxis().set_visible(False)
|
| 117 |
+
|
| 118 |
+
plt.imshow(attention_list[index])
|
| 119 |
+
col.pyplot(plt)
|
| 120 |
+
plt.close()
|