hamsteryang commited on
Commit
a623a02
·
1 Parent(s): c8ef432

update app.py 2023-08-31 07:59

Browse files
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -35,10 +35,14 @@ def create_densenet121_model(num_classes: int = 1, seed: int = 42):
35
  return model, transforms
36
 
37
  # Create densenet121 model
38
- densenet, densenet_transforms = create_densenet121_model(num_classes=1)
39
 
40
  # Load saved weights
41
- densenet.load_state_dict(torch.load("FL_global_model.pt", map_location=torch.device("cpu")))
 
 
 
 
42
  model_weights = state_dict["model"]
43
  densenet.load_state_dict(model_weights,strict=False)
44
  '''
 
35
  return model, transforms
36
 
37
  # Create densenet121 model
38
+ densenet, densenet_transforms = create_densenet121_model(num_classes=2)
39
 
40
  # Load saved weights
41
+ # densenet.load_state_dict(torch.load("FL_global_model.pt", map_location=torch.device("cpu")))
42
+ state_dict = torch.load("FL_global_model.pt", map_location=torch.device("cpu"))
43
+ print("==============")
44
+ print(state_dict.keys())
45
+ print("==============")
46
  model_weights = state_dict["model"]
47
  densenet.load_state_dict(model_weights,strict=False)
48
  '''