Jasonnn13 commited on
Commit
6db0507
·
verified ·
1 Parent(s): 40151ee

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +345 -37
src/streamlit_app.py CHANGED
@@ -1,40 +1,348 @@
1
- import altair as alt
 
 
2
  import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ import io
2
+ from typing import List, Optional, Tuple
3
+
4
  import numpy as np
 
5
  import streamlit as st
6
+ import torch
7
+ torch.classes.__path__ = []
8
+ import torch.nn as nn
9
+ from PIL import Image
10
+
11
+ from pathlib import Path
12
+
13
+
14
+ # Fixed class mapping provided by user
15
+ CLASS_TO_LABEL = {
16
+ 0: "Adenocarcinoma",
17
+ 1: "Large Cell Carcinoma",
18
+ 2: "Normal",
19
+ 3: "Squamous Cell Carcinoma",
20
+ }
21
+
22
+
23
+ def _infer_num_classes_from_state(state_dict: dict) -> Optional[int]:
24
+ candidates = [
25
+ "classifier.2.weight",
26
+ "head.fc.weight",
27
+ "fc.weight",
28
+ "classifier.weight",
29
+ ]
30
+ for k in candidates:
31
+ if k in state_dict:
32
+ return int(state_dict[k].shape[0])
33
+ # Try to find any linear layer weight at the tail of classifier
34
+ keys = [k for k in state_dict.keys() if k.endswith(".weight")]
35
+ for k in keys:
36
+ if ".classifier" in k or ".head" in k or k.endswith("fc.weight"):
37
+ try:
38
+ return int(state_dict[k].shape[0])
39
+ except Exception:
40
+ pass
41
+ return None
42
+
43
+
44
+ def _infer_class_names(ckpt: dict, num_classes: int) -> List[str]:
45
+ # Common patterns
46
+ for key in ("classes", "class_names", "labels"):
47
+ if isinstance(ckpt.get(key), (list, tuple)):
48
+ return list(ckpt[key])
49
+ if isinstance(ckpt.get("idx_to_class"), dict):
50
+ # Ensure ordered by index
51
+ mapping = ckpt["idx_to_class"]
52
+ try:
53
+ return [mapping[i] for i in range(len(mapping))]
54
+ except Exception:
55
+ # Fallback arbitrary order
56
+ return list(mapping.values())
57
+ if isinstance(ckpt.get("class_to_idx"), dict):
58
+ inv = sorted(ckpt["class_to_idx"].items(), key=lambda x: x[1])
59
+ return [name for name, _ in inv]
60
+ return [f"Class {i}" for i in range(num_classes)]
61
+
62
+
63
+ @st.cache_resource(show_spinner=True)
64
+ def load_model(weights_path: str) -> Tuple[nn.Module, List[str]]:
65
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
+ ckpt = torch.load(weights_path, map_location=device)
67
+ if isinstance(ckpt, dict):
68
+ state_dict = ckpt.get("state_dict") or ckpt.get("model_state_dict") or ckpt
69
+ else:
70
+ state_dict = ckpt
71
+
72
+ # Prefer fixed mapping if provided, otherwise infer
73
+ if CLASS_TO_LABEL:
74
+ num_classes = len(CLASS_TO_LABEL)
75
+ else:
76
+ num_classes = _infer_num_classes_from_state(state_dict) or 2
77
+
78
+ model = None
79
+ errors = []
80
+
81
+ # Try torchvision ConvNeXt Large first
82
+ try:
83
+ from torchvision.models import convnext_large
84
+ tv_model = convnext_large(weights=None)
85
+ in_features = tv_model.classifier[2].in_features
86
+ tv_model.classifier[2] = nn.Linear(in_features, num_classes)
87
+ tv_model.load_state_dict(state_dict, strict=False)
88
+ model = tv_model
89
+ except Exception as e:
90
+ errors.append(f"torchvision load failed: {e}")
91
+
92
+ if model is None:
93
+ raise RuntimeError(
94
+ "Failed to load model with the provided weights. "
95
+ + " ; ".join(errors)
96
+ )
97
+
98
+ model.to(device)
99
+ model.eval()
100
+
101
+ if CLASS_TO_LABEL and len(CLASS_TO_LABEL) == num_classes:
102
+ class_names = [CLASS_TO_LABEL[i] for i in range(num_classes)]
103
+ else:
104
+ class_names = _infer_class_names(ckpt if isinstance(ckpt, dict) else {}, num_classes)
105
+ return model, class_names
106
+
107
+
108
+ def preprocess_image(img: Image.Image) -> torch.Tensor:
109
+ # Ensure RGB
110
+ if img.mode != "RGB":
111
+ img = img.convert("RGB")
112
+ # Resize to 224 while keeping aspect ratio via center-crop like behavior
113
+ img = img.resize((224, 224))
114
+
115
+ arr = np.array(img).astype("float32") / 255.0
116
+ mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
117
+ std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
118
+ arr = (arr - mean) / std
119
+ arr = np.transpose(arr, (2, 0, 1))
120
+ tensor = torch.from_numpy(arr)
121
+ return tensor
122
+
123
+
124
+ def predict(model: nn.Module, tensor: torch.Tensor) -> Tuple[int, float, np.ndarray]:
125
+ device = next(model.parameters()).device
126
+ with torch.no_grad():
127
+ logits = model(tensor.unsqueeze(0).to(device))
128
+ if isinstance(logits, (list, tuple)):
129
+ logits = logits[0]
130
+ probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
131
+ idx = int(np.argmax(probs))
132
+ conf = float(probs[idx])
133
+ return idx, conf, probs
134
+
135
+
136
+ st.set_page_config(page_title="CT Scan Classifier", page_icon="🩺", layout="centered")
137
+
138
+ # Custom CSS for UI Enhancement
139
+ st.markdown("""
140
+ <style>
141
+ /* Main Background & Fonts */
142
+ h1, h2, h3 {
143
+ font-family: 'Helvetica Neue', sans-serif;
144
+ }
145
+ h1 {
146
+ font-weight: 700;
147
+ color: #0f52ba; /* Medical Blue */
148
+ }
149
+
150
+ /* Info Cards Styling */
151
+ div[data-testid="stVerticalBlock"] > div[data-testid="stVerticalBlock"] {
152
+ /* Generic adjustment for nested blocks if needed */
153
+ }
154
+
155
+ /* Custom Button for "Start Detecting" (Anchor Link) */
156
+ a.custom-btn {
157
+ display: inline-block;
158
+ padding: 0.6em 1.2em;
159
+ margin-top: 20px;
160
+ color: #ffffff !important;
161
+ background-color: #ff4b4b;
162
+ border-radius: 8px;
163
+ text-decoration: none;
164
+ font-weight: 600;
165
+ text-align: center;
166
+ transition: all 0.2s ease-in-out;
167
+ box-shadow: 0 4px 6px rgba(0,0,0,0.1);
168
+ }
169
+ a.custom-btn:hover {
170
+ background-color: #ff3333;
171
+ transform: translateY(-2px);
172
+ box-shadow: 0 6px 8px rgba(0,0,0,0.15);
173
+ }
174
+
175
+ /* Style for the metrics/prediction result */
176
+ div[data-testid="stMetricValue"] {
177
+ font-size: 1.5rem;
178
+ }
179
+ </style>
180
+ """, unsafe_allow_html=True)
181
+
182
+ # Resolve static asset directory robustly (works locally and on Streamlit Cloud)
183
+ APP_DIR = Path(__file__).parent.resolve()
184
+ _public_candidates = [
185
+ APP_DIR / "public",
186
+ Path.cwd() / "public",
187
+ APP_DIR.parent / "public",
188
+ ]
189
+ PUBLIC_DIR = next((p for p in _public_candidates if p.exists()), _public_candidates[0])
190
+
191
+ # --- HERO SECTION ---
192
+ st.title("Detect Chest Cancer with CTSense")
193
+ st.caption("Fast, Accurate, and Effortless!")
194
+ col1, col2 = st.columns([2, 1])
195
+ with col1:
196
+ st.markdown(
197
+ """
198
+ <div style="font-size: 1.1em; color: #444; line-height: 1.6;">
199
+ Welcome to the future of chest cancer detection. With the power of <b>CTSense</b>,
200
+ you can analyze your <b>CT scans</b> with just one click and receive fast, reliable insights
201
+ powered by advanced AI technology.
202
+ <br><br>
203
+ Start your scan now and experience precision made simple.
204
+ </div>
205
+ """, unsafe_allow_html=True
206
+ )
207
+ # Replaced st.button with an HTML anchor link styled as a button
208
+ st.markdown('<a href="#prediction-section" class="custom-btn">Start Detecting</a>', unsafe_allow_html=True)
209
+
210
+ with col2:
211
+ # Prefer local static image if present; fallback to remote URL
212
+ hero_local = PUBLIC_DIR / "1.png"
213
+ st.image(str(hero_local), use_column_width=True, width=500)
214
+
215
+ # --- INFO SECTION ---
216
+ st.divider()
217
+ st.header("What You Need to Know About Chest Cancer")
218
+
219
+ st.subheader("What Is Chest Cancer?")
220
+ st.write(
221
+ "Chest cancer refers to several types of cancers that form in the tissues of the lungs. "
222
+ "These cancers grow uncontrollably and can interfere with your breathing, oxygen levels, and overall health. "
223
+ "Some types grow slowly, while others spread quickly. Early detection is crucial."
224
+ )
225
+
226
+ st.subheader("Main Types of Chest Cancer")
227
+ st.caption("In our system, we detect these categories:")
228
+
229
+ # Row 1: Adenocarcinoma | Large Cell Carcinoma
230
+ row1_left, row1_right = st.columns(2)
231
+ with row1_left:
232
+ with st.container(border=True):
233
+ st.subheader("Adenocarcinoma")
234
+ st.write(
235
+ "A common type of lung cancer that starts in the glandular cells. "
236
+ "It often grows in the outer parts of the lungs and is more likely to appear in non-smokers than other types."
237
+ )
238
+ with row1_right:
239
+ with st.container(border=True):
240
+ st.subheader("Large Cell Carcinoma")
241
+ st.write(
242
+ "A more aggressive and large cancer that can appear anywhere in the lungs. "
243
+ "It grows and spreads faster and is usually harder to treat if found late."
244
+ )
245
+
246
+ # Row 2: Squamous Cell Carcinoma | Normal
247
+ row2_left, row2_right = st.columns(2)
248
+ with row2_left:
249
+ with st.container(border=True):
250
+ st.subheader("Squamous Cell Carcinoma")
251
+ st.write(
252
+ "This type begins in the thin, flat cells lining the airways. "
253
+ "It often develops in the center of the lungs and is strongly linked to smoking."
254
+ )
255
+ with row2_right:
256
+ with st.container(border=True):
257
+ st.subheader("Normal")
258
+ st.write(
259
+ "No signs of detectable cancer were found based on the uploaded scan. "
260
+ "The AI did not identify any suspicious growths (cancer)."
261
+ )
262
+
263
+ st.subheader("What Happens if It’s Left Untreated?")
264
+ st.write(
265
+ "Without treatment, chest cancer can spread to other organs, reduce lung function, "
266
+ "cause severe breathing issues, and become life-threatening. Early diagnosis significantly improves "
267
+ "treatment options and survival rates."
268
+ )
269
+
270
+ st.subheader("How Do You Detect It?")
271
+ st.write(
272
+ "Chest cancer often begins with mild or unclear symptoms like coughing, chest pain, or fatigue. "
273
+ "Because these signs can be easily missed, doctors rely on **CT scans** to spot abnormalities."
274
+ )
275
+ st.write(
276
+ "With CTSense AI, you can upload your chest scan and receive a fast, AI-powered analysis that helps identify "
277
+ "the presence of cancer types such as Adenocarcinoma, Large Cell Carcinoma, and Squamous Cell Carcinoma."
278
+ )
279
+
280
+ st.divider()
281
+
282
+ # --- PREDICTION / CLASSIFIER SECTION ---
283
+ # Add an invisible anchor for the button to scroll to
284
+ st.markdown('<div id="prediction-section"></div>', unsafe_allow_html=True)
285
+
286
+ st.title("CT Scan Classifier (ConvNeXt Large)")
287
+
288
+ # Sidebar for Model Info & Graphs
289
+ with st.sidebar:
290
+ st.subheader("CTSense")
291
+ st.write("Using weights: `CTScan_ConvNeXtLarge.pth`")
292
+
293
+ st.link_button("GitHub Repository", "https://github.com/Jasonnn13/FinalProjectComputerVision")
294
+
295
+ st.subheader("Training Curves")
296
+ shown_any = False
297
+ for rel, label in [
298
+ ("acc.png", "Accuracy"),
299
+ ("loss.png", "Loss"),
300
+ ]:
301
+ img_path = PUBLIC_DIR / rel
302
+ st.caption(f"{label} (from {img_path.name})")
303
+ st.image(str(img_path), use_column_width=True)
304
+ shown_any = True
305
+
306
+ if not shown_any:
307
+ st.caption("Place images like public/acc.png and public/loss.png to display here.")
308
+
309
+
310
+ @st.cache_resource(show_spinner=False)
311
+ def _load_once():
312
+ return load_model("CTScan_ConvNeXtLarge.pth")
313
+
314
+
315
+ try:
316
+ model, class_names = _load_once()
317
+ except Exception as e:
318
+ st.error("Failed to load model. See details below.")
319
+ st.exception(e)
320
+ st.stop()
321
+
322
+
323
+ uploaded = st.file_uploader(
324
+ "Upload CT image (PNG/JPG)", type=["png", "jpg", "jpeg"], accept_multiple_files=False
325
+ )
326
+
327
+ if uploaded is not None:
328
+ image_bytes = uploaded.read()
329
+ img = Image.open(io.BytesIO(image_bytes))
330
+ st.image(img, caption="Uploaded Image", use_column_width=True)
331
+
332
+ if st.button("Predict", type="primary"):
333
+ with st.spinner("Running inference..."):
334
+ tensor = preprocess_image(img)
335
+ idx, conf, probs = predict(model, tensor)
336
+
337
+ pred_label = class_names[idx] if idx < len(class_names) else f"Class {idx}"
338
+
339
+ st.markdown("---")
340
+ st.subheader("Prediction Result")
341
+ col_res1, col_res2 = st.columns(2)
342
+ with col_res1:
343
+ st.success(f"**{pred_label}**")
344
+ with col_res2:
345
+ st.metric("Confidence", f"{conf:.2%}")
346
 
347
+ else:
348
+ st.info("Please upload an image to begin.")