File size: 5,281 Bytes
5ab5efe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# app.py
import streamlit as st
import torch
import torch.nn as nn
from PIL import Image
import numpy as np
from torchvision import transforms
from models.model import DeepLearningModel
import joblib
from scripts.xai_eval import convert_to_gradcam
import cv2


# Bucket name
BUCKET_NAME = "aipi540-cv"
VERTEX_AI_ENDPOINT = ""

# class type
class_names = ["Normal", "Mild Diabetic Retinopathy", "Severe Diabetic Retinopathy"]


# need to change the following code
class ModelHandler:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # Image preprocessing
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.3205, 0.2244, 0.1613], 
                               std=[0.2996,0.2158, 0.1711]) ## I think this is out of date slightly? at least with what VGG uses
        ])

    
    # Preprocess the image
    def preprocess_image(self, image):
        """Preprocess the image for model input."""
        img_tensor = self.transform(image)
        return img_tensor.unsqueeze(0).to(self.device)

def load_model(model_type):
    handler = ModelHandler()
    device = handler.device
    
    # Load the model
    model = DeepLearningModel()
    model.load_state_dict(torch.load("models/vgg16_model.pth", map_location=device))
    
    model = model.to(device)
    if hasattr(model, 'eval'):
        model.eval()
    
    return model, handler.preprocess_image

# Prediction function
def predict(model, image_tensor):
    '''Predict the class of the input image using the given model.'''
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval() 

    with torch.no_grad():
        image_tensor = image_tensor.to(device)
        outputs = model(image_tensor)
    
    #  Convert outputs to probabilities and predicted class
    probabilities = torch.softmax(outputs, dim=1)  
    predicted_class = torch.argmax(probabilities, dim=1).item()
    class_probabilities = probabilities[0].cpu().numpy()

    return predicted_class, class_probabilities

def generate_gradcam(model, image_tensor):
    """Generate Grad-CAM heatmap for the input image using the given model."""
    try:
        cam = convert_to_gradcam(model)
        heatmap = cam(input_tensor=image_tensor, targets=None)  

        # Remove batch dimension and convert to numpy array
        if isinstance(heatmap, torch.Tensor):
            heatmap = heatmap.squeeze().cpu().numpy() 
        else:
            heatmap = heatmap.squeeze()  

        # Normalize the heatmap to [0, 1] range
        heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())  
        heatmap = np.uint8(255 * heatmap)

        heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)

        return heatmap
    except Exception as e:
        return f"Grad-CAM Error: {str(e)}"

# the streamlit app
def main():
    st.set_page_config(
        page_title="Diabetic Retinopathy Prediction",
        page_icon="👁️",
        layout="wide"
    )
    
    st.title("👁️ Diabetic Retinopathy Detection System")
    st.write("Upload a fundus image to detect diabetic retinopathy severity")

    # 只加载 Deep Learning Model (VGG16)
    st.sidebar.header("Model Information")
    st.sidebar.write("Using Deep Learning Model (VGG16)")

    # 加载 Deep Learning Model
    model, preprocess = load_model("Deep Learning Model")

    st.sidebar.header("About")
    st.sidebar.markdown("""
    This system aims to detect diabetic retinopathy (DR) from fundus images.
    ### Model:
    - **Deep Learning Model**: VGG16-based architecture
    
    ### Classes:
    - Normal (No DR)
    - Mild DR
    - Severe DR
    """)

    st.header("Image Upload")
    uploaded_file = st.file_uploader(
        "Choose a fundus image", 
        type=["jpg", "jpeg", "png"]
    )

    if uploaded_file is not None:
        image = Image.open(uploaded_file).convert('RGB')
        st.image(image, caption="Uploaded Image", use_container_width=True)

        if st.button("Analyze Image"):
            try:
                processed_image = preprocess(image)

                with st.spinner("Analyzing image..."):
                    predicted_class, class_probs = predict(model, processed_image)

                st.success("Analysis Complete!")

                # Display prediction results
                st.header("Prediction Results")
                st.write(f"**Predicted Condition:** {class_names[predicted_class]}")
                st.write("**Class Probabilities:**")
                st.json({class_names[i]: float(class_probs[i]) for i in range(len(class_probs))})

                # *
                with st.spinner("Generating XAI..."):
                    heatmap = generate_gradcam(model, processed_image)

                st.header("Grad-CAM Explanation")
                if isinstance(heatmap, str):  
                    st.error(heatmap)
                else:
                    st.image(heatmap, caption="Grad-CAM Heatmap", use_container_width=True)

            except Exception as e:
                st.error(f"Error during analysis: {str(e)}")

            

                

if __name__ == "__main__":
    main()