|
|
import os |
|
|
import re |
|
|
import torch |
|
|
import math |
|
|
from transformers import GPT2Tokenizer |
|
|
from datasets import load_dataset |
|
|
import numpy as np |
|
|
import wandb |
|
|
|
|
|
from model import minGRULM |
|
|
from util import generate_text, generate_name |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset_path = 'flpelerin/tinystories-100k' |
|
|
model_name = generate_name() |
|
|
|
|
|
num_epochs = 1 |
|
|
batch_size = 4 |
|
|
seq_length = 256 |
|
|
learning_rate = 1e-4 |
|
|
|
|
|
input_len = 50 |
|
|
num_predict = 250 |
|
|
|
|
|
infer_every = 200 |
|
|
reset_state_every = 16 |
|
|
validate_every = 200 |
|
|
save_every = 500 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"Total context size is {batch_size * seq_length} tokens") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
vocab_size = tokenizer.vocab_size |
|
|
print(f"Tokenizer has {vocab_size} unique tokens") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset = load_dataset(dataset_path) |
|
|
|
|
|
def process_function(examples): |
|
|
return tokenizer( |
|
|
examples['text'], |
|
|
padding='max_length', |
|
|
truncation=True, |
|
|
max_length=(seq_length * batch_size) |
|
|
) |
|
|
|
|
|
tokenized_datasets = dataset.map(process_function, batched=True) |
|
|
print(f"Dataset has {tokenized_datasets['train'].num_rows} rows of {batch_size} times {seq_length} tokens") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
split_dataset = tokenized_datasets['train'].train_test_split(test_size=(1/validate_every)) |
|
|
train_dataset = split_dataset['train'] |
|
|
valid_dataset = split_dataset['test'] |
|
|
|
|
|
print(f"Training set size: {len(train_dataset)}") |
|
|
print(f"Validation set size: {len(valid_dataset)}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = minGRULM( |
|
|
vocab_size = vocab_size, |
|
|
d_model = 384, |
|
|
d_inner = 768, |
|
|
n_layers = 6 |
|
|
) |
|
|
|
|
|
model.to(device) |
|
|
parameters_count = sum(p.numel() for p in model.parameters()) |
|
|
print(f"Model has {parameters_count:,} parameters") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
symlink_path = 'pytorch_model.bin' |
|
|
|
|
|
def update_symlink(target_path, symlink_path): |
|
|
""" |
|
|
Creates or updates a symbolic link pointing to the target path. |
|
|
|
|
|
Args: |
|
|
target_path (str): The file path the symlink should point to. |
|
|
symlink_path (str): The symlink's path. |
|
|
""" |
|
|
try: |
|
|
if os.path.islink(symlink_path) or os.path.exists(symlink_path): |
|
|
os.remove(symlink_path) |
|
|
os.symlink(target_path, symlink_path) |
|
|
print(f"Updated symlink: {symlink_path} -> {target_path}") |
|
|
except OSError as e: |
|
|
print(f"Warning: Failed to create symlink {symlink_path} -> {target_path}. Error: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if os.path.exists(symlink_path): |
|
|
try: |
|
|
model.load_state_dict(torch.load(symlink_path, map_location=device)) |
|
|
print(f"Loaded model weights from {symlink_path}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading model from {symlink_path}: {e}") |
|
|
print("Starting training from scratch.") |
|
|
else: |
|
|
print("No checkpoint found. Starting training from scratch.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
wandb.login(key="860f8753998c6e6dc356914de07e8855aa2f9642") |
|
|
wandb.init( |
|
|
project="minGRU-Training", |
|
|
name=model_name, |
|
|
config={ |
|
|
"dataset_path": dataset_path, |
|
|
"num_epochs": num_epochs, |
|
|
"batch_size": batch_size, |
|
|
"seq_length": seq_length, |
|
|
"learning_rate": learning_rate, |
|
|
"input_len": input_len, |
|
|
"num_predict": num_predict, |
|
|
"infer_every": infer_every, |
|
|
"reset_state_every": reset_state_every, |
|
|
"validate_every": validate_every, |
|
|
"save_every": save_every, |
|
|
"dataset_rows": tokenized_datasets['train'].num_rows, |
|
|
"dataset_token_count": batch_size * seq_length, |
|
|
"train_set_size": len(train_dataset), |
|
|
"valid_set_size": len(valid_dataset), |
|
|
"model_parameters": parameters_count, |
|
|
"vocab_size": vocab_size, |
|
|
"d_model": model.d_model, |
|
|
"d_inner": model.d_inner, |
|
|
"n_layers": model.n_layers, |
|
|
"device": str(device), |
|
|
"model_name": model_name |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) |
|
|
h_states = None |
|
|
step = 0 |
|
|
|
|
|
for epoch in range(num_epochs): |
|
|
print(f"Starting Epoch {epoch + 1}/{num_epochs}") |
|
|
for i in range(0, len(train_dataset), batch_size): |
|
|
batch = train_dataset[i:i + batch_size] |
|
|
input_ids = torch.tensor(batch['input_ids']).to(device) |
|
|
|
|
|
|
|
|
if step % reset_state_every == 0: |
|
|
h_states = None |
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
try: |
|
|
_, h_states, loss = model.forward(input_ids, h_states) |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
except Exception as e: |
|
|
print(f"Error during training step {step + 1}: {e}") |
|
|
continue |
|
|
|
|
|
step += 1 |
|
|
|
|
|
|
|
|
if h_states is not None: |
|
|
try: |
|
|
avg_states = sum([torch.mean(h).item() for h in h_states]) / len(h_states) |
|
|
var_states = torch.var(torch.cat(h_states, dim=0)).item() |
|
|
except Exception as e: |
|
|
avg_states = None |
|
|
var_states = None |
|
|
else: |
|
|
avg_states = None |
|
|
var_states = None |
|
|
|
|
|
|
|
|
wandb.log({ |
|
|
"loss": loss.item(), |
|
|
"average_hidden_state": avg_states, |
|
|
"variance_hidden_state": var_states, |
|
|
"step": step |
|
|
}) |
|
|
print(f"Epoch: {epoch + 1}/{num_epochs}, Step: {step}, Loss: {loss.item():.4f}, Hidden States: average = {avg_states}, variance = {var_states}") |
|
|
|
|
|
|
|
|
if step % validate_every == 0: |
|
|
validation_loss = 0.0 |
|
|
valid_steps = 0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
for vi in range(0, len(valid_dataset), batch_size): |
|
|
val_batch = valid_dataset[vi:vi + batch_size] |
|
|
val_input_ids = torch.tensor(val_batch['input_ids']).to(device) |
|
|
|
|
|
|
|
|
_, _, val_loss = model.forward(val_input_ids, None) |
|
|
validation_loss += val_loss.item() |
|
|
valid_steps += 1 |
|
|
|
|
|
avg_validation_loss = validation_loss / valid_steps if valid_steps > 0 else float('inf') |
|
|
|
|
|
wandb.log({"validation_loss": avg_validation_loss, "step": step}) |
|
|
print(f"----- Validation after Step {step}: Average Loss = {avg_validation_loss:.4f} -----") |
|
|
|
|
|
|
|
|
if step % infer_every == 0: |
|
|
with torch.no_grad(): |
|
|
if input_ids.size(1) < input_len: |
|
|
print("Input length is shorter than input_len. Skipping inference.") |
|
|
continue |
|
|
|
|
|
sample_ids = input_ids[0][:input_len] |
|
|
input_text = tokenizer.decode(sample_ids, skip_special_tokens=True) |
|
|
print(f"Input for Inference: {input_text}") |
|
|
|
|
|
prompt = sample_ids.unsqueeze(0) |
|
|
generated_text = generate_text(model, tokenizer, prompt, num_predict) |
|
|
print(f"Generated Text:\n{generated_text}\n") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if step % save_every == 0: |
|
|
step_str = f"{step}k" |
|
|
checkpoint_filename = f"{model_name}-{step_str}.bin" |
|
|
checkpoint_path = checkpoint_filename |
|
|
try: |
|
|
torch.save(model.state_dict(), checkpoint_path) |
|
|
print(f"Saved model checkpoint at step {step} to {checkpoint_path}") |
|
|
|
|
|
|
|
|
update_symlink(checkpoint_path, symlink_path) |
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error saving checkpoint at step {step}: {e}") |