File size: 9,308 Bytes
b2ea2e8 2fc11ed aed0c2d 2fc11ed c9340cb 0f2d26d cb65f81 9b48427 62466de 2fc11ed 62466de 2fc11ed 62466de 307f9d8 2fc11ed 307f9d8 a563d57 62466de 0f2d26d 2fc11ed 62466de 2fc11ed b53dc87 aed0c2d 0f2d26d 2fc11ed 29a6938 87a4d20 29a6938 2fc11ed 0f2d26d 62466de 0f2d26d beb0496 0f2d26d 2fc11ed 0f2d26d 62466de 0f2d26d 2fc11ed 62466de b2ea2e8 a563d57 b2ea2e8 a563d57 b2ea2e8 a563d57 b2ea2e8 a563d57 b2ea2e8 a563d57 b2ea2e8 a563d57 b2ea2e8 b53dc87 a563d57 b53dc87 a563d57 b53dc87 62466de ac8f63b a563d57 62466de b53dc87 62466de 83c3c6d b2ea2e8 62466de 2fc11ed 0f2d26d b53dc87 0f2d26d 62466de f10ccaa 62466de 2fc11ed 0f2d26d 2fc11ed 0f2d26d a563d57 b2ea2e8 2fc11ed aed0c2d be0031f b2ea2e8 aed0c2d de45ebc 2fc11ed 0f2d26d 2fc11ed 0f2d26d 2fc11ed 0f2d26d aed0c2d 0f2d26d 62466de 0f2d26d b2ea2e8 0f2d26d aed0c2d b53dc87 b2ea2e8 a563d57 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 |
import os # For file operations
import re # For regex operations
import torch
import math
from transformers import GPT2Tokenizer
from datasets import load_dataset
import numpy as np
import wandb # Import W&B library
from model import minGRULM
from util import generate_text, generate_name
# ============================
# Configuration Parameters
# ============================
dataset_path = 'flpelerin/tinystories-100k'
model_name = generate_name() # Example: "mingru-a14c"
num_epochs = 1
batch_size = 4
seq_length = 256
learning_rate = 1e-4
input_len = 50
num_predict = 250
infer_every = 200
reset_state_every = 16
validate_every = 200
save_every = 500 # Controls checkpointing frequency
# ============================
# Initialize the Device
# ============================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Total context size is {batch_size * seq_length} tokens")
# ============================
# Initialize the Tokenizer
# ============================
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
vocab_size = tokenizer.vocab_size
print(f"Tokenizer has {vocab_size} unique tokens")
# ============================
# Load and Preprocess Dataset
# ============================
dataset = load_dataset(dataset_path)
def process_function(examples):
return tokenizer(
examples['text'],
padding='max_length', # Fixed padding
truncation=True,
max_length=(seq_length * batch_size) # Fixed max length
)
tokenized_datasets = dataset.map(process_function, batched=True)
print(f"Dataset has {tokenized_datasets['train'].num_rows} rows of {batch_size} times {seq_length} tokens")
# ============================
# Split Dataset into Train and Validation
# ============================
split_dataset = tokenized_datasets['train'].train_test_split(test_size=(1/validate_every))
train_dataset = split_dataset['train']
valid_dataset = split_dataset['test']
print(f"Training set size: {len(train_dataset)}")
print(f"Validation set size: {len(valid_dataset)}")
# ============================
# Initialize the Model
# ============================
model = minGRULM(
vocab_size = vocab_size,
d_model = 384,
d_inner = 768,
n_layers = 6
)
model.to(device)
parameters_count = sum(p.numel() for p in model.parameters())
print(f"Model has {parameters_count:,} parameters")
# ============================
# Symbolic Link Configuration
# ============================
symlink_path = 'pytorch_model.bin'
def update_symlink(target_path, symlink_path):
"""
Creates or updates a symbolic link pointing to the target path.
Args:
target_path (str): The file path the symlink should point to.
symlink_path (str): The symlink's path.
"""
try:
if os.path.islink(symlink_path) or os.path.exists(symlink_path):
os.remove(symlink_path)
os.symlink(target_path, symlink_path)
print(f"Updated symlink: {symlink_path} -> {target_path}")
except OSError as e:
print(f"Warning: Failed to create symlink {symlink_path} -> {target_path}. Error: {e}")
# ============================
# Load Checkpoint from pytorch_model.bin if Exists
# ============================
if os.path.exists(symlink_path):
try:
model.load_state_dict(torch.load(symlink_path, map_location=device))
print(f"Loaded model weights from {symlink_path}")
except Exception as e:
print(f"Error loading model from {symlink_path}: {e}")
print("Starting training from scratch.")
else:
print("No checkpoint found. Starting training from scratch.")
# ============================
# Initialize the Weights and Biases Run
# ============================
wandb.login(key="860f8753998c6e6dc356914de07e8855aa2f9642")
wandb.init(
project="minGRU-Training",
name=model_name,
config={
"dataset_path": dataset_path,
"num_epochs": num_epochs,
"batch_size": batch_size,
"seq_length": seq_length,
"learning_rate": learning_rate,
"input_len": input_len,
"num_predict": num_predict,
"infer_every": infer_every,
"reset_state_every": reset_state_every,
"validate_every": validate_every,
"save_every": save_every, # Logging the new variable
"dataset_rows": tokenized_datasets['train'].num_rows,
"dataset_token_count": batch_size * seq_length,
"train_set_size": len(train_dataset),
"valid_set_size": len(valid_dataset),
"model_parameters": parameters_count,
"vocab_size": vocab_size,
"d_model": model.d_model,
"d_inner": model.d_inner,
"n_layers": model.n_layers,
"device": str(device),
"model_name": model_name # Log model_name
}
)
# ============================
# Training Loop with Validation and Checkpointing
# ============================
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
h_states = None
step = 0
for epoch in range(num_epochs):
print(f"Starting Epoch {epoch + 1}/{num_epochs}")
for i in range(0, len(train_dataset), batch_size):
batch = train_dataset[i:i + batch_size]
input_ids = torch.tensor(batch['input_ids']).to(device)
# Reset hidden states if needed
if step % reset_state_every == 0:
h_states = None
# Otherwise, keep existing hidden states
optimizer.zero_grad()
try:
_, h_states, loss = model.forward(input_ids, h_states)
loss.backward()
optimizer.step()
except Exception as e:
print(f"Error during training step {step + 1}: {e}")
continue # Skip to the next batch
step += 1
# Compute statistics of hidden states
if h_states is not None:
try:
avg_states = sum([torch.mean(h).item() for h in h_states]) / len(h_states)
var_states = torch.var(torch.cat(h_states, dim=0)).item()
except Exception as e:
avg_states = None
var_states = None
else:
avg_states = None
var_states = None
# Log step information
wandb.log({
"loss": loss.item(),
"average_hidden_state": avg_states,
"variance_hidden_state": var_states,
"step": step
})
print(f"Epoch: {epoch + 1}/{num_epochs}, Step: {step}, Loss: {loss.item():.4f}, Hidden States: average = {avg_states}, variance = {var_states}")
# Perform validation at specified intervals
if step % validate_every == 0:
validation_loss = 0.0
valid_steps = 0
with torch.no_grad():
for vi in range(0, len(valid_dataset), batch_size):
val_batch = valid_dataset[vi:vi + batch_size]
val_input_ids = torch.tensor(val_batch['input_ids']).to(device)
# Forward pass
_, _, val_loss = model.forward(val_input_ids, None)
validation_loss += val_loss.item()
valid_steps += 1
avg_validation_loss = validation_loss / valid_steps if valid_steps > 0 else float('inf')
# Log validation loss
wandb.log({"validation_loss": avg_validation_loss, "step": step})
print(f"----- Validation after Step {step}: Average Loss = {avg_validation_loss:.4f} -----")
# Perform inference at specified steps
if step % infer_every == 0:
with torch.no_grad():
if input_ids.size(1) < input_len:
print("Input length is shorter than input_len. Skipping inference.")
continue
# Select a single input from the current batch for inference
sample_ids = input_ids[0][:input_len]
input_text = tokenizer.decode(sample_ids, skip_special_tokens=True)
print(f"Input for Inference: {input_text}")
prompt = sample_ids.unsqueeze(0) # Shape: [1, input_len]
generated_text = generate_text(model, tokenizer, prompt, num_predict)
print(f"Generated Text:\n{generated_text}\n")
# Optionally, log generated text (e.g., as HTML to preserve formatting)
# wandb.log({"generated_text": wandb.Html(f"<pre>{generated_text}</pre>")}, step=step)
# Perform checkpointing at specified steps
if step % save_every == 0:
step_str = f"{step}k" # Format step with 'k', e.g., '750k'
checkpoint_filename = f"{model_name}-{step_str}.bin"
checkpoint_path = checkpoint_filename
try:
torch.save(model.state_dict(), checkpoint_path)
print(f"Saved model checkpoint at step {step} to {checkpoint_path}")
# Update the symbolic link to point to this checkpoint
update_symlink(checkpoint_path, symlink_path)
# Optionally, log the checkpoint to W&B
# wandb.save(checkpoint_path)
except Exception as e:
print(f"Error saving checkpoint at step {step}: {e}") |