Spaces:
Runtime error
Runtime error
| # Truncated Backpropagation Through Time (BPTT) | |
| Truncated BPTT is a useful technique for training language models on very long | |
| sequences. Typically a long sequences is split into chunks and a language model | |
| is trained over the chunks sequentially. The LM may condition on previous | |
| chunks, but gradients only flow through the current chunk. This technique was | |
| the basis for the paper: [Transformer-XL: Attentive Language Models Beyond a | |
| Fixed-Length Context](https://arxiv.org/abs/1901.02860), which achieved | |
| state-of-the-art language modeling results at the time of publication. | |
| It is slightly tricky to implement Truncated BPTT efficiently in fairseq, since | |
| we need to iterate over the data sequentially and disable any batch shuffling | |
| logic. The code provided in this example illustrates how to implement Truncated | |
| BPTT in fairseq by overriding ``FairseqTask::get_batch_iterator`` to iterate | |
| over the data sequentially. Crucially, this example supports batching and | |
| multi-GPU (data parallel) training. | |
| ##### 0. Setup | |
| First, see the general [language modeling README](README.md) for instructions on | |
| preprocessing the WikiText-103 data. | |
| ##### 1. Train a Transformer-XL model on WikiText-103 | |
| We will train a 16-layer Transformer-XL model following the [hyperparameters | |
| used in the original | |
| paper](https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/run_wt103_base.sh). | |
| The following command assumes 4 GPUs, so that the total batch size is 60 | |
| sequences (15 x 4). Training should take ~24 hours on 4 V100 GPUs: | |
| ```bash | |
| CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \ | |
| --user-dir examples/truncated_bptt \ | |
| data-bin/wikitext-103/ \ | |
| --task truncated_bptt_lm --tokens-per-sample 150 \ | |
| --batch-size 15 --max-update 200000 \ | |
| --arch transformer_xl --n-layer 16 --d-model 410 --n-head 10 \ | |
| --d-head 41 --d-inner 2100 --dropout 0.1 --dropatt 0.0 --mem-len 150 \ | |
| --optimizer adam --clip-norm 0.25 \ | |
| --lr-scheduler cosine --warmup-updates 0 --min-lr 0.0 --lr 0.00025 \ | |
| --log-format json --log-interval 25 \ | |
| --fp16 | |
| ``` | |
| If training on a single GPU, set `--update-freq=4` to accumulate 4x gradients | |
| and simulate training on 4 GPUs. | |
| ##### 2. Evaluate | |
| ```bash | |
| fairseq-eval-lm data-bin/wikitext-103/ \ | |
| --path checkpoints/checkpoint_best.pt \ | |
| --user-dir examples/truncated_bptt/ \ | |
| --task truncated_bptt_lm \ | |
| --batch-size 1 --required-batch-size-multiple 1 \ | |
| --model-overrides '{"mem_len":640,"clamp_len":400,"same_length":True}' \ | |
| --tokens-per-sample 64 | |
| # ... | INFO | fairseq_cli.eval_lm | num. model params: 151123537 | |
| # ... | INFO | fairseq_cli.eval_lm | Evaluated 245569 tokens in 83.1s (2956.82 tokens/s) | |
| # ... | INFO | fairseq_cli.eval_lm | Loss (base 2): 4.5668, Perplexity: 23.70 | |
| # Compare to 24.0 test perplexity from the paper | |
| ``` | |
| *Note:* During training the model saw 150 tokens of context | |
| (``--tokens-per-sample=150``) and 150 extra memory tokens (``--mem-len=150``). | |
| During evaluation we measure perplexity on sequences of 64 tokens | |
| (``--tokens-per-sample=64``) and increase the memory length | |
| (``--model-overrides='{"mem_len":640}'``). These settings match the evaluation | |
| settings from [the original | |
| paper](https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/run_wt103_base.sh). | |