rnmee commited on
Commit
68295e1
·
verified ·
1 Parent(s): 969f85e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -41
app.py CHANGED
@@ -1,46 +1,55 @@
1
  import torch
2
  import streamlit as st
3
- import torchvision.models as models
4
  import numpy as np
5
  import cv2
6
  from PIL import Image
7
  import torchvision.transforms as transforms
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
 
9
  @st.cache_resource
10
  def load_models():
11
- # Load classification model architecture
12
- classification_model = models.resnet152(pretrained=False) # Use the same architecture as before
13
- num_ftrs = classification_model.fc.in_features
14
- classification_model.fc = torch.nn.Linear(num_ftrs, 5) # Assuming 5 DR stages
15
- classification_checkpoint = torch.load("classifier.pt", map_location=torch.device("cpu"))
16
- classification_model.load_state_dict(classification_checkpoint["model_state_dict"])
17
  classification_model.eval()
18
-
19
  # Load segmentation model
20
  segmentation_model = torch.load("best_unet_model.pth", map_location=torch.device("cpu"))
21
  segmentation_model.eval()
22
-
23
  return classification_model, segmentation_model
24
 
25
-
26
  classification_model, segmentation_model = load_models()
27
 
28
  # Define preprocessing function for classification
29
  def preprocess_image(image):
30
- # Convert to numpy array & extract green channel
31
  image = np.array(image)
32
- green_channel = image[:, :, 1]
33
-
34
- # Apply CLAHE
35
  clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
36
  img_clahe = clahe.apply(green_channel)
37
-
38
- # Convert to tensor
39
  transform = transforms.Compose([
40
  transforms.ToTensor(),
41
  transforms.Normalize(mean=[0.5], std=[0.5])
42
  ])
43
-
44
  return transform(img_clahe).unsqueeze(0)
45
 
46
  # Define function for segmentation preprocessing
@@ -50,7 +59,6 @@ def preprocess_segmentation(image):
50
  transforms.ToTensor(),
51
  transforms.Normalize(mean=[0.5], std=[0.5])
52
  ])
53
-
54
  return transform(image).unsqueeze(0)
55
 
56
  # Streamlit UI
@@ -60,46 +68,37 @@ uploaded_file = st.file_uploader("Upload a Retinal Image", type=["jpg", "png", "
60
  if uploaded_file:
61
  image = Image.open(uploaded_file)
62
  st.image(image, caption="Uploaded Image", use_column_width=True)
63
-
64
  # Preprocess image for classification
65
  input_tensor = preprocess_image(image)
66
-
67
  # Run classification
68
  with torch.no_grad():
69
  output = classification_model(input_tensor)
70
  predicted_class = torch.argmax(output).item()
71
-
72
- # Display result
73
  dr_stages = ["No DR", "Mild", "Moderate", "Severe", "Proliferative DR"]
74
  st.write(f"**Diabetic Retinopathy Stage:** {dr_stages[predicted_class]}")
75
-
76
  # If DR detected, proceed to segmentation
77
  if predicted_class > 0:
78
  st.write("Lesion segmentation in progress...")
79
-
80
- # Preprocess for segmentation
81
  segmentation_input = preprocess_segmentation(image)
82
-
83
- # Run segmentation
84
  with torch.no_grad():
85
  segmentation_output = segmentation_model(segmentation_input)
86
 
87
- # Convert output to mask
88
  segmentation_mask = segmentation_output.squeeze().cpu().numpy()
89
  segmentation_mask = (segmentation_mask > 0.5).astype(np.uint8) * 255
90
-
91
- # Convert mask to image
 
 
92
  mask_image = Image.fromarray(segmentation_mask)
93
-
94
- # Display segmentation result
95
- st.image(mask_image, caption="Segmented Lesions", use_column_width=True)
96
-
97
- # Provide download button
98
- st.download_button(
99
- label="Download Segmented Mask",
100
- data=mask_image.tobytes(),
101
- file_name="segmented_mask.png",
102
- mime="image/png"
103
- )
104
  else:
105
  st.write("No DR detected. Segmentation not required.")
 
1
  import torch
2
  import streamlit as st
 
3
  import numpy as np
4
  import cv2
5
  from PIL import Image
6
  import torchvision.transforms as transforms
7
+ import torchvision.models as models
8
+ import io
9
+
10
+ # Define the classification model structure (ResNet152 with modified FC layer)
11
+ class ResNetClassifier(torch.nn.Module):
12
+ def __init__(self, num_classes=5):
13
+ super(ResNetClassifier, self).__init__()
14
+ self.model = models.resnet152(pretrained=False)
15
+ self.model.fc = torch.nn.Sequential(
16
+ torch.nn.Linear(self.model.fc.in_features, 512),
17
+ torch.nn.ReLU(),
18
+ torch.nn.Linear(512, num_classes)
19
+ )
20
+
21
+ def forward(self, x):
22
+ return self.model(x)
23
 
24
+ # Load models
25
  @st.cache_resource
26
  def load_models():
27
+ # Load classification model
28
+ classification_model = ResNetClassifier()
29
+ classifier_checkpoint = torch.load("classifier.pt", map_location=torch.device("cpu"))
30
+ classification_model.load_state_dict(classifier_checkpoint["model_state_dict"])
 
 
31
  classification_model.eval()
32
+
33
  # Load segmentation model
34
  segmentation_model = torch.load("best_unet_model.pth", map_location=torch.device("cpu"))
35
  segmentation_model.eval()
36
+
37
  return classification_model, segmentation_model
38
 
 
39
  classification_model, segmentation_model = load_models()
40
 
41
  # Define preprocessing function for classification
42
  def preprocess_image(image):
 
43
  image = np.array(image)
44
+ green_channel = image[:, :, 1] # Extract green channel
45
+
 
46
  clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
47
  img_clahe = clahe.apply(green_channel)
48
+
 
49
  transform = transforms.Compose([
50
  transforms.ToTensor(),
51
  transforms.Normalize(mean=[0.5], std=[0.5])
52
  ])
 
53
  return transform(img_clahe).unsqueeze(0)
54
 
55
  # Define function for segmentation preprocessing
 
59
  transforms.ToTensor(),
60
  transforms.Normalize(mean=[0.5], std=[0.5])
61
  ])
 
62
  return transform(image).unsqueeze(0)
63
 
64
  # Streamlit UI
 
68
  if uploaded_file:
69
  image = Image.open(uploaded_file)
70
  st.image(image, caption="Uploaded Image", use_column_width=True)
71
+
72
  # Preprocess image for classification
73
  input_tensor = preprocess_image(image)
74
+
75
  # Run classification
76
  with torch.no_grad():
77
  output = classification_model(input_tensor)
78
  predicted_class = torch.argmax(output).item()
79
+
 
80
  dr_stages = ["No DR", "Mild", "Moderate", "Severe", "Proliferative DR"]
81
  st.write(f"**Diabetic Retinopathy Stage:** {dr_stages[predicted_class]}")
82
+
83
  # If DR detected, proceed to segmentation
84
  if predicted_class > 0:
85
  st.write("Lesion segmentation in progress...")
86
+
 
87
  segmentation_input = preprocess_segmentation(image)
88
+
 
89
  with torch.no_grad():
90
  segmentation_output = segmentation_model(segmentation_input)
91
 
 
92
  segmentation_mask = segmentation_output.squeeze().cpu().numpy()
93
  segmentation_mask = (segmentation_mask > 0.5).astype(np.uint8) * 255
94
+
95
+ st.image(segmentation_mask, caption="Segmented Lesions", use_column_width=True)
96
+
97
+ # Provide download option
98
  mask_image = Image.fromarray(segmentation_mask)
99
+ buf = io.BytesIO()
100
+ mask_image.save(buf, format="PNG")
101
+ byte_im = buf.getvalue()
102
+ st.download_button("Download Segmentation Mask", data=byte_im, file_name="segmentation_mask.png", mime="image/png")
 
 
 
 
 
 
 
103
  else:
104
  st.write("No DR detected. Segmentation not required.")