| import streamlit as st
|
| import torch
|
| import torch.nn as nn
|
| from torchvision import transforms
|
| from PIL import Image
|
| import io
|
|
|
| class MedicalCNN(nn.Module):
|
| def __init__(self, num_classes):
|
| super().__init__()
|
| self.features = nn.Sequential(
|
| nn.Conv2d(3, 32, kernel_size=3, padding=1),
|
| nn.BatchNorm2d(32),
|
| nn.ReLU(),
|
| nn.Dropout2d(0.2),
|
| nn.MaxPool2d(2, 2),
|
|
|
| nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
| nn.BatchNorm2d(64),
|
| nn.ReLU(),
|
| nn.Dropout2d(0.3),
|
| nn.MaxPool2d(2, 2),
|
|
|
| nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
| nn.BatchNorm2d(128),
|
| nn.ReLU(),
|
| nn.Dropout2d(0.4),
|
| nn.MaxPool2d(2, 2),
|
|
|
| nn.Conv2d(128, 256, kernel_size=3, padding=1),
|
| nn.BatchNorm2d(256),
|
| nn.ReLU(),
|
| nn.Dropout2d(0.4),
|
| nn.MaxPool2d(2, 2),
|
| )
|
|
|
| self.classifier = nn.Sequential(
|
| nn.Linear(256 * 8 * 8, 1024),
|
| nn.ReLU(),
|
| nn.Dropout(0.5),
|
| nn.Linear(1024, 512),
|
| nn.ReLU(),
|
| nn.Dropout(0.5),
|
| nn.Linear(512, num_classes)
|
| )
|
|
|
| def forward(self, x):
|
| x = self.features(x)
|
| x = torch.flatten(x, 1)
|
| x = self.classifier(x)
|
| return x
|
|
|
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| model = MedicalCNN(num_classes=2)
|
| model.load_state_dict(torch.load("Disease_Detection_model.pth", map_location=device))
|
| model.eval()
|
|
|
|
|
| test_transform = transforms.Compose([
|
| transforms.Resize((128, 128)),
|
| transforms.ToTensor()
|
| ])
|
|
|
| class_names = ['NORMAL', 'PNEUMONIA']
|
|
|
| st.title("Chest X-ray Disease Prediction")
|
| st.write("Upload a chest X-ray image to predict NORMAL or PNEUMONIA.")
|
|
|
| uploaded_file = st.file_uploader("Choose an X-ray image", type=["jpg", "jpeg", "png"])
|
|
|
| if uploaded_file is not None:
|
| image = Image.open(uploaded_file).convert('RGB')
|
| st.image(image, caption="Uploaded X-ray", use_column_width=True)
|
|
|
| img_tensor = test_transform(image).unsqueeze(0).to(device)
|
| with torch.no_grad():
|
| outputs = model(img_tensor)
|
| probabilities = torch.softmax(outputs, dim=1)
|
| pred_idx = torch.argmax(probabilities, dim=1).item()
|
| confidence = probabilities[0, pred_idx].item()
|
|
|
| st.markdown(f"### Prediction: **{class_names[pred_idx]}**")
|
| st.markdown(f"Confidence: **{confidence:.2f}**")
|
|
|
| if st.button("Show Probability Chart"):
|
| st.bar_chart(probabilities.cpu().numpy()[0])
|
|
|
| st.info("Supported formats: JPG, JPEG, PNG")
|
| st.caption("Model: MedicalCNN | Classes: NORMAL, PNEUMONIA")
|
|
|