younadi's picture
first beating !
7569568
import itertools
import random
import os
import time
import torch
import math
import inspect
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import json
from tqdm import tqdm
import loguru
import shutil
# model architecture
class Head(nn.Module):
"""One head of self-attention."""
def __init__(self, head_size, n_embd, dropout, block_size):
super().__init__()
self.head_size = head_size
self.key = nn.Linear(n_embd, head_size, bias=False)
self.query = nn.Linear(n_embd, head_size, bias=False)
self.value = nn.Linear(n_embd, head_size, bias=False)
self.pos_embeds = nn.Embedding(block_size, head_size)
self.attn_dropout = nn.Dropout(dropout)
self.register_buffer(
"pos_embeds_indices",
torch.arange(0, block_size, 1, dtype=torch.int64)
)
self.register_buffer(
"causal_mask",
(torch.tril(torch.ones((block_size, block_size), dtype=torch.bool)) == False)
)
# ======
def forward(self, x):
B, T, C = x.shape
k = self.key(x) # (B, T, C)
q = self.query(x) # (B, T, C)
v = self.value(x) # (B, T, C)
r = self.pos_embeds(self.pos_embeds_indices[-T:]) # (T, C)
s_rel = q @ r.t() # (B, T, T)
s_rel = F.pad(s_rel, (1, 0)) # (B, T, 1+T)
s_rel = s_rel.reshape(B, s_rel.shape[-1], s_rel.shape[-2]) # (B, 1+T, T)
s_rel = s_rel[:, 1:, :] # (B, T, T)
attn = ((q @ k.transpose(-2, -1)) + s_rel) * (self.head_size**-0.5) # (B, T, T)
attn = attn.masked_fill(self.causal_mask[:T, :T], float("-inf")) # (B, T, T)
attn = F.softmax(attn, dim=-1) # (B, T, T)
attn = self.attn_dropout(attn) # (B, T, T)
out = attn @ v # (B, T, C)
return out
# ======
# ======
class MultiHeadAttention(nn.Module):
"""Multiple heads of self-attention in parallel."""
def __init__(self, num_heads, head_size, n_embd, dropout, block_size):
super().__init__()
self.heads = nn.ModuleList([Head(head_size, n_embd, dropout, block_size) for _ in range(num_heads)])
self.proj = nn.Linear(n_embd, n_embd)
self.dropout = nn.Dropout(dropout)
# ======
def forward(self, x):
out = torch.cat([h(x) for h in self.heads], dim=-1)
out = self.dropout(self.proj(out))
return out
# ======
# ======
class FeedForward(nn.Module):
"""A simple linear layer followed by a non-linearity."""
def __init__(self, n_embd, ff_width, dropout):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_embd, ff_width * n_embd, bias=False),
nn.SiLU(),
nn.Linear(ff_width * n_embd, n_embd, bias=False),
nn.Dropout(dropout),
)
# ======
def forward(self, x):
return self.net(x)
# ======
# ======
class Block(nn.Module):
"""Transformer block: communication followed by feedforward."""
def __init__(self, n_embd, n_head, ff_width, dropout, block_size):
super().__init__()
head_size = n_embd // n_head
self.sa = MultiHeadAttention(n_head, head_size, n_embd, dropout, block_size)
self.ffwd = FeedForward(n_embd, ff_width, dropout)
self.ln1 = nn.RMSNorm(normalized_shape=n_embd, eps=1e-5)
self.ln2 = nn.RMSNorm(normalized_shape=n_embd, eps=1e-5)
# ======
def forward(self, x):
x = x + self.sa(self.ln1(x))
x = x + self.ffwd(self.ln2(x))
return x
# ======
# ======
class GPT(nn.Module):
"""Transformer architecture for flowshop scheduling."""
def __init__(self, intermediate_schedules, vocab_size, n_embd, n_head, n_layer, dropout, ff_width, block_size):
super().__init__()
self.jobs_embedding_table = nn.Embedding(vocab_size, n_embd)
self.blocks = nn.Sequential(*[Block(n_embd=n_embd, n_head=n_head, ff_width=ff_width, dropout=dropout, block_size=block_size) for _ in range(n_layer)])
self.ln_f = nn.LayerNorm(n_embd)
self.makespan_head = nn.Linear(n_embd, 1)
self.intermediate_schedules = intermediate_schedules
# ======
def forward(self, idx, targets):
B, T = idx.shape
# idx and targets are both (B,T,) tensor of integers
x = self.jobs_embedding_table(idx) # (B,T,C,)
x = self.blocks(x) # (B,T,C,)
x = self.ln_f(x) # (B,T,C,)
if not self.intermediate_schedules:
x = x[:, -1, :].squeeze() # (B,C,)
makespans = self.makespan_head(x).squeeze() # (B,)
targets = targets[:,-1].squeeze() # (B,)
else:
x = x.reshape(B*T, -1) # (B*T,C)
makespans = self.makespan_head(x).squeeze() # (B*T,)
targets = targets.reshape(B*T) # (B*T,)
# ======
loss = F.smooth_l1_loss(makespans, targets, beta=0.1)
return makespans, loss
# ======
def generate(self, job_embeds):
"""
Function to generate Transformer predicted makespan from input schedule of job embeddings.
Input:
- job_embeds: a (T,C,) tensor of job embeddings
Output:
- makespan: a scalar tensor of the predicted makespan
"""
job_embeds = job_embeds.unsqueeze(0) # (1,T,C)
x = self.blocks(job_embeds) # (1,T,C,)
x = self.ln_f(x) # (1,T,C,)
makespans = self.makespan_head(x).squeeze() # (T,)
makespan = makespans[-1]
return makespan
# ======
def train(
testing: bool,
seed: int,
data_dir: str,
n_embd: int,
n_head: int,
n_layer: int,
intermediate_schedules: bool,
dropout: float,
ff_width: int,
train_batch_size: int,
val_batch_size: int,
nb_epochs: int,
early_stopping_patience: int,
checkpoint_interval_ratio: float,
decay_lr: bool,
lr_partitions_ratios: list[float],
init_lr: float,
max_lr: float,
min_lr: float,
lr_warmup_iters_ratio: float,
lr_decay_iters_ratio: float,
beta1: float,
beta2: float,
weight_decay: float,
grad_clip: float,
compile: bool,
compile_mode: str,
save_only_last_checkpoint: bool,
output_dir: str,
):
os.makedirs(output_dir, exist_ok=True)
# check if experiment termination flag file exists
if not testing:
if os.path.exists(os.path.join(output_dir, ".terminated_phase1")):
print("Phase 1 already terminated. Exiting...")
return
# ======
if not os.path.exists(os.path.join(output_dir, "viz_train.ipynb")):
shutil.copy("viz_train.ipynb", output_dir)
# ======
else:
# delete all the files that are created by this script like log files, batch_losses.npy and so on ...
files_to_delete = [
"train.log",
"train_parameters.json",
"batch_losses.npy",
"last_batch_loss_idx.npy",
"val_losses.npy",
"last_val_loss_idx.npy",
"train_pbar_epoch.log",
"train_pbar_val.log",
".terminated_phase1",
"viz_train.ipynb",
]
for f in files_to_delete:
f_path = os.path.join(output_dir, f)
if os.path.exists(f_path): os.remove(f_path)
# ======
checkpoints_dir = os.path.join(output_dir, "checkpoints")
if os.path.exists(checkpoints_dir): shutil.rmtree(checkpoints_dir)
shutil.copy("viz_train.ipynb", output_dir)
# ======
# check if GPU is available
assert torch.cuda.is_available(), "This code requires a GPU to run. Please run it on a machine with a CUDA-compatible GPU."
device = "cuda"
# setup logging
loguru.logger.add(os.path.join(output_dir, "train.log"))
# set random seeds
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
# setup model architecture parameters
with open(os.path.join(data_dir, "metadata.json"), "r") as f:
metadata = json.load(f)
block_size = metadata["nb_jobs"] # context window size
vocab_size = metadata["nb_jobs"] # vocabulary size
n_embd = n_embd # embedding dimension
n_head = n_head # number of attention heads
assert n_embd % n_head == 0
n_layer = n_layer # number of transformer blocks
intermediate_schedules = intermediate_schedules
ff_width = ff_width
# setup training parameters and utils
train_batch_size = train_batch_size # batch size for training
val_batch_size = val_batch_size # batch size for validation
nb_epochs = nb_epochs # number of pseudo-epochs to train for
early_stopping_patience = early_stopping_patience # number of epochs without improvement to trigger early stopping
dropout = dropout
class FlowshopDataset(torch.utils.data.Dataset):
"""Dataset for flowshop scheduling problem."""
def __init__(self, dataset_path, split, load_in_memory=True):
metadata_path = os.path.join(dataset_path, "metadata.json")
with open(metadata_path, "r") as f:
metadata = json.load(f)
nb_samples = metadata["nb_samples"]
nb_jobs = metadata["nb_jobs"]
schedules_path = os.path.join(dataset_path, f"schedules_{split}.npy")
schedules = np.lib.format.open_memmap(
schedules_path, dtype=np.int32, mode="r", shape=(nb_samples, nb_jobs)
)
loguru.logger.info(f"Loaded schedules from {schedules_path} with shape {schedules.shape}")
makespans_path = os.path.join(dataset_path, f"makespans_{split}.npy")
makespans = np.lib.format.open_memmap(
makespans_path, dtype=np.float32, mode="r", shape=(nb_samples,)
)
loguru.logger.info(f"Loaded makespans from {makespans_path} with shape {makespans.shape}")
if load_in_memory:
schedules = np.array(schedules)
makespans = np.array(makespans)
self.schedules = torch.from_numpy(schedules)
self.makespans = torch.from_numpy(makespans)
# ======
def __len__(self):
return len(self.schedules)
# ======
def __getitem__(self, idx):
return self.schedules[idx], self.makespans[idx]
# ======
train_dataset = FlowshopDataset(data_dir, split="train", load_in_memory=True)
train_data_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=train_batch_size,
shuffle=False,
drop_last=False,
)
## log the shape of the items returned by the train_data_loader
for schedules, makespans in train_data_loader:
loguru.logger.info(f"schedules.shape: {schedules.shape}")
loguru.logger.info(f"makespans.shape: {makespans.shape}")
break
nb_iters = nb_epochs * len(train_data_loader)
checkpoint_interval = int(checkpoint_interval_ratio * len(train_data_loader))
decay_lr = decay_lr
lr_partitions_ratios = lr_partitions_ratios + [None]
lr_partitions_iters = [int(r * nb_iters) for r in lr_partitions_ratios[:-1]]
lr_partitions_iters = lr_partitions_iters + [nb_iters - sum(lr_partitions_iters)]
assert sum(lr_partitions_iters) == nb_iters
init_lr = init_lr #1e-4
max_lr = max_lr #1e-3
min_lr = min_lr #5*1e-5
lr_warmup_iters_ratio = lr_warmup_iters_ratio #0.1
lr_decay_iters_ratio = lr_decay_iters_ratio #0.95
beta1 = beta1 # Adam beta1
beta2 = beta2 # Adam beta2
weight_decay = weight_decay # 1e-1 # weight decay
grad_clip = grad_clip # 1.0 # gradient clipping value
compile = compile
compile_mode = compile_mode
save_only_last_checkpoint = save_only_last_checkpoint
def human_readable(num):
"""Define function to make large numbers of parameters human-readable."""
magnitude = 0
while abs(num) >= 1000:
magnitude += 1
num /= 1000.0
return "%.0f%s" % (num, ["", "K", "M", "G", "T", "P"][magnitude])
# ======
def get_lr(it):
"""Get the learning rate for the current iteration."""
i = 0
tmp_it = it - lr_partitions_iters[i]
while tmp_it >= 0:
i += 1
tmp_it -= lr_partitions_iters[i]
lr_partition_iters = lr_partitions_iters[i]
warmup_iters = int(lr_partition_iters * lr_warmup_iters_ratio)
lr_decay_iters = int(lr_partition_iters * lr_decay_iters_ratio)
it = it - sum(lr_partitions_iters[:i])
# 1) linear warmup for warmup_iters steps
if it < warmup_iters:
return (it / warmup_iters) * (max_lr - init_lr) + init_lr
# 2) if it > lr_decay_iters, return min learning rate
if it > lr_decay_iters:
return min_lr
# 3) in between, use cosine decay down to min learning rate
decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
assert 0 <= decay_ratio <= 1
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
return min_lr + coeff * (max_lr - min_lr)
# ======
# log parameters
loguru.logger.info(f"data_dir: {data_dir}")
loguru.logger.info(f"block_size: {block_size}")
loguru.logger.info(f"vocab_size: {vocab_size}")
loguru.logger.info(f"n_embd: {n_embd}")
loguru.logger.info(f"n_head: {n_head}")
loguru.logger.info(f"n_layer: {n_layer}")
loguru.logger.info(f"ff_width: {ff_width}")
loguru.logger.info(f"train_batch_size: {train_batch_size}")
loguru.logger.info(f"val_batch_size: {val_batch_size}")
loguru.logger.info(f"dropout: {dropout}")
loguru.logger.info(f"nb_epochs: {nb_epochs}")
loguru.logger.info(f"early_stopping_patience: {early_stopping_patience}")
loguru.logger.info(f"nb_iters: {nb_iters}")
loguru.logger.info(f"checkpoint_interval: {checkpoint_interval}")
loguru.logger.info(f"decay_lr: {decay_lr}")
loguru.logger.info(f"lr_partitions_ratios: {lr_partitions_ratios}")
loguru.logger.info(f"lr_partitions_iters: {lr_partitions_iters}")
loguru.logger.info(f"init_lr: {init_lr}")
loguru.logger.info(f"max_lr: {max_lr}")
loguru.logger.info(f"min_lr: {min_lr}")
loguru.logger.info(f"lr_warmup_iters_ratio: {lr_warmup_iters_ratio}")
loguru.logger.info(f"lr_decay_iters_ratio: {lr_decay_iters_ratio}")
loguru.logger.info(f"beta1: {beta1}")
loguru.logger.info(f"beta2: {beta2}")
loguru.logger.info(f"weight_decay: {weight_decay}")
loguru.logger.info(f"grad_clip: {grad_clip}")
loguru.logger.info(f"compile: {compile}")
loguru.logger.info(f"compile_mode: {compile_mode}")
loguru.logger.info(f"intermediate_schedules: {intermediate_schedules}")
loguru.logger.info(f"save_only_last_checkpoint: {save_only_last_checkpoint}")
# save parameters into a train_parameters.json
train_params = {
"data_dir": data_dir,
"block_size": block_size,
"vocab_size": vocab_size,
"n_embd": n_embd,
"n_head": n_head,
"n_layer": n_layer,
"ff_width": ff_width,
"train_batch_size": train_batch_size,
"val_batch_size": val_batch_size,
"dropout": dropout,
"nb_epochs": nb_epochs,
"early_stopping_patience": early_stopping_patience,
"nb_iters": nb_iters,
"checkpoint_interval": checkpoint_interval,
"decay_lr": decay_lr,
"lr_partitions_ratios": lr_partitions_ratios,
"lr_partitions_iters": lr_partitions_iters,
"init_lr": init_lr,
"max_lr": max_lr,
"min_lr": min_lr,
"lr_warmup_iters_ratio": lr_warmup_iters_ratio,
"lr_decay_iters_ratio": lr_decay_iters_ratio,
"beta1": beta1,
"beta2": beta2,
"weight_decay": weight_decay,
"grad_clip": grad_clip,
"compile": compile,
"compile_mode": compile_mode,
"intermediate_schedules": intermediate_schedules,
"save_only_last_checkpoint": save_only_last_checkpoint,
}
with open(os.path.join(output_dir, "train_parameters.json"), "w") as f: json.dump(train_params, f, indent=4)
# load the last checkpoint if it exists, otherwise initialize the training from scratch
try:
last_checkpoint = torch.load(os.path.join(output_dir, "checkpoints", "last_checkpoint.pth"))
start_epoch = last_checkpoint["epoch"]
start_epoch_iter = last_checkpoint["epoch_iter"] + 1
model_state_dict = last_checkpoint["model_state_dict"]
optimizer_state_dict = last_checkpoint["optimizer_state_dict"]
best_val_loss = last_checkpoint["best_val_loss"]
patience_counter = last_checkpoint["patience_counter"]
improved_this_epoch = last_checkpoint["improved_this_epoch"]
except FileNotFoundError:
os.makedirs(os.path.join(output_dir, "checkpoints"), exist_ok=True)
start_epoch = 0
start_epoch_iter = 0
model_state_dict = None
optimizer_state_dict = None
best_val_loss = float("inf")
patience_counter = 0
improved_this_epoch = False
# initialize the model
model = GPT(
intermediate_schedules=intermediate_schedules,
vocab_size=vocab_size,
n_embd=n_embd,
n_head=n_head,
n_layer=n_layer,
dropout=dropout,
ff_width=ff_width,
block_size=block_size
).to(device)
loguru.logger.info(f"The model has {human_readable(sum(p.numel() for p in model.parameters() if p.requires_grad))} trainable parameters")
train_model = model
if model_state_dict is not None:
train_model.load_state_dict(model_state_dict)
if compile:
train_model = torch.compile(train_model, mode=compile_mode)
# initialize the optimizer
param_dict = {pn: p for pn, p in train_model.named_parameters()}
## filter out those that do not require grad
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
## create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
## i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
optim_groups = [
{"params": decay_params, "weight_decay": weight_decay},
{"params": nodecay_params, "weight_decay": 0.0},
]
num_decay_params = sum(p.numel() for p in decay_params)
num_nodecay_params = sum(p.numel() for p in nodecay_params)
loguru.logger.info(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
loguru.logger.info(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
## create AdamW optimizer and use the fused version if it is available
fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters
use_fused = fused_available and ("cuda" in device)
loguru.logger.info(f"using fused AdamW: {use_fused}")
extra_args = dict(fused=True) if use_fused else dict()
optimizer = torch.optim.AdamW(optim_groups, lr=init_lr, betas=(beta1, beta2), **extra_args)
## load the optimizer state if it exists
if optimizer_state_dict is not None:
optimizer.load_state_dict(optimizer_state_dict)
# set the torch precision to tf32
torch.set_float32_matmul_precision("high")
# initialize the np memmap array to save the batch losses
batch_losses_path = os.path.join(output_dir, "batch_losses.npy")
last_batch_loss_idx_path = os.path.join(output_dir, "last_batch_loss_idx.npy")
val_losses_path = os.path.join(output_dir, "val_losses.npy")
last_val_loss_idx_path = os.path.join(output_dir, "last_val_loss_idx.npy")
try:
batch_losses = np.lib.format.open_memmap(batch_losses_path, mode="r+", dtype=np.float32, shape=(nb_iters,))
last_batch_loss_idx = np.lib.format.open_memmap(last_batch_loss_idx_path, mode="r+", dtype=np.int32, shape=())
val_losses = np.lib.format.open_memmap(val_losses_path, mode="r+", dtype=np.float32, shape=(nb_epochs * math.ceil(len(train_data_loader)/checkpoint_interval),))
last_val_loss_idx = np.lib.format.open_memmap(last_val_loss_idx_path, mode="r+", dtype=np.int32, shape=())
except FileNotFoundError:
batch_losses = np.lib.format.open_memmap(batch_losses_path, mode="w+", dtype=np.float32, shape=(nb_iters,))
last_batch_loss_idx = np.lib.format.open_memmap(last_batch_loss_idx_path, mode="w+", dtype=np.int32, shape=())
val_losses = np.lib.format.open_memmap(val_losses_path, mode="w+", dtype=np.float32, shape=(nb_epochs * math.ceil(len(train_data_loader)/checkpoint_interval),))
last_val_loss_idx = np.lib.format.open_memmap(last_val_loss_idx_path, mode="w+", dtype=np.int32, shape=())
last_batch_loss_idx[...] = 0
last_batch_loss_idx.flush()
last_val_loss_idx[...] = 0
last_val_loss_idx.flush()
# create data_loader for validation
val_data_loader = torch.utils.data.DataLoader(
FlowshopDataset(data_dir, split="val", load_in_memory=True),
batch_size=val_batch_size,
shuffle=False,
)
# launch the training loop
early_stop = False
for epoch in range(start_epoch, nb_epochs):
if early_stop: break
if start_epoch_iter == 0: improved_this_epoch = False
# implement the logic to resume after failure
## create the generator, sampler, data loader
generator = torch.Generator()
generator.manual_seed(seed + epoch)
train_sampler = torch.utils.data.RandomSampler(
train_dataset,
generator=generator
)
train_data_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=train_batch_size,
sampler=train_sampler,
drop_last=False if not compile else True,
)
train_data_loader_iterator = iter(train_data_loader)
## skip the iterations until start_iter for the epoch
train_data_loader_iterator = itertools.islice(
train_data_loader_iterator,
start_epoch_iter,
None,
)
# iterate over the training data loader
for epoch_iter, (schedules_batch, makespans_batch) in (pbar:=tqdm(
enumerate(train_data_loader_iterator, start=start_epoch_iter),
total=len(train_data_loader),
initial=start_epoch_iter,
desc=f"Epoch {epoch+1}/{nb_epochs}",
)):
with open(os.path.join(output_dir, "train_pbar_epoch.log"), "w") as f: f.write(str(pbar))
# move the batch to the device
schedules_batch = schedules_batch.to(device)
makespans_batch = makespans_batch.to(device)
# clear the gradients
optimizer.zero_grad(set_to_none=True)
# forward pass and compute the loss
makespans, loss = train_model(schedules_batch, makespans_batch)
# backward pass
loss.backward()
# clip the gradients if grad_clip is set to a non-zero value
if grad_clip != 0.0: torch.nn.utils.clip_grad_norm_(train_model.parameters(), grad_clip)
# get the learning rate for the current iteration and set it in the optimizer
current_iter = epoch * len(train_data_loader) + epoch_iter
lr = get_lr(current_iter) if decay_lr else init_lr
for param_group in optimizer.param_groups: param_group["lr"] = lr
# update the parameters
optimizer.step()
# save the training loss for the current iteration
batch_losses[current_iter] = loss.item()
last_batch_loss_idx[...] = current_iter
batch_losses.flush()
last_batch_loss_idx.flush()
# validation and checkpointing
if (epoch_iter + 1) % checkpoint_interval == 0 or (epoch_iter + 1) == len(train_data_loader):
# initialize the total validation loss
total_val_loss = 0
# iterate over the validation data loader
for schedules_batch, makespans_batch in (pbar2:=tqdm(
val_data_loader,
desc=f"Validation {epoch+(epoch_iter+1)/len(train_data_loader):.2f}",
)):
with open(os.path.join(output_dir, "train_pbar_val.log"), "w") as f: f.write(str(pbar2))
# move the batch to the device
schedules_batch = schedules_batch.to(device)
makespans_batch = makespans_batch.to(device)
# compute the validation loss without gradient tracking, and update the total validation loss
with torch.no_grad():
makespans, loss = train_model(schedules_batch, makespans_batch)
total_val_loss += loss.item() * schedules_batch.size(0)
# ======
with open(os.path.join(output_dir, "train_pbar_val.log"), "w") as f: f.write(str(pbar2))
# compute the total validation loss (averaging over the dataset)
total_val_loss /= len(val_data_loader.dataset)
val_loss_idx = epoch * math.ceil(len(train_data_loader)/checkpoint_interval) + epoch_iter // checkpoint_interval
val_losses[val_loss_idx] = total_val_loss
last_val_loss_idx[...] = val_loss_idx
val_losses.flush()
last_val_loss_idx.flush()
# early stopping check
if total_val_loss < best_val_loss:
best_val_loss = total_val_loss
improved_this_epoch = True
# save the checkpoint
checkpoint = {
"epoch": epoch,
"epoch_iter": epoch_iter,
"model_state_dict": train_model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"validation_loss": total_val_loss,
"time": time.strftime("%Y_%m_%d_%H_%M_%S"),
"best_val_loss": best_val_loss,
"patience_counter": patience_counter,
"improved_this_epoch": improved_this_epoch,
}
torch.save(
checkpoint,
os.path.join(output_dir, "checkpoints", "last_checkpoint.pth")
)
if not save_only_last_checkpoint:
torch.save(
checkpoint,
os.path.join(output_dir, "checkpoints", f"checkpoint_epoch_{epoch+(epoch_iter+1)/len(train_data_loader):.2f}.pth")
)
if best_val_loss == total_val_loss:
torch.save(
checkpoint,
os.path.join(output_dir, "checkpoints", "best_checkpoint.pth")
)
# ======
# ======
# ======
with open(os.path.join(output_dir, "train_pbar_epoch.log"), "w") as f: f.write(str(pbar))
# set the start_epoch_iter to 0 for the next epoch
start_epoch_iter = 0
# check if early stopping should be triggered at the end of the epoch
if improved_this_epoch:
patience_counter = 0
else:
patience_counter += 1
if patience_counter >= early_stopping_patience:
loguru.logger.info(f"Early stopping triggered! Validation loss hasn't improved for {early_stopping_patience} epochs.")
early_stop = True
# ======
# log the best validation loss
loguru.logger.info(f"Best validation loss: {best_val_loss:.4f}")
# create experiment termination flag file
with open(os.path.join(output_dir, ".terminated_phase1"), "w") as f:
pass
# ======
# ======
if __name__ == "__main__":
# parse arguments
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--testing", type=bool, required=True)
parser.add_argument("--seed", type=int, required=True)
parser.add_argument("--data_dir", type=str, required=True)
parser.add_argument("--n_embd", type=int, required=True)
parser.add_argument("--n_head", type=int, required=True)
parser.add_argument("--n_layer", type=int, required=True)
parser.add_argument("--intermediate_schedules", type=bool, required=True)
parser.add_argument("--dropout", type=float, required=True)
parser.add_argument("--ff_width", type=int, required=True)
parser.add_argument("--train_batch_size", type=int, required=True)
parser.add_argument("--val_batch_size", type=int, required=True)
parser.add_argument("--nb_epochs", type=int, required=True)
parser.add_argument("--early_stopping_patience", type=int, required=True)
parser.add_argument("--checkpoint_interval_ratio", type=float, required=True)
parser.add_argument("--decay_lr", type=bool, required=True)
parser.add_argument("--lr_partitions_ratios", type=lambda s: [float(item) for item in s.split(',')], help='Comma-separated list of floats that do not add up to 1 (e.g., 0.1,0.5,1)', required=True)
parser.add_argument("--init_lr", type=float, required=True)
parser.add_argument("--max_lr", type=float, required=True)
parser.add_argument("--min_lr", type=float, required=True)
parser.add_argument("--lr_warmup_iters_ratio", type=float, required=True)
parser.add_argument("--lr_decay_iters_ratio", type=float, required=True)
parser.add_argument("--beta1", type=float, required=True)
parser.add_argument("--beta2", type=float, required=True)
parser.add_argument("--weight_decay", type=float, required=True)
parser.add_argument("--grad_clip", type=float, required=True)
parser.add_argument("--compile", type=bool, required=True)
parser.add_argument("--compile_mode", type=str, required=True)
parser.add_argument("--save_only_last_checkpoint", type=bool, required=True)
parser.add_argument("--output_dir", type=str, required=True)
args = parser.parse_args()
train(
testing=args.testing,
seed=args.seed,
data_dir=args.data_dir,
n_embd=args.n_embd,
n_head=args.n_head,
n_layer=args.n_layer,
intermediate_schedules=args.intermediate_schedules,
dropout=args.dropout,
ff_width=args.ff_width,
train_batch_size=args.train_batch_size,
val_batch_size=args.val_batch_size,
nb_epochs=args.nb_epochs,
early_stopping_patience=args.early_stopping_patience,
checkpoint_interval_ratio=args.checkpoint_interval_ratio,
decay_lr=args.decay_lr,
lr_partitions_ratios=args.lr_partitions_ratios,
init_lr=args.init_lr,
max_lr=args.max_lr,
min_lr=args.min_lr,
lr_warmup_iters_ratio=args.lr_warmup_iters_ratio,
lr_decay_iters_ratio=args.lr_decay_iters_ratio,
beta1=args.beta1,
beta2=args.beta2,
weight_decay=args.weight_decay,
grad_clip=args.grad_clip,
compile=args.compile,
compile_mode=args.compile_mode,
save_only_last_checkpoint=args.save_only_last_checkpoint,
output_dir=args.output_dir,
)