File size: 8,231 Bytes
54cd552
 
 
 
 
 
 
 
 
 
 
 
 
 
ea25230
54cd552
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea25230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54cd552
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea25230
54cd552
 
ea25230
54cd552
 
 
ea25230
 
 
 
 
54cd552
ea25230
 
 
 
54cd552
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
"""
Utility script to push GeneMamba model to Hugging Face Hub.

Usage:
    python scripts/push_to_hub.py --model_path ./my_checkpoint --repo_name username/GeneMamba-custom

Requirements:
    - Hugging Face CLI: huggingface-cli login
    - Git LFS installed (for large model files)
"""

import os
import shutil
import argparse
import json
from pathlib import Path
from huggingface_hub import HfApi


def collect_local_files(root: Path):
    files = set()
    for path in root.rglob("*"):
        if not path.is_file():
            continue
        if "__pycache__" in path.parts:
            continue
        if path.suffix == ".pyc":
            continue
        files.add(path.relative_to(root).as_posix())
    return files


def normalize_config_for_hf(config_path: Path):
    with config_path.open("r", encoding="utf-8") as f:
        config = json.load(f)

    if "d_model" in config and "hidden_size" not in config:
        config["hidden_size"] = config["d_model"]
    if "mamba_layer" in config and "num_hidden_layers" not in config:
        config["num_hidden_layers"] = config["mamba_layer"]

    legacy_checkpoint_config = ("d_model" in config) or ("mamba_layer" in config)

    config["model_type"] = "genemamba"
    config.setdefault("architectures", ["GeneMambaModel"])
    config.setdefault("max_position_embeddings", 2048)
    config.setdefault("intermediate_size", 2048)
    config.setdefault("hidden_dropout_prob", 0.1)
    config.setdefault("initializer_range", 0.02)
    if legacy_checkpoint_config and config.get("mamba_mode") == "gate":
        config["mamba_mode"] = "mean"
    else:
        config.setdefault("mamba_mode", "mean")
    config.setdefault("embedding_pooling", "mean")
    config.setdefault("num_labels", 2)
    config.setdefault("pad_token_id", 1)
    config.setdefault("bos_token_id", 0)
    config.setdefault("eos_token_id", 2)
    config.setdefault("use_cache", True)
    config.setdefault("torch_dtype", "float32")
    config.setdefault("transformers_version", "4.40.2")
    config["auto_map"] = {
        "AutoConfig": "configuration_genemamba.GeneMambaConfig",
        "AutoModel": "modeling_genemamba.GeneMambaModel",
        "AutoModelForMaskedLM": "modeling_genemamba.GeneMambaForMaskedLM",
        "AutoModelForSequenceClassification": "modeling_genemamba.GeneMambaForSequenceClassification",
    }

    with config_path.open("w", encoding="utf-8") as f:
        json.dump(config, f, indent=2)
        f.write("\n")


def main():
    project_root = Path(__file__).resolve().parent.parent

    parser = argparse.ArgumentParser(
        description="Push a GeneMamba model to Hugging Face Hub"
    )
    
    parser.add_argument(
        "--model_path",
        default=str(project_root),
        help="Path to local model directory. Defaults to project root.",
    )
    
    parser.add_argument(
        "--repo_name",
        required=True,
        help="Target repo name on Hub (format: username/repo-name)",
    )
    
    parser.add_argument(
        "--private",
        action="store_true",
        help="Make the repository private",
    )
    
    parser.add_argument(
        "--commit_message",
        default="Upload GeneMamba model",
        help="Git commit message",
    )

    parser.add_argument(
        "--sync_delete",
        action="store_true",
        help="Delete remote files not present locally (useful to remove stale folders)",
    )
    
    args = parser.parse_args()
    model_path = Path(args.model_path).resolve()

    if "converted_checkpoints" in model_path.parts:
        print("βœ— ERROR: model_path cannot be inside 'converted_checkpoints'.")
        print(f"  - Received: {model_path}")
        print(f"  - Use project root instead: {project_root}")
        return 1

    if not model_path.exists() or not model_path.is_dir():
        print(f"βœ— ERROR: model_path is not a valid directory: {model_path}")
        return 1
    
    print("=" * 80)
    print("GeneMamba Model Upload to Hugging Face Hub")
    print("=" * 80)
    
    # Step 1: Check model files
    print(f"\n[Step 1] Checking model files in '{model_path}'...")
    
    required_files = ["config.json"]
    optional_files = ["model.safetensors", "pytorch_model.bin", "tokenizer.json"]
    
    for file in required_files:
        filepath = os.path.join(str(model_path), file)
        if not os.path.exists(filepath):
            print(f"βœ— ERROR: Required file '{file}' not found!")
            return 1
    
    print(f"βœ“ All required files present")
    
    # Check optional files
    found_optional = []
    for file in optional_files:
        filepath = os.path.join(str(model_path), file)
        if os.path.exists(filepath):
            found_optional.append(file)
    
    print(f"βœ“ Found optional files: {', '.join(found_optional) if found_optional else 'none'}")
    
    # Step 2: Copy model definition files
    print(f"\n[Step 2] Preparing model files...")
    
    try:
        model_path = Path(args.model_path)
        script_dir = Path(__file__).parent.parent
        
        # Files to copy for custom model support
        model_files = [
            "modeling_genemamba.py",
            "configuration_genemamba.py",
            "modeling_outputs.py",
            "README.md",
        ]
        
        print("  - Syncing model definition files...")
        for file in model_files:
            src = script_dir / file
            dst = model_path / file
            if not src.exists():
                print(f"    βœ— Missing source file: {file}")
                return 1
            shutil.copy(src, dst)
            print(f"    βœ“ Synced {file}")
        
        config_path = model_path / "config.json"
        normalize_config_for_hf(config_path)
        print("  - Normalized config.json for custom AutoModel loading")

        print("βœ“ Model files prepared")
        
    except Exception as e:
        print(f"βœ— ERROR: {e}")
        import traceback
        traceback.print_exc()
        return 1
    
    # Step 3: Push to Hub
    print(f"\n[Step 3] Pushing to Hub...")
    print(f"  - Target repo: {args.repo_name}")
    print(f"  - Private: {args.private}")
    print(f"  - Commit message: {args.commit_message}")
    print(f"  - Sync delete: {args.sync_delete}")
    
    try:
        api = HfApi()
        api.create_repo(repo_id=args.repo_name, private=args.private, exist_ok=True)
        api.upload_folder(
            folder_path=str(model_path),
            repo_id=args.repo_name,
            repo_type="model",
            commit_message=args.commit_message,
        )

        if args.sync_delete:
            print("  - Syncing remote deletions...")
            local_files = collect_local_files(model_path)
            remote_files = set(api.list_repo_files(repo_id=args.repo_name, repo_type="model"))
            protected_files = {".gitattributes"}
            stale_files = sorted(
                [p for p in remote_files if p not in local_files and p not in protected_files]
            )

            for stale_path in stale_files:
                api.delete_file(
                    path_in_repo=stale_path,
                    repo_id=args.repo_name,
                    repo_type="model",
                    commit_message=f"Remove stale file: {stale_path}",
                )
            print(f"    βœ“ Removed {len(stale_files)} stale remote files")
        
        print(f"βœ“ Model pushed successfully!")
        print(f"  - URL: https://huggingface.co/{args.repo_name}")
        
    except Exception as e:
        print(f"βœ— ERROR during push: {e}")
        print(f"\nTroubleshooting:")
        print(f"  1. Make sure you're logged in: huggingface-cli login")
        print(f"  2. Check that you own the repo or have write access")
        print(f"  3. If repo doesn't exist, create it first: huggingface-cli repo create {args.repo_name}")
        return 1
    
    print("\n" + "=" * 80)
    print("Upload Complete!")
    print("=" * 80)
    print(f"\nYou can now load the model with:")
    print(f"  from transformers import AutoModel")
    print(f"  model = AutoModel.from_pretrained('{args.repo_name}', trust_remote_code=True)")
    
    return 0


if __name__ == "__main__":
    exit(main())