File size: 10,770 Bytes
6db0507
 
 
08e8363
 
6db0507
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbc33fe
6db0507
cbc33fe
6db0507
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbc33fe
 
 
6db0507
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08e8363
6db0507
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
import io
from typing import List, Optional, Tuple

import numpy as np
import streamlit as st
import torch
torch.classes.__path__ = []
import torch.nn as nn
from PIL import Image

from pathlib import Path


# Fixed class mapping provided by user
CLASS_TO_LABEL = {
	0: "Adenocarcinoma",
	1: "Large Cell Carcinoma",
	2: "Normal",
	3: "Squamous Cell Carcinoma",
}


def _infer_num_classes_from_state(state_dict: dict) -> Optional[int]:
	candidates = [
		"classifier.2.weight",
		"head.fc.weight",
		"fc.weight",
		"classifier.weight",
	]
	for k in candidates:
		if k in state_dict:
			return int(state_dict[k].shape[0])
	# Try to find any linear layer weight at the tail of classifier
	keys = [k for k in state_dict.keys() if k.endswith(".weight")]
	for k in keys:
		if ".classifier" in k or ".head" in k or k.endswith("fc.weight"):
			try:
				return int(state_dict[k].shape[0])
			except Exception:
				pass
	return None


def _infer_class_names(ckpt: dict, num_classes: int) -> List[str]:
	# Common patterns
	for key in ("classes", "class_names", "labels"):
		if isinstance(ckpt.get(key), (list, tuple)):
			return list(ckpt[key])
	if isinstance(ckpt.get("idx_to_class"), dict):
		# Ensure ordered by index
		mapping = ckpt["idx_to_class"]
		try:
			return [mapping[i] for i in range(len(mapping))]
		except Exception:
			# Fallback arbitrary order
			return list(mapping.values())
	if isinstance(ckpt.get("class_to_idx"), dict):
		inv = sorted(ckpt["class_to_idx"].items(), key=lambda x: x[1])
		return [name for name, _ in inv]
	return [f"Class {i}" for i in range(num_classes)]


@st.cache_resource(show_spinner=True)
def load_model(weights_path: str) -> Tuple[nn.Module, List[str]]:
	device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
	ckpt = torch.load(weights_path, map_location=device)
	if isinstance(ckpt, dict):
		state_dict = ckpt.get("state_dict") or ckpt.get("model_state_dict") or ckpt
	else:
		state_dict = ckpt

	# Prefer fixed mapping if provided, otherwise infer
	if CLASS_TO_LABEL:
		num_classes = len(CLASS_TO_LABEL)
	else:
		num_classes = _infer_num_classes_from_state(state_dict) or 2

	model = None
	errors = []

	# Try torchvision ConvNeXt Large first
	try:
		from torchvision.models import convnext_large
		tv_model = convnext_large(weights=None)
		in_features = tv_model.classifier[2].in_features
		tv_model.classifier[2] = nn.Linear(in_features, num_classes)
		tv_model.load_state_dict(state_dict, strict=False)
		model = tv_model
	except Exception as e:
		errors.append(f"torchvision load failed: {e}")

	if model is None:
		raise RuntimeError(
			"Failed to load model with the provided weights. "
			+ " ; ".join(errors)
		)

	model.to(device)
	model.eval()

	if CLASS_TO_LABEL and len(CLASS_TO_LABEL) == num_classes:
		class_names = [CLASS_TO_LABEL[i] for i in range(num_classes)]
	else:
		class_names = _infer_class_names(ckpt if isinstance(ckpt, dict) else {}, num_classes)
	return model, class_names


def preprocess_image(img: Image.Image) -> torch.Tensor:
	# Ensure RGB
	if img.mode != "RGB":
		img = img.convert("RGB")
	# Resize to 224 while keeping aspect ratio via center-crop like behavior
	img = img.resize((224, 224))

	arr = np.array(img).astype("float32") / 255.0
	mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
	std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
	arr = (arr - mean) / std
	arr = np.transpose(arr, (2, 0, 1))
	tensor = torch.from_numpy(arr)
	return tensor


def predict(model: nn.Module, tensor: torch.Tensor) -> Tuple[int, float, np.ndarray]:
	device = next(model.parameters()).device
	with torch.no_grad():
		logits = model(tensor.unsqueeze(0).to(device))
		if isinstance(logits, (list, tuple)):
			logits = logits[0]
		probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
		idx = int(np.argmax(probs))
		conf = float(probs[idx])
	return idx, conf, probs


st.set_page_config(page_title="CT Scan Classifier", page_icon="🩺", layout="centered")

# Custom CSS for UI Enhancement
st.markdown("""
<style>
    /* Main Background & Fonts */
    h1, h2, h3 {
        font-family: 'Helvetica Neue', sans-serif;
    }
    h1 {
        font-weight: 700;
        color: #0f52ba; /* Medical Blue */
    }
    
    /* Info Cards Styling */
    div[data-testid="stVerticalBlock"] > div[data-testid="stVerticalBlock"] {
        /* Generic adjustment for nested blocks if needed */
    }
    
    /* Custom Button for "Start Detecting" (Anchor Link) */
    a.custom-btn {
        display: inline-block;
        padding: 0.6em 1.2em;
        margin-top: 20px;
        color: #ffffff !important;
        background-color: #ff4b4b;
        border-radius: 8px;
        text-decoration: none;
        font-weight: 600;
        text-align: center;
        transition: all 0.2s ease-in-out;
        box-shadow: 0 4px 6px rgba(0,0,0,0.1);
    }
    a.custom-btn:hover {
        background-color: #ff3333;
        transform: translateY(-2px);
        box-shadow: 0 6px 8px rgba(0,0,0,0.15);
    }

    /* Style for the metrics/prediction result */
    div[data-testid="stMetricValue"] {
        font-size: 1.5rem;
    }
</style>
""", unsafe_allow_html=True)

# Resolve static asset directory robustly (works locally and on Streamlit Cloud)
APP_DIR = Path(__file__).parent.resolve()
_public_candidates = [
	APP_DIR / "public",
	Path.cwd() / "public",
	APP_DIR.parent / "public",
]
PUBLIC_DIR = next((p for p in _public_candidates if p.exists()), _public_candidates[0])

# --- HERO SECTION ---
st.title("Detect Chest Cancer with CTSense")
st.caption("Fast, Accurate, and Effortless!")
col1, col2 = st.columns([2, 1])
with col1:
    st.markdown(
		"""
		<div style="font-size: 1.1em; color: #444; line-height: 1.6;">
		Welcome to the future of chest cancer detection. With the power of <b>CTSense</b>, 
		you can analyze your <b>CT scans</b> with just one click and receive fast, reliable insights 
		powered by advanced AI technology.
		<br><br>
		Start your scan now and experience precision made simple.
		</div>
		""", unsafe_allow_html=True
	)
    # Replaced st.button with an HTML anchor link styled as a button
    st.markdown('<a href="#prediction-section" class="custom-btn">Start Detecting</a>', unsafe_allow_html=True)

with col2:
	# Prefer local static image if present; fallback to remote URL
	hero_local = PUBLIC_DIR / "1.png"
	st.image(str(hero_local), use_column_width=True, width=500)

# --- INFO SECTION ---
st.divider()
st.header("What You Need to Know About Chest Cancer")

st.subheader("What Is Chest Cancer?")
st.write(
	"Chest cancer refers to several types of cancers that form in the tissues of the lungs. "
	"These cancers grow uncontrollably and can interfere with your breathing, oxygen levels, and overall health. "
	"Some types grow slowly, while others spread quickly. Early detection is crucial."
)

st.subheader("Main Types of Chest Cancer")
st.caption("In our system, we detect these categories:")

# Row 1: Adenocarcinoma | Large Cell Carcinoma
row1_left, row1_right = st.columns(2)
with row1_left:
	with st.container(border=True):
		st.subheader("Adenocarcinoma")
		st.write(
			"A common type of lung cancer that starts in the glandular cells. "
			"It often grows in the outer parts of the lungs and is more likely to appear in non-smokers than other types."
		)
with row1_right:
	with st.container(border=True):
		st.subheader("Large Cell Carcinoma")
		st.write(
			"A more aggressive and large cancer that can appear anywhere in the lungs. "
			"It grows and spreads faster and is usually harder to treat if found late."
		)

# Row 2: Squamous Cell Carcinoma | Normal
row2_left, row2_right = st.columns(2)
with row2_left:
	with st.container(border=True):
		st.subheader("Squamous Cell Carcinoma")
		st.write(
			"This type begins in the thin, flat cells lining the airways. "
			"It often develops in the center of the lungs and is strongly linked to smoking."
		)
with row2_right:
	with st.container(border=True):
		st.subheader("Normal")
		st.write(
			"No signs of detectable cancer were found based on the uploaded scan. "
			"The AI did not identify any suspicious growths (cancer)."
		)

st.subheader("What Happens if It’s Left Untreated?")
st.write(
	"Without treatment, chest cancer can spread to other organs, reduce lung function, "
	"cause severe breathing issues, and become life-threatening. Early diagnosis significantly improves "
	"treatment options and survival rates."
)

st.subheader("How Do You Detect It?")
st.write(
	"Chest cancer often begins with mild or unclear symptoms like coughing, chest pain, or fatigue. "
	"Because these signs can be easily missed, doctors rely on **CT scans** to spot abnormalities."
)
st.write(
	"With CTSense AI, you can upload your chest scan and receive a fast, AI-powered analysis that helps identify "
	"the presence of cancer types such as Adenocarcinoma, Large Cell Carcinoma, and Squamous Cell Carcinoma."
)

st.divider()

# --- PREDICTION / CLASSIFIER SECTION ---
# Add an invisible anchor for the button to scroll to
st.markdown('<div id="prediction-section"></div>', unsafe_allow_html=True)

st.title("CT Scan Classifier (ConvNeXt Large)")

# Sidebar for Model Info & Graphs
with st.sidebar:
	st.subheader("CTSense")
	st.write("Using weights: `CTScan_ConvNeXtLarge.pth`")
 
	st.link_button("GitHub Repository", "https://github.com/Jasonnn13/FinalProjectComputerVision")

	st.subheader("Training Curves")
	shown_any = False
	for rel, label in [    
		("acc.png", "Accuracy"),
		("loss.png", "Loss"),
	]:
		img_path = PUBLIC_DIR / rel
		st.caption(f"{label} (from {img_path.name})")
		st.image(str(img_path), use_column_width=True)
		shown_any = True

	if not shown_any:
		st.caption("Place images like public/acc.png and public/loss.png to display here.")


@st.cache_resource(show_spinner=False)
def _load_once():
	return load_model("CTScan_ConvNeXtLarge.pth")


try:
	model, class_names = _load_once()
except Exception as e:
	st.error("Failed to load model. See details below.")
	st.exception(e)
	st.stop()


uploaded = st.file_uploader(
	"Upload CT image (PNG/JPG)", type=["png", "jpg", "jpeg"], accept_multiple_files=False
)

if uploaded is not None:
	image_bytes = uploaded.read()
	img = Image.open(io.BytesIO(image_bytes))
	st.image(img, caption="Uploaded Image", use_column_width=True)

	if st.button("Predict", type="primary"):
		with st.spinner("Running inference..."):
			tensor = preprocess_image(img)
			idx, conf, probs = predict(model, tensor)

		pred_label = class_names[idx] if idx < len(class_names) else f"Class {idx}"

		st.markdown("---")
		st.subheader("Prediction Result")
		col_res1, col_res2 = st.columns(2)
		with col_res1:
			st.success(f"**{pred_label}**")
		with col_res2:
			st.metric("Confidence", f"{conf:.2%}")

else:
	st.info("Please upload an image to begin.")