hamsteryang commited on
Commit
415d148
·
1 Parent(s): 7998818

update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -39,7 +39,8 @@ densenet, densenet_transforms = create_densenet121_model(num_classes=1)
39
 
40
  # Load saved weights
41
  densenet.load_state_dict(torch.load("FL_global_model.pt", map_location=torch.device("cpu")))
42
-
 
43
 
44
 
45
  def predict(img) -> Tuple[Dict, float]:
@@ -49,13 +50,13 @@ def predict(img) -> Tuple[Dict, float]:
49
  start_time = timer()
50
 
51
  # Transform the target image and add a batch dimension
52
- img = effnetb2_transforms(img).unsqueeze(0)
53
 
54
  # Put model into evaluation mode and turn on inference mode
55
- effnetb2.eval()
56
  with torch.inference_mode():
57
  # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
58
- pred_probs = torch.sigmoid(effnetb2(img)).squeeze()
59
 
60
  # Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
61
  pred_labels_and_probs = {
 
39
 
40
  # Load saved weights
41
  densenet.load_state_dict(torch.load("FL_global_model.pt", map_location=torch.device("cpu")))
42
+ model_weights = state_dict["model"]
43
+ densenet.load_state_dict(model_weights)
44
 
45
 
46
  def predict(img) -> Tuple[Dict, float]:
 
50
  start_time = timer()
51
 
52
  # Transform the target image and add a batch dimension
53
+ img = densenet_transforms(img).unsqueeze(0)
54
 
55
  # Put model into evaluation mode and turn on inference mode
56
+ densenet.eval()
57
  with torch.inference_mode():
58
  # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
59
+ pred_probs = torch.sigmoid(densenet(img)).squeeze()
60
 
61
  # Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
62
  pred_labels_and_probs = {