GeneMamba / scripts /push_to_hub.py
mineself2016's picture
Sync latest GeneMamba docs and next-token training updates
ea25230 verified
"""
Utility script to push GeneMamba model to Hugging Face Hub.
Usage:
python scripts/push_to_hub.py --model_path ./my_checkpoint --repo_name username/GeneMamba-custom
Requirements:
- Hugging Face CLI: huggingface-cli login
- Git LFS installed (for large model files)
"""
import os
import shutil
import argparse
import json
from pathlib import Path
from huggingface_hub import HfApi
def collect_local_files(root: Path):
files = set()
for path in root.rglob("*"):
if not path.is_file():
continue
if "__pycache__" in path.parts:
continue
if path.suffix == ".pyc":
continue
files.add(path.relative_to(root).as_posix())
return files
def normalize_config_for_hf(config_path: Path):
with config_path.open("r", encoding="utf-8") as f:
config = json.load(f)
if "d_model" in config and "hidden_size" not in config:
config["hidden_size"] = config["d_model"]
if "mamba_layer" in config and "num_hidden_layers" not in config:
config["num_hidden_layers"] = config["mamba_layer"]
legacy_checkpoint_config = ("d_model" in config) or ("mamba_layer" in config)
config["model_type"] = "genemamba"
config.setdefault("architectures", ["GeneMambaModel"])
config.setdefault("max_position_embeddings", 2048)
config.setdefault("intermediate_size", 2048)
config.setdefault("hidden_dropout_prob", 0.1)
config.setdefault("initializer_range", 0.02)
if legacy_checkpoint_config and config.get("mamba_mode") == "gate":
config["mamba_mode"] = "mean"
else:
config.setdefault("mamba_mode", "mean")
config.setdefault("embedding_pooling", "mean")
config.setdefault("num_labels", 2)
config.setdefault("pad_token_id", 1)
config.setdefault("bos_token_id", 0)
config.setdefault("eos_token_id", 2)
config.setdefault("use_cache", True)
config.setdefault("torch_dtype", "float32")
config.setdefault("transformers_version", "4.40.2")
config["auto_map"] = {
"AutoConfig": "configuration_genemamba.GeneMambaConfig",
"AutoModel": "modeling_genemamba.GeneMambaModel",
"AutoModelForMaskedLM": "modeling_genemamba.GeneMambaForMaskedLM",
"AutoModelForSequenceClassification": "modeling_genemamba.GeneMambaForSequenceClassification",
}
with config_path.open("w", encoding="utf-8") as f:
json.dump(config, f, indent=2)
f.write("\n")
def main():
project_root = Path(__file__).resolve().parent.parent
parser = argparse.ArgumentParser(
description="Push a GeneMamba model to Hugging Face Hub"
)
parser.add_argument(
"--model_path",
default=str(project_root),
help="Path to local model directory. Defaults to project root.",
)
parser.add_argument(
"--repo_name",
required=True,
help="Target repo name on Hub (format: username/repo-name)",
)
parser.add_argument(
"--private",
action="store_true",
help="Make the repository private",
)
parser.add_argument(
"--commit_message",
default="Upload GeneMamba model",
help="Git commit message",
)
parser.add_argument(
"--sync_delete",
action="store_true",
help="Delete remote files not present locally (useful to remove stale folders)",
)
args = parser.parse_args()
model_path = Path(args.model_path).resolve()
if "converted_checkpoints" in model_path.parts:
print("βœ— ERROR: model_path cannot be inside 'converted_checkpoints'.")
print(f" - Received: {model_path}")
print(f" - Use project root instead: {project_root}")
return 1
if not model_path.exists() or not model_path.is_dir():
print(f"βœ— ERROR: model_path is not a valid directory: {model_path}")
return 1
print("=" * 80)
print("GeneMamba Model Upload to Hugging Face Hub")
print("=" * 80)
# Step 1: Check model files
print(f"\n[Step 1] Checking model files in '{model_path}'...")
required_files = ["config.json"]
optional_files = ["model.safetensors", "pytorch_model.bin", "tokenizer.json"]
for file in required_files:
filepath = os.path.join(str(model_path), file)
if not os.path.exists(filepath):
print(f"βœ— ERROR: Required file '{file}' not found!")
return 1
print(f"βœ“ All required files present")
# Check optional files
found_optional = []
for file in optional_files:
filepath = os.path.join(str(model_path), file)
if os.path.exists(filepath):
found_optional.append(file)
print(f"βœ“ Found optional files: {', '.join(found_optional) if found_optional else 'none'}")
# Step 2: Copy model definition files
print(f"\n[Step 2] Preparing model files...")
try:
model_path = Path(args.model_path)
script_dir = Path(__file__).parent.parent
# Files to copy for custom model support
model_files = [
"modeling_genemamba.py",
"configuration_genemamba.py",
"modeling_outputs.py",
"README.md",
]
print(" - Syncing model definition files...")
for file in model_files:
src = script_dir / file
dst = model_path / file
if not src.exists():
print(f" βœ— Missing source file: {file}")
return 1
shutil.copy(src, dst)
print(f" βœ“ Synced {file}")
config_path = model_path / "config.json"
normalize_config_for_hf(config_path)
print(" - Normalized config.json for custom AutoModel loading")
print("βœ“ Model files prepared")
except Exception as e:
print(f"βœ— ERROR: {e}")
import traceback
traceback.print_exc()
return 1
# Step 3: Push to Hub
print(f"\n[Step 3] Pushing to Hub...")
print(f" - Target repo: {args.repo_name}")
print(f" - Private: {args.private}")
print(f" - Commit message: {args.commit_message}")
print(f" - Sync delete: {args.sync_delete}")
try:
api = HfApi()
api.create_repo(repo_id=args.repo_name, private=args.private, exist_ok=True)
api.upload_folder(
folder_path=str(model_path),
repo_id=args.repo_name,
repo_type="model",
commit_message=args.commit_message,
)
if args.sync_delete:
print(" - Syncing remote deletions...")
local_files = collect_local_files(model_path)
remote_files = set(api.list_repo_files(repo_id=args.repo_name, repo_type="model"))
protected_files = {".gitattributes"}
stale_files = sorted(
[p for p in remote_files if p not in local_files and p not in protected_files]
)
for stale_path in stale_files:
api.delete_file(
path_in_repo=stale_path,
repo_id=args.repo_name,
repo_type="model",
commit_message=f"Remove stale file: {stale_path}",
)
print(f" βœ“ Removed {len(stale_files)} stale remote files")
print(f"βœ“ Model pushed successfully!")
print(f" - URL: https://huggingface.co/{args.repo_name}")
except Exception as e:
print(f"βœ— ERROR during push: {e}")
print(f"\nTroubleshooting:")
print(f" 1. Make sure you're logged in: huggingface-cli login")
print(f" 2. Check that you own the repo or have write access")
print(f" 3. If repo doesn't exist, create it first: huggingface-cli repo create {args.repo_name}")
return 1
print("\n" + "=" * 80)
print("Upload Complete!")
print("=" * 80)
print(f"\nYou can now load the model with:")
print(f" from transformers import AutoModel")
print(f" model = AutoModel.from_pretrained('{args.repo_name}', trust_remote_code=True)")
return 0
if __name__ == "__main__":
exit(main())