hamsteryang commited on
Commit
ea236cf
·
1 Parent(s): ed114bb

update app.py 2023-08-31-08:12

Browse files
Files changed (1) hide show
  1. app.py +20 -35
app.py CHANGED
@@ -8,14 +8,13 @@ from typing import Tuple, Dict
8
  import torchvision
9
 
10
  from torch import nn
11
-
12
  from torchvision.models import densenet121
13
 
14
- def create_densenet121_model(num_classes: int = 1, seed: int = 42):
15
  """Creates a DenseNet121 model and transforms."""
16
 
17
  # Create DenseNet121 model
18
- model = densenet121(pretrained=False) # Set to False since we will be loading our own weights
19
 
20
  # Freeze all layers in base model
21
  for param in model.parameters():
@@ -25,7 +24,6 @@ def create_densenet121_model(num_classes: int = 1, seed: int = 42):
25
  torch.manual_seed(seed)
26
  model.classifier = nn.Linear(model.classifier.in_features, num_classes)
27
 
28
- # You might want to use the appropriate transforms for densenet121 here
29
  transforms = torchvision.transforms.Compose([
30
  torchvision.transforms.Resize(224),
31
  torchvision.transforms.ToTensor(),
@@ -35,26 +33,15 @@ def create_densenet121_model(num_classes: int = 1, seed: int = 42):
35
  return model, transforms
36
 
37
  # Create densenet121 model
38
- densenet, densenet_transforms = create_densenet121_model(num_classes=1)
39
 
40
  # Load saved weights
41
- # densenet.load_state_dict(torch.load("FL_global_model.pt", map_location=torch.device("cpu")))
42
  state_dict = torch.load("FL_global_model.pt", map_location=torch.device("cpu"))
43
- print("==============")
44
- print(state_dict.keys())
45
- print("==============")
46
  model_weights = state_dict["model"]
47
- densenet.load_state_dict(model_weights,strict=False)
48
- '''
49
- weights = {k: torch.from_numpy(v).to(self.device) if isinstance(v, np.ndarray) else v.to(self.device) for k, v in weights.items()}
50
- # creat new state_dict and del fc.weight andfc.bias
51
- new_state_dict = {k: v for k, v in weights.items() if k not in ["fc.weight", "fc.bias"]}
52
- self.model.load_state_dict(new_state_dict, strict=False)
53
- '''
54
 
55
  def predict(img) -> Tuple[Dict, float]:
56
- """Transforms and performs a prediction on img and returns prediction and time taken.
57
- """
58
  # Start the timer
59
  start_time = timer()
60
 
@@ -64,12 +51,12 @@ def predict(img) -> Tuple[Dict, float]:
64
  # Put model into evaluation mode and turn on inference mode
65
  densenet.eval()
66
  with torch.inference_mode():
67
- # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
68
- pred_probs = torch.sigmoid(densenet(img)).squeeze()
69
 
70
- # Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
71
  pred_labels_and_probs = {
72
- 'Normal': 1-pred_probs.item(), 'Pneumonia': pred_probs.item()}
 
 
73
 
74
  # Calculate the prediction time
75
  pred_time = round(timer() - start_time, 5)
@@ -77,23 +64,21 @@ def predict(img) -> Tuple[Dict, float]:
77
  # Return the prediction dictionary and prediction time
78
  return pred_labels_and_probs, pred_time
79
 
80
-
81
  example_list = [[f"examples/example{i+1}.jpg"] for i in range(3)]
82
- # Create title, description and article strings
83
  title = "ChestXray Classification"
84
- description = "An Alexnet computer vision model to classify images of Xray Chest images as Normal or Pneumonia."
85
  article = "Created at (https://github.com/azizche/chest_xray_Classification)."
86
 
87
- # Create the Gradio demo
88
- demo = gr.Interface(fn=predict, # mapping function from input to output
89
- inputs=gr.Image(type="pil"), # what are the inputs?
90
- outputs=[gr.Label(num_top_classes=2, label="Predictions"), # what are the outputs?
91
- gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
92
- examples=example_list,
93
- title=title,
94
- description=description,
95
- article=article)
96
-
97
 
98
  # Launch the demo!
99
  demo.launch()
 
8
  import torchvision
9
 
10
  from torch import nn
 
11
  from torchvision.models import densenet121
12
 
13
+ def create_densenet121_model(num_classes: int = 2, seed: int = 42):
14
  """Creates a DenseNet121 model and transforms."""
15
 
16
  # Create DenseNet121 model
17
+ model = densenet121(weights=None) # Set to None since we will be loading our own weights
18
 
19
  # Freeze all layers in base model
20
  for param in model.parameters():
 
24
  torch.manual_seed(seed)
25
  model.classifier = nn.Linear(model.classifier.in_features, num_classes)
26
 
 
27
  transforms = torchvision.transforms.Compose([
28
  torchvision.transforms.Resize(224),
29
  torchvision.transforms.ToTensor(),
 
33
  return model, transforms
34
 
35
  # Create densenet121 model
36
+ densenet, densenet_transforms = create_densenet121_model()
37
 
38
  # Load saved weights
 
39
  state_dict = torch.load("FL_global_model.pt", map_location=torch.device("cpu"))
 
 
 
40
  model_weights = state_dict["model"]
41
+ densenet.load_state_dict(model_weights,strict=True) # Set strict to True since we now expect it to match
 
 
 
 
 
 
42
 
43
  def predict(img) -> Tuple[Dict, float]:
44
+ """Transforms and performs a prediction on img and returns prediction and time taken."""
 
45
  # Start the timer
46
  start_time = timer()
47
 
 
51
  # Put model into evaluation mode and turn on inference mode
52
  densenet.eval()
53
  with torch.inference_mode():
54
+ pred_probs = torch.softmax(densenet(img), dim=1).squeeze()
 
55
 
 
56
  pred_labels_and_probs = {
57
+ 'Normal': pred_probs[0].item(),
58
+ 'Nodules': pred_probs[1].item()
59
+ }
60
 
61
  # Calculate the prediction time
62
  pred_time = round(timer() - start_time, 5)
 
64
  # Return the prediction dictionary and prediction time
65
  return pred_labels_and_probs, pred_time
66
 
 
67
  example_list = [[f"examples/example{i+1}.jpg"] for i in range(3)]
68
+
69
  title = "ChestXray Classification"
70
+ description = "A Densenet121 computer vision model to classify images of Xray Chest images as Normal or Nodules."
71
  article = "Created at (https://github.com/azizche/chest_xray_Classification)."
72
 
73
+ demo = gr.Interface(
74
+ fn=predict,
75
+ inputs=gr.Image(type="pil"),
76
+ outputs=[gr.Label(num_top_classes=2, label="Predictions"), gr.Number(label="Prediction time (s)")],
77
+ examples=example_list,
78
+ title=title,
79
+ description=description,
80
+ article=article
81
+ )
 
82
 
83
  # Launch the demo!
84
  demo.launch()