Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import os | |
| # ========================================== | |
| # 0. GLOBAL CONFIG (Super Lightweight) | |
| # ========================================== | |
| st.set_page_config(page_title="AI Pneumonia Assistant", page_icon="π«", layout="wide") | |
| st.write("β App booted successfully.") | |
| # define paths only (no file checks yet) | |
| BASE_DIR = "kaggle_checkpoints" | |
| VIT_DIR_NAME = "vit_best_model" | |
| RESNET_DIR_NAME = "resnet_best_model" | |
| VIT_PATH = os.path.join(BASE_DIR, VIT_DIR_NAME) | |
| RESNET_PATH = os.path.join(BASE_DIR, RESNET_DIR_NAME) | |
| DETECTOR_PATH = "faster_rcnn_epoch_5.pth" | |
| # ========================================== | |
| # 1. MODEL LOADING (Cached & Lazy) | |
| # ========================================== | |
| def load_ensemble(): | |
| """ | |
| This function performs all the heavy lifting. | |
| It is ONLY called when the user clicks 'Start AI Engine'. | |
| """ | |
| # --- LAZY IMPORTS (Crucial: Don't import PyTorch at top level) --- | |
| import torch | |
| from transformers import AutoModelForImageClassification, AutoImageProcessor | |
| from torchvision.models.detection import fasterrcnn_resnet50_fpn | |
| from torchvision.models.detection.faster_rcnn import FastRCNNPredictor | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Helper for safe loading | |
| def load_processor_safely(model_path, fallback_ckpt): | |
| try: | |
| return AutoImageProcessor.from_pretrained(model_path) | |
| except Exception: | |
| return AutoImageProcessor.from_pretrained(fallback_ckpt) | |
| # --- A. Load ViT --- | |
| if not os.path.exists(os.path.join(VIT_PATH, "config.json")): | |
| raise FileNotFoundError(f"Critical: {VIT_PATH} not found. Did Docker build fail?") | |
| vit_model = AutoModelForImageClassification.from_pretrained(VIT_PATH).to(device) | |
| vit_processor = load_processor_safely(VIT_PATH, "google/vit-base-patch16-384") | |
| vit_model.eval() | |
| # --- B. Load ResNet-101 --- | |
| resnet_model = AutoModelForImageClassification.from_pretrained(RESNET_PATH).to(device) | |
| resnet_processor = load_processor_safely(RESNET_PATH, "microsoft/resnet-101") | |
| resnet_model.eval() | |
| # --- C. Load Object Detector --- | |
| detector = fasterrcnn_resnet50_fpn(weights=None) | |
| in_features = detector.roi_heads.box_predictor.cls_score.in_features | |
| detector.roi_heads.box_predictor = FastRCNNPredictor(in_features, 2) | |
| if os.path.exists(DETECTOR_PATH): | |
| state_dict = torch.load(DETECTOR_PATH, map_location=device) | |
| detector.load_state_dict(state_dict) | |
| else: | |
| print(f"β οΈ Warning: Detector weights not found at {DETECTOR_PATH}. Using random weights.") | |
| detector.to(device) | |
| detector.eval() | |
| return vit_model, vit_processor, resnet_model, resnet_processor, detector, device | |
| # ========================================== | |
| # 2. PAGE: TECHNICAL REPORT (Default) | |
| # ========================================== | |
| def run_technical_report(): | |
| # Lazy import pandas just for this page | |
| import pandas as pd | |
| st.title("π¬ Technical Retrospective") | |
| st.markdown("### Efficient Medical Imaging Pipeline β RSNA Pneumonia Detection") | |
| st.markdown("---") | |
| # --- 1. KEY ACHIEVEMENTS --- | |
| st.header("1. Engineering Impact") | |
| st.success(""" | |
| **π Clinical Safety:** Reduced dangerous false negatives by a **factor of 3** compared to accuracy-focused baselines by implementing Weighted Cross-Entropy Loss (74.2% Recall). | |
| **β‘ Cost Efficiency:** Achieved **<0.5s latency per scan** by deploying a "Gatekeeper" inference pipeline that filters normal cases before triggering expensive object detection. | |
| **πΎ Resource Optimization:** Enabled SOTA-level training on **consumer hardware (16GB VRAM)** using Gradient Accumulation and Mixed Precision (FP16). | |
| """) | |
| # --- 2. ARCHITECTURE --- | |
| st.header("2. System Architecture") | |
| col1, col2 = st.columns([1, 1.2]) | |
| with col1: | |
| st.write(""" | |
| **Hybrid Vision Pipeline:** | |
| * **ViT-Base:** Captures global context/symmetry. | |
| * **ResNet-101:** Captures local texture anomalies. | |
| * **Fusion:** Soft-voting ensemble mitigates inductive bias of standalone CNNs. | |
| """) | |
| with col2: | |
| st.code(""" | |
| [ INPUT X-RAY ] | |
| β | |
| ββββββββ΄βββββββ | |
| [ ViT-Base ] [ ResNet-101 ] | |
| β β | |
| βββββ βββββ | |
| [ VOTING ENSEMBLE ] | |
| β | |
| [ RISK SCORE (0-1) ] | |
| β | |
| ββββββββββ΄βββββββββ | |
| [ > Threshold? ] [ < Threshold ] | |
| β β | |
| [ OBJECT DETECTOR ] [ NEGATIVE ] | |
| """, language="text") | |
| # --- 3. CHALLENGES --- | |
| st.header("3. Challenges Solved") | |
| with st.expander("βοΈ Addressed Class Imbalance (3:1 Ratio)", expanded=True): | |
| st.write("Implemented **Weighted Cross-Entropy Loss** to penalize missing a 'Lung Opacity' case 3x more heavily than a 'Normal' case.") | |
| with st.expander("π― Aligned Metrics with Clinical Needs"): | |
| st.write("Shifted optimization target from Accuracy to **Macro-Average Recall**, ensuring the model acts as a reliable 'Safety Net'.") | |
| # --- 4. PERFORMANCE --- | |
| st.header("4. Validation Benchmarks") | |
| metrics_col1, metrics_col2, metrics_col3 = st.columns(3) | |
| metrics_col1.metric("Macro-Avg Recall", "74.2%") | |
| metrics_col2.metric("Lung Opacity Recall", "74.0%") | |
| metrics_col3.metric("Normal Recall", "90.3%") | |
| # ========================================== | |
| # 3. PAGE: LIVE DEMO (Button Protected) | |
| # ========================================== | |
| def run_demo(): | |
| st.title("π« Hybrid-AI Radiologist Assistant") | |
| # --- SESSION STATE CHECK --- | |
| # We check if models are loaded. If not, we show the "Start" button. | |
| if "models_loaded" not in st.session_state: | |
| st.session_state.models_loaded = False | |
| if not st.session_state.models_loaded: | |
| st.info("The AI engine is currently cold. Initialize it to start diagnostics.") | |
| # This button is the ONLY thing that triggers the heavy load | |
| if st.button("π Initialize AI Engine (Loads ~500MB Models)"): | |
| with st.spinner("Waking up the GPU... This takes about 45 seconds..."): | |
| try: | |
| # Load and store in session state | |
| st.session_state.models = load_ensemble() | |
| st.session_state.models_loaded = True | |
| st.rerun() # Refresh page to show the tool | |
| except Exception as e: | |
| st.error(f"Failed to load models: {e}") | |
| st.stop() | |
| else: | |
| # Stop here if button not clicked | |
| st.stop() | |
| # --- IF WE GET HERE, MODELS ARE LOADED --- | |
| # Unpack models from session state | |
| vit, vit_proc, resnet, resnet_proc, detector, device = st.session_state.models | |
| # Lazy imports for the demo logic | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image, ImageDraw | |
| from torchvision import transforms | |
| import pandas as pd | |
| st.markdown("### ViT + ResNet-101 Ensemble Pipeline") | |
| # --- CONTROLS --- | |
| st.sidebar.markdown("---") | |
| st.sidebar.header("βοΈ Settings") | |
| sensitivity = st.sidebar.slider("Opacity Threshold", 0.0, 1.0, 0.35) | |
| force_detect = st.sidebar.checkbox("π¨ Force Specialist Check", value=False) | |
| uploaded_file = st.file_uploader("Upload X-Ray (DICOM/JPG/PNG)", type=["jpg", "png", "jpeg"]) | |
| if uploaded_file: | |
| image = Image.open(uploaded_file).convert("RGB") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.image(image, caption="Scan", use_container_width=True) | |
| if st.button("Run Diagnostics", type="primary"): | |
| with st.spinner("Analyzing..."): | |
| # 1. ViT | |
| inputs_vit = vit_proc(images=image, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| logits_vit = vit(**inputs_vit).logits | |
| probs_vit = F.softmax(logits_vit, dim=-1).cpu().numpy()[0] | |
| # 2. ResNet | |
| inputs_res = resnet_proc(images=image, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| logits_res = resnet(**inputs_res).logits | |
| probs_res = F.softmax(logits_res, dim=-1).cpu().numpy()[0] | |
| # 3. Ensemble | |
| probs_avg = (probs_vit + probs_res) / 2.0 | |
| id2label = vit.config.id2label | |
| labels = [id2label[i] for i in range(len(probs_avg))] | |
| # Decision | |
| opacity_idx = next(i for i, label in id2label.items() if "Lung_Opacity" in label) | |
| opacity_risk = probs_avg[opacity_idx] | |
| with col2: | |
| st.write("### Assessment") | |
| df = pd.DataFrame(probs_avg, index=labels, columns=["Conf"]) | |
| st.bar_chart(df) | |
| color = "red" if opacity_risk > sensitivity else "green" | |
| st.markdown(f"**Risk:** <span style='color:{color}; font-size:24px'><b>{opacity_risk:.1%}</b></span>", unsafe_allow_html=True) | |
| # 4. Detector | |
| if (opacity_risk > sensitivity) or force_detect: | |
| det_img = image.resize((320, 320)) | |
| det_tensor = transforms.ToTensor()(det_img).to(device) | |
| with torch.no_grad(): | |
| output = detector([det_tensor])[0] | |
| keep = output['scores'].cpu().numpy() > 0.5 | |
| boxes = output['boxes'].cpu().numpy()[keep] | |
| with col2: | |
| if len(boxes) > 0: | |
| st.error(f"**Pathology Localized: {len(boxes)} regions**") | |
| draw_img = det_img.copy() | |
| draw = ImageDraw.Draw(draw_img) | |
| for box in boxes: | |
| draw.rectangle(list(box), outline="red", width=3) | |
| st.image(draw_img, use_container_width=True) | |
| else: | |
| st.info("High risk, but no specific box localized.") | |
| else: | |
| with col2: | |
| st.success("β Negative") | |
| # ========================================== | |
| # 4. NAVIGATION | |
| # ========================================== | |
| st.sidebar.title("Navigation") | |
| page = st.sidebar.radio("Go to", ["Technical Report", "Live Diagnostic Tool"]) | |
| if page == "Technical Report": | |
| run_technical_report() | |
| elif page == "Live Diagnostic Tool": | |
| run_demo() |