Mrhuman1 commited on
Commit
588ecb6
·
verified ·
1 Parent(s): cee555d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -55
app.py CHANGED
@@ -7,8 +7,8 @@ from PIL import Image
7
  from fpdf import FPDF
8
  import os
9
  from datetime import datetime
10
- print("This app is not accurate , use with caution, images uploaded with bad pixels may effect the result")
11
- # Define HardSwish activation
12
  class HardSwish(nn.Module):
13
  def __init__(self):
14
  super(HardSwish, self).__init__()
@@ -16,7 +16,7 @@ class HardSwish(nn.Module):
16
  def forward(self, x):
17
  return x * (torch.clamp(x + 3, 0, 6) / 6)
18
 
19
- # Model class
20
  class CustomEfficientNet(nn.Module):
21
  def __init__(self, num_classes):
22
  super(CustomEfficientNet, self).__init__()
@@ -37,10 +37,10 @@ class_names = ['No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung
37
  'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis', 'Pneumothorax',
38
  'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices']
39
 
40
- # Device config
41
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
42
 
43
- # Load the model
44
  model = CustomEfficientNet(num_classes=14)
45
  checkpoint = torch.load('Final_global_model.pth.tar', map_location=device)
46
  if 'state_dict' in checkpoint:
@@ -50,22 +50,27 @@ else:
50
  model = model.to(device)
51
  model.eval()
52
 
53
- # Transform
54
  transform = transforms.Compose([
55
  transforms.Resize((300, 300)),
56
  transforms.ToTensor(),
57
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
58
  ])
59
 
60
- # Prediction and PDF generation
61
- def predict_and_generate_pdf(image, patient_name, xray_date):
 
 
 
 
62
  if image is None:
63
  raise ValueError("❌ Error: No image uploaded.")
64
  if not patient_name.strip():
65
  raise ValueError("❌ Error: Patient name is required.")
66
- if not xray_date:
67
  raise ValueError("❌ Error: X-ray date is required.")
68
 
 
69
  if not isinstance(image, Image.Image):
70
  image = Image.fromarray(image)
71
 
@@ -80,88 +85,70 @@ def predict_and_generate_pdf(image, patient_name, xray_date):
80
  outputs = model(img)
81
  probs = torch.sigmoid(outputs).cpu().numpy()[0]
82
 
83
- # Process results
84
  results = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
85
  sorted_results = dict(sorted(results.items(), key=lambda item: item[1], reverse=True))
86
  top5 = list(sorted_results.items())[:5]
87
 
88
- # Comment Logic
89
  top_label, top_prob = top5[0]
90
  if top_label == "No Finding" and top_prob > 0.5:
91
  comment = "✅ No major abnormal findings detected."
92
  elif top_prob > 0.5:
93
  comment = f"⚠️ High likelihood of {top_label}."
94
  else:
95
- comment = f"🔎 Possibility of {top_label}, but low confidence."
96
-
97
- # Save thumbnail
98
- image_path = "xray_thumbnail.jpg"
99
- image_copy = image.copy()
100
- image_copy.thumbnail((100, 100))
101
- image_copy.save(image_path)
102
 
103
- # Create PDF
104
  pdf = FPDF()
105
  pdf.add_page()
106
- pdf.set_font("Arial", 'B', 18)
107
- pdf.cell(0, 10, "Chest X-ray Analysis Report", ln=True, align='C')
108
  pdf.ln(10)
109
 
110
  # Patient Details
111
  pdf.set_font("Arial", '', 12)
112
  pdf.cell(0, 10, f"Patient Name: {patient_name}", ln=True)
113
- pdf.cell(0, 10, f"X-ray Date: {xray_date}", ln=True)
114
  pdf.cell(0, 10, f"Report Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", ln=True)
115
  pdf.ln(10)
116
 
117
- # X-ray Image
118
- pdf.image(image_path, x=80, w=50)
119
- pdf.ln(10)
120
-
121
  # Top 5 Predictions
122
  pdf.set_font("Arial", 'B', 14)
123
  pdf.cell(0, 10, "Top 5 Predictions:", ln=True)
124
  pdf.set_font("Arial", '', 12)
125
 
126
- pdf.set_fill_color(230, 230, 230)
127
  for disease, prob in top5:
128
- pdf.cell(100, 10, disease, 1, 0, 'L', fill=True)
129
- pdf.cell(40, 10, f"{prob*100:.2f}%", 1, 1, 'C', fill=True)
130
 
131
  pdf.ln(10)
132
 
133
- # Comments
134
  pdf.set_font("Arial", 'B', 14)
135
  pdf.cell(0, 10, "Doctor's Comment:", ln=True)
136
  pdf.set_font("Arial", '', 12)
137
- pdf.set_fill_color(240, 248, 255)
138
- pdf.multi_cell(0, 10, comment, fill=True)
139
 
140
- # Save PDF
141
- output_pdf_path = "chest_xray_report.pdf"
142
- pdf.output(output_pdf_path)
 
143
 
144
- return output_pdf_path
145
 
146
  # Gradio Interface
147
- with gr.Blocks(theme="default") as demo:
148
- gr.Markdown("# 🩺 Chest X-ray Disease Classification App")
149
- gr.Markdown("Upload a chest X-ray, enter patient information, and generate a detailed PDF report.")
150
-
151
- with gr.Row():
152
- with gr.Column():
153
- image_input = gr.Image(type="pil", label="Upload Chest X-ray Image")
154
- name_input = gr.Textbox(label="Patient Name")
155
- date_input = gr.Textbox(label="Date of X-ray (YYYY-MM-DD)", placeholder="e.g. 2025-04-27")
156
- submit_button = gr.Button("Analyze & Generate PDF Report")
157
- with gr.Column():
158
- file_output = gr.File(label="Download Generated Report (PDF)")
159
-
160
- submit_button.click(
161
- fn=predict_and_generate_pdf,
162
- inputs=[image_input, name_input, date_input],
163
- outputs=file_output
164
- )
165
 
166
  if __name__ == "__main__":
167
- demo.launch()
 
7
  from fpdf import FPDF
8
  import os
9
  from datetime import datetime
10
+
11
+ # Define HardSwish activation function
12
  class HardSwish(nn.Module):
13
  def __init__(self):
14
  super(HardSwish, self).__init__()
 
16
  def forward(self, x):
17
  return x * (torch.clamp(x + 3, 0, 6) / 6)
18
 
19
+ # Define Custom EfficientNet
20
  class CustomEfficientNet(nn.Module):
21
  def __init__(self, num_classes):
22
  super(CustomEfficientNet, self).__init__()
 
37
  'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis', 'Pneumothorax',
38
  'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices']
39
 
40
+ # Device configuration
41
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
42
 
43
+ # Load model
44
  model = CustomEfficientNet(num_classes=14)
45
  checkpoint = torch.load('Final_global_model.pth.tar', map_location=device)
46
  if 'state_dict' in checkpoint:
 
50
  model = model.to(device)
51
  model.eval()
52
 
53
+ # Transformations
54
  transform = transforms.Compose([
55
  transforms.Resize((300, 300)),
56
  transforms.ToTensor(),
57
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
58
  ])
59
 
60
+ # Helper to sanitize filename
61
+ def clean_filename(name):
62
+ return "".join(c for c in name if c.isalnum() or c in (' ', '_', '-')).rstrip()
63
+
64
+ # Prediction and PDF generation function
65
+ def predict(image, patient_name, xray_date):
66
  if image is None:
67
  raise ValueError("❌ Error: No image uploaded.")
68
  if not patient_name.strip():
69
  raise ValueError("❌ Error: Patient name is required.")
70
+ if not xray_date.strip():
71
  raise ValueError("❌ Error: X-ray date is required.")
72
 
73
+ # Ensure correct image mode
74
  if not isinstance(image, Image.Image):
75
  image = Image.fromarray(image)
76
 
 
85
  outputs = model(img)
86
  probs = torch.sigmoid(outputs).cpu().numpy()[0]
87
 
 
88
  results = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
89
  sorted_results = dict(sorted(results.items(), key=lambda item: item[1], reverse=True))
90
  top5 = list(sorted_results.items())[:5]
91
 
92
+ # Doctor's comment
93
  top_label, top_prob = top5[0]
94
  if top_label == "No Finding" and top_prob > 0.5:
95
  comment = "✅ No major abnormal findings detected."
96
  elif top_prob > 0.5:
97
  comment = f"⚠️ High likelihood of {top_label}."
98
  else:
99
+ comment = f"🔎 Possible {top_label}, but confidence is low."
 
 
 
 
 
 
100
 
101
+ # Generate PDF
102
  pdf = FPDF()
103
  pdf.add_page()
104
+ pdf.set_font("Arial", 'B', 16)
105
+ pdf.cell(0, 10, "Chest X-ray Disease Report", ln=True, align='C')
106
  pdf.ln(10)
107
 
108
  # Patient Details
109
  pdf.set_font("Arial", '', 12)
110
  pdf.cell(0, 10, f"Patient Name: {patient_name}", ln=True)
111
+ pdf.cell(0, 10, f"Date of X-ray: {xray_date}", ln=True)
112
  pdf.cell(0, 10, f"Report Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", ln=True)
113
  pdf.ln(10)
114
 
 
 
 
 
115
  # Top 5 Predictions
116
  pdf.set_font("Arial", 'B', 14)
117
  pdf.cell(0, 10, "Top 5 Predictions:", ln=True)
118
  pdf.set_font("Arial", '', 12)
119
 
 
120
  for disease, prob in top5:
121
+ pdf.cell(0, 10, f"{disease}: {prob*100:.2f}%", ln=True)
 
122
 
123
  pdf.ln(10)
124
 
125
+ # Comment Section
126
  pdf.set_font("Arial", 'B', 14)
127
  pdf.cell(0, 10, "Doctor's Comment:", ln=True)
128
  pdf.set_font("Arial", '', 12)
129
+ pdf.multi_cell(0, 10, comment)
 
130
 
131
+ # Save PDF with clean filename
132
+ safe_name = clean_filename(patient_name.replace(" ", "_"))
133
+ report_filename = f"{safe_name}_{xray_date}_Report.pdf"
134
+ pdf.output(report_filename)
135
 
136
+ return report_filename
137
 
138
  # Gradio Interface
139
+ interface = gr.Interface(
140
+ fn=predict,
141
+ inputs=[
142
+ gr.Image(type="pil", label="Upload Chest X-ray Image"),
143
+ gr.Textbox(label="Patient Name"),
144
+ gr.Textbox(label="Date of X-ray (YYYY-MM-DD)", placeholder="e.g. 2025-04-27")
145
+ ],
146
+ outputs=gr.File(label="Download PDF Report"),
147
+ title="Chest X-ray Disease Classification with Report",
148
+ description="Upload an X-ray, enter patient details, and download a detailed PDF report.",
149
+ theme="default",
150
+ allow_flagging="never"
151
+ )
 
 
 
 
 
152
 
153
  if __name__ == "__main__":
154
+ interface.launch()