| |
| import os |
| import streamlit as st |
| import torch |
| import torch.nn as nn |
| from torchvision import models, transforms |
| from PIL import Image |
|
|
| |
| device = torch.device("cpu") |
|
|
| |
| LABEL2IDX = {"Normal": 0, "Pneumonia": 1} |
| IDX2LABEL = {v: k for k, v in LABEL2IDX.items()} |
|
|
| |
| tf = transforms.Compose([ |
| transforms.Grayscale(num_output_channels=3), |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
| ]) |
|
|
| def build_model(name): |
| """Build the model architecture""" |
| if name == "MobileNetV2": |
| m = models.mobilenet_v2(pretrained=False) |
| m.classifier[1] = nn.Sequential( |
| nn.Dropout(0.3), |
| nn.Linear(m.classifier[1].in_features, 1) |
| ) |
| else: |
| m = models.resnet18(pretrained=False) |
| m.fc = nn.Sequential( |
| nn.Dropout(0.3), |
| nn.Linear(m.fc.in_features, 1) |
| ) |
| return m |
|
|
| @st.cache_resource |
| def load_model(name: str): |
| """Load the pre-trained model""" |
| mdl = build_model(name) |
| |
| |
| if name == "MobileNetV2": |
| model_filename = "best_MobileNetV2_model.pth" |
| else: |
| model_filename = "best_ResNet_model.pth" |
| |
| |
| try: |
| model_path = os.path.join(".", model_filename) |
| mdl.load_state_dict(torch.load(model_path, map_location=device, weights_only=True)) |
| except Exception as e: |
| st.error(f"β Error loading model: {e}") |
| return None |
| |
| mdl.to(device).eval() |
| return mdl |
|
|
| def predict(img, model, threshold=0.7): |
| """Make prediction on the image""" |
| if model is None: |
| return None, None |
| |
| x = tf(img).unsqueeze(0).to(device) |
| with torch.no_grad(): |
| out = model(x) |
| p = torch.sigmoid(out).item() |
| |
| pred_idx = int(p > threshold) |
| return pred_idx, p |
|
|
| |
| SAMPLE_IMAGES = { |
| "Sample Normal X-ray": "sample_normal.jpg", |
| "Sample Pneumonia X-ray": "sample_pneumonia.jpg" |
| } |
|
|
| |
| st.set_page_config(page_title="Chest X-ray Classification", layout="wide") |
| st.title("π« Chest X-ray Classification (Normal vs Pneumonia)") |
|
|
| |
| st.sidebar.header("βοΈ Settings") |
| model_name = st.sidebar.selectbox("Choose model", ["ResNet18", "MobileNetV2"]) |
| threshold = st.sidebar.slider("Decision Threshold", 0.4, 0.9, 0.7, 0.05) |
|
|
| |
| src = st.sidebar.radio("Image source", ["Use sample image", "Upload your own"]) |
|
|
| |
| col_img, col_res = st.columns([1.2, 1]) |
|
|
| if src == "Use sample image": |
| selected_sample = st.selectbox("Choose sample image", list(SAMPLE_IMAGES.keys())) |
| |
| try: |
| img_path = SAMPLE_IMAGES[selected_sample] |
| img = Image.open(img_path) |
| |
| with col_img: |
| st.image(img, caption=f"πΌοΈ {selected_sample}", width=400) |
| |
| with col_res: |
| if st.button("π Predict", use_container_width=True): |
| with st.spinner("π Loading model and predicting..."): |
| model = load_model(model_name) |
| if model is not None: |
| pred_idx, prob_pos = predict(img, model, threshold) |
| |
| if pred_idx is not None: |
| pred_label = IDX2LABEL[pred_idx] |
| |
| st.success("π― Prediction Result") |
| st.markdown(f"**Predicted Label:** **{pred_label}**") |
| st.metric("P(Pneumonia)", f"{prob_pos:.3f}") |
| |
| except Exception as e: |
| st.error(f"β Error loading sample image: {e}") |
|
|
| else: |
| uploaded_file = st.file_uploader("Upload chest X-ray image (JPEG/PNG)", |
| type=["jpg", "jpeg", "png"]) |
| |
| if uploaded_file is not None: |
| try: |
| img = Image.open(uploaded_file) |
| |
| with col_img: |
| st.image(img, caption=f"πΌοΈ {uploaded_file.name}", width=400) |
| |
| with col_res: |
| if st.button("π Predict", use_container_width=True): |
| with st.spinner("π Loading model and predicting..."): |
| model = load_model(model_name) |
| if model is not None: |
| pred_idx, prob_pos = predict(img, model, threshold) |
| |
| if pred_idx is not None: |
| pred_label = IDX2LABEL[pred_idx] |
| |
| st.success("π― Prediction Result") |
| st.markdown(f"**Predicted Label:** **{pred_label}**") |
| st.metric("P(Pneumonia)", f"{prob_pos:.3f}") |
| |
| except Exception as e: |
| st.error(f"β Error processing image: {e}") |
|
|
| |
| st.sidebar.markdown("---") |
| st.sidebar.warning( |
| "**Disclaimer:** This is a demonstration tool. Always consult healthcare professionals for medical diagnosis." |
| ) |