Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| # Page Config - must be first | |
| st.set_page_config( | |
| page_title="Chest X-ray Disease Classifier", | |
| page_icon="π©Ί", | |
| layout="centered" | |
| ) | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| from efficientnet_pytorch import EfficientNet | |
| from PIL import Image | |
| from datetime import datetime | |
| from io import BytesIO | |
| from fpdf import FPDF # For PDF generation | |
| # --- Define HardSwish activation --- | |
| class HardSwish(nn.Module): | |
| def __init__(self): | |
| super(HardSwish, self).__init__() | |
| def forward(self, x): | |
| return x * (torch.clamp(x + 3, 0, 6) / 6) | |
| # --- Define Custom EfficientNet model --- | |
| class CustomEfficientNet(nn.Module): | |
| def __init__(self, num_classes): | |
| super(CustomEfficientNet, self).__init__() | |
| self.model = EfficientNet.from_name('efficientnet-b3') | |
| num_ftrs = self.model._fc.in_features | |
| self.model._fc = nn.Sequential( | |
| nn.Linear(num_ftrs, 512), | |
| HardSwish(), | |
| nn.Dropout(p=0.4), | |
| nn.Linear(512, num_classes) | |
| ) | |
| def forward(self, x): | |
| return self.model(x) | |
| # Disease class labels | |
| class_names = [ | |
| 'No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', | |
| 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis', 'Pneumothorax', | |
| 'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices' | |
| ] | |
| # Device configuration | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Load model | |
| def load_model(): | |
| model = CustomEfficientNet(num_classes=14) | |
| checkpoint = torch.load('Final_global_model.pth.tar', map_location=device) | |
| if 'state_dict' in checkpoint: | |
| model.load_state_dict(checkpoint['state_dict']) | |
| else: | |
| model.load_state_dict(checkpoint) | |
| model = model.to(device) | |
| model.eval() | |
| return model | |
| model = load_model() | |
| # Image transforms | |
| transform = transforms.Compose([ | |
| transforms.Resize((300, 300)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| # Predict function | |
| def predict(image): | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| img = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| outputs = model(img) | |
| probs = torch.sigmoid(outputs).cpu().numpy()[0] | |
| results = {class_names[i]: float(probs[i]) for i in range(len(class_names))} | |
| sorted_results = dict(sorted(results.items(), key=lambda item: item[1], reverse=True)) | |
| top5 = {k: v for k, v in list(sorted_results.items())[:5]} | |
| return top5 | |
| # PDF generator | |
| def generate_pdf(name, date, image, comment): | |
| pdf = FPDF() | |
| pdf.add_page() | |
| # Title | |
| pdf.set_font("Arial", 'B', 20) | |
| pdf.cell(0, 10, "Chest X-ray AI Report", ln=True, align="C") | |
| pdf.ln(10) | |
| # Patient Info | |
| pdf.set_font("Arial", '', 12) | |
| pdf.cell(0, 10, f"Patient Name: {name}", ln=True) | |
| pdf.cell(0, 10, f"Scan Date: {date.strftime('%Y-%m-%d')}", ln=True) | |
| pdf.ln(10) | |
| # X-ray image | |
| image_buffer = BytesIO() | |
| image.save(image_buffer, format='JPEG') | |
| image_buffer.seek(0) | |
| # Save to temp file because fpdf only accepts filename | |
| with open("temp_xray.jpg", "wb") as f: | |
| f.write(image_buffer.read()) | |
| pdf.image("temp_xray.jpg", x=30, w=150) # resize image | |
| pdf.ln(10) | |
| # AI Comment | |
| pdf.set_font("Arial", 'B', 14) | |
| pdf.cell(0, 10, "AI Analysis:", ln=True) | |
| pdf.set_font("Arial", '', 12) | |
| pdf.multi_cell(0, 10, comment) | |
| return pdf.output(dest='S').encode('latin1') | |
| # ----------------- Streamlit App ------------------- | |
| st.markdown( | |
| """ | |
| <h1 style="text-align: center;">π©Ί Chest X-ray Disease Classifier</h1> | |
| <p style="text-align: center;">Upload a chest X-ray image to get disease predictions and AI report.</p> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| with st.form("prediction_form"): | |
| patient_name = st.text_input("π€ Patient Name", placeholder="Enter full name...") | |
| scan_date = st.date_input("π Scan Date", value=datetime.today()) | |
| uploaded_file = st.file_uploader("π€ Upload Chest X-ray Image", type=["png", "jpg", "jpeg", "bmp", "tiff"]) | |
| submit_button = st.form_submit_button("π Analyze X-ray") | |
| # Process the form | |
| if submit_button: | |
| if not uploaded_file: | |
| st.error("β οΈ Please upload a chest X-ray image.") | |
| elif not patient_name.strip(): | |
| st.error("β οΈ Please enter the patient's name.") | |
| else: | |
| image = Image.open(uploaded_file) | |
| with st.spinner('π Analyzing the X-ray...'): | |
| top5_predictions = predict(image) | |
| st.success('β Analysis Completed!') | |
| # Display Info | |
| st.markdown("---") | |
| st.subheader("π Patient Information") | |
| st.write(f"**Name:** {patient_name}") | |
| st.write(f"**Scan Date:** {scan_date.strftime('%Y-%m-%d')}") | |
| # Display Predictions | |
| st.markdown("---") | |
| st.subheader("π§ͺ Top 5 Predictions") | |
| most_likely_disease = list(top5_predictions.items())[0] | |
| ai_comment = f"The most likely disease is **{most_likely_disease[0]}** with a probability of **{most_likely_disease[1]*100:.2f}%**." | |
| for disease, prob in top5_predictions.items(): | |
| st.progress(prob) | |
| st.write(f"πΉ **{disease}** β {prob*100:.2f}%") | |
| st.markdown("---") | |
| st.subheader("πΌοΈ Uploaded X-ray Image (Resized)") | |
| width, height = image.size | |
| resized_image = image.resize((width//2, height//2)) | |
| st.image(resized_image, caption="Uploaded Chest X-ray", use_column_width=False) | |
| st.markdown("---") | |
| st.subheader("π¬ AI Comment") | |
| st.info(ai_comment) | |
| # Generate PDF | |
| pdf_bytes = generate_pdf(patient_name, scan_date, resized_image, ai_comment) | |
| st.download_button( | |
| label="π Download PDF Report", | |
| data=pdf_bytes, | |
| file_name=f"{patient_name.replace(' ', '_')}_Xray_Report.pdf", | |
| mime="application/pdf" | |
| ) | |