File size: 9,700 Bytes
0861a59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
"""

Utility to convert PyTorch (.pt) checkpoints to Hugging Face (.bin) format

python -m utils.convert_checkpoints --checkpoints checkpoints/stdp_model_epoch_15.pt checkpoints/stdp_model_epoch_20.pt --output hf_stdp_model

"""
import os
import torch
import logging
import argparse
import datetime  # Added missing import
from pathlib import Path
from typing import Dict, Any, Optional
import json
import shutil

# Configure logging - Fix the typo in format string (levellevel → levelname)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

def convert_stdp_checkpoint(

    checkpoint_path: str, 

    output_dir: str,

    config_path: Optional[str] = None

) -> str:
    """

    Convert STDP/SNN PyTorch checkpoint to Hugging Face format.

    

    Args:

        checkpoint_path: Path to the .pt checkpoint file

        output_dir: Directory to save the converted model

        config_path: Optional path to config.json file

        

    Returns:

        Path to the converted model directory

    """
    logger.info(f"Converting checkpoint: {checkpoint_path}")
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    try:
        # Load checkpoint
        checkpoint = torch.load(checkpoint_path, map_location="cpu")
        
        # Extract epoch from filename if available
        checkpoint_filename = os.path.basename(checkpoint_path)
        epoch = None
        if "epoch_" in checkpoint_filename:
            try:
                epoch = int(checkpoint_filename.split("epoch_")[1].split(".")[0])
            except (ValueError, IndexError):
                pass
        
        # Create config for the model
        config = {
            "model_type": "stdp_snn",
            "architectures": ["STDPSpikeNeuralNetwork"],
            "epoch": epoch,
            "original_checkpoint": checkpoint_path,
            "conversion_date": str(datetime.datetime.now())
        }
        
        # Update with loaded config if it exists in checkpoint
        if isinstance(checkpoint, dict) and "config" in checkpoint:
            config.update(checkpoint["config"])
        
        # Load additional config from file if provided
        if config_path and os.path.exists(config_path):
            with open(config_path, 'r') as f:
                file_config = json.load(f)
                if "STDP_CONFIG" in file_config:
                    config.update(file_config["STDP_CONFIG"])
        
        # Extract model weights
        model_weights = {}
        if "model_state_dict" in checkpoint:
            model_weights = checkpoint["model_state_dict"]
        elif "state_dict" in checkpoint:
            model_weights = checkpoint["state_dict"]
        elif "weights" in checkpoint:
            model_weights = {"weights": checkpoint["weights"]}
        elif "synaptic_weights" in checkpoint:
            model_weights = {"synaptic_weights": checkpoint["synaptic_weights"]}
        else:
            # If no recognized format, assume the checkpoint itself is the weights
            model_weights = checkpoint
        
        # Create model directory structure
        model_dir = os.path.join(output_dir, "pytorch_model.bin")
        
        # Save converted weights in HF format
        torch.save(model_weights, model_dir)
        logger.info(f"Saved model weights to {model_dir}")
        
        # Save config file
        config_file = os.path.join(output_dir, "config.json")
        with open(config_file, 'w') as f:
            json.dump(config, f, indent=2)
        logger.info(f"Saved model config to {config_file}")
        
        # Create a simple README
        readme_file = os.path.join(output_dir, "README.md")
        with open(readme_file, 'w') as f:
            f.write(f"# Converted STDP/SNN Model\n\n")
            f.write(f"This model was converted from PyTorch checkpoint: `{checkpoint_path}`\n\n")
            f.write(f"Converted on: {config['conversion_date']}\n")
            if epoch is not None:
                f.write(f"Training epoch: {epoch}\n")
        
        return output_dir
        
    except Exception as e:
        logger.error(f"Error converting checkpoint: {e}")
        raise

def prepare_for_hf_upload(

    checkpoint_paths: list,

    output_dir: str,

    config_path: Optional[str] = None,

    include_code: bool = True

) -> str:
    """

    Prepare multiple checkpoints for HF upload with code.

    

    Args:

        checkpoint_paths: List of paths to checkpoint files

        output_dir: Directory to save the prepared model

        config_path: Optional path to config.json file

        include_code: Whether to include inference code

        

    Returns:

        Path to the prepared directory

    """
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Convert each checkpoint
    converted_models = []
    for cp_path in checkpoint_paths:
        model_name = os.path.splitext(os.path.basename(cp_path))[0]
        model_dir = os.path.join(output_dir, model_name)
        converted_models.append(convert_stdp_checkpoint(cp_path, model_dir, config_path))
    
    # Include necessary code files
    if include_code:
        code_files = [
            "communicator_STDP.py",
            "config.py",
            "model_Custm.py"
        ]
        
        for file in code_files:
            src_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), file)
            if os.path.exists(src_path):
                dst_path = os.path.join(output_dir, file)
                shutil.copy2(src_path, dst_path)
                logger.info(f"Copied {file} to {dst_path}")
        
        # Create an inference script - FIX: Use single quotes for inner docstring
        inference_script = '''

import torch

import os

import json

import argparse

from pathlib import Path



def load_stdp_model(model_dir):

    """Load STDP model from directory."""

    weights_path = os.path.join(model_dir, "pytorch_model.bin")

    config_path = os.path.join(model_dir, "config.json")

    

    # Load weights

    weights = torch.load(weights_path, map_location="cpu")

    

    # Load config

    with open(config_path, 'r') as f:

        config = json.load(f)

    

    return weights, config



def main():

    parser = argparse.ArgumentParser(description="Run inference with STDP model")

    parser.add_argument("--model", type=str, required=True, help="Model directory")

    parser.add_argument("--input", type=str, required=True, help="Input text or file")

    args = parser.parse_args()

    

    # Load model

    weights, config = load_stdp_model(args.model)

    print(f"Loaded model from {args.model}")

    print(f"Model config: {json.dumps(config, indent=2)}")

    

    # Get input

    input_text = args.input

    if os.path.exists(args.input):

        with open(args.input, 'r') as f:

            input_text = f.read()

    

    print(f"Input text: {input_text[:100]}...")

    

    # Run inference using communicator_STDP if available

    try:

        from communicator_STDP import CommSTDP

        communicator = CommSTDP({}, device="cpu")

        result = communicator.process(input_text, weights)

        print(f"Result: {result}")

    except ImportError:

        print("communicator_STDP not available. Weights loaded successfully.")

        print(f"Weights shape: {weights.shape if hasattr(weights, 'shape') else '[dict of tensors]'}")



if __name__ == "__main__":

    main()

'''
        
        inference_path = os.path.join(output_dir, "inference.py")
        with open(inference_path, 'w') as f:
            f.write(inference_script.strip())
        logger.info(f"Created inference script: {inference_path}")
    
    # Create an overall README
    readme_file = os.path.join(output_dir, "README.md")
    with open(readme_file, 'w') as f:
        f.write("# STDP/SNN Trained Models\n\n")
        f.write("This repository contains STDP/SNN models converted from PyTorch checkpoints for use with Hugging Face's infrastructure.\n\n")
        f.write("## Models Included\n\n")
        for i, model in enumerate(converted_models):
            f.write(f"{i+1}. `{os.path.basename(model)}`\n")
        
        f.write("\n## Usage\n\n")
        f.write("```python\n")
        f.write("from transformers import AutoModel\n\n")
        f.write("# Load the model\n")
        f.write("model = AutoModel.from_pretrained('your-username/your-model-name')\n")
        f.write("```\n\n")
        f.write("Or use the included inference.py script:\n\n")
        f.write("```bash\npython inference.py --model ./stdp_model_epoch_15 --input 'Your input text here'\n```")
    
    logger.info(f"Prepared {len(converted_models)} models for HF upload in {output_dir}")
    return output_dir

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Convert PyTorch checkpoints to Hugging Face format")
    parser.add_argument("--checkpoints", nargs="+", required=True, help="Paths to checkpoint files")
    parser.add_argument("--output", type=str, default="hf_model", help="Output directory")
    parser.add_argument("--config", type=str, help="Path to config.json file")
    parser.add_argument("--no-code", action="store_true", help="Don't include inference code")
    
    args = parser.parse_args()
    
    prepare_for_hf_upload(
        args.checkpoints,
        args.output,
        args.config,
        not args.no_code
    )