import streamlit as st # Page Config - must be first st.set_page_config( page_title="Chest X-ray Disease Classifier", page_icon="๐Ÿฉบ", layout="centered" ) import torch import torch.nn as nn import torchvision.transforms as transforms from efficientnet_pytorch import EfficientNet from PIL import Image from datetime import datetime from io import BytesIO from fpdf import FPDF # For PDF generation # --- Define HardSwish activation --- class HardSwish(nn.Module): def __init__(self): super(HardSwish, self).__init__() def forward(self, x): return x * (torch.clamp(x + 3, 0, 6) / 6) # --- Define Custom EfficientNet model --- class CustomEfficientNet(nn.Module): def __init__(self, num_classes): super(CustomEfficientNet, self).__init__() self.model = EfficientNet.from_name('efficientnet-b3') num_ftrs = self.model._fc.in_features self.model._fc = nn.Sequential( nn.Linear(num_ftrs, 512), HardSwish(), nn.Dropout(p=0.4), nn.Linear(512, num_classes) ) def forward(self, x): return self.model(x) # Disease class labels class_names = [ 'No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis', 'Pneumothorax', 'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices' ] # Device configuration device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load model @st.cache_resource def load_model(): model = CustomEfficientNet(num_classes=14) checkpoint = torch.load('Final_global_model.pth.tar', map_location=device) if 'state_dict' in checkpoint: model.load_state_dict(checkpoint['state_dict']) else: model.load_state_dict(checkpoint) model = model.to(device) model.eval() return model model = load_model() # Image transforms transform = transforms.Compose([ transforms.Resize((300, 300)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Predict function def predict(image): if image.mode != 'RGB': image = image.convert('RGB') img = transform(image).unsqueeze(0).to(device) with torch.no_grad(): outputs = model(img) probs = torch.sigmoid(outputs).cpu().numpy()[0] results = {class_names[i]: float(probs[i]) for i in range(len(class_names))} sorted_results = dict(sorted(results.items(), key=lambda item: item[1], reverse=True)) top5 = {k: v for k, v in list(sorted_results.items())[:5]} return top5 # PDF generator def generate_pdf(name, date, image, comment): pdf = FPDF() pdf.add_page() # Title pdf.set_font("Arial", 'B', 20) pdf.cell(0, 10, "Chest X-ray AI Report", ln=True, align="C") pdf.ln(10) # Patient Info pdf.set_font("Arial", '', 12) pdf.cell(0, 10, f"Patient Name: {name}", ln=True) pdf.cell(0, 10, f"Scan Date: {date.strftime('%Y-%m-%d')}", ln=True) pdf.ln(10) # X-ray image image_buffer = BytesIO() image.save(image_buffer, format='JPEG') image_buffer.seek(0) # Save to temp file because fpdf only accepts filename with open("temp_xray.jpg", "wb") as f: f.write(image_buffer.read()) pdf.image("temp_xray.jpg", x=30, w=150) # resize image pdf.ln(10) # AI Comment pdf.set_font("Arial", 'B', 14) pdf.cell(0, 10, "AI Analysis:", ln=True) pdf.set_font("Arial", '', 12) pdf.multi_cell(0, 10, comment) return pdf.output(dest='S').encode('latin1') # ----------------- Streamlit App ------------------- st.markdown( """

๐Ÿฉบ Chest X-ray Disease Classifier

Upload a chest X-ray image to get disease predictions and AI report.

""", unsafe_allow_html=True ) with st.form("prediction_form"): patient_name = st.text_input("๐Ÿ‘ค Patient Name", placeholder="Enter full name...") scan_date = st.date_input("๐Ÿ“… Scan Date", value=datetime.today()) uploaded_file = st.file_uploader("๐Ÿ“ค Upload Chest X-ray Image", type=["png", "jpg", "jpeg", "bmp", "tiff"]) submit_button = st.form_submit_button("๐Ÿ” Analyze X-ray") # Process the form if submit_button: if not uploaded_file: st.error("โš ๏ธ Please upload a chest X-ray image.") elif not patient_name.strip(): st.error("โš ๏ธ Please enter the patient's name.") else: image = Image.open(uploaded_file) with st.spinner('๐Ÿ”Ž Analyzing the X-ray...'): top5_predictions = predict(image) st.success('โœ… Analysis Completed!') # Display Info st.markdown("---") st.subheader("๐Ÿ“‹ Patient Information") st.write(f"**Name:** {patient_name}") st.write(f"**Scan Date:** {scan_date.strftime('%Y-%m-%d')}") # Display Predictions st.markdown("---") st.subheader("๐Ÿงช Top 5 Predictions") most_likely_disease = list(top5_predictions.items())[0] ai_comment = f"The most likely disease is **{most_likely_disease[0]}** with a probability of **{most_likely_disease[1]*100:.2f}%**." for disease, prob in top5_predictions.items(): st.progress(prob) st.write(f"๐Ÿ”น **{disease}** โ€” {prob*100:.2f}%") st.markdown("---") st.subheader("๐Ÿ–ผ๏ธ Uploaded X-ray Image (Resized)") width, height = image.size resized_image = image.resize((width//2, height//2)) st.image(resized_image, caption="Uploaded Chest X-ray", use_column_width=False) st.markdown("---") st.subheader("๐Ÿ’ฌ AI Comment") st.info(ai_comment) # Generate PDF pdf_bytes = generate_pdf(patient_name, scan_date, resized_image, ai_comment) st.download_button( label="๐Ÿ“„ Download PDF Report", data=pdf_bytes, file_name=f"{patient_name.replace(' ', '_')}_Xray_Report.pdf", mime="application/pdf" )