resberry commited on
Commit
e30af19
·
verified ·
1 Parent(s): 0f9b243

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -110
app.py CHANGED
@@ -7,84 +7,105 @@ from PIL import Image
7
  import os
8
  import time
9
 
10
- # Define the transform for the input image
 
 
11
  transform = transforms.Compose([
12
- transforms.Resize((150, 150)),
13
  transforms.ToTensor(),
14
- transforms.Normalize((0.5,), (0.5,))
 
 
 
15
  ])
16
 
17
- # Load the trained ResNet50 model
 
 
18
  class FineTunedResNet(nn.Module):
19
  def __init__(self, num_classes=4):
20
- super(FineTunedResNet, self).__init__()
21
- self.resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) # Load pre-trained ResNet50
 
 
22
 
23
- # Replace the fully connected layer with more layers and batch normalization
24
  self.resnet.fc = nn.Sequential(
25
- nn.Linear(self.resnet.fc.in_features, 1024), # First additional layer
26
  nn.BatchNorm1d(1024),
27
  nn.ReLU(),
28
  nn.Dropout(0.5),
29
- nn.Linear(1024, 512), # Second additional layer
 
30
  nn.BatchNorm1d(512),
31
  nn.ReLU(),
32
  nn.Dropout(0.5),
33
- nn.Linear(512, 256), # Third additional layer
 
34
  nn.BatchNorm1d(256),
35
  nn.ReLU(),
36
  nn.Dropout(0.5),
37
- nn.Linear(256, num_classes) # Output layer
 
38
  )
39
 
40
  def forward(self, x):
41
  return self.resnet(x)
42
 
43
- model = FineTunedResNet(num_classes=4)
44
- model_path = 'models/final_fine_tuned_resnet50.pth'
 
 
45
 
46
- if not os.path.exists(model_path):
47
- raise FileNotFoundError(f"The model file '{model_path}' does not exist. Please check the path.")
48
 
49
- model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
 
50
  model.eval()
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- # Define a function to make predictions
53
- def predict(image):
54
- start_time = time.time() # Start the timer
55
- image = transform(image).unsqueeze(0) # Transform and add batch dimension
56
-
57
  with torch.no_grad():
58
  output = model(image)
59
- probabilities = F.softmax(output, dim=1)[0]
60
- top_prob, top_class = torch.topk(probabilities, 3)
61
- classes = ['🦠 COVID', '🫁 Normal', '🦠 Pneumonia', '🦠 TB'] # Adjust based on the classes in your model
62
-
63
- end_time = time.time() # End the timer
64
- prediction_time = end_time - start_time # Calculate the prediction time
65
-
66
- # Format the result string
67
- result = f"Top Predictions:\n"
68
- for i in range(top_prob.size(0)):
69
- result += f"{classes[top_class[i]]}: Score {top_prob[i].item()}\n"
70
- result += f"Prediction Time: {prediction_time:.2f} seconds"
71
-
72
  return result
73
 
74
- # Example images with labels
 
 
75
  examples = [
76
- ['examples/Pneumonia/02009view1_frontal.jpg', '🦠 Pneumonia'],
77
- ['examples/Pneumonia/02055view1_frontal.jpg', '🦠 Pneumonia'],
78
- ['examples/Pneumonia/03152view1_frontal.jpg', '🦠 Pneumonia'],
79
- ['examples/COVID/11547_2020_1200_Fig3_HTML-a.png', '🦠 COVID'],
80
- ['examples/COVID/11547_2020_1200_Fig3_HTML-b.png', '🦠 COVID'],
81
- ['examples/COVID/11547_2020_1203_Fig1_HTML-b.png', '🦠 COVID'],
82
- ['examples/Normal/06bc1cfe-23a0-43a4-a01b-dfa10314bbb0.jpg', '🫁 Normal'],
83
- ['examples/Normal/08ae6c0b-d044-4de2-a410-b3cf8dc65868.jpg', '🫁 Normal'],
84
- ['examples/Normal/IM-0178-0001.jpeg', '🫁 Normal']
85
  ]
86
 
87
- # Load visualization images
 
 
88
  visualization_images = [
89
  "pictures/1.png",
90
  "pictures/2.png",
@@ -93,86 +114,41 @@ visualization_images = [
93
  "pictures/5.png"
94
  ]
95
 
96
- # Function to display visualization images
97
  def display_visualizations():
98
- return [Image.open(image) for image in visualization_images]
99
-
100
- # Custom CSS to enhance appearance (injected via HTML)
101
- custom_css = """
102
- <style>
103
- body {
104
- font-family: 'Arial', sans-serif;
105
- background-color: #f5f5f5;
106
- }
107
- .gradio-container {
108
- background-color: #ffffff;
109
- border: 1px solid #e6e6e6;
110
- box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
111
- border-radius: 10px;
112
- padding: 20px;
113
- }
114
- .gradio-title {
115
- color: #333333;
116
- font-weight: bold;
117
- font-size: 24px;
118
- margin-bottom: 10px;
119
- }
120
- .gradio-description {
121
- color: #666666;
122
- font-size: 16px;
123
- margin-bottom: 20px;
124
- }
125
- .gradio-image {
126
- border-radius: 10px;
127
- }
128
- .gradio-button {
129
- background-color: #007bff;
130
- color: #ffffff;
131
- border: none;
132
- padding: 10px 20px;
133
- border-radius: 5px;
134
- cursor: pointer;
135
- }
136
- .gradio-button:hover {
137
- background-color: #0056b3;
138
- }
139
- .gradio-label {
140
- color: #007bff;
141
- font-weight: bold;
142
- }
143
- </style>
144
- """
145
 
146
- # Create Gradio interfaces
 
 
147
  prediction_interface = gr.Interface(
148
  fn=predict,
149
- inputs=gr.Image(type="pil", label="Upload Chest X-ray Image"),
150
- outputs=gr.Label(label="Predicted Disease"),
151
  examples=examples,
 
152
  title="Lung Disease Detection XVI",
153
- description=f"""
154
- Upload a chest X-ray image to detect lung diseases such as 🦠 COVID-19, 🦠 Pneumonia, 🫁 Normal, or 🦠 TB.
155
- Use the example images to see how the model works.
156
- {custom_css}
157
- """
158
  )
159
 
160
  visualization_interface = gr.Interface(
161
  fn=display_visualizations,
162
  inputs=None,
163
- outputs=[gr.Image(type="pil", label=f"Visualization {i+1}") for i in range(len(visualization_images))],
164
- title="Model Performance Visualizations",
165
- description=f"""
166
- Here are some visualizations that depict the performance of the model during training and testing.
167
- {custom_css}
168
- """
169
  )
170
 
171
- # Combine interfaces into a tabbed interface
172
  app = gr.TabbedInterface(
173
  interface_list=[prediction_interface, visualization_interface],
174
  tab_names=["Predict", "Model Performance"]
175
  )
176
 
177
- # Launch the interface
178
- app.launch(share=True)
 
 
 
7
  import os
8
  import time
9
 
10
+ # =========================
11
+ # Image preprocessing
12
+ # =========================
13
  transform = transforms.Compose([
14
+ transforms.Resize((224, 224)), # Required for ResNet50
15
  transforms.ToTensor(),
16
+ transforms.Normalize(
17
+ mean=[0.485, 0.456, 0.406],
18
+ std=[0.229, 0.224, 0.225]
19
+ )
20
  ])
21
 
22
+ # =========================
23
+ # Model definition
24
+ # =========================
25
  class FineTunedResNet(nn.Module):
26
  def __init__(self, num_classes=4):
27
+ super().__init__()
28
+ self.resnet = models.resnet50(
29
+ weights=models.ResNet50_Weights.DEFAULT
30
+ )
31
 
 
32
  self.resnet.fc = nn.Sequential(
33
+ nn.Linear(self.resnet.fc.in_features, 1024),
34
  nn.BatchNorm1d(1024),
35
  nn.ReLU(),
36
  nn.Dropout(0.5),
37
+
38
+ nn.Linear(1024, 512),
39
  nn.BatchNorm1d(512),
40
  nn.ReLU(),
41
  nn.Dropout(0.5),
42
+
43
+ nn.Linear(512, 256),
44
  nn.BatchNorm1d(256),
45
  nn.ReLU(),
46
  nn.Dropout(0.5),
47
+
48
+ nn.Linear(256, num_classes)
49
  )
50
 
51
  def forward(self, x):
52
  return self.resnet(x)
53
 
54
+ # =========================
55
+ # Load model
56
+ # =========================
57
+ MODEL_PATH = "models/final_fine_tuned_resnet50.pth"
58
 
59
+ if not os.path.exists(MODEL_PATH):
60
+ raise FileNotFoundError(f"Model not found: {MODEL_PATH}")
61
 
62
+ model = FineTunedResNet(num_classes=4)
63
+ model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
64
  model.eval()
65
+ model.to("cpu")
66
+
67
+ CLASSES = ["🦠 COVID", "🫁 Normal", "🦠 Pneumonia", "🦠 TB"]
68
+
69
+ # =========================
70
+ # Prediction function
71
+ # =========================
72
+ def predict(image: Image.Image) -> str:
73
+ start = time.time()
74
+
75
+ image = transform(image).unsqueeze(0)
76
 
 
 
 
 
 
77
  with torch.no_grad():
78
  output = model(image)
79
+ probs = F.softmax(output, dim=1)[0]
80
+ top_probs, top_idxs = torch.topk(probs, 3)
81
+
82
+ elapsed = time.time() - start
83
+
84
+ result = "Top Predictions:\n\n"
85
+ for prob, idx in zip(top_probs, top_idxs):
86
+ result += f"{CLASSES[idx]} {prob.item():.4f}\n"
87
+
88
+ result += f"\n⏱️ Prediction Time: {elapsed:.2f} seconds"
 
 
 
89
  return result
90
 
91
+ # =========================
92
+ # Example images
93
+ # =========================
94
  examples = [
95
+ ["examples/Pneumonia/02009view1_frontal.jpg"],
96
+ ["examples/Pneumonia/02055view1_frontal.jpg"],
97
+ ["examples/Pneumonia/03152view1_frontal.jpg"],
98
+ ["examples/COVID/11547_2020_1200_Fig3_HTML-a.png"],
99
+ ["examples/COVID/11547_2020_1200_Fig3_HTML-b.png"],
100
+ ["examples/COVID/11547_2020_1203_Fig1_HTML-b.png"],
101
+ ["examples/Normal/06bc1cfe-23a0-43a4-a01b-dfa10314bbb0.jpg"],
102
+ ["examples/Normal/08ae6c0b-d044-4de2-a410-b3cf8dc65868.jpg"],
103
+ ["examples/Normal/IM-0178-0001.jpeg"]
104
  ]
105
 
106
+ # =========================
107
+ # Visualization images
108
+ # =========================
109
  visualization_images = [
110
  "pictures/1.png",
111
  "pictures/2.png",
 
114
  "pictures/5.png"
115
  ]
116
 
 
117
  def display_visualizations():
118
+ return [Image.open(path) for path in visualization_images]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
+ # =========================
121
+ # Gradio interfaces
122
+ # =========================
123
  prediction_interface = gr.Interface(
124
  fn=predict,
125
+ inputs=gr.Image(type="pil", label="Upload Chest X-ray"),
126
+ outputs=gr.Textbox(label="Prediction Result"),
127
  examples=examples,
128
+ cache_examples=False, # IMPORTANT for HF Spaces
129
  title="Lung Disease Detection XVI",
130
+ description="""
131
+ Upload a chest X-ray image to detect:
132
+ 🦠 COVID-19 🦠 Pneumonia 🫁 Normal 🦠 Tuberculosis
133
+ """
 
134
  )
135
 
136
  visualization_interface = gr.Interface(
137
  fn=display_visualizations,
138
  inputs=None,
139
+ outputs=[
140
+ gr.Image(type="pil", label=f"Visualization {i+1}")
141
+ for i in range(len(visualization_images))
142
+ ],
143
+ title="Model Performance Visualizations"
 
144
  )
145
 
 
146
  app = gr.TabbedInterface(
147
  interface_list=[prediction_interface, visualization_interface],
148
  tab_names=["Predict", "Model Performance"]
149
  )
150
 
151
+ # =========================
152
+ # Launch (HF Spaces safe)
153
+ # =========================
154
+ app.launch()