| import itertools
|
| import random
|
| import os
|
| import time
|
| import torch
|
| import math
|
| import inspect
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| import numpy as np
|
| import json
|
| from tqdm import tqdm
|
| import loguru
|
| import shutil
|
|
|
|
|
|
|
|
|
|
|
| class Head(nn.Module):
|
| """One head of self-attention."""
|
|
|
| def __init__(self, head_size, n_embd, dropout, block_size):
|
| super().__init__()
|
| self.head_size = head_size
|
| self.key = nn.Linear(n_embd, head_size, bias=False)
|
| self.query = nn.Linear(n_embd, head_size, bias=False)
|
| self.value = nn.Linear(n_embd, head_size, bias=False)
|
| self.pos_embeds = nn.Embedding(block_size, head_size)
|
| self.attn_dropout = nn.Dropout(dropout)
|
| self.register_buffer(
|
| "pos_embeds_indices",
|
| torch.arange(0, block_size, 1, dtype=torch.int64)
|
| )
|
| self.register_buffer(
|
| "causal_mask",
|
| (torch.tril(torch.ones((block_size, block_size), dtype=torch.bool)) == False)
|
| )
|
|
|
|
|
|
|
| def forward(self, x):
|
| B, T, C = x.shape
|
| k = self.key(x)
|
| q = self.query(x)
|
| v = self.value(x)
|
| r = self.pos_embeds(self.pos_embeds_indices[-T:])
|
| s_rel = q @ r.t()
|
| s_rel = F.pad(s_rel, (1, 0))
|
| s_rel = s_rel.reshape(B, s_rel.shape[-1], s_rel.shape[-2])
|
| s_rel = s_rel[:, 1:, :]
|
| attn = ((q @ k.transpose(-2, -1)) + s_rel) * (self.head_size**-0.5)
|
| attn = attn.masked_fill(self.causal_mask[:T, :T], float("-inf"))
|
| attn = F.softmax(attn, dim=-1)
|
| attn = self.attn_dropout(attn)
|
| out = attn @ v
|
|
|
| return out
|
|
|
|
|
|
|
|
|
| class MultiHeadAttention(nn.Module):
|
| """Multiple heads of self-attention in parallel."""
|
|
|
| def __init__(self, num_heads, head_size, n_embd, dropout, block_size):
|
| super().__init__()
|
| self.heads = nn.ModuleList([Head(head_size, n_embd, dropout, block_size) for _ in range(num_heads)])
|
| self.proj = nn.Linear(n_embd, n_embd)
|
| self.dropout = nn.Dropout(dropout)
|
|
|
|
|
|
|
| def forward(self, x):
|
| out = torch.cat([h(x) for h in self.heads], dim=-1)
|
| out = self.dropout(self.proj(out))
|
| return out
|
|
|
|
|
|
|
|
|
| class FeedForward(nn.Module):
|
| """A simple linear layer followed by a non-linearity."""
|
|
|
| def __init__(self, n_embd, ff_width, dropout):
|
| super().__init__()
|
| self.net = nn.Sequential(
|
| nn.Linear(n_embd, ff_width * n_embd, bias=False),
|
| nn.SiLU(),
|
| nn.Linear(ff_width * n_embd, n_embd, bias=False),
|
| nn.Dropout(dropout),
|
| )
|
|
|
|
|
|
|
| def forward(self, x):
|
| return self.net(x)
|
|
|
|
|
|
|
|
|
| class Block(nn.Module):
|
| """Transformer block: communication followed by feedforward."""
|
|
|
| def __init__(self, n_embd, n_head, ff_width, dropout, block_size):
|
| super().__init__()
|
| head_size = n_embd // n_head
|
| self.sa = MultiHeadAttention(n_head, head_size, n_embd, dropout, block_size)
|
| self.ffwd = FeedForward(n_embd, ff_width, dropout)
|
| self.ln1 = nn.RMSNorm(normalized_shape=n_embd, eps=1e-5)
|
| self.ln2 = nn.RMSNorm(normalized_shape=n_embd, eps=1e-5)
|
|
|
|
|
|
|
| def forward(self, x):
|
| x = x + self.sa(self.ln1(x))
|
| x = x + self.ffwd(self.ln2(x))
|
| return x
|
|
|
|
|
|
|
|
|
| class GPT(nn.Module):
|
| """Transformer architecture for flowshop scheduling."""
|
|
|
| def __init__(self, intermediate_schedules, vocab_size, n_embd, n_head, n_layer, dropout, ff_width, block_size):
|
| super().__init__()
|
| self.jobs_embedding_table = nn.Embedding(vocab_size, n_embd)
|
| self.blocks = nn.Sequential(*[Block(n_embd=n_embd, n_head=n_head, ff_width=ff_width, dropout=dropout, block_size=block_size) for _ in range(n_layer)])
|
| self.ln_f = nn.LayerNorm(n_embd)
|
| self.makespan_head = nn.Linear(n_embd, 1)
|
| self.intermediate_schedules = intermediate_schedules
|
|
|
|
|
|
|
| def forward(self, idx, targets):
|
| B, T = idx.shape
|
|
|
|
|
| x = self.jobs_embedding_table(idx)
|
| x = self.blocks(x)
|
| x = self.ln_f(x)
|
| if not self.intermediate_schedules:
|
| x = x[:, -1, :].squeeze()
|
| makespans = self.makespan_head(x).squeeze()
|
| targets = targets[:,-1].squeeze()
|
| else:
|
| x = x.reshape(B*T, -1)
|
| makespans = self.makespan_head(x).squeeze()
|
| targets = targets.reshape(B*T)
|
|
|
| loss = F.smooth_l1_loss(makespans, targets, beta=0.1)
|
| return makespans, loss
|
|
|
|
|
|
|
| def generate(self, job_embeds):
|
| """
|
| Function to generate Transformer predicted makespan from input schedule of job embeddings.
|
| Input:
|
| - job_embeds: a (T,C,) tensor of job embeddings
|
| Output:
|
| - makespan: a scalar tensor of the predicted makespan
|
| """
|
|
|
| job_embeds = job_embeds.unsqueeze(0)
|
| x = self.blocks(job_embeds)
|
| x = self.ln_f(x)
|
| makespans = self.makespan_head(x).squeeze()
|
| makespan = makespans[-1]
|
|
|
| return makespan
|
|
|
|
|
|
|
| if __name__ == "__main__":
|
|
|
|
|
| from argparse import ArgumentParser
|
| parser = ArgumentParser()
|
| parser.add_argument("--testing", type=bool, required=True)
|
| parser.add_argument("--seed", type=int, required=True)
|
| parser.add_argument("--data_dir", type=str, required=True)
|
| parser.add_argument("--n_embd", type=int, required=True)
|
| parser.add_argument("--n_head", type=int, required=True)
|
| parser.add_argument("--n_layer", type=int, required=True)
|
| parser.add_argument("--intermediate_schedules", type=bool, required=True)
|
| parser.add_argument("--dropout", type=float, required=True)
|
| parser.add_argument("--ff_width", type=int, required=True)
|
| parser.add_argument("--train_batch_size", type=int, required=True)
|
| parser.add_argument("--val_batch_size", type=int, required=True)
|
| parser.add_argument("--nb_epochs", type=int, required=True)
|
| parser.add_argument("--early_stopping_patience", type=int, required=True)
|
| parser.add_argument("--checkpoint_interval_ratio", type=float, required=True)
|
| parser.add_argument("--decay_lr", type=bool, required=True)
|
| parser.add_argument("--lr_partitions_ratios", type=lambda s: [float(item) for item in s.split(',')], help='Comma-separated list of floats that do not add up to 1 (e.g., 0.1,0.5,1)', required=True)
|
| parser.add_argument("--init_lr", type=float, required=True)
|
| parser.add_argument("--max_lr", type=float, required=True)
|
| parser.add_argument("--min_lr", type=float, required=True)
|
| parser.add_argument("--lr_warmup_iters_ratio", type=float, required=True)
|
| parser.add_argument("--lr_decay_iters_ratio", type=float, required=True)
|
| parser.add_argument("--beta1", type=float, required=True)
|
| parser.add_argument("--beta2", type=float, required=True)
|
| parser.add_argument("--weight_decay", type=float, required=True)
|
| parser.add_argument("--grad_clip", type=float, required=True)
|
| parser.add_argument("--compile", type=bool, required=True)
|
| parser.add_argument("--compile_mode", type=str, required=True)
|
| parser.add_argument("--save_only_last_checkpoint", type=bool, required=True)
|
| parser.add_argument("--output_dir", type=str, required=True)
|
| args = parser.parse_args()
|
|
|
| os.makedirs(args.output_dir, exist_ok=True)
|
|
|
|
|
| if not args.testing:
|
| if os.path.exists(os.path.join(args.output_dir, ".terminated_phase1")):
|
| print("Phase 1 already terminated. Exiting...")
|
| exit()
|
|
|
| if not os.path.exists(os.path.join(args.output_dir, "viz_train.ipynb")):
|
| shutil.copy("viz_train.ipynb", args.output_dir)
|
|
|
| else:
|
|
|
|
|
| files_to_delete = [
|
| "train.log",
|
| "train_parameters.json",
|
| "batch_losses.npy",
|
| "last_batch_loss_idx.npy",
|
| "val_losses.npy",
|
| "last_val_loss_idx.npy",
|
| "train_pbar_epoch.log",
|
| "train_pbar_val.log",
|
| ".terminated_phase1",
|
| "viz_train.ipynb",
|
| ]
|
| for f in files_to_delete:
|
| f_path = os.path.join(args.output_dir, f)
|
| if os.path.exists(f_path): os.remove(f_path)
|
|
|
| checkpoints_dir = os.path.join(args.output_dir, "checkpoints")
|
| if os.path.exists(checkpoints_dir): shutil.rmtree(checkpoints_dir)
|
| shutil.copy("viz_train.ipynb", args.output_dir)
|
|
|
|
|
|
|
| assert torch.cuda.is_available(), "This code requires a GPU to run. Please run it on a machine with a CUDA-compatible GPU."
|
| device = "cuda"
|
|
|
|
|
| loguru.logger.add(os.path.join(args.output_dir, "train.log"))
|
|
|
|
|
| torch.manual_seed(args.seed)
|
| random.seed(args.seed)
|
| np.random.seed(args.seed)
|
|
|
|
|
| with open(os.path.join(args.data_dir, "metadata.json"), "r") as f:
|
| metadata = json.load(f)
|
| block_size = metadata["nb_jobs"]
|
| vocab_size = metadata["nb_jobs"]
|
| n_embd = args.n_embd
|
| n_head = args.n_head
|
| assert n_embd % n_head == 0
|
| n_layer = args.n_layer
|
| intermediate_schedules = args.intermediate_schedules
|
| ff_width = args.ff_width
|
|
|
|
|
| train_batch_size = args.train_batch_size
|
| val_batch_size = args.val_batch_size
|
| nb_epochs = args.nb_epochs
|
| early_stopping_patience = args.early_stopping_patience
|
| dropout = args.dropout
|
|
|
|
|
| class FlowshopDataset(torch.utils.data.Dataset):
|
| """Dataset for flowshop scheduling problem."""
|
|
|
| def __init__(self, dataset_path, split, load_in_memory=True):
|
| metadata_path = os.path.join(dataset_path, "metadata.json")
|
| with open(metadata_path, "r") as f:
|
| metadata = json.load(f)
|
| nb_samples = metadata["nb_samples"]
|
| nb_jobs = metadata["nb_jobs"]
|
| schedules_path = os.path.join(dataset_path, f"schedules_{split}.npy")
|
| schedules = np.lib.format.open_memmap(
|
| schedules_path, dtype=np.int32, mode="r", shape=(nb_samples, nb_jobs)
|
| )
|
| loguru.logger.info(f"Loaded schedules from {schedules_path} with shape {schedules.shape}")
|
| makespans_path = os.path.join(dataset_path, f"makespans_{split}.npy")
|
| makespans = np.lib.format.open_memmap(
|
| makespans_path, dtype=np.float32, mode="r", shape=(nb_samples,)
|
| )
|
| loguru.logger.info(f"Loaded makespans from {makespans_path} with shape {makespans.shape}")
|
| if load_in_memory:
|
| schedules = np.array(schedules)
|
| makespans = np.array(makespans)
|
| self.schedules = torch.from_numpy(schedules)
|
| self.makespans = torch.from_numpy(makespans)
|
|
|
|
|
|
|
| def __len__(self):
|
| return len(self.schedules)
|
|
|
|
|
|
|
| def __getitem__(self, idx):
|
| return self.schedules[idx], self.makespans[idx]
|
|
|
|
|
|
|
| train_dataset = FlowshopDataset(args.data_dir, split="train", load_in_memory=False)
|
| train_data_loader = torch.utils.data.DataLoader(
|
| train_dataset,
|
| batch_size=train_batch_size,
|
| shuffle=False,
|
| drop_last=False,
|
| )
|
|
|
| for schedules, makespans in train_data_loader:
|
| loguru.logger.info(f"schedules.shape: {schedules.shape}")
|
| loguru.logger.info(f"makespans.shape: {makespans.shape}")
|
| break
|
| nb_iters = nb_epochs * len(train_data_loader)
|
| checkpoint_interval = int(args.checkpoint_interval_ratio * len(train_data_loader))
|
| decay_lr = args.decay_lr
|
| lr_partitions_ratios = args.lr_partitions_ratios + [None]
|
| lr_partitions_iters = [int(r * nb_iters) for r in lr_partitions_ratios[:-1]]
|
| lr_partitions_iters = lr_partitions_iters + [nb_iters - sum(lr_partitions_iters)]
|
| assert sum(lr_partitions_iters) == nb_iters
|
| init_lr = args.init_lr
|
| max_lr = args.max_lr
|
| min_lr = args.min_lr
|
| lr_warmup_iters_ratio = args.lr_warmup_iters_ratio
|
| lr_decay_iters_ratio = args.lr_decay_iters_ratio
|
| beta1 = args.beta1
|
| beta2 = args.beta2
|
| weight_decay = args.weight_decay
|
| grad_clip = args.grad_clip
|
| compile = args.compile
|
| compile_mode = args.compile_mode
|
| save_only_last_checkpoint = args.save_only_last_checkpoint
|
|
|
|
|
| def human_readable(num):
|
| """Define function to make large numbers of parameters human-readable."""
|
|
|
| magnitude = 0
|
| while abs(num) >= 1000:
|
| magnitude += 1
|
| num /= 1000.0
|
| return "%.0f%s" % (num, ["", "K", "M", "G", "T", "P"][magnitude])
|
|
|
|
|
|
|
| def get_lr(it):
|
| """Get the learning rate for the current iteration."""
|
|
|
| i = 0
|
| tmp_it = it - lr_partitions_iters[i]
|
| while tmp_it >= 0:
|
| i += 1
|
| tmp_it -= lr_partitions_iters[i]
|
|
|
| lr_partition_iters = lr_partitions_iters[i]
|
| warmup_iters = int(lr_partition_iters * lr_warmup_iters_ratio)
|
| lr_decay_iters = int(lr_partition_iters * lr_decay_iters_ratio)
|
|
|
| it = it - sum(lr_partitions_iters[:i])
|
|
|
|
|
| if it < warmup_iters:
|
| return (it / warmup_iters) * (max_lr - init_lr) + init_lr
|
|
|
|
|
| if it > lr_decay_iters:
|
| return min_lr
|
|
|
|
|
| decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
|
| assert 0 <= decay_ratio <= 1
|
| coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
|
| return min_lr + coeff * (max_lr - min_lr)
|
|
|
|
|
|
|
|
|
| loguru.logger.info(f"data_dir: {args.data_dir}")
|
| loguru.logger.info(f"block_size: {block_size}")
|
| loguru.logger.info(f"vocab_size: {vocab_size}")
|
| loguru.logger.info(f"n_embd: {n_embd}")
|
| loguru.logger.info(f"n_head: {n_head}")
|
| loguru.logger.info(f"n_layer: {n_layer}")
|
| loguru.logger.info(f"ff_width: {ff_width}")
|
| loguru.logger.info(f"train_batch_size: {train_batch_size}")
|
| loguru.logger.info(f"val_batch_size: {val_batch_size}")
|
| loguru.logger.info(f"dropout: {dropout}")
|
| loguru.logger.info(f"nb_epochs: {nb_epochs}")
|
| loguru.logger.info(f"early_stopping_patience: {early_stopping_patience}")
|
| loguru.logger.info(f"nb_iters: {nb_iters}")
|
| loguru.logger.info(f"checkpoint_interval: {checkpoint_interval}")
|
| loguru.logger.info(f"decay_lr: {decay_lr}")
|
| loguru.logger.info(f"lr_partitions_ratios: {lr_partitions_ratios}")
|
| loguru.logger.info(f"lr_partitions_iters: {lr_partitions_iters}")
|
| loguru.logger.info(f"init_lr: {init_lr}")
|
| loguru.logger.info(f"max_lr: {max_lr}")
|
| loguru.logger.info(f"min_lr: {min_lr}")
|
| loguru.logger.info(f"lr_warmup_iters_ratio: {lr_warmup_iters_ratio}")
|
| loguru.logger.info(f"lr_decay_iters_ratio: {lr_decay_iters_ratio}")
|
| loguru.logger.info(f"beta1: {beta1}")
|
| loguru.logger.info(f"beta2: {beta2}")
|
| loguru.logger.info(f"weight_decay: {weight_decay}")
|
| loguru.logger.info(f"grad_clip: {grad_clip}")
|
| loguru.logger.info(f"compile: {compile}")
|
| loguru.logger.info(f"compile_mode: {compile_mode}")
|
| loguru.logger.info(f"intermediate_schedules: {intermediate_schedules}")
|
| loguru.logger.info(f"save_only_last_checkpoint: {save_only_last_checkpoint}")
|
|
|
|
|
| import json
|
| train_params = {
|
| "data_dir": args.data_dir,
|
| "block_size": block_size,
|
| "vocab_size": vocab_size,
|
| "n_embd": n_embd,
|
| "n_head": n_head,
|
| "n_layer": n_layer,
|
| "ff_width": ff_width,
|
| "train_batch_size": train_batch_size,
|
| "val_batch_size": val_batch_size,
|
| "dropout": dropout,
|
| "nb_epochs": nb_epochs,
|
| "early_stopping_patience": early_stopping_patience,
|
| "nb_iters": nb_iters,
|
| "checkpoint_interval": checkpoint_interval,
|
| "decay_lr": decay_lr,
|
| "lr_partitions_ratios": lr_partitions_ratios,
|
| "lr_partitions_iters": lr_partitions_iters,
|
| "init_lr": init_lr,
|
| "max_lr": max_lr,
|
| "min_lr": min_lr,
|
| "lr_warmup_iters_ratio": lr_warmup_iters_ratio,
|
| "lr_decay_iters_ratio": lr_decay_iters_ratio,
|
| "beta1": beta1,
|
| "beta2": beta2,
|
| "weight_decay": weight_decay,
|
| "grad_clip": grad_clip,
|
| "compile": compile,
|
| "compile_mode": compile_mode,
|
| "intermediate_schedules": intermediate_schedules,
|
| "save_only_last_checkpoint": save_only_last_checkpoint,
|
| }
|
| with open(os.path.join(args.output_dir, "train_parameters.json"), "w") as f: json.dump(train_params, f, indent=4)
|
|
|
|
|
| try:
|
| last_checkpoint = torch.load(os.path.join(args.output_dir, "checkpoints", "last_checkpoint.pth"))
|
| start_epoch = last_checkpoint["epoch"]
|
| start_epoch_iter = last_checkpoint["epoch_iter"] + 1
|
| model_state_dict = last_checkpoint["model_state_dict"]
|
| optimizer_state_dict = last_checkpoint["optimizer_state_dict"]
|
| best_val_loss = last_checkpoint["best_val_loss"]
|
| patience_counter = last_checkpoint["patience_counter"]
|
| improved_this_epoch = last_checkpoint["improved_this_epoch"]
|
| except FileNotFoundError:
|
| os.makedirs(os.path.join(args.output_dir, "checkpoints"), exist_ok=True)
|
| start_epoch = 0
|
| start_epoch_iter = 0
|
| model_state_dict = None
|
| optimizer_state_dict = None
|
| best_val_loss = float("inf")
|
| patience_counter = 0
|
| improved_this_epoch = False
|
|
|
|
|
| model = GPT(
|
| intermediate_schedules=intermediate_schedules,
|
| vocab_size=vocab_size,
|
| n_embd=n_embd,
|
| n_head=n_head,
|
| n_layer=n_layer,
|
| dropout=dropout,
|
| ff_width=ff_width,
|
| block_size=block_size
|
| ).to(device)
|
| loguru.logger.info(f"The model has {human_readable(sum(p.numel() for p in model.parameters() if p.requires_grad))} trainable parameters")
|
| train_model = model
|
| if model_state_dict is not None:
|
| train_model.load_state_dict(model_state_dict)
|
| if compile:
|
| train_model = torch.compile(train_model, mode=compile_mode)
|
|
|
|
|
| param_dict = {pn: p for pn, p in train_model.named_parameters()}
|
|
|
| param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
|
|
|
|
|
| decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
|
| nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
|
| optim_groups = [
|
| {"params": decay_params, "weight_decay": weight_decay},
|
| {"params": nodecay_params, "weight_decay": 0.0},
|
| ]
|
| num_decay_params = sum(p.numel() for p in decay_params)
|
| num_nodecay_params = sum(p.numel() for p in nodecay_params)
|
| loguru.logger.info(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
|
| loguru.logger.info(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
|
|
|
| fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters
|
| use_fused = fused_available and ("cuda" in device)
|
| loguru.logger.info(f"using fused AdamW: {use_fused}")
|
| extra_args = dict(fused=True) if use_fused else dict()
|
| optimizer = torch.optim.AdamW(optim_groups, lr=init_lr, betas=(beta1, beta2), **extra_args)
|
|
|
| if optimizer_state_dict is not None:
|
| optimizer.load_state_dict(optimizer_state_dict)
|
|
|
|
|
| torch.set_float32_matmul_precision("high")
|
|
|
|
|
| batch_losses_path = os.path.join(args.output_dir, "batch_losses.npy")
|
| last_batch_loss_idx_path = os.path.join(args.output_dir, "last_batch_loss_idx.npy")
|
| val_losses_path = os.path.join(args.output_dir, "val_losses.npy")
|
| last_val_loss_idx_path = os.path.join(args.output_dir, "last_val_loss_idx.npy")
|
|
|
| try:
|
| batch_losses = np.lib.format.open_memmap(batch_losses_path, mode="r+", dtype=np.float32, shape=(nb_iters,))
|
| last_batch_loss_idx = np.lib.format.open_memmap(last_batch_loss_idx_path, mode="r+", dtype=np.int32, shape=())
|
| val_losses = np.lib.format.open_memmap(val_losses_path, mode="r+", dtype=np.float32, shape=(nb_epochs * math.ceil(len(train_data_loader)/checkpoint_interval),))
|
| last_val_loss_idx = np.lib.format.open_memmap(last_val_loss_idx_path, mode="r+", dtype=np.int32, shape=())
|
| except FileNotFoundError:
|
| batch_losses = np.lib.format.open_memmap(batch_losses_path, mode="w+", dtype=np.float32, shape=(nb_iters,))
|
| last_batch_loss_idx = np.lib.format.open_memmap(last_batch_loss_idx_path, mode="w+", dtype=np.int32, shape=())
|
| val_losses = np.lib.format.open_memmap(val_losses_path, mode="w+", dtype=np.float32, shape=(nb_epochs * math.ceil(len(train_data_loader)/checkpoint_interval),))
|
| last_val_loss_idx = np.lib.format.open_memmap(last_val_loss_idx_path, mode="w+", dtype=np.int32, shape=())
|
| last_batch_loss_idx[...] = 0
|
| last_batch_loss_idx.flush()
|
| last_val_loss_idx[...] = 0
|
| last_val_loss_idx.flush()
|
|
|
|
|
| val_data_loader = torch.utils.data.DataLoader(
|
| FlowshopDataset(args.data_dir, split="val", load_in_memory=True),
|
| batch_size=val_batch_size,
|
| shuffle=False,
|
| )
|
|
|
|
|
| early_stop = False
|
| for epoch in range(start_epoch, nb_epochs):
|
| if early_stop: break
|
| if start_epoch_iter == 0: improved_this_epoch = False
|
|
|
|
|
|
|
| generator = torch.Generator()
|
| generator.manual_seed(args.seed + epoch)
|
| train_sampler = torch.utils.data.RandomSampler(
|
| train_dataset,
|
| generator=generator
|
| )
|
| train_data_loader = torch.utils.data.DataLoader(
|
| train_dataset,
|
| batch_size=train_batch_size,
|
| sampler=train_sampler,
|
| drop_last=False if not compile else True,
|
| )
|
| train_data_loader_iterator = iter(train_data_loader)
|
|
|
| train_data_loader_iterator = itertools.islice(
|
| train_data_loader_iterator,
|
| start_epoch_iter,
|
| None,
|
| )
|
|
|
|
|
| for epoch_iter, (schedules_batch, makespans_batch) in (pbar:=tqdm(
|
| enumerate(train_data_loader_iterator, start=start_epoch_iter),
|
| total=len(train_data_loader),
|
| initial=start_epoch_iter,
|
| desc=f"Epoch {epoch+1}/{nb_epochs}",
|
| )):
|
| with open(os.path.join(args.output_dir, "train_pbar_epoch.log"), "w") as f: f.write(str(pbar))
|
|
|
|
|
| schedules_batch = schedules_batch.to(device)
|
| makespans_batch = makespans_batch.to(device)
|
|
|
|
|
| optimizer.zero_grad(set_to_none=True)
|
|
|
|
|
| makespans, loss = train_model(schedules_batch, makespans_batch)
|
|
|
|
|
| loss.backward()
|
|
|
|
|
| if grad_clip != 0.0: torch.nn.utils.clip_grad_norm_(train_model.parameters(), grad_clip)
|
|
|
|
|
| current_iter = epoch * len(train_data_loader) + epoch_iter
|
| lr = get_lr(current_iter) if decay_lr else init_lr
|
| for param_group in optimizer.param_groups: param_group["lr"] = lr
|
|
|
|
|
| optimizer.step()
|
|
|
|
|
| batch_losses[current_iter] = loss.item()
|
| last_batch_loss_idx[...] = current_iter
|
| batch_losses.flush()
|
| last_batch_loss_idx.flush()
|
|
|
|
|
| if (epoch_iter + 1) % checkpoint_interval == 0 or (epoch_iter + 1) == len(train_data_loader):
|
|
|
|
|
| total_val_loss = 0
|
|
|
|
|
| for schedules_batch, makespans_batch in (pbar2:=tqdm(
|
| val_data_loader,
|
| desc=f"Validation {epoch+(epoch_iter+1)/len(train_data_loader):.2f}",
|
| )):
|
| with open(os.path.join(args.output_dir, "train_pbar_val.log"), "w") as f: f.write(str(pbar2))
|
|
|
|
|
| schedules_batch = schedules_batch.to(device)
|
| makespans_batch = makespans_batch.to(device)
|
|
|
|
|
| with torch.no_grad():
|
| makespans, loss = train_model(schedules_batch, makespans_batch)
|
| total_val_loss += loss.item() * schedules_batch.size(0)
|
|
|
| with open(os.path.join(args.output_dir, "train_pbar_val.log"), "w") as f: f.write(str(pbar2))
|
|
|
|
|
| total_val_loss /= len(val_data_loader.dataset)
|
| val_loss_idx = epoch * math.ceil(len(train_data_loader)/checkpoint_interval) + epoch_iter // checkpoint_interval
|
| val_losses[val_loss_idx] = total_val_loss
|
| last_val_loss_idx[...] = val_loss_idx
|
| val_losses.flush()
|
| last_val_loss_idx.flush()
|
|
|
|
|
| if total_val_loss < best_val_loss:
|
| best_val_loss = total_val_loss
|
| improved_this_epoch = True
|
|
|
|
|
| checkpoint = {
|
| "epoch": epoch,
|
| "epoch_iter": epoch_iter,
|
| "model_state_dict": train_model.state_dict(),
|
| "optimizer_state_dict": optimizer.state_dict(),
|
| "validation_loss": total_val_loss,
|
| "time": time.strftime("%Y_%m_%d_%H_%M_%S"),
|
| "best_val_loss": best_val_loss,
|
| "patience_counter": patience_counter,
|
| "improved_this_epoch": improved_this_epoch,
|
| }
|
| torch.save(
|
| checkpoint,
|
| os.path.join(args.output_dir, "checkpoints", "last_checkpoint.pth")
|
| )
|
| if not save_only_last_checkpoint:
|
| torch.save(
|
| checkpoint,
|
| os.path.join(args.output_dir, "checkpoints", f"checkpoint_epoch_{epoch+(epoch_iter+1)/len(train_data_loader):.2f}.pth")
|
| )
|
| if best_val_loss == total_val_loss:
|
| torch.save(
|
| checkpoint,
|
| os.path.join(args.output_dir, "checkpoints", "best_checkpoint.pth")
|
| )
|
|
|
|
|
|
|
| with open(os.path.join(args.output_dir, "train_pbar_epoch.log"), "w") as f: f.write(str(pbar))
|
|
|
|
|
| start_epoch_iter = 0
|
|
|
|
|
| if improved_this_epoch:
|
| patience_counter = 0
|
| else:
|
| patience_counter += 1
|
| if patience_counter >= early_stopping_patience:
|
| loguru.logger.info(f"Early stopping triggered! Validation loss hasn't improved for {early_stopping_patience} epochs.")
|
| early_stop = True
|
|
|
|
|
|
|
| loguru.logger.info(f"Best validation loss: {best_val_loss:.4f}")
|
|
|
|
|
| with open(os.path.join(args.output_dir, ".terminated_phase1"), "w") as f:
|
| pass
|
|
|
| |