""" Utility to convert PyTorch (.pt) checkpoints to Hugging Face (.bin) format python -m utils.convert_checkpoints --checkpoints checkpoints/stdp_model_epoch_15.pt checkpoints/stdp_model_epoch_20.pt --output hf_stdp_model """ import os import torch import logging import argparse import datetime # Added missing import from pathlib import Path from typing import Dict, Any, Optional import json import shutil # Configure logging - Fix the typo in format string (levellevel → levelname) logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) def convert_stdp_checkpoint( checkpoint_path: str, output_dir: str, config_path: Optional[str] = None ) -> str: """ Convert STDP/SNN PyTorch checkpoint to Hugging Face format. Args: checkpoint_path: Path to the .pt checkpoint file output_dir: Directory to save the converted model config_path: Optional path to config.json file Returns: Path to the converted model directory """ logger.info(f"Converting checkpoint: {checkpoint_path}") # Create output directory os.makedirs(output_dir, exist_ok=True) try: # Load checkpoint checkpoint = torch.load(checkpoint_path, map_location="cpu") # Extract epoch from filename if available checkpoint_filename = os.path.basename(checkpoint_path) epoch = None if "epoch_" in checkpoint_filename: try: epoch = int(checkpoint_filename.split("epoch_")[1].split(".")[0]) except (ValueError, IndexError): pass # Create config for the model config = { "model_type": "stdp_snn", "architectures": ["STDPSpikeNeuralNetwork"], "epoch": epoch, "original_checkpoint": checkpoint_path, "conversion_date": str(datetime.datetime.now()) } # Update with loaded config if it exists in checkpoint if isinstance(checkpoint, dict) and "config" in checkpoint: config.update(checkpoint["config"]) # Load additional config from file if provided if config_path and os.path.exists(config_path): with open(config_path, 'r') as f: file_config = json.load(f) if "STDP_CONFIG" in file_config: config.update(file_config["STDP_CONFIG"]) # Extract model weights model_weights = {} if "model_state_dict" in checkpoint: model_weights = checkpoint["model_state_dict"] elif "state_dict" in checkpoint: model_weights = checkpoint["state_dict"] elif "weights" in checkpoint: model_weights = {"weights": checkpoint["weights"]} elif "synaptic_weights" in checkpoint: model_weights = {"synaptic_weights": checkpoint["synaptic_weights"]} else: # If no recognized format, assume the checkpoint itself is the weights model_weights = checkpoint # Create model directory structure model_dir = os.path.join(output_dir, "pytorch_model.bin") # Save converted weights in HF format torch.save(model_weights, model_dir) logger.info(f"Saved model weights to {model_dir}") # Save config file config_file = os.path.join(output_dir, "config.json") with open(config_file, 'w') as f: json.dump(config, f, indent=2) logger.info(f"Saved model config to {config_file}") # Create a simple README readme_file = os.path.join(output_dir, "README.md") with open(readme_file, 'w') as f: f.write(f"# Converted STDP/SNN Model\n\n") f.write(f"This model was converted from PyTorch checkpoint: `{checkpoint_path}`\n\n") f.write(f"Converted on: {config['conversion_date']}\n") if epoch is not None: f.write(f"Training epoch: {epoch}\n") return output_dir except Exception as e: logger.error(f"Error converting checkpoint: {e}") raise def prepare_for_hf_upload( checkpoint_paths: list, output_dir: str, config_path: Optional[str] = None, include_code: bool = True ) -> str: """ Prepare multiple checkpoints for HF upload with code. Args: checkpoint_paths: List of paths to checkpoint files output_dir: Directory to save the prepared model config_path: Optional path to config.json file include_code: Whether to include inference code Returns: Path to the prepared directory """ # Create output directory os.makedirs(output_dir, exist_ok=True) # Convert each checkpoint converted_models = [] for cp_path in checkpoint_paths: model_name = os.path.splitext(os.path.basename(cp_path))[0] model_dir = os.path.join(output_dir, model_name) converted_models.append(convert_stdp_checkpoint(cp_path, model_dir, config_path)) # Include necessary code files if include_code: code_files = [ "communicator_STDP.py", "config.py", "model_Custm.py" ] for file in code_files: src_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), file) if os.path.exists(src_path): dst_path = os.path.join(output_dir, file) shutil.copy2(src_path, dst_path) logger.info(f"Copied {file} to {dst_path}") # Create an inference script - FIX: Use single quotes for inner docstring inference_script = ''' import torch import os import json import argparse from pathlib import Path def load_stdp_model(model_dir): """Load STDP model from directory.""" weights_path = os.path.join(model_dir, "pytorch_model.bin") config_path = os.path.join(model_dir, "config.json") # Load weights weights = torch.load(weights_path, map_location="cpu") # Load config with open(config_path, 'r') as f: config = json.load(f) return weights, config def main(): parser = argparse.ArgumentParser(description="Run inference with STDP model") parser.add_argument("--model", type=str, required=True, help="Model directory") parser.add_argument("--input", type=str, required=True, help="Input text or file") args = parser.parse_args() # Load model weights, config = load_stdp_model(args.model) print(f"Loaded model from {args.model}") print(f"Model config: {json.dumps(config, indent=2)}") # Get input input_text = args.input if os.path.exists(args.input): with open(args.input, 'r') as f: input_text = f.read() print(f"Input text: {input_text[:100]}...") # Run inference using communicator_STDP if available try: from communicator_STDP import CommSTDP communicator = CommSTDP({}, device="cpu") result = communicator.process(input_text, weights) print(f"Result: {result}") except ImportError: print("communicator_STDP not available. Weights loaded successfully.") print(f"Weights shape: {weights.shape if hasattr(weights, 'shape') else '[dict of tensors]'}") if __name__ == "__main__": main() ''' inference_path = os.path.join(output_dir, "inference.py") with open(inference_path, 'w') as f: f.write(inference_script.strip()) logger.info(f"Created inference script: {inference_path}") # Create an overall README readme_file = os.path.join(output_dir, "README.md") with open(readme_file, 'w') as f: f.write("# STDP/SNN Trained Models\n\n") f.write("This repository contains STDP/SNN models converted from PyTorch checkpoints for use with Hugging Face's infrastructure.\n\n") f.write("## Models Included\n\n") for i, model in enumerate(converted_models): f.write(f"{i+1}. `{os.path.basename(model)}`\n") f.write("\n## Usage\n\n") f.write("```python\n") f.write("from transformers import AutoModel\n\n") f.write("# Load the model\n") f.write("model = AutoModel.from_pretrained('your-username/your-model-name')\n") f.write("```\n\n") f.write("Or use the included inference.py script:\n\n") f.write("```bash\npython inference.py --model ./stdp_model_epoch_15 --input 'Your input text here'\n```") logger.info(f"Prepared {len(converted_models)} models for HF upload in {output_dir}") return output_dir if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert PyTorch checkpoints to Hugging Face format") parser.add_argument("--checkpoints", nargs="+", required=True, help="Paths to checkpoint files") parser.add_argument("--output", type=str, default="hf_model", help="Output directory") parser.add_argument("--config", type=str, help="Path to config.json file") parser.add_argument("--no-code", action="store_true", help="Don't include inference code") args = parser.parse_args() prepare_for_hf_upload( args.checkpoints, args.output, args.config, not args.no_code )