priyansh-nagar commited on
Commit
24d0ae3
·
verified ·
1 Parent(s): 02717b0

Upload convert_to_state.py

Browse files
Files changed (1) hide show
  1. convert_to_state.py +14 -0
convert_to_state.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import models
3
+
4
+ # Step 1: create the model architecture
5
+ model = models.efficientnet_b0(pretrained=False)
6
+ model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, 2)
7
+
8
+ # Step 2: load your previous checkpoint
9
+ state_dict = torch.load("models/efficientnet_b0_ffpp_c23.pth", map_location="cpu")
10
+ model.load_state_dict(state_dict, strict=False)
11
+
12
+ # Step 3: save only the weights
13
+ torch.save(model.state_dict(), "models/deeptrust_weights.pt")
14
+ print("Saved state_dict as models/deeptrust_weights.pt")