WildnerveAI's picture
Upload 11 files
4b1fd1d verified
"""
Train a new Wildnerve model with parameters loaded from config.json.
"""
import os
import sys
import torch
import logging
import argparse
from pathlib import Path
from typing import Dict, Any, Optional, List, Tuple
# Import configuration
from config import app_config, get_model_architecture_params
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def train_model(
specialization: str,
dataset_path: str,
output_dir: str,
num_epochs: Optional[int] = None,
batch_size: Optional[int] = None,
learning_rate: Optional[float] = None,
device: Optional[str] = None
):
"""Train a model with parameters from config.json"""
# Get model architecture parameters from config.json
arch_params = get_model_architecture_params()
logger.info(f"Loaded architecture parameters from config: {arch_params}")
# Get training parameters from config.json
if hasattr(app_config, "TRAINING_CONFIG"):
training_config = app_config.TRAINING_CONFIG
num_epochs = num_epochs or getattr(training_config, "NUM_EPOCHS", 10)
learning_rate = learning_rate or getattr(training_config, "LEARNING_RATE", 1e-4)
elif hasattr(app_config, "TRANSFORMER_CONFIG"):
transformer_config = app_config.TRANSFORMER_CONFIG
num_epochs = num_epochs or getattr(transformer_config, "NUM_EPOCHS", 10)
learning_rate = learning_rate or getattr(transformer_config, "LEARNING_RATE", 1e-4)
# Get data loader parameters from config.json
if hasattr(app_config, "DATA_LOADER_CONFIG"):
data_loader_config = app_config.DATA_LOADER_CONFIG
batch_size = batch_size or getattr(data_loader_config, "BATCH_SIZE", 32)
# Use command-line values as overrides, or fall back to defaults
num_epochs = num_epochs or 10
batch_size = batch_size or 32
learning_rate = learning_rate or 1e-4
# Create output directory
os.makedirs(output_dir, exist_ok=True)
# Set device
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
try:
# Import necessary modules
from model_Custm import Wildnerve_tlm01
from transformers import AutoTokenizer
from torch.utils.data import DataLoader, Dataset
import json
# Get model name from config
model_name = getattr(app_config.TRANSFORMER_CONFIG, "MODEL_NAME", "gpt2") if hasattr(app_config, "TRANSFORMER_CONFIG") else "gpt2"
# Initialize the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load dataset
logger.info(f"Loading dataset from {dataset_path}")
with open(dataset_path, 'r') as f:
data = json.load(f)
# Create a simple dataset class
class TextDataset(Dataset):
def __init__(self, texts, tokenizer, max_length):
self.encodings = tokenizer(texts, truncation=True, padding="max_length",
max_length=max_length, return_tensors="pt")
def __getitem__(self, idx):
item = {key: val[idx] for key, val in self.encodings.items()}
item["labels"] = item["input_ids"].clone()
return item
def __len__(self):
return len(self.encodings["input_ids"])
# Extract texts from your dataset
texts = [item["text"] for item in data]
# Create dataset and dataloader
train_dataset = TextDataset(texts, tokenizer, arch_params["max_seq_length"])
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# Log key parameters
logger.info(f"Training with parameters:")
logger.info(f"- specialization: {specialization}")
logger.info(f"- model_name: {model_name}")
logger.info(f"- embedding_dim: {arch_params['embedding_dim']}")
logger.info(f"- hidden_dim: {arch_params['hidden_dim']}")
logger.info(f"- num_heads: {arch_params['num_heads']}")
logger.info(f"- num_layers: {arch_params['num_layers']}")
logger.info(f"- vocab_size: {arch_params['vocab_size']}")
logger.info(f"- num_epochs: {num_epochs}")
logger.info(f"- batch_size: {batch_size}")
logger.info(f"- learning_rate: {learning_rate}")
# Initialize the model with architecture parameters from config
model = Wildnerve_tlm01(
vocab_size=arch_params["vocab_size"],
specialization=specialization,
dataset_path=dataset_path,
model_name=model_name,
embedding_dim=arch_params["embedding_dim"],
num_heads=arch_params["num_heads"],
hidden_dim=arch_params["hidden_dim"],
num_layers=arch_params["num_layers"],
output_size=arch_params["vocab_size"],
dropout=arch_params.get("dropout", 0.1),
max_seq_length=arch_params["max_seq_length"],
tokenizer=tokenizer
)
# Move model to the device
model.to(device)
# Set up optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
# Training loop
logger.info(f"Starting training for {num_epochs} epochs")
for epoch in range(num_epochs):
model.train()
total_loss = 0
for batch_idx, batch in enumerate(train_dataloader):
# Move batch to device
batch = {k: v.to(device) for k, v in batch.items()}
# Forward pass
outputs = model(batch["input_ids"],
attention_mask=batch.get("attention_mask"))
# Calculate loss
loss = torch.nn.functional.cross_entropy(
outputs.view(-1, outputs.size(-1)),
batch["labels"].view(-1)
)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Track loss
total_loss += loss.item()
if (batch_idx + 1) % 10 == 0:
logger.info(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx+1}/{len(train_dataloader)}, "
f"Loss: {loss.item():.4f}")
avg_loss = total_loss / len(train_dataloader)
logger.info(f"Epoch {epoch+1}/{num_epochs} completed. Average loss: {avg_loss:.4f}")
# Save checkpoint
checkpoint_path = os.path.join(output_dir, f"model_epoch_{epoch+1}.bin")
torch.save({
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"epoch": epoch,
"loss": avg_loss,
"config": {
"embedding_dim": arch_params["embedding_dim"],
"hidden_dim": arch_params["hidden_dim"],
"num_heads": arch_params["num_heads"],
"num_layers": arch_params["num_layers"],
"vocab_size": arch_params["vocab_size"]
}
}, checkpoint_path)
logger.info(f"Saved checkpoint to {checkpoint_path}")
# Save final model
final_model_path = os.path.join(output_dir, f"{specialization}_final_model.bin")
torch.save({
"model_state_dict": model.state_dict(),
"config": {
"embedding_dim": arch_params["embedding_dim"],
"hidden_dim": arch_params["hidden_dim"],
"num_heads": arch_params["num_heads"],
"num_layers": arch_params["num_layers"],
"vocab_size": arch_params["vocab_size"]
}
}, final_model_path)
logger.info(f"Training completed. Final model saved to {final_model_path}")
return final_model_path
except Exception as e:
logger.error(f"Error during training: {e}", exc_info=True)
return None
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train a Wildnerve model")
parser.add_argument("--specialization", type=str, default="general", help="Model specialization")
parser.add_argument("--dataset", type=str, required=True, help="Path to the dataset file")
parser.add_argument("--output", type=str, default="./checkpoints", help="Output directory")
parser.add_argument("--epochs", type=int, help="Number of training epochs (overrides config)")
parser.add_argument("--batch-size", type=int, help="Batch size (overrides config)")
parser.add_argument("--learning-rate", type=float, help="Learning rate (overrides config)")
parser.add_argument("--device", type=str, help="Device to use (cuda or cpu)")
args = parser.parse_args()
train_model(
specialization=args.specialization,
dataset_path=args.dataset,
output_dir=args.output,
num_epochs=args.epochs,
batch_size=args.batch_size,
learning_rate=args.learning_rate,
device=args.device
)