|
|
import os |
|
|
import argparse |
|
|
import json |
|
|
import numpy as np |
|
|
from tqdm import tqdm |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.utils.data import DataLoader, random_split |
|
|
from torch.cuda.amp import GradScaler, autocast |
|
|
from transformers import GPT2TokenizerFast, AdamW, get_linear_schedule_with_warmup |
|
|
from datasets import load_dataset |
|
|
from transformers import logging as hf_logging |
|
|
|
|
|
|
|
|
hf_logging.set_verbosity_error() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Config: |
|
|
"""Centralized configuration for training""" |
|
|
def __init__(self): |
|
|
|
|
|
self.vocab_size = 50257 |
|
|
self.d_model = 512 |
|
|
self.nhead = 8 |
|
|
self.num_layers = 6 |
|
|
self.dim_feedforward = 2048 |
|
|
self.dropout = 0.1 |
|
|
|
|
|
|
|
|
self.batch_size = 32 |
|
|
self.num_epochs = 3 |
|
|
self.learning_rate = 5e-5 |
|
|
self.weight_decay = 0.01 |
|
|
self.warmup_steps = 0.1 |
|
|
self.max_seq_length = 512 |
|
|
self.gradient_accumulation_steps = 1 |
|
|
self.max_grad_norm = 1.0 |
|
|
self.seed = 42 |
|
|
|
|
|
|
|
|
self.output_dir = "./checkpoints" |
|
|
self.model_save_prefix = "reasoning_model" |
|
|
|
|
|
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self.fp16 = torch.cuda.is_available() |
|
|
|
|
|
def save(self, path): |
|
|
"""Save configuration to file""" |
|
|
os.makedirs(os.path.dirname(path), exist_ok=True) |
|
|
with open(path, 'w') as f: |
|
|
json.dump(self.__dict__, f, indent=2) |
|
|
|
|
|
@classmethod |
|
|
def from_file(cls, path): |
|
|
"""Load configuration from file""" |
|
|
config = cls() |
|
|
with open(path, 'r') as f: |
|
|
config.__dict__.update(json.load(f)) |
|
|
return config |
|
|
|
|
|
def load_and_preprocess_data(config): |
|
|
"""Load and preprocess the dataset""" |
|
|
|
|
|
dataset = load_dataset("ag2428/reasoningDataV4", split="train") |
|
|
|
|
|
|
|
|
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
|
|
|
def tokenize_function(examples): |
|
|
|
|
|
texts = [f"{inst}\n{ans}" for inst, ans in zip(examples["instruction"], examples["answer"])] |
|
|
|
|
|
|
|
|
tokenized = tokenizer( |
|
|
texts, |
|
|
max_length=config.max_seq_length, |
|
|
truncation=True, |
|
|
padding="max_length", |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
|
|
|
tokenized["labels"] = tokenized["input_ids"].clone() |
|
|
return tokenized |
|
|
|
|
|
|
|
|
tokenized_datasets = dataset.map( |
|
|
tokenize_function, |
|
|
batched=True, |
|
|
remove_columns=dataset.column_names, |
|
|
desc="Tokenizing dataset" |
|
|
) |
|
|
|
|
|
|
|
|
train_val = tokenized_datasets.train_test_split(test_size=0.1, seed=config.seed) |
|
|
train_dataset = train_val["train"] |
|
|
val_dataset = train_val["test"] |
|
|
|
|
|
|
|
|
train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels']) |
|
|
val_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels']) |
|
|
|
|
|
|
|
|
train_loader = DataLoader( |
|
|
train_dataset, |
|
|
batch_size=config.batch_size, |
|
|
shuffle=True, |
|
|
num_workers=4, |
|
|
pin_memory=True |
|
|
) |
|
|
|
|
|
val_loader = DataLoader( |
|
|
val_dataset, |
|
|
batch_size=config.batch_size, |
|
|
shuffle=False, |
|
|
num_workers=4, |
|
|
pin_memory=True |
|
|
) |
|
|
|
|
|
return train_loader, val_loader, tokenizer |
|
|
|
|
|
def train_epoch(model, train_loader, optimizer, scheduler, scaler, config, epoch): |
|
|
"""Train for one epoch""" |
|
|
model.train() |
|
|
total_loss = 0 |
|
|
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}") |
|
|
|
|
|
for step, batch in enumerate(progress_bar): |
|
|
|
|
|
input_ids = batch['input_ids'].to(config.device) |
|
|
attention_mask = batch['attention_mask'].to(config.device) |
|
|
labels = batch['labels'].to(config.device) |
|
|
|
|
|
|
|
|
with autocast(enabled=config.fp16): |
|
|
outputs = model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
labels=labels |
|
|
) |
|
|
loss = outputs.loss / config.gradient_accumulation_steps |
|
|
|
|
|
|
|
|
if config.fp16: |
|
|
scaler.scale(loss).backward() |
|
|
else: |
|
|
loss.backward() |
|
|
|
|
|
|
|
|
if (step + 1) % config.gradient_accumulation_steps == 0: |
|
|
if config.fp16: |
|
|
scaler.unscale_(optimizer) |
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) |
|
|
scaler.step(optimizer) |
|
|
scaler.update() |
|
|
else: |
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) |
|
|
optimizer.step() |
|
|
|
|
|
scheduler.step() |
|
|
optimizer.zero_grad() |
|
|
|
|
|
total_loss += loss.item() * config.gradient_accumulation_steps |
|
|
|
|
|
|
|
|
progress_bar.set_postfix({ |
|
|
'loss': f"{total_loss / (step + 1):.4f}", |
|
|
'lr': f"{scheduler.get_last_lr()[0]:.2e}" |
|
|
}) |
|
|
|
|
|
return total_loss / len(train_loader) |
|
|
|
|
|
def evaluate(model, val_loader, config): |
|
|
"""Evaluate the model on the validation set""" |
|
|
model.eval() |
|
|
total_loss = 0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch in tqdm(val_loader, desc="Evaluating"): |
|
|
input_ids = batch['input_ids'].to(config.device) |
|
|
attention_mask = batch['attention_mask'].to(config.device) |
|
|
labels = batch['labels'].to(config.device) |
|
|
|
|
|
with autocast(enabled=config.fp16): |
|
|
outputs = model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
labels=labels |
|
|
) |
|
|
loss = outputs.loss |
|
|
|
|
|
total_loss += loss.item() |
|
|
|
|
|
return total_loss / len(val_loader) |
|
|
|
|
|
def save_checkpoint(model, optimizer, scheduler, epoch, config, is_best=False): |
|
|
"""Save model checkpoint""" |
|
|
os.makedirs(config.output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
checkpoint = { |
|
|
'epoch': epoch, |
|
|
'model_state_dict': model.state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'scheduler_state_dict': scheduler.state_dict(), |
|
|
'config': config.__dict__, |
|
|
} |
|
|
|
|
|
|
|
|
if is_best: |
|
|
filename = os.path.join(config.output_dir, f"{config.model_save_prefix}_best.pt") |
|
|
else: |
|
|
filename = os.path.join(config.output_dir, f"{config.model_save_prefix}_epoch_{epoch}.pt") |
|
|
|
|
|
torch.save(checkpoint, filename) |
|
|
print(f"Checkpoint saved to {filename}") |
|
|
|
|
|
def main(): |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Train a reasoning model") |
|
|
parser.add_argument('--config', type=str, default=None, help="Path to config file") |
|
|
parser.add_argument('--output_dir', type=str, default=None, help="Output directory for checkpoints") |
|
|
parser.add_argument('--batch_size', type=int, default=None, help="Batch size") |
|
|
parser.add_argument('--num_epochs', type=int, default=None, help="Number of epochs") |
|
|
parser.add_argument('--learning_rate', type=float, default=None, help="Learning rate") |
|
|
parser.add_argument('--fp16', action='store_true', help="Use mixed precision training") |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if args.config: |
|
|
config = Config.from_file(args.config) |
|
|
else: |
|
|
config = Config() |
|
|
|
|
|
|
|
|
if args.output_dir: |
|
|
config.output_dir = args.output_dir |
|
|
if args.batch_size: |
|
|
config.batch_size = args.batch_size |
|
|
if args.num_epochs: |
|
|
config.num_epochs = args.num_epochs |
|
|
if args.learning_rate: |
|
|
config.learning_rate = args.learning_rate |
|
|
if args.fp16: |
|
|
config.fp16 = True |
|
|
|
|
|
|
|
|
torch.manual_seed(config.seed) |
|
|
np.random.seed(config.seed) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed_all(config.seed) |
|
|
|
|
|
|
|
|
os.makedirs(config.output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
config.save(os.path.join(config.output_dir, "config.json")) |
|
|
|
|
|
|
|
|
print("Loading and preprocessing data...") |
|
|
train_loader, val_loader, tokenizer = load_and_preprocess_data(config) |
|
|
|
|
|
|
|
|
print("Initializing model...") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PlaceholderModel(nn.Module): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id) |
|
|
|
|
|
def forward(self, input_ids, attention_mask, labels=None): |
|
|
|
|
|
|
|
|
raise NotImplementedError( |
|
|
"Please implement your transformer model and replace this placeholder. " |
|
|
"See the TODO comment in the code for more details." |
|
|
) |
|
|
|
|
|
model = PlaceholderModel() |
|
|
model = model.to(config.device) |
|
|
|
|
|
|
|
|
no_decay = ['bias', 'LayerNorm.weight'] |
|
|
optimizer_grouped_parameters = [ |
|
|
{ |
|
|
'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], |
|
|
'weight_decay': config.weight_decay, |
|
|
}, |
|
|
{ |
|
|
'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], |
|
|
'weight_decay': 0.0, |
|
|
} |
|
|
] |
|
|
|
|
|
optimizer = AdamW(optimizer_grouped_parameters, lr=config.learning_rate) |
|
|
|
|
|
|
|
|
total_steps = len(train_loader) * config.num_epochs // config.gradient_accumulation_steps |
|
|
warmup_steps = int(total_steps * config.warmup_steps) |
|
|
|
|
|
|
|
|
scheduler = get_linear_schedule_with_warmup( |
|
|
optimizer, |
|
|
num_warmup_steps=warmup_steps, |
|
|
num_training_steps=total_steps |
|
|
) |
|
|
|
|
|
|
|
|
scaler = GradScaler(enabled=config.fp16) |
|
|
|
|
|
|
|
|
print("Starting training...") |
|
|
best_val_loss = float('inf') |
|
|
|
|
|
for epoch in range(config.num_epochs): |
|
|
|
|
|
train_loss = train_epoch(model, train_loader, optimizer, scheduler, scaler, config, epoch) |
|
|
|
|
|
|
|
|
val_loss = evaluate(model, val_loader, config) |
|
|
|
|
|
print(f"Epoch {epoch + 1}/{config.num_epochs}:") |
|
|
print(f" Train loss: {train_loss:.4f}") |
|
|
print(f" Val loss: {val_loss:.4f}") |
|
|
|
|
|
|
|
|
save_checkpoint(model, optimizer, scheduler, epoch, config) |
|
|
|
|
|
|
|
|
if val_loss < best_val_loss: |
|
|
best_val_loss = val_loss |
|
|
save_checkpoint(model, optimizer, scheduler, epoch, config, is_best=True) |
|
|
|
|
|
print("Training complete!") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |