""" Utility script to push GeneMamba model to Hugging Face Hub. Usage: python scripts/push_to_hub.py --model_path ./my_checkpoint --repo_name username/GeneMamba-custom Requirements: - Hugging Face CLI: huggingface-cli login - Git LFS installed (for large model files) """ import os import shutil import argparse import json from pathlib import Path from huggingface_hub import HfApi def collect_local_files(root: Path): files = set() for path in root.rglob("*"): if not path.is_file(): continue if "__pycache__" in path.parts: continue if path.suffix == ".pyc": continue files.add(path.relative_to(root).as_posix()) return files def normalize_config_for_hf(config_path: Path): with config_path.open("r", encoding="utf-8") as f: config = json.load(f) if "d_model" in config and "hidden_size" not in config: config["hidden_size"] = config["d_model"] if "mamba_layer" in config and "num_hidden_layers" not in config: config["num_hidden_layers"] = config["mamba_layer"] legacy_checkpoint_config = ("d_model" in config) or ("mamba_layer" in config) config["model_type"] = "genemamba" config.setdefault("architectures", ["GeneMambaModel"]) config.setdefault("max_position_embeddings", 2048) config.setdefault("intermediate_size", 2048) config.setdefault("hidden_dropout_prob", 0.1) config.setdefault("initializer_range", 0.02) if legacy_checkpoint_config and config.get("mamba_mode") == "gate": config["mamba_mode"] = "mean" else: config.setdefault("mamba_mode", "mean") config.setdefault("embedding_pooling", "mean") config.setdefault("num_labels", 2) config.setdefault("pad_token_id", 1) config.setdefault("bos_token_id", 0) config.setdefault("eos_token_id", 2) config.setdefault("use_cache", True) config.setdefault("torch_dtype", "float32") config.setdefault("transformers_version", "4.40.2") config["auto_map"] = { "AutoConfig": "configuration_genemamba.GeneMambaConfig", "AutoModel": "modeling_genemamba.GeneMambaModel", "AutoModelForMaskedLM": "modeling_genemamba.GeneMambaForMaskedLM", "AutoModelForSequenceClassification": "modeling_genemamba.GeneMambaForSequenceClassification", } with config_path.open("w", encoding="utf-8") as f: json.dump(config, f, indent=2) f.write("\n") def main(): project_root = Path(__file__).resolve().parent.parent parser = argparse.ArgumentParser( description="Push a GeneMamba model to Hugging Face Hub" ) parser.add_argument( "--model_path", default=str(project_root), help="Path to local model directory. Defaults to project root.", ) parser.add_argument( "--repo_name", required=True, help="Target repo name on Hub (format: username/repo-name)", ) parser.add_argument( "--private", action="store_true", help="Make the repository private", ) parser.add_argument( "--commit_message", default="Upload GeneMamba model", help="Git commit message", ) parser.add_argument( "--sync_delete", action="store_true", help="Delete remote files not present locally (useful to remove stale folders)", ) args = parser.parse_args() model_path = Path(args.model_path).resolve() if "converted_checkpoints" in model_path.parts: print("✗ ERROR: model_path cannot be inside 'converted_checkpoints'.") print(f" - Received: {model_path}") print(f" - Use project root instead: {project_root}") return 1 if not model_path.exists() or not model_path.is_dir(): print(f"✗ ERROR: model_path is not a valid directory: {model_path}") return 1 print("=" * 80) print("GeneMamba Model Upload to Hugging Face Hub") print("=" * 80) # Step 1: Check model files print(f"\n[Step 1] Checking model files in '{model_path}'...") required_files = ["config.json"] optional_files = ["model.safetensors", "pytorch_model.bin", "tokenizer.json"] for file in required_files: filepath = os.path.join(str(model_path), file) if not os.path.exists(filepath): print(f"✗ ERROR: Required file '{file}' not found!") return 1 print(f"✓ All required files present") # Check optional files found_optional = [] for file in optional_files: filepath = os.path.join(str(model_path), file) if os.path.exists(filepath): found_optional.append(file) print(f"✓ Found optional files: {', '.join(found_optional) if found_optional else 'none'}") # Step 2: Copy model definition files print(f"\n[Step 2] Preparing model files...") try: model_path = Path(args.model_path) script_dir = Path(__file__).parent.parent # Files to copy for custom model support model_files = [ "modeling_genemamba.py", "configuration_genemamba.py", "modeling_outputs.py", "README.md", ] print(" - Syncing model definition files...") for file in model_files: src = script_dir / file dst = model_path / file if not src.exists(): print(f" ✗ Missing source file: {file}") return 1 shutil.copy(src, dst) print(f" ✓ Synced {file}") config_path = model_path / "config.json" normalize_config_for_hf(config_path) print(" - Normalized config.json for custom AutoModel loading") print("✓ Model files prepared") except Exception as e: print(f"✗ ERROR: {e}") import traceback traceback.print_exc() return 1 # Step 3: Push to Hub print(f"\n[Step 3] Pushing to Hub...") print(f" - Target repo: {args.repo_name}") print(f" - Private: {args.private}") print(f" - Commit message: {args.commit_message}") print(f" - Sync delete: {args.sync_delete}") try: api = HfApi() api.create_repo(repo_id=args.repo_name, private=args.private, exist_ok=True) api.upload_folder( folder_path=str(model_path), repo_id=args.repo_name, repo_type="model", commit_message=args.commit_message, ) if args.sync_delete: print(" - Syncing remote deletions...") local_files = collect_local_files(model_path) remote_files = set(api.list_repo_files(repo_id=args.repo_name, repo_type="model")) protected_files = {".gitattributes"} stale_files = sorted( [p for p in remote_files if p not in local_files and p not in protected_files] ) for stale_path in stale_files: api.delete_file( path_in_repo=stale_path, repo_id=args.repo_name, repo_type="model", commit_message=f"Remove stale file: {stale_path}", ) print(f" ✓ Removed {len(stale_files)} stale remote files") print(f"✓ Model pushed successfully!") print(f" - URL: https://huggingface.co/{args.repo_name}") except Exception as e: print(f"✗ ERROR during push: {e}") print(f"\nTroubleshooting:") print(f" 1. Make sure you're logged in: huggingface-cli login") print(f" 2. Check that you own the repo or have write access") print(f" 3. If repo doesn't exist, create it first: huggingface-cli repo create {args.repo_name}") return 1 print("\n" + "=" * 80) print("Upload Complete!") print("=" * 80) print(f"\nYou can now load the model with:") print(f" from transformers import AutoModel") print(f" model = AutoModel.from_pretrained('{args.repo_name}', trust_remote_code=True)") return 0 if __name__ == "__main__": exit(main())