algorythmtechnologies commited on
Commit
3a0f81c
·
verified ·
1 Parent(s): a55cadf

Update supernova/train.py

Browse files
Files changed (1) hide show
  1. supernova/train.py +24 -30
supernova/train.py CHANGED
@@ -15,11 +15,11 @@ from transformers import get_cosine_schedule_with_warmup
15
  from .config import ModelConfig
16
  from .model import SupernovaModel
17
  from .tokenizer import load_gpt2_tokenizer
18
- from .data import load_sources_from_yaml, TokenChunkDataset
19
 
20
- # -----------------------
21
  # Utilities
22
- # -----------------------
23
  def compute_grad_norm(model: nn.Module) -> float:
24
  total = 0.0
25
  for p in model.parameters():
@@ -61,9 +61,9 @@ class EMA:
61
  p.data.copy_(self.backup[name])
62
  del self.backup
63
 
64
- # -----------------------
65
  # Training loop
66
- # -----------------------
67
  def train(
68
  config_path: str,
69
  data_config_path: str,
@@ -153,7 +153,7 @@ def train(
153
  drop_last=True,
154
  )
155
 
156
- # optimizer with simple parameter grouping example to avoid weight decay on norms/bias
157
  def param_groups(model):
158
  decay, no_decay = [], []
159
  for n, p in model.named_parameters():
@@ -213,12 +213,12 @@ def train(
213
  running_loss = 0.0
214
  t0 = time.time()
215
  no_improve_steps = 0
216
- early_stop_patience = 10_000 # you can tune this
217
 
218
  # training loop
219
  while step < max_steps:
220
  if sampler is not None:
221
- sampler.set_epoch(step) # shuffle differently per epoch for DDP
222
 
223
  for batch in dl:
224
  x, y = batch
@@ -266,26 +266,21 @@ def train(
266
  # periodic validation
267
  if validate_every and step % validate_every == 0:
268
  if val_dl is None:
269
- # quick in-memory val split: take first N batches (user should replace with real val)
270
- # NOTE: for production, create a dedicated validation dataset.
271
- if val_dl is None:
272
- # Create a proper validation dataset with a small subset of training sources
273
- val_sources = []
274
- for source in sources[:min(3, len(sources))]: # Use first few sources for validation
275
- val_source = DataSource(
276
- name=f"{source.name}_val",
277
- hf_path="wikitext", # Use a reliable, small dataset for validation
278
- hf_name="wikitext-2-v1",
279
- split="validation",
280
- text_field="text",
281
- weight=1,
282
- streaming=False
283
- )
284
- val_sources.append(val_source)
285
- val_ds = TokenChunkDataset(val_sources, seq_len=seq_len, eos_token_id=tok.eos_token_id)
286
- val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=False)
287
-
288
-
289
 
290
  model.eval()
291
  # optionally swap in EMA weights for evaluation
@@ -329,7 +324,7 @@ def train(
329
  }
330
  if not ddp or local_rank == 0:
331
  atomic_save(ckpt, best_path)
332
- print(f"Saved best checkpoint to {best_path}")
333
  else:
334
  no_improve_steps += validate_every
335
  if no_improve_steps >= early_stop_patience:
@@ -376,7 +371,6 @@ def train(
376
  if writer:
377
  writer.close()
378
 
379
-
380
  if __name__ == "__main__":
381
  ap = argparse.ArgumentParser()
382
  ap.add_argument("--config", required=True)
 
15
  from .config import ModelConfig
16
  from .model import SupernovaModel
17
  from .tokenizer import load_gpt2_tokenizer
18
+ from .data import load_sources_from_yaml, TokenChunkDataset, DataSource
19
 
20
+ # ------------------------------
21
  # Utilities
22
+ # ------------------------------
23
  def compute_grad_norm(model: nn.Module) -> float:
24
  total = 0.0
25
  for p in model.parameters():
 
61
  p.data.copy_(self.backup[name])
62
  del self.backup
63
 
64
+ # ------------------------------
65
  # Training loop
66
+ # ------------------------------
67
  def train(
68
  config_path: str,
69
  data_config_path: str,
 
153
  drop_last=True,
154
  )
155
 
156
+ # optimizer with simple parameter grouping to avoid weight decay on norms/bias
157
  def param_groups(model):
158
  decay, no_decay = [], []
159
  for n, p in model.named_parameters():
 
213
  running_loss = 0.0
214
  t0 = time.time()
215
  no_improve_steps = 0
216
+ early_stop_patience = 10_000 # you can tune this
217
 
218
  # training loop
219
  while step < max_steps:
220
  if sampler is not None:
221
+ sampler.set_epoch(step) # shuffle differently per epoch for DDP
222
 
223
  for batch in dl:
224
  x, y = batch
 
266
  # periodic validation
267
  if validate_every and step % validate_every == 0:
268
  if val_dl is None:
269
+ # Create a proper validation dataset with a small subset of training sources
270
+ val_sources = []
271
+ for source in sources[:min(3, len(sources))]:
272
+ val_source = DataSource(
273
+ name=f"{source.name}_val",
274
+ hf_path="wikitext", # Use a reliable, small dataset for validation
275
+ hf_name="wikitext-2-v1",
276
+ split="validation",
277
+ text_field="text",
278
+ weight=1,
279
+ streaming=False # Don't stream validation data
280
+ )
281
+ val_sources.append(val_source)
282
+ val_ds = TokenChunkDataset(val_sources, seq_len=seq_len, eos_token_id=tok.eos_token_id)
283
+ val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=False)
 
 
 
 
 
284
 
285
  model.eval()
286
  # optionally swap in EMA weights for evaluation
 
324
  }
325
  if not ddp or local_rank == 0:
326
  atomic_save(ckpt, best_path)
327
+ print(f"Saved best checkpoint to {best_path}")
328
  else:
329
  no_improve_steps += validate_every
330
  if no_improve_steps >= early_stop_patience:
 
371
  if writer:
372
  writer.close()
373
 
 
374
  if __name__ == "__main__":
375
  ap = argparse.ArgumentParser()
376
  ap.add_argument("--config", required=True)