Update train.py
Browse files
train.py
CHANGED
|
@@ -14,8 +14,7 @@ from util import generate_text, generate_name
|
|
| 14 |
# Configuration Parameters
|
| 15 |
# ============================
|
| 16 |
dataset_path = 'flpelerin/tinystories-100k'
|
| 17 |
-
|
| 18 |
-
model_name = run_name # Example: "mingru-a14c"
|
| 19 |
|
| 20 |
num_epochs = 1
|
| 21 |
batch_size = 4
|
|
@@ -28,7 +27,7 @@ num_predict = 250
|
|
| 28 |
infer_every = 200
|
| 29 |
reset_state_every = 16
|
| 30 |
validate_every = 200
|
| 31 |
-
save_every = 500 #
|
| 32 |
|
| 33 |
# ============================
|
| 34 |
# Initialize the Device
|
|
@@ -87,45 +86,53 @@ parameters_count = sum(p.numel() for p in model.parameters())
|
|
| 87 |
print(f"Model has {parameters_count:,} parameters")
|
| 88 |
|
| 89 |
# ============================
|
| 90 |
-
#
|
| 91 |
# ============================
|
| 92 |
-
|
| 93 |
-
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 94 |
-
checkpoint_pattern = re.compile(rf"^{re.escape(model_name)}-(\d+)k\.bin$")
|
| 95 |
|
| 96 |
-
def
|
| 97 |
"""
|
| 98 |
-
|
| 99 |
-
|
| 100 |
Args:
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
Returns:
|
| 105 |
-
tuple: (checkpoint_path (str), step (int)) if found, else (None, 0)
|
| 106 |
"""
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
max_step = step
|
| 115 |
-
latest_ckpt = filename
|
| 116 |
-
if latest_ckpt:
|
| 117 |
-
return os.path.join(directory, latest_ckpt), max_step
|
| 118 |
-
else:
|
| 119 |
-
return None, 0
|
| 120 |
|
| 121 |
# ============================
|
| 122 |
-
# Load Checkpoint if Exists
|
| 123 |
# ============================
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
else:
|
|
|
|
| 129 |
print("No checkpoint found. Starting training from scratch.")
|
| 130 |
|
| 131 |
# ============================
|
|
@@ -134,36 +141,28 @@ else:
|
|
| 134 |
wandb.login(key="860f8753998c6e6dc356914de07e8855aa2f9642")
|
| 135 |
wandb.init(
|
| 136 |
project="minGRU-Training",
|
| 137 |
-
name=
|
| 138 |
config={
|
| 139 |
"dataset_path": dataset_path,
|
| 140 |
-
|
| 141 |
"num_epochs": num_epochs,
|
| 142 |
"batch_size": batch_size,
|
| 143 |
"seq_length": seq_length,
|
| 144 |
"learning_rate": learning_rate,
|
| 145 |
-
|
| 146 |
"input_len": input_len,
|
| 147 |
"num_predict": num_predict,
|
| 148 |
-
|
| 149 |
"infer_every": infer_every,
|
| 150 |
"reset_state_every": reset_state_every,
|
| 151 |
"validate_every": validate_every,
|
| 152 |
"save_every": save_every, # Logging the new variable
|
| 153 |
-
|
| 154 |
"dataset_rows": tokenized_datasets['train'].num_rows,
|
| 155 |
"dataset_token_count": batch_size * seq_length,
|
| 156 |
-
|
| 157 |
"train_set_size": len(train_dataset),
|
| 158 |
"valid_set_size": len(valid_dataset),
|
| 159 |
-
|
| 160 |
"model_parameters": parameters_count,
|
| 161 |
-
|
| 162 |
"vocab_size": vocab_size,
|
| 163 |
"d_model": 384,
|
| 164 |
"d_inner": 768,
|
| 165 |
"n_layers": 6,
|
| 166 |
-
|
| 167 |
"device": str(device),
|
| 168 |
"model_name": model_name # Log model_name
|
| 169 |
}
|
|
@@ -183,7 +182,22 @@ for epoch in range(num_epochs):
|
|
| 183 |
input_ids = torch.tensor(batch['input_ids']).to(device)
|
| 184 |
|
| 185 |
# Reset hidden states if needed
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
if h_states is not None:
|
| 188 |
try:
|
| 189 |
avg_states = sum([torch.mean(h).item() for h in h_states]) / len(h_states)
|
|
@@ -195,12 +209,6 @@ for epoch in range(num_epochs):
|
|
| 195 |
avg_states = None
|
| 196 |
var_states = None
|
| 197 |
|
| 198 |
-
optimizer.zero_grad()
|
| 199 |
-
_, h_states, loss = model.forward(input_ids, h_states)
|
| 200 |
-
loss.backward()
|
| 201 |
-
optimizer.step()
|
| 202 |
-
|
| 203 |
-
step += 1
|
| 204 |
# Log step information
|
| 205 |
wandb.log({
|
| 206 |
"loss": loss.item(),
|
|
@@ -251,8 +259,15 @@ for epoch in range(num_epochs):
|
|
| 251 |
if step % save_every == 0:
|
| 252 |
step_str = f"{step}k" # Format step with 'k', e.g., '750k'
|
| 253 |
checkpoint_filename = f"{model_name}-{step_str}.bin"
|
| 254 |
-
checkpoint_path =
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
# Configuration Parameters
|
| 15 |
# ============================
|
| 16 |
dataset_path = 'flpelerin/tinystories-100k'
|
| 17 |
+
model_name = generate_name()} # Example: "mingru-a14c"
|
|
|
|
| 18 |
|
| 19 |
num_epochs = 1
|
| 20 |
batch_size = 4
|
|
|
|
| 27 |
infer_every = 200
|
| 28 |
reset_state_every = 16
|
| 29 |
validate_every = 200
|
| 30 |
+
save_every = 500 # Controls checkpointing frequency
|
| 31 |
|
| 32 |
# ============================
|
| 33 |
# Initialize the Device
|
|
|
|
| 86 |
print(f"Model has {parameters_count:,} parameters")
|
| 87 |
|
| 88 |
# ============================
|
| 89 |
+
# Symbolic Link Configuration
|
| 90 |
# ============================
|
| 91 |
+
symlink_path = 'pytorch_model.bin'
|
|
|
|
|
|
|
| 92 |
|
| 93 |
+
def update_symlink(target_path, symlink_path):
|
| 94 |
"""
|
| 95 |
+
Creates or updates a symbolic link pointing to the target path.
|
| 96 |
+
|
| 97 |
Args:
|
| 98 |
+
target_path (str): The file path the symlink should point to.
|
| 99 |
+
symlink_path (str): The symlink's path.
|
|
|
|
|
|
|
|
|
|
| 100 |
"""
|
| 101 |
+
try:
|
| 102 |
+
if os.path.islink(symlink_path) or os.path.exists(symlink_path):
|
| 103 |
+
os.remove(symlink_path)
|
| 104 |
+
os.symlink(target_path, symlink_path)
|
| 105 |
+
print(f"Updated symlink: {symlink_path} -> {target_path}")
|
| 106 |
+
except OSError as e:
|
| 107 |
+
print(f"Warning: Failed to create symlink {symlink_path} -> {target_path}. Error: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
# ============================
|
| 110 |
+
# Load Checkpoint from pytorch_model.bin if Exists
|
| 111 |
# ============================
|
| 112 |
+
if os.path.exists(symlink_path):
|
| 113 |
+
try:
|
| 114 |
+
model.load_state_dict(torch.load(symlink_path, map_location=device))
|
| 115 |
+
print(f"Loaded model weights from {symlink_path}")
|
| 116 |
+
|
| 117 |
+
# Optional: Extract step number from the symlink's target filename
|
| 118 |
+
ckpt_filename = os.readlink(symlink_path) if os.path.islink(symlink_path) else None
|
| 119 |
+
if ckpt_filename:
|
| 120 |
+
match = re.search(rf"{re.escape(model_name)}-(\d+)k\.bin$", os.path.basename(ckpt_filename))
|
| 121 |
+
if match:
|
| 122 |
+
latest_step = int(match.group(1))
|
| 123 |
+
print(f"Resuming from step {latest_step}k")
|
| 124 |
+
else:
|
| 125 |
+
latest_step = 0
|
| 126 |
+
print("Could not extract step number from checkpoint filename. Starting from step 0.")
|
| 127 |
+
else:
|
| 128 |
+
latest_step = 0
|
| 129 |
+
print("Symlink does not point to a valid checkpoint. Starting from step 0.")
|
| 130 |
+
except Exception as e:
|
| 131 |
+
print(f"Error loading model from {symlink_path}: {e}")
|
| 132 |
+
latest_step = 0
|
| 133 |
+
print("Starting training from scratch.")
|
| 134 |
else:
|
| 135 |
+
latest_step = 0
|
| 136 |
print("No checkpoint found. Starting training from scratch.")
|
| 137 |
|
| 138 |
# ============================
|
|
|
|
| 141 |
wandb.login(key="860f8753998c6e6dc356914de07e8855aa2f9642")
|
| 142 |
wandb.init(
|
| 143 |
project="minGRU-Training",
|
| 144 |
+
name=model_name,
|
| 145 |
config={
|
| 146 |
"dataset_path": dataset_path,
|
|
|
|
| 147 |
"num_epochs": num_epochs,
|
| 148 |
"batch_size": batch_size,
|
| 149 |
"seq_length": seq_length,
|
| 150 |
"learning_rate": learning_rate,
|
|
|
|
| 151 |
"input_len": input_len,
|
| 152 |
"num_predict": num_predict,
|
|
|
|
| 153 |
"infer_every": infer_every,
|
| 154 |
"reset_state_every": reset_state_every,
|
| 155 |
"validate_every": validate_every,
|
| 156 |
"save_every": save_every, # Logging the new variable
|
|
|
|
| 157 |
"dataset_rows": tokenized_datasets['train'].num_rows,
|
| 158 |
"dataset_token_count": batch_size * seq_length,
|
|
|
|
| 159 |
"train_set_size": len(train_dataset),
|
| 160 |
"valid_set_size": len(valid_dataset),
|
|
|
|
| 161 |
"model_parameters": parameters_count,
|
|
|
|
| 162 |
"vocab_size": vocab_size,
|
| 163 |
"d_model": 384,
|
| 164 |
"d_inner": 768,
|
| 165 |
"n_layers": 6,
|
|
|
|
| 166 |
"device": str(device),
|
| 167 |
"model_name": model_name # Log model_name
|
| 168 |
}
|
|
|
|
| 182 |
input_ids = torch.tensor(batch['input_ids']).to(device)
|
| 183 |
|
| 184 |
# Reset hidden states if needed
|
| 185 |
+
if step % reset_state_every == 0:
|
| 186 |
+
h_states = None
|
| 187 |
+
# Otherwise, keep existing hidden states
|
| 188 |
+
|
| 189 |
+
optimizer.zero_grad()
|
| 190 |
+
try:
|
| 191 |
+
_, h_states, loss = model.forward(input_ids, h_states)
|
| 192 |
+
loss.backward()
|
| 193 |
+
optimizer.step()
|
| 194 |
+
except Exception as e:
|
| 195 |
+
print(f"Error during training step {step + 1}: {e}")
|
| 196 |
+
continue # Skip to the next batch
|
| 197 |
+
|
| 198 |
+
step += 1
|
| 199 |
+
|
| 200 |
+
# Compute statistics of hidden states
|
| 201 |
if h_states is not None:
|
| 202 |
try:
|
| 203 |
avg_states = sum([torch.mean(h).item() for h in h_states]) / len(h_states)
|
|
|
|
| 209 |
avg_states = None
|
| 210 |
var_states = None
|
| 211 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
# Log step information
|
| 213 |
wandb.log({
|
| 214 |
"loss": loss.item(),
|
|
|
|
| 259 |
if step % save_every == 0:
|
| 260 |
step_str = f"{step}k" # Format step with 'k', e.g., '750k'
|
| 261 |
checkpoint_filename = f"{model_name}-{step_str}.bin"
|
| 262 |
+
checkpoint_path = checkpoint_filename
|
| 263 |
+
try:
|
| 264 |
+
torch.save(model.state_dict(), checkpoint_path)
|
| 265 |
+
print(f"Saved model checkpoint at step {step} to {checkpoint_path}")
|
| 266 |
+
|
| 267 |
+
# Update the symbolic link to point to this checkpoint
|
| 268 |
+
update_symlink(checkpoint_path, symlink_path)
|
| 269 |
+
|
| 270 |
+
# Optionally, log the checkpoint to W&B
|
| 271 |
+
# wandb.save(checkpoint_path)
|
| 272 |
+
except Exception as e:
|
| 273 |
+
print(f"Error saving checkpoint at step {step}: {e}")
|