DrSyedFaizan commited on
Commit
9a6b4fa
·
verified ·
1 Parent(s): 3c3063a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +303 -303
app.py CHANGED
@@ -1,304 +1,304 @@
1
- import streamlit as st
2
- import torch
3
- import torchvision
4
- import torchmetrics
5
- import pytorch_lightning as pl
6
- import numpy as np
7
- import cv2
8
- import time
9
- import pydicom
10
- import nibabel as nib
11
- import io
12
- from torchvision import transforms
13
- from PIL import Image
14
-
15
- # Load the trained model
16
- class PneumoniaModel(pl.LightningModule):
17
- def __init__(self):
18
- super(PneumoniaModel, self).__init__()
19
- self.model = torchvision.models.resnet18()
20
- self.model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
21
- self.model.fc = torch.nn.Linear(in_features=512, out_features=1, bias=True)
22
-
23
- self.loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([5.0]))
24
- self.val_acc = torchmetrics.Accuracy(task="binary")
25
- self.val_auc = torchmetrics.AUROC(task="binary")
26
- self.val_outputs = []
27
-
28
- def forward(self, data):
29
- return self.model(data)
30
-
31
- def validation_step(self, batch, batch_idx):
32
- x_ray, label = batch
33
- label = label.float()
34
- pred = self(x_ray)[:, 0]
35
- loss = self.loss_fn(pred, label)
36
- self.val_outputs.append({"preds": pred, "targets": label})
37
- return loss
38
-
39
- def on_validation_epoch_end(self):
40
- all_preds = torch.cat([x["preds"] for x in self.val_outputs]).cpu().numpy()
41
- all_targets = torch.cat([x["targets"] for x in self.val_outputs]).cpu().numpy()
42
- self.val_outputs.clear()
43
-
44
- def configure_optimizers(self):
45
- return torch.optim.Adam(self.model.parameters(), lr=1e-4)
46
-
47
- # Load trained model weights
48
- model = PneumoniaModel()
49
- checkpoint = torch.load("weights_3.ckpt", map_location=torch.device('cpu'))
50
- state_dict = checkpoint["state_dict"]
51
- model.load_state_dict(state_dict)
52
- model.eval()
53
-
54
- # Preprocessing function
55
- def preprocess_image(image):
56
- transform = transforms.Compose([
57
- transforms.ToPILImage(),
58
- transforms.Resize((256, 256)),
59
- transforms.ToTensor(),
60
- transforms.Normalize(mean=[0.5], std=[0.5])
61
- ])
62
- return transform(image).unsqueeze(0)
63
-
64
- # Function to load and preprocess different file types
65
- def load_image(file_path, file_type):
66
- file_type = file_type.lower()
67
-
68
- try:
69
- if file_type in ["png", "jpg", "jpeg"]:
70
- # For file objects from streamlit
71
- if hasattr(file_path, 'read'):
72
- image = Image.open(file_path).convert("L") # Convert to grayscale
73
- else:
74
- image = Image.open(file_path).convert("L")
75
- image = np.array(image)
76
-
77
- elif file_type == "dcm":
78
- # For file objects from streamlit
79
- if hasattr(file_path, 'read'):
80
- # Create a temporary BytesIO object
81
- temp_file = io.BytesIO(file_path.read())
82
- file_path.seek(0) # Reset pointer for future reads
83
- dicom_data = pydicom.dcmread(temp_file)
84
- else:
85
- dicom_data = pydicom.dcmread(file_path)
86
-
87
- image = dicom_data.pixel_array
88
-
89
- elif file_type in ["nii", "nii.gz"]:
90
- # For file objects from streamlit
91
- if hasattr(file_path, 'read'):
92
- # We need to save temporarily for nibabel
93
- with open("temp_file." + file_type, "wb") as f:
94
- f.write(file_path.read())
95
- file_path.seek(0) # Reset pointer for future reads
96
- nifti_data = nib.load("temp_file." + file_type)
97
- # Clean up the temp file
98
- import os
99
- try:
100
- os.remove("temp_file." + file_type)
101
- except:
102
- pass # Ignore cleanup errors
103
- else:
104
- nifti_data = nib.load(file_path)
105
-
106
- image = nifti_data.get_fdata()
107
- image = np.squeeze(image) # Only one squeeze needed
108
-
109
- else:
110
- return None
111
-
112
- # Common processing for all image types
113
- # Normalize to 0-255 range if needed
114
- if image.max() > 1.0 and image.max() <= 255:
115
- # Already in 0-255 range, no need to normalize
116
- pass
117
- else:
118
- # Normalize to 0-255
119
- image = np.uint8(255 * (image - np.min(image)) / (np.max(image) - np.min(image) + 1e-10)) # Added small value to prevent division by zero
120
-
121
- # Resize to model's expected input size
122
- image = cv2.resize(image, (256, 256))
123
-
124
- # Apply the preprocessing and return tensor
125
- return preprocess_image(image)
126
-
127
- except Exception as e:
128
- import traceback
129
- st.error(f"Error processing image: {str(e)}")
130
- st.code(traceback.format_exc())
131
- return None
132
-
133
- # Streamlit Web App
134
- st.set_page_config(
135
- page_title="PneumoFind",
136
- page_icon="🫁",
137
- layout="wide",
138
- initial_sidebar_state="expanded"
139
- )
140
-
141
- # Custom CSS
142
- st.markdown("""
143
- <style>
144
- .main-header {
145
- font-size: 3rem;
146
- color: #3498db;
147
- text-align: center;
148
- margin-bottom: 1rem;
149
- font-weight: 700;
150
- }
151
- .sub-header {
152
- font-size: 1.5rem;
153
- color: #2c3e50;
154
- text-align: center;
155
- margin-bottom: 2rem;
156
- }
157
- .result-normal {
158
- padding: 20px;
159
- border-radius: 10px;
160
- background-color: #2ecc71;
161
- color: white;
162
- text-align: center;
163
- font-size: 2rem;
164
- font-weight: bold;
165
- margin: 20px 0;
166
- }
167
- .result-pneumonia {
168
- padding: 20px;
169
- border-radius: 10px;
170
- background-color: #e74c3c;
171
- color: white;
172
- text-align: center;
173
- font-size: 2rem;
174
- font-weight: bold;
175
- margin: 20px 0;
176
- }
177
- .upload-section {
178
- background-color: #f8f9fa;
179
- padding: 30px;
180
- border-radius: 15px;
181
- box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
182
- margin-bottom: 30px;
183
- }
184
- .footer {
185
- text-align: center;
186
- color: #7f8c8d;
187
- font-size: 0.9rem;
188
- padding: 20px;
189
- border-top: 1px solid #eee;
190
- margin-top: 40px;
191
- }
192
- .stImage img {
193
- border-radius: 10px;
194
- box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
195
- }
196
- </style>
197
- """, unsafe_allow_html=True)
198
-
199
- # Header
200
- st.markdown("<h1 class='main-header'>PneumoFind</h1>", unsafe_allow_html=True)
201
- st.markdown("<h2 class='sub-header'>Advanced AI-Powered Pneumonia Detection</h2>", unsafe_allow_html=True)
202
-
203
- # Sidebar
204
- with st.sidebar:
205
- st.image("steth.png")
206
- st.markdown("## About PneumoFind")
207
- st.info(
208
- "PneumoFind uses deep learning to analyze chest X-rays and "
209
- "detect signs of pneumonia with high accuracy. Upload your medical "
210
- "image for instant analysis."
211
- )
212
-
213
- st.markdown("## Supported Formats")
214
- st.markdown("- X-ray Images (PNG, JPG, JPEG)")
215
-
216
-
217
- st.markdown("## Interpretation")
218
- st.success("**Normal**: No signs of pneumonia detected")
219
- st.error("**Pneumonia**: Signs of pneumonia detected")
220
-
221
- # Main content
222
- st.markdown("<div class='upload-section'>", unsafe_allow_html=True)
223
- uploaded_file = st.file_uploader("Upload an X-ray image for analysis", type=["png", "jpg", "jpeg"])
224
- st.markdown("</div>", unsafe_allow_html=True)
225
-
226
- # Process image if uploaded
227
- if uploaded_file is not None:
228
- col1, col2 = st.columns(2)
229
-
230
- with col1:
231
- st.markdown("### Uploaded Image")
232
- st.image(uploaded_file, caption="X-ray Image", use_container_width=True)
233
-
234
- with col2:
235
- st.markdown("### Analysis Results")
236
-
237
- # Progress bar for analysis simulation
238
- with st.spinner("Analyzing image..."):
239
- # Process the image
240
- file_extension = uploaded_file.name.split(".")[-1]
241
- processed_image = load_image(uploaded_file, file_extension)
242
-
243
- if processed_image is not None:
244
- # Process with model
245
- progress_bar = st.progress(0)
246
- for i in range(100):
247
- time.sleep(0.01) # Add a small delay for visual effect
248
- progress_bar.progress(i + 1)
249
-
250
- # Get prediction
251
- with torch.no_grad():
252
- output = model(processed_image) # Model outputs raw logits
253
- probability = torch.sigmoid(output).item() # Convert logits to probability
254
- prediction = "Pneumonia Detected" if probability > 0.15 else "No Pneumonia Detected"
255
-
256
- # Display results
257
- if probability > 0.15:
258
- st.markdown(f"<div class='result-pneumonia'>{prediction}</div>", unsafe_allow_html=True)
259
- st.warning(f"Confidence Score: {probability:.2f}") # Display correct probability
260
- st.markdown("#### Recommendation")
261
- st.error("Please consult a healthcare professional for proper diagnosis and treatment.")
262
- else:
263
- st.markdown(f"<div class='result-normal'>{prediction}</div>", unsafe_allow_html=True)
264
- st.info(f"Confidence Score: {1 - probability:.2f}") # Correct confidence display
265
- st.markdown("#### Recommendation")
266
- st.success("X-ray appears normal. Continue regular health monitoring.")
267
-
268
- else:
269
- st.error("Error: File format not supported or corrupted image.")
270
- else:
271
- # Display sample image gallery when no file is uploaded
272
- st.markdown("### Sample X-rays")
273
- st.info("Upload an X-ray image to get started. Here are example images for reference.")
274
-
275
- sample_col1, sample_col2 = st.columns(2)
276
- with sample_col1:
277
- st.image("nopneumoniaxray.png",
278
- caption="Example of a normal chest X-ray", width=300)
279
- with sample_col2:
280
- st.image("pneumoniaxray.png",
281
- caption="Example of a pneumonia chest X-ray", width=300)
282
-
283
- # Informational section
284
- st.markdown("## About Pneumonia")
285
- expander = st.expander("Learn more about pneumonia")
286
- with expander:
287
- st.markdown("""
288
- Pneumonia is an infection that inflames the air sacs in one or both lungs. The air sacs may fill with fluid or pus,
289
- causing cough with phlegm or pus, fever, chills, and difficulty breathing. Various organisms, including bacteria,
290
- viruses and fungi, can cause pneumonia.
291
-
292
- **Common symptoms include:**
293
- - Chest pain when breathing or coughing
294
- - Confusion or changes in mental awareness (in adults age 65 and older)
295
- - Cough, which may produce phlegm
296
- - Fatigue
297
- - Fever, sweating and shaking chills
298
- - Lower than normal body temperature (in adults older than age 65 and people with weak immune systems)
299
- - Nausea, vomiting or diarrhea
300
- - Shortness of breath
301
- """)
302
-
303
- # Footer
304
  st.markdown("<div class='footer'>App Developed by Syed Faizan | © 2025 PneumoFind</div>", unsafe_allow_html=True)
 
1
+ import streamlit as st
2
+ import torch
3
+ import torchvision
4
+ import torchmetrics
5
+ import pytorch_lightning as pl
6
+ import numpy as np
7
+ import cv2
8
+ import time
9
+ import pydicom
10
+ import nibabel as nib
11
+ import io
12
+ from torchvision import transforms
13
+ from PIL import Image
14
+
15
+ # Load the trained model
16
+ class PneumoniaModel(pl.LightningModule):
17
+ def __init__(self):
18
+ super(PneumoniaModel, self).__init__()
19
+ self.model = torchvision.models.resnet18()
20
+ self.model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
21
+ self.model.fc = torch.nn.Linear(in_features=512, out_features=1, bias=True)
22
+
23
+ self.loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([5.0]))
24
+ self.val_acc = torchmetrics.Accuracy(task="binary")
25
+ self.val_auc = torchmetrics.AUROC(task="binary")
26
+ self.val_outputs = []
27
+
28
+ def forward(self, data):
29
+ return self.model(data)
30
+
31
+ def validation_step(self, batch, batch_idx):
32
+ x_ray, label = batch
33
+ label = label.float()
34
+ pred = self(x_ray)[:, 0]
35
+ loss = self.loss_fn(pred, label)
36
+ self.val_outputs.append({"preds": pred, "targets": label})
37
+ return loss
38
+
39
+ def on_validation_epoch_end(self):
40
+ all_preds = torch.cat([x["preds"] for x in self.val_outputs]).cpu().numpy()
41
+ all_targets = torch.cat([x["targets"] for x in self.val_outputs]).cpu().numpy()
42
+ self.val_outputs.clear()
43
+
44
+ def configure_optimizers(self):
45
+ return torch.optim.Adam(self.model.parameters(), lr=1e-4)
46
+
47
+ # Load trained model weights
48
+ model = PneumoniaModel()
49
+ checkpoint = torch.load("weights_3.ckpt", map_location=torch.device('cpu'), weights_only=False)
50
+ state_dict = checkpoint["state_dict"]
51
+ model.load_state_dict(state_dict)
52
+ model.eval()
53
+
54
+ # Preprocessing function
55
+ def preprocess_image(image):
56
+ transform = transforms.Compose([
57
+ transforms.ToPILImage(),
58
+ transforms.Resize((256, 256)),
59
+ transforms.ToTensor(),
60
+ transforms.Normalize(mean=[0.5], std=[0.5])
61
+ ])
62
+ return transform(image).unsqueeze(0)
63
+
64
+ # Function to load and preprocess different file types
65
+ def load_image(file_path, file_type):
66
+ file_type = file_type.lower()
67
+
68
+ try:
69
+ if file_type in ["png", "jpg", "jpeg"]:
70
+ # For file objects from streamlit
71
+ if hasattr(file_path, 'read'):
72
+ image = Image.open(file_path).convert("L") # Convert to grayscale
73
+ else:
74
+ image = Image.open(file_path).convert("L")
75
+ image = np.array(image)
76
+
77
+ elif file_type == "dcm":
78
+ # For file objects from streamlit
79
+ if hasattr(file_path, 'read'):
80
+ # Create a temporary BytesIO object
81
+ temp_file = io.BytesIO(file_path.read())
82
+ file_path.seek(0) # Reset pointer for future reads
83
+ dicom_data = pydicom.dcmread(temp_file)
84
+ else:
85
+ dicom_data = pydicom.dcmread(file_path)
86
+
87
+ image = dicom_data.pixel_array
88
+
89
+ elif file_type in ["nii", "nii.gz"]:
90
+ # For file objects from streamlit
91
+ if hasattr(file_path, 'read'):
92
+ # We need to save temporarily for nibabel
93
+ with open("temp_file." + file_type, "wb") as f:
94
+ f.write(file_path.read())
95
+ file_path.seek(0) # Reset pointer for future reads
96
+ nifti_data = nib.load("temp_file." + file_type)
97
+ # Clean up the temp file
98
+ import os
99
+ try:
100
+ os.remove("temp_file." + file_type)
101
+ except:
102
+ pass # Ignore cleanup errors
103
+ else:
104
+ nifti_data = nib.load(file_path)
105
+
106
+ image = nifti_data.get_fdata()
107
+ image = np.squeeze(image) # Only one squeeze needed
108
+
109
+ else:
110
+ return None
111
+
112
+ # Common processing for all image types
113
+ # Normalize to 0-255 range if needed
114
+ if image.max() > 1.0 and image.max() <= 255:
115
+ # Already in 0-255 range, no need to normalize
116
+ pass
117
+ else:
118
+ # Normalize to 0-255
119
+ image = np.uint8(255 * (image - np.min(image)) / (np.max(image) - np.min(image) + 1e-10)) # Added small value to prevent division by zero
120
+
121
+ # Resize to model's expected input size
122
+ image = cv2.resize(image, (256, 256))
123
+
124
+ # Apply the preprocessing and return tensor
125
+ return preprocess_image(image)
126
+
127
+ except Exception as e:
128
+ import traceback
129
+ st.error(f"Error processing image: {str(e)}")
130
+ st.code(traceback.format_exc())
131
+ return None
132
+
133
+ # Streamlit Web App
134
+ st.set_page_config(
135
+ page_title="PneumoFind",
136
+ page_icon="🫁",
137
+ layout="wide",
138
+ initial_sidebar_state="expanded"
139
+ )
140
+
141
+ # Custom CSS
142
+ st.markdown("""
143
+ <style>
144
+ .main-header {
145
+ font-size: 3rem;
146
+ color: #3498db;
147
+ text-align: center;
148
+ margin-bottom: 1rem;
149
+ font-weight: 700;
150
+ }
151
+ .sub-header {
152
+ font-size: 1.5rem;
153
+ color: #2c3e50;
154
+ text-align: center;
155
+ margin-bottom: 2rem;
156
+ }
157
+ .result-normal {
158
+ padding: 20px;
159
+ border-radius: 10px;
160
+ background-color: #2ecc71;
161
+ color: white;
162
+ text-align: center;
163
+ font-size: 2rem;
164
+ font-weight: bold;
165
+ margin: 20px 0;
166
+ }
167
+ .result-pneumonia {
168
+ padding: 20px;
169
+ border-radius: 10px;
170
+ background-color: #e74c3c;
171
+ color: white;
172
+ text-align: center;
173
+ font-size: 2rem;
174
+ font-weight: bold;
175
+ margin: 20px 0;
176
+ }
177
+ .upload-section {
178
+ background-color: #f8f9fa;
179
+ padding: 30px;
180
+ border-radius: 15px;
181
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
182
+ margin-bottom: 30px;
183
+ }
184
+ .footer {
185
+ text-align: center;
186
+ color: #7f8c8d;
187
+ font-size: 0.9rem;
188
+ padding: 20px;
189
+ border-top: 1px solid #eee;
190
+ margin-top: 40px;
191
+ }
192
+ .stImage img {
193
+ border-radius: 10px;
194
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
195
+ }
196
+ </style>
197
+ """, unsafe_allow_html=True)
198
+
199
+ # Header
200
+ st.markdown("<h1 class='main-header'>PneumoFind</h1>", unsafe_allow_html=True)
201
+ st.markdown("<h2 class='sub-header'>Advanced AI-Powered Pneumonia Detection</h2>", unsafe_allow_html=True)
202
+
203
+ # Sidebar
204
+ with st.sidebar:
205
+ st.image("steth.png")
206
+ st.markdown("## About PneumoFind")
207
+ st.info(
208
+ "PneumoFind uses deep learning to analyze chest X-rays and "
209
+ "detect signs of pneumonia with high accuracy. Upload your medical "
210
+ "image for instant analysis."
211
+ )
212
+
213
+ st.markdown("## Supported Formats")
214
+ st.markdown("- X-ray Images (PNG, JPG, JPEG)")
215
+
216
+
217
+ st.markdown("## Interpretation")
218
+ st.success("**Normal**: No signs of pneumonia detected")
219
+ st.error("**Pneumonia**: Signs of pneumonia detected")
220
+
221
+ # Main content
222
+ st.markdown("<div class='upload-section'>", unsafe_allow_html=True)
223
+ uploaded_file = st.file_uploader("Upload an X-ray image for analysis", type=["png", "jpg", "jpeg"])
224
+ st.markdown("</div>", unsafe_allow_html=True)
225
+
226
+ # Process image if uploaded
227
+ if uploaded_file is not None:
228
+ col1, col2 = st.columns(2)
229
+
230
+ with col1:
231
+ st.markdown("### Uploaded Image")
232
+ st.image(uploaded_file, caption="X-ray Image", use_container_width=True)
233
+
234
+ with col2:
235
+ st.markdown("### Analysis Results")
236
+
237
+ # Progress bar for analysis simulation
238
+ with st.spinner("Analyzing image..."):
239
+ # Process the image
240
+ file_extension = uploaded_file.name.split(".")[-1]
241
+ processed_image = load_image(uploaded_file, file_extension)
242
+
243
+ if processed_image is not None:
244
+ # Process with model
245
+ progress_bar = st.progress(0)
246
+ for i in range(100):
247
+ time.sleep(0.01) # Add a small delay for visual effect
248
+ progress_bar.progress(i + 1)
249
+
250
+ # Get prediction
251
+ with torch.no_grad():
252
+ output = model(processed_image) # Model outputs raw logits
253
+ probability = torch.sigmoid(output).item() # Convert logits to probability
254
+ prediction = "Pneumonia Detected" if probability > 0.15 else "No Pneumonia Detected"
255
+
256
+ # Display results
257
+ if probability > 0.15:
258
+ st.markdown(f"<div class='result-pneumonia'>{prediction}</div>", unsafe_allow_html=True)
259
+ st.warning(f"Confidence Score: {probability:.2f}") # Display correct probability
260
+ st.markdown("#### Recommendation")
261
+ st.error("Please consult a healthcare professional for proper diagnosis and treatment.")
262
+ else:
263
+ st.markdown(f"<div class='result-normal'>{prediction}</div>", unsafe_allow_html=True)
264
+ st.info(f"Confidence Score: {1 - probability:.2f}") # Correct confidence display
265
+ st.markdown("#### Recommendation")
266
+ st.success("X-ray appears normal. Continue regular health monitoring.")
267
+
268
+ else:
269
+ st.error("Error: File format not supported or corrupted image.")
270
+ else:
271
+ # Display sample image gallery when no file is uploaded
272
+ st.markdown("### Sample X-rays")
273
+ st.info("Upload an X-ray image to get started. Here are example images for reference.")
274
+
275
+ sample_col1, sample_col2 = st.columns(2)
276
+ with sample_col1:
277
+ st.image("nopneumoniaxray.png",
278
+ caption="Example of a normal chest X-ray", width=300)
279
+ with sample_col2:
280
+ st.image("pneumoniaxray.png",
281
+ caption="Example of a pneumonia chest X-ray", width=300)
282
+
283
+ # Informational section
284
+ st.markdown("## About Pneumonia")
285
+ expander = st.expander("Learn more about pneumonia")
286
+ with expander:
287
+ st.markdown("""
288
+ Pneumonia is an infection that inflames the air sacs in one or both lungs. The air sacs may fill with fluid or pus,
289
+ causing cough with phlegm or pus, fever, chills, and difficulty breathing. Various organisms, including bacteria,
290
+ viruses and fungi, can cause pneumonia.
291
+
292
+ **Common symptoms include:**
293
+ - Chest pain when breathing or coughing
294
+ - Confusion or changes in mental awareness (in adults age 65 and older)
295
+ - Cough, which may produce phlegm
296
+ - Fatigue
297
+ - Fever, sweating and shaking chills
298
+ - Lower than normal body temperature (in adults older than age 65 and people with weak immune systems)
299
+ - Nausea, vomiting or diarrhea
300
+ - Shortness of breath
301
+ """)
302
+
303
+ # Footer
304
  st.markdown("<div class='footer'>App Developed by Syed Faizan | © 2025 PneumoFind</div>", unsafe_allow_html=True)