rnmee commited on
Commit
d1e252d
·
verified ·
1 Parent(s): 8cd493f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -0
app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import streamlit as st
3
+ import numpy as np
4
+ import cv2
5
+ from PIL import Image
6
+ import torchvision.transforms as transforms
7
+
8
+ # Load models
9
+ @st.cache_resource
10
+ def load_models():
11
+ # Load classification model
12
+ classification_model = torch.load("resnet152_final.pt", map_location=torch.device("cpu"))
13
+ classification_model.eval()
14
+
15
+ # Load segmentation model
16
+ segmentation_model = torch.load("best_unet_model.pth", map_location=torch.device("cpu"))
17
+ segmentation_model.eval()
18
+
19
+ return classification_model, segmentation_model
20
+
21
+ classification_model, segmentation_model = load_models()
22
+
23
+ # Define preprocessing function for classification
24
+ def preprocess_image(image):
25
+ # Convert to grayscale & extract green channel
26
+ image = np.array(image)
27
+ green_channel = image[:, :, 1]
28
+
29
+ # Apply CLAHE
30
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
31
+ img_clahe = clahe.apply(green_channel)
32
+
33
+ # Convert to tensor
34
+ transform = transforms.Compose([
35
+ transforms.ToTensor(),
36
+ transforms.Normalize(mean=[0.5], std=[0.5])
37
+ ])
38
+
39
+ return transform(img_clahe).unsqueeze(0)
40
+
41
+ # Define function for segmentation preprocessing (resize + normalize)
42
+ def preprocess_segmentation(image):
43
+ transform = transforms.Compose([
44
+ transforms.Resize((512, 512)),
45
+ transforms.ToTensor(),
46
+ transforms.Normalize(mean=[0.5], std=[0.5])
47
+ ])
48
+
49
+ return transform(image).unsqueeze(0)
50
+
51
+ # Streamlit UI
52
+ st.title("Diabetic Retinopathy Detection, Classification & Segmentation")
53
+
54
+ uploaded_file = st.file_uploader("Upload a Retinal Image", type=["jpg", "png", "jpeg"])
55
+ if uploaded_file:
56
+ image = Image.open(uploaded_file)
57
+ st.image(image, caption="Uploaded Image", use_column_width=True)
58
+
59
+ # Preprocess image for classification
60
+ input_tensor = preprocess_image(image)
61
+
62
+ # Run classification
63
+ with torch.no_grad():
64
+ output = classification_model(input_tensor)
65
+ predicted_class = torch.argmax(output).item()
66
+
67
+ # Display result
68
+ dr_stages = ["No DR", "Mild", "Moderate", "Severe", "Proliferative DR"]
69
+ st.write(f"**Diabetic Retinopathy Stage:** {dr_stages[predicted_class]}")
70
+
71
+ # If DR detected, proceed to segmentation
72
+ if predicted_class > 0:
73
+ st.write("Lesion segmentation in progress...")
74
+
75
+ # Preprocess for segmentation
76
+ segmentation_input = preprocess_segmentation(image)
77
+
78
+ # Run segmentation
79
+ with torch.no_grad():
80
+ segmentation_output = segmentation_model(segmentation_input)
81
+
82
+ # Convert output to mask
83
+ segmentation_mask = segmentation_output.squeeze().cpu().numpy()
84
+ segmentation_mask = (segmentation_mask > 0.5).astype(np.uint8) * 255
85
+
86
+ # Display segmentation result
87
+ st.image(segmentation_mask, caption="Segmented Lesions", use_column_width=True)
88
+ else:
89
+ st.write("No DR detected. Segmentation not required.")