BiliSakura commited on
Commit
80a7dc6
·
verified ·
1 Parent(s): 3594262

Update all files for SegEarth-OV

Browse files
Files changed (1) hide show
  1. convert_to_safetensors.py +113 -0
convert_to_safetensors.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Convert CLIP and SAM3 checkpoints to safetensors format.
4
+ Run from repo root: python convert_to_safetensors.py
5
+ """
6
+ import argparse
7
+ from pathlib import Path
8
+
9
+ import torch
10
+ from safetensors.torch import save_file
11
+
12
+
13
+ def convert_clip(source_dir: Path, output_dir: Path = None) -> Path:
14
+ """Convert CLIP pytorch_model.bin to model.safetensors."""
15
+ output_dir = output_dir or source_dir
16
+ bin_path = source_dir / "pytorch_model.bin"
17
+ out_path = output_dir / "model.safetensors"
18
+
19
+ if out_path.exists():
20
+ print(f"CLIP: {out_path} already exists, skip")
21
+ return out_path
22
+
23
+ if bin_path.exists():
24
+ print(f"CLIP: Loading from {bin_path}...")
25
+ state_dict = torch.load(bin_path, map_location="cpu", weights_only=True)
26
+ else:
27
+ print(f"CLIP: Loading from HuggingFace...")
28
+ from transformers import CLIPModel
29
+ model = CLIPModel.from_pretrained(str(source_dir) if source_dir.exists() else "openai/clip-vit-base-patch16")
30
+ state_dict = model.state_dict()
31
+
32
+ state_dict = {k: v.float() if v.dtype in (torch.float16, torch.bfloat16) else v for k, v in state_dict.items()}
33
+ save_file(state_dict, str(out_path))
34
+ print(f"CLIP: Saved to {out_path}")
35
+ if bin_path.exists():
36
+ bin_path.unlink()
37
+ print(f"CLIP: Removed {bin_path}")
38
+ return out_path
39
+
40
+
41
+ def _extract_sam3_state_dict(ckpt: dict) -> dict:
42
+ """Extract SAM3 image model state dict from checkpoint (same logic as sam3._load_checkpoint)."""
43
+ if "model" in ckpt and isinstance(ckpt["model"], dict):
44
+ ckpt = ckpt["model"]
45
+ sam3_image_ckpt = {
46
+ k.replace("detector.", ""): v for k, v in ckpt.items() if "detector" in k
47
+ }
48
+ return sam3_image_ckpt
49
+
50
+
51
+ def convert_sam3(pt_path: Path, output_path: Path = None) -> Path:
52
+ """Convert SAM3 sam3.pt to model.safetensors (image model weights only)."""
53
+ output_path = output_path or pt_path.parent / "model.safetensors"
54
+
55
+ print(f"SAM3: Loading from {pt_path}...")
56
+ ckpt = torch.load(pt_path, map_location="cpu", weights_only=True)
57
+ state_dict = _extract_sam3_state_dict(ckpt)
58
+
59
+ state_dict = {k: v.float() if v.dtype in (torch.float16, torch.bfloat16) else v for k, v in state_dict.items()}
60
+ save_file(state_dict, str(output_path))
61
+ print(f"SAM3: Saved to {output_path}")
62
+ return output_path
63
+
64
+
65
+ def copy_sam3_safetensors(source: Path, dest_dir: Path) -> Path:
66
+ """Copy HF model.safetensors (detector_model.* keys) to SegEarth-OV. Pipeline maps keys on load."""
67
+ dest = dest_dir / "model.safetensors"
68
+ if source.exists():
69
+ import shutil
70
+ shutil.copy2(source, dest)
71
+ print(f"SAM3: Copied {source} -> {dest}")
72
+ return dest
73
+ return None
74
+
75
+
76
+ def main():
77
+ parser = argparse.ArgumentParser(description="Convert CLIP and SAM3 checkpoints to safetensors")
78
+ parser.add_argument("--clip", action="store_true", help="Convert CLIP only")
79
+ parser.add_argument("--sam3", action="store_true", help="Convert SAM3 only")
80
+ parser.add_argument("--all", action="store_true", help="Convert all (default when no --clip/--sam3)")
81
+ args = parser.parse_args()
82
+
83
+ repo = Path(__file__).parent
84
+ do_both = not args.clip and not args.sam3
85
+
86
+ if args.clip or do_both:
87
+ clip_dir = repo / "OV" / "weights" / "backbone" / "clip-vit-base-patch16"
88
+ if clip_dir.exists():
89
+ convert_clip(clip_dir)
90
+ else:
91
+ print(f"CLIP: {clip_dir} not found, skip")
92
+
93
+ if args.sam3 or do_both:
94
+ sam3_dir = repo / "OV-3" / "weights" / "backbone" / "sam3"
95
+ hf_safetensors = Path("/data/projects/models/hf_models/facebook/sam3/model.safetensors")
96
+ sam3_pt = sam3_dir / "sam3.pt"
97
+ if hf_safetensors.exists():
98
+ copy_sam3_safetensors(hf_safetensors, sam3_dir)
99
+ if sam3_pt.exists():
100
+ sam3_pt.unlink()
101
+ print(f"SAM3: Removed {sam3_pt}")
102
+ elif sam3_pt.exists():
103
+ st_path = convert_sam3(sam3_pt)
104
+ sam3_pt.unlink()
105
+ print(f"SAM3: Removed {sam3_pt}")
106
+ else:
107
+ print(f"SAM3: Neither {hf_safetensors} nor {sam3_pt} found, skip")
108
+
109
+ print("Done.")
110
+
111
+
112
+ if __name__ == "__main__":
113
+ main()