Mrhuman1 commited on
Commit
4f3fb6c
ยท
verified ยท
1 Parent(s): 8bb8257

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -27
app.py CHANGED
@@ -4,8 +4,9 @@ import torch.nn as nn
4
  import torchvision.transforms as transforms
5
  from efficientnet_pytorch import EfficientNet
6
  from PIL import Image
 
7
 
8
- # Define the HardSwish activation function
9
  class HardSwish(nn.Module):
10
  def __init__(self):
11
  super(HardSwish, self).__init__()
@@ -13,7 +14,7 @@ class HardSwish(nn.Module):
13
  def forward(self, x):
14
  return x * (torch.clamp(x + 3, 0, 6) / 6)
15
 
16
- # Define the model class
17
  class CustomEfficientNet(nn.Module):
18
  def __init__(self, num_classes):
19
  super(CustomEfficientNet, self).__init__()
@@ -30,14 +31,16 @@ class CustomEfficientNet(nn.Module):
30
  return self.model(x)
31
 
32
  # Class names
33
- class_names = ['No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity',
34
- 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis', 'Pneumothorax',
35
- 'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices']
 
 
36
 
37
  # Device configuration
38
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
39
 
40
- # Load the model
41
  model = CustomEfficientNet(num_classes=14)
42
  checkpoint = torch.load('Final_global_model.pth.tar', map_location=device)
43
  if 'state_dict' in checkpoint:
@@ -47,7 +50,7 @@ else:
47
  model = model.to(device)
48
  model.eval()
49
 
50
- # Transformations for input image
51
  transform = transforms.Compose([
52
  transforms.Resize((300, 300)),
53
  transforms.ToTensor(),
@@ -55,42 +58,53 @@ transform = transforms.Compose([
55
  ])
56
 
57
  # Prediction function
58
- def predict(image):
59
  if image is None:
60
  raise ValueError("โŒ Error: No image uploaded.")
61
 
62
- # Ensure image is in RGB mode
63
  if not isinstance(image, Image.Image):
64
  image = Image.fromarray(image)
65
 
66
  if image.mode != 'RGB':
67
  image = image.convert('RGB')
68
 
69
- # Preprocess the image
70
  img = transform(image).unsqueeze(0).to(device)
71
 
72
- # Prediction
73
  with torch.no_grad():
74
  outputs = model(img)
75
  probs = torch.sigmoid(outputs).cpu().numpy()[0]
76
 
77
- # Prepare results
78
  results = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
79
  sorted_results = dict(sorted(results.items(), key=lambda item: item[1], reverse=True))
80
- top5 = {k: v for k, v in list(sorted_results.items())[:5]}
81
-
82
- return top5
83
-
84
- # Gradio Interface
85
- interface = gr.Interface(
86
- fn=predict,
87
- inputs=gr.Image(type="pil"),
88
- outputs=gr.Label(num_top_classes=5),
89
- title="Chest X-ray Disease Classification",
90
- description="Upload a chest X-ray image (JPG, PNG, BMP, TIFF, etc.) to get disease predictions.\n\nTop 5 diseases are shown with their probability.",
91
- theme="default",
92
- allow_flagging="never"
93
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  if __name__ == "__main__":
96
- interface.launch()
 
4
  import torchvision.transforms as transforms
5
  from efficientnet_pytorch import EfficientNet
6
  from PIL import Image
7
+ from datetime import datetime
8
 
9
+ # Define HardSwish activation
10
  class HardSwish(nn.Module):
11
  def __init__(self):
12
  super(HardSwish, self).__init__()
 
14
  def forward(self, x):
15
  return x * (torch.clamp(x + 3, 0, 6) / 6)
16
 
17
+ # Define custom EfficientNet model
18
  class CustomEfficientNet(nn.Module):
19
  def __init__(self, num_classes):
20
  super(CustomEfficientNet, self).__init__()
 
31
  return self.model(x)
32
 
33
  # Class names
34
+ class_names = [
35
+ 'No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity',
36
+ 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis', 'Pneumothorax',
37
+ 'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices'
38
+ ]
39
 
40
  # Device configuration
41
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
42
 
43
+ # Load model
44
  model = CustomEfficientNet(num_classes=14)
45
  checkpoint = torch.load('Final_global_model.pth.tar', map_location=device)
46
  if 'state_dict' in checkpoint:
 
50
  model = model.to(device)
51
  model.eval()
52
 
53
+ # Transformations
54
  transform = transforms.Compose([
55
  transforms.Resize((300, 300)),
56
  transforms.ToTensor(),
 
58
  ])
59
 
60
  # Prediction function
61
+ def predict(patient_name, scan_date, image):
62
  if image is None:
63
  raise ValueError("โŒ Error: No image uploaded.")
64
 
65
+ # Ensure image is RGB
66
  if not isinstance(image, Image.Image):
67
  image = Image.fromarray(image)
68
 
69
  if image.mode != 'RGB':
70
  image = image.convert('RGB')
71
 
 
72
  img = transform(image).unsqueeze(0).to(device)
73
 
 
74
  with torch.no_grad():
75
  outputs = model(img)
76
  probs = torch.sigmoid(outputs).cpu().numpy()[0]
77
 
 
78
  results = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
79
  sorted_results = dict(sorted(results.items(), key=lambda item: item[1], reverse=True))
80
+ top5 = {k: f"{v*100:.2f}%" for k, v in list(sorted_results.items())[:5]}
81
+
82
+ summary = f"๐Ÿ“‹ **Patient Name**: {patient_name}\n๐Ÿ“… **Scan Date**: {scan_date.strftime('%Y-%m-%d')}\n\n### Top 5 Predictions"
83
+ return summary, top5
84
+
85
+ # Gradio UI
86
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
87
+ gr.Markdown(
88
+ """
89
+ # ๐Ÿฉบ Chest X-ray Disease Classifier
90
+ Upload a chest X-ray and get the top 5 predicted diseases with probability scores.
91
+ """
92
+ )
93
+ with gr.Row():
94
+ with gr.Column():
95
+ patient_name = gr.Textbox(label="Patient Name", placeholder="Enter full name...")
96
+ scan_date = gr.Date(label="Scan Date", value=datetime.today)
97
+ image = gr.Image(label="Chest X-ray Image", type="pil")
98
+ predict_button = gr.Button("๐Ÿ” Predict")
99
+ with gr.Column():
100
+ summary = gr.Markdown()
101
+ output = gr.Label(num_top_classes=5)
102
+
103
+ predict_button.click(
104
+ predict,
105
+ inputs=[patient_name, scan_date, image],
106
+ outputs=[summary, output]
107
+ )
108
 
109
  if __name__ == "__main__":
110
+ demo.launch()