Delete codebase/inference/convert_weight.py
Browse files
codebase/inference/convert_weight.py
DELETED
|
@@ -1,34 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import open_clip
|
| 3 |
-
import os
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
def main():
|
| 7 |
-
# trained_ckpt_path = "/home/zilun/RS5M_v5/ckpt/epoch_5.pt"
|
| 8 |
-
# model, _, _ = open_clip.create_model_and_transforms("ViT-B/32", pretrained="openai")
|
| 9 |
-
|
| 10 |
-
trained_ckpt_path = "/home/zilun/RS5M_v5/ckpt/epoch_2.pt"
|
| 11 |
-
model, _, _ = open_clip.create_model_and_transforms("ViT-H/14", pretrained="openclip")
|
| 12 |
-
|
| 13 |
-
checkpoint = torch.load(trained_ckpt_path, map_location="cpu")["state_dict"]
|
| 14 |
-
sd = {k: v for k, v in checkpoint.items()}
|
| 15 |
-
for key in list(sd.keys()):
|
| 16 |
-
if "text_backbone." in key:
|
| 17 |
-
sd[key.replace("text_backbone.", '')] = sd[key]
|
| 18 |
-
del sd[key]
|
| 19 |
-
if "image_backbone" in key:
|
| 20 |
-
sd[key.replace("image_backbone.", "visual.")] = sd[key]
|
| 21 |
-
del sd[key]
|
| 22 |
-
|
| 23 |
-
msg = model.load_state_dict(sd, strict=False)
|
| 24 |
-
print(msg)
|
| 25 |
-
print("loaded RSCLIP")
|
| 26 |
-
|
| 27 |
-
torch.save(
|
| 28 |
-
model.state_dict(),
|
| 29 |
-
os.path.join("/home/zilun/RS5M_v5/ckpt", "RS5M_ViT-B-32.pt"),
|
| 30 |
-
)
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
if __name__ == "__main__":
|
| 34 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|