File size: 2,569 Bytes
c7deb87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""Shard MiniOmni3_LM.pt into HF safetensors shards.

Output layout (under --out-dir):
    model-00001-of-0000N.safetensors
    model-00002-of-0000N.safetensors
    ...
    model.safetensors.index.json
"""
import argparse
from pathlib import Path

import torch
from huggingface_hub import save_torch_state_dict


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--ckpt", default="/Users/yansc-xzf/Desktop/工作/Mini-Omni3/github/omni3/Mini-Omni3/checkpoints/MiniOmni3_LM.pt")
    ap.add_argument("--out-dir", default="/Users/yansc-xzf/Desktop/工作/Mini-Omni3/github/omni3/Mini-Omni3/checkpoints/")
    ap.add_argument("--max-shard-size", default="4GB",
                    help="HF-style size string, e.g. '4GB', '2GB', '500MB'.")
    args = ap.parse_args()

    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    print(f"Loading {args.ckpt} (this loads the full ~13 GB into RAM once)…")
    obj = torch.load(args.ckpt, map_location="cpu", weights_only=False)

    # Most training scripts save either a raw state_dict, or {"model": state_dict, ...}.
    if isinstance(obj, dict) and "model" in obj and isinstance(obj["model"], dict):
        state_dict = obj["model"]
        print("Detected wrapped checkpoint — using obj['model'].")
    elif isinstance(obj, dict) and all(isinstance(v, torch.Tensor) for v in obj.values()):
        state_dict = obj
        print("Detected raw state_dict.")
    else:
        raise SystemExit(
            f"Unexpected checkpoint structure: top-level type={type(obj)}. "
            f"Inspect the file and adjust this script (the variable `state_dict` "
            f"must end up as Dict[str, Tensor])."
        )

    # Safetensors can't store shared/aliased tensors silently. Clone any duplicates
    # so save_torch_state_dict doesn't complain. Cheap when there are no aliases.
    seen = {}
    for k, v in list(state_dict.items()):
        ptr = v.data_ptr()
        if ptr in seen:
            state_dict[k] = v.clone()
        else:
            seen[ptr] = k

    print(f"Sharding to {out_dir} with max_shard_size={args.max_shard_size}…")
    save_torch_state_dict(
        state_dict=state_dict,
        save_directory=str(out_dir),
        max_shard_size=args.max_shard_size,
        # filename_pattern defaults to "model-{index:05d}-of-{total:05d}.safetensors"
    )
    print("Done. Files written:")
    for p in sorted(out_dir.iterdir()):
        sz = p.stat().st_size / (1024**3)
        print(f"  {p.name}  ({sz:.2f} GB)")


if __name__ == "__main__":
    main()