ludaladila commited on
Commit
5ab5efe
·
1 Parent(s): c3a7b51

Add application file

Browse files
Files changed (1) hide show
  1. app.py +167 -0
app.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import streamlit as st
3
+ import torch
4
+ import torch.nn as nn
5
+ from PIL import Image
6
+ import numpy as np
7
+ from torchvision import transforms
8
+ from models.model import DeepLearningModel
9
+ import joblib
10
+ from scripts.xai_eval import convert_to_gradcam
11
+ import cv2
12
+
13
+
14
+ # Bucket name
15
+ BUCKET_NAME = "aipi540-cv"
16
+ VERTEX_AI_ENDPOINT = ""
17
+
18
+ # class type
19
+ class_names = ["Normal", "Mild Diabetic Retinopathy", "Severe Diabetic Retinopathy"]
20
+
21
+
22
+ # need to change the following code
23
+ class ModelHandler:
24
+ def __init__(self):
25
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+ # Image preprocessing
27
+ self.transform = transforms.Compose([
28
+ transforms.Resize((224, 224)),
29
+ transforms.ToTensor(),
30
+ transforms.Normalize(mean=[0.3205, 0.2244, 0.1613],
31
+ std=[0.2996,0.2158, 0.1711]) ## I think this is out of date slightly? at least with what VGG uses
32
+ ])
33
+
34
+
35
+ # Preprocess the image
36
+ def preprocess_image(self, image):
37
+ """Preprocess the image for model input."""
38
+ img_tensor = self.transform(image)
39
+ return img_tensor.unsqueeze(0).to(self.device)
40
+
41
+ def load_model(model_type):
42
+ handler = ModelHandler()
43
+ device = handler.device
44
+
45
+ # Load the model
46
+ model = DeepLearningModel()
47
+ model.load_state_dict(torch.load("models/vgg16_model.pth", map_location=device))
48
+
49
+ model = model.to(device)
50
+ if hasattr(model, 'eval'):
51
+ model.eval()
52
+
53
+ return model, handler.preprocess_image
54
+
55
+ # Prediction function
56
+ def predict(model, image_tensor):
57
+ '''Predict the class of the input image using the given model.'''
58
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59
+ model.eval()
60
+
61
+ with torch.no_grad():
62
+ image_tensor = image_tensor.to(device)
63
+ outputs = model(image_tensor)
64
+
65
+ # Convert outputs to probabilities and predicted class
66
+ probabilities = torch.softmax(outputs, dim=1)
67
+ predicted_class = torch.argmax(probabilities, dim=1).item()
68
+ class_probabilities = probabilities[0].cpu().numpy()
69
+
70
+ return predicted_class, class_probabilities
71
+
72
+ def generate_gradcam(model, image_tensor):
73
+ """Generate Grad-CAM heatmap for the input image using the given model."""
74
+ try:
75
+ cam = convert_to_gradcam(model)
76
+ heatmap = cam(input_tensor=image_tensor, targets=None)
77
+
78
+ # Remove batch dimension and convert to numpy array
79
+ if isinstance(heatmap, torch.Tensor):
80
+ heatmap = heatmap.squeeze().cpu().numpy()
81
+ else:
82
+ heatmap = heatmap.squeeze()
83
+
84
+ # Normalize the heatmap to [0, 1] range
85
+ heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())
86
+ heatmap = np.uint8(255 * heatmap)
87
+
88
+ heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
89
+
90
+ return heatmap
91
+ except Exception as e:
92
+ return f"Grad-CAM Error: {str(e)}"
93
+
94
+ # the streamlit app
95
+ def main():
96
+ st.set_page_config(
97
+ page_title="Diabetic Retinopathy Prediction",
98
+ page_icon="👁️",
99
+ layout="wide"
100
+ )
101
+
102
+ st.title("👁️ Diabetic Retinopathy Detection System")
103
+ st.write("Upload a fundus image to detect diabetic retinopathy severity")
104
+
105
+ # 只加载 Deep Learning Model (VGG16)
106
+ st.sidebar.header("Model Information")
107
+ st.sidebar.write("Using Deep Learning Model (VGG16)")
108
+
109
+ # 加载 Deep Learning Model
110
+ model, preprocess = load_model("Deep Learning Model")
111
+
112
+ st.sidebar.header("About")
113
+ st.sidebar.markdown("""
114
+ This system aims to detect diabetic retinopathy (DR) from fundus images.
115
+ ### Model:
116
+ - **Deep Learning Model**: VGG16-based architecture
117
+
118
+ ### Classes:
119
+ - Normal (No DR)
120
+ - Mild DR
121
+ - Severe DR
122
+ """)
123
+
124
+ st.header("Image Upload")
125
+ uploaded_file = st.file_uploader(
126
+ "Choose a fundus image",
127
+ type=["jpg", "jpeg", "png"]
128
+ )
129
+
130
+ if uploaded_file is not None:
131
+ image = Image.open(uploaded_file).convert('RGB')
132
+ st.image(image, caption="Uploaded Image", use_container_width=True)
133
+
134
+ if st.button("Analyze Image"):
135
+ try:
136
+ processed_image = preprocess(image)
137
+
138
+ with st.spinner("Analyzing image..."):
139
+ predicted_class, class_probs = predict(model, processed_image)
140
+
141
+ st.success("Analysis Complete!")
142
+
143
+ # Display prediction results
144
+ st.header("Prediction Results")
145
+ st.write(f"**Predicted Condition:** {class_names[predicted_class]}")
146
+ st.write("**Class Probabilities:**")
147
+ st.json({class_names[i]: float(class_probs[i]) for i in range(len(class_probs))})
148
+
149
+ # *
150
+ with st.spinner("Generating XAI..."):
151
+ heatmap = generate_gradcam(model, processed_image)
152
+
153
+ st.header("Grad-CAM Explanation")
154
+ if isinstance(heatmap, str):
155
+ st.error(heatmap)
156
+ else:
157
+ st.image(heatmap, caption="Grad-CAM Heatmap", use_container_width=True)
158
+
159
+ except Exception as e:
160
+ st.error(f"Error during analysis: {str(e)}")
161
+
162
+
163
+
164
+
165
+
166
+ if __name__ == "__main__":
167
+ main()