flpelerin commited on
Commit
a563d57
·
verified ·
1 Parent(s): b2ea2e8

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +69 -54
train.py CHANGED
@@ -14,8 +14,7 @@ from util import generate_text, generate_name
14
  # Configuration Parameters
15
  # ============================
16
  dataset_path = 'flpelerin/tinystories-100k'
17
- run_name = generate_name()
18
- model_name = run_name # Example: "mingru-a14c"
19
 
20
  num_epochs = 1
21
  batch_size = 4
@@ -28,7 +27,7 @@ num_predict = 250
28
  infer_every = 200
29
  reset_state_every = 16
30
  validate_every = 200
31
- save_every = 500 # New variable to control checkpointing frequency
32
 
33
  # ============================
34
  # Initialize the Device
@@ -87,45 +86,53 @@ parameters_count = sum(p.numel() for p in model.parameters())
87
  print(f"Model has {parameters_count:,} parameters")
88
 
89
  # ============================
90
- # Setup Checkpoint Directory and Naming
91
  # ============================
92
- checkpoint_dir = run_name # Directory named after run_name
93
- os.makedirs(checkpoint_dir, exist_ok=True)
94
- checkpoint_pattern = re.compile(rf"^{re.escape(model_name)}-(\d+)k\.bin$")
95
 
96
- def find_latest_checkpoint(directory, pattern):
97
  """
98
- Finds the checkpoint file with the highest step number in the specified directory.
99
-
100
  Args:
101
- directory (str): Path to the checkpoint directory.
102
- pattern (re.Pattern): Compiled regex pattern to match checkpoint files.
103
-
104
- Returns:
105
- tuple: (checkpoint_path (str), step (int)) if found, else (None, 0)
106
  """
107
- max_step = 0
108
- latest_ckpt = None
109
- for filename in os.listdir(directory):
110
- match = pattern.match(filename)
111
- if match:
112
- step = int(match.group(1))
113
- if step > max_step:
114
- max_step = step
115
- latest_ckpt = filename
116
- if latest_ckpt:
117
- return os.path.join(directory, latest_ckpt), max_step
118
- else:
119
- return None, 0
120
 
121
  # ============================
122
- # Load Checkpoint if Exists
123
  # ============================
124
- latest_ckpt_path, latest_step = find_latest_checkpoint(checkpoint_dir, checkpoint_pattern)
125
- if latest_ckpt_path:
126
- model.load_state_dict(torch.load(latest_ckpt_path, map_location=device))
127
- print(f"Loaded model weights from {latest_ckpt_path} at step {latest_step}k")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  else:
 
129
  print("No checkpoint found. Starting training from scratch.")
130
 
131
  # ============================
@@ -134,36 +141,28 @@ else:
134
  wandb.login(key="860f8753998c6e6dc356914de07e8855aa2f9642")
135
  wandb.init(
136
  project="minGRU-Training",
137
- name=run_name,
138
  config={
139
  "dataset_path": dataset_path,
140
-
141
  "num_epochs": num_epochs,
142
  "batch_size": batch_size,
143
  "seq_length": seq_length,
144
  "learning_rate": learning_rate,
145
-
146
  "input_len": input_len,
147
  "num_predict": num_predict,
148
-
149
  "infer_every": infer_every,
150
  "reset_state_every": reset_state_every,
151
  "validate_every": validate_every,
152
  "save_every": save_every, # Logging the new variable
153
-
154
  "dataset_rows": tokenized_datasets['train'].num_rows,
155
  "dataset_token_count": batch_size * seq_length,
156
-
157
  "train_set_size": len(train_dataset),
158
  "valid_set_size": len(valid_dataset),
159
-
160
  "model_parameters": parameters_count,
161
-
162
  "vocab_size": vocab_size,
163
  "d_model": 384,
164
  "d_inner": 768,
165
  "n_layers": 6,
166
-
167
  "device": str(device),
168
  "model_name": model_name # Log model_name
169
  }
@@ -183,7 +182,22 @@ for epoch in range(num_epochs):
183
  input_ids = torch.tensor(batch['input_ids']).to(device)
184
 
185
  # Reset hidden states if needed
186
- h_states = h_states if (step % reset_state_every != 0 and h_states is not None) else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  if h_states is not None:
188
  try:
189
  avg_states = sum([torch.mean(h).item() for h in h_states]) / len(h_states)
@@ -195,12 +209,6 @@ for epoch in range(num_epochs):
195
  avg_states = None
196
  var_states = None
197
 
198
- optimizer.zero_grad()
199
- _, h_states, loss = model.forward(input_ids, h_states)
200
- loss.backward()
201
- optimizer.step()
202
-
203
- step += 1
204
  # Log step information
205
  wandb.log({
206
  "loss": loss.item(),
@@ -251,8 +259,15 @@ for epoch in range(num_epochs):
251
  if step % save_every == 0:
252
  step_str = f"{step}k" # Format step with 'k', e.g., '750k'
253
  checkpoint_filename = f"{model_name}-{step_str}.bin"
254
- checkpoint_path = os.path.join(checkpoint_dir, checkpoint_filename)
255
- torch.save(model.state_dict(), checkpoint_path)
256
- print(f"Saved model checkpoint at step {step} to {checkpoint_path}")
257
- # Optionally, log the checkpoint to W&B
258
- # wandb.save(checkpoint_path)
 
 
 
 
 
 
 
 
14
  # Configuration Parameters
15
  # ============================
16
  dataset_path = 'flpelerin/tinystories-100k'
17
+ model_name = generate_name()} # Example: "mingru-a14c"
 
18
 
19
  num_epochs = 1
20
  batch_size = 4
 
27
  infer_every = 200
28
  reset_state_every = 16
29
  validate_every = 200
30
+ save_every = 500 # Controls checkpointing frequency
31
 
32
  # ============================
33
  # Initialize the Device
 
86
  print(f"Model has {parameters_count:,} parameters")
87
 
88
  # ============================
89
+ # Symbolic Link Configuration
90
  # ============================
91
+ symlink_path = 'pytorch_model.bin'
 
 
92
 
93
+ def update_symlink(target_path, symlink_path):
94
  """
95
+ Creates or updates a symbolic link pointing to the target path.
96
+
97
  Args:
98
+ target_path (str): The file path the symlink should point to.
99
+ symlink_path (str): The symlink's path.
 
 
 
100
  """
101
+ try:
102
+ if os.path.islink(symlink_path) or os.path.exists(symlink_path):
103
+ os.remove(symlink_path)
104
+ os.symlink(target_path, symlink_path)
105
+ print(f"Updated symlink: {symlink_path} -> {target_path}")
106
+ except OSError as e:
107
+ print(f"Warning: Failed to create symlink {symlink_path} -> {target_path}. Error: {e}")
 
 
 
 
 
 
108
 
109
  # ============================
110
+ # Load Checkpoint from pytorch_model.bin if Exists
111
  # ============================
112
+ if os.path.exists(symlink_path):
113
+ try:
114
+ model.load_state_dict(torch.load(symlink_path, map_location=device))
115
+ print(f"Loaded model weights from {symlink_path}")
116
+
117
+ # Optional: Extract step number from the symlink's target filename
118
+ ckpt_filename = os.readlink(symlink_path) if os.path.islink(symlink_path) else None
119
+ if ckpt_filename:
120
+ match = re.search(rf"{re.escape(model_name)}-(\d+)k\.bin$", os.path.basename(ckpt_filename))
121
+ if match:
122
+ latest_step = int(match.group(1))
123
+ print(f"Resuming from step {latest_step}k")
124
+ else:
125
+ latest_step = 0
126
+ print("Could not extract step number from checkpoint filename. Starting from step 0.")
127
+ else:
128
+ latest_step = 0
129
+ print("Symlink does not point to a valid checkpoint. Starting from step 0.")
130
+ except Exception as e:
131
+ print(f"Error loading model from {symlink_path}: {e}")
132
+ latest_step = 0
133
+ print("Starting training from scratch.")
134
  else:
135
+ latest_step = 0
136
  print("No checkpoint found. Starting training from scratch.")
137
 
138
  # ============================
 
141
  wandb.login(key="860f8753998c6e6dc356914de07e8855aa2f9642")
142
  wandb.init(
143
  project="minGRU-Training",
144
+ name=model_name,
145
  config={
146
  "dataset_path": dataset_path,
 
147
  "num_epochs": num_epochs,
148
  "batch_size": batch_size,
149
  "seq_length": seq_length,
150
  "learning_rate": learning_rate,
 
151
  "input_len": input_len,
152
  "num_predict": num_predict,
 
153
  "infer_every": infer_every,
154
  "reset_state_every": reset_state_every,
155
  "validate_every": validate_every,
156
  "save_every": save_every, # Logging the new variable
 
157
  "dataset_rows": tokenized_datasets['train'].num_rows,
158
  "dataset_token_count": batch_size * seq_length,
 
159
  "train_set_size": len(train_dataset),
160
  "valid_set_size": len(valid_dataset),
 
161
  "model_parameters": parameters_count,
 
162
  "vocab_size": vocab_size,
163
  "d_model": 384,
164
  "d_inner": 768,
165
  "n_layers": 6,
 
166
  "device": str(device),
167
  "model_name": model_name # Log model_name
168
  }
 
182
  input_ids = torch.tensor(batch['input_ids']).to(device)
183
 
184
  # Reset hidden states if needed
185
+ if step % reset_state_every == 0:
186
+ h_states = None
187
+ # Otherwise, keep existing hidden states
188
+
189
+ optimizer.zero_grad()
190
+ try:
191
+ _, h_states, loss = model.forward(input_ids, h_states)
192
+ loss.backward()
193
+ optimizer.step()
194
+ except Exception as e:
195
+ print(f"Error during training step {step + 1}: {e}")
196
+ continue # Skip to the next batch
197
+
198
+ step += 1
199
+
200
+ # Compute statistics of hidden states
201
  if h_states is not None:
202
  try:
203
  avg_states = sum([torch.mean(h).item() for h in h_states]) / len(h_states)
 
209
  avg_states = None
210
  var_states = None
211
 
 
 
 
 
 
 
212
  # Log step information
213
  wandb.log({
214
  "loss": loss.item(),
 
259
  if step % save_every == 0:
260
  step_str = f"{step}k" # Format step with 'k', e.g., '750k'
261
  checkpoint_filename = f"{model_name}-{step_str}.bin"
262
+ checkpoint_path = checkpoint_filename
263
+ try:
264
+ torch.save(model.state_dict(), checkpoint_path)
265
+ print(f"Saved model checkpoint at step {step} to {checkpoint_path}")
266
+
267
+ # Update the symbolic link to point to this checkpoint
268
+ update_symlink(checkpoint_path, symlink_path)
269
+
270
+ # Optionally, log the checkpoint to W&B
271
+ # wandb.save(checkpoint_path)
272
+ except Exception as e:
273
+ print(f"Error saving checkpoint at step {step}: {e}")