rnmee commited on
Commit
4938b3b
·
verified ·
1 Parent(s): 83adec3

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -89
app.py DELETED
@@ -1,89 +0,0 @@
1
- import torch
2
- import streamlit as st
3
- import numpy as np
4
- import cv2
5
- from PIL import Image
6
- import torchvision.transforms as transforms
7
-
8
- # Load models
9
- @st.cache_resource
10
- def load_models():
11
- # Load classification model
12
- classification_model = torch.load("resnet152_final.pt", map_location=torch.device("cpu"))
13
- classification_model.eval()
14
-
15
- # Load segmentation model
16
- segmentation_model = torch.load("best_unet_model.pth", map_location=torch.device("cpu"))
17
- segmentation_model.eval()
18
-
19
- return classification_model, segmentation_model
20
-
21
- classification_model, segmentation_model = load_models()
22
-
23
- # Define preprocessing function for classification
24
- def preprocess_image(image):
25
- # Convert to grayscale & extract green channel
26
- image = np.array(image)
27
- green_channel = image[:, :, 1]
28
-
29
- # Apply CLAHE
30
- clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
31
- img_clahe = clahe.apply(green_channel)
32
-
33
- # Convert to tensor
34
- transform = transforms.Compose([
35
- transforms.ToTensor(),
36
- transforms.Normalize(mean=[0.5], std=[0.5])
37
- ])
38
-
39
- return transform(img_clahe).unsqueeze(0)
40
-
41
- # Define function for segmentation preprocessing (resize + normalize)
42
- def preprocess_segmentation(image):
43
- transform = transforms.Compose([
44
- transforms.Resize((512, 512)),
45
- transforms.ToTensor(),
46
- transforms.Normalize(mean=[0.5], std=[0.5])
47
- ])
48
-
49
- return transform(image).unsqueeze(0)
50
-
51
- # Streamlit UI
52
- st.title("Diabetic Retinopathy Detection, Classification & Segmentation")
53
-
54
- uploaded_file = st.file_uploader("Upload a Retinal Image", type=["jpg", "png", "jpeg"])
55
- if uploaded_file:
56
- image = Image.open(uploaded_file)
57
- st.image(image, caption="Uploaded Image", use_column_width=True)
58
-
59
- # Preprocess image for classification
60
- input_tensor = preprocess_image(image)
61
-
62
- # Run classification
63
- with torch.no_grad():
64
- output = classification_model(input_tensor)
65
- predicted_class = torch.argmax(output).item()
66
-
67
- # Display result
68
- dr_stages = ["No DR", "Mild", "Moderate", "Severe", "Proliferative DR"]
69
- st.write(f"**Diabetic Retinopathy Stage:** {dr_stages[predicted_class]}")
70
-
71
- # If DR detected, proceed to segmentation
72
- if predicted_class > 0:
73
- st.write("Lesion segmentation in progress...")
74
-
75
- # Preprocess for segmentation
76
- segmentation_input = preprocess_segmentation(image)
77
-
78
- # Run segmentation
79
- with torch.no_grad():
80
- segmentation_output = segmentation_model(segmentation_input)
81
-
82
- # Convert output to mask
83
- segmentation_mask = segmentation_output.squeeze().cpu().numpy()
84
- segmentation_mask = (segmentation_mask > 0.5).astype(np.uint8) * 255
85
-
86
- # Display segmentation result
87
- st.image(segmentation_mask, caption="Segmented Lesions", use_column_width=True)
88
- else:
89
- st.write("No DR detected. Segmentation not required.")