""" Train a new Wildnerve model with parameters loaded from config.json. """ import os import sys import torch import logging import argparse from pathlib import Path from typing import Dict, Any, Optional, List, Tuple # Import configuration from config import app_config, get_model_architecture_params # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) def train_model( specialization: str, dataset_path: str, output_dir: str, num_epochs: Optional[int] = None, batch_size: Optional[int] = None, learning_rate: Optional[float] = None, device: Optional[str] = None ): """Train a model with parameters from config.json""" # Get model architecture parameters from config.json arch_params = get_model_architecture_params() logger.info(f"Loaded architecture parameters from config: {arch_params}") # Get training parameters from config.json if hasattr(app_config, "TRAINING_CONFIG"): training_config = app_config.TRAINING_CONFIG num_epochs = num_epochs or getattr(training_config, "NUM_EPOCHS", 10) learning_rate = learning_rate or getattr(training_config, "LEARNING_RATE", 1e-4) elif hasattr(app_config, "TRANSFORMER_CONFIG"): transformer_config = app_config.TRANSFORMER_CONFIG num_epochs = num_epochs or getattr(transformer_config, "NUM_EPOCHS", 10) learning_rate = learning_rate or getattr(transformer_config, "LEARNING_RATE", 1e-4) # Get data loader parameters from config.json if hasattr(app_config, "DATA_LOADER_CONFIG"): data_loader_config = app_config.DATA_LOADER_CONFIG batch_size = batch_size or getattr(data_loader_config, "BATCH_SIZE", 32) # Use command-line values as overrides, or fall back to defaults num_epochs = num_epochs or 10 batch_size = batch_size or 32 learning_rate = learning_rate or 1e-4 # Create output directory os.makedirs(output_dir, exist_ok=True) # Set device if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {device}") try: # Import necessary modules from model_Custm import Wildnerve_tlm01 from transformers import AutoTokenizer from torch.utils.data import DataLoader, Dataset import json # Get model name from config model_name = getattr(app_config.TRANSFORMER_CONFIG, "MODEL_NAME", "gpt2") if hasattr(app_config, "TRANSFORMER_CONFIG") else "gpt2" # Initialize the tokenizer tokenizer = AutoTokenizer.from_pretrained(model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Load dataset logger.info(f"Loading dataset from {dataset_path}") with open(dataset_path, 'r') as f: data = json.load(f) # Create a simple dataset class class TextDataset(Dataset): def __init__(self, texts, tokenizer, max_length): self.encodings = tokenizer(texts, truncation=True, padding="max_length", max_length=max_length, return_tensors="pt") def __getitem__(self, idx): item = {key: val[idx] for key, val in self.encodings.items()} item["labels"] = item["input_ids"].clone() return item def __len__(self): return len(self.encodings["input_ids"]) # Extract texts from your dataset texts = [item["text"] for item in data] # Create dataset and dataloader train_dataset = TextDataset(texts, tokenizer, arch_params["max_seq_length"]) train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) # Log key parameters logger.info(f"Training with parameters:") logger.info(f"- specialization: {specialization}") logger.info(f"- model_name: {model_name}") logger.info(f"- embedding_dim: {arch_params['embedding_dim']}") logger.info(f"- hidden_dim: {arch_params['hidden_dim']}") logger.info(f"- num_heads: {arch_params['num_heads']}") logger.info(f"- num_layers: {arch_params['num_layers']}") logger.info(f"- vocab_size: {arch_params['vocab_size']}") logger.info(f"- num_epochs: {num_epochs}") logger.info(f"- batch_size: {batch_size}") logger.info(f"- learning_rate: {learning_rate}") # Initialize the model with architecture parameters from config model = Wildnerve_tlm01( vocab_size=arch_params["vocab_size"], specialization=specialization, dataset_path=dataset_path, model_name=model_name, embedding_dim=arch_params["embedding_dim"], num_heads=arch_params["num_heads"], hidden_dim=arch_params["hidden_dim"], num_layers=arch_params["num_layers"], output_size=arch_params["vocab_size"], dropout=arch_params.get("dropout", 0.1), max_seq_length=arch_params["max_seq_length"], tokenizer=tokenizer ) # Move model to the device model.to(device) # Set up optimizer optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) # Training loop logger.info(f"Starting training for {num_epochs} epochs") for epoch in range(num_epochs): model.train() total_loss = 0 for batch_idx, batch in enumerate(train_dataloader): # Move batch to device batch = {k: v.to(device) for k, v in batch.items()} # Forward pass outputs = model(batch["input_ids"], attention_mask=batch.get("attention_mask")) # Calculate loss loss = torch.nn.functional.cross_entropy( outputs.view(-1, outputs.size(-1)), batch["labels"].view(-1) ) # Backward pass optimizer.zero_grad() loss.backward() optimizer.step() # Track loss total_loss += loss.item() if (batch_idx + 1) % 10 == 0: logger.info(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx+1}/{len(train_dataloader)}, " f"Loss: {loss.item():.4f}") avg_loss = total_loss / len(train_dataloader) logger.info(f"Epoch {epoch+1}/{num_epochs} completed. Average loss: {avg_loss:.4f}") # Save checkpoint checkpoint_path = os.path.join(output_dir, f"model_epoch_{epoch+1}.bin") torch.save({ "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "epoch": epoch, "loss": avg_loss, "config": { "embedding_dim": arch_params["embedding_dim"], "hidden_dim": arch_params["hidden_dim"], "num_heads": arch_params["num_heads"], "num_layers": arch_params["num_layers"], "vocab_size": arch_params["vocab_size"] } }, checkpoint_path) logger.info(f"Saved checkpoint to {checkpoint_path}") # Save final model final_model_path = os.path.join(output_dir, f"{specialization}_final_model.bin") torch.save({ "model_state_dict": model.state_dict(), "config": { "embedding_dim": arch_params["embedding_dim"], "hidden_dim": arch_params["hidden_dim"], "num_heads": arch_params["num_heads"], "num_layers": arch_params["num_layers"], "vocab_size": arch_params["vocab_size"] } }, final_model_path) logger.info(f"Training completed. Final model saved to {final_model_path}") return final_model_path except Exception as e: logger.error(f"Error during training: {e}", exc_info=True) return None if __name__ == "__main__": parser = argparse.ArgumentParser(description="Train a Wildnerve model") parser.add_argument("--specialization", type=str, default="general", help="Model specialization") parser.add_argument("--dataset", type=str, required=True, help="Path to the dataset file") parser.add_argument("--output", type=str, default="./checkpoints", help="Output directory") parser.add_argument("--epochs", type=int, help="Number of training epochs (overrides config)") parser.add_argument("--batch-size", type=int, help="Batch size (overrides config)") parser.add_argument("--learning-rate", type=float, help="Learning rate (overrides config)") parser.add_argument("--device", type=str, help="Device to use (cuda or cpu)") args = parser.parse_args() train_model( specialization=args.specialization, dataset_path=args.dataset, output_dir=args.output, num_epochs=args.epochs, batch_size=args.batch_size, learning_rate=args.learning_rate, device=args.device )