File size: 10,770 Bytes
6db0507 08e8363 6db0507 cbc33fe 6db0507 cbc33fe 6db0507 cbc33fe 6db0507 08e8363 6db0507 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 | import io
from typing import List, Optional, Tuple
import numpy as np
import streamlit as st
import torch
torch.classes.__path__ = []
import torch.nn as nn
from PIL import Image
from pathlib import Path
# Fixed class mapping provided by user
CLASS_TO_LABEL = {
0: "Adenocarcinoma",
1: "Large Cell Carcinoma",
2: "Normal",
3: "Squamous Cell Carcinoma",
}
def _infer_num_classes_from_state(state_dict: dict) -> Optional[int]:
candidates = [
"classifier.2.weight",
"head.fc.weight",
"fc.weight",
"classifier.weight",
]
for k in candidates:
if k in state_dict:
return int(state_dict[k].shape[0])
# Try to find any linear layer weight at the tail of classifier
keys = [k for k in state_dict.keys() if k.endswith(".weight")]
for k in keys:
if ".classifier" in k or ".head" in k or k.endswith("fc.weight"):
try:
return int(state_dict[k].shape[0])
except Exception:
pass
return None
def _infer_class_names(ckpt: dict, num_classes: int) -> List[str]:
# Common patterns
for key in ("classes", "class_names", "labels"):
if isinstance(ckpt.get(key), (list, tuple)):
return list(ckpt[key])
if isinstance(ckpt.get("idx_to_class"), dict):
# Ensure ordered by index
mapping = ckpt["idx_to_class"]
try:
return [mapping[i] for i in range(len(mapping))]
except Exception:
# Fallback arbitrary order
return list(mapping.values())
if isinstance(ckpt.get("class_to_idx"), dict):
inv = sorted(ckpt["class_to_idx"].items(), key=lambda x: x[1])
return [name for name, _ in inv]
return [f"Class {i}" for i in range(num_classes)]
@st.cache_resource(show_spinner=True)
def load_model(weights_path: str) -> Tuple[nn.Module, List[str]]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt = torch.load(weights_path, map_location=device)
if isinstance(ckpt, dict):
state_dict = ckpt.get("state_dict") or ckpt.get("model_state_dict") or ckpt
else:
state_dict = ckpt
# Prefer fixed mapping if provided, otherwise infer
if CLASS_TO_LABEL:
num_classes = len(CLASS_TO_LABEL)
else:
num_classes = _infer_num_classes_from_state(state_dict) or 2
model = None
errors = []
# Try torchvision ConvNeXt Large first
try:
from torchvision.models import convnext_large
tv_model = convnext_large(weights=None)
in_features = tv_model.classifier[2].in_features
tv_model.classifier[2] = nn.Linear(in_features, num_classes)
tv_model.load_state_dict(state_dict, strict=False)
model = tv_model
except Exception as e:
errors.append(f"torchvision load failed: {e}")
if model is None:
raise RuntimeError(
"Failed to load model with the provided weights. "
+ " ; ".join(errors)
)
model.to(device)
model.eval()
if CLASS_TO_LABEL and len(CLASS_TO_LABEL) == num_classes:
class_names = [CLASS_TO_LABEL[i] for i in range(num_classes)]
else:
class_names = _infer_class_names(ckpt if isinstance(ckpt, dict) else {}, num_classes)
return model, class_names
def preprocess_image(img: Image.Image) -> torch.Tensor:
# Ensure RGB
if img.mode != "RGB":
img = img.convert("RGB")
# Resize to 224 while keeping aspect ratio via center-crop like behavior
img = img.resize((224, 224))
arr = np.array(img).astype("float32") / 255.0
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
arr = (arr - mean) / std
arr = np.transpose(arr, (2, 0, 1))
tensor = torch.from_numpy(arr)
return tensor
def predict(model: nn.Module, tensor: torch.Tensor) -> Tuple[int, float, np.ndarray]:
device = next(model.parameters()).device
with torch.no_grad():
logits = model(tensor.unsqueeze(0).to(device))
if isinstance(logits, (list, tuple)):
logits = logits[0]
probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
idx = int(np.argmax(probs))
conf = float(probs[idx])
return idx, conf, probs
st.set_page_config(page_title="CT Scan Classifier", page_icon="🩺", layout="centered")
# Custom CSS for UI Enhancement
st.markdown("""
<style>
/* Main Background & Fonts */
h1, h2, h3 {
font-family: 'Helvetica Neue', sans-serif;
}
h1 {
font-weight: 700;
color: #0f52ba; /* Medical Blue */
}
/* Info Cards Styling */
div[data-testid="stVerticalBlock"] > div[data-testid="stVerticalBlock"] {
/* Generic adjustment for nested blocks if needed */
}
/* Custom Button for "Start Detecting" (Anchor Link) */
a.custom-btn {
display: inline-block;
padding: 0.6em 1.2em;
margin-top: 20px;
color: #ffffff !important;
background-color: #ff4b4b;
border-radius: 8px;
text-decoration: none;
font-weight: 600;
text-align: center;
transition: all 0.2s ease-in-out;
box-shadow: 0 4px 6px rgba(0,0,0,0.1);
}
a.custom-btn:hover {
background-color: #ff3333;
transform: translateY(-2px);
box-shadow: 0 6px 8px rgba(0,0,0,0.15);
}
/* Style for the metrics/prediction result */
div[data-testid="stMetricValue"] {
font-size: 1.5rem;
}
</style>
""", unsafe_allow_html=True)
# Resolve static asset directory robustly (works locally and on Streamlit Cloud)
APP_DIR = Path(__file__).parent.resolve()
_public_candidates = [
APP_DIR / "public",
Path.cwd() / "public",
APP_DIR.parent / "public",
]
PUBLIC_DIR = next((p for p in _public_candidates if p.exists()), _public_candidates[0])
# --- HERO SECTION ---
st.title("Detect Chest Cancer with CTSense")
st.caption("Fast, Accurate, and Effortless!")
col1, col2 = st.columns([2, 1])
with col1:
st.markdown(
"""
<div style="font-size: 1.1em; color: #444; line-height: 1.6;">
Welcome to the future of chest cancer detection. With the power of <b>CTSense</b>,
you can analyze your <b>CT scans</b> with just one click and receive fast, reliable insights
powered by advanced AI technology.
<br><br>
Start your scan now and experience precision made simple.
</div>
""", unsafe_allow_html=True
)
# Replaced st.button with an HTML anchor link styled as a button
st.markdown('<a href="#prediction-section" class="custom-btn">Start Detecting</a>', unsafe_allow_html=True)
with col2:
# Prefer local static image if present; fallback to remote URL
hero_local = PUBLIC_DIR / "1.png"
st.image(str(hero_local), use_column_width=True, width=500)
# --- INFO SECTION ---
st.divider()
st.header("What You Need to Know About Chest Cancer")
st.subheader("What Is Chest Cancer?")
st.write(
"Chest cancer refers to several types of cancers that form in the tissues of the lungs. "
"These cancers grow uncontrollably and can interfere with your breathing, oxygen levels, and overall health. "
"Some types grow slowly, while others spread quickly. Early detection is crucial."
)
st.subheader("Main Types of Chest Cancer")
st.caption("In our system, we detect these categories:")
# Row 1: Adenocarcinoma | Large Cell Carcinoma
row1_left, row1_right = st.columns(2)
with row1_left:
with st.container(border=True):
st.subheader("Adenocarcinoma")
st.write(
"A common type of lung cancer that starts in the glandular cells. "
"It often grows in the outer parts of the lungs and is more likely to appear in non-smokers than other types."
)
with row1_right:
with st.container(border=True):
st.subheader("Large Cell Carcinoma")
st.write(
"A more aggressive and large cancer that can appear anywhere in the lungs. "
"It grows and spreads faster and is usually harder to treat if found late."
)
# Row 2: Squamous Cell Carcinoma | Normal
row2_left, row2_right = st.columns(2)
with row2_left:
with st.container(border=True):
st.subheader("Squamous Cell Carcinoma")
st.write(
"This type begins in the thin, flat cells lining the airways. "
"It often develops in the center of the lungs and is strongly linked to smoking."
)
with row2_right:
with st.container(border=True):
st.subheader("Normal")
st.write(
"No signs of detectable cancer were found based on the uploaded scan. "
"The AI did not identify any suspicious growths (cancer)."
)
st.subheader("What Happens if It’s Left Untreated?")
st.write(
"Without treatment, chest cancer can spread to other organs, reduce lung function, "
"cause severe breathing issues, and become life-threatening. Early diagnosis significantly improves "
"treatment options and survival rates."
)
st.subheader("How Do You Detect It?")
st.write(
"Chest cancer often begins with mild or unclear symptoms like coughing, chest pain, or fatigue. "
"Because these signs can be easily missed, doctors rely on **CT scans** to spot abnormalities."
)
st.write(
"With CTSense AI, you can upload your chest scan and receive a fast, AI-powered analysis that helps identify "
"the presence of cancer types such as Adenocarcinoma, Large Cell Carcinoma, and Squamous Cell Carcinoma."
)
st.divider()
# --- PREDICTION / CLASSIFIER SECTION ---
# Add an invisible anchor for the button to scroll to
st.markdown('<div id="prediction-section"></div>', unsafe_allow_html=True)
st.title("CT Scan Classifier (ConvNeXt Large)")
# Sidebar for Model Info & Graphs
with st.sidebar:
st.subheader("CTSense")
st.write("Using weights: `CTScan_ConvNeXtLarge.pth`")
st.link_button("GitHub Repository", "https://github.com/Jasonnn13/FinalProjectComputerVision")
st.subheader("Training Curves")
shown_any = False
for rel, label in [
("acc.png", "Accuracy"),
("loss.png", "Loss"),
]:
img_path = PUBLIC_DIR / rel
st.caption(f"{label} (from {img_path.name})")
st.image(str(img_path), use_column_width=True)
shown_any = True
if not shown_any:
st.caption("Place images like public/acc.png and public/loss.png to display here.")
@st.cache_resource(show_spinner=False)
def _load_once():
return load_model("CTScan_ConvNeXtLarge.pth")
try:
model, class_names = _load_once()
except Exception as e:
st.error("Failed to load model. See details below.")
st.exception(e)
st.stop()
uploaded = st.file_uploader(
"Upload CT image (PNG/JPG)", type=["png", "jpg", "jpeg"], accept_multiple_files=False
)
if uploaded is not None:
image_bytes = uploaded.read()
img = Image.open(io.BytesIO(image_bytes))
st.image(img, caption="Uploaded Image", use_column_width=True)
if st.button("Predict", type="primary"):
with st.spinner("Running inference..."):
tensor = preprocess_image(img)
idx, conf, probs = predict(model, tensor)
pred_label = class_names[idx] if idx < len(class_names) else f"Class {idx}"
st.markdown("---")
st.subheader("Prediction Result")
col_res1, col_res2 = st.columns(2)
with col_res1:
st.success(f"**{pred_label}**")
with col_res2:
st.metric("Confidence", f"{conf:.2%}")
else:
st.info("Please upload an image to begin.")
|