File size: 5,020 Bytes
12aef23 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
#! /usr/bin/bash
set -eux
train_device=0,1,2,3,4,5,6,7
eval_device=0
# xzq-fairseq
root_dir=$(dirname "$PWD")
src_lang=en
tgt_lang=de
threshold=0.7
data_name=wmt23
# pair_lang=${src_lang}-${tgt_lang}
task_name=${src_lang}2${tgt_lang}
data_dir=$root_dir/data/${task_name}/${threshold}
raw_data_dir=$data_dir/raw
trainable_data_dir=$data_dir/trainable_data
## eval&decode param
decode_max_tokens=2048
beam=5
nbest=1
lenpen=1.0
## common param
criterion=label_smoothed_cross_entropy
label_smoothing=0.1
seed=42
max_epoch=40
keep_last_epochs=1
keep_best_checkpoints=5
patience=5
num_workers=8
# specified param
conf_name=transformer_big
# Global Batch=卡数*max-tokens*梯度累计,对于训练数据较大的语种(train-set几十M),global batch在 100k tokens以上较好
if [ $conf_name == "transformer_big" ]; then
arch=transformer_vaswani_wmt_en_de_big
activation_fn=relu
encoder_ffn_embed_dim=4096
share_all_embeddings=0
share_decoder_input_output_embed=1
learing_rate=1e-3
warmup=4000
max_tokens=8192
weight_decay=0.0
dropout=0.3
gradient_accumulation_steps=4
else
echo "unknown conf_name=$conf_name"
exit
fi
model_dir=$root_dir/exps/$task_name/${threshold}/${conf_name}_${data_name}
mkdir -p $model_dir
cp ${BASH_SOURCE[0]} $model_dir
gpu_num=`echo "$train_device" | awk '{split($0,arr,",");print length(arr)}'`
export CUDA_VISIBLE_DEVICES=$train_device
cmd="fairseq-train $trainable_data_dir \
--distributed-world-size $gpu_num -s $src_lang -t $tgt_lang \
--arch $arch \
--fp16 \
--optimizer adam --clip-norm 0.0 \
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates $warmup \
--lr $learing_rate --adam-betas '(0.9, 0.98)' \
--weight-decay $weight_decay \
--dropout $dropout \
--criterion $criterion --label-smoothing $label_smoothing \
--max-epoch $max_epoch \
--max-tokens $max_tokens \
--update-freq $gradient_accumulation_steps \
--activation-fn $activation_fn \
--encoder-ffn-embed-dim $encoder_ffn_embed_dim \
--seed $seed \
--num-workers $num_workers \
--no-epoch-checkpoints \
--keep-last-epochs $keep_last_epochs \
--keep-best-checkpoints $keep_best_checkpoints \
--patience $patience \
--no-progress-bar \
--log-interval 100 \
--task "translation" \
--ddp-backend no_c10d \
--save-dir $model_dir \
--tensorboard-logdir $model_dir"
# add param
if [ $share_all_embeddings -eq 1 ]; then
cmd=${cmd}" --share-all-embeddings "
fi
if [ $share_decoder_input_output_embed -eq 1 ]; then
cmd=${cmd}" --share-decoder-input-output-embed "
fi
if [ ${max_update:=0} -ne 0 ]; then
cmd=${cmd}" --max-update $max_update"
fi
# run command
cur_time=`date +"%Y-%m-%d %H:%M:%S"`
echo "=============$cur_time===================" >> $model_dir/train.log
cmd="nohup ${cmd} >> $model_dir/train.log 2>&1 &"
eval $cmd
# wait
# ### decode
# checkpoint_path=$model_dir/checkpoint_best.pt
# save_dir=$model_dir/decode_result
# mkdir -p $save_dir
# cp ${BASH_SOURCE[0]} $save_dir
# declare -A gen_subset_dict
# gen_subset_dict=([test]=flores [test1]=wmt22 [test2]=wmt23)
# for gen_subset in ${!gen_subset_dict[*]}
# do
# decode_file=$save_dir/decode_${gen_subset_dict[$gen_subset]}_beam${beam}_lenpen${lenpen}.$tgt_lang
# pure_file=$save_dir/pure_decode_${gen_subset_dict[$gen_subset]}_beam${beam}_lenpen${lenpen}.$tgt_lang
# CUDA_VISIBLE_DEVICES=$eval_device fairseq-generate \
# $trainable_data_dir \
# -s $src_lang -t $tgt_lang \
# --user-dir $user_dir \
# --gen-subset $gen_subset \
# --path $checkpoint_path \
# --max-tokens $decode_max_tokens \
# --beam $beam \
# --nbest $nbest \
# --lenpen $lenpen \
# --seed $seed \
# --remove-bpe | tee $decode_file
# ### eval
# # purify file
# grep ^H $decode_file | LC_ALL=C sort -V | cut -f3- | perl $root_dir/mosesdecoder/scripts/tokenizer/detokenizer.perl -l $tgt_lang > $pure_file
# eval_file=$model_dir/eval_${gen_subset_dict[$gen_subset]}.log
# cur_time=`date +"%Y-%m-%d %H:%M:%S"`
# echo "=============$cur_time===================" >> $eval_file
# echo $checkpoint_path >> $eval_file
# tail -n1 $decode_file >> $eval_file # multi-bleu
# # get score
# src_file=$raw_data_dir/test.${gen_subset_dict[$gen_subset]}.$src_lang
# ref_file=$raw_data_dir/test.${gen_subset_dict[$gen_subset]}.$tgt_lang
# sacrebleu_file=$save_dir/sacrebleu.${gen_subset_dict[$gen_subset]}.beam${beam}_lenpen${lenpen}
# comet22_file=$save_dir/comet22.${gen_subset_dict[$gen_subset]}.beam${beam}_lenpen${lenpen}
# sacrebleu $ref_file -i $pure_file -w 2 >> $eval_file
# comet-score -s $src_file -t $pure_file -r $ref_file --model $root_dir/wmt22-comet-da/checkpoints/model.ckpt | tee $comet22_file
# echo "Comet22 Score" >> $eval_file
# tail -n1 $comet22_file >> $eval_file # 只取平均comet分
# echo -e "decode finished! \n decode tokenized file in $decode_file \n detokenized file in $pure_file \n sacrebleu file in $eval_file"
# done
|