| """ |
| Utility script to push GeneMamba model to Hugging Face Hub. |
| |
| Usage: |
| python scripts/push_to_hub.py --model_path ./my_checkpoint --repo_name username/GeneMamba-custom |
| |
| Requirements: |
| - Hugging Face CLI: huggingface-cli login |
| - Git LFS installed (for large model files) |
| """ |
|
|
| import os |
| import shutil |
| import argparse |
| import json |
| from pathlib import Path |
| from huggingface_hub import HfApi |
|
|
|
|
| def collect_local_files(root: Path): |
| files = set() |
| for path in root.rglob("*"): |
| if not path.is_file(): |
| continue |
| if "__pycache__" in path.parts: |
| continue |
| if path.suffix == ".pyc": |
| continue |
| files.add(path.relative_to(root).as_posix()) |
| return files |
|
|
|
|
| def normalize_config_for_hf(config_path: Path): |
| with config_path.open("r", encoding="utf-8") as f: |
| config = json.load(f) |
|
|
| if "d_model" in config and "hidden_size" not in config: |
| config["hidden_size"] = config["d_model"] |
| if "mamba_layer" in config and "num_hidden_layers" not in config: |
| config["num_hidden_layers"] = config["mamba_layer"] |
|
|
| legacy_checkpoint_config = ("d_model" in config) or ("mamba_layer" in config) |
|
|
| config["model_type"] = "genemamba" |
| config.setdefault("architectures", ["GeneMambaModel"]) |
| config.setdefault("max_position_embeddings", 2048) |
| config.setdefault("intermediate_size", 2048) |
| config.setdefault("hidden_dropout_prob", 0.1) |
| config.setdefault("initializer_range", 0.02) |
| if legacy_checkpoint_config and config.get("mamba_mode") == "gate": |
| config["mamba_mode"] = "mean" |
| else: |
| config.setdefault("mamba_mode", "mean") |
| config.setdefault("embedding_pooling", "mean") |
| config.setdefault("num_labels", 2) |
| config.setdefault("pad_token_id", 1) |
| config.setdefault("bos_token_id", 0) |
| config.setdefault("eos_token_id", 2) |
| config.setdefault("use_cache", True) |
| config.setdefault("torch_dtype", "float32") |
| config.setdefault("transformers_version", "4.40.2") |
| config["auto_map"] = { |
| "AutoConfig": "configuration_genemamba.GeneMambaConfig", |
| "AutoModel": "modeling_genemamba.GeneMambaModel", |
| "AutoModelForMaskedLM": "modeling_genemamba.GeneMambaForMaskedLM", |
| "AutoModelForSequenceClassification": "modeling_genemamba.GeneMambaForSequenceClassification", |
| } |
|
|
| with config_path.open("w", encoding="utf-8") as f: |
| json.dump(config, f, indent=2) |
| f.write("\n") |
|
|
|
|
| def main(): |
| project_root = Path(__file__).resolve().parent.parent |
|
|
| parser = argparse.ArgumentParser( |
| description="Push a GeneMamba model to Hugging Face Hub" |
| ) |
| |
| parser.add_argument( |
| "--model_path", |
| default=str(project_root), |
| help="Path to local model directory. Defaults to project root.", |
| ) |
| |
| parser.add_argument( |
| "--repo_name", |
| required=True, |
| help="Target repo name on Hub (format: username/repo-name)", |
| ) |
| |
| parser.add_argument( |
| "--private", |
| action="store_true", |
| help="Make the repository private", |
| ) |
| |
| parser.add_argument( |
| "--commit_message", |
| default="Upload GeneMamba model", |
| help="Git commit message", |
| ) |
|
|
| parser.add_argument( |
| "--sync_delete", |
| action="store_true", |
| help="Delete remote files not present locally (useful to remove stale folders)", |
| ) |
| |
| args = parser.parse_args() |
| model_path = Path(args.model_path).resolve() |
|
|
| if "converted_checkpoints" in model_path.parts: |
| print("β ERROR: model_path cannot be inside 'converted_checkpoints'.") |
| print(f" - Received: {model_path}") |
| print(f" - Use project root instead: {project_root}") |
| return 1 |
|
|
| if not model_path.exists() or not model_path.is_dir(): |
| print(f"β ERROR: model_path is not a valid directory: {model_path}") |
| return 1 |
| |
| print("=" * 80) |
| print("GeneMamba Model Upload to Hugging Face Hub") |
| print("=" * 80) |
| |
| |
| print(f"\n[Step 1] Checking model files in '{model_path}'...") |
| |
| required_files = ["config.json"] |
| optional_files = ["model.safetensors", "pytorch_model.bin", "tokenizer.json"] |
| |
| for file in required_files: |
| filepath = os.path.join(str(model_path), file) |
| if not os.path.exists(filepath): |
| print(f"β ERROR: Required file '{file}' not found!") |
| return 1 |
| |
| print(f"β All required files present") |
| |
| |
| found_optional = [] |
| for file in optional_files: |
| filepath = os.path.join(str(model_path), file) |
| if os.path.exists(filepath): |
| found_optional.append(file) |
| |
| print(f"β Found optional files: {', '.join(found_optional) if found_optional else 'none'}") |
| |
| |
| print(f"\n[Step 2] Preparing model files...") |
| |
| try: |
| model_path = Path(args.model_path) |
| script_dir = Path(__file__).parent.parent |
| |
| |
| model_files = [ |
| "modeling_genemamba.py", |
| "configuration_genemamba.py", |
| "modeling_outputs.py", |
| "README.md", |
| ] |
| |
| print(" - Syncing model definition files...") |
| for file in model_files: |
| src = script_dir / file |
| dst = model_path / file |
| if not src.exists(): |
| print(f" β Missing source file: {file}") |
| return 1 |
| shutil.copy(src, dst) |
| print(f" β Synced {file}") |
| |
| config_path = model_path / "config.json" |
| normalize_config_for_hf(config_path) |
| print(" - Normalized config.json for custom AutoModel loading") |
|
|
| print("β Model files prepared") |
| |
| except Exception as e: |
| print(f"β ERROR: {e}") |
| import traceback |
| traceback.print_exc() |
| return 1 |
| |
| |
| print(f"\n[Step 3] Pushing to Hub...") |
| print(f" - Target repo: {args.repo_name}") |
| print(f" - Private: {args.private}") |
| print(f" - Commit message: {args.commit_message}") |
| print(f" - Sync delete: {args.sync_delete}") |
| |
| try: |
| api = HfApi() |
| api.create_repo(repo_id=args.repo_name, private=args.private, exist_ok=True) |
| api.upload_folder( |
| folder_path=str(model_path), |
| repo_id=args.repo_name, |
| repo_type="model", |
| commit_message=args.commit_message, |
| ) |
|
|
| if args.sync_delete: |
| print(" - Syncing remote deletions...") |
| local_files = collect_local_files(model_path) |
| remote_files = set(api.list_repo_files(repo_id=args.repo_name, repo_type="model")) |
| protected_files = {".gitattributes"} |
| stale_files = sorted( |
| [p for p in remote_files if p not in local_files and p not in protected_files] |
| ) |
|
|
| for stale_path in stale_files: |
| api.delete_file( |
| path_in_repo=stale_path, |
| repo_id=args.repo_name, |
| repo_type="model", |
| commit_message=f"Remove stale file: {stale_path}", |
| ) |
| print(f" β Removed {len(stale_files)} stale remote files") |
| |
| print(f"β Model pushed successfully!") |
| print(f" - URL: https://huggingface.co/{args.repo_name}") |
| |
| except Exception as e: |
| print(f"β ERROR during push: {e}") |
| print(f"\nTroubleshooting:") |
| print(f" 1. Make sure you're logged in: huggingface-cli login") |
| print(f" 2. Check that you own the repo or have write access") |
| print(f" 3. If repo doesn't exist, create it first: huggingface-cli repo create {args.repo_name}") |
| return 1 |
| |
| print("\n" + "=" * 80) |
| print("Upload Complete!") |
| print("=" * 80) |
| print(f"\nYou can now load the model with:") |
| print(f" from transformers import AutoModel") |
| print(f" model = AutoModel.from_pretrained('{args.repo_name}', trust_remote_code=True)") |
| |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| exit(main()) |
|
|