#!/usr/bin/env python3 """ Convert GeneMamba checkpoint to HuggingFace compatible format. This script converts an existing GeneMamba checkpoint (from the original training) to be compatible with the HuggingFace Transformers library. Usage: python scripts/convert_checkpoint.py \ --input_checkpoint /path/to/original/checkpoint \ --output_dir /path/to/output """ import os import json import shutil import argparse from pathlib import Path def convert_checkpoint(input_checkpoint_path, output_dir): """ Convert a GeneMamba checkpoint to HuggingFace format. Args: input_checkpoint_path: Path to the original checkpoint directory output_dir: Output directory for the converted checkpoint """ input_path = Path(input_checkpoint_path) output_path = Path(output_dir) # Verify input checkpoint exists if not input_path.exists(): raise FileNotFoundError(f"Input checkpoint not found: {input_path}") # Check for required files config_file = input_path / "config.json" model_file = input_path / "model.safetensors" tokenizer_file = input_path / "tokenizer.json" tokenizer_config_file = input_path / "tokenizer_config.json" if not config_file.exists(): raise FileNotFoundError(f"config.json not found in {input_path}") if not model_file.exists(): raise FileNotFoundError(f"model.safetensors not found in {input_path}") print(f"[Step 1] Reading original checkpoint from: {input_path}") # Create output directory output_path.mkdir(parents=True, exist_ok=True) # Read original config with open(config_file, 'r') as f: original_config = json.load(f) print("[Step 2] Converting config.json...") # Create new HuggingFace-compatible config hf_config = { # Model type (CRITICAL for HuggingFace to recognize the model) "model_type": "genemamba", # Architecture info "architectures": ["GeneMambaModel"], # Vocabulary and sequence "vocab_size": original_config.get("vocab_size", 25426), "max_position_embeddings": original_config.get("max_position_embeddings", 2048), # Model dimensions "hidden_size": original_config.get("d_model", 512), "num_hidden_layers": original_config.get("mamba_layer", 24), "intermediate_size": 2048, # Regularization "hidden_dropout_prob": 0.1, "initializer_range": 0.02, # Mamba-specific "mamba_mode": original_config.get("mamba_mode", "gate"), "embedding_pooling": original_config.get("embedding_pooling", "mean"), # Task-specific "num_labels": 2, "pad_token_id": 1, "eos_token_id": 2, "bos_token_id": 0, "use_cache": True, # Metadata "torch_dtype": original_config.get("torch_dtype", "float32"), "transformers_version": "4.40.2", } # Save new config new_config_path = output_path / "config.json" with open(new_config_path, 'w') as f: json.dump(hf_config, f, indent=2) print(f"✓ Saved config.json to {new_config_path}") # Copy model weights print("[Step 3] Copying model weights...") output_model_file = output_path / "model.safetensors" shutil.copy2(model_file, output_model_file) print(f"✓ Copied model.safetensors ({os.path.getsize(model_file) / 1e9:.2f} GB)") # Copy tokenizer files if they exist print("[Step 4] Copying tokenizer files...") if tokenizer_file.exists(): shutil.copy2(tokenizer_file, output_path / "tokenizer.json") print("✓ Copied tokenizer.json") else: print("⚠ tokenizer.json not found (optional)") if tokenizer_config_file.exists(): shutil.copy2(tokenizer_config_file, output_path / "tokenizer_config.json") print("✓ Copied tokenizer_config.json") else: print("⚠ tokenizer_config.json not found (will be created)") # Create a basic tokenizer config if it doesn't exist basic_tokenizer_config = { "add_bos_token": True, "add_eos_token": False, "add_prefix_space": False, "bos_token": "<|begin_of_sequence|>", "eos_token": "<|end_of_sequence|>", "model_max_length": 2048, "pad_token": "<|pad|>", "tokenizer_class": "PreTrainedTokenizerFast", "unk_token": "<|unk|>", } with open(output_path / "tokenizer_config.json", 'w') as f: json.dump(basic_tokenizer_config, f, indent=2) print("✓ Created tokenizer_config.json") # Copy special tokens map special_tokens_map = input_path / "special_tokens_map.json" if special_tokens_map.exists(): shutil.copy2(special_tokens_map, output_path / "special_tokens_map.json") print("✓ Copied special_tokens_map.json") print("\n" + "="*70) print("✓ CONVERSION COMPLETE!") print("="*70) print(f"\nModel info:") print(f" Architecture: GeneMamba") print(f" Model Type: {hf_config['model_type']}") print(f" Hidden Size: {hf_config['hidden_size']}") print(f" Num Layers: {hf_config['num_hidden_layers']}") print(f" Vocab Size: {hf_config['vocab_size']}") print(f"\nConverted checkpoint saved to: {output_path}") print(f"\nNext step - Upload to HuggingFace Hub:") print(f" python scripts/push_to_hub.py \\") print(f" --model_path {output_path} \\") print(f" --repo_name /") def main(): parser = argparse.ArgumentParser( description="Convert GeneMamba checkpoint to HuggingFace format", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # Convert 24L-512D model python scripts/convert_checkpoint.py \\ --input_checkpoint /project/zhiwei/cq5/LLM_checkpoints/GeneMamba/GeneMamba2_24l_512d/1/10m/checkpoint-31250 \\ --output_dir ./converted_checkpoints/GeneMamba2_24l_512d # Convert 48L-768D model python scripts/convert_checkpoint.py \\ --input_checkpoint /project/zhiwei/cq5/LLM_checkpoints/GeneMamba/GeneMamba2_48l_768d/1/4m/checkpoint-31250 \\ --output_dir ./converted_checkpoints/GeneMamba2_48l_768d """) parser.add_argument( "--input_checkpoint", required=True, help="Path to original GeneMamba checkpoint directory" ) parser.add_argument( "--output_dir", required=True, help="Output directory for HuggingFace compatible checkpoint" ) args = parser.parse_args() try: convert_checkpoint(args.input_checkpoint, args.output_dir) except Exception as e: print(f"\n✗ ERROR: {str(e)}") exit(1) if __name__ == "__main__": main()