junaid17's picture
Upload 5 files
aab5f08 verified
import streamlit as st
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import io
class MedicalCNN(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Dropout2d(0.2),
nn.MaxPool2d(2, 2),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Dropout2d(0.3),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Dropout2d(0.4),
nn.MaxPool2d(2, 2),
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Dropout2d(0.4),
nn.MaxPool2d(2, 2),
)
self.classifier = nn.Sequential(
nn.Linear(256 * 8 * 8, 1024),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, num_classes)
)
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MedicalCNN(num_classes=2)
model.load_state_dict(torch.load("Disease_Detection_model.pth", map_location=device))
model.eval()
# Define preprocessing (same as test_transform)
test_transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor()
])
class_names = ['NORMAL', 'PNEUMONIA']
st.title("Chest X-ray Disease Prediction")
st.write("Upload a chest X-ray image to predict NORMAL or PNEUMONIA.")
uploaded_file = st.file_uploader("Choose an X-ray image", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file).convert('RGB')
st.image(image, caption="Uploaded X-ray", use_column_width=True)
img_tensor = test_transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(img_tensor)
probabilities = torch.softmax(outputs, dim=1)
pred_idx = torch.argmax(probabilities, dim=1).item()
confidence = probabilities[0, pred_idx].item()
st.markdown(f"### Prediction: **{class_names[pred_idx]}**")
st.markdown(f"Confidence: **{confidence:.2f}**")
if st.button("Show Probability Chart"):
st.bar_chart(probabilities.cpu().numpy()[0])
st.info("Supported formats: JPG, JPEG, PNG")
st.caption("Model: MedicalCNN | Classes: NORMAL, PNEUMONIA")