Update supernova/train.py
Browse files- supernova/train.py +24 -30
supernova/train.py
CHANGED
|
@@ -15,11 +15,11 @@ from transformers import get_cosine_schedule_with_warmup
|
|
| 15 |
from .config import ModelConfig
|
| 16 |
from .model import SupernovaModel
|
| 17 |
from .tokenizer import load_gpt2_tokenizer
|
| 18 |
-
from .data import load_sources_from_yaml, TokenChunkDataset
|
| 19 |
|
| 20 |
-
#
|
| 21 |
# Utilities
|
| 22 |
-
#
|
| 23 |
def compute_grad_norm(model: nn.Module) -> float:
|
| 24 |
total = 0.0
|
| 25 |
for p in model.parameters():
|
|
@@ -61,9 +61,9 @@ class EMA:
|
|
| 61 |
p.data.copy_(self.backup[name])
|
| 62 |
del self.backup
|
| 63 |
|
| 64 |
-
#
|
| 65 |
# Training loop
|
| 66 |
-
#
|
| 67 |
def train(
|
| 68 |
config_path: str,
|
| 69 |
data_config_path: str,
|
|
@@ -153,7 +153,7 @@ def train(
|
|
| 153 |
drop_last=True,
|
| 154 |
)
|
| 155 |
|
| 156 |
-
# optimizer with simple parameter grouping
|
| 157 |
def param_groups(model):
|
| 158 |
decay, no_decay = [], []
|
| 159 |
for n, p in model.named_parameters():
|
|
@@ -213,12 +213,12 @@ def train(
|
|
| 213 |
running_loss = 0.0
|
| 214 |
t0 = time.time()
|
| 215 |
no_improve_steps = 0
|
| 216 |
-
early_stop_patience = 10_000
|
| 217 |
|
| 218 |
# training loop
|
| 219 |
while step < max_steps:
|
| 220 |
if sampler is not None:
|
| 221 |
-
sampler.set_epoch(step)
|
| 222 |
|
| 223 |
for batch in dl:
|
| 224 |
x, y = batch
|
|
@@ -266,26 +266,21 @@ def train(
|
|
| 266 |
# periodic validation
|
| 267 |
if validate_every and step % validate_every == 0:
|
| 268 |
if val_dl is None:
|
| 269 |
-
#
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
val_sources.append(val_source)
|
| 285 |
-
val_ds = TokenChunkDataset(val_sources, seq_len=seq_len, eos_token_id=tok.eos_token_id)
|
| 286 |
-
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=False)
|
| 287 |
-
|
| 288 |
-
|
| 289 |
|
| 290 |
model.eval()
|
| 291 |
# optionally swap in EMA weights for evaluation
|
|
@@ -329,7 +324,7 @@ def train(
|
|
| 329 |
}
|
| 330 |
if not ddp or local_rank == 0:
|
| 331 |
atomic_save(ckpt, best_path)
|
| 332 |
-
|
| 333 |
else:
|
| 334 |
no_improve_steps += validate_every
|
| 335 |
if no_improve_steps >= early_stop_patience:
|
|
@@ -376,7 +371,6 @@ def train(
|
|
| 376 |
if writer:
|
| 377 |
writer.close()
|
| 378 |
|
| 379 |
-
|
| 380 |
if __name__ == "__main__":
|
| 381 |
ap = argparse.ArgumentParser()
|
| 382 |
ap.add_argument("--config", required=True)
|
|
|
|
| 15 |
from .config import ModelConfig
|
| 16 |
from .model import SupernovaModel
|
| 17 |
from .tokenizer import load_gpt2_tokenizer
|
| 18 |
+
from .data import load_sources_from_yaml, TokenChunkDataset, DataSource
|
| 19 |
|
| 20 |
+
# ------------------------------
|
| 21 |
# Utilities
|
| 22 |
+
# ------------------------------
|
| 23 |
def compute_grad_norm(model: nn.Module) -> float:
|
| 24 |
total = 0.0
|
| 25 |
for p in model.parameters():
|
|
|
|
| 61 |
p.data.copy_(self.backup[name])
|
| 62 |
del self.backup
|
| 63 |
|
| 64 |
+
# ------------------------------
|
| 65 |
# Training loop
|
| 66 |
+
# ------------------------------
|
| 67 |
def train(
|
| 68 |
config_path: str,
|
| 69 |
data_config_path: str,
|
|
|
|
| 153 |
drop_last=True,
|
| 154 |
)
|
| 155 |
|
| 156 |
+
# optimizer with simple parameter grouping to avoid weight decay on norms/bias
|
| 157 |
def param_groups(model):
|
| 158 |
decay, no_decay = [], []
|
| 159 |
for n, p in model.named_parameters():
|
|
|
|
| 213 |
running_loss = 0.0
|
| 214 |
t0 = time.time()
|
| 215 |
no_improve_steps = 0
|
| 216 |
+
early_stop_patience = 10_000 # you can tune this
|
| 217 |
|
| 218 |
# training loop
|
| 219 |
while step < max_steps:
|
| 220 |
if sampler is not None:
|
| 221 |
+
sampler.set_epoch(step) # shuffle differently per epoch for DDP
|
| 222 |
|
| 223 |
for batch in dl:
|
| 224 |
x, y = batch
|
|
|
|
| 266 |
# periodic validation
|
| 267 |
if validate_every and step % validate_every == 0:
|
| 268 |
if val_dl is None:
|
| 269 |
+
# Create a proper validation dataset with a small subset of training sources
|
| 270 |
+
val_sources = []
|
| 271 |
+
for source in sources[:min(3, len(sources))]:
|
| 272 |
+
val_source = DataSource(
|
| 273 |
+
name=f"{source.name}_val",
|
| 274 |
+
hf_path="wikitext", # Use a reliable, small dataset for validation
|
| 275 |
+
hf_name="wikitext-2-v1",
|
| 276 |
+
split="validation",
|
| 277 |
+
text_field="text",
|
| 278 |
+
weight=1,
|
| 279 |
+
streaming=False # Don't stream validation data
|
| 280 |
+
)
|
| 281 |
+
val_sources.append(val_source)
|
| 282 |
+
val_ds = TokenChunkDataset(val_sources, seq_len=seq_len, eos_token_id=tok.eos_token_id)
|
| 283 |
+
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
|
| 285 |
model.eval()
|
| 286 |
# optionally swap in EMA weights for evaluation
|
|
|
|
| 324 |
}
|
| 325 |
if not ddp or local_rank == 0:
|
| 326 |
atomic_save(ckpt, best_path)
|
| 327 |
+
print(f"Saved best checkpoint to {best_path}")
|
| 328 |
else:
|
| 329 |
no_improve_steps += validate_every
|
| 330 |
if no_improve_steps >= early_stop_patience:
|
|
|
|
| 371 |
if writer:
|
| 372 |
writer.close()
|
| 373 |
|
|
|
|
| 374 |
if __name__ == "__main__":
|
| 375 |
ap = argparse.ArgumentParser()
|
| 376 |
ap.add_argument("--config", required=True)
|