AudioInteraction / utils.py
zhifeixie's picture
Add files using upload-large-folder tool
c7deb87 verified
"""Shard MiniOmni3_LM.pt into HF safetensors shards.
Output layout (under --out-dir):
model-00001-of-0000N.safetensors
model-00002-of-0000N.safetensors
...
model.safetensors.index.json
"""
import argparse
from pathlib import Path
import torch
from huggingface_hub import save_torch_state_dict
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--ckpt", default="/Users/yansc-xzf/Desktop/工作/Mini-Omni3/github/omni3/Mini-Omni3/checkpoints/MiniOmni3_LM.pt")
ap.add_argument("--out-dir", default="/Users/yansc-xzf/Desktop/工作/Mini-Omni3/github/omni3/Mini-Omni3/checkpoints/")
ap.add_argument("--max-shard-size", default="4GB",
help="HF-style size string, e.g. '4GB', '2GB', '500MB'.")
args = ap.parse_args()
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
print(f"Loading {args.ckpt} (this loads the full ~13 GB into RAM once)…")
obj = torch.load(args.ckpt, map_location="cpu", weights_only=False)
# Most training scripts save either a raw state_dict, or {"model": state_dict, ...}.
if isinstance(obj, dict) and "model" in obj and isinstance(obj["model"], dict):
state_dict = obj["model"]
print("Detected wrapped checkpoint — using obj['model'].")
elif isinstance(obj, dict) and all(isinstance(v, torch.Tensor) for v in obj.values()):
state_dict = obj
print("Detected raw state_dict.")
else:
raise SystemExit(
f"Unexpected checkpoint structure: top-level type={type(obj)}. "
f"Inspect the file and adjust this script (the variable `state_dict` "
f"must end up as Dict[str, Tensor])."
)
# Safetensors can't store shared/aliased tensors silently. Clone any duplicates
# so save_torch_state_dict doesn't complain. Cheap when there are no aliases.
seen = {}
for k, v in list(state_dict.items()):
ptr = v.data_ptr()
if ptr in seen:
state_dict[k] = v.clone()
else:
seen[ptr] = k
print(f"Sharding to {out_dir} with max_shard_size={args.max_shard_size}…")
save_torch_state_dict(
state_dict=state_dict,
save_directory=str(out_dir),
max_shard_size=args.max_shard_size,
# filename_pattern defaults to "model-{index:05d}-of-{total:05d}.safetensors"
)
print("Done. Files written:")
for p in sorted(out_dir.iterdir()):
sz = p.stat().st_size / (1024**3)
print(f" {p.name} ({sz:.2f} GB)")
if __name__ == "__main__":
main()