Spaces:
Sleeping
Sleeping
Commit
·
ea236cf
1
Parent(s):
ed114bb
update app.py 2023-08-31-08:12
Browse files
app.py
CHANGED
|
@@ -8,14 +8,13 @@ from typing import Tuple, Dict
|
|
| 8 |
import torchvision
|
| 9 |
|
| 10 |
from torch import nn
|
| 11 |
-
|
| 12 |
from torchvision.models import densenet121
|
| 13 |
|
| 14 |
-
def create_densenet121_model(num_classes: int =
|
| 15 |
"""Creates a DenseNet121 model and transforms."""
|
| 16 |
|
| 17 |
# Create DenseNet121 model
|
| 18 |
-
model = densenet121(
|
| 19 |
|
| 20 |
# Freeze all layers in base model
|
| 21 |
for param in model.parameters():
|
|
@@ -25,7 +24,6 @@ def create_densenet121_model(num_classes: int = 1, seed: int = 42):
|
|
| 25 |
torch.manual_seed(seed)
|
| 26 |
model.classifier = nn.Linear(model.classifier.in_features, num_classes)
|
| 27 |
|
| 28 |
-
# You might want to use the appropriate transforms for densenet121 here
|
| 29 |
transforms = torchvision.transforms.Compose([
|
| 30 |
torchvision.transforms.Resize(224),
|
| 31 |
torchvision.transforms.ToTensor(),
|
|
@@ -35,26 +33,15 @@ def create_densenet121_model(num_classes: int = 1, seed: int = 42):
|
|
| 35 |
return model, transforms
|
| 36 |
|
| 37 |
# Create densenet121 model
|
| 38 |
-
densenet, densenet_transforms = create_densenet121_model(
|
| 39 |
|
| 40 |
# Load saved weights
|
| 41 |
-
# densenet.load_state_dict(torch.load("FL_global_model.pt", map_location=torch.device("cpu")))
|
| 42 |
state_dict = torch.load("FL_global_model.pt", map_location=torch.device("cpu"))
|
| 43 |
-
print("==============")
|
| 44 |
-
print(state_dict.keys())
|
| 45 |
-
print("==============")
|
| 46 |
model_weights = state_dict["model"]
|
| 47 |
-
densenet.load_state_dict(model_weights,strict=
|
| 48 |
-
'''
|
| 49 |
-
weights = {k: torch.from_numpy(v).to(self.device) if isinstance(v, np.ndarray) else v.to(self.device) for k, v in weights.items()}
|
| 50 |
-
# creat new state_dict and del fc.weight andfc.bias
|
| 51 |
-
new_state_dict = {k: v for k, v in weights.items() if k not in ["fc.weight", "fc.bias"]}
|
| 52 |
-
self.model.load_state_dict(new_state_dict, strict=False)
|
| 53 |
-
'''
|
| 54 |
|
| 55 |
def predict(img) -> Tuple[Dict, float]:
|
| 56 |
-
"""Transforms and performs a prediction on img and returns prediction and time taken.
|
| 57 |
-
"""
|
| 58 |
# Start the timer
|
| 59 |
start_time = timer()
|
| 60 |
|
|
@@ -64,12 +51,12 @@ def predict(img) -> Tuple[Dict, float]:
|
|
| 64 |
# Put model into evaluation mode and turn on inference mode
|
| 65 |
densenet.eval()
|
| 66 |
with torch.inference_mode():
|
| 67 |
-
|
| 68 |
-
pred_probs = torch.sigmoid(densenet(img)).squeeze()
|
| 69 |
|
| 70 |
-
# Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
|
| 71 |
pred_labels_and_probs = {
|
| 72 |
-
'Normal':
|
|
|
|
|
|
|
| 73 |
|
| 74 |
# Calculate the prediction time
|
| 75 |
pred_time = round(timer() - start_time, 5)
|
|
@@ -77,23 +64,21 @@ def predict(img) -> Tuple[Dict, float]:
|
|
| 77 |
# Return the prediction dictionary and prediction time
|
| 78 |
return pred_labels_and_probs, pred_time
|
| 79 |
|
| 80 |
-
|
| 81 |
example_list = [[f"examples/example{i+1}.jpg"] for i in range(3)]
|
| 82 |
-
|
| 83 |
title = "ChestXray Classification"
|
| 84 |
-
description = "
|
| 85 |
article = "Created at (https://github.com/azizche/chest_xray_Classification)."
|
| 86 |
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
|
| 98 |
# Launch the demo!
|
| 99 |
demo.launch()
|
|
|
|
| 8 |
import torchvision
|
| 9 |
|
| 10 |
from torch import nn
|
|
|
|
| 11 |
from torchvision.models import densenet121
|
| 12 |
|
| 13 |
+
def create_densenet121_model(num_classes: int = 2, seed: int = 42):
|
| 14 |
"""Creates a DenseNet121 model and transforms."""
|
| 15 |
|
| 16 |
# Create DenseNet121 model
|
| 17 |
+
model = densenet121(weights=None) # Set to None since we will be loading our own weights
|
| 18 |
|
| 19 |
# Freeze all layers in base model
|
| 20 |
for param in model.parameters():
|
|
|
|
| 24 |
torch.manual_seed(seed)
|
| 25 |
model.classifier = nn.Linear(model.classifier.in_features, num_classes)
|
| 26 |
|
|
|
|
| 27 |
transforms = torchvision.transforms.Compose([
|
| 28 |
torchvision.transforms.Resize(224),
|
| 29 |
torchvision.transforms.ToTensor(),
|
|
|
|
| 33 |
return model, transforms
|
| 34 |
|
| 35 |
# Create densenet121 model
|
| 36 |
+
densenet, densenet_transforms = create_densenet121_model()
|
| 37 |
|
| 38 |
# Load saved weights
|
|
|
|
| 39 |
state_dict = torch.load("FL_global_model.pt", map_location=torch.device("cpu"))
|
|
|
|
|
|
|
|
|
|
| 40 |
model_weights = state_dict["model"]
|
| 41 |
+
densenet.load_state_dict(model_weights,strict=True) # Set strict to True since we now expect it to match
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
def predict(img) -> Tuple[Dict, float]:
|
| 44 |
+
"""Transforms and performs a prediction on img and returns prediction and time taken."""
|
|
|
|
| 45 |
# Start the timer
|
| 46 |
start_time = timer()
|
| 47 |
|
|
|
|
| 51 |
# Put model into evaluation mode and turn on inference mode
|
| 52 |
densenet.eval()
|
| 53 |
with torch.inference_mode():
|
| 54 |
+
pred_probs = torch.softmax(densenet(img), dim=1).squeeze()
|
|
|
|
| 55 |
|
|
|
|
| 56 |
pred_labels_and_probs = {
|
| 57 |
+
'Normal': pred_probs[0].item(),
|
| 58 |
+
'Nodules': pred_probs[1].item()
|
| 59 |
+
}
|
| 60 |
|
| 61 |
# Calculate the prediction time
|
| 62 |
pred_time = round(timer() - start_time, 5)
|
|
|
|
| 64 |
# Return the prediction dictionary and prediction time
|
| 65 |
return pred_labels_and_probs, pred_time
|
| 66 |
|
|
|
|
| 67 |
example_list = [[f"examples/example{i+1}.jpg"] for i in range(3)]
|
| 68 |
+
|
| 69 |
title = "ChestXray Classification"
|
| 70 |
+
description = "A Densenet121 computer vision model to classify images of Xray Chest images as Normal or Nodules."
|
| 71 |
article = "Created at (https://github.com/azizche/chest_xray_Classification)."
|
| 72 |
|
| 73 |
+
demo = gr.Interface(
|
| 74 |
+
fn=predict,
|
| 75 |
+
inputs=gr.Image(type="pil"),
|
| 76 |
+
outputs=[gr.Label(num_top_classes=2, label="Predictions"), gr.Number(label="Prediction time (s)")],
|
| 77 |
+
examples=example_list,
|
| 78 |
+
title=title,
|
| 79 |
+
description=description,
|
| 80 |
+
article=article
|
| 81 |
+
)
|
|
|
|
| 82 |
|
| 83 |
# Launch the demo!
|
| 84 |
demo.launch()
|