|
|
"""
|
|
|
Utility to convert PyTorch (.pt) checkpoints to Hugging Face (.bin) format
|
|
|
python -m utils.convert_checkpoints --checkpoints checkpoints/stdp_model_epoch_15.pt checkpoints/stdp_model_epoch_20.pt --output hf_stdp_model
|
|
|
"""
|
|
|
import os
|
|
|
import torch
|
|
|
import logging
|
|
|
import argparse
|
|
|
import datetime
|
|
|
from pathlib import Path
|
|
|
from typing import Dict, Any, Optional
|
|
|
import json
|
|
|
import shutil
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def convert_stdp_checkpoint(
|
|
|
checkpoint_path: str,
|
|
|
output_dir: str,
|
|
|
config_path: Optional[str] = None
|
|
|
) -> str:
|
|
|
"""
|
|
|
Convert STDP/SNN PyTorch checkpoint to Hugging Face format.
|
|
|
|
|
|
Args:
|
|
|
checkpoint_path: Path to the .pt checkpoint file
|
|
|
output_dir: Directory to save the converted model
|
|
|
config_path: Optional path to config.json file
|
|
|
|
|
|
Returns:
|
|
|
Path to the converted model directory
|
|
|
"""
|
|
|
logger.info(f"Converting checkpoint: {checkpoint_path}")
|
|
|
|
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
|
|
try:
|
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
|
|
|
|
|
|
|
|
checkpoint_filename = os.path.basename(checkpoint_path)
|
|
|
epoch = None
|
|
|
if "epoch_" in checkpoint_filename:
|
|
|
try:
|
|
|
epoch = int(checkpoint_filename.split("epoch_")[1].split(".")[0])
|
|
|
except (ValueError, IndexError):
|
|
|
pass
|
|
|
|
|
|
|
|
|
config = {
|
|
|
"model_type": "stdp_snn",
|
|
|
"architectures": ["STDPSpikeNeuralNetwork"],
|
|
|
"epoch": epoch,
|
|
|
"original_checkpoint": checkpoint_path,
|
|
|
"conversion_date": str(datetime.datetime.now())
|
|
|
}
|
|
|
|
|
|
|
|
|
if isinstance(checkpoint, dict) and "config" in checkpoint:
|
|
|
config.update(checkpoint["config"])
|
|
|
|
|
|
|
|
|
if config_path and os.path.exists(config_path):
|
|
|
with open(config_path, 'r') as f:
|
|
|
file_config = json.load(f)
|
|
|
if "STDP_CONFIG" in file_config:
|
|
|
config.update(file_config["STDP_CONFIG"])
|
|
|
|
|
|
|
|
|
model_weights = {}
|
|
|
if "model_state_dict" in checkpoint:
|
|
|
model_weights = checkpoint["model_state_dict"]
|
|
|
elif "state_dict" in checkpoint:
|
|
|
model_weights = checkpoint["state_dict"]
|
|
|
elif "weights" in checkpoint:
|
|
|
model_weights = {"weights": checkpoint["weights"]}
|
|
|
elif "synaptic_weights" in checkpoint:
|
|
|
model_weights = {"synaptic_weights": checkpoint["synaptic_weights"]}
|
|
|
else:
|
|
|
|
|
|
model_weights = checkpoint
|
|
|
|
|
|
|
|
|
model_dir = os.path.join(output_dir, "pytorch_model.bin")
|
|
|
|
|
|
|
|
|
torch.save(model_weights, model_dir)
|
|
|
logger.info(f"Saved model weights to {model_dir}")
|
|
|
|
|
|
|
|
|
config_file = os.path.join(output_dir, "config.json")
|
|
|
with open(config_file, 'w') as f:
|
|
|
json.dump(config, f, indent=2)
|
|
|
logger.info(f"Saved model config to {config_file}")
|
|
|
|
|
|
|
|
|
readme_file = os.path.join(output_dir, "README.md")
|
|
|
with open(readme_file, 'w') as f:
|
|
|
f.write(f"# Converted STDP/SNN Model\n\n")
|
|
|
f.write(f"This model was converted from PyTorch checkpoint: `{checkpoint_path}`\n\n")
|
|
|
f.write(f"Converted on: {config['conversion_date']}\n")
|
|
|
if epoch is not None:
|
|
|
f.write(f"Training epoch: {epoch}\n")
|
|
|
|
|
|
return output_dir
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error converting checkpoint: {e}")
|
|
|
raise
|
|
|
|
|
|
def prepare_for_hf_upload(
|
|
|
checkpoint_paths: list,
|
|
|
output_dir: str,
|
|
|
config_path: Optional[str] = None,
|
|
|
include_code: bool = True
|
|
|
) -> str:
|
|
|
"""
|
|
|
Prepare multiple checkpoints for HF upload with code.
|
|
|
|
|
|
Args:
|
|
|
checkpoint_paths: List of paths to checkpoint files
|
|
|
output_dir: Directory to save the prepared model
|
|
|
config_path: Optional path to config.json file
|
|
|
include_code: Whether to include inference code
|
|
|
|
|
|
Returns:
|
|
|
Path to the prepared directory
|
|
|
"""
|
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
converted_models = []
|
|
|
for cp_path in checkpoint_paths:
|
|
|
model_name = os.path.splitext(os.path.basename(cp_path))[0]
|
|
|
model_dir = os.path.join(output_dir, model_name)
|
|
|
converted_models.append(convert_stdp_checkpoint(cp_path, model_dir, config_path))
|
|
|
|
|
|
|
|
|
if include_code:
|
|
|
code_files = [
|
|
|
"communicator_STDP.py",
|
|
|
"config.py",
|
|
|
"model_Custm.py"
|
|
|
]
|
|
|
|
|
|
for file in code_files:
|
|
|
src_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), file)
|
|
|
if os.path.exists(src_path):
|
|
|
dst_path = os.path.join(output_dir, file)
|
|
|
shutil.copy2(src_path, dst_path)
|
|
|
logger.info(f"Copied {file} to {dst_path}")
|
|
|
|
|
|
|
|
|
inference_script = '''
|
|
|
import torch
|
|
|
import os
|
|
|
import json
|
|
|
import argparse
|
|
|
from pathlib import Path
|
|
|
|
|
|
def load_stdp_model(model_dir):
|
|
|
"""Load STDP model from directory."""
|
|
|
weights_path = os.path.join(model_dir, "pytorch_model.bin")
|
|
|
config_path = os.path.join(model_dir, "config.json")
|
|
|
|
|
|
# Load weights
|
|
|
weights = torch.load(weights_path, map_location="cpu")
|
|
|
|
|
|
# Load config
|
|
|
with open(config_path, 'r') as f:
|
|
|
config = json.load(f)
|
|
|
|
|
|
return weights, config
|
|
|
|
|
|
def main():
|
|
|
parser = argparse.ArgumentParser(description="Run inference with STDP model")
|
|
|
parser.add_argument("--model", type=str, required=True, help="Model directory")
|
|
|
parser.add_argument("--input", type=str, required=True, help="Input text or file")
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
# Load model
|
|
|
weights, config = load_stdp_model(args.model)
|
|
|
print(f"Loaded model from {args.model}")
|
|
|
print(f"Model config: {json.dumps(config, indent=2)}")
|
|
|
|
|
|
# Get input
|
|
|
input_text = args.input
|
|
|
if os.path.exists(args.input):
|
|
|
with open(args.input, 'r') as f:
|
|
|
input_text = f.read()
|
|
|
|
|
|
print(f"Input text: {input_text[:100]}...")
|
|
|
|
|
|
# Run inference using communicator_STDP if available
|
|
|
try:
|
|
|
from communicator_STDP import CommSTDP
|
|
|
communicator = CommSTDP({}, device="cpu")
|
|
|
result = communicator.process(input_text, weights)
|
|
|
print(f"Result: {result}")
|
|
|
except ImportError:
|
|
|
print("communicator_STDP not available. Weights loaded successfully.")
|
|
|
print(f"Weights shape: {weights.shape if hasattr(weights, 'shape') else '[dict of tensors]'}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
'''
|
|
|
|
|
|
inference_path = os.path.join(output_dir, "inference.py")
|
|
|
with open(inference_path, 'w') as f:
|
|
|
f.write(inference_script.strip())
|
|
|
logger.info(f"Created inference script: {inference_path}")
|
|
|
|
|
|
|
|
|
readme_file = os.path.join(output_dir, "README.md")
|
|
|
with open(readme_file, 'w') as f:
|
|
|
f.write("# STDP/SNN Trained Models\n\n")
|
|
|
f.write("This repository contains STDP/SNN models converted from PyTorch checkpoints for use with Hugging Face's infrastructure.\n\n")
|
|
|
f.write("## Models Included\n\n")
|
|
|
for i, model in enumerate(converted_models):
|
|
|
f.write(f"{i+1}. `{os.path.basename(model)}`\n")
|
|
|
|
|
|
f.write("\n## Usage\n\n")
|
|
|
f.write("```python\n")
|
|
|
f.write("from transformers import AutoModel\n\n")
|
|
|
f.write("# Load the model\n")
|
|
|
f.write("model = AutoModel.from_pretrained('your-username/your-model-name')\n")
|
|
|
f.write("```\n\n")
|
|
|
f.write("Or use the included inference.py script:\n\n")
|
|
|
f.write("```bash\npython inference.py --model ./stdp_model_epoch_15 --input 'Your input text here'\n```")
|
|
|
|
|
|
logger.info(f"Prepared {len(converted_models)} models for HF upload in {output_dir}")
|
|
|
return output_dir
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
parser = argparse.ArgumentParser(description="Convert PyTorch checkpoints to Hugging Face format")
|
|
|
parser.add_argument("--checkpoints", nargs="+", required=True, help="Paths to checkpoint files")
|
|
|
parser.add_argument("--output", type=str, default="hf_model", help="Output directory")
|
|
|
parser.add_argument("--config", type=str, help="Path to config.json file")
|
|
|
parser.add_argument("--no-code", action="store_true", help="Don't include inference code")
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
prepare_for_hf_upload(
|
|
|
args.checkpoints,
|
|
|
args.output,
|
|
|
args.config,
|
|
|
not args.no_code
|
|
|
)
|
|
|
|