Mrhuman1's picture
Update app.py
ddaacb7 verified
Raw
History Blame Contribute Delete
6.09 kB
import streamlit as st
# Page Config - must be first
st.set_page_config(
page_title="Chest X-ray Disease Classifier",
page_icon="🩺",
layout="centered"
)
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from efficientnet_pytorch import EfficientNet
from PIL import Image
from datetime import datetime
from io import BytesIO
from fpdf import FPDF # For PDF generation
# --- Define HardSwish activation ---
class HardSwish(nn.Module):
def __init__(self):
super(HardSwish, self).__init__()
def forward(self, x):
return x * (torch.clamp(x + 3, 0, 6) / 6)
# --- Define Custom EfficientNet model ---
class CustomEfficientNet(nn.Module):
def __init__(self, num_classes):
super(CustomEfficientNet, self).__init__()
self.model = EfficientNet.from_name('efficientnet-b3')
num_ftrs = self.model._fc.in_features
self.model._fc = nn.Sequential(
nn.Linear(num_ftrs, 512),
HardSwish(),
nn.Dropout(p=0.4),
nn.Linear(512, num_classes)
)
def forward(self, x):
return self.model(x)
# Disease class labels
class_names = [
'No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity',
'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis', 'Pneumothorax',
'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices'
]
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load model
@st.cache_resource
def load_model():
model = CustomEfficientNet(num_classes=14)
checkpoint = torch.load('Final_global_model.pth.tar', map_location=device)
if 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['state_dict'])
else:
model.load_state_dict(checkpoint)
model = model.to(device)
model.eval()
return model
model = load_model()
# Image transforms
transform = transforms.Compose([
transforms.Resize((300, 300)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Predict function
def predict(image):
if image.mode != 'RGB':
image = image.convert('RGB')
img = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(img)
probs = torch.sigmoid(outputs).cpu().numpy()[0]
results = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
sorted_results = dict(sorted(results.items(), key=lambda item: item[1], reverse=True))
top5 = {k: v for k, v in list(sorted_results.items())[:5]}
return top5
# PDF generator
def generate_pdf(name, date, image, comment):
pdf = FPDF()
pdf.add_page()
# Title
pdf.set_font("Arial", 'B', 20)
pdf.cell(0, 10, "Chest X-ray AI Report", ln=True, align="C")
pdf.ln(10)
# Patient Info
pdf.set_font("Arial", '', 12)
pdf.cell(0, 10, f"Patient Name: {name}", ln=True)
pdf.cell(0, 10, f"Scan Date: {date.strftime('%Y-%m-%d')}", ln=True)
pdf.ln(10)
# X-ray image
image_buffer = BytesIO()
image.save(image_buffer, format='JPEG')
image_buffer.seek(0)
# Save to temp file because fpdf only accepts filename
with open("temp_xray.jpg", "wb") as f:
f.write(image_buffer.read())
pdf.image("temp_xray.jpg", x=30, w=150) # resize image
pdf.ln(10)
# AI Comment
pdf.set_font("Arial", 'B', 14)
pdf.cell(0, 10, "AI Analysis:", ln=True)
pdf.set_font("Arial", '', 12)
pdf.multi_cell(0, 10, comment)
return pdf.output(dest='S').encode('latin1')
# ----------------- Streamlit App -------------------
st.markdown(
"""
<h1 style="text-align: center;">🩺 Chest X-ray Disease Classifier</h1>
<p style="text-align: center;">Upload a chest X-ray image to get disease predictions and AI report.</p>
""",
unsafe_allow_html=True
)
with st.form("prediction_form"):
patient_name = st.text_input("πŸ‘€ Patient Name", placeholder="Enter full name...")
scan_date = st.date_input("πŸ“… Scan Date", value=datetime.today())
uploaded_file = st.file_uploader("πŸ“€ Upload Chest X-ray Image", type=["png", "jpg", "jpeg", "bmp", "tiff"])
submit_button = st.form_submit_button("πŸ” Analyze X-ray")
# Process the form
if submit_button:
if not uploaded_file:
st.error("⚠️ Please upload a chest X-ray image.")
elif not patient_name.strip():
st.error("⚠️ Please enter the patient's name.")
else:
image = Image.open(uploaded_file)
with st.spinner('πŸ”Ž Analyzing the X-ray...'):
top5_predictions = predict(image)
st.success('βœ… Analysis Completed!')
# Display Info
st.markdown("---")
st.subheader("πŸ“‹ Patient Information")
st.write(f"**Name:** {patient_name}")
st.write(f"**Scan Date:** {scan_date.strftime('%Y-%m-%d')}")
# Display Predictions
st.markdown("---")
st.subheader("πŸ§ͺ Top 5 Predictions")
most_likely_disease = list(top5_predictions.items())[0]
ai_comment = f"The most likely disease is **{most_likely_disease[0]}** with a probability of **{most_likely_disease[1]*100:.2f}%**."
for disease, prob in top5_predictions.items():
st.progress(prob)
st.write(f"πŸ”Ή **{disease}** β€” {prob*100:.2f}%")
st.markdown("---")
st.subheader("πŸ–ΌοΈ Uploaded X-ray Image (Resized)")
width, height = image.size
resized_image = image.resize((width//2, height//2))
st.image(resized_image, caption="Uploaded Chest X-ray", use_column_width=False)
st.markdown("---")
st.subheader("πŸ’¬ AI Comment")
st.info(ai_comment)
# Generate PDF
pdf_bytes = generate_pdf(patient_name, scan_date, resized_image, ai_comment)
st.download_button(
label="πŸ“„ Download PDF Report",
data=pdf_bytes,
file_name=f"{patient_name.replace(' ', '_')}_Xray_Report.pdf",
mime="application/pdf"
)