Mrhuman1 commited on
Commit
ab120a7
·
verified ·
1 Parent(s): 67246b5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -0
app.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision.transforms as transforms
5
+ from efficientnet_pytorch import EfficientNet
6
+ from PIL import Image
7
+ import numpy as np
8
+ import os
9
+ from fpdf import FPDF
10
+ from datetime import datetime
11
+
12
+ # Define HardSwish activation
13
+ class HardSwish(nn.Module):
14
+ def __init__(self):
15
+ super(HardSwish, self).__init__()
16
+
17
+ def forward(self, x):
18
+ return x * (torch.clamp(x + 3, 0, 6) / 6)
19
+
20
+ # Define model class
21
+ class CustomEfficientNet(nn.Module):
22
+ def __init__(self, num_classes):
23
+ super(CustomEfficientNet, self).__init__()
24
+ self.model = EfficientNet.from_name('efficientnet-b3')
25
+ num_ftrs = self.model._fc.in_features
26
+ self.model._fc = nn.Sequential(
27
+ nn.Linear(num_ftrs, 512),
28
+ HardSwish(),
29
+ nn.Dropout(p=0.4),
30
+ nn.Linear(512, num_classes)
31
+ )
32
+
33
+ def forward(self, x):
34
+ return self.model(x)
35
+
36
+ # Class names
37
+ class_names = ['No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity',
38
+ 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis', 'Pneumothorax',
39
+ 'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices']
40
+
41
+ # Device config
42
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
43
+
44
+ # Load the model
45
+ model = CustomEfficientNet(num_classes=14)
46
+ checkpoint = torch.load('Final_global_model.pth.tar', map_location=device)
47
+ if 'state_dict' in checkpoint:
48
+ model.load_state_dict(checkpoint['state_dict'])
49
+ else:
50
+ model.load_state_dict(checkpoint)
51
+ model = model.to(device)
52
+ model.eval()
53
+
54
+ # Transform
55
+ transform = transforms.Compose([
56
+ transforms.Resize((300, 300)),
57
+ transforms.ToTensor(),
58
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
59
+ ])
60
+
61
+ # Prediction and PDF generation
62
+ def predict_and_generate_pdf(image, patient_name, xray_date):
63
+ if image is None:
64
+ raise ValueError("❌ Error: No image uploaded.")
65
+
66
+ if not isinstance(image, Image.Image):
67
+ image = Image.fromarray(image)
68
+
69
+ if image.mode != 'RGB':
70
+ image = image.convert('RGB')
71
+
72
+ # Preprocess
73
+ img = transform(image).unsqueeze(0).to(device)
74
+
75
+ # Predict
76
+ with torch.no_grad():
77
+ outputs = model(img)
78
+ probs = torch.sigmoid(outputs).cpu().numpy()[0]
79
+
80
+ # Process results
81
+ results = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
82
+ sorted_results = dict(sorted(results.items(), key=lambda item: item[1], reverse=True))
83
+ top5 = list(sorted_results.items())[:5]
84
+
85
+ # Generate Comment
86
+ top_label, top_prob = top5[0]
87
+ if top_label == "No Finding" and top_prob > 0.5:
88
+ comment = "No major abnormal findings detected."
89
+ elif top_prob > 0.5:
90
+ comment = f"High likelihood of {top_label}."
91
+ else:
92
+ comment = f"Possibility of {top_label}, but with low confidence."
93
+
94
+ # Save thumbnail
95
+ image_path = "xray_temp.jpg"
96
+ image.thumbnail((200, 200))
97
+ image.save(image_path)
98
+
99
+ # Create PDF
100
+ pdf = FPDF()
101
+ pdf.add_page()
102
+ pdf.set_font("Arial", 'B', 16)
103
+ pdf.cell(0, 10, "Chest X-ray Report", ln=True, align='C')
104
+ pdf.ln(10)
105
+
106
+ pdf.set_font("Arial", '', 12)
107
+ pdf.cell(0, 10, f"Patient Name: {patient_name}", ln=True)
108
+ pdf.cell(0, 10, f"X-ray Date: {xray_date}", ln=True)
109
+ pdf.cell(0, 10, f"Report Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", ln=True)
110
+ pdf.ln(10)
111
+
112
+ pdf.image(image_path, w=80)
113
+ pdf.ln(10)
114
+
115
+ pdf.set_font("Arial", 'B', 14)
116
+ pdf.cell(0, 10, "Top 5 Predictions:", ln=True)
117
+ pdf.set_font("Arial", '', 12)
118
+ for disease, prob in top5:
119
+ pdf.cell(0, 10, f"{disease}: {prob*100:.2f}%", ln=True)
120
+
121
+ pdf.ln(10)
122
+ pdf.set_font("Arial", 'B', 14)
123
+ pdf.cell(0, 10, "Comment:", ln=True)
124
+ pdf.set_font("Arial", '', 12)
125
+ pdf.multi_cell(0, 10, comment)
126
+
127
+ # Save PDF
128
+ pdf_output_path = "report.pdf"
129
+ pdf.output(pdf_output_path)
130
+
131
+ return pdf_output_path
132
+
133
+ # Gradio Interface
134
+ with gr.Blocks(theme="default") as demo:
135
+ gr.Markdown("# 🩺 Chest X-ray Disease Classification")
136
+ gr.Markdown("Upload a chest X-ray, enter patient's name and date, and download a PDF report.")
137
+
138
+ with gr.Row():
139
+ with gr.Column():
140
+ image_input = gr.Image(type="pil", label="Upload Chest X-ray")
141
+ name_input = gr.Textbox(label="Patient Name")
142
+ date_input = gr.Textbox(label="X-ray Date (YYYY-MM-DD)", placeholder="2025-04-27")
143
+ submit_btn = gr.Button("Analyze & Generate Report")
144
+ with gr.Column():
145
+ file_output = gr.File(label="Download Report (PDF)")
146
+
147
+ submit_btn.click(
148
+ fn=predict_and_generate_pdf,
149
+ inputs=[image_input, name_input, date_input],
150
+ outputs=file_output
151
+ )
152
+
153
+ if __name__ == "__main__":
154
+ demo.launch()