rnmee commited on
Commit
8d8a5c4
·
verified ·
1 Parent(s): ee11704

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -5
app.py CHANGED
@@ -8,8 +8,8 @@ import torchvision.transforms as transforms
8
  # Load models
9
  @st.cache_resource
10
  def load_models():
11
- # Load classification model
12
- classification_model = torch.load("resnet152_final.pt", map_location=torch.device("cpu"))
13
  classification_model.eval()
14
 
15
  # Load segmentation model
@@ -22,7 +22,7 @@ classification_model, segmentation_model = load_models()
22
 
23
  # Define preprocessing function for classification
24
  def preprocess_image(image):
25
- # Convert to grayscale & extract green channel
26
  image = np.array(image)
27
  green_channel = image[:, :, 1]
28
 
@@ -38,7 +38,7 @@ def preprocess_image(image):
38
 
39
  return transform(img_clahe).unsqueeze(0)
40
 
41
- # Define function for segmentation preprocessing (resize + normalize)
42
  def preprocess_segmentation(image):
43
  transform = transforms.Compose([
44
  transforms.Resize((512, 512)),
@@ -83,7 +83,18 @@ if uploaded_file:
83
  segmentation_mask = segmentation_output.squeeze().cpu().numpy()
84
  segmentation_mask = (segmentation_mask > 0.5).astype(np.uint8) * 255
85
 
 
 
 
86
  # Display segmentation result
87
- st.image(segmentation_mask, caption="Segmented Lesions", use_column_width=True)
 
 
 
 
 
 
 
 
88
  else:
89
  st.write("No DR detected. Segmentation not required.")
 
8
  # Load models
9
  @st.cache_resource
10
  def load_models():
11
+ # Load classification model (corrected filename)
12
+ classification_model = torch.load("classifier.pt", map_location=torch.device("cpu"))
13
  classification_model.eval()
14
 
15
  # Load segmentation model
 
22
 
23
  # Define preprocessing function for classification
24
  def preprocess_image(image):
25
+ # Convert to numpy array & extract green channel
26
  image = np.array(image)
27
  green_channel = image[:, :, 1]
28
 
 
38
 
39
  return transform(img_clahe).unsqueeze(0)
40
 
41
+ # Define function for segmentation preprocessing
42
  def preprocess_segmentation(image):
43
  transform = transforms.Compose([
44
  transforms.Resize((512, 512)),
 
83
  segmentation_mask = segmentation_output.squeeze().cpu().numpy()
84
  segmentation_mask = (segmentation_mask > 0.5).astype(np.uint8) * 255
85
 
86
+ # Convert mask to image
87
+ mask_image = Image.fromarray(segmentation_mask)
88
+
89
  # Display segmentation result
90
+ st.image(mask_image, caption="Segmented Lesions", use_column_width=True)
91
+
92
+ # Provide download button
93
+ st.download_button(
94
+ label="Download Segmented Mask",
95
+ data=mask_image.tobytes(),
96
+ file_name="segmented_mask.png",
97
+ mime="image/png"
98
+ )
99
  else:
100
  st.write("No DR detected. Segmentation not required.")