VJBharathkumar commited on
Commit
1abfaa3
·
verified ·
1 Parent(s): e1e0672

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +141 -37
src/streamlit_app.py CHANGED
@@ -1,40 +1,144 @@
1
- import altair as alt
2
  import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
  import numpy as np
 
3
  import streamlit as st
4
+ import tensorflow as tf
5
+ from tensorflow import keras
6
+ import pydicom
7
 
8
+ # ----------------------------------------------------
9
+ # App Configuration
10
+ # ----------------------------------------------------
11
+ st.set_page_config(
12
+ page_title="Pneumonia Detection (Chest X-ray) Clinical Decision Support",
13
+ layout="centered"
14
+ )
15
+
16
+ st.title("Pneumonia Detection (Chest X-ray) – Clinical Decision Support")
17
+ st.caption(
18
+ "Upload one or more Chest X-ray DICOM files (.dcm). "
19
+ "Adjust the decision threshold and submit to obtain a probability-based binary prediction. "
20
+ "This system is intended for clinical decision support and does not replace professional medical judgment."
21
+ )
22
+
23
+ # ----------------------------------------------------
24
+ # Load Model
25
+ # ----------------------------------------------------
26
+ MODEL_PATH = "model.keras"
27
+
28
+ @st.cache_resource
29
+ def load_model():
30
+ try:
31
+ return keras.models.load_model(MODEL_PATH)
32
+ except Exception:
33
+ keras.config.enable_unsafe_deserialization()
34
+ return keras.models.load_model(MODEL_PATH, safe_mode=False)
35
+
36
+ model = load_model()
37
+
38
+ input_shape = model.input_shape
39
+ img_size = int(input_shape[1]) if input_shape and input_shape[1] else 256
40
+ expected_channels = int(input_shape[-1]) if input_shape and input_shape[-1] else 3
41
+
42
+ # ----------------------------------------------------
43
+ # Threshold Slider (DEFAULT = 0.37 for ResNet)
44
+ # ----------------------------------------------------
45
+ st.subheader("Model Parameters")
46
+
47
+ threshold = st.slider(
48
+ "Decision Threshold",
49
+ min_value=0.01,
50
+ max_value=0.99,
51
+ value=0.37, # <-- DEFAULT CHANGED HERE
52
+ step=0.01,
53
+ help="If predicted probability ≥ threshold → Pneumonia. Otherwise → Not Pneumonia."
54
+ )
55
+
56
+ # ----------------------------------------------------
57
+ # File Upload
58
+ # ----------------------------------------------------
59
+ st.subheader("Upload Chest X-ray DICOM Files")
60
+
61
+ uploaded_files = st.file_uploader(
62
+ "Select one or multiple DICOM files (.dcm)",
63
+ type=["dcm"],
64
+ accept_multiple_files=True
65
+ )
66
+
67
+ col1, col2 = st.columns(2)
68
+ with col1:
69
+ submit = st.button("Submit", type="primary", use_container_width=True)
70
+ with col2:
71
+ clear = st.button("Clear", use_container_width=True)
72
+
73
+ if clear:
74
+ st.experimental_rerun()
75
+
76
+ # ----------------------------------------------------
77
+ # Helper Functions
78
+ # ----------------------------------------------------
79
+ def read_dicom(file):
80
+ data = file.read()
81
+ dcm = pydicom.dcmread(io.BytesIO(data))
82
+ img = dcm.pixel_array.astype(np.float32)
83
+
84
+ img = (img - img.min()) / (img.max() - img.min() + 1e-8)
85
+ return img
86
+
87
+ def preprocess(img):
88
+ x = tf.convert_to_tensor(img[..., None], dtype=tf.float32)
89
+ x = tf.image.resize(x, (img_size, img_size))
90
+ x = tf.clip_by_value(x, 0.0, 1.0)
91
+ x = x.numpy()
92
+
93
+ # If model expects 3 channels (ResNet)
94
+ if expected_channels == 3 and x.shape[-1] == 1:
95
+ x = np.repeat(x, 3, axis=-1)
96
+
97
+ x = np.expand_dims(x, axis=0)
98
+ return x.astype(np.float32)
99
+
100
+ def get_probability(x):
101
+ prediction = model.predict(x, verbose=0)
102
+
103
+ if isinstance(prediction, (list, tuple)):
104
+ prob = float(np.ravel(prediction[-1])[0])
105
+ else:
106
+ prob = float(np.ravel(prediction)[0])
107
+
108
+ return max(0.0, min(1.0, prob))
109
+
110
+ # ----------------------------------------------------
111
+ # Inference Section
112
+ # ----------------------------------------------------
113
+ st.subheader("Prediction Results")
114
+
115
+ if submit:
116
+ if not uploaded_files:
117
+ st.warning("Please upload at least one DICOM file before clicking Submit.")
118
+ else:
119
+ with st.spinner("Processing uploaded file(s)..."):
120
+ for file in uploaded_files:
121
+ try:
122
+ image_array = read_dicom(file)
123
+ x_input = preprocess(image_array)
124
+ probability = get_probability(x_input)
125
+
126
+ predicted_label = "Pneumonia" if probability >= threshold else "Not Pneumonia"
127
+
128
+ st.write(
129
+ f"For the uploaded file '{file.name}', the model estimates a pneumonia probability of "
130
+ f"{probability * 100:.2f}%. Based on the selected decision threshold of {threshold:.2f}, "
131
+ f"the predicted outcome is '{predicted_label}'."
132
+ )
133
+
134
+ except Exception as e:
135
+ st.error(
136
+ f"For the uploaded file '{file.name}', the system could not generate a prediction. "
137
+ f"Reason: {str(e)}."
138
+ )
139
+
140
+ st.divider()
141
+ st.caption(
142
+ "Clinical Notice: This application is designed for decision support purposes only. "
143
+ "Final diagnosis and treatment decisions must be made by qualified healthcare professionals."
144
+ )