Update train.py
Browse files
train.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
-
import os #
|
|
|
|
| 2 |
import torch
|
| 3 |
import math
|
| 4 |
from transformers import GPT2Tokenizer
|
|
@@ -14,6 +15,7 @@ from util import generate_text, generate_name
|
|
| 14 |
# ============================
|
| 15 |
dataset_path = 'flpelerin/tinystories-100k'
|
| 16 |
run_name = generate_name()
|
|
|
|
| 17 |
|
| 18 |
num_epochs = 1
|
| 19 |
batch_size = 4
|
|
@@ -84,13 +86,45 @@ model.to(device)
|
|
| 84 |
parameters_count = sum(p.numel() for p in model.parameters())
|
| 85 |
print(f"Model has {parameters_count:,} parameters")
|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
# ============================
|
| 88 |
# Load Checkpoint if Exists
|
| 89 |
# ============================
|
| 90 |
-
|
| 91 |
-
if
|
| 92 |
-
model.load_state_dict(torch.load(
|
| 93 |
-
print(f"Loaded model weights from {
|
| 94 |
else:
|
| 95 |
print("No checkpoint found. Starting training from scratch.")
|
| 96 |
|
|
@@ -130,7 +164,8 @@ wandb.init(
|
|
| 130 |
"d_inner": 768,
|
| 131 |
"n_layers": 6,
|
| 132 |
|
| 133 |
-
"device": str(device)
|
|
|
|
| 134 |
}
|
| 135 |
)
|
| 136 |
|
|
@@ -139,7 +174,7 @@ wandb.init(
|
|
| 139 |
# ============================
|
| 140 |
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
|
| 141 |
h_states = None
|
| 142 |
-
step =
|
| 143 |
|
| 144 |
for epoch in range(num_epochs):
|
| 145 |
print(f"Starting Epoch {epoch + 1}/{num_epochs}")
|
|
@@ -148,9 +183,17 @@ for epoch in range(num_epochs):
|
|
| 148 |
input_ids = torch.tensor(batch['input_ids']).to(device)
|
| 149 |
|
| 150 |
# Reset hidden states if needed
|
| 151 |
-
h_states = h_states if (step % reset_state_every != 0) else None
|
| 152 |
-
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
optimizer.zero_grad()
|
| 156 |
_, h_states, loss = model.forward(input_ids, h_states)
|
|
@@ -162,7 +205,8 @@ for epoch in range(num_epochs):
|
|
| 162 |
wandb.log({
|
| 163 |
"loss": loss.item(),
|
| 164 |
"average_hidden_state": avg_states,
|
| 165 |
-
"variance_hidden_state": var_states
|
|
|
|
| 166 |
})
|
| 167 |
print(f"Epoch: {epoch + 1}/{num_epochs}, Step: {step}, Loss: {loss.item():.4f}, Hidden States: average = {avg_states}, variance = {var_states}")
|
| 168 |
|
|
@@ -189,6 +233,9 @@ for epoch in range(num_epochs):
|
|
| 189 |
# Perform inference at specified steps
|
| 190 |
if step % infer_every == 0:
|
| 191 |
with torch.no_grad():
|
|
|
|
|
|
|
|
|
|
| 192 |
# Select a single input from the current batch for inference
|
| 193 |
sample_ids = input_ids[0][:input_len]
|
| 194 |
input_text = tokenizer.decode(sample_ids, skip_special_tokens=True)
|
|
@@ -202,6 +249,9 @@ for epoch in range(num_epochs):
|
|
| 202 |
|
| 203 |
# Perform checkpointing at specified steps
|
| 204 |
if step % save_every == 0:
|
|
|
|
|
|
|
|
|
|
| 205 |
torch.save(model.state_dict(), checkpoint_path)
|
| 206 |
print(f"Saved model checkpoint at step {step} to {checkpoint_path}")
|
| 207 |
# Optionally, log the checkpoint to W&B
|
|
|
|
| 1 |
+
import os # For file operations
|
| 2 |
+
import re # For regex operations
|
| 3 |
import torch
|
| 4 |
import math
|
| 5 |
from transformers import GPT2Tokenizer
|
|
|
|
| 15 |
# ============================
|
| 16 |
dataset_path = 'flpelerin/tinystories-100k'
|
| 17 |
run_name = generate_name()
|
| 18 |
+
model_name = run_name # Example: "mingru-a14c"
|
| 19 |
|
| 20 |
num_epochs = 1
|
| 21 |
batch_size = 4
|
|
|
|
| 86 |
parameters_count = sum(p.numel() for p in model.parameters())
|
| 87 |
print(f"Model has {parameters_count:,} parameters")
|
| 88 |
|
| 89 |
+
# ============================
|
| 90 |
+
# Setup Checkpoint Directory and Naming
|
| 91 |
+
# ============================
|
| 92 |
+
checkpoint_dir = run_name # Directory named after run_name
|
| 93 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 94 |
+
checkpoint_pattern = re.compile(rf"^{re.escape(model_name)}-(\d+)k\.bin$")
|
| 95 |
+
|
| 96 |
+
def find_latest_checkpoint(directory, pattern):
|
| 97 |
+
"""
|
| 98 |
+
Finds the checkpoint file with the highest step number in the specified directory.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
directory (str): Path to the checkpoint directory.
|
| 102 |
+
pattern (re.Pattern): Compiled regex pattern to match checkpoint files.
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
tuple: (checkpoint_path (str), step (int)) if found, else (None, 0)
|
| 106 |
+
"""
|
| 107 |
+
max_step = 0
|
| 108 |
+
latest_ckpt = None
|
| 109 |
+
for filename in os.listdir(directory):
|
| 110 |
+
match = pattern.match(filename)
|
| 111 |
+
if match:
|
| 112 |
+
step = int(match.group(1))
|
| 113 |
+
if step > max_step:
|
| 114 |
+
max_step = step
|
| 115 |
+
latest_ckpt = filename
|
| 116 |
+
if latest_ckpt:
|
| 117 |
+
return os.path.join(directory, latest_ckpt), max_step
|
| 118 |
+
else:
|
| 119 |
+
return None, 0
|
| 120 |
+
|
| 121 |
# ============================
|
| 122 |
# Load Checkpoint if Exists
|
| 123 |
# ============================
|
| 124 |
+
latest_ckpt_path, latest_step = find_latest_checkpoint(checkpoint_dir, checkpoint_pattern)
|
| 125 |
+
if latest_ckpt_path:
|
| 126 |
+
model.load_state_dict(torch.load(latest_ckpt_path, map_location=device))
|
| 127 |
+
print(f"Loaded model weights from {latest_ckpt_path} at step {latest_step}k")
|
| 128 |
else:
|
| 129 |
print("No checkpoint found. Starting training from scratch.")
|
| 130 |
|
|
|
|
| 164 |
"d_inner": 768,
|
| 165 |
"n_layers": 6,
|
| 166 |
|
| 167 |
+
"device": str(device),
|
| 168 |
+
"model_name": model_name # Log model_name
|
| 169 |
}
|
| 170 |
)
|
| 171 |
|
|
|
|
| 174 |
# ============================
|
| 175 |
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
|
| 176 |
h_states = None
|
| 177 |
+
step = latest_step # Start from the latest step if checkpoint was loaded
|
| 178 |
|
| 179 |
for epoch in range(num_epochs):
|
| 180 |
print(f"Starting Epoch {epoch + 1}/{num_epochs}")
|
|
|
|
| 183 |
input_ids = torch.tensor(batch['input_ids']).to(device)
|
| 184 |
|
| 185 |
# Reset hidden states if needed
|
| 186 |
+
h_states = h_states if (step % reset_state_every != 0 and h_states is not None) else None
|
| 187 |
+
if h_states is not None:
|
| 188 |
+
try:
|
| 189 |
+
avg_states = sum([torch.mean(h).item() for h in h_states]) / len(h_states)
|
| 190 |
+
var_states = torch.var(torch.cat(h_states, dim=0)).item()
|
| 191 |
+
except Exception as e:
|
| 192 |
+
avg_states = None
|
| 193 |
+
var_states = None
|
| 194 |
+
else:
|
| 195 |
+
avg_states = None
|
| 196 |
+
var_states = None
|
| 197 |
|
| 198 |
optimizer.zero_grad()
|
| 199 |
_, h_states, loss = model.forward(input_ids, h_states)
|
|
|
|
| 205 |
wandb.log({
|
| 206 |
"loss": loss.item(),
|
| 207 |
"average_hidden_state": avg_states,
|
| 208 |
+
"variance_hidden_state": var_states,
|
| 209 |
+
"step": step
|
| 210 |
})
|
| 211 |
print(f"Epoch: {epoch + 1}/{num_epochs}, Step: {step}, Loss: {loss.item():.4f}, Hidden States: average = {avg_states}, variance = {var_states}")
|
| 212 |
|
|
|
|
| 233 |
# Perform inference at specified steps
|
| 234 |
if step % infer_every == 0:
|
| 235 |
with torch.no_grad():
|
| 236 |
+
if input_ids.size(1) < input_len:
|
| 237 |
+
print("Input length is shorter than input_len. Skipping inference.")
|
| 238 |
+
continue
|
| 239 |
# Select a single input from the current batch for inference
|
| 240 |
sample_ids = input_ids[0][:input_len]
|
| 241 |
input_text = tokenizer.decode(sample_ids, skip_special_tokens=True)
|
|
|
|
| 249 |
|
| 250 |
# Perform checkpointing at specified steps
|
| 251 |
if step % save_every == 0:
|
| 252 |
+
step_str = f"{step}k" # Format step with 'k', e.g., '750k'
|
| 253 |
+
checkpoint_filename = f"{model_name}-{step_str}.bin"
|
| 254 |
+
checkpoint_path = os.path.join(checkpoint_dir, checkpoint_filename)
|
| 255 |
torch.save(model.state_dict(), checkpoint_path)
|
| 256 |
print(f"Saved model checkpoint at step {step} to {checkpoint_path}")
|
| 257 |
# Optionally, log the checkpoint to W&B
|