Zilun commited on
Commit
457dd08
·
1 Parent(s): d3beb1f

Delete codebase/inference/convert_weight.py

Browse files
codebase/inference/convert_weight.py DELETED
@@ -1,34 +0,0 @@
1
- import torch
2
- import open_clip
3
- import os
4
-
5
-
6
- def main():
7
- # trained_ckpt_path = "/home/zilun/RS5M_v5/ckpt/epoch_5.pt"
8
- # model, _, _ = open_clip.create_model_and_transforms("ViT-B/32", pretrained="openai")
9
-
10
- trained_ckpt_path = "/home/zilun/RS5M_v5/ckpt/epoch_2.pt"
11
- model, _, _ = open_clip.create_model_and_transforms("ViT-H/14", pretrained="openclip")
12
-
13
- checkpoint = torch.load(trained_ckpt_path, map_location="cpu")["state_dict"]
14
- sd = {k: v for k, v in checkpoint.items()}
15
- for key in list(sd.keys()):
16
- if "text_backbone." in key:
17
- sd[key.replace("text_backbone.", '')] = sd[key]
18
- del sd[key]
19
- if "image_backbone" in key:
20
- sd[key.replace("image_backbone.", "visual.")] = sd[key]
21
- del sd[key]
22
-
23
- msg = model.load_state_dict(sd, strict=False)
24
- print(msg)
25
- print("loaded RSCLIP")
26
-
27
- torch.save(
28
- model.state_dict(),
29
- os.path.join("/home/zilun/RS5M_v5/ckpt", "RS5M_ViT-B-32.pt"),
30
- )
31
-
32
-
33
- if __name__ == "__main__":
34
- main()