hamoudaGue's picture
Update app/app.py
69d4d1e verified
# app/streamlit_app.py
import os
import streamlit as st
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
# Simplified - no device detection needed for HF Spaces
device = torch.device("cpu")
# Label mappings
LABEL2IDX = {"Normal": 0, "Pneumonia": 1}
IDX2LABEL = {v: k for k, v in LABEL2IDX.items()}
# Image transformations
tf = transforms.Compose([
transforms.Grayscale(num_output_channels=3),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def build_model(name):
"""Build the model architecture"""
if name == "MobileNetV2":
m = models.mobilenet_v2(pretrained=False)
m.classifier[1] = nn.Sequential(
nn.Dropout(0.3),
nn.Linear(m.classifier[1].in_features, 1)
)
else: # ResNet18
m = models.resnet18(pretrained=False)
m.fc = nn.Sequential(
nn.Dropout(0.3),
nn.Linear(m.fc.in_features, 1)
)
return m
@st.cache_resource
def load_model(name: str):
"""Load the pre-trained model"""
mdl = build_model(name)
# Model file names
if name == "MobileNetV2":
model_filename = "best_MobileNetV2_model.pth"
else: # ResNet18
model_filename = "best_ResNet_model.pth"
# Load model - HF Spaces will have these files in the repo
try:
model_path = os.path.join(".", model_filename)
mdl.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
except Exception as e:
st.error(f"❌ Error loading model: {e}")
return None
mdl.to(device).eval()
return mdl
def predict(img, model, threshold=0.7):
"""Make prediction on the image"""
if model is None:
return None, None
x = tf(img).unsqueeze(0).to(device)
with torch.no_grad():
out = model(x)
p = torch.sigmoid(out).item()
pred_idx = int(p > threshold)
return pred_idx, p
# Sample test images
SAMPLE_IMAGES = {
"Sample Normal X-ray": "sample_normal.jpg",
"Sample Pneumonia X-ray": "sample_pneumonia.jpg"
}
# --------------- UI ----------------
st.set_page_config(page_title="Chest X-ray Classification", layout="wide")
st.title("🫁 Chest X-ray Classification (Normal vs Pneumonia)")
# Sidebar for controls
st.sidebar.header("βš™οΈ Settings")
model_name = st.sidebar.selectbox("Choose model", ["ResNet18", "MobileNetV2"])
threshold = st.sidebar.slider("Decision Threshold", 0.4, 0.9, 0.7, 0.05)
# Image source selection
src = st.sidebar.radio("Image source", ["Use sample image", "Upload your own"])
# Main layout with two columns
col_img, col_res = st.columns([1.2, 1])
if src == "Use sample image":
selected_sample = st.selectbox("Choose sample image", list(SAMPLE_IMAGES.keys()))
try:
img_path = SAMPLE_IMAGES[selected_sample]
img = Image.open(img_path)
with col_img:
st.image(img, caption=f"πŸ–ΌοΈ {selected_sample}", width=400)
with col_res:
if st.button("πŸ” Predict", use_container_width=True):
with st.spinner("πŸ”„ Loading model and predicting..."):
model = load_model(model_name)
if model is not None:
pred_idx, prob_pos = predict(img, model, threshold)
if pred_idx is not None:
pred_label = IDX2LABEL[pred_idx]
st.success("🎯 Prediction Result")
st.markdown(f"**Predicted Label:** **{pred_label}**")
st.metric("P(Pneumonia)", f"{prob_pos:.3f}")
except Exception as e:
st.error(f"❌ Error loading sample image: {e}")
else: # Upload your own
uploaded_file = st.file_uploader("Upload chest X-ray image (JPEG/PNG)",
type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
try:
img = Image.open(uploaded_file)
with col_img:
st.image(img, caption=f"πŸ–ΌοΈ {uploaded_file.name}", width=400)
with col_res:
if st.button("πŸ” Predict", use_container_width=True):
with st.spinner("πŸ”„ Loading model and predicting..."):
model = load_model(model_name)
if model is not None:
pred_idx, prob_pos = predict(img, model, threshold)
if pred_idx is not None:
pred_label = IDX2LABEL[pred_idx]
st.success("🎯 Prediction Result")
st.markdown(f"**Predicted Label:** **{pred_label}**")
st.metric("P(Pneumonia)", f"{prob_pos:.3f}")
except Exception as e:
st.error(f"❌ Error processing image: {e}")
# Disclaimer
st.sidebar.markdown("---")
st.sidebar.warning(
"**Disclaimer:** This is a demonstration tool. Always consult healthcare professionals for medical diagnosis."
)