Wildnerve-tlm01_Hybrid_Model / utils /prepare_hf_transformer_training.py
WildnerveAI's picture
Upload 20 files
0861a59 verified
"""
How to Use prepare_hf_transformer_training.py Safely
Here's a secure way to prepare and upload your model to Hugging Face:
Step 1: Navigate to Your Project Directory
cd C:/Users/User/OneDrive/Documents/tlm
Step 2: Set Up Authentication for Hugging Face
huggingface-cli login
Step 3: Run the Preparation Script
python -m utils.prepare_hf_transformer_training --stdp_checkpoint "checkpoints/stdp_model_epoch_20.pt" --output_dir "C:/Users/User/OneDrive/Documents/tlm/Wildnerve-tlm_HF/hf_upload"
Step 4: Initialize Git and Upload to Hugging Face
cd hf_upload
git init
git add .
git commit -m "Add TLM model with STDP checkpoint"
git remote add origin https://huggingface.co/YOUR-USERNAME/Wildnerve-tlm01
git pull origin main --allow-unrelated-histories
git push origin main
"""
import os
import shutil
import logging
import argparse
from pathlib import Path
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
def prepare_training_package(
stdp_checkpoint_path,
output_dir="hf_transformer_training",
include_all=False
):
"""Prepare a clean training package for Hugging Face with STDP checkpoint.
Args:
stdp_checkpoint_path: Path to the STDP checkpoint file
output_dir: Directory where to create the package
include_all: Whether to include all supporting files (utils, analyzers, etc.)"""
os.makedirs(output_dir, exist_ok=True)
# Core files needed for transformer training
essential_files = [
# Core components
"app.py",
"main.py",
"config.json",
"config.py",
"inference.py",
# Model implementations
"model_List.py",
"model_Custm.py",
"model_PrTr.py",
"model_Combn.py",
"model_manager.py",
# Communication components
"communicator.py",
"communicator_STDP.py",
# Data and training
"tokenizer.py",
"trainer.py",
"dataloader.py",
"dataset.py",
"data",
# STDP specific components
"STDP_Communicator/datasets_stdp.py",
"STDP_Communicator/train_stdp.py",
# Utils (only essential ones)
"utils/convert_checkpoints.py",
]
# Additional support files (only included if include_all=True)
additional_files = [
"utils/transformer_utils.py",
"utils/smartHybridAttention.py",
"utils/sentence_transformer_utils.py",
"utils/output_formatter.py",
"emergency_monitor.py",
]
# Choose which files to copy
required_files = essential_files + (additional_files if include_all else [])
logger.info(f"Starting package preparation in {output_dir}")
logger.info(f"Including {'all' if include_all else 'only essential'} files")
# Track successful and failed copies
copied_files = []
missing_files = []
# Copy files
for file_path in required_files:
src = Path(file_path)
if not src.exists():
logger.warning(f"File {file_path} not found, skipping")
missing_files.append(str(src))
continue
# Create destination directories
dst = Path(output_dir) / src
os.makedirs(dst.parent, exist_ok=True)
# Copy file or directory
try:
if src.is_dir():
shutil.copytree(src, dst, dirs_exist_ok=True)
else:
shutil.copy2(src, dst)
copied_files.append(str(src))
logger.info(f"Copied {src} to {dst}")
except Exception as e:
logger.error(f"Error copying {src}: {e}")
# Copy STDP checkpoint
if os.path.exists(stdp_checkpoint_path):
stdp_dst = Path(output_dir) / "checkpoints" / Path(stdp_checkpoint_path).name
os.makedirs(stdp_dst.parent, exist_ok=True)
try:
shutil.copy2(stdp_checkpoint_path, stdp_dst)
logger.info(f"Copied STDP checkpoint to {stdp_dst}")
copied_files.append(str(stdp_checkpoint_path))
except Exception as e:
logger.error(f"Error copying checkpoint: {e}")
missing_files.append(str(stdp_checkpoint_path))
else:
logger.warning(f"STDP checkpoint not found at {stdp_checkpoint_path}")
missing_files.append(str(stdp_checkpoint_path))
# Create Hugging Face training script
create_transformer_training_script(output_dir, stdp_checkpoint_path) # ADD THIS LINE
# Create requirements.txt if not already copied
if "requirements.txt" not in copied_files:
create_requirements(output_dir)
copied_files.append("requirements.txt (generated)")
# Create README.md if not already copied
if "README.md" not in copied_files:
create_readme(output_dir, stdp_checkpoint_path)
copied_files.append("README.md (generated)")
# Summarize what was done
logger.info(f"Package prepared in {output_dir}")
logger.info(f"Copied {len(copied_files)} files: {', '.join(copied_files[:5])}...")
if missing_files:
logger.warning(f"Missing {len(missing_files)} files: {', '.join(missing_files)}")
return output_dir
def create_transformer_training_script(output_dir, stdp_checkpoint_path):
"""Create a script to load STDP checkpoint and train transformer."""
# Fix: Change the inner docstring to use single quotes to avoid conflict with the outer triple quotes
script = """
import os
import torch
import logging
from config import load_config, app_config
from tokenizer import TokenizerWrapper
from model_manager import ModelManager
from dataloader import prepare_data_loaders
from trainer import Trainer, EarlyStopping
# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
def train_transformer(stdp_checkpoint_path):
'''Train the transformer component after loading STDP weights.'''
logger.info(f"Starting transformer training with STDP checkpoint: {stdp_checkpoint_path}")
# Initialize components
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")
# Create tokenizer
tokenizer = TokenizerWrapper()
# Get model manager
model_manager = ModelManager()
# Get specialization
specialization = app_config.TRANSFORMER_CONFIG.specialization
# Load STDP weights
if os.path.exists(stdp_checkpoint_path):
try:
stdp_checkpoint = torch.load(stdp_checkpoint_path, map_location=device)
logger.info(f"Loaded STDP checkpoint from {stdp_checkpoint_path}")
# Now integrate STDP weights with transformer model if needed
# This depends on your specific architecture
except Exception as e:
logger.error(f"Error loading STDP checkpoint: {e}")
else:
logger.warning(f"STDP checkpoint not found at {stdp_checkpoint_path}")
# Get model and move to device
model = model_manager.get_model(specialization)
model.to(device)
# Get data loaders
data_path = app_config.DATASET_PATHS.get(specialization)
if not data_path or not os.path.exists(data_path):
# Use a default dataset path
data_path = next(iter(app_config.DATASET_PATHS.values()))
logger.warning(f"Dataset for {specialization} not found, using {data_path}")
train_loader, val_loader = prepare_data_loaders(
data_path,
tokenizer,
batch_size=app_config.TRANSFORMER_CONFIG.BATCH_SIZE
)
# Set up checkpoint directory
checkpoint_dir = os.path.join("checkpoints", "transformer")
os.makedirs(checkpoint_dir, exist_ok=True)
# Set up early stopping
early_stopping = EarlyStopping(
patience=app_config.TRAINING_CONFIG.PATIENCE,
delta=app_config.TRAINING_CONFIG.DELTA,
verbose=True,
path=os.path.join(checkpoint_dir, "best_model.pt")
)
# Create trainer
trainer = Trainer(
model=model,
tokenizer=tokenizer,
train_dataloader=train_loader,
val_dataloader=val_loader,
device=device,
early_stopping=early_stopping,
checkpoint_dir=checkpoint_dir,
total_epochs=app_config.TRAINING_CONFIG.TRANSFORMER_NUM_EPOCHS
)
# Train the model
logger.info("Starting transformer training...")
trainer.train()
# Save final model
final_model_path = os.path.join(checkpoint_dir, "final_model.pt")
torch.save({
'model_state_dict': model.state_dict(),
'config': {
'transformer_epochs': app_config.TRAINING_CONFIG.TRANSFORMER_NUM_EPOCHS,
'stdp_epochs': 20, # Assuming the STDP checkpoint is from epoch 20
'specialization': specialization
}
}, final_model_path)
logger.info(f"Final model saved to {final_model_path}")
return final_model_path
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Train transformer after STDP")
parser.add_argument("--stdp_checkpoint", type=str, default="checkpoints/stdp_model_epoch_20.pt",
help="Path to pre-trained STDP checkpoint")
args = parser.parse_args()
# Train transformer
train_transformer(args.stdp_checkpoint)
"""
script_path = os.path.join(output_dir, "train_transformer_hf.py")
with open(script_path, "w") as f:
f.write(script.strip())
logger.info(f"Created training script at {script_path}")
def create_requirements(output_dir):
"""Create requirements.txt file with all necessary dependencies."""
requirements = [
"torch>=2.0.0",
"transformers>=4.30.0",
"datasets>=2.12.0",
"pydantic>=2.0.0",
"sentence-transformers>=2.2.2",
"scikit-learn>=1.2.2",
"numpy>=1.24.0",
"pandas>=2.0.0",
"tqdm>=4.65.0",
"matplotlib>=3.7.1",
"snntorch>=0.7.0"
]
with open(os.path.join(output_dir, "requirements.txt"), "w") as f:
f.write("\n".join(requirements))
logger.info("Created requirements.txt")
def create_readme(output_dir, stdp_checkpoint_path):
"""Create README with model information and usage instructions."""
readme = f"""# Wildnerve-tlm01: Transformer Language Model with STDP
This repository contains the Wildnerve-tlm01 model, a transformer-based language model enhanced with
STDP (Spike-Timing-Dependent Plasticity) for improved learning capabilities.
## Pre-trained STDP Checkpoint
The STDP component was trained for 20 epochs and saved in: `{os.path.basename(stdp_checkpoint_path)}`
## Model Architecture
Wildnerve-tlm01 combines:
- Transformer architecture for language understanding
- Spiking Neural Network (SNN) with STDP for biological learning
- Smart Hybrid Attention for efficient processing
## Usage
"""
with open(os.path.join(output_dir, "README.md"), "w") as f:
f.write(readme)
logger.info("Created README.md")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Prepare Hugging Face training package")
parser.add_argument("--stdp_checkpoint", type=str, default="checkpoints/stdp_model_epoch_20.pt",
help="Path to pre-trained STDP checkpoint")
parser.add_argument("--output_dir", type=str, default="hf_upload",
help="Output directory for training package")
parser.add_argument("--include_all", action="store_true",
help="Include additional supporting files")
args = parser.parse_args()
prepare_training_package(args.stdp_checkpoint, args.output_dir, args.include_all)