Spaces:
Sleeping
Sleeping
Commit
·
a623a02
1
Parent(s):
c8ef432
update app.py 2023-08-31 07:59
Browse files
app.py
CHANGED
|
@@ -35,10 +35,14 @@ def create_densenet121_model(num_classes: int = 1, seed: int = 42):
|
|
| 35 |
return model, transforms
|
| 36 |
|
| 37 |
# Create densenet121 model
|
| 38 |
-
densenet, densenet_transforms = create_densenet121_model(num_classes=
|
| 39 |
|
| 40 |
# Load saved weights
|
| 41 |
-
densenet.load_state_dict(torch.load("FL_global_model.pt", map_location=torch.device("cpu")))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
model_weights = state_dict["model"]
|
| 43 |
densenet.load_state_dict(model_weights,strict=False)
|
| 44 |
'''
|
|
|
|
| 35 |
return model, transforms
|
| 36 |
|
| 37 |
# Create densenet121 model
|
| 38 |
+
densenet, densenet_transforms = create_densenet121_model(num_classes=2)
|
| 39 |
|
| 40 |
# Load saved weights
|
| 41 |
+
# densenet.load_state_dict(torch.load("FL_global_model.pt", map_location=torch.device("cpu")))
|
| 42 |
+
state_dict = torch.load("FL_global_model.pt", map_location=torch.device("cpu"))
|
| 43 |
+
print("==============")
|
| 44 |
+
print(state_dict.keys())
|
| 45 |
+
print("==============")
|
| 46 |
model_weights = state_dict["model"]
|
| 47 |
densenet.load_state_dict(model_weights,strict=False)
|
| 48 |
'''
|