anzheCheng commited on
Commit
19ccf99
·
verified ·
1 Parent(s): 6fdc2e7

Upload export_to_hub.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. export_to_hub.py +105 -0
export_to_hub.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from pathlib import Path
4
+
5
+ from huggingface_hub import HfApi
6
+
7
+ try:
8
+ from .eigen_moe import HFEigenMoE, default_hub_checkpoint_filename
9
+ except ImportError:
10
+ from eigen_moe import HFEigenMoE, default_hub_checkpoint_filename
11
+
12
+
13
+ def parse_args():
14
+ parser = argparse.ArgumentParser(description="Export and optionally push EMoE to Hugging Face Hub.")
15
+ parser.add_argument("--checkpoint", type=str, required=True, help="Path to local .pth checkpoint.")
16
+ parser.add_argument("--output-dir", type=str, default="./hf_export", help="Directory for save_pretrained output.")
17
+ parser.add_argument("--repo-id", type=str, default="", help="Hub repo id, e.g. anzheCheng/EMoE.")
18
+ parser.add_argument("--push", action="store_true", help="Push exported files to --repo-id.")
19
+ parser.add_argument("--upload-original-checkpoint", action="store_true", help="Also upload the raw .pth file.")
20
+ parser.add_argument("--checkpoint-filename", type=str, default="", help="Filename to store raw checkpoint as in Hub.")
21
+ parser.add_argument("--vit-model-name", type=str, default="vit_base_patch16_224")
22
+ parser.add_argument("--num-classes", type=int, default=1000)
23
+ parser.add_argument("--strict", action="store_true", help="Strict checkpoint loading.")
24
+
25
+ # MoE hyperparameters
26
+ parser.add_argument("--experts", type=int, default=8)
27
+ parser.add_argument("--r", type=int, default=128)
28
+ parser.add_argument("--bottleneck", type=int, default=192)
29
+ parser.add_argument("--tau", type=float, default=1.0)
30
+ parser.add_argument("--router-mode", choices=["soft", "top1", "top2"], default="top1")
31
+ parser.add_argument("--alpha", type=float, default=1.0)
32
+ parser.add_argument("--blocks", type=str, default="last6")
33
+ parser.add_argument("--apply-to-patches-only", dest="apply_to_patches_only", action="store_true")
34
+ parser.add_argument("--no-apply-to-patches-only", dest="apply_to_patches_only", action="store_false")
35
+ parser.add_argument("--ortho-lambda", type=float, default=1e-3)
36
+ parser.add_argument("--freeze-backbone", dest="freeze_backbone", action="store_true")
37
+ parser.add_argument("--no-freeze-backbone", dest="freeze_backbone", action="store_false")
38
+ parser.add_argument("--unfreeze-layernorm", action="store_true", default=False)
39
+ parser.add_argument("--backbone-pretrained", action="store_true", default=False)
40
+ parser.set_defaults(apply_to_patches_only=True, freeze_backbone=True)
41
+ return parser.parse_args()
42
+
43
+
44
+ def main():
45
+ args = parse_args()
46
+ checkpoint_path = Path(args.checkpoint)
47
+ if not checkpoint_path.exists():
48
+ raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
49
+
50
+ moe_config = {
51
+ "experts": args.experts,
52
+ "r": args.r,
53
+ "bottleneck": args.bottleneck,
54
+ "tau": args.tau,
55
+ "router_mode": args.router_mode,
56
+ "alpha": args.alpha,
57
+ "blocks": args.blocks,
58
+ "apply_to_patches_only": args.apply_to_patches_only,
59
+ "ortho_lambda": args.ortho_lambda,
60
+ "freeze_backbone": args.freeze_backbone,
61
+ "unfreeze_layernorm": args.unfreeze_layernorm,
62
+ }
63
+
64
+ model = HFEigenMoE(
65
+ vit_model_name=args.vit_model_name,
66
+ num_classes=args.num_classes,
67
+ backbone_pretrained=args.backbone_pretrained,
68
+ moe_config=moe_config,
69
+ )
70
+ missing, unexpected = model.load_checkpoint(
71
+ str(checkpoint_path),
72
+ map_location="cpu",
73
+ strict=args.strict,
74
+ )
75
+ print(f"Loaded checkpoint: missing_keys={len(missing)} unexpected_keys={len(unexpected)}")
76
+
77
+ os.makedirs(args.output_dir, exist_ok=True)
78
+ model.save_pretrained(args.output_dir)
79
+ print(f"Saved Hub format model to: {args.output_dir}")
80
+
81
+ if not args.push:
82
+ return
83
+
84
+ if not args.repo_id:
85
+ raise ValueError("--repo-id is required when using --push.")
86
+
87
+ print(f"Pushing to Hub repo: {args.repo_id}")
88
+ model.push_to_hub(args.repo_id)
89
+
90
+ if args.upload_original_checkpoint:
91
+ upload_name = args.checkpoint_filename or default_hub_checkpoint_filename(args.vit_model_name)
92
+ if not upload_name:
93
+ upload_name = checkpoint_path.name
94
+ api = HfApi()
95
+ api.upload_file(
96
+ path_or_fileobj=str(checkpoint_path),
97
+ path_in_repo=upload_name,
98
+ repo_id=args.repo_id,
99
+ repo_type="model",
100
+ )
101
+ print(f"Uploaded original checkpoint as: {upload_name}")
102
+
103
+
104
+ if __name__ == "__main__":
105
+ main()