| from train import train | |
| params = [ | |
| { | |
| "testing": False, | |
| "seed": 97, | |
| "data_dir": f"../datasets/exhaustive_{nb_jobs}_{nb_machines}/top_{top_k}", | |
| "n_embd": 64, | |
| "n_head": 4, | |
| "n_layer": 2, | |
| "ff_width": 4, | |
| "intermediate_schedules": True, | |
| "train_batch_size": 128, | |
| "val_batch_size": 256, | |
| "nb_epochs": 5, | |
| "early_stopping_patience": 15, | |
| "dropout": 0.0, | |
| "checkpoint_interval_ratio": 1.0, | |
| "decay_lr": True, | |
| "lr_partitions_ratios": [0.66], | |
| "init_lr": 1e-4, | |
| "max_lr": 1e-3, | |
| "min_lr": 5e-5, | |
| "lr_warmup_iters_ratio": 0.1, | |
| "lr_decay_iters_ratio": 0.95, | |
| "beta1": 0.9, | |
| "beta2": 0.95, | |
| "weight_decay": 5e-0, | |
| "grad_clip": 1.0, | |
| "compile": "", | |
| "compile_mode": "default", | |
| "save_only_last_checkpoint": False, | |
| "output_dir": f"../datasets/exhaustive_{nb_jobs}_{nb_machines}/top_{top_k}/train_Sm_Wd5e-0", | |
| } | |
| for nb_jobs in [7, 8, 9] | |
| for nb_machines in [2, 3, 4, 5, 6] | |
| for top_k in [0, 1, 2, 3, 4] | |
| ] | |
| for param in params: | |
| train(**param) |