Mrhuman1 commited on
Commit
ad42d37
Β·
verified Β·
1 Parent(s): 2dda4d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -32
app.py CHANGED
@@ -1,4 +1,13 @@
1
  import streamlit as st
 
 
 
 
 
 
 
 
 
2
  import torch
3
  import torch.nn as nn
4
  import torchvision.transforms as transforms
@@ -6,7 +15,7 @@ from efficientnet_pytorch import EfficientNet
6
  from PIL import Image
7
  from datetime import datetime
8
 
9
- # Define HardSwish activation
10
  class HardSwish(nn.Module):
11
  def __init__(self):
12
  super(HardSwish, self).__init__()
@@ -14,7 +23,7 @@ class HardSwish(nn.Module):
14
  def forward(self, x):
15
  return x * (torch.clamp(x + 3, 0, 6) / 6)
16
 
17
- # Define custom EfficientNet model
18
  class CustomEfficientNet(nn.Module):
19
  def __init__(self, num_classes):
20
  super(CustomEfficientNet, self).__init__()
@@ -30,17 +39,17 @@ class CustomEfficientNet(nn.Module):
30
  def forward(self, x):
31
  return self.model(x)
32
 
33
- # Class names
34
  class_names = [
35
  'No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity',
36
  'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis', 'Pneumothorax',
37
  'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices'
38
  ]
39
 
40
- # Device configuration
41
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
42
 
43
- # Load model
44
  @st.cache_resource
45
  def load_model():
46
  model = CustomEfficientNet(num_classes=14)
@@ -55,14 +64,14 @@ def load_model():
55
 
56
  model = load_model()
57
 
58
- # Transformations
59
  transform = transforms.Compose([
60
  transforms.Resize((300, 300)),
61
  transforms.ToTensor(),
62
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
63
  ])
64
 
65
- # Prediction function
66
  def predict(image):
67
  if image.mode != 'RGB':
68
  image = image.convert('RGB')
@@ -74,45 +83,55 @@ def predict(image):
74
 
75
  results = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
76
  sorted_results = dict(sorted(results.items(), key=lambda item: item[1], reverse=True))
77
- top5 = {k: f"{v*100:.2f}%" for k, v in list(sorted_results.items())[:5]}
78
 
79
  return top5
80
 
81
- # ---------------- Streamlit App -----------------
82
-
83
- st.set_page_config(page_title="Chest X-ray Disease Classifier", page_icon="🩺", layout="centered")
84
-
85
  st.markdown(
86
  """
87
- <h1 style="text-align:center;">🩺 Chest X-ray Disease Classifier</h1>
88
- <p style="text-align:center;">Upload a chest X-ray image to get disease predictions.</p>
89
  """,
90
  unsafe_allow_html=True
91
  )
92
 
93
- # Input form
94
- with st.form("input_form"):
95
- patient_name = st.text_input("Patient Name", placeholder="Enter full name...")
96
- scan_date = st.date_input("Scan Date", value=datetime.today())
97
- uploaded_file = st.file_uploader("Upload Chest X-ray Image", type=["png", "jpg", "jpeg", "bmp", "tiff"])
98
- submitted = st.form_submit_button("πŸ” Predict")
99
 
100
- if submitted:
101
- if uploaded_file is not None and patient_name:
 
 
 
 
 
 
 
102
  image = Image.open(uploaded_file)
103
 
104
- with st.spinner('Analyzing the X-ray...'):
105
  top5_predictions = predict(image)
106
 
107
- st.success('βœ… Prediction completed!')
108
 
 
109
  st.markdown("---")
110
- st.markdown(f"**πŸ“‹ Patient Name:** {patient_name}")
111
- st.markdown(f"**πŸ“… Scan Date:** {scan_date.strftime('%Y-%m-%d')}")
112
- st.markdown("### πŸ§ͺ Top 5 Predictions:")
113
- for disease, probability in top5_predictions.items():
114
- st.write(f"πŸ”Ή **{disease}**: {probability}")
115
 
116
- st.image(image, caption="Uploaded Chest X-ray", use_column_width=True)
117
- else:
118
- st.error("⚠️ Please fill all fields and upload an image.")
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+
3
+ # ⚑️ Page config must be FIRST
4
+ st.set_page_config(
5
+ page_title="Chest X-ray Disease Classifier",
6
+ page_icon="🩺",
7
+ layout="centered"
8
+ )
9
+
10
+ # Imports
11
  import torch
12
  import torch.nn as nn
13
  import torchvision.transforms as transforms
 
15
  from PIL import Image
16
  from datetime import datetime
17
 
18
+ # --- Define HardSwish activation ---
19
  class HardSwish(nn.Module):
20
  def __init__(self):
21
  super(HardSwish, self).__init__()
 
23
  def forward(self, x):
24
  return x * (torch.clamp(x + 3, 0, 6) / 6)
25
 
26
+ # --- Define Custom EfficientNet model ---
27
  class CustomEfficientNet(nn.Module):
28
  def __init__(self, num_classes):
29
  super(CustomEfficientNet, self).__init__()
 
39
  def forward(self, x):
40
  return self.model(x)
41
 
42
+ # Disease class labels
43
  class_names = [
44
  'No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity',
45
  'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis', 'Pneumothorax',
46
  'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices'
47
  ]
48
 
49
+ # --- Device configuration ---
50
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
51
 
52
+ # --- Load model ---
53
  @st.cache_resource
54
  def load_model():
55
  model = CustomEfficientNet(num_classes=14)
 
64
 
65
  model = load_model()
66
 
67
+ # --- Transformations ---
68
  transform = transforms.Compose([
69
  transforms.Resize((300, 300)),
70
  transforms.ToTensor(),
71
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
72
  ])
73
 
74
+ # --- Prediction function ---
75
  def predict(image):
76
  if image.mode != 'RGB':
77
  image = image.convert('RGB')
 
83
 
84
  results = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
85
  sorted_results = dict(sorted(results.items(), key=lambda item: item[1], reverse=True))
86
+ top5 = {k: v for k, v in list(sorted_results.items())[:5]}
87
 
88
  return top5
89
 
90
+ # --- Streamlit UI ---
 
 
 
91
  st.markdown(
92
  """
93
+ <h1 style="text-align: center;">🩺 Chest X-ray Disease Classifier</h1>
94
+ <p style="text-align: center;">Upload a chest X-ray image to get disease predictions.<br>Top 5 diseases shown with probability scores.</p>
95
  """,
96
  unsafe_allow_html=True
97
  )
98
 
99
+ st.write("") # small space
100
+
101
+ with st.form("prediction_form"):
102
+ patient_name = st.text_input("πŸ‘€ Patient Name", placeholder="Enter patient's full name...")
103
+ scan_date = st.date_input("πŸ“… Scan Date", value=datetime.today())
104
+ uploaded_file = st.file_uploader("πŸ“€ Upload Chest X-ray Image", type=["png", "jpg", "jpeg", "bmp", "tiff"])
105
 
106
+ submit_button = st.form_submit_button("πŸ” Analyze X-ray")
107
+
108
+ # Handle submission
109
+ if submit_button:
110
+ if not uploaded_file:
111
+ st.error("⚠️ Please upload a chest X-ray image.")
112
+ elif not patient_name.strip():
113
+ st.error("⚠️ Please enter the patient's name.")
114
+ else:
115
  image = Image.open(uploaded_file)
116
 
117
+ with st.spinner('πŸ”Ž Analyzing the X-ray...'):
118
  top5_predictions = predict(image)
119
 
120
+ st.success('βœ… Analysis Completed!')
121
 
122
+ # --- Show results ---
123
  st.markdown("---")
124
+ st.subheader("πŸ“‹ Patient Information")
125
+ st.write(f"**Name:** {patient_name}")
126
+ st.write(f"**Scan Date:** {scan_date.strftime('%Y-%m-%d')}")
 
 
127
 
128
+ st.markdown("---")
129
+ st.subheader("πŸ§ͺ Top 5 Predicted Diseases")
130
+
131
+ for disease, prob in top5_predictions.items():
132
+ st.progress(prob) # show progress bar
133
+ st.write(f"πŸ”Ή **{disease}** β€” {prob*100:.2f}%")
134
+
135
+ st.markdown("---")
136
+ st.subheader("πŸ–ΌοΈ Uploaded X-ray Image")
137
+ st.image(image, use_column_width=True)