|
|
"""
|
|
|
Train a new Wildnerve model with parameters loaded from config.json.
|
|
|
"""
|
|
|
import os
|
|
|
import sys
|
|
|
import torch
|
|
|
import logging
|
|
|
import argparse
|
|
|
from pathlib import Path
|
|
|
from typing import Dict, Any, Optional, List, Tuple
|
|
|
|
|
|
|
|
|
from config import app_config, get_model_architecture_params
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def train_model(
|
|
|
specialization: str,
|
|
|
dataset_path: str,
|
|
|
output_dir: str,
|
|
|
num_epochs: Optional[int] = None,
|
|
|
batch_size: Optional[int] = None,
|
|
|
learning_rate: Optional[float] = None,
|
|
|
device: Optional[str] = None
|
|
|
):
|
|
|
"""Train a model with parameters from config.json"""
|
|
|
|
|
|
arch_params = get_model_architecture_params()
|
|
|
logger.info(f"Loaded architecture parameters from config: {arch_params}")
|
|
|
|
|
|
|
|
|
if hasattr(app_config, "TRAINING_CONFIG"):
|
|
|
training_config = app_config.TRAINING_CONFIG
|
|
|
num_epochs = num_epochs or getattr(training_config, "NUM_EPOCHS", 10)
|
|
|
learning_rate = learning_rate or getattr(training_config, "LEARNING_RATE", 1e-4)
|
|
|
elif hasattr(app_config, "TRANSFORMER_CONFIG"):
|
|
|
transformer_config = app_config.TRANSFORMER_CONFIG
|
|
|
num_epochs = num_epochs or getattr(transformer_config, "NUM_EPOCHS", 10)
|
|
|
learning_rate = learning_rate or getattr(transformer_config, "LEARNING_RATE", 1e-4)
|
|
|
|
|
|
|
|
|
if hasattr(app_config, "DATA_LOADER_CONFIG"):
|
|
|
data_loader_config = app_config.DATA_LOADER_CONFIG
|
|
|
batch_size = batch_size or getattr(data_loader_config, "BATCH_SIZE", 32)
|
|
|
|
|
|
|
|
|
num_epochs = num_epochs or 10
|
|
|
batch_size = batch_size or 32
|
|
|
learning_rate = learning_rate or 1e-4
|
|
|
|
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
if device is None:
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
logger.info(f"Using device: {device}")
|
|
|
|
|
|
try:
|
|
|
|
|
|
from model_Custm import Wildnerve_tlm01
|
|
|
from transformers import AutoTokenizer
|
|
|
from torch.utils.data import DataLoader, Dataset
|
|
|
import json
|
|
|
|
|
|
|
|
|
model_name = getattr(app_config.TRANSFORMER_CONFIG, "MODEL_NAME", "gpt2") if hasattr(app_config, "TRANSFORMER_CONFIG") else "gpt2"
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
if tokenizer.pad_token is None:
|
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
|
|
|
logger.info(f"Loading dataset from {dataset_path}")
|
|
|
with open(dataset_path, 'r') as f:
|
|
|
data = json.load(f)
|
|
|
|
|
|
|
|
|
class TextDataset(Dataset):
|
|
|
def __init__(self, texts, tokenizer, max_length):
|
|
|
self.encodings = tokenizer(texts, truncation=True, padding="max_length",
|
|
|
max_length=max_length, return_tensors="pt")
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
item = {key: val[idx] for key, val in self.encodings.items()}
|
|
|
item["labels"] = item["input_ids"].clone()
|
|
|
return item
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.encodings["input_ids"])
|
|
|
|
|
|
|
|
|
texts = [item["text"] for item in data]
|
|
|
|
|
|
|
|
|
train_dataset = TextDataset(texts, tokenizer, arch_params["max_seq_length"])
|
|
|
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
|
|
|
|
|
|
|
|
logger.info(f"Training with parameters:")
|
|
|
logger.info(f"- specialization: {specialization}")
|
|
|
logger.info(f"- model_name: {model_name}")
|
|
|
logger.info(f"- embedding_dim: {arch_params['embedding_dim']}")
|
|
|
logger.info(f"- hidden_dim: {arch_params['hidden_dim']}")
|
|
|
logger.info(f"- num_heads: {arch_params['num_heads']}")
|
|
|
logger.info(f"- num_layers: {arch_params['num_layers']}")
|
|
|
logger.info(f"- vocab_size: {arch_params['vocab_size']}")
|
|
|
logger.info(f"- num_epochs: {num_epochs}")
|
|
|
logger.info(f"- batch_size: {batch_size}")
|
|
|
logger.info(f"- learning_rate: {learning_rate}")
|
|
|
|
|
|
|
|
|
model = Wildnerve_tlm01(
|
|
|
vocab_size=arch_params["vocab_size"],
|
|
|
specialization=specialization,
|
|
|
dataset_path=dataset_path,
|
|
|
model_name=model_name,
|
|
|
embedding_dim=arch_params["embedding_dim"],
|
|
|
num_heads=arch_params["num_heads"],
|
|
|
hidden_dim=arch_params["hidden_dim"],
|
|
|
num_layers=arch_params["num_layers"],
|
|
|
output_size=arch_params["vocab_size"],
|
|
|
dropout=arch_params.get("dropout", 0.1),
|
|
|
max_seq_length=arch_params["max_seq_length"],
|
|
|
tokenizer=tokenizer
|
|
|
)
|
|
|
|
|
|
|
|
|
model.to(device)
|
|
|
|
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
|
|
|
|
|
|
|
|
|
logger.info(f"Starting training for {num_epochs} epochs")
|
|
|
for epoch in range(num_epochs):
|
|
|
model.train()
|
|
|
total_loss = 0
|
|
|
|
|
|
for batch_idx, batch in enumerate(train_dataloader):
|
|
|
|
|
|
batch = {k: v.to(device) for k, v in batch.items()}
|
|
|
|
|
|
|
|
|
outputs = model(batch["input_ids"],
|
|
|
attention_mask=batch.get("attention_mask"))
|
|
|
|
|
|
|
|
|
loss = torch.nn.functional.cross_entropy(
|
|
|
outputs.view(-1, outputs.size(-1)),
|
|
|
batch["labels"].view(-1)
|
|
|
)
|
|
|
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
loss.backward()
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
|
total_loss += loss.item()
|
|
|
|
|
|
if (batch_idx + 1) % 10 == 0:
|
|
|
logger.info(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx+1}/{len(train_dataloader)}, "
|
|
|
f"Loss: {loss.item():.4f}")
|
|
|
|
|
|
avg_loss = total_loss / len(train_dataloader)
|
|
|
logger.info(f"Epoch {epoch+1}/{num_epochs} completed. Average loss: {avg_loss:.4f}")
|
|
|
|
|
|
|
|
|
checkpoint_path = os.path.join(output_dir, f"model_epoch_{epoch+1}.bin")
|
|
|
torch.save({
|
|
|
"model_state_dict": model.state_dict(),
|
|
|
"optimizer_state_dict": optimizer.state_dict(),
|
|
|
"epoch": epoch,
|
|
|
"loss": avg_loss,
|
|
|
"config": {
|
|
|
"embedding_dim": arch_params["embedding_dim"],
|
|
|
"hidden_dim": arch_params["hidden_dim"],
|
|
|
"num_heads": arch_params["num_heads"],
|
|
|
"num_layers": arch_params["num_layers"],
|
|
|
"vocab_size": arch_params["vocab_size"]
|
|
|
}
|
|
|
}, checkpoint_path)
|
|
|
logger.info(f"Saved checkpoint to {checkpoint_path}")
|
|
|
|
|
|
|
|
|
final_model_path = os.path.join(output_dir, f"{specialization}_final_model.bin")
|
|
|
torch.save({
|
|
|
"model_state_dict": model.state_dict(),
|
|
|
"config": {
|
|
|
"embedding_dim": arch_params["embedding_dim"],
|
|
|
"hidden_dim": arch_params["hidden_dim"],
|
|
|
"num_heads": arch_params["num_heads"],
|
|
|
"num_layers": arch_params["num_layers"],
|
|
|
"vocab_size": arch_params["vocab_size"]
|
|
|
}
|
|
|
}, final_model_path)
|
|
|
logger.info(f"Training completed. Final model saved to {final_model_path}")
|
|
|
|
|
|
return final_model_path
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error during training: {e}", exc_info=True)
|
|
|
return None
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
parser = argparse.ArgumentParser(description="Train a Wildnerve model")
|
|
|
parser.add_argument("--specialization", type=str, default="general", help="Model specialization")
|
|
|
parser.add_argument("--dataset", type=str, required=True, help="Path to the dataset file")
|
|
|
parser.add_argument("--output", type=str, default="./checkpoints", help="Output directory")
|
|
|
parser.add_argument("--epochs", type=int, help="Number of training epochs (overrides config)")
|
|
|
parser.add_argument("--batch-size", type=int, help="Batch size (overrides config)")
|
|
|
parser.add_argument("--learning-rate", type=float, help="Learning rate (overrides config)")
|
|
|
parser.add_argument("--device", type=str, help="Device to use (cuda or cpu)")
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
train_model(
|
|
|
specialization=args.specialization,
|
|
|
dataset_path=args.dataset,
|
|
|
output_dir=args.output,
|
|
|
num_epochs=args.epochs,
|
|
|
batch_size=args.batch_size,
|
|
|
learning_rate=args.learning_rate,
|
|
|
device=args.device
|
|
|
)
|
|
|
|