File size: 1,849 Bytes
d171350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""Convert a midmid PyTorch checkpoint to safetensors + config.json.

Usage:
    python convert_checkpoint.py path/to/best.pt --output-dir ./model_upload

This produces:
    model_upload/model.safetensors   (weights only, no pickle)
    model_upload/config.json         (model hyperparameters)

Then upload to HF:
    huggingface-cli upload markury/midmid3-19m-0326 ./model_upload
"""

import argparse
import json
from pathlib import Path

import torch
from safetensors.torch import save_file


def main():
    parser = argparse.ArgumentParser(description="Convert midmid checkpoint to safetensors")
    parser.add_argument("checkpoint", type=Path, help="Path to .pt checkpoint")
    parser.add_argument("--output-dir", type=Path, default=Path("model_upload"),
                        help="Output directory (default: ./model_upload)")
    args = parser.parse_args()

    args.output_dir.mkdir(parents=True, exist_ok=True)

    print(f"Loading checkpoint: {args.checkpoint}")
    ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False)

    # Save config
    config = ckpt["config"]
    config_path = args.output_dir / "config.json"
    with open(config_path, "w") as f:
        json.dump(config, f, indent=2)
    print(f"Config saved: {config_path}")
    print(f"  {json.dumps(config, indent=2)}")

    # Save weights as safetensors
    state_dict = ckpt["model_state_dict"]
    safetensors_path = args.output_dir / "model.safetensors"
    save_file(state_dict, str(safetensors_path))
    print(f"Weights saved: {safetensors_path}")

    # Summary
    n_params = sum(p.numel() for p in state_dict.values())
    print(f"  {n_params:,} parameters ({n_params / 1e6:.1f}M)")

    print(f"\nUpload to HF with:")
    print(f"  huggingface-cli upload markury/midmid3-19m-0326 {args.output_dir}")


if __name__ == "__main__":
    main()