hamsteryang commited on
Commit
0c73815
·
1 Parent(s): 415d148

update app.py 2023-08-31 07:48

Browse files
Files changed (1) hide show
  1. app.py +7 -2
app.py CHANGED
@@ -40,8 +40,13 @@ densenet, densenet_transforms = create_densenet121_model(num_classes=1)
40
  # Load saved weights
41
  densenet.load_state_dict(torch.load("FL_global_model.pt", map_location=torch.device("cpu")))
42
  model_weights = state_dict["model"]
43
- densenet.load_state_dict(model_weights)
44
-
 
 
 
 
 
45
 
46
  def predict(img) -> Tuple[Dict, float]:
47
  """Transforms and performs a prediction on img and returns prediction and time taken.
 
40
  # Load saved weights
41
  densenet.load_state_dict(torch.load("FL_global_model.pt", map_location=torch.device("cpu")))
42
  model_weights = state_dict["model"]
43
+ densenet.load_state_dict(model_weights,strict=False)
44
+ '''
45
+ weights = {k: torch.from_numpy(v).to(self.device) if isinstance(v, np.ndarray) else v.to(self.device) for k, v in weights.items()}
46
+ # creat new state_dict and del fc.weight andfc.bias
47
+ new_state_dict = {k: v for k, v in weights.items() if k not in ["fc.weight", "fc.bias"]}
48
+ self.model.load_state_dict(new_state_dict, strict=False)
49
+ '''
50
 
51
  def predict(img) -> Tuple[Dict, float]:
52
  """Transforms and performs a prediction on img and returns prediction and time taken.