algorythmtechnologies commited on
Commit
288c71b
·
verified ·
1 Parent(s): 3a0f81c

Update supernova/train.py

Browse files
Files changed (1) hide show
  1. supernova/train.py +7 -24
supernova/train.py CHANGED
@@ -1,4 +1,3 @@
1
- # train.py (improved)
2
  import argparse
3
  import json
4
  import math
@@ -138,14 +137,13 @@ def train(
138
 
139
  # dataset and dataloader
140
  sources = load_sources_from_yaml(data_config_path)
141
- # TODO: improve TokenChunkDataset to perform token-packing (pack multiple short examples into one sequence)
142
  ds = TokenChunkDataset(tok, sources, seq_len=seq_len, eos_token_id=tok.eos_token_id)
143
-
144
  sampler = DistributedSampler(ds) if ddp else None
 
 
145
  dl = DataLoader(
146
  ds,
147
  batch_size=batch_size,
148
- shuffle=(sampler is None),
149
  sampler=sampler,
150
  num_workers=num_workers,
151
  pin_memory=pin_memory,
@@ -153,7 +151,7 @@ def train(
153
  drop_last=True,
154
  )
155
 
156
- # optimizer with simple parameter grouping to avoid weight decay on norms/bias
157
  def param_groups(model):
158
  decay, no_decay = [], []
159
  for n, p in model.named_parameters():
@@ -169,22 +167,14 @@ def train(
169
  ]
170
 
171
  optimizer = torch.optim.AdamW(param_groups(model), lr=lr, betas=(0.9, 0.95), eps=1e-8)
172
-
173
- # scheduler
174
  scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps)
175
-
176
- # AMP scaler
177
  scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))
178
-
179
- # EMA
180
  ema = EMA(model if not ddp else model.module, decay=ema_decay) if use_ema else None
181
 
182
- # logging + checkpoint dir
183
  os.makedirs(out_dir, exist_ok=True)
184
  writer = SummaryWriter(log_dir=os.path.join(out_dir, "runs")) if use_tensorboard and (not ddp or local_rank == 0) else None
185
 
186
- # validation dataset (simple split: user should provide a separate validation YAML ideally)
187
- # TODO: Implement a proper validation dataset pipeline. For now, we use a small random subset of training data.
188
  val_ds = None
189
  val_dl = None
190
 
@@ -194,7 +184,6 @@ def train(
194
  if resume_from and os.path.exists(resume_from):
195
  ckpt = torch.load(resume_from, map_location=device)
196
  model_state = ckpt["model_state_dict"]
197
- # if ddp, load into module
198
  target = model.module if ddp else model
199
  target.load_state_dict(model_state)
200
  optimizer.load_state_dict(ckpt.get("optimizer_state_dict", {}))
@@ -234,7 +223,6 @@ def train(
234
  running_loss += loss.item()
235
 
236
  if micro % grad_accum == 0:
237
- # gradient clipping
238
  if clip_grad_norm is not None:
239
  scaler.unscale_(optimizer)
240
  torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
@@ -246,7 +234,6 @@ def train(
246
 
247
  if ema:
248
  ema.update(model if not ddp else model.module)
249
-
250
  step += 1
251
 
252
  # logging
@@ -266,24 +253,22 @@ def train(
266
  # periodic validation
267
  if validate_every and step % validate_every == 0:
268
  if val_dl is None:
269
- # Create a proper validation dataset with a small subset of training sources
270
  val_sources = []
271
  for source in sources[:min(3, len(sources))]:
272
  val_source = DataSource(
273
  name=f"{source.name}_val",
274
- hf_path="wikitext", # Use a reliable, small dataset for validation
275
  hf_name="wikitext-2-v1",
276
  split="validation",
277
  text_field="text",
278
  weight=1,
279
- streaming=False # Don't stream validation data
280
  )
281
  val_sources.append(val_source)
282
  val_ds = TokenChunkDataset(val_sources, seq_len=seq_len, eos_token_id=tok.eos_token_id)
283
- val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=False)
284
 
285
  model.eval()
286
- # optionally swap in EMA weights for evaluation
287
  if ema:
288
  ema.store(model if not ddp else model.module)
289
  ema.copy_to(model if not ddp else model.module)
@@ -303,12 +288,10 @@ def train(
303
  writer.add_scalar("val/loss", mean_val, step)
304
  print(f"[eval] step={step} val_loss={mean_val:.6f}")
305
 
306
- # restore weights
307
  if ema:
308
  ema.restore(model if not ddp else model.module)
309
  model.train()
310
 
311
- # early stop / best model saving
312
  if mean_val < best_val_loss:
313
  best_val_loss = mean_val
314
  no_improve_steps = 0
 
 
1
  import argparse
2
  import json
3
  import math
 
137
 
138
  # dataset and dataloader
139
  sources = load_sources_from_yaml(data_config_path)
 
140
  ds = TokenChunkDataset(tok, sources, seq_len=seq_len, eos_token_id=tok.eos_token_id)
 
141
  sampler = DistributedSampler(ds) if ddp else None
142
+
143
+ # NOTE: NO shuffle for IterableDataset!
144
  dl = DataLoader(
145
  ds,
146
  batch_size=batch_size,
 
147
  sampler=sampler,
148
  num_workers=num_workers,
149
  pin_memory=pin_memory,
 
151
  drop_last=True,
152
  )
153
 
154
+ # optimizer
155
  def param_groups(model):
156
  decay, no_decay = [], []
157
  for n, p in model.named_parameters():
 
167
  ]
168
 
169
  optimizer = torch.optim.AdamW(param_groups(model), lr=lr, betas=(0.9, 0.95), eps=1e-8)
 
 
170
  scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps)
 
 
171
  scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))
 
 
172
  ema = EMA(model if not ddp else model.module, decay=ema_decay) if use_ema else None
173
 
 
174
  os.makedirs(out_dir, exist_ok=True)
175
  writer = SummaryWriter(log_dir=os.path.join(out_dir, "runs")) if use_tensorboard and (not ddp or local_rank == 0) else None
176
 
177
+ # validation
 
178
  val_ds = None
179
  val_dl = None
180
 
 
184
  if resume_from and os.path.exists(resume_from):
185
  ckpt = torch.load(resume_from, map_location=device)
186
  model_state = ckpt["model_state_dict"]
 
187
  target = model.module if ddp else model
188
  target.load_state_dict(model_state)
189
  optimizer.load_state_dict(ckpt.get("optimizer_state_dict", {}))
 
223
  running_loss += loss.item()
224
 
225
  if micro % grad_accum == 0:
 
226
  if clip_grad_norm is not None:
227
  scaler.unscale_(optimizer)
228
  torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
 
234
 
235
  if ema:
236
  ema.update(model if not ddp else model.module)
 
237
  step += 1
238
 
239
  # logging
 
253
  # periodic validation
254
  if validate_every and step % validate_every == 0:
255
  if val_dl is None:
 
256
  val_sources = []
257
  for source in sources[:min(3, len(sources))]:
258
  val_source = DataSource(
259
  name=f"{source.name}_val",
260
+ hf_path="wikitext",
261
  hf_name="wikitext-2-v1",
262
  split="validation",
263
  text_field="text",
264
  weight=1,
265
+ streaming=False
266
  )
267
  val_sources.append(val_source)
268
  val_ds = TokenChunkDataset(val_sources, seq_len=seq_len, eos_token_id=tok.eos_token_id)
269
+ val_dl = DataLoader(val_ds, batch_size=batch_size, num_workers=0, pin_memory=True, drop_last=False)
270
 
271
  model.eval()
 
272
  if ema:
273
  ema.store(model if not ddp else model.module)
274
  ema.copy_to(model if not ddp else model.module)
 
288
  writer.add_scalar("val/loss", mean_val, step)
289
  print(f"[eval] step={step} val_loss={mean_val:.6f}")
290
 
 
291
  if ema:
292
  ema.restore(model if not ddp else model.module)
293
  model.train()
294
 
 
295
  if mean_val < best_val_loss:
296
  best_val_loss = mean_val
297
  no_improve_steps = 0