Mrhuman1 commited on
Commit
2c1ab36
·
verified ·
1 Parent(s): b7a8f2d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -71
app.py CHANGED
@@ -1,14 +1,10 @@
1
- import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
  import torchvision.transforms as transforms
5
  from efficientnet_pytorch import EfficientNet
6
  from PIL import Image
7
- from fpdf import FPDF
8
- import os
9
- from datetime import datetime
10
 
11
- # Define HardSwish activation function
12
  class HardSwish(nn.Module):
13
  def __init__(self):
14
  super(HardSwish, self).__init__()
@@ -16,7 +12,7 @@ class HardSwish(nn.Module):
16
  def forward(self, x):
17
  return x * (torch.clamp(x + 3, 0, 6) / 6)
18
 
19
- # Define Custom EfficientNet
20
  class CustomEfficientNet(nn.Module):
21
  def __init__(self, num_classes):
22
  super(CustomEfficientNet, self).__init__()
@@ -40,7 +36,7 @@ class_names = ['No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung
40
  # Device configuration
41
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
42
 
43
- # Load model
44
  model = CustomEfficientNet(num_classes=14)
45
  checkpoint = torch.load('Final_global_model.pth.tar', map_location=device)
46
  if 'state_dict' in checkpoint:
@@ -50,97 +46,47 @@ else:
50
  model = model.to(device)
51
  model.eval()
52
 
53
- # Transformations
54
  transform = transforms.Compose([
55
  transforms.Resize((300, 300)),
56
  transforms.ToTensor(),
57
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
58
  ])
59
 
60
- # Prediction and PDF generation function
61
- def predict(image, patient_name, xray_date):
62
  if image is None:
63
  raise ValueError("❌ Error: No image uploaded.")
64
- if not patient_name.strip():
65
- raise ValueError("❌ Error: Patient name is required.")
66
- if not xray_date.strip():
67
- raise ValueError("❌ Error: X-ray date is required.")
68
 
69
- # Ensure correct image mode
70
  if not isinstance(image, Image.Image):
71
  image = Image.fromarray(image)
72
 
73
  if image.mode != 'RGB':
74
  image = image.convert('RGB')
75
 
76
- # Preprocess
77
  img = transform(image).unsqueeze(0).to(device)
78
 
79
- # Predict
80
  with torch.no_grad():
81
  outputs = model(img)
82
  probs = torch.sigmoid(outputs).cpu().numpy()[0]
83
 
 
84
  results = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
85
  sorted_results = dict(sorted(results.items(), key=lambda item: item[1], reverse=True))
86
- top5 = list(sorted_results.items())[:5]
87
-
88
- # Doctor's comment
89
- top_label, top_prob = top5[0]
90
- if top_label == "No Finding" and top_prob > 0.5:
91
- comment = "✅ No major abnormal findings detected."
92
- elif top_prob > 0.5:
93
- comment = f"⚠️ High likelihood of {top_label}."
94
- else:
95
- comment = f"🔎 Possible {top_label}, but confidence is low."
96
-
97
- # Generate PDF report
98
- pdf = FPDF()
99
- pdf.add_page()
100
- pdf.set_font("Arial", 'B', 16)
101
- pdf.cell(0, 10, "Chest X-ray Disease Report", ln=True, align='C')
102
- pdf.ln(10)
103
-
104
- # Patient Details
105
- pdf.set_font("Arial", '', 12)
106
- pdf.cell(0, 10, f"Patient Name: {patient_name}", ln=True)
107
- pdf.cell(0, 10, f"X-ray Date: {xray_date}", ln=True)
108
- pdf.cell(0, 10, f"Report Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", ln=True)
109
- pdf.ln(10)
110
-
111
- # Top 5 Predictions
112
- pdf.set_font("Arial", 'B', 14)
113
- pdf.cell(0, 10, "Top 5 Predictions:", ln=True)
114
- pdf.set_font("Arial", '', 12)
115
-
116
- for disease, prob in top5:
117
- pdf.cell(0, 10, f"{disease}: {prob*100:.2f}%", ln=True)
118
-
119
- pdf.ln(10)
120
-
121
- # Comment Section
122
- pdf.set_font("Arial", 'B', 14)
123
- pdf.cell(0, 10, "Doctor's Comment:", ln=True)
124
- pdf.set_font("Arial", '', 12)
125
- pdf.multi_cell(0, 10, comment)
126
-
127
- # Save PDF
128
- output_pdf_path = "chest_xray_report.pdf"
129
- pdf.output(output_pdf_path)
130
-
131
- return output_pdf_path
132
 
133
  # Gradio Interface
134
  interface = gr.Interface(
135
  fn=predict,
136
- inputs=[
137
- gr.Image(type="pil", label="Upload Chest X-ray Image"),
138
- gr.Textbox(label="Patient Name"),
139
- gr.Textbox(label="Date of X-ray (YYYY-MM-DD)", placeholder="e.g. 2025-04-27")
140
- ],
141
- outputs=gr.File(label="Download PDF Report"),
142
- title="Chest X-ray Disease Classification with Report",
143
- description="Upload an X-ray, enter patient details, and download a detailed PDF report.",
144
  theme="default",
145
  allow_flagging="never"
146
  )
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torchvision.transforms as transforms
4
  from efficientnet_pytorch import EfficientNet
5
  from PIL import Image
 
 
 
6
 
7
+ # Define the HardSwish activation function
8
  class HardSwish(nn.Module):
9
  def __init__(self):
10
  super(HardSwish, self).__init__()
 
12
  def forward(self, x):
13
  return x * (torch.clamp(x + 3, 0, 6) / 6)
14
 
15
+ # Define the model class
16
  class CustomEfficientNet(nn.Module):
17
  def __init__(self, num_classes):
18
  super(CustomEfficientNet, self).__init__()
 
36
  # Device configuration
37
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
38
 
39
+ # Load the model
40
  model = CustomEfficientNet(num_classes=14)
41
  checkpoint = torch.load('Final_global_model.pth.tar', map_location=device)
42
  if 'state_dict' in checkpoint:
 
46
  model = model.to(device)
47
  model.eval()
48
 
49
+ # Transformations for input image
50
  transform = transforms.Compose([
51
  transforms.Resize((300, 300)),
52
  transforms.ToTensor(),
53
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
54
  ])
55
 
56
+ # Prediction function
57
+ def predict(image):
58
  if image is None:
59
  raise ValueError("❌ Error: No image uploaded.")
 
 
 
 
60
 
61
+ # Ensure image is in RGB mode
62
  if not isinstance(image, Image.Image):
63
  image = Image.fromarray(image)
64
 
65
  if image.mode != 'RGB':
66
  image = image.convert('RGB')
67
 
68
+ # Preprocess the image
69
  img = transform(image).unsqueeze(0).to(device)
70
 
71
+ # Prediction
72
  with torch.no_grad():
73
  outputs = model(img)
74
  probs = torch.sigmoid(outputs).cpu().numpy()[0]
75
 
76
+ # Prepare results
77
  results = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
78
  sorted_results = dict(sorted(results.items(), key=lambda item: item[1], reverse=True))
79
+ top5 = {k: v for k, v in list(sorted_results.items())[:5]}
80
+
81
+ return top5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  # Gradio Interface
84
  interface = gr.Interface(
85
  fn=predict,
86
+ inputs=gr.Image(type="pil"),
87
+ outputs=gr.Label(num_top_classes=5),
88
+ title="Chest X-ray Disease Classification",
89
+ description="Upload a chest X-ray image (JPG, PNG, BMP, TIFF, etc.) to get disease predictions.\n\nTop 5 diseases are shown with their probability.",
 
 
 
 
90
  theme="default",
91
  allow_flagging="never"
92
  )