import itertools import random import os import time import torch import math import inspect import torch.nn as nn import torch.nn.functional as F import numpy as np import json from tqdm import tqdm import loguru import shutil # model architecture class Head(nn.Module): """One head of self-attention.""" def __init__(self, head_size, n_embd, dropout, block_size): super().__init__() self.head_size = head_size self.key = nn.Linear(n_embd, head_size, bias=False) self.query = nn.Linear(n_embd, head_size, bias=False) self.value = nn.Linear(n_embd, head_size, bias=False) self.pos_embeds = nn.Embedding(block_size, head_size) self.attn_dropout = nn.Dropout(dropout) self.register_buffer( "pos_embeds_indices", torch.arange(0, block_size, 1, dtype=torch.int64) ) self.register_buffer( "causal_mask", (torch.tril(torch.ones((block_size, block_size), dtype=torch.bool)) == False) ) # ====== def forward(self, x): B, T, C = x.shape k = self.key(x) # (B, T, C) q = self.query(x) # (B, T, C) v = self.value(x) # (B, T, C) r = self.pos_embeds(self.pos_embeds_indices[-T:]) # (T, C) s_rel = q @ r.t() # (B, T, T) s_rel = F.pad(s_rel, (1, 0)) # (B, T, 1+T) s_rel = s_rel.reshape(B, s_rel.shape[-1], s_rel.shape[-2]) # (B, 1+T, T) s_rel = s_rel[:, 1:, :] # (B, T, T) attn = ((q @ k.transpose(-2, -1)) + s_rel) * (self.head_size**-0.5) # (B, T, T) attn = attn.masked_fill(self.causal_mask[:T, :T], float("-inf")) # (B, T, T) attn = F.softmax(attn, dim=-1) # (B, T, T) attn = self.attn_dropout(attn) # (B, T, T) out = attn @ v # (B, T, C) return out # ====== # ====== class MultiHeadAttention(nn.Module): """Multiple heads of self-attention in parallel.""" def __init__(self, num_heads, head_size, n_embd, dropout, block_size): super().__init__() self.heads = nn.ModuleList([Head(head_size, n_embd, dropout, block_size) for _ in range(num_heads)]) self.proj = nn.Linear(n_embd, n_embd) self.dropout = nn.Dropout(dropout) # ====== def forward(self, x): out = torch.cat([h(x) for h in self.heads], dim=-1) out = self.dropout(self.proj(out)) return out # ====== # ====== class FeedForward(nn.Module): """A simple linear layer followed by a non-linearity.""" def __init__(self, n_embd, ff_width, dropout): super().__init__() self.net = nn.Sequential( nn.Linear(n_embd, ff_width * n_embd, bias=False), nn.SiLU(), nn.Linear(ff_width * n_embd, n_embd, bias=False), nn.Dropout(dropout), ) # ====== def forward(self, x): return self.net(x) # ====== # ====== class Block(nn.Module): """Transformer block: communication followed by feedforward.""" def __init__(self, n_embd, n_head, ff_width, dropout, block_size): super().__init__() head_size = n_embd // n_head self.sa = MultiHeadAttention(n_head, head_size, n_embd, dropout, block_size) self.ffwd = FeedForward(n_embd, ff_width, dropout) self.ln1 = nn.RMSNorm(normalized_shape=n_embd, eps=1e-5) self.ln2 = nn.RMSNorm(normalized_shape=n_embd, eps=1e-5) # ====== def forward(self, x): x = x + self.sa(self.ln1(x)) x = x + self.ffwd(self.ln2(x)) return x # ====== # ====== class GPT(nn.Module): """Transformer architecture for flowshop scheduling.""" def __init__(self, intermediate_schedules, vocab_size, n_embd, n_head, n_layer, dropout, ff_width, block_size): super().__init__() self.jobs_embedding_table = nn.Embedding(vocab_size, n_embd) self.blocks = nn.Sequential(*[Block(n_embd=n_embd, n_head=n_head, ff_width=ff_width, dropout=dropout, block_size=block_size) for _ in range(n_layer)]) self.ln_f = nn.LayerNorm(n_embd) self.makespan_head = nn.Linear(n_embd, 1) self.intermediate_schedules = intermediate_schedules # ====== def forward(self, idx, targets): B, T = idx.shape # idx and targets are both (B,T,) tensor of integers x = self.jobs_embedding_table(idx) # (B,T,C,) x = self.blocks(x) # (B,T,C,) x = self.ln_f(x) # (B,T,C,) if not self.intermediate_schedules: x = x[:, -1, :].squeeze() # (B,C,) makespans = self.makespan_head(x).squeeze() # (B,) targets = targets[:,-1].squeeze() # (B,) else: x = x.reshape(B*T, -1) # (B*T,C) makespans = self.makespan_head(x).squeeze() # (B*T,) targets = targets.reshape(B*T) # (B*T,) # ====== loss = F.smooth_l1_loss(makespans, targets, beta=0.1) return makespans, loss # ====== def generate(self, job_embeds): """ Function to generate Transformer predicted makespan from input schedule of job embeddings. Input: - job_embeds: a (T,C,) tensor of job embeddings Output: - makespan: a scalar tensor of the predicted makespan """ job_embeds = job_embeds.unsqueeze(0) # (1,T,C) x = self.blocks(job_embeds) # (1,T,C,) x = self.ln_f(x) # (1,T,C,) makespans = self.makespan_head(x).squeeze() # (T,) makespan = makespans[-1] return makespan # ====== def train( testing: bool, seed: int, data_dir: str, n_embd: int, n_head: int, n_layer: int, intermediate_schedules: bool, dropout: float, ff_width: int, train_batch_size: int, val_batch_size: int, nb_epochs: int, early_stopping_patience: int, checkpoint_interval_ratio: float, decay_lr: bool, lr_partitions_ratios: list[float], init_lr: float, max_lr: float, min_lr: float, lr_warmup_iters_ratio: float, lr_decay_iters_ratio: float, beta1: float, beta2: float, weight_decay: float, grad_clip: float, compile: bool, compile_mode: str, save_only_last_checkpoint: bool, output_dir: str, ): os.makedirs(output_dir, exist_ok=True) # check if experiment termination flag file exists if not testing: if os.path.exists(os.path.join(output_dir, ".terminated_phase1")): print("Phase 1 already terminated. Exiting...") return # ====== if not os.path.exists(os.path.join(output_dir, "viz_train.ipynb")): shutil.copy("viz_train.ipynb", output_dir) # ====== else: # delete all the files that are created by this script like log files, batch_losses.npy and so on ... files_to_delete = [ "train.log", "train_parameters.json", "batch_losses.npy", "last_batch_loss_idx.npy", "val_losses.npy", "last_val_loss_idx.npy", "train_pbar_epoch.log", "train_pbar_val.log", ".terminated_phase1", "viz_train.ipynb", ] for f in files_to_delete: f_path = os.path.join(output_dir, f) if os.path.exists(f_path): os.remove(f_path) # ====== checkpoints_dir = os.path.join(output_dir, "checkpoints") if os.path.exists(checkpoints_dir): shutil.rmtree(checkpoints_dir) shutil.copy("viz_train.ipynb", output_dir) # ====== # check if GPU is available assert torch.cuda.is_available(), "This code requires a GPU to run. Please run it on a machine with a CUDA-compatible GPU." device = "cuda" # setup logging loguru.logger.add(os.path.join(output_dir, "train.log")) # set random seeds torch.manual_seed(seed) random.seed(seed) np.random.seed(seed) # setup model architecture parameters with open(os.path.join(data_dir, "metadata.json"), "r") as f: metadata = json.load(f) block_size = metadata["nb_jobs"] # context window size vocab_size = metadata["nb_jobs"] # vocabulary size n_embd = n_embd # embedding dimension n_head = n_head # number of attention heads assert n_embd % n_head == 0 n_layer = n_layer # number of transformer blocks intermediate_schedules = intermediate_schedules ff_width = ff_width # setup training parameters and utils train_batch_size = train_batch_size # batch size for training val_batch_size = val_batch_size # batch size for validation nb_epochs = nb_epochs # number of pseudo-epochs to train for early_stopping_patience = early_stopping_patience # number of epochs without improvement to trigger early stopping dropout = dropout class FlowshopDataset(torch.utils.data.Dataset): """Dataset for flowshop scheduling problem.""" def __init__(self, dataset_path, split, load_in_memory=True): metadata_path = os.path.join(dataset_path, "metadata.json") with open(metadata_path, "r") as f: metadata = json.load(f) nb_samples = metadata["nb_samples"] nb_jobs = metadata["nb_jobs"] schedules_path = os.path.join(dataset_path, f"schedules_{split}.npy") schedules = np.lib.format.open_memmap( schedules_path, dtype=np.int32, mode="r", shape=(nb_samples, nb_jobs) ) loguru.logger.info(f"Loaded schedules from {schedules_path} with shape {schedules.shape}") makespans_path = os.path.join(dataset_path, f"makespans_{split}.npy") makespans = np.lib.format.open_memmap( makespans_path, dtype=np.float32, mode="r", shape=(nb_samples,) ) loguru.logger.info(f"Loaded makespans from {makespans_path} with shape {makespans.shape}") if load_in_memory: schedules = np.array(schedules) makespans = np.array(makespans) self.schedules = torch.from_numpy(schedules) self.makespans = torch.from_numpy(makespans) # ====== def __len__(self): return len(self.schedules) # ====== def __getitem__(self, idx): return self.schedules[idx], self.makespans[idx] # ====== train_dataset = FlowshopDataset(data_dir, split="train", load_in_memory=True) train_data_loader = torch.utils.data.DataLoader( train_dataset, batch_size=train_batch_size, shuffle=False, drop_last=False, ) ## log the shape of the items returned by the train_data_loader for schedules, makespans in train_data_loader: loguru.logger.info(f"schedules.shape: {schedules.shape}") loguru.logger.info(f"makespans.shape: {makespans.shape}") break nb_iters = nb_epochs * len(train_data_loader) checkpoint_interval = int(checkpoint_interval_ratio * len(train_data_loader)) decay_lr = decay_lr lr_partitions_ratios = lr_partitions_ratios + [None] lr_partitions_iters = [int(r * nb_iters) for r in lr_partitions_ratios[:-1]] lr_partitions_iters = lr_partitions_iters + [nb_iters - sum(lr_partitions_iters)] assert sum(lr_partitions_iters) == nb_iters init_lr = init_lr #1e-4 max_lr = max_lr #1e-3 min_lr = min_lr #5*1e-5 lr_warmup_iters_ratio = lr_warmup_iters_ratio #0.1 lr_decay_iters_ratio = lr_decay_iters_ratio #0.95 beta1 = beta1 # Adam beta1 beta2 = beta2 # Adam beta2 weight_decay = weight_decay # 1e-1 # weight decay grad_clip = grad_clip # 1.0 # gradient clipping value compile = compile compile_mode = compile_mode save_only_last_checkpoint = save_only_last_checkpoint def human_readable(num): """Define function to make large numbers of parameters human-readable.""" magnitude = 0 while abs(num) >= 1000: magnitude += 1 num /= 1000.0 return "%.0f%s" % (num, ["", "K", "M", "G", "T", "P"][magnitude]) # ====== def get_lr(it): """Get the learning rate for the current iteration.""" i = 0 tmp_it = it - lr_partitions_iters[i] while tmp_it >= 0: i += 1 tmp_it -= lr_partitions_iters[i] lr_partition_iters = lr_partitions_iters[i] warmup_iters = int(lr_partition_iters * lr_warmup_iters_ratio) lr_decay_iters = int(lr_partition_iters * lr_decay_iters_ratio) it = it - sum(lr_partitions_iters[:i]) # 1) linear warmup for warmup_iters steps if it < warmup_iters: return (it / warmup_iters) * (max_lr - init_lr) + init_lr # 2) if it > lr_decay_iters, return min learning rate if it > lr_decay_iters: return min_lr # 3) in between, use cosine decay down to min learning rate decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) assert 0 <= decay_ratio <= 1 coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 return min_lr + coeff * (max_lr - min_lr) # ====== # log parameters loguru.logger.info(f"data_dir: {data_dir}") loguru.logger.info(f"block_size: {block_size}") loguru.logger.info(f"vocab_size: {vocab_size}") loguru.logger.info(f"n_embd: {n_embd}") loguru.logger.info(f"n_head: {n_head}") loguru.logger.info(f"n_layer: {n_layer}") loguru.logger.info(f"ff_width: {ff_width}") loguru.logger.info(f"train_batch_size: {train_batch_size}") loguru.logger.info(f"val_batch_size: {val_batch_size}") loguru.logger.info(f"dropout: {dropout}") loguru.logger.info(f"nb_epochs: {nb_epochs}") loguru.logger.info(f"early_stopping_patience: {early_stopping_patience}") loguru.logger.info(f"nb_iters: {nb_iters}") loguru.logger.info(f"checkpoint_interval: {checkpoint_interval}") loguru.logger.info(f"decay_lr: {decay_lr}") loguru.logger.info(f"lr_partitions_ratios: {lr_partitions_ratios}") loguru.logger.info(f"lr_partitions_iters: {lr_partitions_iters}") loguru.logger.info(f"init_lr: {init_lr}") loguru.logger.info(f"max_lr: {max_lr}") loguru.logger.info(f"min_lr: {min_lr}") loguru.logger.info(f"lr_warmup_iters_ratio: {lr_warmup_iters_ratio}") loguru.logger.info(f"lr_decay_iters_ratio: {lr_decay_iters_ratio}") loguru.logger.info(f"beta1: {beta1}") loguru.logger.info(f"beta2: {beta2}") loguru.logger.info(f"weight_decay: {weight_decay}") loguru.logger.info(f"grad_clip: {grad_clip}") loguru.logger.info(f"compile: {compile}") loguru.logger.info(f"compile_mode: {compile_mode}") loguru.logger.info(f"intermediate_schedules: {intermediate_schedules}") loguru.logger.info(f"save_only_last_checkpoint: {save_only_last_checkpoint}") # save parameters into a train_parameters.json train_params = { "data_dir": data_dir, "block_size": block_size, "vocab_size": vocab_size, "n_embd": n_embd, "n_head": n_head, "n_layer": n_layer, "ff_width": ff_width, "train_batch_size": train_batch_size, "val_batch_size": val_batch_size, "dropout": dropout, "nb_epochs": nb_epochs, "early_stopping_patience": early_stopping_patience, "nb_iters": nb_iters, "checkpoint_interval": checkpoint_interval, "decay_lr": decay_lr, "lr_partitions_ratios": lr_partitions_ratios, "lr_partitions_iters": lr_partitions_iters, "init_lr": init_lr, "max_lr": max_lr, "min_lr": min_lr, "lr_warmup_iters_ratio": lr_warmup_iters_ratio, "lr_decay_iters_ratio": lr_decay_iters_ratio, "beta1": beta1, "beta2": beta2, "weight_decay": weight_decay, "grad_clip": grad_clip, "compile": compile, "compile_mode": compile_mode, "intermediate_schedules": intermediate_schedules, "save_only_last_checkpoint": save_only_last_checkpoint, } with open(os.path.join(output_dir, "train_parameters.json"), "w") as f: json.dump(train_params, f, indent=4) # load the last checkpoint if it exists, otherwise initialize the training from scratch try: last_checkpoint = torch.load(os.path.join(output_dir, "checkpoints", "last_checkpoint.pth")) start_epoch = last_checkpoint["epoch"] start_epoch_iter = last_checkpoint["epoch_iter"] + 1 model_state_dict = last_checkpoint["model_state_dict"] optimizer_state_dict = last_checkpoint["optimizer_state_dict"] best_val_loss = last_checkpoint["best_val_loss"] patience_counter = last_checkpoint["patience_counter"] improved_this_epoch = last_checkpoint["improved_this_epoch"] except FileNotFoundError: os.makedirs(os.path.join(output_dir, "checkpoints"), exist_ok=True) start_epoch = 0 start_epoch_iter = 0 model_state_dict = None optimizer_state_dict = None best_val_loss = float("inf") patience_counter = 0 improved_this_epoch = False # initialize the model model = GPT( intermediate_schedules=intermediate_schedules, vocab_size=vocab_size, n_embd=n_embd, n_head=n_head, n_layer=n_layer, dropout=dropout, ff_width=ff_width, block_size=block_size ).to(device) loguru.logger.info(f"The model has {human_readable(sum(p.numel() for p in model.parameters() if p.requires_grad))} trainable parameters") train_model = model if model_state_dict is not None: train_model.load_state_dict(model_state_dict) if compile: train_model = torch.compile(train_model, mode=compile_mode) # initialize the optimizer param_dict = {pn: p for pn, p in train_model.named_parameters()} ## filter out those that do not require grad param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} ## create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. ## i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] optim_groups = [ {"params": decay_params, "weight_decay": weight_decay}, {"params": nodecay_params, "weight_decay": 0.0}, ] num_decay_params = sum(p.numel() for p in decay_params) num_nodecay_params = sum(p.numel() for p in nodecay_params) loguru.logger.info(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") loguru.logger.info(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") ## create AdamW optimizer and use the fused version if it is available fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters use_fused = fused_available and ("cuda" in device) loguru.logger.info(f"using fused AdamW: {use_fused}") extra_args = dict(fused=True) if use_fused else dict() optimizer = torch.optim.AdamW(optim_groups, lr=init_lr, betas=(beta1, beta2), **extra_args) ## load the optimizer state if it exists if optimizer_state_dict is not None: optimizer.load_state_dict(optimizer_state_dict) # set the torch precision to tf32 torch.set_float32_matmul_precision("high") # initialize the np memmap array to save the batch losses batch_losses_path = os.path.join(output_dir, "batch_losses.npy") last_batch_loss_idx_path = os.path.join(output_dir, "last_batch_loss_idx.npy") val_losses_path = os.path.join(output_dir, "val_losses.npy") last_val_loss_idx_path = os.path.join(output_dir, "last_val_loss_idx.npy") try: batch_losses = np.lib.format.open_memmap(batch_losses_path, mode="r+", dtype=np.float32, shape=(nb_iters,)) last_batch_loss_idx = np.lib.format.open_memmap(last_batch_loss_idx_path, mode="r+", dtype=np.int32, shape=()) val_losses = np.lib.format.open_memmap(val_losses_path, mode="r+", dtype=np.float32, shape=(nb_epochs * math.ceil(len(train_data_loader)/checkpoint_interval),)) last_val_loss_idx = np.lib.format.open_memmap(last_val_loss_idx_path, mode="r+", dtype=np.int32, shape=()) except FileNotFoundError: batch_losses = np.lib.format.open_memmap(batch_losses_path, mode="w+", dtype=np.float32, shape=(nb_iters,)) last_batch_loss_idx = np.lib.format.open_memmap(last_batch_loss_idx_path, mode="w+", dtype=np.int32, shape=()) val_losses = np.lib.format.open_memmap(val_losses_path, mode="w+", dtype=np.float32, shape=(nb_epochs * math.ceil(len(train_data_loader)/checkpoint_interval),)) last_val_loss_idx = np.lib.format.open_memmap(last_val_loss_idx_path, mode="w+", dtype=np.int32, shape=()) last_batch_loss_idx[...] = 0 last_batch_loss_idx.flush() last_val_loss_idx[...] = 0 last_val_loss_idx.flush() # create data_loader for validation val_data_loader = torch.utils.data.DataLoader( FlowshopDataset(data_dir, split="val", load_in_memory=True), batch_size=val_batch_size, shuffle=False, ) # launch the training loop early_stop = False for epoch in range(start_epoch, nb_epochs): if early_stop: break if start_epoch_iter == 0: improved_this_epoch = False # implement the logic to resume after failure ## create the generator, sampler, data loader generator = torch.Generator() generator.manual_seed(seed + epoch) train_sampler = torch.utils.data.RandomSampler( train_dataset, generator=generator ) train_data_loader = torch.utils.data.DataLoader( train_dataset, batch_size=train_batch_size, sampler=train_sampler, drop_last=False if not compile else True, ) train_data_loader_iterator = iter(train_data_loader) ## skip the iterations until start_iter for the epoch train_data_loader_iterator = itertools.islice( train_data_loader_iterator, start_epoch_iter, None, ) # iterate over the training data loader for epoch_iter, (schedules_batch, makespans_batch) in (pbar:=tqdm( enumerate(train_data_loader_iterator, start=start_epoch_iter), total=len(train_data_loader), initial=start_epoch_iter, desc=f"Epoch {epoch+1}/{nb_epochs}", )): with open(os.path.join(output_dir, "train_pbar_epoch.log"), "w") as f: f.write(str(pbar)) # move the batch to the device schedules_batch = schedules_batch.to(device) makespans_batch = makespans_batch.to(device) # clear the gradients optimizer.zero_grad(set_to_none=True) # forward pass and compute the loss makespans, loss = train_model(schedules_batch, makespans_batch) # backward pass loss.backward() # clip the gradients if grad_clip is set to a non-zero value if grad_clip != 0.0: torch.nn.utils.clip_grad_norm_(train_model.parameters(), grad_clip) # get the learning rate for the current iteration and set it in the optimizer current_iter = epoch * len(train_data_loader) + epoch_iter lr = get_lr(current_iter) if decay_lr else init_lr for param_group in optimizer.param_groups: param_group["lr"] = lr # update the parameters optimizer.step() # save the training loss for the current iteration batch_losses[current_iter] = loss.item() last_batch_loss_idx[...] = current_iter batch_losses.flush() last_batch_loss_idx.flush() # validation and checkpointing if (epoch_iter + 1) % checkpoint_interval == 0 or (epoch_iter + 1) == len(train_data_loader): # initialize the total validation loss total_val_loss = 0 # iterate over the validation data loader for schedules_batch, makespans_batch in (pbar2:=tqdm( val_data_loader, desc=f"Validation {epoch+(epoch_iter+1)/len(train_data_loader):.2f}", )): with open(os.path.join(output_dir, "train_pbar_val.log"), "w") as f: f.write(str(pbar2)) # move the batch to the device schedules_batch = schedules_batch.to(device) makespans_batch = makespans_batch.to(device) # compute the validation loss without gradient tracking, and update the total validation loss with torch.no_grad(): makespans, loss = train_model(schedules_batch, makespans_batch) total_val_loss += loss.item() * schedules_batch.size(0) # ====== with open(os.path.join(output_dir, "train_pbar_val.log"), "w") as f: f.write(str(pbar2)) # compute the total validation loss (averaging over the dataset) total_val_loss /= len(val_data_loader.dataset) val_loss_idx = epoch * math.ceil(len(train_data_loader)/checkpoint_interval) + epoch_iter // checkpoint_interval val_losses[val_loss_idx] = total_val_loss last_val_loss_idx[...] = val_loss_idx val_losses.flush() last_val_loss_idx.flush() # early stopping check if total_val_loss < best_val_loss: best_val_loss = total_val_loss improved_this_epoch = True # save the checkpoint checkpoint = { "epoch": epoch, "epoch_iter": epoch_iter, "model_state_dict": train_model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "validation_loss": total_val_loss, "time": time.strftime("%Y_%m_%d_%H_%M_%S"), "best_val_loss": best_val_loss, "patience_counter": patience_counter, "improved_this_epoch": improved_this_epoch, } torch.save( checkpoint, os.path.join(output_dir, "checkpoints", "last_checkpoint.pth") ) if not save_only_last_checkpoint: torch.save( checkpoint, os.path.join(output_dir, "checkpoints", f"checkpoint_epoch_{epoch+(epoch_iter+1)/len(train_data_loader):.2f}.pth") ) if best_val_loss == total_val_loss: torch.save( checkpoint, os.path.join(output_dir, "checkpoints", "best_checkpoint.pth") ) # ====== # ====== # ====== with open(os.path.join(output_dir, "train_pbar_epoch.log"), "w") as f: f.write(str(pbar)) # set the start_epoch_iter to 0 for the next epoch start_epoch_iter = 0 # check if early stopping should be triggered at the end of the epoch if improved_this_epoch: patience_counter = 0 else: patience_counter += 1 if patience_counter >= early_stopping_patience: loguru.logger.info(f"Early stopping triggered! Validation loss hasn't improved for {early_stopping_patience} epochs.") early_stop = True # ====== # log the best validation loss loguru.logger.info(f"Best validation loss: {best_val_loss:.4f}") # create experiment termination flag file with open(os.path.join(output_dir, ".terminated_phase1"), "w") as f: pass # ====== # ====== if __name__ == "__main__": # parse arguments from argparse import ArgumentParser parser = ArgumentParser() parser.add_argument("--testing", type=bool, required=True) parser.add_argument("--seed", type=int, required=True) parser.add_argument("--data_dir", type=str, required=True) parser.add_argument("--n_embd", type=int, required=True) parser.add_argument("--n_head", type=int, required=True) parser.add_argument("--n_layer", type=int, required=True) parser.add_argument("--intermediate_schedules", type=bool, required=True) parser.add_argument("--dropout", type=float, required=True) parser.add_argument("--ff_width", type=int, required=True) parser.add_argument("--train_batch_size", type=int, required=True) parser.add_argument("--val_batch_size", type=int, required=True) parser.add_argument("--nb_epochs", type=int, required=True) parser.add_argument("--early_stopping_patience", type=int, required=True) parser.add_argument("--checkpoint_interval_ratio", type=float, required=True) parser.add_argument("--decay_lr", type=bool, required=True) parser.add_argument("--lr_partitions_ratios", type=lambda s: [float(item) for item in s.split(',')], help='Comma-separated list of floats that do not add up to 1 (e.g., 0.1,0.5,1)', required=True) parser.add_argument("--init_lr", type=float, required=True) parser.add_argument("--max_lr", type=float, required=True) parser.add_argument("--min_lr", type=float, required=True) parser.add_argument("--lr_warmup_iters_ratio", type=float, required=True) parser.add_argument("--lr_decay_iters_ratio", type=float, required=True) parser.add_argument("--beta1", type=float, required=True) parser.add_argument("--beta2", type=float, required=True) parser.add_argument("--weight_decay", type=float, required=True) parser.add_argument("--grad_clip", type=float, required=True) parser.add_argument("--compile", type=bool, required=True) parser.add_argument("--compile_mode", type=str, required=True) parser.add_argument("--save_only_last_checkpoint", type=bool, required=True) parser.add_argument("--output_dir", type=str, required=True) args = parser.parse_args() train( testing=args.testing, seed=args.seed, data_dir=args.data_dir, n_embd=args.n_embd, n_head=args.n_head, n_layer=args.n_layer, intermediate_schedules=args.intermediate_schedules, dropout=args.dropout, ff_width=args.ff_width, train_batch_size=args.train_batch_size, val_batch_size=args.val_batch_size, nb_epochs=args.nb_epochs, early_stopping_patience=args.early_stopping_patience, checkpoint_interval_ratio=args.checkpoint_interval_ratio, decay_lr=args.decay_lr, lr_partitions_ratios=args.lr_partitions_ratios, init_lr=args.init_lr, max_lr=args.max_lr, min_lr=args.min_lr, lr_warmup_iters_ratio=args.lr_warmup_iters_ratio, lr_decay_iters_ratio=args.lr_decay_iters_ratio, beta1=args.beta1, beta2=args.beta2, weight_decay=args.weight_decay, grad_clip=args.grad_clip, compile=args.compile, compile_mode=args.compile_mode, save_only_last_checkpoint=args.save_only_last_checkpoint, output_dir=args.output_dir, )