File size: 5,281 Bytes
5ab5efe | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | # app.py
import streamlit as st
import torch
import torch.nn as nn
from PIL import Image
import numpy as np
from torchvision import transforms
from models.model import DeepLearningModel
import joblib
from scripts.xai_eval import convert_to_gradcam
import cv2
# Bucket name
BUCKET_NAME = "aipi540-cv"
VERTEX_AI_ENDPOINT = ""
# class type
class_names = ["Normal", "Mild Diabetic Retinopathy", "Severe Diabetic Retinopathy"]
# need to change the following code
class ModelHandler:
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Image preprocessing
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.3205, 0.2244, 0.1613],
std=[0.2996,0.2158, 0.1711]) ## I think this is out of date slightly? at least with what VGG uses
])
# Preprocess the image
def preprocess_image(self, image):
"""Preprocess the image for model input."""
img_tensor = self.transform(image)
return img_tensor.unsqueeze(0).to(self.device)
def load_model(model_type):
handler = ModelHandler()
device = handler.device
# Load the model
model = DeepLearningModel()
model.load_state_dict(torch.load("models/vgg16_model.pth", map_location=device))
model = model.to(device)
if hasattr(model, 'eval'):
model.eval()
return model, handler.preprocess_image
# Prediction function
def predict(model, image_tensor):
'''Predict the class of the input image using the given model.'''
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval()
with torch.no_grad():
image_tensor = image_tensor.to(device)
outputs = model(image_tensor)
# Convert outputs to probabilities and predicted class
probabilities = torch.softmax(outputs, dim=1)
predicted_class = torch.argmax(probabilities, dim=1).item()
class_probabilities = probabilities[0].cpu().numpy()
return predicted_class, class_probabilities
def generate_gradcam(model, image_tensor):
"""Generate Grad-CAM heatmap for the input image using the given model."""
try:
cam = convert_to_gradcam(model)
heatmap = cam(input_tensor=image_tensor, targets=None)
# Remove batch dimension and convert to numpy array
if isinstance(heatmap, torch.Tensor):
heatmap = heatmap.squeeze().cpu().numpy()
else:
heatmap = heatmap.squeeze()
# Normalize the heatmap to [0, 1] range
heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())
heatmap = np.uint8(255 * heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
return heatmap
except Exception as e:
return f"Grad-CAM Error: {str(e)}"
# the streamlit app
def main():
st.set_page_config(
page_title="Diabetic Retinopathy Prediction",
page_icon="👁️",
layout="wide"
)
st.title("👁️ Diabetic Retinopathy Detection System")
st.write("Upload a fundus image to detect diabetic retinopathy severity")
# 只加载 Deep Learning Model (VGG16)
st.sidebar.header("Model Information")
st.sidebar.write("Using Deep Learning Model (VGG16)")
# 加载 Deep Learning Model
model, preprocess = load_model("Deep Learning Model")
st.sidebar.header("About")
st.sidebar.markdown("""
This system aims to detect diabetic retinopathy (DR) from fundus images.
### Model:
- **Deep Learning Model**: VGG16-based architecture
### Classes:
- Normal (No DR)
- Mild DR
- Severe DR
""")
st.header("Image Upload")
uploaded_file = st.file_uploader(
"Choose a fundus image",
type=["jpg", "jpeg", "png"]
)
if uploaded_file is not None:
image = Image.open(uploaded_file).convert('RGB')
st.image(image, caption="Uploaded Image", use_container_width=True)
if st.button("Analyze Image"):
try:
processed_image = preprocess(image)
with st.spinner("Analyzing image..."):
predicted_class, class_probs = predict(model, processed_image)
st.success("Analysis Complete!")
# Display prediction results
st.header("Prediction Results")
st.write(f"**Predicted Condition:** {class_names[predicted_class]}")
st.write("**Class Probabilities:**")
st.json({class_names[i]: float(class_probs[i]) for i in range(len(class_probs))})
# *
with st.spinner("Generating XAI..."):
heatmap = generate_gradcam(model, processed_image)
st.header("Grad-CAM Explanation")
if isinstance(heatmap, str):
st.error(heatmap)
else:
st.image(heatmap, caption="Grad-CAM Heatmap", use_container_width=True)
except Exception as e:
st.error(f"Error during analysis: {str(e)}")
if __name__ == "__main__":
main() |