AalapP's picture
Update app.py
b697bff verified
import streamlit as st
import os
# ==========================================
# 0. GLOBAL CONFIG (Super Lightweight)
# ==========================================
st.set_page_config(page_title="AI Pneumonia Assistant", page_icon="🫁", layout="wide")
st.write("βœ… App booted successfully.")
# define paths only (no file checks yet)
BASE_DIR = "kaggle_checkpoints"
VIT_DIR_NAME = "vit_best_model"
RESNET_DIR_NAME = "resnet_best_model"
VIT_PATH = os.path.join(BASE_DIR, VIT_DIR_NAME)
RESNET_PATH = os.path.join(BASE_DIR, RESNET_DIR_NAME)
DETECTOR_PATH = "faster_rcnn_epoch_5.pth"
# ==========================================
# 1. MODEL LOADING (Cached & Lazy)
# ==========================================
@st.cache_resource
def load_ensemble():
"""
This function performs all the heavy lifting.
It is ONLY called when the user clicks 'Start AI Engine'.
"""
# --- LAZY IMPORTS (Crucial: Don't import PyTorch at top level) ---
import torch
from transformers import AutoModelForImageClassification, AutoImageProcessor
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Helper for safe loading
def load_processor_safely(model_path, fallback_ckpt):
try:
return AutoImageProcessor.from_pretrained(model_path)
except Exception:
return AutoImageProcessor.from_pretrained(fallback_ckpt)
# --- A. Load ViT ---
if not os.path.exists(os.path.join(VIT_PATH, "config.json")):
raise FileNotFoundError(f"Critical: {VIT_PATH} not found. Did Docker build fail?")
vit_model = AutoModelForImageClassification.from_pretrained(VIT_PATH).to(device)
vit_processor = load_processor_safely(VIT_PATH, "google/vit-base-patch16-384")
vit_model.eval()
# --- B. Load ResNet-101 ---
resnet_model = AutoModelForImageClassification.from_pretrained(RESNET_PATH).to(device)
resnet_processor = load_processor_safely(RESNET_PATH, "microsoft/resnet-101")
resnet_model.eval()
# --- C. Load Object Detector ---
detector = fasterrcnn_resnet50_fpn(weights=None)
in_features = detector.roi_heads.box_predictor.cls_score.in_features
detector.roi_heads.box_predictor = FastRCNNPredictor(in_features, 2)
if os.path.exists(DETECTOR_PATH):
state_dict = torch.load(DETECTOR_PATH, map_location=device)
detector.load_state_dict(state_dict)
else:
print(f"⚠️ Warning: Detector weights not found at {DETECTOR_PATH}. Using random weights.")
detector.to(device)
detector.eval()
return vit_model, vit_processor, resnet_model, resnet_processor, detector, device
# ==========================================
# 2. PAGE: TECHNICAL REPORT (Default)
# ==========================================
def run_technical_report():
# Lazy import pandas just for this page
import pandas as pd
st.title("πŸ”¬ Technical Retrospective")
st.markdown("### Efficient Medical Imaging Pipeline – RSNA Pneumonia Detection")
st.markdown("---")
# --- 1. KEY ACHIEVEMENTS ---
st.header("1. Engineering Impact")
st.success("""
**πŸš€ Clinical Safety:** Reduced dangerous false negatives by a **factor of 3** compared to accuracy-focused baselines by implementing Weighted Cross-Entropy Loss (74.2% Recall).
**⚑ Cost Efficiency:** Achieved **<0.5s latency per scan** by deploying a "Gatekeeper" inference pipeline that filters normal cases before triggering expensive object detection.
**πŸ’Ύ Resource Optimization:** Enabled SOTA-level training on **consumer hardware (16GB VRAM)** using Gradient Accumulation and Mixed Precision (FP16).
""")
# --- 2. ARCHITECTURE ---
st.header("2. System Architecture")
col1, col2 = st.columns([1, 1.2])
with col1:
st.write("""
**Hybrid Vision Pipeline:**
* **ViT-Base:** Captures global context/symmetry.
* **ResNet-101:** Captures local texture anomalies.
* **Fusion:** Soft-voting ensemble mitigates inductive bias of standalone CNNs.
""")
with col2:
st.code("""
[ INPUT X-RAY ]
β”‚
β”Œβ”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”
[ ViT-Base ] [ ResNet-101 ]
β”‚ β”‚
└───┐ β”Œβ”€β”€β”€β”˜
[ VOTING ENSEMBLE ]
β”‚
[ RISK SCORE (0-1) ]
β”‚
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”
[ > Threshold? ] [ < Threshold ]
β”‚ β”‚
[ OBJECT DETECTOR ] [ NEGATIVE ]
""", language="text")
# --- 3. CHALLENGES ---
st.header("3. Challenges Solved")
with st.expander("βš–οΈ Addressed Class Imbalance (3:1 Ratio)", expanded=True):
st.write("Implemented **Weighted Cross-Entropy Loss** to penalize missing a 'Lung Opacity' case 3x more heavily than a 'Normal' case.")
with st.expander("🎯 Aligned Metrics with Clinical Needs"):
st.write("Shifted optimization target from Accuracy to **Macro-Average Recall**, ensuring the model acts as a reliable 'Safety Net'.")
# --- 4. PERFORMANCE ---
st.header("4. Validation Benchmarks")
metrics_col1, metrics_col2, metrics_col3 = st.columns(3)
metrics_col1.metric("Macro-Avg Recall", "74.2%")
metrics_col2.metric("Lung Opacity Recall", "74.0%")
metrics_col3.metric("Normal Recall", "90.3%")
# ==========================================
# 3. PAGE: LIVE DEMO (Button Protected)
# ==========================================
def run_demo():
st.title("🫁 Hybrid-AI Radiologist Assistant")
# --- SESSION STATE CHECK ---
# We check if models are loaded. If not, we show the "Start" button.
if "models_loaded" not in st.session_state:
st.session_state.models_loaded = False
if not st.session_state.models_loaded:
st.info("The AI engine is currently cold. Initialize it to start diagnostics.")
# This button is the ONLY thing that triggers the heavy load
if st.button("πŸš€ Initialize AI Engine (Loads ~500MB Models)"):
with st.spinner("Waking up the GPU... This takes about 45 seconds..."):
try:
# Load and store in session state
st.session_state.models = load_ensemble()
st.session_state.models_loaded = True
st.rerun() # Refresh page to show the tool
except Exception as e:
st.error(f"Failed to load models: {e}")
st.stop()
else:
# Stop here if button not clicked
st.stop()
# --- IF WE GET HERE, MODELS ARE LOADED ---
# Unpack models from session state
vit, vit_proc, resnet, resnet_proc, detector, device = st.session_state.models
# Lazy imports for the demo logic
import torch
import torch.nn.functional as F
from PIL import Image, ImageDraw
from torchvision import transforms
import pandas as pd
st.markdown("### ViT + ResNet-101 Ensemble Pipeline")
# --- CONTROLS ---
st.sidebar.markdown("---")
st.sidebar.header("βš™οΈ Settings")
sensitivity = st.sidebar.slider("Opacity Threshold", 0.0, 1.0, 0.35)
force_detect = st.sidebar.checkbox("🚨 Force Specialist Check", value=False)
uploaded_file = st.file_uploader("Upload X-Ray (DICOM/JPG/PNG)", type=["jpg", "png", "jpeg"])
if uploaded_file:
image = Image.open(uploaded_file).convert("RGB")
col1, col2 = st.columns(2)
with col1:
st.image(image, caption="Scan", use_container_width=True)
if st.button("Run Diagnostics", type="primary"):
with st.spinner("Analyzing..."):
# 1. ViT
inputs_vit = vit_proc(images=image, return_tensors="pt").to(device)
with torch.no_grad():
logits_vit = vit(**inputs_vit).logits
probs_vit = F.softmax(logits_vit, dim=-1).cpu().numpy()[0]
# 2. ResNet
inputs_res = resnet_proc(images=image, return_tensors="pt").to(device)
with torch.no_grad():
logits_res = resnet(**inputs_res).logits
probs_res = F.softmax(logits_res, dim=-1).cpu().numpy()[0]
# 3. Ensemble
probs_avg = (probs_vit + probs_res) / 2.0
id2label = vit.config.id2label
labels = [id2label[i] for i in range(len(probs_avg))]
# Decision
opacity_idx = next(i for i, label in id2label.items() if "Lung_Opacity" in label)
opacity_risk = probs_avg[opacity_idx]
with col2:
st.write("### Assessment")
df = pd.DataFrame(probs_avg, index=labels, columns=["Conf"])
st.bar_chart(df)
color = "red" if opacity_risk > sensitivity else "green"
st.markdown(f"**Risk:** <span style='color:{color}; font-size:24px'><b>{opacity_risk:.1%}</b></span>", unsafe_allow_html=True)
# 4. Detector
if (opacity_risk > sensitivity) or force_detect:
det_img = image.resize((320, 320))
det_tensor = transforms.ToTensor()(det_img).to(device)
with torch.no_grad():
output = detector([det_tensor])[0]
keep = output['scores'].cpu().numpy() > 0.5
boxes = output['boxes'].cpu().numpy()[keep]
with col2:
if len(boxes) > 0:
st.error(f"**Pathology Localized: {len(boxes)} regions**")
draw_img = det_img.copy()
draw = ImageDraw.Draw(draw_img)
for box in boxes:
draw.rectangle(list(box), outline="red", width=3)
st.image(draw_img, use_container_width=True)
else:
st.info("High risk, but no specific box localized.")
else:
with col2:
st.success("βœ… Negative")
# ==========================================
# 4. NAVIGATION
# ==========================================
st.sidebar.title("Navigation")
page = st.sidebar.radio("Go to", ["Technical Report", "Live Diagnostic Tool"])
if page == "Technical Report":
run_technical_report()
elif page == "Live Diagnostic Tool":
run_demo()