from train import train params = [ { "testing": False, "seed": 97, "data_dir": f"../datasets/exhaustive_{nb_jobs}_{nb_machines}/top_{top_k}", "n_embd": 64, "n_head": 4, "n_layer": 2, "ff_width": 4, "intermediate_schedules": True, "train_batch_size": 128, "val_batch_size": 256, "nb_epochs": 5, "early_stopping_patience": 15, "dropout": 0.0, "checkpoint_interval_ratio": 1.0, "decay_lr": True, "lr_partitions_ratios": [0.66], "init_lr": 1e-4, "max_lr": 1e-3, "min_lr": 5e-5, "lr_warmup_iters_ratio": 0.1, "lr_decay_iters_ratio": 0.95, "beta1": 0.9, "beta2": 0.95, "weight_decay": 5e-0, "grad_clip": 1.0, "compile": "", "compile_mode": "default", "save_only_last_checkpoint": False, "output_dir": f"../datasets/exhaustive_{nb_jobs}_{nb_machines}/top_{top_k}/train_Sm_Wd5e-0", } for nb_jobs in [7, 8, 9] for nb_machines in [2, 3, 4, 5, 6] for top_k in [0, 1, 2, 3, 4] ] for param in params: train(**param)