Add fairseq
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- fairseq/.github/ISSUE_TEMPLATE.md +3 -0
- fairseq/.github/ISSUE_TEMPLATE/bug_report.md +43 -0
- fairseq/.github/ISSUE_TEMPLATE/documentation.md +15 -0
- fairseq/.github/ISSUE_TEMPLATE/feature_request.md +24 -0
- fairseq/.github/ISSUE_TEMPLATE/how-to-question.md +33 -0
- fairseq/.github/PULL_REQUEST_TEMPLATE.md +16 -0
- fairseq/.github/stale.yml +30 -0
- fairseq/.github/workflows/build.yml +55 -0
- fairseq/.github/workflows/build_wheels.yml +41 -0
- fairseq/.gitignore +136 -0
- fairseq/.gitmodules +4 -0
- fairseq/CODE_OF_CONDUCT.md +77 -0
- fairseq/CONTRIBUTING.md +28 -0
- fairseq/LICENSE +21 -0
- fairseq/README.md +229 -0
- fairseq/build/lib.linux-x86_64-cpython-39/fairseq/version.py +1 -1
- fairseq/docs/Makefile +20 -0
- fairseq/docs/_static/theme_overrides.css +9 -0
- fairseq/docs/command_line_tools.rst +85 -0
- fairseq/docs/conf.py +134 -0
- fairseq/docs/criterions.rst +31 -0
- fairseq/docs/data.rst +58 -0
- fairseq/docs/docutils.conf +2 -0
- fairseq/docs/fairseq_logo.png +0 -0
- fairseq/docs/getting_started.rst +216 -0
- fairseq/docs/hydra_integration.md +284 -0
- fairseq/docs/index.rst +49 -0
- fairseq/docs/lr_scheduler.rst +34 -0
- fairseq/docs/make.bat +36 -0
- fairseq/docs/models.rst +104 -0
- fairseq/docs/modules.rst +9 -0
- fairseq/docs/optim.rst +38 -0
- fairseq/docs/overview.rst +74 -0
- fairseq/docs/requirements.txt +2 -0
- fairseq/docs/tasks.rst +61 -0
- fairseq/docs/tutorial_classifying_names.rst +415 -0
- fairseq/docs/tutorial_simple_lstm.rst +518 -0
- fairseq/examples/.gitignore +2 -0
- fairseq/examples/__init__.py +9 -0
- fairseq/examples/adaptive_span/README.md +90 -0
- fairseq/examples/adaptive_span/__init__.py +19 -0
- fairseq/examples/adaptive_span/adagrad_with_grad_clip.py +128 -0
- fairseq/examples/adaptive_span/adaptive_span_attention.py +160 -0
- fairseq/examples/adaptive_span/adaptive_span_loss.py +106 -0
- fairseq/examples/adaptive_span/adaptive_span_model.py +263 -0
- fairseq/examples/adaptive_span/adaptive_span_model_wrapper.py +145 -0
- fairseq/examples/adaptive_span/truncated_bptt_lm_task.py +281 -0
- fairseq/examples/backtranslation/README.md +297 -0
- fairseq/examples/backtranslation/deduplicate_lines.py +41 -0
- fairseq/examples/backtranslation/extract_bt_data.py +72 -0
fairseq/.github/ISSUE_TEMPLATE.md
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## 👉 [Please follow one of these issue templates](https://github.com/pytorch/fairseq/issues/new/choose) 👈
|
| 2 |
+
|
| 3 |
+
Note: to keep the backlog clean and actionable, issues may be immediately closed if they do not follow one of the above issue templates.
|
fairseq/.github/ISSUE_TEMPLATE/bug_report.md
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: 🐛 Bug Report
|
| 3 |
+
about: Submit a bug report to help us improve
|
| 4 |
+
labels: 'bug, needs triage'
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## 🐛 Bug
|
| 8 |
+
|
| 9 |
+
<!-- A clear and concise description of what the bug is. -->
|
| 10 |
+
|
| 11 |
+
### To Reproduce
|
| 12 |
+
|
| 13 |
+
Steps to reproduce the behavior (**always include the command you ran**):
|
| 14 |
+
|
| 15 |
+
1. Run cmd '....'
|
| 16 |
+
2. See error
|
| 17 |
+
|
| 18 |
+
<!-- If you have a code sample, error messages, stack traces, please provide it here as well -->
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
#### Code sample
|
| 22 |
+
<!-- Ideally attach a minimal code sample to reproduce the decried issue.
|
| 23 |
+
Minimal means having the shortest code but still preserving the bug. -->
|
| 24 |
+
|
| 25 |
+
### Expected behavior
|
| 26 |
+
|
| 27 |
+
<!-- A clear and concise description of what you expected to happen. -->
|
| 28 |
+
|
| 29 |
+
### Environment
|
| 30 |
+
|
| 31 |
+
- fairseq Version (e.g., 1.0 or main):
|
| 32 |
+
- PyTorch Version (e.g., 1.0)
|
| 33 |
+
- OS (e.g., Linux):
|
| 34 |
+
- How you installed fairseq (`pip`, source):
|
| 35 |
+
- Build command you used (if compiling from source):
|
| 36 |
+
- Python version:
|
| 37 |
+
- CUDA/cuDNN version:
|
| 38 |
+
- GPU models and configuration:
|
| 39 |
+
- Any other relevant information:
|
| 40 |
+
|
| 41 |
+
### Additional context
|
| 42 |
+
|
| 43 |
+
<!-- Add any other context about the problem here. -->
|
fairseq/.github/ISSUE_TEMPLATE/documentation.md
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: 📚 Documentation/Typos
|
| 3 |
+
about: Report an issue related to documentation or a typo
|
| 4 |
+
labels: 'documentation, needs triage'
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## 📚 Documentation
|
| 8 |
+
|
| 9 |
+
For typos and doc fixes, please go ahead and:
|
| 10 |
+
|
| 11 |
+
1. Create an issue.
|
| 12 |
+
2. Fix the typo.
|
| 13 |
+
3. Submit a PR.
|
| 14 |
+
|
| 15 |
+
Thanks!
|
fairseq/.github/ISSUE_TEMPLATE/feature_request.md
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: 🚀 Feature Request
|
| 3 |
+
about: Submit a proposal/request for a new feature
|
| 4 |
+
labels: 'enhancement, help wanted, needs triage'
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## 🚀 Feature Request
|
| 8 |
+
<!-- A clear and concise description of the feature proposal -->
|
| 9 |
+
|
| 10 |
+
### Motivation
|
| 11 |
+
|
| 12 |
+
<!-- Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too -->
|
| 13 |
+
|
| 14 |
+
### Pitch
|
| 15 |
+
|
| 16 |
+
<!-- A clear and concise description of what you want to happen. -->
|
| 17 |
+
|
| 18 |
+
### Alternatives
|
| 19 |
+
|
| 20 |
+
<!-- A clear and concise description of any alternative solutions or features you've considered, if any. -->
|
| 21 |
+
|
| 22 |
+
### Additional context
|
| 23 |
+
|
| 24 |
+
<!-- Add any other context or screenshots about the feature request here. -->
|
fairseq/.github/ISSUE_TEMPLATE/how-to-question.md
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: ❓ Questions/Help
|
| 3 |
+
about: If you have questions, please first search existing issues and docs
|
| 4 |
+
labels: 'question, needs triage'
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## ❓ Questions and Help
|
| 8 |
+
|
| 9 |
+
### Before asking:
|
| 10 |
+
1. search the issues.
|
| 11 |
+
2. search the docs.
|
| 12 |
+
|
| 13 |
+
<!-- If you still can't find what you need: -->
|
| 14 |
+
|
| 15 |
+
#### What is your question?
|
| 16 |
+
|
| 17 |
+
#### Code
|
| 18 |
+
|
| 19 |
+
<!-- Please paste a code snippet if your question requires it! -->
|
| 20 |
+
|
| 21 |
+
#### What have you tried?
|
| 22 |
+
|
| 23 |
+
#### What's your environment?
|
| 24 |
+
|
| 25 |
+
- fairseq Version (e.g., 1.0 or main):
|
| 26 |
+
- PyTorch Version (e.g., 1.0)
|
| 27 |
+
- OS (e.g., Linux):
|
| 28 |
+
- How you installed fairseq (`pip`, source):
|
| 29 |
+
- Build command you used (if compiling from source):
|
| 30 |
+
- Python version:
|
| 31 |
+
- CUDA/cuDNN version:
|
| 32 |
+
- GPU models and configuration:
|
| 33 |
+
- Any other relevant information:
|
fairseq/.github/PULL_REQUEST_TEMPLATE.md
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Before submitting
|
| 2 |
+
|
| 3 |
+
- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
|
| 4 |
+
- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)?
|
| 5 |
+
- [ ] Did you make sure to update the docs?
|
| 6 |
+
- [ ] Did you write any new necessary tests?
|
| 7 |
+
|
| 8 |
+
## What does this PR do?
|
| 9 |
+
Fixes # (issue).
|
| 10 |
+
|
| 11 |
+
## PR review
|
| 12 |
+
Anyone in the community is free to review the PR once the tests have passed.
|
| 13 |
+
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
|
| 14 |
+
|
| 15 |
+
## Did you have fun?
|
| 16 |
+
Make sure you had fun coding 🙃
|
fairseq/.github/stale.yml
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Configuration for probot-stale - https://github.com/probot/stale
|
| 2 |
+
# Mostly copied from github.com/facebook/react/blob/master/.github/stale.yml
|
| 3 |
+
# Number of days of inactivity before an issue becomes stale
|
| 4 |
+
daysUntilStale: 90
|
| 5 |
+
# Number of days of inactivity before a stale issue is closed
|
| 6 |
+
daysUntilClose: 7
|
| 7 |
+
# Issues with these labels will never be considered stale
|
| 8 |
+
exemptLabels:
|
| 9 |
+
- bug
|
| 10 |
+
# Label to use when marking an issue as stale
|
| 11 |
+
staleLabel: stale
|
| 12 |
+
issues:
|
| 13 |
+
# Comment to post when marking an issue as stale.
|
| 14 |
+
markComment: >
|
| 15 |
+
This issue has been automatically marked as stale.
|
| 16 |
+
**If this issue is still affecting you, please leave any comment** (for example, "bump"), and we'll keep it open.
|
| 17 |
+
We are sorry that we haven't been able to prioritize it yet. If you have any new additional information, please include it with your comment!
|
| 18 |
+
# Comment to post when closing a stale issue.
|
| 19 |
+
closeComment: >
|
| 20 |
+
Closing this issue after a prolonged period of inactivity. If this issue is still present in the latest release, please create a new issue with up-to-date information. Thank you!
|
| 21 |
+
pulls:
|
| 22 |
+
# Comment to post when marking a pull request as stale.
|
| 23 |
+
markComment: >
|
| 24 |
+
This pull request has been automatically marked as stale.
|
| 25 |
+
**If this pull request is still relevant, please leave any comment** (for example, "bump"), and we'll keep it open.
|
| 26 |
+
We are sorry that we haven't been able to prioritize reviewing it yet. Your contribution is very much appreciated.
|
| 27 |
+
# Comment to post when closing a stale pull request.
|
| 28 |
+
closeComment: >
|
| 29 |
+
Closing this pull request after a prolonged period of inactivity. If this issue is still present in the latest release, please ask for this pull request to be reopened. Thank you!
|
| 30 |
+
|
fairseq/.github/workflows/build.yml
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: build
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
# Trigger the workflow on push to main or any pull request
|
| 5 |
+
push:
|
| 6 |
+
branches:
|
| 7 |
+
- main
|
| 8 |
+
pull_request:
|
| 9 |
+
|
| 10 |
+
jobs:
|
| 11 |
+
build:
|
| 12 |
+
|
| 13 |
+
strategy:
|
| 14 |
+
max-parallel: 4
|
| 15 |
+
matrix:
|
| 16 |
+
platform: [ubuntu-latest, macos-latest]
|
| 17 |
+
python-version: [3.6, 3.7]
|
| 18 |
+
|
| 19 |
+
runs-on: ${{ matrix.platform }}
|
| 20 |
+
|
| 21 |
+
steps:
|
| 22 |
+
- uses: actions/checkout@v2
|
| 23 |
+
|
| 24 |
+
- name: Set up Python ${{ matrix.python-version }}
|
| 25 |
+
uses: actions/setup-python@v2
|
| 26 |
+
with:
|
| 27 |
+
python-version: ${{ matrix.python-version }}
|
| 28 |
+
|
| 29 |
+
- name: Conditionally install pytorch
|
| 30 |
+
if: matrix.platform == 'windows-latest'
|
| 31 |
+
run: pip3 install torch -f https://download.pytorch.org/whl/torch_stable.html
|
| 32 |
+
|
| 33 |
+
- name: Install locally
|
| 34 |
+
run: |
|
| 35 |
+
python -m pip install --upgrade pip
|
| 36 |
+
git submodule update --init --recursive
|
| 37 |
+
python setup.py build_ext --inplace
|
| 38 |
+
python -m pip install --editable .
|
| 39 |
+
|
| 40 |
+
- name: Install optional test requirements
|
| 41 |
+
run: |
|
| 42 |
+
python -m pip install iopath transformers pyarrow
|
| 43 |
+
python -m pip install git+https://github.com/facebookresearch/fairscale.git@main
|
| 44 |
+
|
| 45 |
+
- name: Lint with flake8
|
| 46 |
+
run: |
|
| 47 |
+
pip install flake8
|
| 48 |
+
# stop the build if there are Python syntax errors or undefined names
|
| 49 |
+
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --extend-exclude fairseq/model_parallel/megatron
|
| 50 |
+
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
|
| 51 |
+
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --extend-exclude fairseq/model_parallel/megatron
|
| 52 |
+
|
| 53 |
+
- name: Run tests
|
| 54 |
+
run: |
|
| 55 |
+
python setup.py test
|
fairseq/.github/workflows/build_wheels.yml
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: build_wheels
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches:
|
| 6 |
+
- v[0-9]+.[0-9]+.[x0-9]+
|
| 7 |
+
tags:
|
| 8 |
+
- v*
|
| 9 |
+
|
| 10 |
+
jobs:
|
| 11 |
+
build_wheels:
|
| 12 |
+
name: Build wheels on ${{ matrix.os }}
|
| 13 |
+
runs-on: ${{ matrix.os }}
|
| 14 |
+
strategy:
|
| 15 |
+
matrix:
|
| 16 |
+
os: [ubuntu-latest, macos-latest]
|
| 17 |
+
|
| 18 |
+
steps:
|
| 19 |
+
- uses: actions/checkout@v2
|
| 20 |
+
|
| 21 |
+
- name: Install Python
|
| 22 |
+
uses: actions/setup-python@v2
|
| 23 |
+
with:
|
| 24 |
+
python-version: '3.7'
|
| 25 |
+
|
| 26 |
+
- name: Install cibuildwheel
|
| 27 |
+
run: |
|
| 28 |
+
python -m pip install cibuildwheel
|
| 29 |
+
|
| 30 |
+
- name: Build wheels for CPython
|
| 31 |
+
run: |
|
| 32 |
+
python -m cibuildwheel --output-dir dist
|
| 33 |
+
env:
|
| 34 |
+
CIBW_BUILD: "cp36-*64 cp37-*64 cp38-*64"
|
| 35 |
+
CIBW_MANYLINUX_X86_64_IMAGE: manylinux1
|
| 36 |
+
CIBW_BEFORE_BUILD: git submodule update --init --recursive && pip install .
|
| 37 |
+
|
| 38 |
+
- uses: actions/upload-artifact@v2
|
| 39 |
+
with:
|
| 40 |
+
name: wheels
|
| 41 |
+
path: ./dist/*.whl
|
fairseq/.gitignore
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# JetBrains PyCharm IDE
|
| 2 |
+
.idea/
|
| 3 |
+
|
| 4 |
+
# Byte-compiled / optimized / DLL files
|
| 5 |
+
__pycache__/
|
| 6 |
+
*.py[cod]
|
| 7 |
+
*$py.class
|
| 8 |
+
|
| 9 |
+
# C extensions
|
| 10 |
+
*.so
|
| 11 |
+
|
| 12 |
+
# macOS dir files
|
| 13 |
+
.DS_Store
|
| 14 |
+
|
| 15 |
+
# Distribution / packaging
|
| 16 |
+
.Python
|
| 17 |
+
env/
|
| 18 |
+
build/
|
| 19 |
+
develop-eggs/
|
| 20 |
+
dist/
|
| 21 |
+
downloads/
|
| 22 |
+
eggs/
|
| 23 |
+
.eggs/
|
| 24 |
+
lib/
|
| 25 |
+
lib64/
|
| 26 |
+
parts/
|
| 27 |
+
sdist/
|
| 28 |
+
var/
|
| 29 |
+
wheels/
|
| 30 |
+
*.egg-info/
|
| 31 |
+
.installed.cfg
|
| 32 |
+
*.egg
|
| 33 |
+
|
| 34 |
+
# Checkpoints
|
| 35 |
+
checkpoints
|
| 36 |
+
|
| 37 |
+
# PyInstaller
|
| 38 |
+
# Usually these files are written by a python script from a template
|
| 39 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 40 |
+
*.manifest
|
| 41 |
+
*.spec
|
| 42 |
+
|
| 43 |
+
# Installer logs
|
| 44 |
+
pip-log.txt
|
| 45 |
+
pip-delete-this-directory.txt
|
| 46 |
+
|
| 47 |
+
# Unit test / coverage reports
|
| 48 |
+
htmlcov/
|
| 49 |
+
.tox/
|
| 50 |
+
.coverage
|
| 51 |
+
.coverage.*
|
| 52 |
+
.cache
|
| 53 |
+
nosetests.xml
|
| 54 |
+
coverage.xml
|
| 55 |
+
*.cover
|
| 56 |
+
.hypothesis/
|
| 57 |
+
|
| 58 |
+
# Translations
|
| 59 |
+
*.mo
|
| 60 |
+
*.pot
|
| 61 |
+
|
| 62 |
+
# Django stuff:
|
| 63 |
+
*.log
|
| 64 |
+
local_settings.py
|
| 65 |
+
|
| 66 |
+
# Flask stuff:
|
| 67 |
+
instance/
|
| 68 |
+
.webassets-cache
|
| 69 |
+
|
| 70 |
+
# Scrapy stuff:
|
| 71 |
+
.scrapy
|
| 72 |
+
|
| 73 |
+
# Sphinx documentation
|
| 74 |
+
docs/_build/
|
| 75 |
+
|
| 76 |
+
# PyBuilder
|
| 77 |
+
target/
|
| 78 |
+
|
| 79 |
+
# Jupyter Notebook
|
| 80 |
+
.ipynb_checkpoints
|
| 81 |
+
|
| 82 |
+
# pyenv
|
| 83 |
+
.python-version
|
| 84 |
+
|
| 85 |
+
# celery beat schedule file
|
| 86 |
+
celerybeat-schedule
|
| 87 |
+
|
| 88 |
+
# SageMath parsed files
|
| 89 |
+
*.sage.py
|
| 90 |
+
|
| 91 |
+
# dotenv
|
| 92 |
+
.env
|
| 93 |
+
|
| 94 |
+
# virtualenv
|
| 95 |
+
.venv
|
| 96 |
+
venv/
|
| 97 |
+
ENV/
|
| 98 |
+
|
| 99 |
+
# Spyder project settings
|
| 100 |
+
.spyderproject
|
| 101 |
+
.spyproject
|
| 102 |
+
|
| 103 |
+
# Rope project settings
|
| 104 |
+
.ropeproject
|
| 105 |
+
|
| 106 |
+
# mkdocs documentation
|
| 107 |
+
/site
|
| 108 |
+
|
| 109 |
+
# mypy
|
| 110 |
+
.mypy_cache/
|
| 111 |
+
|
| 112 |
+
# Generated files
|
| 113 |
+
/fairseq/temporal_convolution_tbc
|
| 114 |
+
/fairseq/modules/*_layer/*_forward.cu
|
| 115 |
+
/fairseq/modules/*_layer/*_backward.cu
|
| 116 |
+
/fairseq/version.py
|
| 117 |
+
|
| 118 |
+
# data
|
| 119 |
+
data-bin/
|
| 120 |
+
|
| 121 |
+
# reranking
|
| 122 |
+
/examples/reranking/rerank_data
|
| 123 |
+
|
| 124 |
+
# Cython-generated C++ source files
|
| 125 |
+
/fairseq/data/data_utils_fast.cpp
|
| 126 |
+
/fairseq/data/token_block_utils_fast.cpp
|
| 127 |
+
|
| 128 |
+
# VSCODE
|
| 129 |
+
.vscode/ftp-sync.json
|
| 130 |
+
.vscode/settings.json
|
| 131 |
+
|
| 132 |
+
# Experimental Folder
|
| 133 |
+
experimental/*
|
| 134 |
+
|
| 135 |
+
# Weights and Biases logs
|
| 136 |
+
wandb/
|
fairseq/.gitmodules
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[submodule "fairseq/model_parallel/megatron"]
|
| 2 |
+
path = fairseq/model_parallel/megatron
|
| 3 |
+
url = https://github.com/ngoyal2707/Megatron-LM
|
| 4 |
+
branch = fairseq
|
fairseq/CODE_OF_CONDUCT.md
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code of Conduct
|
| 2 |
+
|
| 3 |
+
## Our Pledge
|
| 4 |
+
|
| 5 |
+
In the interest of fostering an open and welcoming environment, we as
|
| 6 |
+
contributors and maintainers pledge to make participation in our project and
|
| 7 |
+
our community a harassment-free experience for everyone, regardless of age, body
|
| 8 |
+
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
| 9 |
+
level of experience, education, socio-economic status, nationality, personal
|
| 10 |
+
appearance, race, religion, or sexual identity and orientation.
|
| 11 |
+
|
| 12 |
+
## Our Standards
|
| 13 |
+
|
| 14 |
+
Examples of behavior that contributes to creating a positive environment
|
| 15 |
+
include:
|
| 16 |
+
|
| 17 |
+
* Using welcoming and inclusive language
|
| 18 |
+
* Being respectful of differing viewpoints and experiences
|
| 19 |
+
* Gracefully accepting constructive criticism
|
| 20 |
+
* Focusing on what is best for the community
|
| 21 |
+
* Showing empathy towards other community members
|
| 22 |
+
|
| 23 |
+
Examples of unacceptable behavior by participants include:
|
| 24 |
+
|
| 25 |
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
| 26 |
+
advances
|
| 27 |
+
* Trolling, insulting/derogatory comments, and personal or political attacks
|
| 28 |
+
* Public or private harassment
|
| 29 |
+
* Publishing others' private information, such as a physical or electronic
|
| 30 |
+
address, without explicit permission
|
| 31 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
| 32 |
+
professional setting
|
| 33 |
+
|
| 34 |
+
## Our Responsibilities
|
| 35 |
+
|
| 36 |
+
Project maintainers are responsible for clarifying the standards of acceptable
|
| 37 |
+
behavior and are expected to take appropriate and fair corrective action in
|
| 38 |
+
response to any instances of unacceptable behavior.
|
| 39 |
+
|
| 40 |
+
Project maintainers have the right and responsibility to remove, edit, or
|
| 41 |
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
| 42 |
+
that are not aligned to this Code of Conduct, or to ban temporarily or
|
| 43 |
+
permanently any contributor for other behaviors that they deem inappropriate,
|
| 44 |
+
threatening, offensive, or harmful.
|
| 45 |
+
|
| 46 |
+
## Scope
|
| 47 |
+
|
| 48 |
+
This Code of Conduct applies within all project spaces, and it also applies when
|
| 49 |
+
an individual is representing the project or its community in public spaces.
|
| 50 |
+
Examples of representing a project or community include using an official
|
| 51 |
+
project e-mail address, posting via an official social media account, or acting
|
| 52 |
+
as an appointed representative at an online or offline event. Representation of
|
| 53 |
+
a project may be further defined and clarified by project maintainers.
|
| 54 |
+
|
| 55 |
+
## Enforcement
|
| 56 |
+
|
| 57 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
| 58 |
+
reported by contacting the project team at <conduct@pytorch.org>. All
|
| 59 |
+
complaints will be reviewed and investigated and will result in a response that
|
| 60 |
+
is deemed necessary and appropriate to the circumstances. The project team is
|
| 61 |
+
obligated to maintain confidentiality with regard to the reporter of an incident.
|
| 62 |
+
Further details of specific enforcement policies may be posted separately.
|
| 63 |
+
|
| 64 |
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
| 65 |
+
faith may face temporary or permanent repercussions as determined by other
|
| 66 |
+
members of the project's leadership.
|
| 67 |
+
|
| 68 |
+
## Attribution
|
| 69 |
+
|
| 70 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
| 71 |
+
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
| 72 |
+
|
| 73 |
+
[homepage]: https://www.contributor-covenant.org
|
| 74 |
+
|
| 75 |
+
For answers to common questions about this code of conduct, see
|
| 76 |
+
https://www.contributor-covenant.org/faq
|
| 77 |
+
|
fairseq/CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributing to Facebook AI Research Sequence-to-Sequence Toolkit (fairseq)
|
| 2 |
+
We want to make contributing to this project as easy and transparent as
|
| 3 |
+
possible.
|
| 4 |
+
|
| 5 |
+
## Pull Requests
|
| 6 |
+
We actively welcome your pull requests.
|
| 7 |
+
|
| 8 |
+
1. Fork the repo and create your branch from `main`.
|
| 9 |
+
2. If you've added code that should be tested, add tests.
|
| 10 |
+
3. If you've changed APIs, update the documentation.
|
| 11 |
+
4. Ensure the test suite passes.
|
| 12 |
+
5. Make sure your code lints.
|
| 13 |
+
6. If you haven't already, complete the Contributor License Agreement ("CLA").
|
| 14 |
+
|
| 15 |
+
## Contributor License Agreement ("CLA")
|
| 16 |
+
In order to accept your pull request, we need you to submit a CLA. You only need
|
| 17 |
+
to do this once to work on any of Facebook's open source projects.
|
| 18 |
+
|
| 19 |
+
Complete your CLA here: <https://code.facebook.com/cla>
|
| 20 |
+
|
| 21 |
+
## Issues
|
| 22 |
+
We use GitHub issues to track public bugs. Please ensure your description is
|
| 23 |
+
clear and has sufficient instructions to be able to reproduce the issue.
|
| 24 |
+
|
| 25 |
+
## License
|
| 26 |
+
By contributing to Facebook AI Research Sequence-to-Sequence Toolkit (fairseq),
|
| 27 |
+
you agree that your contributions will be licensed under the LICENSE file in
|
| 28 |
+
the root directory of this source tree.
|
fairseq/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) Facebook, Inc. and its affiliates.
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
fairseq/README.md
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<p align="center">
|
| 2 |
+
<img src="docs/fairseq_logo.png" width="150">
|
| 3 |
+
<br />
|
| 4 |
+
<br />
|
| 5 |
+
<a href="https://github.com/pytorch/fairseq/blob/main/LICENSE"><img alt="MIT License" src="https://img.shields.io/badge/license-MIT-blue.svg" /></a>
|
| 6 |
+
<a href="https://github.com/pytorch/fairseq/releases"><img alt="Latest Release" src="https://img.shields.io/github/release/pytorch/fairseq.svg" /></a>
|
| 7 |
+
<a href="https://github.com/pytorch/fairseq/actions?query=workflow:build"><img alt="Build Status" src="https://github.com/pytorch/fairseq/workflows/build/badge.svg" /></a>
|
| 8 |
+
<a href="https://fairseq.readthedocs.io/en/latest/?badge=latest"><img alt="Documentation Status" src="https://readthedocs.org/projects/fairseq/badge/?version=latest" /></a>
|
| 9 |
+
</p>
|
| 10 |
+
|
| 11 |
+
--------------------------------------------------------------------------------
|
| 12 |
+
|
| 13 |
+
Fairseq(-py) is a sequence modeling toolkit that allows researchers and
|
| 14 |
+
developers to train custom models for translation, summarization, language
|
| 15 |
+
modeling and other text generation tasks.
|
| 16 |
+
|
| 17 |
+
We provide reference implementations of various sequence modeling papers:
|
| 18 |
+
|
| 19 |
+
<details><summary>List of implemented papers</summary><p>
|
| 20 |
+
|
| 21 |
+
* **Convolutional Neural Networks (CNN)**
|
| 22 |
+
+ [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/conv_lm/README.md)
|
| 23 |
+
+ [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
|
| 24 |
+
+ [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
|
| 25 |
+
+ [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
|
| 26 |
+
+ [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
|
| 27 |
+
* **LightConv and DynamicConv models**
|
| 28 |
+
+ [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
|
| 29 |
+
* **Long Short-Term Memory (LSTM) networks**
|
| 30 |
+
+ Effective Approaches to Attention-based Neural Machine Translation (Luong et al., 2015)
|
| 31 |
+
* **Transformer (self-attention) networks**
|
| 32 |
+
+ Attention Is All You Need (Vaswani et al., 2017)
|
| 33 |
+
+ [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
|
| 34 |
+
+ [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
|
| 35 |
+
+ [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/README.adaptive_inputs.md)
|
| 36 |
+
+ [Lexically constrained decoding with dynamic beam allocation (Post & Vilar, 2018)](examples/constrained_decoding/README.md)
|
| 37 |
+
+ [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context (Dai et al., 2019)](examples/truncated_bptt/README.md)
|
| 38 |
+
+ [Adaptive Attention Span in Transformers (Sukhbaatar et al., 2019)](examples/adaptive_span/README.md)
|
| 39 |
+
+ [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
|
| 40 |
+
+ [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
|
| 41 |
+
+ [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
|
| 42 |
+
+ [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md )
|
| 43 |
+
+ [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md)
|
| 44 |
+
+ [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md)
|
| 45 |
+
+ [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md)
|
| 46 |
+
+ [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
|
| 47 |
+
+ [Generating Medical Reports from Patient-Doctor Conversations Using Sequence-to-Sequence Models (Enarvi et al., 2020)](examples/pointer_generator/README.md)
|
| 48 |
+
+ [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md)
|
| 49 |
+
+ [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
|
| 50 |
+
+ [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md)
|
| 51 |
+
+ [Unsupervised Cross-lingual Representation Learning for Speech Recognition (Conneau et al., 2020)](https://arxiv.org/abs/2006.13979)
|
| 52 |
+
+ [Robust wav2vec 2.0: Analyzing Domain Shift in Self-Supervised Pre-Training (Hsu, et al., 2021)](https://arxiv.org/abs/2104.01027)
|
| 53 |
+
+ [Unsupervised Speech Recognition (Baevski, et al., 2021)](https://arxiv.org/abs/2105.11084)
|
| 54 |
+
* **Non-autoregressive Transformers**
|
| 55 |
+
+ Non-Autoregressive Neural Machine Translation (Gu et al., 2017)
|
| 56 |
+
+ Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018)
|
| 57 |
+
+ Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al. 2019)
|
| 58 |
+
+ Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019)
|
| 59 |
+
+ [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
|
| 60 |
+
* **Finetuning**
|
| 61 |
+
+ [Better Fine-Tuning by Reducing Representational Collapse (Aghajanyan et al. 2020)](examples/rxf/README.md)
|
| 62 |
+
|
| 63 |
+
</p></details>
|
| 64 |
+
|
| 65 |
+
### What's New:
|
| 66 |
+
|
| 67 |
+
* September 2021 [`master` branch renamed to `main`](https://github.com/github/renaming).
|
| 68 |
+
* July 2021 [Released DrNMT code](examples/discriminative_reranking_nmt/README.md)
|
| 69 |
+
* July 2021 [Released Robust wav2vec 2.0 model](examples/wav2vec/README.md)
|
| 70 |
+
* June 2021 [Released XLMR-XL and XLMR-XXL models](examples/xlmr/README.md)
|
| 71 |
+
* May 2021 [Released Unsupervised Speech Recognition code](examples/wav2vec/unsupervised/README.md)
|
| 72 |
+
* March 2021 [Added full parameter and optimizer state sharding + CPU offloading](examples/fully_sharded_data_parallel/README.md)
|
| 73 |
+
* February 2021 [Added LASER training code](examples/laser/README.md)
|
| 74 |
+
* December 2020: [Added Adaptive Attention Span code](examples/adaptive_span/README.md)
|
| 75 |
+
* December 2020: [GottBERT model and code released](examples/gottbert/README.md)
|
| 76 |
+
* November 2020: Adopted the [Hydra](https://github.com/facebookresearch/hydra) configuration framework
|
| 77 |
+
* [see documentation explaining how to use it for new and existing projects](docs/hydra_integration.md)
|
| 78 |
+
* November 2020: [fairseq 0.10.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.10.0)
|
| 79 |
+
* October 2020: [Added R3F/R4F (Better Fine-Tuning) code](examples/rxf/README.md)
|
| 80 |
+
* October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md)
|
| 81 |
+
* October 2020: [Added CRISS models and code](examples/criss/README.md)
|
| 82 |
+
|
| 83 |
+
<details><summary>Previous updates</summary><p>
|
| 84 |
+
|
| 85 |
+
* September 2020: [Added Linformer code](examples/linformer/README.md)
|
| 86 |
+
* September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md)
|
| 87 |
+
* August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md)
|
| 88 |
+
* August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md)
|
| 89 |
+
* July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md)
|
| 90 |
+
* May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq)
|
| 91 |
+
* April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md)
|
| 92 |
+
* April 2020: [Quant-Noise code released](examples/quant_noise/README.md)
|
| 93 |
+
* April 2020: [Initial model parallel support and 11B parameters unidirectional LM released](examples/megatron_11b/README.md)
|
| 94 |
+
* March 2020: [Byte-level BPE code released](examples/byte_level_bpe/README.md)
|
| 95 |
+
* February 2020: [mBART model and code released](examples/mbart/README.md)
|
| 96 |
+
* February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/main/examples/backtranslation#training-your-own-model-wmt18-english-german)
|
| 97 |
+
* December 2019: [fairseq 0.9.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.9.0)
|
| 98 |
+
* November 2019: [VizSeq released (a visual analysis toolkit for evaluating fairseq models)](https://facebookresearch.github.io/vizseq/docs/getting_started/fairseq_example)
|
| 99 |
+
* November 2019: [CamemBERT model and code released](examples/camembert/README.md)
|
| 100 |
+
* November 2019: [BART model and code released](examples/bart/README.md)
|
| 101 |
+
* November 2019: [XLM-R models and code released](examples/xlmr/README.md)
|
| 102 |
+
* September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md)
|
| 103 |
+
* August 2019: [WMT'19 models released](examples/wmt19/README.md)
|
| 104 |
+
* July 2019: fairseq relicensed under MIT license
|
| 105 |
+
* July 2019: [RoBERTa models and code released](examples/roberta/README.md)
|
| 106 |
+
* June 2019: [wav2vec models and code released](examples/wav2vec/README.md)
|
| 107 |
+
|
| 108 |
+
</p></details>
|
| 109 |
+
|
| 110 |
+
### Features:
|
| 111 |
+
|
| 112 |
+
* multi-GPU training on one machine or across multiple machines (data and model parallel)
|
| 113 |
+
* fast generation on both CPU and GPU with multiple search algorithms implemented:
|
| 114 |
+
+ beam search
|
| 115 |
+
+ Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424))
|
| 116 |
+
+ sampling (unconstrained, top-k and top-p/nucleus)
|
| 117 |
+
+ [lexically constrained decoding](examples/constrained_decoding/README.md) (Post & Vilar, 2018)
|
| 118 |
+
* [gradient accumulation](https://fairseq.readthedocs.io/en/latest/getting_started.html#large-mini-batch-training-with-delayed-updates) enables training with large mini-batches even on a single GPU
|
| 119 |
+
* [mixed precision training](https://fairseq.readthedocs.io/en/latest/getting_started.html#training-with-half-precision-floating-point-fp16) (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores))
|
| 120 |
+
* [extensible](https://fairseq.readthedocs.io/en/latest/overview.html): easily register new models, criterions, tasks, optimizers and learning rate schedulers
|
| 121 |
+
* [flexible configuration](docs/hydra_integration.md) based on [Hydra](https://github.com/facebookresearch/hydra) allowing a combination of code, command-line and file based configuration
|
| 122 |
+
* [full parameter and optimizer state sharding](examples/fully_sharded_data_parallel/README.md)
|
| 123 |
+
* [offloading parameters to CPU](examples/fully_sharded_data_parallel/README.md)
|
| 124 |
+
|
| 125 |
+
We also provide [pre-trained models for translation and language modeling](#pre-trained-models-and-examples)
|
| 126 |
+
with a convenient `torch.hub` interface:
|
| 127 |
+
|
| 128 |
+
``` python
|
| 129 |
+
en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model')
|
| 130 |
+
en2de.translate('Hello world', beam=5)
|
| 131 |
+
# 'Hallo Welt'
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
See the PyTorch Hub tutorials for [translation](https://pytorch.org/hub/pytorch_fairseq_translation/)
|
| 135 |
+
and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more examples.
|
| 136 |
+
|
| 137 |
+
# Requirements and Installation
|
| 138 |
+
|
| 139 |
+
* [PyTorch](http://pytorch.org/) version >= 1.5.0
|
| 140 |
+
* Python version >= 3.6
|
| 141 |
+
* For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
|
| 142 |
+
* **To install fairseq** and develop locally:
|
| 143 |
+
|
| 144 |
+
``` bash
|
| 145 |
+
git clone https://github.com/pytorch/fairseq
|
| 146 |
+
cd fairseq
|
| 147 |
+
pip install --editable ./
|
| 148 |
+
|
| 149 |
+
# on MacOS:
|
| 150 |
+
# CFLAGS="-stdlib=libc++" pip install --editable ./
|
| 151 |
+
|
| 152 |
+
# to install the latest stable release (0.10.x)
|
| 153 |
+
# pip install fairseq
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
* **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library:
|
| 157 |
+
|
| 158 |
+
``` bash
|
| 159 |
+
git clone https://github.com/NVIDIA/apex
|
| 160 |
+
cd apex
|
| 161 |
+
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" \
|
| 162 |
+
--global-option="--deprecated_fused_adam" --global-option="--xentropy" \
|
| 163 |
+
--global-option="--fast_multihead_attn" ./
|
| 164 |
+
```
|
| 165 |
+
|
| 166 |
+
* **For large datasets** install [PyArrow](https://arrow.apache.org/docs/python/install.html#using-pip): `pip install pyarrow`
|
| 167 |
+
* If you use Docker make sure to increase the shared memory size either with `--ipc=host` or `--shm-size`
|
| 168 |
+
as command line options to `nvidia-docker run` .
|
| 169 |
+
|
| 170 |
+
# Getting Started
|
| 171 |
+
|
| 172 |
+
The [full documentation](https://fairseq.readthedocs.io/) contains instructions
|
| 173 |
+
for getting started, training new models and extending fairseq with new model
|
| 174 |
+
types and tasks.
|
| 175 |
+
|
| 176 |
+
# Pre-trained models and examples
|
| 177 |
+
|
| 178 |
+
We provide pre-trained models and pre-processed, binarized test sets for several tasks listed below,
|
| 179 |
+
as well as example training and evaluation commands.
|
| 180 |
+
|
| 181 |
+
* [Translation](examples/translation/README.md): convolutional and transformer models are available
|
| 182 |
+
* [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available
|
| 183 |
+
|
| 184 |
+
We also have more detailed READMEs to reproduce results from specific papers:
|
| 185 |
+
|
| 186 |
+
* [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
|
| 187 |
+
* [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
|
| 188 |
+
* [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md)
|
| 189 |
+
* [Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)](examples/quant_noise/README.md)
|
| 190 |
+
* [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md)
|
| 191 |
+
* [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md)
|
| 192 |
+
* [Reducing Transformer Depth on Demand with Structured Dropout (Fan et al., 2019)](examples/layerdrop/README.md)
|
| 193 |
+
* [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md)
|
| 194 |
+
* [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
|
| 195 |
+
* [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
|
| 196 |
+
* [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
|
| 197 |
+
* [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
|
| 198 |
+
* [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
|
| 199 |
+
* [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
|
| 200 |
+
* [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
|
| 201 |
+
* [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
|
| 202 |
+
* [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
|
| 203 |
+
* [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
|
| 204 |
+
* [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
|
| 205 |
+
* [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/README.conv.md)
|
| 206 |
+
|
| 207 |
+
# Join the fairseq community
|
| 208 |
+
|
| 209 |
+
* Twitter: https://twitter.com/fairseq
|
| 210 |
+
* Facebook page: https://www.facebook.com/groups/fairseq.users
|
| 211 |
+
* Google group: https://groups.google.com/forum/#!forum/fairseq-users
|
| 212 |
+
|
| 213 |
+
# License
|
| 214 |
+
|
| 215 |
+
fairseq(-py) is MIT-licensed.
|
| 216 |
+
The license applies to the pre-trained models as well.
|
| 217 |
+
|
| 218 |
+
# Citation
|
| 219 |
+
|
| 220 |
+
Please cite as:
|
| 221 |
+
|
| 222 |
+
``` bibtex
|
| 223 |
+
@inproceedings{ott2019fairseq,
|
| 224 |
+
title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling},
|
| 225 |
+
author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli},
|
| 226 |
+
booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations},
|
| 227 |
+
year = {2019},
|
| 228 |
+
}
|
| 229 |
+
```
|
fairseq/build/lib.linux-x86_64-cpython-39/fairseq/version.py
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
__version__ = "1.0.0a0+
|
|
|
|
| 1 |
+
__version__ = "1.0.0a0+ce30da5"
|
fairseq/docs/Makefile
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Minimal makefile for Sphinx documentation
|
| 2 |
+
#
|
| 3 |
+
|
| 4 |
+
# You can set these variables from the command line.
|
| 5 |
+
SPHINXOPTS =
|
| 6 |
+
SPHINXBUILD = python -msphinx
|
| 7 |
+
SPHINXPROJ = fairseq
|
| 8 |
+
SOURCEDIR = .
|
| 9 |
+
BUILDDIR = _build
|
| 10 |
+
|
| 11 |
+
# Put it first so that "make" without argument is like "make help".
|
| 12 |
+
help:
|
| 13 |
+
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
| 14 |
+
|
| 15 |
+
.PHONY: help Makefile
|
| 16 |
+
|
| 17 |
+
# Catch-all target: route all unknown targets to Sphinx using the new
|
| 18 |
+
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
| 19 |
+
%: Makefile
|
| 20 |
+
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
fairseq/docs/_static/theme_overrides.css
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.wy-table-responsive table td kbd {
|
| 2 |
+
white-space: nowrap;
|
| 3 |
+
}
|
| 4 |
+
.wy-table-responsive table td {
|
| 5 |
+
white-space: normal !important;
|
| 6 |
+
}
|
| 7 |
+
.wy-table-responsive {
|
| 8 |
+
overflow: visible !important;
|
| 9 |
+
}
|
fairseq/docs/command_line_tools.rst
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.. _Command-line Tools:
|
| 2 |
+
|
| 3 |
+
Command-line Tools
|
| 4 |
+
==================
|
| 5 |
+
|
| 6 |
+
Fairseq provides several command-line tools for training and evaluating models:
|
| 7 |
+
|
| 8 |
+
- :ref:`fairseq-preprocess`: Data pre-processing: build vocabularies and binarize training data
|
| 9 |
+
- :ref:`fairseq-train`: Train a new model on one or multiple GPUs
|
| 10 |
+
- :ref:`fairseq-generate`: Translate pre-processed data with a trained model
|
| 11 |
+
- :ref:`fairseq-interactive`: Translate raw text with a trained model
|
| 12 |
+
- :ref:`fairseq-score`: BLEU scoring of generated translations against reference translations
|
| 13 |
+
- :ref:`fairseq-eval-lm`: Language model evaluation
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
.. _fairseq-preprocess:
|
| 17 |
+
|
| 18 |
+
fairseq-preprocess
|
| 19 |
+
~~~~~~~~~~~~~~~~~~
|
| 20 |
+
.. automodule:: fairseq_cli.preprocess
|
| 21 |
+
|
| 22 |
+
.. argparse::
|
| 23 |
+
:module: fairseq.options
|
| 24 |
+
:func: get_preprocessing_parser
|
| 25 |
+
:prog: fairseq-preprocess
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
.. _fairseq-train:
|
| 29 |
+
|
| 30 |
+
fairseq-train
|
| 31 |
+
~~~~~~~~~~~~~
|
| 32 |
+
.. automodule:: fairseq_cli.train
|
| 33 |
+
|
| 34 |
+
.. argparse::
|
| 35 |
+
:module: fairseq.options
|
| 36 |
+
:func: get_training_parser
|
| 37 |
+
:prog: fairseq-train
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
.. _fairseq-generate:
|
| 41 |
+
|
| 42 |
+
fairseq-generate
|
| 43 |
+
~~~~~~~~~~~~~~~~
|
| 44 |
+
.. automodule:: fairseq_cli.generate
|
| 45 |
+
|
| 46 |
+
.. argparse::
|
| 47 |
+
:module: fairseq.options
|
| 48 |
+
:func: get_generation_parser
|
| 49 |
+
:prog: fairseq-generate
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
.. _fairseq-interactive:
|
| 53 |
+
|
| 54 |
+
fairseq-interactive
|
| 55 |
+
~~~~~~~~~~~~~~~~~~~
|
| 56 |
+
.. automodule:: fairseq_cli.interactive
|
| 57 |
+
|
| 58 |
+
.. argparse::
|
| 59 |
+
:module: fairseq.options
|
| 60 |
+
:func: get_interactive_generation_parser
|
| 61 |
+
:prog: fairseq-interactive
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
.. _fairseq-score:
|
| 65 |
+
|
| 66 |
+
fairseq-score
|
| 67 |
+
~~~~~~~~~~~~~
|
| 68 |
+
.. automodule:: fairseq_cli.score
|
| 69 |
+
|
| 70 |
+
.. argparse::
|
| 71 |
+
:module: fairseq_cli.score
|
| 72 |
+
:func: get_parser
|
| 73 |
+
:prog: fairseq-score
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
.. _fairseq-eval-lm:
|
| 77 |
+
|
| 78 |
+
fairseq-eval-lm
|
| 79 |
+
~~~~~~~~~~~~~~~
|
| 80 |
+
.. automodule:: fairseq_cli.eval_lm
|
| 81 |
+
|
| 82 |
+
.. argparse::
|
| 83 |
+
:module: fairseq.options
|
| 84 |
+
:func: get_eval_lm_parser
|
| 85 |
+
:prog: fairseq-eval-lm
|
fairseq/docs/conf.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
#
|
| 4 |
+
# fairseq documentation build configuration file, created by
|
| 5 |
+
# sphinx-quickstart on Fri Aug 17 21:45:30 2018.
|
| 6 |
+
#
|
| 7 |
+
# This file is execfile()d with the current directory set to its
|
| 8 |
+
# containing dir.
|
| 9 |
+
#
|
| 10 |
+
# Note that not all possible configuration values are present in this
|
| 11 |
+
# autogenerated file.
|
| 12 |
+
#
|
| 13 |
+
# All configuration values have a default; values that are commented out
|
| 14 |
+
# serve to show the default.
|
| 15 |
+
|
| 16 |
+
# If extensions (or modules to document with autodoc) are in another directory,
|
| 17 |
+
# add these directories to sys.path here. If the directory is relative to the
|
| 18 |
+
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
| 19 |
+
|
| 20 |
+
import os
|
| 21 |
+
import sys
|
| 22 |
+
from fairseq import __version__
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# source code directory, relative to this file, for sphinx-autobuild
|
| 26 |
+
sys.path.insert(0, os.path.abspath(".."))
|
| 27 |
+
|
| 28 |
+
source_suffix = [".rst"]
|
| 29 |
+
|
| 30 |
+
# -- General configuration ------------------------------------------------
|
| 31 |
+
|
| 32 |
+
# If your documentation needs a minimal Sphinx version, state it here.
|
| 33 |
+
#
|
| 34 |
+
# needs_sphinx = '1.0'
|
| 35 |
+
|
| 36 |
+
# Add any Sphinx extension module names here, as strings. They can be
|
| 37 |
+
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
| 38 |
+
# ones.
|
| 39 |
+
extensions = [
|
| 40 |
+
"sphinx.ext.autodoc",
|
| 41 |
+
"sphinx.ext.intersphinx",
|
| 42 |
+
"sphinx.ext.viewcode",
|
| 43 |
+
"sphinx.ext.napoleon",
|
| 44 |
+
"sphinxarg.ext",
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
# Add any paths that contain templates here, relative to this directory.
|
| 48 |
+
templates_path = ["_templates"]
|
| 49 |
+
|
| 50 |
+
# The master toctree document.
|
| 51 |
+
master_doc = "index"
|
| 52 |
+
|
| 53 |
+
# General information about the project.
|
| 54 |
+
project = "fairseq"
|
| 55 |
+
copyright = "Facebook AI Research (FAIR)"
|
| 56 |
+
author = "Facebook AI Research (FAIR)"
|
| 57 |
+
|
| 58 |
+
github_doc_root = "https://github.com/pytorch/fairseq/tree/main/docs/"
|
| 59 |
+
|
| 60 |
+
# The version info for the project you're documenting, acts as replacement for
|
| 61 |
+
# |version| and |release|, also used in various other places throughout the
|
| 62 |
+
# built documents.
|
| 63 |
+
#
|
| 64 |
+
# The short X.Y version.
|
| 65 |
+
version = __version__
|
| 66 |
+
# The full version, including alpha/beta/rc tags.
|
| 67 |
+
release = __version__
|
| 68 |
+
|
| 69 |
+
# The language for content autogenerated by Sphinx. Refer to documentation
|
| 70 |
+
# for a list of supported languages.
|
| 71 |
+
#
|
| 72 |
+
# This is also used if you do content translation via gettext catalogs.
|
| 73 |
+
# Usually you set "language" from the command line for these cases.
|
| 74 |
+
language = None
|
| 75 |
+
|
| 76 |
+
# List of patterns, relative to source directory, that match files and
|
| 77 |
+
# directories to ignore when looking for source files.
|
| 78 |
+
# This patterns also effect to html_static_path and html_extra_path
|
| 79 |
+
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
| 80 |
+
|
| 81 |
+
# The name of the Pygments (syntax highlighting) style to use.
|
| 82 |
+
pygments_style = "sphinx"
|
| 83 |
+
highlight_language = "python"
|
| 84 |
+
|
| 85 |
+
# If true, `todo` and `todoList` produce output, else they produce nothing.
|
| 86 |
+
todo_include_todos = False
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# -- Options for HTML output ----------------------------------------------
|
| 90 |
+
|
| 91 |
+
# The theme to use for HTML and HTML Help pages. See the documentation for
|
| 92 |
+
# a list of builtin themes.
|
| 93 |
+
#
|
| 94 |
+
html_theme = "sphinx_rtd_theme"
|
| 95 |
+
|
| 96 |
+
# Theme options are theme-specific and customize the look and feel of a theme
|
| 97 |
+
# further. For a list of options available for each theme, see the
|
| 98 |
+
# documentation.
|
| 99 |
+
#
|
| 100 |
+
# html_theme_options = {}
|
| 101 |
+
|
| 102 |
+
# Add any paths that contain custom static files (such as style sheets) here,
|
| 103 |
+
# relative to this directory. They are copied after the builtin static files,
|
| 104 |
+
# so a file named "default.css" will overwrite the builtin "default.css".
|
| 105 |
+
html_static_path = ["_static"]
|
| 106 |
+
|
| 107 |
+
html_context = {
|
| 108 |
+
"css_files": [
|
| 109 |
+
"_static/theme_overrides.css", # override wide tables in RTD theme
|
| 110 |
+
],
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
# Custom sidebar templates, must be a dictionary that maps document names
|
| 114 |
+
# to template names.
|
| 115 |
+
#
|
| 116 |
+
# This is required for the alabaster theme
|
| 117 |
+
# refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars
|
| 118 |
+
# html_sidebars = {
|
| 119 |
+
# '**': [
|
| 120 |
+
# 'about.html',
|
| 121 |
+
# 'navigation.html',
|
| 122 |
+
# 'relations.html', # needs 'show_related': True theme option to display
|
| 123 |
+
# 'searchbox.html',
|
| 124 |
+
# 'donate.html',
|
| 125 |
+
# ]
|
| 126 |
+
# }
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# Example configuration for intersphinx: refer to the Python standard library.
|
| 130 |
+
intersphinx_mapping = {
|
| 131 |
+
"numpy": ("http://docs.scipy.org/doc/numpy/", None),
|
| 132 |
+
"python": ("https://docs.python.org/", None),
|
| 133 |
+
"torch": ("https://pytorch.org/docs/master/", None),
|
| 134 |
+
}
|
fairseq/docs/criterions.rst
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.. role:: hidden
|
| 2 |
+
:class: hidden-section
|
| 3 |
+
|
| 4 |
+
.. _Criterions:
|
| 5 |
+
|
| 6 |
+
Criterions
|
| 7 |
+
==========
|
| 8 |
+
|
| 9 |
+
Criterions compute the loss function given the model and batch, roughly::
|
| 10 |
+
|
| 11 |
+
loss = criterion(model, batch)
|
| 12 |
+
|
| 13 |
+
.. automodule:: fairseq.criterions
|
| 14 |
+
:members:
|
| 15 |
+
|
| 16 |
+
.. autoclass:: fairseq.criterions.FairseqCriterion
|
| 17 |
+
:members:
|
| 18 |
+
:undoc-members:
|
| 19 |
+
|
| 20 |
+
.. autoclass:: fairseq.criterions.adaptive_loss.AdaptiveLoss
|
| 21 |
+
:members:
|
| 22 |
+
:undoc-members:
|
| 23 |
+
.. autoclass:: fairseq.criterions.composite_loss.CompositeLoss
|
| 24 |
+
:members:
|
| 25 |
+
:undoc-members:
|
| 26 |
+
.. autoclass:: fairseq.criterions.cross_entropy.CrossEntropyCriterion
|
| 27 |
+
:members:
|
| 28 |
+
:undoc-members:
|
| 29 |
+
.. autoclass:: fairseq.criterions.label_smoothed_cross_entropy.LabelSmoothedCrossEntropyCriterion
|
| 30 |
+
:members:
|
| 31 |
+
:undoc-members:
|
fairseq/docs/data.rst
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.. role:: hidden
|
| 2 |
+
:class: hidden-section
|
| 3 |
+
|
| 4 |
+
.. module:: fairseq.data
|
| 5 |
+
|
| 6 |
+
Data Loading and Utilities
|
| 7 |
+
==========================
|
| 8 |
+
|
| 9 |
+
.. _datasets:
|
| 10 |
+
|
| 11 |
+
Datasets
|
| 12 |
+
--------
|
| 13 |
+
|
| 14 |
+
**Datasets** define the data format and provide helpers for creating
|
| 15 |
+
mini-batches.
|
| 16 |
+
|
| 17 |
+
.. autoclass:: fairseq.data.FairseqDataset
|
| 18 |
+
:members:
|
| 19 |
+
.. autoclass:: fairseq.data.LanguagePairDataset
|
| 20 |
+
:members:
|
| 21 |
+
.. autoclass:: fairseq.data.MonolingualDataset
|
| 22 |
+
:members:
|
| 23 |
+
|
| 24 |
+
**Helper Datasets**
|
| 25 |
+
|
| 26 |
+
These datasets wrap other :class:`fairseq.data.FairseqDataset` instances and
|
| 27 |
+
provide additional functionality:
|
| 28 |
+
|
| 29 |
+
.. autoclass:: fairseq.data.BacktranslationDataset
|
| 30 |
+
:members:
|
| 31 |
+
.. autoclass:: fairseq.data.ConcatDataset
|
| 32 |
+
:members:
|
| 33 |
+
.. autoclass:: fairseq.data.ResamplingDataset
|
| 34 |
+
:members:
|
| 35 |
+
.. autoclass:: fairseq.data.RoundRobinZipDatasets
|
| 36 |
+
:members:
|
| 37 |
+
.. autoclass:: fairseq.data.TransformEosDataset
|
| 38 |
+
:members:
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
Dictionary
|
| 42 |
+
----------
|
| 43 |
+
|
| 44 |
+
.. autoclass:: fairseq.data.Dictionary
|
| 45 |
+
:members:
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
Iterators
|
| 49 |
+
---------
|
| 50 |
+
|
| 51 |
+
.. autoclass:: fairseq.data.CountingIterator
|
| 52 |
+
:members:
|
| 53 |
+
.. autoclass:: fairseq.data.EpochBatchIterator
|
| 54 |
+
:members:
|
| 55 |
+
.. autoclass:: fairseq.data.GroupedIterator
|
| 56 |
+
:members:
|
| 57 |
+
.. autoclass:: fairseq.data.ShardedIterator
|
| 58 |
+
:members:
|
fairseq/docs/docutils.conf
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[writers]
|
| 2 |
+
option-limit=0
|
fairseq/docs/fairseq_logo.png
ADDED
|
fairseq/docs/getting_started.rst
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Evaluating Pre-trained Models
|
| 2 |
+
=============================
|
| 3 |
+
|
| 4 |
+
First, download a pre-trained model along with its vocabularies:
|
| 5 |
+
|
| 6 |
+
.. code-block:: console
|
| 7 |
+
|
| 8 |
+
> curl https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2 | tar xvjf -
|
| 9 |
+
|
| 10 |
+
This model uses a `Byte Pair Encoding (BPE)
|
| 11 |
+
vocabulary <https://arxiv.org/abs/1508.07909>`__, so we'll have to apply
|
| 12 |
+
the encoding to the source text before it can be translated. This can be
|
| 13 |
+
done with the
|
| 14 |
+
`apply\_bpe.py <https://github.com/rsennrich/subword-nmt/blob/master/subword_nmt/apply_bpe.py>`__
|
| 15 |
+
script using the ``wmt14.en-fr.fconv-cuda/bpecodes`` file. ``@@`` is
|
| 16 |
+
used as a continuation marker and the original text can be easily
|
| 17 |
+
recovered with e.g. ``sed s/@@ //g`` or by passing the ``--remove-bpe``
|
| 18 |
+
flag to :ref:`fairseq-generate`. Prior to BPE, input text needs to be tokenized
|
| 19 |
+
using ``tokenizer.perl`` from
|
| 20 |
+
`mosesdecoder <https://github.com/moses-smt/mosesdecoder>`__.
|
| 21 |
+
|
| 22 |
+
Let's use :ref:`fairseq-interactive` to generate translations interactively.
|
| 23 |
+
Here, we use a beam size of 5 and preprocess the input with the Moses
|
| 24 |
+
tokenizer and the given Byte-Pair Encoding vocabulary. It will automatically
|
| 25 |
+
remove the BPE continuation markers and detokenize the output.
|
| 26 |
+
|
| 27 |
+
.. code-block:: console
|
| 28 |
+
|
| 29 |
+
> MODEL_DIR=wmt14.en-fr.fconv-py
|
| 30 |
+
> fairseq-interactive \
|
| 31 |
+
--path $MODEL_DIR/model.pt $MODEL_DIR \
|
| 32 |
+
--beam 5 --source-lang en --target-lang fr \
|
| 33 |
+
--tokenizer moses \
|
| 34 |
+
--bpe subword_nmt --bpe-codes $MODEL_DIR/bpecodes
|
| 35 |
+
| loading model(s) from wmt14.en-fr.fconv-py/model.pt
|
| 36 |
+
| [en] dictionary: 44206 types
|
| 37 |
+
| [fr] dictionary: 44463 types
|
| 38 |
+
| Type the input sentence and press return:
|
| 39 |
+
Why is it rare to discover new marine mammal species?
|
| 40 |
+
S-0 Why is it rare to discover new marine mam@@ mal species ?
|
| 41 |
+
H-0 -0.0643349438905716 Pourquoi est-il rare de découvrir de nouvelles espèces de mammifères marins?
|
| 42 |
+
P-0 -0.0763 -0.1849 -0.0956 -0.0946 -0.0735 -0.1150 -0.1301 -0.0042 -0.0321 -0.0171 -0.0052 -0.0062 -0.0015
|
| 43 |
+
|
| 44 |
+
This generation script produces three types of outputs: a line prefixed
|
| 45 |
+
with *O* is a copy of the original source sentence; *H* is the
|
| 46 |
+
hypothesis along with an average log-likelihood; and *P* is the
|
| 47 |
+
positional score per token position, including the
|
| 48 |
+
end-of-sentence marker which is omitted from the text.
|
| 49 |
+
|
| 50 |
+
Other types of output lines you might see are *D*, the detokenized hypothesis,
|
| 51 |
+
*T*, the reference target, *A*, alignment info, *E* the history of generation steps.
|
| 52 |
+
|
| 53 |
+
See the `README <https://github.com/pytorch/fairseq#pre-trained-models>`__ for a
|
| 54 |
+
full list of pre-trained models available.
|
| 55 |
+
|
| 56 |
+
Training a New Model
|
| 57 |
+
====================
|
| 58 |
+
|
| 59 |
+
The following tutorial is for machine translation. For an example of how
|
| 60 |
+
to use Fairseq for other tasks, such as :ref:`language modeling`, please see the
|
| 61 |
+
``examples/`` directory.
|
| 62 |
+
|
| 63 |
+
Data Pre-processing
|
| 64 |
+
-------------------
|
| 65 |
+
|
| 66 |
+
Fairseq contains example pre-processing scripts for several translation
|
| 67 |
+
datasets: IWSLT 2014 (German-English), WMT 2014 (English-French) and WMT
|
| 68 |
+
2014 (English-German). To pre-process and binarize the IWSLT dataset:
|
| 69 |
+
|
| 70 |
+
.. code-block:: console
|
| 71 |
+
|
| 72 |
+
> cd examples/translation/
|
| 73 |
+
> bash prepare-iwslt14.sh
|
| 74 |
+
> cd ../..
|
| 75 |
+
> TEXT=examples/translation/iwslt14.tokenized.de-en
|
| 76 |
+
> fairseq-preprocess --source-lang de --target-lang en \
|
| 77 |
+
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
|
| 78 |
+
--destdir data-bin/iwslt14.tokenized.de-en
|
| 79 |
+
|
| 80 |
+
This will write binarized data that can be used for model training to
|
| 81 |
+
``data-bin/iwslt14.tokenized.de-en``.
|
| 82 |
+
|
| 83 |
+
Training
|
| 84 |
+
--------
|
| 85 |
+
|
| 86 |
+
Use :ref:`fairseq-train` to train a new model. Here a few example settings that work
|
| 87 |
+
well for the IWSLT 2014 dataset:
|
| 88 |
+
|
| 89 |
+
.. code-block:: console
|
| 90 |
+
|
| 91 |
+
> mkdir -p checkpoints/fconv
|
| 92 |
+
> CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt14.tokenized.de-en \
|
| 93 |
+
--optimizer nag --lr 0.25 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \
|
| 94 |
+
--arch fconv_iwslt_de_en --save-dir checkpoints/fconv
|
| 95 |
+
|
| 96 |
+
By default, :ref:`fairseq-train` will use all available GPUs on your machine. Use the
|
| 97 |
+
``CUDA_VISIBLE_DEVICES`` environment variable to select specific GPUs and/or to
|
| 98 |
+
change the number of GPU devices that will be used.
|
| 99 |
+
|
| 100 |
+
Also note that the batch size is specified in terms of the maximum
|
| 101 |
+
number of tokens per batch (``--max-tokens``). You may need to use a
|
| 102 |
+
smaller value depending on the available GPU memory on your system.
|
| 103 |
+
|
| 104 |
+
Generation
|
| 105 |
+
----------
|
| 106 |
+
|
| 107 |
+
Once your model is trained, you can generate translations using
|
| 108 |
+
:ref:`fairseq-generate` **(for binarized data)** or
|
| 109 |
+
:ref:`fairseq-interactive` **(for raw text)**:
|
| 110 |
+
|
| 111 |
+
.. code-block:: console
|
| 112 |
+
|
| 113 |
+
> fairseq-generate data-bin/iwslt14.tokenized.de-en \
|
| 114 |
+
--path checkpoints/fconv/checkpoint_best.pt \
|
| 115 |
+
--batch-size 128 --beam 5
|
| 116 |
+
| [de] dictionary: 35475 types
|
| 117 |
+
| [en] dictionary: 24739 types
|
| 118 |
+
| data-bin/iwslt14.tokenized.de-en test 6750 examples
|
| 119 |
+
| model fconv
|
| 120 |
+
| loaded checkpoint trainings/fconv/checkpoint_best.pt
|
| 121 |
+
S-721 danke .
|
| 122 |
+
T-721 thank you .
|
| 123 |
+
...
|
| 124 |
+
|
| 125 |
+
To generate translations with only a CPU, use the ``--cpu`` flag. BPE
|
| 126 |
+
continuation markers can be removed with the ``--remove-bpe`` flag.
|
| 127 |
+
|
| 128 |
+
Advanced Training Options
|
| 129 |
+
=========================
|
| 130 |
+
|
| 131 |
+
Large mini-batch training with delayed updates
|
| 132 |
+
----------------------------------------------
|
| 133 |
+
|
| 134 |
+
The ``--update-freq`` option can be used to accumulate gradients from
|
| 135 |
+
multiple mini-batches and delay updating, creating a larger effective
|
| 136 |
+
batch size. Delayed updates can also improve training speed by reducing
|
| 137 |
+
inter-GPU communication costs and by saving idle time caused by variance
|
| 138 |
+
in workload across GPUs. See `Ott et al.
|
| 139 |
+
(2018) <https://arxiv.org/abs/1806.00187>`__ for more details.
|
| 140 |
+
|
| 141 |
+
To train on a single GPU with an effective batch size that is equivalent
|
| 142 |
+
to training on 8 GPUs:
|
| 143 |
+
|
| 144 |
+
.. code-block:: console
|
| 145 |
+
|
| 146 |
+
> CUDA_VISIBLE_DEVICES=0 fairseq-train --update-freq 8 (...)
|
| 147 |
+
|
| 148 |
+
Training with half precision floating point (FP16)
|
| 149 |
+
--------------------------------------------------
|
| 150 |
+
|
| 151 |
+
.. note::
|
| 152 |
+
|
| 153 |
+
FP16 training requires a Volta GPU and CUDA 9.1 or greater
|
| 154 |
+
|
| 155 |
+
Recent GPUs enable efficient half precision floating point computation,
|
| 156 |
+
e.g., using `Nvidia Tensor Cores
|
| 157 |
+
<https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html>`__.
|
| 158 |
+
Fairseq supports FP16 training with the ``--fp16`` flag:
|
| 159 |
+
|
| 160 |
+
.. code-block:: console
|
| 161 |
+
|
| 162 |
+
> fairseq-train --fp16 (...)
|
| 163 |
+
|
| 164 |
+
Distributed training
|
| 165 |
+
--------------------
|
| 166 |
+
|
| 167 |
+
Distributed training in fairseq is implemented on top of ``torch.distributed``.
|
| 168 |
+
The easiest way to launch jobs is with the `torch.distributed.launch
|
| 169 |
+
<https://pytorch.org/docs/stable/distributed.html#launch-utility>`__ tool.
|
| 170 |
+
|
| 171 |
+
For example, to train a large English-German Transformer model on 2 nodes each
|
| 172 |
+
with 8 GPUs (in total 16 GPUs), run the following command on each node,
|
| 173 |
+
replacing ``node_rank=0`` with ``node_rank=1`` on the second node and making
|
| 174 |
+
sure to update ``--master_addr`` to the IP address of the first node:
|
| 175 |
+
|
| 176 |
+
.. code-block:: console
|
| 177 |
+
|
| 178 |
+
> python -m torch.distributed.launch --nproc_per_node=8 \
|
| 179 |
+
--nnodes=2 --node_rank=0 --master_addr="192.168.1.1" \
|
| 180 |
+
--master_port=12345 \
|
| 181 |
+
$(which fairseq-train) data-bin/wmt16_en_de_bpe32k \
|
| 182 |
+
--arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \
|
| 183 |
+
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
|
| 184 |
+
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \
|
| 185 |
+
--lr 0.0005 \
|
| 186 |
+
--dropout 0.3 --weight-decay 0.0 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
|
| 187 |
+
--max-tokens 3584 \
|
| 188 |
+
--max-epoch 70 \
|
| 189 |
+
--fp16
|
| 190 |
+
|
| 191 |
+
On SLURM clusters, fairseq will automatically detect the number of nodes and
|
| 192 |
+
GPUs, but a port number must be provided:
|
| 193 |
+
|
| 194 |
+
.. code-block:: console
|
| 195 |
+
|
| 196 |
+
> salloc --gpus=16 --nodes 2 (...)
|
| 197 |
+
> srun fairseq-train --distributed-port 12345 (...).
|
| 198 |
+
|
| 199 |
+
Sharding very large datasets
|
| 200 |
+
----------------------------
|
| 201 |
+
|
| 202 |
+
It can be challenging to train over very large datasets, particularly if your
|
| 203 |
+
machine does not have much system RAM. Most tasks in fairseq support training
|
| 204 |
+
over "sharded" datasets, in which the original dataset has been preprocessed
|
| 205 |
+
into non-overlapping chunks (or "shards").
|
| 206 |
+
|
| 207 |
+
For example, instead of preprocessing all your data into a single "data-bin"
|
| 208 |
+
directory, you can split the data and create "data-bin1", "data-bin2", etc.
|
| 209 |
+
Then you can adapt your training command like so:
|
| 210 |
+
|
| 211 |
+
.. code-block:: console
|
| 212 |
+
|
| 213 |
+
> fairseq-train data-bin1:data-bin2:data-bin3 (...)
|
| 214 |
+
|
| 215 |
+
Training will now iterate over each shard, one by one, with each shard
|
| 216 |
+
corresponding to an "epoch", thus reducing system memory usage.
|
fairseq/docs/hydra_integration.md
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Hydra
|
| 2 |
+
|
| 3 |
+
[Hydra](https://github.com/facebookresearch/hydra) is an open-source Python
|
| 4 |
+
framework that simplifies the development of research and other complex
|
| 5 |
+
applications. The key feature is the ability to dynamically create a
|
| 6 |
+
hierarchical configuration by composition and override it through config files
|
| 7 |
+
and the command line. The name Hydra comes from its ability to run multiple
|
| 8 |
+
similar jobs - much like a Hydra with multiple heads.
|
| 9 |
+
|
| 10 |
+
## Motivation
|
| 11 |
+
|
| 12 |
+
Until recently, all components in fairseq were configured through a shared
|
| 13 |
+
`args` namespace that was created at application startup. Components declared
|
| 14 |
+
their own `add_args` method to update the argparse parser, hoping that the names
|
| 15 |
+
would not clash with arguments from other components. While this model works for
|
| 16 |
+
smaller applications, as fairseq grew and became integrated into other
|
| 17 |
+
applications, this became problematic. In order to determine how to configure
|
| 18 |
+
each component, one needed to a) examine what args were added by this component,
|
| 19 |
+
and b) read the code to figure out what shared arguments it is using that were
|
| 20 |
+
added in other places. Reproducing models involved sharing commands that often
|
| 21 |
+
contained dozens of command line switches.
|
| 22 |
+
|
| 23 |
+
The model described above is still supported by fairseq for backward
|
| 24 |
+
compatibility, but will be deprecated some time in the future.
|
| 25 |
+
|
| 26 |
+
New components in fairseq should now create a dataclass that encapsulates all
|
| 27 |
+
parameters required to configure this component. The dataclass is registered
|
| 28 |
+
along with the component, and fairseq takes care of constructing and providing
|
| 29 |
+
this configuration object to the component's constructor. Note that sharing
|
| 30 |
+
parameters can optionally still work, but one has to explicitly point to the
|
| 31 |
+
"source of truth" (see inheritance example below). These changes make components
|
| 32 |
+
in fairseq more independent and re-usable by other applications: all that is
|
| 33 |
+
needed to create a component is to initialize its dataclass and overwrite some
|
| 34 |
+
of the defaults.
|
| 35 |
+
|
| 36 |
+
While configuring fairseq through command line (using either the legacy argparse
|
| 37 |
+
based or the new Hydra based entry points) is still fully supported, you can now
|
| 38 |
+
take advantage of configuring fairseq completely or piece-by-piece through
|
| 39 |
+
hierarchical YAML configuration files. These files can also be shipped as
|
| 40 |
+
examples that others can use to run an identically configured job.
|
| 41 |
+
|
| 42 |
+
Additionally, Hydra has a rich and growing [library of
|
| 43 |
+
plugins](https://github.com/facebookresearch/hydra/tree/master/plugins) that
|
| 44 |
+
provide functionality such as hyperparameter sweeping (including using bayesian
|
| 45 |
+
optimization through the [Ax](https://github.com/facebook/Ax) library), job
|
| 46 |
+
launching across various platforms, and more.
|
| 47 |
+
|
| 48 |
+
## Creating or migrating components
|
| 49 |
+
|
| 50 |
+
In general, each new (or updated) component should provide a companion
|
| 51 |
+
[dataclass](https://www.python.org/dev/peps/pep-0557/). These dataclass are
|
| 52 |
+
typically located in the same file as the component and are passed as arguments
|
| 53 |
+
to the `register_*()` functions. Top-level configs that should be present in
|
| 54 |
+
every fairseq application are placed in the
|
| 55 |
+
[global](fairseq/dataclass/configs.py) config file and added to the
|
| 56 |
+
`FairseqConfig` object.
|
| 57 |
+
|
| 58 |
+
Each dataclass is a plain-old-data object, similar to a `NamedTuple`. These
|
| 59 |
+
classes are decorated with a `@dataclass` decorator, and typically inherit from
|
| 60 |
+
`FairseqDataclass` (which adds some functionality for backward compatibility).
|
| 61 |
+
Each field must have a type, and generally has metadata (such as a help string)
|
| 62 |
+
and a default value. Only primitive types or other config objects are allowed as
|
| 63 |
+
data types for each field.
|
| 64 |
+
|
| 65 |
+
#### Example:
|
| 66 |
+
|
| 67 |
+
```python
|
| 68 |
+
from dataclasses import dataclass, field
|
| 69 |
+
from fairseq.dataclass import FairseqDataclass
|
| 70 |
+
|
| 71 |
+
@dataclass
|
| 72 |
+
class InteractiveConfig(FairseqDataclass):
|
| 73 |
+
buffer_size: int = field(
|
| 74 |
+
default=0,
|
| 75 |
+
metadata={
|
| 76 |
+
"help": "read this many sentences into a buffer before processing them"
|
| 77 |
+
},
|
| 78 |
+
)
|
| 79 |
+
input: str = field(
|
| 80 |
+
default="-",
|
| 81 |
+
metadata={"help": "file to read from; use - for stdin"},
|
| 82 |
+
)
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
### Inherting values
|
| 86 |
+
|
| 87 |
+
Some components require sharing a value. For example, a learning rate scheduler
|
| 88 |
+
and an optimizer may both need to know the initial learning rate value. One can
|
| 89 |
+
declare a field that, by default, will inherit its value from another config
|
| 90 |
+
node in the same hierarchy:
|
| 91 |
+
|
| 92 |
+
```python
|
| 93 |
+
@dataclass
|
| 94 |
+
FairseqAdamConfig(FairseqDataclass):
|
| 95 |
+
...
|
| 96 |
+
lr: List[float] = II("optimization.lr")
|
| 97 |
+
...
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
`II("optimization.lr")` is syntactic sugar for `"${optimization.lr}"`, which is
|
| 101 |
+
the value one can use in a YAML config file or through command line to achieve
|
| 102 |
+
the same effect. Note that this assumes that there is an "optimization" config
|
| 103 |
+
object in the root config and it has a field called "lr".
|
| 104 |
+
|
| 105 |
+
### Tasks and Models
|
| 106 |
+
|
| 107 |
+
Creating Tasks and Models works same as before, except that legacy
|
| 108 |
+
implementations now inherit from `LegacyFairseq*` base classes, while new
|
| 109 |
+
components inherit from `FairseqTask` and `FairseqModel` and provide a dataclass
|
| 110 |
+
to the `register_*()` functions.
|
| 111 |
+
|
| 112 |
+
#### Task example:
|
| 113 |
+
|
| 114 |
+
```python
|
| 115 |
+
@dataclass
|
| 116 |
+
class LanguageModelingConfig(FairseqDataclass):
|
| 117 |
+
data: Optional[str] = field(
|
| 118 |
+
default=None, metadata={"help": "path to data directory"}
|
| 119 |
+
)
|
| 120 |
+
...
|
| 121 |
+
|
| 122 |
+
@register_task("language_modeling", dataclass=LanguageModelingConfig)
|
| 123 |
+
class LanguageModelingTask(FairseqTask):
|
| 124 |
+
...
|
| 125 |
+
@classmethod
|
| 126 |
+
def setup_task(cls, cfg: LanguageModelingConfig):
|
| 127 |
+
...
|
| 128 |
+
```
|
| 129 |
+
|
| 130 |
+
#### Model example:
|
| 131 |
+
|
| 132 |
+
```python
|
| 133 |
+
@dataclass
|
| 134 |
+
class TransformerLanguageModelConfig(FairseqDataclass):
|
| 135 |
+
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
|
| 136 |
+
default="relu", metadata={"help": "activation function to use"}
|
| 137 |
+
)
|
| 138 |
+
dropout: float = field(default=0.1, metadata={"help": "dropout probability"})
|
| 139 |
+
...
|
| 140 |
+
|
| 141 |
+
@register_model("transformer_lm", dataclass=TransformerLanguageModelConfig)
|
| 142 |
+
class TransformerLanguageModel(FairseqLanguageModel):
|
| 143 |
+
...
|
| 144 |
+
@classmethod
|
| 145 |
+
def build_model(cls, cfg: TransformerLanguageModelConfig, task: FairseqTask):
|
| 146 |
+
...
|
| 147 |
+
```
|
| 148 |
+
|
| 149 |
+
### Other components
|
| 150 |
+
|
| 151 |
+
Other components work as before, but they now take their configuration dataclass
|
| 152 |
+
as the only constructor argument:
|
| 153 |
+
|
| 154 |
+
```python
|
| 155 |
+
@dataclass
|
| 156 |
+
class MosesTokenizerConfig(FairseqDataclass):
|
| 157 |
+
source_lang: str = field(default="en", metadata={"help": "source language"})
|
| 158 |
+
...
|
| 159 |
+
|
| 160 |
+
@register_tokenizer("moses", dataclass=MosesTokenizerConfig)
|
| 161 |
+
class MosesTokenizer(object):
|
| 162 |
+
def __init__(self, cfg: MosesTokenizerConfig):
|
| 163 |
+
...
|
| 164 |
+
```
|
| 165 |
+
|
| 166 |
+
Note that if you are adding a new registry for a new set of components, you need
|
| 167 |
+
to add it to the `FairseqConfig` object in `fairseq/dataclass/configs.py`:
|
| 168 |
+
|
| 169 |
+
```python
|
| 170 |
+
@dataclass
|
| 171 |
+
class FairseqConfig(object):
|
| 172 |
+
...
|
| 173 |
+
my_new_registry: Any = None
|
| 174 |
+
```
|
| 175 |
+
|
| 176 |
+
## Training with `fairseq-hydra-train`
|
| 177 |
+
|
| 178 |
+
To fully take advantage of configuration flexibility offered by Hydra, you may
|
| 179 |
+
want to train new models using the `fairseq-hydra-train` entry point. Legacy CLI
|
| 180 |
+
tools such as `fairseq-train` will remain supported for the foreseeable future
|
| 181 |
+
but will be deprecated eventually.
|
| 182 |
+
|
| 183 |
+
On startup, Hydra will create a configuration object that contains a hierarchy
|
| 184 |
+
of all the necessary dataclasses populated with their default values in the
|
| 185 |
+
code. The default values are overwritten by values found in YAML files in
|
| 186 |
+
`fairseq/config` directory (which currently sets minimal defaults) and then
|
| 187 |
+
further overwritten by values provided through command line arguments.
|
| 188 |
+
|
| 189 |
+
Some of the most common use cases are shown below:
|
| 190 |
+
|
| 191 |
+
### 1. Override default values through command line:
|
| 192 |
+
|
| 193 |
+
```shell script
|
| 194 |
+
$ fairseq-hydra-train \
|
| 195 |
+
distributed_training.distributed_world_size=1 \
|
| 196 |
+
dataset.batch_size=2 \
|
| 197 |
+
task.data=data-bin \
|
| 198 |
+
model=transformer_lm/transformer_lm_gpt \
|
| 199 |
+
task=language_modeling \
|
| 200 |
+
optimization.max_update=5000
|
| 201 |
+
```
|
| 202 |
+
|
| 203 |
+
Note that along with explicitly providing values for parameters such as
|
| 204 |
+
`dataset.batch_size`, this also tells Hydra to overlay configuration found in
|
| 205 |
+
`fairseq/config/model/transformer_lm/transformer_lm_gpt.yaml` over the default
|
| 206 |
+
values in the dataclass. If you want to train a model without specifying a
|
| 207 |
+
particular architecture you can simply specify `model=transformer_lm`. This only
|
| 208 |
+
works for migrated tasks and models.
|
| 209 |
+
|
| 210 |
+
### 2. Replace bundled configs with an external config:
|
| 211 |
+
|
| 212 |
+
```shell script
|
| 213 |
+
$ fairseq-hydra-train \
|
| 214 |
+
--config-dir /path/to/external/configs \
|
| 215 |
+
--config-name wiki103
|
| 216 |
+
```
|
| 217 |
+
|
| 218 |
+
where `/path/to/external/configs/wiki103.yaml` contains:
|
| 219 |
+
|
| 220 |
+
```yaml
|
| 221 |
+
# @package _group_
|
| 222 |
+
|
| 223 |
+
model:
|
| 224 |
+
_name: transformer_lm
|
| 225 |
+
distributed_training:
|
| 226 |
+
distributed_world_size: 1
|
| 227 |
+
dataset:
|
| 228 |
+
batch_size: 2
|
| 229 |
+
task:
|
| 230 |
+
_name: language_modeling
|
| 231 |
+
data: /path/to/data
|
| 232 |
+
add_bos_token: false
|
| 233 |
+
max_target_positions: 1024
|
| 234 |
+
optimization:
|
| 235 |
+
max_update: 50000
|
| 236 |
+
lr: [ 0.25 ]
|
| 237 |
+
criterion: cross_entropy
|
| 238 |
+
optimizer: adam
|
| 239 |
+
lr_scheduler:
|
| 240 |
+
_name: cosine
|
| 241 |
+
```
|
| 242 |
+
|
| 243 |
+
Note that here bundled configs from `fairseq/config` directory are not used,
|
| 244 |
+
however the defaults from each dataclass will still be used (unless overwritten
|
| 245 |
+
by your external config).
|
| 246 |
+
|
| 247 |
+
Additionally you can choose to break up your configs by creating a directory
|
| 248 |
+
structure in the same location as your main config file, with the names of the
|
| 249 |
+
top-level fields (such as "model", "dataset", etc), and placing config files
|
| 250 |
+
with meaningful names that would populate that specific section of your
|
| 251 |
+
top-level config file (for example, you might have
|
| 252 |
+
`model/small_transformer_lm.yaml`, `model/big_transformer_lm.yaml`, etc). You
|
| 253 |
+
can then specify the correct configuration via command line, defaults in the
|
| 254 |
+
main config, or even launch all of them as a sweep (see Hydra documentation on
|
| 255 |
+
how to do this).
|
| 256 |
+
|
| 257 |
+
### 3. Add an external config directory to Hydra search path:
|
| 258 |
+
|
| 259 |
+
This allows combining default configuration (including using any bundled config
|
| 260 |
+
files), while specifying your own config files for some parts of the
|
| 261 |
+
configuration.
|
| 262 |
+
|
| 263 |
+
```shell script
|
| 264 |
+
$ fairseq-hydra-train \
|
| 265 |
+
distributed_training.distributed_world_size=1 \
|
| 266 |
+
dataset.batch_size=2 \
|
| 267 |
+
task.data=/path/to/data/ \
|
| 268 |
+
model=transformer_lm/2_layers \
|
| 269 |
+
task=language_modeling \
|
| 270 |
+
optimization.max_update=5000 \
|
| 271 |
+
--config-dir /path/to/external/configs
|
| 272 |
+
```
|
| 273 |
+
|
| 274 |
+
where `/path/to/external/configs` has the following structure:
|
| 275 |
+
```
|
| 276 |
+
.
|
| 277 |
+
+-- model
|
| 278 |
+
| +-- transformer_lm
|
| 279 |
+
| | +-- 2_layers.yaml
|
| 280 |
+
```
|
| 281 |
+
|
| 282 |
+
and `2_layers.yaml` contains a copy of `transformer_lm_gpt.yaml` but with
|
| 283 |
+
`decoder_layers` set to 2. You can add other configs to configure other
|
| 284 |
+
components as well.
|
fairseq/docs/index.rst
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.. fairseq documentation master file, created by
|
| 2 |
+
sphinx-quickstart on Fri Aug 17 21:45:30 2018.
|
| 3 |
+
You can adapt this file completely to your liking, but it should at least
|
| 4 |
+
contain the root `toctree` directive.
|
| 5 |
+
|
| 6 |
+
:github_url: https://github.com/pytorch/fairseq
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
fairseq documentation
|
| 10 |
+
=====================
|
| 11 |
+
|
| 12 |
+
Fairseq is a sequence modeling toolkit written in `PyTorch
|
| 13 |
+
<http://pytorch.org/>`_ that allows researchers and developers to
|
| 14 |
+
train custom models for translation, summarization, language modeling and other
|
| 15 |
+
text generation tasks.
|
| 16 |
+
|
| 17 |
+
.. toctree::
|
| 18 |
+
:maxdepth: 1
|
| 19 |
+
:caption: Getting Started
|
| 20 |
+
|
| 21 |
+
getting_started
|
| 22 |
+
command_line_tools
|
| 23 |
+
|
| 24 |
+
.. toctree::
|
| 25 |
+
:maxdepth: 1
|
| 26 |
+
:caption: Extending Fairseq
|
| 27 |
+
|
| 28 |
+
overview
|
| 29 |
+
tutorial_simple_lstm
|
| 30 |
+
tutorial_classifying_names
|
| 31 |
+
|
| 32 |
+
.. toctree::
|
| 33 |
+
:maxdepth: 2
|
| 34 |
+
:caption: Library Reference
|
| 35 |
+
|
| 36 |
+
tasks
|
| 37 |
+
models
|
| 38 |
+
criterions
|
| 39 |
+
optim
|
| 40 |
+
lr_scheduler
|
| 41 |
+
data
|
| 42 |
+
modules
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
Indices and tables
|
| 46 |
+
==================
|
| 47 |
+
|
| 48 |
+
* :ref:`genindex`
|
| 49 |
+
* :ref:`search`
|
fairseq/docs/lr_scheduler.rst
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.. role:: hidden
|
| 2 |
+
:class: hidden-section
|
| 3 |
+
|
| 4 |
+
.. _Learning Rate Schedulers:
|
| 5 |
+
|
| 6 |
+
Learning Rate Schedulers
|
| 7 |
+
========================
|
| 8 |
+
|
| 9 |
+
Learning Rate Schedulers update the learning rate over the course of training.
|
| 10 |
+
Learning rates can be updated after each update via :func:`step_update` or at
|
| 11 |
+
epoch boundaries via :func:`step`.
|
| 12 |
+
|
| 13 |
+
.. automodule:: fairseq.optim.lr_scheduler
|
| 14 |
+
:members:
|
| 15 |
+
|
| 16 |
+
.. autoclass:: fairseq.optim.lr_scheduler.FairseqLRScheduler
|
| 17 |
+
:members:
|
| 18 |
+
:undoc-members:
|
| 19 |
+
|
| 20 |
+
.. autoclass:: fairseq.optim.lr_scheduler.cosine_lr_scheduler.CosineSchedule
|
| 21 |
+
:members:
|
| 22 |
+
:undoc-members:
|
| 23 |
+
.. autoclass:: fairseq.optim.lr_scheduler.fixed_schedule.FixedSchedule
|
| 24 |
+
:members:
|
| 25 |
+
:undoc-members:
|
| 26 |
+
.. autoclass:: fairseq.optim.lr_scheduler.inverse_square_root_schedule.InverseSquareRootSchedule
|
| 27 |
+
:members:
|
| 28 |
+
:undoc-members:
|
| 29 |
+
.. autoclass:: fairseq.optim.lr_scheduler.reduce_lr_on_plateau.ReduceLROnPlateau
|
| 30 |
+
:members:
|
| 31 |
+
:undoc-members:
|
| 32 |
+
.. autoclass:: fairseq.optim.lr_scheduler.triangular_lr_scheduler.TriangularSchedule
|
| 33 |
+
:members:
|
| 34 |
+
:undoc-members:
|
fairseq/docs/make.bat
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@ECHO OFF
|
| 2 |
+
|
| 3 |
+
pushd %~dp0
|
| 4 |
+
|
| 5 |
+
REM Command file for Sphinx documentation
|
| 6 |
+
|
| 7 |
+
if "%SPHINXBUILD%" == "" (
|
| 8 |
+
set SPHINXBUILD=python -msphinx
|
| 9 |
+
)
|
| 10 |
+
set SOURCEDIR=.
|
| 11 |
+
set BUILDDIR=_build
|
| 12 |
+
set SPHINXPROJ=fairseq
|
| 13 |
+
|
| 14 |
+
if "%1" == "" goto help
|
| 15 |
+
|
| 16 |
+
%SPHINXBUILD% >NUL 2>NUL
|
| 17 |
+
if errorlevel 9009 (
|
| 18 |
+
echo.
|
| 19 |
+
echo.The Sphinx module was not found. Make sure you have Sphinx installed,
|
| 20 |
+
echo.then set the SPHINXBUILD environment variable to point to the full
|
| 21 |
+
echo.path of the 'sphinx-build' executable. Alternatively you may add the
|
| 22 |
+
echo.Sphinx directory to PATH.
|
| 23 |
+
echo.
|
| 24 |
+
echo.If you don't have Sphinx installed, grab it from
|
| 25 |
+
echo.http://sphinx-doc.org/
|
| 26 |
+
exit /b 1
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
|
| 30 |
+
goto end
|
| 31 |
+
|
| 32 |
+
:help
|
| 33 |
+
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
|
| 34 |
+
|
| 35 |
+
:end
|
| 36 |
+
popd
|
fairseq/docs/models.rst
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.. role:: hidden
|
| 2 |
+
:class: hidden-section
|
| 3 |
+
|
| 4 |
+
.. module:: fairseq.models
|
| 5 |
+
|
| 6 |
+
.. _Models:
|
| 7 |
+
|
| 8 |
+
Models
|
| 9 |
+
======
|
| 10 |
+
|
| 11 |
+
A Model defines the neural network's ``forward()`` method and encapsulates all
|
| 12 |
+
of the learnable parameters in the network. Each model also provides a set of
|
| 13 |
+
named *architectures* that define the precise network configuration (e.g.,
|
| 14 |
+
embedding dimension, number of layers, etc.).
|
| 15 |
+
|
| 16 |
+
Both the model type and architecture are selected via the ``--arch``
|
| 17 |
+
command-line argument. Once selected, a model may expose additional command-line
|
| 18 |
+
arguments for further configuration.
|
| 19 |
+
|
| 20 |
+
.. note::
|
| 21 |
+
|
| 22 |
+
All fairseq Models extend :class:`BaseFairseqModel`, which in turn extends
|
| 23 |
+
:class:`torch.nn.Module`. Thus any fairseq Model can be used as a
|
| 24 |
+
stand-alone Module in other PyTorch code.
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
Convolutional Neural Networks (CNN)
|
| 28 |
+
-----------------------------------
|
| 29 |
+
|
| 30 |
+
.. module:: fairseq.models.fconv
|
| 31 |
+
.. autoclass:: fairseq.models.fconv.FConvModel
|
| 32 |
+
:members:
|
| 33 |
+
.. autoclass:: fairseq.models.fconv.FConvEncoder
|
| 34 |
+
:members:
|
| 35 |
+
:undoc-members:
|
| 36 |
+
.. autoclass:: fairseq.models.fconv.FConvDecoder
|
| 37 |
+
:members:
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
Long Short-Term Memory (LSTM) networks
|
| 41 |
+
--------------------------------------
|
| 42 |
+
|
| 43 |
+
.. module:: fairseq.models.lstm
|
| 44 |
+
.. autoclass:: fairseq.models.lstm.LSTMModel
|
| 45 |
+
:members:
|
| 46 |
+
.. autoclass:: fairseq.models.lstm.LSTMEncoder
|
| 47 |
+
:members:
|
| 48 |
+
.. autoclass:: fairseq.models.lstm.LSTMDecoder
|
| 49 |
+
:members:
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
Transformer (self-attention) networks
|
| 53 |
+
-------------------------------------
|
| 54 |
+
|
| 55 |
+
.. module:: fairseq.models.transformer
|
| 56 |
+
.. autoclass:: fairseq.models.transformer.TransformerModel
|
| 57 |
+
:members:
|
| 58 |
+
.. autoclass:: fairseq.models.transformer.TransformerEncoder
|
| 59 |
+
:members:
|
| 60 |
+
.. autoclass:: fairseq.models.transformer.TransformerEncoderLayer
|
| 61 |
+
:members:
|
| 62 |
+
.. autoclass:: fairseq.models.transformer.TransformerDecoder
|
| 63 |
+
:members:
|
| 64 |
+
.. autoclass:: fairseq.models.transformer.TransformerDecoderLayer
|
| 65 |
+
:members:
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
Adding new models
|
| 69 |
+
-----------------
|
| 70 |
+
|
| 71 |
+
.. currentmodule:: fairseq.models
|
| 72 |
+
.. autofunction:: fairseq.models.register_model
|
| 73 |
+
.. autofunction:: fairseq.models.register_model_architecture
|
| 74 |
+
.. autoclass:: fairseq.models.BaseFairseqModel
|
| 75 |
+
:members:
|
| 76 |
+
:undoc-members:
|
| 77 |
+
.. autoclass:: fairseq.models.FairseqEncoderDecoderModel
|
| 78 |
+
:members:
|
| 79 |
+
:undoc-members:
|
| 80 |
+
.. autoclass:: fairseq.models.FairseqEncoderModel
|
| 81 |
+
:members:
|
| 82 |
+
:undoc-members:
|
| 83 |
+
.. autoclass:: fairseq.models.FairseqLanguageModel
|
| 84 |
+
:members:
|
| 85 |
+
:undoc-members:
|
| 86 |
+
.. autoclass:: fairseq.models.FairseqMultiModel
|
| 87 |
+
:members:
|
| 88 |
+
:undoc-members:
|
| 89 |
+
.. autoclass:: fairseq.models.FairseqEncoder
|
| 90 |
+
:members:
|
| 91 |
+
.. autoclass:: fairseq.models.CompositeEncoder
|
| 92 |
+
:members:
|
| 93 |
+
.. autoclass:: fairseq.models.FairseqDecoder
|
| 94 |
+
:members:
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
.. _Incremental decoding:
|
| 98 |
+
|
| 99 |
+
Incremental decoding
|
| 100 |
+
--------------------
|
| 101 |
+
|
| 102 |
+
.. autoclass:: fairseq.models.FairseqIncrementalDecoder
|
| 103 |
+
:members:
|
| 104 |
+
:undoc-members:
|
fairseq/docs/modules.rst
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Modules
|
| 2 |
+
=======
|
| 3 |
+
|
| 4 |
+
Fairseq provides several stand-alone :class:`torch.nn.Module` classes that may
|
| 5 |
+
be helpful when implementing a new :class:`~fairseq.models.BaseFairseqModel`.
|
| 6 |
+
|
| 7 |
+
.. automodule:: fairseq.modules
|
| 8 |
+
:members:
|
| 9 |
+
:undoc-members:
|
fairseq/docs/optim.rst
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.. role:: hidden
|
| 2 |
+
:class: hidden-section
|
| 3 |
+
|
| 4 |
+
.. _optimizers:
|
| 5 |
+
|
| 6 |
+
Optimizers
|
| 7 |
+
==========
|
| 8 |
+
|
| 9 |
+
Optimizers update the Model parameters based on the gradients.
|
| 10 |
+
|
| 11 |
+
.. automodule:: fairseq.optim
|
| 12 |
+
:members:
|
| 13 |
+
|
| 14 |
+
.. autoclass:: fairseq.optim.FairseqOptimizer
|
| 15 |
+
:members:
|
| 16 |
+
:undoc-members:
|
| 17 |
+
|
| 18 |
+
.. autoclass:: fairseq.optim.adadelta.Adadelta
|
| 19 |
+
:members:
|
| 20 |
+
:undoc-members:
|
| 21 |
+
.. autoclass:: fairseq.optim.adagrad.Adagrad
|
| 22 |
+
:members:
|
| 23 |
+
:undoc-members:
|
| 24 |
+
.. autoclass:: fairseq.optim.adafactor.FairseqAdafactor
|
| 25 |
+
:members:
|
| 26 |
+
:undoc-members:
|
| 27 |
+
.. autoclass:: fairseq.optim.adam.FairseqAdam
|
| 28 |
+
:members:
|
| 29 |
+
:undoc-members:
|
| 30 |
+
.. autoclass:: fairseq.optim.fp16_optimizer.FP16Optimizer
|
| 31 |
+
:members:
|
| 32 |
+
:undoc-members:
|
| 33 |
+
.. autoclass:: fairseq.optim.nag.FairseqNAG
|
| 34 |
+
:members:
|
| 35 |
+
:undoc-members:
|
| 36 |
+
.. autoclass:: fairseq.optim.sgd.SGD
|
| 37 |
+
:members:
|
| 38 |
+
:undoc-members:
|
fairseq/docs/overview.rst
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Overview
|
| 2 |
+
========
|
| 3 |
+
|
| 4 |
+
Fairseq can be extended through user-supplied `plug-ins
|
| 5 |
+
<https://en.wikipedia.org/wiki/Plug-in_(computing)>`_. We support five kinds of
|
| 6 |
+
plug-ins:
|
| 7 |
+
|
| 8 |
+
- :ref:`Models` define the neural network architecture and encapsulate all of the
|
| 9 |
+
learnable parameters.
|
| 10 |
+
- :ref:`Criterions` compute the loss function given the model outputs and targets.
|
| 11 |
+
- :ref:`Tasks` store dictionaries and provide helpers for loading/iterating over
|
| 12 |
+
Datasets, initializing the Model/Criterion and calculating the loss.
|
| 13 |
+
- :ref:`Optimizers` update the Model parameters based on the gradients.
|
| 14 |
+
- :ref:`Learning Rate Schedulers` update the learning rate over the course of
|
| 15 |
+
training.
|
| 16 |
+
|
| 17 |
+
**Training Flow**
|
| 18 |
+
|
| 19 |
+
Given a ``model``, ``criterion``, ``task``, ``optimizer`` and ``lr_scheduler``,
|
| 20 |
+
fairseq implements the following high-level training flow::
|
| 21 |
+
|
| 22 |
+
for epoch in range(num_epochs):
|
| 23 |
+
itr = task.get_batch_iterator(task.dataset('train'))
|
| 24 |
+
for num_updates, batch in enumerate(itr):
|
| 25 |
+
task.train_step(batch, model, criterion, optimizer)
|
| 26 |
+
average_and_clip_gradients()
|
| 27 |
+
optimizer.step()
|
| 28 |
+
lr_scheduler.step_update(num_updates)
|
| 29 |
+
lr_scheduler.step(epoch)
|
| 30 |
+
|
| 31 |
+
where the default implementation for ``task.train_step`` is roughly::
|
| 32 |
+
|
| 33 |
+
def train_step(self, batch, model, criterion, optimizer, **unused):
|
| 34 |
+
loss = criterion(model, batch)
|
| 35 |
+
optimizer.backward(loss)
|
| 36 |
+
return loss
|
| 37 |
+
|
| 38 |
+
**Registering new plug-ins**
|
| 39 |
+
|
| 40 |
+
New plug-ins are *registered* through a set of ``@register`` function
|
| 41 |
+
decorators, for example::
|
| 42 |
+
|
| 43 |
+
@register_model('my_lstm')
|
| 44 |
+
class MyLSTM(FairseqEncoderDecoderModel):
|
| 45 |
+
(...)
|
| 46 |
+
|
| 47 |
+
Once registered, new plug-ins can be used with the existing :ref:`Command-line
|
| 48 |
+
Tools`. See the Tutorial sections for more detailed walkthroughs of how to add
|
| 49 |
+
new plug-ins.
|
| 50 |
+
|
| 51 |
+
**Loading plug-ins from another directory**
|
| 52 |
+
|
| 53 |
+
New plug-ins can be defined in a custom module stored in the user system. In
|
| 54 |
+
order to import the module, and make the plugin available to *fairseq*, the
|
| 55 |
+
command line supports the ``--user-dir`` flag that can be used to specify a
|
| 56 |
+
custom location for additional modules to load into *fairseq*.
|
| 57 |
+
|
| 58 |
+
For example, assuming this directory tree::
|
| 59 |
+
|
| 60 |
+
/home/user/my-module/
|
| 61 |
+
└── __init__.py
|
| 62 |
+
|
| 63 |
+
with ``__init__.py``::
|
| 64 |
+
|
| 65 |
+
from fairseq.models import register_model_architecture
|
| 66 |
+
from fairseq.models.transformer import transformer_vaswani_wmt_en_de_big
|
| 67 |
+
|
| 68 |
+
@register_model_architecture('transformer', 'my_transformer')
|
| 69 |
+
def transformer_mmt_big(args):
|
| 70 |
+
transformer_vaswani_wmt_en_de_big(args)
|
| 71 |
+
|
| 72 |
+
it is possible to invoke the :ref:`fairseq-train` script with the new architecture with::
|
| 73 |
+
|
| 74 |
+
fairseq-train ... --user-dir /home/user/my-module -a my_transformer --task translation
|
fairseq/docs/requirements.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
sphinx<2.0
|
| 2 |
+
sphinx-argparse
|
fairseq/docs/tasks.rst
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.. role:: hidden
|
| 2 |
+
:class: hidden-section
|
| 3 |
+
|
| 4 |
+
.. module:: fairseq.tasks
|
| 5 |
+
|
| 6 |
+
.. _Tasks:
|
| 7 |
+
|
| 8 |
+
Tasks
|
| 9 |
+
=====
|
| 10 |
+
|
| 11 |
+
Tasks store dictionaries and provide helpers for loading/iterating over
|
| 12 |
+
Datasets, initializing the Model/Criterion and calculating the loss.
|
| 13 |
+
|
| 14 |
+
Tasks can be selected via the ``--task`` command-line argument. Once selected, a
|
| 15 |
+
task may expose additional command-line arguments for further configuration.
|
| 16 |
+
|
| 17 |
+
Example usage::
|
| 18 |
+
|
| 19 |
+
# setup the task (e.g., load dictionaries)
|
| 20 |
+
task = fairseq.tasks.setup_task(args)
|
| 21 |
+
|
| 22 |
+
# build model and criterion
|
| 23 |
+
model = task.build_model(args)
|
| 24 |
+
criterion = task.build_criterion(args)
|
| 25 |
+
|
| 26 |
+
# load datasets
|
| 27 |
+
task.load_dataset('train')
|
| 28 |
+
task.load_dataset('valid')
|
| 29 |
+
|
| 30 |
+
# iterate over mini-batches of data
|
| 31 |
+
batch_itr = task.get_batch_iterator(
|
| 32 |
+
task.dataset('train'), max_tokens=4096,
|
| 33 |
+
)
|
| 34 |
+
for batch in batch_itr:
|
| 35 |
+
# compute the loss
|
| 36 |
+
loss, sample_size, logging_output = task.get_loss(
|
| 37 |
+
model, criterion, batch,
|
| 38 |
+
)
|
| 39 |
+
loss.backward()
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
Translation
|
| 43 |
+
-----------
|
| 44 |
+
|
| 45 |
+
.. autoclass:: fairseq.tasks.translation.TranslationTask
|
| 46 |
+
|
| 47 |
+
.. _language modeling:
|
| 48 |
+
|
| 49 |
+
Language Modeling
|
| 50 |
+
-----------------
|
| 51 |
+
|
| 52 |
+
.. autoclass:: fairseq.tasks.language_modeling.LanguageModelingTask
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
Adding new tasks
|
| 56 |
+
----------------
|
| 57 |
+
|
| 58 |
+
.. autofunction:: fairseq.tasks.register_task
|
| 59 |
+
.. autoclass:: fairseq.tasks.FairseqTask
|
| 60 |
+
:members:
|
| 61 |
+
:undoc-members:
|
fairseq/docs/tutorial_classifying_names.rst
ADDED
|
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Tutorial: Classifying Names with a Character-Level RNN
|
| 2 |
+
======================================================
|
| 3 |
+
|
| 4 |
+
In this tutorial we will extend fairseq to support *classification* tasks. In
|
| 5 |
+
particular we will re-implement the PyTorch tutorial for `Classifying Names with
|
| 6 |
+
a Character-Level RNN <https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html>`_
|
| 7 |
+
in fairseq. It is recommended to quickly skim that tutorial before beginning
|
| 8 |
+
this one.
|
| 9 |
+
|
| 10 |
+
This tutorial covers:
|
| 11 |
+
|
| 12 |
+
1. **Preprocessing the data** to create dictionaries.
|
| 13 |
+
2. **Registering a new Model** that encodes an input sentence with a simple RNN
|
| 14 |
+
and predicts the output label.
|
| 15 |
+
3. **Registering a new Task** that loads our dictionaries and dataset.
|
| 16 |
+
4. **Training the Model** using the existing command-line tools.
|
| 17 |
+
5. **Writing an evaluation script** that imports fairseq and allows us to
|
| 18 |
+
interactively evaluate our model on new inputs.
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
1. Preprocessing the data
|
| 22 |
+
-------------------------
|
| 23 |
+
|
| 24 |
+
The original tutorial provides raw data, but we'll work with a modified version
|
| 25 |
+
of the data that is already tokenized into characters and split into separate
|
| 26 |
+
train, valid and test sets.
|
| 27 |
+
|
| 28 |
+
Download and extract the data from here:
|
| 29 |
+
`tutorial_names.tar.gz <https://dl.fbaipublicfiles.com/fairseq/data/tutorial_names.tar.gz>`_
|
| 30 |
+
|
| 31 |
+
Once extracted, let's preprocess the data using the :ref:`fairseq-preprocess`
|
| 32 |
+
command-line tool to create the dictionaries. While this tool is primarily
|
| 33 |
+
intended for sequence-to-sequence problems, we're able to reuse it here by
|
| 34 |
+
treating the label as a "target" sequence of length 1. We'll also output the
|
| 35 |
+
preprocessed files in "raw" format using the ``--dataset-impl`` option to
|
| 36 |
+
enhance readability:
|
| 37 |
+
|
| 38 |
+
.. code-block:: console
|
| 39 |
+
|
| 40 |
+
> fairseq-preprocess \
|
| 41 |
+
--trainpref names/train --validpref names/valid --testpref names/test \
|
| 42 |
+
--source-lang input --target-lang label \
|
| 43 |
+
--destdir names-bin --dataset-impl raw
|
| 44 |
+
|
| 45 |
+
After running the above command you should see a new directory,
|
| 46 |
+
:file:`names-bin/`, containing the dictionaries for *inputs* and *labels*.
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
2. Registering a new Model
|
| 50 |
+
--------------------------
|
| 51 |
+
|
| 52 |
+
Next we'll register a new model in fairseq that will encode an input sentence
|
| 53 |
+
with a simple RNN and predict the output label. Compared to the original PyTorch
|
| 54 |
+
tutorial, our version will also work with batches of data and GPU Tensors.
|
| 55 |
+
|
| 56 |
+
First let's copy the simple RNN module implemented in the `PyTorch tutorial
|
| 57 |
+
<https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html#creating-the-network>`_.
|
| 58 |
+
Create a new file named :file:`fairseq/models/rnn_classifier.py` with the
|
| 59 |
+
following contents::
|
| 60 |
+
|
| 61 |
+
import torch
|
| 62 |
+
import torch.nn as nn
|
| 63 |
+
|
| 64 |
+
class RNN(nn.Module):
|
| 65 |
+
|
| 66 |
+
def __init__(self, input_size, hidden_size, output_size):
|
| 67 |
+
super(RNN, self).__init__()
|
| 68 |
+
|
| 69 |
+
self.hidden_size = hidden_size
|
| 70 |
+
|
| 71 |
+
self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
|
| 72 |
+
self.i2o = nn.Linear(input_size + hidden_size, output_size)
|
| 73 |
+
self.softmax = nn.LogSoftmax(dim=1)
|
| 74 |
+
|
| 75 |
+
def forward(self, input, hidden):
|
| 76 |
+
combined = torch.cat((input, hidden), 1)
|
| 77 |
+
hidden = self.i2h(combined)
|
| 78 |
+
output = self.i2o(combined)
|
| 79 |
+
output = self.softmax(output)
|
| 80 |
+
return output, hidden
|
| 81 |
+
|
| 82 |
+
def initHidden(self):
|
| 83 |
+
return torch.zeros(1, self.hidden_size)
|
| 84 |
+
|
| 85 |
+
We must also *register* this model with fairseq using the
|
| 86 |
+
:func:`~fairseq.models.register_model` function decorator. Once the model is
|
| 87 |
+
registered we'll be able to use it with the existing :ref:`Command-line Tools`.
|
| 88 |
+
|
| 89 |
+
All registered models must implement the :class:`~fairseq.models.BaseFairseqModel`
|
| 90 |
+
interface, so we'll create a small wrapper class in the same file and register
|
| 91 |
+
it in fairseq with the name ``'rnn_classifier'``::
|
| 92 |
+
|
| 93 |
+
from fairseq.models import BaseFairseqModel, register_model
|
| 94 |
+
|
| 95 |
+
# Note: the register_model "decorator" should immediately precede the
|
| 96 |
+
# definition of the Model class.
|
| 97 |
+
|
| 98 |
+
@register_model('rnn_classifier')
|
| 99 |
+
class FairseqRNNClassifier(BaseFairseqModel):
|
| 100 |
+
|
| 101 |
+
@staticmethod
|
| 102 |
+
def add_args(parser):
|
| 103 |
+
# Models can override this method to add new command-line arguments.
|
| 104 |
+
# Here we'll add a new command-line argument to configure the
|
| 105 |
+
# dimensionality of the hidden state.
|
| 106 |
+
parser.add_argument(
|
| 107 |
+
'--hidden-dim', type=int, metavar='N',
|
| 108 |
+
help='dimensionality of the hidden state',
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
@classmethod
|
| 112 |
+
def build_model(cls, args, task):
|
| 113 |
+
# Fairseq initializes models by calling the ``build_model()``
|
| 114 |
+
# function. This provides more flexibility, since the returned model
|
| 115 |
+
# instance can be of a different type than the one that was called.
|
| 116 |
+
# In this case we'll just return a FairseqRNNClassifier instance.
|
| 117 |
+
|
| 118 |
+
# Initialize our RNN module
|
| 119 |
+
rnn = RNN(
|
| 120 |
+
# We'll define the Task in the next section, but for now just
|
| 121 |
+
# notice that the task holds the dictionaries for the "source"
|
| 122 |
+
# (i.e., the input sentence) and "target" (i.e., the label).
|
| 123 |
+
input_size=len(task.source_dictionary),
|
| 124 |
+
hidden_size=args.hidden_dim,
|
| 125 |
+
output_size=len(task.target_dictionary),
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# Return the wrapped version of the module
|
| 129 |
+
return FairseqRNNClassifier(
|
| 130 |
+
rnn=rnn,
|
| 131 |
+
input_vocab=task.source_dictionary,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
def __init__(self, rnn, input_vocab):
|
| 135 |
+
super(FairseqRNNClassifier, self).__init__()
|
| 136 |
+
|
| 137 |
+
self.rnn = rnn
|
| 138 |
+
self.input_vocab = input_vocab
|
| 139 |
+
|
| 140 |
+
# The RNN module in the tutorial expects one-hot inputs, so we can
|
| 141 |
+
# precompute the identity matrix to help convert from indices to
|
| 142 |
+
# one-hot vectors. We register it as a buffer so that it is moved to
|
| 143 |
+
# the GPU when ``cuda()`` is called.
|
| 144 |
+
self.register_buffer('one_hot_inputs', torch.eye(len(input_vocab)))
|
| 145 |
+
|
| 146 |
+
def forward(self, src_tokens, src_lengths):
|
| 147 |
+
# The inputs to the ``forward()`` function are determined by the
|
| 148 |
+
# Task, and in particular the ``'net_input'`` key in each
|
| 149 |
+
# mini-batch. We'll define the Task in the next section, but for
|
| 150 |
+
# now just know that *src_tokens* has shape `(batch, src_len)` and
|
| 151 |
+
# *src_lengths* has shape `(batch)`.
|
| 152 |
+
bsz, max_src_len = src_tokens.size()
|
| 153 |
+
|
| 154 |
+
# Initialize the RNN hidden state. Compared to the original PyTorch
|
| 155 |
+
# tutorial we'll also handle batched inputs and work on the GPU.
|
| 156 |
+
hidden = self.rnn.initHidden()
|
| 157 |
+
hidden = hidden.repeat(bsz, 1) # expand for batched inputs
|
| 158 |
+
hidden = hidden.to(src_tokens.device) # move to GPU
|
| 159 |
+
|
| 160 |
+
for i in range(max_src_len):
|
| 161 |
+
# WARNING: The inputs have padding, so we should mask those
|
| 162 |
+
# elements here so that padding doesn't affect the results.
|
| 163 |
+
# This is left as an exercise for the reader. The padding symbol
|
| 164 |
+
# is given by ``self.input_vocab.pad()`` and the unpadded length
|
| 165 |
+
# of each input is given by *src_lengths*.
|
| 166 |
+
|
| 167 |
+
# One-hot encode a batch of input characters.
|
| 168 |
+
input = self.one_hot_inputs[src_tokens[:, i].long()]
|
| 169 |
+
|
| 170 |
+
# Feed the input to our RNN.
|
| 171 |
+
output, hidden = self.rnn(input, hidden)
|
| 172 |
+
|
| 173 |
+
# Return the final output state for making a prediction
|
| 174 |
+
return output
|
| 175 |
+
|
| 176 |
+
Finally let's define a *named architecture* with the configuration for our
|
| 177 |
+
model. This is done with the :func:`~fairseq.models.register_model_architecture`
|
| 178 |
+
function decorator. Thereafter this named architecture can be used with the
|
| 179 |
+
``--arch`` command-line argument, e.g., ``--arch pytorch_tutorial_rnn``::
|
| 180 |
+
|
| 181 |
+
from fairseq.models import register_model_architecture
|
| 182 |
+
|
| 183 |
+
# The first argument to ``register_model_architecture()`` should be the name
|
| 184 |
+
# of the model we registered above (i.e., 'rnn_classifier'). The function we
|
| 185 |
+
# register here should take a single argument *args* and modify it in-place
|
| 186 |
+
# to match the desired architecture.
|
| 187 |
+
|
| 188 |
+
@register_model_architecture('rnn_classifier', 'pytorch_tutorial_rnn')
|
| 189 |
+
def pytorch_tutorial_rnn(args):
|
| 190 |
+
# We use ``getattr()`` to prioritize arguments that are explicitly given
|
| 191 |
+
# on the command-line, so that the defaults defined below are only used
|
| 192 |
+
# when no other value has been specified.
|
| 193 |
+
args.hidden_dim = getattr(args, 'hidden_dim', 128)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
3. Registering a new Task
|
| 197 |
+
-------------------------
|
| 198 |
+
|
| 199 |
+
Now we'll register a new :class:`~fairseq.tasks.FairseqTask` that will load our
|
| 200 |
+
dictionaries and dataset. Tasks can also control how the data is batched into
|
| 201 |
+
mini-batches, but in this tutorial we'll reuse the batching provided by
|
| 202 |
+
:class:`fairseq.data.LanguagePairDataset`.
|
| 203 |
+
|
| 204 |
+
Create a new file named :file:`fairseq/tasks/simple_classification.py` with the
|
| 205 |
+
following contents::
|
| 206 |
+
|
| 207 |
+
import os
|
| 208 |
+
import torch
|
| 209 |
+
|
| 210 |
+
from fairseq.data import Dictionary, LanguagePairDataset
|
| 211 |
+
from fairseq.tasks import FairseqTask, register_task
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
@register_task('simple_classification')
|
| 215 |
+
class SimpleClassificationTask(LegacyFairseqTask):
|
| 216 |
+
|
| 217 |
+
@staticmethod
|
| 218 |
+
def add_args(parser):
|
| 219 |
+
# Add some command-line arguments for specifying where the data is
|
| 220 |
+
# located and the maximum supported input length.
|
| 221 |
+
parser.add_argument('data', metavar='FILE',
|
| 222 |
+
help='file prefix for data')
|
| 223 |
+
parser.add_argument('--max-positions', default=1024, type=int,
|
| 224 |
+
help='max input length')
|
| 225 |
+
|
| 226 |
+
@classmethod
|
| 227 |
+
def setup_task(cls, args, **kwargs):
|
| 228 |
+
# Here we can perform any setup required for the task. This may include
|
| 229 |
+
# loading Dictionaries, initializing shared Embedding layers, etc.
|
| 230 |
+
# In this case we'll just load the Dictionaries.
|
| 231 |
+
input_vocab = Dictionary.load(os.path.join(args.data, 'dict.input.txt'))
|
| 232 |
+
label_vocab = Dictionary.load(os.path.join(args.data, 'dict.label.txt'))
|
| 233 |
+
print('| [input] dictionary: {} types'.format(len(input_vocab)))
|
| 234 |
+
print('| [label] dictionary: {} types'.format(len(label_vocab)))
|
| 235 |
+
|
| 236 |
+
return SimpleClassificationTask(args, input_vocab, label_vocab)
|
| 237 |
+
|
| 238 |
+
def __init__(self, args, input_vocab, label_vocab):
|
| 239 |
+
super().__init__(args)
|
| 240 |
+
self.input_vocab = input_vocab
|
| 241 |
+
self.label_vocab = label_vocab
|
| 242 |
+
|
| 243 |
+
def load_dataset(self, split, **kwargs):
|
| 244 |
+
"""Load a given dataset split (e.g., train, valid, test)."""
|
| 245 |
+
|
| 246 |
+
prefix = os.path.join(self.args.data, '{}.input-label'.format(split))
|
| 247 |
+
|
| 248 |
+
# Read input sentences.
|
| 249 |
+
sentences, lengths = [], []
|
| 250 |
+
with open(prefix + '.input', encoding='utf-8') as file:
|
| 251 |
+
for line in file:
|
| 252 |
+
sentence = line.strip()
|
| 253 |
+
|
| 254 |
+
# Tokenize the sentence, splitting on spaces
|
| 255 |
+
tokens = self.input_vocab.encode_line(
|
| 256 |
+
sentence, add_if_not_exist=False,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
sentences.append(tokens)
|
| 260 |
+
lengths.append(tokens.numel())
|
| 261 |
+
|
| 262 |
+
# Read labels.
|
| 263 |
+
labels = []
|
| 264 |
+
with open(prefix + '.label', encoding='utf-8') as file:
|
| 265 |
+
for line in file:
|
| 266 |
+
label = line.strip()
|
| 267 |
+
labels.append(
|
| 268 |
+
# Convert label to a numeric ID.
|
| 269 |
+
torch.LongTensor([self.label_vocab.add_symbol(label)])
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
assert len(sentences) == len(labels)
|
| 273 |
+
print('| {} {} {} examples'.format(self.args.data, split, len(sentences)))
|
| 274 |
+
|
| 275 |
+
# We reuse LanguagePairDataset since classification can be modeled as a
|
| 276 |
+
# sequence-to-sequence task where the target sequence has length 1.
|
| 277 |
+
self.datasets[split] = LanguagePairDataset(
|
| 278 |
+
src=sentences,
|
| 279 |
+
src_sizes=lengths,
|
| 280 |
+
src_dict=self.input_vocab,
|
| 281 |
+
tgt=labels,
|
| 282 |
+
tgt_sizes=torch.ones(len(labels)), # targets have length 1
|
| 283 |
+
tgt_dict=self.label_vocab,
|
| 284 |
+
left_pad_source=False,
|
| 285 |
+
# Since our target is a single class label, there's no need for
|
| 286 |
+
# teacher forcing. If we set this to ``True`` then our Model's
|
| 287 |
+
# ``forward()`` method would receive an additional argument called
|
| 288 |
+
# *prev_output_tokens* that would contain a shifted version of the
|
| 289 |
+
# target sequence.
|
| 290 |
+
input_feeding=False,
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
def max_positions(self):
|
| 294 |
+
"""Return the max input length allowed by the task."""
|
| 295 |
+
# The source should be less than *args.max_positions* and the "target"
|
| 296 |
+
# has max length 1.
|
| 297 |
+
return (self.args.max_positions, 1)
|
| 298 |
+
|
| 299 |
+
@property
|
| 300 |
+
def source_dictionary(self):
|
| 301 |
+
"""Return the source :class:`~fairseq.data.Dictionary`."""
|
| 302 |
+
return self.input_vocab
|
| 303 |
+
|
| 304 |
+
@property
|
| 305 |
+
def target_dictionary(self):
|
| 306 |
+
"""Return the target :class:`~fairseq.data.Dictionary`."""
|
| 307 |
+
return self.label_vocab
|
| 308 |
+
|
| 309 |
+
# We could override this method if we wanted more control over how batches
|
| 310 |
+
# are constructed, but it's not necessary for this tutorial since we can
|
| 311 |
+
# reuse the batching provided by LanguagePairDataset.
|
| 312 |
+
#
|
| 313 |
+
# def get_batch_iterator(
|
| 314 |
+
# self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
|
| 315 |
+
# ignore_invalid_inputs=False, required_batch_size_multiple=1,
|
| 316 |
+
# seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=1,
|
| 317 |
+
# data_buffer_size=0, disable_iterator_cache=False,
|
| 318 |
+
# ):
|
| 319 |
+
# (...)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
4. Training the Model
|
| 323 |
+
---------------------
|
| 324 |
+
|
| 325 |
+
Now we're ready to train the model. We can use the existing :ref:`fairseq-train`
|
| 326 |
+
command-line tool for this, making sure to specify our new Task (``--task
|
| 327 |
+
simple_classification``) and Model architecture (``--arch
|
| 328 |
+
pytorch_tutorial_rnn``):
|
| 329 |
+
|
| 330 |
+
.. note::
|
| 331 |
+
|
| 332 |
+
You can also configure the dimensionality of the hidden state by passing the
|
| 333 |
+
``--hidden-dim`` argument to :ref:`fairseq-train`.
|
| 334 |
+
|
| 335 |
+
.. code-block:: console
|
| 336 |
+
|
| 337 |
+
> fairseq-train names-bin \
|
| 338 |
+
--task simple_classification \
|
| 339 |
+
--arch pytorch_tutorial_rnn \
|
| 340 |
+
--optimizer adam --lr 0.001 --lr-shrink 0.5 \
|
| 341 |
+
--max-tokens 1000
|
| 342 |
+
(...)
|
| 343 |
+
| epoch 027 | loss 1.200 | ppl 2.30 | wps 15728 | ups 119.4 | wpb 116 | bsz 116 | num_updates 3726 | lr 1.5625e-05 | gnorm 1.290 | clip 0% | oom 0 | wall 32 | train_wall 21
|
| 344 |
+
| epoch 027 | valid on 'valid' subset | valid_loss 1.41304 | valid_ppl 2.66 | num_updates 3726 | best 1.41208
|
| 345 |
+
| done training in 31.6 seconds
|
| 346 |
+
|
| 347 |
+
The model files should appear in the :file:`checkpoints/` directory.
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
5. Writing an evaluation script
|
| 351 |
+
-------------------------------
|
| 352 |
+
|
| 353 |
+
Finally we can write a short script to evaluate our model on new inputs. Create
|
| 354 |
+
a new file named :file:`eval_classifier.py` with the following contents::
|
| 355 |
+
|
| 356 |
+
from fairseq import checkpoint_utils, data, options, tasks
|
| 357 |
+
|
| 358 |
+
# Parse command-line arguments for generation
|
| 359 |
+
parser = options.get_generation_parser(default_task='simple_classification')
|
| 360 |
+
args = options.parse_args_and_arch(parser)
|
| 361 |
+
|
| 362 |
+
# Setup task
|
| 363 |
+
task = tasks.setup_task(args)
|
| 364 |
+
|
| 365 |
+
# Load model
|
| 366 |
+
print('| loading model from {}'.format(args.path))
|
| 367 |
+
models, _model_args = checkpoint_utils.load_model_ensemble([args.path], task=task)
|
| 368 |
+
model = models[0]
|
| 369 |
+
|
| 370 |
+
while True:
|
| 371 |
+
sentence = input('\nInput: ')
|
| 372 |
+
|
| 373 |
+
# Tokenize into characters
|
| 374 |
+
chars = ' '.join(list(sentence.strip()))
|
| 375 |
+
tokens = task.source_dictionary.encode_line(
|
| 376 |
+
chars, add_if_not_exist=False,
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
# Build mini-batch to feed to the model
|
| 380 |
+
batch = data.language_pair_dataset.collate(
|
| 381 |
+
samples=[{'id': -1, 'source': tokens}], # bsz = 1
|
| 382 |
+
pad_idx=task.source_dictionary.pad(),
|
| 383 |
+
eos_idx=task.source_dictionary.eos(),
|
| 384 |
+
left_pad_source=False,
|
| 385 |
+
input_feeding=False,
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
# Feed batch to the model and get predictions
|
| 389 |
+
preds = model(**batch['net_input'])
|
| 390 |
+
|
| 391 |
+
# Print top 3 predictions and their log-probabilities
|
| 392 |
+
top_scores, top_labels = preds[0].topk(k=3)
|
| 393 |
+
for score, label_idx in zip(top_scores, top_labels):
|
| 394 |
+
label_name = task.target_dictionary.string([label_idx])
|
| 395 |
+
print('({:.2f})\t{}'.format(score, label_name))
|
| 396 |
+
|
| 397 |
+
Now we can evaluate our model interactively. Note that we have included the
|
| 398 |
+
original data path (:file:`names-bin/`) so that the dictionaries can be loaded:
|
| 399 |
+
|
| 400 |
+
.. code-block:: console
|
| 401 |
+
|
| 402 |
+
> python eval_classifier.py names-bin --path checkpoints/checkpoint_best.pt
|
| 403 |
+
| [input] dictionary: 64 types
|
| 404 |
+
| [label] dictionary: 24 types
|
| 405 |
+
| loading model from checkpoints/checkpoint_best.pt
|
| 406 |
+
|
| 407 |
+
Input: Satoshi
|
| 408 |
+
(-0.61) Japanese
|
| 409 |
+
(-1.20) Arabic
|
| 410 |
+
(-2.86) Italian
|
| 411 |
+
|
| 412 |
+
Input: Sinbad
|
| 413 |
+
(-0.30) Arabic
|
| 414 |
+
(-1.76) English
|
| 415 |
+
(-4.08) Russian
|
fairseq/docs/tutorial_simple_lstm.rst
ADDED
|
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Tutorial: Simple LSTM
|
| 2 |
+
=====================
|
| 3 |
+
|
| 4 |
+
In this tutorial we will extend fairseq by adding a new
|
| 5 |
+
:class:`~fairseq.models.FairseqEncoderDecoderModel` that encodes a source
|
| 6 |
+
sentence with an LSTM and then passes the final hidden state to a second LSTM
|
| 7 |
+
that decodes the target sentence (without attention).
|
| 8 |
+
|
| 9 |
+
This tutorial covers:
|
| 10 |
+
|
| 11 |
+
1. **Writing an Encoder and Decoder** to encode/decode the source/target
|
| 12 |
+
sentence, respectively.
|
| 13 |
+
2. **Registering a new Model** so that it can be used with the existing
|
| 14 |
+
:ref:`Command-line tools`.
|
| 15 |
+
3. **Training the Model** using the existing command-line tools.
|
| 16 |
+
4. **Making generation faster** by modifying the Decoder to use
|
| 17 |
+
:ref:`Incremental decoding`.
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
1. Building an Encoder and Decoder
|
| 21 |
+
----------------------------------
|
| 22 |
+
|
| 23 |
+
In this section we'll define a simple LSTM Encoder and Decoder. All Encoders
|
| 24 |
+
should implement the :class:`~fairseq.models.FairseqEncoder` interface and
|
| 25 |
+
Decoders should implement the :class:`~fairseq.models.FairseqDecoder` interface.
|
| 26 |
+
These interfaces themselves extend :class:`torch.nn.Module`, so FairseqEncoders
|
| 27 |
+
and FairseqDecoders can be written and used in the same ways as ordinary PyTorch
|
| 28 |
+
Modules.
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
Encoder
|
| 32 |
+
~~~~~~~
|
| 33 |
+
|
| 34 |
+
Our Encoder will embed the tokens in the source sentence, feed them to a
|
| 35 |
+
:class:`torch.nn.LSTM` and return the final hidden state. To create our encoder
|
| 36 |
+
save the following in a new file named :file:`fairseq/models/simple_lstm.py`::
|
| 37 |
+
|
| 38 |
+
import torch.nn as nn
|
| 39 |
+
from fairseq import utils
|
| 40 |
+
from fairseq.models import FairseqEncoder
|
| 41 |
+
|
| 42 |
+
class SimpleLSTMEncoder(FairseqEncoder):
|
| 43 |
+
|
| 44 |
+
def __init__(
|
| 45 |
+
self, args, dictionary, embed_dim=128, hidden_dim=128, dropout=0.1,
|
| 46 |
+
):
|
| 47 |
+
super().__init__(dictionary)
|
| 48 |
+
self.args = args
|
| 49 |
+
|
| 50 |
+
# Our encoder will embed the inputs before feeding them to the LSTM.
|
| 51 |
+
self.embed_tokens = nn.Embedding(
|
| 52 |
+
num_embeddings=len(dictionary),
|
| 53 |
+
embedding_dim=embed_dim,
|
| 54 |
+
padding_idx=dictionary.pad(),
|
| 55 |
+
)
|
| 56 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 57 |
+
|
| 58 |
+
# We'll use a single-layer, unidirectional LSTM for simplicity.
|
| 59 |
+
self.lstm = nn.LSTM(
|
| 60 |
+
input_size=embed_dim,
|
| 61 |
+
hidden_size=hidden_dim,
|
| 62 |
+
num_layers=1,
|
| 63 |
+
bidirectional=False,
|
| 64 |
+
batch_first=True,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
def forward(self, src_tokens, src_lengths):
|
| 68 |
+
# The inputs to the ``forward()`` function are determined by the
|
| 69 |
+
# Task, and in particular the ``'net_input'`` key in each
|
| 70 |
+
# mini-batch. We discuss Tasks in the next tutorial, but for now just
|
| 71 |
+
# know that *src_tokens* has shape `(batch, src_len)` and *src_lengths*
|
| 72 |
+
# has shape `(batch)`.
|
| 73 |
+
|
| 74 |
+
# Note that the source is typically padded on the left. This can be
|
| 75 |
+
# configured by adding the `--left-pad-source "False"` command-line
|
| 76 |
+
# argument, but here we'll make the Encoder handle either kind of
|
| 77 |
+
# padding by converting everything to be right-padded.
|
| 78 |
+
if self.args.left_pad_source:
|
| 79 |
+
# Convert left-padding to right-padding.
|
| 80 |
+
src_tokens = utils.convert_padding_direction(
|
| 81 |
+
src_tokens,
|
| 82 |
+
padding_idx=self.dictionary.pad(),
|
| 83 |
+
left_to_right=True
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# Embed the source.
|
| 87 |
+
x = self.embed_tokens(src_tokens)
|
| 88 |
+
|
| 89 |
+
# Apply dropout.
|
| 90 |
+
x = self.dropout(x)
|
| 91 |
+
|
| 92 |
+
# Pack the sequence into a PackedSequence object to feed to the LSTM.
|
| 93 |
+
x = nn.utils.rnn.pack_padded_sequence(x, src_lengths, batch_first=True)
|
| 94 |
+
|
| 95 |
+
# Get the output from the LSTM.
|
| 96 |
+
_outputs, (final_hidden, _final_cell) = self.lstm(x)
|
| 97 |
+
|
| 98 |
+
# Return the Encoder's output. This can be any object and will be
|
| 99 |
+
# passed directly to the Decoder.
|
| 100 |
+
return {
|
| 101 |
+
# this will have shape `(bsz, hidden_dim)`
|
| 102 |
+
'final_hidden': final_hidden.squeeze(0),
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
# Encoders are required to implement this method so that we can rearrange
|
| 106 |
+
# the order of the batch elements during inference (e.g., beam search).
|
| 107 |
+
def reorder_encoder_out(self, encoder_out, new_order):
|
| 108 |
+
"""
|
| 109 |
+
Reorder encoder output according to `new_order`.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
encoder_out: output from the ``forward()`` method
|
| 113 |
+
new_order (LongTensor): desired order
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
`encoder_out` rearranged according to `new_order`
|
| 117 |
+
"""
|
| 118 |
+
final_hidden = encoder_out['final_hidden']
|
| 119 |
+
return {
|
| 120 |
+
'final_hidden': final_hidden.index_select(0, new_order),
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
Decoder
|
| 125 |
+
~~~~~~~
|
| 126 |
+
|
| 127 |
+
Our Decoder will predict the next word, conditioned on the Encoder's final
|
| 128 |
+
hidden state and an embedded representation of the previous target word -- which
|
| 129 |
+
is sometimes called *teacher forcing*. More specifically, we'll use a
|
| 130 |
+
:class:`torch.nn.LSTM` to produce a sequence of hidden states that we'll project
|
| 131 |
+
to the size of the output vocabulary to predict each target word.
|
| 132 |
+
|
| 133 |
+
::
|
| 134 |
+
|
| 135 |
+
import torch
|
| 136 |
+
from fairseq.models import FairseqDecoder
|
| 137 |
+
|
| 138 |
+
class SimpleLSTMDecoder(FairseqDecoder):
|
| 139 |
+
|
| 140 |
+
def __init__(
|
| 141 |
+
self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128,
|
| 142 |
+
dropout=0.1,
|
| 143 |
+
):
|
| 144 |
+
super().__init__(dictionary)
|
| 145 |
+
|
| 146 |
+
# Our decoder will embed the inputs before feeding them to the LSTM.
|
| 147 |
+
self.embed_tokens = nn.Embedding(
|
| 148 |
+
num_embeddings=len(dictionary),
|
| 149 |
+
embedding_dim=embed_dim,
|
| 150 |
+
padding_idx=dictionary.pad(),
|
| 151 |
+
)
|
| 152 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 153 |
+
|
| 154 |
+
# We'll use a single-layer, unidirectional LSTM for simplicity.
|
| 155 |
+
self.lstm = nn.LSTM(
|
| 156 |
+
# For the first layer we'll concatenate the Encoder's final hidden
|
| 157 |
+
# state with the embedded target tokens.
|
| 158 |
+
input_size=encoder_hidden_dim + embed_dim,
|
| 159 |
+
hidden_size=hidden_dim,
|
| 160 |
+
num_layers=1,
|
| 161 |
+
bidirectional=False,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
# Define the output projection.
|
| 165 |
+
self.output_projection = nn.Linear(hidden_dim, len(dictionary))
|
| 166 |
+
|
| 167 |
+
# During training Decoders are expected to take the entire target sequence
|
| 168 |
+
# (shifted right by one position) and produce logits over the vocabulary.
|
| 169 |
+
# The *prev_output_tokens* tensor begins with the end-of-sentence symbol,
|
| 170 |
+
# ``dictionary.eos()``, followed by the target sequence.
|
| 171 |
+
def forward(self, prev_output_tokens, encoder_out):
|
| 172 |
+
"""
|
| 173 |
+
Args:
|
| 174 |
+
prev_output_tokens (LongTensor): previous decoder outputs of shape
|
| 175 |
+
`(batch, tgt_len)`, for teacher forcing
|
| 176 |
+
encoder_out (Tensor, optional): output from the encoder, used for
|
| 177 |
+
encoder-side attention
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
tuple:
|
| 181 |
+
- the last decoder layer's output of shape
|
| 182 |
+
`(batch, tgt_len, vocab)`
|
| 183 |
+
- the last decoder layer's attention weights of shape
|
| 184 |
+
`(batch, tgt_len, src_len)`
|
| 185 |
+
"""
|
| 186 |
+
bsz, tgt_len = prev_output_tokens.size()
|
| 187 |
+
|
| 188 |
+
# Extract the final hidden state from the Encoder.
|
| 189 |
+
final_encoder_hidden = encoder_out['final_hidden']
|
| 190 |
+
|
| 191 |
+
# Embed the target sequence, which has been shifted right by one
|
| 192 |
+
# position and now starts with the end-of-sentence symbol.
|
| 193 |
+
x = self.embed_tokens(prev_output_tokens)
|
| 194 |
+
|
| 195 |
+
# Apply dropout.
|
| 196 |
+
x = self.dropout(x)
|
| 197 |
+
|
| 198 |
+
# Concatenate the Encoder's final hidden state to *every* embedded
|
| 199 |
+
# target token.
|
| 200 |
+
x = torch.cat(
|
| 201 |
+
[x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)],
|
| 202 |
+
dim=2,
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
# Using PackedSequence objects in the Decoder is harder than in the
|
| 206 |
+
# Encoder, since the targets are not sorted in descending length order,
|
| 207 |
+
# which is a requirement of ``pack_padded_sequence()``. Instead we'll
|
| 208 |
+
# feed nn.LSTM directly.
|
| 209 |
+
initial_state = (
|
| 210 |
+
final_encoder_hidden.unsqueeze(0), # hidden
|
| 211 |
+
torch.zeros_like(final_encoder_hidden).unsqueeze(0), # cell
|
| 212 |
+
)
|
| 213 |
+
output, _ = self.lstm(
|
| 214 |
+
x.transpose(0, 1), # convert to shape `(tgt_len, bsz, dim)`
|
| 215 |
+
initial_state,
|
| 216 |
+
)
|
| 217 |
+
x = output.transpose(0, 1) # convert to shape `(bsz, tgt_len, hidden)`
|
| 218 |
+
|
| 219 |
+
# Project the outputs to the size of the vocabulary.
|
| 220 |
+
x = self.output_projection(x)
|
| 221 |
+
|
| 222 |
+
# Return the logits and ``None`` for the attention weights
|
| 223 |
+
return x, None
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
2. Registering the Model
|
| 227 |
+
------------------------
|
| 228 |
+
|
| 229 |
+
Now that we've defined our Encoder and Decoder we must *register* our model with
|
| 230 |
+
fairseq using the :func:`~fairseq.models.register_model` function decorator.
|
| 231 |
+
Once the model is registered we'll be able to use it with the existing
|
| 232 |
+
:ref:`Command-line Tools`.
|
| 233 |
+
|
| 234 |
+
All registered models must implement the
|
| 235 |
+
:class:`~fairseq.models.BaseFairseqModel` interface. For sequence-to-sequence
|
| 236 |
+
models (i.e., any model with a single Encoder and Decoder), we can instead
|
| 237 |
+
implement the :class:`~fairseq.models.FairseqEncoderDecoderModel` interface.
|
| 238 |
+
|
| 239 |
+
Create a small wrapper class in the same file and register it in fairseq with
|
| 240 |
+
the name ``'simple_lstm'``::
|
| 241 |
+
|
| 242 |
+
from fairseq.models import FairseqEncoderDecoderModel, register_model
|
| 243 |
+
|
| 244 |
+
# Note: the register_model "decorator" should immediately precede the
|
| 245 |
+
# definition of the Model class.
|
| 246 |
+
|
| 247 |
+
@register_model('simple_lstm')
|
| 248 |
+
class SimpleLSTMModel(FairseqEncoderDecoderModel):
|
| 249 |
+
|
| 250 |
+
@staticmethod
|
| 251 |
+
def add_args(parser):
|
| 252 |
+
# Models can override this method to add new command-line arguments.
|
| 253 |
+
# Here we'll add some new command-line arguments to configure dropout
|
| 254 |
+
# and the dimensionality of the embeddings and hidden states.
|
| 255 |
+
parser.add_argument(
|
| 256 |
+
'--encoder-embed-dim', type=int, metavar='N',
|
| 257 |
+
help='dimensionality of the encoder embeddings',
|
| 258 |
+
)
|
| 259 |
+
parser.add_argument(
|
| 260 |
+
'--encoder-hidden-dim', type=int, metavar='N',
|
| 261 |
+
help='dimensionality of the encoder hidden state',
|
| 262 |
+
)
|
| 263 |
+
parser.add_argument(
|
| 264 |
+
'--encoder-dropout', type=float, default=0.1,
|
| 265 |
+
help='encoder dropout probability',
|
| 266 |
+
)
|
| 267 |
+
parser.add_argument(
|
| 268 |
+
'--decoder-embed-dim', type=int, metavar='N',
|
| 269 |
+
help='dimensionality of the decoder embeddings',
|
| 270 |
+
)
|
| 271 |
+
parser.add_argument(
|
| 272 |
+
'--decoder-hidden-dim', type=int, metavar='N',
|
| 273 |
+
help='dimensionality of the decoder hidden state',
|
| 274 |
+
)
|
| 275 |
+
parser.add_argument(
|
| 276 |
+
'--decoder-dropout', type=float, default=0.1,
|
| 277 |
+
help='decoder dropout probability',
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
@classmethod
|
| 281 |
+
def build_model(cls, args, task):
|
| 282 |
+
# Fairseq initializes models by calling the ``build_model()``
|
| 283 |
+
# function. This provides more flexibility, since the returned model
|
| 284 |
+
# instance can be of a different type than the one that was called.
|
| 285 |
+
# In this case we'll just return a SimpleLSTMModel instance.
|
| 286 |
+
|
| 287 |
+
# Initialize our Encoder and Decoder.
|
| 288 |
+
encoder = SimpleLSTMEncoder(
|
| 289 |
+
args=args,
|
| 290 |
+
dictionary=task.source_dictionary,
|
| 291 |
+
embed_dim=args.encoder_embed_dim,
|
| 292 |
+
hidden_dim=args.encoder_hidden_dim,
|
| 293 |
+
dropout=args.encoder_dropout,
|
| 294 |
+
)
|
| 295 |
+
decoder = SimpleLSTMDecoder(
|
| 296 |
+
dictionary=task.target_dictionary,
|
| 297 |
+
encoder_hidden_dim=args.encoder_hidden_dim,
|
| 298 |
+
embed_dim=args.decoder_embed_dim,
|
| 299 |
+
hidden_dim=args.decoder_hidden_dim,
|
| 300 |
+
dropout=args.decoder_dropout,
|
| 301 |
+
)
|
| 302 |
+
model = SimpleLSTMModel(encoder, decoder)
|
| 303 |
+
|
| 304 |
+
# Print the model architecture.
|
| 305 |
+
print(model)
|
| 306 |
+
|
| 307 |
+
return model
|
| 308 |
+
|
| 309 |
+
# We could override the ``forward()`` if we wanted more control over how
|
| 310 |
+
# the encoder and decoder interact, but it's not necessary for this
|
| 311 |
+
# tutorial since we can inherit the default implementation provided by
|
| 312 |
+
# the FairseqEncoderDecoderModel base class, which looks like:
|
| 313 |
+
#
|
| 314 |
+
# def forward(self, src_tokens, src_lengths, prev_output_tokens):
|
| 315 |
+
# encoder_out = self.encoder(src_tokens, src_lengths)
|
| 316 |
+
# decoder_out = self.decoder(prev_output_tokens, encoder_out)
|
| 317 |
+
# return decoder_out
|
| 318 |
+
|
| 319 |
+
Finally let's define a *named architecture* with the configuration for our
|
| 320 |
+
model. This is done with the :func:`~fairseq.models.register_model_architecture`
|
| 321 |
+
function decorator. Thereafter this named architecture can be used with the
|
| 322 |
+
``--arch`` command-line argument, e.g., ``--arch tutorial_simple_lstm``::
|
| 323 |
+
|
| 324 |
+
from fairseq.models import register_model_architecture
|
| 325 |
+
|
| 326 |
+
# The first argument to ``register_model_architecture()`` should be the name
|
| 327 |
+
# of the model we registered above (i.e., 'simple_lstm'). The function we
|
| 328 |
+
# register here should take a single argument *args* and modify it in-place
|
| 329 |
+
# to match the desired architecture.
|
| 330 |
+
|
| 331 |
+
@register_model_architecture('simple_lstm', 'tutorial_simple_lstm')
|
| 332 |
+
def tutorial_simple_lstm(args):
|
| 333 |
+
# We use ``getattr()`` to prioritize arguments that are explicitly given
|
| 334 |
+
# on the command-line, so that the defaults defined below are only used
|
| 335 |
+
# when no other value has been specified.
|
| 336 |
+
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256)
|
| 337 |
+
args.encoder_hidden_dim = getattr(args, 'encoder_hidden_dim', 256)
|
| 338 |
+
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256)
|
| 339 |
+
args.decoder_hidden_dim = getattr(args, 'decoder_hidden_dim', 256)
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
3. Training the Model
|
| 343 |
+
---------------------
|
| 344 |
+
|
| 345 |
+
Now we're ready to train the model. We can use the existing :ref:`fairseq-train`
|
| 346 |
+
command-line tool for this, making sure to specify our new Model architecture
|
| 347 |
+
(``--arch tutorial_simple_lstm``).
|
| 348 |
+
|
| 349 |
+
.. note::
|
| 350 |
+
|
| 351 |
+
Make sure you've already preprocessed the data from the IWSLT example in the
|
| 352 |
+
:file:`examples/translation/` directory.
|
| 353 |
+
|
| 354 |
+
.. code-block:: console
|
| 355 |
+
|
| 356 |
+
> fairseq-train data-bin/iwslt14.tokenized.de-en \
|
| 357 |
+
--arch tutorial_simple_lstm \
|
| 358 |
+
--encoder-dropout 0.2 --decoder-dropout 0.2 \
|
| 359 |
+
--optimizer adam --lr 0.005 --lr-shrink 0.5 \
|
| 360 |
+
--max-tokens 12000
|
| 361 |
+
(...)
|
| 362 |
+
| epoch 052 | loss 4.027 | ppl 16.30 | wps 420805 | ups 39.7 | wpb 9841 | bsz 400 | num_updates 20852 | lr 1.95313e-05 | gnorm 0.218 | clip 0% | oom 0 | wall 529 | train_wall 396
|
| 363 |
+
| epoch 052 | valid on 'valid' subset | valid_loss 4.74989 | valid_ppl 26.91 | num_updates 20852 | best 4.74954
|
| 364 |
+
|
| 365 |
+
The model files should appear in the :file:`checkpoints/` directory. While this
|
| 366 |
+
model architecture is not very good, we can use the :ref:`fairseq-generate` script to
|
| 367 |
+
generate translations and compute our BLEU score over the test set:
|
| 368 |
+
|
| 369 |
+
.. code-block:: console
|
| 370 |
+
|
| 371 |
+
> fairseq-generate data-bin/iwslt14.tokenized.de-en \
|
| 372 |
+
--path checkpoints/checkpoint_best.pt \
|
| 373 |
+
--beam 5 \
|
| 374 |
+
--remove-bpe
|
| 375 |
+
(...)
|
| 376 |
+
| Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s)
|
| 377 |
+
| Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
4. Making generation faster
|
| 381 |
+
---------------------------
|
| 382 |
+
|
| 383 |
+
While autoregressive generation from sequence-to-sequence models is inherently
|
| 384 |
+
slow, our implementation above is especially slow because it recomputes the
|
| 385 |
+
entire sequence of Decoder hidden states for every output token (i.e., it is
|
| 386 |
+
``O(n^2)``). We can make this significantly faster by instead caching the
|
| 387 |
+
previous hidden states.
|
| 388 |
+
|
| 389 |
+
In fairseq this is called :ref:`Incremental decoding`. Incremental decoding is a
|
| 390 |
+
special mode at inference time where the Model only receives a single timestep
|
| 391 |
+
of input corresponding to the immediately previous output token (for teacher
|
| 392 |
+
forcing) and must produce the next output incrementally. Thus the model must
|
| 393 |
+
cache any long-term state that is needed about the sequence, e.g., hidden
|
| 394 |
+
states, convolutional states, etc.
|
| 395 |
+
|
| 396 |
+
To implement incremental decoding we will modify our model to implement the
|
| 397 |
+
:class:`~fairseq.models.FairseqIncrementalDecoder` interface. Compared to the
|
| 398 |
+
standard :class:`~fairseq.models.FairseqDecoder` interface, the incremental
|
| 399 |
+
decoder interface allows ``forward()`` methods to take an extra keyword argument
|
| 400 |
+
(*incremental_state*) that can be used to cache state across time-steps.
|
| 401 |
+
|
| 402 |
+
Let's replace our ``SimpleLSTMDecoder`` with an incremental one::
|
| 403 |
+
|
| 404 |
+
import torch
|
| 405 |
+
from fairseq.models import FairseqIncrementalDecoder
|
| 406 |
+
|
| 407 |
+
class SimpleLSTMDecoder(FairseqIncrementalDecoder):
|
| 408 |
+
|
| 409 |
+
def __init__(
|
| 410 |
+
self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128,
|
| 411 |
+
dropout=0.1,
|
| 412 |
+
):
|
| 413 |
+
# This remains the same as before.
|
| 414 |
+
super().__init__(dictionary)
|
| 415 |
+
self.embed_tokens = nn.Embedding(
|
| 416 |
+
num_embeddings=len(dictionary),
|
| 417 |
+
embedding_dim=embed_dim,
|
| 418 |
+
padding_idx=dictionary.pad(),
|
| 419 |
+
)
|
| 420 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 421 |
+
self.lstm = nn.LSTM(
|
| 422 |
+
input_size=encoder_hidden_dim + embed_dim,
|
| 423 |
+
hidden_size=hidden_dim,
|
| 424 |
+
num_layers=1,
|
| 425 |
+
bidirectional=False,
|
| 426 |
+
)
|
| 427 |
+
self.output_projection = nn.Linear(hidden_dim, len(dictionary))
|
| 428 |
+
|
| 429 |
+
# We now take an additional kwarg (*incremental_state*) for caching the
|
| 430 |
+
# previous hidden and cell states.
|
| 431 |
+
def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
|
| 432 |
+
if incremental_state is not None:
|
| 433 |
+
# If the *incremental_state* argument is not ``None`` then we are
|
| 434 |
+
# in incremental inference mode. While *prev_output_tokens* will
|
| 435 |
+
# still contain the entire decoded prefix, we will only use the
|
| 436 |
+
# last step and assume that the rest of the state is cached.
|
| 437 |
+
prev_output_tokens = prev_output_tokens[:, -1:]
|
| 438 |
+
|
| 439 |
+
# This remains the same as before.
|
| 440 |
+
bsz, tgt_len = prev_output_tokens.size()
|
| 441 |
+
final_encoder_hidden = encoder_out['final_hidden']
|
| 442 |
+
x = self.embed_tokens(prev_output_tokens)
|
| 443 |
+
x = self.dropout(x)
|
| 444 |
+
x = torch.cat(
|
| 445 |
+
[x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)],
|
| 446 |
+
dim=2,
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
# We will now check the cache and load the cached previous hidden and
|
| 450 |
+
# cell states, if they exist, otherwise we will initialize them to
|
| 451 |
+
# zeros (as before). We will use the ``utils.get_incremental_state()``
|
| 452 |
+
# and ``utils.set_incremental_state()`` helpers.
|
| 453 |
+
initial_state = utils.get_incremental_state(
|
| 454 |
+
self, incremental_state, 'prev_state',
|
| 455 |
+
)
|
| 456 |
+
if initial_state is None:
|
| 457 |
+
# first time initialization, same as the original version
|
| 458 |
+
initial_state = (
|
| 459 |
+
final_encoder_hidden.unsqueeze(0), # hidden
|
| 460 |
+
torch.zeros_like(final_encoder_hidden).unsqueeze(0), # cell
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
# Run one step of our LSTM.
|
| 464 |
+
output, latest_state = self.lstm(x.transpose(0, 1), initial_state)
|
| 465 |
+
|
| 466 |
+
# Update the cache with the latest hidden and cell states.
|
| 467 |
+
utils.set_incremental_state(
|
| 468 |
+
self, incremental_state, 'prev_state', latest_state,
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
# This remains the same as before
|
| 472 |
+
x = output.transpose(0, 1)
|
| 473 |
+
x = self.output_projection(x)
|
| 474 |
+
return x, None
|
| 475 |
+
|
| 476 |
+
# The ``FairseqIncrementalDecoder`` interface also requires implementing a
|
| 477 |
+
# ``reorder_incremental_state()`` method, which is used during beam search
|
| 478 |
+
# to select and reorder the incremental state.
|
| 479 |
+
def reorder_incremental_state(self, incremental_state, new_order):
|
| 480 |
+
# Load the cached state.
|
| 481 |
+
prev_state = utils.get_incremental_state(
|
| 482 |
+
self, incremental_state, 'prev_state',
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
# Reorder batches according to *new_order*.
|
| 486 |
+
reordered_state = (
|
| 487 |
+
prev_state[0].index_select(1, new_order), # hidden
|
| 488 |
+
prev_state[1].index_select(1, new_order), # cell
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
# Update the cached state.
|
| 492 |
+
utils.set_incremental_state(
|
| 493 |
+
self, incremental_state, 'prev_state', reordered_state,
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
Finally, we can rerun generation and observe the speedup:
|
| 497 |
+
|
| 498 |
+
.. code-block:: console
|
| 499 |
+
|
| 500 |
+
# Before
|
| 501 |
+
|
| 502 |
+
> fairseq-generate data-bin/iwslt14.tokenized.de-en \
|
| 503 |
+
--path checkpoints/checkpoint_best.pt \
|
| 504 |
+
--beam 5 \
|
| 505 |
+
--remove-bpe
|
| 506 |
+
(...)
|
| 507 |
+
| Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s)
|
| 508 |
+
| Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)
|
| 509 |
+
|
| 510 |
+
# After
|
| 511 |
+
|
| 512 |
+
> fairseq-generate data-bin/iwslt14.tokenized.de-en \
|
| 513 |
+
--path checkpoints/checkpoint_best.pt \
|
| 514 |
+
--beam 5 \
|
| 515 |
+
--remove-bpe
|
| 516 |
+
(...)
|
| 517 |
+
| Translated 6750 sentences (153132 tokens) in 5.5s (1225.54 sentences/s, 27802.94 tokens/s)
|
| 518 |
+
| Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)
|
fairseq/examples/.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
!*/*.sh
|
| 2 |
+
!*/*.md
|
fairseq/examples/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
try:
|
| 7 |
+
from fairseq.version import __version__ # noqa
|
| 8 |
+
except ImportError:
|
| 9 |
+
pass
|
fairseq/examples/adaptive_span/README.md
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adaptive Span
|
| 2 |
+
|
| 3 |
+
Adaptive Span is a novel self-attention mechanism that can learn its optimal
|
| 4 |
+
attention span. This allows us to extend significantly the maximum context size
|
| 5 |
+
used in Transformer, while maintaining control over their memory footprint
|
| 6 |
+
and computational time. It uses the Truncated BPTT technique for training,
|
| 7 |
+
as in [transformerXL](https://github.com/pytorch/fairseq/blob/main/examples/truncated_bptt/README.md).
|
| 8 |
+
|
| 9 |
+
Adaptive Span was introduced by paper:
|
| 10 |
+
[Adaptive Attention Span in Transformers](https://arxiv.org/abs/1905.07799),
|
| 11 |
+
which achieved state-of-the-art language modeling results at the time of publication.
|
| 12 |
+
|
| 13 |
+
We manage to reproduce their result in fairseq and keep most of the
|
| 14 |
+
[original implementation](https://github.com/facebookresearch/adaptive-span) untouched.
|
| 15 |
+
You can refer to the their sweep file as well if any combination of hyperparameter is not clear.
|
| 16 |
+
|
| 17 |
+
##### 0. Setup
|
| 18 |
+
|
| 19 |
+
First you need to process the Enwik8 dataset, we use the pre-tokenized dataset
|
| 20 |
+
from [adaptive span paper](https://github.com/facebookresearch/adaptive-span/blob/master/get_data.sh).
|
| 21 |
+
You can download the dataset, and then run:
|
| 22 |
+
```bash
|
| 23 |
+
fairseq-preprocess --only-source --trainpref ~/data/enwik8/train.txt \
|
| 24 |
+
--validpref ~/data/enwik8/valid.txt --testpref ~/data/enwik8/test.txt \
|
| 25 |
+
--destdir ~/data/enwik8/data-bin/ --joined-dictionary --workers 20
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
##### 1. Train a Adaptive Span model on Enwik8
|
| 29 |
+
|
| 30 |
+
We will train a 12-layer Adaptive Span model following the [hyperparameters
|
| 31 |
+
used in the original
|
| 32 |
+
paper](https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8.sh).
|
| 33 |
+
|
| 34 |
+
The following command assumes 4 GPUs, so that the total batch size is 64
|
| 35 |
+
sequences (4 x 16). Training should take 2-3 days on 4 V100 GPUs:
|
| 36 |
+
```bash
|
| 37 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \
|
| 38 |
+
--user-dir examples/adaptive_span \
|
| 39 |
+
--data ~/data/enwik8/data-bin/ \
|
| 40 |
+
--fp16 --fp16-no-flatten-grads --max-update 600000 \
|
| 41 |
+
--task truncated_bptt_lm --tokens-per-sample 512 --arch adaptive_span \
|
| 42 |
+
--n-layer 12 --d-model 512 --n-head 8 --d-inner 2048 --dropout 0.3 \
|
| 43 |
+
--attn-span 8192 --optimizer adagrad_with_grad_clip --adagrad-clip 0.03 \
|
| 44 |
+
--validate-interval-updates 1000 \
|
| 45 |
+
--lr-scheduler fixed --warmup-updates 32000 --batch-size-valid 32 \
|
| 46 |
+
--lr 0.07 --criterion adaptive_span_loss --batch-size 16 --update-freq 1 \
|
| 47 |
+
--seed 2 --log-format json --log-interval 25 --aux-loss-scaler 5e-07
|
| 48 |
+
```
|
| 49 |
+
This should land around 1.05 on validation, 1.03 on test. You can lower the
|
| 50 |
+
--aux-loss-scaler for better performance (longer span). It gives ~0.03 bpc
|
| 51 |
+
improvement to the transformerXL baseline here.
|
| 52 |
+
If training on a single GPU, set `--update-freq=4` to accumulate 4x gradients
|
| 53 |
+
and simulate training on 4 GPUs.
|
| 54 |
+
You can also reproduce the transformerXL result on enwik8 using this code base.
|
| 55 |
+
It should land around 1.06 on test,matching the [original paper](https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/run_enwik8_base.sh).
|
| 56 |
+
You can try by
|
| 57 |
+
```bash
|
| 58 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \
|
| 59 |
+
--user-dir examples/truncated_bptt \
|
| 60 |
+
~/data/enwik8/data-bin/ \
|
| 61 |
+
--task truncated_bptt_lm --fp16 --max-update 400000 \
|
| 62 |
+
--tokens-per-sample 512 --arch transformer_xl --n-layer 12 \
|
| 63 |
+
--d-model 512 --n-head 8 --d-head 64 --d-inner 2048 --dropout 0.1 \
|
| 64 |
+
--dropatt 0.0 --mem-len 512 --optimizer adam --clip-norm 0.25 \
|
| 65 |
+
--lr-scheduler cosine --warmup-updates 0 \
|
| 66 |
+
--lr 0.0 --lr 0.00025 --batch-size 15 \
|
| 67 |
+
--update-freq 1 --seed 2 --log-format json --log-interval 25 \
|
| 68 |
+
--fp16
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
##### 2. Evaluate
|
| 72 |
+
For Adaptive Span:
|
| 73 |
+
```bash
|
| 74 |
+
fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/checkpoint_best.pt \
|
| 75 |
+
--user-dir examples/adaptive_span \
|
| 76 |
+
--task truncated_bptt_lm --batch-size 8 --tokens-per-sample 512 --gen-subset test
|
| 77 |
+
```
|
| 78 |
+
For Transformer-XL evaluation:
|
| 79 |
+
```bash
|
| 80 |
+
fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/checkpoint_best.pt \
|
| 81 |
+
--user-dir examples/truncated_bptt/ --task truncated_bptt_lm --batch-size 8 \
|
| 82 |
+
--tokens-per-sample 80 \
|
| 83 |
+
--model-overrides '{"mem_len":2100,"clamp_len":820,"same_length":True}' \
|
| 84 |
+
--gen-subset valid
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
*Note:* During training the model saw 512 tokens of context
|
| 88 |
+
(``--tokens-per-sample=512``), with batch size 8. These settings match the evaluation
|
| 89 |
+
settings from [the original
|
| 90 |
+
paper](https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8.sh).
|
fairseq/examples/adaptive_span/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import importlib
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
# automatically import any Python files in the current directory
|
| 10 |
+
cur_dir = os.path.dirname(__file__)
|
| 11 |
+
for file in os.listdir(cur_dir):
|
| 12 |
+
path = os.path.join(cur_dir, file)
|
| 13 |
+
if (
|
| 14 |
+
not file.startswith("_")
|
| 15 |
+
and not file.startswith(".")
|
| 16 |
+
and (file.endswith(".py") or os.path.isdir(path))
|
| 17 |
+
):
|
| 18 |
+
mod_name = file[: file.find(".py")] if file.endswith(".py") else file
|
| 19 |
+
module = importlib.import_module(__name__ + "." + mod_name)
|
fairseq/examples/adaptive_span/adagrad_with_grad_clip.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from torch.optim import Adagrad
|
| 7 |
+
|
| 8 |
+
from fairseq.optim import LegacyFairseqOptimizer, register_optimizer
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@register_optimizer("adagrad_with_grad_clip")
|
| 12 |
+
class FairseqAdagradWithGradClip(LegacyFairseqOptimizer):
|
| 13 |
+
def __init__(self, args, params):
|
| 14 |
+
super().__init__(args)
|
| 15 |
+
self._optimizer = AdagradWithGradClip(params, **self.optimizer_config)
|
| 16 |
+
|
| 17 |
+
@staticmethod
|
| 18 |
+
def add_args(parser):
|
| 19 |
+
"""Add optimizer-specific arguments to the parser."""
|
| 20 |
+
# fmt: off
|
| 21 |
+
parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
|
| 22 |
+
help='weight decay')
|
| 23 |
+
parser.add_argument('--adagrad-clip', default=0.0, type=float, metavar='D',
|
| 24 |
+
help='internal grad clip')
|
| 25 |
+
# fmt: on
|
| 26 |
+
|
| 27 |
+
@property
|
| 28 |
+
def optimizer_config(self):
|
| 29 |
+
"""
|
| 30 |
+
Return a kwarg dictionary that will be used to override optimizer
|
| 31 |
+
args stored in checkpoints. This allows us to load a checkpoint and
|
| 32 |
+
resume training using a different set of optimizer args, e.g., with a
|
| 33 |
+
different learning rate.
|
| 34 |
+
"""
|
| 35 |
+
return {
|
| 36 |
+
"lr": self.args.lr[0],
|
| 37 |
+
"weight_decay": self.args.weight_decay,
|
| 38 |
+
"grad_clip": self.args.adagrad_clip,
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
@property
|
| 42 |
+
def supports_flat_params(self):
|
| 43 |
+
return False
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _clip_grad(clr, grad, group_grad_clip):
|
| 47 |
+
if group_grad_clip > 0:
|
| 48 |
+
norm = grad.norm(2).item()
|
| 49 |
+
if norm > group_grad_clip:
|
| 50 |
+
clr *= group_grad_clip / (norm + 1e-10)
|
| 51 |
+
return clr
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class AdagradWithGradClip(Adagrad):
|
| 55 |
+
"""Adagrad algorithm with custom gradient clipping"""
|
| 56 |
+
|
| 57 |
+
def __init__(
|
| 58 |
+
self,
|
| 59 |
+
params,
|
| 60 |
+
lr=1e-2,
|
| 61 |
+
lr_decay=0,
|
| 62 |
+
weight_decay=0,
|
| 63 |
+
initial_accumulator_value=0,
|
| 64 |
+
grad_clip=0,
|
| 65 |
+
):
|
| 66 |
+
Adagrad.__init__(
|
| 67 |
+
self,
|
| 68 |
+
params,
|
| 69 |
+
lr=lr,
|
| 70 |
+
lr_decay=lr_decay,
|
| 71 |
+
weight_decay=weight_decay,
|
| 72 |
+
initial_accumulator_value=initial_accumulator_value,
|
| 73 |
+
)
|
| 74 |
+
self.defaults["grad_clip"] = grad_clip
|
| 75 |
+
self.param_groups[0].setdefault("grad_clip", grad_clip)
|
| 76 |
+
|
| 77 |
+
def step(self, closure=None):
|
| 78 |
+
loss = None
|
| 79 |
+
if closure is not None:
|
| 80 |
+
loss = closure()
|
| 81 |
+
|
| 82 |
+
for group in self.param_groups:
|
| 83 |
+
for p in group["params"]:
|
| 84 |
+
if p.grad is None:
|
| 85 |
+
continue
|
| 86 |
+
|
| 87 |
+
grad = p.grad.data
|
| 88 |
+
state = self.state[p]
|
| 89 |
+
|
| 90 |
+
state["step"] += 1
|
| 91 |
+
|
| 92 |
+
if group["weight_decay"] != 0:
|
| 93 |
+
if p.grad.data.is_sparse:
|
| 94 |
+
raise RuntimeError(
|
| 95 |
+
"weight_decay option is "
|
| 96 |
+
"not compatible with sparse "
|
| 97 |
+
"gradients"
|
| 98 |
+
)
|
| 99 |
+
grad = grad.add(group["weight_decay"], p.data)
|
| 100 |
+
|
| 101 |
+
clr = group["lr"] / (1 + (state["step"] - 1) * group["lr_decay"])
|
| 102 |
+
|
| 103 |
+
# clip
|
| 104 |
+
clr = _clip_grad(clr=clr, grad=grad, group_grad_clip=group["grad_clip"])
|
| 105 |
+
|
| 106 |
+
if grad.is_sparse:
|
| 107 |
+
# the update is non-linear so indices must be unique
|
| 108 |
+
grad = grad.coalesce()
|
| 109 |
+
grad_indices = grad._indices()
|
| 110 |
+
grad_values = grad._values()
|
| 111 |
+
size = grad.size()
|
| 112 |
+
|
| 113 |
+
def make_sparse(values):
|
| 114 |
+
constructor = grad.new
|
| 115 |
+
if grad_indices.dim() == 0 or values.dim() == 0:
|
| 116 |
+
return constructor().resize_as_(grad)
|
| 117 |
+
return constructor(grad_indices, values, size)
|
| 118 |
+
|
| 119 |
+
state["sum"].add_(make_sparse(grad_values.pow(2)))
|
| 120 |
+
std = state["sum"]._sparse_mask(grad)
|
| 121 |
+
std_values = std._values().sqrt_().add_(1e-10)
|
| 122 |
+
p.data.add_(-clr, make_sparse(grad_values / std_values))
|
| 123 |
+
else:
|
| 124 |
+
state["sum"].addcmul_(1, grad, grad)
|
| 125 |
+
std = state["sum"].sqrt().add_(1e-10)
|
| 126 |
+
p.data.addcdiv_(-clr, grad, std)
|
| 127 |
+
|
| 128 |
+
return loss
|
fairseq/examples/adaptive_span/adaptive_span_attention.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class AdaptiveMask(nn.Module):
|
| 13 |
+
"""Soft masking function for adaptive size.
|
| 14 |
+
It masks out the last K values of an input. The masking value
|
| 15 |
+
goes from 1 to 0 gradually, so K can be learned with
|
| 16 |
+
back-propagation.
|
| 17 |
+
Args:
|
| 18 |
+
max_size: maximum size (i.e. input dimension)
|
| 19 |
+
ramp_size: size of the ramp going from 0 to 1
|
| 20 |
+
init_val: initial size proportion not to be masked out
|
| 21 |
+
shape: learn multiple sizes independent of each other
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, max_size, ramp_size, init_val=0, shape=(1,)):
|
| 25 |
+
nn.Module.__init__(self)
|
| 26 |
+
self._max_size = max_size
|
| 27 |
+
self._ramp_size = ramp_size
|
| 28 |
+
self.current_val = nn.Parameter(torch.zeros(*shape) + init_val)
|
| 29 |
+
mask_template = torch.linspace(1 - max_size, 0, steps=max_size)
|
| 30 |
+
self.register_buffer("mask_template", mask_template)
|
| 31 |
+
|
| 32 |
+
def forward(self, x):
|
| 33 |
+
mask = self.mask_template.float() + self.current_val.float() * self._max_size
|
| 34 |
+
mask = mask / self._ramp_size + 1
|
| 35 |
+
mask = mask.clamp(0, 1)
|
| 36 |
+
if x.size(-1) < self._max_size:
|
| 37 |
+
# the input could have been trimmed beforehand to save computation
|
| 38 |
+
mask = mask.narrow(-1, self._max_size - x.size(-1), x.size(-1))
|
| 39 |
+
x = (x * mask).type_as(x)
|
| 40 |
+
return x
|
| 41 |
+
|
| 42 |
+
def get_current_max_size(self, include_ramp=True):
|
| 43 |
+
current_size = math.ceil(self.current_val.max().item() * self._max_size)
|
| 44 |
+
if include_ramp:
|
| 45 |
+
current_size += self._ramp_size
|
| 46 |
+
current_size = max(0, min(self._max_size, current_size))
|
| 47 |
+
return current_size
|
| 48 |
+
|
| 49 |
+
def get_current_avg_size(self, include_ramp=True):
|
| 50 |
+
current_size = math.ceil(
|
| 51 |
+
self.current_val.float().mean().item() * self._max_size
|
| 52 |
+
)
|
| 53 |
+
if include_ramp:
|
| 54 |
+
current_size += self._ramp_size
|
| 55 |
+
current_size = max(0, min(self._max_size, current_size))
|
| 56 |
+
return current_size
|
| 57 |
+
|
| 58 |
+
def clamp_param(self):
|
| 59 |
+
"""this need to be called after each update"""
|
| 60 |
+
self.current_val.data.clamp_(0, 1)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class AdaptiveSpan(nn.Module):
|
| 64 |
+
"""Adaptive attention span for Transformerself.
|
| 65 |
+
This module learns an attention span length from data for each
|
| 66 |
+
self-attention head.
|
| 67 |
+
Args:
|
| 68 |
+
attn_span: maximum attention span
|
| 69 |
+
adapt_span_loss: loss coefficient for the span length
|
| 70 |
+
adapt_span_ramp: length of the masking ramp
|
| 71 |
+
adapt_span_init: initial size ratio
|
| 72 |
+
adapt_span_cache: adapt cache size to reduce memory usage
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def __init__(
|
| 76 |
+
self,
|
| 77 |
+
attn_span,
|
| 78 |
+
adapt_span_ramp,
|
| 79 |
+
adapt_span_init,
|
| 80 |
+
n_head,
|
| 81 |
+
adapt_span_layer,
|
| 82 |
+
**kargs
|
| 83 |
+
):
|
| 84 |
+
nn.Module.__init__(self)
|
| 85 |
+
self._max_span = attn_span
|
| 86 |
+
self._n_head = n_head
|
| 87 |
+
self._adapt_span_layer = adapt_span_layer
|
| 88 |
+
if self._adapt_span_layer:
|
| 89 |
+
self._mask = AdaptiveMask(
|
| 90 |
+
max_size=self._max_span,
|
| 91 |
+
ramp_size=adapt_span_ramp,
|
| 92 |
+
init_val=adapt_span_init,
|
| 93 |
+
)
|
| 94 |
+
else:
|
| 95 |
+
self._mask = AdaptiveMask(
|
| 96 |
+
max_size=self._max_span,
|
| 97 |
+
ramp_size=adapt_span_ramp,
|
| 98 |
+
init_val=adapt_span_init,
|
| 99 |
+
shape=(n_head, 1, 1),
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
def forward(self, attn, normalize=True):
|
| 103 |
+
"""mask attention with the right span"""
|
| 104 |
+
# batch and head dimensions are merged together, so separate them first
|
| 105 |
+
self.clamp_param()
|
| 106 |
+
if self._adapt_span_layer:
|
| 107 |
+
attn = self._mask(attn)
|
| 108 |
+
else:
|
| 109 |
+
B = attn.size(0) # batch size
|
| 110 |
+
M = attn.size(1) # block size
|
| 111 |
+
attn = attn.reshape(B // self._n_head, self._n_head, M, -1)
|
| 112 |
+
attn = self._mask(attn)
|
| 113 |
+
attn = attn.view(B, M, -1)
|
| 114 |
+
return attn
|
| 115 |
+
|
| 116 |
+
def get_trim_len(self):
|
| 117 |
+
"""how much of memory can be trimmed to reduce computation"""
|
| 118 |
+
L = self._max_span
|
| 119 |
+
trim_len = min(L - 1, L - self._mask.get_current_max_size())
|
| 120 |
+
# too fine granularity might be bad for the memory management
|
| 121 |
+
trim_len = math.floor(trim_len / 64) * 64
|
| 122 |
+
return trim_len
|
| 123 |
+
|
| 124 |
+
def trim_memory(self, query, key, value, key_pe):
|
| 125 |
+
"""trim out unnecessary memory beforehand to reduce computation"""
|
| 126 |
+
trim_len = self.get_trim_len()
|
| 127 |
+
cache_size = key.size(1) - query.size(1)
|
| 128 |
+
trim_len_cache = trim_len - (self._max_span - cache_size)
|
| 129 |
+
if trim_len_cache > 0:
|
| 130 |
+
key = key[:, trim_len_cache:, :]
|
| 131 |
+
value = value[:, trim_len_cache:, :]
|
| 132 |
+
elif trim_len_cache < 0:
|
| 133 |
+
# cache is too short! this happens when validation resumes
|
| 134 |
+
# after a lot of updates.
|
| 135 |
+
key = F.pad(key, [0, 0, -trim_len_cache, 0])
|
| 136 |
+
value = F.pad(value, [0, 0, -trim_len_cache, 0])
|
| 137 |
+
if trim_len > 0:
|
| 138 |
+
if key_pe is not None:
|
| 139 |
+
key_pe = key_pe[:, :, trim_len:]
|
| 140 |
+
return key, value, key_pe
|
| 141 |
+
|
| 142 |
+
def get_cache_size(self):
|
| 143 |
+
"""determine how long the cache should be"""
|
| 144 |
+
trim_len = self.get_trim_len()
|
| 145 |
+
# give a buffer of 64 steps since a span might increase
|
| 146 |
+
# in future updates
|
| 147 |
+
return min(self._max_span, self._max_span - trim_len + 64)
|
| 148 |
+
|
| 149 |
+
def get_loss(self):
|
| 150 |
+
"""a loss term for regularizing the span length"""
|
| 151 |
+
return self._max_span * self._mask.current_val.float().mean()
|
| 152 |
+
|
| 153 |
+
def get_current_max_span(self):
|
| 154 |
+
return self._mask.get_current_max_size()
|
| 155 |
+
|
| 156 |
+
def get_current_avg_span(self):
|
| 157 |
+
return self._mask.get_current_avg_size()
|
| 158 |
+
|
| 159 |
+
def clamp_param(self):
|
| 160 |
+
self._mask.clamp_param()
|
fairseq/examples/adaptive_span/adaptive_span_loss.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from fairseq import metrics, utils
|
| 11 |
+
from fairseq.criterions import register_criterion
|
| 12 |
+
from fairseq.criterions.cross_entropy import CrossEntropyCriterion
|
| 13 |
+
from fairseq.dataclass import FairseqDataclass
|
| 14 |
+
from omegaconf import II
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class AdaptiveSpanCriterionConfig(FairseqDataclass):
|
| 19 |
+
sentence_avg: bool = II("optimization.sentence_avg")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@register_criterion("adaptive_span_loss", dataclass=AdaptiveSpanCriterionConfig)
|
| 23 |
+
class AdaptiveSpanCriterion(CrossEntropyCriterion):
|
| 24 |
+
def __init__(self, task, sentence_avg):
|
| 25 |
+
super().__init__(task, sentence_avg)
|
| 26 |
+
|
| 27 |
+
def forward(self, model, sample, reduce=True):
|
| 28 |
+
"""Compute the loss for the given sample.
|
| 29 |
+
|
| 30 |
+
Returns a tuple with three elements:
|
| 31 |
+
1) the loss here is summed, different from the adaptive span code
|
| 32 |
+
2) the sample size, which is used as the denominator for the gradient
|
| 33 |
+
3) logging outputs to display while training
|
| 34 |
+
"""
|
| 35 |
+
net_output = model(**sample["net_input"])
|
| 36 |
+
loss, aux_loss, avg_span, max_span = self.compute_loss(
|
| 37 |
+
model, net_output, sample, reduce=reduce
|
| 38 |
+
)
|
| 39 |
+
sample_size = (
|
| 40 |
+
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
|
| 41 |
+
)
|
| 42 |
+
loss /= sample_size
|
| 43 |
+
total_loss = loss + aux_loss
|
| 44 |
+
sample_size = 1
|
| 45 |
+
|
| 46 |
+
logging_output = {
|
| 47 |
+
"loss": loss.data,
|
| 48 |
+
"ntokens": sample["ntokens"],
|
| 49 |
+
"nsentences": sample["target"].size(0),
|
| 50 |
+
"sample_size": sample_size,
|
| 51 |
+
"total_loss": total_loss.data,
|
| 52 |
+
"avg_span": avg_span * sample_size,
|
| 53 |
+
"max_span": max_span * sample_size,
|
| 54 |
+
}
|
| 55 |
+
return total_loss, sample_size, logging_output
|
| 56 |
+
|
| 57 |
+
def compute_loss(self, model, net_output, sample, reduce=True):
|
| 58 |
+
loss, _ = super().compute_loss(model, net_output, sample, reduce)
|
| 59 |
+
aux_loss = model.get_aux_loss()
|
| 60 |
+
avg_span = model.get_current_avg_span()
|
| 61 |
+
max_span = model.get_current_max_span()
|
| 62 |
+
return loss, aux_loss, avg_span, max_span
|
| 63 |
+
|
| 64 |
+
@staticmethod
|
| 65 |
+
def reduce_metrics(logging_outputs) -> None:
|
| 66 |
+
"""Aggregate logging outputs from data parallel training."""
|
| 67 |
+
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
| 68 |
+
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
| 69 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
| 70 |
+
total_loss_sum = sum(log.get("total_loss", 0) for log in logging_outputs)
|
| 71 |
+
avg_span_sum = sum(log.get("avg_span", 0) for log in logging_outputs)
|
| 72 |
+
max_span_sum = sum(log.get("max_span", 0) for log in logging_outputs)
|
| 73 |
+
|
| 74 |
+
# we divide by log(2) to convert the loss from base e to base 2
|
| 75 |
+
metrics.log_scalar(
|
| 76 |
+
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
| 77 |
+
)
|
| 78 |
+
metrics.log_scalar("avg_span", avg_span_sum / sample_size, sample_size, round=3)
|
| 79 |
+
metrics.log_scalar("max_span", max_span_sum / sample_size, sample_size, round=3)
|
| 80 |
+
# total loss contains the L1 norm on adaptive-span
|
| 81 |
+
metrics.log_scalar(
|
| 82 |
+
"total_loss",
|
| 83 |
+
total_loss_sum / sample_size / math.log(2),
|
| 84 |
+
sample_size,
|
| 85 |
+
round=3,
|
| 86 |
+
)
|
| 87 |
+
if sample_size != ntokens:
|
| 88 |
+
metrics.log_scalar(
|
| 89 |
+
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
|
| 90 |
+
)
|
| 91 |
+
metrics.log_derived(
|
| 92 |
+
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
|
| 93 |
+
)
|
| 94 |
+
else:
|
| 95 |
+
metrics.log_derived(
|
| 96 |
+
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
@staticmethod
|
| 100 |
+
def logging_outputs_can_be_summed() -> bool:
|
| 101 |
+
"""
|
| 102 |
+
Whether the logging outputs returned by `forward` can be summed
|
| 103 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
| 104 |
+
to True will improves distributed training speed.
|
| 105 |
+
"""
|
| 106 |
+
return True
|
fairseq/examples/adaptive_span/adaptive_span_model.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
from fairseq.modules.layer_norm import LayerNorm
|
| 14 |
+
|
| 15 |
+
from .adaptive_span_attention import AdaptiveSpan
|
| 16 |
+
|
| 17 |
+
# Size notations:
|
| 18 |
+
# B = batch_size, H = d_model, M = block_size, L = attn_span
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _skew(X, pad_value):
|
| 22 |
+
"""shift every row 1 step to right"""
|
| 23 |
+
# X = B x M x L
|
| 24 |
+
B, M, L = X.size()
|
| 25 |
+
X = F.pad(X, (0, M + 1), value=pad_value) # B x M x (L+M+1)
|
| 26 |
+
X = X.view(B, -1) # B x ML+MM+M
|
| 27 |
+
X = X[:, :-M] # B x ML+MM
|
| 28 |
+
X = X.view(B, M, M + L) # B x M x L+M
|
| 29 |
+
return X
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _unskew(X):
|
| 33 |
+
"""reverse _skew operation"""
|
| 34 |
+
# X = B x M x L+M
|
| 35 |
+
B, M, L = X.size()
|
| 36 |
+
L -= M
|
| 37 |
+
X = X.view(B, -1) # B x ML+MM
|
| 38 |
+
X = F.pad(X, (0, M)) # B x ML+MM+M
|
| 39 |
+
X = X.view(B, M, M + L + 1) # B x M x L+M+1
|
| 40 |
+
X = X[:, :, :L] # B x M x L
|
| 41 |
+
return X
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class SeqAttention(nn.Module):
|
| 45 |
+
"""Sequential self-attention layer.
|
| 46 |
+
Each token will attend to its previous fixed number of steps.
|
| 47 |
+
Note that attention doesn't include the current step itself.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(self, d_model, n_head, attn_span, dropout, adapt_span_layer, **kargs):
|
| 51 |
+
nn.Module.__init__(self)
|
| 52 |
+
self.dropout = nn.Dropout(dropout)
|
| 53 |
+
self.d_model = d_model # size of a single head
|
| 54 |
+
self.attn_span = attn_span
|
| 55 |
+
self.adaptive_span = AdaptiveSpan(
|
| 56 |
+
attn_span=attn_span,
|
| 57 |
+
n_head=n_head,
|
| 58 |
+
adapt_span_layer=adapt_span_layer,
|
| 59 |
+
**kargs
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
def forward(self, query, key, value, key_pe):
|
| 63 |
+
# query size = B x M x H
|
| 64 |
+
# key, value sizes = B x (M+L) x H
|
| 65 |
+
|
| 66 |
+
key, value, key_pe = self.adaptive_span.trim_memory(query, key, value, key_pe)
|
| 67 |
+
|
| 68 |
+
# compute attention from context
|
| 69 |
+
# B x M (dest) x (M+L) (src)
|
| 70 |
+
attn_cont = torch.matmul(query, key.transpose(-1, -2))
|
| 71 |
+
attn_cont = _unskew(attn_cont) # B x M x L
|
| 72 |
+
|
| 73 |
+
# compute the effect of position embedding
|
| 74 |
+
attn_pos = torch.matmul(query, key_pe) # B x M x L_pos
|
| 75 |
+
attn = attn_cont + attn_pos
|
| 76 |
+
|
| 77 |
+
attn = attn / math.sqrt(self.d_model) # B x M X L_pos
|
| 78 |
+
|
| 79 |
+
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
|
| 80 |
+
|
| 81 |
+
# trim attention lengths according to the learned span
|
| 82 |
+
attn = self.adaptive_span(attn)
|
| 83 |
+
|
| 84 |
+
attn = self.dropout(attn) # B x M X L_pos
|
| 85 |
+
|
| 86 |
+
attn_cont = _skew(attn, 0) # B x M X (L+M)
|
| 87 |
+
out = torch.matmul(attn_cont, value) # B x M x H
|
| 88 |
+
return out
|
| 89 |
+
|
| 90 |
+
def get_cache_size(self):
|
| 91 |
+
return self.adaptive_span.get_cache_size()
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class MultiHeadSeqAttention(nn.Module):
|
| 95 |
+
def __init__(self, d_model, n_head, **kargs):
|
| 96 |
+
nn.Module.__init__(self)
|
| 97 |
+
assert d_model % n_head == 0
|
| 98 |
+
self.n_head = n_head
|
| 99 |
+
self.head_dim = d_model // n_head
|
| 100 |
+
self.attn = SeqAttention(d_model=self.head_dim, n_head=n_head, **kargs)
|
| 101 |
+
self.proj_query = nn.Linear(d_model, d_model, bias=False)
|
| 102 |
+
nn.init.xavier_normal_(self.proj_query.weight)
|
| 103 |
+
self.proj_out = nn.Linear(d_model, d_model, bias=False)
|
| 104 |
+
nn.init.xavier_normal_(self.proj_out.weight)
|
| 105 |
+
self.proj_val = nn.Linear(d_model, d_model, bias=False)
|
| 106 |
+
nn.init.xavier_normal_(self.proj_val.weight)
|
| 107 |
+
self.proj_key = nn.Linear(d_model, d_model, bias=False)
|
| 108 |
+
nn.init.xavier_normal_(self.proj_key.weight)
|
| 109 |
+
|
| 110 |
+
def head_reshape(self, x):
|
| 111 |
+
K = self.n_head
|
| 112 |
+
D = self.head_dim
|
| 113 |
+
x = x.view(x.size()[:-1] + (K, D)) # B x (M+L) x K x D
|
| 114 |
+
x = x.transpose(1, 2).contiguous() # B x K x (M+L) x D
|
| 115 |
+
x = x.view(-1, x.size(-2), x.size(-1)) # B_K x (M+L) x D
|
| 116 |
+
return x
|
| 117 |
+
|
| 118 |
+
def forward(self, query, key, value, key_pe):
|
| 119 |
+
B = query.size(0)
|
| 120 |
+
K = self.n_head
|
| 121 |
+
D = self.head_dim
|
| 122 |
+
M = query.size(1)
|
| 123 |
+
|
| 124 |
+
query = self.proj_query(query)
|
| 125 |
+
query = self.head_reshape(query)
|
| 126 |
+
value = self.proj_val(value)
|
| 127 |
+
value = self.head_reshape(value)
|
| 128 |
+
key = self.proj_key(key)
|
| 129 |
+
key = self.head_reshape(key)
|
| 130 |
+
|
| 131 |
+
out = self.attn(query, key, value, key_pe) # B_K x M x D
|
| 132 |
+
out = out.view(B, K, M, D) # B x K x M x D
|
| 133 |
+
out = out.transpose(1, 2).contiguous() # B x M x K x D
|
| 134 |
+
out = out.view(B, M, -1) # B x M x K_D
|
| 135 |
+
out = self.proj_out(out)
|
| 136 |
+
return out
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class FeedForwardLayer(nn.Module):
|
| 140 |
+
def __init__(self, d_model, d_inner, dropout, **kargs):
|
| 141 |
+
nn.Module.__init__(self)
|
| 142 |
+
self.fc1 = nn.Linear(d_model, d_inner)
|
| 143 |
+
self.fc2 = nn.Linear(d_inner, d_model)
|
| 144 |
+
nn.init.xavier_uniform_(self.fc1.weight)
|
| 145 |
+
nn.init.xavier_uniform_(self.fc2.weight)
|
| 146 |
+
self.dropout = nn.Dropout(dropout)
|
| 147 |
+
|
| 148 |
+
def forward(self, h):
|
| 149 |
+
h1 = F.relu(self.fc1(h))
|
| 150 |
+
h1 = self.dropout(h1)
|
| 151 |
+
h2 = self.fc2(h1)
|
| 152 |
+
return h2
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class TransformerSeqLayer(nn.Module):
|
| 156 |
+
def __init__(self, d_model, **kargs):
|
| 157 |
+
nn.Module.__init__(self)
|
| 158 |
+
self.attn = MultiHeadSeqAttention(d_model=d_model, **kargs)
|
| 159 |
+
self.norm1 = LayerNorm(d_model)
|
| 160 |
+
self.ff = FeedForwardLayer(d_model=d_model, **kargs)
|
| 161 |
+
self.norm2 = LayerNorm(d_model)
|
| 162 |
+
|
| 163 |
+
def forward(self, h, h_cache, key_pe):
|
| 164 |
+
# h = B x M x H
|
| 165 |
+
# h_cache = B x L x H
|
| 166 |
+
h_all = torch.cat([h_cache, h], dim=1) # B x (M+L) x H
|
| 167 |
+
attn_out = self.attn(h, h_all, h_all, key_pe)
|
| 168 |
+
h = self.norm1(h + attn_out) # B x M x H
|
| 169 |
+
if self.ff is not None:
|
| 170 |
+
ff_out = self.ff(h)
|
| 171 |
+
out = self.norm2(h + ff_out) # B x M x H
|
| 172 |
+
else:
|
| 173 |
+
out = h
|
| 174 |
+
return out
|
| 175 |
+
|
| 176 |
+
def get_cache_size(self):
|
| 177 |
+
return self.attn.attn.get_cache_size()
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class TransformerSeq(nn.Module):
|
| 181 |
+
def __init__(
|
| 182 |
+
self,
|
| 183 |
+
vocab_size,
|
| 184 |
+
d_model,
|
| 185 |
+
n_head,
|
| 186 |
+
n_layer,
|
| 187 |
+
attn_span,
|
| 188 |
+
emb_dropout,
|
| 189 |
+
aux_loss_scaler,
|
| 190 |
+
adapt_span_layer,
|
| 191 |
+
**kargs
|
| 192 |
+
):
|
| 193 |
+
nn.Module.__init__(self)
|
| 194 |
+
# token embeddings
|
| 195 |
+
self.in_emb = nn.Embedding(vocab_size, d_model)
|
| 196 |
+
nn.init.normal_(self.in_emb.weight, mean=0, std=d_model ** -0.5)
|
| 197 |
+
self.out_emb = nn.Linear(d_model, vocab_size)
|
| 198 |
+
self.aux_loss_scaler = aux_loss_scaler
|
| 199 |
+
if emb_dropout > 0:
|
| 200 |
+
self.emb_dropout = nn.Dropout(emb_dropout)
|
| 201 |
+
else:
|
| 202 |
+
self.emb_dropout = None
|
| 203 |
+
# position embeddings
|
| 204 |
+
self.key_pe = nn.Parameter(torch.randn(1, d_model // n_head, attn_span))
|
| 205 |
+
|
| 206 |
+
self.layers = nn.ModuleList()
|
| 207 |
+
self.layers.extend(
|
| 208 |
+
TransformerSeqLayer(
|
| 209 |
+
d_model=d_model,
|
| 210 |
+
n_head=n_head,
|
| 211 |
+
attn_span=attn_span,
|
| 212 |
+
adapt_span_layer=adapt_span_layer,
|
| 213 |
+
**kargs
|
| 214 |
+
)
|
| 215 |
+
for _ in range(n_layer)
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
def forward(self, x, h_cache, target=None):
|
| 219 |
+
# x size = B x M
|
| 220 |
+
block_size = x.size(1)
|
| 221 |
+
h = self.in_emb(x) # B x M x H
|
| 222 |
+
if self.emb_dropout is not None:
|
| 223 |
+
h = self.emb_dropout(h)
|
| 224 |
+
|
| 225 |
+
h_cache_next = []
|
| 226 |
+
for l, layer in enumerate(self.layers):
|
| 227 |
+
cache_size = layer.attn.attn.get_cache_size()
|
| 228 |
+
if cache_size > block_size:
|
| 229 |
+
h_cache_next_l = torch.cat(
|
| 230 |
+
[h_cache[l][:, -cache_size + block_size :, :], h], dim=1
|
| 231 |
+
).detach()
|
| 232 |
+
else:
|
| 233 |
+
h_cache_next_l = h[:, -cache_size:, :].detach()
|
| 234 |
+
h_cache_next.append(h_cache_next_l)
|
| 235 |
+
h = layer(h, h_cache[l], self.key_pe) # B x M x H
|
| 236 |
+
|
| 237 |
+
if self.emb_dropout is not None:
|
| 238 |
+
h = self.emb_dropout(h)
|
| 239 |
+
|
| 240 |
+
out = F.log_softmax(self.out_emb(h).float(), dim=-1).type_as(h)
|
| 241 |
+
dummy_loss = None
|
| 242 |
+
|
| 243 |
+
return out, h_cache_next, dummy_loss
|
| 244 |
+
|
| 245 |
+
def get_aux_loss(self):
|
| 246 |
+
loss = 0.0
|
| 247 |
+
for layer in self.layers:
|
| 248 |
+
loss += layer.attn.attn.adaptive_span.get_loss()
|
| 249 |
+
return self.aux_loss_scaler * loss
|
| 250 |
+
|
| 251 |
+
def get_current_max_span(self):
|
| 252 |
+
max_span = 0.0
|
| 253 |
+
for layer in self.layers:
|
| 254 |
+
max_span = max(
|
| 255 |
+
max_span, layer.attn.attn.adaptive_span.get_current_max_span()
|
| 256 |
+
)
|
| 257 |
+
return max_span
|
| 258 |
+
|
| 259 |
+
def get_current_avg_span(self):
|
| 260 |
+
avg_span = 0.0
|
| 261 |
+
for layer in self.layers:
|
| 262 |
+
avg_span += layer.attn.attn.adaptive_span.get_current_avg_span()
|
| 263 |
+
return avg_span / len(self.layers)
|
fairseq/examples/adaptive_span/adaptive_span_model_wrapper.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Dict, List, Optional
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from fairseq.dataclass import FairseqDataclass
|
| 12 |
+
from fairseq.models import (
|
| 13 |
+
FairseqIncrementalDecoder,
|
| 14 |
+
FairseqLanguageModel,
|
| 15 |
+
register_model,
|
| 16 |
+
)
|
| 17 |
+
from .adaptive_span_model import TransformerSeq as AdaptiveSpanTransformerModel
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class AdaptiveSpanSmallConfig(FairseqDataclass):
|
| 25 |
+
# defaults come from https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8_small.sh
|
| 26 |
+
vocab_size: int = 50
|
| 27 |
+
d_model: int = 256
|
| 28 |
+
n_head: int = 4
|
| 29 |
+
d_inner: int = 1024
|
| 30 |
+
n_layer: int = 8
|
| 31 |
+
attn_span: int = 1024
|
| 32 |
+
dropout: float = 0.0
|
| 33 |
+
emb_dropout: float = 0.0
|
| 34 |
+
adapt_span_ramp: int = 32
|
| 35 |
+
adapt_span_init: float = 0.0
|
| 36 |
+
aux_loss_scaler: float = 0.000002
|
| 37 |
+
adapt_span_layer: bool = False
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@register_model("adaptive_span", dataclass=AdaptiveSpanSmallConfig)
|
| 41 |
+
class AdaptiveSpanTransformer(FairseqLanguageModel):
|
| 42 |
+
@classmethod
|
| 43 |
+
def build_model(cls, cfg: AdaptiveSpanSmallConfig, task):
|
| 44 |
+
return cls(AdaptiveSpanDecoder(cfg, task))
|
| 45 |
+
|
| 46 |
+
def get_aux_loss(self):
|
| 47 |
+
return self.decoder.get_aux_loss()
|
| 48 |
+
|
| 49 |
+
def get_current_max_span(self):
|
| 50 |
+
return self.decoder.get_current_max_span()
|
| 51 |
+
|
| 52 |
+
def get_current_avg_span(self):
|
| 53 |
+
return self.decoder.get_current_avg_span()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class AdaptiveSpanDecoder(FairseqIncrementalDecoder):
|
| 57 |
+
def __init__(self, cfg, task):
|
| 58 |
+
|
| 59 |
+
super().__init__(task.target_dictionary)
|
| 60 |
+
|
| 61 |
+
self.config = cfg
|
| 62 |
+
config = AdaptiveSpanSmallConfig(
|
| 63 |
+
vocab_size=len(task.target_dictionary),
|
| 64 |
+
d_model=cfg.d_model,
|
| 65 |
+
n_head=cfg.n_head,
|
| 66 |
+
d_inner=cfg.d_inner,
|
| 67 |
+
n_layer=cfg.n_layer,
|
| 68 |
+
attn_span=cfg.attn_span,
|
| 69 |
+
dropout=cfg.dropout,
|
| 70 |
+
emb_dropout=cfg.emb_dropout,
|
| 71 |
+
adapt_span_ramp=cfg.adapt_span_ramp,
|
| 72 |
+
adapt_span_init=cfg.adapt_span_init,
|
| 73 |
+
aux_loss_scaler=cfg.aux_loss_scaler,
|
| 74 |
+
adapt_span_layer=cfg.adapt_span_layer,
|
| 75 |
+
)
|
| 76 |
+
logger.info(config)
|
| 77 |
+
self.model = AdaptiveSpanTransformerModel(**config.__dict__)
|
| 78 |
+
|
| 79 |
+
self._mems = None
|
| 80 |
+
|
| 81 |
+
def forward(
|
| 82 |
+
self,
|
| 83 |
+
src_tokens,
|
| 84 |
+
incremental_state: Optional[Dict[str, List[torch.Tensor]]] = None,
|
| 85 |
+
encoder_out=None,
|
| 86 |
+
):
|
| 87 |
+
bsz = src_tokens.size(0)
|
| 88 |
+
if incremental_state is not None: # used during inference
|
| 89 |
+
mems = self.get_incremental_state("mems")
|
| 90 |
+
src_tokens = src_tokens[:, -1:] # only keep the most recent token
|
| 91 |
+
else:
|
| 92 |
+
mems = self._mems
|
| 93 |
+
|
| 94 |
+
if mems is None:
|
| 95 |
+
# first time init
|
| 96 |
+
mems = self.init_hid_cache(bsz)
|
| 97 |
+
output = self.model(x=src_tokens, h_cache=mems,)
|
| 98 |
+
if incremental_state is not None:
|
| 99 |
+
self.set_incremental_state(incremental_state, "mems", output[1])
|
| 100 |
+
else:
|
| 101 |
+
self._mems = output[1]
|
| 102 |
+
return (output[0],)
|
| 103 |
+
|
| 104 |
+
def max_positions(self):
|
| 105 |
+
return self.config.attn_span
|
| 106 |
+
|
| 107 |
+
def init_hid_cache(self, batch_sz):
|
| 108 |
+
hid = []
|
| 109 |
+
for layer in self.model.layers:
|
| 110 |
+
param = next(self.model.parameters())
|
| 111 |
+
h = torch.zeros(
|
| 112 |
+
batch_sz,
|
| 113 |
+
layer.get_cache_size(),
|
| 114 |
+
self.config.d_model,
|
| 115 |
+
dtype=param.dtype,
|
| 116 |
+
device=param.device,
|
| 117 |
+
)
|
| 118 |
+
hid.append(h)
|
| 119 |
+
return hid
|
| 120 |
+
|
| 121 |
+
def get_aux_loss(self):
|
| 122 |
+
return self.model.get_aux_loss()
|
| 123 |
+
|
| 124 |
+
def get_current_max_span(self):
|
| 125 |
+
return self.model.get_current_max_span()
|
| 126 |
+
|
| 127 |
+
def get_current_avg_span(self):
|
| 128 |
+
return self.model.get_current_avg_span()
|
| 129 |
+
|
| 130 |
+
def reorder_incremental_state(
|
| 131 |
+
self,
|
| 132 |
+
incremental_state: Dict[str, Dict[str, Optional[torch.Tensor]]],
|
| 133 |
+
new_order: torch.Tensor,
|
| 134 |
+
):
|
| 135 |
+
"""Reorder incremental state.
|
| 136 |
+
|
| 137 |
+
This will be called when the order of the input has changed from the
|
| 138 |
+
previous time step. A typical use case is beam search, where the input
|
| 139 |
+
order changes between time steps based on the selection of beams.
|
| 140 |
+
"""
|
| 141 |
+
raise NotImplementedError("This is required for generation/beam search")
|
| 142 |
+
# mems = self.get_incremental_state(incremental_state, "mems")
|
| 143 |
+
# if mems is not None:
|
| 144 |
+
# new_mems = [mems_i.index_select(1, new_order) for mems_i in mems]
|
| 145 |
+
# self.set_incremental_state(incremental_state, "mems", new_mems)
|
fairseq/examples/adaptive_span/truncated_bptt_lm_task.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
from dataclasses import dataclass, field
|
| 9 |
+
from typing import List, Optional, Tuple
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from fairseq import utils
|
| 13 |
+
from fairseq.data import (
|
| 14 |
+
Dictionary,
|
| 15 |
+
TokenBlockDataset,
|
| 16 |
+
data_utils,
|
| 17 |
+
iterators,
|
| 18 |
+
)
|
| 19 |
+
from fairseq.dataclass import FairseqDataclass
|
| 20 |
+
from fairseq.distributed import utils as dist_utils
|
| 21 |
+
from fairseq.tasks import FairseqTask, register_task
|
| 22 |
+
from omegaconf import II
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class TruncatedBPTTLMConfig(FairseqDataclass):
|
| 30 |
+
data: str = field(default="???", metadata={"help": "path to data directory"})
|
| 31 |
+
tokens_per_sample: int = field(
|
| 32 |
+
default=1024,
|
| 33 |
+
metadata={"help": "max number of tokens per sequence"},
|
| 34 |
+
)
|
| 35 |
+
batch_size: int = II("dataset.batch_size")
|
| 36 |
+
# Some models use *max_target_positions* to know how many positional
|
| 37 |
+
# embeddings to learn. We use II(...) to make it default to
|
| 38 |
+
# *tokens_per_sample*, but in principle there could be more positional
|
| 39 |
+
# embeddings than tokens in a single batch. This may also be irrelevant for
|
| 40 |
+
# custom model implementations.
|
| 41 |
+
max_target_positions: int = II("task.tokens_per_sample")
|
| 42 |
+
# these will be populated automatically if not provided
|
| 43 |
+
data_parallel_rank: Optional[int] = None
|
| 44 |
+
data_parallel_size: Optional[int] = None
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@register_task("truncated_bptt_lm", dataclass=TruncatedBPTTLMConfig)
|
| 48 |
+
class TruncatedBPTTLMTask(FairseqTask):
|
| 49 |
+
def __init__(self, cfg: TruncatedBPTTLMConfig):
|
| 50 |
+
super().__init__(cfg)
|
| 51 |
+
|
| 52 |
+
if cfg.data_parallel_rank is None or cfg.data_parallel_size is None:
|
| 53 |
+
if torch.distributed.is_initialized():
|
| 54 |
+
cfg.data_parallel_rank = dist_utils.get_data_parallel_rank()
|
| 55 |
+
cfg.data_parallel_size = dist_utils.get_data_parallel_world_size()
|
| 56 |
+
else:
|
| 57 |
+
cfg.data_parallel_rank = 0
|
| 58 |
+
cfg.data_parallel_size = 1
|
| 59 |
+
|
| 60 |
+
# load the dictionary
|
| 61 |
+
paths = utils.split_paths(cfg.data)
|
| 62 |
+
assert len(paths) > 0
|
| 63 |
+
self.dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
|
| 64 |
+
logger.info("dictionary: {} types".format(len(self.dictionary)))
|
| 65 |
+
|
| 66 |
+
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
| 67 |
+
"""Load a given dataset split (e.g., train, valid, test)"""
|
| 68 |
+
|
| 69 |
+
# support sharded datasets
|
| 70 |
+
paths = utils.split_paths(self.cfg.data)
|
| 71 |
+
assert len(paths) > 0
|
| 72 |
+
data_path = paths[(epoch - 1) % len(paths)]
|
| 73 |
+
split_path = os.path.join(data_path, split)
|
| 74 |
+
|
| 75 |
+
# each element of *data* will be a tensorized line from the original
|
| 76 |
+
# text dataset, similar to ``open(split_path).readlines()``
|
| 77 |
+
data = data_utils.load_indexed_dataset(
|
| 78 |
+
split_path, self.dictionary, combine=combine
|
| 79 |
+
)
|
| 80 |
+
if data is None:
|
| 81 |
+
raise FileNotFoundError(
|
| 82 |
+
"Dataset not found: {} ({})".format(split, split_path)
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# this is similar to ``data.view(-1).split(tokens_per_sample)``
|
| 86 |
+
data = TokenBlockDataset(
|
| 87 |
+
data,
|
| 88 |
+
data.sizes,
|
| 89 |
+
block_size=self.cfg.tokens_per_sample,
|
| 90 |
+
pad=None, # unused
|
| 91 |
+
eos=None, # unused
|
| 92 |
+
break_mode="none",
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
self.datasets[split] = TruncatedBPTTDataset(
|
| 96 |
+
data=data,
|
| 97 |
+
bsz_per_shard=self.cfg.batch_size,
|
| 98 |
+
shard_id=self.cfg.data_parallel_rank,
|
| 99 |
+
num_shards=self.cfg.data_parallel_size,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
def dataset(self, split):
|
| 103 |
+
return self.datasets[split]
|
| 104 |
+
|
| 105 |
+
def get_batch_iterator(
|
| 106 |
+
self, dataset, num_workers=0, epoch=1, data_buffer_size=0, **kwargs
|
| 107 |
+
):
|
| 108 |
+
return iterators.EpochBatchIterator(
|
| 109 |
+
dataset=dataset,
|
| 110 |
+
collate_fn=self._collate_fn,
|
| 111 |
+
num_workers=num_workers,
|
| 112 |
+
epoch=epoch,
|
| 113 |
+
buffer_size=data_buffer_size,
|
| 114 |
+
# we don't use the batching functionality from EpochBatchIterator;
|
| 115 |
+
# instead every item in *dataset* is a whole batch
|
| 116 |
+
batch_sampler=[[i] for i in range(len(dataset))],
|
| 117 |
+
disable_shuffling=True,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
def _collate_fn(self, items: List[List[torch.Tensor]]):
|
| 121 |
+
# we don't use fairseq's batching functionality, so we expect a single
|
| 122 |
+
# Tensor of type List[torch.Tensor]
|
| 123 |
+
assert len(items) == 1
|
| 124 |
+
|
| 125 |
+
# item will have shape B x T (the last batch may have length < T)
|
| 126 |
+
id, item = items[0]
|
| 127 |
+
item = data_utils.collate_tokens(item, pad_idx=self.source_dictionary.pad())
|
| 128 |
+
B, T = item.size()
|
| 129 |
+
|
| 130 |
+
# shift item one position over and append a padding token for the target
|
| 131 |
+
target = torch.nn.functional.pad(
|
| 132 |
+
item[:, 1:], (0, 1, 0, 0), value=self.target_dictionary.pad()
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# fairseq expects batches to have the following structure
|
| 136 |
+
return {
|
| 137 |
+
"id": torch.tensor([id]*item.size(0)),
|
| 138 |
+
"net_input": {
|
| 139 |
+
"src_tokens": item,
|
| 140 |
+
},
|
| 141 |
+
"target": target,
|
| 142 |
+
"nsentences": item.size(0),
|
| 143 |
+
"ntokens": item.numel(),
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
def build_dataset_for_inference(
|
| 147 |
+
self, src_tokens: List[torch.Tensor], src_lengths: List[int], **kwargs
|
| 148 |
+
) -> torch.utils.data.Dataset:
|
| 149 |
+
eos = self.source_dictionary.eos()
|
| 150 |
+
dataset = TokenBlockDataset(
|
| 151 |
+
src_tokens,
|
| 152 |
+
src_lengths,
|
| 153 |
+
block_size=None, # ignored for "eos" break mode
|
| 154 |
+
pad=self.source_dictionary.pad(),
|
| 155 |
+
eos=eos,
|
| 156 |
+
break_mode="eos",
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
class Dataset(torch.utils.data.Dataset):
|
| 160 |
+
def __getitem__(self, i):
|
| 161 |
+
item = dataset[i]
|
| 162 |
+
if item[-1] == eos:
|
| 163 |
+
# remove eos to support generating with a prefix
|
| 164 |
+
item = item[:-1]
|
| 165 |
+
return (i, [item])
|
| 166 |
+
|
| 167 |
+
def __len__(self):
|
| 168 |
+
return len(dataset)
|
| 169 |
+
|
| 170 |
+
return Dataset()
|
| 171 |
+
|
| 172 |
+
def inference_step(
|
| 173 |
+
self, generator, models, sample, prefix_tokens=None, constraints=None
|
| 174 |
+
):
|
| 175 |
+
with torch.no_grad():
|
| 176 |
+
if constraints is not None:
|
| 177 |
+
raise NotImplementedError
|
| 178 |
+
|
| 179 |
+
# SequenceGenerator doesn't use *src_tokens* directly, we need to
|
| 180 |
+
# pass the *prefix_tokens* argument instead.
|
| 181 |
+
if prefix_tokens is None and sample["net_input"]["src_tokens"].nelement():
|
| 182 |
+
prefix_tokens = sample["net_input"]["src_tokens"]
|
| 183 |
+
|
| 184 |
+
# begin generation with the end-of-sentence token
|
| 185 |
+
bos_token = self.source_dictionary.eos()
|
| 186 |
+
|
| 187 |
+
return generator.generate(
|
| 188 |
+
models, sample, prefix_tokens=prefix_tokens, bos_token=bos_token
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
def eval_lm_dataloader(
|
| 192 |
+
self,
|
| 193 |
+
dataset,
|
| 194 |
+
max_tokens: Optional[int] = 36000,
|
| 195 |
+
batch_size: Optional[int] = None,
|
| 196 |
+
max_positions: Optional[int] = None,
|
| 197 |
+
num_shards: int = 1,
|
| 198 |
+
shard_id: int = 0,
|
| 199 |
+
num_workers: int = 1,
|
| 200 |
+
data_buffer_size: int = 10,
|
| 201 |
+
context_window: int = 0,
|
| 202 |
+
):
|
| 203 |
+
if context_window > 0:
|
| 204 |
+
raise NotImplementedError(
|
| 205 |
+
"Transformer-XL doesn't need --context-window, try "
|
| 206 |
+
"--model-overrides '{\"mem_len\":42}' instead "
|
| 207 |
+
)
|
| 208 |
+
return self.get_batch_iterator(
|
| 209 |
+
dataset=dataset,
|
| 210 |
+
max_tokens=max_tokens,
|
| 211 |
+
max_sentences=batch_size,
|
| 212 |
+
max_positions=max_positions,
|
| 213 |
+
ignore_invalid_inputs=True,
|
| 214 |
+
num_shards=num_shards,
|
| 215 |
+
shard_id=shard_id,
|
| 216 |
+
num_workers=num_workers,
|
| 217 |
+
data_buffer_size=data_buffer_size,
|
| 218 |
+
).next_epoch_itr(shuffle=False)
|
| 219 |
+
|
| 220 |
+
@property
|
| 221 |
+
def source_dictionary(self):
|
| 222 |
+
return self.dictionary
|
| 223 |
+
|
| 224 |
+
@property
|
| 225 |
+
def target_dictionary(self):
|
| 226 |
+
return self.dictionary
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class TruncatedBPTTDataset(torch.utils.data.Dataset):
|
| 230 |
+
def __init__(
|
| 231 |
+
self,
|
| 232 |
+
data: List[torch.Tensor], # ordered list of items
|
| 233 |
+
bsz_per_shard, # number of items processed per GPUs per forward
|
| 234 |
+
shard_id, # current GPU ID
|
| 235 |
+
num_shards, # number of GPUs
|
| 236 |
+
):
|
| 237 |
+
super().__init__()
|
| 238 |
+
self.data = data
|
| 239 |
+
|
| 240 |
+
def batchify(data, bsz):
|
| 241 |
+
# Work out how cleanly we can divide the dataset into bsz parts.
|
| 242 |
+
nbatch = data.size(0) // bsz
|
| 243 |
+
# Trim off any extra elements that wouldn't cleanly fit (remainders).
|
| 244 |
+
data = data.narrow(0, 0, nbatch * bsz)
|
| 245 |
+
# Evenly divide the data across the bsz batches.
|
| 246 |
+
data = data.view(bsz, -1).contiguous()
|
| 247 |
+
return data
|
| 248 |
+
|
| 249 |
+
# total number of sequences processed by all GPUs in each forward pass
|
| 250 |
+
global_batch_size = bsz_per_shard * num_shards
|
| 251 |
+
|
| 252 |
+
"""
|
| 253 |
+
With a 16 item dataset, bsz_per_shard=2 and num_shards=3,
|
| 254 |
+
*indices* might look like:
|
| 255 |
+
|
| 256 |
+
indices = [[0, 1],
|
| 257 |
+
[2, 3],
|
| 258 |
+
[4, 5],
|
| 259 |
+
[6, 7],
|
| 260 |
+
[8, 9],
|
| 261 |
+
[10, 11]]
|
| 262 |
+
|
| 263 |
+
The size of the TruncatedBPTTDataset instance will be 2,
|
| 264 |
+
and shard 1 will see items:
|
| 265 |
+
|
| 266 |
+
[(0, [data[4], data[6]]),
|
| 267 |
+
(1, [data[5], data[7]])]
|
| 268 |
+
"""
|
| 269 |
+
indices = batchify(torch.arange(len(data)), global_batch_size)
|
| 270 |
+
assert indices.size(0) == global_batch_size
|
| 271 |
+
|
| 272 |
+
self.my_indices = indices[
|
| 273 |
+
shard_id * bsz_per_shard : (shard_id + 1) * bsz_per_shard
|
| 274 |
+
]
|
| 275 |
+
assert self.my_indices.size(0) == bsz_per_shard
|
| 276 |
+
|
| 277 |
+
def __len__(self):
|
| 278 |
+
return self.my_indices.size(1)
|
| 279 |
+
|
| 280 |
+
def __getitem__(self, i) -> Tuple[int, List[torch.Tensor]]:
|
| 281 |
+
return (i, [self.data[idx] for idx in self.my_indices[:, i]])
|
fairseq/examples/backtranslation/README.md
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Understanding Back-Translation at Scale (Edunov et al., 2018)
|
| 2 |
+
|
| 3 |
+
This page includes pre-trained models from the paper [Understanding Back-Translation at Scale (Edunov et al., 2018)](https://arxiv.org/abs/1808.09381).
|
| 4 |
+
|
| 5 |
+
## Pre-trained models
|
| 6 |
+
|
| 7 |
+
Model | Description | Dataset | Download
|
| 8 |
+
---|---|---|---
|
| 9 |
+
`transformer.wmt18.en-de` | Transformer <br> ([Edunov et al., 2018](https://arxiv.org/abs/1808.09381)) <br> WMT'18 winner | [WMT'18 English-German](http://www.statmt.org/wmt18/translation-task.html) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz) <br> See NOTE in the archive
|
| 10 |
+
|
| 11 |
+
## Example usage (torch.hub)
|
| 12 |
+
|
| 13 |
+
We require a few additional Python dependencies for preprocessing:
|
| 14 |
+
```bash
|
| 15 |
+
pip install subword_nmt sacremoses
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
Then to generate translations from the full model ensemble:
|
| 19 |
+
```python
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
# List available models
|
| 23 |
+
torch.hub.list('pytorch/fairseq') # [..., 'transformer.wmt18.en-de', ... ]
|
| 24 |
+
|
| 25 |
+
# Load the WMT'18 En-De ensemble
|
| 26 |
+
en2de_ensemble = torch.hub.load(
|
| 27 |
+
'pytorch/fairseq', 'transformer.wmt18.en-de',
|
| 28 |
+
checkpoint_file='wmt18.model1.pt:wmt18.model2.pt:wmt18.model3.pt:wmt18.model4.pt:wmt18.model5.pt',
|
| 29 |
+
tokenizer='moses', bpe='subword_nmt')
|
| 30 |
+
|
| 31 |
+
# The ensemble contains 5 models
|
| 32 |
+
len(en2de_ensemble.models)
|
| 33 |
+
# 5
|
| 34 |
+
|
| 35 |
+
# Translate
|
| 36 |
+
en2de_ensemble.translate('Hello world!')
|
| 37 |
+
# 'Hallo Welt!'
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
## Training your own model (WMT'18 English-German)
|
| 41 |
+
|
| 42 |
+
The following instructions can be adapted to reproduce the models from the paper.
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
#### Step 1. Prepare parallel data and optionally train a baseline (English-German) model
|
| 46 |
+
|
| 47 |
+
First download and preprocess the data:
|
| 48 |
+
```bash
|
| 49 |
+
# Download and prepare the data
|
| 50 |
+
cd examples/backtranslation/
|
| 51 |
+
bash prepare-wmt18en2de.sh
|
| 52 |
+
cd ../..
|
| 53 |
+
|
| 54 |
+
# Binarize the data
|
| 55 |
+
TEXT=examples/backtranslation/wmt18_en_de
|
| 56 |
+
fairseq-preprocess \
|
| 57 |
+
--joined-dictionary \
|
| 58 |
+
--source-lang en --target-lang de \
|
| 59 |
+
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
|
| 60 |
+
--destdir data-bin/wmt18_en_de --thresholdtgt 0 --thresholdsrc 0 \
|
| 61 |
+
--workers 20
|
| 62 |
+
|
| 63 |
+
# Copy the BPE code into the data-bin directory for future use
|
| 64 |
+
cp examples/backtranslation/wmt18_en_de/code data-bin/wmt18_en_de/code
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
(Optionally) Train a baseline model (English-German) using just the parallel data:
|
| 68 |
+
```bash
|
| 69 |
+
CHECKPOINT_DIR=checkpoints_en_de_parallel
|
| 70 |
+
fairseq-train --fp16 \
|
| 71 |
+
data-bin/wmt18_en_de \
|
| 72 |
+
--source-lang en --target-lang de \
|
| 73 |
+
--arch transformer_wmt_en_de_big --share-all-embeddings \
|
| 74 |
+
--dropout 0.3 --weight-decay 0.0 \
|
| 75 |
+
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
|
| 76 |
+
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
|
| 77 |
+
--lr 0.001 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
|
| 78 |
+
--max-tokens 3584 --update-freq 16 \
|
| 79 |
+
--max-update 30000 \
|
| 80 |
+
--save-dir $CHECKPOINT_DIR
|
| 81 |
+
# Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a
|
| 82 |
+
# different number of GPUs.
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
Average the last 10 checkpoints:
|
| 86 |
+
```bash
|
| 87 |
+
python scripts/average_checkpoints.py \
|
| 88 |
+
--inputs $CHECKPOINT_DIR \
|
| 89 |
+
--num-epoch-checkpoints 10 \
|
| 90 |
+
--output $CHECKPOINT_DIR/checkpoint.avg10.pt
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
Evaluate BLEU:
|
| 94 |
+
```bash
|
| 95 |
+
# tokenized BLEU on newstest2017:
|
| 96 |
+
bash examples/backtranslation/tokenized_bleu.sh \
|
| 97 |
+
wmt17 \
|
| 98 |
+
en-de \
|
| 99 |
+
data-bin/wmt18_en_de \
|
| 100 |
+
data-bin/wmt18_en_de/code \
|
| 101 |
+
$CHECKPOINT_DIR/checkpoint.avg10.pt
|
| 102 |
+
# BLEU4 = 29.57, 60.9/35.4/22.9/15.5 (BP=1.000, ratio=1.014, syslen=63049, reflen=62152)
|
| 103 |
+
# compare to 29.46 in Table 1, which is also for tokenized BLEU
|
| 104 |
+
|
| 105 |
+
# generally it's better to report (detokenized) sacrebleu though:
|
| 106 |
+
bash examples/backtranslation/sacrebleu.sh \
|
| 107 |
+
wmt17 \
|
| 108 |
+
en-de \
|
| 109 |
+
data-bin/wmt18_en_de \
|
| 110 |
+
data-bin/wmt18_en_de/code \
|
| 111 |
+
$CHECKPOINT_DIR/checkpoint.avg10.pt
|
| 112 |
+
# BLEU+case.mixed+lang.en-de+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 29.0 60.6/34.7/22.4/14.9 (BP = 1.000 ratio = 1.013 hyp_len = 62099 ref_len = 61287)
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
#### Step 2. Back-translate monolingual German data
|
| 117 |
+
|
| 118 |
+
Train a reverse model (German-English) to do the back-translation:
|
| 119 |
+
```bash
|
| 120 |
+
CHECKPOINT_DIR=checkpoints_de_en_parallel
|
| 121 |
+
fairseq-train --fp16 \
|
| 122 |
+
data-bin/wmt18_en_de \
|
| 123 |
+
--source-lang de --target-lang en \
|
| 124 |
+
--arch transformer_wmt_en_de_big --share-all-embeddings \
|
| 125 |
+
--dropout 0.3 --weight-decay 0.0 \
|
| 126 |
+
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
|
| 127 |
+
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
|
| 128 |
+
--lr 0.001 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
|
| 129 |
+
--max-tokens 3584 --update-freq 16 \
|
| 130 |
+
--max-update 30000 \
|
| 131 |
+
--save-dir $CHECKPOINT_DIR
|
| 132 |
+
# Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a
|
| 133 |
+
# different number of GPUs.
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
Let's evaluate the back-translation (BT) model to make sure it is well trained:
|
| 137 |
+
```bash
|
| 138 |
+
bash examples/backtranslation/sacrebleu.sh \
|
| 139 |
+
wmt17 \
|
| 140 |
+
de-en \
|
| 141 |
+
data-bin/wmt18_en_de \
|
| 142 |
+
data-bin/wmt18_en_de/code \
|
| 143 |
+
$CHECKPOINT_DIR/checkpoint_best.py
|
| 144 |
+
# BLEU+case.mixed+lang.de-en+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 34.9 66.9/41.8/28.5/19.9 (BP = 0.983 ratio = 0.984 hyp_len = 63342 ref_len = 64399)
|
| 145 |
+
# compare to the best system from WMT'17 which scored 35.1: http://matrix.statmt.org/matrix/systems_list/1868
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
Next prepare the monolingual data:
|
| 149 |
+
```bash
|
| 150 |
+
# Download and prepare the monolingual data
|
| 151 |
+
# By default the script samples 25M monolingual sentences, which after
|
| 152 |
+
# deduplication should be just over 24M sentences. These are split into 25
|
| 153 |
+
# shards, each with 1M sentences (except for the last shard).
|
| 154 |
+
cd examples/backtranslation/
|
| 155 |
+
bash prepare-de-monolingual.sh
|
| 156 |
+
cd ../..
|
| 157 |
+
|
| 158 |
+
# Binarize each shard of the monolingual data
|
| 159 |
+
TEXT=examples/backtranslation/wmt18_de_mono
|
| 160 |
+
for SHARD in $(seq -f "%02g" 0 24); do \
|
| 161 |
+
fairseq-preprocess \
|
| 162 |
+
--only-source \
|
| 163 |
+
--source-lang de --target-lang en \
|
| 164 |
+
--joined-dictionary \
|
| 165 |
+
--srcdict data-bin/wmt18_en_de/dict.de.txt \
|
| 166 |
+
--testpref $TEXT/bpe.monolingual.dedup.${SHARD} \
|
| 167 |
+
--destdir data-bin/wmt18_de_mono/shard${SHARD} \
|
| 168 |
+
--workers 20; \
|
| 169 |
+
cp data-bin/wmt18_en_de/dict.en.txt data-bin/wmt18_de_mono/shard${SHARD}/; \
|
| 170 |
+
done
|
| 171 |
+
```
|
| 172 |
+
|
| 173 |
+
Now we're ready to perform back-translation over the monolingual data. The
|
| 174 |
+
following command generates via sampling, but it's possible to use greedy
|
| 175 |
+
decoding (`--beam 1`), beam search (`--beam 5`),
|
| 176 |
+
top-k sampling (`--sampling --beam 1 --sampling-topk 10`), etc.:
|
| 177 |
+
```bash
|
| 178 |
+
mkdir backtranslation_output
|
| 179 |
+
for SHARD in $(seq -f "%02g" 0 24); do \
|
| 180 |
+
fairseq-generate --fp16 \
|
| 181 |
+
data-bin/wmt18_de_mono/shard${SHARD} \
|
| 182 |
+
--path $CHECKPOINT_DIR/checkpoint_best.pt \
|
| 183 |
+
--skip-invalid-size-inputs-valid-test \
|
| 184 |
+
--max-tokens 4096 \
|
| 185 |
+
--sampling --beam 1 \
|
| 186 |
+
> backtranslation_output/sampling.shard${SHARD}.out; \
|
| 187 |
+
done
|
| 188 |
+
```
|
| 189 |
+
|
| 190 |
+
After BT, use the `extract_bt_data.py` script to re-combine the shards, extract
|
| 191 |
+
the back-translations and apply length ratio filters:
|
| 192 |
+
```bash
|
| 193 |
+
python examples/backtranslation/extract_bt_data.py \
|
| 194 |
+
--minlen 1 --maxlen 250 --ratio 1.5 \
|
| 195 |
+
--output backtranslation_output/bt_data --srclang en --tgtlang de \
|
| 196 |
+
backtranslation_output/sampling.shard*.out
|
| 197 |
+
|
| 198 |
+
# Ensure lengths are the same:
|
| 199 |
+
# wc -l backtranslation_output/bt_data.{en,de}
|
| 200 |
+
# 21795614 backtranslation_output/bt_data.en
|
| 201 |
+
# 21795614 backtranslation_output/bt_data.de
|
| 202 |
+
# 43591228 total
|
| 203 |
+
```
|
| 204 |
+
|
| 205 |
+
Binarize the filtered BT data and combine it with the parallel data:
|
| 206 |
+
```bash
|
| 207 |
+
TEXT=backtranslation_output
|
| 208 |
+
fairseq-preprocess \
|
| 209 |
+
--source-lang en --target-lang de \
|
| 210 |
+
--joined-dictionary \
|
| 211 |
+
--srcdict data-bin/wmt18_en_de/dict.en.txt \
|
| 212 |
+
--trainpref $TEXT/bt_data \
|
| 213 |
+
--destdir data-bin/wmt18_en_de_bt \
|
| 214 |
+
--workers 20
|
| 215 |
+
|
| 216 |
+
# We want to train on the combined data, so we'll symlink the parallel + BT data
|
| 217 |
+
# in the wmt18_en_de_para_plus_bt directory. We link the parallel data as "train"
|
| 218 |
+
# and the BT data as "train1", so that fairseq will combine them automatically
|
| 219 |
+
# and so that we can use the `--upsample-primary` option to upsample the
|
| 220 |
+
# parallel data (if desired).
|
| 221 |
+
PARA_DATA=$(readlink -f data-bin/wmt18_en_de)
|
| 222 |
+
BT_DATA=$(readlink -f data-bin/wmt18_en_de_bt)
|
| 223 |
+
COMB_DATA=data-bin/wmt18_en_de_para_plus_bt
|
| 224 |
+
mkdir -p $COMB_DATA
|
| 225 |
+
for LANG in en de; do \
|
| 226 |
+
ln -s ${PARA_DATA}/dict.$LANG.txt ${COMB_DATA}/dict.$LANG.txt; \
|
| 227 |
+
for EXT in bin idx; do \
|
| 228 |
+
ln -s ${PARA_DATA}/train.en-de.$LANG.$EXT ${COMB_DATA}/train.en-de.$LANG.$EXT; \
|
| 229 |
+
ln -s ${BT_DATA}/train.en-de.$LANG.$EXT ${COMB_DATA}/train1.en-de.$LANG.$EXT; \
|
| 230 |
+
ln -s ${PARA_DATA}/valid.en-de.$LANG.$EXT ${COMB_DATA}/valid.en-de.$LANG.$EXT; \
|
| 231 |
+
ln -s ${PARA_DATA}/test.en-de.$LANG.$EXT ${COMB_DATA}/test.en-de.$LANG.$EXT; \
|
| 232 |
+
done; \
|
| 233 |
+
done
|
| 234 |
+
```
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
#### 3. Train an English-German model over the combined parallel + BT data
|
| 238 |
+
|
| 239 |
+
Finally we can train a model over the parallel + BT data:
|
| 240 |
+
```bash
|
| 241 |
+
CHECKPOINT_DIR=checkpoints_en_de_parallel_plus_bt
|
| 242 |
+
fairseq-train --fp16 \
|
| 243 |
+
data-bin/wmt18_en_de_para_plus_bt \
|
| 244 |
+
--upsample-primary 16 \
|
| 245 |
+
--source-lang en --target-lang de \
|
| 246 |
+
--arch transformer_wmt_en_de_big --share-all-embeddings \
|
| 247 |
+
--dropout 0.3 --weight-decay 0.0 \
|
| 248 |
+
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
|
| 249 |
+
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
|
| 250 |
+
--lr 0.0007 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
|
| 251 |
+
--max-tokens 3584 --update-freq 16 \
|
| 252 |
+
--max-update 100000 \
|
| 253 |
+
--save-dir $CHECKPOINT_DIR
|
| 254 |
+
# Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a
|
| 255 |
+
# different number of GPUs.
|
| 256 |
+
```
|
| 257 |
+
|
| 258 |
+
Average the last 10 checkpoints:
|
| 259 |
+
```bash
|
| 260 |
+
python scripts/average_checkpoints.py \
|
| 261 |
+
--inputs $CHECKPOINT_DIR \
|
| 262 |
+
--num-epoch-checkpoints 10 \
|
| 263 |
+
--output $CHECKPOINT_DIR/checkpoint.avg10.pt
|
| 264 |
+
```
|
| 265 |
+
|
| 266 |
+
Evaluate BLEU:
|
| 267 |
+
```bash
|
| 268 |
+
# tokenized BLEU on newstest2017:
|
| 269 |
+
bash examples/backtranslation/tokenized_bleu.sh \
|
| 270 |
+
wmt17 \
|
| 271 |
+
en-de \
|
| 272 |
+
data-bin/wmt18_en_de \
|
| 273 |
+
data-bin/wmt18_en_de/code \
|
| 274 |
+
$CHECKPOINT_DIR/checkpoint.avg10.pt
|
| 275 |
+
# BLEU4 = 32.35, 64.4/38.9/26.2/18.3 (BP=0.977, ratio=0.977, syslen=60729, reflen=62152)
|
| 276 |
+
# compare to 32.35 in Table 1, which is also for tokenized BLEU
|
| 277 |
+
|
| 278 |
+
# generally it's better to report (detokenized) sacrebleu:
|
| 279 |
+
bash examples/backtranslation/sacrebleu.sh \
|
| 280 |
+
wmt17 \
|
| 281 |
+
en-de \
|
| 282 |
+
data-bin/wmt18_en_de \
|
| 283 |
+
data-bin/wmt18_en_de/code \
|
| 284 |
+
$CHECKPOINT_DIR/checkpoint.avg10.pt
|
| 285 |
+
# BLEU+case.mixed+lang.en-de+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 31.5 64.3/38.2/25.6/17.6 (BP = 0.971 ratio = 0.971 hyp_len = 59515 ref_len = 61287)
|
| 286 |
+
```
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
## Citation
|
| 290 |
+
```bibtex
|
| 291 |
+
@inproceedings{edunov2018backtranslation,
|
| 292 |
+
title = {Understanding Back-Translation at Scale},
|
| 293 |
+
author = {Edunov, Sergey and Ott, Myle and Auli, Michael and Grangier, David},
|
| 294 |
+
booktitle = {Conference of the Association for Computational Linguistics (ACL)},
|
| 295 |
+
year = 2018,
|
| 296 |
+
}
|
| 297 |
+
```
|
fairseq/examples/backtranslation/deduplicate_lines.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the MIT license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
import fileinput
|
| 9 |
+
import hashlib
|
| 10 |
+
import sys
|
| 11 |
+
from multiprocessing import Pool
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_hashes_and_lines(raw_line):
|
| 15 |
+
hash = hashlib.md5(raw_line).hexdigest()
|
| 16 |
+
return hash, raw_line
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def main():
|
| 20 |
+
parser = argparse.ArgumentParser()
|
| 21 |
+
parser.add_argument("--workers", type=int, default=10)
|
| 22 |
+
parser.add_argument("files", nargs="*", help="input files")
|
| 23 |
+
args = parser.parse_args()
|
| 24 |
+
|
| 25 |
+
seen = set()
|
| 26 |
+
with fileinput.input(args.files, mode="rb") as h:
|
| 27 |
+
pool = Pool(args.workers)
|
| 28 |
+
results = pool.imap_unordered(get_hashes_and_lines, h, 1000)
|
| 29 |
+
for i, (hash, raw_line) in enumerate(results):
|
| 30 |
+
if hash not in seen:
|
| 31 |
+
seen.add(hash)
|
| 32 |
+
sys.stdout.buffer.write(raw_line)
|
| 33 |
+
if i % 1000000 == 0:
|
| 34 |
+
print(i, file=sys.stderr, end="", flush=True)
|
| 35 |
+
elif i % 100000 == 0:
|
| 36 |
+
print(".", file=sys.stderr, end="", flush=True)
|
| 37 |
+
print(file=sys.stderr, flush=True)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
if __name__ == "__main__":
|
| 41 |
+
main()
|
fairseq/examples/backtranslation/extract_bt_data.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the MIT license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
import fileinput
|
| 9 |
+
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def main():
|
| 14 |
+
parser = argparse.ArgumentParser(
|
| 15 |
+
description=(
|
| 16 |
+
"Extract back-translations from the stdout of fairseq-generate. "
|
| 17 |
+
"If there are multiply hypotheses for a source, we only keep the first one. "
|
| 18 |
+
)
|
| 19 |
+
)
|
| 20 |
+
parser.add_argument("--output", required=True, help="output prefix")
|
| 21 |
+
parser.add_argument(
|
| 22 |
+
"--srclang", required=True, help="source language (extracted from H-* lines)"
|
| 23 |
+
)
|
| 24 |
+
parser.add_argument(
|
| 25 |
+
"--tgtlang", required=True, help="target language (extracted from S-* lines)"
|
| 26 |
+
)
|
| 27 |
+
parser.add_argument("--minlen", type=int, help="min length filter")
|
| 28 |
+
parser.add_argument("--maxlen", type=int, help="max length filter")
|
| 29 |
+
parser.add_argument("--ratio", type=float, help="ratio filter")
|
| 30 |
+
parser.add_argument("files", nargs="*", help="input files")
|
| 31 |
+
args = parser.parse_args()
|
| 32 |
+
|
| 33 |
+
def validate(src, tgt):
|
| 34 |
+
srclen = len(src.split(" ")) if src != "" else 0
|
| 35 |
+
tgtlen = len(tgt.split(" ")) if tgt != "" else 0
|
| 36 |
+
if (
|
| 37 |
+
(args.minlen is not None and (srclen < args.minlen or tgtlen < args.minlen))
|
| 38 |
+
or (
|
| 39 |
+
args.maxlen is not None
|
| 40 |
+
and (srclen > args.maxlen or tgtlen > args.maxlen)
|
| 41 |
+
)
|
| 42 |
+
or (
|
| 43 |
+
args.ratio is not None
|
| 44 |
+
and (max(srclen, tgtlen) / float(min(srclen, tgtlen)) > args.ratio)
|
| 45 |
+
)
|
| 46 |
+
):
|
| 47 |
+
return False
|
| 48 |
+
return True
|
| 49 |
+
|
| 50 |
+
def safe_index(toks, index, default):
|
| 51 |
+
try:
|
| 52 |
+
return toks[index]
|
| 53 |
+
except IndexError:
|
| 54 |
+
return default
|
| 55 |
+
|
| 56 |
+
with open(args.output + "." + args.srclang, "w") as src_h, open(
|
| 57 |
+
args.output + "." + args.tgtlang, "w"
|
| 58 |
+
) as tgt_h:
|
| 59 |
+
for line in tqdm(fileinput.input(args.files)):
|
| 60 |
+
if line.startswith("S-"):
|
| 61 |
+
tgt = safe_index(line.rstrip().split("\t"), 1, "")
|
| 62 |
+
elif line.startswith("H-"):
|
| 63 |
+
if tgt is not None:
|
| 64 |
+
src = safe_index(line.rstrip().split("\t"), 2, "")
|
| 65 |
+
if validate(src, tgt):
|
| 66 |
+
print(src, file=src_h)
|
| 67 |
+
print(tgt, file=tgt_h)
|
| 68 |
+
tgt = None
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
if __name__ == "__main__":
|
| 72 |
+
main()
|