File size: 6,918 Bytes
54cd552 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 | #!/usr/bin/env python3
"""
Convert GeneMamba checkpoint to HuggingFace compatible format.
This script converts an existing GeneMamba checkpoint (from the original training)
to be compatible with the HuggingFace Transformers library.
Usage:
python scripts/convert_checkpoint.py \
--input_checkpoint /path/to/original/checkpoint \
--output_dir /path/to/output
"""
import os
import json
import shutil
import argparse
from pathlib import Path
def convert_checkpoint(input_checkpoint_path, output_dir):
"""
Convert a GeneMamba checkpoint to HuggingFace format.
Args:
input_checkpoint_path: Path to the original checkpoint directory
output_dir: Output directory for the converted checkpoint
"""
input_path = Path(input_checkpoint_path)
output_path = Path(output_dir)
# Verify input checkpoint exists
if not input_path.exists():
raise FileNotFoundError(f"Input checkpoint not found: {input_path}")
# Check for required files
config_file = input_path / "config.json"
model_file = input_path / "model.safetensors"
tokenizer_file = input_path / "tokenizer.json"
tokenizer_config_file = input_path / "tokenizer_config.json"
if not config_file.exists():
raise FileNotFoundError(f"config.json not found in {input_path}")
if not model_file.exists():
raise FileNotFoundError(f"model.safetensors not found in {input_path}")
print(f"[Step 1] Reading original checkpoint from: {input_path}")
# Create output directory
output_path.mkdir(parents=True, exist_ok=True)
# Read original config
with open(config_file, 'r') as f:
original_config = json.load(f)
print("[Step 2] Converting config.json...")
# Create new HuggingFace-compatible config
hf_config = {
# Model type (CRITICAL for HuggingFace to recognize the model)
"model_type": "genemamba",
# Architecture info
"architectures": ["GeneMambaModel"],
# Vocabulary and sequence
"vocab_size": original_config.get("vocab_size", 25426),
"max_position_embeddings": original_config.get("max_position_embeddings", 2048),
# Model dimensions
"hidden_size": original_config.get("d_model", 512),
"num_hidden_layers": original_config.get("mamba_layer", 24),
"intermediate_size": 2048,
# Regularization
"hidden_dropout_prob": 0.1,
"initializer_range": 0.02,
# Mamba-specific
"mamba_mode": original_config.get("mamba_mode", "gate"),
"embedding_pooling": original_config.get("embedding_pooling", "mean"),
# Task-specific
"num_labels": 2,
"pad_token_id": 1,
"eos_token_id": 2,
"bos_token_id": 0,
"use_cache": True,
# Metadata
"torch_dtype": original_config.get("torch_dtype", "float32"),
"transformers_version": "4.40.2",
}
# Save new config
new_config_path = output_path / "config.json"
with open(new_config_path, 'w') as f:
json.dump(hf_config, f, indent=2)
print(f"✓ Saved config.json to {new_config_path}")
# Copy model weights
print("[Step 3] Copying model weights...")
output_model_file = output_path / "model.safetensors"
shutil.copy2(model_file, output_model_file)
print(f"✓ Copied model.safetensors ({os.path.getsize(model_file) / 1e9:.2f} GB)")
# Copy tokenizer files if they exist
print("[Step 4] Copying tokenizer files...")
if tokenizer_file.exists():
shutil.copy2(tokenizer_file, output_path / "tokenizer.json")
print("✓ Copied tokenizer.json")
else:
print("âš tokenizer.json not found (optional)")
if tokenizer_config_file.exists():
shutil.copy2(tokenizer_config_file, output_path / "tokenizer_config.json")
print("✓ Copied tokenizer_config.json")
else:
print("âš tokenizer_config.json not found (will be created)")
# Create a basic tokenizer config if it doesn't exist
basic_tokenizer_config = {
"add_bos_token": True,
"add_eos_token": False,
"add_prefix_space": False,
"bos_token": "<|begin_of_sequence|>",
"eos_token": "<|end_of_sequence|>",
"model_max_length": 2048,
"pad_token": "<|pad|>",
"tokenizer_class": "PreTrainedTokenizerFast",
"unk_token": "<|unk|>",
}
with open(output_path / "tokenizer_config.json", 'w') as f:
json.dump(basic_tokenizer_config, f, indent=2)
print("✓ Created tokenizer_config.json")
# Copy special tokens map
special_tokens_map = input_path / "special_tokens_map.json"
if special_tokens_map.exists():
shutil.copy2(special_tokens_map, output_path / "special_tokens_map.json")
print("✓ Copied special_tokens_map.json")
print("\n" + "="*70)
print("✓ CONVERSION COMPLETE!")
print("="*70)
print(f"\nModel info:")
print(f" Architecture: GeneMamba")
print(f" Model Type: {hf_config['model_type']}")
print(f" Hidden Size: {hf_config['hidden_size']}")
print(f" Num Layers: {hf_config['num_hidden_layers']}")
print(f" Vocab Size: {hf_config['vocab_size']}")
print(f"\nConverted checkpoint saved to: {output_path}")
print(f"\nNext step - Upload to HuggingFace Hub:")
print(f" python scripts/push_to_hub.py \\")
print(f" --model_path {output_path} \\")
print(f" --repo_name <your_username>/<repo_name>")
def main():
parser = argparse.ArgumentParser(
description="Convert GeneMamba checkpoint to HuggingFace format",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Convert 24L-512D model
python scripts/convert_checkpoint.py \\
--input_checkpoint /project/zhiwei/cq5/LLM_checkpoints/GeneMamba/GeneMamba2_24l_512d/1/10m/checkpoint-31250 \\
--output_dir ./converted_checkpoints/GeneMamba2_24l_512d
# Convert 48L-768D model
python scripts/convert_checkpoint.py \\
--input_checkpoint /project/zhiwei/cq5/LLM_checkpoints/GeneMamba/GeneMamba2_48l_768d/1/4m/checkpoint-31250 \\
--output_dir ./converted_checkpoints/GeneMamba2_48l_768d
""")
parser.add_argument(
"--input_checkpoint",
required=True,
help="Path to original GeneMamba checkpoint directory"
)
parser.add_argument(
"--output_dir",
required=True,
help="Output directory for HuggingFace compatible checkpoint"
)
args = parser.parse_args()
try:
convert_checkpoint(args.input_checkpoint, args.output_dir)
except Exception as e:
print(f"\n✗ ERROR: {str(e)}")
exit(1)
if __name__ == "__main__":
main()
|