rnmee commited on
Commit
1ae3b71
·
verified ·
1 Parent(s): 3d986df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -4
app.py CHANGED
@@ -10,13 +10,21 @@ from typing import Tuple
10
  # Device configuration
11
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
 
13
- # Constants
14
  CLASS_NAMES = ["Mild", "Moderate", "No_DR", "Proliferate_DR", "Severe"]
15
  LESION_COLORS = {
16
  0: [0, 0, 0], # Background (black)
17
  1: [255, 255, 0], # Bright lesions (yellow)
18
  2: [255, 0, 0] # Red lesions (red)
19
  }
 
 
 
 
 
 
 
 
20
 
21
  # ====================== UNET ARCHITECTURE ======================
22
  class UNet(nn.Module):
@@ -199,11 +207,21 @@ def main():
199
  probabilities = ps[0].cpu().numpy() * 100
200
 
201
  st.subheader("Classification Results")
202
- if CLASS_NAMES[pred_class] == "No_DR": # Changed to compare class name
203
- st.success(f"**Prediction:** {CLASS_NAMES[pred_class]}")
 
 
 
 
 
 
204
  st.write("No diabetic retinopathy detected - no segmentation needed.")
205
  else:
206
- st.error(f"**Prediction:** {CLASS_NAMES[pred_class]}")
 
 
 
 
207
  st.write("**Confidence Levels:**")
208
  for name, prob in zip(CLASS_NAMES, probabilities):
209
  st.progress(int(prob))
 
10
  # Device configuration
11
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
 
13
+ # ====================== CONSTANTS ======================
14
  CLASS_NAMES = ["Mild", "Moderate", "No_DR", "Proliferate_DR", "Severe"]
15
  LESION_COLORS = {
16
  0: [0, 0, 0], # Background (black)
17
  1: [255, 255, 0], # Bright lesions (yellow)
18
  2: [255, 0, 0] # Red lesions (red)
19
  }
20
+ UK_GRADES = {
21
+ "No_DR": "R0 – No retinopathy",
22
+ "Mild": "R1 – Background DR",
23
+ "Moderate": "R1 – Background DR",
24
+ "Severe": "R2 – Pre-proliferative DR",
25
+ "Proliferate_DR": "R3 – Proliferative DR"
26
+ }
27
+
28
 
29
  # ====================== UNET ARCHITECTURE ======================
30
  class UNet(nn.Module):
 
207
  probabilities = ps[0].cpu().numpy() * 100
208
 
209
  st.subheader("Classification Results")
210
+ predicted_class_name = CLASS_NAMES[pred_class]
211
+ uk_grade = UK_GRADES[predicted_class_name]
212
+
213
+ if predicted_class_name == "No_DR":
214
+ st.success(f"""
215
+ **Prediction:** {predicted_class_name}
216
+ **UK Grade:** {uk_grade}
217
+ """)
218
  st.write("No diabetic retinopathy detected - no segmentation needed.")
219
  else:
220
+ st.error(f"""
221
+ **Prediction:** {predicted_class_name}
222
+ **UK Grade:** {uk_grade}
223
+ """)
224
+
225
  st.write("**Confidence Levels:**")
226
  for name, prob in zip(CLASS_NAMES, probabilities):
227
  st.progress(int(prob))