Wildnerve-tlm01_Hybrid_Model / utils /convert_checkpoints.py
WildnerveAI's picture
Upload 20 files
0861a59 verified
"""
Utility to convert PyTorch (.pt) checkpoints to Hugging Face (.bin) format
python -m utils.convert_checkpoints --checkpoints checkpoints/stdp_model_epoch_15.pt checkpoints/stdp_model_epoch_20.pt --output hf_stdp_model
"""
import os
import torch
import logging
import argparse
import datetime # Added missing import
from pathlib import Path
from typing import Dict, Any, Optional
import json
import shutil
# Configure logging - Fix the typo in format string (levellevel → levelname)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
def convert_stdp_checkpoint(
checkpoint_path: str,
output_dir: str,
config_path: Optional[str] = None
) -> str:
"""
Convert STDP/SNN PyTorch checkpoint to Hugging Face format.
Args:
checkpoint_path: Path to the .pt checkpoint file
output_dir: Directory to save the converted model
config_path: Optional path to config.json file
Returns:
Path to the converted model directory
"""
logger.info(f"Converting checkpoint: {checkpoint_path}")
# Create output directory
os.makedirs(output_dir, exist_ok=True)
try:
# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location="cpu")
# Extract epoch from filename if available
checkpoint_filename = os.path.basename(checkpoint_path)
epoch = None
if "epoch_" in checkpoint_filename:
try:
epoch = int(checkpoint_filename.split("epoch_")[1].split(".")[0])
except (ValueError, IndexError):
pass
# Create config for the model
config = {
"model_type": "stdp_snn",
"architectures": ["STDPSpikeNeuralNetwork"],
"epoch": epoch,
"original_checkpoint": checkpoint_path,
"conversion_date": str(datetime.datetime.now())
}
# Update with loaded config if it exists in checkpoint
if isinstance(checkpoint, dict) and "config" in checkpoint:
config.update(checkpoint["config"])
# Load additional config from file if provided
if config_path and os.path.exists(config_path):
with open(config_path, 'r') as f:
file_config = json.load(f)
if "STDP_CONFIG" in file_config:
config.update(file_config["STDP_CONFIG"])
# Extract model weights
model_weights = {}
if "model_state_dict" in checkpoint:
model_weights = checkpoint["model_state_dict"]
elif "state_dict" in checkpoint:
model_weights = checkpoint["state_dict"]
elif "weights" in checkpoint:
model_weights = {"weights": checkpoint["weights"]}
elif "synaptic_weights" in checkpoint:
model_weights = {"synaptic_weights": checkpoint["synaptic_weights"]}
else:
# If no recognized format, assume the checkpoint itself is the weights
model_weights = checkpoint
# Create model directory structure
model_dir = os.path.join(output_dir, "pytorch_model.bin")
# Save converted weights in HF format
torch.save(model_weights, model_dir)
logger.info(f"Saved model weights to {model_dir}")
# Save config file
config_file = os.path.join(output_dir, "config.json")
with open(config_file, 'w') as f:
json.dump(config, f, indent=2)
logger.info(f"Saved model config to {config_file}")
# Create a simple README
readme_file = os.path.join(output_dir, "README.md")
with open(readme_file, 'w') as f:
f.write(f"# Converted STDP/SNN Model\n\n")
f.write(f"This model was converted from PyTorch checkpoint: `{checkpoint_path}`\n\n")
f.write(f"Converted on: {config['conversion_date']}\n")
if epoch is not None:
f.write(f"Training epoch: {epoch}\n")
return output_dir
except Exception as e:
logger.error(f"Error converting checkpoint: {e}")
raise
def prepare_for_hf_upload(
checkpoint_paths: list,
output_dir: str,
config_path: Optional[str] = None,
include_code: bool = True
) -> str:
"""
Prepare multiple checkpoints for HF upload with code.
Args:
checkpoint_paths: List of paths to checkpoint files
output_dir: Directory to save the prepared model
config_path: Optional path to config.json file
include_code: Whether to include inference code
Returns:
Path to the prepared directory
"""
# Create output directory
os.makedirs(output_dir, exist_ok=True)
# Convert each checkpoint
converted_models = []
for cp_path in checkpoint_paths:
model_name = os.path.splitext(os.path.basename(cp_path))[0]
model_dir = os.path.join(output_dir, model_name)
converted_models.append(convert_stdp_checkpoint(cp_path, model_dir, config_path))
# Include necessary code files
if include_code:
code_files = [
"communicator_STDP.py",
"config.py",
"model_Custm.py"
]
for file in code_files:
src_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), file)
if os.path.exists(src_path):
dst_path = os.path.join(output_dir, file)
shutil.copy2(src_path, dst_path)
logger.info(f"Copied {file} to {dst_path}")
# Create an inference script - FIX: Use single quotes for inner docstring
inference_script = '''
import torch
import os
import json
import argparse
from pathlib import Path
def load_stdp_model(model_dir):
"""Load STDP model from directory."""
weights_path = os.path.join(model_dir, "pytorch_model.bin")
config_path = os.path.join(model_dir, "config.json")
# Load weights
weights = torch.load(weights_path, map_location="cpu")
# Load config
with open(config_path, 'r') as f:
config = json.load(f)
return weights, config
def main():
parser = argparse.ArgumentParser(description="Run inference with STDP model")
parser.add_argument("--model", type=str, required=True, help="Model directory")
parser.add_argument("--input", type=str, required=True, help="Input text or file")
args = parser.parse_args()
# Load model
weights, config = load_stdp_model(args.model)
print(f"Loaded model from {args.model}")
print(f"Model config: {json.dumps(config, indent=2)}")
# Get input
input_text = args.input
if os.path.exists(args.input):
with open(args.input, 'r') as f:
input_text = f.read()
print(f"Input text: {input_text[:100]}...")
# Run inference using communicator_STDP if available
try:
from communicator_STDP import CommSTDP
communicator = CommSTDP({}, device="cpu")
result = communicator.process(input_text, weights)
print(f"Result: {result}")
except ImportError:
print("communicator_STDP not available. Weights loaded successfully.")
print(f"Weights shape: {weights.shape if hasattr(weights, 'shape') else '[dict of tensors]'}")
if __name__ == "__main__":
main()
'''
inference_path = os.path.join(output_dir, "inference.py")
with open(inference_path, 'w') as f:
f.write(inference_script.strip())
logger.info(f"Created inference script: {inference_path}")
# Create an overall README
readme_file = os.path.join(output_dir, "README.md")
with open(readme_file, 'w') as f:
f.write("# STDP/SNN Trained Models\n\n")
f.write("This repository contains STDP/SNN models converted from PyTorch checkpoints for use with Hugging Face's infrastructure.\n\n")
f.write("## Models Included\n\n")
for i, model in enumerate(converted_models):
f.write(f"{i+1}. `{os.path.basename(model)}`\n")
f.write("\n## Usage\n\n")
f.write("```python\n")
f.write("from transformers import AutoModel\n\n")
f.write("# Load the model\n")
f.write("model = AutoModel.from_pretrained('your-username/your-model-name')\n")
f.write("```\n\n")
f.write("Or use the included inference.py script:\n\n")
f.write("```bash\npython inference.py --model ./stdp_model_epoch_15 --input 'Your input text here'\n```")
logger.info(f"Prepared {len(converted_models)} models for HF upload in {output_dir}")
return output_dir
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert PyTorch checkpoints to Hugging Face format")
parser.add_argument("--checkpoints", nargs="+", required=True, help="Paths to checkpoint files")
parser.add_argument("--output", type=str, default="hf_model", help="Output directory")
parser.add_argument("--config", type=str, help="Path to config.json file")
parser.add_argument("--no-code", action="store_true", help="Don't include inference code")
args = parser.parse_args()
prepare_for_hf_upload(
args.checkpoints,
args.output,
args.config,
not args.no_code
)