flpelerin
commited on
Commit
·
aed0c2d
1
Parent(s):
87a4d20
adding wandb
Browse files
train.py
CHANGED
|
@@ -3,6 +3,7 @@ import math
|
|
| 3 |
from transformers import GPT2Tokenizer
|
| 4 |
from datasets import load_dataset
|
| 5 |
import numpy as np
|
|
|
|
| 6 |
|
| 7 |
from model import minGRULM
|
| 8 |
from util import generate_text
|
|
@@ -10,21 +11,16 @@ from util import generate_text
|
|
| 10 |
# ============================
|
| 11 |
# Configuration Parameters
|
| 12 |
# ============================
|
| 13 |
-
|
| 14 |
dataset_path = 'flpelerin/tinystories-10k'
|
| 15 |
-
|
| 16 |
num_epochs = 1
|
| 17 |
batch_size = 4
|
| 18 |
seq_length = 256
|
| 19 |
learning_rate = 1e-4
|
| 20 |
infer_step = 50
|
| 21 |
-
|
| 22 |
input_len = 50
|
| 23 |
num_predict = 250
|
| 24 |
-
|
| 25 |
reset_state_every = 16
|
| 26 |
validate_every = 100 # Perform validation every 100 training steps
|
| 27 |
-
|
| 28 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 29 |
print(f"Total context size is {batch_size * seq_length} tokens")
|
| 30 |
|
|
@@ -33,10 +29,25 @@ tokenizer.pad_token = tokenizer.eos_token
|
|
| 33 |
vocab_size = tokenizer.vocab_size
|
| 34 |
print(f"Tokenzier has {vocab_size} unique tokens")
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
# ============================
|
| 37 |
# Load and Preprocess Dataset
|
| 38 |
# ============================
|
| 39 |
-
|
| 40 |
dataset = load_dataset(dataset_path)
|
| 41 |
|
| 42 |
def process_function(examples):
|
|
@@ -49,23 +60,22 @@ def process_function(examples):
|
|
| 49 |
|
| 50 |
tokenized_datasets = dataset.map(process_function, batched=True)
|
| 51 |
print(f"Dataset has {tokenized_datasets['train'].num_rows} rows of {batch_size} times {seq_length} tokens")
|
|
|
|
| 52 |
|
| 53 |
# ============================
|
| 54 |
# Split Dataset into Train and Validation
|
| 55 |
# ============================
|
| 56 |
-
|
| 57 |
-
# Split the training set into 90% train and 10% validation
|
| 58 |
split_dataset = tokenized_datasets['train'].train_test_split(test_size=1 / validate_every)
|
| 59 |
train_dataset = split_dataset['train']
|
| 60 |
valid_dataset = split_dataset['test']
|
| 61 |
|
| 62 |
print(f"Training set size: {len(train_dataset)}")
|
| 63 |
print(f"Validation set size: {len(valid_dataset)}")
|
|
|
|
| 64 |
|
| 65 |
# ============================
|
| 66 |
# Initialize the Model
|
| 67 |
# ============================
|
| 68 |
-
|
| 69 |
model = minGRULM(
|
| 70 |
vocab_size = vocab_size,
|
| 71 |
d_model = 384,
|
|
@@ -75,6 +85,7 @@ model = minGRULM(
|
|
| 75 |
|
| 76 |
model.to(device)
|
| 77 |
print(f"Model has {sum(p.numel() for p in model.parameters()):,} parameters")
|
|
|
|
| 78 |
|
| 79 |
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
|
| 80 |
|
|
@@ -83,7 +94,6 @@ h_states = None
|
|
| 83 |
# ============================
|
| 84 |
# Training Loop with Validation
|
| 85 |
# ============================
|
| 86 |
-
|
| 87 |
step = 0
|
| 88 |
for epoch in range(num_epochs):
|
| 89 |
print(f"Starting Epoch {epoch + 1}/{num_epochs}")
|
|
@@ -92,10 +102,9 @@ for epoch in range(num_epochs):
|
|
| 92 |
input_ids = torch.tensor(batch['input_ids']).to(device)
|
| 93 |
|
| 94 |
# Reset hidden states if needed
|
| 95 |
-
h_states = h_states if (step % reset_state_every
|
| 96 |
str_states = (
|
| 97 |
-
''.join(['{:.3f}, '.format(h_states[0][0][0][j].item()) for j in range(10)]
|
| 98 |
-
if h_states is not None else 'None'
|
| 99 |
)
|
| 100 |
|
| 101 |
optimizer.zero_grad()
|
|
@@ -104,6 +113,13 @@ for epoch in range(num_epochs):
|
|
| 104 |
optimizer.step()
|
| 105 |
|
| 106 |
step += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
print(f"Epoch: {epoch + 1}/{num_epochs}, Step: {step}, Loss: {loss.item():.4f}, Hidden State: {str_states}")
|
| 108 |
|
| 109 |
# Perform validation at specified intervals
|
|
@@ -122,12 +138,9 @@ for epoch in range(num_epochs):
|
|
| 122 |
validation_loss += val_loss.item()
|
| 123 |
valid_steps += 1
|
| 124 |
|
| 125 |
-
# Optionally, limit the number of batches for faster validation
|
| 126 |
-
# Uncomment the following lines to validate on only the first 100 batches
|
| 127 |
-
# if valid_steps >= 100:
|
| 128 |
-
# break
|
| 129 |
-
|
| 130 |
avg_validation_loss = validation_loss / valid_steps if valid_steps > 0 else float('inf')
|
|
|
|
|
|
|
| 131 |
print(f"----- Validation after Step {step}: Average Loss = {avg_validation_loss:.4f} -----")
|
| 132 |
model.train() # Switch back to training mode
|
| 133 |
|
|
@@ -142,3 +155,5 @@ for epoch in range(num_epochs):
|
|
| 142 |
prompt = sample_ids.unsqueeze(0) # Shape: [1, input_len]
|
| 143 |
generated_text = generate_text(model, tokenizer, prompt, num_predict)
|
| 144 |
print(f"Generated Text:\n{generated_text}\n")
|
|
|
|
|
|
|
|
|
| 3 |
from transformers import GPT2Tokenizer
|
| 4 |
from datasets import load_dataset
|
| 5 |
import numpy as np
|
| 6 |
+
import wandb # Import W&B library
|
| 7 |
|
| 8 |
from model import minGRULM
|
| 9 |
from util import generate_text
|
|
|
|
| 11 |
# ============================
|
| 12 |
# Configuration Parameters
|
| 13 |
# ============================
|
|
|
|
| 14 |
dataset_path = 'flpelerin/tinystories-10k'
|
|
|
|
| 15 |
num_epochs = 1
|
| 16 |
batch_size = 4
|
| 17 |
seq_length = 256
|
| 18 |
learning_rate = 1e-4
|
| 19 |
infer_step = 50
|
|
|
|
| 20 |
input_len = 50
|
| 21 |
num_predict = 250
|
|
|
|
| 22 |
reset_state_every = 16
|
| 23 |
validate_every = 100 # Perform validation every 100 training steps
|
|
|
|
| 24 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 25 |
print(f"Total context size is {batch_size * seq_length} tokens")
|
| 26 |
|
|
|
|
| 29 |
vocab_size = tokenizer.vocab_size
|
| 30 |
print(f"Tokenzier has {vocab_size} unique tokens")
|
| 31 |
|
| 32 |
+
# Initialize W&B Logging
|
| 33 |
+
wandb.init(project="minGRU-Training", config={
|
| 34 |
+
"dataset_path": dataset_path,
|
| 35 |
+
"num_epochs": num_epochs,
|
| 36 |
+
"batch_size": batch_size,
|
| 37 |
+
"seq_length": seq_length,
|
| 38 |
+
"learning_rate": learning_rate,
|
| 39 |
+
"infer_step": infer_step,
|
| 40 |
+
"input_len": input_len,
|
| 41 |
+
"num_predict": num_predict,
|
| 42 |
+
"reset_state_every": reset_state_every,
|
| 43 |
+
"validate_every": validate_every,
|
| 44 |
+
"device": str(device),
|
| 45 |
+
settings=wandb.Settings(api_key="860f8753998c6e6dc356914de07e8855aa2f9642")
|
| 46 |
+
})
|
| 47 |
+
|
| 48 |
# ============================
|
| 49 |
# Load and Preprocess Dataset
|
| 50 |
# ============================
|
|
|
|
| 51 |
dataset = load_dataset(dataset_path)
|
| 52 |
|
| 53 |
def process_function(examples):
|
|
|
|
| 60 |
|
| 61 |
tokenized_datasets = dataset.map(process_function, batched=True)
|
| 62 |
print(f"Dataset has {tokenized_datasets['train'].num_rows} rows of {batch_size} times {seq_length} tokens")
|
| 63 |
+
wandb.log({"dataset_rows": tokenized_datasets['train'].num_rows, "dataset_token_count": batch_size * seq_length}) # Log dataset stats
|
| 64 |
|
| 65 |
# ============================
|
| 66 |
# Split Dataset into Train and Validation
|
| 67 |
# ============================
|
|
|
|
|
|
|
| 68 |
split_dataset = tokenized_datasets['train'].train_test_split(test_size=1 / validate_every)
|
| 69 |
train_dataset = split_dataset['train']
|
| 70 |
valid_dataset = split_dataset['test']
|
| 71 |
|
| 72 |
print(f"Training set size: {len(train_dataset)}")
|
| 73 |
print(f"Validation set size: {len(valid_dataset)}")
|
| 74 |
+
wandb.log({"train_set_size": len(train_dataset), "valid_set_size": len(valid_dataset)}) # Log set sizes
|
| 75 |
|
| 76 |
# ============================
|
| 77 |
# Initialize the Model
|
| 78 |
# ============================
|
|
|
|
| 79 |
model = minGRULM(
|
| 80 |
vocab_size = vocab_size,
|
| 81 |
d_model = 384,
|
|
|
|
| 85 |
|
| 86 |
model.to(device)
|
| 87 |
print(f"Model has {sum(p.numel() for p in model.parameters()):,} parameters")
|
| 88 |
+
wandb.log({"model_parameters": sum(p.numel() for p in model.parameters())}) # Log model parameter count
|
| 89 |
|
| 90 |
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
|
| 91 |
|
|
|
|
| 94 |
# ============================
|
| 95 |
# Training Loop with Validation
|
| 96 |
# ============================
|
|
|
|
| 97 |
step = 0
|
| 98 |
for epoch in range(num_epochs):
|
| 99 |
print(f"Starting Epoch {epoch + 1}/{num_epochs}")
|
|
|
|
| 102 |
input_ids = torch.tensor(batch['input_ids']).to(device)
|
| 103 |
|
| 104 |
# Reset hidden states if needed
|
| 105 |
+
h_states = h_states if (step % reset_state_every!= 0) else None
|
| 106 |
str_states = (
|
| 107 |
+
''.join(['{:.3f}, '.format(h_states[0][0][0][j].item()) for j in range(10)] if h_states is not None else 'None'
|
|
|
|
| 108 |
)
|
| 109 |
|
| 110 |
optimizer.zero_grad()
|
|
|
|
| 113 |
optimizer.step()
|
| 114 |
|
| 115 |
step += 1
|
| 116 |
+
# Log step information
|
| 117 |
+
wandb.log({
|
| 118 |
+
"step": step,
|
| 119 |
+
"epoch": epoch + 1,
|
| 120 |
+
"loss": loss.item(),
|
| 121 |
+
"hidden_state": str_states
|
| 122 |
+
})
|
| 123 |
print(f"Epoch: {epoch + 1}/{num_epochs}, Step: {step}, Loss: {loss.item():.4f}, Hidden State: {str_states}")
|
| 124 |
|
| 125 |
# Perform validation at specified intervals
|
|
|
|
| 138 |
validation_loss += val_loss.item()
|
| 139 |
valid_steps += 1
|
| 140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
avg_validation_loss = validation_loss / valid_steps if valid_steps > 0 else float('inf')
|
| 142 |
+
# Log validation loss
|
| 143 |
+
wandb.log({"validation_loss": avg_validation_loss, "step": step})
|
| 144 |
print(f"----- Validation after Step {step}: Average Loss = {avg_validation_loss:.4f} -----")
|
| 145 |
model.train() # Switch back to training mode
|
| 146 |
|
|
|
|
| 155 |
prompt = sample_ids.unsqueeze(0) # Shape: [1, input_len]
|
| 156 |
generated_text = generate_text(model, tokenizer, prompt, num_predict)
|
| 157 |
print(f"Generated Text:\n{generated_text}\n")
|
| 158 |
+
# Optionally, log generated text (e.g., as HTML to preserve formatting)
|
| 159 |
+
# wandb.log({"generated_text": wandb.Html(f"<pre>{generated_text}</pre>")}, step=step)
|