algorythmtechnologies commited on
Commit
ca8d994
·
verified ·
1 Parent(s): 333e53d

Update supernova/train.py

Browse files
Files changed (1) hide show
  1. supernova/train.py +23 -28
supernova/train.py CHANGED
@@ -31,7 +31,7 @@ def compute_grad_norm(model: nn.Module, debug: bool = False) -> float:
31
  grad_count += 1
32
  param_norm = p.grad.data.float().norm(2).item()
33
  total += param_norm * param_norm
34
- if debug and param_norm > 1e-8: # Only print non-zero gradients
35
  print(f" {name}: grad_norm={param_norm:.6f}")
36
  elif debug:
37
  print(f" {name}: NO GRAD")
@@ -46,13 +46,13 @@ def atomic_save(obj: Dict[str, Any], path: str):
46
  torch.save(obj, tmp)
47
  os.replace(tmp, path)
48
 
49
- def save_safetensors(model_state_dict: Dict[str, torch.Tensor], path: str):
50
  """Save model weights in safetensors format."""
51
  try:
52
  tmp = path + ".tmp"
53
  save_file(model_state_dict, tmp)
54
  os.replace(tmp, path)
55
- print(f"Saved safetensors to {path}")
56
  except Exception as e:
57
  print(f"Warning: Failed to save safetensors: {e}")
58
 
@@ -111,14 +111,13 @@ def train(
111
  num_workers: int = 4,
112
  pin_memory: bool = True,
113
  compile_model: bool = False,
114
- save_safetensors: bool = True,
115
  ):
116
  # reproducibility
117
  torch.manual_seed(seed)
118
  torch.cuda.manual_seed_all(seed)
119
  import random
120
  random.seed(seed)
121
- # performance flags
122
  torch.backends.cudnn.benchmark = True
123
 
124
  # device / distributed
@@ -136,7 +135,6 @@ def train(
136
  assert tok.vocab_size == cfg.vocab_size, "Tokenizer vocab size mismatch."
137
 
138
  model = SupernovaModel(cfg)
139
- # optional: enable gradient checkpointing for memory saving if model supports it
140
  if hasattr(model, "gradient_checkpointing_enable"):
141
  try:
142
  model.gradient_checkpointing_enable()
@@ -145,24 +143,19 @@ def train(
145
 
146
  model.to(device)
147
 
148
- # double-check params
149
  total_params = sum(p.numel() for p in model.parameters())
150
  assert total_params == 25_000_000, f"Model has {total_params} params, expected 25,000,000"
151
 
152
- # optional compile (PyTorch 2.0)
153
  if compile_model:
154
  try:
155
  model = torch.compile(model)
156
  except Exception as e:
157
  print("torch.compile not available/failed:", e)
158
 
159
- # DDP wrap
160
  if ddp:
161
  model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=False)
162
 
163
- # dataset and dataloader
164
  sources = load_sources_from_yaml(data_config_path)
165
- # TODO: improve TokenChunkDataset to perform token-packing (pack multiple short examples into one sequence)
166
  ds = TokenChunkDataset(
167
  tokenizer=tok,
168
  sources=sources,
@@ -171,7 +164,6 @@ def train(
171
  )
172
  sampler = DistributedSampler(ds) if ddp else None
173
 
174
- # NOTE: NO shuffle for IterableDataset!
175
  dl = DataLoader(
176
  ds,
177
  batch_size=batch_size,
@@ -182,7 +174,6 @@ def train(
182
  drop_last=True,
183
  )
184
 
185
- # optimizer
186
  def param_groups(model):
187
  decay, no_decay = [], []
188
  for n, p in model.named_parameters():
@@ -199,20 +190,16 @@ def train(
199
 
200
  optimizer = torch.optim.AdamW(param_groups(model), lr=lr, betas=(0.9, 0.95), eps=1e-8)
201
  scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps)
202
- # AMP scaler
203
  scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))
204
 
205
- # EMA
206
  ema = EMA(model if not ddp else model.module, decay=ema_decay) if use_ema else None
207
 
208
  os.makedirs(out_dir, exist_ok=True)
209
  writer = SummaryWriter(log_dir=os.path.join(out_dir, "runs")) if use_tensorboard and (not ddp or local_rank == 0) else None
210
 
211
- # validation
212
  val_ds = None
213
  val_dl = None
214
 
215
- # resume
216
  start_step = 0
217
  best_val_loss = float("inf")
218
  if resume_from and os.path.exists(resume_from):
@@ -236,12 +223,11 @@ def train(
236
  running_loss = 0.0
237
  t0 = time.time()
238
  no_improve_steps = 0
239
- early_stop_patience = 10_000 # you can tune this
240
 
241
- # training loop
242
  while step < max_steps:
243
  if sampler is not None:
244
- sampler.set_epoch(step) # shuffle differently per epoch for DDP
245
 
246
  for batch in dl:
247
  x, y = batch
@@ -262,10 +248,8 @@ def train(
262
  scaler.unscale_(optimizer)
263
  torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
264
 
265
- # Compute gradient norm BEFORE clearing gradients (only when needed for logging)
266
  grad_norm = None
267
  if (step + 1) % 50 == 0 and (not ddp or local_rank == 0):
268
- # Enable debug mode for first few steps to diagnose gradient issues
269
  debug_gradients = step < 5
270
  grad_norm = compute_grad_norm(model if not ddp else model.module, debug=debug_gradients)
271
 
@@ -278,7 +262,6 @@ def train(
278
  ema.update(model if not ddp else model.module)
279
  step += 1
280
 
281
- # logging
282
  if step % 50 == 0 and (not ddp or local_rank == 0) and grad_norm is not None:
283
  avg_loss = running_loss * grad_accum / 50.0
284
  running_loss = 0.0
@@ -291,11 +274,8 @@ def train(
291
  writer.add_scalar("train/lr", lr_now, step)
292
  t0 = time.time()
293
 
294
- # periodic validation
295
  if validate_every and step % validate_every == 0:
296
  if val_dl is None:
297
- # Use a proper validation dataset with wikitext-2 validation split
298
- # This provides more reliable validation than using training data subsets
299
  val_sources = []
300
  for source in sources[:min(3, len(sources))]:
301
  val_source = DataSource(
@@ -344,7 +324,22 @@ def train(
344
  if mean_val < best_val_loss:
345
  best_val_loss = mean_val
346
  no_improve_steps = 0
347
- best_path = os.path.join(out_dir, f"supernova_best_step{step}.pt")
348
  model_state = model.module.state_dict() if ddp else model.state_dict()
349
  ckpt = {
350
- "model_state_dict": model_state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  grad_count += 1
32
  param_norm = p.grad.data.float().norm(2).item()
33
  total += param_norm * param_norm
34
+ if debug and param_norm > 1e-8:
35
  print(f" {name}: grad_norm={param_norm:.6f}")
36
  elif debug:
37
  print(f" {name}: NO GRAD")
 
46
  torch.save(obj, tmp)
47
  os.replace(tmp, path)
48
 
49
+ def save_safetensors_checkpoint(model_state_dict: Dict[str, torch.Tensor], path: str):
50
  """Save model weights in safetensors format."""
51
  try:
52
  tmp = path + ".tmp"
53
  save_file(model_state_dict, tmp)
54
  os.replace(tmp, path)
55
+ print(f"Saved safetensors to {path}")
56
  except Exception as e:
57
  print(f"Warning: Failed to save safetensors: {e}")
58
 
 
111
  num_workers: int = 4,
112
  pin_memory: bool = True,
113
  compile_model: bool = False,
114
+ export_safetensors: bool = True,
115
  ):
116
  # reproducibility
117
  torch.manual_seed(seed)
118
  torch.cuda.manual_seed_all(seed)
119
  import random
120
  random.seed(seed)
 
121
  torch.backends.cudnn.benchmark = True
122
 
123
  # device / distributed
 
135
  assert tok.vocab_size == cfg.vocab_size, "Tokenizer vocab size mismatch."
136
 
137
  model = SupernovaModel(cfg)
 
138
  if hasattr(model, "gradient_checkpointing_enable"):
139
  try:
140
  model.gradient_checkpointing_enable()
 
143
 
144
  model.to(device)
145
 
 
146
  total_params = sum(p.numel() for p in model.parameters())
147
  assert total_params == 25_000_000, f"Model has {total_params} params, expected 25,000,000"
148
 
 
149
  if compile_model:
150
  try:
151
  model = torch.compile(model)
152
  except Exception as e:
153
  print("torch.compile not available/failed:", e)
154
 
 
155
  if ddp:
156
  model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=False)
157
 
 
158
  sources = load_sources_from_yaml(data_config_path)
 
159
  ds = TokenChunkDataset(
160
  tokenizer=tok,
161
  sources=sources,
 
164
  )
165
  sampler = DistributedSampler(ds) if ddp else None
166
 
 
167
  dl = DataLoader(
168
  ds,
169
  batch_size=batch_size,
 
174
  drop_last=True,
175
  )
176
 
 
177
  def param_groups(model):
178
  decay, no_decay = [], []
179
  for n, p in model.named_parameters():
 
190
 
191
  optimizer = torch.optim.AdamW(param_groups(model), lr=lr, betas=(0.9, 0.95), eps=1e-8)
192
  scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps)
 
193
  scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))
194
 
 
195
  ema = EMA(model if not ddp else model.module, decay=ema_decay) if use_ema else None
196
 
197
  os.makedirs(out_dir, exist_ok=True)
198
  writer = SummaryWriter(log_dir=os.path.join(out_dir, "runs")) if use_tensorboard and (not ddp or local_rank == 0) else None
199
 
 
200
  val_ds = None
201
  val_dl = None
202
 
 
203
  start_step = 0
204
  best_val_loss = float("inf")
205
  if resume_from and os.path.exists(resume_from):
 
223
  running_loss = 0.0
224
  t0 = time.time()
225
  no_improve_steps = 0
226
+ early_stop_patience = 10_000
227
 
 
228
  while step < max_steps:
229
  if sampler is not None:
230
+ sampler.set_epoch(step)
231
 
232
  for batch in dl:
233
  x, y = batch
 
248
  scaler.unscale_(optimizer)
249
  torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
250
 
 
251
  grad_norm = None
252
  if (step + 1) % 50 == 0 and (not ddp or local_rank == 0):
 
253
  debug_gradients = step < 5
254
  grad_norm = compute_grad_norm(model if not ddp else model.module, debug=debug_gradients)
255
 
 
262
  ema.update(model if not ddp else model.module)
263
  step += 1
264
 
 
265
  if step % 50 == 0 and (not ddp or local_rank == 0) and grad_norm is not None:
266
  avg_loss = running_loss * grad_accum / 50.0
267
  running_loss = 0.0
 
274
  writer.add_scalar("train/lr", lr_now, step)
275
  t0 = time.time()
276
 
 
277
  if validate_every and step % validate_every == 0:
278
  if val_dl is None:
 
 
279
  val_sources = []
280
  for source in sources[:min(3, len(sources))]:
281
  val_source = DataSource(
 
324
  if mean_val < best_val_loss:
325
  best_val_loss = mean_val
326
  no_improve_steps = 0
327
+ best_path_pt = os.path.join(out_dir, f"supernova_best_step{step}.pt")
328
  model_state = model.module.state_dict() if ddp else model.state_dict()
329
  ckpt = {
330
+ "model_state_dict": model_state,
331
+ "optimizer_state_dict": optimizer.state_dict(),
332
+ "scheduler_state_dict": scheduler.state_dict(),
333
+ "scaler_state_dict": (scaler.state_dict() if scaler else None),
334
+ "step": step,
335
+ "best_val_loss": best_val_loss,
336
+ "config": cfg.__dict__,
337
+ }
338
+ if not ddp or local_rank == 0:
339
+ atomic_save(ckpt, best_path_pt)
340
+ print(f"Saved best checkpoint to {best_path_pt}")
341
+
342
+ # Save safetensors
343
+ if export_safetensors:
344
+ best_path_st = os.path.join(out_dir, f"supernova_best_step{step}.safetensors")
345
+ save_safetensors_checkpoint(