Kompella Sri Aasrith Souri commited on
Commit
30ecce6
·
1 Parent(s): 10cc86d

fixed training datasets

Browse files
Files changed (3) hide show
  1. Supernova25million +1 -0
  2. supernova/train.py +23 -5
  3. train_main.py +72 -0
Supernova25million ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 288c71bea4b8740818638d0e2dae0a647da22763
supernova/train.py CHANGED
@@ -139,7 +139,12 @@ def train(
139
  # dataset and dataloader
140
  sources = load_sources_from_yaml(data_config_path)
141
  # TODO: improve TokenChunkDataset to perform token-packing (pack multiple short examples into one sequence)
142
- ds = TokenChunkDataset(tok, sources, seq_len=seq_len, eos_token_id=tok.eos_token_id)
 
 
 
 
 
143
 
144
  sampler = DistributedSampler(ds) if ddp else None
145
  dl = DataLoader(
@@ -174,7 +179,10 @@ def train(
174
  scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps)
175
 
176
  # AMP scaler
177
- scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))
 
 
 
178
 
179
  # EMA
180
  ema = EMA(model if not ddp else model.module, decay=ema_decay) if use_ema else None
@@ -225,7 +233,8 @@ def train(
225
  x = x.to(device, non_blocking=True)
226
  y = y.to(device, non_blocking=True)
227
 
228
- with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):
 
229
  logits, loss = model(x, y)
230
  loss = loss / grad_accum
231
 
@@ -268,7 +277,15 @@ def train(
268
  if val_dl is None:
269
  # quick in-memory val split: take first N batches (user should replace with real val)
270
  # NOTE: for production, create a dedicated validation dataset.
271
- val_ds = TokenChunkDataset(tok, sources[: max(1, len(sources) // 20)], seq_len=seq_len, eos_token_id=tok.eos_token_id)
 
 
 
 
 
 
 
 
272
  val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=False)
273
 
274
  model.eval()
@@ -284,7 +301,8 @@ def train(
284
  break
285
  vx = vx.to(device)
286
  vy = vy.to(device)
287
- with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):
 
288
  _, vloss = model(vx, vy)
289
  val_losses.append(float(vloss.detach().cpu().item()))
290
  mean_val = float(sum(val_losses) / max(1, len(val_losses)))
 
139
  # dataset and dataloader
140
  sources = load_sources_from_yaml(data_config_path)
141
  # TODO: improve TokenChunkDataset to perform token-packing (pack multiple short examples into one sequence)
142
+ ds = TokenChunkDataset(
143
+ tokenizer=tok,
144
+ sources=sources,
145
+ seq_len=seq_len,
146
+ eos_token_id=tok.eos_token_id
147
+ )
148
 
149
  sampler = DistributedSampler(ds) if ddp else None
150
  dl = DataLoader(
 
179
  scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps)
180
 
181
  # AMP scaler
182
+ if device.type == "cuda":
183
+ scaler = torch.amp.GradScaler('cuda', enabled=True)
184
+ else:
185
+ scaler = torch.amp.GradScaler('cpu', enabled=False)
186
 
187
  # EMA
188
  ema = EMA(model if not ddp else model.module, decay=ema_decay) if use_ema else None
 
233
  x = x.to(device, non_blocking=True)
234
  y = y.to(device, non_blocking=True)
235
 
236
+ device_type = 'cuda' if device.type == 'cuda' else 'cpu'
237
+ with torch.amp.autocast(device_type, enabled=(device.type == "cuda")):
238
  logits, loss = model(x, y)
239
  loss = loss / grad_accum
240
 
 
277
  if val_dl is None:
278
  # quick in-memory val split: take first N batches (user should replace with real val)
279
  # NOTE: for production, create a dedicated validation dataset.
280
+ val_sources = sources[: max(1, len(sources) // 20)]
281
+ if not val_sources:
282
+ val_sources = sources[:1] # fallback to at least one source
283
+ val_ds = TokenChunkDataset(
284
+ tokenizer=tok,
285
+ sources=val_sources,
286
+ seq_len=seq_len,
287
+ eos_token_id=tok.eos_token_id
288
+ )
289
  val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=False)
290
 
291
  model.eval()
 
301
  break
302
  vx = vx.to(device)
303
  vy = vy.to(device)
304
+ device_type = 'cuda' if device.type == 'cuda' else 'cpu'
305
+ with torch.amp.autocast(device_type, enabled=(device.type == "cuda")):
306
  _, vloss = model(vx, vy)
307
  val_losses.append(float(vloss.detach().cpu().item()))
308
  mean_val = float(sum(val_losses) / max(1, len(val_losses)))
train_main.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Main training script - can be run directly without import issues.
4
+ This script imports and runs the training function from the supernova package.
5
+ """
6
+
7
+ import argparse
8
+ import sys
9
+ import os
10
+
11
+ # Add the current directory to Python path to ensure supernova package can be imported
12
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
13
+
14
+ from supernova.train import train
15
+
16
+ def main():
17
+ parser = argparse.ArgumentParser(description="Train Supernova 25M model")
18
+ parser.add_argument("--config", required=True, help="Path to model config JSON")
19
+ parser.add_argument("--data", required=True, help="Path to data config YAML")
20
+ parser.add_argument("--seq-len", type=int, default=1024, help="Sequence length")
21
+ parser.add_argument("--batch-size", type=int, default=16, help="Batch size")
22
+ parser.add_argument("--grad-accum", type=int, default=8, help="Gradient accumulation steps")
23
+ parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
24
+ parser.add_argument("--warmup-steps", type=int, default=2000, help="Warmup steps")
25
+ parser.add_argument("--max-steps", type=int, default=100000, help="Maximum training steps")
26
+ parser.add_argument("--save-every", type=int, default=10000, help="Save checkpoint every N steps")
27
+ parser.add_argument("--out-dir", default="checkpoints", help="Output directory")
28
+ parser.add_argument("--seed", type=int, default=42, help="Random seed")
29
+ parser.add_argument("--validate-every", type=int, default=1000, help="Validate every N steps")
30
+ parser.add_argument("--val-steps", type=int, default=100, help="Validation steps")
31
+ parser.add_argument("--clip-grad-norm", type=float, default=1.0, help="Gradient clipping norm")
32
+ parser.add_argument("--no-ema", action="store_true", help="Disable EMA")
33
+ parser.add_argument("--ema-decay", type=float, default=0.9999, help="EMA decay rate")
34
+ parser.add_argument("--resume-from", help="Resume from checkpoint")
35
+ parser.add_argument("--no-tensorboard", action="store_true", help="Disable tensorboard")
36
+ parser.add_argument("--ddp", action="store_true", help="Use distributed training")
37
+ parser.add_argument("--local-rank", type=int, default=0, help="Local rank for DDP")
38
+ parser.add_argument("--num-workers", type=int, default=4, help="DataLoader workers")
39
+ parser.add_argument("--no-pin-memory", action="store_true", help="Disable pin memory")
40
+ parser.add_argument("--compile-model", action="store_true", help="Use torch.compile")
41
+
42
+ args = parser.parse_args()
43
+
44
+ # Call the training function
45
+ train(
46
+ config_path=args.config,
47
+ data_config_path=args.data,
48
+ seq_len=args.seq_len,
49
+ batch_size=args.batch_size,
50
+ grad_accum=args.grad_accum,
51
+ lr=args.lr,
52
+ warmup_steps=args.warmup_steps,
53
+ max_steps=args.max_steps,
54
+ save_every=args.save_every,
55
+ out_dir=args.out_dir,
56
+ seed=args.seed,
57
+ validate_every=args.validate_every,
58
+ val_steps=args.val_steps,
59
+ clip_grad_norm=args.clip_grad_norm,
60
+ use_ema=not args.no_ema,
61
+ ema_decay=args.ema_decay,
62
+ resume_from=args.resume_from,
63
+ use_tensorboard=not args.no_tensorboard,
64
+ ddp=args.ddp,
65
+ local_rank=args.local_rank,
66
+ num_workers=args.num_workers,
67
+ pin_memory=not args.no_pin_memory,
68
+ compile_model=args.compile_model,
69
+ )
70
+
71
+ if __name__ == "__main__":
72
+ main()