Mrhuman1 commited on
Commit
2e8579c
·
verified ·
1 Parent(s): c7883ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -81
app.py CHANGED
@@ -4,12 +4,8 @@ import torch.nn as nn
4
  import torchvision.transforms as transforms
5
  from efficientnet_pytorch import EfficientNet
6
  from PIL import Image
7
- import numpy as np
8
- import os
9
- from fpdf import FPDF
10
- from datetime import datetime
11
 
12
- # Define HardSwish activation
13
  class HardSwish(nn.Module):
14
  def __init__(self):
15
  super(HardSwish, self).__init__()
@@ -17,7 +13,7 @@ class HardSwish(nn.Module):
17
  def forward(self, x):
18
  return x * (torch.clamp(x + 3, 0, 6) / 6)
19
 
20
- # Define model class
21
  class CustomEfficientNet(nn.Module):
22
  def __init__(self, num_classes):
23
  super(CustomEfficientNet, self).__init__()
@@ -38,7 +34,7 @@ class_names = ['No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung
38
  'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis', 'Pneumothorax',
39
  'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices']
40
 
41
- # Device config
42
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
43
 
44
  # Load the model
@@ -51,104 +47,50 @@ else:
51
  model = model.to(device)
52
  model.eval()
53
 
54
- # Transform
55
  transform = transforms.Compose([
56
  transforms.Resize((300, 300)),
57
  transforms.ToTensor(),
58
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
59
  ])
60
 
61
- # Prediction and PDF generation
62
- def predict_and_generate_pdf(image, patient_name, xray_date):
63
  if image is None:
64
  raise ValueError("❌ Error: No image uploaded.")
65
 
 
66
  if not isinstance(image, Image.Image):
67
  image = Image.fromarray(image)
68
 
69
  if image.mode != 'RGB':
70
  image = image.convert('RGB')
71
 
72
- # Preprocess
73
  img = transform(image).unsqueeze(0).to(device)
74
 
75
- # Predict
76
  with torch.no_grad():
77
  outputs = model(img)
78
  probs = torch.sigmoid(outputs).cpu().numpy()[0]
79
 
80
- # Process results
81
  results = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
82
  sorted_results = dict(sorted(results.items(), key=lambda item: item[1], reverse=True))
83
- top5 = list(sorted_results.items())[:5]
84
-
85
- # Generate Comment
86
- top_label, top_prob = top5[0]
87
- if top_label == "No Finding" and top_prob > 0.5:
88
- comment = "No major abnormal findings detected."
89
- elif top_prob > 0.5:
90
- comment = f"High likelihood of {top_label}."
91
- else:
92
- comment = f"Possibility of {top_label}, but with low confidence."
93
-
94
- # Save thumbnail
95
- image_path = "xray_temp.jpg"
96
- image.thumbnail((200, 200))
97
- image.save(image_path)
98
-
99
- # Create PDF
100
- pdf = FPDF()
101
- pdf.add_page()
102
- pdf.set_font("Arial", 'B', 16)
103
- pdf.cell(0, 10, "Chest X-ray Report", ln=True, align='C')
104
- pdf.ln(10)
105
-
106
- pdf.set_font("Arial", '', 12)
107
- pdf.cell(0, 10, f"Patient Name: {patient_name}", ln=True)
108
- pdf.cell(0, 10, f"X-ray Date: {xray_date}", ln=True)
109
- pdf.cell(0, 10, f"Report Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", ln=True)
110
- pdf.ln(10)
111
-
112
- pdf.image(image_path, w=80)
113
- pdf.ln(10)
114
-
115
- pdf.set_font("Arial", 'B', 14)
116
- pdf.cell(0, 10, "Top 5 Predictions:", ln=True)
117
- pdf.set_font("Arial", '', 12)
118
- for disease, prob in top5:
119
- pdf.cell(0, 10, f"{disease}: {prob*100:.2f}%", ln=True)
120
-
121
- pdf.ln(10)
122
- pdf.set_font("Arial", 'B', 14)
123
- pdf.cell(0, 10, "Comment:", ln=True)
124
- pdf.set_font("Arial", '', 12)
125
- pdf.multi_cell(0, 10, comment)
126
-
127
- # Save PDF
128
- pdf_output_path = "report.pdf"
129
- pdf.output(pdf_output_path)
130
-
131
- return pdf_output_path
132
 
133
  # Gradio Interface
134
- with gr.Blocks(theme="default") as demo:
135
- gr.Markdown("# 🩺 Chest X-ray Disease Classification")
136
- gr.Markdown("Upload a chest X-ray, enter patient's name and date, and download a PDF report.")
137
-
138
- with gr.Row():
139
- with gr.Column():
140
- image_input = gr.Image(type="pil", label="Upload Chest X-ray")
141
- name_input = gr.Textbox(label="Patient Name")
142
- date_input = gr.Textbox(label="X-ray Date (YYYY-MM-DD)", placeholder="2025-04-27")
143
- submit_btn = gr.Button("Analyze & Generate Report")
144
- with gr.Column():
145
- file_output = gr.File(label="Download Report (PDF)")
146
-
147
- submit_btn.click(
148
- fn=predict_and_generate_pdf,
149
- inputs=[image_input, name_input, date_input],
150
- outputs=file_output
151
- )
152
 
153
  if __name__ == "__main__":
154
- demo.launch()
 
4
  import torchvision.transforms as transforms
5
  from efficientnet_pytorch import EfficientNet
6
  from PIL import Image
 
 
 
 
7
 
8
+ # Define the HardSwish activation function
9
  class HardSwish(nn.Module):
10
  def __init__(self):
11
  super(HardSwish, self).__init__()
 
13
  def forward(self, x):
14
  return x * (torch.clamp(x + 3, 0, 6) / 6)
15
 
16
+ # Define the model class
17
  class CustomEfficientNet(nn.Module):
18
  def __init__(self, num_classes):
19
  super(CustomEfficientNet, self).__init__()
 
34
  'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis', 'Pneumothorax',
35
  'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices']
36
 
37
+ # Device configuration
38
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
39
 
40
  # Load the model
 
47
  model = model.to(device)
48
  model.eval()
49
 
50
+ # Transformations for input image
51
  transform = transforms.Compose([
52
  transforms.Resize((300, 300)),
53
  transforms.ToTensor(),
54
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
55
  ])
56
 
57
+ # Prediction function
58
+ def predict(image):
59
  if image is None:
60
  raise ValueError("❌ Error: No image uploaded.")
61
 
62
+ # Ensure image is in RGB mode
63
  if not isinstance(image, Image.Image):
64
  image = Image.fromarray(image)
65
 
66
  if image.mode != 'RGB':
67
  image = image.convert('RGB')
68
 
69
+ # Preprocess the image
70
  img = transform(image).unsqueeze(0).to(device)
71
 
72
+ # Prediction
73
  with torch.no_grad():
74
  outputs = model(img)
75
  probs = torch.sigmoid(outputs).cpu().numpy()[0]
76
 
77
+ # Prepare results
78
  results = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
79
  sorted_results = dict(sorted(results.items(), key=lambda item: item[1], reverse=True))
80
+ top5 = {k: v for k, v in list(sorted_results.items())[:5]}
81
+
82
+ return top5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  # Gradio Interface
85
+ interface = gr.Interface(
86
+ fn=predict,
87
+ inputs=gr.Image(type="pil"),
88
+ outputs=gr.Label(num_top_classes=5),
89
+ title="Chest X-ray Disease Classification",
90
+ description="Upload a chest X-ray image (JPG, PNG, BMP, TIFF, etc.) to get disease predictions.\n\nTop 5 diseases are shown with their probability.",
91
+ theme="default",
92
+ allow_flagging="never"
93
+ )
 
 
 
 
 
 
 
 
 
94
 
95
  if __name__ == "__main__":
96
+ interface.launch()