Mrhuman1 commited on
Commit
2f96bed
·
verified ·
1 Parent(s): 2e8579c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -23
app.py CHANGED
@@ -4,8 +4,11 @@ import torch.nn as nn
4
  import torchvision.transforms as transforms
5
  from efficientnet_pytorch import EfficientNet
6
  from PIL import Image
 
 
 
7
 
8
- # Define the HardSwish activation function
9
  class HardSwish(nn.Module):
10
  def __init__(self):
11
  super(HardSwish, self).__init__()
@@ -13,7 +16,7 @@ class HardSwish(nn.Module):
13
  def forward(self, x):
14
  return x * (torch.clamp(x + 3, 0, 6) / 6)
15
 
16
- # Define the model class
17
  class CustomEfficientNet(nn.Module):
18
  def __init__(self, num_classes):
19
  super(CustomEfficientNet, self).__init__()
@@ -34,7 +37,7 @@ class_names = ['No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung
34
  'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis', 'Pneumothorax',
35
  'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices']
36
 
37
- # Device configuration
38
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
39
 
40
  # Load the model
@@ -47,50 +50,118 @@ else:
47
  model = model.to(device)
48
  model.eval()
49
 
50
- # Transformations for input image
51
  transform = transforms.Compose([
52
  transforms.Resize((300, 300)),
53
  transforms.ToTensor(),
54
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
55
  ])
56
 
57
- # Prediction function
58
- def predict(image):
59
  if image is None:
60
  raise ValueError("❌ Error: No image uploaded.")
 
 
 
 
61
 
62
- # Ensure image is in RGB mode
63
  if not isinstance(image, Image.Image):
64
  image = Image.fromarray(image)
65
 
66
  if image.mode != 'RGB':
67
  image = image.convert('RGB')
68
 
69
- # Preprocess the image
70
  img = transform(image).unsqueeze(0).to(device)
71
 
72
- # Prediction
73
  with torch.no_grad():
74
  outputs = model(img)
75
  probs = torch.sigmoid(outputs).cpu().numpy()[0]
76
 
77
- # Prepare results
78
  results = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
79
  sorted_results = dict(sorted(results.items(), key=lambda item: item[1], reverse=True))
80
- top5 = {k: v for k, v in list(sorted_results.items())[:5]}
81
-
82
- return top5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  # Gradio Interface
85
- interface = gr.Interface(
86
- fn=predict,
87
- inputs=gr.Image(type="pil"),
88
- outputs=gr.Label(num_top_classes=5),
89
- title="Chest X-ray Disease Classification",
90
- description="Upload a chest X-ray image (JPG, PNG, BMP, TIFF, etc.) to get disease predictions.\n\nTop 5 diseases are shown with their probability.",
91
- theme="default",
92
- allow_flagging="never"
93
- )
 
 
 
 
 
 
 
 
 
94
 
95
  if __name__ == "__main__":
96
- interface.launch()
 
4
  import torchvision.transforms as transforms
5
  from efficientnet_pytorch import EfficientNet
6
  from PIL import Image
7
+ from fpdf import FPDF
8
+ import os
9
+ from datetime import datetime
10
 
11
+ # Define HardSwish activation
12
  class HardSwish(nn.Module):
13
  def __init__(self):
14
  super(HardSwish, self).__init__()
 
16
  def forward(self, x):
17
  return x * (torch.clamp(x + 3, 0, 6) / 6)
18
 
19
+ # Model class
20
  class CustomEfficientNet(nn.Module):
21
  def __init__(self, num_classes):
22
  super(CustomEfficientNet, self).__init__()
 
37
  'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis', 'Pneumothorax',
38
  'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices']
39
 
40
+ # Device config
41
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
42
 
43
  # Load the model
 
50
  model = model.to(device)
51
  model.eval()
52
 
53
+ # Transform
54
  transform = transforms.Compose([
55
  transforms.Resize((300, 300)),
56
  transforms.ToTensor(),
57
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
58
  ])
59
 
60
+ # Prediction and PDF generation
61
+ def predict_and_generate_pdf(image, patient_name, xray_date):
62
  if image is None:
63
  raise ValueError("❌ Error: No image uploaded.")
64
+ if not patient_name.strip():
65
+ raise ValueError("❌ Error: Patient name is required.")
66
+ if not xray_date:
67
+ raise ValueError("❌ Error: X-ray date is required.")
68
 
 
69
  if not isinstance(image, Image.Image):
70
  image = Image.fromarray(image)
71
 
72
  if image.mode != 'RGB':
73
  image = image.convert('RGB')
74
 
75
+ # Preprocess
76
  img = transform(image).unsqueeze(0).to(device)
77
 
78
+ # Predict
79
  with torch.no_grad():
80
  outputs = model(img)
81
  probs = torch.sigmoid(outputs).cpu().numpy()[0]
82
 
83
+ # Process results
84
  results = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
85
  sorted_results = dict(sorted(results.items(), key=lambda item: item[1], reverse=True))
86
+ top5 = list(sorted_results.items())[:5]
87
+
88
+ # Comment Logic
89
+ top_label, top_prob = top5[0]
90
+ if top_label == "No Finding" and top_prob > 0.5:
91
+ comment = "✅ No major abnormal findings detected."
92
+ elif top_prob > 0.5:
93
+ comment = f"⚠️ High likelihood of {top_label}."
94
+ else:
95
+ comment = f"🔎 Possibility of {top_label}, but low confidence."
96
+
97
+ # Save thumbnail
98
+ image_path = "xray_thumbnail.jpg"
99
+ image_copy = image.copy()
100
+ image_copy.thumbnail((100, 100))
101
+ image_copy.save(image_path)
102
+
103
+ # Create PDF
104
+ pdf = FPDF()
105
+ pdf.add_page()
106
+ pdf.set_font("Arial", 'B', 18)
107
+ pdf.cell(0, 10, "Chest X-ray Analysis Report", ln=True, align='C')
108
+ pdf.ln(10)
109
+
110
+ # Patient Details
111
+ pdf.set_font("Arial", '', 12)
112
+ pdf.cell(0, 10, f"Patient Name: {patient_name}", ln=True)
113
+ pdf.cell(0, 10, f"X-ray Date: {xray_date}", ln=True)
114
+ pdf.cell(0, 10, f"Report Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", ln=True)
115
+ pdf.ln(10)
116
+
117
+ # X-ray Image
118
+ pdf.image(image_path, x=80, w=50)
119
+ pdf.ln(10)
120
+
121
+ # Top 5 Predictions
122
+ pdf.set_font("Arial", 'B', 14)
123
+ pdf.cell(0, 10, "Top 5 Predictions:", ln=True)
124
+ pdf.set_font("Arial", '', 12)
125
+
126
+ pdf.set_fill_color(230, 230, 230)
127
+ for disease, prob in top5:
128
+ pdf.cell(100, 10, disease, 1, 0, 'L', fill=True)
129
+ pdf.cell(40, 10, f"{prob*100:.2f}%", 1, 1, 'C', fill=True)
130
+
131
+ pdf.ln(10)
132
+
133
+ # Comments
134
+ pdf.set_font("Arial", 'B', 14)
135
+ pdf.cell(0, 10, "Doctor's Comment:", ln=True)
136
+ pdf.set_font("Arial", '', 12)
137
+ pdf.set_fill_color(240, 248, 255)
138
+ pdf.multi_cell(0, 10, comment, fill=True)
139
+
140
+ # Save PDF
141
+ output_pdf_path = "chest_xray_report.pdf"
142
+ pdf.output(output_pdf_path)
143
+
144
+ return output_pdf_path
145
 
146
  # Gradio Interface
147
+ with gr.Blocks(theme="default") as demo:
148
+ gr.Markdown("# 🩺 Chest X-ray Disease Classification App")
149
+ gr.Markdown("Upload a chest X-ray, enter patient information, and generate a detailed PDF report.")
150
+
151
+ with gr.Row():
152
+ with gr.Column():
153
+ image_input = gr.Image(type="pil", label="Upload Chest X-ray Image")
154
+ name_input = gr.Textbox(label="Patient Name")
155
+ date_input = gr.Date(label="Date of X-ray")
156
+ submit_button = gr.Button("Analyze & Generate PDF Report")
157
+ with gr.Column():
158
+ file_output = gr.File(label="Download Generated Report (PDF)")
159
+
160
+ submit_button.click(
161
+ fn=predict_and_generate_pdf,
162
+ inputs=[image_input, name_input, date_input],
163
+ outputs=file_output
164
+ )
165
 
166
  if __name__ == "__main__":
167
+ demo.launch()