"""Shard MiniOmni3_LM.pt into HF safetensors shards. Output layout (under --out-dir): model-00001-of-0000N.safetensors model-00002-of-0000N.safetensors ... model.safetensors.index.json """ import argparse from pathlib import Path import torch from huggingface_hub import save_torch_state_dict def main(): ap = argparse.ArgumentParser() ap.add_argument("--ckpt", default="/Users/yansc-xzf/Desktop/工作/Mini-Omni3/github/omni3/Mini-Omni3/checkpoints/MiniOmni3_LM.pt") ap.add_argument("--out-dir", default="/Users/yansc-xzf/Desktop/工作/Mini-Omni3/github/omni3/Mini-Omni3/checkpoints/") ap.add_argument("--max-shard-size", default="4GB", help="HF-style size string, e.g. '4GB', '2GB', '500MB'.") args = ap.parse_args() out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) print(f"Loading {args.ckpt} (this loads the full ~13 GB into RAM once)…") obj = torch.load(args.ckpt, map_location="cpu", weights_only=False) # Most training scripts save either a raw state_dict, or {"model": state_dict, ...}. if isinstance(obj, dict) and "model" in obj and isinstance(obj["model"], dict): state_dict = obj["model"] print("Detected wrapped checkpoint — using obj['model'].") elif isinstance(obj, dict) and all(isinstance(v, torch.Tensor) for v in obj.values()): state_dict = obj print("Detected raw state_dict.") else: raise SystemExit( f"Unexpected checkpoint structure: top-level type={type(obj)}. " f"Inspect the file and adjust this script (the variable `state_dict` " f"must end up as Dict[str, Tensor])." ) # Safetensors can't store shared/aliased tensors silently. Clone any duplicates # so save_torch_state_dict doesn't complain. Cheap when there are no aliases. seen = {} for k, v in list(state_dict.items()): ptr = v.data_ptr() if ptr in seen: state_dict[k] = v.clone() else: seen[ptr] = k print(f"Sharding to {out_dir} with max_shard_size={args.max_shard_size}…") save_torch_state_dict( state_dict=state_dict, save_directory=str(out_dir), max_shard_size=args.max_shard_size, # filename_pattern defaults to "model-{index:05d}-of-{total:05d}.safetensors" ) print("Done. Files written:") for p in sorted(out_dir.iterdir()): sz = p.stat().st_size / (1024**3) print(f" {p.name} ({sz:.2f} GB)") if __name__ == "__main__": main()