|
|
#!/usr/bin/env bash |
|
|
|
|
|
|
|
|
BATCH_SIZE=${1:-240} |
|
|
AMP=${2:-true} |
|
|
NUM_EPOCHS=${3:-130} |
|
|
LEARNING_RATE=${4:-0.01} |
|
|
WEIGHT_DECAY=${5:-0.1} |
|
|
|
|
|
|
|
|
|
|
|
TASK=homo |
|
|
|
|
|
python -m torch.distributed.run --nnodes=1 --nproc_per_node=gpu --max_restarts 0 --module \ |
|
|
se3_transformer.runtime.training \ |
|
|
--amp "$AMP" \ |
|
|
--batch_size "$BATCH_SIZE" \ |
|
|
--epochs "$NUM_EPOCHS" \ |
|
|
--lr "$LEARNING_RATE" \ |
|
|
--min_lr 0.00001 \ |
|
|
--weight_decay "$WEIGHT_DECAY" \ |
|
|
--use_layer_norm \ |
|
|
--norm \ |
|
|
--save_ckpt_path model_qm9.pth \ |
|
|
--precompute_bases \ |
|
|
--seed 42 \ |
|
|
--task "$TASK" |