flpelerin commited on
Commit
b2ea2e8
·
verified ·
1 Parent(s): b53dc87

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +61 -11
train.py CHANGED
@@ -1,4 +1,5 @@
1
- import os # Added for file operations
 
2
  import torch
3
  import math
4
  from transformers import GPT2Tokenizer
@@ -14,6 +15,7 @@ from util import generate_text, generate_name
14
  # ============================
15
  dataset_path = 'flpelerin/tinystories-100k'
16
  run_name = generate_name()
 
17
 
18
  num_epochs = 1
19
  batch_size = 4
@@ -84,13 +86,45 @@ model.to(device)
84
  parameters_count = sum(p.numel() for p in model.parameters())
85
  print(f"Model has {parameters_count:,} parameters")
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  # ============================
88
  # Load Checkpoint if Exists
89
  # ============================
90
- checkpoint_path = 'pytorch_model.bin'
91
- if os.path.exists(checkpoint_path):
92
- model.load_state_dict(torch.load(checkpoint_path, map_location=device))
93
- print(f"Loaded model weights from {checkpoint_path}")
94
  else:
95
  print("No checkpoint found. Starting training from scratch.")
96
 
@@ -130,7 +164,8 @@ wandb.init(
130
  "d_inner": 768,
131
  "n_layers": 6,
132
 
133
- "device": str(device)
 
134
  }
135
  )
136
 
@@ -139,7 +174,7 @@ wandb.init(
139
  # ============================
140
  optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
141
  h_states = None
142
- step = 0
143
 
144
  for epoch in range(num_epochs):
145
  print(f"Starting Epoch {epoch + 1}/{num_epochs}")
@@ -148,9 +183,17 @@ for epoch in range(num_epochs):
148
  input_ids = torch.tensor(batch['input_ids']).to(device)
149
 
150
  # Reset hidden states if needed
151
- h_states = h_states if (step % reset_state_every != 0) else None
152
- avg_states = sum([torch.mean(h_states[i]).item() for i in range(len(h_states))]) / len(h_states) if h_states is not None else None
153
- var_states = torch.var(torch.cat(h_states, dim=0)).item() if h_states else None
 
 
 
 
 
 
 
 
154
 
155
  optimizer.zero_grad()
156
  _, h_states, loss = model.forward(input_ids, h_states)
@@ -162,7 +205,8 @@ for epoch in range(num_epochs):
162
  wandb.log({
163
  "loss": loss.item(),
164
  "average_hidden_state": avg_states,
165
- "variance_hidden_state": var_states
 
166
  })
167
  print(f"Epoch: {epoch + 1}/{num_epochs}, Step: {step}, Loss: {loss.item():.4f}, Hidden States: average = {avg_states}, variance = {var_states}")
168
 
@@ -189,6 +233,9 @@ for epoch in range(num_epochs):
189
  # Perform inference at specified steps
190
  if step % infer_every == 0:
191
  with torch.no_grad():
 
 
 
192
  # Select a single input from the current batch for inference
193
  sample_ids = input_ids[0][:input_len]
194
  input_text = tokenizer.decode(sample_ids, skip_special_tokens=True)
@@ -202,6 +249,9 @@ for epoch in range(num_epochs):
202
 
203
  # Perform checkpointing at specified steps
204
  if step % save_every == 0:
 
 
 
205
  torch.save(model.state_dict(), checkpoint_path)
206
  print(f"Saved model checkpoint at step {step} to {checkpoint_path}")
207
  # Optionally, log the checkpoint to W&B
 
1
+ import os # For file operations
2
+ import re # For regex operations
3
  import torch
4
  import math
5
  from transformers import GPT2Tokenizer
 
15
  # ============================
16
  dataset_path = 'flpelerin/tinystories-100k'
17
  run_name = generate_name()
18
+ model_name = run_name # Example: "mingru-a14c"
19
 
20
  num_epochs = 1
21
  batch_size = 4
 
86
  parameters_count = sum(p.numel() for p in model.parameters())
87
  print(f"Model has {parameters_count:,} parameters")
88
 
89
+ # ============================
90
+ # Setup Checkpoint Directory and Naming
91
+ # ============================
92
+ checkpoint_dir = run_name # Directory named after run_name
93
+ os.makedirs(checkpoint_dir, exist_ok=True)
94
+ checkpoint_pattern = re.compile(rf"^{re.escape(model_name)}-(\d+)k\.bin$")
95
+
96
+ def find_latest_checkpoint(directory, pattern):
97
+ """
98
+ Finds the checkpoint file with the highest step number in the specified directory.
99
+
100
+ Args:
101
+ directory (str): Path to the checkpoint directory.
102
+ pattern (re.Pattern): Compiled regex pattern to match checkpoint files.
103
+
104
+ Returns:
105
+ tuple: (checkpoint_path (str), step (int)) if found, else (None, 0)
106
+ """
107
+ max_step = 0
108
+ latest_ckpt = None
109
+ for filename in os.listdir(directory):
110
+ match = pattern.match(filename)
111
+ if match:
112
+ step = int(match.group(1))
113
+ if step > max_step:
114
+ max_step = step
115
+ latest_ckpt = filename
116
+ if latest_ckpt:
117
+ return os.path.join(directory, latest_ckpt), max_step
118
+ else:
119
+ return None, 0
120
+
121
  # ============================
122
  # Load Checkpoint if Exists
123
  # ============================
124
+ latest_ckpt_path, latest_step = find_latest_checkpoint(checkpoint_dir, checkpoint_pattern)
125
+ if latest_ckpt_path:
126
+ model.load_state_dict(torch.load(latest_ckpt_path, map_location=device))
127
+ print(f"Loaded model weights from {latest_ckpt_path} at step {latest_step}k")
128
  else:
129
  print("No checkpoint found. Starting training from scratch.")
130
 
 
164
  "d_inner": 768,
165
  "n_layers": 6,
166
 
167
+ "device": str(device),
168
+ "model_name": model_name # Log model_name
169
  }
170
  )
171
 
 
174
  # ============================
175
  optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
176
  h_states = None
177
+ step = latest_step # Start from the latest step if checkpoint was loaded
178
 
179
  for epoch in range(num_epochs):
180
  print(f"Starting Epoch {epoch + 1}/{num_epochs}")
 
183
  input_ids = torch.tensor(batch['input_ids']).to(device)
184
 
185
  # Reset hidden states if needed
186
+ h_states = h_states if (step % reset_state_every != 0 and h_states is not None) else None
187
+ if h_states is not None:
188
+ try:
189
+ avg_states = sum([torch.mean(h).item() for h in h_states]) / len(h_states)
190
+ var_states = torch.var(torch.cat(h_states, dim=0)).item()
191
+ except Exception as e:
192
+ avg_states = None
193
+ var_states = None
194
+ else:
195
+ avg_states = None
196
+ var_states = None
197
 
198
  optimizer.zero_grad()
199
  _, h_states, loss = model.forward(input_ids, h_states)
 
205
  wandb.log({
206
  "loss": loss.item(),
207
  "average_hidden_state": avg_states,
208
+ "variance_hidden_state": var_states,
209
+ "step": step
210
  })
211
  print(f"Epoch: {epoch + 1}/{num_epochs}, Step: {step}, Loss: {loss.item():.4f}, Hidden States: average = {avg_states}, variance = {var_states}")
212
 
 
233
  # Perform inference at specified steps
234
  if step % infer_every == 0:
235
  with torch.no_grad():
236
+ if input_ids.size(1) < input_len:
237
+ print("Input length is shorter than input_len. Skipping inference.")
238
+ continue
239
  # Select a single input from the current batch for inference
240
  sample_ids = input_ids[0][:input_len]
241
  input_text = tokenizer.decode(sample_ids, skip_special_tokens=True)
 
249
 
250
  # Perform checkpointing at specified steps
251
  if step % save_every == 0:
252
+ step_str = f"{step}k" # Format step with 'k', e.g., '750k'
253
+ checkpoint_filename = f"{model_name}-{step_str}.bin"
254
+ checkpoint_path = os.path.join(checkpoint_dir, checkpoint_filename)
255
  torch.save(model.state_dict(), checkpoint_path)
256
  print(f"Saved model checkpoint at step {step} to {checkpoint_path}")
257
  # Optionally, log the checkpoint to W&B