Update supernova/train.py
Browse files- supernova/train.py +7 -24
supernova/train.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
# train.py (improved)
|
| 2 |
import argparse
|
| 3 |
import json
|
| 4 |
import math
|
|
@@ -138,14 +137,13 @@ def train(
|
|
| 138 |
|
| 139 |
# dataset and dataloader
|
| 140 |
sources = load_sources_from_yaml(data_config_path)
|
| 141 |
-
# TODO: improve TokenChunkDataset to perform token-packing (pack multiple short examples into one sequence)
|
| 142 |
ds = TokenChunkDataset(tok, sources, seq_len=seq_len, eos_token_id=tok.eos_token_id)
|
| 143 |
-
|
| 144 |
sampler = DistributedSampler(ds) if ddp else None
|
|
|
|
|
|
|
| 145 |
dl = DataLoader(
|
| 146 |
ds,
|
| 147 |
batch_size=batch_size,
|
| 148 |
-
shuffle=(sampler is None),
|
| 149 |
sampler=sampler,
|
| 150 |
num_workers=num_workers,
|
| 151 |
pin_memory=pin_memory,
|
|
@@ -153,7 +151,7 @@ def train(
|
|
| 153 |
drop_last=True,
|
| 154 |
)
|
| 155 |
|
| 156 |
-
# optimizer
|
| 157 |
def param_groups(model):
|
| 158 |
decay, no_decay = [], []
|
| 159 |
for n, p in model.named_parameters():
|
|
@@ -169,22 +167,14 @@ def train(
|
|
| 169 |
]
|
| 170 |
|
| 171 |
optimizer = torch.optim.AdamW(param_groups(model), lr=lr, betas=(0.9, 0.95), eps=1e-8)
|
| 172 |
-
|
| 173 |
-
# scheduler
|
| 174 |
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps)
|
| 175 |
-
|
| 176 |
-
# AMP scaler
|
| 177 |
scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))
|
| 178 |
-
|
| 179 |
-
# EMA
|
| 180 |
ema = EMA(model if not ddp else model.module, decay=ema_decay) if use_ema else None
|
| 181 |
|
| 182 |
-
# logging + checkpoint dir
|
| 183 |
os.makedirs(out_dir, exist_ok=True)
|
| 184 |
writer = SummaryWriter(log_dir=os.path.join(out_dir, "runs")) if use_tensorboard and (not ddp or local_rank == 0) else None
|
| 185 |
|
| 186 |
-
# validation
|
| 187 |
-
# TODO: Implement a proper validation dataset pipeline. For now, we use a small random subset of training data.
|
| 188 |
val_ds = None
|
| 189 |
val_dl = None
|
| 190 |
|
|
@@ -194,7 +184,6 @@ def train(
|
|
| 194 |
if resume_from and os.path.exists(resume_from):
|
| 195 |
ckpt = torch.load(resume_from, map_location=device)
|
| 196 |
model_state = ckpt["model_state_dict"]
|
| 197 |
-
# if ddp, load into module
|
| 198 |
target = model.module if ddp else model
|
| 199 |
target.load_state_dict(model_state)
|
| 200 |
optimizer.load_state_dict(ckpt.get("optimizer_state_dict", {}))
|
|
@@ -234,7 +223,6 @@ def train(
|
|
| 234 |
running_loss += loss.item()
|
| 235 |
|
| 236 |
if micro % grad_accum == 0:
|
| 237 |
-
# gradient clipping
|
| 238 |
if clip_grad_norm is not None:
|
| 239 |
scaler.unscale_(optimizer)
|
| 240 |
torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
|
|
@@ -246,7 +234,6 @@ def train(
|
|
| 246 |
|
| 247 |
if ema:
|
| 248 |
ema.update(model if not ddp else model.module)
|
| 249 |
-
|
| 250 |
step += 1
|
| 251 |
|
| 252 |
# logging
|
|
@@ -266,24 +253,22 @@ def train(
|
|
| 266 |
# periodic validation
|
| 267 |
if validate_every and step % validate_every == 0:
|
| 268 |
if val_dl is None:
|
| 269 |
-
# Create a proper validation dataset with a small subset of training sources
|
| 270 |
val_sources = []
|
| 271 |
for source in sources[:min(3, len(sources))]:
|
| 272 |
val_source = DataSource(
|
| 273 |
name=f"{source.name}_val",
|
| 274 |
-
hf_path="wikitext",
|
| 275 |
hf_name="wikitext-2-v1",
|
| 276 |
split="validation",
|
| 277 |
text_field="text",
|
| 278 |
weight=1,
|
| 279 |
-
streaming=False
|
| 280 |
)
|
| 281 |
val_sources.append(val_source)
|
| 282 |
val_ds = TokenChunkDataset(val_sources, seq_len=seq_len, eos_token_id=tok.eos_token_id)
|
| 283 |
-
val_dl = DataLoader(val_ds, batch_size=batch_size,
|
| 284 |
|
| 285 |
model.eval()
|
| 286 |
-
# optionally swap in EMA weights for evaluation
|
| 287 |
if ema:
|
| 288 |
ema.store(model if not ddp else model.module)
|
| 289 |
ema.copy_to(model if not ddp else model.module)
|
|
@@ -303,12 +288,10 @@ def train(
|
|
| 303 |
writer.add_scalar("val/loss", mean_val, step)
|
| 304 |
print(f"[eval] step={step} val_loss={mean_val:.6f}")
|
| 305 |
|
| 306 |
-
# restore weights
|
| 307 |
if ema:
|
| 308 |
ema.restore(model if not ddp else model.module)
|
| 309 |
model.train()
|
| 310 |
|
| 311 |
-
# early stop / best model saving
|
| 312 |
if mean_val < best_val_loss:
|
| 313 |
best_val_loss = mean_val
|
| 314 |
no_improve_steps = 0
|
|
|
|
|
|
|
| 1 |
import argparse
|
| 2 |
import json
|
| 3 |
import math
|
|
|
|
| 137 |
|
| 138 |
# dataset and dataloader
|
| 139 |
sources = load_sources_from_yaml(data_config_path)
|
|
|
|
| 140 |
ds = TokenChunkDataset(tok, sources, seq_len=seq_len, eos_token_id=tok.eos_token_id)
|
|
|
|
| 141 |
sampler = DistributedSampler(ds) if ddp else None
|
| 142 |
+
|
| 143 |
+
# NOTE: NO shuffle for IterableDataset!
|
| 144 |
dl = DataLoader(
|
| 145 |
ds,
|
| 146 |
batch_size=batch_size,
|
|
|
|
| 147 |
sampler=sampler,
|
| 148 |
num_workers=num_workers,
|
| 149 |
pin_memory=pin_memory,
|
|
|
|
| 151 |
drop_last=True,
|
| 152 |
)
|
| 153 |
|
| 154 |
+
# optimizer
|
| 155 |
def param_groups(model):
|
| 156 |
decay, no_decay = [], []
|
| 157 |
for n, p in model.named_parameters():
|
|
|
|
| 167 |
]
|
| 168 |
|
| 169 |
optimizer = torch.optim.AdamW(param_groups(model), lr=lr, betas=(0.9, 0.95), eps=1e-8)
|
|
|
|
|
|
|
| 170 |
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps)
|
|
|
|
|
|
|
| 171 |
scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))
|
|
|
|
|
|
|
| 172 |
ema = EMA(model if not ddp else model.module, decay=ema_decay) if use_ema else None
|
| 173 |
|
|
|
|
| 174 |
os.makedirs(out_dir, exist_ok=True)
|
| 175 |
writer = SummaryWriter(log_dir=os.path.join(out_dir, "runs")) if use_tensorboard and (not ddp or local_rank == 0) else None
|
| 176 |
|
| 177 |
+
# validation
|
|
|
|
| 178 |
val_ds = None
|
| 179 |
val_dl = None
|
| 180 |
|
|
|
|
| 184 |
if resume_from and os.path.exists(resume_from):
|
| 185 |
ckpt = torch.load(resume_from, map_location=device)
|
| 186 |
model_state = ckpt["model_state_dict"]
|
|
|
|
| 187 |
target = model.module if ddp else model
|
| 188 |
target.load_state_dict(model_state)
|
| 189 |
optimizer.load_state_dict(ckpt.get("optimizer_state_dict", {}))
|
|
|
|
| 223 |
running_loss += loss.item()
|
| 224 |
|
| 225 |
if micro % grad_accum == 0:
|
|
|
|
| 226 |
if clip_grad_norm is not None:
|
| 227 |
scaler.unscale_(optimizer)
|
| 228 |
torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
|
|
|
|
| 234 |
|
| 235 |
if ema:
|
| 236 |
ema.update(model if not ddp else model.module)
|
|
|
|
| 237 |
step += 1
|
| 238 |
|
| 239 |
# logging
|
|
|
|
| 253 |
# periodic validation
|
| 254 |
if validate_every and step % validate_every == 0:
|
| 255 |
if val_dl is None:
|
|
|
|
| 256 |
val_sources = []
|
| 257 |
for source in sources[:min(3, len(sources))]:
|
| 258 |
val_source = DataSource(
|
| 259 |
name=f"{source.name}_val",
|
| 260 |
+
hf_path="wikitext",
|
| 261 |
hf_name="wikitext-2-v1",
|
| 262 |
split="validation",
|
| 263 |
text_field="text",
|
| 264 |
weight=1,
|
| 265 |
+
streaming=False
|
| 266 |
)
|
| 267 |
val_sources.append(val_source)
|
| 268 |
val_ds = TokenChunkDataset(val_sources, seq_len=seq_len, eos_token_id=tok.eos_token_id)
|
| 269 |
+
val_dl = DataLoader(val_ds, batch_size=batch_size, num_workers=0, pin_memory=True, drop_last=False)
|
| 270 |
|
| 271 |
model.eval()
|
|
|
|
| 272 |
if ema:
|
| 273 |
ema.store(model if not ddp else model.module)
|
| 274 |
ema.copy_to(model if not ddp else model.module)
|
|
|
|
| 288 |
writer.add_scalar("val/loss", mean_val, step)
|
| 289 |
print(f"[eval] step={step} val_loss={mean_val:.6f}")
|
| 290 |
|
|
|
|
| 291 |
if ema:
|
| 292 |
ema.restore(model if not ddp else model.module)
|
| 293 |
model.train()
|
| 294 |
|
|
|
|
| 295 |
if mean_val < best_val_loss:
|
| 296 |
best_val_loss = mean_val
|
| 297 |
no_improve_steps = 0
|