import io from typing import List, Optional, Tuple import numpy as np import streamlit as st import torch torch.classes.__path__ = [] import torch.nn as nn from PIL import Image from pathlib import Path # Fixed class mapping provided by user CLASS_TO_LABEL = { 0: "Adenocarcinoma", 1: "Large Cell Carcinoma", 2: "Normal", 3: "Squamous Cell Carcinoma", } def _infer_num_classes_from_state(state_dict: dict) -> Optional[int]: candidates = [ "classifier.2.weight", "head.fc.weight", "fc.weight", "classifier.weight", ] for k in candidates: if k in state_dict: return int(state_dict[k].shape[0]) # Try to find any linear layer weight at the tail of classifier keys = [k for k in state_dict.keys() if k.endswith(".weight")] for k in keys: if ".classifier" in k or ".head" in k or k.endswith("fc.weight"): try: return int(state_dict[k].shape[0]) except Exception: pass return None def _infer_class_names(ckpt: dict, num_classes: int) -> List[str]: # Common patterns for key in ("classes", "class_names", "labels"): if isinstance(ckpt.get(key), (list, tuple)): return list(ckpt[key]) if isinstance(ckpt.get("idx_to_class"), dict): # Ensure ordered by index mapping = ckpt["idx_to_class"] try: return [mapping[i] for i in range(len(mapping))] except Exception: # Fallback arbitrary order return list(mapping.values()) if isinstance(ckpt.get("class_to_idx"), dict): inv = sorted(ckpt["class_to_idx"].items(), key=lambda x: x[1]) return [name for name, _ in inv] return [f"Class {i}" for i in range(num_classes)] @st.cache_resource(show_spinner=True) def load_model(weights_path: str) -> Tuple[nn.Module, List[str]]: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ckpt = torch.load(weights_path, map_location=device) if isinstance(ckpt, dict): state_dict = ckpt.get("state_dict") or ckpt.get("model_state_dict") or ckpt else: state_dict = ckpt # Prefer fixed mapping if provided, otherwise infer if CLASS_TO_LABEL: num_classes = len(CLASS_TO_LABEL) else: num_classes = _infer_num_classes_from_state(state_dict) or 2 model = None errors = [] # Try torchvision ConvNeXt Large first try: from torchvision.models import convnext_large tv_model = convnext_large(weights=None) in_features = tv_model.classifier[2].in_features tv_model.classifier[2] = nn.Linear(in_features, num_classes) tv_model.load_state_dict(state_dict, strict=False) model = tv_model except Exception as e: errors.append(f"torchvision load failed: {e}") if model is None: raise RuntimeError( "Failed to load model with the provided weights. " + " ; ".join(errors) ) model.to(device) model.eval() if CLASS_TO_LABEL and len(CLASS_TO_LABEL) == num_classes: class_names = [CLASS_TO_LABEL[i] for i in range(num_classes)] else: class_names = _infer_class_names(ckpt if isinstance(ckpt, dict) else {}, num_classes) return model, class_names def preprocess_image(img: Image.Image) -> torch.Tensor: # Ensure RGB if img.mode != "RGB": img = img.convert("RGB") # Resize to 224 while keeping aspect ratio via center-crop like behavior img = img.resize((224, 224)) arr = np.array(img).astype("float32") / 255.0 mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) std = np.array([0.229, 0.224, 0.225], dtype=np.float32) arr = (arr - mean) / std arr = np.transpose(arr, (2, 0, 1)) tensor = torch.from_numpy(arr) return tensor def predict(model: nn.Module, tensor: torch.Tensor) -> Tuple[int, float, np.ndarray]: device = next(model.parameters()).device with torch.no_grad(): logits = model(tensor.unsqueeze(0).to(device)) if isinstance(logits, (list, tuple)): logits = logits[0] probs = torch.softmax(logits, dim=1).cpu().numpy()[0] idx = int(np.argmax(probs)) conf = float(probs[idx]) return idx, conf, probs st.set_page_config(page_title="CT Scan Classifier", page_icon="🩺", layout="centered") # Custom CSS for UI Enhancement st.markdown(""" """, unsafe_allow_html=True) # Resolve static asset directory robustly (works locally and on Streamlit Cloud) APP_DIR = Path(__file__).parent.resolve() _public_candidates = [ APP_DIR / "public", Path.cwd() / "public", APP_DIR.parent / "public", ] PUBLIC_DIR = next((p for p in _public_candidates if p.exists()), _public_candidates[0]) # --- HERO SECTION --- st.title("Detect Chest Cancer with CTSense") st.caption("Fast, Accurate, and Effortless!") col1, col2 = st.columns([2, 1]) with col1: st.markdown( """
Welcome to the future of chest cancer detection. With the power of CTSense, you can analyze your CT scans with just one click and receive fast, reliable insights powered by advanced AI technology.

Start your scan now and experience precision made simple.
""", unsafe_allow_html=True ) # Replaced st.button with an HTML anchor link styled as a button st.markdown('Start Detecting', unsafe_allow_html=True) with col2: # Prefer local static image if present; fallback to remote URL hero_local = PUBLIC_DIR / "1.png" st.image(str(hero_local), use_column_width=True, width=500) # --- INFO SECTION --- st.divider() st.header("What You Need to Know About Chest Cancer") st.subheader("What Is Chest Cancer?") st.write( "Chest cancer refers to several types of cancers that form in the tissues of the lungs. " "These cancers grow uncontrollably and can interfere with your breathing, oxygen levels, and overall health. " "Some types grow slowly, while others spread quickly. Early detection is crucial." ) st.subheader("Main Types of Chest Cancer") st.caption("In our system, we detect these categories:") # Row 1: Adenocarcinoma | Large Cell Carcinoma row1_left, row1_right = st.columns(2) with row1_left: with st.container(border=True): st.subheader("Adenocarcinoma") st.write( "A common type of lung cancer that starts in the glandular cells. " "It often grows in the outer parts of the lungs and is more likely to appear in non-smokers than other types." ) with row1_right: with st.container(border=True): st.subheader("Large Cell Carcinoma") st.write( "A more aggressive and large cancer that can appear anywhere in the lungs. " "It grows and spreads faster and is usually harder to treat if found late." ) # Row 2: Squamous Cell Carcinoma | Normal row2_left, row2_right = st.columns(2) with row2_left: with st.container(border=True): st.subheader("Squamous Cell Carcinoma") st.write( "This type begins in the thin, flat cells lining the airways. " "It often develops in the center of the lungs and is strongly linked to smoking." ) with row2_right: with st.container(border=True): st.subheader("Normal") st.write( "No signs of detectable cancer were found based on the uploaded scan. " "The AI did not identify any suspicious growths (cancer)." ) st.subheader("What Happens if It’s Left Untreated?") st.write( "Without treatment, chest cancer can spread to other organs, reduce lung function, " "cause severe breathing issues, and become life-threatening. Early diagnosis significantly improves " "treatment options and survival rates." ) st.subheader("How Do You Detect It?") st.write( "Chest cancer often begins with mild or unclear symptoms like coughing, chest pain, or fatigue. " "Because these signs can be easily missed, doctors rely on **CT scans** to spot abnormalities." ) st.write( "With CTSense AI, you can upload your chest scan and receive a fast, AI-powered analysis that helps identify " "the presence of cancer types such as Adenocarcinoma, Large Cell Carcinoma, and Squamous Cell Carcinoma." ) st.divider() # --- PREDICTION / CLASSIFIER SECTION --- # Add an invisible anchor for the button to scroll to st.markdown('
', unsafe_allow_html=True) st.title("CT Scan Classifier (ConvNeXt Large)") # Sidebar for Model Info & Graphs with st.sidebar: st.subheader("CTSense") st.write("Using weights: `CTScan_ConvNeXtLarge.pth`") st.link_button("GitHub Repository", "https://github.com/Jasonnn13/FinalProjectComputerVision") st.subheader("Training Curves") shown_any = False for rel, label in [ ("acc.png", "Accuracy"), ("loss.png", "Loss"), ]: img_path = PUBLIC_DIR / rel st.caption(f"{label} (from {img_path.name})") st.image(str(img_path), use_column_width=True) shown_any = True if not shown_any: st.caption("Place images like public/acc.png and public/loss.png to display here.") @st.cache_resource(show_spinner=False) def _load_once(): return load_model("CTScan_ConvNeXtLarge.pth") try: model, class_names = _load_once() except Exception as e: st.error("Failed to load model. See details below.") st.exception(e) st.stop() uploaded = st.file_uploader( "Upload CT image (PNG/JPG)", type=["png", "jpg", "jpeg"], accept_multiple_files=False ) if uploaded is not None: image_bytes = uploaded.read() img = Image.open(io.BytesIO(image_bytes)) st.image(img, caption="Uploaded Image", use_column_width=True) if st.button("Predict", type="primary"): with st.spinner("Running inference..."): tensor = preprocess_image(img) idx, conf, probs = predict(model, tensor) pred_label = class_names[idx] if idx < len(class_names) else f"Class {idx}" st.markdown("---") st.subheader("Prediction Result") col_res1, col_res2 = st.columns(2) with col_res1: st.success(f"**{pred_label}**") with col_res2: st.metric("Confidence", f"{conf:.2%}") else: st.info("Please upload an image to begin.")