Update app.py
Browse files
app.py
CHANGED
|
@@ -10,13 +10,21 @@ from typing import Tuple
|
|
| 10 |
# Device configuration
|
| 11 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 12 |
|
| 13 |
-
#
|
| 14 |
CLASS_NAMES = ["Mild", "Moderate", "No_DR", "Proliferate_DR", "Severe"]
|
| 15 |
LESION_COLORS = {
|
| 16 |
0: [0, 0, 0], # Background (black)
|
| 17 |
1: [255, 255, 0], # Bright lesions (yellow)
|
| 18 |
2: [255, 0, 0] # Red lesions (red)
|
| 19 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
# ====================== UNET ARCHITECTURE ======================
|
| 22 |
class UNet(nn.Module):
|
|
@@ -199,11 +207,21 @@ def main():
|
|
| 199 |
probabilities = ps[0].cpu().numpy() * 100
|
| 200 |
|
| 201 |
st.subheader("Classification Results")
|
| 202 |
-
|
| 203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
st.write("No diabetic retinopathy detected - no segmentation needed.")
|
| 205 |
else:
|
| 206 |
-
st.error(f"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
st.write("**Confidence Levels:**")
|
| 208 |
for name, prob in zip(CLASS_NAMES, probabilities):
|
| 209 |
st.progress(int(prob))
|
|
|
|
| 10 |
# Device configuration
|
| 11 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 12 |
|
| 13 |
+
# ====================== CONSTANTS ======================
|
| 14 |
CLASS_NAMES = ["Mild", "Moderate", "No_DR", "Proliferate_DR", "Severe"]
|
| 15 |
LESION_COLORS = {
|
| 16 |
0: [0, 0, 0], # Background (black)
|
| 17 |
1: [255, 255, 0], # Bright lesions (yellow)
|
| 18 |
2: [255, 0, 0] # Red lesions (red)
|
| 19 |
}
|
| 20 |
+
UK_GRADES = {
|
| 21 |
+
"No_DR": "R0 – No retinopathy",
|
| 22 |
+
"Mild": "R1 – Background DR",
|
| 23 |
+
"Moderate": "R1 – Background DR",
|
| 24 |
+
"Severe": "R2 – Pre-proliferative DR",
|
| 25 |
+
"Proliferate_DR": "R3 – Proliferative DR"
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
|
| 29 |
# ====================== UNET ARCHITECTURE ======================
|
| 30 |
class UNet(nn.Module):
|
|
|
|
| 207 |
probabilities = ps[0].cpu().numpy() * 100
|
| 208 |
|
| 209 |
st.subheader("Classification Results")
|
| 210 |
+
predicted_class_name = CLASS_NAMES[pred_class]
|
| 211 |
+
uk_grade = UK_GRADES[predicted_class_name]
|
| 212 |
+
|
| 213 |
+
if predicted_class_name == "No_DR":
|
| 214 |
+
st.success(f"""
|
| 215 |
+
**Prediction:** {predicted_class_name}
|
| 216 |
+
**UK Grade:** {uk_grade}
|
| 217 |
+
""")
|
| 218 |
st.write("No diabetic retinopathy detected - no segmentation needed.")
|
| 219 |
else:
|
| 220 |
+
st.error(f"""
|
| 221 |
+
**Prediction:** {predicted_class_name}
|
| 222 |
+
**UK Grade:** {uk_grade}
|
| 223 |
+
""")
|
| 224 |
+
|
| 225 |
st.write("**Confidence Levels:**")
|
| 226 |
for name, prob in zip(CLASS_NAMES, probabilities):
|
| 227 |
st.progress(int(prob))
|