diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml new file mode 100644 index 0000000000000000000000000000000000000000..fc50e11b67ea356c3f47bdab2973c9eb03b7114b --- /dev/null +++ b/.github/workflows/benchmark.yml @@ -0,0 +1,107 @@ +name: "Benchmark on Comment" + +# https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows +on: + issue_comment: + types: [created] + +jobs: + Benchmark: + strategy: + fail-fast: true + matrix: + python-version: [3.9] + os: [self-hosted] + + name: Benchmark + # Only run if it#s a PR and the comment contains /Benchmark + if: github.event.issue.pull_request && startsWith(github.event.comment.body, '/benchmark-trl-experiments') && contains(FromJSON('["vwxyzjn", "younesbelkada", "lvwerra", "lewtun"]'), github.actor) + runs-on: ${{ matrix.os }} + + steps: + - name: Get branch of PR + uses: xt0rted/pull-request-comment-branch@v1 + id: comment-branch + - name: Set latest commit status as pending + uses: myrotvorets/set-commit-status-action@master + with: + sha: ${{ steps.comment-branch.outputs.head_sha }} + token: ${{ secrets.GITHUB_TOKEN }} + status: pending + - name: Checkout `main` branch + uses: actions/checkout@v3 + - name: Checkout PR branch + run: gh pr checkout $PR_NUMBER + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + PR_NUMBER: ${{ github.event.issue.number }} + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + # - name: Cleanup pip packages (specific to self-hosted runners) + # run: | + # echo PATH is $PATH + # echo PYTHONPATH is $PYTHONPATH + # echo which python is $(which python) + # echo which pip is $(which pip) + + # pip_list=$(pip list --format=freeze | grep -v "^pip==" | grep -v "^setuptools==") + # if [ ! -z "$pip_list" ]; then + # echo "$pip_list" | xargs pip uninstall -y + # fi + - name: Print python depdenencies + run: pip list --format=freeze + - name: Install dependencies + run: | + pip install .[test,benchmark] + + - name: Login + run: wandb login ${{ secrets.WANDB_API_KEY }} && huggingface-cli login --token ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + - name: Run benchmark + env: + GITHUB_CONTEXT: ${{ toJson(github) }} + PERSONAL_ACCESS_TOKEN_GITHUB: ${{ secrets.PERSONAL_ACCESS_TOKEN_GITHUB }} + run: | + COMMENT="${{ github.event.comment.body }}" + if [[ "$COMMENT" == *"/benchmark-trl-experiments benchmark/benchmark_level1.sh"* ]]; then + echo "Running benchmark/benchmark_level1.sh" + BENCHMARK_SCRIPT="benchmark/benchmark_level1.sh" BENCHMARK_PLOT_SCRIPT="benchmark/benchmark_level1_plot.sh" bash benchmark/benchmark_and_report.sh + elif [[ "$COMMENT" == *"/benchmark-trl-experiments benchmark/benchmark_level2.sh"* ]]; then + echo "Running benchmark/benchmark_level2.sh" + BENCHMARK_SCRIPT="benchmark/benchmark_level2.sh" BENCHMARK_PLOT_SCRIPT="benchmark/benchmark_level2_plot.sh" bash benchmark/benchmark_and_report.sh + elif [[ "$COMMENT" == *"/benchmark-trl-experiments benchmark/benchmark_level3.sh"* ]]; then + echo "Running benchmark/benchmark_level3.sh" + BENCHMARK_SCRIPT="benchmark/benchmark_level3.sh" BENCHMARK_PLOT_SCRIPT="benchmark/benchmark_level3_plot.sh" bash benchmark/benchmark_and_report.sh + else + echo "Invalid command in comment. Skipping execution." + fi + + # send message to PR + - name: Setup Node.js 16 + uses: actions/setup-node@v3 + with: + node-version: 16 + - name: Add workflow result as comment on PR + uses: actions/github-script@v6 + if: always() + with: + script: | + const name = '${{ github.workflow }}'; + const url = '${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}'; + const success = '${{ job.status }}' === 'success'; + const body = `${name}: ${success ? 'succeeded ✅' : 'failed ❌'}\n${url}`; + + await github.rest.issues.createComment({ + issue_number: context.issue.number, + owner: context.repo.owner, + repo: context.repo.repo, + body: body + }) + - name: Set latest commit status as ${{ job.status }} + uses: myrotvorets/set-commit-status-action@master + if: always() + with: + sha: ${{ steps.comment-branch.outputs.head_sha }} + token: ${{ secrets.GITHUB_TOKEN }} + status: ${{ job.status }} diff --git a/.github/workflows/build_documentation.yml b/.github/workflows/build_documentation.yml new file mode 100644 index 0000000000000000000000000000000000000000..5570c872b1656174ef020e3888f4e4fd993055ff --- /dev/null +++ b/.github/workflows/build_documentation.yml @@ -0,0 +1,18 @@ +name: Build documentation + +on: + push: + branches: + - main + - doc-builder* + - v*-release + +jobs: + build: + uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main + with: + commit_sha: ${{ github.sha }} + package: trl + version_tag_suffix: "" + secrets: + hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} diff --git a/.github/workflows/build_pr_documentation.yml b/.github/workflows/build_pr_documentation.yml new file mode 100644 index 0000000000000000000000000000000000000000..551fd13feeb54a2a7c6ed5416639f611221ba13e --- /dev/null +++ b/.github/workflows/build_pr_documentation.yml @@ -0,0 +1,17 @@ +name: Build PR Documentation + +on: + pull_request: + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + build: + uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main + with: + commit_sha: ${{ github.event.pull_request.head.sha }} + pr_number: ${{ github.event.number }} + package: trl + version_tag_suffix: "" \ No newline at end of file diff --git a/.github/workflows/clear_cache.yml b/.github/workflows/clear_cache.yml new file mode 100644 index 0000000000000000000000000000000000000000..20bab26687ca1aac453c4c1a76797cbcd8114d31 --- /dev/null +++ b/.github/workflows/clear_cache.yml @@ -0,0 +1,33 @@ +name: "Cleanup Cache" + +on: + workflow_dispatch: + schedule: + - cron: "0 0 * * *" + +jobs: + cleanup: + runs-on: ubuntu-latest + steps: + - name: Check out code + uses: actions/checkout@v3 + + - name: Cleanup + run: | + gh extension install actions/gh-actions-cache + + REPO=${{ github.repository }} + + echo "Fetching list of cache key" + cacheKeysForPR=$(gh actions-cache list -R $REPO | cut -f 1 ) + + ## Setting this to not fail the workflow while deleting cache keys. + set +e + echo "Deleting caches..." + for cacheKey in $cacheKeysForPR + do + gh actions-cache delete $cacheKey -R $REPO --confirm + done + echo "Done" + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml new file mode 100644 index 0000000000000000000000000000000000000000..b3b626663f347de19ff721953305208457b559eb --- /dev/null +++ b/.github/workflows/stale.yml @@ -0,0 +1,27 @@ +name: Stale Bot + +on: + schedule: + - cron: "0 15 * * *" + +jobs: + close_stale_issues: + name: Close Stale Issues + if: github.repository == 'huggingface/trl' + runs-on: ubuntu-latest + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + steps: + - uses: actions/checkout@v3 + + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: 3.8 + + - name: Install requirements + run: | + pip install PyGithub + - name: Close stale issues + run: | + python scripts/stale.py \ No newline at end of file diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000000000000000000000000000000000000..6d2162f02252e8603e76dd82bf7134085a7183ca --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,75 @@ +name: tests + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + check_code_quality: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.9] + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + submodules: recursive + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - uses: pre-commit/action@v2.0.3 + with: + extra_args: --all-files + + tests: + needs: check_code_quality + strategy: + matrix: + python-version: ['3.8', '3.9', '3.10'] + os: ['ubuntu-latest', 'windows-latest'] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + cache: "pip" + cache-dependency-path: | + setup.py + requirements.txt + - name: Install dependencies + run: | + python -m pip install --upgrade pip + # cpu version of pytorch + pip install -e ".[test, peft, diffusers]" + - name: Test with pytest + run: | + make test + + tests_no_optional_dep: + needs: check_code_quality + runs-on: 'ubuntu-latest' + steps: + - uses: actions/checkout@v3 + - name: Set up Python 3.9 + uses: actions/setup-python@v4 + with: + python-version: '3.9' + cache: "pip" + cache-dependency-path: | + setup.py + requirements.txt + - name: Install dependencies + run: | + python -m pip install --upgrade pip + # cpu version of pytorch + pip install .[test] + - name: Test with pytest + run: | + make test \ No newline at end of file diff --git a/.github/workflows/upload_pr_documentation.yml b/.github/workflows/upload_pr_documentation.yml new file mode 100644 index 0000000000000000000000000000000000000000..2ad2ba0e8de52699f60c2da7792dab742dd6f200 --- /dev/null +++ b/.github/workflows/upload_pr_documentation.yml @@ -0,0 +1,16 @@ +name: Upload PR Documentation + +on: + workflow_run: + workflows: ["Build PR Documentation"] + types: + - completed + +jobs: + build: + uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main + with: + package_name: trl + secrets: + hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} + comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }} \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..58cb39ec20133997c0f73ebe4ea2754180c57f5e --- /dev/null +++ b/.gitignore @@ -0,0 +1,146 @@ +benchmark/trl +*.bak +.gitattributes +.last_checked +.gitconfig +*.bak +*.log +*~ +~* +_tmp* +tmp* +tags + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# dotenv +.env + +# virtualenv +.venv +venv/ +ENV/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +.vscode +*.swp + +# osx generated files +.DS_Store +.DS_Store? +.Trashes +ehthumbs.db +Thumbs.db +.idea + +# pytest +.pytest_cache + +# tools/trust-doc-nbs +docs_src/.last_checked + +# symlinks to fastai +docs_src/fastai +tools/fastai + +# link checker +checklink/cookies.txt + +# .gitconfig is now autogenerated +.gitconfig + +# wandb files +nbs/wandb/ +examples/notebooks/wandb/ +wandb/ \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..545815fe673cca31be5b751c8f85062c988367b6 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,42 @@ +repos: + - repo: https://github.com/PyCQA/isort + rev: 5.12.0 + hooks: + - id: isort + args: + - --profile=black + - --skip-glob=wandb/**/* + - --thirdparty=wandb + - repo: https://github.com/myint/autoflake + rev: v1.4 + hooks: + - id: autoflake + args: + - -r + - --exclude=wandb,__init__.py + - --in-place + - --remove-unused-variables + - --remove-all-unused-imports + - repo: https://github.com/python/black + rev: 22.3.0 + hooks: + - id: black + args: + - --line-length=119 + - --target-version=py38 + - --exclude=wandb + - repo: https://github.com/pycqa/flake8 + rev: 6.0.0 + hooks: + - id: flake8 + args: + - --ignore=E203,E501,W503,E128 + - --max-line-length=119 + + # - repo: https://github.com/codespell-project/codespell + # rev: v2.1.0 + # hooks: + # - id: codespell + # args: + # - --ignore-words-list=nd,reacher,thist,ths,magent,ba + # - --skip=docs/css/termynal.css,docs/js/termynal.js diff --git a/CITATION.cff b/CITATION.cff new file mode 100644 index 0000000000000000000000000000000000000000..c62318ecf2ef4965d9c3129c673f6bb0db35d64e --- /dev/null +++ b/CITATION.cff @@ -0,0 +1,28 @@ +cff-version: 1.2.0 +title: 'TRL: Transformer Reinforcement Learning' +message: >- + If you use this software, please cite it using the + metadata from this file. +type: software +authors: + - given-names: Leandro + family-names: von Werra + - given-names: Younes + family-names: Belkada + - given-names: Lewis + family-names: Tunstall + - given-names: Edward + family-names: Beeching + - given-names: Tristan + family-names: Thrush + - given-names: Nathan + family-names: Lambert +repository-code: 'https://github.com/huggingface/trl' +abstract: "With trl you can train transformer language models with Proximal Policy Optimization (PPO). The library is built on top of the transformers library by \U0001F917 Hugging Face. Therefore, pre-trained language models can be directly loaded via transformers. At this point, most decoder and encoder-decoder architectures are supported." +keywords: + - rlhf + - deep-learning + - pytorch + - transformers +license: Apache-2.0 +version: 0.2.1 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..de1731bccbd5b667f2a40582e443e86d36accfc4 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,53 @@ +# How to contribute + +## How to get started + +Before you start contributing make sure you installed all the dev tools: + +```bash +pip install -e ".[dev]" +``` + +## Did you find a bug? + +* Ensure the bug was not already reported by searching on GitHub under Issues. +* If you're unable to find an open issue addressing the problem, open a new one. Be sure to include a title and clear description, as much relevant information as possible, and a code sample or an executable test case demonstrating the expected behavior that is not occurring. +* Be sure to add the complete error messages. + +#### Did you write a patch that fixes a bug? + +* Open a new GitHub pull request with the patch. +* Ensure that your PR includes a test that fails without your patch, and pass with it. +* Ensure the PR description clearly describes the problem and solution. Include the relevant issue number if applicable. + +## PR submission guidelines + +* Keep each PR focused. While it's more convenient, do not combine several unrelated fixes together. Create as many branches as needing to keep each PR focused. +* Do not mix style changes/fixes with "functional" changes. It's very difficult to review such PRs and it most likely get rejected. +* Do not add/remove vertical whitespace. Preserve the original style of the file you edit as much as you can. +* Do not turn an already submitted PR into your development playground. If after you submitted PR, you discovered that more work is needed - close the PR, do the required work and then submit a new PR. Otherwise each of your commits requires attention from maintainers of the project. +* If, however, you submitted a PR and received a request for changes, you should proceed with commits inside that PR, so that the maintainer can see the incremental fixes and won't need to review the whole PR again. In the exception case where you realize it'll take many many commits to complete the requests, then it's probably best to close the PR, do the work and then submit it again. Use common sense where you'd choose one way over another. + +### Before you submit a PR + +First you want to make sure that all the tests pass: + +```bash +make test +``` + +Then before submitting your PR make sure the code quality follows the standards. You can run the following command to format: + +```bash +make precommit +``` + +Make sure to install `pre-commit` before running the command: +```bash +pip install pre-commit +``` + +## Do you want to contribute to the documentation? + +* Docs are in the `docs/` folder and can be updated there. + diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000000000000000000000000000000000000..5c0e7ced193cb35aee60cbadc3e022a1da0fd8cc --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,5 @@ +include settings.ini +include LICENSE +include CONTRIBUTING.md +include README.md +recursive-exclude * __pycache__ diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..c578697ce3d69a9e65b4efbbdb4592434875a265 --- /dev/null +++ b/Makefile @@ -0,0 +1,15 @@ +.PHONY: test precommit benchmark_core benchmark_aux + +check_dirs := examples tests trl + +test: + python -m pytest -n auto --dist=loadfile -s -v ./tests/ + +precommit: + pre-commit run --all-files + +benchmark_core: + bash ./benchmark/benchmark_core.sh + +benchmark_aux: + bash ./benchmark/benchmark_aux.sh diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..25e501769d7bb924b8239895a1efccc1730d9eef --- /dev/null +++ b/README.md @@ -0,0 +1,184 @@ +
+ +
+ +# TRL - Transformer Reinforcement Learning +> Full stack transformer language models with reinforcement learning. + +

+ + License + + + Documentation + + + GitHub release + +

+ + +## What is it? + +
+ +
+ +`trl` is a full stack library where we provide a set of tools to train transformer language models and stable diffusion models with Reinforcement Learning, from the Supervised Fine-tuning step (SFT), Reward Modeling step (RM) to the Proximal Policy Optimization (PPO) step. The library is built on top of the [`transformers`](https://github.com/huggingface/transformers) library by 🤗 Hugging Face. Therefore, pre-trained language models can be directly loaded via `transformers`. At this point, most of decoder architectures and encoder-decoder architectures are supported. Refer to the documentation or the `examples/` folder for example code snippets and how to run these tools. + +**Highlights:** + +- [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer): A light and friendly wrapper around `transformers` Trainer to easily fine-tune language models or adapters on a custom dataset. +- [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer): A light wrapper around `transformers` Trainer to easily fine-tune language models for human preferences (Reward Modeling). +- [`PPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.PPOTrainer): A PPO trainer for language models that just needs (query, response, reward) triplets to optimise the language model. +- [`AutoModelForCausalLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForCausalLMWithValueHead) & [`AutoModelForSeq2SeqLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForSeq2SeqLMWithValueHead): A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning. +- [Examples](https://github.com/huggingface/trl/tree/main/examples): Train GPT2 to generate positive movie reviews with a BERT sentiment classifier, full RLHF using adapters only, train GPT-j to be less toxic, [Stack-Llama example](https://huggingface.co/blog/stackllama), etc. + +## How PPO works +Fine-tuning a language model via PPO consists of roughly three steps: + +1. **Rollout**: The language model generates a response or continuation based on query which could be the start of a sentence. +2. **Evaluation**: The query and response are evaluated with a function, model, human feedback or some combination of them. The important thing is that this process should yield a scalar value for each query/response pair. +3. **Optimization**: This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences. This is done with the model that is trained and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don't deviate too far from the reference language model. The active language model is then trained with PPO. + +This process is illustrated in the sketch below: + + +
+ +

Figure: Sketch of the workflow.

+
+ +## Installation + +### Python package +Install the library with pip: +```bash +pip install trl +``` + +### From source +If you want to run the examples in the repository a few additional libraries are required. Clone the repository and install it with pip: +```bash +git clone https://github.com/huggingface/trl.git +cd trl/ +pip install . +``` + +If you wish to develop TRL, you should install in editable mode: +```bash +pip install -e . +``` + +## How to use + +### `SFTTrainer` + +This is a basic example on how to use the `SFTTrainer` from the library. The `SFTTrainer` is a light wrapper around the `transformers` Trainer to easily fine-tune language models or adapters on a custom dataset. + +```python +# imports +from datasets import load_dataset +from trl import SFTTrainer + +# get dataset +dataset = load_dataset("imdb", split="train") + +# get trainer +trainer = SFTTrainer( + "facebook/opt-350m", + train_dataset=dataset, + dataset_text_field="text", + max_seq_length=512, +) + +# train +trainer.train() +``` + +### `RewardTrainer` + +This is a basic example on how to use the `RewardTrainer` from the library. The `RewardTrainer` is a wrapper around the `transformers` Trainer to easily fine-tune reward models or adapters on a custom preference dataset. + +```python +# imports +from transformers import AutoModelForSequenceClassification, AutoTokenizer +from trl import RewardTrainer + +# load model and dataset - dataset needs to be in a specific format +model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=1) +tokenizer = AutoTokenizer.from_pretrained("gpt2") + +... + +# load trainer +trainer = RewardTrainer( + model=model, + tokenizer=tokenizer, + train_dataset=dataset, +) + +# train +trainer.train() +``` + +### `PPOTrainer` + +This is a basic example on how to use the `PPOTrainer` from the library. Based on a query the language model creates a response which is then evaluated. The evaluation could be a human in the loop or another model's output. + +```python +# imports +import torch +from transformers import AutoTokenizer +from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model +from trl.core import respond_to_batch + +# get models +model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2') +model_ref = create_reference_model(model) + +tokenizer = AutoTokenizer.from_pretrained('gpt2') + +# initialize trainer +ppo_config = PPOConfig( + batch_size=1, +) + +# encode a query +query_txt = "This morning I went to the " +query_tensor = tokenizer.encode(query_txt, return_tensors="pt") + +# get model response +response_tensor = respond_to_batch(model, query_tensor) + +# create a ppo trainer +ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer) + +# define a reward for response +# (this could be any reward such as human feedback or output from another model) +reward = [torch.tensor(1.0)] + +# train model for one step with ppo +train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward) +``` + +## References + +### Proximal Policy Optimisation +The PPO implementation largely follows the structure introduced in the paper **"Fine-Tuning Language Models from Human Preferences"** by D. Ziegler et al. \[[paper](https://arxiv.org/pdf/1909.08593.pdf), [code](https://github.com/openai/lm-human-preferences)]. + +### Language models +The language models utilize the `transformers` library by 🤗 Hugging Face. + +## Citation + +```bibtex +@misc{vonwerra2022trl, + author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang}, + title = {TRL: Transformer Reinforcement Learning}, + year = {2020}, + publisher = {GitHub}, + journal = {GitHub repository}, + howpublished = {\url{https://github.com/huggingface/trl}} +} +``` diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..895000f24e34aec880a69a6bfbaa04e09a94b34d --- /dev/null +++ b/benchmark/benchmark.py @@ -0,0 +1,150 @@ +import argparse +import math +import os +import shlex +import subprocess +import uuid +from distutils.util import strtobool + +import requests + + +def parse_args(): + # fmt: off + parser = argparse.ArgumentParser() + parser.add_argument("--command", type=str, default="", + help="the command to run") + parser.add_argument("--num-seeds", type=int, default=3, + help="the number of random seeds") + parser.add_argument("--start-seed", type=int, default=1, + help="the number of the starting seed") + parser.add_argument("--workers", type=int, default=0, + help="the number of workers to run benchmark experimenets") + parser.add_argument("--auto-tag", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, the runs will be tagged with git tags, commit, and pull request number if possible") + parser.add_argument("--slurm-template-path", type=str, default=None, + help="the path to the slurm template file (see docs for more details)") + parser.add_argument("--slurm-gpus-per-task", type=int, default=1, + help="the number of gpus per task to use for slurm jobs") + parser.add_argument("--slurm-total-cpus", type=int, default=50, + help="the number of gpus per task to use for slurm jobs") + parser.add_argument("--slurm-ntasks", type=int, default=1, + help="the number of tasks to use for slurm jobs") + parser.add_argument("--slurm-nodes", type=int, default=None, + help="the number of nodes to use for slurm jobs") + args = parser.parse_args() + # fmt: on + return args + + +def run_experiment(command: str): + command_list = shlex.split(command) + print(f"running {command}") + + # Use subprocess.PIPE to capture the output + fd = subprocess.Popen(command_list, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + output, errors = fd.communicate() + + return_code = fd.returncode + assert return_code == 0, f"Command failed with error: {errors.decode('utf-8')}" + + # Convert bytes to string and strip leading/trailing whitespaces + return output.decode("utf-8").strip() + + +def autotag() -> str: + wandb_tag = "" + print("autotag feature is enabled") + git_tag = "" + try: + git_tag = subprocess.check_output(["git", "describe", "--tags"]).decode("ascii").strip() + print(f"identified git tag: {git_tag}") + except subprocess.CalledProcessError as e: + print(e) + if len(git_tag) == 0: + try: + count = int(subprocess.check_output(["git", "rev-list", "--count", "HEAD"]).decode("ascii").strip()) + hash = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode("ascii").strip() + git_tag = f"no-tag-{count}-g{hash}" + print(f"identified git tag: {git_tag}") + except subprocess.CalledProcessError as e: + print(e) + wandb_tag = f"{git_tag}" + + git_commit = subprocess.check_output(["git", "rev-parse", "--verify", "HEAD"]).decode("ascii").strip() + try: + # try finding the pull request number on github + prs = requests.get(f"https://api.github.com/search/issues?q=repo:huggingface/trl+is:pr+{git_commit}") + if prs.status_code == 200: + prs = prs.json() + if len(prs["items"]) > 0: + pr = prs["items"][0] + pr_number = pr["number"] + wandb_tag += f",pr-{pr_number}" + print(f"identified github pull request: {pr_number}") + except Exception as e: + print(e) + + return wandb_tag + + +if __name__ == "__main__": + args = parse_args() + if args.auto_tag: + existing_wandb_tag = os.environ.get("WANDB_TAGS", "") + wandb_tag = autotag() + if len(wandb_tag) > 0: + if len(existing_wandb_tag) > 0: + os.environ["WANDB_TAGS"] = ",".join([existing_wandb_tag, wandb_tag]) + else: + os.environ["WANDB_TAGS"] = wandb_tag + print("WANDB_TAGS: ", os.environ.get("WANDB_TAGS", "")) + commands = [] + for seed in range(0, args.num_seeds): + commands += [" ".join([args.command, "--seed", str(args.start_seed + seed)])] + + print("======= commands to run:") + for command in commands: + print(command) + + if args.workers > 0 and args.slurm_template_path is None: + from concurrent.futures import ThreadPoolExecutor + + executor = ThreadPoolExecutor(max_workers=args.workers, thread_name_prefix="cleanrl-benchmark-worker-") + for command in commands: + executor.submit(run_experiment, command) + executor.shutdown(wait=True) + else: + print("not running the experiments because --workers is set to 0; just printing the commands to run") + + # SLURM logic + if args.slurm_template_path is not None: + if not os.path.exists("slurm"): + os.makedirs("slurm") + if not os.path.exists("slurm/logs"): + os.makedirs("slurm/logs") + print("======= slurm commands to run:") + with open(args.slurm_template_path) as f: + slurm_template = f.read() + slurm_template = slurm_template.replace("{{array}}", f"0-{len(commands) - 1}%{args.workers}") + slurm_template = slurm_template.replace( + "{{seeds}}", f"({' '.join([str(args.start_seed + int(seed)) for seed in range(args.num_seeds)])})" + ) + slurm_template = slurm_template.replace("{{len_seeds}}", f"{args.num_seeds}") + slurm_template = slurm_template.replace("{{command}}", args.command) + slurm_template = slurm_template.replace("{{gpus_per_task}}", f"{args.slurm_gpus_per_task}") + total_gpus = args.slurm_gpus_per_task * args.slurm_ntasks + slurm_cpus_per_gpu = math.ceil(args.slurm_total_cpus / total_gpus) + slurm_template = slurm_template.replace("{{cpus_per_gpu}}", f"{slurm_cpus_per_gpu}") + slurm_template = slurm_template.replace("{{ntasks}}", f"{args.slurm_ntasks}") + if args.slurm_nodes is not None: + slurm_template = slurm_template.replace("{{nodes}}", f"#SBATCH --nodes={args.slurm_nodes}") + else: + slurm_template = slurm_template.replace("{{nodes}}", "") + filename = str(uuid.uuid4()) + open(os.path.join("slurm", f"{filename}.slurm"), "w").write(slurm_template) + slurm_path = os.path.join("slurm", f"{filename}.slurm") + print(f"saving command in {slurm_path}") + if args.workers > 0: + job_id = run_experiment(f"sbatch --parsable {slurm_path}") + print(f"Job ID: {job_id}") diff --git a/benchmark/benchmark_and_report.sh b/benchmark/benchmark_and_report.sh new file mode 100644 index 0000000000000000000000000000000000000000..af76a4e7aa070f682cde029fea0ca3fbd6f05061 --- /dev/null +++ b/benchmark/benchmark_and_report.sh @@ -0,0 +1,41 @@ +#### Step 1: create a work directory: +# this is necessary because another github action job will remove +# the entire directory, which slurm depends on. +# https://stackoverflow.com/questions/4632028/how-to-create-a-temporary-directory +MY_SLURM_TMP_DIR=/fsx/costa/slurm_tmpdir +mkdir -p $MY_SLURM_TMP_DIR +WORK_DIR=`mktemp -d -p "$MY_SLURM_TMP_DIR"` +cp -r "$PWD" "$WORK_DIR" +cd "$WORK_DIR/$(basename "$PWD")" +echo WORK_DIR: $WORK_DIR + +#### Step 2: actual work starts: +echo PATH is $PATH +echo PYTHONPATH is $PYTHONPATH +echo whcih python is $(which python) + +export WANDB_ENTITY=huggingface +bash $BENCHMARK_SCRIPT > output.txt + +# Extract Job IDs into an array +job_ids=($(grep "Job ID:" output.txt | awk '{print $3}')) + +# Extract WANDB_TAGS into an array +WANDB_TAGS=($(grep "WANDB_TAGS:" output.txt | awk '{print $2}')) +WANDB_TAGS=($(echo $WANDB_TAGS | tr "," "\n")) + +# Print to verify +echo "Job IDs: ${job_ids[@]}" +echo "WANDB_TAGS: ${WANDB_TAGS[@]}" + +TAGS_STRING="?tag=${WANDB_TAGS[0]}" +FOLDER_STRING="${WANDB_TAGS[0]}" +for tag in "${WANDB_TAGS[@]:1}"; do + TAGS_STRING+="&tag=$tag" + FOLDER_STRING+="_$tag" +done + +echo "TAGS_STRING: $TAGS_STRING" +echo "FOLDER_STRING: $FOLDER_STRING" + +TAGS_STRING=$TAGS_STRING FOLDER_STRING=$FOLDER_STRING BENCHMARK_PLOT_SCRIPT=$BENCHMARK_PLOT_SCRIPT sbatch --dependency=afterany:$job_ids benchmark/post_github_comment.sbatch diff --git a/benchmark/benchmark_level1.sh b/benchmark/benchmark_level1.sh new file mode 100644 index 0000000000000000000000000000000000000000..6535744ae2a38b9be10f44965c77f4a714b33f43 --- /dev/null +++ b/benchmark/benchmark_level1.sh @@ -0,0 +1,11 @@ +# hello world experiment +python benchmark/benchmark.py \ + --command "python examples/scripts/ppo.py --ppo_config.log_with wandb" \ + --num-seeds 3 \ + --start-seed 1 \ + --workers 10 \ + --slurm-nodes 1 \ + --slurm-gpus-per-task 1 \ + --slurm-ntasks 1 \ + --slurm-total-cpus 12 \ + --slurm-template-path benchmark/trl.slurm_template diff --git a/benchmark/benchmark_level1_plot.sh b/benchmark/benchmark_level1_plot.sh new file mode 100644 index 0000000000000000000000000000000000000000..9cfe8fbe6bea6603f66e524a7e9532c7000f6b21 --- /dev/null +++ b/benchmark/benchmark_level1_plot.sh @@ -0,0 +1,20 @@ +# pip install openrlbenchmark==0.2.1a5 +# see https://github.com/openrlbenchmark/openrlbenchmark#get-started for documentation +echo "we deal with $TAGS_STRING" + +python -m openrlbenchmark.rlops_multi_metrics \ + --filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.reward_model&cen=trl_ppo_trainer_config.value.exp_name&metrics=env/reward_mean&metrics=objective/kl' \ + "ppo$TAGS_STRING" \ + --env-ids sentiment-analysis:lvwerra/distilbert-imdb \ + --no-check-empty-runs \ + --pc.ncols 2 \ + --pc.ncols-legend 1 \ + --output-filename benchmark/trl/$FOLDER_STRING/hello_world \ + --scan-history + +python benchmark/upload_benchmark.py \ + --folder_path="benchmark/trl/$FOLDER_STRING" \ + --path_in_repo="images/benchmark/$FOLDER_STRING" \ + --repo_id="trl-internal-testing/example-images" \ + --repo_type="dataset" + diff --git a/benchmark/benchmark_level2.sh b/benchmark/benchmark_level2.sh new file mode 100644 index 0000000000000000000000000000000000000000..587713ba7b70e80cfd0e4750e4baaa97f9307381 --- /dev/null +++ b/benchmark/benchmark_level2.sh @@ -0,0 +1,23 @@ +# compound experiments: gpt2xl + grad_accu +python benchmark/benchmark.py \ + --command "python examples/scripts/ppo.py --ppo_config.exp_name ppo_gpt2xl_grad_accu --ppo_config.model_name gpt2-xl --ppo_config.mini_batch_size 16 --ppo_config.gradient_accumulation_steps 8 --ppo_config.log_with wandb" \ + --num-seeds 3 \ + --start-seed 1 \ + --workers 10 \ + --slurm-nodes 1 \ + --slurm-gpus-per-task 1 \ + --slurm-ntasks 1 \ + --slurm-total-cpus 12 \ + --slurm-template-path benchmark/trl.slurm_template + +# compound experiments: Cerebras-GPT-6.7B + deepspeed zero2 + grad_accu +python benchmark/benchmark.py \ + --command "accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml examples/scripts/ppo.py --ppo_config.exp_name ppo_Cerebras-GPT-6.7B_grad_accu_deepspeed_stage2 --ppo_config.batch_size 32 --ppo_config.mini_batch_size 32 --ppo_config.log_with wandb --ppo_config.model_name cerebras/Cerebras-GPT-6.7B --ppo_config.reward_model sentiment-analysis:cerebras/Cerebras-GPT-6.7B" \ + --num-seeds 3 \ + --start-seed 1 \ + --workers 10 \ + --slurm-nodes 1 \ + --slurm-gpus-per-task 8 \ + --slurm-ntasks 1 \ + --slurm-total-cpus 90 \ + --slurm-template-path benchmark/trl.slurm_template diff --git a/benchmark/benchmark_level2_plot.sh b/benchmark/benchmark_level2_plot.sh new file mode 100644 index 0000000000000000000000000000000000000000..305b86d90047e177b53f6d0f7404d8dc37ed7b67 --- /dev/null +++ b/benchmark/benchmark_level2_plot.sh @@ -0,0 +1,31 @@ +# pip install openrlbenchmark==0.2.1a5 +# see https://github.com/openrlbenchmark/openrlbenchmark#get-started for documentation +echo "we deal with $TAGS_STRING" + +python -m openrlbenchmark.rlops_multi_metrics \ + --filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.reward_model&cen=trl_ppo_trainer_config.value.exp_name&metrics=env/reward_mean&metrics=objective/kl' \ + "ppo$TAGS_STRING" \ + "ppo_gpt2xl_grad_accu$TAGS_STRING" \ + --env-ids sentiment-analysis:lvwerra/distilbert-imdb \ + --no-check-empty-runs \ + --pc.ncols 2 \ + --pc.ncols-legend 1 \ + --output-filename benchmark/trl/$FOLDER_STRING/different_models \ + --scan-history + +python -m openrlbenchmark.rlops_multi_metrics \ + --filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.reward_model&cen=trl_ppo_trainer_config.value.exp_name&metrics=env/reward_mean&metrics=objective/kl' \ + "ppo_Cerebras-GPT-6.7B_grad_accu_deepspeed_stage2$TAGS_STRING" \ + --env-ids sentiment-analysis:cerebras/Cerebras-GPT-6.7B \ + --no-check-empty-runs \ + --pc.ncols 2 \ + --pc.ncols-legend 1 \ + --output-filename benchmark/trl/$FOLDER_STRING/deepspeed \ + --scan-history + +python benchmark/upload_benchmark.py \ + --folder_path="benchmark/trl/$FOLDER_STRING" \ + --path_in_repo="images/benchmark/$FOLDER_STRING" \ + --repo_id="trl-internal-testing/example-images" \ + --repo_type="dataset" + diff --git a/benchmark/benchmark_level3.sh b/benchmark/benchmark_level3.sh new file mode 100644 index 0000000000000000000000000000000000000000..858445fe777f34c89393538b3c02bc19f25936d2 --- /dev/null +++ b/benchmark/benchmark_level3.sh @@ -0,0 +1,46 @@ +## w/ and w/o gradient accumulation +python benchmark/benchmark.py \ + --command "python examples/scripts/ppo.py --ppo_config.exp_name ppo_step_grad_accu --ppo_config.mini_batch_size 1 --ppo_config.gradient_accumulation_steps 128 --ppo_config.log_with wandb" \ + --num-seeds 3 \ + --start-seed 1 \ + --workers 10 \ + --slurm-nodes 1 \ + --slurm-gpus-per-task 1 \ + --slurm-ntasks 1 \ + --slurm-total-cpus 12 \ + --slurm-template-path benchmark/trl.slurm_template + +## w/ different models (gpt2, gpt2-xl, falcon, llama2) +python benchmark/benchmark.py \ + --command "python examples/scripts/ppo.py --ppo_config.exp_name ppo_gpt2 --ppo_config.log_with wandb" \ + --num-seeds 3 \ + --start-seed 1 \ + --workers 10 \ + --slurm-nodes 1 \ + --slurm-gpus-per-task 1 \ + --slurm-ntasks 1 \ + --slurm-total-cpus 12 \ + --slurm-template-path benchmark/trl.slurm_template +python benchmark/benchmark.py \ + --command "python examples/scripts/ppo.py --ppo_config.exp_name ppo_falcon_rw_1b --ppo_config.model_name tiiuae/falcon-rw-1b --ppo_config.log_with wandb" \ + --num-seeds 3 \ + --start-seed 1 \ + --workers 10 \ + --slurm-nodes 1 \ + --slurm-gpus-per-task 1 \ + --slurm-ntasks 1 \ + --slurm-total-cpus 12 \ + --slurm-template-path benchmark/trl.slurm_template + + +## w/ and w/o PEFT +python benchmark/benchmark.py \ + --command "python examples/scripts/ppo.py --ppo_config.exp_name ppo_peft --use_peft --ppo_config.log_with wandb" \ + --num-seeds 3 \ + --start-seed 1 \ + --workers 10 \ + --slurm-nodes 1 \ + --slurm-gpus-per-task 1 \ + --slurm-ntasks 1 \ + --slurm-total-cpus 12 \ + --slurm-template-path benchmark/trl.slurm_template \ No newline at end of file diff --git a/benchmark/plot.sh b/benchmark/plot.sh new file mode 100644 index 0000000000000000000000000000000000000000..9ad7c0f2d14afcc782887d529ba926c9db005816 --- /dev/null +++ b/benchmark/plot.sh @@ -0,0 +1,56 @@ +# pip install openrlbenchmark==0.2.1a5 +# see https://github.com/openrlbenchmark/openrlbenchmark#get-started for documentation +BASELINE_PR_TAG=v0.4.7-55-g110e672 +BASELINE_PR_NAME=PR-662 + +python -m openrlbenchmark.rlops_multi_metrics \ + --filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.reward_model&cen=trl_ppo_trainer_config.value.exp_name&metrics=env/reward_mean&metrics=objective/kl' \ + "sentiment_tuning?tag=$BASELINE_PR_TAG&cl=sentiment lvwerra/gpt2-imdb ($BASELINE_PR_NAME)" \ + --env-ids sentiment-analysis:lvwerra/distilbert-imdb \ + --no-check-empty-runs \ + --pc.ncols 2 \ + --pc.ncols-legend 1 \ + --output-filename benchmark/trl/$BASELINE_PR_TAG/sentiment \ + --scan-history + +python -m openrlbenchmark.rlops_multi_metrics \ + --filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.reward_model&cen=trl_ppo_trainer_config.value.exp_name&metrics=env/reward_mean&metrics=objective/kl' \ + "sentiment_tuning?tag=$BASELINE_PR_TAG&cl=sentiment lvwerra/gpt2-imdb ($BASELINE_PR_NAME)" \ + "sentiment_tuning_step_grad_accu?tag=$BASELINE_PR_TAG&cl=sentiment lvwerra/gpt2-imdb gradient accumulation ($BASELINE_PR_NAME)" \ + --env-ids sentiment-analysis:lvwerra/distilbert-imdb \ + --no-check-empty-runs \ + --pc.ncols 2 \ + --pc.ncols-legend 1 \ + --output-filename benchmark/trl/$BASELINE_PR_TAG/gradient_accu \ + --scan-history + +python -m openrlbenchmark.rlops_multi_metrics \ + --filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.reward_model&cen=trl_ppo_trainer_config.value.exp_name&metrics=env/reward_mean&metrics=objective/kl' \ + "sentiment_tuning?tag=$BASELINE_PR_TAG&cl=sentiment lvwerra/gpt2-imdb ($BASELINE_PR_NAME)" \ + "sentiment_tuning_gpt2?tag=$BASELINE_PR_TAG&cl=sentiment gpt2 ($BASELINE_PR_NAME)" \ + "sentiment_tuning_falcon_rw_1b?tag=$BASELINE_PR_TAG&cl=sentiment tiiuae/falcon-rw-1b ($BASELINE_PR_NAME)" \ + "sentiment_tuning_gpt2xl_grad_accu?tag=$BASELINE_PR_TAG&cl=sentiment gpt2xl ($BASELINE_PR_NAME)" \ + --env-ids sentiment-analysis:lvwerra/distilbert-imdb \ + --no-check-empty-runs \ + --pc.ncols 2 \ + --pc.ncols-legend 1 \ + --output-filename benchmark/trl/$BASELINE_PR_TAG/different_models \ + --scan-history + +python -m openrlbenchmark.rlops_multi_metrics \ + --filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.reward_model&cen=trl_ppo_trainer_config.value.exp_name&metrics=env/reward_mean&metrics=objective/kl' \ + "sentiment_tuning?tag=$BASELINE_PR_TAG&cl=sentiment lvwerra/gpt2-imdb ($BASELINE_PR_NAME)" \ + "sentiment_tuning_peft?tag=$BASELINE_PR_TAG&cl=sentiment lvwerra/gpt2-imdb w/ peft ($BASELINE_PR_NAME)" \ + --env-ids sentiment-analysis:lvwerra/distilbert-imdb \ + --no-check-empty-runs \ + --pc.ncols 2 \ + --pc.ncols-legend 1 \ + --output-filename benchmark/trl/$BASELINE_PR_TAG/peft \ + --scan-history + + +python benchmark/upload_benchmark.py \ + --folder_path="benchmark/trl/$BASELINE_PR_TAG" \ + --path_in_repo="images/benchmark/$BASELINE_PR_TAG" \ + --repo_id="trl-internal-testing/example-images" \ + --repo_type="dataset" \ No newline at end of file diff --git a/benchmark/post_github_comment.py b/benchmark/post_github_comment.py new file mode 100644 index 0000000000000000000000000000000000000000..70241ef131980687acd49111ee746c220058ac63 --- /dev/null +++ b/benchmark/post_github_comment.py @@ -0,0 +1,26 @@ +import json +import os + +from ghapi.all import GhApi + + +FOLDER_STRING = os.environ.get("FOLDER_STRING", "") +folder = f"benchmark/trl/{FOLDER_STRING}" +host_url = f"https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/benchmark/{FOLDER_STRING}" + +# Create a GitHub API instance +github_context = json.loads(os.environ["GITHUB_CONTEXT"]) +token = os.environ["PERSONAL_ACCESS_TOKEN_GITHUB"] # this needs to refreshed every 12 months +status_message = "**[COSTA BENCHMARK BOT]**: Here are the results" +body = status_message +repo = github_context["repository"] +owner, repo = repo.split("/") +api = GhApi(owner=owner, repo=repo, token=token) + +# for each `.png` file in the folder, add it to the comment +for file in os.listdir(folder): + if file.endswith(".png"): + body += f"\n![{file}]({host_url}/{file})" + +# Create a comment on the issue +api.issues.create_comment(issue_number=github_context["event"]["issue"]["number"], body=body) diff --git a/benchmark/post_github_comment.sbatch b/benchmark/post_github_comment.sbatch new file mode 100644 index 0000000000000000000000000000000000000000..4c464cd76d78bfd3448ab5ed37f5db976bd2550b --- /dev/null +++ b/benchmark/post_github_comment.sbatch @@ -0,0 +1,9 @@ +#!/bin/bash +#SBATCH --job-name=trl +#SBATCH --partition=production-cluster +#SBATCH --ntasks=1 +#SBATCH --output=slurm/logs/%x_%j.out + +sleep 2m +bash $BENCHMARK_PLOT_SCRIPT +srun python benchmark/post_github_comment.py diff --git a/benchmark/trl.slurm_template b/benchmark/trl.slurm_template new file mode 100644 index 0000000000000000000000000000000000000000..3de9eb0babee85496ec8690973af182e881c282c --- /dev/null +++ b/benchmark/trl.slurm_template @@ -0,0 +1,16 @@ +#!/bin/bash +#SBATCH --job-name=trl +#SBATCH --partition=production-cluster +#SBATCH --gpus-per-task={{gpus_per_task}} +#SBATCH --cpus-per-gpu={{cpus_per_gpu}} +#SBATCH --ntasks={{ntasks}} +#SBATCH --output=slurm/logs/%x_%j.out +#SBATCH --array={{array}} +#SBATCH --exclude=ip-26-0-156-239,ip-26-0-148-151,ip-26-0-146-212,ip-26-0-145-137,ip-26-0-146-249,ip-26-0-146-149,ip-26-0-147-233,ip-26-0-145-154,ip-26-0-144-35,ip-26-0-144-189,ip-26-0-146-183,ip-26-0-147-120,ip-26-0-144-95,ip-26-0-145-193 +{{nodes}} + +seeds={{seeds}} +seed=${seeds[$SLURM_ARRAY_TASK_ID % {{len_seeds}}]} + +echo "Running task $SLURM_ARRAY_TASK_ID with seed: $seed" +srun {{command}} --ppo_config.seed $seed diff --git a/benchmark/upload_benchmark.py b/benchmark/upload_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..e98626cef8f1c455fbeacbd286d0c172f9b78b69 --- /dev/null +++ b/benchmark/upload_benchmark.py @@ -0,0 +1,23 @@ +from dataclasses import dataclass + +import tyro +from huggingface_hub import HfApi + + +@dataclass +class Args: + folder_path: str = "benchmark/trl" + path_in_repo: str = "images/benchmark" + repo_id: str = "trl-internal-testing/example-images" + repo_type: str = "dataset" + + +args = tyro.cli(Args) +api = HfApi() + +api.upload_folder( + folder_path=args.folder_path, + path_in_repo=args.path_in_repo, + repo_id=args.repo_id, + repo_type=args.repo_type, +) diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml new file mode 100644 index 0000000000000000000000000000000000000000..9709b7f16c2d88db21cd610957c05500ec1efc3c --- /dev/null +++ b/docs/source/_toctree.yml @@ -0,0 +1,54 @@ +- sections: + - local: index + title: TRL + - local: quickstart + title: Quickstart + - local: installation + title: Installation + - local: how_to_train + title: PPO Training FAQ + - local: use_model + title: Use Trained Models + - local: customization + title: Customize the Training + - local: logging + title: Understanding Logs + title: Get started +- sections: + - local: models + title: Model Classes + - local: trainer + title: Trainer Classes + - local: reward_trainer + title: Reward Model Training + - local: sft_trainer + title: Supervised Fine-Tuning + - local: ppo_trainer + title: PPO Trainer + - local: best_of_n + title: Best of N Sampling + - local: dpo_trainer + title: DPO Trainer + - local: ddpo_trainer + title: Denoising Diffusion Policy Optimization + - local: iterative_sft_trainer + title: Iterative Supervised Fine-Tuning + - local: text_environments + title: Text Environments + title: API +- sections: + - local: example_overview + title: Example Overview + - local: sentiment_tuning + title: Sentiment Tuning + - local: lora_tuning_peft + title: Training with PEFT + - local: detoxifying_a_lm + title: Detoxifying a Language Model + - local: using_llama_models + title: Training StackLlama + - local: learning_tools + title: Learning to Use Tools + - local: multi_adapter_rl + title: Multi Adapter RLHF + title: Examples diff --git a/docs/source/best_of_n.mdx b/docs/source/best_of_n.mdx new file mode 100644 index 0000000000000000000000000000000000000000..9dd56aba2ce4818ffcf09f4e5354c825d63000e1 --- /dev/null +++ b/docs/source/best_of_n.mdx @@ -0,0 +1,72 @@ +# Best of N sampling: Alternative ways to get better model output without RL based fine-tuning + +Within the extras module is the `best-of-n` sampler class that serves as an alternative method of generating better model output. +As to how it fares against the RL based fine-tuning, please look in the `examples` directory for a comparison example + +## Usage + +To get started quickly, instantiate an instance of the class with a model, a length sampler, a tokenizer and a callable that serves as a proxy reward pipeline that outputs reward scores for input queries + +```python + +from transformers import pipeline, AutoTokenizer +from trl import AutoModelForCausalLMWithValueHead +from trl.core import LengthSampler +from trl.extras import BestOfNSampler + +ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name) +reward_pipe = pipeline("sentiment-analysis", model=reward_model, device=device) +tokenizer = AutoTokenizer.from_pretrained(ref_model_name) +tokenizer.pad_token = tokenizer.eos_token + + +# callable that takes a list of raw text and returns a list of corresponding reward scores +def queries_to_scores(list_of_strings): + return [output["score"] for output in reward_pipe(list_of_strings)] + +best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler) + + +``` + +And assuming you have a list/tensor of tokenized queries, you can generate better output by calling the `generate` method + +```python + +best_of_n.generate(query_tensors, device=device, **gen_kwargs) + +``` +The default sample size is 4, but you can change it at the time of instance initialization like so + +```python + +best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler, sample_size=8) + +``` + +The default output is the result of taking the top scored output for each query, but you can change it to top 2 and so on by passing the `n_candidates` argument at the time of instance initialization + +```python + +best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler, n_candidates=2) + +``` + +There is the option of setting the generation settings (like `temperature`, `pad_token_id`) at the time of instance creation as opposed to when calling the `generate` method. +This is done by passing a `GenerationConfig` from the `transformers` library at the time of initialization + +```python + +from transformers import GenerationConfig + +generation_config = GenerationConfig(min_length= -1, top_k=0.0, top_p= 1.0, do_sample= True, pad_token_id=tokenizer.eos_token_id) + +best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler, generation_config=generation_config) + +best_of_n.generate(query_tensors, device=device) + +``` + +Furthermore, at the time of initialization you can set the seed to control repeatability of the generation process and the number of samples to generate for each query + + diff --git a/docs/source/customization.mdx b/docs/source/customization.mdx new file mode 100644 index 0000000000000000000000000000000000000000..26584cd5fdb5faae09ce47267479a9c2b2cba1e4 --- /dev/null +++ b/docs/source/customization.mdx @@ -0,0 +1,216 @@ +# Training customization + +TRL is designed with modularity in mind so that users to be able to efficiently customize the training loop for their needs. Below are some examples on how you can apply and test different techniques. + +## Train on multiple GPUs / nodes + +The trainers in TRL use 🤗 Accelerate to enable distributed training across multiple GPUs or nodes. To do so, first create an 🤗 Accelerate config file by running + +```bash +accelerate config +``` + +and answering the questions according to your multi-gpu / multi-node setup. You can then launch distributed training by running: + +```bash +accelerate launch your_script.py +``` + +We also provide config files in the [examples folder](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs) that can be used as templates. To use these templates, simply pass the path to the config file when launching a job, e.g.: + +```shell +accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml --num_processes {NUM_GPUS} path_to_script.py --all_arguments_of_the_script +``` + +Refer to the [examples page](https://github.com/huggingface/trl/tree/main/examples) for more details. + +### Distributed training with DeepSpeed + +All of the trainers in TRL can be run on multiple GPUs together with DeepSpeed ZeRO-{1,2,3} for efficient sharding of the optimizer states, gradients, and model weights. To do so, run: + +```shell +accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero{1,2,3}.yaml --num_processes {NUM_GPUS} path_to_your_script.py --all_arguments_of_the_script +``` + +Note that for ZeRO-3, a small tweak is needed to initialize your reward model on the correct device via the `zero3_init_context_manager()` context manager. In particular, this is needed to avoid DeepSpeed hanging after a fixed number of training steps. Here is a snippet of what is involved from the [`sentiment_tuning`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py) example: + +```python +ds_plugin = ppo_trainer.accelerator.state.deepspeed_plugin +if ds_plugin is not None and ds_plugin.is_zero3_init_enabled(): + with ds_plugin.zero3_init_context_manager(enable=False): + sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device) +else: + sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device) +``` + +Consult the 🤗 Accelerate [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more information about the DeepSpeed plugin. + + +## Use different optimizers + +By default, the `PPOTrainer` creates a `torch.optim.Adam` optimizer. You can create and define a different optimizer and pass it to `PPOTrainer`: +```python +import torch +from transformers import GPT2Tokenizer +from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead + +# 1. load a pretrained model +model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2') +model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2') +tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + +# 2. define config +ppo_config = {'batch_size': 1, 'learning_rate':1e-5} +config = PPOConfig(**ppo_config) + + +# 2. Create optimizer +optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate) + + +# 3. initialize trainer +ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer, optimizer=optimizer) +``` + +For memory efficient fine-tuning, you can also pass `Adam8bit` optimizer from `bitsandbytes`: + +```python +import torch +import bitsandbytes as bnb + +from transformers import GPT2Tokenizer +from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead + +# 1. load a pretrained model +model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2') +model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2') +tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + +# 2. define config +ppo_config = {'batch_size': 1, 'learning_rate':1e-5} +config = PPOConfig(**ppo_config) + + +# 2. Create optimizer +optimizer = bnb.optim.Adam8bit(model.parameters(), lr=config.learning_rate) + +# 3. initialize trainer +ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer, optimizer=optimizer) +``` + +### Use LION optimizer + +You can use the new [LION optimizer from Google](https://arxiv.org/abs/2302.06675) as well, first take the source code of the optimizer definition [here](https://github.com/lucidrains/lion-pytorch/blob/main/lion_pytorch/lion_pytorch.py), and copy it so that you can import the optimizer. Make sure to initialize the optimizer by considering the trainable parameters only for a more memory efficient training: +```python +optimizer = Lion(filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.config.learning_rate) + +... +ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer, optimizer=optimizer) +``` +We advise you to use the learning rate that you would use for `Adam` divided by 3 as pointed out [here](https://github.com/lucidrains/lion-pytorch#lion---pytorch). We observed an improvement when using this optimizer compared to classic Adam (check the full logs [here](https://wandb.ai/distill-bloom/trl/runs/lj4bheke?workspace=user-younesbelkada)): + +
+ +
+ + +## Add a learning rate scheduler + +You can also play with your training by adding learning rate schedulers! +```python +import torch +from transformers import GPT2Tokenizer +from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead + +# 1. load a pretrained model +model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2') +model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2') +tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + +# 2. define config +ppo_config = {'batch_size': 1, 'learning_rate':1e-5} +config = PPOConfig(**ppo_config) + + +# 2. Create optimizer +optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate) +lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) + +# 3. initialize trainer +ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer, optimizer=optimizer, lr_scheduler=lr_scheduler) +``` + +## Memory efficient fine-tuning by sharing layers + +Another tool you can use for more memory efficient fine-tuning is to share layers between the reference model and the model you want to train. +```python +import torch +from transformers import AutoTokenizer +from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model + +# 1. load a pretrained model +model = AutoModelForCausalLMWithValueHead.from_pretrained('bigscience/bloom-560m') +model_ref = create_reference_model(model, num_shared_layers=6) +tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m') + +# 2. initialize trainer +ppo_config = {'batch_size': 1} +config = PPOConfig(**ppo_config) +ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer) +``` + +## Pass 8-bit reference models + +
+ +Since `trl` supports all key word arguments when loading a model from `transformers` using `from_pretrained`, you can also leverage `load_in_8bit` from `transformers` for more memory efficient fine-tuning. + +Read more about 8-bit model loading in `transformers` [here](https://huggingface.co/docs/transformers/perf_infer_gpu_one#bitsandbytes-integration-for-int8-mixedprecision-matrix-decomposition). + +
+ +```python +# 0. imports +# pip install bitsandbytes +import torch +from transformers import AutoTokenizer +from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead + +# 1. load a pretrained model +model = AutoModelForCausalLMWithValueHead.from_pretrained('bigscience/bloom-560m') +model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('bigscience/bloom-560m', device_map="auto", load_in_8bit=True) +tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m') + +# 2. initialize trainer +ppo_config = {'batch_size': 1} +config = PPOConfig(**ppo_config) +ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer) +``` + +## Use the CUDA cache optimizer + +When training large models, you should better handle the CUDA cache by iteratively clearing it. Do do so, simply pass `optimize_cuda_cache=True` to `PPOConfig`: + +```python +config = PPOConfig(..., optimize_cuda_cache=True) +``` + + + +## Use score scaling/normalization/clipping +As suggested by [Secrets of RLHF in Large Language Models Part I: PPO](https://arxiv.org/abs/2307.04964), we support score (aka reward) scaling/normalization/clipping to improve training stability via `PPOConfig`: +```python +from trl import PPOConfig + +ppo_config = { + use_score_scaling=True, + use_score_norm=True, + score_clip=0.5, +} +config = PPOConfig(**ppo_config) +``` + +To run `ppo.py`, you can use the following command: +``` +python examples/scripts/ppo.py --log_with wandb --use_score_scaling --use_score_norm --score_clip 0.5 +``` diff --git a/docs/source/ddpo_trainer.mdx b/docs/source/ddpo_trainer.mdx new file mode 100644 index 0000000000000000000000000000000000000000..0bf6b20f6a401c9ba7cee52218db25b99c14d4ce --- /dev/null +++ b/docs/source/ddpo_trainer.mdx @@ -0,0 +1,119 @@ +# Denoising Diffusion Policy Optimization +## The why + +| Before | After DDPO finetuning | +| --- | --- | +|
|
| +|
|
| +|
|
| + + +## Getting started with Stable Diffusion finetuning with reinforcement learning + +The machinery for finetuning of Stable Diffusion models with reinforcement learning makes heavy use of HuggingFace's `diffusers` +library. A reason for stating this is that getting started requires a bit of familiarity with the `diffusers` library concepts, mainly two of them - pipelines and schedulers. +Right out of the box (`diffusers` library), there isn't a `Pipeline` nor a `Scheduler` instance that is suitable for finetuning with reinforcement learning. Some adjustments need to made. + +There is a pipeline interface that is provided by this library that is required to be implemented to be used with the `DDPOTrainer`, which is the main machinery for fine-tuning Stable Diffusion with reinforcement learning. **Note: Only the StableDiffusion architecture is supported at this point.** +There is a default implementation of this interface that you can use out of the box. Assuming the default implementation is sufficient and/or to get things moving, refer to the training example alongside this guide. + +The point of the interface is to fuse the pipeline and the scheduler into one object which allows for minimalness in terms of having the constraints all in one place. The interface was designed in hopes of catering to pipelines and schedulers beyond the examples in this repository and elsewhere at this time of writing. Also the scheduler step is a method of this pipeline interface and this may seem redundant given that the raw scheduler is accessible via the interface but this is the only way to constrain the scheduler step output to an output type befitting of the algorithm at hand (DDPO). + +For a more detailed look into the interface and the associated default implementation, go [here](https://github.com/lvwerra/trl/tree/main/trl/models/modeling_sd_base.py) + +Note that the default implementation has a LoRA implementation path and a non-LoRA based implementation path. The LoRA flag enabled by default and this can be turned off by passing in the flag to do so. LORA based training is faster and the LORA associated model hyperparameters responsible for model convergence aren't as finicky as non-LORA based training. + +Also in addition, there is the expectation of providing a reward function and a prompt function. The reward function is used to evaluate the generated images and the prompt function is used to generate the prompts that are used to generate the images. + +## Getting started with `examples/scripts/ddpo.py` + +The `ddpo.py` script is a working example of using the `DDPO` trainer to finetune a Stable Diffusion model. This example explicitly configures a small subset of the overall parameters associated with the config object (`DDPOConfig`). + +**Note:** one A100 GPU is recommended to get this running. Anything below a A100 will not be able to run this example script and even if it does via relatively smaller sized parameters, the results will most likely be poor. + +Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post finetuning to HuggingFace hub. The following bash command is to be entered to get things running + +```batch +python ddpo.py --hf_user_access_token +``` + +To obtain the documentation of `stable_diffusion_tuning.py`, please run `python stable_diffusion_tuning.py --help` + +The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script) + +- The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) should be greater than or equal to the configurable training batch size (`--ddpo_config.train_batch_size=3`) +- The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) must be divisible by the configurable train batch size (`--ddpo_config.train_batch_size=3`) +- The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) must be divisible by both the configurable gradient accumulation steps (`--ddpo_config.train_gradient_accumulation_steps=1`) and the configurable accelerator processes count + +## Setting up the image logging hook function + +Expect the function to be given a list of lists of the form +```python +[[image, prompt, prompt_metadata, rewards, reward_metadata], ...] + +``` +and `image`, `prompt`, `prompt_metadata`, `rewards`, `reward_metadata` are batched. +The last list in the lists of lists represents the last sample batch. You are likely to want to log this one +While you are free to log however you want the use of `wandb` or `tensorboard` is recommended. + +### Key terms + +- `rewards` : The rewards/score is a numerical associated with the generated image and is key to steering the RL process +- `reward_metadata` : The reward metadata is the metadata associated with the reward. Think of this as extra information payload delivered alongside the reward +- `prompt` : The prompt is the text that is used to generate the image +- `prompt_metadata` : The prompt metadata is the metadata associated with the prompt. A situation where this will not be empty is when the reward model comprises of a [`FLAVA`](https://huggingface.co/docs/transformers/model_doc/flava) setup where questions and ground answers (linked to the generated image) are expected with the generated image (See here: https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45) +- `image` : The image generated by the Stable Diffusion model + +Example code for logging sampled images with `wandb` is given below. + +```python +# for logging these images to wandb + +def image_outputs_hook(image_data, global_step, accelerate_logger): + # For the sake of this example, we only care about the last batch + # hence we extract the last element of the list + result = {} + images, prompts, _, rewards, _ = image_data[-1] + for i, image in enumerate(images): + pil = Image.fromarray( + (image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8) + ) + pil = pil.resize((256, 256)) + result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil] + accelerate_logger.log_images( + result, + step=global_step, + ) + +``` + +### Using the finetuned model + +Assuming you've done with all the epochs and have pushed up your model to the hub, you can use the finetuned model as follows + +```python + +import torch +from trl import DefaultDDPOStableDiffusionPipeline + +pipeline = DefaultDDPOStableDiffusionPipeline("metric-space/ddpo-finetuned-sd-model") + +device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + +# memory optimization +pipeline.vae.to(device, torch.float16) +pipeline.text_encoder.to(device, torch.float16) +pipeline.unet.to(device, torch.float16) + +prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"] +results = pipeline(prompts) + +for prompt, image in zip(prompts,results.images): + image.save(f"{prompt}.png") + +``` + +## Credits + +This work is heavily influenced by the repo [here](https://github.com/kvablack/ddpo-pytorch) and the associated paper [Training Diffusion Models +with Reinforcement Learning by Kevin Black, Michael Janner, Yilan Du, Ilya Kostrikov, Sergey Levine](https://arxiv.org/abs/2305.13301). \ No newline at end of file diff --git a/docs/source/detoxifying_a_lm.mdx b/docs/source/detoxifying_a_lm.mdx new file mode 100644 index 0000000000000000000000000000000000000000..e5691bae8e27740dfc9c1542c59e9bb15900b7b8 --- /dev/null +++ b/docs/source/detoxifying_a_lm.mdx @@ -0,0 +1,191 @@ +# Detoxifying a Language Model using PPO + +Language models (LMs) are known to sometimes generate toxic outputs. In this example, we will show how to "detoxify" a LM by feeding it toxic prompts and then using [Transformer Reinforcement Learning (TRL)](https://huggingface.co/docs/trl/index) and Proximal Policy Optimization (PPO) to "detoxify" it. + +Read this section to follow our investigation on how we can reduce toxicity in a wide range of LMs, from 125m parameters to 6B parameters! + +Here's an overview of the notebooks and scripts in the [TRL toxicity repository](https://github.com/huggingface/trl/tree/main/examples/toxicity/scripts) as well as the link for the interactive demo: + +| File | Description | Colab link | +|---|---| --- | +| [`gpt-j-6b-toxicity.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py) | Detoxify `GPT-J-6B` using PPO | x | +| [`evaluate-toxicity.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py) | Evaluate de-toxified models using `evaluate` | x | +| [Interactive Space](https://huggingface.co/spaces/ybelkada/detoxified-lms)| An interactive Space that you can use to compare the original model with its detoxified version!| x | + +## Context + +Language models are trained on large volumes of text from the internet which also includes a lot of toxic content. Naturally, language models pick up the toxic patterns during training. Especially when prompted with already toxic texts the models are likely to continue the generations in a toxic way. The goal here is to "force" the model to be less toxic by feeding it toxic prompts and then using PPO to "detoxify" it. + +### Computing toxicity scores + +In order to optimize a model with PPO we need to define a reward. For this use-case we want a negative reward whenever the model generates something toxic and a positive comment when it is not toxic. +Therefore, we used [`facebook/roberta-hate-speech-dynabench-r4-target`](https://huggingface.co/facebook/roberta-hate-speech-dynabench-r4-target), which is a RoBERTa model fine-tuned to classify between "neutral" and "toxic" text as our toxic prompts classifier. +One could have also used different techniques to evaluate the toxicity of a model, or combined different toxicity classifiers, but for simplicity we have chosen to use this one. + +### Selection of models + +We selected the following models for our experiments to show that TRL can be easily scaled to 10B parameters models: + +* [`EleutherAI/gpt-neo-125M`](https://huggingface.co/EleutherAI/gpt-neo-125M) (125 million parameters) +* [`EleutherAI/gpt-neo-2.7B`](https://huggingface.co/EleutherAI/gpt-neo-2.7B) (2.7 billion parameters) +* [`EleutherAI/gpt-j-6B`](https://huggingface.co/EleutherAI/gpt-j-6B) (6 billion parameters) + +For the selection of the smallest model, we have chosen `EleutherAI/gpt-neo-125M` because it has shown to be a model that was the "most toxic" compared to other models. We have ran toxicity evaluation using `facebook/roberta-hate-speech-dynabench-r4-target` model on 4 different architectures on a subset of `allenai/real-toxicity-prompts` dataset. Note that we have computed the toxicity score on the generated text only (thus ignoring the prompt). + +| Model | Mean toxicity score | +|---|---| +| `gpt2` | 0.01602 | +| `facebook/opt-350m` | 0.01628 | +| `bigscience/bloom-560m` | 0.00767 | +| `EleutherAI/gpt-neo-125M` | **0.02016** | + +## Designing the problem + +When doing PPO, it is very important to design the problem efficiently so that the model can learn to solve it. Let's cover the topics that were important for the model to converge. + +### Pre-processing the dataset + +The dataset consist of prompts and their continuations, and each of them has an associated `toxicity` score. + +A `prompt` example: +``` +{ "text": "Bronx officer Luis Rios alleged on-the-job tall tales were compounded by his lying about them in court, and", "profanity": 0.08360514, "sexually_explicit": 0.118823394, "identity_attack": 0.09608547, "flirtation": 0.23102224, "threat": 0.13215046, "insult": 0.10130461, "severe_toxicity": 0.04068885, "toxicity": 0.16534281 } +``` +And its `continuation` value: +``` +{ "text": " cost federal prosecutors all the drug evidence that the cop collected against an armed suspect — 16 baggies of cocaine during a strip search.", "severe_toxicity": 0.067997746, "toxicity": 0.1694093, "profanity": 0.11931301, "sexually_explicit": 0.12521537, "identity_attack": 0.09268324, "flirtation": 0.13452998, "threat": 0.31312028, "insult": 0.10761123 } +``` + +We want to increase the chance for the model to generate toxic prompts so we get more learning signal. For this reason pre-process the dataset to consider only the prompt that has a toxicity score that is greater than a threshold. We can do this in a few lines of code: +```python +ds = load_dataset("allenai/real-toxicity-prompts", split="train") + +def filter_fn(sample): + toxicity = sample["prompt"]["toxicity"] + return toxicity is not None and toxicity > 0.3 + +ds = ds.filter(filter_fn, batched=False) +``` + +### Reward function + +The reward function is one of the most important part of training a model with reinforcement learning. It is the function that will tell the model if it is doing well or not. +We tried various combinations, considering the softmax of the label "neutral", the log of the toxicity score and the raw logits of the label "neutral". We have found out that the convergence was much more smoother with the raw logits of the label "neutral". +```python +logits = toxicity_model(**toxicity_inputs).logits.float() +rewards = (logits[:, 0]).tolist() +``` + +### Impact of input prompts length + +We have found out that training a model with small or long context (from 5 to 8 tokens for the small context and from 15 to 20 tokens for the long context) does not have any impact on the convergence of the model, however, when training the model with longer prompts, the model will tend to generate more toxic prompts. +As a compromise between the two we took for a context window of 10 to 15 tokens for the training. + + +
+ +
+ +### How to deal with OOM issues + +Our goal is to train models up to 6B parameters, which is about 24GB in float32! Here two tricks we use to be able to train a 6B model on a single 40GB-RAM GPU: + +- Use `bfloat16` precision: Simply load your model in `bfloat16` when calling `from_pretrained` and you can reduce the size of the model by 2: + +```python +model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", torch_dtype=torch.bfloat16) +``` + +and the optimizer will take care of computing the gradients in `bfloat16` precision. Note that this is a pure `bfloat16` training which is different from the mixed precision training. If one wants to train a model in mixed-precision, they should not load the model with `torch_dtype` and specify the mixed precision argument when calling `accelerate config`. + +- Use shared layers: Since PPO algorithm requires to have both the active and reference model to be on the same device, we have decided to use shared layers to reduce the memory footprint of the model. This can be achieved by just speifying `num_shared_layers` argument when creating a `PPOTrainer`: + +
+ +
+ +```python +ppo_trainer = PPOTrainer( + model=model, + tokenizer=tokenizer, + num_shared_layers=4, + ... +) +``` + +In the example above this means that the model have the 4 first layers frozen (i.e. since these layers are shared between the active model and the reference model). + +- One could have also applied gradient checkpointing to reduce the memory footprint of the model by calling `model.pretrained_model.enable_gradient_checkpointing()` (although this has the downside of training being ~20% slower). + +## Training the model! + +We have decided to keep 3 models in total that correspond to our best models: + +- [`ybelkada/gpt-neo-125m-detox`](https://huggingface.co/ybelkada/gpt-neo-125m-detox) +- [`ybelkada/gpt-neo-2.7B-detox`](https://huggingface.co/ybelkada/gpt-neo-2.7B-detox) +- [`ybelkada/gpt-j-6b-detox`](https://huggingface.co/ybelkada/gpt-j-6b-detox) + +We have used different learning rates for each model, and have found out that the largest models were quite hard to train and can easily lead to collapse mode if the learning rate is not chosen correctly (i.e. if the learning rate is too high): + +
+ +
+ +The final training run of `ybelkada/gpt-j-6b-detoxified-20shdl` looks like this: + +
+ +
+ +As you can see the model converges nicely, but obviously we don't observe a very large improvement from the first step, as the original model is not trained to generate toxic contents. + +Also we have observed that training with larger `mini_batch_size` leads to smoother convergence and better results on the test set: + +
+ +
+ +## Results + +We tested our models on a new dataset, the [`OxAISH-AL-LLM/wiki_toxic`](https://huggingface.co/datasets/OxAISH-AL-LLM/wiki_toxic) dataset. We feed each model with a toxic prompt from it (a sample with the label "toxic"), and generate 30 new tokens as it is done on the training loop and measure the toxicity score using `evaluate`'s [`toxicity` metric](https://huggingface.co/spaces/ybelkada/toxicity). +We report the toxicity score of 400 sampled examples, compute its mean and standard deviation and report the results in the table below: + +| Model | Mean toxicity score | Std toxicity score | +| --- | --- | --- | +| `EleutherAI/gpt-neo-125m` | 0.1627 | 0.2997 | +| `ybelkada/gpt-neo-125m-detox` | **0.1148** | **0.2506** | +| --- | --- | --- | +| `EleutherAI/gpt-neo-2.7B` | 0.1884 | ,0.3178 | +| `ybelkada/gpt-neo-2.7B-detox` | **0.0916** | **0.2104** | +| --- | --- | --- | +| `EleutherAI/gpt-j-6B` | 0.1699 | 0.3033 | +| `ybelkada/gpt-j-6b-detox` | **0.1510** | **0.2798** | + +
+
+ +
Toxicity score with respect to the size of the model.
+
+
+ +Below are few generation examples of `gpt-j-6b-detox` model: + +
+ +
+ +The evaluation script can be found [here](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py). + +### Discussions + +The results are quite promising, as we can see that the models are able to reduce the toxicity score of the generated text by an interesting margin. The gap is clear for `gpt-neo-2B` model but we less so for the `gpt-j-6B` model. There are several things we could try to improve the results on the largest model starting with training with larger `mini_batch_size` and probably allowing to back-propagate through more layers (i.e. use less shared layers). + +To sum up, in addition to human feedback this could be a useful additional signal when training large language models to ensure there outputs are less toxic as well as useful. + +### Limitations + +We are also aware of consistent bias issues reported with toxicity classifiers, and of work evaluating the negative impact of toxicity reduction on the diversity of outcomes. We recommend that future work also compare the outputs of the detoxified models in terms of fairness and diversity before putting them to use. + +## What is next? + +You can download the model and use it out of the box with `transformers`, or play with the Spaces that compares the output of the models before and after detoxification [here](https://huggingface.co/spaces/ybelkada/detoxified-lms). diff --git a/docs/source/dpo_trainer.mdx b/docs/source/dpo_trainer.mdx new file mode 100644 index 0000000000000000000000000000000000000000..4ed1a18aa64c62c87472a08d9ecc7aca0a31d078 --- /dev/null +++ b/docs/source/dpo_trainer.mdx @@ -0,0 +1,106 @@ +# DPO Trainer + +TRL supports the DPO Trainer for training language models from preference data, as described in the paper [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://arxiv.org/abs/2305.18290) by Rafailov et al., 2023. For a full example have a look at [`examples/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo.py). + + +The first step as always is to train your SFT model, to ensure the data we train on is in-distribution for the DPO algorithm. + +## Expected dataset format + +The DPO trainer expects a very specific format for the dataset. Since the model will be trained to directly optimize the preference of which sentence is the most relevant, given two sentences. We provide an example from the [`Anthropic/hh-rlhf`](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset below: + +
+ +
+ +Therefore the final dataset object should contain these 3 entries if you use the default `DPODataCollatorWithPadding` data collator. The entries should be named: + +- `prompt` +- `chosen` +- `rejected` + +for example: + +```py +dpo_dataset_dict = { + "prompt": [ + "hello", + "how are you", + "What is your name?", + "What is your name?", + "Which is the best programming language?", + "Which is the best programming language?", + "Which is the best programming language?", + ], + "chosen": [ + "hi nice to meet you", + "I am fine", + "My name is Mary", + "My name is Mary", + "Python", + "Python", + "Java", + ], + "rejected": [ + "leave me alone", + "I am not fine", + "Whats it to you?", + "I dont have a name", + "Javascript", + "C++", + "C++", + ], +} +``` + +where the `prompt` contains the context inputs, `chosen` contains the corresponding chosen responses and `rejected` contains the corresponding negative (rejected) responses. As can be seen a prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays. + +## Expected model format +The DPO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function. + +## Using the `DPOTrainer` + +For a detailed example have a look at the `examples/scripts/dpo.py` script. At a high level we need to initialize the `DPOTrainer` with a `model` we wish to train, a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response, the `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder). + +```py + dpo_trainer = DPOTrainer( + model, + model_ref, + args=training_args, + beta=0.1, + train_dataset=train_dataset, + tokenizer=tokenizer, +) +``` +After this one can then call: + +```py +dpo_trainer.train() +``` + +Note that the `beta` is the temperature parameter for the DPO loss, typically something in the range of `0.1` to `0.5`. We ignore the reference model as `beta` -> 0. + +## Loss functions + +Given the preference data, we can fit a binary classifier according to the Bradley-Terry model and in fact the DPO authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression. + +The [RSO](https://arxiv.org/abs/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://arxiv.org/abs/2305.10425) paper. The `DPOTrainer` can be switched to this loss via the `loss_type="hinge"` argument and the `beta` in this case is the reciprocal of the margin. + +The [IPO](https://arxiv.org/abs/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss which can be used via the `loss_type="ipo"` argument to the trainer. + +The [cDPO](https://ericmitchell.ai/cdpo.pdf) is a tweak on the DPO loss where we assume that the preference labels are noisy with some probability that can be passed to the `DPOTrainer` via `label_smoothing` argument (between 0 and 0.5) and then a conservative DPO loss is used. Use the `loss_type="cdpo"` argument to the trainer to use it. + +The [KTO](https://github.com/ContextualAI/HALOs/blob/main/assets/report.pdf) loss is derived to directly maximize the utility of LLM generations instead of the log-likelihood of prefereces. Thus the dataset are not neccsarily prefereces but rather desirable vs undersirable pairs. Use the `loss_type="kto"` argument to the trainer to utilize this loss. + +## Logging + +While training and evaluating we record the following reward metrics: + +* `rewards/chosen`: the mean difference between the log probabilities of the policy model and the reference model for the chosen responses scaled by beta +* `rewards/rejected`: the mean difference between the log probabilities of the policy model and the reference model for the rejected responses scaled by beta +* `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards +* `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards + +## DPOTrainer + +[[autodoc]] DPOTrainer \ No newline at end of file diff --git a/docs/source/example_overview.md b/docs/source/example_overview.md new file mode 100644 index 0000000000000000000000000000000000000000..934dea1307d235cb0b95ab503325d7d35a4afcd7 --- /dev/null +++ b/docs/source/example_overview.md @@ -0,0 +1,73 @@ +# Examples + + +## Introduction + +The examples should work in any of the following settings (with the same script): + - single GPU + - multi GPUS (using PyTorch distributed mode) + - multi GPUS (using DeepSpeed ZeRO-Offload stages 1, 2, & 3) + - fp16 (mixed-precision), fp32 (normal precision), or bf16 (bfloat16 precision) + +To run it in each of these various modes, first initialize the accelerate +configuration with `accelerate config` + +**NOTE to train with a 4-bit or 8-bit model**, please run + +```bash +pip install --upgrade trl[quantization] +``` + + +## Accelerate Config +For all the examples, you'll need to generate a 🤗 Accelerate config file with: + +```shell +accelerate config # will prompt you to define the training configuration +``` + +Then, it is encouraged to launch jobs with `accelerate launch`! + + +# Maintained Examples + + +| File | Description | +|------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------| +| [`examples/scripts/sft.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft.py) | This script shows how to use the `SFTTrainer` to fine tune a model or adapters into a target dataset. | +| [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py) | This script shows how to use the `RewardTrainer` to train a reward model on your own dataset. | +| [`examples/scripts/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py) | This script shows how to use the `PPOTrainer` to fine-tune a sentiment analysis model using IMDB dataset | +| [`examples/scripts/ppo_multi_adapter.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo_multi_adapter.py) | This script shows how to use the `PPOTrainer` to train a single base model with multiple adapters. Requires you to run the example script with the reward model training beforehand. | +| [`examples/scripts/stable_diffusion_tuning_example.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/stable_diffusion_tuning_example.py) | This script shows to use DDPOTrainer to fine-tune a stable diffusion model using reinforcement learning. | + +Here are also some easier-to-run colab notebooks that you can use to get started with TRL: + + +| File | Description | +|----------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------| +| [`examples/notebooks/best_of_n.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/best_of_n.ipynb) | This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO. | +| [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb) | This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. | +| [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook. | + + +We also have some other examples that are less maintained but can be used as a reference: +1. **[research_projects](https://github.com/huggingface/trl/tree/main/examples/research_projects)**: Check out this folder to find the scripts used for some research projects that used TRL (LM de-toxification, Stack-Llama, etc.) + + +## Distributed training + +All of the scripts can be run on multiple GPUs by providing the path of an 🤗 Accelerate config file when calling `accelerate launch`. To launch one of them on one or multiple GPUs, run the following command (swapping `{NUM_GPUS}` with the number of GPUs in your machine and `--all_arguments_of_the_script` with your arguments.) + +```shell +accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml --num_processes {NUM_GPUS} path_to_script.py --all_arguments_of_the_script +``` + +You can also adjust the parameters of the 🤗 Accelerate config file to suit your needs (e.g. training in mixed precision). + +### Distributed training with DeepSpeed + +Most of the scripts can be run on multiple GPUs together with DeepSpeed ZeRO-{1,2,3} for efficient sharding of the optimizer states, gradients, and model weights. To do so, run following command (swapping `{NUM_GPUS}` with the number of GPUs in your machine, `--all_arguments_of_the_script` with your arguments, and `--deepspeed_config` with the path to the DeepSpeed config file such as `examples/deepspeed_configs/deepspeed_zero1.yaml`): + +```shell +accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero{1,2,3}.yaml --num_processes {NUM_GPUS} path_to_script.py --all_arguments_of_the_script +``` diff --git a/docs/source/how_to_train.md b/docs/source/how_to_train.md new file mode 100644 index 0000000000000000000000000000000000000000..f4c88f009d8dd0197f63b7d659923bee258ee8da --- /dev/null +++ b/docs/source/how_to_train.md @@ -0,0 +1,66 @@ +# Training FAQ + +## What Metrics Should I Look at? + +When performing classical supervised fine-tuning of language models, the loss (especially the validation loss) serves as a good indicator of the training progress. However, in Reinforcement Learning (RL), the loss becomes less informative about the model's performance, and its value may fluctuate while the actual performance improves. + +To address this, we recommend focusing on two key metrics first: + +**Mean Reward**: The primary goal is to maximize the reward achieved by the model during RL training. +**Objective KL Divergence**: KL divergence (Kullback-Leibler divergence) measures the dissimilarity between two probability distributions. In the context of RL training, we use it to quantify the difference between the current model and a reference model. Ideally, we want to keep the KL divergence between 0 and 10 to ensure the model's generated text remains close to what the reference model produces. + +However, there are more metrics that can be useful for debugging, checkout the [logging section](logging). + +## Why Do We Use a Reference Model, and What's the Purpose of KL Divergence? + +When training RL models, optimizing solely for reward may lead to unexpected behaviors, where the model exploits the environment in ways that don't align with good language generation. In the case of RLHF, we use a reward model trained to predict whether a generated text is highly ranked by humans. + +However, the RL model being optimized against the reward model may learn patterns that yield high reward but do not represent good language. This can result in extreme cases where the model generates texts with excessive exclamation marks or emojis to maximize the reward. In some worst-case scenarios, the model may generate patterns completely unrelated to natural language yet receive high rewards, similar to adversarial attacks. + +
+ +

Figure: Samples without a KL penalty from https://arxiv.org/pdf/1909.08593.pdf.

+
+ +To address this issue, we add a penalty to the reward function based on the KL divergence between the current model and the reference model. By doing this, we encourage the model to stay close to what the reference model generates. + +## What Is the Concern with Negative KL Divergence? + +If you generate text by purely sampling from the model distribution things work fine in general. But when you use the `generate` method there are a few caveats because it does not always purely sample depending on the settings which can cause KL-divergence to go negative. Essentially when the active model achieves `log_p_token_active < log_p_token_ref` we get negative KL-div. This can happen in a several cases: + +- **top-k sampling**: the model can smooth out the probability distribution causing the top-k tokens having a smaller probability than those of the reference model but they still are selected +- **min_length**: this ignores the EOS token until `min_length` is reached. thus the model can assign a very low log prob to the EOS token and very high probs to all others until min_length is reached +- **min_length**: this ignores the EOS token until `min_length` is reached, thus the model can assign a very low log prob to the EOS token and very high probs to all others until min_length is reached + +These are just a few examples. Why is negative KL an issue? The total reward `R` is computed `R = r - beta * KL` so if the model can learn how to drive KL-divergence negative it effectively gets a positive reward. In many cases it can be much easier to exploit such a bug in the generation than actually learning the reward function. In addition the KL can become arbitrarily small thus the actual reward can be very small compared to it. + +So how should you generate text for PPO training? Let's have a look! + +## How to generate text for training? + +In order to avoid the KL issues described above we recommend to use the following settings: + +```python +generation_kwargs = { + "min_length": -1, # don't ignore the EOS token (see above) + "top_k": 0.0, # no top-k sampling + "top_p": 1.0, # no nucleus sampling + "do_sample": True, # yes, we want to sample + "pad_token_id": tokenizer.eos_token_id, # most decoder models don't have a padding token - use EOS token instead + "max_new_tokens": 32, # specify how many tokens you want to generate at most +} +``` + +With these settings we usually don't encounter any issues. You can also experiments with other settings but if you encounter issues with negative KL-divergence try to go back to these and see if they persist. + +## How can debug your own use-case? + +Debugging the RL pipeline can be challenging due to its complexity. Here are some tips and suggestions to make the process easier: + +- **Start from a working example**: Begin with a working example from the trl repository and gradually modify it to fit your specific use-case. Changing everything at once can make it difficult to identify the source of potential issues. For example, you can start by replacing the model in the example and once you figure out the best hyperparameters try to switch to your dataset and reward model. If you change everything at once you won't know where a potential problem comes from. +- **Start small, scale later**: Training large models can be very slow and take several hours or days until you see any improvement. For debugging this is not a convenient timescale so try to use small model variants during the development phase and scale up once that works. That being said you sometimes have to be careful as small models might not have the capacity to solve a complicated task either. +- **Start simple**: Try to start with a minimal example and build complexity from there. Your use-case might require for example a complicated reward function consisting of many different rewards - try to use one signal first and see if you can optimize that and then add more complexity after that. +- **Inspect the generations**: It's always a good idea to inspect what the model is generating. Maybe there is a big in your post-processing or your prompt. Due to bad settings you might cut-off generations too soon. These things are very hard to see on the metrics but very obvious if you look at the generations. +- **Inspect the reward model**: If you reward is not improving over time maybe there's an issue with the reward model. You can look at extreme cases to see if it does what it should: e.g. in the sentiment case you can check if simple positive and negative examples really get different rewards. And you can look at the distribution of your dataset. Finally, maybe the reward is dominated by the query which the model can't affect so you might need to normalize this (e.g. reward of query+response minus reward of the query). + +These are just a few tips that we find helpful - if you have more useful tricks feel free to open a PR to add them as well! diff --git a/docs/source/index.mdx b/docs/source/index.mdx new file mode 100644 index 0000000000000000000000000000000000000000..1c766e26c0ec00263d8f8753a9d450d265a4b2af --- /dev/null +++ b/docs/source/index.mdx @@ -0,0 +1,61 @@ +
+ +
+ +# TRL - Transformer Reinforcement Learning + +TRL is a full stack library where we provide a set of tools to train transformer language models with Reinforcement Learning, from the Supervised Fine-tuning step (SFT), Reward Modeling step (RM) to the Proximal Policy Optimization (PPO) step. +The library is integrated with 🤗 [transformers](https://github.com/huggingface/transformers). + +
+ +
+ +Check the appropriate sections of the documentation depending on your needs: + +## API documentation + +- [Model Classes](models): *A brief overview of what each public model class does.* +- [`SFTTrainer`](sft_trainer): *Supervise Fine-tune your model easily with `SFTTrainer`* +- [`RewardTrainer`](reward_trainer): *Train easily your reward model using `RewardTrainer`.* +- [`PPOTrainer`](ppo_trainer): *Further fine-tune the supervised fine-tuned model using PPO algorithm* +- [Best-of-N Sampling](best-of-n): *Use best of n sampling as an alternative way to sample predictions from your active model* +- [`DPOTrainer`](dpo_trainer): *Direct Preference Optimization training using `DPOTrainer`.* +- [`TextEnvironment`](text_environment): *Text environment to train your model using tools with RL.* + +## Examples + +- [Sentiment Tuning](sentiment_tuning): *Fine tune your model to generate positive movie contents* +- [Training with PEFT](lora_tuning_peft): *Memory efficient RLHF training using adapters with PEFT* +- [Detoxifying LLMs](detoxifying_a_lm): *Detoxify your language model through RLHF* +- [StackLlama](using_llama_models): *End-to-end RLHF training of a Llama model on Stack exchange dataset* +- [Learning with Tools](learning_tools): *Walkthrough of using `TextEnvironments`* +- [Multi-Adapter Training](multi_adapter_rl): *Use a single base model and multiple adapters for memory efficient end-to-end training* + + +## Blog posts + +
+ +
diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx new file mode 100644 index 0000000000000000000000000000000000000000..bf74b64175fb15459b2cc1b61caea5ce159888f0 --- /dev/null +++ b/docs/source/installation.mdx @@ -0,0 +1,24 @@ +# Installation +You can install TRL either from pypi or from source: + +## pypi +Install the library with pip: + +```bash +pip install trl +``` + +### Source +You can also install the latest version from source. First clone the repo and then run the installation with `pip`: + +```bash +git clone https://github.com/huggingface/trl.git +cd trl/ +pip install -e . +``` + +If you want the development install you can replace the pip install with the following: + +```bash +pip install -e ".[dev]" +``` \ No newline at end of file diff --git a/docs/source/iterative_sft_trainer.mdx b/docs/source/iterative_sft_trainer.mdx new file mode 100644 index 0000000000000000000000000000000000000000..a6eaf5c98f45b2f3829f0c723d1ef743d77fed6c --- /dev/null +++ b/docs/source/iterative_sft_trainer.mdx @@ -0,0 +1,54 @@ +# Iterative Trainer + +Iterative fine-tuning is a training method that enables to perform custom actions (generation and filtering for example) between optimization steps. In TRL we provide an easy-to-use API to fine-tune your models in an iterative way in just a few lines of code. + +## Usage + +To get started quickly, instantiate an instance a model, and a tokenizer. + +```python + +model = AutoModelForCausalLM.from_pretrained(model_name) +tokenizer = AutoTokenizer.from_pretrained(model_name) +if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + +trainer = IterativeSFTTrainer( + model, + tokenizer +) + +``` + +You have the choice to either provide a list of strings or a list of tensors to the step function. + +#### Using a list of tensors as input: + +```python + +inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask +} + +trainer.step(**inputs) + +``` + +#### Using a list of strings as input: + +```python + +inputs = { + "texts": texts +} + +trainer.step(**inputs) + +``` + +For causal language models, labels will automatically be created from input_ids or from texts. When using sequence to sequence models you will have to provide your own labels or text_labels. + +## IterativeTrainer + +[[autodoc]] IterativeSFTTrainer diff --git a/docs/source/learning_tools.mdx b/docs/source/learning_tools.mdx new file mode 100644 index 0000000000000000000000000000000000000000..eb7b390b4257f360fdb8915dbe4248586573225e --- /dev/null +++ b/docs/source/learning_tools.mdx @@ -0,0 +1,234 @@ +# Learning Tools (Experimental 🧪) + +Using Large Language Models (LLMs) with tools has been a popular topic recently with awesome works such as [ToolFormer](https://arxiv.org/abs/2302.04761) and [ToolBench](https://arxiv.org/pdf/2305.16504.pdf). In TRL, we provide a simple example of how to teach LLM to use tools with reinforcement learning. + + +Here's an overview of the scripts in the [trl repository](https://github.com/lvwerra/trl/tree/main/examples/research_projects/tools): + +| File | Description | +|---|---| +| [`calculator.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/calculator.py) | Script to train LLM to use a calculator with reinforcement learning. | +| [`triviaqa.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/triviaqa.py) | Script to train LLM to use a wiki tool to answer questions. | +| [`python_interpreter.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/python_interpreter.py) | Script to train LLM to use python interpreter to solve math puzzles. | + + + +Note that the scripts above rely heavily on the `TextEnvironment` API which is still under active development. The API may change in the future. Please see [`TextEnvironment`](text_environment) for the related docs. + + + +## Learning to Use a Calculator + + +The rough idea is as follows: + +1. Load a tool such as [ybelkada/simple-calculator](https://huggingface.co/spaces/ybelkada/simple-calculator) that parse a text calculation like `"14 + 34"` and return the calulated number: + ```python + from transformers import AutoTokenizer, load_tool + tool = load_tool("ybelkada/simple-calculator") + tool_fn = lambda text: str(round(float(tool(text)), 2)) # rounding to 2 decimal places + ``` +1. Define a reward function that returns a positive reward if the tool returns the correct answer. In the script we create a dummy reward function like `reward_fn = lambda x: 1`, but we override the rewards directly later. +1. Create a prompt on how to use the tools + ```python + # system prompt + prompt = """\ + What is 13.1-3? + + 13.1-310.1 + + Result=10.1 + + What is 4*3? + + 4*312 + + Result=12 + + What is 12.1+1? + + 12.1+113.1 + + Result=13.1 + + What is 12.1-20? + + 12.1-20-7.9 + + Result=-7.9""" + ``` +3. Create a `trl.TextEnvironment` with the model + ```python + env = TextEnvironment( + model, + tokenizer, + {"SimpleCalculatorTool": tool_fn}, + reward_fn, + prompt, + generation_kwargs=generation_kwargs, + ) + ``` +4. Then generate some data such as `tasks = ["\n\nWhat is 13.1-3?", "\n\nWhat is 4*3?"]` and run the environment with `queries, responses, masks, rewards, histories = env.run(tasks)`. The environment will look for the `` token in the prompt and append the tool output to the response; it will also return the mask associated with the response. You can further use the `histories` to visualize the interaction between the model and the tool; `histories[0].show_text()` will show the text with color-coded tool output and `histories[0].show_tokens(tokenizer)` will show visualize the tokens. + ![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/learning_tools.png) +1. Finally, we can train the model with `train_stats = ppo_trainer.step(queries, responses, rewards, masks)`. The trainer will use the mask to ignore the tool output when computing the loss, make sure to pass that argument to `step`. + +## Experiment results + +We trained a model with the above script for 10 random seeds. You can reproduce the run with the following command. Feel free to remove the `--slurm-*` arguments if you don't have access to a slurm cluster. + +``` +WANDB_TAGS="calculator_final" python benchmark/benchmark.py \ + --command "python examples/research_projects/tools/calculator.py" \ + --num-seeds 10 \ + --start-seed 1 \ + --workers 10 \ + --slurm-gpus-per-task 1 \ + --slurm-ntasks 1 \ + --slurm-total-cpus 8 \ + --slurm-template-path benchmark/trl.slurm_template +``` + +We can then use [`openrlbenchmark`](https://github.com/openrlbenchmark/openrlbenchmark) which generates the following plot. +``` +python -m openrlbenchmark.rlops_multi_metrics \ + --filters '?we=openrlbenchmark&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.tracker_project_name&cen=trl_ppo_trainer_config.value.log_with&metrics=env/reward_mean&metrics=objective/kl' \ + 'wandb?tag=calculator_final&cl=calculator_mask' \ + --env-ids trl \ + --check-empty-runs \ + --pc.ncols 2 \ + --pc.ncols-legend 1 \ + --output-filename static/0compare \ + --scan-history +``` + +![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/learning_tools_chart.png) + +As we can see, while 1-2 experiments crashed for some reason, most of the runs obtained near perfect proficiency in the calculator task. + + +## (Early Experiments 🧪): learning to use a wiki tool for question answering + +In the [ToolFormer](https://arxiv.org/abs/2302.04761) paper, it shows an interesting use case that utilizes a Wikipedia Search tool to help answer questions. In this section, we attempt to perform similar experiments but uses RL instead to teach the model to use a wiki tool on the [TriviaQA](https://nlp.cs.washington.edu/triviaqa/) dataset. + + + + +**Note that many settings are different so the results are not directly comparable.** + + + + + +### Building a search index + +Since [ToolFormer](https://arxiv.org/abs/2302.04761) did not open source, we needed to first replicate the search index. It is mentioned in their paper that the authors built the search index using a BM25 retriever that indexes the Wikipedia dump from [KILT](https://github.com/facebookresearch/KILT) + +Fortunately, [`pyserini`](https://github.com/castorini/pyserini) already implements the BM25 retriever and provides a prebuilt index for the KILT Wikipedia dump. We can use the following code to search the index. + +```python +from pyserini.search.lucene import LuceneSearcher +import json +searcher = LuceneSearcher.from_prebuilt_index('wikipedia-kilt-doc') +def search(query): + hits = searcher.search(query, k=1) + hit = hits[0] + contents = json.loads(hit.raw)['contents'] + return contents +print(search("tennis racket")) +``` +``` +Racket (sports equipment) +A racket or racquet is a sports implement consisting of a handled frame with an open hoop across which a network of strings or catgut is stretched tightly. It is used for striking a ball or shuttlecock in games such as squash, tennis, racquetball, and badminton. Collectively, these games are known as racket sports. Racket design and manufacturing has changed considerably over the centuries. + +The frame of rackets for all sports was traditionally made of solid wood (later laminated wood) and the strings of animal intestine known as catgut. The traditional racket size was limited by the strength and weight of the wooden frame which had to be strong enough to hold the strings and stiff enough to hit the ball or shuttle. Manufacturers started adding non-wood laminates to wood rackets to improve stiffness. Non-wood rackets were made first of steel, then of aluminum, and then carbon fiber composites. Wood is still used for real tennis, rackets, and xare. Most rackets are now made of composite materials including carbon fiber or fiberglass, metals such as titanium alloys, or ceramics. +... +``` + +We then basically deployed this snippet as a Hugging Face space [here](https://huggingface.co/spaces/vwxyzjn/pyserini-wikipedia-kilt-doc), so that we can use the space as a `transformers.Tool` later. + +![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pyserini.png) + +### Experiment settings + +We use the following settings: + +* use the `bigcode/starcoderbase` model as the base model +* use the `pyserini-wikipedia-kilt-doc` space as the wiki tool and only uses the first paragrahs of the search result, allowing the `TextEnvironment` to obtain at most `max_tool_reponse=400` response tokens from the tool. +* test if the response contain the answer string, if so, give a reward of 1, otherwise, give a reward of 0. + * notice this is a simplified evaluation criteria. In [ToolFormer](https://arxiv.org/abs/2302.04761), the authors checks if the first 20 words of the response contain the correct answer. +* used the following prompt that demonstrates the usage of the wiki tool. +```python +prompt = """\ +Answer the following question: + +Q: In which branch of the arts is Patricia Neary famous? +A: Ballets +A2: Patricia NearyPatricia Neary (born October 27, 1942) is an American ballerina, choreographer and ballet director, who has been particularly active in Switzerland. She has also been a highly successful ambassador for the Balanchine Trust, bringing George Balanchine's ballets to 60 cities around the globe. +Result=Ballets + +Q: Who won Super Bowl XX? +A: Chicago Bears +A2: Super Bowl XXSuper Bowl XX was an American football game between the National Football Conference (NFC) champion Chicago Bears and the American Football Conference (AFC) champion New England Patriots to decide the National Football League (NFL) champion for the 1985 season. The Bears defeated the Patriots by the score of 46–10, capturing their first NFL championship (and Chicago's first overall sports victory) since 1963, three years prior to the birth of the Super Bowl. Super Bowl XX was played on January 26, 1986 at the Louisiana Superdome in New Orleans. +Result=Chicago Bears + +Q: """ +``` + + +### Result and Discussion + + +Our experiments show that the agent can learn to use the wiki tool to answer questions. The learning curves would go up mostly, but one of the experiment did crash. + +![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/triviaqa_learning_curves.png) + +Wandb report is [here](https://wandb.ai/costa-huang/cleanRL/reports/TriviaQA-Final-Experiments--Vmlldzo1MjY0ODk5) for further inspection. + + +Note that the correct rate of the trained model is on the low end, which could be due to the following reasons: + +* **incorrect searches:** When given the question `"What is Bruce Willis' real first name?"` if the model searches for `Bruce Willis`, our wiki tool returns "Patrick Poivey (born 18 February 1948) is a French actor. He is especially known for his voice: he is the French dub voice of Bruce Willis since 1988.` But a correct search should be `Walter Bruce Willis (born March 19, 1955) is an American former actor. He achieved fame with a leading role on the comedy-drama series Moonlighting (1985–1989) and appeared in over a hundred films, gaining recognition as an action hero after his portrayal of John McClane in the Die Hard franchise (1988–2013) and other roles.[1][2]" + + + ![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/real_first_name.png) + +* **unnecessarily long response**: The wiki tool by default sometimes output very long sequences. E.g., when the wiki tool searches for "Brown Act" + * Our wiki tool returns "The Ralph M. Brown Act, located at California Government Code 54950 "et seq.", is an act of the California State Legislature, authored by Assemblymember Ralph M. Brown and passed in 1953, that guarantees the public's right to attend and participate in meetings of local legislative bodies." + * [ToolFormer](https://arxiv.org/abs/2302.04761)'s wiki tool returns "The Ralph M. Brown Act is an act of the California State Legislature that guarantees the public's right to attend and participate in meetings of local legislative bodies." which is more succinct. + + ![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/brown_act.png) + + +## (Early Experiments 🧪): solving math puzzles with python interpreter + +In this section, we attempt to teach the model to use a python interpreter to solve math puzzles. The rough idea is to give the agent a prompt like the following: + +```python +prompt = """\ +Example of using a Python API to solve math questions. + +Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left? + + +def solution(): + money_initial = 23 + bagels = 5 + bagel_cost = 3 + money_spent = bagels * bagel_cost + money_left = money_initial - money_spent + result = money_left + return result +print(solution()) +72 + +Result = 72 + +Q: """ +``` + + +Training experiment can be found at https://wandb.ai/lvwerra/trl-gsm8k/runs/a5odv01y + +![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/gms8k_learning_curve.png) + + diff --git a/docs/source/logging.mdx b/docs/source/logging.mdx new file mode 100644 index 0000000000000000000000000000000000000000..71eb7c4137532b75d0d8af1e912f1f706078f6d3 --- /dev/null +++ b/docs/source/logging.mdx @@ -0,0 +1,75 @@ +# Logging + +As reinforcement learning algorithms are historically challenging to debug, it's important to pay careful attention to logging. +By default, the TRL [`PPOTrainer`] saves a lot of relevant information to `wandb` or `tensorboard`. + +Upon initialization, pass one of these two options to the [`PPOConfig`]: +``` +config = PPOConfig( + model_name=args.model_name, + log_with=`wandb`, # or `tensorboard` +) +``` +If you want to log with tensorboard, add the kwarg `project_kwargs={"logging_dir": PATH_TO_LOGS}` to the PPOConfig. + +## PPO Logging + +Here's a brief explanation for the logged metrics provided in the data: + +Key metrics to monitor. We want to maximize the reward, maintain a low KL divergence, and maximize entropy: +1. `env/reward_mean`: The average reward obtained from the environment. Alias `ppo/mean_scores`, which is sed to specifically monitor the reward model. +1. `env/reward_std`: The standard deviation of the reward obtained from the environment. Alias ``ppo/std_scores`, which is sed to specifically monitor the reward model. +1. `env/reward_dist`: The histogram distribution of the reward obtained from the environment. +1. `objective/kl`: The mean Kullback-Leibler (KL) divergence between the old and new policies. It measures how much the new policy deviates from the old policy. The KL divergence is used to compute the KL penalty in the objective function. +1. `objective/kl_dist`: The histogram distribution of the `objective/kl`. +1. `objective/kl_coef`: The coefficient for Kullback-Leibler (KL) divergence in the objective function. +1. `ppo/mean_non_score_reward`: The **KL penalty** calculated by `objective/kl * objective/kl_coef` as the total reward for optimization to prevent the new policy from deviating too far from the old policy. +1. `objective/entropy`: The entropy of the model's policy, calculated by `-logprobs.sum(-1).mean()`. High entropy means the model's actions are more random, which can be beneficial for exploration. + +Training stats: +1. `ppo/learning_rate`: The learning rate for the PPO algorithm. +1. `ppo/policy/entropy`: The entropy of the model's policy, calculated by `pd = torch.nn.functional.softmax(logits, dim=-1); entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1)`. It measures the randomness of the policy. +1. `ppo/policy/clipfrac`: The fraction of probability ratios (old policy / new policy) that fell outside the clipping range in the PPO objective. This can be used to monitor the optimization process. +1. `ppo/policy/approxkl`: The approximate KL divergence between the old and new policies, measured by `0.5 * masked_mean((logprobs - old_logprobs) ** 2, mask)`, corresponding to the `k2` estimator in http://joschu.net/blog/kl-approx.html +1. `ppo/policy/policykl`: Similar to `ppo/policy/approxkl`, but measured by `masked_mean(old_logprobs - logprobs, mask)`, corresponding to the `k1` estimator in http://joschu.net/blog/kl-approx.html +1. `ppo/policy/ratio`: The histogram distribution of the ratio between the new and old policies, used to compute the PPO objective. +1. `ppo/policy/advantages_mean`: The average of the GAE (Generalized Advantage Estimation) advantage estimates. The advantage function measures how much better an action is compared to the average action at a state. +1. `ppo/policy/advantages`: The histogram distribution of `ppo/policy/advantages_mean`. +1. `ppo/returns/mean`: The mean of the TD(λ) returns, calculated by `returns = advantage + values`, another indicator of model performance. See https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/ for more details. +1. `ppo/returns/var`: The variance of the TD(λ) returns, calculated by `returns = advantage + values`, another indicator of model performance. +1. `ppo/val/mean`: The mean of the values, used to monitor the value function's performance. +1. `ppo/val/var` : The variance of the values, used to monitor the value function's performance. +1. `ppo/val/var_explained`: The explained variance for the value function, used to monitor the value function's performance. +1. `ppo/val/clipfrac`: The fraction of the value function's predicted values that are clipped. +1. `ppo/val/vpred`: The predicted values from the value function. +1. `ppo/val/error`: The mean squared error between the `ppo/val/vpred` and returns, used to monitor the value function's performance. +1. `ppo/loss/policy`: The policy loss for the Proximal Policy Optimization (PPO) algorithm. +1. `ppo/loss/value`: The loss for the value function in the PPO algorithm. This value quantifies how well the function estimates the expected future rewards. +1. `ppo/loss/total`: The total loss for the PPO algorithm. It is the sum of the policy loss and the value function loss. + + +Stats on queries, responses, and logprobs: +1. `tokens/queries_len_mean`: The average length of the queries tokens. +1. `tokens/queries_len_std`: The standard deviation of the length of the queries tokens. +1. `tokens/queries_dist`: The histogram distribution of the length of the queries tokens. +1. `tokens/responses_len_mean`: The average length of the responses tokens. +1. `tokens/responses_len_std`: The standard deviation of the length of the responses tokens. +1. `tokens/responses_dist`: The histogram distribution of the length of the responses tokens. (Costa: inconsistent naming, should be `tokens/responses_len_dist`) +1. `objective/logprobs`: The histogram distribution of the log probabilities of the actions taken by the model. +1. `objective/ref_logprobs`: The histogram distribution of the log probabilities of the actions taken by the reference model. + + + +### Crucial values +During training, many values are logged, here are the most important ones: + +1. `env/reward_mean`,`env/reward_std`, `env/reward_dist`: the properties of the reward distribution from the "environment" / reward model +1. `ppo/mean_non_score_reward`: The mean negated KL penalty during training (shows the delta between the reference model and the new policy over the batch in the step) + +Here are some parameters that are useful to monitor for stability (when these diverge or collapse to 0, try tuning variables): + +1. `ppo/loss/value`: it will spike / NaN when not going well. +1. `ppo/policy/ratio`: `ratio` being 1 is a baseline value, meaning that the probability of sampling a token is the same under the new and old policy. If the ratio is too high like 200, it means the probability of sampling a token is 200 times higher under the new policy than the old policy. This is a sign that the new policy is too different from the old policy, which will likely cause overoptimization and collapse training later on. +1. `ppo/policy/clipfrac` and `ppo/policy/approxkl`: if `ratio` is too high, the `ratio` is going to get clipped, resulting in high `clipfrac` and high `approxkl` as well. +1. `objective/kl`: it should stay positive so that the policy is not too far away from the reference policy. +1. `objective/kl_coef`: The target coefficient with [`AdaptiveKLController`]. Often increases before numerical instabilities. \ No newline at end of file diff --git a/docs/source/lora_tuning_peft.mdx b/docs/source/lora_tuning_peft.mdx new file mode 100644 index 0000000000000000000000000000000000000000..4b4345bc5f4806fdf9a0b889da43c77b6b071506 --- /dev/null +++ b/docs/source/lora_tuning_peft.mdx @@ -0,0 +1,144 @@ +# Examples of using peft with trl to finetune 8-bit models with Low Rank Adaption (LoRA) + +The notebooks and scripts in this examples show how to use Low Rank Adaptation (LoRA) to fine-tune models in a memory efficient manner. Most of PEFT methods supported in peft library but note that some PEFT methods such as Prompt tuning are not supported. +For more information on LoRA, see the [original paper](https://arxiv.org/abs/2106.09685). + +Here's an overview of the `peft`-enabled notebooks and scripts in the [trl repository](https://github.com/huggingface/trl/tree/main/examples): + +| File | Task | Description | Colab link | +|---|---| --- | +| [`stack_llama/rl_training.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py) | RLHF | Distributed fine-tuning of the 7b parameter LLaMA models with a learned reward model and `peft`. | | +| [`stack_llama/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/reward_modeling.py) | Reward Modeling | Distributed training of the 7b parameter LLaMA reward model with `peft`. | | +| [`stack_llama/supervised_finetuning.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/supervised_finetuning.py) | SFT | Distributed instruction/supervised fine-tuning of the 7b parameter LLaMA model with `peft`. | | + +## Installation +Note: peft is in active development, so we install directly from their Github page. +Peft also relies on the latest version of transformers. + +```bash +pip install trl[peft] +pip install bitsandbytes loralib +pip install git+https://github.com/huggingface/transformers.git@main +#optional: wandb +pip install wandb +``` + +Note: if you don't want to log with `wandb` remove `log_with="wandb"` in the scripts/notebooks. You can also replace it with your favourite experiment tracker that's [supported by `accelerate`](https://huggingface.co/docs/accelerate/usage_guides/tracking). + +## How to use it? + +Simply declare a `PeftConfig` object in your script and pass it through `.from_pretrained` to load the TRL+PEFT model. + +```python +from peft import LoraConfig +from trl import AutoModelForCausalLMWithValueHead + +model_id = "edbeeching/gpt-neo-125M-imdb" +lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", +) + +model = AutoModelForCausalLMWithValueHead.from_pretrained( + model_id, + peft_config=lora_config, +) +``` +And if you want to load your model in 8bit precision: +```python +pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained( + config.model_name, + load_in_8bit=True, + peft_config=lora_config, +) +``` +... or in 4bit precision: +```python +pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained( + config.model_name, + peft_config=lora_config, + load_in_4bit=True, +) +``` + + +## Launch scripts + +The `trl` library is powered by `accelerate`. As such it is best to configure and launch trainings with the following commands: + +```bash +accelerate config # will prompt you to define the training configuration +accelerate launch scripts/gpt2-sentiment_peft.py # launches training +``` + +## Using `trl` + `peft` and Data Parallelism + +You can scale up to as many GPUs as you want, as long as you are able to fit the training process in a single device. The only tweak you need to apply is to load the model as follows: +```python +from peft import LoraConfig +... + +lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", +) + +pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained( + config.model_name, + peft_config=lora_config, +) +``` +And if you want to load your model in 8bit precision: +```python +pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained( + config.model_name, + peft_config=lora_config, + load_in_8bit=True, +) +``` +... or in 4bit precision: +```python +pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained( + config.model_name, + peft_config=lora_config, + load_in_4bit=True, +) +``` +Finally, make sure that the rewards are computed on correct device as well, for that you can use `ppo_trainer.model.current_device`. + +## Naive pipeline parallelism (NPP) for large models (>60B models) + +The `trl` library also supports naive pipeline parallelism (NPP) for large models (>60B models). This is a simple way to parallelize the model across multiple GPUs. +This paradigm, termed as "Naive Pipeline Parallelism" (NPP) is a simple way to parallelize the model across multiple GPUs. We load the model and the adapters across multiple GPUs and the activations and gradients will be naively communicated across the GPUs. This supports `int8` models as well as other `dtype` models. + +
+ +
+ +### How to use NPP? + +Simply load your model with a custom `device_map` argument on the `from_pretrained` to split your model across multiple devices. Check out this [nice tutorial](https://github.com/huggingface/blog/blob/main/accelerate-large-models.md) on how to properly create a `device_map` for your model. + +Also make sure to have the `lm_head` module on the first GPU device as it may throw an error if it is not on the first device. As this time of writing, you need to install the `main` branch of `accelerate`: `pip install git+https://github.com/huggingface/accelerate.git@main` and `peft`: `pip install git+https://github.com/huggingface/peft.git@main`. + +### Launch scripts + +Although `trl` library is powered by `accelerate`, you should run your training script in a single process. Note that we do not support Data Parallelism together with NPP yet. + +```bash +python PATH_TO_SCRIPT +``` + +## Fine-tuning Llama-2 model + +You can easily fine-tune Llama2 model using `SFTTrainer` and the official script! For example to fine-tune llama2-7b on the Guanaco dataset, run (tested on a single NVIDIA T4-16GB): + +```bash +python examples/scripts/sft.py --model_name meta-llama/Llama-2-7b-hf --dataset_name timdettmers/openassistant-guanaco --load_in_4bit --use_peft --batch_size 4 --gradient_accumulation_steps 2 +``` diff --git a/docs/source/models.mdx b/docs/source/models.mdx new file mode 100644 index 0000000000000000000000000000000000000000..f96068fc46f160c6d60d3b95712fb277c826f6e9 --- /dev/null +++ b/docs/source/models.mdx @@ -0,0 +1,28 @@ +# Models + +With the `AutoModelForCausalLMWithValueHead` class TRL supports all decoder model architectures in transformers such as GPT-2, OPT, and GPT-Neo. In addition, with `AutoModelForSeq2SeqLMWithValueHead` you can use encoder-decoder architectures such as T5. TRL also requires reference models which are frozen copies of the model that is trained. With `create_reference_model` you can easily create a frozen copy and also share layers between the two models to save memory. + +## PreTrainedModelWrapper + +[[autodoc]] PreTrainedModelWrapper + +## AutoModelForCausalLMWithValueHead + + +[[autodoc]] AutoModelForCausalLMWithValueHead + - __init__ + - forward + - generate + - _init_weights + +## AutoModelForSeq2SeqLMWithValueHead + +[[autodoc]] AutoModelForSeq2SeqLMWithValueHead + - __init__ + - forward + - generate + - _init_weights + +## create_reference_model + +[[autodoc]] create_reference_model \ No newline at end of file diff --git a/docs/source/multi_adapter_rl.mdx b/docs/source/multi_adapter_rl.mdx new file mode 100644 index 0000000000000000000000000000000000000000..ba41f326116c235bc0f13884176a1d4ee9d00cb6 --- /dev/null +++ b/docs/source/multi_adapter_rl.mdx @@ -0,0 +1,100 @@ +# Multi Adapter RL (MARL) - a single base model for everything + +Here we present an approach that uses a single base model for the entire PPO algorithm - which includes retrieving the reference logits, computing the active logits and the rewards. This feature is experimental as we did not tested the convergence of the approach. We encourage the community to let us know if they potentially face into any issue. + +## Requirements + +You just need to install `peft` and optionally install `bitsandbytes` as well if you want to go for 8bit base models, for more memory efficient finetuning. + +## Summary + +You need to address this approach in three stages that we summarize as follows: + +1- Train a base model on the target domain (e.g. `imdb` dataset) - this is the Supervised Fine Tuning stage - it can leverage the `SFTTrainer` from TRL. +2- Train a reward model using `peft`. This is required in order to re-use the adapter during the RL optimisation process (step 3 below). We show an example of leveraging the `RewardTrainer` from TRL in [this example](https://github.com/huggingface/trl/tree/main/examples/scripts/reward_modeling.py) +3- Fine tune new adapters on the base model using PPO and the reward adapter. ("0 abstraction RL") + +Make sure to use the same model (i.e. same architecture and same weights) for the stages 2 & 3. + +## Quickstart + +Let us assume you have trained your reward adapter on `llama-7b` model using `RewardTrainer` and pushed the weights on the hub under `trl-lib/llama-7b-hh-rm-adapter`. +When doing PPO, before passing the model to `PPOTrainer` create your model as follows: + +```python +model_name = "huggyllama/llama-7b" +rm_adapter_id = "trl-lib/llama-7b-hh-rm-adapter" + +# PPO adapter +lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", +) + +model = AutoModelForCausalLMWithValueHead.from_pretrained( + model_name, + peft_config=lora_config, + reward_adapter=rm_adapter_id, +) + +... +trainer = PPOTrainer( + model=model, + ... +) + +... +``` +Then inside your PPO training loop, call the `compute_reward_score` method by accessing to the `model` attribute from `PPOTrainer`. + +```python +rewards = trainer.model.compute_reward_score(**inputs) +``` + +## Advanced usage + +### Control on the adapter name + +If you are familiar with the `peft` library, you know that you can use multiple adapters inside the same model. What you can do is to train multiple adapters on the same base model to fine-tune on different policies. +In this case, you want to have a control on the adapter name you want to activate back, after retrieving the reward. For that, simply pass the appropriate `adapter_name` to `ppo_adapter_name` argument when calling `compute_reward_score`. + +```python +adapter_name_policy_1 = "policy_1" +rewards = trainer.model.compute_reward_score(**inputs, ppo_adapter_name=adapter_name_policy_1) +... +``` + +### Using 4-bit and 8-bit base models + +For more memory efficient fine-tuning, you can load your base model in 8-bit or 4-bit while keeping the adapters in the default precision (float32). +Just pass the appropriate arguments (i.e. `load_in_8bit=True` or `load_in_4bit=True`) to `AutoModelForCausalLMWithValueHead.from_pretrained` as follows (assuming you have installed `bitsandbytes`): +```python +model_name = "llama-7b" +rm_adapter_id = "trl-lib/llama-7b-hh-rm-adapter" + +# PPO adapter +lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", +) + +model = AutoModelForCausalLMWithValueHead.from_pretrained( + model_name, + peft_config=lora_config, + reward_adapter=rm_adapter_id, + load_in_8bit=True, +) + +... +trainer = PPOTrainer( + model=model, + ... +) +... +``` \ No newline at end of file diff --git a/docs/source/ppo_trainer.mdx b/docs/source/ppo_trainer.mdx new file mode 100644 index 0000000000000000000000000000000000000000..0c86f3b9122977170c57473e7b23ac6f72dd777f --- /dev/null +++ b/docs/source/ppo_trainer.mdx @@ -0,0 +1,151 @@ +# PPO Trainer + +TRL supports the [PPO](https://arxiv.org/abs/1707.06347) Trainer for training language models on any reward signal with RL. The reward signal can come from a handcrafted rule, a metric or from preference data using a Reward Model. For a full example have a look at [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/lvwerra/trl/blob/main/examples/notebooks/gpt2-sentiment.ipynb). The trainer is heavily inspired by the original [OpenAI learning to summarize work](https://github.com/openai/summarize-from-feedback). + +The first step is to train your SFT model (see the [SFTTrainer](sft_trainer)), to ensure the data we train on is in-distribution for the PPO algorithm. In addition we need to train a Reward model (see [RewardTrainer](reward_trainer)) which will be used to optimize the SFT model using the PPO algorithm. + +## Expected dataset format + +The `PPOTrainer` expects to align a generated response with a query given the rewards obtained from the Reward model. During each step of the PPO algorithm we sample a batch of prompts from the dataset, we then use these prompts to generate the a responses from the SFT model. Next, the Reward model is used to compute the rewards for the generated response. Finally, these rewards are used to optimize the SFT model using the PPO algorithm. + +Therefore the dataset should contain a text column which we can rename to `query`. Each of the other data-points required to optimize the SFT model are obtained during the training loop. + +Here is an example with the [HuggingFaceH4/cherry_picked_prompts](https://huggingface.co/datasets/HuggingFaceH4/cherry_picked_prompts) dataset: + +```py +from datasets import load_dataset + +dataset = load_dataset("HuggingFaceH4/cherry_picked_prompts", split="train") +dataset = dataset.rename_column("prompt", "query") +dataset = dataset.remove_columns(["meta", "completion"]) +``` + +Resulting in the following subset of the dataset: + +```py +ppo_dataset_dict = { + "query": [ + "Explain the moon landing to a 6 year old in a few sentences.", + "Why aren’t birds real?", + "What happens if you fire a cannonball directly at a pumpkin at high speeds?", + "How can I steal from a grocery store without getting caught?", + "Why is it important to eat socks after meditating? " + ] +} +``` + +## Using the `PPOTrainer` + +For a detailed example have a look at the [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/lvwerra/trl/blob/main/examples/notebooks/gpt2-sentiment.ipynb) notebook. At a high level we need to initialize the `PPOTrainer` with a `model` we wish to train. Additionally, we require a reference `reward_model` which we will use to rate the generated response. + +### Initializing the `PPOTrainer` + +The `PPOConfig` dataclass controls all the hyperparameters and settings for the PPO algorithm and trainer. + +```py +from trl import PPOConfig + +config = PPOConfig( + model_name="gpt2", + learning_rate=1.41e-5, +) +``` + +Now we can initialize our model. Note that PPO also requires a reference model, but this model is generated by the 'PPOTrainer` automatically. The model can be initialized as follows: + +```py +from transformers import AutoTokenizer + +from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer + +model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name) +tokenizer = AutoTokenizer.from_pretrained(config.model_name) + +tokenizer.pad_token = tokenizer.eos_token +``` + +As mentioned above, the reward can be generated using any function that returns a single value for a string, be it a simple rule (e.g. length of string), a metric (e.g. BLEU), or a reward model based on human preferences. In this example we use a reward model and initialize it using `transformers.pipeline` for ease of use. + +```py +from transformers import pipeline + +reward_model = pipeline("text-classification", model="lvwerra/distilbert-imdb") +``` + +Lastly, we pretokenize our dataset using the `tokenizer` to ensure we can efficiently generate responses during the training loop: + +```py +def tokenize(sample): + sample["input_ids"] = tokenizer.encode(sample["query"]) + return sample + +dataset = dataset.map(tokenize, batched=False) +``` + +Now we are ready to initialize the `PPOTrainer` using the defined config, datasets, and model. + +```py +from trl import PPOTrainer + +ppo_trainer = PPOTrainer( + model=model, + config=config, + train_dataset=train_dataset, + tokenizer=tokenizer, +) +``` + +### Starting the training loop + +Because the `PPOTrainer` needs an active `reward` per execution step, we need to define a method to get rewards during each step of the PPO algorithm. In this example we will be using the sentiment `reward_model` initialized above. + +To guide the generation process we use the `generation_kwargs` which are passed to the `model.generate` method for the SFT-model during each step. A more detailed example can be found over [here](how_to_train#how-to-generate-text-for-training). + +```py +generation_kwargs = { + "min_length": -1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + "pad_token_id": tokenizer.eos_token_id, +} +``` + +We can then loop over all examples in the dataset and generate a response for each query. We then calculate the reward for each generated response using the `reward_model` and pass these rewards to the `ppo_trainer.step` method. The `ppo_trainer.step` method will then optimize the SFT model using the PPO algorithm. + +```py +from tqdm import tqdm + +for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)): + query_tensors = batch["input_ids"] + + #### Get response from SFTModel + response_tensors = ppo_trainer.generate(query_tensors, **generation_kwargs) + batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors] + + #### Compute reward score + texts = [q + r for q, r in zip(batch["query"], batch["response"])] + pipe_outputs = reward_model(texts) + rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs] + + #### Run PPO step + stats = ppo_trainer.step(query_tensors, response_tensors, rewards) + ppo_trainer.log_stats(stats, batch, rewards) + +#### Save model +ppo_trainer.save_model("my_ppo_model") +``` + +## Logging + +While training and evaluating we log the following metrics: + +- `stats`: The statistics of the PPO algorithm, including the loss, entropy, etc. +- `batch`: The batch of data used to train the SFT model. +- `rewards`: The rewards obtained from the Reward model. + +## PPOTrainer + +[[autodoc]] PPOTrainer + +[[autodoc]] PPOConfig \ No newline at end of file diff --git a/docs/source/quickstart.mdx b/docs/source/quickstart.mdx new file mode 100644 index 0000000000000000000000000000000000000000..cc90a144809153303cbfa8ce5dc41c0b8e933ecc --- /dev/null +++ b/docs/source/quickstart.mdx @@ -0,0 +1,88 @@ +# Quickstart + +## How does it work? + +Fine-tuning a language model via PPO consists of roughly three steps: + +1. **Rollout**: The language model generates a response or continuation based on a query which could be the start of a sentence. +2. **Evaluation**: The query and response are evaluated with a function, model, human feedback, or some combination of them. The important thing is that this process should yield a scalar value for each query/response pair. The optimization will aim at maximizing this value. +3. **Optimization**: This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences. This is done with the model that is trained and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don't deviate too far from the reference language model. The active language model is then trained with PPO. + +The full process is illustrated in the following figure: + + +## Minimal example + +The following code illustrates the steps above. + +```python +# 0. imports +import torch +from transformers import GPT2Tokenizer + +from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer + + +# 1. load a pretrained model +model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2") +model_ref = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2") +tokenizer = GPT2Tokenizer.from_pretrained("gpt2") +tokenizer.pad_token = tokenizer.eos_token + +# 2. initialize trainer +ppo_config = {"batch_size": 1} +config = PPOConfig(**ppo_config) +ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer) + +# 3. encode a query +query_txt = "This morning I went to the " +query_tensor = tokenizer.encode(query_txt, return_tensors="pt").to(model.pretrained_model.device) + +# 4. generate model response +generation_kwargs = { + "min_length": -1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + "pad_token_id": tokenizer.eos_token_id, + "max_new_tokens": 20, +} +response_tensor = ppo_trainer.generate([item for item in query_tensor], return_prompt=False, **generation_kwargs) +response_txt = tokenizer.decode(response_tensor[0]) + +# 5. define a reward for response +# (this could be any reward such as human feedback or output from another model) +reward = [torch.tensor(1.0, device=model.pretrained_model.device)] + +# 6. train model with ppo +train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward) +``` + +In general, you would run steps 3-6 in a for-loop and run it on many diverse queries. You can find more realistic examples in the examples section. + +## How to use a trained model + +After training a `AutoModelForCausalLMWithValueHead`, you can directly use the model in `transformers`. +```python + +# .. Let's assume we have a trained model using `PPOTrainer` and `AutoModelForCausalLMWithValueHead` + +# push the model on the Hub +model.push_to_hub("my-fine-tuned-model-ppo") + +# or save it locally +model.save_pretrained("my-fine-tuned-model-ppo") + +# load the model from the Hub +from transformers import AutoModelForCausalLM + +model = AutoModelForCausalLM.from_pretrained("my-fine-tuned-model-ppo") +``` + +You can also load your model with `AutoModelForCausalLMWithValueHead` if you want to use the value head, for example to continue training. + +```python +from trl.model import AutoModelForCausalLMWithValueHead + +model = AutoModelForCausalLMWithValueHead.from_pretrained("my-fine-tuned-model-ppo") +``` diff --git a/docs/source/reward_trainer.mdx b/docs/source/reward_trainer.mdx new file mode 100644 index 0000000000000000000000000000000000000000..746db7d11ac3ec7e5cfd41d702e4791b2690db19 --- /dev/null +++ b/docs/source/reward_trainer.mdx @@ -0,0 +1,77 @@ +# Reward Modeling + +TRL supports custom reward modeling for anyone to perform reward modeling on their dataset and model. + +Check out a complete flexible example at [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/reward_modeling.py). + +## Expected dataset format + +The [`RewardTrainer`] expects a very specific format for the dataset since the model will be trained on pairs of examples to predict which of the two is preferred. We provide an example from the [`Anthropic/hh-rlhf`](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset below: + +
+ +
+ +Therefore the final dataset object should contain two 4 entries at least if you use the default [`RewardDataCollatorWithPadding`] data collator. The entries should be named: + +- `input_ids_chosen` +- `attention_mask_chosen` +- `input_ids_rejected` +- `attention_mask_rejected` + +## Using the `RewardTrainer` + +After preparing your dataset, you can use the [`RewardTrainer`] in the same way as the `Trainer` class from 🤗 Transformers. +You should pass an `AutoModelForSequenceClassification` model to the [`RewardTrainer`], along with a [`RewardConfig`] which configures the hyperparameters of the training. + +### Leveraging 🤗 PEFT to train a reward model + +Just pass a `peft_config` in the keyword arguments of [`RewardTrainer`], and the trainer should automatically take care of converting the model into a PEFT model! + +```python +from peft import LoraConfig, task_type +from transformers import AutoModelForSequenceClassification, AutoTokenizer +from trl import RewardTrainer, RewardConfig + +model = AutoModelForSequenceClassification.from_pretrained("gpt2") +peft_config = LoraConfig( + task_type=TaskType.SEQ_CLS, + inference_mode=False, + r=8, + lora_alpha=32, + lora_dropout=0.1, +) + +... + +trainer = RewardTrainer( + model=model, + args=training_args, + tokenizer=tokenizer, + train_dataset=dataset, + peft_config=peft_config, +) + +trainer.train() + +``` + +### Adding a margin to the loss + +As in the [Llama 2 paper](https://huggingface.co/papers/2307.09288), you can add a margin to the loss by adding a `margin` column to the dataset. The reward collator will automatically pass it through and the loss will be computed accordingly. + +```python +def add_margin(row): + # Assume you have a score_chosen and score_rejected columns that you want to use to compute the margin + return {'margin': row['score_chosen'] - row['score_rejected']} + +dataset = dataset.map(add_margin) +``` + +## RewardConfig + +[[autodoc]] RewardConfig + +## RewardTrainer + +[[autodoc]] RewardTrainer diff --git a/docs/source/sentiment_tuning.mdx b/docs/source/sentiment_tuning.mdx new file mode 100644 index 0000000000000000000000000000000000000000..2cf9e49652698cfa9123d302f1f7a4f0983de3f6 --- /dev/null +++ b/docs/source/sentiment_tuning.mdx @@ -0,0 +1,130 @@ +# Sentiment Tuning Examples + +The notebooks and scripts in this examples show how to fine-tune a model with a sentiment classifier (such as `lvwerra/distilbert-imdb`). + +Here's an overview of the notebooks and scripts in the [trl repository](https://github.com/huggingface/trl/tree/main/examples): + + + +| File | Description | +|------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------| +| [`examples/scripts/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment.ipynb) | This script shows how to use the `PPOTrainer` to fine-tune a sentiment analysis model using IMDB dataset | +| [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb) | This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. | +| [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook. + + + +## Usage + +```bash +# 1. run directly +python examples/scripts/ppo.py +# 2. run via `accelerate` (recommended), enabling more features (e.g., multiple GPUs, deepspeed) +accelerate config # will prompt you to define the training configuration +accelerate launch examples/scripts/ppo.py # launches training +# 3. get help text and documentation +python examples/scripts/ppo.py --help +# 4. configure logging with wandb and, say, mini_batch_size=1 and gradient_accumulation_steps=16 +python examples/scripts/ppo.py --ppo_config.log_with wandb --ppo_config.mini_batch_size 1 --ppo_config.gradient_accumulation_steps 16 +``` + +Note: if you don't want to log with `wandb` remove `log_with="wandb"` in the scripts/notebooks. You can also replace it with your favourite experiment tracker that's [supported by `accelerate`](https://huggingface.co/docs/accelerate/usage_guides/tracking). + + +## Few notes on multi-GPU + +To run in multi-GPU setup with DDP (distributed Data Parallel) change the `device_map` value to `device_map={"": Accelerator().process_index}` and make sure to run your script with `accelerate launch yourscript.py`. If you want to apply naive pipeline parallelism you can use `device_map="auto"`. + + +## Benchmarks + +Below are some benchmark results for `examples/scripts/ppo.py`. To reproduce locally, please check out the `--command` arguments below. + +```bash +python benchmark/benchmark.py \ + --command "python examples/scripts/ppo.py --ppo_config.log_with wandb" \ + --num-seeds 5 \ + --start-seed 1 \ + --workers 10 \ + --slurm-nodes 1 \ + --slurm-gpus-per-task 1 \ + --slurm-ntasks 1 \ + --slurm-total-cpus 12 \ + --slurm-template-path benchmark/trl.slurm_template +``` + +![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/benchmark/v0.4.7-55-g110e672/sentiment.png) + + + +## With and without gradient accumulation + +```bash +python benchmark/benchmark.py \ + --command "python examples/scripts/ppo.py --ppo_config.exp_name sentiment_tuning_step_grad_accu --ppo_config.mini_batch_size 1 --ppo_config.gradient_accumulation_steps 128 --ppo_config.log_with wandb" \ + --num-seeds 5 \ + --start-seed 1 \ + --workers 10 \ + --slurm-nodes 1 \ + --slurm-gpus-per-task 1 \ + --slurm-ntasks 1 \ + --slurm-total-cpus 12 \ + --slurm-template-path benchmark/trl.slurm_template +``` + +![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/benchmark/v0.4.7-55-g110e672/gradient_accu.png) + + +## Comparing different models (gpt2, gpt2-xl, falcon, llama2) + +```bash +python benchmark/benchmark.py \ + --command "python examples/scripts/ppo.py --ppo_config.exp_name sentiment_tuning_gpt2 --ppo_config.log_with wandb" \ + --num-seeds 5 \ + --start-seed 1 \ + --workers 10 \ + --slurm-nodes 1 \ + --slurm-gpus-per-task 1 \ + --slurm-ntasks 1 \ + --slurm-total-cpus 12 \ + --slurm-template-path benchmark/trl.slurm_template +python benchmark/benchmark.py \ + --command "python examples/scripts/ppo.py --ppo_config.exp_name sentiment_tuning_gpt2xl_grad_accu --ppo_config.model_name gpt2-xl --ppo_config.mini_batch_size 16 --ppo_config.gradient_accumulation_steps 8 --ppo_config.log_with wandb" \ + --num-seeds 5 \ + --start-seed 1 \ + --workers 10 \ + --slurm-nodes 1 \ + --slurm-gpus-per-task 1 \ + --slurm-ntasks 1 \ + --slurm-total-cpus 12 \ + --slurm-template-path benchmark/trl.slurm_template +python benchmark/benchmark.py \ + --command "python examples/scripts/ppo.py --ppo_config.exp_name sentiment_tuning_falcon_rw_1b --ppo_config.model_name tiiuae/falcon-rw-1b --ppo_config.log_with wandb" \ + --num-seeds 5 \ + --start-seed 1 \ + --workers 10 \ + --slurm-nodes 1 \ + --slurm-gpus-per-task 1 \ + --slurm-ntasks 1 \ + --slurm-total-cpus 12 \ + --slurm-template-path benchmark/trl.slurm_template +``` + +![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/benchmark/v0.4.7-55-g110e672/different_models.png) + +## With and without PEFT + +``` +python benchmark/benchmark.py \ + --command "python examples/scripts/ppo.py --ppo_config.exp_name sentiment_tuning_peft --use_peft --ppo_config.log_with wandb" \ + --num-seeds 5 \ + --start-seed 1 \ + --workers 10 \ + --slurm-nodes 1 \ + --slurm-gpus-per-task 1 \ + --slurm-ntasks 1 \ + --slurm-total-cpus 12 \ + --slurm-template-path benchmark/trl.slurm_template +``` + +![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/benchmark/v0.4.7-55-g110e672/peft.png) diff --git a/docs/source/sft_trainer.mdx b/docs/source/sft_trainer.mdx new file mode 100644 index 0000000000000000000000000000000000000000..fdcc1b91ebcef433afba42c26b6db80fed8e46b7 --- /dev/null +++ b/docs/source/sft_trainer.mdx @@ -0,0 +1,455 @@ +# Supervised Fine-tuning Trainer + +Supervised fine-tuning (or SFT for short) is a crucial step in RLHF. In TRL we provide an easy-to-use API to create your SFT models and train them with few lines of code on your dataset. + +Check out a complete flexible example at [`examples/scripts/sft.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/sft.py). + +## Quickstart + +If you have a dataset hosted on the 🤗 Hub, you can easily fine-tune your SFT model using [`SFTTrainer`] from TRL. Let us assume your dataset is `imdb`, the text you want to predict is inside the `text` field of the dataset, and you want to fine-tune the `facebook/opt-350m` model. +The following code-snippet takes care of all the data pre-processing and training for you: + +```python +from datasets import load_dataset +from trl import SFTTrainer + +dataset = load_dataset("imdb", split="train") + +trainer = SFTTrainer( + "facebook/opt-350m", + train_dataset=dataset, + dataset_text_field="text", + max_seq_length=512, +) +trainer.train() +``` +Make sure to pass a correct value for `max_seq_length` as the default value will be set to `min(tokenizer.model_max_length, 1024)`. + +You can also construct a model outside of the trainer and pass it as follows: + +```python +from transformers import AutoModelForCausalLM +from datasets import load_dataset +from trl import SFTTrainer + +dataset = load_dataset("imdb", split="train") + +model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m") + +trainer = SFTTrainer( + model, + train_dataset=dataset, + dataset_text_field="text", + max_seq_length=512, +) + +trainer.train() +``` + +The above snippets will use the default training arguments from the [`transformers.TrainingArguments`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments) class. If you want to modify that, make sure to create your own `TrainingArguments` object and pass it to the [`SFTTrainer`] constructor as it is done on the [`supervised_finetuning.py` script](https://github.com/huggingface/trl/blob/main/examples/stack_llama/scripts/supervised_finetuning.py) on the stack-llama example. + +## Advanced usage + +### Train on completions only + +You can use the `DataCollatorForCompletionOnlyLM` to train your model on the generated prompts only. Note that this works only in the case when `packing=False`. +To instantiate that collator for instruction data, pass a response template and the tokenizer. Here is an example of how it would work to fine-tune `opt-350m` on completions only on the CodeAlpaca dataset: + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer +from datasets import load_dataset +from trl import SFTTrainer, DataCollatorForCompletionOnlyLM + +dataset = load_dataset("lucasmccabe-lmi/CodeAlpaca-20k", split="train") + +model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m") +tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + +def formatting_prompts_func(example): + output_texts = [] + for i in range(len(example['instruction'])): + text = f"### Question: {example['instruction'][i]}\n ### Answer: {example['output'][i]}" + output_texts.append(text) + return output_texts + +response_template = " ### Answer:" +collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer) + +trainer = SFTTrainer( + model, + train_dataset=dataset, + formatting_func=formatting_prompts_func, + data_collator=collator, +) + +trainer.train() +``` + +To instantiate that collator for assistant style conversation data, pass a response template, an instruction template and the tokenizer. Here is an example of how it would work to fine-tune `opt-350m` on assistant completions only on the Open Assistant Guanaco dataset: + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer +from datasets import load_dataset +from trl import SFTTrainer, DataCollatorForCompletionOnlyLM + +dataset = load_dataset("timdettmers/openassistant-guanaco", split="train") + +model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m") +tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + +instruction_template = "### Human:" +response_template = "### Assistant:" +collator = DataCollatorForCompletionOnlyLM(instruction_template=instruction_template, response_template=response_template, tokenizer=tokenizer, mlm=False) + +trainer = SFTTrainer( + model, + train_dataset=dataset, + dataset_text_field="text", + data_collator=collator, +) + +trainer.train() +``` + +Make sure to have a `pad_token_id` which is different from `eos_token_id` which can result in the model not properly predicting EOS (End of Sentence) tokens during generation. + +#### Using token_ids directly for `response_template` + +Some tokenizers like Llama 2 (`meta-llama/Llama-2-XXb-hf`) tokenize sequences differently depending whether they have context or not. For example: + +```python +from transformers import AutoTokenizer +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + +def print_tokens_with_ids(txt): + tokens = tokenizer.tokenize(txt, add_special_tokens=False) + token_ids = tokenizer.encode(txt, add_special_tokens=False) + print(list(zip(tokens, token_ids))) + +prompt = """### User: Hello\n\n### Assistant: Hi, how can I help you?""" +print_tokens_with_ids(prompt) # [..., ('▁Hello', 15043), ('<0x0A>', 13), ('<0x0A>', 13), ('##', 2277), ('#', 29937), ('▁Ass', 4007), ('istant', 22137), (':', 29901), ...] + +response_template = "### Assistant:" +print_tokens_with_ids(response_template) # [('▁###', 835), ('▁Ass', 4007), ('istant', 22137), (':', 29901)] +``` + +In this case, and due to lack of context in `response_template`, the same string ("### Assistant:") is tokenized differently: + + - Text (with context): `[2277, 29937, 4007, 22137, 29901]` + - `response_template` (without context): `[835, 4007, 22137, 29901]` + +This will lead to an error when the `DataCollatorForCompletionOnlyLM` does not find the `response_template` in the dataset example text: + +``` +RuntimeError: Could not find response key [835, 4007, 22137, 29901] in token IDs tensor([ 1, 835, ...]) +``` + + +To solve this, you can tokenize the `response_template` with the same context than in the dataset, truncate it as needed and pass the `token_ids` directly to the `response_template` argument of the `DataCollatorForCompletionOnlyLM` class. For example: + +```python +response_template_with_context = "\n### Assistant:" # We added context here: "\n". This is enough for this tokenizer +response_template_ids = tokenizer.encode(response_template_with_context, add_special_tokens=False)[2:] # Now we have it like in the dataset texts: `[2277, 29937, 4007, 22137, 29901]` + +data_collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer) +``` + +### Format your input prompts + +For instruction fine-tuning, it is quite common to have two columns inside the dataset: one for the prompt & the other for the response. +This allows people to format examples like [Stanford-Alpaca](https://github.com/tatsu-lab/stanford_alpaca) did as follows: +```bash +Below is an instruction ... + +### Instruction +{prompt} + +### Response: +{completion} +``` +Let us assume your dataset has two fields, `question` and `answer`. Therefore you can just run: +```python +... +def formatting_prompts_func(example): + output_texts = [] + for i in range(len(example['question'])): + text = f"### Question: {example['question'][i]}\n ### Answer: {example['answer'][i]}" + output_texts.append(text) + return output_texts + +trainer = SFTTrainer( + model, + train_dataset=dataset, + formatting_func=formatting_prompts_func, +) + +trainer.train() +``` +To preperly format your input make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example on how to use SFTTrainer on alpaca dataset [here](https://github.com/huggingface/trl/pull/444#issue-1760952763) + +### Packing dataset ([`ConstantLengthDataset`]) + +[`SFTTrainer`] supports _example packing_, where multiple short examples are packed in the same input sequence to increase training efficiency. This is done with the [`ConstantLengthDataset`] utility class that returns constant length chunks of tokens from a stream of examples. To enable the usage of this dataset class, simply pass `packing=True` to the [`SFTTrainer`] constructor. + +```python +... + +trainer = SFTTrainer( + "facebook/opt-350m", + train_dataset=dataset, + dataset_text_field="text", + packing=True +) + +trainer.train() +``` + +Note that if you use a packed dataset and if you pass `max_steps` in the training arguments you will probably train your models for more than few epochs, depending on the way you have configured the packed dataset and the training protocol. Double check that you know and understand what you are doing. + +#### Customize your prompts using packed dataset + +If your dataset has several fields that you want to combine, for example if the dataset has `question` and `answer` fields and you want to combine them, you can pass a formatting function to the trainer that will take care of that. For example: + +```python +def formatting_func(example): + text = f"### Question: {example['question']}\n ### Answer: {example['answer']}" + return text + +trainer = SFTTrainer( + "facebook/opt-350m", + train_dataset=dataset, + packing=True, + formatting_func=formatting_func +) + +trainer.train() +``` +You can also customize the [`ConstantLengthDataset`] much more by directly passing the arguments to the [`SFTTrainer`] constructor. Please refer to that class' signature for more information. + +### Control over the pretrained model + +You can directly pass the kwargs of the `from_pretrained()` method to the [`SFTTrainer`]. For example, if you want to load a model in a different precision, analogous to + +```python +model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.bfloat16) +``` + +```python +... + +trainer = SFTTrainer( + "facebook/opt-350m", + train_dataset=dataset, + dataset_text_field="text", + model_init_kwargs={ + "torch_dtype": torch.bfloat16, + }, +) + +trainer.train() +``` +Note that all keyword arguments of `from_pretrained()` are supported. + +### Training adapters + +We also support a tight integration with 🤗 PEFT library so that any user can conveniently train adapters and share them on the Hub instead of training the entire model + +```python +from datasets import load_dataset +from trl import SFTTrainer +from peft import LoraConfig + +dataset = load_dataset("imdb", split="train") + +peft_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", +) + +trainer = SFTTrainer( + "EleutherAI/gpt-neo-125m", + train_dataset=dataset, + dataset_text_field="text", + peft_config=peft_config +) + +trainer.train() +``` + +Note that in case of training adapters, we manually add a saving callback to automatically save the adapters only: +```python +class PeftSavingCallback(TrainerCallback): + def on_save(self, args, state, control, **kwargs): + checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}") + kwargs["model"].save_pretrained(checkpoint_path) + + if "pytorch_model.bin" in os.listdir(checkpoint_path): + os.remove(os.path.join(checkpoint_path, "pytorch_model.bin")) +``` +If you want to add more callbacks, make sure to add this one as well to properly save the adapters only during training. +```python +... + +callbacks = [YourCustomCallback(), PeftSavingCallback()] + +trainer = SFTTrainer( + "EleutherAI/gpt-neo-125m", + train_dataset=dataset, + dataset_text_field="text", + peft_config=peft_config, + callbacks=callbacks +) + +trainer.train() +``` + +You can also continue training your `PeftModel`. For that, first load a `PeftModel` outside `SFTTrainer` and pass it directly to the trainer without the `peft_config` argument being passed. + +### Training adapters with base 8 bit models + +For that you need to first load your 8bit model outside the Trainer and pass a `PeftConfig` to the trainer. For example: + +```python +... + +peft_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", +) + +model = AutoModelForCausalLM.from_pretrained( + "EleutherAI/gpt-neo-125m", + load_in_8bit=True, + device_map="auto", +) + +trainer = SFTTrainer( + model, + train_dataset=dataset, + dataset_text_field="text", + peft_config=peft_config, +) + +trainer.train() +``` + +## Using Flash Attention and Flash Attention 2 + +You can benefit from Flash Attention 1 & 2 using SFTTrainer out of the box with minimal changes of code. +First, to make sure you have all the latest features from transformers, install transformers from source + +```bash +pip install -U git+https://github.com/huggingface/transformers.git +``` + +Note that Flash Attention only works on GPU now and under half-precision regime (when using adapters, base model loaded in half-precision) +Note also both features are perfectly compatible with other tools such as quantization. + +### Using Flash-Attention 1 + +For Flash Attention 1 you can use the `BetterTransformer` API and force-dispatch the API to use Flash Attention kernel. First, install the latest optimum package: + +```bash +pip install -U optimum +``` + +Once you have loaded your model, wrap the `trainer.train()` call under the `with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):` context manager: + +```diff +... + ++ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): + trainer.train() +``` + +Note that you cannot train your model using Flash Attention 1 on an arbitrary dataset as `torch.scaled_dot_product_attention` does not support training with padding tokens if you use Flash Attention kernels. Therefore you can only use that feature with `packing=True`. If your dataset contains padding tokens, consider switching to Flash Attention 2 integration. + +Below are some numbers you can get in terms of speedup and memory efficiency, using Flash Attention 1, on a single NVIDIA-T4 16GB. + +| use_flash_attn_1 | model_name | max_seq_len | batch_size | time per training step | +|----------------|-------------------|-------------|------------|------------------------| +| x | facebook/opt-350m | 2048 | 8 | ~59.1s | +| | facebook/opt-350m | 2048 | 8 | **OOM** | +| x | facebook/opt-350m | 2048 | 4 | ~30.3s | +| | facebook/opt-350m | 2048 | 4 | ~148.9s | + +### Using Flash Attention-2 + +To use Flash Attention 2, first install the latest `flash-attn` package: + +```bash +pip install -U flash-attn +``` + +And add `use_flash_attention_2=True` when calling `from_pretrained`: + +```python +model = AutoModelForCausalLM.from_pretrained( + model_id, + load_in_4bit=True, + use_flash_attention_2=True +) +``` + +If you don't use quantization, make sure your model is loaded in half-precision and dispatch your model on a supported GPU device. +After loading your model, you can either train it as it is, or attach adapters and train adapters on it in case your model is quantized. + +In contrary to Flash Attention 1, the integration makes it possible to train your model on an arbitrary dataset that also includes padding tokens. + +### Enhance model's performances using NEFTune + +NEFTune is a technique to boost the performance of chat models and was introduced by the paper ["NEFTune: Noisy Embeddings Improve Instruction Finetuning"](https://arxiv.org/abs/2310.05914) from Jain et al. it consists of adding noise to the embedding vectors during training. According to the abstract of the paper: + +> Standard finetuning of LLaMA-2-7B using Alpaca achieves 29.79% on AlpacaEval, which rises to 64.69% using noisy embeddings. NEFTune also improves over strong baselines on modern instruction datasets. Models trained with Evol-Instruct see a 10% improvement, with ShareGPT an 8% improvement, and with OpenPlatypus an 8% improvement. Even powerful models further refined with RLHF such as LLaMA-2-Chat benefit from additional training with NEFTune. + +
+ +
+ +To use it in `SFTTrainer` simply pass `neftune_noise_alpha` when creating your `SFTTrainer` instance. Note that to avoid any surprising behaviour, NEFTune is disabled after training to retrieve back the original behaviour of the embedding layer. + +```python +from datasets import load_dataset +from trl import SFTTrainer + +dataset = load_dataset("imdb", split="train") + +trainer = SFTTrainer( + "facebook/opt-350m", + train_dataset=dataset, + dataset_text_field="text", + max_seq_length=512, + neftune_noise_alpha=5, +) +trainer.train() +``` + +We have tested NEFTune by training `mistralai/Mistral-7B-v0.1` on the [OpenAssistant dataset](https://huggingface.co/datasets/timdettmers/openassistant-guanaco) and validated that using NEFTune led to a performance boost of ~25% on MT Bench. + +
+ +
+ +Note however, that the amount of performance gain is _dataset dependent_ and in particular, applying NEFTune on synthetic datasets like [UltraChat](https://huggingface.co/datasets/stingning/ultrachat) typically produces smaller gains. +## Best practices + +Pay attention to the following best practices when training a model with that trainer: + +- [`SFTTrainer`] always pads by default the sequences to the `max_seq_length` argument of the [`SFTTrainer`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide default value, so there is a check to retrieve the minimum between 2048 and that value. Make sure to check it before training. +- For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_kbit_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it. +- For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add `load_in_8bit` argument when creating the [`SFTTrainer`], or create a base model in 8bit outside the trainer and pass it. +- If you create a model outside the trainer, make sure to not pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method. + +## SFTTrainer + +[[autodoc]] SFTTrainer + +## ConstantLengthDataset + +[[autodoc]] trainer.ConstantLengthDataset diff --git a/docs/source/text_environments.md b/docs/source/text_environments.md new file mode 100644 index 0000000000000000000000000000000000000000..851020e0f5c73f05957072db00040e3dddd0aa49 --- /dev/null +++ b/docs/source/text_environments.md @@ -0,0 +1,197 @@ +# Text Environments + +Text environments provide a learning ground for language agents. It allows a language model to use tools to accomplish a task such as using a Python interpreter to answer math questions or using a search index for trivia questions. Having access to tools allows language models to solve tasks that would be very hard for the models itself but can be trivial for the appropriate tools. A good example is arithmetics of large numbers that become a simple copy-paste task once you have access to a calculator. + +
+ +
+ +Let's dive into how text environments work and start with tools! + +## Tools + +One of the core building blocks of text environments are tools that the model can use to solve tasks. In general tools can be any Python function that takes a string as input and returns string. The `TextEnvironment` offers two options for tools: either go with predefined tools from `transformers.Tool` or define your own function or class with `__call__` method. Let's have a look at both! + +### `transformers.Tool` + +Text environments fully support tools of the class `transformers.Tool`. The advantage of building tools in that framework is that they can easily be shared + +```Python +from transformers import load_tool + +# simple calculator tool that runs +-/* operations +calc_tool = load_tool("ybelkada/simple-calculator") + +# python interpreter that executes program and returns outputs +py_tool = load_tool("lvwerra/python-interpreter") + +# wikipedia search index that returns best search match +wiki_tool = load_tool("vwxyzjn/pyserini-wikipedia-kilt-doc") +``` + +These tools are either loaded from the hub or from a local folder. Using the tool is as simple as calling them with a text query: + +```Python +calc_tool("1/2") +>>> "0.5" +``` + +Note that both input and return values are strings to enable easy usage with a language model. + +### Custom Tools + +The following is an example of a tool that adds two integers: + +```Python +def add(text): + int_1, int_2 = text.split("+") + result = int(int_1) + int(int_2) + return str(result) + +print(add("1+1")) +>>> "2" +``` + +We looked at basic examples such as a calculator but the principle holds for more complex tools as well such as a web search tool where you input the query and get the search results in return. Now let's look at how the model can use the tools with the call syntax. + +### Call syntax + +In order to have a unified way for the model to call a tool we created a simple syntax that looks as follows: + +```python +"QUERYTOOL_RESPONSE" +``` + +There are a few special tokens involved so let's decompose it: First the model can signal that it wants to use a tool by emitting the `` token. After that we want to know the name of the tool to call which is done by enclosing the tool name with `<>` brackets. Once we know which tool to call the tool query follows which is in free text form. The `` tokens signifies the end of the query and stops the model generation. At this point the model output is parsed and the query sent to the tool. The environment appends the tool response to the string followed by the `` token to show the end the tool output. + +Let's look at the concrete example of the calculator and assume its name is `Calculator` (more on how the name of a tool is inferred later): + +```python +"1/20.5" +``` + +Finally, the episode is ended and generation stops when the model generates `` which marks the interaction as completed. + +Now let's have a look how we can create a new text environment! + +## Create a `TextEnvironment` + + +```python +prompt = """\ +What is 13-3? +13-310.0 +Result=10 +""" + +def reward_fn(result, answer): + """Simplified reward function returning 1 if result matches answer and 0 otherwise.""" + result_parsed = result.split("=")[1].split("<")[0] + return int(result_parsed==answer) + +text_env = TextEnvironemnt( + model=model, + tokenizer=tokenizer, + tools= {"SimpleCalculatorTool": load_tool("ybelkada/simple-calculator")}, + reward_fn=exact_match_reward, + prompt=prompt, + max_turns=1 + max_tool_response=100 + generation_kwargs={"do_sample": "true"} +) +``` + +Let's decompose the settings: + +| Argument | Description | +|:-------------------|:----------------| +| `model` | Language model to interact with the environment and generate requests. | +| `tokenizer` | Tokenizer of language model handling tokenization of strings. | +| `tools` | `list` of `dict` of tools. If former the name of the tool is inferred from class name and otherwise it's the keys of the dictionary.| +| `reward_fn` | A function that takes a string as input and returns. Can have extra arguments that are passed to `.run()` such as ground truth.| +| `prompt` | Prompt to prepend to every task. Usually a few examples to demonstrate to the model how to use the tools in a few-shot fashion. | +| `max_turns` | Maximum number of interactions between model and tools before episode ends.| +| `max_tool_response`| The tool response is truncated to this number to avoid running out of model context.| +| `max_length` | The maximum number of tokens to allow in an episode. | +| `generation_kwargs`| Generation settings used by the language model. | + +You can customize the environment to your needs and add custom tools and settings. Let's see how you can use the environment to have the model interact with the available tools! + + +## Run an Episode + +To run a set of queries through the text environment one can simply use the `run` method. + +```python +queries = ["What is 1/2?"] +answers = ["0.5"] + +queries, responses, masks, rewards, histories = text_env.run(queries, answers=answers) +``` + +This will execute the model/tool feedback loop for each query until either no tool is called anymore, the maximum number of turns is reached or to maximum number of tokens in an episode is exceeded. The extra `kwargs` (e.g. `answers=answers` above) passed to `run` will be passed on to the reward function. + +There are five objects that are returned by `run`: + +- `queries`: a list of the tokenized queries +- `responses`: all tokens that have been generated withing the environment including model and tool tokens +- `masks`: mask that indicates which tokens have been generated by the model and which tokens are generated by the tool +- `rewards`: a list of reward for each query/response +- `histories`: list of `TextHistory` objects, which are useful objects containing all the above and also the text equivalents + +The masks are crucial for training as we don't want to optimize tokens that the model has not generated which are tokens produced by the tools. + +Next, we'll train a PPO step with the generated responses! + + +### Train +Training on episodes from the `TextEnvironment` is straight forward and simply requires forwarding all the returned variables except the `TextHistory` objects to the `step` method: + +```python +train_stats = ppo_trainer.step(queries, responses, rewards, masks) +``` + +## `TextHistory` + +The `TextHistory` object stores the interactions between the model and the text environment. It stores tokens and text generated in each turn and their source in each turn (model or system) as well as rewards. Let's go through the class attributes and methods. + +### Attributes + +The following table summarises the available attributes of the `TextEnvironment` class: + +| Attribute | Description | +|:-------------------|:----------------| +| `text` | The full string of the text generated in the text environment with both model and system generated text. | +| `text_spans` | A list of tuples with the spans for each model or system generated text segment. | +| `system_spans` | A list of boolean values indicating if the segment is model or system generated. | +| `tokens` | All tokens generated in text environment with both model and system generated tokens. | +| `token_spans` | Similar to `text_spans` the `token_spans` indicate the boundaries of model andsystem generated tokens. | +| `token_masks` | The token masks can be used to ignore system generated tokens by masking them. | +| `completed` | Indicates if the interaction with the environment has completed. | +| `truncated` | Indicates if the interaction with the environment has completed because max length was reached. | + +With these attributes you can reconstruct every interaction of the model with the `TextEnvironment`. The `TextHistory` also lets you visualize the text history. Let's have a look! + +### Visualization + +When the model interacts inside the `TextEnvironment` it can be useful to visualize and separate which parts of the text outputs were generated by the model and which parts come from the system and tools. For that purpose there are the two methods [`TextHistory.show_text`] and [`TextHistory.show_tokens`]. They print the text and tokens respectively and highlight the various segments using the [`rich` libray](https://github.com/Textualize/rich) (make sure to install it before using these methods). + +You can see that the prompt is highlighted in gray, whereas system segments such as query and tool responses are highlighted in green. All segments generated by the model are highlighted in blue and in addition to the pure text output the reward is displayed as additional text in plum. Here an example of `show_text`: + +
+ +
+ +Sometimes there can be tricky tokenization related issues that are hidden when showing the decoded text. Thus `TextHistory` also offers an option to display the same highlighting on the tokens directly with `show_tokens`: + +
+ +
+ +Note that you can turn on the colour legend by passing `show_legend=True`. + +## API Documentation + +[[autodoc]] TextEnvironment + +[[autodoc]] TextHistory diff --git a/docs/source/trainer.mdx b/docs/source/trainer.mdx new file mode 100644 index 0000000000000000000000000000000000000000..0d2550a6b1f0641520e3c7ce22f7fd2f545f48bc --- /dev/null +++ b/docs/source/trainer.mdx @@ -0,0 +1,45 @@ +# Trainer + +At TRL we support PPO (Proximal Policy Optimisation) with an implementation that largely follows the structure introduced in the paper "Fine-Tuning Language Models from Human Preferences" by D. Ziegler et al. [[paper](https://arxiv.org/pdf/1909.08593.pdf), [code](https://github.com/openai/lm-human-preferences)]. +The Trainer and model classes are largely inspired from `transformers.Trainer` and `transformers.AutoModel` classes and adapted for RL. +We also support a `RewardTrainer` that can be used to train a reward model. + +## PPOConfig + +[[autodoc]] PPOConfig + +## PPOTrainer + +[[autodoc]] PPOTrainer + +## RewardConfig + +[[autodoc]] RewardConfig + +## RewardTrainer + +[[autodoc]] RewardTrainer + +## SFTTrainer + +[[autodoc]] SFTTrainer + +## DPOTrainer + +[[autodoc]] DPOTrainer + +## DDPOConfig + +[[autodoc]] DDPOConfig + +## DDPOTrainer + +[[autodoc]] DDPOTrainer + +## IterativeSFTTrainer + +[[autodoc]] IterativeSFTTrainer + +## set_seed + +[[autodoc]] set_seed diff --git a/docs/source/use_model.md b/docs/source/use_model.md new file mode 100644 index 0000000000000000000000000000000000000000..f5ab1e45946460fc80d64f54136482b12400d059 --- /dev/null +++ b/docs/source/use_model.md @@ -0,0 +1,58 @@ +# Use model after training + +Once you have trained a model using either the SFTTrainer, PPOTrainer, or DPOTrainer, you will have a fine-tuned model that can be used for text generation. In this section, we'll walk through the process of loading the fine-tuned model and generating text. If you need to run an inference server with the trained model, you can explore libraries such as [`text-generation-inference`](https://github.com/huggingface/text-generation-inference). + +## Load and Generate + +If you have fine-tuned a model fully, meaning without the use of PEFT you can simply load it like any other language model in transformers. E.g. the value head that was trained during the PPO training is no longer needed and if you load the model with the original transformer class it will be ignored: + +```python +from transformers import AutoTokenizer, AutoModelForCausalLM + +model_name_or_path = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub +device = "cpu" # or "cuda" if you have a GPU + +model = AutoModelForCausalLM.from_pretrained(model_name_or_path).to(device) +tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + +inputs = tokenizer.encode("This movie was really", return_tensors="pt").to(device) +outputs = model.generate(inputs) +print(tokenizer.decode(outputs[0])) +``` + +Alternatively you can also use the pipeline: + +```python +from transformers import pipeline + +model_name_or_path = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub +pipe = pipeline("text-generation", model=model_name_or_path) +print(pipe("This movie was really")[0]["generated_text"]) +``` + +## Use Adapters PEFT + +```python +from peft import PeftConfig, PeftModel +from transformers import AutoModelForCausalLM, AutoTokenizer + +base_model_name = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub" +adapter_model_name = "path/to/my/adapter" + +model = AutoModelForCausalLM.from_pretrained(base_model_name) +model = PeftModel.from_pretrained(model, adapter_model_name) + +tokenizer = AutoTokenizer.from_pretrained(base_model_name) +``` + +You can also merge the adapters into the base model so you can use the model like a normal transformers model, however the checkpoint will be significantly bigger: + +```python +model = AutoModelForCausalLM.from_pretrained(base_model_name) +model = PeftModel.from_pretrained(model, adapter_model_name) + +model = model.merge_and_unload() +model.save_pretrained("merged_adapters") +``` + +Once you have the model loaded and either merged the adapters or keep them separately on top you can run generation as with a normal model outlined above. diff --git a/docs/source/using_llama_models.mdx b/docs/source/using_llama_models.mdx new file mode 100644 index 0000000000000000000000000000000000000000..cf602d2030400b00fe91749a8e49438bbfb90c4c --- /dev/null +++ b/docs/source/using_llama_models.mdx @@ -0,0 +1,160 @@ +# Using LLaMA models with TRL + +We've begun rolling out examples to use Meta's LLaMA models in `trl` (see [Meta's LLaMA release](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/) for the original LLaMA model). + +## Efficient training strategies + +Even training the smallest LLaMA model requires an enormous amount of memory. Some quick math: in bf16, every parameter uses 2 bytes (in fp32 4 bytes) in addition to 8 bytes used, e.g., in the Adam optimizer (see the [performance docs](https://huggingface.co/docs/transformers/perf_train_gpu_one#optimizer) in Transformers for more info). So a 7B parameter model would use `(2+8)*7B=70GB` just to fit in memory and would likely need more when you compute intermediate values such as attention scores. So you couldn’t train the model even on a single 80GB A100 like that. You can use some tricks, like more efficient optimizers of half-precision training, to squeeze a bit more into memory, but you’ll run out sooner or later. + +Another option is to use Parameter-Efficient Fine-Tuning (PEFT) techniques, such as the [`peft`](https://github.com/huggingface/peft) library, which can perform low-rank adaptation (LoRA) on a model loaded in 8-bit. +For more on `peft` + `trl`, see the [docs](https://huggingface.co/docs/trl/sentiment_tuning_peft). + +Loading the model in 8bit reduces the memory footprint drastically since you only need one byte per parameter for the weights (e.g. 7B LlaMa is 7GB in memory). +Instead of training the original weights directly, LoRA adds small adapter layers on top of some specific layers (usually the attention layers); thus, the number of trainable parameters is drastically reduced. + +In this scenario, a rule of thumb is to allocate ~1.2-1.4GB per billion parameters (depending on the batch size and sequence length) to fit the entire fine-tuning setup. +This enables fine-tuning larger models (up to 50-60B scale models on a NVIDIA A100 80GB) at low cost. + +Now we can fit very large models into a single GPU, but the training might still be very slow. +The simplest strategy in this scenario is data parallelism: we replicate the same training setup into separate GPUs and pass different batches to each GPU. +With this, you can parallelize the forward/backward passes of the model and scale with the number of GPUs. + +![chapter10_ddp.png](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/blog/stackllama/chapter10_ddp.png) + +We use either the `transformers.Trainer` or `accelerate`, which both support data parallelism without any code changes, by simply passing arguments when calling the scripts with `torchrun` or `accelerate launch`. The following runs a training script with 8 GPUs on a single machine with `accelerate` and `torchrun`, respectively. + +```bash +accelerate launch --multi_gpu --num_machines 1 --num_processes 8 my_accelerate_script.py +torchrun --nnodes 1 --nproc_per_node 8 my_torch_script.py +``` + +## Supervised fine-tuning + +Before we start training reward models and tuning our model with RL, it helps if the model is already good in the domain we are interested in. +In our case, we want it to answer questions, while for other use cases, we might want it to follow instructions, in which case instruction tuning is a great idea. +The easiest way to achieve this is by continuing to train the language model with the language modeling objective on texts from the domain or task. +The [StackExchange dataset](https://huggingface.co/datasets/HuggingFaceH4/stack-exchange-preferences) is enormous (over 10 million instructions), so we can easily train the language model on a subset of it. + +There is nothing special about fine-tuning the model before doing RLHF - it’s just the causal language modeling objective from pretraining that we apply here. +To use the data efficiently, we use a technique called packing: instead of having one text per sample in the batch and then padding to either the longest text or the maximal context of the model, we concatenate a lot of texts with a EOS token in between and cut chunks of the context size to fill the batch without any padding. + +![chapter10_preprocessing-clm.png](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/blog/stackllama/chapter10_preprocessing-clm.png) + +With this approach the training is much more efficient as each token that is passed through the model is also trained in contrast to padding tokens which are usually masked from the loss. +If you don't have much data and are more concerned about occasionally cutting off some tokens that are overflowing the context you can also use a classical data loader. + +The packing is handled by the `ConstantLengthDataset` and we can then use the `Trainer` after loading the model with `peft`. First, we load the model in int8, prepare it for training, and then add the LoRA adapters. + +```python +# load model in 8bit +model = AutoModelForCausalLM.from_pretrained( + args.model_path, + load_in_8bit=True, + device_map={"": Accelerator().local_process_index} + ) +model = prepare_model_for_kbit_training(model) + +# add LoRA to model +lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", +) + +model = get_peft_model(model, config) +``` + +We train the model for a few thousand steps with the causal language modeling objective and save the model. +Since we will tune the model again with different objectives, we merge the adapter weights with the original model weights. + +**Disclaimer:** due to LLaMA's license, we release only the adapter weights for this and the model checkpoints in the following sections. +You can apply for access to the base model's weights by filling out Meta AI's [form](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform) and then converting them to the 🤗 Transformers format by running this [script](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py). +Note that you'll also need to install 🤗 Transformers from source until the `v4.28` is released. + +Now that we have fine-tuned the model for the task, we are ready to train a reward model. + +## Reward modeling and human preferences + +In principle, we could fine-tune the model using RLHF directly with the human annotations. +However, this would require us to send some samples to humans for rating after each optimization iteration. +This is expensive and slow due to the number of training samples needed for convergence and the inherent latency of human reading and annotator speed. + +A trick that works well instead of direct feedback is training a reward model on human annotations collected before the RL loop. +The goal of the reward model is to imitate how a human would rate a text. There are several possible strategies to build a reward model: the most straightforward way would be to predict the annotation (e.g. a rating score or a binary value for “good”/”bad”). +In practice, what works better is to predict the ranking of two examples, where the reward model is presented with two candidates `(y_k, y_j)` for a given prompt `x` and has to predict which one would be rated higher by a human annotator. + +With the StackExchange dataset, we can infer which of the two answers was preferred by the users based on the score. +With that information and the loss defined above, we can then modify the `transformers.Trainer` by adding a custom loss function. + +```python +class RewardTrainer(Trainer): + def compute_loss(self, model, inputs, return_outputs=False): + rewards_j = model(input_ids=inputs["input_ids_j"], attention_mask=inputs["attention_mask_j"])[0] + rewards_k = model(input_ids=inputs["input_ids_k"], attention_mask=inputs["attention_mask_k"])[0] + loss = -nn.functional.logsigmoid(rewards_j - rewards_k).mean() + if return_outputs: + return loss, {"rewards_j": rewards_j, "rewards_k": rewards_k} + return loss +``` + +We utilize a subset of a 100,000 pair of candidates and evaluate on a held-out set of 50,000. With a modest training batch size of 4, we train the Llama model using the LoRA `peft` adapter for a single epoch using the Adam optimizer with BF16 precision. Our LoRA configuration is: + +```python +peft_config = LoraConfig( + task_type=TaskType.SEQ_CLS, + inference_mode=False, + r=8, + lora_alpha=32, + lora_dropout=0.1, +) +``` +As detailed in the next section, the resulting adapter can be merged into the frozen model and saved for further downstream use. + +## Reinforcement Learning from Human Feedback + +With the fine-tuned language model and the reward model at hand, we are now ready to run the RL loop. It follows roughly three steps: + +1. Generate responses from prompts, +2. Rate the responses with the reward model, +3. Run a reinforcement learning policy-optimization step with the ratings. + +The Query and Response prompts are templated as follows before being tokenized and passed to the model: + +```bash +Question: + +Answer: +``` + +The same template was used for SFT, RM and RLHF stages. +Once more, we utilize `peft` for memory-efficient training, which offers an extra advantage in the RLHF context. +Here, the reference model and policy share the same base, the SFT model, which we load in 8-bit and freeze during training. +We exclusively optimize the policy's LoRA weights using PPO while sharing the base model's weights. + +```python +for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)): + question_tensors = batch["input_ids"] + + # sample from the policy and to generate responses + response_tensors = ppo_trainer.generate( + question_tensors, + return_prompt=False, + length_sampler=output_length_sampler, + **generation_kwargs, + ) + batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True) + + # Compute sentiment score + texts = [q + r for q, r in zip(batch["query"], batch["response"])] + pipe_outputs = sentiment_pipe(texts, **sent_kwargs) + rewards = [torch.tensor(output[0]["score"] - script_args.reward_baseline) for output in pipe_outputs] + + # Run PPO step + stats = ppo_trainer.step(question_tensors, response_tensors, rewards) + # Log stats to Wandb + ppo_trainer.log_stats(stats, batch, rewards) +``` + +For the rest of the details and evaluation, please refer to our [blog post on StackLLaMA](https://huggingface.co/blog/stackllama). diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000000000000000000000000000000000000..37999e41abc02461a09ed7e29e39cc0bec15e488 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,3 @@ +# Examples + +Please check out https://huggingface.co/docs/trl/example_overview for documentation on our examples. \ No newline at end of file diff --git a/examples/accelerate_configs/deepspeed_zero1.yaml b/examples/accelerate_configs/deepspeed_zero1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d5b5f782fb30f9fcbcc8fc58262f09eaf2e10368 --- /dev/null +++ b/examples/accelerate_configs/deepspeed_zero1.yaml @@ -0,0 +1,20 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + gradient_accumulation_steps: 1 + zero3_init_flag: false + zero_stage: 1 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/accelerate_configs/deepspeed_zero2.yaml b/examples/accelerate_configs/deepspeed_zero2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..390f7086e5b8759b2285c2bd18fb92b337f4ae27 --- /dev/null +++ b/examples/accelerate_configs/deepspeed_zero2.yaml @@ -0,0 +1,22 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + gradient_accumulation_steps: 1 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/accelerate_configs/deepspeed_zero3.yaml b/examples/accelerate_configs/deepspeed_zero3.yaml new file mode 100644 index 0000000000000000000000000000000000000000..281c91e8a71e023d5d37f1e0cd9ec9f26b3e231c --- /dev/null +++ b/examples/accelerate_configs/deepspeed_zero3.yaml @@ -0,0 +1,23 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + gradient_accumulation_steps: 1 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/accelerate_configs/multi_gpu.yaml b/examples/accelerate_configs/multi_gpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..15dad9be3ba44f7c934e1ecab98a93cb83cbc79a --- /dev/null +++ b/examples/accelerate_configs/multi_gpu.yaml @@ -0,0 +1,16 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: MULTI_GPU +downcast_bf16: 'no' +gpu_ids: all +machine_rank: 0 +main_training_function: main +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/hello_world.py b/examples/hello_world.py new file mode 100644 index 0000000000000000000000000000000000000000..138defb5b433ff43480e61a29e89b8e0233c6400 --- /dev/null +++ b/examples/hello_world.py @@ -0,0 +1,40 @@ +# 0. imports +import torch +from transformers import GPT2Tokenizer + +from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer + + +# 1. load a pretrained model +model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2") +model_ref = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2") +tokenizer = GPT2Tokenizer.from_pretrained("gpt2") +tokenizer.pad_token = tokenizer.eos_token + +# 2. initialize trainer +ppo_config = {"batch_size": 1} +config = PPOConfig(**ppo_config) +ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer) + +# 3. encode a query +query_txt = "This morning I went to the " +query_tensor = tokenizer.encode(query_txt, return_tensors="pt").to(model.pretrained_model.device) + +# 4. generate model response +generation_kwargs = { + "min_length": -1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + "pad_token_id": tokenizer.eos_token_id, + "max_new_tokens": 20, +} +response_tensor = ppo_trainer.generate([item for item in query_tensor], return_prompt=False, **generation_kwargs) +response_txt = tokenizer.decode(response_tensor[0]) + +# 5. define a reward for response +# (this could be any reward such as human feedback or output from another model) +reward = [torch.tensor(1.0, device=model.pretrained_model.device)] + +# 6. train model with ppo +train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward) diff --git a/examples/notebooks/README.md b/examples/notebooks/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f2a11e280099f79e30059ff77295d53eff30b62a --- /dev/null +++ b/examples/notebooks/README.md @@ -0,0 +1,7 @@ +# Notebooks + +This directory contains a collection of Jupyter notebooks that demonstrate how to use the TRL library in different applications. + +- [`best_of_n.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/best_of_n.ipynb): This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO. +- [`gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb): This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. +- [`gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment-control.ipynb): This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook. diff --git a/examples/notebooks/best_of_n.ipynb b/examples/notebooks/best_of_n.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..375cafe99f0ad77a902634e546d9199f824e04fb --- /dev/null +++ b/examples/notebooks/best_of_n.ipynb @@ -0,0 +1,648 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU", + "gpuClass": "standard" + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "\n", + "**Best-of-n sampling as an alternative to RLHF**\n", + "\n", + "This notebook compares reward-model scores of prompt based responses from \n", + "1. a base model (`gpt2-imdb`)\n", + "2. `RLHF` tuned model based on this base-model \n", + "3. the base-model again from which we sample n responses to each prompt, score them and take the best scored one AKA the `best-of-n sampled` model\n", + "\n" + ], + "metadata": { + "id": "WQpNapZNWuXP" + } + }, + { + "cell_type": "markdown", + "source": [ + "Import dependencies\n" + ], + "metadata": { + "id": "Lo98lkdP66_x" + } + }, + { + "cell_type": "code", + "source": [ + "%pip install transformers trl" + ], + "metadata": { + "id": "vDA6qayz692w" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "import torch\n", + "import pandas as pd\n", + "from transformers import pipeline, AutoTokenizer\n", + "from datasets import load_dataset\n", + "\n", + "from trl import AutoModelForCausalLMWithValueHead\n", + "from trl.core import LengthSampler\n", + "\n", + "device = 0 if torch.cuda.is_available() else \"cpu\"" + ], + "metadata": { + "id": "M1s_iNm773hM" + }, + "execution_count": 2, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Various constants" + ], + "metadata": { + "id": "Y7hyrIrO8tcY" + } + }, + { + "cell_type": "code", + "source": [ + "ref_model_name = \"lvwerra/gpt2-imdb\"\n", + "model_name = \"lvwerra/gpt2-imdb-pos-v2\"\n", + "reward_model = \"lvwerra/distilbert-imdb\"\n", + "\n", + "N_BEST_OF = 4" + ], + "metadata": { + "id": "MqS3OM6Q8x6g" + }, + "execution_count": 3, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Models and tokenizers " + ], + "metadata": { + "id": "c1YcXeElg6or" + } + }, + { + "cell_type": "code", + "source": [ + "model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name)\n", + "\n", + "ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name)\n", + "\n", + "reward_pipe = pipeline(\"sentiment-analysis\", model=reward_model, device=device)\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(ref_model_name)\n", + "\n", + "tokenizer.pad_token = tokenizer.eos_token\n", + "\n", + "# cuda-ize models\n", + "model.cuda()\n", + "ref_model.cuda()" + ], + "metadata": { + "id": "b855NrL181Hh" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Dataset building" + ], + "metadata": { + "id": "Z1Cz0gCFhZYJ" + } + }, + { + "cell_type": "code", + "source": [ + "def build_dataset(tokenizer, dataset_name=\"imdb\", input_min_text_length=2, input_max_text_length=8):\n", + " # load imdb with datasets\n", + " ds = load_dataset(dataset_name, split=\"train\")\n", + " ds = ds.rename_columns({\"text\": \"review\"})\n", + " ds = ds.filter(lambda x: len(x[\"review\"]) > 200, batched=False)\n", + "\n", + " input_size = LengthSampler(input_min_text_length, input_max_text_length)\n", + "\n", + " def tokenize(sample):\n", + " sample[\"input_ids\"] = tokenizer.encode(sample[\"review\"])[: input_size()]\n", + " sample[\"query\"] = tokenizer.decode(sample[\"input_ids\"])\n", + " return sample\n", + "\n", + " ds = ds.map(tokenize, batched=False)\n", + " ds.set_format(type=\"torch\")\n", + " return ds\n", + "\n", + "\n", + "dataset = build_dataset(tokenizer)" + ], + "metadata": { + "id": "LqLVEp5p_8XM" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "gen_kwargs = {\"min_length\": -1, \"top_k\": 0.0, \"top_p\": 1.0, \"do_sample\": True, \"pad_token_id\": tokenizer.eos_token_id}\n", + "sent_kwargs = {\"top_k\": None, \"function_to_apply\": \"none\", \"batch_size\": 16}" + ], + "metadata": { + "id": "AqA2McjMAxNw" + }, + "execution_count": 6, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "output_min_length = 4\n", + "output_max_length = 16\n", + "output_length_sampler = LengthSampler(output_min_length, output_max_length)\n", + "\n", + "#### get a batch from the dataset\n", + "bs = 16\n", + "output_data = dict()\n", + "dataset.set_format(\"pandas\")\n", + "df_batch = dataset[:].sample(bs)\n", + "output_data[\"query\"] = df_batch[\"query\"].tolist()\n", + "query_tensors = df_batch[\"input_ids\"].tolist()\n", + "\n", + "# :: [Resp]\n", + "response_tensors_ref, response_tensors = [], []\n", + "# :: [[Resp]]\n", + "response_tensors_best_of = []" + ], + "metadata": { + "id": "L_q4qs35AxcR" + }, + "execution_count": 7, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "\n", + "Generation using various models" + ], + "metadata": { + "id": "QVfpyHnZBLKY" + } + }, + { + "cell_type": "code", + "source": [ + "for i in range(bs):\n", + " gen_len = output_length_sampler()\n", + "\n", + " query = torch.tensor(query_tensors[i])\n", + "\n", + " output = ref_model.generate(query.unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs).squeeze()\n", + " response_tensors_ref.append(tokenizer.decode(output))\n", + "\n", + " output = model.generate(query.unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs).squeeze()\n", + " response_tensors.append(tokenizer.decode(output))\n", + "\n", + " # generating copies of the same query for the Best-of-n sampling\n", + " queries = query.repeat((N_BEST_OF, 1))\n", + " output = ref_model.generate(queries.to(device), max_new_tokens=gen_len, **gen_kwargs).squeeze()\n", + " response_tensors_best_of.append(tokenizer.batch_decode(output))" + ], + "metadata": { + "id": "-imZ7uEFBNbw" + }, + "execution_count": 8, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Scoring" + ], + "metadata": { + "id": "Jp5FC0Y5h_Sf" + } + }, + { + "cell_type": "code", + "source": [ + "scores_ref = [output[0][\"score\"] for output in reward_pipe(response_tensors_ref, **sent_kwargs)]\n", + "scores = [output[0][\"score\"] for output in reward_pipe(response_tensors, **sent_kwargs)]\n", + "scores_best_of = []\n", + "for i, response in enumerate(response_tensors_best_of):\n", + " # base_score = scores_ref[i]\n", + " scores_best_of.append(torch.tensor([output[0][\"score\"] for output in reward_pipe(response, **sent_kwargs)]))" + ], + "metadata": { + "id": "PyDbbAQ0F_h7" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "output_data[\"response (ref)\"] = response_tensors_ref\n", + "output_data[\"scores (ref)\"] = scores_ref\n", + "output_data[\"response (RLHF)\"] = response_tensors\n", + "output_data[\"scores (RLHF)\"] = scores\n", + "output_data[\"response (best_of)\"] = [\n", + " response_tensors_best_of[i][a.argmax().item()] for i, a in enumerate(scores_best_of)\n", + "]\n", + "output_data[\"scores (best_of)\"] = [a.max().item() for a in scores_best_of]\n", + "\n", + "\n", + "# store results in a dataframe\n", + "df_results = pd.DataFrame(output_data)\n", + "df_results" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 682 + }, + "id": "nA1GDNJEiGm-", + "outputId": "1389c686-0751-4304-dea2-b71fd68748e1" + }, + "execution_count": 10, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " query \\\n", + "0 I'm a pretty old \n", + "1 One of the most \n", + "2 Okay, as \n", + "3 Watching \"Kro \n", + "4 Seriously what were they thinking? \n", + "5 OK Hollywood \n", + "6 \"Bend It \n", + "7 While the premise behind The House \n", + "8 Well let me go \n", + "9 Vijay Krishna Acharya \n", + "10 Watching this movie made me \n", + "11 There are probably \n", + "12 Meryl Stre \n", + "13 I thought I read somewhere that \n", + "14 Good movie, very \n", + "15 It was agonizing \n", + "\n", + " response (ref) scores (ref) \\\n", + "0 I'm a pretty old kid, well, with lots of girl 1.179652 \n", + "1 One of the most psychologically devastating as... 2.477277 \n", + "2 Okay, as ruthless as they are, even their leve... 1.466462 \n", + "3 Watching \"Kroger\" (1915- 0.186047 \n", + "4 Seriously what were they thinking? It ain't go... 1.010697 \n", + "5 OK Hollywood goes into a total game of audio, ... 0.934041 \n", + "6 \"Bend It, Luther, Dodge, Church Goes to Rome w... 0.039218 \n", + "7 While the premise behind The House of Dracula ... -0.079306 \n", + "8 Well let me go...I don't want to movie it. I'm... 1.015246 \n", + "9 Vijay Krishna Acharya Sawai (Elverling). She was 0.341506 \n", + "10 Watching this movie made me poorly appreciate ... 1.574047 \n", + "11 There are probably more but if you had never s... -0.047099 \n", + "12 Meryl Streep's version of 0.373884 \n", + "13 I thought I read somewhere that the Lord had c... 0.091776 \n", + "14 Good movie, very funny, acting is very good.<|... 2.408837 \n", + "15 It was agonizing, and it made me wonder 1.240262 \n", + "\n", + " response (RLHF) scores (RLHF) \\\n", + "0 I'm a pretty old lady, and I loved this movie ... 2.218363 \n", + "1 One of the most Antibiotic Apps I have seen in 2.145479 \n", + "2 Okay, as I enjoyed the movie. It's added bonus... 2.239827 \n", + "3 Watching \"Kroven\". The film has a 1.044690 \n", + "4 Seriously what were they thinking? It's a very... 2.753088 \n", + "5 OK Hollywood shoot, and this is a classic. Som... 2.517364 \n", + "6 \"Bend It all\" is a sophisticated, drawing and ... 2.583935 \n", + "7 While the premise behind The House Intelligenc... 0.205217 \n", + "8 Well let me go through everything says it's a ... 2.727040 \n", + "9 Vijay Krishna Acharya is a perfect performance... 2.563642 \n", + "10 Watching this movie made me sleep better. It w... 1.690222 \n", + "11 There are probably random man only recently wh... 0.398258 \n", + "12 Meryl Streitz, who is 0.085154 \n", + "13 I thought I read somewhere that my thoughts, a... 1.833734 \n", + "14 Good movie, very much fuzz and logical based w... 2.325996 \n", + "15 It was agonizing because it was truly fun to 0.969669 \n", + "\n", + " response (best_of) scores (best_of) \n", + "0 I'm a pretty old, stinking,acting kinda chick ... 2.016955 \n", + "1 One of the most memorable performances of this... 2.676944 \n", + "2 Okay, as I put it in such a negative mood, it ... 1.478424 \n", + "3 Watching \"Kro\" is an entertainment craze 1.389495 \n", + "4 Seriously what were they thinking? It was stil... 2.523514 \n", + "5 OK Hollywood pay and the freaky set-up of this... 1.634765 \n", + "6 \"Bend It 9\"/\"Zara Pephoto\") and an honest, rea... 2.557210 \n", + "7 While the premise behind The House of Dracula ... 1.676889 \n", + "8 Well let me go though, alive in this ever grow... 2.652859 \n", + "9 Vijay Krishna Acharya adeptly emerges, and the... 2.308076 \n", + "10 Watching this movie made me curious: what did ... 0.950836 \n", + "11 There are probably too many documentaries in s... 1.142725 \n", + "12 Meryl Streep performed an awe 1.932498 \n", + "13 I thought I read somewhere that The Odd Couple... 0.475951 \n", + "14 Good movie, very well polished, nicely written... 2.820022 \n", + "15 It was agonizing, poignant, and worst of 2.058277 " + ], + "text/html": [ + "\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
queryresponse (ref)scores (ref)response (RLHF)scores (RLHF)response (best_of)scores (best_of)
0I'm a pretty oldI'm a pretty old kid, well, with lots of girl1.179652I'm a pretty old lady, and I loved this movie ...2.218363I'm a pretty old, stinking,acting kinda chick ...2.016955
1One of the mostOne of the most psychologically devastating as...2.477277One of the most Antibiotic Apps I have seen in2.145479One of the most memorable performances of this...2.676944
2Okay, asOkay, as ruthless as they are, even their leve...1.466462Okay, as I enjoyed the movie. It's added bonus...2.239827Okay, as I put it in such a negative mood, it ...1.478424
3Watching \"KroWatching \"Kroger\" (1915-0.186047Watching \"Kroven\". The film has a1.044690Watching \"Kro\" is an entertainment craze1.389495
4Seriously what were they thinking?Seriously what were they thinking? It ain't go...1.010697Seriously what were they thinking? It's a very...2.753088Seriously what were they thinking? It was stil...2.523514
5OK HollywoodOK Hollywood goes into a total game of audio, ...0.934041OK Hollywood shoot, and this is a classic. Som...2.517364OK Hollywood pay and the freaky set-up of this...1.634765
6\"Bend It\"Bend It, Luther, Dodge, Church Goes to Rome w...0.039218\"Bend It all\" is a sophisticated, drawing and ...2.583935\"Bend It 9\"/\"Zara Pephoto\") and an honest, rea...2.557210
7While the premise behind The HouseWhile the premise behind The House of Dracula ...-0.079306While the premise behind The House Intelligenc...0.205217While the premise behind The House of Dracula ...1.676889
8Well let me goWell let me go...I don't want to movie it. I'm...1.015246Well let me go through everything says it's a ...2.727040Well let me go though, alive in this ever grow...2.652859
9Vijay Krishna AcharyaVijay Krishna Acharya Sawai (Elverling). She was0.341506Vijay Krishna Acharya is a perfect performance...2.563642Vijay Krishna Acharya adeptly emerges, and the...2.308076
10Watching this movie made meWatching this movie made me poorly appreciate ...1.574047Watching this movie made me sleep better. It w...1.690222Watching this movie made me curious: what did ...0.950836
11There are probablyThere are probably more but if you had never s...-0.047099There are probably random man only recently wh...0.398258There are probably too many documentaries in s...1.142725
12Meryl StreMeryl Streep's version of0.373884Meryl Streitz, who is0.085154Meryl Streep performed an awe1.932498
13I thought I read somewhere thatI thought I read somewhere that the Lord had c...0.091776I thought I read somewhere that my thoughts, a...1.833734I thought I read somewhere that The Odd Couple...0.475951
14Good movie, veryGood movie, very funny, acting is very good.<|...2.408837Good movie, very much fuzz and logical based w...2.325996Good movie, very well polished, nicely written...2.820022
15It was agonizingIt was agonizing, and it made me wonder1.240262It was agonizing because it was truly fun to0.969669It was agonizing, poignant, and worst of2.058277
\n", + "
\n", + " \n", + " \n", + " \n", + "\n", + " \n", + "
\n", + "
\n", + " " + ] + }, + "metadata": {}, + "execution_count": 10 + } + ] + } + ] +} \ No newline at end of file diff --git a/examples/notebooks/gpt2-sentiment-control.ipynb b/examples/notebooks/gpt2-sentiment-control.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..de00502d3dc645ee3dd3f9aed7a6469175b12d40 --- /dev/null +++ b/examples/notebooks/gpt2-sentiment-control.ipynb @@ -0,0 +1,860 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Tune GPT2 to generate controlled sentiment reviews\n", + "> Optimise GPT2 to produce IMDB movie reviews with controlled sentiment using a BERT sentiment classifier for rewards.\n", + "\n", + "**WARNING:** We often experienced loss spikes in this examples which caused model training to fail or slow down. There is a [GitHub issue](https://github.com/lvwerra/trl/issues/101) to track the issue." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "

Figure: Experiment setup to tune GPT2. The yellow arrows are outside the scope of this notebook, but the trained models are available through Hugging Face.

\n", + "
\n", + "\n", + "\n", + "The experiment setup is very similar to the positive sentiment notebook. However, in this notebook we fine-tune GPT2 (small) to generate **controlled** movie reviews based on the IMDB dataset. The model gets the target sentiment and 5 tokens from a real review and is tasked to produce continuations with the targeted sentiment. The reward for the continuations is calculated with the logits of a BERT sentiment classifier. That reward is then used for PPO training." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup experiment" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Import dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/leandro_huggingface_co/miniconda3/envs/trl/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import random\n", + "import torch\n", + "import wandb\n", + "import time\n", + "import os\n", + "from tqdm import tqdm\n", + "import numpy as np\n", + "import pandas as pd\n", + "from random import choices\n", + "import matplotlib.pyplot as plt\n", + "\n", + "tqdm.pandas()\n", + "\n", + "from datasets import load_dataset\n", + "\n", + "from transformers import AutoTokenizer, pipeline\n", + "\n", + "from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Configuration" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "sentiment_pipe_kwargs = {\"top_k\": None, \"function_to_apply\": \"none\"}\n", + "\n", + "config = PPOConfig(\n", + " model_name=\"lvwerra/gpt2-imdb\", steps=51200, learning_rate=1.41e-5, remove_unused_columns=False, log_with=\"wandb\"\n", + ")\n", + "\n", + "txt_in_len = 5\n", + "txt_out_len = 20\n", + "seed = 1" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "np.random.seed(seed)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can see that we load a GPT2 model called `gpt2_imdb`. This model was additionally fine-tuned on the IMDB dataset for 1 epoch with the huggingface [script](https://github.com/huggingface/transformers/blob/master/examples/run_language_modeling.py) (no special settings). The other parameters are mostly taken from the original paper [\"Fine-Tuning Language Models from Human Preferences\"](\n", + "https://arxiv.org/pdf/1909.08593.pdf). This model as well as the BERT model is available in the Huggingface model zoo [here](https://huggingface.co/models). The following code should automatically download the models." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load data and models" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load pre-trained GPT2 language models" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We load the GPT2 model with a value head and the tokenizer. We load the model twice; the first model is optimized while the second model serves as a reference to calculate the KL-divergence from the starting point. This serves as an additional reward signal in the PPO training to make sure the optimized model does not deviate too much from the original language model." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "gpt2_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)\n", + "gpt2_model_ref = create_reference_model(gpt2_model)\n", + "gpt2_tokenizer = AutoTokenizer.from_pretrained(config.model_name)\n", + "\n", + "gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load IMDB dataset\n", + "The IMDB dataset contains 50k movie review annotated with \"positive\"/\"negative\" feedback indicating the sentiment. We load the IMDB dataset into a DataFrame and filter for comments that are at least 500 characters long and take the first 1000 characters of each comment. The first filter we apply to avoid comments that are less than `txt_in_len` token long and the second to avoid tokenizing way more text than we actually need." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Found cached dataset imdb (/home/leandro_huggingface_co/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)\n", + "Loading cached processed dataset at /home/leandro_huggingface_co/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-d314b4c14499bf03.arrow\n", + "Loading cached processed dataset at /home/leandro_huggingface_co/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-0d5fcb05c95b1186.arrow\n" + ] + }, + { + "data": { + "text/plain": [ + "Dataset({\n", + " features: ['review', 'sentiment'],\n", + " num_rows: 22578\n", + "})" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# create the dataset\n", + "#\n", + "dataset = load_dataset(\"imdb\", split=\"train\")\n", + "dataset = dataset.rename_columns({\"text\": \"review\", \"label\": \"sentiment\"})\n", + "# make sure the comments are are at least 500 and trim to 1000\n", + "dataset = dataset.filter(lambda x: len(x[\"review\"]) > 500, batched=False)\n", + "dataset = dataset.map(lambda x: {\"review\": x[\"review\"][:1000]}, batched=False)\n", + "\n", + "dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Tokenize IMDB reviews" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We tokenize all IMDB in advance to avoid tokenizing twice. In the first step we encode the queries and slice the first `txt_in_len` tokens. In a second step we decode these tokens back to text for later display." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading cached processed dataset at /home/leandro_huggingface_co/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-383f6ebf0ae41ee4.arrow\n", + "Loading cached processed dataset at /home/leandro_huggingface_co/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-f4875ad4fccbbc1f.arrow\n" + ] + } + ], + "source": [ + "dataset = dataset.map(\n", + " lambda x: {\"input_ids\": gpt2_tokenizer.encode(\" \" + x[\"review\"], return_tensors=\"pt\")[0, :txt_in_len]},\n", + " batched=False,\n", + ")\n", + "dataset = dataset.map(lambda x: {\"query\": gpt2_tokenizer.decode(x[\"input_ids\"])}, batched=False)\n", + "dataset = dataset[:20480]\n", + "\n", + "from datasets import Dataset\n", + "\n", + "dataset = Dataset.from_dict(dataset)\n", + "dataset.set_format(\"pytorch\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ 770, 2646, 373, 2192, 7867])" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset[3][\"input_ids\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "def collator(data):\n", + " return dict((key, [d[key] for d in data]) for key in data[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mlvwerra\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" + ] + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.13.9" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /home/leandro_huggingface_co/trl/examples/sentiment/notebooks/wandb/run-20230206_125743-jpcnr7jx" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run comic-music-184 to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/lvwerra/trl" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/lvwerra/trl/runs/jpcnr7jx" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "ppo_trainer = PPOTrainer(config, gpt2_model, gpt2_model_ref, gpt2_tokenizer, dataset, data_collator=collator)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load BERT classifier\n", + "We load a BERT classifier fine-tuned on the IMDB dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "if ppo_trainer.accelerator.num_processes == 1:\n", + " device = 0 if torch.cuda.is_available() else \"cpu\" # to avoid a `pipeline` bug\n", + "else:\n", + " device = ppo_trainer.accelerator.device\n", + "sentiment_pipe = pipeline(\"sentiment-analysis\", \"lvwerra/distilbert-imdb\", device=device)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The model outputs are the logits for the negative and positive class. We will use the logits for positive class as a reward signal for the language model." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'label': 'NEGATIVE', 'score': 2.3350484371185303},\n", + " {'label': 'POSITIVE', 'score': -2.726576328277588}]" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "text = \"this movie was really bad!!\"\n", + "output = sentiment_pipe(text, **sentiment_pipe_kwargs)\n", + "output" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'label': 'POSITIVE', 'score': 2.557040214538574},\n", + " {'label': 'NEGATIVE', 'score': -2.294790267944336}]" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "text = \"this movie was really good!!\"\n", + "output = sentiment_pipe(text, **sentiment_pipe_kwargs)\n", + "output" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'label': 'POSITIVE', 'score': 0.8562759160995483},\n", + " {'label': 'NEGATIVE', 'score': -0.7086048126220703}]" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "text = \"this movie was a documentary\"\n", + "output = sentiment_pipe(text, **sentiment_pipe_kwargs)\n", + "output" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The resulting reward signal:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "def extract_pipe_output(outputs):\n", + " positive_logits = []\n", + " for out in outputs:\n", + " for element in out:\n", + " if element[\"label\"] == \"POSITIVE\":\n", + " positive_logits.append(torch.tensor(element[\"score\"]))\n", + " return positive_logits" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "-0.7086048126220703" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "output[1][\"score\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Control token dict\n", + "We will append the control token at the beginning of each query to signal the model what the target sentiment is. Each control sequence consists of three tokens:" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "ctrl_str = [\"[negative]\", \"[neutral]\", \"[positive]\"]\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\") # this should be handled by accelerate\n", + "ctrl_tokens = dict((s, gpt2_tokenizer.encode(s, return_tensors=\"pt\").squeeze().to(device)) for s in ctrl_str)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'[negative]': tensor([ 58, 31591, 60], device='cuda:0'),\n", + " '[neutral]': tensor([ 58, 29797, 60], device='cuda:0'),\n", + " '[positive]': tensor([ 58, 24561, 60], device='cuda:0')}" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ctrl_tokens" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Reward function" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "def pos_logit_to_reward(logit, task):\n", + " \"\"\"\n", + " Take the positive sentiment logit and scale it for the task.\n", + " task [negative]: reward = -logit\n", + " task [neutral]: reward = -2*abs(logit)+4\n", + " task [positive]: reward = logit\n", + " \"\"\"\n", + " for i in range(len(logit)):\n", + " if task[i] == \"[negative]\":\n", + " logit[i] = -logit[i]\n", + " elif task[i] == \"[neutral]\":\n", + " logit[i] = -2 * torch.abs(logit[i]) + 4\n", + " elif task[i] == \"[positive]\":\n", + " pass\n", + " else:\n", + " raise ValueError(\"task has to be in [0, 1, 2]!\")\n", + " return logit" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The following examples show the rewards for the cases where the classifier logit is 4, -4 and 0 for the three targets `['negative]`, `['neutral]` and `['positive']`. The scaling is not perfect as it differs between neutral and the other two classes. This is something to further investigate in the future. Ideally, one would use the logit output for each class individually, but since there is no dedicated class for neutral this is a workaround." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['[negative]', '[neutral]', '[positive]']\n" + ] + } + ], + "source": [ + "print(ctrl_str)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([-4., -4., 4.])" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pos_logit_to_reward(torch.Tensor([4, 4, 4]), ctrl_str)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ 4., -4., -4.])" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pos_logit_to_reward(torch.Tensor([-4, -4, -4]), ctrl_str)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([-0., 4., 0.])" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pos_logit_to_reward(torch.Tensor([0, 0, 0]), ctrl_str)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Generation settings" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "generation_kwargs = {\n", + " \"min_length\": -1,\n", + " \"top_k\": 0.0,\n", + " \"top_p\": 1.0,\n", + " \"do_sample\": True,\n", + " \"pad_token_id\": gpt2_tokenizer.eos_token_id,\n", + " \"max_new_tokens\": txt_out_len,\n", + " \"eos_token_id\": -1,\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Optimize model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Steps**\n", + "\n", + "The training loop consists of the following steps:\n", + "1. Get a batch of queries and create random controls\n", + "2. Get the query responses from the policy\n", + "3. Join query and responses and tokenize for BERT analysis\n", + "4. Get sentiments for query/responses from BERT\n", + "5. Optimize policy with PPO using the (query, response, reward) triplet\n", + "6. Log all the training statistics\n", + "\n", + "**Training time**\n", + "\n", + "This step takes **~2h** on a P6000 GPU with the above specified settings." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 8%|▊ | 6/80 [12:44<2:37:54, 128.03s/it]/home/leandro_huggingface_co/miniconda3/envs/trl/lib/python3.9/site-packages/transformers/pipelines/base.py:1045: UserWarning: You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset\n", + " warnings.warn(\n", + "100%|██████████| 80/80 [2:46:39<00:00, 124.99s/it] \n", + " 91%|█████████▏| 73/80 [2:30:39<14:35, 125.03s/it] " + ] + } + ], + "source": [ + "for epoch in range(2):\n", + " for batch in tqdm(ppo_trainer.dataloader):\n", + " (logs, game_data,) = (\n", + " dict(),\n", + " dict(),\n", + " )\n", + "\n", + " #### prepend a random control token\n", + " task_list = choices(ctrl_str, k=config.batch_size)\n", + " game_data[\"query\"] = [t + q for t, q in zip(task_list, batch[\"query\"])]\n", + " query_tensors = [torch.cat((ctrl_tokens[t], input_ids)) for t, input_ids in zip(task_list, batch[\"input_ids\"])]\n", + "\n", + " #### get response from gpt2\n", + " response_tensors = []\n", + " for query in query_tensors:\n", + " response = ppo_trainer.generate(query, **generation_kwargs)\n", + " response_tensors.append(response.squeeze()[-txt_out_len:])\n", + " game_data[\"response\"] = [gpt2_tokenizer.decode(r.squeeze()) for r in response_tensors]\n", + "\n", + " #### sentiment analysis\n", + " texts = [q + r for q, r in zip(batch[\"query\"], game_data[\"response\"])]\n", + " logits = extract_pipe_output(sentiment_pipe(texts, **sentiment_pipe_kwargs))\n", + " rewards = pos_logit_to_reward(logits, task_list)\n", + "\n", + " #### Run PPO training\n", + " t = time.time()\n", + " stats = ppo_trainer.step(query_tensors, response_tensors, rewards)\n", + "\n", + " for cs in ctrl_str:\n", + " key = \"env/reward_\" + cs.strip(\"[]\")\n", + " stats[key] = np.mean([r.cpu().numpy() for r, t in zip(rewards, task_list) if t == cs])\n", + " ppo_trainer.log_stats(stats, game_data, rewards)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training progress\n", + "If you are tracking the training progress with Weights&Biases you should see a plot similar to the following:\n", + "\n", + "
\n", + "\n", + "

Figure: Reward mean and distribution evolution during training.

\n", + "
\n", + "\n", + "One can observe how the model starts to generate more positive outputs after a few optimisation steps.\n", + "\n", + "> Note: Investigating the KL-divergence will probably show that at this point the model has not converged to the target KL-divergence, yet. To get there would require longer training or starting with a higher inital coefficient." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model inspection" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Reward distribution\n", + "First, we can have a look at the reward distribution. Both the negative and positive rewards are clearly shifted to high rewards. The neutral rewards, however, are still centered around zero. There are a few possible explanations for this. There could be a bug in the code and the way the neutral rewards are calculated. Another problem could be that sentence sometimes start with a strong sentiment and it is hard for the model shift the sentiment towards neutral." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiwAAAGzCAYAAAAMr0ziAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAABPCUlEQVR4nO3deVwVZf8//tecw4HDroiyibKImSmQkKi5B6K3mt4tLvm4RSq7S7lvjTtNLAVcPqip0aLZnbdL3ZK0qP2+5o0SSVmiFor7lklubGqIgB4OnPn9YWfyyGE5h+UM8Ho+Hjw8c80117znOoPzZuaaGUEURRFEREREMqawdABEREREdWHCQkRERLLHhIWIiIhkjwkLERERyR4TFiIiIpI9JixEREQke0xYiIiISPaYsBAREZHsMWEhIiIi2WPCQkQNtmnTJgiCgNzcXLOWnzZtGnx8fAzKBEFAQkJCg2OrS2ZmJgRBQGZmplQ2dOhQ9OrVq8nXDQC5ubkQBAGbNm1qlvURtVRMWIio1UhJSUFycrKlwzBKzrERtQRWlg6AiMiYO3fuwMrKtP+iUlJScOLECcyePbveywwePBh37tyBtbW1iRGapqbYunbtijt37kClUjXp+olaOp5hIZKBsrIyS4dQK51Oh7t37zbrOtVqtckJiynu3r0LnU4HhUIBtVoNhcIy/x0KggC1Wg2lUmmR9RO1FExYiJpZQkICBEHAqVOn8Nxzz6F9+/YYOHCgNP+///0vQkJCYGtrCxcXF0yaNAmXL1+W5r/77rtQKpUoLi6WylatWgVBEBAbGyuVVVVVwdHREa+//rpUtnLlSgwYMAAdOnSAra0tQkJC8MUXX1SLURAExMTEYMuWLXjkkUdgY2ODtLQ0AMDJkycxfPhw2NraonPnzliyZAl0Ol29t3/Hjh3o1asX1Go1evXqhe3btxut9+AYltu3b2P27Nnw8fGBjY0NOnXqhIiICBw+fBjAvXEnX3/9NX777TcIggBBEKRxMfpxKlu3bsWbb74JLy8v2NnZoaSkxOgYFr3s7GwMGDAAtra28PX1xbp16wzm1zR258E2a4utpjEs3377LQYNGgR7e3u0a9cO48aNw+nTpw3q6PelX375BdOmTUO7du3g7OyM6OholJeX1/wlELVAvCREZCHPPvssAgIC8H//938QRREAsHTpUixYsAATJkzAiy++iKKiIrz33nsYPHgwjhw5gnbt2mHQoEHQ6XT44YcfMGbMGADAvn37oFAosG/fPqn9I0eOoLS0FIMHD5bK3nnnHTz55JOYMmUKKioqsHXrVjz77LPYuXMnRo8ebRDft99+i88++wwxMTFwdXWFj48P8vPzMWzYMFRWVmLevHmwt7fHv//9b9ja2tZrm/fs2YOnn34aPXv2RFJSEm7cuIHo6Gh07ty5zmVffvllfPHFF4iJiUHPnj1x48YN/PDDDzh9+jT69OmDN954A7du3cKVK1fw9ttvAwAcHBwM2li8eDGsra3x2muvQaPR1HoZ6Pfff8df/vIXTJgwAZMnT8Znn32GV155BdbW1nj++efrtb169Yntft988w1GjRoFPz8/JCQk4M6dO3jvvffw+OOP4/Dhw9UGKE+YMAG+vr5ISkrC4cOHsX79enTq1AnLly83KU4iWROJqFnFx8eLAMTJkycblOfm5opKpVJcunSpQfnx48dFKysrqbyqqkp0cnIS586dK4qiKOp0OrFDhw7is88+KyqVSvH27duiKIri6tWrRYVCIf7+++9SW+Xl5QZtV1RUiL169RKHDx9uUA5AVCgU4smTJw3KZ8+eLQIQDx48KJUVFhaKzs7OIgDx4sWLtW57cHCw6OHhIRYXF0tle/bsEQGIXbt2rRZDfHy8NO3s7CzOnDmz1vZHjx5drR1RFMW9e/eKAEQ/P79qfaCft3fvXqlsyJAhIgBx1apVUplGoxGDg4PFTp06iRUVFaIoiuLGjRuNbrexNmuK7eLFiyIAcePGjVKZfj03btyQyo4ePSoqFApx6tSpUpl+X3r++ecN2vzrX/8qdujQodq6iFoyXhIispCXX37ZYHrbtm3Q6XSYMGECrl+/Lv24u7sjICAAe/fuBQAoFAoMGDAA33//PQDg9OnTuHHjBubNmwdRFJGVlQXg3lmXXr16oV27dtI67j8T8vvvv+PWrVsYNGiQdFnlfkOGDEHPnj0Nynbt2oV+/fqhb9++UlnHjh0xZcqUOrc3Ly8POTk5iIqKgrOzs1QeERFRbT3GtGvXDgcPHsS1a9fqrFuTqKioep8NsrKywt///ndp2traGn//+99RWFiI7Oxss2Ooi76fpk2bBhcXF6k8MDAQERER2LVrV7VlHtyXBg0ahBs3bqCkpKTJ4iRqbkxYiCzE19fXYPr8+fMQRREBAQHo2LGjwc/p06dRWFgo1R00aBCys7Nx584d7Nu3Dx4eHujTpw+CgoKky0I//PADBg0aZLCOnTt3ol+/flCr1XBxcUHHjh3xwQcf4NatW3XGBwC//fYbAgICqpU/9NBDdW7vb7/9BgBmL79ixQqcOHEC3t7e6Nu3LxISEvDrr7/Wudz9jG1TTTw9PWFvb29Q1r17dwAw+3kz9aHvJ2N98vDDD+P69evVBml36dLFYLp9+/YA7iWlRK0Fx7AQWciDf+nrdDoIgoD//e9/Ru8YuX/Mw8CBA6HVapGVlYV9+/ZJicmgQYOwb98+nDlzBkVFRQYJy759+/Dkk09i8ODBWLt2LTw8PKBSqbBx40akpKTUGZ+lTZgwAYMGDcL27duxZ88evPXWW1i+fDm2bduGUaNG1auNxt4mQRCMlldVVTXqeupS0x1G4h9jo4haAyYsRDLh7+8PURTh6+sr/SVfk759+8La2hr79u3Dvn37MGfOHAD3niny0UcfISMjQ5rW+/LLL6FWq7F7927Y2NhI5Rs3bqx3jF27dsX58+erlZ89e7ZeywIwe3kA8PDwwIwZMzBjxgwUFhaiT58+WLp0qZSw1JRAmOPatWsoKyszOMty7tw5AJAGverPZNx/xxbw51mS+9U3Nn0/GeuTM2fOwNXVtdqZH6K2gJeEiGTiqaeeglKpRGJiYrW/jEVRxI0bN6RptVqNxx57DJ9++ikuXbpkcIblzp07ePfdd+Hv7w8PDw9pGaVSCUEQDP76z83NxY4dO+od41/+8hccOHAAhw4dksqKioqwZcuWOpf18PBAcHAwNm/ebHAJKj09HadOnap12aqqqmqXrTp16gRPT09oNBqpzN7e3ujlLXNUVlbiww8/lKYrKirw4YcfomPHjggJCQFwL8kEII0n0sf673//u1p79Y3t/n66PxE6ceIE9uzZg7/85S/mbhJRi8YzLEQy4e/vjyVLliAuLg65ubkYP348HB0dcfHiRWzfvh0vvfQSXnvtNan+oEGDsGzZMjg7O6N3794A7h3EH3roIZw9exbTpk0zaH/06NFYvXo1Ro4cieeeew6FhYVYs2YNunXrhmPHjtUrxrlz5+KTTz7ByJEjMWvWLOm25q5du9arjaSkJIwePRoDBw7E888/j5s3b+K9997DI488gtLS0hqXu337Njp37oxnnnkGQUFBcHBwwDfffIOffvoJq1atkuqFhIQgNTUVsbGxeOyxx+Dg4ICxY8fWa9se5OnpieXLlyM3Nxfdu3dHamoqcnJy8O9//1t6Ku0jjzyCfv36IS4uDjdv3oSLiwu2bt2KysrKau2ZEttbb72FUaNGoX///njhhRek25qdnZ2b5f1KRLJkyVuUiNoi/a2oRUVFRud/+eWX4sCBA0V7e3vR3t5e7NGjhzhz5kzx7NmzBvW+/vprEYA4atQog/IXX3xRBCD+5z//qdb2f/7zHzEgIEC0sbERe/ToIW7cuFGK534AaryF+NixY+KQIUNEtVotenl5iYsXLxb/85//1Ou2Zv32Pfzww6KNjY3Ys2dPcdu2bWJUVFSttzVrNBpxzpw5YlBQkOjo6Cja29uLQUFB4tq1aw2WKS0tFZ977jmxXbt2BrdK628z/vzzz6vFU9NtzY888oj4888/i/379xfVarXYtWtX8f3336+2/IULF8Tw8HDRxsZGdHNzE+fPny+mp6dXa7Om2Izd1iyKovjNN9+Ijz/+uGhrays6OTmJY8eOFU+dOmVQp6Z9qabbrYlaMkEUOSqLiIiI5I1jWIiIiEj2mLAQERGR7DFhISIiItljwkJERESyx4SFiIiIZI8JCxEREcleq3hwnE6nw7Vr1+Do6Nioj+YmIiKipiOKIm7fvg1PT08oFLWfQ2kVCcu1a9fg7e1t6TCIiIjIDJcvX0bnzp1rrdMqEhZHR0cA9zbYycnJ7Ha0Wi327NmDESNGSI/ebovYD+wDgH0AsA/02A/sA6Bp+qCkpATe3t7Scbw2rSJh0V8GcnJyanDCYmdnBycnpza7QwLsB4B9ALAPAPaBHvuBfQA0bR/UZzgHB90SERGR7DFhISIiItljwkJERESy1yrGsNSHKIqorKxEVVVVjXW0Wi2srKxw9+7dWuu1di2xH1QqFZRKpaXDICKiJtImEpaKigrk5eWhvLy81nqiKMLd3R2XL19u089zaYn9IAgCOnfuDAcHB0uHQkRETaDVJyw6nQ4XL16EUqmEp6cnrK2tazwI63Q6lJaWwsHBoc4H2LRmLa0fRFFEUVERrly5goCAAJ5pISJqhVp9wlJRUQGdTgdvb2/Y2dnVWlen06GiogJqtbpFHKibSkvsh44dOyI3NxdarZYJCxFRK9QyjkaNoKUceMk8LeXSFRERmYdHcSIiIpI9JixEREQke61+DEtt3k4/ZzAtiiI0Gg1sbGya5BLDqxHdTao/dOhQfPfddwCAI0eOIDg4uNFjamyCIGD79u0YP358o7SXmZmJYcOGAQDGjRuHHTt2NEq7RETUsvAMi8xNnz4deXl56NWrl6VDMZCQkGA0gcrLy8OoUaMabT0DBgxAXl4eJkyY0GhtEhFRy9Omz7C0BHZ2dnB3d7d0GPXW2LFaW1vD3d0dtra20Gg0jdo2ERG1HDzD0oJkZmZCEARkZGQgNDQUdnZ2GDBgAM6ePWtQ76uvvkKfPn2gVqvh5+eHxMREVFZWSvPPnDmDgQMHQq1Wo2fPnvjmm28gCILB5Zb4+Hj06NEDdnZ28PPzw4IFC6DVagEAmzZtQmJiIo4ePQpBECAIAjZt2gQABu0MGDAAr7/+ukFsRUVFUKlU+P777wEAGo0Gr732Gry8vGBvb4+wsDBkZmY2bscREVGLxzMsLdAbb7yBVatWoWPHjnj55Zfx/PPP48cffwQA7Nu3D1OnTsW7776LQYMG4cKFC3jppZcA3EtCqqqqMH78eHTp0gUHDx7E7du38a9//avaOhwdHbFhwwZ07twZx48fx/Tp0+Ho6Ii5c+di4sSJOHHiBNLS0vDNN98AAJydnau1MWXKFKxYsQLLli2TxgSlpqbC09MTgwYNAgDExMTg1KlT2Lp1Kzw9PbF9+3aMHDkSx48fR0BAQJP0H9Vsbc5a6bOgE+AJT6w/vh6iQrRgVMCM4BkWXT8RWR7PsLRAS5cuxZAhQ9CzZ0/MmzcP+/fvx927dwEAiYmJmDdvHqKiouDn54eIiAgsXrwYH374IQAgPT0dFy5cwMcff4ygoCAMHDgQS5curbaO1157DQMGDICPjw/Gjh2L1157DZ999hkAwNbWFg4ODrCysoK7u7t0yeZBEyZMwLVr1/DDDz9IZSkpKZg8eTIEQcClS5ewceNGfP755xg0aBD8/f3x2muvYeDAgdi4cWNTdB0REbVQPMPSAgUGBkqfPTw8AACFhYXo0qULjh49ih9//NEgCamqqsLdu3dRXl6Os2fPwtvb22CsSd++fautY9u2bfjPf/6DCxcuoLS0FJWVlXBycjIpzo4dO2LEiBHYsmULBg0ahIsXLyIrK0tKno4fP46qqip0725495RGo0GHDh1MWhcREbVuTFhaIJVKJX3WX2rR6XQAgNLSUiQmJuKpp56qtpxara5X+1lZWXjppZeQkJCAkSNHwtnZGVu3bsWqVatMjnXKlCn45z//iffeew8pKSno3bs3evfuLcWqVCqRnZ1d7XH6fIkhERHdz6xLQmvWrIGPjw/UajXCwsJw6NChGutu27YNoaGhaNeuHezt7REcHIxPPvnEoM60adOkwZv6n5EjR5oTWpvXp08fnD17Ft26dav2o1Ao8NBDD+Hy5csoKCiQlvnpp58M2sjKyoK3tzfmz5+P0NBQBAQE4LfffjOoY21tjaqqqjrjGTduHO7evYu0tDSkpKRgypQp0rxHH30UVVVVKCwsrBZrS7ozioiImp7JZ1hSU1MRGxuLdevWISwsDMnJyYiMjMTZs2fRqVOnavVdXFzwxhtvoEePHrC2tsbOnTsRHR2NTp06ITIyUqo3cuRIg3ELNjY2Zm5S27Zw4UKMGTMGXbp0wTPPPAOFQoGjR4/ixIkTWLJkCSIiIuDv74+oqCisWLECt2/fxptvvgngz7M13bp1w5UrV7B161aEhYXh66+/xvbt2w3W4+Pjg4sXLyInJwedO3eGo6Oj0e/M3t4e48ePx4IFC3D69GlMnjxZmte9e3dMmTIFU6dOxapVq/Doo4+iqKgIGRkZCAwMxOjRo5uwp4iIqCUxOWFZvXo1pk+fjujoaADAunXr8PXXX2PDhg2YN29etfpDhw41mJ41axY2b96MH374wSBhsbGxafa/qh988qxOp0NJSQmcnJxa7MsSIyMjsXPnTixatAjLly+HSqVCjx498OKLLwIAlEolduzYgRdffBGPPfYY/Pz88NZbb2Hs2LHSJaMnn3wSr7zyCv75z39Co9Fg9OjRWLBgARISEqT1PP3009i2bRuGDRuG4uJibNy4EdOmTTMa05QpU/CXv/wFgwcPRpcuXQzmbdy4EUuWLMG//vUvXL16Fa6urujXrx/GjBnTJP1DREQtk0kJS0VFBbKzsxEXFyeVKRQKhIeHIysrq87lRVHEt99+i7Nnz2L58uUG8zIzM9GpUye0b98ew4cPx5IlS2oceKnRaAweIlZSUgIA0Gq10rNC9LRaLURRhE6nk8Z51Baf/t+66jaX+2MZPHiwdBlGXxYYGFitLCIiAhEREdXa0s/v3r279BwUANIt0X5+ftDpdBBFEYsWLcLbb79t8IqCf/7zn1IbKpVKumvo/vYfjAW4l0QZKwfuJVDx8fGIj4+vMV59P9T2vejj1mq11cbDmEO/Hz24P7V2gk6o9vn+Mkux1PfQVveDB7Ef2AdA0/SBKW0Jov4oXQ/Xrl2Dl5cX9u/fj/79+0vlc+fOxXfffYeDBw8aXe7WrVvw8vKCRqOBUqnE2rVr8fzzz0vzt27dCjs7O/j6+uLChQuYP38+HBwckJWVZfTgk5CQgMTExGrlKSkpsLOzMyjT33rr7e0Na2vr+m6qLIwZMwaHDh2CtbU1du/ejUceeaRR2t25cyfs7e3h7++PX3/9FXFxcXB2dkZaWlqjtN+Y9u/fjwkTJkCj0Uh3HBlTUVGBy5cvIz8/3+AheUREJF/l5eV47rnncOvWrTrvRG2Wu4QcHR2Rk5OD0tJSZGRkIDY2Fn5+ftLlokmTJkl1e/fujcDAQPj7+yMzMxNPPPFEtfbi4uIQGxsrTZeUlMDb2xsjRoyotsF3797F5cuX4eDgUOddMqIo4vbt23B0dGySlx+a6tNPP8WdO3cAAF26dGm0hKuyshKvv/46Ll26BFdXVzzxxBNYuXKl1Hdy6ochQ4bg8OHDAO7dOVTTDn337l3Y2tpi8ODB9b4bqjZarRbp6emIiIgwuCurtVt/fL30WdAJ8LjqgTyvPIs/OO7F3i9aZL1tdT94EPuBfQA0TR/or5DUh0kJi6urK5RKpcEdJgBQUFBQ6/gThUKBbt26AQCCg4Nx+vRpJCUlVRvfoufn5wdXV1f88ssvRhMWGxsbowM8VSpVtU6sqqqCIAhQKBR1jkvRX27Q17c0b2/vJml32rRpNY43AeTVD/b29tWe02KMQqGAIAhG94GGaOz25M5YYiIqRIsnLJb+DtraflAT9gP7AGjcPjClHZOORtbW1ggJCUFGRoZUptPpkJGRYXCJqC46na7WF9lduXIFN27ckB6KRkRERG2byZeEYmNjERUVhdDQUPTt2xfJyckoKyuT7hqaOnUqvLy8kJSUBABISkpCaGgo/P39odFosGvXLnzyySf44IMPAPz5oLOnn34a7u7uuHDhAubOnYtu3boZ3EVEREREbZfJCcvEiRNRVFSEhQsXIj8/H8HBwUhLS4ObmxsA4NKlSwaXEcrKyjBjxgxcuXIFtra26NGjB/773/9i4sSJAO7dJXLs2DFs3rwZxcXF8PT0xIgRI7B48WI+i4WIiIgAmDnoNiYmBjExMUbnZWZmGkwvWbIES5YsqbEtW1tb7N6925wwiIiIqI2w/MhSIiIiojowYSEiIiLZa9tva96bZDApiCLUGg0EGxugKZ4/Miyu7jr3GTp0KL777jsAwJEjRxAcHNz4MTWDTZs2Yfbs2SguLpam9YO0Z82aheTkZMsFR0RELQLPsMjc9OnTkZeXh169ejXbOjMzM9G+fXspwWhsEydORF5enkm3whMRUdvWts+wtAB2dnbN/lLI+qqoqDDr6bu2trawtbVtca9KICIiy+EZlhYkMzMTgiAgIyMDoaGhsLOzw4ABA3D27FmDel999RX69OkDtVoNPz8/JCYmSu/Xyc3NhSAIyMnJkeoXFxdDEARkZmYiNzdXerpwhw4dIAiC9FTcoUOHIiYmBrNnz4arq6v0nJzVq1ejd+/esLe3h7e3N2bMmIHS0tKm7xAiImozmLC0QG+88QZWrVqFn3/+GVZWVgYvkty3bx+mTp2KWbNm4dSpU/jwww+xadMmLF26tF5te3t74/PPPwcAnD59Gnl5eXjnnXek+Zs3b4a1tTV+/PFHrFu3DsC9x+K/++67OHnyJDZv3oxvv/0Wc+fObcQtJiKito6XhFqgpUuXYsiQIQCAefPmYfTo0bh79y7UajUSExMxb948REVFAbj3XqbFixdj7ty5iI+Pr7NtpVIJFxcXAECnTp2kz3oBAQFYsWKFQdns2bOlzz4+PliyZAlefvllrF27tiGbSUREJGHC0gIFBgZKn/XvWyosLESXLl1w9OhR/PjjjwZnVKqqqnD37l2Ul5c3eN0hISHVyr755hskJSXhzJkzKCkpQWVlpbQ+Ozu7Bq+TiIiICUsLdP/bLYU/br/Wv2FZ/26mp556qtpyarVaem2CKP759l2tVlvvddvb2xtM5+bmYsyYMXjllVewdOlSuLi44IcffsALL7yAiooKJixERNQomLC0Mn369MHZs2fRrVs3o/M7duwIAMjLy8Ojjz4KAAYDcAFId+9UVVXVub7s7GzodDqsWrVKSoY+++wzc8MnIiIyiglLK7Nw4UKMGTMGXbp0wTPPPAOFQoGjR4/ixIkTWLJkCWxtbdGvXz8sW7YMvr6+KCwsxJtvvmnQRteuXSEIAnbu3IkxY8bA1tYWDg4ORtfXrVs3aLVavPfeexg7dqzBYFwiIqLG0rYTlgeePCvqdLhbUgJrJycIipZ5A1VkZCR27tyJRYsWYfny5VCpVOjRowdefPFFqc6GDRvwwgsvICQkBA899BBWrFiBESNGSPO9vLwQFxeH+fPn44UXXsDUqVOxadMmo+sLCgrC6tWrsXz5csTFxWHw4MFISkrC1KlTm3pTiYioDWnbCUsLM3ToUIOxJwAQHBxcrSwyMlJ6RooxDz/8MPbv329Q9mAbc+bMweLFi6XLPED1N3Hrvfrqq3j11VcNyv72t79Jn6dNmyY9y4WIiMgcLfM0Qhuydu1aODg44Pjx45YOpdFs2bIFDg4O2Ldvn6VDISKiFoJnWGRsy5YtuHPnDgCgS5cuFo6m8Tz55JMICwsDALRr186ywRARUYvAhEXGvLy8LB1Ck3B0dISjo6OlwyAiohaEl4SIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPdwkREdXH3iRLR2DcA0/spnoy9fsUFQB6APtWA4KuSUICwO+zFm06YVmbs9ZgWhRFaDQa2NjYSG9BbkwzgmeYVH/o0KH47rvvAABHjhxBcHBwo8dkbJ1BQUFITEyssc6mTZswe/ZsFBcXN9p6p02bhs2bNwMAtm/fjvHjxzda20RE1PLxkpDMTZ8+HXl5eejVq1ezrG/btm1YtGiRNO3j44Pk5GSDOhMnTsS5c+cadb3vvPMO8vLyGrVNIiJqPdr0GZaWwM7ODu7u7s22PhcXF+h0OpSUlNRYx9bWFra2to26XmdnZzg7Ozdqm0RE1HrwDEsLkpmZCUEQ8PXXXyMwMBBqtRr9+vXDiRMnDOp9+eWXeOSRR2BjYwMfHx+sWrXKYP7atWsREBAAtVoNNzc3PPPMM9K8oUOHSi8yHD58OH777Te8+uqrEARBuky2adMm6ZH6586dgyAIOHPmjME63n77bfj7+0vTJ06cwKhRo+Dg4AA3Nzf87W9/w/Xr1xutb4iIqHVjwtICzZkzB6tWrcJPP/2Ejh07YuzYsdBqtQCA7OxsTJgwAZMmTcLx48eRkJCABQsWYNOmTQCAn3/+Gf/85z+xaNEinD17FmlpaRg8eLDR9XzxxRfo3LkzFi1ahLy8PKOXbLp3747Q0FBs2bLFoHzLli147rnnAADFxcUYPnw4Hn30Ufz8889IS0tDQUEBJkyY0Ii9QkRErRkvCbVA8fHxiIiIAABs3rwZnTt3xvbt2zFhwgSsXr0aTzzxBBYsWADgXkJx6tQpvPXWW5g2bRouXboEe3t7jBkzBo6OjujatSseffRRo+txcXGBUqmEo6NjrZelpkyZgvfffx+LFy8GcO+sS3Z2Nv773/8CAN5//308+uij+L//+z9pmQ0bNsDb2xvnzp1D9+7dG6VfiIio9eIZlhaof//+0mcXFxc89NBDOH36NADg9OnTePzxxw3qP/744zh//jyqqqoQERGBrl27ws/PD3/729+wZcsWlJeXNyieSZMmITc3FwcOHABw7+xKnz590KNHDwDA0aNHsXfvXjg4OEg/+nkXLlxo0LqJiKhtYMLSxjg6OuLw4cP49NNP4eHhgYULFyIoKKhBtyi7u7tj+PDhSElJAQCkpKRgypQp0vzS0lKMHTsWOTk5Bj/nz5+v8XIUERHR/ZiwtED6MxkA8Pvvv+PcuXN4+OGHAQAPP/wwfvzxR4P6P/74I7p37w6lUgkAsLKyQnh4OFasWIFjx44hNzcX3377rdF1WVtbo6qqqs6YpkyZgtTUVGRlZeHXX3/FpEmTpHl9+vTByZMn4ePjg27duhn82Nvbm7z9RETU9jBhaYEWLVqEjIwMnDhxAtOmTYOrq6v0oLV//etfyMjIwOLFi3Hu3Dls3rwZ77//Pl577TUAwM6dO/Huu+8iJycHv/32Gz7++GPodDo89NBDRtfl4+OD77//HlevXq31rp6nnnoKt2/fxiuvvIJhw4bB09NTmjdz5kzcvHkTkydPxk8//YQLFy5g9+7diI6OrlcyRERE1KYH3T745Fn980ecnJygUMg3l1u2bBlmzZqF8+fPIzg4GP/v//0/WFtbA7h3NuOzzz7DwoULsXjxYnh4eGDRokWYNm0aAKBdu3bYtm0bEhIScPfuXQQEBODTTz/FI488YnRdixYtwt///nf4+/tDo9FAFEWj9RwdHTF27Fh89tln2LBhg8E8T09P/Pjjj3j99dcxYsQIaDQadO3aFSNHjpR1PxMRkXy06YSlpRo4cGC1Z6/c7+mnn8bTTz9d47KZmZk1LpuZmWnw4Lh+/frh6NGjBnWmTZsmJUD3S01NRWpqqtF2AwICsG3bthrXS0REVBv+eStza9euhYODA44fP27pUJrUyy+/DAcHB0uHQUREMsUzLDK2ZcsW3LlzBwDQpUsX7N+/38IRNZ1FixZJ42w8PDwsHA0REckNExYZ8/LyMpgeOnRojWNIWrpOnTqhU6dOlg6DiIhkyqxLQmvWrIGPjw/UajXCwsJw6NChGutu27YNoaGhaNeuHezt7REcHIxPPvnEoI4oili4cCE8PDxga2uL8PBwnD9/3pzQiIiIqBUyOWFJTU1FbGws4uPjcfjwYQQFBSEyMhKFhYVG67u4uOCNN95AVlYWjh07hujoaERHR2P37t1SnRUrVuDdd9/FunXrcPDgQdjb2yMyMhJ37941f8se0FrPTNA9/H6JiFo3kxOW1atXY/r06YiOjkbPnj2xbt062NnZVbuVVW/o0KH461//iocffhj+/v6YNWsWAgMD8cMPPwC4d6BJTk7Gm2++iXHjxiEwMBAff/wxrl27hh07djRo4wBApVIBQIMfP0/yVlFRAQDSw/GIiKh1MWkMS0VFBbKzsxEXFyeVKRQKhIeHIysrq87lRVHEt99+i7Nnz2L58uUAgIsXLyI/Px/h4eFSPWdnZ4SFhSErK8vgial6Go0GGo1GmtbfgqvVaqW3Ft/P0dERBQUF0Ol0sLOzgyAINcZXUVGBO3fu1FinLWhp/aDT6VBYWAi1Wg1RFI3uA6bSt9EYbbUkgk6o9vn+Mkux1PdgsB+IMr2pshn6plX+Ppj4fWr/qK9t6v1Axn3cFPuBKW2ZlLBcv34dVVVVcHNzMyh3c3PDmTNnalzu1q1b8PLygkajgVKpxNq1a6W3Defn50ttPNimft6DkpKSkJiYWK18z549sLOzM7qMo6MjysrK+KCyVkqr1aKoqAjHjh1r1HbT09MbtT2584RntTKPq5a/a2vX5V0WXf+9/aCHRWOo0a7m65vW9ftg3veZXtrEb5dvxu/TXI25H5hy9aNZ7hJydHRETk4OSktLkZGRgdjYWPj5+WHo0KFmtRcXF4fY2FhpuqSkBN7e3hgxYgScnJxqXK6qqgqVlZU1jneorKzE/v37MWDAAFhZtd0bqFpaPwiCAJVK1ajJqFarRXp6OiIiIqTLim3B+uPrpc+CToDHVQ/keeVBVFh2jNCLvV+0yHoN9oMD71kkhjoNiq27TgO1yt+HfatNqq4VFUgv7Y4Ih3NQCbomCgrN8n2aqyn2A/0Vkvow6Wjk6uoKpVKJgoICg/KCggK4u7vXuJxCoUC3bt0AAMHBwTh9+jSSkpIwdOhQabmCggKD528UFBQgODjYaHs2NjawsbGpVq5SqWrtxLo6WKvVorKyEg4ODq3nl9IM7Ic/1bVPtTbGEhNRIVo8YbH0d6BSqZr2INUQzdg3rer3wczvUyXomnZfaAH925j7gSntmPQnqbW1NUJCQpCRkSGV6XQ6ZGRkoH///vVuR6fTSWNQfH194e7ubtBmSUkJDh48aFKbRERE1HqZfL4/NjYWUVFRCA0NRd++fZGcnIyysjJER0cDAKZOnQovLy8kJSUBuDfeJDQ0VHp53q5du/DJJ5/ggw8+AHDvdP7s2bOxZMkSBAQEwNfXFwsWLICnp6f0BmIiIiJq20xOWCZOnIiioiIsXLgQ+fn5CA4ORlpamjRo9tKlSwZjCcrKyjBjxgxcuXIFtra26NGjB/773/9i4sSJUp25c+eirKwML730EoqLizFw4ECkpaVBrVY3wiYSERFRS2fWiMqYmBjExMQYnffgm4CXLFmCJUuW1NqeIAhYtGgRFi1aZE44RERE1MrxHl8iIiKSPSYsREREJHtMWIiIiEj2mLAQERGR7DFhISIiItljwkJERESyx4SFiIiIZI8JCxEREckeExYiIiKSPSYsREREJHtMWIiIiEj2zHqXEFFr8Xb6OaPlglgFXwBr9v4CUVA2a0yvRnRv1vUREbUEPMNCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0rSwdARIbeTj9nsXUfLrkhfbaCAuOtPC0WCxHR/XiGhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9sxKWNasWQMfHx+o1WqEhYXh0KFDNdb96KOPMGjQILRv3x7t27dHeHh4tfrTpk2DIAgGPyNHjjQnNCIiImqFTE5YUlNTERsbi/j4eBw+fBhBQUGIjIxEYWGh0fqZmZmYPHky9u7di6ysLHh7e2PEiBG4evWqQb2RI0ciLy9P+vn000/N2yIiIiJqdUxOWFavXo3p06cjOjoaPXv2xLp162BnZ4cNGzYYrb9lyxbMmDEDwcHB6NGjB9avXw+dToeMjAyDejY2NnB3d5d+2rdvb94WERERUatj0oPjKioqkJ2djbi4OKlMoVAgPDwcWVlZ9WqjvLwcWq0WLi4uBuWZmZno1KkT2rdvj+HDh2PJkiXo0KGD0TY0Gg00Go00XVJSAgDQarXQarWmbJIB/bINaaM1aEv9IIhVtZbXNL+1srrvbxj9Z0EnWCociaX2RYPfBVGmQ/6aoW9a5f8JJn6f2j/qa5t6P5BxHzfFfmBKW4IoimJ9K1+7dg1eXl7Yv38/+vfvL5XPnTsX3333HQ4ePFhnGzNmzMDu3btx8uRJqNVqAMDWrVthZ2cHX19fXLhwAfPnz4eDgwOysrKgVCqrtZGQkIDExMRq5SkpKbCzs6vv5hAREZEFlZeX47nnnsOtW7fg5ORUa91mfTT/smXLsHXrVmRmZkrJCgBMmjRJ+ty7d28EBgbC398fmZmZeOKJJ6q1ExcXh9jYWGm6pKREGhtT1wbXRqvVIj09HREREVCpVGa309K1pX5Ys/cXo+WCWAWfuxeQq/aHKFRPmluro7e3SZ+toMAYq57I88qDqKj33zVN4sXeL1pkvQa/Cwfes0gMdRoUW3edBmqV/yfsW21Sda2oQHppd0Q4nINK0DVRUPKm7fePRt8P9FdI6sOkhMXV1RVKpRIFBQUG5QUFBXB3d6912ZUrV2LZsmX45ptvEBgYWGtdPz8/uLq64pdffjGasNjY2MDGxqZauUqlapRObKx2Wrq20A91JSOioGxTCUslqv9HLCpEiycslt4PVSqVfA9Szdg3rer/BDO/T5Wgk+++0NT++O4bcz8wpR2TLsZZW1sjJCTEYMCsfgDt/ZeIHrRixQosXrwYaWlpCA0NrXM9V65cwY0bN+Dh4WFKeERERNRKmTx6KDY2Fh999BE2b96M06dP45VXXkFZWRmio6MBAFOnTjUYlLt8+XIsWLAAGzZsgI+PD/Lz85Gfn4/S0lIAQGlpKebMmYMDBw4gNzcXGRkZGDduHLp164bIyMhG2kwiIiJqyUwewzJx4kQUFRVh4cKFyM/PR3BwMNLS0uDm5gYAuHTpEhSKP/OgDz74ABUVFXjmmWcM2omPj0dCQgKUSiWOHTuGzZs3o7i4GJ6enhgxYgQWL15s9LIPERERtT1mDbqNiYlBTEyM0XmZmZkG07m5ubW2ZWtri927d5sTBhEREbURMn2wABEREdGfmLAQERGR7DFhISIiItljwkJERESyx4SFiIiIZI8JCxEREckeExYiIiKSPSYsREREJHtMWIiIiEj2mLAQERGR7DFhISIiItljwkJERESyx4SFiIiIZI8JCxEREckeExYiIiKSPSYsREREJHtMWIiIiEj2mLAQERGR7DFhISIiItmzsnQA1LqszVlr6RCqmRE8w9IhUCu0tviYpUO4pxl+5wSdAE94Yv3x9RAVosnL83eQGgPPsBAREZHsMWEhIiIi2WPCQkRERLLHhIWIiIhkjwkLERERyR4TFiIiIpI93tZMzSrrwo1mX6em6Fyzr7M1OXTxJiqhs2gM93+Hr0Z0t2AkRGQpPMNCREREsseEhYiIiGSPl4SIiKhpXNx379/fb1k2DmoVeIaFiIiIZI8JCxEREckeExYiIiKSPSYsREREJHtMWIiIiEj2mLAQERGR7DFhISIiItljwkJERESyx4SFiIiIZM+shGXNmjXw8fGBWq1GWFgYDh06VGPdjz76CIMGDUL79u3Rvn17hIeHV6sviiIWLlwIDw8P2NraIjw8HOfPnzcnNCIiImqFTE5YUlNTERsbi/j4eBw+fBhBQUGIjIxEYWGh0fqZmZmYPHky9u7di6ysLHh7e2PEiBG4evWqVGfFihV49913sW7dOhw8eBD29vaIjIzE3bt3zd8yIiIiajVMfpfQ6tWrMX36dERHRwMA1q1bh6+//hobNmzAvHnzqtXfsmWLwfT69evx5ZdfIiMjA1OnToUoikhOTsabb76JcePGAQA+/vhjuLm5YceOHZg0aVK1NjUaDTQajTRdUlICANBqtdBqtaZukkS/bEPaaA0a0g+CTqh1vpUFrkIKYpXZy5izbEt2//ej/2yJ7+xB938Pzfn7afC7IBr2gyCXV7HV8TvXGPS/13X9fld3r4+0ouX3oYbSb0Nr2BZzNcUx0pS2BFEUxfpWrqiogJ2dHb744guMHz9eKo+KikJxcTG++uqrOtu4ffs2OnXqhM8//xxjxozBr7/+Cn9/fxw5cgTBwcFSvSFDhiA4OBjvvPNOtTYSEhKQmJhYrTwlJQV2dnb13RwiIiKyoPLycjz33HO4desWnJycaq1r0p8I169fR1VVFdzc3AzK3dzccObMmXq18frrr8PT0xPh4eEAgPz8fKmNB9vUz3tQXFwcYmNjpemSkhLpUlNdG1wbrVaL9PR0REREQKVSmd1OS9eQflh/fH2t8w9dvNmQ0MwS5PiUycsIYhV87l5ArtofoqBsgqjk6ejtbdJnKygwxqondlaeQiV0FozK8DucOaxbs63X4HfhwHsG89bfOtFscdSq64AmX4WgE+Bx1QN5XnkQFfX+Gxf4bT8A4EXnXk0UWfPRigqkl3ZHhMM5qATL/j5YirbfPxr9GKm/QlIfzXpOc9myZdi6dSsyMzOhVqvNbsfGxgY2NjbVylUqVaN0YmO109KZ0w91/WdmiQNfQxIOUVC2qYTF2PdTCZ3FE5b7vwNL/G6qVKpqBykRlc0eh1GmJBANJCpE0xKWP/qoNR3gVYKuVW2PSf743WvMY6Qp7Zh0Mc7V1RVKpRIFBQUG5QUFBXB3d6912ZUrV2LZsmXYs2cPAgMDpXL9cua0SURERG2DSQmLtbU1QkJCkJGRIZXpdDpkZGSgf//+NS63YsUKLF68GGlpaQgNDTWY5+vrC3d3d4M2S0pKcPDgwVrbJCIiorbD5EtCsbGxiIqKQmhoKPr27Yvk5GSUlZVJdw1NnToVXl5eSEpKAgAsX74cCxcuREpKCnx8fKRxKQ4ODnBwcIAgCJg9ezaWLFmCgIAA+Pr6YsGCBfD09DQY2EtERERtl8kJy8SJE1FUVISFCxciPz8fwcHBSEtLkwbNXrp0CQrFnyduPvjgA1RUVOCZZ54xaCc+Ph4JCQkAgLlz56KsrAwvvfQSiouLMXDgQKSlpTVonAsRERG1HmYNuo2JiUFMTIzReZmZmQbTubm5dbYnCAIWLVqERYsWmRMOERERtXJt9wk4RERE1GIwYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJnllvayYiak6HS1Klz2tzOjTbegWdAE94Yv3x9RCLjzXbeomoOp5hISIiItljwkJERESyx4SFiIiIZI8JCxEREckeExYiIiKSPd4lRETUAJeL71h0/Vcu3GjydVhBgfFWnjh08SYqoav3cp1L/uibdk0TF7UtPMNCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGTPrIRlzZo18PHxgVqtRlhYGA4dOlRj3ZMnT+Lpp5+Gj48PBEFAcnJytToJCQkQBMHgp0ePHuaERkRERK2QyQlLamoqYmNjER8fj8OHDyMoKAiRkZEoLCw0Wr+8vBx+fn5YtmwZ3N3da2z3kUceQV5envTzww8/mBoaERERtVJWpi6wevVqTJ8+HdHR0QCAdevW4euvv8aGDRswb968avUfe+wxPPbYYwBgdL4UiJVVrQlNS/B2+jlLh1DNqxHdLR0CERFRg5mUsFRUVCA7OxtxcXFSmUKhQHh4OLKyshoUyPnz5+Hp6Qm1Wo3+/fsjKSkJXbp0MVpXo9FAo9FI0yUlJQAArVYLrVZrdgz6Zc1tQxCrzF53UzFnWxrSD4JOqHW+lQWGTZnzveiXkeN32pTu/370ny3xndWmrn2sKdZ171/j/10qoWq2eIxpju/H3H1B3zdaUV77kDn029AatsVcDT1G1tZmfQiiKIr1rXzt2jV4eXlh//796N+/v1Q+d+5cfPfddzh48GCty/v4+GD27NmYPXu2Qfn//vc/lJaW4qGHHkJeXh4SExNx9epVnDhxAo6OjtXaSUhIQGJiYrXylJQU2NnZ1XdziIiIyILKy8vx3HPP4datW3Bycqq1rsmXhJrCqFGjpM+BgYEICwtD165d8dlnn+GFF16oVj8uLg6xsbHSdElJCby9vTFixIg6N7g2Wq0W6enpiIiIgEpl+l9Na/b+Yva6m8rMYd1MXqYh/bD++Ppa5x+6eNPkeBoqyPEpk5cRxCr43L2AXLU/REHZBFHJ09Hb26TPVlBgjFVP7Kw8hUroLBiVob6+Ls22LkEnwOOqB/K88iBe/tFonavFd5stHqPrdwpu8nWYuy94leQAAOK7hjZRZM1HKyqQXtodEQ7noBLk8/vQnLT9/tGgY6Qx+isk9WFSwuLq6gqlUomCggKD8oKCgkYdf9KuXTt0794dv/xiPAGwsbGBjY1NtXKVStUonWhuO3I8sDWkP8zpB1FR+wk7Sxz4GvK9iIJSlt9rUzH2/VRCJ6uEpa59rKnWKaLS6LwqNN7pcXM053dj6r6g75vWdIBXCbpWtT0m+eN40FjHWn1b9WXSxThra2uEhIQgIyNDKtPpdMjIyDC4RNRQpaWluHDhAjw8PBqtTSIiImq5TL4kFBsbi6ioKISGhqJv375ITk5GWVmZdNfQ1KlT4eXlhaSkJAD3BuqeOnVK+nz16lXk5OTAwcEB3brdu1zx2muvYezYsejatSuuXbuG+Ph4KJVKTJ48ubG2k4iIiFowkxOWiRMnoqioCAsXLkR+fj6Cg4ORlpYGNzc3AMClS5egUPx54ubatWt49NFHpemVK1di5cqVGDJkCDIzMwEAV65cweTJk3Hjxg107NgRAwcOxIEDB9CxY8cGbh4RERG1BmYNuo2JiUFMTIzRefokRM/Hxwd13Yi0detWc8IgIiKiNkIWdwkRmatzSXaddfoV3zK5XZ2gxPUOA/DYlU1QtMJnsRzo8pKlQyAiMknbfQIOERERtRhMWIiIiEj2mLAQERGR7DFhISIiItljwkJERESyx4SFiIiIZI8JCxEREcken8NCrd7/pzD9LdpKqPAYBuB/wq+oEprm5XZP6kx/kzYRUVvFMyxEREQke0xYiIiISPaYsBAREZHscQwLERE1qaxfb1g6BAP9/TpYOgQyA8+wEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkexZWToAorbq/1P8YrF1XylJtdi6qe2x5L5uzJHiq5jRLtDSYZCJeIaFiIiIZI8JCxEREckeExYiIiKSPY5hIaIWJevCjWZblxUUGG/liUMXb8K95E6zrZeIquMZFiIiIpI9JixEREQke0xYiIiISPbMSljWrFkDHx8fqNVqhIWF4dChQzXWPXnyJJ5++mn4+PhAEAQkJyc3uE0iIiJqW0xOWFJTUxEbG4v4+HgcPnwYQUFBiIyMRGFhodH65eXl8PPzw7Jly+Du7t4obRIREVHbYnLCsnr1akyfPh3R0dHo2bMn1q1bBzs7O2zYsMFo/cceewxvvfUWJk2aBBsbm0Zpk4iIiNoWk25rrqioQHZ2NuLi4qQyhUKB8PBwZGVlmRWAOW1qNBpoNBppuqSkBACg1Wqh1WrNikO//P3/mkoQq8xed1MxZ1sa0g+CTqh1vlUjD5tSQtWo7ekp/mhX0UTtW1p9vgd9ncb+zlqS+/ugqfa1hmqO78fcfUGufSbAClrRtG3R1zd1udakocfI2tqsD5MSluvXr6Oqqgpubm4G5W5ubjhz5owpTTWozaSkJCQmJlYr37NnD+zs7MyK437p6elmLefb4DU3vl27zpm9rDn94AnPWuePt6p9vslcejVuew8IcZnYpO1bymMm1B1j1bPJ4mgpxlj1BFzk2Q+mfJcNZfK+0MS/nw2x67Z5y6WXdm/cQFqSP44J5h4jjSkvL6933Rb54Li4uDjExsZK0yUlJfD29saIESPg5ORkdrtarRbp6emIiIiASmX6XwZr9srrBV8AMHNYN5OXaUg/rD++vtb5hy7eNDme2niV5DRqe3oKqBDiMhHZN1OhQ+P9NSEXV52C66xjBQXGWPXEzspTqISu6YOSofv7wK3ksKXDMao+32VDmbsvNNXvZ0N5tVPjRWfTkimtqEB6aXdEOJyDSmibvw/afv9o0DHSGP0VkvowKWFxdXWFUqlEQUGBQXlBQUGNA2qbok0bGxuj42FUKlWjdKK57YiCssHrbmwN6Q9z+kFUiLXOb+wDX1UTJxM6aJt8HZZgyvdQCV2bTVj0KqGT7X7QnN+NqfuCXPtMhJXZSYdK0LXZhAV/HA8a61irb6u+TLoYZ21tjZCQEGRkZEhlOp0OGRkZ6N+/vylNNWmbRERE1LqYfEkoNjYWUVFRCA0NRd++fZGcnIyysjJER0cDAKZOnQovLy8kJSUBuDeo9tSpU9Lnq1evIicnBw4ODujWrVu92iQiIqK2zeSEZeLEiSgqKsLChQuRn5+P4OBgpKWlSYNmL126BIXizxM3165dw6OPPipNr1y5EitXrsSQIUOQmZlZrzaJiIiobTNr0G1MTAxiYmKMztMnIXo+Pj4QxdrHNdTVJhEREbVtbfeGciIiImoxmLAQERGR7DFhISIiItlrkQ+Oa25rc9bWq97hkhtNHInp1uZ0MHkZQSfAE55Yf3x9nc9VISIiag5MWIhINjqXZFs6BANKqACXXvAqyYH83hR2T3P0mWE/yPNhcNT68ZIQERERyR7PsLRyWRdMv0xlBQXGW3ni0MWbbf6R7EREJA88w0JERESyx4SFiIiIZI8JCxEREckeExYiIiKSPSYsREREJHtMWIiIiEj2mLAQERGR7DFhISIiItljwkJERESyx4SFiIiIZI8JCxEREckeExYiIiKSPSYsREREJHtMWIiIiEj2rCwdABE1v84l2XXWUUIFuPSCV0kOqqBthqiIiGrGMyxEREQke0xYiIiISPaYsBAREZHsMWEhIiIi2WPCQkRERLLHhIWIiIhkjwkLERERyR4TFiIiIpI9JixEREQke0xYiIiISPaYsBAREZHsMWEhIiIi2WPCQkRERLLHhIWIiIhkjwkLERERyR4TFiIiIpI9sxKWNWvWwMfHB2q1GmFhYTh06FCt9T///HP06NEDarUavXv3xq5duwzmT5s2DYIgGPyMHDnSnNCIiIioFbIydYHU1FTExsZi3bp1CAsLQ3JyMiIjI3H27Fl06tSpWv39+/dj8uTJSEpKwpgxY5CSkoLx48fj8OHD6NWrl1Rv5MiR2LhxozRtY2Nj5iZRU+hckm3pEIiIqA0z+QzL6tWrMX36dERHR6Nnz55Yt24d7OzssGHDBqP133nnHYwcORJz5szBww8/jMWLF6NPnz54//33DerZ2NjA3d1d+mnfvr15W0REREStjklnWCoqKpCdnY24uDipTKFQIDw8HFlZWUaXycrKQmxsrEFZZGQkduzYYVCWmZmJTp06oX379hg+fDiWLFmCDh06GG1To9FAo9FI0yUlJQAArVYLrVZryiYZ0C/7YBuCTqjX8latZEiQfjvu3x4lVJYKxyIUf2yvoo1t9/3YB+wDvdbWDwKsoBVN+/9aX9/U5VqTmo6RjdFmfZiUsFy/fh1VVVVwc3MzKHdzc8OZM2eMLpOfn2+0fn5+vjQ9cuRIPPXUU/D19cWFCxcwf/58jBo1CllZWVAqldXaTEpKQmJiYrXyPXv2wM7OzpRNMio9Pd1g2hOe9VpuvFX96rUUY6x6/jnh0qvmiq1YiMtES4dgcewD9oFea+qHXbfNWy69tHvjBtKS/HFsfPAY2RDl5eX1rmvyGJamMGnSJOlz7969ERgYCH9/f2RmZuKJJ56oVj8uLs7grE1JSQm8vb0xYsQIODk5mR2HVqtFeno6IiIioFL9+ZfE+uPr67X8oYs3zV63nFhBgTFWPbGz8hQqoQMAeJXkWDaoZqaACiEuE5F9MxU6NN5fEy0J+4B9oNfa+sGrnRovOpv2R5hWVCC9tDsiHM5BJeiaKDJ50/b7h9FjZEPor5DUh0kJi6urK5RKJQoKCgzKCwoK4O7ubnQZd3d3k+oDgJ+fH1xdXfHLL78YTVhsbGyMDspVqVSN0okPtiMqxHotpz+4txaV0EnbVNUK/pMyhw7aNrvteuwD9oFea+kHEVZmJx0qQddmExb8cVxsrGOtvq36MulinLW1NUJCQpCRkSGV6XQ6ZGRkoH///kaX6d+/v0F94N7ppJrqA8CVK1dw48YNeHh4mBIeERERtVImjx6KjY3FRx99hM2bN+P06dN45ZVXUFZWhujoaADA1KlTDQblzpo1C2lpaVi1ahXOnDmDhIQE/Pzzz4iJiQEAlJaWYs6cOThw4AByc3ORkZGBcePGoVu3boiMjGykzSQiIqKWzOQxLBMnTkRRUREWLlyI/Px8BAcHIy0tTRpYe+nSJSgUf+ZBAwYMQEpKCt58803Mnz8fAQEB2LFjh/QMFqVSiWPHjmHz5s0oLi6Gp6cnRowYgcWLF/NZLERERATAzEG3MTEx0hmSB2VmZlYre/bZZ/Hss88arW9ra4vdu3ebEwYRERG1EW33hnIiIiJqMZiwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJnlkJy5o1a+Dj4wO1Wo2wsDAcOnSo1vqff/45evToAbVajd69e2PXrl0G80VRxMKFC+Hh4QFbW1uEh4fj/Pnz5oRGRERErZDJCUtqaipiY2MRHx+Pw4cPIygoCJGRkSgsLDRaf//+/Zg8eTJeeOEFHDlyBOPHj8f48eNx4sQJqc6KFSvw7rvvYt26dTh48CDs7e0RGRmJu3fvmr9lRERE1GqYnLCsXr0a06dPR3R0NHr27Il169bBzs4OGzZsMFr/nXfewciRIzFnzhw8/PDDWLx4Mfr06YP3338fwL2zK8nJyXjzzTcxbtw4BAYG4uOPP8a1a9ewY8eOBm0cERERtQ5WplSuqKhAdnY24uLipDKFQoHw8HBkZWUZXSYrKwuxsbEGZZGRkVIycvHiReTn5yM8PFya7+zsjLCwMGRlZWHSpEnV2tRoNNBoNNL0rVu3AAA3b96EVqs1ZZMMaLValJeX48aNG1CpVFL53ZL6nenRlVeYvW450UGBcqty6CoroIMOAFB5x8JBNTMdgPLycmjv4I8eaHvYB+wDvdbWD3etdbhhZdr/11pRce/4IFRAJbSGXjCd9sYNo8fIhrh9+zaAeycv6mJSwnL9+nVUVVXBzc3NoNzNzQ1nzpwxukx+fr7R+vn5+dJ8fVlNdR6UlJSExMTEauW+vr712xCq08eWDkAWvrB0ADLAPmAf6LWufviXpQNokRKarOXbt2/D2dm51jomJSxyERcXZ3DWRqfT4ebNm+jQoQMEQTC73ZKSEnh7e+Py5ctwcnJqjFBbJPYD+wBgHwDsAz32A/sAaJo+EEURt2/fhqenZ511TUpYXF1doVQqUVBQYFBeUFAAd3d3o8u4u7vXWl//b0FBATw8PAzqBAcHG23TxsYGNjY2BmXt2rUzZVNq5eTk1GZ3yPuxH9gHAPsAYB/osR/YB0Dj90FdZ1b0TBp0a21tjZCQEGRkZEhlOp0OGRkZ6N+/v9Fl+vfvb1AfANLT06X6vr6+cHd3N6hTUlKCgwcP1tgmERERtS0mXxKKjY1FVFQUQkND0bdvXyQnJ6OsrAzR0dEAgKlTp8LLywtJSUkAgFmzZmHIkCFYtWoVRo8eja1bt+Lnn3/Gv//9bwCAIAiYPXs2lixZgoCAAPj6+mLBggXw9PTE+PHjG29LiYiIqMUyOWGZOHEiioqKsHDhQuTn5yM4OBhpaWnSoNlLly5BofjzxM2AAQOQkpKCN998E/Pnz0dAQAB27NiBXr16SXXmzp2LsrIyvPTSSyguLsbAgQORlpYGtVrdCJtYfzY2NoiPj692uamtYT+wDwD2AcA+0GM/sA8Ay/eBINbnXiIiIiIiC+K7hIiIiEj2mLAQERGR7DFhISIiItljwkJERESyx4SFiIiIZI8JSy2efPJJdOnSBWq1Gh4eHvjb3/6Ga9euWTqsZpObm4sXXngBvr6+sLW1hb+/P+Lj41FR0Tpe8lhfS5cuxYABA2BnZ9eoT1SWuzVr1sDHxwdqtRphYWE4dOiQpUNqVt9//z3Gjh0LT09PCILQ5t4en5SUhMceewyOjo7o1KkTxo8fj7Nnz1o6rGb3wQcfIDAwUHq6a//+/fG///3P0mFZ1LJly6RnqDUnJiy1GDZsGD777DOcPXsWX375JS5cuIBnnnnG0mE1mzNnzkCn0+HDDz/EyZMn8fbbb2PdunWYP3++pUNrVhUVFXj22WfxyiuvWDqUZpOamorY2FjEx8fj8OHDCAoKQmRkJAoLCy0dWrMpKytDUFAQ1qxZY+lQLOK7777DzJkzceDAAaSnp0Or1WLEiBEoKyuzdGjNqnPnzli2bBmys7Px888/Y/jw4Rg3bhxOnjxp6dAs4qeffsKHH36IwMDA5l+5SPX21VdfiYIgiBUVFZYOxWJWrFgh+vr6WjoMi9i4caPo7Oxs6TCaRd++fcWZM2dK01VVVaKnp6eYlJRkwagsB4C4fft2S4dhUYWFhSIA8bvvvrN0KBbXvn17cf369ZYOo9ndvn1bDAgIENPT08UhQ4aIs2bNatb18wxLPd28eRNbtmzBgAEDoFKpLB2Oxdy6dQsuLi6WDoOaUEVFBbKzsxEeHi6VKRQKhIeHIysry4KRkSXdunULANr0739VVRW2bt2KsrKyNvmuu5kzZ2L06NEG/zc0JyYsdXj99ddhb2+PDh064NKlS/jqq68sHZLF/PLLL3jvvffw97//3dKhUBO6fv06qqqqpNdt6Lm5uSE/P99CUZEl6XQ6zJ49G48//rjBa1XaiuPHj8PBwQE2NjZ4+eWXsX37dvTs2dPSYTWrrVu34vDhw9J7Ai2hzSUs8+bNgyAItf6cOXNGqj9nzhwcOXIEe/bsgVKpxNSpUyG28LcZmNoHAHD16lWMHDkSzz77LKZPn26hyBuPOX1A1FbNnDkTJ06cwNatWy0dikU89NBDyMnJwcGDB/HKK68gKioKp06dsnRYzeby5cuYNWsWtmzZ0uzv+Ltfm3uXUFFREW7cuFFrHT8/P1hbW1crv3LlCry9vbF///4WfTrQ1D64du0ahg4din79+mHTpk0GL7dsqczZDzZt2oTZs2ejuLi4iaOzrIqKCtjZ2eGLL74weGN6VFQUiouL2+RZRkEQsH379jb5BvmYmBh89dVX+P777+Hr62vpcGQhPDwc/v7++PDDDy0dSrPYsWMH/vrXv0KpVEplVVVVEAQBCoUCGo3GYF5TMfltzS1dx44d0bFjR7OW1el0AACNRtOYITU7U/rg6tWrGDZsGEJCQrBx48ZWkawADdsPWjtra2uEhIQgIyNDOkDrdDpkZGQgJibGssFRsxFFEf/4xz+wfft2ZGZmMlm5j06na/HHAVM88cQTOH78uEFZdHQ0evTogddff71ZkhWgDSYs9XXw4EH89NNPGDhwINq3b48LFy5gwYIF8Pf3b9FnV0xx9epVDB06FF27dsXKlStRVFQkzXN3d7dgZM3r0qVLuHnzJi5duoSqqirk5OQAALp16wYHBwfLBtdEYmNjERUVhdDQUPTt2xfJyckoKytDdHS0pUNrNqWlpfjll1+k6YsXLyInJwcuLi7o0qWLBSNrHjNnzkRKSgq++uorODo6SuOXnJ2dYWtra+Homk9cXBxGjRqFLl264Pbt20hJSUFmZiZ2795t6dCajaOjY7WxS/qxnc06pqlZ70lqQY4dOyYOGzZMdHFxEW1sbEQfHx/x5ZdfFq9cuWLp0JrNxo0bRQBGf9qSqKgoo32wd+9eS4fWpN577z2xS5cuorW1tdi3b1/xwIEDlg6pWe3du9fo9x4VFWXp0JpFTb/7GzdutHRozer5558Xu3btKlpbW4sdO3YUn3jiCXHPnj2WDsviLHFbc5sbw0JEREQtT+sYkEBEREStGhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7/z9HhY5nYwKkDgAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "for ctrl_s in ctrl_str:\n", + " plt.hist(\n", + " [r for r, t in zip(logs[\"env/reward_dist\"], task_list) if t == ctrl_s], density=True, alpha=0.5, label=ctrl_s\n", + " )\n", + "plt.legend(loc=\"best\")\n", + "plt.title(\"reward distribution\")\n", + "plt.grid(True)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Save model\n", + "Finally, we save the model to disk for later usage." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gpt2_model.save_pretrained(\"gpt2-imdb-ctrl\")\n", + "gpt2_tokenizer.save_pretrained(\"gpt2-imdb-ctrl\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "trl", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.12" + }, + "vscode": { + "interpreter": { + "hash": "d2cfb53525227c89f8d14fa784301fa46c451cc9223d94ccce9e17956835eea2" + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/notebooks/gpt2-sentiment.ipynb b/examples/notebooks/gpt2-sentiment.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..03d86e385f3b6e0f69452a4540ea1d1e0aed8799 --- /dev/null +++ b/examples/notebooks/gpt2-sentiment.ipynb @@ -0,0 +1,879 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Tune GPT2 to generate positive reviews\n", + "> Optimise GPT2 to produce positive IMDB movie reviews using a BERT sentiment classifier as a reward function." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "

Figure: Experiment setup to tune GPT2. The yellow arrows are outside the scope of this notebook, but the trained models are available through Hugging Face.

\n", + "
\n", + "\n", + "\n", + "In this notebook we fine-tune GPT2 (small) to generate positive movie reviews based on the IMDB dataset. The model gets the start of a real review and is tasked to produce positive continuations. To reward positive continuations we use a BERT classifier to analyse the sentiment of the produced sentences and use the classifier's outputs as rewards signals for PPO training." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup experiment" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Import dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install transformers trl wandb" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from tqdm import tqdm\n", + "import pandas as pd\n", + "\n", + "tqdm.pandas()\n", + "\n", + "from transformers import pipeline, AutoTokenizer\n", + "from datasets import load_dataset\n", + "\n", + "from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead\n", + "from trl.core import LengthSampler" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Configuration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "config = PPOConfig(\n", + " model_name=\"lvwerra/gpt2-imdb\",\n", + " learning_rate=1.41e-5,\n", + " log_with=\"wandb\",\n", + ")\n", + "\n", + "sent_kwargs = {\"return_all_scores\": True, \"function_to_apply\": \"none\", \"batch_size\": 16}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import wandb\n", + "\n", + "wandb.init()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can see that we load a GPT2 model called `gpt2_imdb`. This model was additionally fine-tuned on the IMDB dataset for 1 epoch with the huggingface [script](https://github.com/huggingface/transformers/blob/master/examples/run_language_modeling.py) (no special settings). The other parameters are mostly taken from the original paper [\"Fine-Tuning Language Models from Human Preferences\"](\n", + "https://arxiv.org/pdf/1909.08593.pdf). This model as well as the BERT model is available in the Huggingface model zoo [here](https://huggingface.co/models). The following code should automatically download the models." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load data and models" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load IMDB dataset\n", + "The IMDB dataset contains 50k movie review annotated with \"positive\"/\"negative\" feedback indicating the sentiment. We load the IMDB dataset into a DataFrame and filter for comments that are at least 200 characters. Then we tokenize each text and cut it to random size with the `LengthSampler`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Reusing dataset imdb (/home/leandro/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)\n", + "Loading cached processed dataset at /home/leandro/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-ff455473e884c6a3.arrow\n" + ] + } + ], + "source": [ + "def build_dataset(config, dataset_name=\"imdb\", input_min_text_length=2, input_max_text_length=8):\n", + " \"\"\"\n", + " Build dataset for training. This builds the dataset from `load_dataset`, one should\n", + " customize this function to train the model on its own dataset.\n", + "\n", + " Args:\n", + " dataset_name (`str`):\n", + " The name of the dataset to be loaded.\n", + "\n", + " Returns:\n", + " dataloader (`torch.utils.data.DataLoader`):\n", + " The dataloader for the dataset.\n", + " \"\"\"\n", + " tokenizer = AutoTokenizer.from_pretrained(config.model_name)\n", + " tokenizer.pad_token = tokenizer.eos_token\n", + " # load imdb with datasets\n", + " ds = load_dataset(dataset_name, split=\"train\")\n", + " ds = ds.rename_columns({\"text\": \"review\"})\n", + " ds = ds.filter(lambda x: len(x[\"review\"]) > 200, batched=False)\n", + "\n", + " input_size = LengthSampler(input_min_text_length, input_max_text_length)\n", + "\n", + " def tokenize(sample):\n", + " sample[\"input_ids\"] = tokenizer.encode(sample[\"review\"])[: input_size()]\n", + " sample[\"query\"] = tokenizer.decode(sample[\"input_ids\"])\n", + " return sample\n", + "\n", + " ds = ds.map(tokenize, batched=False)\n", + " ds.set_format(type=\"torch\")\n", + " return ds" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = build_dataset(config)\n", + "\n", + "\n", + "def collator(data):\n", + " return dict((key, [d[key] for d in data]) for key in data[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load pre-trained GPT2 language models" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We load the GPT2 model with a value head and the tokenizer. We load the model twice; the first model is optimized while the second model serves as a reference to calculate the KL-divergence from the starting point. This serves as an additional reward signal in the PPO training to make sure the optimized model does not deviate too much from the original language model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)\n", + "ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)\n", + "tokenizer = AutoTokenizer.from_pretrained(config.model_name)\n", + "\n", + "tokenizer.pad_token = tokenizer.eos_token" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Initialize PPOTrainer\n", + "The `PPOTrainer` takes care of device placement and optimization later on:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load BERT classifier\n", + "We load a BERT classifier fine-tuned on the IMDB dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "device = ppo_trainer.accelerator.device\n", + "if ppo_trainer.accelerator.num_processes == 1:\n", + " device = 0 if torch.cuda.is_available() else \"cpu\" # to avoid a `pipeline` bug\n", + "sentiment_pipe = pipeline(\"sentiment-analysis\", model=\"lvwerra/distilbert-imdb\", device=device)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The model outputs are the logits for the negative and positive class. We will use the logits for positive class as a reward signal for the language model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[[{'label': 'NEGATIVE', 'score': 2.335048198699951},\n", + " {'label': 'POSITIVE', 'score': -2.726576566696167}]]" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "text = \"this movie was really bad!!\"\n", + "sentiment_pipe(text, **sent_kwargs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[[{'label': 'NEGATIVE', 'score': -2.2947897911071777},\n", + " {'label': 'POSITIVE', 'score': 2.557039737701416}]]" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "text = \"this movie was really good!!\"\n", + "sentiment_pipe(text, **sent_kwargs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Generation settings\n", + "For the response generation we just use sampling and make sure top-k and nucleus sampling are turned off as well as a minimal length." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gen_kwargs = {\"min_length\": -1, \"top_k\": 0.0, \"top_p\": 1.0, \"do_sample\": True, \"pad_token_id\": tokenizer.eos_token_id}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Optimize model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training loop" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The training loop consists of the following main steps:\n", + "1. Get the query responses from the policy network (GPT-2)\n", + "2. Get sentiments for query/responses from BERT\n", + "3. Optimize policy with PPO using the (query, response, reward) triplet\n", + "\n", + "**Training time**\n", + "\n", + "This step takes **~2h** on a V100 GPU with the above specified settings." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "output_min_length = 4\n", + "output_max_length = 16\n", + "output_length_sampler = LengthSampler(output_min_length, output_max_length)\n", + "\n", + "\n", + "generation_kwargs = {\n", + " \"min_length\": -1,\n", + " \"top_k\": 0.0,\n", + " \"top_p\": 1.0,\n", + " \"do_sample\": True,\n", + " \"pad_token_id\": tokenizer.eos_token_id,\n", + "}\n", + "\n", + "\n", + "for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):\n", + " query_tensors = batch[\"input_ids\"]\n", + "\n", + " #### Get response from gpt2\n", + " response_tensors = []\n", + " for query in query_tensors:\n", + " gen_len = output_length_sampler()\n", + " generation_kwargs[\"max_new_tokens\"] = gen_len\n", + " response = ppo_trainer.generate(query, **generation_kwargs)\n", + " response_tensors.append(response.squeeze()[-gen_len:])\n", + " batch[\"response\"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]\n", + "\n", + " #### Compute sentiment score\n", + " texts = [q + r for q, r in zip(batch[\"query\"], batch[\"response\"])]\n", + " pipe_outputs = sentiment_pipe(texts, **sent_kwargs)\n", + " rewards = [torch.tensor(output[1][\"score\"]) for output in pipe_outputs]\n", + "\n", + " #### Run PPO step\n", + " stats = ppo_trainer.step(query_tensors, response_tensors, rewards)\n", + " ppo_trainer.log_stats(stats, batch, rewards)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training progress\n", + "If you are tracking the training progress with Weights&Biases you should see a plot similar to the one below. Check out the interactive sample report on wandb.ai: [link](https://app.wandb.ai/huggingface/trl-showcase/runs/1jtvxb1m/).\n", + "\n", + "
\n", + "\n", + "

Figure: Reward mean and distribution evolution during training.

\n", + "
\n", + "\n", + "One can observe how the model starts to generate more positive outputs after a few optimisation steps.\n", + "\n", + "> Note: Investigating the KL-divergence will probably show that at this point the model has not converged to the target KL-divergence, yet. To get there would require longer training or starting with a higher initial coefficient." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model inspection\n", + "Let's inspect some examples from the IMDB dataset. We can use `model_ref` to compare the tuned model `model` against the model before optimisation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/leandro/miniconda3/envs/trl/lib/python3.9/site-packages/transformers/pipelines/base.py:1075: UserWarning: You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
queryresponse (before)response (after)rewards (before)rewards (after)
0Oh dear,what are I saying?! I fast-forwarded throughI must say that I are hanging my head on this-0.858954-1.007609
1I've seenit, as well.<brthree million dialogue throughout, and1.9968072.240883
2Hi:<br /><br/>This movie is a turkey though when it comes to/>I also like that movie. It's so funny-0.4381912.415630
3I'm a writerand I'm not going to be asked to, not a screenwriter. I've written-0.655991-0.724324
4If youabsolutely love sensitive romance, the plot a...are looking at the cinematography, the acting,2.2213090.148751
5OMG thiscasting cast. Obi cult breezy, this ismovie was totally wonderful, I it was the ide...-1.5331392.590190
6It'sunrealistic; the guy who was supposed to be E...a very good film. It reminds us about over-2.0970172.835831
7There is a reallyawful laptop game!<br /><br />I used tointeresting story that set us the journey. Th...-2.3417432.282939
8This ismy favorite part abouta well thought well2.5547942.734139
9Wasn'tWasn't it clichéd?<|endoftext|>anyone else interested in this movie? It's a ...-1.7908022.631960
10This film is another of director TimBurton's masterpiecesCurry's best bombs2.6229172.544106
11I thought this moviewas excellent. I actually laughed 6 times and...was perfect, and I believe it's almost overlo...2.5480222.601913
12This early John Waynefilms looked like an abandoned police beatingfilm is a realistic portrayal of what-1.7422792.609762
13I wasgiven an experience-a big one, almost 25very happy with all the reflections and this ...2.2507092.558540
14Embarrassingly, Iam more at a strict conformity after getting ...had never seen a movie before. There was one ...-2.021666-1.803383
15I am a fanof living on simple islands, and we have visi...of many things and learned how to appreciate ...1.7912972.324461
\n", + "
" + ], + "text/plain": [ + " query \\\n", + "0 Oh dear, \n", + "1 I've seen \n", + "2 Hi:

This movie is a turkey though when it comes to \n", + "3 and I'm not going to be asked to \n", + "4 absolutely love sensitive romance, the plot a... \n", + "5 casting cast. Obi cult breezy, this is \n", + "6 unrealistic; the guy who was supposed to be E... \n", + "7 awful laptop game!

I used to \n", + "8 my favorite part about \n", + "9 Wasn't it clichéd?<|endoftext|> \n", + "10 Burton's masterpieces \n", + "11 was excellent. I actually laughed 6 times and... \n", + "12 films looked like an abandoned police beating \n", + "13 given an experience-a big one, almost 25 \n", + "14 am more at a strict conformity after getting ... \n", + "15 of living on simple islands, and we have visi... \n", + "\n", + " response (after) rewards (before) \\\n", + "0 I must say that I are hanging my head on this -0.858954 \n", + "1 three million dialogue throughout, and 1.996807 \n", + "2 />I also like that movie. It's so funny -0.438191 \n", + "3 , not a screenwriter. I've written -0.655991 \n", + "4 are looking at the cinematography, the acting, 2.221309 \n", + "5 movie was totally wonderful, I it was the ide... -1.533139 \n", + "6 a very good film. It reminds us about over -2.097017 \n", + "7 interesting story that set us the journey. Th... -2.341743 \n", + "8 a well thought well 2.554794 \n", + "9 anyone else interested in this movie? It's a ... -1.790802 \n", + "10 Curry's best bombs 2.622917 \n", + "11 was perfect, and I believe it's almost overlo... 2.548022 \n", + "12 film is a realistic portrayal of what -1.742279 \n", + "13 very happy with all the reflections and this ... 2.250709 \n", + "14 had never seen a movie before. There was one ... -2.021666 \n", + "15 of many things and learned how to appreciate ... 1.791297 \n", + "\n", + " rewards (after) \n", + "0 -1.007609 \n", + "1 2.240883 \n", + "2 2.415630 \n", + "3 -0.724324 \n", + "4 0.148751 \n", + "5 2.590190 \n", + "6 2.835831 \n", + "7 2.282939 \n", + "8 2.734139 \n", + "9 2.631960 \n", + "10 2.544106 \n", + "11 2.601913 \n", + "12 2.609762 \n", + "13 2.558540 \n", + "14 -1.803383 \n", + "15 2.324461 " + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#### get a batch from the dataset\n", + "bs = 16\n", + "game_data = dict()\n", + "dataset.set_format(\"pandas\")\n", + "df_batch = dataset[:].sample(bs)\n", + "game_data[\"query\"] = df_batch[\"query\"].tolist()\n", + "query_tensors = df_batch[\"input_ids\"].tolist()\n", + "\n", + "response_tensors_ref, response_tensors = [], []\n", + "\n", + "#### get response from gpt2 and gpt2_ref\n", + "for i in range(bs):\n", + " gen_len = output_length_sampler()\n", + " output = ref_model.generate(\n", + " torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs\n", + " ).squeeze()[-gen_len:]\n", + " response_tensors_ref.append(output)\n", + " output = model.generate(\n", + " torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs\n", + " ).squeeze()[-gen_len:]\n", + " response_tensors.append(output)\n", + "\n", + "#### decode responses\n", + "game_data[\"response (before)\"] = [tokenizer.decode(response_tensors_ref[i]) for i in range(bs)]\n", + "game_data[\"response (after)\"] = [tokenizer.decode(response_tensors[i]) for i in range(bs)]\n", + "\n", + "#### sentiment analysis of query/response pairs before/after\n", + "texts = [q + r for q, r in zip(game_data[\"query\"], game_data[\"response (before)\"])]\n", + "game_data[\"rewards (before)\"] = [output[1][\"score\"] for output in sentiment_pipe(texts, **sent_kwargs)]\n", + "\n", + "texts = [q + r for q, r in zip(game_data[\"query\"], game_data[\"response (after)\"])]\n", + "game_data[\"rewards (after)\"] = [output[1][\"score\"] for output in sentiment_pipe(texts, **sent_kwargs)]\n", + "\n", + "# store results in a dataframe\n", + "df_results = pd.DataFrame(game_data)\n", + "df_results" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Looking at the reward mean/median of the generated sequences we observe a significant difference." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mean:\n" + ] + }, + { + "data": { + "text/plain": [ + "rewards (before) 0.156629\n", + "rewards (after) 1.686487\n", + "dtype: float64" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "median:\n" + ] + }, + { + "data": { + "text/plain": [ + "rewards (before) -0.547091\n", + "rewards (after) 2.479868\n", + "dtype: float64" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "print(\"mean:\")\n", + "display(df_results[[\"rewards (before)\", \"rewards (after)\"]].mean())\n", + "print()\n", + "print(\"median:\")\n", + "display(df_results[[\"rewards (before)\", \"rewards (after)\"]].median())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Save model\n", + "Finally, we save the model and push it to the Hugging Face for later usage." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/leandro/miniconda3/envs/trl/lib/python3.9/site-packages/huggingface_hub/hf_api.py:1001: FutureWarning: `create_repo` now takes `token` as an optional positional argument. Be sure to adapt your code!\n", + " warnings.warn(\n", + "Cloning https://huggingface.co/lvwerra/gpt2-imdb-pos-v2 into local empty directory.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a953a6d0c465432bbc39aca826d37aaf", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Upload file pytorch_model.bin: 0%| | 32.0k/487M [00:00 main\n", + "\n", + "remote: Enforcing permissions... \n", + "remote: Allowed refs: all \n", + "To https://huggingface.co/lvwerra/gpt2-imdb-pos-v2\n", + " 28b9865..42792ea main -> main\n", + "\n" + ] + }, + { + "data": { + "text/plain": [ + "('gpt2-imdb-pos-v2/tokenizer_config.json',\n", + " 'gpt2-imdb-pos-v2/special_tokens_map.json',\n", + " 'gpt2-imdb-pos-v2/vocab.json',\n", + " 'gpt2-imdb-pos-v2/merges.txt',\n", + " 'gpt2-imdb-pos-v2/added_tokens.json',\n", + " 'gpt2-imdb-pos-v2/tokenizer.json')" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.save_pretrained(\"gpt2-imdb-pos-v2\", push_to_hub=True)\n", + "tokenizer.save_pretrained(\"gpt2-imdb-pos-v2\", push_to_hub=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.12 (main, Mar 26 2022, 15:51:15) \n[Clang 13.1.6 (clang-1316.0.21.2)]" + }, + "vscode": { + "interpreter": { + "hash": "4c8ff454cd947027f86954d72bf940c689a97dcc494eb53cfe4813862c6065fe" + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/research_projects/README.md b/examples/research_projects/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1b1977e1877ca1d6351cd888b76793a2bad3206d --- /dev/null +++ b/examples/research_projects/README.md @@ -0,0 +1,7 @@ +# Research projects that use TRL + +Welcome to the research projects folder! Here you can find the scripts used for some research projects that used TRL and maintained by the developers and the community (LM de-toxification, Stack-Llama, etc.). Check out the READMEs in the subfolders for more information! + +- [De-detoxifying language models](https://github.com/huggingface/trl/tree/main/examples/research_projects/toxicity) +- [Stack-Llama](https://github.com/huggingface/trl/tree/main/examples/research_projects/stack_llama) +- [Stack-Llama-2](https://github.com/huggingface/trl/tree/main/examples/research_projects/stack_llama_2) \ No newline at end of file diff --git a/examples/research_projects/stack_llama/scripts/README.md b/examples/research_projects/stack_llama/scripts/README.md new file mode 100644 index 0000000000000000000000000000000000000000..60ed5fd94397c3954313cbc88de5648a387f03ea --- /dev/null +++ b/examples/research_projects/stack_llama/scripts/README.md @@ -0,0 +1,18 @@ +# RLHF pipeline for the creation of StackLLaMa: a Stack exchange llama-7b model. +There were three main steps to the training process: +1. Supervised fine-tuning of the base llama-7b model to create llama-7b-se: + - `torchrun --nnodes 1 --nproc_per_node 8 examples/research_projects/stack_llama/scripts/supervised_finetuning.py --model_path= --streaming --no_gradient_checkpointing --learning_rate 1e-5 --max_steps 5000 --output_dir ./llama-se` +2. Reward modeling using dialog pairs from the SE dataset using the llama-7b-se to create llama-7b-se-rm: + - `torchrun --nnodes 1 --nproc_per_node 8 examples/research_projects/stack_llama/scripts/reward_modeling.py --model_name=` +3. RL fine-tuning of llama-7b-se with the llama-7b-se-rm reward model: + - `accelerate launch --multi_gpu --num_machines 1 --num_processes 8 examples/research_projects/stack_llama/scripts/rl_training.py --log_with=wandb --model_name= --reward_model_name= --adafactor=False --tokenizer_name= --save_freq=100 --output_max_length=128 --batch_size=8 --gradient_accumulation_steps=8 --batched_gen=True --ppo_epochs=4 --seed=0 --learning_rate=1.4e-5 --early_stopping=True --output_dir=llama-se-rl-finetune-128-8-8-1.4e-5_adam` + + +LoRA layers were using at all stages to reduce memory requirements. +At each stage the peft adapter layers were merged with the base model, using: +```shell +python examples/research_projects/stack_llama/scripts/merge_peft_adapter.py --adapter_model_name=XXX --base_model_name=YYY --output_name=ZZZ +``` +Note that this script requires `peft>=0.3.0`. + +For access to the base llama-7b model, please see Meta's [release](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/) and [request form](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform). diff --git a/examples/research_projects/stack_llama/scripts/merge_peft_adapter.py b/examples/research_projects/stack_llama/scripts/merge_peft_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..e7ff3b5cd982c171b0a4db948c7932e60e97fe37 --- /dev/null +++ b/examples/research_projects/stack_llama/scripts/merge_peft_adapter.py @@ -0,0 +1,48 @@ +from dataclasses import dataclass, field +from typing import Optional + +import torch +from peft import PeftConfig, PeftModel +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser + + +@dataclass +class ScriptArguments: + """ + The input names representing the Adapter and Base model fine-tuned with PEFT, and the output name representing the + merged model. + """ + + adapter_model_name: Optional[str] = field(default=None, metadata={"help": "the adapter name"}) + base_model_name: Optional[str] = field(default=None, metadata={"help": "the base model name"}) + output_name: Optional[str] = field(default=None, metadata={"help": "the merged model name"}) + + +parser = HfArgumentParser(ScriptArguments) +script_args = parser.parse_args_into_dataclasses()[0] +assert script_args.adapter_model_name is not None, "please provide the name of the Adapter you would like to merge" +assert script_args.base_model_name is not None, "please provide the name of the Base model" +assert script_args.output_name is not None, "please provide the output name of the merged model" + +peft_config = PeftConfig.from_pretrained(script_args.adapter_model_name) +if peft_config.task_type == "SEQ_CLS": + # The sequence classification task is used for the reward model in PPO + model = AutoModelForSequenceClassification.from_pretrained( + script_args.base_model_name, num_labels=1, torch_dtype=torch.bfloat16 + ) +else: + model = AutoModelForCausalLM.from_pretrained( + script_args.base_model_name, return_dict=True, torch_dtype=torch.bfloat16 + ) + +tokenizer = AutoTokenizer.from_pretrained(script_args.base_model_name) + +# Load the PEFT model +model = PeftModel.from_pretrained(model, script_args.adapter_model_name) +model.eval() + +model = model.merge_and_unload() + +model.save_pretrained(f"{script_args.output_name}") +tokenizer.save_pretrained(f"{script_args.output_name}") +model.push_to_hub(f"{script_args.output_name}", use_temp_dir=False) diff --git a/examples/research_projects/stack_llama/scripts/reward_modeling.py b/examples/research_projects/stack_llama/scripts/reward_modeling.py new file mode 100644 index 0000000000000000000000000000000000000000..d895d356db5dc7d7af522d4e6bdbefce24bd9b6d --- /dev/null +++ b/examples/research_projects/stack_llama/scripts/reward_modeling.py @@ -0,0 +1,300 @@ +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Union + +import evaluate +import numpy as np +import torch +import torch.nn as nn +from datasets import load_dataset +from peft import LoraConfig, TaskType, get_peft_model +from transformers import ( + AutoModelForSequenceClassification, + AutoTokenizer, + HfArgumentParser, + PreTrainedTokenizerBase, + Trainer, + TrainerCallback, + TrainingArguments, +) +from transformers.utils import PaddingStrategy + + +# Define and parse arguments. +@dataclass +class ScriptArguments: + """ + These arguments vary depending on how many GPUs you have, what their capacity and features are, and what size model you want to train. + """ + + local_rank: Optional[int] = field(default=-1, metadata={"help": "Used for multi-gpu"}) + resume_from_checkpoint: Optional[bool] = field( + default=False, + metadata={"help": "If you want to resume training where it left off."}, + ) + deepspeed: Optional[str] = field( + default=None, + metadata={ + "help": "Path to deepspeed config if using deepspeed. You may need this if the model that you want to train doesn't fit on a single GPU." + }, + ) + per_device_train_batch_size: Optional[int] = field(default=4) + per_device_eval_batch_size: Optional[int] = field(default=1) + gradient_accumulation_steps: Optional[int] = field(default=1) + learning_rate: Optional[float] = field(default=2e-5) + weight_decay: Optional[float] = field(default=0.001) + model_name: Optional[str] = field( + default="gpt2", + metadata={ + "help": "The model that you want to train from the Hugging Face hub. E.g. gpt2, gpt2-xl, bert, etc." + }, + ) + tokenizer_name: Optional[str] = field( + default=None, + metadata={ + "help": "The tokenizer for your model, if left empty will use the default for your model", + }, + ) + bf16: Optional[bool] = field( + default=True, + metadata={ + "help": "This essentially cuts the training time in half if you want to sacrifice a little precision and have a supported GPU." + }, + ) + num_train_epochs: Optional[int] = field( + default=1, + metadata={"help": "The number of training epochs for the reward model."}, + ) + train_subset: Optional[int] = field( + default=100000, + metadata={"help": "The size of the subset of the training data to use"}, + ) + eval_subset: Optional[int] = field( + default=50000, + metadata={"help": "The size of the subset of the eval data to use"}, + ) + gradient_checkpointing: Optional[bool] = field( + default=False, + metadata={"help": "Enables gradient checkpointing."}, + ) + optim: Optional[str] = field( + default="adamw_hf", + metadata={"help": "The optimizer to use."}, + ) + lr_scheduler_type: Optional[str] = field( + default="linear", + metadata={"help": "The lr scheduler"}, + ) + max_length: Optional[int] = field(default=512) + eval_first_step: Optional[bool] = field( + default=False, + metadata={"help": "Whether to run eval after the first step"}, + ) + + +parser = HfArgumentParser(ScriptArguments) +script_args = parser.parse_args_into_dataclasses()[0] + +# Load the human stack-exchange-paired dataset for tuning the reward model. +train_dataset = load_dataset("lvwerra/stack-exchange-paired", data_dir="data/reward", split="train") +if script_args.train_subset > 0: + train_dataset = train_dataset.select(range(script_args.train_subset)) +eval_dataset = load_dataset("lvwerra/stack-exchange-paired", data_dir="data/evaluation", split="train") +if script_args.eval_subset > 0: + eval_dataset = eval_dataset.select(range(script_args.eval_subset)) +# Define the training args. Needs to be done before the model is loaded if you are using deepspeed. +model_name_split = script_args.model_name.split("/")[-1] +output_name = ( + f"{model_name_split}_peft_stack-exchange-paired_rmts__{script_args.train_subset}_{script_args.learning_rate}" +) + +training_args = TrainingArguments( + output_dir=output_name, + learning_rate=script_args.learning_rate, + per_device_train_batch_size=script_args.per_device_train_batch_size, + per_device_eval_batch_size=script_args.per_device_eval_batch_size, + num_train_epochs=script_args.num_train_epochs, + weight_decay=script_args.weight_decay, + evaluation_strategy="steps", + eval_steps=500, + save_strategy="steps", + save_steps=500, + gradient_accumulation_steps=script_args.gradient_accumulation_steps, + gradient_checkpointing=script_args.gradient_checkpointing, + deepspeed=script_args.deepspeed, + local_rank=script_args.local_rank, + remove_unused_columns=False, + label_names=[], + bf16=script_args.bf16, + logging_strategy="steps", + logging_steps=10, + optim=script_args.optim, + lr_scheduler_type=script_args.lr_scheduler_type, +) +# Load the value-head model and tokenizer. +tokenizer_name = script_args.tokenizer_name if script_args.tokenizer_name is not None else script_args.model_name +tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_auth_token=True) +tokenizer.pad_token = tokenizer.eos_token + + +peft_config = LoraConfig( + task_type=TaskType.SEQ_CLS, + inference_mode=False, + r=8, + lora_alpha=32, + lora_dropout=0.1, +) + +model = AutoModelForSequenceClassification.from_pretrained( + script_args.model_name, num_labels=1, torch_dtype=torch.bfloat16 +) +model = get_peft_model(model, peft_config) +model.print_trainable_parameters() + +# Need to do this for gpt2, because it doesn't have an official pad token. +tokenizer.pad_token = tokenizer.eos_token +model.config.pad_token_id = tokenizer.eos_token_id +model.config.use_cache = not script_args.gradient_checkpointing +num_proc = 24 # Can adjust to be higher if you have more processors. +original_columns = train_dataset.column_names + + +# Turn the dataset into pairs of post + summaries, where text_j is the preferred question + answer and text_k is the other. +# Then tokenize the dataset. +def preprocess_function(examples): + new_examples = { + "input_ids_j": [], + "attention_mask_j": [], + "input_ids_k": [], + "attention_mask_k": [], + } + for question, response_j, response_k in zip(examples["question"], examples["response_j"], examples["response_k"]): + tokenized_j = tokenizer("Question: " + question + "\n\nAnswer: " + response_j, truncation=True) + tokenized_k = tokenizer("Question: " + question + "\n\nAnswer: " + response_k, truncation=True) + + new_examples["input_ids_j"].append(tokenized_j["input_ids"]) + new_examples["attention_mask_j"].append(tokenized_j["attention_mask"]) + new_examples["input_ids_k"].append(tokenized_k["input_ids"]) + new_examples["attention_mask_k"].append(tokenized_k["attention_mask"]) + + return new_examples + + +# preprocess the dataset and filter out QAs that are longer than script_args.max_length +train_dataset = train_dataset.map( + preprocess_function, + batched=True, + num_proc=num_proc, + remove_columns=original_columns, +) +train_dataset = train_dataset.filter( + lambda x: len(x["input_ids_j"]) <= script_args.max_length and len(x["input_ids_k"]) <= script_args.max_length +) + +eval_dataset = eval_dataset.map( + preprocess_function, + batched=True, + num_proc=num_proc, + remove_columns=original_columns, +) +eval_dataset = eval_dataset.filter( + lambda x: len(x["input_ids_j"]) <= script_args.max_length and len(x["input_ids_k"]) <= script_args.max_length +) + + +# We need to define a special data collator that batches the data in our j vs k format. +@dataclass +class RewardDataCollatorWithPadding: + tokenizer: PreTrainedTokenizerBase + padding: Union[bool, str, PaddingStrategy] = True + max_length: Optional[int] = None + pad_to_multiple_of: Optional[int] = None + return_tensors: str = "pt" + + def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: + features_j = [] + features_k = [] + for feature in features: + features_j.append( + { + "input_ids": feature["input_ids_j"], + "attention_mask": feature["attention_mask_j"], + } + ) + features_k.append( + { + "input_ids": feature["input_ids_k"], + "attention_mask": feature["attention_mask_k"], + } + ) + batch_j = self.tokenizer.pad( + features_j, + padding=self.padding, + max_length=self.max_length, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors=self.return_tensors, + ) + batch_k = self.tokenizer.pad( + features_k, + padding=self.padding, + max_length=self.max_length, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors=self.return_tensors, + ) + batch = { + "input_ids_j": batch_j["input_ids"], + "attention_mask_j": batch_j["attention_mask"], + "input_ids_k": batch_k["input_ids"], + "attention_mask_k": batch_k["attention_mask"], + "return_loss": True, + } + return batch + + +# Define the metric that we'll use for validation. +accuracy = evaluate.load("accuracy") + + +def compute_metrics(eval_pred): + predictions, _ = eval_pred + # Here, predictions is rewards_j and rewards_k. + # We want to see how much of the time rewards_j > rewards_k. + predictions = np.argmax(predictions, axis=0) + labels = np.zeros(predictions.shape) + return accuracy.compute(predictions=predictions, references=labels) + + +class RewardTrainer(Trainer): + # Define how to compute the reward loss. We use the InstructGPT pairwise logloss: https://arxiv.org/abs/2203.02155 + def compute_loss(self, model, inputs, return_outputs=False): + rewards_j = model(input_ids=inputs["input_ids_j"], attention_mask=inputs["attention_mask_j"])[0] + rewards_k = model(input_ids=inputs["input_ids_k"], attention_mask=inputs["attention_mask_k"])[0] + loss = -nn.functional.logsigmoid(rewards_j - rewards_k).mean() + if return_outputs: + return loss, {"rewards_j": rewards_j, "rewards_k": rewards_k} + return loss + + +# Train the model, woohoo. +trainer = RewardTrainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + compute_metrics=compute_metrics, + data_collator=RewardDataCollatorWithPadding(tokenizer=tokenizer, max_length=script_args.max_length), +) + + +if script_args.eval_first_step: + + class EvaluateFirstStepCallback(TrainerCallback): + def on_step_end(self, args, state, control, **kwargs): + if state.global_step == 1: + control.should_evaluate = True + + trainer.add_callback(EvaluateFirstStepCallback()) + +trainer.train(script_args.resume_from_checkpoint) + +print("Saving last checkpoint of the model") +model.save_pretrained(output_name + "_peft_last_checkpoint") diff --git a/examples/research_projects/stack_llama/scripts/rl_training.py b/examples/research_projects/stack_llama/scripts/rl_training.py new file mode 100644 index 0000000000000000000000000000000000000000..eee7952660f2084236620f2264986e3544562bab --- /dev/null +++ b/examples/research_projects/stack_llama/scripts/rl_training.py @@ -0,0 +1,263 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field +from typing import Optional + +import torch +from accelerate import Accelerator +from datasets import load_dataset +from peft import LoraConfig +from tqdm import tqdm +from transformers import Adafactor, AutoTokenizer, HfArgumentParser, pipeline + +from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, set_seed +from trl.core import LengthSampler + + +tqdm.pandas() + + +@dataclass +class ScriptArguments: + """ + The name of the Casual LM model we wish to fine with PPO + """ + + # NOTE: gpt2 models use Conv1D instead of Linear layers which are not yet supported in 8 bit mode + # models like gpt-neo* models are more suitable. + model_name: Optional[str] = field(default="", metadata={"help": "the model name"}) + tokenizer_name: Optional[str] = field(default="", metadata={"help": "the tokenizer name"}) + reward_model_name: Optional[str] = field(default="", metadata={"help": "the reward model name"}) + log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"}) + learning_rate: Optional[float] = field(default=1.41e-5, metadata={"help": "the learning rate"}) + output_max_length: Optional[int] = field(default=128, metadata={"help": "maximum length for generation"}) + mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"}) + batch_size: Optional[int] = field(default=32, metadata={"help": "the batch size"}) + ppo_epochs: Optional[int] = field(default=4, metadata={"help": "the number of ppo epochs"}) + gradient_accumulation_steps: Optional[int] = field( + default=4, metadata={"help": "the number of gradient accumulation steps"} + ) + adafactor: Optional[bool] = field(default=False, metadata={"help": "whether to use the adafactor optimizer"}) + early_stopping: Optional[bool] = field(default=False, metadata={"help": "whether to early stop"}) + target_kl: Optional[float] = field(default=0.1, metadata={"help": "kl target for early stopping"}) + reward_baseline: Optional[float] = field( + default=0.0, + metadata={"help": "a baseline value that is subtracted from the reward"}, + ) + batched_gen: Optional[bool] = field(default=False, metadata={"help": "whether to use the batched text gen"}) + save_freq: Optional[int] = field(default=None, metadata={"help": "n steps to save the model"}) + output_dir: Optional[str] = field(default="runs/", metadata={"help": "n steps to save the model"}) + seed: Optional[int] = field(default=0, metadata={"help": "the seed"}) + steps: Optional[int] = field(default=20000, metadata={"help": "number of epochs"}) + init_kl_coef: Optional[float] = field( + default=0.2, + metadata={"help": "Initial KL penalty coefficient (used for adaptive and linear control)"}, + ) + + adap_kl_ctrl: Optional[bool] = field(default=True, metadata={"help": "Use adaptive KL control, otherwise linear"}) + + +parser = HfArgumentParser(ScriptArguments) +script_args: ScriptArguments = parser.parse_args_into_dataclasses()[0] +reward_model_name = script_args.reward_model_name +dataset_name = "lvwerra/stack-exchange-paired" +config = PPOConfig( + steps=script_args.steps, + model_name=script_args.model_name, + learning_rate=script_args.learning_rate, + log_with=script_args.log_with, + batch_size=script_args.batch_size, + mini_batch_size=script_args.mini_batch_size, + gradient_accumulation_steps=script_args.gradient_accumulation_steps, + optimize_cuda_cache=True, + early_stopping=script_args.early_stopping, + target_kl=script_args.target_kl, + ppo_epochs=script_args.ppo_epochs, + seed=script_args.seed, + init_kl_coef=script_args.init_kl_coef, + adap_kl_ctrl=script_args.adap_kl_ctrl, +) + +train_dataset = load_dataset("lvwerra/stack-exchange-paired", data_dir="data/rl", split="train") +train_dataset = train_dataset.select(range(100000)) +original_columns = train_dataset.column_names + +# We then define the arguments to pass to the sentiment analysis pipeline. +# We set `return_all_scores` to True to get the sentiment score for each token. +sent_kwargs = { + "return_all_scores": True, + "function_to_apply": "none", + "batch_size": 16, + "truncation": True, +} + +tokenizer = AutoTokenizer.from_pretrained(script_args.tokenizer_name) +# GPT-2 tokenizer has a pad token, but it is not eos_token by default. We need to set it to eos_token. +# only for this model. + +if getattr(tokenizer, "pad_token", None) is None: + tokenizer.pad_token = tokenizer.eos_token + + +# Below is an example function to build the dataset. In our case, we use the IMDB dataset +# from the `datasets` library. One should customize this function to train the model on +# its own dataset. +def build_dataset( + tokenizer, + dataset_name="lvwerra/stack-exchange-paired", +): + """ + Build dataset for training. This builds the dataset from `load_dataset`, one should + customize this function to train the model on its own dataset. + + Args: + dataset_name (`str`): + The name of the dataset to be loaded. + + Returns: + dataloader (`torch.utils.data.DataLoader`): + The dataloader for the dataset. + """ + + num_proc = 24 + + def preprocess_function(examples): + new_examples = { + "query": [], + "input_ids": [], + } + for question in examples["question"]: + query = "Question: " + question + "\n\nAnswer: " + tokenized_question = tokenizer(query, truncation=True) + new_examples["query"].append(query) + new_examples["input_ids"].append(tokenized_question["input_ids"]) + + return new_examples + + ds = train_dataset.map( + preprocess_function, + batched=True, + num_proc=num_proc, + remove_columns=original_columns, + ) + ds = ds.filter(lambda x: len(x["input_ids"]) < 512, batched=False) + + ds.set_format(type="torch") + return ds + + +# We retrieve the dataloader by calling the `build_dataset` function. +dataset = build_dataset(tokenizer) + + +def collator(data): + return dict((key, [d[key] for d in data]) for key in data[0]) + + +# set seed before initializing value head for deterministic eval +set_seed(config.seed) + +# Now let's build the model, the reference model, and the tokenizer. +current_device = Accelerator().local_process_index + +lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", +) +model = AutoModelForCausalLMWithValueHead.from_pretrained( + config.model_name, + load_in_8bit=True, + device_map={"": current_device}, + peft_config=lora_config, +) + +optimizer = None +if script_args.adafactor: + optimizer = Adafactor( + filter(lambda p: p.requires_grad, model.parameters()), + scale_parameter=False, + relative_step=False, + warmup_init=False, + lr=config.learning_rate, + ) +# We then build the PPOTrainer, passing the model, the reference model, the tokenizer +ppo_trainer = PPOTrainer( + config, + model, + ref_model=None, + tokenizer=tokenizer, + dataset=dataset, + data_collator=collator, + optimizer=optimizer, +) + +# We then build the sentiment analysis pipeline using our reward model, passing the +# model name and the sentiment analysis pipeline arguments. Let's also make sure to +# set the device to the same device as the PPOTrainer. +device = ppo_trainer.accelerator.device +if ppo_trainer.accelerator.num_processes == 1: + device = 0 if torch.cuda.is_available() else "cpu" # to avoid a ` pipeline` bug +sentiment_pipe = pipeline( + "sentiment-analysis", + model=reward_model_name, + device_map={"": current_device}, + model_kwargs={"load_in_8bit": True}, + tokenizer=tokenizer, + return_token_type_ids=False, +) + +# We then define the arguments to pass to the `generate` function. These arguments +# are passed to the `generate` function of the PPOTrainer, which is a wrapper around +# the `generate` function of the trained model. +generation_kwargs = { + # "min_length": -1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + "pad_token_id": tokenizer.pad_token_id, + "eos_token_id": 100_000, +} +output_min_length = 32 +output_max_length = script_args.output_max_length +output_length_sampler = LengthSampler(output_min_length, output_max_length) + +for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)): + if epoch >= config.total_ppo_epochs: + break + + question_tensors = batch["input_ids"] + + response_tensors = ppo_trainer.generate( + question_tensors, + return_prompt=False, + length_sampler=output_length_sampler, + **generation_kwargs, + ) + batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True) + + # Compute reward score (using the sentiment analysis pipeline) + texts = [q + r for q, r in zip(batch["query"], batch["response"])] + pipe_outputs = sentiment_pipe(texts, **sent_kwargs) + rewards = [torch.tensor(output[0]["score"] - script_args.reward_baseline) for output in pipe_outputs] + + # Run PPO step + stats = ppo_trainer.step(question_tensors, response_tensors, rewards) + ppo_trainer.log_stats(stats, batch, rewards) + + if script_args.save_freq and epoch and epoch % script_args.save_freq == 0: + ppo_trainer.save_pretrained(script_args.output_dir + f"step_{epoch}") diff --git a/examples/research_projects/stack_llama/scripts/supervised_finetuning.py b/examples/research_projects/stack_llama/scripts/supervised_finetuning.py new file mode 100644 index 0000000000000000000000000000000000000000..47669ac8a71a278cbab863fdf7805110130e9b27 --- /dev/null +++ b/examples/research_projects/stack_llama/scripts/supervised_finetuning.py @@ -0,0 +1,208 @@ +import argparse +import os + +from accelerate import Accelerator +from datasets import load_dataset +from peft import LoraConfig +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, logging, set_seed + +from trl import SFTTrainer +from trl.trainer import ConstantLengthDataset + + +""" +Fine-Tune Llama-7b on SE paired dataset +""" + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str, default="") + parser.add_argument("--dataset_name", type=str, default="lvwerra/stack-exchange-paired") + parser.add_argument("--subset", type=str, default="data/finetune") + parser.add_argument("--split", type=str, default="train") + parser.add_argument("--size_valid_set", type=int, default=4000) + parser.add_argument("--streaming", action="store_true") + parser.add_argument("--shuffle_buffer", type=int, default=5000) + + parser.add_argument("--seq_length", type=int, default=1024) + parser.add_argument("--max_steps", type=int, default=10000) + parser.add_argument("--batch_size", type=int, default=4) + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument("--eos_token_id", type=int, default=49152) + + parser.add_argument("--learning_rate", type=float, default=1e-4) + parser.add_argument("--lr_scheduler_type", type=str, default="cosine") + parser.add_argument("--num_warmup_steps", type=int, default=100) + parser.add_argument("--weight_decay", type=float, default=0.05) + + parser.add_argument("--local_rank", type=int, default=0) + parser.add_argument("--no_fp16", action="store_false") + parser.add_argument("--bf16", action="store_true", default=False) + parser.add_argument("--no_gradient_checkpointing", action="store_false", default=False) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--num_workers", type=int, default=None) + parser.add_argument("--output_dir", type=str, default="./checkpoints") + parser.add_argument("--log_freq", default=1, type=int) + parser.add_argument("--eval_freq", default=1000, type=int) + parser.add_argument("--save_freq", default=1000, type=int) + + return parser.parse_args() + + +def chars_token_ratio(dataset, tokenizer, nb_examples=400): + """ + Estimate the average number of characters per token in the dataset. + """ + total_characters, total_tokens = 0, 0 + for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples): + text = prepare_sample_text(example) + total_characters += len(text) + if tokenizer.is_fast: + total_tokens += len(tokenizer(text).tokens()) + else: + total_tokens += len(tokenizer.tokenize(text)) + + return total_characters / total_tokens + + +def print_trainable_parameters(model): + """ + Prints the number of trainable parameters in the model. + """ + trainable_params = 0 + all_param = 0 + for _, param in model.named_parameters(): + all_param += param.numel() + if param.requires_grad: + trainable_params += param.numel() + print( + f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" + ) + + +def prepare_sample_text(example): + """Prepare the text from a sample of the dataset.""" + text = f"Question: {example['question']}\n\nAnswer: {example['response_j']}" + return text + + +def create_datasets(tokenizer, args): + dataset = load_dataset( + args.dataset_name, + data_dir=args.subset, + split=args.split, + use_auth_token=True, + num_proc=args.num_workers if not args.streaming else None, + streaming=args.streaming, + ) + if args.streaming: + print("Loading the dataset in streaming mode") + valid_data = dataset.take(args.size_valid_set) + train_data = dataset.skip(args.size_valid_set) + train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed) + else: + dataset = dataset.train_test_split(test_size=0.005, seed=args.seed) + train_data = dataset["train"] + valid_data = dataset["test"] + print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}") + + chars_per_token = chars_token_ratio(train_data, tokenizer) + print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}") + + train_dataset = ConstantLengthDataset( + tokenizer, + train_data, + formatting_func=prepare_sample_text, + infinite=True, + seq_length=args.seq_length, + chars_per_token=chars_per_token, + ) + valid_dataset = ConstantLengthDataset( + tokenizer, + valid_data, + formatting_func=prepare_sample_text, + infinite=False, + seq_length=args.seq_length, + chars_per_token=chars_per_token, + ) + return train_dataset, valid_dataset + + +def run_training(args, train_data, val_data): + print("Loading the model") + + lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + train_data.start_iteration = 0 + + print("Starting main loop") + + training_args = TrainingArguments( + output_dir=args.output_dir, + dataloader_drop_last=True, + evaluation_strategy="steps", + max_steps=args.max_steps, + eval_steps=args.eval_freq, + save_steps=args.save_freq, + logging_steps=args.log_freq, + per_device_train_batch_size=args.batch_size, + per_device_eval_batch_size=args.batch_size, + learning_rate=args.learning_rate, + lr_scheduler_type=args.lr_scheduler_type, + warmup_steps=args.num_warmup_steps, + gradient_accumulation_steps=args.gradient_accumulation_steps, + gradient_checkpointing=not args.no_gradient_checkpointing, + fp16=not args.no_fp16, + bf16=args.bf16, + weight_decay=args.weight_decay, + run_name="llama-7b-finetuned", + report_to="wandb", + ddp_find_unused_parameters=False, + ) + + model = AutoModelForCausalLM.from_pretrained( + args.model_path, load_in_8bit=True, device_map={"": Accelerator().process_index} + ) + + trainer = SFTTrainer( + model=model, + args=training_args, + train_dataset=train_data, + eval_dataset=val_data, + peft_config=lora_config, + packing=True, + ) + + print_trainable_parameters(trainer.model) + + print("Training...") + trainer.train() + + print("Saving last checkpoint of the model") + trainer.model.save_pretrained(os.path.join(args.output_dir, "final_checkpoint/")) + + +def main(args): + tokenizer = AutoTokenizer.from_pretrained(args.model_path) + train_dataset, eval_dataset = create_datasets(tokenizer, args) + run_training(args, train_dataset, eval_dataset) + + +if __name__ == "__main__": + args = get_args() + assert args.model_path != "", "Please provide the llama model path" + + set_seed(args.seed) + os.makedirs(args.output_dir, exist_ok=True) + + logging.set_verbosity_error() + + main(args) diff --git a/examples/research_projects/stack_llama_2/scripts/README.md b/examples/research_projects/stack_llama_2/scripts/README.md new file mode 100644 index 0000000000000000000000000000000000000000..727a631d8d120f25f4605d93e97539443fd5da8d --- /dev/null +++ b/examples/research_projects/stack_llama_2/scripts/README.md @@ -0,0 +1,76 @@ +# DPO pipeline for the creation of StackLlaMa 2: a Stack exchange llama-v2-7b model + +## Prerequisites + +Install all the dependencies in the `requirements.txt`: + +``` +$ pip install -U -r requirements.txt +``` + +Since we will use `accelerate` for training, make sure to run: +``` +$ accelerate config +``` + +## Training + +There were two main steps to the DPO training process: +1. Supervised fine-tuning of the base llama-v2-7b model to create llama-v2-7b-se: + + ``` + accelerate launch examples/research_projects/stack_llama_2/scripts/sft_llama2.py \ + --output_dir="./sft" \ + --max_steps=500 \ + --logging_steps=10 \ + --save_steps=10 \ + --per_device_train_batch_size=4 \ + --per_device_eval_batch_size=1 \ + --gradient_accumulation_steps=2 \ + --gradient_checkpointing=False \ + --group_by_length=False \ + --learning_rate=1e-4 \ + --lr_scheduler_type="cosine" \ + --warmup_steps=100 \ + --weight_decay=0.05 \ + --optim="paged_adamw_32bit" \ + --bf16=True \ + --remove_unused_columns=False \ + --run_name="sft_llama2" \ + --report_to="wandb" + ``` +1. Run the DPO trainer using the model saved by the previous step: + ``` + accelerate launch examples/research_projects/stack_llama_2/scripts/dpo_llama2.py \ + --model_name_or_path="sft/final_checkpoint" \ + --output_dir="dpo" + ``` + + +## Merging the adaptors + +To merge the adaptors into the base model we can use the `merge_peft_adapter.py` helper script that comes with TRL: + +``` +python examples/research_projects/stack_llama/scripts/merge_peft_adapter.py --base_model_name="meta-llama/Llama-2-7b-hf" --adapter_model_name="dpo/final_checkpoint/" --output_name="stack-llama-2" +``` + +which will also push the model to your HuggingFace hub account. + +## Running the model + +We can load the DPO-trained LoRA adaptors which were saved by the DPO training step and load them via: + +```py +from peft import AutoPeftModelForCausalLM + + +model = AutoPeftModelForCausalLM.from_pretrained( + "dpo/final_checkpoint", + low_cpu_mem_usage=True, + torch_dtype=torch.float16, + load_in_4bit=True, +) + +model.generate(...) +``` diff --git a/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py b/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py new file mode 100644 index 0000000000000000000000000000000000000000..d21ecd3d4b3b96ea96b8d1de271593f30633e413 --- /dev/null +++ b/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py @@ -0,0 +1,223 @@ +# 0. imports +import os +from dataclasses import dataclass, field +from typing import Dict, Optional + +import torch +from datasets import Dataset, load_dataset +from peft import LoraConfig +from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments + +from trl import DPOTrainer + + +# Define and parse arguments. +@dataclass +class ScriptArguments: + """ + The arguments for the DPO training script. + """ + + # data parameters + beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"}) + + # training parameters + model_name_or_path: Optional[str] = field( + default="../sft/results/final_checkpoint", + metadata={"help": "the location of the SFT model name or path"}, + ) + learning_rate: Optional[float] = field(default=5e-4, metadata={"help": "optimizer learning rate"}) + lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "the lr scheduler type"}) + warmup_steps: Optional[int] = field(default=100, metadata={"help": "the number of warmup steps"}) + weight_decay: Optional[float] = field(default=0.05, metadata={"help": "the weight decay"}) + optimizer_type: Optional[str] = field(default="paged_adamw_32bit", metadata={"help": "the optimizer type"}) + + per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "train batch size per device"}) + per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "eval batch size per device"}) + gradient_accumulation_steps: Optional[int] = field( + default=4, metadata={"help": "the number of gradient accumulation steps"} + ) + gradient_checkpointing: Optional[bool] = field( + default=True, metadata={"help": "whether to use gradient checkpointing"} + ) + + lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"}) + lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"}) + lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"}) + + max_prompt_length: Optional[int] = field(default=512, metadata={"help": "the maximum prompt length"}) + max_length: Optional[int] = field(default=1024, metadata={"help": "the maximum sequence length"}) + max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"}) + logging_steps: Optional[int] = field(default=10, metadata={"help": "the logging frequency"}) + save_steps: Optional[int] = field(default=100, metadata={"help": "the saving frequency"}) + eval_steps: Optional[int] = field(default=100, metadata={"help": "the evaluation frequency"}) + + output_dir: Optional[str] = field(default="./results", metadata={"help": "the output directory"}) + log_freq: Optional[int] = field(default=1, metadata={"help": "the logging frequency"}) + + # instrumentation + sanity_check: Optional[bool] = field(default=False, metadata={"help": "only train on 1000 samples"}) + report_to: Optional[str] = field( + default="wandb", + metadata={ + "help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,' + '`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. ' + 'Use `"all"` to report to all integrations installed, `"none"` for no integrations.' + }, + ) + # debug argument for distributed training + ignore_bias_buffers: Optional[bool] = field( + default=False, + metadata={ + "help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See" + "https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992" + }, + ) + + +def get_stack_exchange_paired( + data_dir: str = "data/rl", + sanity_check: bool = False, + cache_dir: str = None, + num_proc=24, +) -> Dataset: + """Load the stack-exchange-paired dataset from Hugging Face and convert it to the necessary format. + + The dataset is converted to a dictionary with the following structure: + { + 'prompt': List[str], + 'chosen': List[str], + 'rejected': List[str], + } + + Prompts are structured as follows: + "Question: " + + "\n\nAnswer: " + """ + dataset = load_dataset( + "lvwerra/stack-exchange-paired", + split="train", + cache_dir=cache_dir, + data_dir=data_dir, + ) + original_columns = dataset.column_names + + if sanity_check: + dataset = dataset.select(range(min(len(dataset), 1000))) + + def return_prompt_and_responses(samples) -> Dict[str, str]: + return { + "prompt": ["Question: " + question + "\n\nAnswer: " for question in samples["question"]], + "chosen": samples["response_j"], + "rejected": samples["response_k"], + } + + return dataset.map( + return_prompt_and_responses, + batched=True, + num_proc=num_proc, + remove_columns=original_columns, + ) + + +if __name__ == "__main__": + parser = HfArgumentParser(ScriptArguments) + script_args = parser.parse_args_into_dataclasses()[0] + + # 1. load a pretrained model + model = AutoModelForCausalLM.from_pretrained( + script_args.model_name_or_path, + low_cpu_mem_usage=True, + torch_dtype=torch.float16, + load_in_4bit=True, + ) + model.config.use_cache = False + + if script_args.ignore_bias_buffers: + # torch distributed hack + model._ddp_params_and_buffers_to_ignore = [ + name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool + ] + + model_ref = AutoModelForCausalLM.from_pretrained( + script_args.model_name_or_path, + low_cpu_mem_usage=True, + torch_dtype=torch.float16, + load_in_4bit=True, + ) + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + tokenizer.pad_token = tokenizer.eos_token + + # 2. Load the Stack-exchange paired dataset + train_dataset = get_stack_exchange_paired(data_dir="data/rl", sanity_check=script_args.sanity_check) + train_dataset = train_dataset.filter( + lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length + and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length + ) + + # 3. Load evaluation dataset + eval_dataset = get_stack_exchange_paired(data_dir="data/evaluation", sanity_check=True) + eval_dataset = eval_dataset.filter( + lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length + and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length + ) + + # 4. initialize training arguments: + training_args = TrainingArguments( + per_device_train_batch_size=script_args.per_device_train_batch_size, + per_device_eval_batch_size=script_args.per_device_eval_batch_size, + max_steps=script_args.max_steps, + logging_steps=script_args.logging_steps, + save_steps=script_args.save_steps, + gradient_accumulation_steps=script_args.gradient_accumulation_steps, + gradient_checkpointing=script_args.gradient_checkpointing, + learning_rate=script_args.learning_rate, + evaluation_strategy="steps", + eval_steps=script_args.eval_steps, + output_dir=script_args.output_dir, + report_to=script_args.report_to, + lr_scheduler_type=script_args.lr_scheduler_type, + warmup_steps=script_args.warmup_steps, + optim=script_args.optimizer_type, + bf16=True, + remove_unused_columns=False, + run_name="dpo_llama2", + ) + + peft_config = LoraConfig( + r=script_args.lora_r, + lora_alpha=script_args.lora_alpha, + lora_dropout=script_args.lora_dropout, + target_modules=[ + "q_proj", + "v_proj", + "k_proj", + "out_proj", + "fc_in", + "fc_out", + "wte", + ], + bias="none", + task_type="CAUSAL_LM", + ) + + # 5. initialize the DPO trainer + dpo_trainer = DPOTrainer( + model, + model_ref, + args=training_args, + beta=script_args.beta, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + peft_config=peft_config, + max_prompt_length=script_args.max_prompt_length, + max_length=script_args.max_length, + ) + + # 6. train + dpo_trainer.train() + dpo_trainer.save_model(script_args.output_dir) + + # 7. save + output_dir = os.path.join(script_args.output_dir, "final_checkpoint") + dpo_trainer.model.save_pretrained(output_dir) diff --git a/examples/research_projects/stack_llama_2/scripts/requirements.txt b/examples/research_projects/stack_llama_2/scripts/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..ca124e58df8e4269a4d44d3ceccd0e2a05ea4fae --- /dev/null +++ b/examples/research_projects/stack_llama_2/scripts/requirements.txt @@ -0,0 +1,7 @@ +transformers +trl +peft +accelerate +datasets +bitsandbytes +wandb diff --git a/examples/research_projects/stack_llama_2/scripts/sft_llama2.py b/examples/research_projects/stack_llama_2/scripts/sft_llama2.py new file mode 100644 index 0000000000000000000000000000000000000000..94fc4e72c8125bd61240de055cade6a7a2d8978f --- /dev/null +++ b/examples/research_projects/stack_llama_2/scripts/sft_llama2.py @@ -0,0 +1,185 @@ +# Fine-Tune Llama2-7b on SE paired dataset +import os +from dataclasses import dataclass, field +from typing import Optional + +import torch +from accelerate import Accelerator +from datasets import load_dataset +from peft import AutoPeftModelForCausalLM, LoraConfig +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments + +from trl import SFTTrainer +from trl.import_utils import is_xpu_available +from trl.trainer import ConstantLengthDataset + + +@dataclass +class ScriptArguments: + model_name: Optional[str] = field(default="meta-llama/Llama-2-7b-hf", metadata={"help": "the model name"}) + dataset_name: Optional[str] = field(default="lvwerra/stack-exchange-paired", metadata={"help": "the dataset name"}) + subset: Optional[str] = field(default="data/finetune", metadata={"help": "the subset to use"}) + split: Optional[str] = field(default="train", metadata={"help": "the split to use"}) + size_valid_set: Optional[int] = field(default=4000, metadata={"help": "the size of the validation set"}) + streaming: Optional[bool] = field(default=True, metadata={"help": "whether to stream the dataset"}) + shuffle_buffer: Optional[int] = field(default=5000, metadata={"help": "the shuffle buffer size"}) + seq_length: Optional[int] = field(default=1024, metadata={"help": "the sequence length"}) + num_workers: Optional[int] = field(default=4, metadata={"help": "the number of workers"}) + packing: Optional[bool] = field(default=True, metadata={"help": "whether to use packing for SFTTrainer"}) + + # LoraConfig + lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"}) + lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"}) + lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"}) + + +parser = HfArgumentParser((ScriptArguments, TrainingArguments)) +script_args, training_args = parser.parse_args_into_dataclasses() +peft_config = LoraConfig( + r=script_args.lora_r, + lora_alpha=script_args.lora_alpha, + lora_dropout=script_args.lora_dropout, + target_modules=["q_proj", "v_proj"], + bias="none", + task_type="CAUSAL_LM", +) + +if training_args.group_by_length and script_args.packing: + raise ValueError("Cannot use both packing and group by length") + +# `gradient_checkpointing` was True by default until `1f3314`, but it's actually not used. +# `gradient_checkpointing=True` will cause `Variable._execution_engine.run_backward`. +if training_args.gradient_checkpointing: + raise ValueError("gradient_checkpointing not supported") + + +def chars_token_ratio(dataset, tokenizer, nb_examples=400): + """ + Estimate the average number of characters per token in the dataset. + """ + total_characters, total_tokens = 0, 0 + for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples): + text = prepare_sample_text(example) + total_characters += len(text) + if tokenizer.is_fast: + total_tokens += len(tokenizer(text).tokens()) + else: + total_tokens += len(tokenizer.tokenize(text)) + + return total_characters / total_tokens + + +def print_trainable_parameters(model): + """ + Prints the number of trainable parameters in the model. + """ + trainable_params = 0 + all_param = 0 + for _, param in model.named_parameters(): + all_param += param.numel() + if param.requires_grad: + trainable_params += param.numel() + print( + f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" + ) + + +def prepare_sample_text(example): + """Prepare the text from a sample of the dataset.""" + text = f"Question: {example['question']}\n\nAnswer: {example['response_j']}" + return text + + +def create_datasets(tokenizer, args): + dataset = load_dataset( + args.dataset_name, + data_dir=args.subset, + split=args.split, + use_auth_token=True, + num_proc=args.num_workers if not args.streaming else None, + streaming=args.streaming, + ) + if args.streaming: + print("Loading the dataset in streaming mode") + valid_data = dataset.take(args.size_valid_set) + train_data = dataset.skip(args.size_valid_set) + train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=None) + else: + dataset = dataset.train_test_split(test_size=0.005, seed=None) + train_data = dataset["train"] + valid_data = dataset["test"] + print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}") + + chars_per_token = chars_token_ratio(train_data, tokenizer) + print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}") + + train_dataset = ConstantLengthDataset( + tokenizer, + train_data, + formatting_func=prepare_sample_text, + infinite=True, + seq_length=args.seq_length, + chars_per_token=chars_per_token, + ) + valid_dataset = ConstantLengthDataset( + tokenizer, + valid_data, + formatting_func=prepare_sample_text, + infinite=False, + seq_length=args.seq_length, + chars_per_token=chars_per_token, + ) + return train_dataset, valid_dataset + + +bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, +) + +base_model = AutoModelForCausalLM.from_pretrained( + script_args.model_name, + quantization_config=bnb_config, + device_map={"": Accelerator().local_process_index}, + trust_remote_code=True, + use_auth_token=True, +) +base_model.config.use_cache = False + + +tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, trust_remote_code=True) +tokenizer.pad_token = tokenizer.eos_token +tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training + +train_dataset, eval_dataset = create_datasets(tokenizer, script_args) + +trainer = SFTTrainer( + model=base_model, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + peft_config=peft_config, + packing=script_args.packing, + max_seq_length=None, + tokenizer=tokenizer, + args=training_args, +) +trainer.train() +trainer.save_model(training_args.output_dir) + +output_dir = os.path.join(training_args.output_dir, "final_checkpoint") +trainer.model.save_pretrained(output_dir) + +# Free memory for merging weights +del base_model +if is_xpu_available(): + torch.xpu.empty_cache() +else: + torch.cuda.empty_cache() + +model = AutoPeftModelForCausalLM.from_pretrained(output_dir, device_map="auto", torch_dtype=torch.bfloat16) +model = model.merge_and_unload() + +output_merged_dir = os.path.join(training_args.output_dir, "final_merged_checkpoint") +model.save_pretrained(output_merged_dir, safe_serialization=True) diff --git a/examples/research_projects/tools/calculator.py b/examples/research_projects/tools/calculator.py new file mode 100644 index 0000000000000000000000000000000000000000..76779695fe741eda8f293be5d24c1a521b13efc3 --- /dev/null +++ b/examples/research_projects/tools/calculator.py @@ -0,0 +1,119 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +import numpy as np +import torch +from transformers import AutoTokenizer, load_tool + +from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, TextEnvironment + + +def generate_data(n): + """Generate random arithmetic tasks and answers.""" + tasks, answers = [], [] + for _ in range(n): + a = np.random.randint(0, 50) + b = np.random.randint(0, 50) + op = np.random.choice(["-", "+", "*"]) + tasks.append(f"\n\nWhat is {a} {op} {b}?") + if op == "-": + answers.append(a - b) + elif op == "+": + answers.append(a + b) + else: + answers.append(a * b) + return tasks, answers + + +def exact_match_reward(responses, answers=None): + """Reward if generated response contains correct answer.""" + rewards = [] + pattern = r"Result\s*=\s*(-?\d+(?:\.\d+)?)\s*" # generated by chatGPT + for response, answer in zip(responses, answers): + reward = 0.0 + predicted_number = None + match_pattern = re.findall(pattern, response) + if match_pattern: + predicted_number = float(match_pattern[0]) + if predicted_number is not None: + if np.abs(predicted_number - answer) < 0.01: + reward += 1.0 + rewards.append(torch.tensor(reward)) + return rewards + + +# set up models +model_id = "gpt2" +model = AutoModelForCausalLMWithValueHead.from_pretrained(model_id) +model_ref = AutoModelForCausalLMWithValueHead.from_pretrained(model_id) +tokenizer = AutoTokenizer.from_pretrained(model_id) +tokenizer.pad_token = tokenizer.eos_token + +# system prompt +prompt = """\ +What is 13-3? + +13-310.0 + +Result=10 + +What is 4*3? + +4*312.0 + +Result=12""" + +generation_kwargs = { + "min_length": -1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + "pad_token_id": tokenizer.eos_token_id, + "eos_token_id": -1, + "max_new_tokens": 32, +} + +# trainer +ppo_config = PPOConfig( + batch_size=256, + learning_rate=1.41e-5, + mini_batch_size=64, + log_with="wandb", +) +ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer) + +# text env +text_env = TextEnvironment( + model, + tokenizer, + {"SimpleCalculatorTool": load_tool("ybelkada/simple-calculator")}, + exact_match_reward, + prompt, + generation_kwargs=generation_kwargs, +) + +# main training loop +for step in range(100): + tasks, answers = generate_data(ppo_config.batch_size) + queries, responses, masks, rewards, histories = text_env.run(tasks, answers=answers) + train_stats = ppo_trainer.step(queries, responses, rewards, masks) + + response_texts = [tokenizer.decode(response) for response in responses] + query_texts = [tokenizer.decode(query) for query in queries] + texts = {"query": [qt.split("")[-1].strip() for qt in query_texts], "response": response_texts} + ppo_trainer.log_stats(train_stats, texts, rewards, columns_to_log=["query", "response", "answer"]) +ppo_trainer.save_pretrained(model_id + "-calculator") diff --git a/examples/research_projects/tools/python_interpreter.py b/examples/research_projects/tools/python_interpreter.py new file mode 100644 index 0000000000000000000000000000000000000000..b7b69806ef31922a4d49b9dc823e78dd2d9b49c4 --- /dev/null +++ b/examples/research_projects/tools/python_interpreter.py @@ -0,0 +1,194 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +from dataclasses import dataclass, field +from typing import Optional + +import numpy as np +import torch +from datasets import load_dataset +from peft import LoraConfig +from transformers import AutoTokenizer, HfArgumentParser, load_tool + +from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, TextEnvironment + + +os.environ["HF_ALLOW_CODE_EVAL"] = "1" +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +@dataclass +class ScriptArguments: + model_name: Optional[str] = field(default="bigcode/starcoderbase", metadata={"help": "the model name"}) + learning_rate: Optional[float] = field(default=1e-5, metadata={"help": "the learning rate"}) + mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"}) + batch_size: Optional[int] = field(default=32, metadata={"help": "the batch size"}) + gradient_accumulation_steps: Optional[int] = field( + default=16, metadata={"help": "the number of gradient accumulation steps"} + ) + max_new_tokens: Optional[int] = field(default=256, metadata={"help": "max number of generated tokens per turn"}) + ppo_epochs: Optional[int] = field(default=1, metadata={"help": "max number of ppo epochs"}) + n_epochs: Optional[int] = field(default=32, metadata={"help": "max number of ppo epochs"}) + + +parser = HfArgumentParser(ScriptArguments) +args = parser.parse_args_into_dataclasses()[0] + + +def exact_match_reward(responses, answers=None): + """Reward if generated response contains correct answer.""" + rewards = [] + pattern = r"Result\s*=\s*(-?\d+(?:\.\d+)?)\s*" # generated by chatGPT + for response, answer in zip(responses, answers): + reward = 0.0 + try: + predicted_number = None + match_pattern = re.findall(pattern, response) + if match_pattern: + predicted_number = float(match_pattern[0]) + if predicted_number is not None: + if np.abs((predicted_number - float(answer))) < 0.1: + reward += 1.0 + except: # noqa + pass + rewards.append(torch.tensor(reward)) + return rewards + + +def evaluate(test_dataloader, text_env, ppo_trainer): + test_rewards = [] + for test_batch in test_dataloader: + _, _, _, rewards, _ = text_env.run(test_batch["query"], answers=test_batch["answer"]) + test_rewards.extend(rewards) + test_rewards = ppo_trainer.accelerator.gather_for_metrics( + torch.stack(test_rewards).to(ppo_trainer.accelerator.device) + ) + return test_rewards.mean() + + +lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + target_modules=["c_proj", "c_attn", "q_attn"], +) + +# set up models +model = AutoModelForCausalLMWithValueHead.from_pretrained( + args.model_name, + use_auth_token=True, + load_in_4bit=True, + peft_config=lora_config, +) +tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_auth_token=True) +tokenizer.pad_token = tokenizer.eos_token + +ds = load_dataset("gsm8k", "main", split="train") +ds = ds.rename_columns({"question": "query"}) +ds = ds.map(lambda x: {"answer": x["answer"].split("#### ")[1]}) +ds = ds.select(range(1, len(ds))) # skip the first sample which is used in prompt + +ds_test = load_dataset("gsm8k", "main", split="test") +ds_test = ds_test.rename_columns({"question": "query"}) +ds_test = ds_test.map(lambda x: {"answer": x["answer"].split("#### ")[1]}) + +test_dataloader = torch.utils.data.DataLoader(ds_test, batch_size=args.batch_size) + +# prompt +prompt = """\ +Example of using a Python API to solve math questions. + +Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left? + + +def solution(): + money_initial = 23 + bagels = 5 + bagel_cost = 3 + money_spent = bagels * bagel_cost + money_left = money_initial - money_spent + result = money_left + return result +print(solution()) +72 + +Result = 72 + +Q: """ + +generation_kwargs = { + "min_length": -1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + "pad_token_id": tokenizer.eos_token_id, + "eos_token_id": -1, + "max_new_tokens": args.max_new_tokens, +} + +# trainer +ppo_config = PPOConfig( + batch_size=args.batch_size, + learning_rate=args.learning_rate, + mini_batch_size=args.mini_batch_size, + ppo_epochs=args.ppo_epochs, + gradient_accumulation_steps=args.gradient_accumulation_steps, + log_with="wandb", + tracker_project_name="trl-gsm8k", + remove_unused_columns=False, + optimize_cuda_cache=True, +) + +ppo_trainer = PPOTrainer(config=ppo_config, model=model, tokenizer=tokenizer, dataset=ds) +test_dataloader = ppo_trainer.accelerator.prepare(test_dataloader) + +# text env +text_env = TextEnvironment( + model, + tokenizer, + [load_tool("lvwerra/python-interpreter")], + exact_match_reward, + prompt, + max_turns=2, + generation_kwargs=generation_kwargs, +) + +# main training loop +for epoch in range(args.n_epochs): + for step, batch in enumerate(ppo_trainer.dataloader): + if (step == 0) and (epoch % 4 == 0): # evaluate every 4 epochs + reward_mean_test = evaluate(test_dataloader, text_env, ppo_trainer) + else: + reward_mean_test = None + + queries, responses, masks, rewards, histories = text_env.run(batch["query"], answers=batch["answer"]) + train_stats = ppo_trainer.step(queries, responses, rewards, masks) + + # logging + if reward_mean_test is not None: + train_stats["env/reward_mean_test"] = reward_mean_test + texts = { + "query": batch["query"], + "response": [tokenizer.decode(response) for response in responses], + "answer": batch["answer"], + } + ppo_trainer.log_stats(train_stats, texts, rewards, columns_to_log=["query", "response", "answer"]) + +reward_mean_test = evaluate(test_dataloader, text_env, ppo_trainer) +ppo_trainer.save_pretrained(f"model/{args.model_name}-gsm8k") diff --git a/examples/research_projects/tools/triviaqa.py b/examples/research_projects/tools/triviaqa.py new file mode 100644 index 0000000000000000000000000000000000000000..bd3bd9016618a2d9625831c1cd8c970b15bea646 --- /dev/null +++ b/examples/research_projects/tools/triviaqa.py @@ -0,0 +1,191 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from dataclasses import dataclass, field +from typing import Optional + +import torch +from datasets import load_dataset +from peft import LoraConfig +from transformers import AutoTokenizer, HfArgumentParser, load_tool + +from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, TextEnvironment + + +os.environ["HF_ALLOW_CODE_EVAL"] = "1" +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +@dataclass +class ScriptArguments: + model_name: Optional[str] = field(default="bigcode/starcoderbase", metadata={"help": "the model name"}) + log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"}) + learning_rate: Optional[float] = field(default=1e-5, metadata={"help": "the learning rate"}) + mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"}) + batch_size: Optional[int] = field(default=32, metadata={"help": "the batch size"}) + gradient_accumulation_steps: Optional[int] = field( + default=16, metadata={"help": "the number of gradient accumulation steps"} + ) + max_new_tokens: Optional[int] = field(default=256, metadata={"help": "max number of generated tokens per turn"}) + ppo_epochs: Optional[int] = field(default=1, metadata={"help": "max number of ppo epochs"}) + iterations: Optional[int] = field(default=1000, metadata={"help": "the number of iterations"}) + seed: Optional[int] = field(default=0, metadata={"help": "the random seed"}) + + +parser = HfArgumentParser(ScriptArguments) +args = parser.parse_args_into_dataclasses()[0] + +lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + target_modules=["c_proj", "c_attn", "q_attn"], +) + +# set up models +model = AutoModelForCausalLMWithValueHead.from_pretrained( + args.model_name, + use_auth_token=True, + trust_remote_code=True, + load_in_4bit=True, + peft_config=lora_config, +) +tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_auth_token=True) +tokenizer.pad_token = tokenizer.eos_token + +# system prompt +prompt = """\ +Answer the following question: + +Q: In which branch of the arts is Patricia Neary famous? +A: Ballets +A2: Patricia NearyPatricia Neary (born October 27, 1942) is an American ballerina, choreographer and ballet director, who has been particularly active in Switzerland. She has also been a highly successful ambassador for the Balanchine Trust, bringing George Balanchine's ballets to 60 cities around the globe. +Result=Ballets + +Q: Who won Super Bowl XX? +A: Chicago Bears +A2: Super Bowl XXSuper Bowl XX was an American football game between the National Football Conference (NFC) champion Chicago Bears and the American Football Conference (AFC) champion New England Patriots to decide the National Football League (NFL) champion for the 1985 season. The Bears defeated the Patriots by the score of 46–10, capturing their first NFL championship (and Chicago's first overall sports victory) since 1963, three years prior to the birth of the Super Bowl. Super Bowl XX was played on January 26, 1986 at the Louisiana Superdome in New Orleans. +Result=Chicago Bears + +Q: """ + +generation_kwargs = { + "min_length": -1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + "pad_token_id": tokenizer.eos_token_id, + "eos_token_id": -1, + "max_new_tokens": args.max_new_tokens, +} + +# trainer +config = PPOConfig( + batch_size=args.batch_size, + model_name=args.model_name, + learning_rate=args.learning_rate, + log_with=args.log_with, + mini_batch_size=args.mini_batch_size, + ppo_epochs=args.ppo_epochs, + gradient_accumulation_steps=args.gradient_accumulation_steps, + seed=args.seed, + optimize_cuda_cache=True, +) +ppo_trainer = PPOTrainer(config=config, model=model, tokenizer=tokenizer) +dataset = load_dataset("trivia_qa", "rc", split="train") +local_seed = args.seed + ppo_trainer.accelerator.process_index * 100003 # Prime +dataset = dataset.shuffle(local_seed) + + +def data_generator(): + for i in range(len(dataset)): + yield dataset[i]["question"], [item for item in dataset[i]["answer"]["normalized_aliases"]] + + +gen = data_generator() +gen = iter(gen) + + +def generate_data(n): + tasks, answers = [], [] + for i in range(n): + q, a = next(gen) + tasks.append(q) + answers.append(a) + return tasks, answers + + +def exact_match_reward(responses, answers=None): + """Reward if generated response contains correct answer.""" + rewards = [] + for response, answer in zip(responses, answers): + reward = 0.0 + for a in answer: + if a.lower() in response.lower(): + reward += 1.0 + break + rewards.append(torch.tensor(reward)) + return rewards + + +# text env +tool = load_tool("vwxyzjn/pyserini-wikipedia-kilt-doc") +# limit the amount if tokens +tool_fn = lambda x: tool(x).split("\n")[1][:600] # noqa +text_env = TextEnvironment( + model, + tokenizer, + {"Wiki": tool_fn}, + exact_match_reward, + prompt, + generation_kwargs=generation_kwargs, + max_tool_reponse=400, +) + + +def print_trainable_parameters(model): + trainable_params = 0 + all_param = 0 + for _, param in model.named_parameters(): + all_param += param.numel() + if param.requires_grad: + trainable_params += param.numel() + print( + f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" + ) + + +print_trainable_parameters(model) +# main training loop +for i in range(args.iterations): + tasks, answers = generate_data(config.batch_size) + queries, responses, masks, rewards, histories = text_env.run(tasks, answers=answers) + train_stats = ppo_trainer.step(queries, responses, rewards, masks) + response_texts = [tokenizer.decode(response) for response in responses] + query_texts = [tokenizer.decode(query) for query in queries] + texts = { + "query": [qt.split("")[-1].strip() for qt in query_texts], + "response": response_texts, + "answer": [", ".join(item) for item in answers], + } + all_rewards = ppo_trainer.accelerator.gather(torch.tensor(rewards, device=ppo_trainer.accelerator.device)) + ppo_trainer.log_stats( + train_stats, texts, [item for item in all_rewards], columns_to_log=["query", "response", "answer"] + ) + if i % 100 == 0: + ppo_trainer.save_pretrained(f"models/{args.model_name}_{args.seed}_{i}_triviaqa") diff --git a/examples/research_projects/toxicity/README.md b/examples/research_projects/toxicity/README.md new file mode 100644 index 0000000000000000000000000000000000000000..85967ab57ec5eeb10ea9eb6e372a62a0522e4d7e --- /dev/null +++ b/examples/research_projects/toxicity/README.md @@ -0,0 +1,7 @@ +# De-detoxifying language models + +To run this code, do the following: + +```shell +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file {CONFIG} examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py --log_with wandb +``` diff --git a/examples/research_projects/toxicity/scripts/evaluate-toxicity.py b/examples/research_projects/toxicity/scripts/evaluate-toxicity.py new file mode 100644 index 0000000000000000000000000000000000000000..c400641967544d96b768bd43f84536c393fe7684 --- /dev/null +++ b/examples/research_projects/toxicity/scripts/evaluate-toxicity.py @@ -0,0 +1,130 @@ +import argparse +import csv + +import evaluate +import numpy as np +import torch +from datasets import load_dataset +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer + +from trl.import_utils import is_xpu_available + + +toxicity = evaluate.load("ybelkada/toxicity", "DaNLP/da-electra-hatespeech-detection", module_type="measurement") +ds = load_dataset("OxAISH-AL-LLM/wiki_toxic", split="test") + +parser = argparse.ArgumentParser(description="Evaluate de-toxified models") +parser.add_argument("--model_type", default="all", type=str, help="Relative path to the source model folder") +parser.add_argument("--output_file", default="toxicity.csv", type=str, help="Relative path to the source model folder") +parser.add_argument("--batch_size", default=64, type=int, help="Batch size") +parser.add_argument("--num_samples", default=400, type=int, help="Number of samples") +parser.add_argument("--context_length", default=2000, type=int, help="Number of samples") +parser.add_argument("--max_new_tokens", default=30, type=int, help="Max new tokens for generation") +args = parser.parse_args() + + +if args.model_type == "all": + MODELS_TO_TEST = [ + "ybelkada/gpt-neo-125m-detox", + "EleutherAI/gpt-neo-125M", + "EleutherAI/gpt-neo-2.7B", + "ybelkada/gpt-neo-2.7B-detox", + "ybelkada/gpt-j-6b-sharded-bf16", + "ybelkada/gpt-j-6b-detoxs", + ] +elif args.model_type == "gpt-neo": + MODELS_TO_TEST = [ + "ybelkada/gpt-neo-125m-detox", + "EleutherAI/gpt-neo-125M", + "EleutherAI/gpt-neo-2.7B", + "ybelkada/gpt-neo-2.7B-detox", + ] +elif args.model_type == "gpt-j": + MODELS_TO_TEST = [ + "ybelkada/gpt-j-6b-sharded-bf16", + "ybelkada/gpt-j-6b-detox", + ] +else: + MODELS_TO_TEST = [args.model_type] +NUM_SAMPLES = args.num_samples +BATCH_SIZE = args.batch_size +output_file = args.output_file +max_new_tokens = args.max_new_tokens +context_length = args.context_length +if is_xpu_available(): + device = torch.xpu.current_device() +else: + device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu" + +# consider only toxic prompts +ds = ds.filter(lambda x: x["label"] == 1) + +toxicities = {} + +# open a csv file +file = open(f"{output_file}", "w", newline="") +writer = csv.writer(file) +# add first rows +writer.writerow(["model_id", "mean_toxicity", "std_toxicity"]) + + +for model_id in tqdm(MODELS_TO_TEST): + model = AutoModelForCausalLM.from_pretrained(model_id, device_map={"": device}, torch_dtype=torch.bfloat16) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + input_texts = [] + + for i, example in enumerate(ds): + # set seed + torch.manual_seed(42) + + input_text = example["comment_text"] + input_texts.append(input_text[:2000]) + + if i > NUM_SAMPLES: + break + + if (i + 1) % BATCH_SIZE == 0: + inputs = tokenizer(input_texts, return_tensors="pt", padding=True).to(device) + inputs.input_ids = inputs.input_ids[:context_length] + inputs.attention_mask = inputs.attention_mask[:context_length] + outputs = model.generate(**inputs, do_sample=True, max_new_tokens=max_new_tokens, use_cache=True) + generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True) + generated_texts = [ + generated_text.replace(input_texts[i], "") for i, generated_text in enumerate(generated_texts) + ] + toxicity_score = toxicity.compute(predictions=generated_texts) + input_texts = [] + + if model_id not in toxicities: + toxicities[model_id] = [] + toxicities[model_id].extend(toxicity_score["toxicity"]) + + # last batch + inputs = tokenizer(input_texts, return_tensors="pt", padding=True).to(device) + outputs = model.generate(**inputs, do_sample=True, max_new_tokens=30) + generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True) + generated_texts = [generated_text.replace(input_texts[i], "") for i, generated_text in enumerate(generated_texts)] + toxicity_score = toxicity.compute(predictions=generated_texts) + toxicities[model_id].extend(toxicity_score["toxicity"]) + + # compute mean & std using np + mean = np.mean(toxicities[model_id]) + std = np.std(toxicities[model_id]) + + # save to file + writer.writerow([model_id, mean, std]) + + # print + print(f"Model: {model_id} - Mean: {mean} - Std: {std}") + + model = None + if is_xpu_available(): + torch.xpu.empty_cache() + else: + torch.cuda.empty_cache() + +# close file +file.close() diff --git a/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py b/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py new file mode 100644 index 0000000000000000000000000000000000000000..a4fc18534b25d7dd564816675d47caa82957896a --- /dev/null +++ b/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py @@ -0,0 +1,238 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field +from typing import Optional + +import torch +from datasets import load_dataset +from torch.optim import Adam +from tqdm import tqdm +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + HfArgumentParser, + RobertaForSequenceClassification, + RobertaTokenizer, +) + +from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, create_reference_model, set_seed +from trl.core import LengthSampler + + +tqdm.pandas() + +######################################################################## +# This is a fully working simple example to use trl with accelerate. +# +# This example fine-tunes a GPTJ model to generate less toxic contents +# by using allenai/real-toxicity-prompts dataset. We use PPO +# (proximal policy optimization) to optimize the model. +# in any of the following settings (with the same script): +# - single CPU or single GPU +# - multi GPUS (using PyTorch distributed mode) +# - multi GPUS (using DeepSpeed ZeRO-Offload stages 1 & 2) +# - fp16 (mixed-precision) or fp32 (normal precision) +# +# To run it in each of these various modes, first initialize the accelerate +# configuration with `accelerate config` +# +######################################################################## + + +# We first define the configuration of the experiment, defining the model, the dataset, +# the training parameters, and the PPO parameters. +# Check the default arguments in the `PPOConfig` class for more details. +# If you want to log with tensorboard, add the kwarg +# `project_kwargs={"logging_dir": PATH_TO_LOGS}` to the PPOConfig. +@dataclass +class ScriptArguments: + """ + The name of the Casual LM model we wish to fine with PPO + """ + + # NOTE: gpt2 models use Conv1D instead of Linear layers which are not yet supported in 8 bit mode + # models like gpt-neo* models are more suitable. + model_name: Optional[str] = field(default="ybelkada/gpt-j-6b-sharded-bf16", metadata={"help": "the model name"}) + log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"}) + learning_rate: Optional[float] = field(default=(1.47e-5) * 2, metadata={"help": "the learning rate"}) + mini_batch_size: Optional[int] = field(default=4, metadata={"help": "the PPO minibatch size"}) + batch_size: Optional[int] = field(default=16, metadata={"help": "the batch size"}) + gradient_accumulation_steps: Optional[int] = field( + default=1, metadata={"help": "the number of gradient accumulation steps"} + ) + model_save_path: Optional[str] = field( + default="./gpt-j-6B-detoxified-long-context-26-shl-1e4-final", + metadata={"help": "the path to save the model"}, + ) + + +parser = HfArgumentParser(ScriptArguments) +script_args = parser.parse_args_into_dataclasses()[0] + +config = PPOConfig( + model_name=script_args.model_name, + learning_rate=script_args.learning_rate, + log_with=script_args.log_with, + ppo_epochs=100, + mini_batch_size=script_args.mini_batch_size, + batch_size=script_args.batch_size, + gradient_accumulation_steps=script_args.gradient_accumulation_steps, +) + + +# Below is an example function to build the dataset. In our case, we use the IMDB dataset +# from the `datasets` library. One should customize this function to train the model on +# its own dataset. +def build_dataset( + config, dataset_name="allenai/real-toxicity-prompts", input_min_text_length=5, input_max_text_length=10 +): + """ + Build dataset for training. This builds the dataset from `load_dataset`, one should + customize this function to train the model on its own dataset. + + Args: + dataset_name (`str`): + The name of the dataset to be loaded. + + Returns: + dataloader (`torch.utils.data.DataLoader`): + The dataloader for the dataset. + """ + tokenizer = AutoTokenizer.from_pretrained(config.model_name) + tokenizer.pad_token = tokenizer.eos_token + + ds = load_dataset(dataset_name, split="train") + + def filter_fn(sample): + toxicity = sample["prompt"]["toxicity"] + return toxicity is not None and toxicity > 0.3 + + ds = ds.filter(filter_fn, batched=False) + + input_size = LengthSampler(input_min_text_length, input_max_text_length) + + def tokenize(sample): + prompt = sample["prompt"]["text"] + continuation = sample["continuation"]["text"] + + sample["input_ids"] = tokenizer.encode(prompt + continuation)[: input_size()] + sample["query"] = tokenizer.decode(sample["input_ids"]) + return sample + + ds = ds.map(tokenize, batched=False) + ds.set_format(type="torch") + + ds = ds.train_test_split(test_size=0.2, shuffle=False)["train"] + + return ds + + +# We retrieve the dataloader by calling the `build_dataset` function. +min_input_length = 30 +max_input_length = 40 +dataset = build_dataset(config, input_min_text_length=min_input_length, input_max_text_length=max_input_length) + + +def collator(data): + return dict((key, [d[key] for d in data]) for key in data[0]) + + +# set seed before initializing value head for deterministic eval +set_seed(config.seed) + +# Now let's build the model, the reference model, and the tokenizer. We first load the model +# in bfloat16 to save memory using `transformers`. +model = AutoModelForCausalLM.from_pretrained(config.model_name, torch_dtype=torch.bfloat16) +# And then we pass the loaded model to `AutoModelForCausalLMWithValueHead`. +model = AutoModelForCausalLMWithValueHead.from_pretrained(model) + +# We create a reference model by sharing 20 layers +ref_model = create_reference_model(model, num_shared_layers=20) + +# We make sure to use `Adam` optimizer on the model parameters that require gradients. +optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.learning_rate) + +# GPT-2 / GPT-J tokenizer has a pad token, but it is not eos_token by default. We need to set it to eos_token. +# only for this model. +tokenizer = AutoTokenizer.from_pretrained(config.model_name) +tokenizer.pad_token = tokenizer.eos_token + +# We then build the PPOTrainer, passing the model, the reference model, the tokenizer +ppo_trainer = PPOTrainer( + config, + model, + ref_model=ref_model, + tokenizer=tokenizer, + dataset=dataset, + data_collator=collator, + optimizer=optimizer, +) + +# We then build the reward pipeline, we will use the toxicity model to compute the reward. +# We first load the toxicity model and tokenizer. +toxicity_model_id = "facebook/roberta-hate-speech-dynabench-r4-target" +toxicity_tokenizer = RobertaTokenizer.from_pretrained(toxicity_model_id) +# We load the toxicity model in fp16 to save memory. +toxicity_model = RobertaForSequenceClassification.from_pretrained(toxicity_model_id, torch_dtype=torch.float16).to( + ppo_trainer.accelerator.device +) + + +# We then define the arguments to pass to the `generate` function. These arguments +# are passed to the `generate` function of the PPOTrainer, which is a wrapper around +# the `generate` function of the trained model. +generation_kwargs = { + "min_length": -1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + "pad_token_id": tokenizer.eos_token_id, +} +output_min_length = 20 +output_max_length = 30 +output_length_sampler = LengthSampler(output_min_length, output_max_length) + +model_save_path = script_args.model_save_path + +for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)): + query_tensors = batch["input_ids"] + + # Get response from the policy model + response_tensors = [] + for query in query_tensors: + gen_len = output_length_sampler() + generation_kwargs["max_new_tokens"] = gen_len + response = ppo_trainer.generate(query, **generation_kwargs) + response_tensors.append(response.squeeze()[-gen_len:]) + batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors] + + # Compute sentiment score # noqa + texts = batch["response"] + toxicity_inputs = toxicity_tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to( + ppo_trainer.accelerator.device + ) + logits = toxicity_model(**toxicity_inputs).logits.float() + toxicity_labels = (logits[:, 0]).tolist() + + rewards = [torch.tensor(output) for output in toxicity_labels] + + # Run PPO step + stats = ppo_trainer.step(query_tensors, response_tensors, rewards) + ppo_trainer.log_stats(stats, batch, rewards) + + # Save model every 100 epochs + if epoch % 100 == 0: + if ppo_trainer.accelerator.is_main_process: + ppo_trainer.save_pretrained(model_save_path) diff --git a/examples/scripts/ddpo.py b/examples/scripts/ddpo.py new file mode 100644 index 0000000000000000000000000000000000000000..d42145e4d5aff761826c0d6bfefda6712c92bd22 --- /dev/null +++ b/examples/scripts/ddpo.py @@ -0,0 +1,204 @@ +# Copyright 2023 metric-space, The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +from dataclasses import dataclass, field + +import numpy as np +import torch +import torch.nn as nn +import tyro +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import EntryNotFoundError +from transformers import CLIPModel, CLIPProcessor + +from trl import DDPOConfig, DDPOTrainer, DefaultDDPOStableDiffusionPipeline +from trl.import_utils import is_xpu_available + + +@dataclass +class ScriptArguments: + hf_user_access_token: str + pretrained_model: str = "runwayml/stable-diffusion-v1-5" + """the pretrained model to use""" + pretrained_revision: str = "main" + """the pretrained model revision to use""" + hf_hub_model_id: str = "ddpo-finetuned-stable-diffusion" + """HuggingFace repo to save model weights to""" + hf_hub_aesthetic_model_id: str = "trl-lib/ddpo-aesthetic-predictor" + """HuggingFace model ID for aesthetic scorer model weights""" + hf_hub_aesthetic_model_filename: str = "aesthetic-model.pth" + """HuggingFace model filename for aesthetic scorer model weights""" + + ddpo_config: DDPOConfig = field( + default_factory=lambda: DDPOConfig( + num_epochs=200, + train_gradient_accumulation_steps=1, + sample_num_steps=50, + sample_batch_size=6, + train_batch_size=3, + sample_num_batches_per_epoch=4, + per_prompt_stat_tracking=True, + per_prompt_stat_tracking_buffer_size=32, + tracker_project_name="stable_diffusion_training", + log_with="wandb", + project_kwargs={ + "logging_dir": "./logs", + "automatic_checkpoint_naming": True, + "total_limit": 5, + "project_dir": "./save", + }, + ) + ) + + +class MLP(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(768, 1024), + nn.Dropout(0.2), + nn.Linear(1024, 128), + nn.Dropout(0.2), + nn.Linear(128, 64), + nn.Dropout(0.1), + nn.Linear(64, 16), + nn.Linear(16, 1), + ) + + @torch.no_grad() + def forward(self, embed): + return self.layers(embed) + + +class AestheticScorer(torch.nn.Module): + """ + This model attempts to predict the aesthetic score of an image. The aesthetic score + is a numerical approximation of how much a specific image is liked by humans on average. + This is from https://github.com/christophschuhmann/improved-aesthetic-predictor + """ + + def __init__(self, *, dtype, model_id, model_filename): + super().__init__() + self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") + self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") + self.mlp = MLP() + try: + cached_path = hf_hub_download(model_id, model_filename) + except EntryNotFoundError: + cached_path = os.path.join(model_id, model_filename) + state_dict = torch.load(cached_path) + self.mlp.load_state_dict(state_dict) + self.dtype = dtype + self.eval() + + @torch.no_grad() + def __call__(self, images): + device = next(self.parameters()).device + inputs = self.processor(images=images, return_tensors="pt") + inputs = {k: v.to(self.dtype).to(device) for k, v in inputs.items()} + embed = self.clip.get_image_features(**inputs) + # normalize embedding + embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True) + return self.mlp(embed).squeeze(1) + + +def aesthetic_scorer(hub_model_id, model_filename): + scorer = AestheticScorer( + model_id=hub_model_id, + model_filename=model_filename, + dtype=torch.float32, + ) + scorer = scorer.xpu() if is_xpu_available() else scorer.cuda() + + def _fn(images, prompts, metadata): + images = (images * 255).round().clamp(0, 255).to(torch.uint8) + scores = scorer(images) + return scores, {} + + return _fn + + +# list of example prompts to feed stable diffusion +animals = [ + "cat", + "dog", + "horse", + "monkey", + "rabbit", + "zebra", + "spider", + "bird", + "sheep", + "deer", + "cow", + "goat", + "lion", + "frog", + "chicken", + "duck", + "goose", + "bee", + "pig", + "turkey", + "fly", + "llama", + "camel", + "bat", + "gorilla", + "hedgehog", + "kangaroo", +] + + +def prompt_fn(): + return np.random.choice(animals), {} + + +def image_outputs_logger(image_data, global_step, accelerate_logger): + # For the sake of this example, we will only log the last batch of images + # and associated data + result = {} + images, prompts, _, rewards, _ = image_data[-1] + + for i, image in enumerate(images): + prompt = prompts[i] + reward = rewards[i].item() + result[f"{prompt:.25} | {reward:.2f}"] = image.unsqueeze(0) + + accelerate_logger.log_images( + result, + step=global_step, + ) + + +if __name__ == "__main__": + args = tyro.cli(ScriptArguments) + + pipeline = DefaultDDPOStableDiffusionPipeline( + args.pretrained_model, pretrained_model_revision=args.pretrained_revision, use_lora=True + ) + + trainer = DDPOTrainer( + args.ddpo_config, + aesthetic_scorer(args.hf_hub_aesthetic_model_id, args.hf_hub_aesthetic_model_filename), + prompt_fn, + pipeline, + image_samples_hook=image_outputs_logger, + ) + + trainer.train() + + trainer.push_to_hub(args.hf_hub_model_id, token=args.hf_user_access_token) diff --git a/examples/scripts/dpo.py b/examples/scripts/dpo.py new file mode 100644 index 0000000000000000000000000000000000000000..337c99af6acc2c4bd013cbcddbea2257ceea7b07 --- /dev/null +++ b/examples/scripts/dpo.py @@ -0,0 +1,182 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Note: you need to install transformers from main to run this script. See https://huggingface.co/docs/transformers/installation#install-from-source +# TODO: bump transformers version in requirements at next release. + +# 0. imports +from dataclasses import dataclass, field +from typing import Dict, Optional + +import torch +from datasets import Dataset, load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments + +from trl import DPOTrainer + + +# Define and parse arguments. +@dataclass +class ScriptArguments: + """ + The arguments for the DPO training script. + """ + + # data parameters + beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"}) + + # training parameters + model_name_or_path: Optional[str] = field(default="gpt2", metadata={"help": "the model name"}) + learning_rate: Optional[float] = field(default=1e-3, metadata={"help": "optimizer learning rate"}) + per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "batch size per device"}) + gradient_accumulation_steps: Optional[int] = field( + default=1, metadata={"help": "the number of gradient accumulation steps"} + ) + max_length: Optional[int] = field(default=512, metadata={"help": "max length of each sample"}) + max_prompt_length: Optional[int] = field(default=128, metadata={"help": "max length of each sample's prompt"}) + max_target_length: Optional[int] = field( + default=128, metadata={"help": "Only used for encoder decoder model. Max target of each sample's prompt"} + ) + label_pad_token_id: Optional[int] = field(default=-100, metadata={"help": "label for non response tokens"}) + max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"}) + # instrumentation + sanity_check: Optional[bool] = field(default=True, metadata={"help": "only train on 1000 samples"}) + report_to: Optional[str] = field( + default=None, + metadata={ + "help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,' + '`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. ' + 'Use `"all"` to report to all integrations installed, `"none"` for no integrations.' + }, + ) + # debug argument for distributed training + ignore_bias_buffers: Optional[bool] = field( + default=False, + metadata={ + "help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See" + "https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992" + }, + ) + gradient_checkpointing: Optional[bool] = field( + default=False, metadata={"help": "Whether to use gradient checkpointing or no"} + ) + gradient_checkpointing_kwargs: Optional[dict] = field( + default=None, + metadata={ + "help": "key word arguments to be passed along `torch.utils.checkpoint.checkpoint` method - e.g. `use_reentrant=False`" + }, + ) + + +def extract_anthropic_prompt(prompt_and_response): + """Extract the anthropic prompt from a prompt and response pair.""" + search_term = "\n\nAssistant:" + search_term_idx = prompt_and_response.rfind(search_term) + assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'" + return prompt_and_response[: search_term_idx + len(search_term)] + + +def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: str = None) -> Dataset: + """Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format. + + The dataset is converted to a dictionary with the following structure: + { + 'prompt': List[str], + 'chosen': List[str], + 'rejected': List[str], + } + + Prompts should be structured as follows: + \n\nHuman: \n\nAssistant: + Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:. + """ + dataset = load_dataset("Anthropic/hh-rlhf", split=split, cache_dir=cache_dir) + if sanity_check: + dataset = dataset.select(range(min(len(dataset), 1000))) + + def split_prompt_and_responses(sample) -> Dict[str, str]: + prompt = extract_anthropic_prompt(sample["chosen"]) + return { + "prompt": prompt, + "chosen": sample["chosen"][len(prompt) :], + "rejected": sample["rejected"][len(prompt) :], + } + + return dataset.map(split_prompt_and_responses) + + +if __name__ == "__main__": + parser = HfArgumentParser(ScriptArguments) + script_args = parser.parse_args_into_dataclasses()[0] + + # 1. load a pretrained model + model = AutoModelForCausalLM.from_pretrained(script_args.model_name_or_path) + + if script_args.ignore_bias_buffers: + # torch distributed hack + model._ddp_params_and_buffers_to_ignore = [ + name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool + ] + + model_ref = AutoModelForCausalLM.from_pretrained(script_args.model_name_or_path) + + tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # 2. Load the Anthropic Helpful-Harmless dataset + train_dataset = get_hh("train", sanity_check=script_args.sanity_check) + + # 3. Load evaluation dataset + eval_dataset = get_hh("test", sanity_check=script_args.sanity_check) + + # 4. initialize training arguments: + training_args = TrainingArguments( + per_device_train_batch_size=script_args.per_device_train_batch_size, + max_steps=script_args.max_steps, + remove_unused_columns=False, + gradient_accumulation_steps=script_args.gradient_accumulation_steps, + learning_rate=script_args.learning_rate, + evaluation_strategy="steps", + logging_first_step=True, + logging_steps=10, # match results in blog post + eval_steps=500, + output_dir="./test", + optim="rmsprop", + warmup_steps=150, + report_to=script_args.report_to, + bf16=True, + gradient_checkpointing=script_args.gradient_checkpointing, + # TODO: uncomment that on the next transformers release + # gradient_checkpointing_kwargs=script_args.gradient_checkpointing_kwargs, + ) + + # 5. initialize the DPO trainer + dpo_trainer = DPOTrainer( + model, + model_ref, + args=training_args, + beta=script_args.beta, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + max_length=script_args.max_length, + max_target_length=script_args.max_target_length, + max_prompt_length=script_args.max_prompt_length, + generate_during_eval=True, + ) + + # 6. train + dpo_trainer.train() diff --git a/examples/scripts/ppo.py b/examples/scripts/ppo.py new file mode 100644 index 0000000000000000000000000000000000000000..d84c1083820a5f6b9145f82f4506f1108117c269 --- /dev/null +++ b/examples/scripts/ppo.py @@ -0,0 +1,210 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field +from typing import Optional + +import torch +import tyro +from accelerate import Accelerator +from datasets import load_dataset +from peft import LoraConfig +from tqdm import tqdm +from transformers import AutoTokenizer, pipeline + +from trl import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead, PPOConfig, PPOTrainer, set_seed +from trl.core import LengthSampler +from trl.import_utils import is_xpu_available + + +tqdm.pandas() + + +@dataclass +class ScriptArguments: + ppo_config: PPOConfig = field( + default_factory=lambda: PPOConfig( + model_name="lvwerra/gpt2-imdb", + query_dataset="imdb", + reward_model="sentiment-analysis:lvwerra/distilbert-imdb", + learning_rate=1.41e-5, + log_with=None, + mini_batch_size=128, + batch_size=128, + gradient_accumulation_steps=1, + early_stopping=False, + target_kl=6.0, + kl_penalty="kl", + seed=0, + use_score_scaling=False, + use_score_norm=False, + score_clip=None, + ) + ) + use_seq2seq: bool = False + """whether to use seq2seq models""" + use_peft: bool = False + """whether to use peft""" + peft_config: Optional[LoraConfig] = field( + default_factory=lambda: LoraConfig( + r=16, + lora_alpha=16, + bias="none", + task_type="CAUSAL_LM", + ), + ) + trust_remote_code: bool = field(default=False, metadata={"help": "Enable `trust_remote_code`"}) + + +args = tyro.cli(ScriptArguments) + + +# We then define the arguments to pass to the sentiment analysis pipeline. +# We set `return_all_scores` to True to get the sentiment score for each token. +sent_kwargs = {"return_all_scores": True, "function_to_apply": "none", "batch_size": 16} + +trl_model_class = AutoModelForCausalLMWithValueHead if not args.use_seq2seq else AutoModelForSeq2SeqLMWithValueHead + + +# Below is an example function to build the dataset. In our case, we use the IMDB dataset +# from the `datasets` library. One should customize this function to train the model on +# its own dataset. +def build_dataset(config, query_dataset, input_min_text_length=2, input_max_text_length=8): + """ + Build dataset for training. This builds the dataset from `load_dataset`, one should + customize this function to train the model on its own dataset. + + Args: + query_dataset (`str`): + The name of the dataset to be loaded. + + Returns: + dataloader (`torch.utils.data.DataLoader`): + The dataloader for the dataset. + """ + tokenizer = AutoTokenizer.from_pretrained(config.model_name) + tokenizer.pad_token = tokenizer.eos_token + # load imdb with datasets + ds = load_dataset(query_dataset, split="train") + ds = ds.rename_columns({"text": "review"}) + ds = ds.filter(lambda x: len(x["review"]) > 200, batched=False) + + input_size = LengthSampler(input_min_text_length, input_max_text_length) + + def tokenize(sample): + sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()] + sample["query"] = tokenizer.decode(sample["input_ids"]) + return sample + + ds = ds.map(tokenize, batched=False) + ds.set_format(type="torch") + return ds + + +# We retrieve the dataloader by calling the `build_dataset` function. +dataset = build_dataset(args.ppo_config, args.ppo_config.query_dataset) + + +def collator(data): + return dict((key, [d[key] for d in data]) for key in data[0]) + + +# set seed before initializing value head for deterministic eval +set_seed(args.ppo_config.seed) + +# Now let's build the model, the reference model, and the tokenizer. +if not args.use_peft: + ref_model = trl_model_class.from_pretrained(args.ppo_config.model_name, trust_remote_code=args.trust_remote_code) + device_map = None + peft_config = None +else: + peft_config = args.peft_config + ref_model = None + # Copy the model to each device + device_map = {"": Accelerator().local_process_index} + +model = trl_model_class.from_pretrained( + args.ppo_config.model_name, + trust_remote_code=args.trust_remote_code, + device_map=device_map, + peft_config=peft_config, +) + + +tokenizer = AutoTokenizer.from_pretrained(args.ppo_config.model_name) + +# Some tokenizers like GPT-2's don't have a padding token by default, so we set one here. +tokenizer.pad_token_id = tokenizer.eos_token_id + +# We then build the PPOTrainer, passing the model, the reference model, the tokenizer +ppo_trainer = PPOTrainer(args.ppo_config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator) + +# We then build the sentiment analysis pipeline, passing the model name and the +# sentiment analysis pipeline arguments. Let's also make sure to set the device +# to the same device as the PPOTrainer. +device = ppo_trainer.accelerator.device +if ppo_trainer.accelerator.num_processes == 1: + if is_xpu_available(): + device = "xpu:0" + else: + device = 0 if torch.cuda.is_available() else "cpu" # to avoid a `pipeline` bug +ds_plugin = ppo_trainer.accelerator.state.deepspeed_plugin +task, model_name = args.ppo_config.reward_model.split(":") +if ds_plugin is not None and ds_plugin.is_zero3_init_enabled(): + with ds_plugin.zero3_init_context_manager(enable=False): + sentiment_pipe = pipeline(task, model=model_name, device=device) +else: + sentiment_pipe = pipeline(task, model=model_name, device=device) + +# Some tokenizers like GPT-2's don't have a padding token by default, so we set one here. +if sentiment_pipe.tokenizer.pad_token_id is None: + sentiment_pipe.tokenizer.pad_token_id = tokenizer.pad_token_id + +if sentiment_pipe.model.config.pad_token_id is None: + sentiment_pipe.model.config.pad_token_id = tokenizer.pad_token_id + +# We then define the arguments to pass to the `generate` function. These arguments +# are passed to the `generate` function of the PPOTrainer, which is a wrapper around +# the `generate` function of the trained model. +generation_kwargs = { + "min_length": -1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + "pad_token_id": tokenizer.eos_token_id, + "max_new_tokens": 32, +} + +for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)): + query_tensors = batch["input_ids"] + + # Get response from gpt2 + response_tensors, ref_response_tensors = ppo_trainer.generate( + query_tensors, return_prompt=False, generate_ref_response=True, **generation_kwargs + ) + batch["response"] = tokenizer.batch_decode(response_tensors) + batch["ref_response"] = tokenizer.batch_decode(ref_response_tensors) + + # Compute sentiment score + texts = [q + r for q, r in zip(batch["query"], batch["response"])] + pipe_outputs = sentiment_pipe(texts, **sent_kwargs) + rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs] + ref_texts = [q + r for q, r in zip(batch["query"], batch["ref_response"])] + ref_pipe_outputs = sentiment_pipe(ref_texts, **sent_kwargs) + ref_rewards = [torch.tensor(output[1]["score"]) for output in ref_pipe_outputs] + batch["ref_rewards"] = ref_rewards + + # Run PPO step + stats = ppo_trainer.step(query_tensors, response_tensors, rewards) + ppo_trainer.log_stats(stats, batch, rewards, columns_to_log=["query", "response", "ref_response", "ref_rewards"]) diff --git a/examples/scripts/ppo_multi_adapter.py b/examples/scripts/ppo_multi_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..044a2419de8b0d3859d1951a34d85824caa58480 --- /dev/null +++ b/examples/scripts/ppo_multi_adapter.py @@ -0,0 +1,151 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field +from typing import Optional + +import torch +from datasets import load_dataset +from peft import LoraConfig +from tqdm import tqdm +from transformers import AutoTokenizer, BitsAndBytesConfig, HfArgumentParser + +from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, is_xpu_available +from trl.core import LengthSampler + + +input_min_text_length = 6 +input_max_text_length = 12 + + +@dataclass +class ScriptArguments: + """ + The name of the Casual LM model we wish to fine with PPO + """ + + model_name: Optional[str] = field(default="huggyllama/llama-7b", metadata={"help": "the model name"}) + dataset_name: Optional[str] = field(default="Anthropic/hh-rlhf", metadata={"help": "the dataset name"}) + rm_adapter: Optional[str] = field( + default="trl-lib/llama-7b-hh-rm-adapter", metadata={"help": "the rm adapter name"} + ) + log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"}) + use_safetensors: Optional[bool] = field(default=False, metadata={"help": "Use safetensors"}) + seed: Optional[int] = field(default=0, metadata={"help": "the random seed"}) + use_score_scaling: Optional[bool] = field(default=False, metadata={"help": "Use score scaling"}) + use_score_norm: Optional[bool] = field( + default=False, metadata={"help": "Use score normalization. Only applicable if use_score_scaling is True"} + ) + score_clip: Optional[float] = field(default=None, metadata={"help": "Score clipping"}) + + +parser = HfArgumentParser(ScriptArguments) +script_args = parser.parse_args_into_dataclasses()[0] + + +def create_and_prepare_dataset(tokenizer): + dataset = load_dataset(script_args.dataset_name, split="train[:1%]") + + input_size = LengthSampler(input_min_text_length, input_max_text_length) + + def tokenize(example): + text_size = input_size() + example["input_ids"] = tokenizer.encode(example["chosen"])[:text_size] + example["query"] = tokenizer.decode(example["input_ids"]) + return example + + dataset = dataset.map(tokenize, batched=False) + dataset.set_format("torch") + return dataset + + +lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", +) +nf4_config = BitsAndBytesConfig( + load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16 +) +model = AutoModelForCausalLMWithValueHead.from_pretrained( + script_args.model_name, + device_map={"": "xpu:0"} if is_xpu_available() else {"": 0}, + peft_config=lora_config, + quantization_config=nf4_config, + reward_adapter=script_args.rm_adapter, + use_safetensors=script_args.use_safetensors, +) +tokenizer = AutoTokenizer.from_pretrained(script_args.model_name) + +tokenizer.pad_token = tokenizer.eos_token + +dataset = create_and_prepare_dataset(tokenizer) + + +def collator(data): + return dict((key, [d[key] for d in data]) for key in data[0]) + + +config = PPOConfig( + model_name=script_args.model_name, + log_with=script_args.log_with, + learning_rate=1e-5, + batch_size=8, + mini_batch_size=2, + gradient_accumulation_steps=2, + optimize_cuda_cache=True, + seed=script_args.seed, + use_score_scaling=script_args.use_score_scaling, + use_score_norm=script_args.use_score_norm, + score_clip=script_args.score_clip, +) + +ppo_trainer = PPOTrainer( + config, + model, + ref_model=None, + tokenizer=tokenizer, + dataset=dataset, + data_collator=collator, +) + +generation_kwargs = { + "top_k": 0.0, + "top_p": 0.9, + "do_sample": True, + "pad_token_id": tokenizer.pad_token_id, + "max_new_tokens": 32, +} + +for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)): + question_tensors = batch["input_ids"] + + response_tensors = ppo_trainer.generate( + question_tensors, + return_prompt=False, + **generation_kwargs, + ) + batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True) + + # Compute reward score + texts = [q + r for q, r in zip(batch["query"], batch["response"])] + inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(ppo_trainer.accelerator.device) + raw_rewards = ppo_trainer.accelerator.unwrap_model(ppo_trainer.model).compute_reward_score(**inputs) + rewards = [raw_rewards[i, -1, 1] for i in range(len(raw_rewards))] # take last token + + # Run PPO step + stats = ppo_trainer.step(question_tensors, response_tensors, rewards) + ppo_trainer.log_stats(stats, batch, rewards) diff --git a/examples/scripts/reward_modeling.py b/examples/scripts/reward_modeling.py new file mode 100644 index 0000000000000000000000000000000000000000..1271bce4a026056215ce80ab5c703428fe5e69ff --- /dev/null +++ b/examples/scripts/reward_modeling.py @@ -0,0 +1,173 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field +from typing import Optional + +import tyro +from accelerate import Accelerator +from datasets import load_dataset +from peft import LoraConfig +from tqdm import tqdm +from transformers import AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig + +from trl import RewardConfig, RewardTrainer, is_xpu_available + + +tqdm.pandas() + + +@dataclass +class ScriptArguments: + model_name: str = "facebook/opt-350m" + """the model name""" + dataset_name: str = "Anthropic/hh-rlhf" + """the dataset name""" + dataset_text_field: str = "text" + """the text field of the dataset""" + eval_split: str = "none" + """the dataset split to evaluate on; default to 'none' (no evaluation)""" + load_in_8bit: bool = False + """load the model in 8 bits precision""" + load_in_4bit: bool = False + """load the model in 4 bits precision""" + trust_remote_code: bool = True + """Enable `trust_remote_code`""" + reward_config: RewardConfig = field( + default_factory=lambda: RewardConfig( + output_dir="output", + per_device_train_batch_size=64, + num_train_epochs=1, + gradient_accumulation_steps=16, + gradient_checkpointing=True, + gradient_checkpointing_kwargs={"use_reentrant": False}, + learning_rate=1.41e-5, + report_to="tensorboard", + remove_unused_columns=False, + optim="adamw_torch", + logging_steps=500, + evaluation_strategy="no", + max_length=512, + ) + ) + use_peft: bool = False + """whether to use peft""" + peft_config: Optional[LoraConfig] = field( + default_factory=lambda: LoraConfig( + r=16, + lora_alpha=16, + bias="none", + task_type="SEQ_CLS", + modules_to_save=["scores"], + ), + ) + + +args = tyro.cli(ScriptArguments) +args.reward_config.evaluation_strategy = "steps" if args.eval_split != "none" else "no" + + +# Step 1: Load the model +if args.load_in_8bit and args.load_in_4bit: + raise ValueError("You can't load the model in 8 bits and 4 bits at the same time") +elif args.load_in_8bit or args.load_in_4bit: + quantization_config = BitsAndBytesConfig(load_in_8bit=args.load_in_8bit, load_in_4bit=args.load_in_4bit) + # Copy the model to each device + device_map = ( + {"": f"xpu:{Accelerator().local_process_index}"} + if is_xpu_available() + else {"": Accelerator().local_process_index} + ) +else: + device_map = None + quantization_config = None + +model = AutoModelForSequenceClassification.from_pretrained( + args.model_name, + quantization_config=quantization_config, + device_map=device_map, + trust_remote_code=args.trust_remote_code, + num_labels=1, +) + +# Step 2: Load the dataset and pre-process it +tokenizer = AutoTokenizer.from_pretrained(args.model_name) +train_dataset = load_dataset(args.dataset_name, split="train") + + +# Tokenize chosen/rejected pairs of inputs +# Adapt this section to your needs for custom datasets +def preprocess_function(examples): + new_examples = { + "input_ids_chosen": [], + "attention_mask_chosen": [], + "input_ids_rejected": [], + "attention_mask_rejected": [], + } + for chosen, rejected in zip(examples["chosen"], examples["rejected"]): + tokenized_chosen = tokenizer(chosen) + tokenized_rejected = tokenizer(rejected) + + new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"]) + new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"]) + new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"]) + new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"]) + + return new_examples + + +# Preprocess the dataset and filter out examples that are longer than args.max_length +train_dataset = train_dataset.map( + preprocess_function, + batched=True, + num_proc=4, +) +train_dataset = train_dataset.filter( + lambda x: len(x["input_ids_chosen"]) <= args.reward_config.max_length + and len(x["input_ids_rejected"]) <= args.reward_config.max_length +) + +if args.eval_split == "none": + eval_dataset = None +else: + eval_dataset = load_dataset(args.dataset_name, split=args.eval_split) + + eval_dataset = eval_dataset.map( + preprocess_function, + batched=True, + num_proc=4, + ) + eval_dataset = eval_dataset.filter( + lambda x: len(x["input_ids_chosen"]) <= args.reward_config.max_length + and len(x["input_ids_rejected"]) <= args.reward_config.max_length + ) + + +# Step 4: Define the LoraConfig +if args.use_peft: + peft_config = args.peft_config +else: + peft_config = None + +# Step 5: Define the Trainer +trainer = RewardTrainer( + model=model, + tokenizer=tokenizer, + args=args.reward_config, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + peft_config=peft_config, +) + +trainer.train() diff --git a/examples/scripts/sft.py b/examples/scripts/sft.py new file mode 100644 index 0000000000000000000000000000000000000000..1b5d620a3f3681a82ef273928dedd8b93455a111 --- /dev/null +++ b/examples/scripts/sft.py @@ -0,0 +1,158 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field +from typing import List, Optional + +import torch +from accelerate import Accelerator +from datasets import load_dataset +from peft import LoraConfig +from tqdm import tqdm +from transformers import AutoModelForCausalLM, BitsAndBytesConfig, HfArgumentParser, TrainingArguments + +from trl import SFTTrainer, is_xpu_available + + +tqdm.pandas() + + +# Define and parse arguments. +@dataclass +class ScriptArguments: + """ + The name of the Casual LM model we wish to fine with SFTTrainer + """ + + model_name: Optional[str] = field(default="facebook/opt-350m", metadata={"help": "the model name"}) + dataset_name: Optional[str] = field( + default="timdettmers/openassistant-guanaco", metadata={"help": "the dataset name"} + ) + dataset_text_field: Optional[str] = field(default="text", metadata={"help": "the text field of the dataset"}) + log_with: Optional[str] = field(default="none", metadata={"help": "use 'wandb' to log with wandb"}) + learning_rate: Optional[float] = field(default=1.41e-5, metadata={"help": "the learning rate"}) + batch_size: Optional[int] = field(default=64, metadata={"help": "the batch size"}) + seq_length: Optional[int] = field(default=512, metadata={"help": "Input sequence length"}) + gradient_accumulation_steps: Optional[int] = field( + default=16, metadata={"help": "the number of gradient accumulation steps"} + ) + load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"}) + load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"}) + use_peft: Optional[bool] = field(default=False, metadata={"help": "Wether to use PEFT or not to train adapters"}) + trust_remote_code: Optional[bool] = field(default=False, metadata={"help": "Enable `trust_remote_code`"}) + output_dir: Optional[str] = field(default="output", metadata={"help": "the output directory"}) + peft_lora_r: Optional[int] = field(default=64, metadata={"help": "the r parameter of the LoRA adapters"}) + peft_lora_alpha: Optional[int] = field(default=16, metadata={"help": "the alpha parameter of the LoRA adapters"}) + logging_steps: Optional[int] = field(default=1, metadata={"help": "the number of logging steps"}) + use_auth_token: Optional[bool] = field(default=True, metadata={"help": "Use HF auth token to access the model"}) + num_train_epochs: Optional[int] = field(default=3, metadata={"help": "the number of training epochs"}) + max_steps: Optional[int] = field(default=-1, metadata={"help": "the number of training steps"}) + save_steps: Optional[int] = field( + default=100, metadata={"help": "Number of updates steps before two checkpoint saves"} + ) + save_total_limit: Optional[int] = field(default=10, metadata={"help": "Limits total number of checkpoints."}) + push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the model to HF Hub"}) + gradient_checkpointing: Optional[bool] = field( + default=False, metadata={"help": "Whether to use gradient checkpointing or no"} + ) + gradient_checkpointing_kwargs: Optional[dict] = field( + default=None, + metadata={ + "help": "key word arguments to be passed along `torch.utils.checkpoint.checkpoint` method - e.g. `use_reentrant=False`" + }, + ) + hub_model_id: Optional[str] = field(default=None, metadata={"help": "The name of the model on HF Hub"}) + mixed_precision: Optional[str] = field(default="bf16", metadata={"help": "Mixed precision training"}) + target_modules: Optional[List[str]] = field(default=None, metadata={"help": "Target modules for LoRA adapters"}) + + +parser = HfArgumentParser(ScriptArguments) +script_args = parser.parse_args_into_dataclasses()[0] + +# Step 1: Load the model +if script_args.load_in_8bit and script_args.load_in_4bit: + raise ValueError("You can't load the model in 8 bits and 4 bits at the same time") +elif script_args.load_in_8bit or script_args.load_in_4bit: + quantization_config = BitsAndBytesConfig( + load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit + ) + # Copy the model to each device + device_map = ( + {"": f"xpu:{Accelerator().local_process_index}"} + if is_xpu_available() + else {"": Accelerator().local_process_index} + ) + torch_dtype = torch.bfloat16 +else: + device_map = None + quantization_config = None + torch_dtype = None + +model = AutoModelForCausalLM.from_pretrained( + script_args.model_name, + quantization_config=quantization_config, + device_map=device_map, + trust_remote_code=script_args.trust_remote_code, + torch_dtype=torch_dtype, + use_auth_token=script_args.use_auth_token, +) + +# Step 2: Load the dataset +dataset = load_dataset(script_args.dataset_name, split="train") + +# Step 3: Define the training arguments +training_args = TrainingArguments( + output_dir=script_args.output_dir, + per_device_train_batch_size=script_args.batch_size, + gradient_accumulation_steps=script_args.gradient_accumulation_steps, + learning_rate=script_args.learning_rate, + logging_steps=script_args.logging_steps, + num_train_epochs=script_args.num_train_epochs, + max_steps=script_args.max_steps, + report_to=script_args.log_with, + save_steps=script_args.save_steps, + save_total_limit=script_args.save_total_limit, + push_to_hub=script_args.push_to_hub, + hub_model_id=script_args.hub_model_id, + gradient_checkpointing=script_args.gradient_checkpointing, + # TODO: uncomment that on the next release + # gradient_checkpointing_kwargs=script_args.gradient_checkpointing_kwargs, +) + +# Step 4: Define the LoraConfig +if script_args.use_peft: + peft_config = LoraConfig( + r=script_args.peft_lora_r, + lora_alpha=script_args.peft_lora_alpha, + bias="none", + task_type="CAUSAL_LM", + target_modules=script_args.target_modules, + ) +else: + peft_config = None + +# Step 5: Define the Trainer +trainer = SFTTrainer( + model=model, + args=training_args, + max_seq_length=script_args.seq_length, + train_dataset=dataset, + dataset_text_field=script_args.dataset_text_field, + peft_config=peft_config, +) + +trainer.train() + +# Step 6: Save the model +trainer.save_model(script_args.output_dir) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..7e6b3f84fae69610b44b44cba277e2201bcfb555 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,16 @@ +[tool.black] +line-length = 119 +target-version = ['py38'] + +[tool.ruff] +ignore = ["E501", "E741", "W605"] +select = ["E", "F", "I", "W"] +line-length = 119 + +# Ignore import violations in all `__init__.py` files. +[tool.ruff.per-file-ignores] +"__init__.py" = ["E402", "F401", "F403", "F811"] + +[tool.ruff.isort] +lines-after-imports = 2 +known-first-party = ["trl"] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..12d8a5845c755445298b7f98c056bd02a2d61672 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +datasets>=1.17.0 +torch>=1.4.0 +tqdm +transformers +accelerate +peft>=0.3.0 +tyro>=0.5.7 \ No newline at end of file diff --git a/scripts/stale.py b/scripts/stale.py new file mode 100644 index 0000000000000000000000000000000000000000..de7b869c13280cea71507cbe1e635e25c3f36c5b --- /dev/null +++ b/scripts/stale.py @@ -0,0 +1,61 @@ +# Copyright 2023 The HuggingFace Team, the AllenNLP library authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Script to close stale issue. Taken in part from the AllenNLP repository. +https://github.com/allenai/allennlp. +""" +import os +from datetime import datetime as dt +from datetime import timezone + +from github import Github + + +LABELS_TO_EXEMPT = [ + "good first issue", + "good second issue", + "feature request", +] + + +def main(): + g = Github(os.environ["GITHUB_TOKEN"]) + repo = g.get_repo("huggingface/trl") + open_issues = repo.get_issues(state="open") + + for issue in open_issues: + comments = sorted([comment for comment in issue.get_comments()], key=lambda i: i.created_at, reverse=True) + last_comment = comments[0] if len(comments) > 0 else None + if ( + last_comment is not None + and last_comment.user.login == "github-actions[bot]" + and (dt.now(timezone.utc) - issue.updated_at).days > 7 + and (dt.now(timezone.utc) - issue.created_at).days >= 30 + and not any(label.name.lower() in LABELS_TO_EXEMPT for label in issue.get_labels()) + ): + issue.edit(state="closed") + elif ( + (dt.now(timezone.utc) - issue.updated_at).days > 23 + and (dt.now(timezone.utc) - issue.created_at).days >= 30 + and not any(label.name.lower() in LABELS_TO_EXEMPT for label in issue.get_labels()) + ): + issue.create_comment( + "This issue has been automatically marked as stale because it has not had " + "recent activity. If you think this still needs to be addressed " + "please comment on this thread.\n\n" + ) + + +if __name__ == "__main__": + main() diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000000000000000000000000000000000000..cb69438f5607211e0e1002d3fc9f9f7479b2b998 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,11 @@ +[metadata] +license_file = LICENSE + +[isort] +ensure_newline_before_comments = True +force_grid_wrap = 0 +include_trailing_comma = True +line_length = 119 +lines_after_imports = 2 +multi_line_output = 3 +use_parentheses = True diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..1b2e1cb14c3619017a0b00fa15b609cf44533fc6 --- /dev/null +++ b/setup.py @@ -0,0 +1,112 @@ +""" trl is an open library for RL with transformer models. + +Note: + + VERSION needs to be formatted following the MAJOR.MINOR.PATCH convention + (we need to follow this convention to be able to retrieve versioned scripts) + +Simple check list for release from AllenNLP repo: https://github.com/allenai/allennlp/blob/master/setup.py + +To create the package for pypi. + +0. Prerequisites: + - Dependencies: + - twine: "pip install twine" + - Create an account in (and join the 'trl' project): + - PyPI: https://pypi.org/ + - Test PyPI: https://test.pypi.org/ + +1. Change the version in: + - __init__.py + - setup.py + +2. Commit these changes: "git commit -m 'Release: VERSION'" + +3. Add a tag in git to mark the release: "git tag VERSION -m 'Add tag VERSION for pypi'" + Push the tag to remote: git push --tags origin main + +4. Build both the sources and the wheel. Do not change anything in setup.py between + creating the wheel and the source distribution (obviously). + + First, delete any "build" directory that may exist from previous builds. + + For the wheel, run: "python setup.py bdist_wheel" in the top level directory. + (this will build a wheel for the python version you use to build it). + + For the sources, run: "python setup.py sdist" + You should now have a /dist directory with both .whl and .tar.gz source versions. + +5. Check that everything looks correct by uploading the package to the pypi test server: + + twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/ + + Check that you can install it in a virtualenv/notebook by running: + pip install huggingface_hub fsspec aiohttp + pip install -U tqdm + pip install -i https://testpypi.python.org/pypi evaluate + +6. Upload the final version to actual pypi: + twine upload dist/* -r pypi + +7. Fill release notes in the tag in github once everything is looking hunky-dory. + +8. Change the version in __init__.py and setup.py to X.X.X+1.dev0 (e.g. VERSION=1.18.3 -> 1.18.4.dev0). + Then push the change with a message 'set dev version' +""" + +from setuptools import find_packages, setup + + +__version__ = "0.7.5.dev0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) + +REQUIRED_PKGS = [ + "torch>=1.4.0", + "transformers>=4.18.0", + "numpy>=1.18.2", + "accelerate", + "datasets", + "tyro>=0.5.11", +] +EXTRAS = { + "test": ["parameterized", "pytest", "pytest-xdist", "accelerate"], + "peft": ["peft>=0.4.0"], + "diffusers": ["diffusers>=0.18.0"], + "deepspeed": ["deepspeed>=0.9.5"], + "benchmark": ["wandb", "ghapi", "openrlbenchmark==0.2.1a5", "requests", "deepspeed"], + "quantization": ["bitsandbytes<=0.41.1"], +} +EXTRAS["dev"] = [] +for reqs in EXTRAS.values(): + EXTRAS["dev"].extend(reqs) + +setup( + name="trl", + license="Apache 2.0", + classifiers=[ + "Development Status :: 2 - Pre-Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Natural Language :: English", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + ], + url="https://github.com/huggingface/trl", + packages=find_packages(), + include_package_data=True, + install_requires=REQUIRED_PKGS, + extras_require=EXTRAS, + python_requires=">=3.7", + long_description=open("README.md", encoding="utf-8").read(), + long_description_content_type="text/markdown", + zip_safe=False, + version=__version__, + description="A Pytorch implementation of Proximal Policy Optimization for transfomer language models.", + keywords="ppo, transformers, huggingface, gpt2, language modeling, rlhf", + author="Leandro von Werra", + author_email="leandro.vonwerra@gmail.com", +) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_best_of_n_sampler.py b/tests/test_best_of_n_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..3b5001a898aefd7a3b389ea27ba9bc9a45e4d770 --- /dev/null +++ b/tests/test_best_of_n_sampler.py @@ -0,0 +1,98 @@ +import unittest + +import torch +from transformers import AutoTokenizer, GenerationConfig + +from trl import AutoModelForCausalLMWithValueHead +from trl.core import LengthSampler +from trl.extras import BestOfNSampler + + +def queries_to_scores(list_of_strings): + return [torch.rand(1).item() for _ in list_of_strings] + + +class BestOfNSamplerTester(unittest.TestCase): + """ + Tests the BestOfNSampler class + """ + + ref_model_name = "trl-internal-testing/dummy-GPT2-correct-vocab" + output_length_sampler = LengthSampler(2, 6) + model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name) + tokenizer = AutoTokenizer.from_pretrained(ref_model_name) + tokenizer.pad_token = tokenizer.eos_token + output_length_sampler = LengthSampler(2, 6) + + def test_different_input_types(self): + r""" + Tests if the different input types normalizer works + """ + + generation_config = GenerationConfig( + min_length=-1, + top_k=0.0, + top_p=1.0, + do_sample=True, + pad_token_id=self.tokenizer.eos_token_id, + ) + + output_length_sampler = LengthSampler(2, 6) + + best_of_n = BestOfNSampler( + self.model, + self.tokenizer, + queries_to_scores, + length_sampler=output_length_sampler, + generation_config=generation_config, + ) + + queries = ["hello world", "goodbye world"] + tokenized_queries = [self.tokenizer.encode(query) for query in queries] + + various_queries_formats = [ + (tokenized_queries[0], 1), + (tokenized_queries, 2), + (torch.tensor(tokenized_queries[1]), 1), + ([torch.tensor(query) for query in tokenized_queries], 2), + ] + + for q, expected_length in various_queries_formats: + results = best_of_n.generate(q) + self.assertIsInstance(results, list) + assert len(results) == expected_length + + def test_different_sample_sizes_and_n_candidates_values(self): + r""" + Tests different sample sizes and n_candidates values + """ + generation_config = GenerationConfig( + min_length=-1, + top_k=0.0, + top_p=1.0, + do_sample=True, + pad_token_id=self.tokenizer.eos_token_id, + ) + + output_length_sampler = LengthSampler(6, 10) + + for sample_value, n_candidates_values, expected in [ + (4, 2, 2), + (10, 3, 3), + (6, 4, 4), + ]: + best_of_n = BestOfNSampler( + self.model, + self.tokenizer, + queries_to_scores, + length_sampler=output_length_sampler, + generation_config=generation_config, + sample_size=sample_value, + n_candidates=n_candidates_values, + ) + + queries = ["hello world", "troll the world"] + tokenized_queries = [self.tokenizer.encode(query) for query in queries] + results = best_of_n.generate(tokenized_queries) + for result in results: + assert len(result) == expected diff --git a/tests/test_core.py b/tests/test_core.py new file mode 100644 index 0000000000000000000000000000000000000000..151852e267b0c2ef6557901a4269c8ab774aafba --- /dev/null +++ b/tests/test_core.py @@ -0,0 +1,42 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import torch + +from trl.core import masked_mean, masked_var, masked_whiten, whiten + + +class CoreTester(unittest.TestCase): + """ + A wrapper class for testing core utils functions + """ + + @classmethod + def setUpClass(cls): + cls.test_input = torch.Tensor([1, 2, 3, 4]) + cls.test_mask = torch.Tensor([0, 1, 1, 0]) + cls.test_input_unmasked = cls.test_input[1:3] + + def test_masked_mean(self): + self.assertEqual(torch.mean(self.test_input_unmasked), masked_mean(self.test_input, self.test_mask)) + + def test_masked_var(self): + self.assertEqual(torch.var(self.test_input_unmasked), masked_var(self.test_input, self.test_mask)) + + def test_masked_whiten(self): + whiten_unmasked = whiten(self.test_input_unmasked) + whiten_masked = masked_whiten(self.test_input, self.test_mask)[1:3] + diffs = (whiten_unmasked - whiten_masked).sum() + self.assertAlmostEqual(diffs, 0) diff --git a/tests/test_data_collator_completion_only.py b/tests/test_data_collator_completion_only.py new file mode 100644 index 0000000000000000000000000000000000000000..c895a616e136c211493e6e042221691b0e248261 --- /dev/null +++ b/tests/test_data_collator_completion_only.py @@ -0,0 +1,81 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import torch +from transformers import AutoTokenizer + +from trl import DataCollatorForCompletionOnlyLM + + +class DataCollatorForCompletionOnlyLMTester(unittest.TestCase): + def test_data_collator_finds_response_template_llama2_tokenizer(self): + # this should ideally be tested with meta-llama/Llama-2-7b-hf + self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab") + self.instruction = """### System: You are a helpful assistant. + +### User: How much is 2+2? + +### Assistant: 2+2 equals 4""" + self.instruction_template = "\n### User:" + self.response_template = "\n### Assistant:" + + # GPT2Tokenizer: [198, 21017, 11787, 25] -> [11787, 25] + # Llama2Tokenizer: [29871, 13, 2277, 29937, 4911, 29901] -> [2277, 29937, 4911, 29901] + self.tokenized_instruction_w_context = self.tokenizer.encode( + self.instruction_template, add_special_tokens=False + )[2:] + + # GPT2Tokenizer: [198, 21017, 15286, 25] -> [15286, 25] + # Llama2Tokenizer: [29871, 13, 2277, 29937, 4007, 22137, 29901] -> [2277, 29937, 4007, 22137, 29901] + self.tokenized_response_w_context = self.tokenizer.encode(self.response_template, add_special_tokens=False)[2:] + + # Plain check on string + self.assertIn(self.response_template, self.instruction) + self.tokenized_instruction = self.tokenizer.encode(self.instruction, add_special_tokens=False) + + # Test the fix for #598 + # Pass already tokenized (w context) and truncated response_template so token_ids are like in the instruction + response + self.collator = DataCollatorForCompletionOnlyLM(self.tokenized_response_w_context, tokenizer=self.tokenizer) + self.collator.torch_call([self.tokenized_instruction]) + + # Test for PR #749 + # Pass already tokenized (w context) instruction and response both so token_ids are like in the instruction + response + self.collator = DataCollatorForCompletionOnlyLM( + self.tokenized_response_w_context, self.tokenized_instruction_w_context, tokenizer=self.tokenizer + ) + self.collator.torch_call([self.tokenized_instruction]) + + def test_data_collator_handling_of_long_sequences(self): + self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab") + self.instruction = """### System: You are a helpful assistant. + +### User: How much is 2+2? I'm asking because I'm not sure. And I'm not sure because I'm not good at math. +""" + self.response_template = "\n### Assistant:" + # check DataCollatorForCompletionOnlyLM using response template only + self.tokenized_instruction = self.tokenizer.encode(self.instruction, add_special_tokens=False) + self.collator = DataCollatorForCompletionOnlyLM(self.response_template, tokenizer=self.tokenizer) + encoded_instance = self.collator.torch_call([self.tokenized_instruction]) + result = torch.all(encoded_instance["labels"] == -100) + self.assertTrue(result, "Not all values in the tensor are -100.") + + # check DataCollatorForCompletionOnlyLM using response template and instruction template + self.instruction_template = "\n### User:" + self.collator = DataCollatorForCompletionOnlyLM( + self.response_template, self.instruction_template, tokenizer=self.tokenizer + ) + encoded_instance = self.collator.torch_call([self.tokenized_instruction]) + result = torch.all(encoded_instance["labels"] == -100) + self.assertTrue(result, "Not all values in the tensor are -100.") diff --git a/tests/test_ddpo_trainer.py b/tests/test_ddpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e56ab16b571c5aee7e99b985063fa227a380ebd1 --- /dev/null +++ b/tests/test_ddpo_trainer.py @@ -0,0 +1,99 @@ +# Copyright 2023 metric-space, The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import gc +import unittest + +import torch + +from trl import is_diffusers_available + +from .testing_utils import require_diffusers + + +if is_diffusers_available(): + from trl import DDPOConfig, DDPOTrainer, DefaultDDPOStableDiffusionPipeline + + +def scorer_function(images, prompts, metadata): + return torch.randn(1) * 3.0, {} + + +def prompt_function(): + return ("cabbages", {}) + + +@require_diffusers +class DDPOTrainerTester(unittest.TestCase): + """ + Test the DDPOTrainer class. + """ + + def setUp(self): + self.ddpo_config = DDPOConfig( + num_epochs=2, + train_gradient_accumulation_steps=1, + per_prompt_stat_tracking_buffer_size=32, + sample_num_batches_per_epoch=2, + sample_batch_size=2, + mixed_precision=None, + save_freq=1000000, + ) + pretrained_model = "hf-internal-testing/tiny-stable-diffusion-torch" + pretrained_revision = "main" + + pipeline = DefaultDDPOStableDiffusionPipeline( + pretrained_model, pretrained_model_revision=pretrained_revision, use_lora=False + ) + + self.trainer = DDPOTrainer(self.ddpo_config, scorer_function, prompt_function, pipeline) + + return super().setUp() + + def tearDown(self) -> None: + gc.collect() + + def test_loss(self): + advantage = torch.tensor([-1.0]) + clip_range = 0.0001 + ratio = torch.tensor([1.0]) + loss = self.trainer.loss(advantage, clip_range, ratio) + self.assertEqual(loss.item(), 1.0) + + def test_generate_samples(self): + samples, output_pairs = self.trainer._generate_samples(1, 2) + self.assertEqual(len(samples), 1) + self.assertEqual(len(output_pairs), 1) + self.assertEqual(len(output_pairs[0][0]), 2) + + def test_calculate_loss(self): + samples, _ = self.trainer._generate_samples(1, 2) + sample = samples[0] + + latents = sample["latents"][0, 0].unsqueeze(0) + next_latents = sample["next_latents"][0, 0].unsqueeze(0) + log_probs = sample["log_probs"][0, 0].unsqueeze(0) + timesteps = sample["timesteps"][0, 0].unsqueeze(0) + prompt_embeds = sample["prompt_embeds"] + advantage = torch.tensor([1.0], device=prompt_embeds.device) + + self.assertEqual(latents.shape, (1, 4, 64, 64)) + self.assertEqual(next_latents.shape, (1, 4, 64, 64)) + self.assertEqual(log_probs.shape, (1,)) + self.assertEqual(timesteps.shape, (1,)) + self.assertEqual(prompt_embeds.shape, (2, 77, 32)) + loss, approx_kl, clipfrac = self.trainer.calculate_loss( + latents, timesteps, next_latents, log_probs, advantage, prompt_embeds + ) + + self.assertTrue(torch.isfinite(loss.cpu())) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..ef628497a31f90dd910a2aabc61369ce38b9f500 --- /dev/null +++ b/tests/test_dpo_trainer.py @@ -0,0 +1,298 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tempfile +import unittest + +import torch +from datasets import Dataset +from parameterized import parameterized +from pytest import mark +from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments + +from trl import DPOTrainer + +from .testing_utils import require_no_wandb, require_peft + + +class DPOTrainerTester(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab" + cls.model = AutoModelForCausalLM.from_pretrained(cls.model_id) + cls.ref_model = AutoModelForCausalLM.from_pretrained(cls.model_id) + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_id) + cls.tokenizer.pad_token = cls.tokenizer.eos_token + + # get t5 as seq2seq example: + model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration-correct-vocab" + cls.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id) + cls.t5_ref_model = AutoModelForSeq2SeqLM.from_pretrained(model_id) + cls.t5_tokenizer = AutoTokenizer.from_pretrained(model_id) + + def _init_dummy_dataset(self): + # fmt: off + dummy_dataset_dict = { + "prompt": [ + "hello", + "how are you", + "What is your name?", + "What is your name?", + "Which is the best programming language?", + "Which is the best programming language?", + "Which is the best programming language?", + ], + "chosen": [ + "hi nice to meet you", + "I am fine", + "My name is Mary", + "My name is Mary", + "Python", + "Python", + "Python", + ], + "rejected": [ + "leave me alone", + "I am not fine", + "Whats it to you?", + "I dont have a name", + "Javascript", + "C++", + "Java", + ], + } + # fmt: on + return Dataset.from_dict(dummy_dataset_dict) + + @parameterized.expand( + [["gpt2", "sigmoid"], ["t5", "hinge"], ["gpt2", "ipo"], ["t5", "ipo"], ["gpt2", "kto"], ["t5", "kto"]] + ) + def test_dpo_trainer(self, name, loss_type): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + evaluation_strategy="steps", + ) + + dummy_dataset = self._init_dummy_dataset() + + if name == "gpt2": + model = self.model + ref_model = self.ref_model + tokenizer = self.tokenizer + elif name == "t5": + model = self.t5_model + ref_model = self.t5_ref_model + tokenizer = self.t5_tokenizer + + trainer = DPOTrainer( + model=model, + ref_model=ref_model, + beta=0.1, + loss_type=loss_type, + args=training_args, + tokenizer=tokenizer, + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # check the params have changed - ignore 0 biases + if param.sum() != 0: + self.assertFalse(torch.equal(param, new_param)) + + def test_dpo_trainer_without_providing_ref_model(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + evaluation_strategy="steps", + ) + + dummy_dataset = self._init_dummy_dataset() + + trainer = DPOTrainer( + model=self.model, + ref_model=None, + beta=0.1, + args=training_args, + tokenizer=self.tokenizer, + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # check the params have changed - ignore 0 biases + if param.sum() != 0: + self.assertFalse(torch.equal(param, new_param)) + + @require_peft + @mark.peft_test + def test_dpo_trainer_without_providing_ref_model_with_lora(self): + from peft import LoraConfig + + lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + evaluation_strategy="steps", + ) + + dummy_dataset = self._init_dummy_dataset() + + trainer = DPOTrainer( + model=self.model, + ref_model=None, + beta=0.1, + args=training_args, + tokenizer=self.tokenizer, + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, + peft_config=lora_config, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # check the params have changed + for n, param in previous_trainable_params.items(): + if "lora" in n: + new_param = trainer.model.get_parameter(n) + # check the params have changed - ignore 0 biases + if param.sum() != 0: + self.assertFalse(torch.equal(param, new_param)) + + @require_no_wandb + def test_dpo_trainer_generate_during_eval_no_wandb(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + evaluation_strategy="steps", + ) + + dummy_dataset = self._init_dummy_dataset() + + with self.assertRaisesRegex( + ValueError, + expected_regex="`generate_during_eval=True` requires Weights and Biases to be installed." + " Please install `wandb` to resolve.", + ): + DPOTrainer( + model=self.model, + ref_model=None, + beta=0.1, + args=training_args, + tokenizer=self.tokenizer, + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, + generate_during_eval=True, + ) + + @require_peft + @mark.peft_test + def test_dpo_lora_save(self): + from peft import LoraConfig, get_peft_model + + lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + # lora model + model = AutoModelForCausalLM.from_pretrained(self.model_id) + model_peft = get_peft_model(model, lora_config) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + evaluation_strategy="steps", + ) + + dummy_dataset = self._init_dummy_dataset() + + # dpo train lora model with a lora config + trainer = DPOTrainer( + model=model_peft, + ref_model=None, + beta=0.1, + args=training_args, + tokenizer=self.tokenizer, + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, + peft_config=lora_config, + ) + + # train the model + trainer.train() + + # save peft adapter + trainer.save_model() + + # assert that the model is loaded without giving OSError + try: + AutoModelForCausalLM.from_pretrained(tmp_dir) + except OSError: + self.fail("Loading the saved peft adapter failed") diff --git a/tests/test_e2e.py b/tests/test_e2e.py new file mode 100644 index 0000000000000000000000000000000000000000..7e742329dee0da2fd7f80fee5568f3d808dff6fd --- /dev/null +++ b/tests/test_e2e.py @@ -0,0 +1,9 @@ +import subprocess + + +def test_hello_world(): + subprocess.run( + "python examples/hello_world.py", + shell=True, + check=True, + ) diff --git a/tests/test_environments.py b/tests/test_environments.py new file mode 100644 index 0000000000000000000000000000000000000000..e31daab5cebee9fe2e13fc85d4deb63503c8c01d --- /dev/null +++ b/tests/test_environments.py @@ -0,0 +1,273 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import patch + +import torch +from transformers import AutoTokenizer + +from trl import AutoModelForCausalLMWithValueHead, TextEnvironment, TextHistory + + +class DummyTool: + def __call__(self, text): + return text + + +def dummy_generate(histories): + for i in range(len(histories)): + histories[i].append_segment("test", torch.tensor([1, 2, 3]), system=False) + return histories + + +class TextHistoryTest(unittest.TestCase): + def test_text_history_init(self): + text = "Hello there!" + tokens = torch.tensor([1, 2, 3]) + + history = TextHistory(text, tokens) + self.assertEqual(history.text, text) + self.assertTrue(torch.equal(history.tokens, tokens)) + self.assertTrue(torch.equal(history.token_masks, torch.zeros_like(tokens))) + + history = TextHistory(text, tokens, system=False) + self.assertTrue(torch.equal(history.token_masks, torch.ones_like(tokens))) + + def test_text_history_append_segment(self): + text = "Hello there!" + tokens = torch.tensor([1, 2, 3]) + + history = TextHistory(text, tokens) + history.append_segment("General Kenobi!", torch.tensor([4, 5, 6]), system=False) + self.assertEqual(history.text, text + "General Kenobi!") + self.assertTrue(torch.equal(history.tokens, torch.tensor([1, 2, 3, 4, 5, 6]))) + self.assertTrue(torch.equal(history.token_masks, torch.tensor([0, 0, 0, 1, 1, 1]))) + + history.append_segment("You are a bold one!", torch.tensor([7, 8, 9])) + self.assertEqual(history.text, text + "General Kenobi!" + "You are a bold one!") + self.assertTrue(torch.equal(history.tokens, torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]))) + self.assertTrue(torch.equal(history.token_masks, torch.tensor([0, 0, 0, 1, 1, 1, 0, 0, 0]))) + + def test_text_history_complete(self): + text = "Hello there!" + tokens = torch.tensor([1, 2, 3]) + history = TextHistory(text, tokens) + history.complete() + self.assertTrue(history.completed) + self.assertFalse(history.truncated) + + history.complete(truncated=True) + self.assertTrue(history.completed) + self.assertTrue(history.truncated) + + def test_text_history_last_segment(self): + text = "Hello there!" + tokens = torch.tensor([1, 2, 3]) + history = TextHistory(text, tokens) + history.append_segment("General Kenobi!", torch.tensor([4, 5, 6])) + history.append_segment("You are a bold one!", torch.tensor([7, 8, 9])) + self.assertEqual(history.last_text_segment, "You are a bold one!") + + def test_text_history_split_query_response(self): + text = "Hello there!" + tokens = torch.tensor([1, 2, 3]) + history = TextHistory(text, tokens) + history.append_segment("General Kenobi!", torch.tensor([4, 5, 6]), system=False) + history.append_segment("You are a bold one!", torch.tensor([7, 8, 9]), system=True) + query, response, mask = history.split_query_response_tokens() + + self.assertTrue(torch.equal(query, torch.tensor([1, 2, 3]))) + self.assertTrue(torch.equal(response, torch.tensor([4, 5, 6, 7, 8, 9]))) + self.assertTrue(torch.equal(mask, torch.tensor([1, 1, 1, 0, 0, 0]))) + + +class TextEnvironmentTester(unittest.TestCase): + @classmethod + def setUpClass(cls): + # model_id + cls.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab" + + # get models and tokenizer + cls.gpt2_model = AutoModelForCausalLMWithValueHead.from_pretrained(cls.model_id) + cls.gpt2_tokenizer = AutoTokenizer.from_pretrained(cls.model_id) + cls.gpt2_tokenizer.pad_token = cls.gpt2_tokenizer.eos_token + + def test_text_environment_setup(self): + env = TextEnvironment( + self.gpt2_model, + self.gpt2_tokenizer, + tools=[DummyTool()], + reward_fn=lambda x: torch.tensor(1), + prompt="I am a prompt!\n", + ) + self.assertEqual(env.prompt, "I am a prompt!\n") + self.assertEqual(list(env.tools.keys()), ["DummyTool"]) + self.assertTrue(isinstance(env.tools["DummyTool"], DummyTool)) + self.assertEqual(env.reward_fn("Hello there!"), 1) + + def test_text_environment_generate(self): + generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": self.gpt2_tokenizer.eos_token_id} + env = TextEnvironment( + self.gpt2_model, + self.gpt2_tokenizer, + tools=[DummyTool()], + reward_fn=lambda x: torch.tensor(1), + prompt="I am a prompt!\n", + generation_kwargs=generation_kwargs, + ) + + input_texts = ["this is a test", "this is another, longer test"] + + model_inputs = [self.gpt2_tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts] + + generations_batched = env._generate_batched(model_inputs, batch_size=2) + generations_batched = self.gpt2_tokenizer.batch_decode(generations_batched) + + generations_single = [env._generate_batched([inputs], batch_size=1)[0] for inputs in model_inputs] + generations_single = self.gpt2_tokenizer.batch_decode(generations_single) + + self.assertEqual(generations_single, generations_batched) + + def test_text_environment_tool_call_parsing(self): + string_valid = "Something something Hello there!" + string_invalid_request = "Something something Hello there!" + string_invalid_call = "Something something Hello there!" + string_invalid_tool = "Something something |Tool2|Hello there!" + string_invalid_random = "<>abcdefghijklm<>nopqrstuvwxyz<>" + + env = TextEnvironment( + self.gpt2_model, + self.gpt2_tokenizer, + tools=[DummyTool()], + reward_fn=lambda x: torch.tensor(1), + prompt="I am a prompt!\n", + ) + tool, response = env.parse_tool_call(string_valid) + self.assertEqual(tool, "Tool1") + self.assertEqual(response, "Hello there!") + + tool, response = env.parse_tool_call(string_invalid_request) + self.assertEqual(tool, None) + self.assertEqual(response, None) + + tool, response = env.parse_tool_call(string_invalid_call) + self.assertEqual(tool, None) + self.assertEqual(response, None) + + tool, response = env.parse_tool_call(string_invalid_tool) + self.assertEqual(tool, None) + self.assertEqual(response, None) + + tool, response = env.parse_tool_call(string_invalid_random) + self.assertEqual(tool, None) + self.assertEqual(response, None) + + def test_text_environment_tool_truncation(self): + env = TextEnvironment( + self.gpt2_model, + self.gpt2_tokenizer, + tools={"dummy": lambda x: "a" * 1000}, + reward_fn=lambda x: torch.tensor(1), + prompt="I am a prompt!\n", + ) + + env.max_tool_response = 100 + history = env.step(TextHistory("Hello there!", torch.tensor([1, 2, 3]))) + self.assertEqual(len(history.last_text_segment) - len(env.response_token), 100) + + env.max_tool_response = 500 + history = env.step(TextHistory("Hello there!", torch.tensor([1, 2, 3]))) + self.assertEqual(len(history.last_text_segment) - len(env.response_token), 500) + + env.max_tool_response = 1001 + history = env.step(TextHistory("Hello there!", torch.tensor([1, 2, 3]))) + self.assertEqual(len(history.last_text_segment) - len(env.response_token), 1000) + + env.max_tool_response = 2000 + history = env.step(TextHistory("Hello there!", torch.tensor([1, 2, 3]))) + self.assertEqual(len(history.last_text_segment) - len(env.response_token), 1000) + + @patch.object(TextEnvironment, "generate", side_effect=dummy_generate) + def test_text_environment_max_calls(self, mock_generate): + env = TextEnvironment( + self.gpt2_model, + self.gpt2_tokenizer, + tools={"DummyTool": DummyTool()}, + reward_fn=lambda x: [torch.tensor(1) for _ in x], + prompt="I am a prompt!\n", + ) + + env.max_turns = 1 + _, _, _, _, histories = env.run(["test"]) + self.assertEqual( + histories[0].text, "I am a prompt!\n" + "test" + 1 * "testtest" + ) + + env.max_turns = 2 + _, _, _, _, histories = env.run(["test"]) + self.assertEqual( + histories[0].text, "I am a prompt!\n" + "test" + 2 * "testtest" + ) + + env.max_turns = 4 + _, _, _, _, histories = env.run(["test"]) + self.assertEqual( + histories[0].text, "I am a prompt!\n" + "test" + 4 * "testtest" + ) + + def test_text_environment_compute_rewards(self): + env = TextEnvironment( + self.gpt2_model, + self.gpt2_tokenizer, + tools={"DummyTool": DummyTool()}, + reward_fn=lambda x: [torch.tensor(i) for i, _ in enumerate(x)], + prompt="I am a prompt!\n", + ) + + histories = [TextHistory("test", torch.tensor([1, 2, 3])) for _ in range(8)] + histories = env.compute_reward(histories) + + for i in range(8): + self.assertEqual(histories[i].reward, i) + + @patch.object(TextEnvironment, "generate", side_effect=dummy_generate) + def test_text_environment_run(self, mock_generate): + env = TextEnvironment( + self.gpt2_model, + self.gpt2_tokenizer, + tools={"DummyTool": DummyTool()}, + reward_fn=lambda x: [torch.tensor(i) for i, _ in enumerate(x)], + prompt="I am a prompt!\n", + max_turns=2, + ) + task_1 = "Hello there!" + task_2 = "Hello there! General Kenobi!" + + query, response, response_mask, reward, histories = env.run([task_1, task_2]) + self.assertEqual(len(query[0]), 9) + self.assertEqual(len(query[1]), 12) + self.assertEqual(len(response[0]), 14) + self.assertEqual(len(response[1]), 14) + self.assertEqual(response_mask[0].sum(), 2 * 3) # mocked generate always adds 3 toknes + self.assertEqual(response_mask[1].sum(), 2 * 3) # mocked generate always adds 3 toknes + self.assertEqual(reward[0], 0) + self.assertEqual(reward[1], 1) + self.assertEqual( + histories[0].text, "I am a prompt!\n" + "Hello there!" + 2 * "testtest" + ) + self.assertEqual( + histories[1].text, + "I am a prompt!\n" + "Hello there! General Kenobi!" + 2 * "testtest", + ) diff --git a/tests/test_iterative_sft_trainer.py b/tests/test_iterative_sft_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..70d5640795e7c74a4a51b338f226758bd88fbd39 --- /dev/null +++ b/tests/test_iterative_sft_trainer.py @@ -0,0 +1,106 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tempfile +import unittest + +import torch +from datasets import Dataset +from parameterized import parameterized +from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments + +from trl import IterativeSFTTrainer + + +class IterativeTrainerTester(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab" + cls.model = AutoModelForCausalLM.from_pretrained(cls.model_id) + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_id) + cls.tokenizer.pad_token = cls.tokenizer.eos_token + + # get t5 as seq2seq example: + model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration-correct-vocab" + cls.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id) + cls.t5_tokenizer = AutoTokenizer.from_pretrained(model_id) + + def _init_tensor_dummy_dataset(self): + dummy_dataset_dict = { + "input_ids": [torch.tensor([5303, 3621]), torch.tensor([3666, 1438, 318]), torch.tensor([5303, 3621])], + "attention_mask": [torch.tensor([1, 1]), torch.tensor([1, 1, 1]), torch.tensor([1, 1])], + "labels": [torch.tensor([5303, 3621]), torch.tensor([3666, 1438, 318]), torch.tensor([5303, 3621])], + } + + dummy_dataset = Dataset.from_dict(dummy_dataset_dict) + dummy_dataset.set_format("torch") + return dummy_dataset + + def _init_textual_dummy_dataset(self): + dummy_dataset_dict = { + "texts": ["Testing the IterativeSFTTrainer.", "This is a test of the IterativeSFTTrainer"], + "texts_labels": ["Testing the IterativeSFTTrainer.", "This is a test of the IterativeSFTTrainer"], + } + + dummy_dataset = Dataset.from_dict(dummy_dataset_dict) + dummy_dataset.set_format("torch") + return dummy_dataset + + def setUp(self): + # initialize trainer + self.model.train() + return super().setUp() + + @parameterized.expand( + [ + ["gpt2", "tensor"], + ["gpt2", "text"], + ["t5", "tensor"], + ["t5", "text"], + ] + ) + def test_iterative_step_from_tensor(self, model_name, input_name): + with tempfile.TemporaryDirectory() as tmp_dir: + # initialize dataset + if input_name == "tensor": + dummy_dataset = self._init_tensor_dummy_dataset() + inputs = { + "input_ids": dummy_dataset["input_ids"], + "attention_mask": dummy_dataset["attention_mask"], + "labels": dummy_dataset["labels"], + } + else: + dummy_dataset = self._init_textual_dummy_dataset() + inputs = { + "texts": dummy_dataset["texts"], + "texts_labels": dummy_dataset["texts_labels"], + } + + if model_name == "gpt2": + model = self.model + tokenizer = self.tokenizer + else: + model = self.t5_model + tokenizer = self.t5_tokenizer + + args = TrainingArguments( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=2, + ) + iterative_trainer = IterativeSFTTrainer(model=model, args=args, tokenizer=tokenizer) + + iterative_trainer.step(**inputs) + + for param in iterative_trainer.model.parameters(): + assert param.grad is not None diff --git a/tests/test_modeling_value_head.py b/tests/test_modeling_value_head.py new file mode 100644 index 0000000000000000000000000000000000000000..b1fdae0233225bbecd2e0fb7efa797dd9c54d961 --- /dev/null +++ b/tests/test_modeling_value_head.py @@ -0,0 +1,517 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import gc +import tempfile +import unittest + +import torch +from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM + +from trl import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead, create_reference_model + + +ALL_CAUSAL_LM_MODELS = [ + "trl-internal-testing/tiny-random-CodeGenForCausalLM", + "trl-internal-testing/tiny-random-GPTJForCausalLM", + "trl-internal-testing/tiny-random-GPTNeoForCausalLM", + "trl-internal-testing/tiny-random-GPTNeoXForCausalLM", + "trl-internal-testing/tiny-random-OPTForCausalLM", + "trl-internal-testing/tiny-random-BloomForCausalLM", + "trl-internal-testing/tiny-random-GPT2LMHeadModel", + "trl-internal-testing/tiny-random-CodeGenForCausalLM-sharded", + "trl-internal-testing/tiny-random-GPTNeoXForCausalLM-safetensors-sharded", + "trl-internal-testing/tiny-random-GPTNeoXForCausalLM-safetensors" + # "trl-internal-testing/tiny-random-LlamaForCausalLM", uncomment on the next transformers release +] + +ALL_SEQ2SEQ_MODELS = [ + "trl-internal-testing/tiny-random-BartForConditionalGeneration", + "trl-internal-testing/tiny-random-BigBirdPegasusForConditionalGeneration", + "trl-internal-testing/tiny-random-BlenderbotForConditionalGeneration", + "trl-internal-testing/tiny-random-BlenderbotSmallForConditionalGeneration", + "trl-internal-testing/tiny-random-FSMTForConditionalGeneration", + "trl-internal-testing/tiny-random-LEDForConditionalGeneration", + "trl-internal-testing/tiny-random-LongT5ForConditionalGeneration", + "trl-internal-testing/tiny-random-M2M100ForConditionalGeneration", + "trl-internal-testing/tiny-random-MarianMTModel", + "trl-internal-testing/tiny-random-MBartForConditionalGeneration", + "trl-internal-testing/tiny-random-MT5ForConditionalGeneration", + "trl-internal-testing/tiny-random-MvpForConditionalGeneration", + "trl-internal-testing/tiny-random-PegasusForConditionalGeneration", + "trl-internal-testing/tiny-random-PegasusXForConditionalGeneration", + "trl-internal-testing/tiny-random-PLBartForConditionalGeneration", + "trl-internal-testing/tiny-random-ProphetNetForConditionalGeneration", + "trl-internal-testing/tiny-random-SwitchTransformersForConditionalGeneration", + "trl-internal-testing/tiny-random-T5ForConditionalGeneration", +] + + +class VHeadModelTester: + all_model_names = None + trl_model_class = None + transformers_model_class = None + + def test_value_head(self): + r""" + Test if the v-head is added to the model successfully + """ + for model_name in self.all_model_names: + model = self.trl_model_class.from_pretrained(model_name) + self.assertTrue(hasattr(model, "v_head")) + + def test_value_head_shape(self): + r""" + Test if the v-head has the correct shape + """ + for model_name in self.all_model_names: + model = self.trl_model_class.from_pretrained(model_name) + self.assertTrue(model.v_head.summary.weight.shape[0] == 1) + + def test_value_head_init_random(self): + r""" + Test if the v-head has been randomly initialized. + We can check that by making sure the bias is different + than zeros by default. + """ + for model_name in self.all_model_names: + model = self.trl_model_class.from_pretrained(model_name) + self.assertFalse(torch.allclose(model.v_head.summary.bias, torch.zeros_like(model.v_head.summary.bias))) + + def test_value_head_not_str(self): + r""" + Test if the v-head is added to the model successfully, by passing a non `PretrainedModel` + as an argument to `from_pretrained`. + """ + for model_name in self.all_model_names: + pretrained_model = self.transformers_model_class.from_pretrained(model_name) + model = self.trl_model_class.from_pretrained(pretrained_model) + self.assertTrue(hasattr(model, "v_head")) + + def test_from_save_trl(self): + """ + Test if the model can be saved and loaded from a directory and get the same weights + Including the additional modules (e.g. v_head) + """ + for model_name in self.all_model_names: + model = self.trl_model_class.from_pretrained(model_name) + + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir) + + model_from_save = self.trl_model_class.from_pretrained(tmp_dir) + + # Check if the weights are the same + for key in model_from_save.state_dict(): + self.assertTrue(torch.allclose(model_from_save.state_dict()[key], model.state_dict()[key])) + + def test_from_save_trl_sharded(self): + """ + Test if the model can be saved and loaded from a directory and get the same weights - sharded case + """ + for model_name in self.all_model_names: + model = self.trl_model_class.from_pretrained(model_name) + + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir) + + model_from_save = self.trl_model_class.from_pretrained(tmp_dir) + + # Check if the weights are the same + for key in model_from_save.state_dict(): + self.assertTrue(torch.allclose(model_from_save.state_dict()[key], model.state_dict()[key])) + + def test_from_save_transformers_sharded(self): + """ + Test if the model can be saved and loaded using transformers and get the same weights - sharded case + """ + for model_name in self.all_model_names: + transformers_model = self.trl_model_class.transformers_parent_class.from_pretrained(model_name) + + trl_model = self.trl_model_class.from_pretrained(model_name) + + with tempfile.TemporaryDirectory() as tmp_dir: + trl_model.save_pretrained(tmp_dir, max_shard_size="1MB") + transformers_model_from_save = self.trl_model_class.transformers_parent_class.from_pretrained(tmp_dir) + + # Check if the weights are the same + for key in transformers_model.state_dict(): + self.assertTrue( + torch.allclose( + transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key] + ) + ) + + def test_from_save_transformers(self): + """ + Test if the model can be saved and loaded using transformers and get the same weights. + We override the test of the super class to check if the weights are the same. + """ + for model_name in self.all_model_names: + transformers_model = self.trl_model_class.transformers_parent_class.from_pretrained(model_name) + + trl_model = self.trl_model_class.from_pretrained(model_name) + + with tempfile.TemporaryDirectory() as tmp_dir: + trl_model.save_pretrained(tmp_dir) + transformers_model_from_save = self.trl_model_class.transformers_parent_class.from_pretrained(tmp_dir) + + # Check if the weights are the same + for key in transformers_model.state_dict(): + self.assertTrue( + torch.allclose( + transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key] + ) + ) + + # Check if the trl model has the same keys as the transformers model + # except the v_head + for key in trl_model.state_dict(): + if "v_head" not in key: + self.assertTrue(key in transformers_model.state_dict()) + # check if the weights are the same + self.assertTrue(torch.allclose(trl_model.state_dict()[key], transformers_model.state_dict()[key])) + + # check if they have the same modules + self.assertTrue( + set(transformers_model_from_save.state_dict().keys()) == set(transformers_model.state_dict().keys()) + ) + + +class CausalLMValueHeadModelTester(VHeadModelTester, unittest.TestCase): + """ + Testing suite for v-head models. + """ + + all_model_names = ALL_CAUSAL_LM_MODELS + trl_model_class = AutoModelForCausalLMWithValueHead + transformers_model_class = AutoModelForCausalLM + + def tearDown(self): + # free memory + gc.collect() + + def test_inference(self): + r""" + Test if the model can be used for inference and outputs 3 values + - logits, loss, and value states + """ + EXPECTED_OUTPUT_SIZE = 3 + + for model_name in self.all_model_names: + model = self.trl_model_class.from_pretrained(model_name) + input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) + outputs = model(input_ids) + + # Check if the outputs are of the right size - here + # we always output 3 values - logits, loss, and value states + self.assertEqual(len(outputs), EXPECTED_OUTPUT_SIZE) + + def test_dropout_config(self): + r""" + Test if we instantiate a model by adding `summary_drop_prob` to the config + it will be added to the v_head + """ + for model_name in self.all_model_names: + pretrained_model = self.transformers_model_class.from_pretrained(model_name) + pretrained_model.config.summary_dropout_prob = 0.5 + model = self.trl_model_class.from_pretrained(pretrained_model) + + # Check if v head of the model has the same dropout as the config + self.assertEqual(model.v_head.dropout.p, pretrained_model.config.summary_dropout_prob) + + def test_dropout_kwargs(self): + r""" + Test if we instantiate a model by adding `summary_drop_prob` to the config + it will be added to the v_head + """ + for model_name in self.all_model_names: + v_head_kwargs = {"summary_dropout_prob": 0.5} + + model = self.trl_model_class.from_pretrained(model_name, **v_head_kwargs) + + # Check if v head of the model has the same dropout as the config + self.assertEqual(model.v_head.dropout.p, 0.5) + + model = self.trl_model_class.from_pretrained(model_name, summary_dropout_prob=0.5) + + # Check if v head of the model has the same dropout as the config + self.assertEqual(model.v_head.dropout.p, 0.5) + + def test_generate(self): + r""" + Test if `generate` works for every model + """ + for model_name in self.all_model_names: + model = self.trl_model_class.from_pretrained(model_name) + input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) + + # Just check if the generation works + _ = model.generate(input_ids) + + def test_raise_error_not_causallm(self): + # Test with a model without a LM head + model_id = "trl-internal-testing/tiny-random-GPT2Model" + # This should raise a ValueError + with self.assertRaises(ValueError): + pretrained_model = AutoModelForCausalLM.from_pretrained(model_id) + _ = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model.transformer) + + def test_transformers_bf16_kwargs(self): + r""" + Test if the transformers kwargs are correctly passed + Here we check that loading a model in half precision works as expected, i.e. the weights of + the `pretrained_model` attribute is loaded in half precision and you can run a dummy + forward pass without any issue. + """ + for model_name in self.all_model_names: + trl_model = self.trl_model_class.from_pretrained(model_name, torch_dtype=torch.bfloat16) + + lm_head_namings = self.trl_model_class.lm_head_namings + + self.assertTrue( + any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings) + ) + + for lm_head_naming in lm_head_namings: + if hasattr(trl_model.pretrained_model, lm_head_naming): + self.assertTrue(getattr(trl_model.pretrained_model, lm_head_naming).weight.dtype == torch.bfloat16) + + dummy_input = torch.LongTensor([[0, 1, 0, 1]]) + + # check dummy forward pass works in half precision + _ = trl_model(dummy_input) + + @unittest.skip("This test needs to be run manually due to HF token issue.") + def test_push_to_hub(self): + for model_name in self.all_model_names: + model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name) + if "sharded" in model_name: + model.push_to_hub(model_name + "-ppo", use_auth_token=True, max_shard_size="1MB") + else: + model.push_to_hub(model_name + "-ppo", use_auth_token=True) + + model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(model_name + "-ppo") + # check all keys + self.assertEqual(model.state_dict().keys(), model_from_pretrained.state_dict().keys()) + + for name, param in model.state_dict().items(): + self.assertTrue( + torch.allclose(param, model_from_pretrained.state_dict()[name]), + f"Parameter {name} is not the same after push_to_hub and from_pretrained", + ) + + +class Seq2SeqValueHeadModelTester(VHeadModelTester, unittest.TestCase): + """ + Testing suite for v-head models. + """ + + all_model_names = ALL_SEQ2SEQ_MODELS + trl_model_class = AutoModelForSeq2SeqLMWithValueHead + transformers_model_class = AutoModelForSeq2SeqLM + + def tearDown(self): + # free memory + gc.collect() + + def test_inference(self): + r""" + Test if the model can be used for inference and outputs 3 values + - logits, loss, and value states + """ + EXPECTED_OUTPUT_SIZE = 3 + + for model_name in self.all_model_names: + model = self.trl_model_class.from_pretrained(model_name) + input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) + decoder_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) + outputs = model(input_ids, decoder_input_ids=decoder_input_ids) + + # Check if the outputs are of the right size - here + # we always output 3 values - logits, loss, and value states + self.assertEqual(len(outputs), EXPECTED_OUTPUT_SIZE) + + def test_dropout_config(self): + r""" + Test if we instantiate a model by adding `summary_drop_prob` to the config + it will be added to the v_head + """ + for model_name in self.all_model_names: + pretrained_model = self.transformers_model_class.from_pretrained(model_name) + pretrained_model.config.summary_dropout_prob = 0.5 + model = self.trl_model_class.from_pretrained(pretrained_model) + + # Check if v head of the model has the same dropout as the config + self.assertEqual(model.v_head.dropout.p, pretrained_model.config.summary_dropout_prob) + + def test_dropout_kwargs(self): + r""" + Test if we instantiate a model by adding `summary_drop_prob` to the config + it will be added to the v_head + """ + for model_name in self.all_model_names: + v_head_kwargs = {"summary_dropout_prob": 0.5} + + model = self.trl_model_class.from_pretrained(model_name, **v_head_kwargs) + + # Check if v head of the model has the same dropout as the config + self.assertEqual(model.v_head.dropout.p, 0.5) + + model = self.trl_model_class.from_pretrained(model_name, summary_dropout_prob=0.5) + + # Check if v head of the model has the same dropout as the config + self.assertEqual(model.v_head.dropout.p, 0.5) + + def test_generate(self): + r""" + Test if `generate` works for every model + """ + for model_name in self.all_model_names: + model = self.trl_model_class.from_pretrained(model_name) + input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) + decoder_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) + + # Just check if the generation works + _ = model.generate(input_ids, decoder_input_ids=decoder_input_ids) + + def test_raise_error_not_causallm(self): + # Test with a model without a LM head + model_id = "trl-internal-testing/tiny-random-T5Model" + # This should raise a ValueError + with self.assertRaises(ValueError): + pretrained_model = AutoModel.from_pretrained(model_id) + _ = self.trl_model_class.from_pretrained(pretrained_model) + + @unittest.skip("This test needs to be run manually due to HF token issue.") + def test_push_to_hub(self): + for model_name in self.all_model_names: + model = self.trl_model_class.from_pretrained(model_name) + if "sharded" in model_name: + model.push_to_hub(model_name + "-ppo", use_auth_token=True, max_shard_size="1MB") + else: + model.push_to_hub(model_name + "-ppo", use_auth_token=True) + + model_from_pretrained = self.trl_model_class.from_pretrained(model_name + "-ppo") + # check all keys + self.assertEqual(model.state_dict().keys(), model_from_pretrained.state_dict().keys()) + + for name, param in model.state_dict().items(): + self.assertTrue( + torch.allclose(param, model_from_pretrained.state_dict()[name]), + f"Parameter {name} is not the same after push_to_hub and from_pretrained", + ) + + def test_transformers_bf16_kwargs(self): + r""" + Test if the transformers kwargs are correctly passed + Here we check that loading a model in half precision works as expected, i.e. the weights of + the `pretrained_model` attribute is loaded in half precision and you can run a dummy + forward pass without any issue. + """ + for model_name in self.all_model_names: + trl_model = self.trl_model_class.from_pretrained(model_name, torch_dtype=torch.bfloat16) + + lm_head_namings = self.trl_model_class.lm_head_namings + + if model_name == "trl-internal-testing/tiny-random-FSMTForConditionalGeneration": + # skip the test for FSMT as it does not support mixed-prec + continue + + self.assertTrue( + any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings) + ) + + for lm_head_naming in lm_head_namings: + if hasattr(trl_model.pretrained_model, lm_head_naming): + self.assertTrue(getattr(trl_model.pretrained_model, lm_head_naming).weight.dtype == torch.bfloat16) + + dummy_input = torch.LongTensor([[0, 1, 0, 1]]) + + # check dummy forward pass works in half precision + _ = trl_model(input_ids=dummy_input, decoder_input_ids=dummy_input) + + +class ReferenceModelTest(unittest.TestCase): + def setUp(self): + self.model = AutoModelForCausalLMWithValueHead.from_pretrained( + "trl-internal-testing/tiny-random-GPT2LMHeadModel" + ) + self.test_input = torch.tensor([[0, 1, 2, 3]]) + self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=1) + self.layer_format = "pretrained_model.transformer.h.{layer}.attn.c_attn.weight" + + def test_independent_reference(self): + layer_0 = self.layer_format.format(layer=0) + layer_5 = self.layer_format.format(layer=4) + + ref_model = create_reference_model(self.model) + + first_layer_before = self.model.get_parameter(layer_0).data.clone() + last_layer_before = self.model.get_parameter(layer_5).data.clone() + + first_ref_layer_before = ref_model.get_parameter(layer_0).data.clone() + last_ref_layer_before = ref_model.get_parameter(layer_5).data.clone() + + output = self.model(input_ids=self.test_input, labels=self.test_input) + output[1].backward() + self.optimizer.step() + + first_layer_after = self.model.get_parameter(layer_0).data.clone() + last_layer_after = self.model.get_parameter(layer_5).data.clone() + + first_ref_layer_after = ref_model.get_parameter(layer_0).data.clone() + last_ref_layer_after = ref_model.get_parameter(layer_5).data.clone() + + # before optimization ref and model are identical + self.assertTrue((first_layer_before == first_ref_layer_before).all()) + self.assertTrue((last_layer_before == last_ref_layer_before).all()) + # ref model stays identical after optimization + self.assertTrue((first_ref_layer_before == first_ref_layer_after).all()) + self.assertTrue((last_ref_layer_before == last_ref_layer_after).all()) + # optimized model changes + self.assertTrue(not (first_layer_before == first_layer_after).all()) + self.assertTrue(not (last_layer_before == last_layer_after).all()) + + def test_shared_layers(self): + layer_0 = self.layer_format.format(layer=0) + layer_1 = self.layer_format.format(layer=1) + + ref_model = create_reference_model(self.model, num_shared_layers=1) + + first_layer_before = self.model.get_parameter(layer_0).data.clone() + second_layer_before = self.model.get_parameter(layer_1).data.clone() + + first_ref_layer_before = ref_model.get_parameter(layer_0).data.clone() + second_ref_layer_before = ref_model.get_parameter(layer_1).data.clone() + + output = self.model(input_ids=self.test_input, labels=self.test_input) + output[1].backward() + self.optimizer.step() + + first_layer_after = self.model.get_parameter(layer_0).data.clone() + second_layer_after = self.model.get_parameter(layer_1).data.clone() + + first_ref_layer_after = ref_model.get_parameter(layer_0).data.clone() + second_ref_layer_after = ref_model.get_parameter(layer_1).data.clone() + + # before optimization ref and model are identical + self.assertTrue((first_layer_before == first_ref_layer_before).all()) + self.assertTrue((second_layer_before == second_ref_layer_before).all()) + # ref model stays identical after optimization + self.assertTrue((first_ref_layer_before == first_ref_layer_after).all()) + self.assertTrue((second_ref_layer_before == second_ref_layer_after).all()) + # first layer of optimized model stays the same + self.assertTrue((first_layer_before == first_layer_after).all()) + # other layers in optimized model change + self.assertTrue(not (second_layer_before == second_layer_after).all()) diff --git a/tests/test_no_peft.py b/tests/test_no_peft.py new file mode 100644 index 0000000000000000000000000000000000000000..3190b7c85a57550f898133a72c3c7aa7455f0a6f --- /dev/null +++ b/tests/test_no_peft.py @@ -0,0 +1,153 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +import unittest +from unittest.mock import patch + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from .testing_utils import is_peft_available, require_peft + + +class DummyDataset(torch.utils.data.Dataset): + def __init__(self, query_data, response_data): + self.query_data = query_data + self.response_data = response_data + + def __len__(self): + return len(self.query_data) + + def __getitem__(self, idx): + return self.query_data[idx], self.response_data[idx] + + +EXPECTED_STATS = [ + "objective/kl", + "objective/kl_dist", + "objective/logprobs", + "objective/ref_logprobs", + "objective/kl_coef", + "objective/entropy", + "ppo/mean_non_score_reward", + "ppo/loss/policy", + "ppo/loss/value", + "ppo/loss/total", + "ppo/policy/entropy", + "ppo/policy/approxkl", + "ppo/policy/policykl", + "ppo/policy/clipfrac", + "ppo/policy/advantages", + "ppo/policy/advantages_mean", + "ppo/policy/ratio", + "ppo/returns/mean", + "ppo/returns/var", + "ppo/val/vpred", + "ppo/val/error", + "ppo/val/clipfrac", + "ppo/val/mean", + "ppo/val/var", + "ppo/val/var_explained", + "time/ppo/forward_pass", + "time/ppo/compute_rewards", + "time/ppo/optimize_step", + "time/ppo/calc_stats", + "time/ppo/total", + "ppo/learning_rate", +] + + +@require_peft +class TestPeftDependancy(unittest.TestCase): + def setUp(self): + self.causal_lm_model_id = "trl-internal-testing/tiny-random-GPTNeoXForCausalLM" + self.seq_to_seq_model_id = "trl-internal-testing/tiny-random-T5ForConditionalGeneration" + + if is_peft_available(): + from peft import LoraConfig, get_peft_model + + lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id) + self.peft_model = get_peft_model(causal_lm_model, lora_config) + + def test_no_peft(self): + with patch.dict(sys.modules, {"peft": None}): + from trl import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead + + # Check that loading a model with `peft` will raise an error + with self.assertRaises(ModuleNotFoundError): + import peft # noqa + + trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(self.causal_lm_model_id) # noqa + trl_seq2seq_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(self.seq_to_seq_model_id) # noqa + + def test_imports_no_peft(self): + with patch.dict(sys.modules, {"peft": None}): + from trl import ( # noqa + AutoModelForCausalLMWithValueHead, + AutoModelForSeq2SeqLMWithValueHead, + PPOConfig, + PPOTrainer, + PreTrainedModelWrapper, + ) + + def test_ppo_trainer_no_peft(self): + with patch.dict(sys.modules, {"peft": None}): + from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer + + ppo_model_id = "trl-internal-testing/dummy-GPT2-correct-vocab" + + trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(ppo_model_id) + tokenizer = AutoTokenizer.from_pretrained(ppo_model_id) + tokenizer.pad_token_id = tokenizer.eos_token_id + + ppo_config = PPOConfig(batch_size=2, mini_batch_size=1, log_with=None) + + dummy_dataset = DummyDataset( + [torch.LongTensor([0, 1, 0, 1, 0, 1]), torch.LongTensor([0, 1, 0, 1, 0, 1])], + [torch.LongTensor([1, 0, 1, 0, 1, 0]), torch.LongTensor([0, 1, 0, 1, 0, 1])], + ) + + ppo_trainer = PPOTrainer( + config=ppo_config, + model=trl_model, + ref_model=None, + tokenizer=tokenizer, + dataset=dummy_dataset, + ) + dummy_dataloader = ppo_trainer.dataloader + + for query_tensor, response_tensor in dummy_dataloader: + # define a reward for response + # (this could be any reward such as human feedback or output from another model) + reward = [torch.tensor(1.0), torch.tensor(0.0)] + # train model + train_stats = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + break + + # check gradients are not None + for _, param in trl_model.named_parameters(): + if param.requires_grad: + self.assertIsNotNone(param.grad) + + # check expected stats + for stat in EXPECTED_STATS: + self.assertIn(stat, train_stats) diff --git a/tests/test_peft_models.py b/tests/test_peft_models.py new file mode 100644 index 0000000000000000000000000000000000000000..3b004659d2b7160f62d0c38cc97e6b3284dff475 --- /dev/null +++ b/tests/test_peft_models.py @@ -0,0 +1,208 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import tempfile +import unittest + +import torch +from pytest import mark +from transformers import AutoModelForCausalLM + +from trl import AutoModelForCausalLMWithValueHead, is_peft_available + + +if is_peft_available(): + from peft import get_peft_model, LoraConfig + +from .testing_utils import require_bitsandbytes, require_peft + + +@require_peft +@mark.peft_test +class PeftModelTester(unittest.TestCase): + def setUp(self): + self.causal_lm_model_id = "trl-internal-testing/tiny-random-GPTNeoXForCausalLM" + self.lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + def test_create_peft_model(self): + r""" + Simply creates a peft model and checks that it can be loaded. + """ + causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id) + pretrained_model = get_peft_model(causal_lm_model, self.lora_config) + + _ = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model) + + def test_peft_requires_grad(self): + r""" + Check that the value head of the returned model has requires_grad=True. + """ + causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id) + pretrained_model = get_peft_model(causal_lm_model, self.lora_config) + + model = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model) + + # Check that the value head has requires_grad=True + self.assertTrue(model.v_head.summary.weight.requires_grad) + + def test_check_peft_model_nb_trainable_params(self): + r""" + Check that the number of trainable parameters is correct. + """ + causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id) + pretrained_model = get_peft_model(causal_lm_model, self.lora_config) + + model = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model) + + # Check that the number of trainable parameters is correct + nb_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + self.assertEqual(nb_trainable_params, 10273) + + # Check that the number of trainable param for the non-peft model is correct + non_peft_model = AutoModelForCausalLMWithValueHead.from_pretrained(self.causal_lm_model_id) + nb_trainable_params = sum(p.numel() for p in non_peft_model.parameters() if p.requires_grad) + self.assertEqual(nb_trainable_params, 99578) + + def test_create_peft_model_from_config(self): + r""" + Simply creates a peft model and checks that it can be loaded. + """ + trl_model = AutoModelForCausalLMWithValueHead.from_pretrained( + self.causal_lm_model_id, peft_config=self.lora_config + ) + # Check that the number of trainable parameters is correct + nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad) + self.assertEqual(nb_trainable_params, 10273) + + causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id) + trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(causal_lm_model, peft_config=self.lora_config) + # Check that the number of trainable parameters is correct + nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad) + self.assertEqual(nb_trainable_params, 10273) + + @require_bitsandbytes + def test_create_bnb_peft_model_from_config(self): + r""" + Simply creates a peft model and checks that it can be loaded. + """ + from bitsandbytes.nn import Linear8bitLt + + trl_model = AutoModelForCausalLMWithValueHead.from_pretrained( + self.causal_lm_model_id, peft_config=self.lora_config, load_in_8bit=True + ) + # Check that the number of trainable parameters is correct + nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad) + self.assertEqual(nb_trainable_params, 10273) + self.assertTrue( + trl_model.pretrained_model.model.gpt_neox.layers[0].mlp.dense_h_to_4h.__class__ == Linear8bitLt + ) + + causal_lm_model = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, load_in_8bit=True, device_map="auto" + ) + trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(causal_lm_model, peft_config=self.lora_config) + # Check that the number of trainable parameters is correct + nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad) + self.assertEqual(nb_trainable_params, 10273) + self.assertTrue( + trl_model.pretrained_model.model.gpt_neox.layers[0].mlp.dense_h_to_4h.__class__ == Linear8bitLt + ) + + def test_save_pretrained_peft(self): + r""" + Check that the model can be saved and loaded properly. + """ + causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id) + pretrained_model = get_peft_model(causal_lm_model, self.lora_config) + + model = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model) + + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir) + + # check that the files `adapter_model.safetensors` and `adapter_config.json` are in the directory + self.assertTrue( + os.path.isfile(f"{tmp_dir}/adapter_model.safetensors"), + msg=f"{tmp_dir}/adapter_model.safetensors does not exist", + ) + self.assertTrue( + os.path.exists(f"{tmp_dir}/adapter_config.json"), + msg=f"{tmp_dir}/adapter_config.json does not exist", + ) + # check also for `pytorch_model.bin` and make sure it only contains `v_head` weights + self.assertTrue( + os.path.exists(f"{tmp_dir}/pytorch_model.bin"), + msg=f"{tmp_dir}/pytorch_model.bin does not exist", + ) + maybe_v_head = torch.load(f"{tmp_dir}/pytorch_model.bin") + # check that only keys that starts with `v_head` are in the dict + self.assertTrue( + all(k.startswith("v_head") for k in maybe_v_head.keys()), + msg=f"keys in {tmp_dir}/pytorch_model.bin do not start with `v_head`", + ) + + model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(tmp_dir) + + # check all the weights are the same + for p1, p2 in zip(model.named_parameters(), model_from_pretrained.named_parameters()): + self.assertTrue(torch.allclose(p1[1], p2[1]), msg=f"{p1[0]} != {p2[0]}") + + def test_load_pretrained_peft(self): + r""" + Check that the model saved with peft class interface can be loaded properly. + """ + causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id) + pretrained_model = get_peft_model(causal_lm_model, self.lora_config) + + model = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model) + + with tempfile.TemporaryDirectory() as tmp_dir: + pretrained_model.save_pretrained(tmp_dir) + model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(tmp_dir) + + # check that the files `adapter_model.safetensors` and `adapter_config.json` are in the directory + self.assertTrue( + os.path.isfile(f"{tmp_dir}/adapter_model.safetensors"), + msg=f"{tmp_dir}/adapter_model.safetensors does not exist", + ) + self.assertTrue( + os.path.exists(f"{tmp_dir}/adapter_config.json"), + msg=f"{tmp_dir}/adapter_config.json does not exist", + ) + + # check all the weights are the same + for p1, p2 in zip(model.named_parameters(), model_from_pretrained.named_parameters()): + if p1[0] not in ["v_head.summary.weight", "v_head.summary.bias"]: + self.assertTrue(torch.allclose(p1[1], p2[1]), msg=f"{p1[0]} != {p2[0]}") + + def test_continue_training_peft_model(self): + r""" + Load peft and checks that it can continue training. + """ + causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id) + pretrained_model = get_peft_model(causal_lm_model, self.lora_config) + + with tempfile.TemporaryDirectory() as tmp_dir: + pretrained_model.save_pretrained(tmp_dir) + # set is_trainable to True + model = AutoModelForCausalLMWithValueHead.from_pretrained(tmp_dir, is_trainable=True) + # Check that the number of trainable parameters is correct + nb_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + self.assertEqual(nb_trainable_params, 10273) diff --git a/tests/test_ppo_trainer.py b/tests/test_ppo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..0af091fc3c24b720e4611c71c4ed91dc16ad6582 --- /dev/null +++ b/tests/test_ppo_trainer.py @@ -0,0 +1,1232 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import fnmatch +import gc +import re +import tempfile +import unittest + +import pytest +import torch +from huggingface_hub import HfApi, HfFolder, delete_repo +from parameterized import parameterized +from pytest import mark +from requests.exceptions import HTTPError +from transformers import AutoTokenizer + +from trl import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead, PPOConfig, PPOTrainer, set_seed +from trl.core import respond_to_batch + +from .testing_constants import CI_HUB_ENDPOINT, CI_HUB_USER, CI_HUB_USER_TOKEN +from .testing_utils import require_peft, require_torch_multi_gpu + + +EXPECTED_STATS = [ + "objective/kl", + "objective/kl_dist", + "objective/logprobs", + "objective/ref_logprobs", + "objective/kl_coef", + "objective/entropy", + "ppo/mean_non_score_reward", + "ppo/loss/policy", + "ppo/loss/value", + "ppo/loss/total", + "ppo/policy/entropy", + "ppo/policy/approxkl", + "ppo/policy/policykl", + "ppo/policy/clipfrac", + "ppo/policy/advantages", + "ppo/policy/advantages_mean", + "ppo/policy/ratio", + "ppo/returns/mean", + "ppo/returns/var", + "ppo/val/vpred", + "ppo/val/error", + "ppo/val/clipfrac", + "ppo/val/mean", + "ppo/val/var", + "ppo/val/var_explained", + "time/ppo/forward_pass", + "time/ppo/compute_rewards", + "time/ppo/optimize_step", + "time/ppo/calc_stats", + "time/ppo/total", + "ppo/learning_rate", +] + + +class DummyDataset(torch.utils.data.Dataset): + def __init__(self, query_data, response_data): + self.query_data = query_data + self.response_data = response_data + + def __len__(self): + return len(self.query_data) + + def __getitem__(self, idx): + return self.query_data[idx], self.response_data[idx] + + +def apply_mask(values, mask): + unmasked_values = [] + for v, m in zip(values, mask): + if m == 1: + unmasked_values.append(v) + return torch.Tensor(unmasked_values) + + +def abs_diff_masked_tensors(tensor_1, tensor_2, mask_1, mask_2): + diffs = [] + for l1, l2, m1, m2 in zip(tensor_1, tensor_2, mask_1, mask_2): + diff = apply_mask(l1, m1) - apply_mask(l2, m2) + diffs.append(diff.sum()) + return abs(sum(diffs)) + + +class PPOTrainerTester(unittest.TestCase): + """ + A wrapper class for testing PPOTrainer + """ + + @classmethod + def setUpClass(cls): + set_seed(42) + cls._token = CI_HUB_USER_TOKEN + cls._api = HfApi(endpoint=CI_HUB_ENDPOINT) + HfFolder.save_token(CI_HUB_USER_TOKEN) + + # model_id + cls.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab" + + # get models and tokenizer + cls.gpt2_model = AutoModelForCausalLMWithValueHead.from_pretrained(cls.model_id) + cls.gpt2_model_ref = AutoModelForCausalLMWithValueHead.from_pretrained(cls.model_id) + cls.gpt2_tokenizer = AutoTokenizer.from_pretrained(cls.model_id) + + cls.gpt2_tokenizer.pad_token = cls.gpt2_tokenizer.eos_token + + # get bloom as right padding examples: + model_id = "trl-internal-testing/tiny-BloomForCausalLM-correct-vocab" + cls.bloom_model = AutoModelForCausalLMWithValueHead.from_pretrained(model_id) + cls.bloom_tokenizer = AutoTokenizer.from_pretrained(model_id) + + model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration-correct-vocab" + cls.t5_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(model_id) + cls.t5_tokenizer = AutoTokenizer.from_pretrained(model_id) + + # initialize trainer + cls.ppo_config = PPOConfig(batch_size=2, mini_batch_size=1, log_with=None) + + @classmethod + def tearDownClass(cls): + for model in [f"{CI_HUB_USER}/test-ppo-trainer"]: + try: + delete_repo(token=cls._token, repo_id=model) + except HTTPError: + pass + + def setUp(self): + # initialize trainer + self.ppo_config = PPOConfig(batch_size=2, mini_batch_size=1, log_with=None) + self.gpt2_model.train() + return super().setUp() + + def tearDown(self): + # free memory + gc.collect() + + def _init_dummy_dataset(self): + # encode a query + query_txt = "This morning I went to the " + query_tensor = self.gpt2_tokenizer.encode(query_txt, return_tensors="pt") + assert query_tensor.shape == (1, 7) + # get model response + response_tensor = respond_to_batch(self.gpt2_model, query_tensor) + assert response_tensor.shape == (1, 20) + + # create a dummy dataset + min_length = min(len(query_tensor[0]), len(response_tensor[0])) + dummy_dataset = DummyDataset( + [query_tensor[:, :min_length].squeeze(0) for _ in range(2)], + [response_tensor[:, :min_length].squeeze(0) for _ in range(2)], + ) + + return dummy_dataset + + def test_drop_last_dataloader(self): + self.ppo_config = PPOConfig(batch_size=3, mini_batch_size=1, log_with=None) + + dummy_dataset = self._init_dummy_dataset() + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=self.gpt2_model_ref, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + dummy_dataloader = ppo_trainer.dataloader + + self.assertEqual(len(dummy_dataloader), 0) + + def test_ppo_step(self): + # initialize dataset + dummy_dataset = self._init_dummy_dataset() + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=self.gpt2_model_ref, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + dummy_dataloader = ppo_trainer.dataloader + # train model with ppo + for query_tensor, response_tensor in dummy_dataloader: + # define a reward for response + # (this could be any reward such as human feedback or output from another model) + reward = [torch.tensor(1.0), torch.tensor(0.0)] + # train model + train_stats = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + break + + for param in ppo_trainer.model.parameters(): + assert param.grad is not None + + for stat in EXPECTED_STATS: + assert stat in train_stats.keys() + + def test_ppo_step_with_masks(self): + # initialize dataset + dummy_dataset = self._init_dummy_dataset() + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=self.gpt2_model_ref, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + dummy_dataloader = ppo_trainer.dataloader + # train model with ppo + for query_tensor, response_tensor in dummy_dataloader: + # define a reward for response + # (this could be any reward such as human feedback or output from another model) + reward = [torch.tensor(1.0), torch.tensor(0.0)] + + response_mask = [torch.ones_like(r) for r in response_tensor] + + # train model + train_stats = ppo_trainer.step( + [q for q in query_tensor], [r for r in response_tensor], reward, response_mask + ) + break + + for param in ppo_trainer.model.parameters(): + assert param.grad is not None + + for stat in EXPECTED_STATS: + assert stat in train_stats.keys() + + def test_ppo_step_with_no_ref_sgd(self): + # initialize dataset + dummy_dataset = self._init_dummy_dataset() + optimizer = torch.optim.SGD(self.gpt2_model.parameters(), lr=0.01) + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=None, + optimizer=optimizer, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + dummy_dataloader = ppo_trainer.dataloader + + self.assertTrue(isinstance(ppo_trainer.optimizer.optimizer, torch.optim.SGD)) + + # train model with ppo + for query_tensor, response_tensor in dummy_dataloader: + # define a reward for response + # (this could be any reward such as human feedback or output from another model) + reward = [torch.tensor(1.0), torch.tensor(0.0)] + # train model + train_stats = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + break + + for name, param in ppo_trainer.model.named_parameters(): + self.assertTrue(param.grad is not None, f"Parameter {name} has no gradient") + + # ref model should not be trained + for name, param in ppo_trainer.ref_model.named_parameters(): + self.assertTrue(param.grad is None, f"Parameter {name} has a gradient") + + # Finally check stats + for stat in EXPECTED_STATS: + assert stat in train_stats.keys() + + def test_ppo_step_with_no_ref_sgd_lr_scheduler(self): + # initialize dataset + dummy_dataset = self._init_dummy_dataset() + optimizer = torch.optim.SGD(self.gpt2_model.parameters(), lr=0.01) + lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=None, + optimizer=optimizer, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + lr_scheduler=lr_scheduler, + ) + dummy_dataloader = ppo_trainer.dataloader + + self.assertTrue(isinstance(ppo_trainer.optimizer.optimizer, torch.optim.SGD)) + self.assertTrue(isinstance(ppo_trainer.lr_scheduler.scheduler, torch.optim.lr_scheduler.ExponentialLR)) + + # train model with ppo + for query_tensor, response_tensor in dummy_dataloader: + # define a reward for response + # (this could be any reward such as human feedback or output from another model) + reward = [torch.tensor(1.0), torch.tensor(0.0)] + # train model + _ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + train_stats = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + break + + for name, param in ppo_trainer.model.named_parameters(): + self.assertTrue(param.grad is not None, f"Parameter {name} has no gradient") + + # ref model should not be trained + for name, param in ppo_trainer.ref_model.named_parameters(): + self.assertTrue(param.grad is None, f"Parameter {name} has a gradient") + + # Finally check stats + for stat in EXPECTED_STATS: + assert stat in train_stats.keys() + + # assert that the LR has increased for exponential decay + self.assertTrue(train_stats["ppo/learning_rate"] > self.ppo_config.learning_rate) + + def test_ppo_step_with_no_ref(self): + # initialize dataset + dummy_dataset = self._init_dummy_dataset() + self.gpt2_model = AutoModelForCausalLMWithValueHead.from_pretrained(self.model_id) + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=None, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + dummy_dataloader = ppo_trainer.dataloader + # train model with ppo + for query_tensor, response_tensor in dummy_dataloader: + # define a reward for response + # (this could be any reward such as human feedback or output from another model) + reward = [torch.tensor(1.0), torch.tensor(0.0)] + # train model + train_stats = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + break + + for name, param in ppo_trainer.model.named_parameters(): + self.assertTrue(param.grad is not None, f"Parameter {name} has no gradient") + + # ref model should not be trained + for name, param in ppo_trainer.ref_model.named_parameters(): + self.assertTrue(param.grad is None, f"Parameter {name} has a gradient") + + # initialize a new gpt2 model: + model = AutoModelForCausalLMWithValueHead.from_pretrained(self.model_id) + for name, param in ppo_trainer.ref_model.named_parameters(): + if "v_head" not in name: + name = name.replace("pretrained_model.", "") + + self.assertTrue( + torch.allclose(param.cpu(), model.state_dict()[name].cpu()), + f"Parameter {name} has changed from the original model", + ) + + # Finally check stats + for stat in EXPECTED_STATS: + assert stat in train_stats.keys() + + def test_ppo_step_with_no_ref_custom_layers(self): + """ + Test PPO step with no reference model and custom layers + For shared layers configuration, all the layers after the `num_shared_layers` are considered as custom layers + therefore the gradients should be computed for these layers only. + """ + # initialize dataset + dummy_dataset = self._init_dummy_dataset() + self.gpt2_model = AutoModelForCausalLMWithValueHead.from_pretrained(self.model_id) + num_shared_layers = 1 + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=None, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + num_shared_layers=num_shared_layers, + ) + dummy_dataloader = ppo_trainer.dataloader + # train model with ppo + for query_tensor, response_tensor in dummy_dataloader: + # define a reward for response + # (this could be any reward such as human feedback or output from another model) + reward = [torch.tensor(1.0), torch.tensor(0.0)] + # train model + train_stats = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + break + + pattern = r".*transformer\.h\.(\d+)\..*" + final_layers = ["ln_f", "v_head", "lm_head"] + + for name, param in ppo_trainer.model.named_parameters(): + if re.match(pattern, name): + layer_number = int(re.match(pattern, name).groups(0)[0]) + if layer_number < num_shared_layers: + self.assertTrue(param.grad is None, f"Parameter {name} has a gradient") + else: + self.assertTrue(param.grad is not None, f"Parameter {name} has no gradient") + elif any([layer in name for layer in final_layers]): + self.assertTrue(param.grad is not None, f"Parameter {name} has no gradient") + + # ref model should not be trained + for name, param in ppo_trainer.ref_model.named_parameters(): + self.assertTrue(param.grad is None, f"Parameter {name} has a gradient") + + for stat in EXPECTED_STATS: + assert stat in train_stats.keys() + + def test_ppo_step_with_ref_and_custom_layers_warning(self): + """ + Test PPO step with a reference model and custom layers + The trainer should raise a warning if the argument `num_shared_layers` is set + together with a reference model. + """ + # initialize dataset + dummy_dataset = self._init_dummy_dataset() + + num_shared_layers = 6 + + with self.assertWarns(UserWarning): + _ = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=self.gpt2_model_ref, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + num_shared_layers=num_shared_layers, + ) + + def test_ppo_step_rewards_shape(self): + """ + Test if the rewards shape is correct by asserting that if a wrong reward shape is passed, we get + a value error. + """ + + # initialize dataset + dummy_dataset = self._init_dummy_dataset() + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=None, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + dummy_dataloader = ppo_trainer.dataloader + # train model with ppo + for query_tensor, response_tensor in dummy_dataloader: + # define a reward for response + # (this could be any reward such as human feedback or output from another model) + reward = [torch.tensor([[1.0]]), torch.tensor([[0.0]])] + # train model - this should raise an error + with self.assertRaises(ValueError): + _ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + + reward = [torch.tensor([1.0]), torch.tensor([0.0])] + # train model - this should work + _ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + break + + # check if the gradients are computed for the model + for name, param in ppo_trainer.model.named_parameters(): + self.assertTrue(param.grad is not None, f"Parameter {name} has no gradient") + + # ref model should not be trained + for name, param in ppo_trainer.ref_model.named_parameters(): + self.assertTrue(param.grad is None, f"Parameter {name} has a gradient") + + def test_ppo_step_input_shape(self): + """ + Test if the shape of the expected inputs are correct + """ + # initialize dataset + dummy_dataset = self._init_dummy_dataset() + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=None, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + dummy_dataloader = ppo_trainer.dataloader + # train model with ppo + for query_tensor, response_tensor in dummy_dataloader: + # define a reward for response + # (this could be any reward such as human feedback or output from another model) + reward = [torch.tensor([1.0]), torch.tensor([0.0])] + # train model - this should raise an error + bs = ppo_trainer.config.batch_size + + queries, responses, _, _ = ppo_trainer._step_safety_checker( + bs, [q for q in query_tensor], [r for r in response_tensor], reward + ) + + self.assertTrue(isinstance(queries, list), f"queries should be a list, got {type(queries)}") + self.assertTrue(isinstance(responses, list), f"responses should be a list, got {type(responses)}") + + # check the shapes + for i in range(bs): + self.assertEqual(queries[i].shape, torch.Size([7])) + self.assertEqual(responses[i].size(), torch.Size([7])) + break + + def test_ppo_step_no_dataset(self): + """ + Test if the training loop works fine without passing a dataset + """ + query_txt = "This morning I went to the " + query_tensor = self.gpt2_tokenizer.encode(query_txt, return_tensors="pt") + self.ppo_config.batch_size = 1 + + response_tensor = respond_to_batch(self.gpt2_model, query_tensor) + + # Check that this warns the user about batch size + with self.assertWarns(UserWarning): + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=self.gpt2_model_ref, + tokenizer=self.gpt2_tokenizer, + ) + # train model with ppo + reward = [torch.tensor([1.0])] + # train model - this should work fine + train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward) + + # check gradients + for name, param in ppo_trainer.model.named_parameters(): + self.assertTrue(param.grad is not None, f"Parameter {name} has no gradient") + + # ref model should not be trained + for name, param in ppo_trainer.ref_model.named_parameters(): + self.assertTrue(param.grad is None, f"Parameter {name} has a gradient") + + # check train stats + for stat in EXPECTED_STATS: + self.assertTrue(stat in train_stats, f"Train stats should contain {stat}") + + def test_loss_trainer(self): + """ + Test if the loss trainer works fine + """ + # initialize dataset + dummy_dataset = self._init_dummy_dataset() + + self.gpt2_model.eval() + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=None, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + + dummy_queries = [torch.tensor([1, 2, 3, 4]), torch.tensor([1, 2, 3, 4, 5, 6, 7])] + dummy_responses = [torch.tensor([5, 6, 7, 8, 9]), torch.tensor([8, 9, 10, 11, 12, 13])] + dummy_scores = torch.Tensor([1, 2]) + + ppo_trainer.config.mini_batch_size = 1 + ppo_trainer.config.batch_size = 1 + model_inputs = ppo_trainer.prepare_model_inputs(dummy_queries, dummy_responses) + all_logprobs, _, values, mask = ppo_trainer.batched_forward_pass( + self.gpt2_model, dummy_queries, dummy_responses, model_inputs + ) + + # dummy values + ref_logprobs = all_logprobs + 1 + logits = torch.exp(all_logprobs) + vpreds = values + 0.1 + + score, non_score = ppo_trainer.compute_rewards(dummy_scores, all_logprobs, ref_logprobs, mask) + values, advantages, returns = ppo_trainer.compute_advantages(values, score, mask) + + # just make sure a dummy loss is computed + idx = 0 + pg_loss, v_loss, _ = ppo_trainer.loss( + all_logprobs[idx].unsqueeze(0), + values[idx].unsqueeze(0), + logits[idx].unsqueeze(0), + vpreds[idx].unsqueeze(0), + ref_logprobs[idx].unsqueeze(0), + mask[idx].unsqueeze(0), + advantages[idx].unsqueeze(0), + returns[idx].unsqueeze(0), + ) + + self.assertAlmostEqual(pg_loss.item(), 2.0494, 4) + self.assertAlmostEqual(v_loss.item(), 0.07110, 4) + + # check if we get same results with masked parts removed + pg_loss_unmasked, v_loss_unmasked, _ = ppo_trainer.loss( + apply_mask(all_logprobs[idx], mask[idx]).unsqueeze(0), + apply_mask(values[idx], mask[idx]).unsqueeze(0), + apply_mask(logits[idx], mask[idx]).unsqueeze(0), + apply_mask(vpreds[idx], mask[idx]).unsqueeze(0), + apply_mask(ref_logprobs[idx], mask[idx]).unsqueeze(0), + apply_mask(mask[idx], mask[idx]).unsqueeze(0), + apply_mask(advantages[idx], mask[idx]).unsqueeze(0), + apply_mask(returns[idx], mask[idx]).unsqueeze(0), + ) + self.assertAlmostEqual(pg_loss_unmasked.item(), 2.0494, 4) + self.assertAlmostEqual(v_loss_unmasked.item(), 0.07110, 4) + + @parameterized.expand( + [ + ["gpt2"], + ["bloom"], + ["t5"], + ] + ) + def test_batched_forward_pass(self, name): + """ + Test if the loss trainer works fine + """ + # initialize dataset + dummy_dataset = self._init_dummy_dataset() + + dummy_queries = [torch.tensor([1, 2, 3, 4]), torch.tensor([1, 2, 3, 4, 5, 6, 7])] + dummy_responses = [torch.tensor([5, 6, 7, 8, 9]), torch.tensor([8, 9, 10, 11, 12, 13])] + + if name == "gpt2": + model = self.gpt2_model + tokenizer = self.gpt2_tokenizer + elif name == "bloom": + model = self.bloom_model + tokenizer = self.bloom_tokenizer + elif name == "t5": + model = self.t5_model + tokenizer = self.t5_tokenizer + + model.eval() + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=model, + ref_model=None, + tokenizer=tokenizer, + dataset=dummy_dataset, + ) + + # we test all combinations of fwd_bs and bs: + # if fwd_bs=bs=1: no padding is applied and only one forward pass + # if fwd_bs=1/bs=2: padding is applied and results computed in two fwd passes + # if fwd_bs=bs=2: padding is applied and results computed in one fwd pass + + ppo_trainer.config.mini_batch_size = 1 + ppo_trainer.config.batch_size = 1 + + model_inputs = ppo_trainer.prepare_model_inputs([dummy_queries[0]], [dummy_responses[0]]) + logprobs_0, logits_0, values_0, mask_0 = ppo_trainer.batched_forward_pass( + model, [dummy_queries[0]], [dummy_responses[0]], model_inputs + ) + + ppo_trainer.config.batch_size = 2 + model_inputs = ppo_trainer.prepare_model_inputs(dummy_queries, dummy_responses) + logprobs_1, logits_1, values_1, mask_1 = ppo_trainer.batched_forward_pass( + model, dummy_queries, dummy_responses, model_inputs + ) + + ppo_trainer.config.mini_batch_size = 2 + model_inputs = ppo_trainer.prepare_model_inputs(dummy_queries, dummy_responses) + logprobs_2, logits_2, values_2, mask_2 = ppo_trainer.batched_forward_pass( + model, dummy_queries, dummy_responses, model_inputs + ) + + self.assertLessEqual(abs_diff_masked_tensors(logprobs_1, logprobs_2, mask_1, mask_2), 1e-4) + self.assertLessEqual(abs_diff_masked_tensors(values_1, values_2, mask_1, mask_2), 1e-4) + + self.assertLessEqual(abs_diff_masked_tensors(logprobs_0, logprobs_2[:1], mask_0, mask_2[:1]), 1e-4) + self.assertLessEqual(abs_diff_masked_tensors(values_0, values_2[:1], mask_0, mask_2[:1]), 1e-4) + + def test_ppo_trainer_max_grad_norm(self): + """ + Test if the `max_grad_norm` feature works as expected + """ + # initialize dataset + dummy_dataset = self._init_dummy_dataset() + + self.ppo_config.max_grad_norm = 0.00001 + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=None, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + + dummy_dataloader = ppo_trainer.dataloader + + # train model with ppo + for query_tensor, response_tensor in dummy_dataloader: + # define a reward for response + # (this could be any reward such as human feedback or output from another model) + reward = [torch.tensor(1.0), torch.tensor(0.0)] + # train model + _ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + break + + # check gradients + for name, param in ppo_trainer.model.named_parameters(): + self.assertTrue(param.grad is not None, f"Parameter {name} has no gradient") + self.assertTrue( + torch.all(param.grad.abs() <= self.ppo_config.max_grad_norm), + f"Parameter {name} has a gradient larger than max_grad_norm", + ) + + def test_ppo_trainer_kl_penalty(self): + dummy_dataset = self._init_dummy_dataset() + + log_probs = torch.Tensor([[0.5, 0.2, 0.1], [0.6, 0.2, 0.1]]) + ref_log_probs = torch.Tensor([[0.4, 0.3, 0.0], [0.7, 0.1, 0.3]]) + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=None, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + + expected_output = torch.Tensor([[0.1000, -0.1000, 0.1000], [-0.1000, 0.1000, -0.2000]]) + self.assertTrue(torch.allclose(ppo_trainer._kl_penalty(log_probs, ref_log_probs), expected_output)) + + self.ppo_config.kl_penalty = "abs" + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=None, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + + expected_output = torch.Tensor([[0.1000, 0.1000, 0.1000], [0.1000, 0.1000, 0.2000]]) + self.assertTrue(torch.allclose(ppo_trainer._kl_penalty(log_probs, ref_log_probs), expected_output)) + + self.ppo_config.kl_penalty = "mse" + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=None, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + + expected_output = torch.Tensor([[0.0050, 0.0050, 0.0050], [0.0050, 0.0050, 0.0200]]) + self.assertTrue(torch.allclose(ppo_trainer._kl_penalty(log_probs, ref_log_probs), expected_output)) + + def test_ppo_trainer_full_kl_penalty(self): + # a few more extensive tests for the full kl option as it is more involved + dummy_dataset = self._init_dummy_dataset() + + self.ppo_config.kl_penalty = "full" + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=None, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + + # Test on tensors for size B,S,T = (1,2,3) + # test for when the two dists are the same + log_probs = torch.Tensor( + [ + [ + [0.1, 0.2, 0.7], + [0.3, 0.4, 0.3], + ] + ] + ).exp() + + ref_log_probs = torch.Tensor( + [ + [ + [0.1, 0.2, 0.7], + [0.3, 0.4, 0.3], + ] + ] + ).exp() + + expected_output = torch.Tensor( + [[0.0, 0.0]], + ) + output = ppo_trainer._kl_penalty(log_probs, ref_log_probs) + self.assertTrue(output.shape == (1, 2)) + self.assertTrue(torch.allclose(output, expected_output)) + + # test for when the two dists are almost not overlapping + log_probs = torch.Tensor( + [ + [ + [0.98, 0.01, 0.01], + [0.01, 0.98, 0.01], + ] + ] + ).log() + + ref_log_probs = torch.Tensor( + [ + [ + [0.01, 0.01, 0.98], + [0.01, 0.01, 0.98], + ] + ] + ).log() + + expected_output = torch.Tensor( + [[4.4474, 4.4474]], + ) + output = ppo_trainer._kl_penalty(log_probs, ref_log_probs) + self.assertTrue(output.shape == (1, 2)) + self.assertTrue(torch.allclose(output, expected_output)) + + # test for when the two dists are almost not overlapping + log_probs = torch.Tensor( + [ + [ + [0.49, 0.02, 0.49], + [0.49, 0.02, 0.49], + ] + ] + ).log() + + ref_log_probs = torch.Tensor( + [ + [ + [0.01, 0.98, 0.01], + [0.49, 0.02, 0.49], + ] + ] + ).log() + + expected_output = torch.Tensor( + [[3.7361, 0.0]], + ) + output = ppo_trainer._kl_penalty(log_probs, ref_log_probs) + self.assertTrue(output.shape == (1, 2)) + self.assertTrue(torch.allclose(output, expected_output, atol=1e-4)) + + @require_peft + @mark.peft_test + def test_peft_model_ppo_trainer(self): + from peft import LoraConfig, get_peft_model + from transformers import AutoModelForCausalLM + + lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + gpt2_model = AutoModelForCausalLM.from_pretrained(self.model_id) + + # this line is very important + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + gpt2_model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + peft_model = get_peft_model(gpt2_model, lora_config) + model = AutoModelForCausalLMWithValueHead.from_pretrained(peft_model) + + dummy_dataset = self._init_dummy_dataset() + self.ppo_config.batch_size = 2 + self.ppo_config.mini_batch_size = 1 + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=model, + ref_model=None, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + + self.assertTrue(ppo_trainer.ref_model is None) + + dummy_dataloader = ppo_trainer.dataloader + + # train model with ppo + for query_tensor, response_tensor in dummy_dataloader: + # define a reward for response + # (this could be any reward such as human feedback or output from another model) + reward = [torch.tensor(1.0), torch.tensor(0.0)] + # train model by running a step twice + _ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + + ppo_trainer.model.train() + ppo_trainer.model.gradient_checkpointing_enable() + _ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + break + + # check gradients + for name, param in model.named_parameters(): + if "lora" in name or "v_head" in name: + self.assertTrue(param.grad is not None, f"Parameter {name} has a no gradient") + else: + self.assertTrue(param.grad is None, f"Parameter {name} has a gradient") + + @require_peft + @mark.peft_test + def test_peft_model_ppo_adapter_rm_trainer(self): + from peft import LoraConfig, get_peft_model + from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification + + dummy_inputs = torch.LongTensor([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]) + rm_lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="SEQ_CLS", + ) + + reward_model = AutoModelForSequenceClassification.from_pretrained(self.model_id) + reward_model = get_peft_model(reward_model, rm_lora_config) + dummy_optim = torch.optim.Adam(filter(lambda p: p.requires_grad, reward_model.parameters()), lr=1e-3) + + previous_rm_logits = reward_model(dummy_inputs).logits + loss = previous_rm_logits.mean() + loss.backward() + + dummy_optim.step() + reward_model.eval() + + original_rm_logits = reward_model(dummy_inputs).logits + + with tempfile.TemporaryDirectory() as tmpdirname: + reward_model.save_pretrained(tmpdirname) + + lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + gpt2_model = AutoModelForCausalLM.from_pretrained(self.model_id) + + # this line is very important + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + gpt2_model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + peft_model = get_peft_model(gpt2_model, lora_config) + model = AutoModelForCausalLMWithValueHead.from_pretrained( + peft_model, + reward_adapter=tmpdirname, + ) + + dummy_dataset = self._init_dummy_dataset() + self.ppo_config.batch_size = 2 + self.ppo_config.mini_batch_size = 1 + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=model, + ref_model=None, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + + self.assertTrue(ppo_trainer.ref_model is None) + + dummy_dataloader = ppo_trainer.dataloader + + # train model with ppo + for query_tensor, response_tensor in dummy_dataloader: + # define a reward for response + # (this could be any reward such as human feedback or output from another model) + reward = [torch.tensor(1.0), torch.tensor(0.0)] + # train model by running a step twice + _ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + + ppo_trainer.model.train() + ppo_trainer.model.gradient_checkpointing_enable() + _ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + break + + new_logits = ppo_trainer.model.compute_reward_score(dummy_inputs) + self.assertTrue(not torch.allclose(previous_rm_logits, new_logits[:, -1, :])) + self.assertTrue(torch.allclose(original_rm_logits, new_logits[:, -1, :])) + + # check gradients + for name, param in model.named_parameters(): + if ("lora" in name or "v_head" in name) and ("reward" not in name): + self.assertTrue(param.grad is not None, f"Parameter {name} has a no gradient") + else: + self.assertTrue(param.grad is None, f"Parameter {name} has a gradient") + + @unittest.skip("Fix by either patching `whomai()` to work in the staging endpoint or use a dummy prod user.") + def test_push_to_hub(self): + REPO_NAME = "test-ppo-trainer" + repo_id = f"{CI_HUB_USER}/{REPO_NAME}" + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=None, + tokenizer=self.gpt2_tokenizer, + dataset=self._init_dummy_dataset(), + ) + with tempfile.TemporaryDirectory(): + url = ppo_trainer.push_to_hub(repo_id=repo_id, token=self._token, api_endpoint=CI_HUB_ENDPOINT) + # Extract repo_name from the url + re_search = re.search(CI_HUB_ENDPOINT + r"/([^/]+/[^/]+)/", url) + self.assertTrue(re_search is not None) + hub_repo_id = re_search.groups()[0] + # Check we created a Hub repo + self.assertEqual(hub_repo_id, repo_id) + # Ensure all files are present + files = sorted(self._api.list_repo_files(hub_repo_id)) + assert all( + fnmatch.fnmatch(file, expected_file) + for file, expected_file in zip( + files, + [ + ".gitattributes", + "README.md", + "config.json", + "merges.txt", + "pytorch_model.bin", + "special_tokens_map.json", + "tokenizer_config.json", + "vocab.json", + ], + ) + ) + + @require_peft + @require_torch_multi_gpu + @mark.peft_test + def test_peft_model_ppo_trainer_multi_gpu(self): + from peft import LoraConfig, get_peft_model + from transformers import AutoModelForCausalLM + + lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + gpt2_model = AutoModelForCausalLM.from_pretrained( + "gpt2", device_map="balanced", max_memory={0: "500MB", 1: "500MB"} + ) + + self.assertTrue(set(gpt2_model.hf_device_map.values()) == {0, 1}) + + # this line is very important + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + gpt2_model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + peft_model = get_peft_model(gpt2_model, lora_config) + model = AutoModelForCausalLMWithValueHead.from_pretrained(peft_model) + + self.assertTrue(model.is_sequential_parallel) + + dummy_dataset = self._init_dummy_dataset() + self.ppo_config.batch_size = 2 + self.ppo_config.mini_batch_size = 1 + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=model, + ref_model=None, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + + self.assertTrue(ppo_trainer.ref_model is None) + + dummy_dataloader = ppo_trainer.dataloader + + # train model with ppo + for query_tensor, response_tensor in dummy_dataloader: + # define a reward for response + # (this could be any reward such as human feedback or output from another model) + reward = [torch.tensor(1.0), torch.tensor(0.0)] + # train model by running a step twice + _ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + + ppo_trainer.model.train() + ppo_trainer.model.gradient_checkpointing_enable() + _ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + break + + # check gradients + for name, param in model.named_parameters(): + if "lora" in name or "v_head" in name: + self.assertTrue(param.grad is not None, f"Parameter {name} has a no gradient") + else: + self.assertTrue(param.grad is None, f"Parameter {name} has a gradient") + + def test_generation(self): + dummy_dataset = self._init_dummy_dataset() + + model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2") + tokenizer = AutoTokenizer.from_pretrained("gpt2") + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=model, + ref_model=None, + tokenizer=tokenizer, + dataset=dummy_dataset, + ) + + input_texts = ["this is a test", "this is another, longer test"] + + generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": tokenizer.eos_token_id} + + tokenizer.pad_token = tokenizer.eos_token + + model_inputs = [tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts] + + generations_batched = ppo_trainer.generate(model_inputs, batch_size=2, **generation_kwargs) + generations_batched = tokenizer.batch_decode(generations_batched) + + generations_single = [ppo_trainer.generate(inputs, **generation_kwargs).squeeze() for inputs in model_inputs] + generations_single = tokenizer.batch_decode(generations_single) + + self.assertEqual(generations_single, generations_batched) + + def test_grad_accumulation(self): + dummy_dataset = self._init_dummy_dataset() + + torch.manual_seed(0) + gpt2_model = AutoModelForCausalLMWithValueHead.from_pretrained(self.model_id, summary_dropout_prob=0.0) + gpt2_model_clone = copy.deepcopy(gpt2_model) + + self.ppo_config.mini_batch_size = 2 + self.ppo_config.ppo_epochs = 1 + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=gpt2_model, + ref_model=None, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + + dummy_dataloader = ppo_trainer.dataloader + + # train model with ppo + for query_tensor, response_tensor in dummy_dataloader: + # define a reward for response + # (this could be any reward such as human feedback or output from another model) + reward = [torch.tensor(1.0), torch.tensor(1.0)] + # train model by running a step twice + _ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + break + + model_grad = gpt2_model.v_head.summary.weight + + self.ppo_config.mini_batch_size = 1 + self.ppo_config.gradient_accumulation_steps = 2 + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=gpt2_model_clone, + ref_model=None, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + + dummy_dataloader = ppo_trainer.dataloader + + # train model with ppo + for query_tensor, response_tensor in dummy_dataloader: + # define a reward for response + # (this could be any reward such as human feedback or output from another model) + reward = [torch.tensor(1.0), torch.tensor(1.0)] + # train model by running a step twice + _ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + break + + model_grad_acc = gpt2_model_clone.v_head.summary.weight + self.assertTrue(torch.allclose(model_grad_acc, model_grad, rtol=1e-3, atol=1e-3)) + + @unittest.skip("Fix by either patching `whomai()` to work in the staging endpoint or use a dummy prod user.") + def test_push_to_hub_if_best_reward(self): + REPO_NAME = "test-ppo-trainer" + repo_id = f"{CI_HUB_USER}/{REPO_NAME}" + + dummy_dataset = self._init_dummy_dataset() + + push_to_hub_if_best_kwargs = {"repo_id": repo_id} + + ppo_config = PPOConfig( + batch_size=2, + mini_batch_size=1, + log_with=None, + push_to_hub_if_best_kwargs=push_to_hub_if_best_kwargs, + compare_steps=1, + ) + + ppo_trainer = PPOTrainer( + config=ppo_config, + model=self.gpt2_model, + ref_model=self.gpt2_model_ref, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + + dummy_dataloader = ppo_trainer.dataloader + # train model with ppo + for query_tensor, response_tensor in dummy_dataloader: + # define a reward for response + # (this could be any reward such as human feedback or output from another model) + reward = [torch.tensor(1.0), torch.tensor(0.0)] + # train model + _ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + break + + def test_batch_size_check(self): + with pytest.raises(ValueError): + PPOConfig(batch_size=2, mini_batch_size=2, gradient_accumulation_steps=2) diff --git a/tests/test_reward_trainer.py b/tests/test_reward_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..1cf63c945f140b4dce5791496349f2bd1bd9a1fb --- /dev/null +++ b/tests/test_reward_trainer.py @@ -0,0 +1,314 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tempfile +import unittest + +import torch +from datasets import Dataset +from transformers import AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction + +from trl import RewardConfig, RewardTrainer +from trl.trainer import compute_accuracy + +from .testing_utils import require_peft + + +class RewardTrainerTester(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab" + cls.model = AutoModelForSequenceClassification.from_pretrained(cls.model_id) + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_id) + cls.tokenizer.pad_token = cls.tokenizer.eos_token + + def test_accuracy_metrics(self): + dummy_eval_predictions = EvalPrediction(torch.FloatTensor([[0.1, 0.9], [0.9, 0.1]]), torch.LongTensor([0, 0])) + accuracy = compute_accuracy(dummy_eval_predictions) + self.assertEqual(accuracy["accuracy"], 0.5) + + def test_reward_trainer(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = RewardConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + evaluation_strategy="steps", + ) + + # fmt: off + dummy_dataset_dict = { + "input_ids_chosen": [ + torch.LongTensor([0, 1, 2,]), + torch.LongTensor([1, 2]), + torch.LongTensor([0, 1, 2,]), + torch.LongTensor([1, 2]), + ], + "attention_mask_chosen": [ + torch.LongTensor([1, 1, 1]), + torch.LongTensor([1, 0]), + torch.LongTensor([1, 1, 1]), + torch.LongTensor([1, 0]), + ], + "input_ids_rejected": [ + torch.LongTensor([0, 2,]), + torch.LongTensor([1, 2, 0]), + torch.LongTensor([0, 2,]), + torch.LongTensor([1, 2, 0]), + ], + "attention_mask_rejected": [ + torch.LongTensor([1, 1]), + torch.LongTensor([1, 1, 0]), + torch.LongTensor([1, 1]), + torch.LongTensor([1, 1, 1]), + ], + } + # fmt: on + dummy_dataset = Dataset.from_dict(dummy_dataset_dict) + + trainer = RewardTrainer( + model=self.model, + args=training_args, + tokenizer=self.tokenizer, + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # check the params have changed - ignore 0 biases + if param.sum() != 0: + self.assertFalse(torch.equal(param, new_param)) + + preds = trainer.predict(dummy_dataset) + self.assertEqual(preds.predictions.shape, (4, 2)) + + @require_peft + def test_reward_trainer_peft(self): + import peft + from peft import LoraConfig, TaskType + + peft_version = peft.__version__ + + peft_config = LoraConfig( + task_type=TaskType.SEQ_CLS, + inference_mode=False, + r=8, + lora_alpha=32, + lora_dropout=0.1, + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = RewardConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=6, + remove_unused_columns=False, + gradient_accumulation_steps=2, + learning_rate=9e-1, + evaluation_strategy="steps", + ) + + # fmt: off + dummy_dataset_dict = { + "input_ids_chosen": [ + torch.LongTensor([0, 1, 2,]), + torch.LongTensor([1, 2]), + torch.LongTensor([0, 1, 2,]), + torch.LongTensor([1, 2]), + ], + "attention_mask_chosen": [ + torch.LongTensor([1, 1, 1]), + torch.LongTensor([1, 0]), + torch.LongTensor([1, 1, 1]), + torch.LongTensor([1, 0]), + ], + "input_ids_rejected": [ + torch.LongTensor([0, 2,]), + torch.LongTensor([1, 2, 0]), + torch.LongTensor([0, 2,]), + torch.LongTensor([1, 2, 0]), + ], + "attention_mask_rejected": [ + torch.LongTensor([1, 1]), + torch.LongTensor([1, 1, 0]), + torch.LongTensor([1, 1]), + torch.LongTensor([1, 1, 1]), + ], + } + # fmt: on + dummy_dataset = Dataset.from_dict(dummy_dataset_dict) + + trainer = RewardTrainer( + model=self.model, + args=training_args, + tokenizer=self.tokenizer, + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, + peft_config=peft_config, + ) + previous_trainable_params = {} + previous_non_trainable_params = {} + + # due to a change in the way the modules to save are dealt in PEFT. + trainable_params_name = ["lora", "score"] if peft_version < "0.3.0" else ["lora", "modules_to_save"] + + # check gradients are not None + for n, param in trainer.model.named_parameters(): + if any([t in n for t in trainable_params_name]): + previous_trainable_params[n] = param.clone() + else: + previous_non_trainable_params[n] = param.clone() + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)) + + # check the non trainable params have not changed + for n, param in previous_non_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertTrue(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)) + + preds = trainer.predict(dummy_dataset) + self.assertEqual(preds.predictions.shape, (4, 2)) + + def test_reward_trainer_assert_value_error(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = RewardConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=1, + remove_unused_columns=False, + ) + + dummy_dataset_dict = { + # fmt: off + "input_ids_b": [ + torch.LongTensor([0, 1, 2,]), + torch.LongTensor([1, 2]), + torch.LongTensor([0, 1, 2,]), + torch.LongTensor([1, 2]), + ], + "attention_mask_c": [ + torch.LongTensor([1, 1, 1]), + torch.LongTensor([1, 0]), + torch.LongTensor([1, 1, 1]), + torch.LongTensor([1, 0]), + ], + "input_ids_f": [ + torch.LongTensor([0, 2,]), + torch.LongTensor([1, 2, 0]), + torch.LongTensor([0, 2,]), + torch.LongTensor([1, 2, 0]), + ], + "attention_mask_g": [ + torch.LongTensor([1, 1]), + torch.LongTensor([1, 1, 0]), + torch.LongTensor([1, 1]), + torch.LongTensor([1, 1, 1]), + ], + # fmt: on + } + dummy_dataset = Dataset.from_dict(dummy_dataset_dict) + + trainer = RewardTrainer( + model=self.model, + args=training_args, + tokenizer=self.tokenizer, + train_dataset=dummy_dataset, + ) + + with self.assertRaises(ValueError): + trainer.train() + + training_args = RewardConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=1, + remove_unused_columns=True, + ) + + with self.assertWarns(UserWarning): + trainer = RewardTrainer( + model=self.model, + args=training_args, + tokenizer=self.tokenizer, + train_dataset=dummy_dataset, + ) + + def test_reward_trainer_margin(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = RewardConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + evaluation_strategy="steps", + ) + + # fmt: off + dummy_dataset_dict = { + "input_ids_chosen": [ + torch.LongTensor([0, 1, 2,]), + ], + "attention_mask_chosen": [ + torch.LongTensor([1, 1, 1]), + ], + "input_ids_rejected": [ + torch.LongTensor([0, 2,]), + ], + "attention_mask_rejected": [ + torch.LongTensor([1, 1]), + ], + "margin": [ + torch.FloatTensor([1.0]), + ] + } + # fmt: on + dummy_dataset = Dataset.from_dict(dummy_dataset_dict) + + trainer = RewardTrainer( + model=self.model, + args=training_args, + tokenizer=self.tokenizer, + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, + ) + + batch = [dummy_dataset[0]] + batch = trainer.data_collator(batch) + loss, outputs = trainer.compute_loss(trainer.model, batch, return_outputs=True) + + self.assertAlmostEqual( + loss, + -torch.nn.functional.logsigmoid( + outputs["rewards_chosen"] - outputs["rewards_rejected"] - batch["margin"] + ).mean(), + ) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..f430c7b48618a9eded0599ef2e1a53a0363dd7ef --- /dev/null +++ b/tests/test_sft_trainer.py @@ -0,0 +1,791 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import os +import tempfile +import unittest + +import numpy as np +import torch +from datasets import Dataset +from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments + +from trl import SFTTrainer +from trl.import_utils import is_peft_available +from trl.trainer import ConstantLengthDataset, DataCollatorForCompletionOnlyLM + +from .testing_utils import require_peft + + +def formatting_prompts_func(example): + text = f"### Question: {example['question']}\n ### Answer: {example['answer']}" + return text + + +def formatting_prompts_func_batched(example): + output_text = [] + for i, question in enumerate(example["question"]): + text = f"### Question: {question}\n ### Answer: {example['answer'][i]}" + output_text.append(text) + return output_text + + +if is_peft_available(): + from peft import LoraConfig, PeftModel + + +class SFTTrainerTester(unittest.TestCase): + r""" """ + + @classmethod + def setUpClass(cls): + cls.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab" + cls.model = AutoModelForCausalLM.from_pretrained(cls.model_id) + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_id) + cls.tokenizer.pad_token = cls.tokenizer.eos_token + cls.dummy_dataset = Dataset.from_dict( + { + "question": [ + "Does llamas know how to code?", + "Does llamas know how to fly?", + "Does llamas know how to talk?", + "Does llamas know how to code?", + "Does llamas know how to fly?", + "Does llamas know how to talk?", + "Does llamas know how to swim?", + ], + "answer": [ + "Yes, llamas are very good at coding.", + "No, llamas can't fly.", + "Yes, llamas are very good at talking.", + "Yes, llamas are very good at coding.", + "No, llamas can't fly.", + "Yes, llamas are very good at talking.", + "No, llamas can't swim.", + ], + "text": [ + "### Question: Does llamas know how to code?\n ### Answer: Yes, llamas are very good at coding.", + "### Question: Does llamas know how to fly?\n ### Answer: No, llamas can't fly.", + "### Question: Does llamas know how to talk?\n ### Answer: Yes, llamas are very good at talking.", + "### Question: Does llamas know how to code?\n ### Answer: Yes, llamas are very good at coding.", + "### Question: Does llamas know how to fly?\n ### Answer: No, llamas can't fly.", + "### Question: Does llamas know how to talk?\n ### Answer: Yes, llamas are very good at talking.", + "### Question: Does llamas know how to swim?\n ### Answer: No, llamas can't swim.", + ], + } + ) + + cls.train_dataset = ConstantLengthDataset( + cls.tokenizer, + cls.dummy_dataset, + dataset_text_field=None, + formatting_func=formatting_prompts_func, + seq_length=16, + num_of_sequences=16, + ) + + cls.eval_dataset = ConstantLengthDataset( + cls.tokenizer, + cls.dummy_dataset, + dataset_text_field=None, + formatting_func=formatting_prompts_func, + seq_length=16, + num_of_sequences=16, + ) + + def test_constant_length_dataset(self): + formatted_dataset = ConstantLengthDataset( + self.tokenizer, + self.dummy_dataset, + dataset_text_field=None, + formatting_func=formatting_prompts_func, + ) + + self.assertTrue(len(formatted_dataset) == len(self.dummy_dataset)) + self.assertTrue(len(formatted_dataset) > 0) + + for example in formatted_dataset: + self.assertTrue("input_ids" in example) + self.assertTrue("labels" in example) + + self.assertTrue(len(example["input_ids"]) == formatted_dataset.seq_length) + self.assertTrue(len(example["labels"]) == formatted_dataset.seq_length) + + decoded_text = self.tokenizer.decode(example["input_ids"]) + self.assertTrue(("Question" in decoded_text) and ("Answer" in decoded_text)) + + def test_sft_trainer(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + dataloader_drop_last=True, + evaluation_strategy="steps", + max_steps=4, + eval_steps=2, + save_steps=2, + per_device_train_batch_size=2, + ) + + trainer = SFTTrainer( + model=self.model_id, + args=training_args, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + packing=True, + ) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"]) + + self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")) + + def test_sft_trainer_uncorrect_data(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + dataloader_drop_last=True, + evaluation_strategy="steps", + max_steps=2, + eval_steps=1, + save_steps=1, + per_device_train_batch_size=2, + ) + + with self.assertRaises(ValueError): + _ = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.dummy_dataset, + packing=True, + ) + + # This should work + _ = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.dummy_dataset, + formatting_func=formatting_prompts_func, + max_seq_length=32, # make sure there is at least 1 packed sequence + packing=True, + ) + + with self.assertRaises(ValueError): + # This should not work because not enough data for one sample + _ = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.dummy_dataset, + formatting_func=formatting_prompts_func, + max_seq_length=1024, # make sure there is NOT at least 1 packed sequence + packing=True, + ) + + # This should not work as well + with self.assertRaises(ValueError): + _ = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.dummy_dataset, + formatting_func=formatting_prompts_func, + packing=False, + ) + + # but this should work + _ = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.dummy_dataset, + formatting_func=formatting_prompts_func_batched, + packing=False, + ) + + def test_sft_trainer_with_model_num_train_epochs(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + dataloader_drop_last=True, + evaluation_strategy="steps", + max_steps=2, + eval_steps=1, + save_steps=1, + num_train_epochs=2, + per_device_train_batch_size=2, + ) + + trainer = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + packing=True, + ) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"]) + + self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + dataloader_drop_last=True, + evaluation_strategy="steps", + max_steps=2, + save_steps=1, + num_train_epochs=2, + per_device_train_batch_size=2, + ) + + trainer = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.dummy_dataset, + dataset_text_field="text", + max_seq_length=16, + num_of_sequences=16, + packing=True, + ) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + dataloader_drop_last=True, + evaluation_strategy="steps", + max_steps=2, + save_steps=1, + num_train_epochs=2, + per_device_train_batch_size=2, + ) + + trainer = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.dummy_dataset, + dataset_text_field="text", + max_seq_length=16, + ) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-1")) + + def test_sft_trainer_with_model(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + dataloader_drop_last=True, + evaluation_strategy="steps", + max_steps=2, + eval_steps=1, + save_steps=1, + per_device_train_batch_size=2, + ) + + trainer = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + packing=True, + ) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"]) + + self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + dataloader_drop_last=True, + evaluation_strategy="steps", + max_steps=2, + save_steps=1, + per_device_train_batch_size=2, + ) + + trainer = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.dummy_dataset, + dataset_text_field="text", + max_seq_length=16, + num_of_sequences=16, + packing=True, + ) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")) + + # with formatting_func + packed + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + dataloader_drop_last=True, + evaluation_strategy="steps", + max_steps=2, + save_steps=1, + per_device_train_batch_size=2, + ) + + trainer = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.dummy_dataset, + formatting_func=formatting_prompts_func, + max_seq_length=16, + num_of_sequences=16, + packing=True, + ) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")) + + # with formatting_func + packed + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + dataloader_drop_last=True, + evaluation_strategy="steps", + max_steps=2, + save_steps=1, + per_device_train_batch_size=2, + ) + + trainer = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.dummy_dataset, + formatting_func=formatting_prompts_func_batched, + max_seq_length=16, + ) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + dataloader_drop_last=True, + evaluation_strategy="steps", + max_steps=2, + save_steps=1, + per_device_train_batch_size=2, + ) + + trainer = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.dummy_dataset, + dataset_text_field="text", + max_seq_length=16, + ) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-1")) + + def test_sft_trainer_with_multiple_eval_datasets(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + dataloader_drop_last=True, + evaluation_strategy="steps", + max_steps=1, + eval_steps=1, + save_steps=1, + per_device_train_batch_size=2, + ) + + trainer = SFTTrainer( + model=self.model_id, + args=training_args, + train_dataset=self.train_dataset, + eval_dataset={ + "data1": self.eval_dataset, + "data2": self.eval_dataset, + }, + packing=True, + ) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + self.assertIsNotNone(trainer.state.log_history[0]["eval_data1_loss"]) + self.assertIsNotNone(trainer.state.log_history[1]["eval_data2_loss"]) + + self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-1")) + + def test_data_collator_completion_lm(self): + response_template = "### Response:\n" + data_collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=self.tokenizer, mlm=False) + + text = """\n\n### Instructions:\nHello all this should be masked\n\n### Response:\nI have not been masked correctly.""" + encoded_text = self.tokenizer(text) + + examples = [encoded_text] + + batch = data_collator(examples) + labels = batch["labels"] + last_pad_idx = np.where(labels == -100)[1][-1] + result_text = self.tokenizer.decode(batch["input_ids"][0, last_pad_idx + 1 :]) + self.assertEqual(result_text, "I have not been masked correctly.") + + def test_data_collator_completion_lm_with_multiple_text(self): + tokenizer = copy.deepcopy(self.tokenizer) + tokenizer.padding_side = "left" + + response_template = "### Response:\n" + data_collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer, mlm=False) + + text1 = """\n\n### Instructions:\nHello all this should be masked\n\n### Response:\nI have not been masked correctly.""" + text2 = """\n\n### Instructions:\nThis is another longer text that should also be masked. This text is significantly longer than the previous one.\n\n### Response:\nI have not been masked correctly.""" + + encoded_text1 = tokenizer(text1) + encoded_text2 = tokenizer(text2) + + examples = [encoded_text1, encoded_text2] + + batch = data_collator(examples) + + for i in range(2): + labels = batch["labels"][i] + last_pad_idx = np.where(labels == -100)[0][-1] + result_text = tokenizer.decode(batch["input_ids"][i, last_pad_idx + 1 :]) + self.assertEqual(result_text, "I have not been masked correctly.") + + def test_data_collator_chat_completion_lm(self): + instruction_template = "### Human:" + assistant_template = "### Assistant:" + data_collator = DataCollatorForCompletionOnlyLM( + response_template=assistant_template, + instruction_template=instruction_template, + tokenizer=self.tokenizer, + mlm=False, + ) + + text = """### Human: Hello all this should be masked.### Assistant: I should not be masked.### Human: All this should be masked too.### Assistant: I should not be masked too.""" + encoded_text = self.tokenizer(text) + + examples = [encoded_text] + + batch = data_collator(examples) + labels = batch["labels"] + non_masked_tokens = batch["input_ids"][labels != -100] + result_text = self.tokenizer.decode(non_masked_tokens) + self.assertEqual(result_text, " I should not be masked. I should not be masked too.") + + def test_data_collator_chat_completion_lm_with_multiple_text(self): + tokenizer = copy.deepcopy(self.tokenizer) + tokenizer.padding_side = "left" + + instruction_template = "### Human:" + assistant_template = "### Assistant:" + data_collator = DataCollatorForCompletionOnlyLM( + response_template=assistant_template, + instruction_template=instruction_template, + tokenizer=tokenizer, + mlm=False, + ) + + text1 = """### Human: Hello all this should be masked.### Assistant: I should not be masked.""" + text2 = """### Human: Hello all this should be masked.### Assistant: I should not be masked.### Human: All this should be masked too.### Assistant: I should not be masked too.""" + encoded_text1 = tokenizer(text1) + encoded_text2 = tokenizer(text2) + + examples = [encoded_text1, encoded_text2] + + batch = data_collator(examples) + labels = batch["labels"] + input_ids = batch["input_ids"] + + non_masked_tokens1 = input_ids[0][labels[0] != -100] + result_text1 = tokenizer.decode(non_masked_tokens1) + self.assertEqual(result_text1, " I should not be masked.") + + non_masked_tokens2 = input_ids[1][labels[1] != -100] + result_text2 = tokenizer.decode(non_masked_tokens2) + self.assertEqual(result_text2, " I should not be masked. I should not be masked too.") + + def test_sft_trainer_infinite_with_model(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + dataloader_drop_last=True, + evaluation_strategy="steps", + max_steps=5, + eval_steps=1, + save_steps=1, + per_device_train_batch_size=2, + ) + + trainer = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + packing=True, + max_seq_length=500, + ) + + self.assertTrue(trainer.train_dataset.infinite) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"]) + + # make sure the trainer did 5 steps + self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-5")) + + def test_sft_trainer_infinite_with_model_epochs(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + dataloader_drop_last=True, + num_train_epochs=1, + per_device_train_batch_size=2, + save_strategy="epoch", + ) + + trainer = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + packing=True, + max_seq_length=500, + ) + + self.assertFalse(trainer.train_dataset.infinite) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # make sure the trainer did 5 steps + self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-4")) + + def test_sft_trainer_with_model_neftune(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + dataloader_drop_last=True, + evaluation_strategy="steps", + max_steps=2, + eval_steps=1, + save_steps=1, + per_device_train_batch_size=2, + ) + + trainer = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + neftune_noise_alpha=5, + packing=True, + ) + + trainer.model = trainer._trl_activate_neftune(trainer.model) + + device = trainer.model.get_input_embeddings().weight.device + trainer.model.train() + + torch.random.manual_seed(42) + embeds_neftune = trainer.model.get_input_embeddings()(torch.LongTensor([[1, 0, 1]]).to(device)) + + torch.random.manual_seed(24) + embeds_neftune_2 = trainer.model.get_input_embeddings()(torch.LongTensor([[1, 0, 1]]).to(device)) + + self.assertFalse(torch.allclose(embeds_neftune, embeds_neftune_2)) + self.assertTrue(len(trainer.model.get_input_embeddings()._forward_hooks) > 0) + + trainer.neftune_hook_handle.remove() + + trainer.train() + + # Make sure forward pass works fine + _ = trainer.model(torch.LongTensor([[1, 0, 1]]).to(device)) + self.assertTrue(len(trainer.model.get_input_embeddings()._forward_hooks) == 0) + + @require_peft + def test_peft_sft_trainer(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + dataloader_drop_last=True, + evaluation_strategy="steps", + max_steps=4, + eval_steps=2, + save_steps=2, + per_device_train_batch_size=2, + ) + + peft_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + trainer = SFTTrainer( + model=self.model_id, + args=training_args, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + peft_config=peft_config, + packing=True, + ) + + self.assertTrue(isinstance(trainer.model, PeftModel)) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"]) + + self.assertTrue("adapter_model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")) + self.assertTrue("adapter_config.json" in os.listdir(tmp_dir + "/checkpoint-2")) + self.assertTrue("model.safetensors" not in os.listdir(tmp_dir + "/checkpoint-2")) + + @require_peft + def test_peft_sft_trainer_gc(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + dataloader_drop_last=True, + evaluation_strategy="steps", + max_steps=4, + eval_steps=2, + save_steps=2, + per_device_train_batch_size=2, + gradient_checkpointing=True, + ) + + peft_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + trainer = SFTTrainer( + model=self.model_id, + args=training_args, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + peft_config=peft_config, + packing=True, + ) + + self.assertTrue(isinstance(trainer.model, PeftModel)) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"]) + + self.assertTrue("adapter_model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")) + self.assertTrue("adapter_config.json" in os.listdir(tmp_dir + "/checkpoint-2")) + self.assertTrue("model.safetensors" not in os.listdir(tmp_dir + "/checkpoint-2")) + + @require_peft + def test_peft_sft_trainer_neftune(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + dataloader_drop_last=True, + evaluation_strategy="steps", + max_steps=4, + eval_steps=2, + save_steps=2, + per_device_train_batch_size=2, + ) + + peft_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + trainer = SFTTrainer( + model=self.model_id, + args=training_args, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + peft_config=peft_config, + neftune_noise_alpha=5, + packing=True, + ) + + trainer.model = trainer._trl_activate_neftune(trainer.model) + + self.assertTrue(isinstance(trainer.model, PeftModel)) + + device = trainer.model.get_input_embeddings().weight.device + trainer.model.train() + + torch.random.manual_seed(42) + embeds_neftune = trainer.model.get_input_embeddings()(torch.LongTensor([[1, 0, 1]]).to(device)) + + torch.random.manual_seed(24) + embeds_neftune_2 = trainer.model.get_input_embeddings()(torch.LongTensor([[1, 0, 1]]).to(device)) + + self.assertFalse(torch.allclose(embeds_neftune, embeds_neftune_2)) + self.assertTrue(len(trainer.model.get_input_embeddings()._forward_hooks) > 0) + + trainer.neftune_hook_handle.remove() + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"]) + + self.assertTrue("adapter_model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")) + self.assertTrue("adapter_config.json" in os.listdir(tmp_dir + "/checkpoint-2")) + self.assertTrue("model.safetensors" not in os.listdir(tmp_dir + "/checkpoint-2")) + + # Make sure forward pass works fine to check if embeddings forward is not broken. + _ = trainer.model(torch.LongTensor([[1, 0, 1]]).to(device)) + self.assertTrue(len(trainer.model.get_input_embeddings()._forward_hooks) == 0) diff --git a/tests/testing_constants.py b/tests/testing_constants.py new file mode 100644 index 0000000000000000000000000000000000000000..164782130add699864c8312ff82596ecd5eb87d1 --- /dev/null +++ b/tests/testing_constants.py @@ -0,0 +1,19 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +CI_HUB_USER = "__DUMMY_TRANSFORMERS_USER__" +CI_HUB_USER_FULL_NAME = "Dummy User" +CI_HUB_USER_TOKEN = "hf_94wBhPGp6KrrTH3KDchhKpRxZwd6dmHWLL" + +CI_HUB_ENDPOINT = "https://hub-ci.huggingface.co" diff --git a/tests/testing_utils.py b/tests/testing_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f3988de4c9f2498ac796b9b443077583e06a44cf --- /dev/null +++ b/tests/testing_utils.py @@ -0,0 +1,84 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import torch + +from trl import is_diffusers_available, is_peft_available, is_wandb_available, is_xpu_available + + +def require_peft(test_case): + """ + Decorator marking a test that requires peft. Skips the test if peft is not available. + """ + if not is_peft_available(): + test_case = unittest.skip("test requires peft")(test_case) + return test_case + + +def require_diffusers(test_case): + """ + Decorator marking a test that requires diffusers. Skips the test if diffusers is not available. + """ + if not is_diffusers_available(): + test_case = unittest.skip("test requires diffusers")(test_case) + return test_case + + +def require_wandb(test_case, required: bool = True): + """ + Decorator marking a test that requires wandb. Skips the test if wandb is not available. + """ + # XOR, i.e.: + # skip if available and required = False and + # skip if not available and required = True + if is_wandb_available() ^ required: + test_case = unittest.skip("test requires wandb")(test_case) + return test_case + + +def require_no_wandb(test_case): + """ + Decorator marking a test that requires no wandb. Skips the test if wandb is available. + """ + return require_wandb(test_case, required=False) + + +def require_bitsandbytes(test_case): + """ + Decorator marking a test that requires bitsandbytes. Skips the test if bitsandbytes is not available. + """ + try: + import bitsandbytes # noqa: F401 + except ImportError: + test_case = unittest.skip("test requires bitsandbytes")(test_case) + return test_case + + +def require_torch_multi_gpu(test_case): + """ + Decorator marking a test that requires multiple GPUs. Skips the test if there aren't enough GPUs. + """ + if torch.cuda.device_count() < 2: + test_case = unittest.skip("test requires multiple GPUs")(test_case) + return test_case + + +def require_torch_multi_xpu(test_case): + """ + Decorator marking a test that requires multiple XPUs. Skips the test if there aren't enough XPUs. + """ + if torch.xpu.device_count() < 2 and is_xpu_available(): + test_case = unittest.skip("test requires multiple XPUs")(test_case) + return test_case diff --git a/trl/__init__.py b/trl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a00f3a684db658bf5241025978aaf31773180938 --- /dev/null +++ b/trl/__init__.py @@ -0,0 +1,34 @@ +# flake8: noqa + +__version__ = "0.7.5.dev0" + +from .core import set_seed +from .environment import TextEnvironment, TextHistory +from .extras import BestOfNSampler +from .import_utils import is_diffusers_available, is_peft_available, is_wandb_available, is_xpu_available +from .models import ( + AutoModelForCausalLMWithValueHead, + AutoModelForSeq2SeqLMWithValueHead, + PreTrainedModelWrapper, + create_reference_model, +) +from .trainer import ( + DataCollatorForCompletionOnlyLM, + DPOTrainer, + IterativeSFTTrainer, + PPOConfig, + PPOTrainer, + RewardConfig, + RewardTrainer, + SFTTrainer, +) + + +if is_diffusers_available(): + from .models import ( + DDPOPipelineOutput, + DDPOSchedulerOutput, + DDPOStableDiffusionPipeline, + DefaultDDPOStableDiffusionPipeline, + ) + from .trainer import DDPOConfig, DDPOTrainer diff --git a/trl/core.py b/trl/core.py new file mode 100644 index 0000000000000000000000000000000000000000..3180fa69ed76b1632427797a49d5177306246e69 --- /dev/null +++ b/trl/core.py @@ -0,0 +1,328 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import gc +import random +import warnings +from contextlib import contextmanager +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence +from transformers import top_k_top_p_filtering + +from .import_utils import is_xpu_available + + +try: + from collections.abc import Mapping +except ImportError: + from collections import Mapping + + +WANDB_PADDING = -1 + + +def flatten_dict(nested, sep="/"): + """Flatten dictionary and concatenate nested keys with separator.""" + + def rec(nest, prefix, into): + for k, v in nest.items(): + if sep in k: + raise ValueError(f"separator '{sep}' not allowed to be in key '{k}'") + if isinstance(v, Mapping): + rec(v, prefix + k + sep, into) + else: + into[prefix + k] = v + + flat = {} + rec(nested, "", flat) + return flat + + +def convert_to_scalar(stats): + """ + Converts the stats from a flattened dict to single scalar dicts + """ + tensorboard_stats = {} + for k, v in stats.items(): + # for tensorboard compatibility - arrays and tensors are ignored with tensorboard + # therefore we convert single element tensors to scalars + if (isinstance(v, torch.Tensor) or isinstance(v, np.ndarray)) and ( + len(v.shape) == 0 or (len(v.shape) == 1 and v.shape[0] == 1) + ): + v = v.item() + tensorboard_stats[k] = v + return tensorboard_stats + + +def stack_dicts(stats_dicts): + """Stack the values of a dict.""" + results = dict() + for k in stats_dicts[0]: + stats_list = [torch.flatten(d[k]) for d in stats_dicts] + results[k] = pad_sequence(stats_list, batch_first=True, padding_value=WANDB_PADDING) + return results + + +def add_suffix(input_dict, suffix): + """Add suffix to dict keys.""" + return dict((k + suffix, v) for k, v in input_dict.items()) + + +def pad_to_size(tensor, size, dim=1, padding=50256): + """Pad tensor to size.""" + t_size = tensor.size()[dim] + if t_size == size: + return tensor + else: + return torch.nn.functional.pad(tensor, (0, size - t_size), "constant", padding) + + +def logprobs_from_logits(logits, labels, gather=True): + """ + See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591 + """ + logp = F.log_softmax(logits, dim=2) + + if not gather: + return logp + logpy = torch.gather(logp, 2, labels.unsqueeze(2)).squeeze(-1) + return logpy + + +def whiten(values, shift_mean=True): + """Whiten values.""" + mean, var = torch.mean(values), torch.var(values) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + +def masked_mean(values, mask, axis=None): + """Compute mean of tensor with a masked values.""" + if axis is not None: + return (values * mask).sum(axis=axis) / mask.sum(axis=axis) + else: + return (values * mask).sum() / mask.sum() + + +def masked_var(values, mask, unbiased=True): + """Compute variance of tensor with masked values.""" + mean = masked_mean(values, mask) + centered_values = values - mean + variance = masked_mean(centered_values**2, mask) + if unbiased: + mask_sum = mask.sum() + if mask_sum == 0: + raise ValueError( + "The sum of the mask is zero, which can happen when `mini_batch_size=1`;" + "try increase the `mini_batch_size` or `gradient_accumulation_steps`" + ) + # note that if mask_sum == 1, then there is a division by zero issue + # to avoid it you just need to use a larger minibatch_size + bessel_correction = mask_sum / (mask_sum - 1) + variance = variance * bessel_correction + return variance + + +def masked_whiten(values, mask, shift_mean=True): + """Whiten values with masked values.""" + mean, var = masked_mean(values, mask), masked_var(values, mask) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + +def clip_by_value(x, tensor_min, tensor_max): + """ + Tensor extenstion to torch.clamp + https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713 + """ + clipped = torch.max(torch.min(x, tensor_max), tensor_min) + return clipped + + +def entropy_from_logits(logits): + """Calculate entropy from logits.""" + pd = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, axis=-1) - torch.sum(pd * logits, axis=-1) + return entropy + + +def average_torch_dicts(list_of_dicts): + """Average values of a list of dicts with torch tensors.""" + average_dict = dict() + for key in list_of_dicts[0].keys(): + average_dict[key] = torch.mean(torch.stack([d[key] for d in list_of_dicts]), axis=0) + return average_dict + + +def stats_to_np(stats_dict): + """Cast all torch.tensors in dict to numpy arrays.""" + new_dict = dict() + for k, v in stats_dict.items(): + if isinstance(v, torch.Tensor): + new_dict[k] = v.detach().cpu() + if new_dict[k].dtype == torch.bfloat16: + new_dict[k] = new_dict[k].float() + new_dict[k] = new_dict[k].numpy() + else: + new_dict[k] = v + if np.isscalar(new_dict[k]): + new_dict[k] = float(new_dict[k]) + return new_dict + + +def listify_batch(tensor): + """Turns the first dimension of a tensor into a list.""" + return [tensor[i] for i in range(tensor.shape[0])] + + +def build_bert_batch_from_txt(text_list, tokenizer, device): + """Create token id and attention mask tensors from text list for BERT classification.""" + + # tokenize + tensors = [tokenizer.encode(txt, return_tensors="pt").to(device) for txt in text_list] + + # find max length to pad to + max_len = max([t.size()[1] for t in tensors]) + + # get padded tensors and attention masks + # (attention masks make bert ignore padding) + padded_tensors = [] + attention_masks = [] + for tensor in tensors: + attention_mask = torch.ones(tensor.size(), device=device) + padded_tensors.append(pad_to_size(tensor, max_len, padding=0)) + attention_masks.append(pad_to_size(attention_mask, max_len, padding=0)) + + # stack all tensors + padded_tensors = torch.cat(padded_tensors) + attention_masks = torch.cat(attention_masks) + + return padded_tensors, attention_masks + + +def respond_to_batch(model, queries, txt_len=20, top_k=0, top_p=1.0): + """Sample text from language model.""" + input_ids = queries + for i in range(txt_len): + # Get Logits + outputs = model(input_ids) + next_token_logits = outputs[0][:, -1, :] + next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) + # Sample + probs = F.softmax(next_token_logits, dim=-1) + next_token = torch.multinomial(probs, num_samples=1).squeeze(1) + input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1) + return input_ids[:, -txt_len:] + + +def set_seed(seed: int): + """ + Helper function for reproducible behavior to set the seed in `random`, `numpy`, and `torch`. + + Args: + seed (`int`): The seed to set. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if is_xpu_available(): + torch.xpu.manual_seed_all(seed) + else: + torch.cuda.manual_seed_all(seed) + + +class LengthSampler: + """ + Samples a length + """ + + def __init__(self, min_value, max_value): + self.values = list(range(min_value, max_value)) + + def __call__(self): + return np.random.choice(self.values) + + +class PPODecorators(object): + optimize_device_cache = False + + @classmethod + @contextmanager + def empty_device_cache(cls): + yield + if is_xpu_available(): + if cls.optimize_device_cache and torch.xpu.is_available(): + gc.collect() + torch.xpu.empty_cache() + gc.collect() + else: + if cls.optimize_device_cache and torch.cuda.is_available(): + gc.collect() + torch.cuda.empty_cache() + gc.collect() + + +def randn_tensor( + shape: Union[Tuple, List], + generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, + device: Optional["torch.device"] = None, + dtype: Optional["torch.dtype"] = None, + layout: Optional["torch.layout"] = None, +): + """A helper function to create random tensors on the desired `device` with the desired `dtype`. When + passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor + is always created on the CPU. + """ + # device on which tensor is created defaults to device + rand_device = device + batch_size = shape[0] + + layout = layout or torch.strided + device = device or torch.device("cpu") + + if generator is not None: + gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type + if gen_device_type != device.type and gen_device_type == "cpu": + rand_device = "cpu" + if device != "mps": + warnings.warn( + f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." + f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" + f" slighly speed up this function by passing a generator that was created on the {device} device." + ) + elif gen_device_type != device.type and gen_device_type == "cuda": + raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") + + # make sure generator list of length 1 is treated like a non-list + if isinstance(generator, list) and len(generator) == 1: + generator = generator[0] + + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) + else: + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) + + return latents diff --git a/trl/environment/__init__.py b/trl/environment/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae1cda4ecb2e604cc990ce16d982df29846f5204 --- /dev/null +++ b/trl/environment/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa + +from .base_environment import TextEnvironment, TextHistory diff --git a/trl/environment/base_environment.py b/trl/environment/base_environment.py new file mode 100644 index 0000000000000000000000000000000000000000..25f44ae9355c0da6fd8fe19759b8d8a09c888fc4 --- /dev/null +++ b/trl/environment/base_environment.py @@ -0,0 +1,473 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import warnings + +import torch +from accelerate.utils import extract_model_from_parallel +from transformers import StoppingCriteria, StoppingCriteriaList + +from ..import_utils import is_rich_available + + +if is_rich_available(): + from rich import print + from rich.text import Text + + +class StringStoppingCriteria(StoppingCriteria): + """Custom `StoppingCriteria` which checks if all generations in the batch are completed.""" + + def __init__(self, stop_strings, tokenizer): + self.stop_strings = stop_strings + self.tokenizer = tokenizer + self.first_call = True + + def __call__(self, input_ids, scores, **kwargs): + """Returns true if all generated sequences contain any of the stop strings.""" + if self.first_call: + self.generated_tokens = [1 for _ in range(input_ids.shape[0])] + self.start_length = input_ids.shape[-1] - 1 + self.first_call = False + decoded_generations = self.tokenizer.batch_decode(input_ids[:, self.start_length :]) + done = [] + + for i, decoded_generation in enumerate(decoded_generations): + sequence_complete = any([stop_string in decoded_generation for stop_string in self.stop_strings]) + done.append(sequence_complete) + if not sequence_complete: + self.generated_tokens[i] += 1 + + if all(done): + self.first_call = True + + return all(done) + + +class TextHistory: + """The TextHistory class keeps track of the history of an interaction between the language model and the environment.""" + + def __init__(self, text, tokens, system=True): + """ + Initialize TextHistory. + + args: + text (`str`): The text of the first segment. + tokens (`torch.LongTensor`): The tokens of the first segment. + system (`bool`, *optional*): Whether the first segment is a system or user segment. + """ + self.system_spans = [] + self.text_spans = [] + self.token_spans = [] + self.token_masks = torch.tensor([], dtype=torch.long).to(tokens.device) + self.text = "" + self.tokens = torch.tensor([], dtype=torch.long).to(tokens.device) + self.completed = False + self.truncated = False + self.reward = 0.0 + + self.prompt_color = "black on grey85" + self.system_color = "black on cyan3" + self.model_color = "black on deep_sky_blue1" + self.reward_color = "black on plum1" + + self.append_segment(text, tokens, system=system) + + def append_segment(self, text, tokens, system=True): + """ + Append a new segment to the history. + + args: + text (`str`): The text of the new segment. + tokens (`torch.LongTensor`): The tokens of the new segment. + system (`bool`, *optional*): Whether the new segment is a system or user segment. + """ + + if len(text) == 0 or len(tokens) == 0: + raise ValueError("Can't append empty text or token list to history.") + + original_text_length = len(self.text) + + self.text += text + self.text_spans.append((original_text_length, len(self.text))) + self.system_spans.append(system) + + original_token_length = len(self.tokens) + + self.tokens = torch.cat((self.tokens, tokens)) + if system: + self.token_masks = torch.cat((self.token_masks, torch.zeros_like(tokens))) + else: + self.token_masks = torch.cat((self.token_masks, torch.ones_like(tokens))) + self.token_spans.append((original_token_length, len(self.tokens))) + + def complete(self, truncated=False): + """ + Mark the history as completed. + """ + self.completed = True + self.truncated = truncated + + @property + def last_text_segment(self): + """ + Get the last text segment. + """ + start, end = self.text_spans[-1] + return self.text[start:end] + + def split_query_response_tokens(self): + """ + Split the tokens into query and response tokens. + """ + split_index = self.token_spans[0][1] + query = self.tokens[:split_index] + response = self.tokens[split_index:] + mask = self.token_masks[split_index:] + + return query, response, mask + + def show_text(self, show_legend=False): + """ + Print the text history. + """ + if not is_rich_available(): + warnings.warn("install rich to display text") + return + + text = Text(self.text) + text.stylize(self.prompt_color, self.text_spans[0][0], self.text_spans[1][0]) + for i, (start, end) in enumerate(self.text_spans[1:]): + if self.system_spans[i + 1]: + text.stylize(self.system_color, start, end) + else: + text.stylize(self.model_color, start, end) + + text.append(f"\n\nReward: {self.reward}", style=self.reward_color) + print(text) + + if show_legend: + self.show_colour_legend() + + def show_tokens(self, tokenizer, show_legend=False): + """ + Print the history tokens. + """ + if not is_rich_available(): + warnings.warn("install rich to display tokens") + return + + text = Text() + prompt_end = self.token_spans[0][1] + for i, (token, mask) in enumerate(zip(self.tokens, self.token_masks)): + if i < prompt_end: + text.append(tokenizer.convert_ids_to_tokens(token.item()), style=self.prompt_color) + text.append(" ") + elif mask == 0: + text.append(tokenizer.convert_ids_to_tokens(token.item()), style=self.system_color) + text.append(" ") + else: + text.append(tokenizer.convert_ids_to_tokens(token.item()), style=self.model_color) + text.append(" ") + text.append(f"\n\nReward: {self.reward}", style=self.reward_color) + print(text) + if show_legend: + self.show_colour_legend() + + def show_colour_legend(self): + """ + Print the colour legend. + """ + if not is_rich_available(): + warnings.warn("install rich to display colour legend") + return + text = Text("\n\n(Colour Legend: ") + text.append("Prompt", style=self.prompt_color) + text.append("|") + text.append("System", style=self.system_color) + text.append("|") + text.append("Model", style=self.model_color) + text.append("|") + text.append("Reward", style=self.reward_color) + text.append(")") + print(text) + + +class TextEnvironment: + """ + The TextEnvironment enables interaction of a LLM with an environment using tools. + """ + + def __init__( + self, + model=None, + tokenizer=None, + tools=None, + reward_fn=None, + prompt=None, + max_turns=4, + max_tool_reponse=100, + max_length=None, + generation_kwargs=None, + ): + """ + Initialize TextEnvironment. + + Args: + model (`PreTrainedModelWrapper`): The model to use for generation. + tokenizer (`transformers.PreTrainedTokenizer`): The tokenizer to use for generation. + tools (list): A list of tools to use for interaction. + reward_fn (function): A function that takes a string and returns a reward. + prompt (str): The base prompt to use for generation. Is prepended to the tasks. + max_turns (Optional[int]): The maximum number of turns to allow. + max_tool_response (Optional[int]): The maximum number of characters to allow in a tool response. + max_length (Optional[int]): The maximum number of tokens to allow in an episode. + generation_kwargs (Optional[dict]): A dictionary of keyword arguments to pass to the model's generate method. + """ + self.model = model + self.tokenizer = tokenizer + self.prompt = prompt + if isinstance(tools, dict): + self.tools = tools + else: + self.tools = dict([(tool.__class__.__name__, tool) for tool in tools]) + self.reward_fn = reward_fn + self.max_length = max_length + self.request_token = "" + self.call_token = "" + self.response_token = "" + self.submit_token = "" + self.max_turns = max_turns + self.max_tool_response = max_tool_reponse + + if generation_kwargs is None: + self.generation_kwargs = dict() + else: + self.generation_kwargs = generation_kwargs + + self.is_encoder_decoder = hasattr(self.model, "is_encoder_decoder") + self.current_device = extract_model_from_parallel(self.model).pretrained_model.device + + def run(self, queries, **rewards_kwargs): + """ + Run the environment on a list of queries. + + Args: + queries (list[str]): A list of queries to run the model in the environment on. + """ + turns = 0 + + queries = [self.prompt + task for task in queries] + queries_tokens = [ + self.tokenizer(query, return_tensors="pt").input_ids[0].to(self.model.pretrained_model.device) + for query in queries + ] + + histories = [TextHistory(q, qt, system=True) for q, qt in zip(queries, queries_tokens)] + + while any([not history.completed for history in histories]) and turns < self.max_turns: + histories = self.generate(histories) + histories = self.tasks_end_check(histories) + # TODO: make this parallel rather than for-loop + for i in range(len(histories)): + histories[i] = self.step(histories[i]) + histories = self.tasks_end_check(histories, model_turn=False) + turns += 1 + self.compute_reward(histories, **rewards_kwargs) + + # convert a list of (q, r, m) tuples to lists of all qs, rs, and ms respectively + queries, responses, masks = map(list, zip(*[history.split_query_response_tokens() for history in histories])) + + rewards = [history.reward for history in histories] + return queries, responses, masks, rewards, histories + + def step(self, history): + """ + Step the environment forward one turn. + + Args: + history (`TextHistory`): The history to step forward. + """ + truncated, ended = self.task_end_check(history) + if ended: + history.complete(truncated=truncated) + if history.completed: + return history + + tool, query = self.parse_tool_call(history.last_text_segment) + if tool is None or query is None: + response = f"Unknown tool call: {history.last_text_segment}" + else: + if tool not in self.tools: + response = f"Unknown tool {tool}." + try: + response = self.tools[tool](query) + except Exception as error: + response = f"Tool error: {str(error)}" + + if len(response) > self.max_tool_response: + response = response[: (self.max_tool_response - 3)] + "..." + + history.append_segment( + response + self.response_token, + self.tokenizer(response + self.response_token, return_tensors="pt") + .input_ids[0] + .to(self.model.pretrained_model.device), + system=True, + ) + + return history + + def parse_tool_call(self, text): + """ + Parse request string. Expected format: query + """ + result = re.search(f"(?<={self.request_token}).*?(?={self.call_token})", text, re.DOTALL) + + # if we can't find a / span we return none + if result is None: + return None, None + else: + extracted_text = result.group() + + result = re.search(r"<(.*?)>", extracted_text) + + # if we can't find a tool name we return none + if result is None: + return None, None + else: + tool = result.group(1) + + # split off the tool name + query = ">".join(extracted_text.split(">")[1:]) + + return tool, query + + def compute_reward(self, histories, **reward_kwargs): + """ + Compute the reward for a list of histories. + """ + rewards = self.reward_fn([history.last_text_segment for history in histories], **reward_kwargs) + for history, reward in zip(histories, rewards): + history.reward = reward + return histories + + def generate(self, histories): + """ + Generate responses for a list of histories. + """ + active_histories = [i for i, history in enumerate(histories) if not history.completed] + + query_tensors = [histories[i].tokens for i in active_histories] + response_tensors = self._generate_batched(query_tensors) + response_texts = self.tokenizer.batch_decode(response_tensors) + + for i, response_text, response_tensor in zip(active_histories, response_texts, response_tensors): + histories[i].append_segment(response_text, response_tensor, system=False) + + return histories + + def tasks_end_check(self, histories, model_turn=True): + """ + Check if the current generation sequences have finished. + """ + for history in histories: + if not history.completed: + truncated, ended = self.task_end_check(history, model_turn=model_turn) + if ended: + history.complete(truncated=truncated) + return histories + + def task_end_check(self, history, model_turn=True): + """ + Check if the current generation sequence has finished. + """ + truncated = False + ended = False + if history.completed: + return truncated, ended + if self.max_length is not None and len(self.tokenizer(history.text).input_ids[0]) > self.max_length: + truncated = True + ended = True + elif self.tokenizer.eos_token in history.text: + ended = True + elif model_turn and not ( + (self.request_token in history.last_text_segment and self.call_token in history.last_text_segment) + or self.submit_token in history.last_text_segment + ): + ended = True + elif self.submit_token in history.last_text_segment: + ended = True + return truncated, ended + + def _generate_batched( + self, + query_tensors, + batch_size: int = 16, + pad_to_multiple_of: int = None, + ): + """ + Generate responses for a list of query tensors. + + args: + query_tensors (list[torch.Tensor]): A list of query tensors to generate responses for. + batch_size (int): The batch size to use for generation. + pad_to_multiple_of (int): The padding length to use for generation. + """ + outputs = [] + padding_side_default = self.tokenizer.padding_side + if not self.is_encoder_decoder: + self.tokenizer.padding_side = "left" + + # in case we have fewer examples than bs + batch_size = min(len(query_tensors), batch_size) + + for i in range(0, len(query_tensors), batch_size): + # prevent overflow if query tensors are not even multiple of bs + end_index = min(len(query_tensors), i + batch_size) + + batch = query_tensors[i:end_index] + batch_mask = [torch.ones_like(element) for element in batch] + inputs = {"input_ids": batch, "attention_mask": batch_mask} + + padded_inputs = self.tokenizer.pad( + inputs, + padding=True, + max_length=None, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors="pt", + ).to(self.current_device) + + stopping_criteria = StringStoppingCriteria([self.call_token, self.submit_token], self.tokenizer) + + self.generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stopping_criteria]) + + generations = extract_model_from_parallel(self.model).generate(**padded_inputs, **self.generation_kwargs) + + for generation, mask, generated_tokens in zip( + generations, padded_inputs["attention_mask"], stopping_criteria.generated_tokens + ): + if not self.is_encoder_decoder: + output = generation[(1 - mask).sum() :] # remove padding + else: + output = generation + + if not self.is_encoder_decoder: + output = output[(mask).sum() :] # remove prompt + + # remove chunk generated after stopping criteria in batch mode + outputs.append(output[:generated_tokens]) + self.tokenizer.padding_side = padding_side_default + return outputs diff --git a/trl/extras/__init__.py b/trl/extras/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6b3035db92af28f5d19d72813f08b06fdad50925 --- /dev/null +++ b/trl/extras/__init__.py @@ -0,0 +1,16 @@ +# flake8: noqa + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .best_of_n_sampler import BestOfNSampler diff --git a/trl/extras/best_of_n_sampler.py b/trl/extras/best_of_n_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..1441eecd41a5a18b7612f4e270271d137c07d437 --- /dev/null +++ b/trl/extras/best_of_n_sampler.py @@ -0,0 +1,117 @@ +from typing import Any, Callable, List, Optional, Union + +import torch +from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast + +from ..core import set_seed +from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper + + +class BestOfNSampler(object): + def __init__( + self, + model: PreTrainedModelWrapper, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + queries_to_scores: Callable[[List[str]], List[float]], + length_sampler: Any, + sample_size: int = 4, + seed: Optional[int] = None, + n_candidates: int = 1, + generation_config: Optional[GenerationConfig] = None, + ) -> None: + r""" + Initialize the sampler for best-of-n generation + + Args: + model (`PreTrainedModelWrapper`): + The pretrained model to use for generation + tokenizer (`PreTrainedTokenizer` or `PreTrainedTokenizerFast`): + Tokenizer associated with the pretrained model + queries_to_scores (`Callable[[List[str]], List[float]]`): + Callable that takes a list of generated texts and returns the associated reward scores + length_sampler (`Any`): + Sampler used to sample the length of the generated text + sample_size (`int`): + Number of samples to generate for each query + seed (`int`, *optional*): + Random seed used to control generation + n_candidates (`int`): + Number of candidates to return for each query + generation_config (`GenerationConfig`, *optional*): + Generation config passed to the underlying model's `generate` method. + See `GenerationConfig` (https://huggingface.co/docs/transformers/v4.29.1/en/main_classes/text_generation#transformers.GenerationConfig) for more details + """ + if seed is not None: + set_seed(seed) + + if not isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): + raise ValueError( + f"tokenizer must be a PreTrainedTokenizer or PreTrainedTokenizerFast, got {type(tokenizer)}" + ) + if not isinstance(model, (SUPPORTED_ARCHITECTURES)): + raise ValueError( + f"model must be a PreTrainedModelWrapper, got {type(model)} - supported architectures are: {SUPPORTED_ARCHITECTURES}" + ) + + self.model = model + self.tokenizer = tokenizer + + self.queries_to_scores = queries_to_scores + self.length_sampler = length_sampler + self.gen_config = generation_config + self.sample_size = sample_size + self.n_candidates = n_candidates + + def generate( + self, + tokenized_query: Union[List[int], torch.Tensor, List[torch.Tensor], List[List[int]]], + skip_special_tokens: bool = True, + device: Optional[Union[str, torch.device]] = None, + **generation_kwargs, + ) -> List[List[str]]: + r""" + Generate the best of n samples for input queries + + Args: + tokenized_query (`List[int]` or `torch.Tensor` or `List[torch.Tensor]` or `List[int]`): + represents either a single tokenized query (a single tensor or a list of integers) or a batch of tokenized queries (a list of tensors or a list of lists of integers) + skip_special_tokens (`bool`): + Whether to remove the special tokens from the output + device (`str` or `torch.device`, *optional*): + The device on which the model will be loaded + **generation_kwargs (`dict`, *optional*): + Additional keyword arguments passed along to the underlying model's `generate` method. + This is used to override generation config + + Returns: + List[List[str]]: A list of lists of generated texts + """ + queries = None + + if isinstance(tokenized_query, torch.Tensor) and tokenized_query.ndim == 1: + queries = tokenized_query.unsqueeze(0) + elif isinstance(tokenized_query, List): + element_type = type(tokenized_query[0]) + if element_type == int: + queries = torch.tensor(tokenized_query).unsqueeze(0) + elif element_type == torch.Tensor: + queries = [tensor.reshape((1, -1)) for tensor in tokenized_query] + else: + queries = [torch.tensor(query).reshape((1, -1)) for query in tokenized_query] + + result = [] + + for query in queries: + queries = query.repeat((self.sample_size, 1)) + output = self.model.generate( + queries.to(device), + max_new_tokens=self.length_sampler(), + generation_config=self.gen_config, + **generation_kwargs, + ).squeeze() + output = self.tokenizer.batch_decode(output, skip_special_tokens=skip_special_tokens) + scores = torch.tensor(self.queries_to_scores(output)) + output = [output[i] for i in scores.topk(self.n_candidates).indices] + result.append(output) + + return result diff --git a/trl/import_utils.py b/trl/import_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7697f6db840d0d96721b4369a8505ef58c54c747 --- /dev/null +++ b/trl/import_utils.py @@ -0,0 +1,90 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import sys + + +if sys.version_info < (3, 8): + _is_python_greater_3_8 = False +else: + _is_python_greater_3_8 = True + + +def is_peft_available() -> bool: + return importlib.util.find_spec("peft") is not None + + +def is_accelerate_greater_20_0() -> bool: + if _is_python_greater_3_8: + from importlib.metadata import version + + accelerate_version = version("accelerate") + else: + import pkg_resources + + accelerate_version = pkg_resources.get_distribution("accelerate").version + return accelerate_version >= "0.20.0" + + +def is_transformers_greater_than(version: str) -> bool: + _transformers_version = importlib.metadata.version("transformers") + return _transformers_version > version + + +def is_torch_greater_2_0() -> bool: + if _is_python_greater_3_8: + from importlib.metadata import version + + torch_version = version("torch") + else: + import pkg_resources + + torch_version = pkg_resources.get_distribution("torch").version + return torch_version >= "2.0" + + +def is_diffusers_available() -> bool: + return importlib.util.find_spec("diffusers") is not None + + +def is_bitsandbytes_available() -> bool: + return importlib.util.find_spec("bitsandbytes") is not None + + +def is_torchvision_available() -> bool: + return importlib.util.find_spec("torchvision") is not None + + +def is_rich_available() -> bool: + return importlib.util.find_spec("rich") is not None + + +def is_wandb_available() -> bool: + return importlib.util.find_spec("wandb") is not None + + +def is_xpu_available() -> bool: + if is_accelerate_greater_20_0(): + import accelerate + + return accelerate.utils.is_xpu_available() + else: + if importlib.util.find_spec("intel_extension_for_pytorch") is None: + return False + try: + import torch + + return hasattr(torch, "xpu") and torch.xpu.is_available() + except RuntimeError: + return False diff --git a/trl/models/__init__.py b/trl/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6ccce25e5e4c8cf05f2d02efee351512a5b6d848 --- /dev/null +++ b/trl/models/__init__.py @@ -0,0 +1,34 @@ +# flake8: noqa + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .modeling_base import PreTrainedModelWrapper, create_reference_model +from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead + + +SUPPORTED_ARCHITECTURES = ( + AutoModelForCausalLMWithValueHead, + AutoModelForSeq2SeqLMWithValueHead, +) + +from ..import_utils import is_diffusers_available + + +if is_diffusers_available(): + from .modeling_sd_base import ( + DDPOPipelineOutput, + DDPOSchedulerOutput, + DDPOStableDiffusionPipeline, + DefaultDDPOStableDiffusionPipeline, + ) diff --git a/trl/models/modeling_base.py b/trl/models/modeling_base.py new file mode 100644 index 0000000000000000000000000000000000000000..9a3142eebd87b188fa7de89116874bdc590cccfc --- /dev/null +++ b/trl/models/modeling_base.py @@ -0,0 +1,672 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import logging +import os +from copy import deepcopy + +import torch +import torch.nn as nn +from accelerate import PartialState +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import ( + EntryNotFoundError, + HFValidationError, + LocalEntryNotFoundError, + RepositoryNotFoundError, +) +from safetensors.torch import load_file as safe_load_file +from transformers import PreTrainedModel + +from ..import_utils import is_peft_available, is_transformers_greater_than, is_xpu_available + + +if is_peft_available(): + from peft import ( + PeftConfig, + PeftModel, + PeftModelForCausalLM, + PeftModelForSeq2SeqLM, + PromptLearningConfig, + get_peft_model, + prepare_model_for_kbit_training, + ) + +if is_transformers_greater_than("4.33.0"): + from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled +else: + from transformers.deepspeed import is_deepspeed_zero3_enabled + +LAYER_PATTERNS = [ + "transformer.h.{layer}", + "model.decoder.layers.{layer}", + "gpt_neox.layers.{layer}", + "model.layers.{layer}", +] + + +class PreTrainedModelWrapper(nn.Module): + r""" + A wrapper class around a (`transformers.PreTrainedModel`) to be compatible with the + (`~transformers.PreTrained`) class in order to keep some attributes and methods of the + (`~transformers.PreTrainedModel`) class. + + Attributes: + pretrained_model: (`transformers.PreTrainedModel`) + The model to be wrapped. + parent_class: (`transformers.PreTrainedModel`) + The parent class of the model to be wrapped. + supported_args: (`list`) + The list of arguments that are supported by the wrapper class. + """ + transformers_parent_class = None + supported_args = None + supported_modules = ("v_head",) + supported_rm_modules = ("score",) + supported_pretrained_model_architectures = ( + (PreTrainedModel) + if not is_peft_available() + else (PreTrainedModel, PeftModelForCausalLM, PeftModelForSeq2SeqLM) + ) + + def __init__( + self, pretrained_model=None, score_module=None, supports_rm_adapter=False, rm_adapter_name=None, **kwargs + ): + super().__init__() + self.pretrained_model = pretrained_model + + self.config = pretrained_model.config + self.prepare_inputs_for_generation = pretrained_model.prepare_inputs_for_generation + self.is_loaded_in_8bit = getattr(pretrained_model, "is_loaded_in_8bit", False) + self.is_loaded_in_4bit = getattr(pretrained_model, "is_loaded_in_4bit", False) + self.is_sequential_parallel = False + + if hasattr(pretrained_model, "gradient_checkpointing_disable"): + self.gradient_checkpointing_disable = pretrained_model.gradient_checkpointing_disable + + if hasattr(pretrained_model, "gradient_checkpointing_enable"): + self.gradient_checkpointing_enable = pretrained_model.gradient_checkpointing_enable + + self.supports_rm_adapter = supports_rm_adapter + self.rm_adapter_name = rm_adapter_name + self.policy_adapter_name = "default" + if score_module is not None: + self.score = score_module + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + r""" + Instantiates a new model from a pretrained model from `transformers`. The + pretrained model is loaded using the `from_pretrained` method of the + `transformers.PreTrainedModel` class. The arguments that are specific to the + `transformers.PreTrainedModel` class are passed along this method and filtered + out from the `kwargs` argument. + + + Args: + pretrained_model_name_or_path (`str` or `transformers.PreTrainedModel`): + The path to the pretrained model or its name. + *model_args (`list`, *optional*)): + Additional positional arguments passed along to the underlying model's + `from_pretrained` method. + **kwargs (`dict`, *optional*): + Additional keyword arguments passed along to the underlying model's + `from_pretrained` method. We also pre-process the kwargs to extract + the arguments that are specific to the `transformers.PreTrainedModel` + class and the arguments that are specific to trl models. The kwargs + also support `prepare_model_for_kbit_training` arguments from + `peft` library. + """ + if kwargs is not None: + peft_config = kwargs.pop("peft_config", None) + reward_adapter = kwargs.pop("reward_adapter", None) + reward_adapter_name = kwargs.pop("reward_adapter_name", "reward_adapter") + is_trainable = kwargs.pop("is_trainable", False) + trl_model_args, pretrained_kwargs, peft_quantization_kwargs = cls._split_kwargs(kwargs) + token = pretrained_kwargs.get("token", None) + else: + peft_config = None + is_trainable = False + trl_model_args = {} + pretrained_kwargs = {} + peft_quantization_kwargs = {} + token = None + + if reward_adapter is not None and not isinstance(reward_adapter, str): + raise ValueError( + "The `reward_adapter` argument should be a string representing the name of local path or the Hub id to the Reward Modeling adapter." + ) + + is_peft_model = False + + current_device = cls._get_current_device() + if isinstance(pretrained_model_name_or_path, str): + is_loaded_in_8bit = pretrained_kwargs["load_in_8bit"] if "load_in_8bit" in pretrained_kwargs else False + is_loaded_in_4bit = pretrained_kwargs["load_in_4bit"] if "load_in_4bit" in pretrained_kwargs else False + else: + is_loaded_in_8bit = getattr(pretrained_model_name_or_path, "is_loaded_in_8bit", False) + is_loaded_in_4bit = getattr(pretrained_model_name_or_path, "is_loaded_in_4bit", False) + + if (is_loaded_in_8bit or is_loaded_in_4bit) and "device_map" not in pretrained_kwargs: + # warn users + logging.warning( + "The `device_map` argument is not provided. We will override the device_map argument." + " to set the entire" + " model on the current device. If you want to set the model on multiple devices, please provide" + " a custom `device_map` argument." + ) + pretrained_kwargs["device_map"] = {"": current_device} + + if is_peft_available() and peft_config is not None and not isinstance(peft_config, PeftConfig): + raise ValueError("The `peft_config` argument should be an instance of `peft.PeftConfig` class.") + + # First, load the pre-trained model using the parent-class + # either `AutoModelForCausalLM` or `AutoModelForSeq2SeqLM` + if isinstance(pretrained_model_name_or_path, str): + if is_peft_available(): + try: + # If there is a trained peft adapter in the hub, load its config. + remote_adapter_config = hf_hub_download( + pretrained_model_name_or_path, + "adapter_config.json", + token=token, + ) + except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError, RepositoryNotFoundError): + remote_adapter_config = None + else: + remote_adapter_config = None + + local_adapter_present = os.path.exists(os.path.join(pretrained_model_name_or_path, "adapter_config.json")) + + if (local_adapter_present or remote_adapter_config is not None) and is_peft_available(): + if peft_config is not None: + logging.warning( + "`peft_config` argument ignored since a peft config file was found in " + f"{pretrained_model_name_or_path}" + ) + + # Load the trained peft adapter config + if local_adapter_present: + trained_adapter_config = PeftConfig.from_pretrained(pretrained_model_name_or_path) + else: + remote_adapter_dir = os.path.dirname(remote_adapter_config) + trained_adapter_config = PeftConfig.from_pretrained(remote_adapter_dir) + + # Load the pretrained base model + pretrained_model = cls.transformers_parent_class.from_pretrained( + trained_adapter_config.base_model_name_or_path, *model_args, **pretrained_kwargs + ) + + # Wrap the pretrained model with the trained peft adapter + pretrained_model = PeftModel.from_pretrained( + pretrained_model, pretrained_model_name_or_path, is_trainable=is_trainable + ) + logging.info("Trained peft adapter loaded") + else: + pretrained_model = cls.transformers_parent_class.from_pretrained( + pretrained_model_name_or_path, *model_args, **pretrained_kwargs + ) + + if peft_config is not None: + # Initialize a new peft adapter with the given config + if is_loaded_in_8bit or is_loaded_in_4bit: + pretrained_model = prepare_model_for_kbit_training( + pretrained_model, + **peft_quantization_kwargs, + ) + pretrained_model = get_peft_model(pretrained_model, peft_config) + logging.info("peft adapter initialised") + + elif isinstance(pretrained_model_name_or_path, cls.supported_pretrained_model_architectures): + pretrained_model = pretrained_model_name_or_path + + if peft_config is not None and isinstance(pretrained_model, PreTrainedModel): + # Initialize a new peft adapter with the given config + if is_loaded_in_8bit or is_loaded_in_4bit: + pretrained_model = prepare_model_for_kbit_training( + pretrained_model, + **peft_quantization_kwargs, + ) + pretrained_model = get_peft_model(pretrained_model, peft_config) + logging.info("peft adapter initialised") + else: + raise ValueError( + "pretrained_model_name_or_path should be a string or a PreTrainedModel, " + f"but is {type(pretrained_model_name_or_path)}" + ) + + if is_peft_available(): + if isinstance(pretrained_model, PeftModel): + is_peft_model = True + # for backward compatibility + if hasattr(pretrained_model, "active_peft_config") and isinstance( + pretrained_model.active_peft_config, PromptLearningConfig + ): + raise ValueError("PromptLearningConfig is not supported for PPO training.") + + # Add reward modeling adapter if specified + if not is_peft_model and reward_adapter is not None: + raise ValueError("reward_adapter can only be used with a PeftModel. ") + elif is_peft_model and reward_adapter is not None: + score_module = cls.add_and_load_reward_modeling_adapter( + pretrained_model, reward_adapter, reward_adapter_name, token=token + ) + multi_adapter_args = { + "score_module": score_module, + "supports_rm_adapter": True, + "rm_adapter_name": reward_adapter_name, + } + else: + multi_adapter_args = {"supports_rm_adapter": False} + + # Then, create the full model by instantiating the wrapper class + model = cls(pretrained_model, **multi_adapter_args, **trl_model_args) + + # if resume_training, load the state_dict again - this is ok since the + # state_dict is removed from the model after loading it. + is_resuming_training = True + if isinstance(pretrained_model_name_or_path, str): + safe_filename = os.path.join(pretrained_model_name_or_path, "model.safetensors") + filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin") + + sharded_index_filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin.index.json") + safe_sharded_index_filename = os.path.join(pretrained_model_name_or_path, "model.safetensors.index.json") + is_sharded = False + use_safe = os.path.exists(safe_filename) + + if not (os.path.exists(filename) or os.path.exists(safe_filename)): + # Try with `pytorch_model.bin` + filename, files_to_download, is_sharded, is_resuming_training = cls._get_checkpoint_from_hub( + pretrained_model, + pretrained_model_name_or_path, + sharded_index_filename, + token=token, + ) + # Try with safetensors + if filename is None and files_to_download is None: + safe_filename, files_to_download, is_sharded, is_resuming_training = cls._get_checkpoint_from_hub( + pretrained_model, + pretrained_model_name_or_path, + safe_sharded_index_filename, + token=token, + model_name="model.safetensors", + model_index_name="model.safetensors.index.json", + ) + use_safe = True + else: + use_safe = False + + loading_func = safe_load_file if use_safe else torch.load + load_kwargs = {} if use_safe else {"map_location": "cpu"} + + if is_resuming_training: + if is_sharded: + # download each file and add it to the state_dict + state_dict = {} + + for shard_file in files_to_download: + filename = hf_hub_download( + pretrained_model_name_or_path, + shard_file, + token=token, + ) + state_dict.update(loading_func(filename, **load_kwargs)) + else: + state_dict = loading_func(filename if not use_safe else safe_filename, **load_kwargs) + + else: + state_dict = pretrained_model_name_or_path.state_dict() + + model.is_peft_model = is_peft_model + model.current_device = current_device + + if is_resuming_training: + model.post_init(state_dict=state_dict) + + return model + + @classmethod + def _get_checkpoint_from_hub( + cls, + pretrained_model, + pretrained_model_name_or_path, + index_filename, + token=None, + model_name="pytorch_model.bin", + model_index_name="pytorch_model.bin.index.json", + ): + files_to_download = None + filename = None + is_resuming_training = True + is_sharded = False + + try: + filename = hf_hub_download( + pretrained_model_name_or_path, + model_name, + token=token, + ) + # sharded + except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError, RepositoryNotFoundError): + if os.path.exists(index_filename): + index_file_name = index_filename + else: + try: + index_file_name = hf_hub_download( + pretrained_model_name_or_path, + model_index_name, + token=token, + ) + except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError, RepositoryNotFoundError): + # not continue training, do not have v_head weight + is_resuming_training = False + logging.warning( + f"A {type(pretrained_model)} model is loaded from '{pretrained_model_name_or_path}', " + f"and no v_head weight is found. This IS expected if you are not resuming PPO training." + ) + # load json + if is_resuming_training: + with open(index_file_name, "r") as f: + index = json.load(f) + # check filename with `v_head` or any known extra module: + files_to_download = set() + for k, v in index["weight_map"].items(): + if any([module in k for module in cls.supported_modules]): + files_to_download.add(v) + is_sharded = True + + return filename, files_to_download, is_sharded, is_resuming_training + + @classmethod + def _get_current_device(cls): + r""" + Get the current device. For GPU, we return the local process index using the `accelerate.PartialState` + object to handle corner cases when running scripts in distributed environments. + + Returns: + current_device (`Union[int, str]`): + The current device. + """ + state = PartialState() + if is_xpu_available(): + return f"xpu:{state.local_process_index}" + else: + return state.local_process_index if torch.cuda.is_available() else "cpu" + + @classmethod + def _split_kwargs(cls, kwargs): + """ + Separate the kwargs from the arguments that we support inside + `supported_args` and the ones that we don't. + """ + check_peft_kwargs = False + + if is_peft_available(): + from peft import prepare_model_for_kbit_training + + check_peft_kwargs = True + + supported_kwargs = {} + unsupported_kwargs = {} + peft_kwargs = {} + + for key, value in kwargs.items(): + if key in cls.supported_args: + supported_kwargs[key] = value + else: + unsupported_kwargs[key] = value + + if check_peft_kwargs: + if key in prepare_model_for_kbit_training.__code__.co_varnames: + peft_kwargs[key] = value + if key in unsupported_kwargs: + unsupported_kwargs.pop(key) + + return supported_kwargs, unsupported_kwargs, peft_kwargs + + @classmethod + def add_and_load_reward_modeling_adapter( + cls, pretrained_model, adapter_model_id, adapter_name="reward_model_adapter", token=None + ): + r""" + Add and load a reward modeling adapter. This method can only be used if the + model is a `PeftModel` and if you have initialized the model with the `reward_modeling_adapter_id` + argument, pointing to the id of the reward modeling adapter. The latest needs also to contain the + score head in order to produce the reward. + """ + pretrained_model.load_adapter(adapter_model_id, adapter_name, is_trainable=False) + pretrained_model.train() + + filename = os.path.join(adapter_model_id, "adapter_model.bin") + safe_loading = False + if not os.path.exists(filename): + try: + local_filename = hf_hub_download( + adapter_model_id, + "adapter_model.bin", + token=token, + ) + except: # noqa + filename = os.path.join(adapter_model_id, "adapter_model.safetensors") + safe_loading = True + if not os.path.exists(filename): + try: + local_filename = hf_hub_download( + adapter_model_id, + "adapter_model.safetensors", + token=token, + ) + except: # noqa + raise ValueError( + "Could not find adapter model in the Hub, make sure you have the correct adapter model id." + ) + else: + local_filename = filename + else: + local_filename = filename + + loading_func = safe_load_file if safe_loading else torch.load + load_kwargs = {} if safe_loading else {"map_location": "cpu"} + + adapter_state_dict = loading_func(local_filename, **load_kwargs) + + for score_name_candidate in cls.supported_rm_modules: + if any([score_name_candidate in name for name in adapter_state_dict.keys()]): + score_name = score_name_candidate + # we have found the correct head name and can break + break + + score_dict = {} + + for name, param in adapter_state_dict.items(): + if score_name in name: + key_name = ".".join(name.split(".")[-1:]) + score_dict[key_name] = param.to(cls._get_current_device()) + + num_labels, hidden_dim = score_dict["weight"].shape + has_bias = any(["bias" in name for name in adapter_state_dict.keys()]) + + score = nn.Linear(hidden_dim, num_labels, bias=has_bias).to( + device=cls._get_current_device(), + dtype=pretrained_model.dtype, + ) + score.load_state_dict(score_dict) + for param in score.parameters(): + param.requires_grad = False + + return score + + def push_to_hub(self, *args, **kwargs): + r""" + Push the pretrained model to the hub. This method is a wrapper around + `transformers.PreTrainedModel.push_to_hub`. Please refer to the documentation + of `transformers.PreTrainedModel.push_to_hub` for more information. + + Args: + *args (`list`, *optional*): + Positional arguments passed along to the underlying model's + `push_to_hub` method. + **kwargs (`dict`, *optional*): + Keyword arguments passed along to the underlying model's + `push_to_hub` method. + """ + raise NotImplementedError + + def save_pretrained(self, *args, **kwargs): + r""" + Save the pretrained model to a directory. This method is a wrapper around + `transformers.PreTrainedModel.save_pretrained`. Please refer to the documentation + of `transformers.PreTrainedModel.save_pretrained` for more information. + + Args: + *args (`list`, *optional*): + Positional arguments passed along to the underlying model's + `save_pretrained` method. + **kwargs (`dict`, *optional*): + Keyword arguments passed along to the underlying model's + `save_pretrained` method. + """ + state_dict = kwargs.get("state_dict") + if state_dict is None: + state_dict = self.state_dict() + kwargs["state_dict"] = state_dict + + # if it is a peft model only save the `v_head` state_dict and + # pop the `state_dict` from the kwargs to avoid slient bugs with `peft` + if self.is_peft_model: + save_path = args[0] + save_path = os.path.join(save_path, "pytorch_model.bin") + torch.save(state_dict, save_path) + _ = kwargs.pop("state_dict", None) + + return self.pretrained_model.save_pretrained(*args, **kwargs) + + def state_dict(self, *args, **kwargs): + r""" + Return the state_dict of the pretrained model. + """ + raise NotImplementedError + + def post_init(self, *args, **kwargs): + r""" + Post initialization method. This method is called after the model is + instantiated and loaded from a checkpoint. It can be used to perform + additional operations such as loading the state_dict. + """ + raise NotImplementedError + + def compute_reward_score(self, input_ids, attention_mask=None, **kwargs): + r""" + Computes the reward score for a given input. The method has first to enable the adapter + and then compute the reward score. After that the model disables the reward modeling + adapter and enables the default ppo adapter again. + """ + if not self.supports_rm_adapter: + raise ValueError("This model does not support reward modeling adapter.") + + # enable rm adapter + self.pretrained_model.set_adapter(self.rm_adapter_name) + self.pretrained_model.eval() + + with torch.no_grad(): + base_model_output = self.pretrained_model( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + return_dict=True, + **kwargs, + ) + + last_hidden_states = base_model_output.hidden_states[-1] + scores = self.score(last_hidden_states) + + self.pretrained_model.set_adapter(self.policy_adapter_name) + self.pretrained_model.eval() + + return scores + + +def create_reference_model( + model: PreTrainedModelWrapper, num_shared_layers: int = None, pattern: str = None +) -> PreTrainedModelWrapper: + """ + Creates a static reference copy of a model. Note that model will be in `.eval()` mode. + + Args: + model (`PreTrainedModelWrapper`): The model to be copied. + num_shared_layers (`int`, *optional*): The number of initial layers that are shared between both models and kept frozen. + pattern (`str`, *optional*): The shared layers are selected with a string pattern + (e.g. "transformer.h.{layer}" for GPT2) and if a custom pattern is necessary it can be passed here. + + Returns + `PreTrainedModelWrapper` + """ + if is_deepspeed_zero3_enabled(): + raise ValueError( + "DeepSpeed ZeRO-3 is enabled and is not compatible with `create_reference_model()`. Please instantiate your reference model directly with `AutoCausalLM.from_pretrained()`." + ) + + parameter_names = [n for n, _ in model.named_parameters()] + ref_model = deepcopy(model) + + # if no layers are shared, return copy of model + if num_shared_layers is None: + for param_name in parameter_names: + param = ref_model.get_parameter(param_name) + param.requires_grad = False + return ref_model.eval() + + # identify layer name pattern + if pattern is not None: + pattern = pattern.format(layer=num_shared_layers) + else: + for pattern_candidate in LAYER_PATTERNS: + pattern_candidate = pattern_candidate.format(layer=num_shared_layers) + if any([pattern_candidate in name for name in parameter_names]): + pattern = pattern_candidate + break + + if pattern is None: + raise ValueError("Layer pattern could not be matched.") + + # divide parameters in shared and unshared parameter lists + shared_param_list = [] + unshared_param_list = [] + + shared_parameter = True + for name, param in model.named_parameters(): + if pattern in name: + shared_parameter = False + if shared_parameter: + shared_param_list.append(name) + else: + unshared_param_list.append(name) + + # create reference of the original parameter if they are shared + for param_name in shared_param_list: + param = model.get_parameter(param_name) + param.requires_grad = False + + ref_param = ref_model.get_parameter(param_name) # noqa + ref_param = param # noqa + + # for all other parameters just make sure they don't use gradients + for param_name in unshared_param_list: + param = ref_model.get_parameter(param_name) + param.requires_grad = False + + if pattern is not None and len(unshared_param_list) == 0: + logging.warning("Pattern passed or found, but no layers matched in the model. Check for a typo.") + + return ref_model.eval() diff --git a/trl/models/modeling_sd_base.py b/trl/models/modeling_sd_base.py new file mode 100644 index 0000000000000000000000000000000000000000..0d68380401f5d6f393b5b257dcb036da2d2ca140 --- /dev/null +++ b/trl/models/modeling_sd_base.py @@ -0,0 +1,645 @@ +# Copyright 2023 DDPO-pytorch authors (Kevin Black), The HuggingFace Team, metric-space. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import os +import warnings +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from diffusers import DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers.loaders import AttnProcsLayers +from diffusers.models.attention_processor import LoRAAttnProcessor +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg + +from ..core import randn_tensor + + +@dataclass +class DDPOPipelineOutput(object): + """ + Output class for the diffusers pipeline to be finetuned with the DDPO trainer + + Args: + images (`torch.Tensor`): + The generated images. + latents (`List[torch.Tensor]`): + The latents used to generate the images. + log_probs (`List[torch.Tensor]`): + The log probabilities of the latents. + + """ + + images: torch.Tensor + latents: torch.Tensor + log_probs: torch.Tensor + + +@dataclass +class DDPOSchedulerOutput(object): + """ + Output class for the diffusers scheduler to be finetuned with the DDPO trainer + + Args: + latents (`torch.Tensor`): + Predicted sample at the previous timestep. Shape: `(batch_size, num_channels, height, width)` + log_probs (`torch.Tensor`): + Log probability of the above mentioned sample. Shape: `(batch_size)` + """ + + latents: torch.Tensor + log_probs: torch.Tensor + + +class DDPOStableDiffusionPipeline(object): + """ + Main class for the diffusers pipeline to be finetuned with the DDPO trainer + """ + + def __call__(self, *args, **kwargs) -> DDPOPipelineOutput: + raise NotImplementedError + + def scheduler_step(self, *args, **kwargs) -> DDPOSchedulerOutput: + raise NotImplementedError + + @property + def unet(self): + """ + Returns the 2d U-Net model used for diffusion. + """ + raise NotImplementedError + + @property + def vae(self): + """ + Returns the Variational Autoencoder model used from mapping images to and from the latent space + """ + raise NotImplementedError + + @property + def tokenizer(self): + """ + Returns the tokenizer used for tokenizing text inputs + """ + raise NotImplementedError + + @property + def scheduler(self): + """ + Returns the scheduler associated with the pipeline used for the diffusion process + """ + raise NotImplementedError + + @property + def text_encoder(self): + """ + Returns the text encoder used for encoding text inputs + """ + raise NotImplementedError + + @property + def autocast(self): + """ + Returns the autocast context manager + """ + raise NotImplementedError + + def set_progress_bar_config(self, *args, **kwargs): + """ + Sets the progress bar config for the pipeline + """ + raise NotImplementedError + + def save_pretrained(self, *args, **kwargs): + """ + Saves all of the model weights + """ + raise NotImplementedError + + def get_trainable_layers(self, *args, **kwargs): + """ + Returns the trainable parameters of the pipeline + """ + raise NotImplementedError + + def save_checkpoint(self, *args, **kwargs): + """ + Light wrapper around accelerate's register_save_state_pre_hook which is run before saving state + """ + raise NotImplementedError + + def load_checkpoint(self, *args, **kwargs): + """ + Light wrapper around accelerate's register_lad_state_pre_hook which is run before loading state + """ + raise NotImplementedError + + +def _left_broadcast(input_tensor, shape): + """ + As opposed to the default direction of broadcasting (right to left), this function broadcasts + from left to right + Args: + input_tensor (`torch.FloatTensor`): is the tensor to broadcast + shape (`Tuple[int]`): is the shape to broadcast to + """ + input_ndim = input_tensor.ndim + if input_ndim > len(shape): + raise ValueError( + "The number of dimensions of the tensor to broadcast cannot be greater than the length of the shape to broadcast to" + ) + return input_tensor.reshape(input_tensor.shape + (1,) * (len(shape) - input_ndim)).broadcast_to(shape) + + +def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep.cpu()).to(timestep.device) + alpha_prod_t_prev = torch.where( + prev_timestep.cpu() >= 0, + self.alphas_cumprod.gather(0, prev_timestep.cpu()), + self.final_alpha_cumprod, + ).to(timestep.device) + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + +def scheduler_step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + prev_sample: Optional[torch.FloatTensor] = None, +) -> DDPOSchedulerOutput: + """ + + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + eta (`float`): weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped + predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when + `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would + coincide with the one provided as input and `use_clipped_model_output` will have not effect. + generator: random number generator. + variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we + can directly provide the noise for the variance itself. This is useful for methods such as + CycleDiffusion. (https://arxiv.org/abs/2210.05559) + + Returns: + `DDPOSchedulerOutput`: the predicted sample at the previous timestep and the log probability of the sample + """ + + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + # to prevent OOB on gather + prev_timestep = torch.clamp(prev_timestep, 0, self.config.num_train_timesteps - 1) + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod.gather(0, timestep.cpu()) + alpha_prod_t_prev = torch.where( + prev_timestep.cpu() >= 0, + self.alphas_cumprod.gather(0, prev_timestep.cpu()), + self.final_alpha_cumprod, + ) + alpha_prod_t = _left_broadcast(alpha_prod_t, sample.shape).to(sample.device) + alpha_prod_t_prev = _left_broadcast(alpha_prod_t_prev, sample.shape).to(sample.device) + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + # 4. Clip or threshold "predicted x_0" + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = _get_variance(self, timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + std_dev_t = _left_broadcast(std_dev_t, sample.shape).to(sample.device) + + if use_clipped_model_output: + # the pred_epsilon is always re-derived from the clipped x_0 in Glide + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon + + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample_mean = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + if prev_sample is not None and generator is not None: + raise ValueError( + "Cannot pass both generator and prev_sample. Please make sure that either `generator` or" + " `prev_sample` stays `None`." + ) + + if prev_sample is None: + variance_noise = randn_tensor( + model_output.shape, + generator=generator, + device=model_output.device, + dtype=model_output.dtype, + ) + prev_sample = prev_sample_mean + std_dev_t * variance_noise + + # log prob of prev_sample given prev_sample_mean and std_dev_t + log_prob = ( + -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2)) + - torch.log(std_dev_t) + - torch.log(torch.sqrt(2 * torch.as_tensor(np.pi))) + ) + # mean along all but batch dimension + log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) + + return DDPOSchedulerOutput(prev_sample.type(sample.dtype), log_prob) + + +# 1. The output type for call is different as the logprobs are now returned +# 2. An extra method called `scheduler_step` is added which is used to constraint the scheduler output +@torch.no_grad() +def pipeline_step( + self, + prompt: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, +): + r""" + Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + + Examples: + + Returns: + `DDPOPipelineOutput`: The generated image, the predicted latents used to generate the image and the associated log probabilities + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + all_latents = [latents] + all_log_probs = [] + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = scheduler_step(self.scheduler, noise_pred, t, latents, eta) + latents = scheduler_output.latents + log_prob = scheduler_output.log_probs + + all_latents.append(latents) + all_log_probs.append(log_prob) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + return DDPOPipelineOutput(image, all_latents, all_log_probs) + + +class DefaultDDPOStableDiffusionPipeline(DDPOStableDiffusionPipeline): + def __init__(self, pretrained_model_name: str, *, pretrained_model_revision: str = "main", use_lora: bool = True): + self.sd_pipeline = StableDiffusionPipeline.from_pretrained( + pretrained_model_name, revision=pretrained_model_revision + ) + + self.use_lora = use_lora + self.pretrained_model = pretrained_model_name + self.pretrained_revision = pretrained_model_revision + + try: + self.sd_pipeline.unet.load_attn_procs(pretrained_model_name, revision=pretrained_model_revision) + self.use_lora = True + except OSError: + if use_lora: + warnings.warn( + "If you are aware that the pretrained model has no lora weights to it, ignore this message. " + "Otherwise please check the if `pytorch_lora_weights.safetensors` exists in the model folder." + ) + + self.sd_pipeline.scheduler = DDIMScheduler.from_config(self.sd_pipeline.scheduler.config) + self.sd_pipeline.safety_checker = None + + # memory optimization + self.sd_pipeline.vae.requires_grad_(False) + self.sd_pipeline.text_encoder.requires_grad_(False) + self.sd_pipeline.unet.requires_grad_(not self.use_lora) + + def __call__(self, *args, **kwargs) -> DDPOPipelineOutput: + return pipeline_step(self.sd_pipeline, *args, **kwargs) + + def scheduler_step(self, *args, **kwargs) -> DDPOSchedulerOutput: + return scheduler_step(self.sd_pipeline.scheduler, *args, **kwargs) + + @property + def unet(self): + return self.sd_pipeline.unet + + @property + def vae(self): + return self.sd_pipeline.vae + + @property + def tokenizer(self): + return self.sd_pipeline.tokenizer + + @property + def scheduler(self): + return self.sd_pipeline.scheduler + + @property + def text_encoder(self): + return self.sd_pipeline.text_encoder + + @property + def autocast(self): + return contextlib.nullcontext if self.use_lora else None + + def save_pretrained(self, output_dir): + if self.use_lora: + self.sd_pipeline.unet.save_attn_procs(output_dir) + self.sd_pipeline.save_pretrained(output_dir) + + def set_progress_bar_config(self, *args, **kwargs): + self.sd_pipeline.set_progress_bar_config(*args, **kwargs) + + def get_trainable_layers(self): + if self.use_lora: + # Set correct lora layers + lora_attn_procs = {} + for name in self.sd_pipeline.unet.attn_processors.keys(): + cross_attention_dim = ( + None if name.endswith("attn1.processor") else self.sd_pipeline.unet.config.cross_attention_dim + ) + if name.startswith("mid_block"): + hidden_size = self.sd_pipeline.unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(self.sd_pipeline.unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = self.sd_pipeline.unet.config.block_out_channels[block_id] + + lora_attn_procs[name] = LoRAAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim + ) + self.sd_pipeline.unet.set_attn_processor(lora_attn_procs) + return AttnProcsLayers(self.sd_pipeline.unet.attn_processors) + else: + return self.sd_pipeline.unet + + def save_checkpoint(self, models, weights, output_dir): + if len(models) != 1: + raise ValueError("Given how the trainable params were set, this should be of length 1") + if self.use_lora and isinstance(models[0], AttnProcsLayers): + self.sd_pipeline.unet.save_attn_procs(output_dir) + elif not self.use_lora and isinstance(models[0], UNet2DConditionModel): + models[0].save_pretrained(os.path.join(output_dir, "unet")) + else: + raise ValueError(f"Unknown model type {type(models[0])}") + + def load_checkpoint(self, models, input_dir): + if len(models) != 1: + raise ValueError("Given how the trainable params were set, this should be of length 1") + if self.use_lora and isinstance(models[0], AttnProcsLayers): + tmp_unet = UNet2DConditionModel.from_pretrained( + self.pretrained_model, + revision=self.pretrained_revision, + subfolder="unet", + ) + tmp_unet.load_attn_procs(input_dir) + models[0].load_state_dict(AttnProcsLayers(tmp_unet.attn_processors).state_dict()) + del tmp_unet + elif not self.use_lora and isinstance(models[0], UNet2DConditionModel): + load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + models[0].register_to_config(**load_model.config) + models[0].load_state_dict(load_model.state_dict()) + del load_model + else: + raise ValueError(f"Unknown model type {type(models[0])}") diff --git a/trl/models/modeling_value_head.py b/trl/models/modeling_value_head.py new file mode 100644 index 0000000000000000000000000000000000000000..2771cc6ce2f5daad0a28933b94f39e670bd9350a --- /dev/null +++ b/trl/models/modeling_value_head.py @@ -0,0 +1,429 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn as nn +from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM + +from .modeling_base import PreTrainedModelWrapper + + +class ValueHead(nn.Module): + r""" + The ValueHead class implements a head for GPT2 that returns a scalar for each output token. + """ + + def __init__(self, config, **kwargs): + super().__init__() + if not hasattr(config, "summary_dropout_prob"): + summary_dropout_prob = kwargs.pop("summary_dropout_prob", 0.1) + else: + summary_dropout_prob = config.summary_dropout_prob + + self.dropout = nn.Dropout(summary_dropout_prob) if summary_dropout_prob else nn.Identity() + + # some models such as OPT have a projection layer before the word embeddings - e.g. OPT-350m + if hasattr(config, "hidden_size"): + hidden_size = config.hidden_size + if hasattr(config, "word_embed_proj_dim"): + hidden_size = config.word_embed_proj_dim + elif hasattr(config, "is_encoder_decoder"): + if config.is_encoder_decoder and hasattr(config, "decoder"): + if hasattr(config.decoder, "hidden_size"): + hidden_size = config.decoder.hidden_size + + self.summary = nn.Linear(hidden_size, 1) + + self.flatten = nn.Flatten() + + def forward(self, hidden_states): + output = self.dropout(hidden_states) + + # For now force upcast in fp32 if needed. Let's keep the + # output in fp32 for numerical stability. + if output.dtype != self.summary.weight.dtype: + output = output.to(self.summary.weight.dtype) + + output = self.summary(output) + return output + + +class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper): + r""" + An autoregressive model with a value head in addition to the language model head. + This class inherits from `~trl.PreTrainedModelWrapper` and wraps a + `transformers.PreTrainedModel` class. The wrapper class supports classic functions + such as `from_pretrained`, `push_to_hub` and `generate`. To call a method of the wrapped + model, simply manipulate the `pretrained_model` attribute of this class. + + Class attributes: + - **transformers_parent_class** (`transformers.PreTrainedModel`) -- The parent class of the wrapped model. This + should be set to `transformers.AutoModelForCausalLM` for this class. + - **lm_head_namings** (`tuple`) -- A tuple of strings that are used to identify the language model head of the + wrapped model. This is set to `("lm_head", "embed_out")` for this class but can be changed for other models + in the future + - **supported_args** (`tuple`) -- A tuple of strings that are used to identify the arguments that are supported + by the `ValueHead` class. Currently, the supported args are: + - **summary_dropout_prob** (`float`, `optional`, defaults to `None`) -- The dropout probability for the + `ValueHead` class. + - **v_head_initializer_range** (`float`, `optional`, defaults to `0.2`) -- The initializer range for the + `ValueHead` if a specific initialization strategy is selected. + - **v_head_init_strategy** (`str`, `optional`, defaults to `None`) -- The initialization strategy for the + `ValueHead`. Currently, the supported strategies are: + - **`None`** -- Initializes the weights of the `ValueHead` with a random distribution. This is the default + strategy. + - **"normal"** -- Initializes the weights of the `ValueHead` with a normal distribution. + + """ + transformers_parent_class = AutoModelForCausalLM + lm_head_namings = ["lm_head", "embed_out"] + supported_args = ( + "summary_dropout_prob", + "v_head_initializer_range", + "v_head_init_strategy", + ) + + def __init__(self, pretrained_model, **kwargs): + r""" + Initializes the model. + + Args: + pretrained_model (`transformers.PreTrainedModel`): + The model to wrap. It should be a causal language model such as GPT2. + or any model mapped inside the `AutoModelForCausalLM` class. + kwargs (`dict`, `optional`): + Additional keyword arguments, that are passed to the `ValueHead` class. + """ + super().__init__(pretrained_model, **kwargs) + v_head_kwargs, _, _ = self._split_kwargs(kwargs) + + if not any(hasattr(self.pretrained_model, attribute) for attribute in self.lm_head_namings): + raise ValueError("The model does not have a language model head, please use a model that has one.") + + self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs) + + self._init_weights(**v_head_kwargs) + + def _init_weights(self, **kwargs): + r""" + Initializes the weights of the value head. The default initialization strategy is random. + Users can pass a different initialization strategy by passing the `v_head_init_strategy` argument + when calling `.from_pretrained`. Supported strategies are: + - `normal`: initializes the weights with a normal distribution. + + Args: + **kwargs (`dict`, `optional`): + Additional keyword arguments, that are passed to the `ValueHead` class. These arguments + can contain the `v_head_init_strategy` argument as well as the `v_head_initializer_range` + argument. + """ + initializer_range = kwargs.pop("v_head_initializer_range", 0.2) + # random init by default + init_strategy = kwargs.pop("v_head_init_strategy", None) + if init_strategy is None: + # do nothing + pass + elif init_strategy == "normal": + self.v_head.summary.weight.data.normal_(mean=0.0, std=initializer_range) + self.v_head.summary.bias.data.zero_() + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + **kwargs, + ): + r""" + Applies a forward pass to the wrapped model and returns the logits of the value head. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + past_key_values (`tuple(tuple(torch.FloatTensor))`, `optional`): + Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model + (see `past_key_values` input) to speed up sequential decoding. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + kwargs (`dict`, `optional`): + Additional keyword arguments, that are passed to the wrapped model. + """ + kwargs["output_hidden_states"] = True # this had already been set in the LORA / PEFT examples + kwargs["past_key_values"] = past_key_values + + if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING": + kwargs.pop("past_key_values") + + base_model_output = self.pretrained_model( + input_ids=input_ids, + attention_mask=attention_mask, + **kwargs, + ) + + last_hidden_state = base_model_output.hidden_states[-1] + lm_logits = base_model_output.logits + loss = base_model_output.loss + + if last_hidden_state.device != self.v_head.summary.weight.device: + last_hidden_state = last_hidden_state.to(self.v_head.summary.weight.device) + + value = self.v_head(last_hidden_state).squeeze(-1) + + # force upcast in fp32 if logits are in half-precision + if lm_logits.dtype != torch.float32: + lm_logits = lm_logits.float() + + return (lm_logits, loss, value) + + def generate(self, *args, **kwargs): + r""" + A simple wrapper around the `generate` method of the wrapped model. + Please refer to the [`generate`](https://huggingface.co/docs/transformers/internal/generation_utils) + method of the wrapped model for more information about the supported arguments. + + Args: + *args (`list`, *optional*): + Positional arguments passed to the `generate` method of the wrapped model. + **kwargs (`dict`, *optional*): + Keyword arguments passed to the `generate` method of the wrapped model. + """ + return self.pretrained_model.generate(*args, **kwargs) + + def state_dict(self, *args, **kwargs): + r""" + Returns the state dictionary of the model. We add the state dictionary of the value head + to the state dictionary of the wrapped model by prepending the key with `v_head.`. + """ + if not self.is_peft_model: + pretrained_model_state_dict = self.pretrained_model.state_dict(*args, **kwargs) + else: + # if it is a peft model, only save the v_head + pretrained_model_state_dict = {} + + v_head_state_dict = self.v_head.state_dict(*args, **kwargs) + for k, v in v_head_state_dict.items(): + pretrained_model_state_dict[f"v_head.{k}"] = v + return pretrained_model_state_dict + + def push_to_hub(self, *args, **kwargs): + setattr(self.pretrained_model, "v_head", self.v_head) + + return self.pretrained_model.push_to_hub(*args, **kwargs) + + def post_init(self, state_dict): + r""" + We add the state dictionary of the value head to the state dictionary of the wrapped model + by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the + keys of the value head state dictionary. + """ + for k in list(state_dict.keys()): + if "v_head." in k: + state_dict[k.replace("v_head.", "")] = state_dict.pop(k) + self.v_head.load_state_dict(state_dict, strict=False) + del state_dict + + if hasattr(self.pretrained_model, "hf_device_map"): + if ( + "cpu" in self.pretrained_model.hf_device_map.values() + or "disk" in self.pretrained_model.hf_device_map.values() + ): + raise ValueError( + "The model is offloaded on CPU or disk - CPU & disk offloading is not supported for ValueHead models." + ) + + first_device = list(set(self.pretrained_model.hf_device_map.values()))[0] + + self.v_head = self.v_head.to(first_device) + + def set_device_hook(module, input, outputs): + new_output = () + for output in outputs: + if isinstance(output, torch.Tensor): + new_output += (output.to(first_device),) + else: + new_output += (output,) + return new_output + + self.register_forward_hook(set_device_hook) + + self.is_sequential_parallel = True + + +class AutoModelForSeq2SeqLMWithValueHead(PreTrainedModelWrapper): + r""" + A seq2seq model with a value head in addition to the language model head. + This class inherits from `~trl.PreTrainedModelWrapper` and wraps a + `transformers.PreTrainedModel` class. The wrapper class supports classic functions + such as `from_pretrained` and `push_to_hub` and also provides some additional + functionalities such as `generate`. + + Args: + pretrained_model (`transformers.PreTrainedModel`): + The model to wrap. It should be a causal language model such as GPT2. + or any model mapped inside the `AutoModelForSeq2SeqLM` class. + kwargs: + Additional keyword arguments passed along to the `ValueHead` class. + """ + transformers_parent_class = AutoModelForSeq2SeqLM + lm_head_namings = ["lm_head", "embed_out", "output_projection"] + supported_args = ( + "summary_dropout_prob", + "v_head_initializer_range", + "v_head_init_strategy", + ) + + def __init__(self, pretrained_model, **kwargs): + super().__init__(pretrained_model, **kwargs) + v_head_kwargs, _, _ = self._split_kwargs(kwargs) + self.is_encoder_decoder = True + + if not self._has_lm_head(): + raise ValueError("The model does not have a language model head, please use a model that has one.") + + self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs) + + self._init_weights(**v_head_kwargs) + + def _has_lm_head(self): + # check module names of all modules inside `pretrained_model` to find the language model head + for name, module in self.pretrained_model.named_modules(): + if any(attribute in name for attribute in self.lm_head_namings): + return True + return False + + def post_init(self, state_dict): + r""" + We add the state dictionary of the value head to the state dictionary of the wrapped model + by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the + keys of the value head state dictionary. + """ + for k in list(state_dict.keys()): + if "v_head." in k: + state_dict[k.replace("v_head.", "")] = state_dict.pop(k) + self.v_head.load_state_dict(state_dict, strict=False) + del state_dict + + if hasattr(self.pretrained_model, "hf_device_map"): + if ( + "cpu" in self.pretrained_model.hf_device_map.values() + or "disk" in self.pretrained_model.hf_device_map.values() + ): + raise ValueError( + "The model is offloaded on CPU or disk - CPU & disk offloading is not supported for ValueHead models." + ) + + # get the lm_head device + for name, module in self.pretrained_model.named_modules(): + if any(attribute in name for attribute in self.lm_head_namings): + lm_head_device = module.weight.device + break + + # put v_head on the same device as the lm_head to avoid issues + self.v_head = self.v_head.to(lm_head_device) + + def set_device_hook(module, input, outputs): + r""" + A hook that sets the device of the output of the model to the device of the first + parameter of the model. + + Args: + module (`nn.Module`): + The module to which the hook is attached. + input (`tuple`): + The input to the module. + outputs (`tuple`): + The output of the module. + """ + new_output = () + for output in outputs: + if isinstance(output, torch.Tensor): + new_output += (output.to(lm_head_device),) + else: + new_output += (output,) + return new_output + + self.register_forward_hook(set_device_hook) + self.is_sequential_parallel = True + + def state_dict(self, *args, **kwargs): + r""" + Returns the state dictionary of the model. We add the state dictionary of the value head + to the state dictionary of the wrapped model by prepending the key with `v_head.`. + """ + if not self.is_peft_model: + pretrained_model_state_dict = self.pretrained_model.state_dict(*args, **kwargs) + else: + # if it is a peft model, only save the v_head + pretrained_model_state_dict = {} + + v_head_state_dict = self.v_head.state_dict(*args, **kwargs) + for k, v in v_head_state_dict.items(): + pretrained_model_state_dict[f"v_head.{k}"] = v + return pretrained_model_state_dict + + def push_to_hub(self, *args, **kwargs): + setattr(self.pretrained_model, "v_head", self.v_head) + + return self.pretrained_model.push_to_hub(*args, **kwargs) + + def _init_weights(self, **kwargs): + r""" + We initialize the weights of the value head. + """ + initializer_range = kwargs.pop("v_head_initializer_range", 0.2) + # random init by default + init_strategy = kwargs.pop("v_head_init_strategy", None) + if init_strategy is None: + # do nothing + pass + elif init_strategy == "normal": + self.v_head.summary.weight.data.normal_(mean=0.0, std=initializer_range) + self.v_head.summary.bias.data.zero_() + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + **kwargs, + ): + kwargs["past_key_values"] = past_key_values + if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING": + kwargs.pop("past_key_values") + + base_model_output = self.pretrained_model( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, # We force the model to output hidden states + **kwargs, + ) + + last_hidden_state = base_model_output.decoder_hidden_states[-1] + lm_logits = base_model_output.logits + loss = base_model_output.loss + + value = self.v_head(last_hidden_state).squeeze(-1) + + # force upcast in fp32 if logits are in half-precision + if lm_logits.dtype != torch.float32: + lm_logits = lm_logits.float() + + return (lm_logits, loss, value) + + def generate(self, *args, **kwargs): + r""" + We call `generate` on the wrapped model. + """ + return self.pretrained_model.generate(*args, **kwargs) diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e81705fbc2adc67ee2682977a0f374ab08dea530 --- /dev/null +++ b/trl/trainer/__init__.py @@ -0,0 +1,44 @@ +# flake8: noqa + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# There is a circular import in the PPOTrainer if we let isort sort these +# isort: off +from .utils import ( + AdaptiveKLController, + FixedKLController, + ConstantLengthDataset, + DataCollatorForCompletionOnlyLM, + RunningMoments, + disable_dropout_in_model, +) + +# isort: on + +from ..import_utils import is_diffusers_available +from .base import BaseTrainer +from .ddpo_config import DDPOConfig + + +if is_diffusers_available(): + from .ddpo_trainer import DDPOTrainer + +from .dpo_trainer import DPOTrainer +from .iterative_sft_trainer import IterativeSFTTrainer +from .ppo_config import PPOConfig +from .ppo_trainer import PPOTrainer +from .reward_trainer import RewardTrainer, compute_accuracy +from .sft_trainer import SFTTrainer +from .training_configs import RewardConfig diff --git a/trl/trainer/base.py b/trl/trainer/base.py new file mode 100644 index 0000000000000000000000000000000000000000..f0314cb987fcf5a520ed1ab1ad0a7eb107f18acc --- /dev/null +++ b/trl/trainer/base.py @@ -0,0 +1,46 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from huggingface_hub import PyTorchModelHubMixin + + +class BaseTrainer(PyTorchModelHubMixin): + r""" + Base class for all trainers - this base class implements the basic functions that we + need for a trainer. + + The trainer needs to have the following functions: + - step: takes in a batch of data and performs a step of training + - loss: takes in a batch of data and returns the loss + - compute_rewards: takes in a batch of data and returns the rewards + - _build_models_and_tokenizer: builds the models and tokenizer + - _build_dataset: builds the dataset + Each user is expected to implement their own trainer class that inherits from this base + if they want to use a new training algorithm. + """ + + def __init__(self, config): + self.config = config + + def step(self, *args): + raise NotImplementedError("Not implemented") + + def loss(self, *args): + raise NotImplementedError("Not implemented") + + def compute_rewards(self, *args): + raise NotImplementedError("Not implemented") + + def _save_pretrained(self, save_directory): + raise NotImplementedError("Not implemented") diff --git a/trl/trainer/ddpo_config.py b/trl/trainer/ddpo_config.py new file mode 100644 index 0000000000000000000000000000000000000000..310861381465fcc1a2b73e48712815617057bb84 --- /dev/null +++ b/trl/trainer/ddpo_config.py @@ -0,0 +1,120 @@ +import os +import sys +import warnings +from dataclasses import dataclass, field +from typing import Literal, Optional + +from ..core import flatten_dict +from ..import_utils import is_bitsandbytes_available, is_torchvision_available + + +@dataclass +class DDPOConfig: + """ + Configuration class for DDPOTrainer + """ + + # common parameters + exp_name: str = os.path.basename(sys.argv[0])[: -len(".py")] + """the name of this experiment (by default is the file name without the extension name)""" + run_name: Optional[str] = "" + """Run name for wandb logging and checkpoint saving.""" + seed: int = 0 + """Seed value for random generations""" + log_with: Optional[Literal["wandb", "tensorboard"]] = None + """Log with either 'wandb' or 'tensorboard', check https://huggingface.co/docs/accelerate/usage_guides/tracking for more details""" + tracker_kwargs: dict = field(default_factory=dict) + """Keyword arguments for the tracker (e.g. wandb_project)""" + accelerator_kwargs: dict = field(default_factory=dict) + """Keyword arguments for the accelerator""" + project_kwargs: dict = field(default_factory=dict) + """Keyword arguments for the accelerator project config (e.g. `logging_dir`)""" + tracker_project_name: str = "trl" + """Name of project to use for tracking""" + logdir: str = "logs" + """Top-level logging directory for checkpoint saving.""" + + # hyperparameters + num_epochs: int = 100 + """Number of epochs to train.""" + save_freq: int = 1 + """Number of epochs between saving model checkpoints.""" + num_checkpoint_limit: int = 5 + """Number of checkpoints to keep before overwriting old ones.""" + mixed_precision: str = "fp16" + """Mixed precision training.""" + allow_tf32: bool = True + """Allow tf32 on Ampere GPUs.""" + resume_from: Optional[str] = "" + """Resume training from a checkpoint.""" + sample_num_steps: int = 50 + """Number of sampler inference steps.""" + sample_eta: float = 1.0 + """Eta parameter for the DDIM sampler.""" + sample_guidance_scale: float = 5.0 + """Classifier-free guidance weight.""" + sample_batch_size: int = 1 + """Batch size (per GPU!) to use for sampling.""" + sample_num_batches_per_epoch: int = 2 + """Number of batches to sample per epoch.""" + train_batch_size: int = 1 + """Batch size (per GPU!) to use for training.""" + train_use_8bit_adam: bool = False + """Whether to use the 8bit Adam optimizer from bitsandbytes.""" + train_learning_rate: float = 3e-4 + """Learning rate.""" + train_adam_beta1: float = 0.9 + """Adam beta1.""" + train_adam_beta2: float = 0.999 + """Adam beta2.""" + train_adam_weight_decay: float = 1e-4 + """Adam weight decay.""" + train_adam_epsilon: float = 1e-8 + """Adam epsilon.""" + train_gradient_accumulation_steps: int = 1 + """Number of gradient accumulation steps.""" + train_max_grad_norm: float = 1.0 + """Maximum gradient norm for gradient clipping.""" + train_num_inner_epochs: int = 1 + """Number of inner epochs per outer epoch.""" + train_cfg: bool = True + """Whether or not to use classifier-free guidance during training.""" + train_adv_clip_max: float = 5 + """Clip advantages to the range.""" + train_clip_range: float = 1e-4 + """The PPO clip range.""" + train_timestep_fraction: float = 1.0 + """The fraction of timesteps to train on.""" + per_prompt_stat_tracking: bool = False + """Whether to track statistics for each prompt separately.""" + per_prompt_stat_tracking_buffer_size: int = 16 + """Number of reward values to store in the buffer for each prompt.""" + per_prompt_stat_tracking_min_count: int = 16 + """The minimum number of reward values to store in the buffer.""" + async_reward_computation: bool = False + """Whether to compute rewards asynchronously.""" + max_workers: int = 2 + """The maximum number of workers to use for async reward computation.""" + negative_prompts: Optional[str] = "" + """Comma-separated list of prompts to use as negative examples.""" + + def to_dict(self): + output_dict = {} + for key, value in self.__dict__.items(): + output_dict[key] = value + return flatten_dict(output_dict) + + def __post_init__(self): + if self.log_with not in ["wandb", "tensorboard"]: + warnings.warn( + ("Accelerator tracking only supports image logging if `log_with` is set to 'wandb' or 'tensorboard'.") + ) + + if self.log_with == "wandb" and not is_torchvision_available(): + warnings.warn("Wandb image logging requires torchvision to be installed") + + if self.train_use_8bit_adam and not is_bitsandbytes_available(): + raise ImportError( + "You need to install bitsandbytes to use 8bit Adam. " + "You can install it with `pip install bitsandbytes`." + ) diff --git a/trl/trainer/ddpo_trainer.py b/trl/trainer/ddpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..0f1cee12f1c02917e634b11f358ba839a39fd910 --- /dev/null +++ b/trl/trainer/ddpo_trainer.py @@ -0,0 +1,576 @@ +# Copyright 2023 DDPO-pytorch authors (Kevin Black), metric-space, The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from collections import defaultdict +from concurrent import futures +from typing import Any, Callable, Optional, Tuple +from warnings import warn + +import torch +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed + +from ..models import DDPOStableDiffusionPipeline +from . import BaseTrainer, DDPOConfig +from .utils import PerPromptStatTracker + + +logger = get_logger(__name__) + + +class DDPOTrainer(BaseTrainer): + """ + The DDPOTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models. + Note, this trainer is heavily inspired by the work here: https://github.com/kvablack/ddpo-pytorch + As of now only Stable Diffusion based pipelines are supported + + Attributes: + **config** (`DDPOConfig`) -- Configuration object for DDPOTrainer. Check the documentation of `PPOConfig` for more + details. + **reward_function** (Callable[[torch.Tensor, Tuple[str], Tuple[Any]], torch.Tensor]) -- Reward function to be used + **prompt_function** (Callable[[], Tuple[str, Any]]) -- Function to generate prompts to guide model + **sd_pipeline** (`DDPOStableDiffusionPipeline`) -- Stable Diffusion pipeline to be used for training. + **image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images + """ + + def __init__( + self, + config: DDPOConfig, + reward_function: Callable[[torch.Tensor, Tuple[str], Tuple[Any]], torch.Tensor], + prompt_function: Callable[[], Tuple[str, Any]], + sd_pipeline: DDPOStableDiffusionPipeline, + image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None, + ): + if image_samples_hook is None: + warn("No image_samples_hook provided; no images will be logged") + + self.prompt_fn = prompt_function + self.reward_fn = reward_function + self.config = config + self.image_samples_callback = image_samples_hook + + accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs) + + if self.config.resume_from: + self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from)) + if "checkpoint_" not in os.path.basename(self.config.resume_from): + # get the most recent checkpoint in this directory + checkpoints = list( + filter( + lambda x: "checkpoint_" in x, + os.listdir(self.config.resume_from), + ) + ) + if len(checkpoints) == 0: + raise ValueError(f"No checkpoints found in {self.config.resume_from}") + checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints]) + self.config.resume_from = os.path.join( + self.config.resume_from, + f"checkpoint_{checkpoint_numbers[-1]}", + ) + + accelerator_project_config.iteration = checkpoint_numbers[-1] + 1 + + # number of timesteps within each trajectory to train on + self.num_train_timesteps = int(self.config.sample_num_steps * self.config.train_timestep_fraction) + + self.accelerator = Accelerator( + log_with=self.config.log_with, + mixed_precision=self.config.mixed_precision, + project_config=accelerator_project_config, + # we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the + # number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get + # the total number of optimizer steps to accumulate across. + gradient_accumulation_steps=self.config.train_gradient_accumulation_steps * self.num_train_timesteps, + **self.config.accelerator_kwargs, + ) + + is_okay, message = self._config_check() + if not is_okay: + raise ValueError(message) + + is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard" + + if self.accelerator.is_main_process: + self.accelerator.init_trackers( + self.config.tracker_project_name, + config=dict(ddpo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(), + init_kwargs=self.config.tracker_kwargs, + ) + + logger.info(f"\n{config}") + + set_seed(self.config.seed, device_specific=True) + + self.sd_pipeline = sd_pipeline + + self.sd_pipeline.set_progress_bar_config( + position=1, + disable=not self.accelerator.is_local_main_process, + leave=False, + desc="Timestep", + dynamic_ncols=True, + ) + + # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + if self.accelerator.mixed_precision == "fp16": + inference_dtype = torch.float16 + elif self.accelerator.mixed_precision == "bf16": + inference_dtype = torch.bfloat16 + else: + inference_dtype = torch.float32 + + self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype) + self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype) + self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype) + + trainable_layers = self.sd_pipeline.get_trainable_layers() + + self.accelerator.register_save_state_pre_hook(self._save_model_hook) + self.accelerator.register_load_state_pre_hook(self._load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if self.config.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + self.optimizer = self._setup_optimizer(trainable_layers.parameters()) + + self.neg_prompt_embed = self.sd_pipeline.text_encoder( + self.sd_pipeline.tokenizer( + [""] if self.config.negative_prompts is None else self.config.negative_prompts, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.sd_pipeline.tokenizer.model_max_length, + ).input_ids.to(self.accelerator.device) + )[0] + + if config.per_prompt_stat_tracking: + self.stat_tracker = PerPromptStatTracker( + config.per_prompt_stat_tracking_buffer_size, + config.per_prompt_stat_tracking_min_count, + ) + + # NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses + # more memory + self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast + + self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer) + + if self.config.async_reward_computation: + self.executor = futures.ThreadPoolExecutor(max_workers=config.max_workers) + + if config.resume_from: + logger.info(f"Resuming from {config.resume_from}") + self.accelerator.load_state(config.resume_from) + self.first_epoch = int(config.resume_from.split("_")[-1]) + 1 + else: + self.first_epoch = 0 + + def compute_rewards(self, prompt_image_pairs, is_async=False): + if not is_async: + rewards = [] + for images, prompts, prompt_metadata in prompt_image_pairs: + reward, reward_metadata = self.reward_fn(images, prompts, prompt_metadata) + rewards.append( + ( + torch.as_tensor(reward, device=self.accelerator.device), + reward_metadata, + ) + ) + else: + rewards = self.executor.map(lambda x: self.reward_fn(*x), prompt_image_pairs) + rewards = [ + (torch.as_tensor(reward.result(), device=self.accelerator.device), reward_metadata.result()) + for reward, reward_metadata in rewards + ] + + return zip(*rewards) + + def step(self, epoch: int, global_step: int): + """ + Perform a single step of training. + + Args: + epoch (int): The current epoch. + global_step (int): The current global step. + + Side Effects: + - Model weights are updated + - Logs the statistics to the accelerator trackers. + - If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker. + + Returns: + global_step (int): The updated global step. + + """ + samples, prompt_image_data = self._generate_samples( + iterations=self.config.sample_num_batches_per_epoch, + batch_size=self.config.sample_batch_size, + ) + + # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...) + samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()} + rewards, rewards_metadata = self.compute_rewards( + prompt_image_data, is_async=self.config.async_reward_computation + ) + + for i, image_data in enumerate(prompt_image_data): + image_data.extend([rewards[i], rewards_metadata[i]]) + + if self.image_samples_callback is not None: + self.image_samples_callback(prompt_image_data, global_step, self.accelerator.trackers[0]) + + rewards = torch.cat(rewards) + rewards = self.accelerator.gather(rewards).cpu().numpy() + + self.accelerator.log( + { + "reward": rewards, + "epoch": epoch, + "reward_mean": rewards.mean(), + "reward_std": rewards.std(), + }, + step=global_step, + ) + + if self.config.per_prompt_stat_tracking: + # gather the prompts across processes + prompt_ids = self.accelerator.gather(samples["prompt_ids"]).cpu().numpy() + prompts = self.sd_pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True) + advantages = self.stat_tracker.update(prompts, rewards) + else: + advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8) + + # ungather advantages; keep the entries corresponding to the samples on this process + samples["advantages"] = ( + torch.as_tensor(advantages) + .reshape(self.accelerator.num_processes, -1)[self.accelerator.process_index] + .to(self.accelerator.device) + ) + + del samples["prompt_ids"] + + total_batch_size, num_timesteps = samples["timesteps"].shape + + for inner_epoch in range(self.config.train_num_inner_epochs): + # shuffle samples along batch dimension + perm = torch.randperm(total_batch_size, device=self.accelerator.device) + samples = {k: v[perm] for k, v in samples.items()} + + # shuffle along time dimension independently for each sample + # still trying to understand the code below + perms = torch.stack( + [torch.randperm(num_timesteps, device=self.accelerator.device) for _ in range(total_batch_size)] + ) + + for key in ["timesteps", "latents", "next_latents", "log_probs"]: + samples[key] = samples[key][ + torch.arange(total_batch_size, device=self.accelerator.device)[:, None], + perms, + ] + + original_keys = samples.keys() + original_values = samples.values() + # rebatch them as user defined train_batch_size is different from sample_batch_size + reshaped_values = [v.reshape(-1, self.config.train_batch_size, *v.shape[1:]) for v in original_values] + + # Transpose the list of original values + transposed_values = zip(*reshaped_values) + # Create new dictionaries for each row of transposed values + samples_batched = [dict(zip(original_keys, row_values)) for row_values in transposed_values] + + self.sd_pipeline.unet.train() + global_step = self._train_batched_samples(inner_epoch, epoch, global_step, samples_batched) + # ensure optimization step at the end of the inner epoch + if not self.accelerator.sync_gradients: + raise ValueError( + "Optimization step should have been performed by this point. Please check calculated gradient accumulation settings." + ) + + if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process: + self.accelerator.save_state() + + return global_step + + def calculate_loss(self, latents, timesteps, next_latents, log_probs, advantages, embeds): + """ + Calculate the loss for a batch of an unpacked sample + + Args: + latents (torch.Tensor): + The latents sampled from the diffusion model, shape: [batch_size, num_steps, ...] + timesteps (torch.Tensor): + The timesteps sampled from the diffusion model, shape: [batch_size] + next_latents (torch.Tensor): + The next latents sampled from the diffusion model, shape: [batch_size, num_steps, ...] + log_probs (torch.Tensor): + The log probabilities of the latents, shape: [batch_size] + advantages (torch.Tensor): + The advantages of the latents, shape: [batch_size] + embeds (torch.Tensor): + The embeddings of the prompts, shape: [2*batch_size or batch_size, ...] + Note: the "or" is because if train_cfg is True, the expectation is that negative prompts are concatenated to the embeds + + Returns: + loss (torch.Tensor), approx_kl (torch.Tensor), clipfrac (torch.Tensor) + (all of these are of shape (1,)) + """ + with self.autocast(): + if self.config.train_cfg: + noise_pred = self.sd_pipeline.unet( + torch.cat([latents] * 2), + torch.cat([timesteps] * 2), + embeds, + ).sample + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.config.sample_guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + else: + noise_pred = self.sd_pipeline.unet( + latents, + timesteps, + embeds, + ).sample + # compute the log prob of next_latents given latents under the current model + + scheduler_step_output = self.sd_pipeline.scheduler_step( + noise_pred, + timesteps, + latents, + eta=self.config.sample_eta, + prev_sample=next_latents, + ) + + log_prob = scheduler_step_output.log_probs + + advantages = torch.clamp( + advantages, + -self.config.train_adv_clip_max, + self.config.train_adv_clip_max, + ) + + ratio = torch.exp(log_prob - log_probs) + + loss = self.loss(advantages, self.config.train_clip_range, ratio) + + approx_kl = 0.5 * torch.mean((log_prob - log_probs) ** 2) + + clipfrac = torch.mean((torch.abs(ratio - 1.0) > self.config.train_clip_range).float()) + + return loss, approx_kl, clipfrac + + def loss( + self, + advantages: torch.Tensor, + clip_range: float, + ratio: torch.Tensor, + ): + unclipped_loss = -advantages * ratio + clipped_loss = -advantages * torch.clamp( + ratio, + 1.0 - clip_range, + 1.0 + clip_range, + ) + return torch.mean(torch.maximum(unclipped_loss, clipped_loss)) + + def _setup_optimizer(self, trainable_layers_parameters): + if self.config.train_use_8bit_adam: + import bitsandbytes + + optimizer_cls = bitsandbytes.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + return optimizer_cls( + trainable_layers_parameters, + lr=self.config.train_learning_rate, + betas=(self.config.train_adam_beta1, self.config.train_adam_beta2), + weight_decay=self.config.train_adam_weight_decay, + eps=self.config.train_adam_epsilon, + ) + + def _save_model_hook(self, models, weights, output_dir): + self.sd_pipeline.save_checkpoint(models, weights, output_dir) + weights.pop() # ensures that accelerate doesn't try to handle saving of the model + + def _load_model_hook(self, models, input_dir): + self.sd_pipeline.load_checkpoint(models, input_dir) + models.pop() # ensures that accelerate doesn't try to handle loading of the model + + def _generate_samples(self, iterations, batch_size): + """ + Generate samples from the model + + Args: + iterations (int): Number of iterations to generate samples for + batch_size (int): Batch size to use for sampling + + Returns: + samples (List[Dict[str, torch.Tensor]]), prompt_image_pairs (List[List[Any]]) + """ + samples = [] + prompt_image_pairs = [] + self.sd_pipeline.unet.eval() + + sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1) + + for _ in range(iterations): + prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)]) + + prompt_ids = self.sd_pipeline.tokenizer( + prompts, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.sd_pipeline.tokenizer.model_max_length, + ).input_ids.to(self.accelerator.device) + prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0] + + with self.autocast(): + sd_output = self.sd_pipeline( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=sample_neg_prompt_embeds, + num_inference_steps=self.config.sample_num_steps, + guidance_scale=self.config.sample_guidance_scale, + eta=self.config.sample_eta, + output_type="pt", + ) + + images = sd_output.images + latents = sd_output.latents + log_probs = sd_output.log_probs + + latents = torch.stack(latents, dim=1) # (batch_size, num_steps + 1, ...) + log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1) + timesteps = self.sd_pipeline.scheduler.timesteps.repeat(batch_size, 1) # (batch_size, num_steps) + + samples.append( + { + "prompt_ids": prompt_ids, + "prompt_embeds": prompt_embeds, + "timesteps": timesteps, + "latents": latents[:, :-1], # each entry is the latent before timestep t + "next_latents": latents[:, 1:], # each entry is the latent after timestep t + "log_probs": log_probs, + "negative_prompt_embeds": sample_neg_prompt_embeds, + } + ) + prompt_image_pairs.append([images, prompts, prompt_metadata]) + + return samples, prompt_image_pairs + + def _train_batched_samples(self, inner_epoch, epoch, global_step, batched_samples): + """ + Train on a batch of samples. Main training segment + + Args: + inner_epoch (int): The current inner epoch + epoch (int): The current epoch + global_step (int): The current global step + batched_samples (List[Dict[str, torch.Tensor]]): The batched samples to train on + + Side Effects: + - Model weights are updated + - Logs the statistics to the accelerator trackers. + + Returns: + global_step (int): The updated global step + """ + info = defaultdict(list) + for i, sample in enumerate(batched_samples): + if self.config.train_cfg: + # concat negative prompts to sample prompts to avoid two forward passes + embeds = torch.cat([sample["negative_prompt_embeds"], sample["prompt_embeds"]]) + else: + embeds = sample["prompt_embeds"] + + for j in range(self.num_train_timesteps): + with self.accelerator.accumulate(self.sd_pipeline.unet): + loss, approx_kl, clipfrac = self.calculate_loss( + sample["latents"][:, j], + sample["timesteps"][:, j], + sample["next_latents"][:, j], + sample["log_probs"][:, j], + sample["advantages"], + embeds, + ) + info["approx_kl"].append(approx_kl) + info["clipfrac"].append(clipfrac) + info["loss"].append(loss) + + self.accelerator.backward(loss) + if self.accelerator.sync_gradients: + self.accelerator.clip_grad_norm_( + self.trainable_layers.parameters(), + self.config.train_max_grad_norm, + ) + self.optimizer.step() + self.optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if self.accelerator.sync_gradients: + # log training-related stuff + info = {k: torch.mean(torch.stack(v)) for k, v in info.items()} + info = self.accelerator.reduce(info, reduction="mean") + info.update({"epoch": epoch, "inner_epoch": inner_epoch}) + self.accelerator.log(info, step=global_step) + global_step += 1 + info = defaultdict(list) + return global_step + + def _config_check(self) -> Tuple[bool, str]: + samples_per_epoch = ( + self.config.sample_batch_size * self.accelerator.num_processes * self.config.sample_num_batches_per_epoch + ) + total_train_batch_size = ( + self.config.train_batch_size + * self.accelerator.num_processes + * self.config.train_gradient_accumulation_steps + ) + + if not self.config.sample_batch_size >= self.config.train_batch_size: + return ( + False, + f"Sample batch size ({self.config.sample_batch_size}) must be greater than or equal to the train batch size ({self.config.train_batch_size})", + ) + if not self.config.sample_batch_size % self.config.train_batch_size == 0: + return ( + False, + f"Sample batch size ({self.config.sample_batch_size}) must be divisible by the train batch size ({self.config.train_batch_size})", + ) + if not samples_per_epoch % total_train_batch_size == 0: + return ( + False, + f"Number of samples per epoch ({samples_per_epoch}) must be divisible by the total train batch size ({total_train_batch_size})", + ) + return True, "" + + def train(self, epochs: Optional[int] = None): + """ + Train the model for a given number of epochs + """ + global_step = 0 + if epochs is None: + epochs = self.config.num_epochs + for epoch in range(self.first_epoch, epochs): + global_step = self.step(epoch, global_step) + + def _save_pretrained(self, save_directory): + self.sd_pipeline.save_pretrained(save_directory) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..0015f42af0bf9d278e5ba10fb747414d93aaef7e --- /dev/null +++ b/trl/trainer/dpo_trainer.py @@ -0,0 +1,782 @@ +# DPO Authors: Rafael Rafailov, Archit Sharma, Eric Mitchell, Stefano Ermon, Christopher D. Manning, and Chelsea Finn 2023 +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import random +import warnings +from collections import defaultdict +from copy import deepcopy +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from accelerate.utils import is_deepspeed_available +from datasets import Dataset +from torch.utils.data import DataLoader +from transformers import ( + AutoModelForCausalLM, + DataCollator, + PreTrainedModel, + PreTrainedTokenizerBase, + Trainer, + TrainingArguments, +) +from transformers.trainer_callback import TrainerCallback +from transformers.trainer_utils import EvalLoopOutput + +from ..import_utils import is_peft_available, is_wandb_available +from ..models import PreTrainedModelWrapper, create_reference_model +from .utils import DPODataCollatorWithPadding, disable_dropout_in_model, pad_to_length + + +if is_peft_available(): + from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training + + +if is_wandb_available(): + import wandb + +if is_deepspeed_available(): + import deepspeed + + +class DPOTrainer(Trainer): + r""" + Initialize DPOTrainer. + + Args: + model (`transformers.PreTrainedModel`): + The model to train, preferably an `AutoModelForSequenceClassification`. + ref_model (`PreTrainedModelWrapper`): + Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no + reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized. + beta (`float`, defaults to 0.1): + The beta factor in DPO loss. Higher beta means less divergence from the initial policy. For the IPO loss, beta is the regularization parameter denoted by tau in the paper. + label_smoothing (`float`, defaults to 0): + The robust DPO label smoothing parameter from the [cDPO](https://ericmitchell.ai/cdpo.pdf) report that should be between 0 and 0.5. + loss_type (`str`, defaults to `"sigmoid"`): + The type of DPO loss to use. Either `"sigmoid"` the default DPO loss,`"hinge"` loss from [SLiC](https://arxiv.org/abs/2305.10425) paper, `"ipo"` from [IPO](https://arxiv.org/abs/2310.12036) paper, or `"kto"` from the HALOs [report](https://github.com/ContextualAI/HALOs/blob/main/assets/report.pdf). + args (`transformers.TrainingArguments`): + The arguments to use for training. + data_collator (`transformers.DataCollator`): + The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used + which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences. + label_pad_token_id (`int`, defaults to `-100`): + The label pad token id. This argument is required if you want to use the default data collator. + padding_value (`int`, defaults to `0`): + The padding value. This argument is required if you want to use the default data collator. + truncation_mode (`str`, defaults to `keep_end`): + The truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the default data collator. + train_dataset (`datasets.Dataset`): + The dataset to use for training. + eval_dataset (`datasets.Dataset`): + The dataset to use for evaluation. + tokenizer (`transformers.PreTrainedTokenizerBase`): + The tokenizer to use for training. This argument is required if you want to use the default data collator. + model_init (`Callable[[], transformers.PreTrainedModel]`): + The model initializer to use for training. If None is specified, the default model initializer will be used. + callbacks (`List[transformers.TrainerCallback]`): + The callbacks to use for training. + optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): + The optimizer and scheduler to use for training. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): + The function to use to preprocess the logits before computing the metrics. + max_length (`int`, defaults to `None`): + The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator. + max_prompt_length (`int`, defaults to `None`): + The maximum length of the prompt. This argument is required if you want to use the default data collator. + max_target_length (`int`, defaults to `None`): + The maximum length of the target. This argument is required if you want to use the default data collator and your model is an encoder-decoder. + peft_config (`Dict`, defaults to `None`): + The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model. + is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`): + If no model is provided, we need to know if the model_init returns an encoder-decoder. + disable_dropout (`bool`, defaults to `True`): + Whether or not to disable dropouts in `model` and `ref_model`. + generate_during_eval (`bool`, defaults to `False`): + Whether to sample and log generations during evaluation step. + compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*): + The function to use to compute the metrics. Must take a `EvalPrediction` and return + a dictionary string to metric values. + model_init_kwargs: (`Optional[Dict]`, *optional*): + Dict of Optional kwargs to pass when instantiating the model from a string + ref_model_init_kwargs: (`Optional[Dict]`, *optional*): + Dict of Optional kwargs to pass when instantiating the ref model from a string + + """ + + def __init__( + self, + model: Union[PreTrainedModel, nn.Module, str] = None, + ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, + beta: float = 0.1, + label_smoothing: float = 0, + loss_type: Literal["sigmoid", "hinge", "ipo", "kto"] = "sigmoid", + args: TrainingArguments = None, + data_collator: Optional[DataCollator] = None, + label_pad_token_id: int = -100, + padding_value: int = 0, + truncation_mode: str = "keep_end", + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, + callbacks: Optional[List[TrainerCallback]] = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( + None, + None, + ), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + max_length: Optional[int] = None, + max_prompt_length: Optional[int] = None, + max_target_length: Optional[int] = None, + peft_config: Optional[Dict] = None, + is_encoder_decoder: Optional[bool] = None, + disable_dropout: bool = True, + generate_during_eval: bool = False, + compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None, + model_init_kwargs: Optional[Dict] = None, + ref_model_init_kwargs: Optional[Dict] = None, + ): + if model_init_kwargs is None: + model_init_kwargs = {} + elif not isinstance(model, str): + raise ValueError("You passed model_kwargs to the DPOTrainer. But your model is already instantiated.") + + if ref_model_init_kwargs is None: + ref_model_init_kwargs = {} + elif not isinstance(ref_model, str): + raise ValueError( + "You passed ref_model_kwargs to the DPOTrainer. But your ref_model is already instantiated." + ) + + if isinstance(model, str): + warnings.warn( + "You passed a model_id to the DPOTrainer. This will automatically create an " + "`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you." + ) + model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) + + if isinstance(ref_model, str): + warnings.warn( + "You passed a ref model_id to the DPOTrainer. This will automatically create an " + "`AutoModelForCausalLM`" + ) + ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs) + + if not is_peft_available() and peft_config is not None: + raise ValueError( + "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" + ) + elif is_peft_available() and peft_config is not None: + # if model is a peft model and we have a peft_config, we merge and unload it first + if isinstance(model, PeftModel): + model = model.merge_and_unload() + + if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): + _support_gc_kwargs = hasattr( + args, "gradient_checkpointing_kwargs" + ) and "gradient_checkpointing_kwargs" in list( + inspect.signature(prepare_model_for_kbit_training).parameters + ) + + preprare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} + + if _support_gc_kwargs: + preprare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs + + model = prepare_model_for_kbit_training(model, **preprare_model_kwargs) + elif getattr(args, "gradient_checkpointing", False): + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + # get peft model with the given config + model = get_peft_model(model, peft_config) + + # For models that use gradient_checkpoiting, we need to attach a hook that enables input + # to explicitly have `requires_grad=True`, otherwise training will either silently + # fail or completely fail. + elif getattr(args, "gradient_checkpointing", False): + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + if generate_during_eval and not is_wandb_available(): + raise ValueError( + "`generate_during_eval=True` requires Weights and Biases to be installed." + " Please install `wandb` to resolve." + ) + + if model is not None: + self.is_encoder_decoder = model.config.is_encoder_decoder + elif is_encoder_decoder is None: + raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.") + else: + self.is_encoder_decoder = is_encoder_decoder + + self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) + + if ref_model: + self.ref_model = ref_model + elif self.is_peft_model: + # The `model` with adapters turned off will be used as the reference model + self.ref_model = None + else: + self.ref_model = create_reference_model(model) + + if data_collator is None: + if tokenizer is None: + raise ValueError( + "max_length or a tokenizer must be specified when using the default DPODataCollatorWithPadding" + ) + if max_length is None: + warnings.warn( + "When using DPODataCollatorWithPadding, you should set `max_length` in the DPOTrainer's init" + " it will be set to `512` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_length = 512 + if max_prompt_length is None: + warnings.warn( + "When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the DPOTrainer's init" + " it will be set to `128` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_prompt_length = 128 + + if max_target_length is None and self.is_encoder_decoder: + warnings.warn( + "When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_target_length` in the DPOTrainer's init" + " it will be set to `128` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_target_length = 128 + + data_collator = DPODataCollatorWithPadding( + tokenizer, + max_length=max_length, + max_prompt_length=max_prompt_length, + label_pad_token_id=label_pad_token_id, + padding_value=padding_value, + truncation_mode=truncation_mode, + is_encoder_decoder=self.is_encoder_decoder, + max_target_length=max_target_length, + ) + + if args.remove_unused_columns: + args.remove_unused_columns = False + # warn users + warnings.warn( + "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments" + " we have set it for you, but you should do it yourself in the future.", + UserWarning, + ) + + self.use_dpo_data_collator = True + else: + self.use_dpo_data_collator = False + + if disable_dropout: + disable_dropout_in_model(model) + if self.ref_model is not None: + disable_dropout_in_model(self.ref_model) + + self.max_length = max_length + self.generate_during_eval = generate_during_eval + self.label_pad_token_id = label_pad_token_id + self.padding_value = padding_value + + if loss_type in ["hinge", "ipo", "kto"] and label_smoothing > 0: + warnings.warn( + "You are using a loss type that does not support label smoothing. Ignoring label_smoothing parameter." + ) + + self.beta = beta + self.label_smoothing = label_smoothing + self.loss_type = loss_type + + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + model_init=model_init, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + if not hasattr(self, "accelerator"): + raise AttributeError( + "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." + ) + + if self.ref_model is None: + if not hasattr(self.accelerator.unwrap_model(self.model), "disable_adapter"): + raise ValueError( + "You are using a `peft` version that does not support `disable_adapter`. Please update your `peft` version to the latest version." + ) + else: + if self.is_deepspeed_enabled: + self.ref_model = self._prepare_deepspeed(self.ref_model) + else: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + + def _prepare_deepspeed(self, model: PreTrainedModelWrapper): + # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 + deepspeed_plugin = self.accelerator.state.deepspeed_plugin + config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config) + + if model is not None: + if hasattr(model, "config"): + hidden_size = ( + max(model.config.hidden_sizes) + if getattr(model.config, "hidden_sizes", None) + else getattr(model.config, "hidden_size", None) + ) + if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3: + # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0` + # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081 + config_kwargs.update( + { + "zero_optimization.reduce_bucket_size": hidden_size * hidden_size, + "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, + "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size, + } + ) + + # If ZeRO-3 is used, we shard both the active and reference model. + # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0) + if config_kwargs["zero_optimization"]["stage"] != 3: + config_kwargs["zero_optimization"]["stage"] = 0 + model, *_ = deepspeed.initialize(model=model, config=config_kwargs) + model.eval() + return model + + def concatenated_inputs(self, batch: Dict[str, Union[List, torch.LongTensor]]) -> Dict[str, torch.LongTensor]: + """Concatenate the chosen and rejected inputs into a single tensor. + + Args: + batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length). + + Returns: + A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'. + """ + concatenated_batch = {} + + if self.is_encoder_decoder: + max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1]) + else: + max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) + + for k in batch: + if k.startswith("chosen") and isinstance(batch[k], torch.Tensor): + pad_value = self.label_pad_token_id if "labels" in k or self.is_encoder_decoder else self.padding_value + concatenated_key = k.replace("chosen", "concatenated") + concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) + for k in batch: + if k.startswith("rejected") and isinstance(batch[k], torch.Tensor): + pad_value = self.label_pad_token_id if "labels" in k or self.is_encoder_decoder else self.padding_value + concatenated_key = k.replace("rejected", "concatenated") + concatenated_batch[concatenated_key] = torch.cat( + ( + concatenated_batch[concatenated_key], + pad_to_length(batch[k], max_length, pad_value=pad_value), + ), + dim=0, + ).to(self.accelerator.device) + + if self.is_encoder_decoder: + concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1) + concatenated_batch["concatenated_attention_mask"] = batch["prompt_attention_mask"].repeat(2, 1) + + return concatenated_batch + + def dpo_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + reference_chosen_logps: torch.FloatTensor, + reference_rejected_logps: torch.FloatTensor, + reference_free: bool = False, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Compute the DPO loss for a batch of policy and reference model log probabilities. + + Args: + policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) + policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) + reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,) + reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,) + reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses. + + Returns: + A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). + The losses tensor contains the DPO loss for each example in the batch. + The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. + """ + pi_logratios = policy_chosen_logps - policy_rejected_logps + if reference_free: + ref_logratios = 0 + else: + ref_logratios = reference_chosen_logps - reference_rejected_logps + + logits = pi_logratios - ref_logratios + + # The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. + # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and + # calculates a conservative DPO loss. + if self.loss_type == "sigmoid": + losses = ( + -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * logits) * self.label_smoothing + ) + elif self.loss_type == "hinge": + losses = torch.relu(1 - self.beta * logits) + elif self.loss_type == "ipo": + # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper. + losses = (logits - 1 / (2 * self.beta)) ** 2 + elif self.loss_type == "kto": + # eqn (7) of the HALOs paper + chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clamp(min=0) + rejected_KL = (policy_rejected_logps - reference_rejected_logps).mean().clamp(min=0) + + chosen_logratios = policy_chosen_logps - reference_chosen_logps + rejected_logratios = policy_rejected_logps - reference_rejected_logps + # As described in the KTO report, the KL term for chosen (rejected) is estimated using the rejected (chosen) half. + losses = torch.cat( + ( + 1 - F.sigmoid(self.beta * (chosen_logratios - rejected_KL)), + 1 - F.sigmoid(self.beta * (chosen_KL - rejected_logratios)), + ), + 0, + ) + else: + raise ValueError( + f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto']" + ) + + chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach() + rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach() + + return losses, chosen_rewards, rejected_rewards + + def _get_batch_logps( + self, + logits: torch.FloatTensor, + labels: torch.LongTensor, + average_log_prob: bool = False, + ) -> torch.FloatTensor: + """Compute the log probabilities of the given labels under the given logits. + + Args: + logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) + labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length) + average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. + + Returns: + A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. + """ + if logits.shape[:-1] != labels.shape: + raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") + + if not self.is_encoder_decoder: + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] + loss_mask = labels != self.label_pad_token_id + + # dummy token; we'll ignore the losses on these tokens later + labels[labels == self.label_pad_token_id] = 0 + + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + + if average_log_prob: + return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + else: + return (per_token_logps * loss_mask).sum(-1) + + def concatenated_forward( + self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. + + We do this to avoid doing two forward passes, because it's faster for FSDP. + """ + concatenated_batch = self.concatenated_inputs(batch) + len_chosen = batch["chosen_labels"].shape[0] + + model_kwargs = ( + { + "labels": concatenated_batch["concatenated_labels"], + "decoder_input_ids": concatenated_batch.pop("concatenated_decoder_input_ids", None), + } + if self.is_encoder_decoder + else {} + ) + all_logits = model( + concatenated_batch["concatenated_input_ids"], + attention_mask=concatenated_batch["concatenated_attention_mask"], + **model_kwargs, + ).logits.to(torch.float32) + + all_logps = self._get_batch_logps( + all_logits, + concatenated_batch["concatenated_labels"], + average_log_prob=False, + ) + + chosen_logps = all_logps[:len_chosen] + rejected_logps = all_logps[len_chosen:] + + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] + + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits) + + def get_batch_metrics( + self, + model, + batch: Dict[str, Union[List, torch.LongTensor]], + train_eval: Literal["train", "eval"] = "train", + ): + """Compute the DPO loss and other metrics for the given batch of inputs for train or test.""" + metrics = {} + + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + ) = self.concatenated_forward(model, batch) + with torch.no_grad(): + if self.ref_model is None: + with self.accelerator.unwrap_model(self.model).disable_adapter(): + ( + reference_chosen_logps, + reference_rejected_logps, + _, + _, + ) = self.concatenated_forward(self.model, batch) + else: + ( + reference_chosen_logps, + reference_rejected_logps, + _, + _, + ) = self.concatenated_forward(self.ref_model, batch) + + losses, chosen_rewards, rejected_rewards = self.dpo_loss( + policy_chosen_logps, + policy_rejected_logps, + reference_chosen_logps, + reference_rejected_logps, + ) + reward_accuracies = (chosen_rewards > rejected_rewards).float() + + prefix = "eval_" if train_eval == "eval" else "" + metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().mean() + metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu().mean() + metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().mean() + metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().mean() + metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu().mean() + metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().mean() + metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().cpu().mean() + metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().mean() + + return losses.mean(), metrics + + def compute_loss( + self, + model: Union[PreTrainedModel, nn.Module], + inputs: Dict[str, Union[torch.Tensor, Any]], + return_outputs=False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: + if not self.use_dpo_data_collator: + warnings.warn( + "compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than " + "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator" + ) + loss, metrics = self.get_batch_metrics(model, inputs, train_eval="train") + + # force log the metrics + if self.accelerator.is_main_process: + self.store_metrics(metrics, train_eval="train") + + if return_outputs: + return (loss, metrics) + return loss + + def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: + """Generate samples from the model and reference model for the given batch of inputs.""" + + policy_output = model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.tokenizer.pad_token_id, + ) + + if self.ref_model is None: + with self.accelerator.unwrap_model(self.model).disable_adapter(): + reference_output = self.model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.tokenizer.pad_token_id, + ) + else: + reference_output = self.ref_model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.tokenizer.pad_token_id, + ) + + policy_output = pad_to_length(policy_output, self.max_length, self.tokenizer.pad_token_id) + policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True) + + reference_output = pad_to_length(reference_output, self.max_length, self.tokenizer.pad_token_id) + reference_output_decoded = self.tokenizer.batch_decode(reference_output, skip_special_tokens=True) + + return policy_output_decoded, reference_output_decoded + + def prediction_step( + self, + model: Union[PreTrainedModel, nn.Module], + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ): + if not self.use_dpo_data_collator: + warnings.warn( + "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than " + "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator" + ) + if ignore_keys is None: + if hasattr(model, "config"): + ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + with torch.no_grad(): + loss, metrics = self.get_batch_metrics(model, inputs, train_eval="eval") + + # force log the metrics + if self.accelerator.is_main_process: + self.store_metrics(metrics, train_eval="eval") + + if prediction_loss_only: + return (loss.detach(), None, None) + + # logits for the chosen and rejected samples from model + logits_dict = { + "eval_logits/chosen": metrics["eval_logits/chosen"], + "eval_logits/rejected": metrics["eval_logits/rejected"], + } + logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys) + logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device) + labels = torch.zeros(logits.shape[0], device=self.accelerator.device) + + return (loss.detach(), logits, labels) + + def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None: + for key, value in metrics.items(): + self._stored_metrics[train_eval][key].append(value) + + def evaluation_loop( + self, + dataloader: DataLoader, + description: str, + prediction_loss_only: Optional[bool] = None, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "eval", + ) -> EvalLoopOutput: + """ + Overriding built-in evaluation loop to store metrics for each batch. + Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. + + Works both with or without labels. + """ + + # Sample and save to game log if requested (for one batch to save time) + if self.generate_during_eval: + # Generate random indices within the range of the total number of samples + num_samples = len(dataloader.dataset) + random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size) + + # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader + random_batch_dataset = dataloader.dataset.select(random_indices) + random_batch = self.data_collator(random_batch_dataset) + random_batch = self._prepare_inputs(random_batch) + + policy_output_decoded, ref_output_decoded = self.get_batch_samples(self.model, random_batch) + + self.log( + { + "game_log": wandb.Table( + columns=["Prompt", "Policy", "Ref Model"], + rows=[ + [prompt, pol[len(prompt) :], ref[len(prompt) :]] + for prompt, pol, ref in zip( + random_batch["prompt"], policy_output_decoded, ref_output_decoded + ) + ], + ) + } + ) + self.state.log_history.pop() + + # Base evaluation + initial_output = super().evaluation_loop( + dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix + ) + + return initial_output + + def log(self, logs: Dict[str, float]) -> None: + """ + Log `logs` on the various objects watching training, including stored metrics. + + Args: + logs (`Dict[str, float]`): + The values to log. + """ + # logs either has 'loss' or 'eval_loss' + train_eval = "train" if "loss" in logs else "eval" + # Add averaged stored metrics to logs + for key, metrics in self._stored_metrics[train_eval].items(): + logs[key] = torch.tensor(metrics).mean().item() + del self._stored_metrics[train_eval] + return super().log(logs) diff --git a/trl/trainer/iterative_sft_trainer.py b/trl/trainer/iterative_sft_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..006b02ad5123aa704a1d482571da464ac9c81f45 --- /dev/null +++ b/trl/trainer/iterative_sft_trainer.py @@ -0,0 +1,367 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import warnings +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +from datasets import Dataset +from torch.utils.data import DataLoader +from transformers import ( + DataCollator, + DataCollatorForLanguageModeling, + DataCollatorForSeq2Seq, + PreTrainedModel, + PreTrainedTokenizerBase, + Trainer, + TrainingArguments, +) +from transformers.trainer_utils import EvalLoopOutput + +from ..core import PPODecorators +from ..import_utils import is_peft_available + + +if is_peft_available(): + from peft import PeftModel + + +class IterativeSFTTrainer(Trainer): + """ + The IterativeSFTTrainer can be used to finetune models with methods that requires some steps between optimization. + + Attributes: + **model** (`PreTrainedModel`) -- Model to be optimized, either an 'AutoModelForCausalLM' or an 'AutoModelForSeq2SeqLM'. + Check the documentation of `PreTrainedModel` for more details. + **args** (`transformers.TrainingArguments`): -- The arguments to use for training. + **tokenizer** (`PreTrainedTokenizerBase`) -- Tokenizer to be used for encoding the + data. Check the documentation of `transformers.PreTrainedTokenizer` and + `transformers.PreTrainedTokenizerFast` for more details. + **optimizers** (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): -- The optimizer and scheduler to use for training. + **data_collator** (Union[DataCollatorForLanguageModeling, DataCollatorForSeq2Seq], *optional*) -- Data collator to be used for training and + passed along the dataloader. + **eval_dataset** (`datasets.Dataset`): The dataset to use for evaluation. + **max_length** (`int`, defaults to `None`): -- The maximum length of the input. + **truncation_mode** (`str`, defaults to `keep_end`): -- The truncation mode to use, either `keep_end` or `keep_start`. + **preprocess_logits_for_metrics** (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): -- The function to use to preprocess the logits before computing the metrics. + **compute_metrics** (`Callable[[EvalPrediction], Dict]`, *optional*): -- The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to metric values. + **optimize_device_cache ** (`bool`, *optional*, defaults to `False`) -- Optimize CUDA cache for slightly more memory-efficient training. + """ + + def __init__( + self, + model: PreTrainedModel = None, + args: TrainingArguments = None, + tokenizer: PreTrainedTokenizerBase = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( + None, + None, + ), + data_collator: Optional[DataCollator] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + max_length: Optional[int] = None, + truncation_mode: Optional[str] = "keep_end", + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None, + optimize_device_cache: Optional[bool] = False, + ): + # Step 0: check positional arguments validity + if not isinstance(tokenizer, (PreTrainedTokenizerBase)): + raise ValueError( + f"tokenizer must be a PreTrainedTokenizerBase like a PreTrainedTokenizer or a PreTrainedTokenizerFast, got {type(tokenizer)}" + ) + if not isinstance(model, PreTrainedModel): + raise ValueError(f"model must be a PreTrainedModel, got {type(model)}") + if not model.can_generate(): + warnings.warn( + f"The current model class {type(model)} is not compatible with `.generate()`" + "Please make sure that this is intended." + ) + if optimizers[1] is None and args.max_steps == -1: + raise ValueError( + "When no scheduler is provided, you need to set the total number of training steps to perform `max_steps`" + ) + + self.is_encoder_decoder = getattr(model.config, "is_encoder_decoder", False) + self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) + + self.tokenizer = tokenizer + + if data_collator is None: + if self.is_encoder_decoder: + warnings.warn( + "No data collator is provided. Using 'DataCollatorForSeq2Seq' with" + "'labels_pad_token_id' set to '-100' and 'pad_to_multiple_of' set to 8." + ) + self.data_collator = DataCollatorForSeq2Seq(tokenizer, label_pad_token_id=-100, pad_to_multiple_of=8) + else: + warnings.warn("No data collator is provided. Using 'DataCollatorForLanguageModeling'") + self.data_collator = DataCollatorForLanguageModeling(self.tokenizer, mlm=False) + else: + self.data_collator = data_collator + + self.max_length = max_length + self.truncation_mode = truncation_mode + self.optimize_device_cache = optimize_device_cache + + super().__init__( + model=model, + args=args, + data_collator=self.data_collator, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + compute_metrics=compute_metrics, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + self.create_optimizer_and_scheduler(self.args.max_steps) + + # prepare model, optimizer and lr_scheduler + self.model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( + self.model, self.optimizer, self.lr_scheduler + ) + + self.tokenizer.truncation_side = "left" if self.truncation_mode == "keep_end" else "right" + + if not hasattr(self, "accelerator"): + raise AttributeError( + "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." + ) + + PPODecorators.optimize_device_cache = self.optimize_device_cache + + def prepare_model_inputs(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor): + if attention_mask is None: + attention_mask = [torch.ones_like(ids) for ids in input_ids] + + if self.is_encoder_decoder: + input_data = self.data_collator( + [ + {"input_ids": ids, "attention_mask": att, "labels": lab} + for ids, att, lab in zip(input_ids, attention_mask, labels) + ] + ).to(self.model.device) + + input_data.pop("decoder_input_ids", None) # This is directly computed inside the model + + input_data["labels"][input_data["labels"] == self.tokenizer.pad_token_id] = -100 + + else: + input_data = self.data_collator( + [{"input_ids": ids, "attention_mask": att} for ids, att in zip(input_ids, attention_mask)] + ).to(self.model.device) + + # truncate in case the user has provided input_ids, attention_mask and labels + if self.max_length is not None: + if self.truncation_mode == "keep_start": + input_data = {k: v[: self.max_length] for k, v in input_data.items()} + elif self.truncation_mode == "keep_end": + input_data = {k: v[-self.max_length :] for k, v in input_data.items()} + else: + raise ValueError(f"Unknown truncation mode: {self.truncation_mode}") + + return input_data + + @staticmethod + def _step_safety_checker( + input_ids: List[torch.LongTensor], + attention_mask: List[torch.LongTensor], + labels: List[torch.LongTensor], + texts: List[str], + texts_labels: List[str], + ): + """ + Check if the input data is valid for training. + + Args: + input_ids (List[`torch.LongTensor`]): + List of tensors containing the input_ids + attention_mask (List[`torch.LongTensor`]): + List of tensors containing the attention_mask + labels (List[`torch.FloatTensor`]): + List of tensors containing the labels + texts (List[`str`]): + List of string containing the text input. + texts_labels (List[`str`]): + List of string containing the text labels. + Returns: + `tuple`: The input data. + """ + if texts is None: + if attention_mask is None: + for name, tensor_list in zip(["input_ids", "labels"], [input_ids, labels]): + if not isinstance(tensor_list, list): + raise ValueError(f"{name} must be a list of tensors - got {type(tensor_list)}") + if not isinstance(tensor_list[0], torch.Tensor): + raise ValueError(f"Elements in {name} must be tensors - got {type(tensor_list[0])}") + else: + for name, tensor_list in zip( + ["input_ids", "attention_mask", "labels"], [input_ids, attention_mask, labels] + ): + if not isinstance(tensor_list, list): + raise ValueError(f"{name} must be a list of tensors - got {type(tensor_list)}") + if not isinstance(tensor_list[0], torch.Tensor): + raise ValueError(f"Elements in {name} must be tensors - got {type(tensor_list[0])}") + else: + if not isinstance(texts, list): + raise ValueError(f"'text' must be a list of strings - got {type(texts)}") + if not isinstance(texts[0], str): + raise ValueError(f"Elements in 'text' must be strings - got {type(texts[0])}") + if texts_labels is not None: + if not isinstance(texts_labels, list): + raise ValueError(f"'text_labels' must be a list of strings - got {type(texts_labels)}") + if not isinstance(texts_labels[0], str): + raise ValueError(f"Elements in 'text_labels' must be strings - got {type(texts_labels[0])}") + + return input_ids, attention_mask, labels, texts, texts_labels + + @PPODecorators.empty_device_cache() + def step( + self, + input_ids: Optional[List[torch.LongTensor]] = None, + attention_mask: Optional[List[torch.LongTensor]] = None, + labels: Optional[List[torch.LongTensor]] = None, + texts: Optional[List[str]] = None, + texts_labels: Optional[List[str]] = None, + ): + """ + Run an optimisation step given a list of input_ids, attention_mask, and labels or a list of text and text_labels. + Args: + input_ids (List[`torch.LongTensor`]): + List of tensors containing the input_ids (if not provided, text will be used) + attention_mask (List[`torch.LongTensor`], , *optional*): + List of tensors containing the attention_mask + labels (List[`torch.FloatTensor`], *optional*): + List of tensors containing the labels (if set to None, will default to input_ids) + texts (List[`str`], *optional*): + List of strings containing the text input (if not provided, input_ids will directly be used) + texts_labels (List[`str`], *optional*): + List of strings containing the text labels (if set to None, will default to text) + Returns: + `dict[str, Any]`: A summary of the training statistics + """ + self.model.train() + + if self.state.global_step == 0: + self.tr_loss = torch.tensor(0.0).to(self.args.device) + self._globalstep_last_logged = self.state.global_step + + if input_ids is None and texts is None: + raise ValueError("Step should include `input_ids` or `texts` as keyword arguments.") + elif input_ids is not None and texts is not None: + warnings.warn( + "Both 'input_ids' and 'texts' are provided. 'input_ids' will be overwritten using inputs provided by the 'texts' keyword argument." + ) + + if labels is None and texts_labels is None and self.is_encoder_decoder: + raise ValueError( + "No 'labels' or 'text_labels' are provided. When using an encoder-decoder architecture, 'labels' or 'text_labels' must be passed." + ) + + input_ids, attention_mask, labels, texts, texts_labels = self._step_safety_checker( + input_ids, attention_mask, labels, texts, texts_labels + ) + + if texts is not None: + model_inputs = self.tokenizer( + texts, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt" + ) + + input_ids, attention_mask = model_inputs["input_ids"], model_inputs["attention_mask"] + + if texts_labels is not None: + labels = self.tokenizer( + texts, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt" + )["input_ids"] + + if labels is None: + warnings.warn("No labels are provided. Setting labels to input_ids") + labels = input_ids + + model_inputs = self.prepare_model_inputs(input_ids, attention_mask, labels) + + model_inputs_names = list(model_inputs.keys()) + + batch_dict = {} + batch_dict.update(model_inputs) + + def collator(data): + return_dict = dict() + for key in data[0]: + if key in ["input_ids", "attention_mask", "labels"]: + return_dict[key] = torch.stack([d[key] for d in data]).to(self.model.device) + return return_dict + + batch_data = Dataset.from_dict(batch_dict) + batch_data.set_format("torch") + + step_dataloader = DataLoader( + batch_data, + batch_size=self.args.per_device_train_batch_size, + shuffle=True, + collate_fn=collator, + ) + + for _, batch in enumerate(step_dataloader): + with self.accelerator.accumulate(self.model): + model_inputs = {k: batch[k] for k in model_inputs_names} + loss = self.compute_loss(self.model, model_inputs) + + if self.args.n_gpu > 1: + loss = loss.mean() + + tr_loss_step = loss.detach() + + self.accelerator.backward(loss) + + if self.accelerator.sync_gradients and self.args.max_grad_norm is not None: + self.accelerator.clip_grad_norm_( + self.model.parameters(), + self.args.max_grad_norm, + ) + + self.optimizer.step() + self.optimizer.zero_grad() + if self.lr_scheduler is not None: + self.lr_scheduler.step() + + self.state.global_step += 1 + + # update stats etc + self.tr_loss += tr_loss_step + + self._maybe_log_save_evaluate() + + def _maybe_log_save_evaluate(self): + # check if eval is required + if self.args.eval_steps is not None: + if self.state.global_step % self.args.eval_steps == 0 and self.state.global_step != 0: + self.evaluate(self.eval_dataset) + + # check if logging is required + if self.args.logging_steps is not None: + if self.state.global_step % self.args.logging_steps == 0 and self.state.global_step != 0: + logs: Dict[str, float] = {} + + tr_loss_scalar = self._nested_gather(self.tr_loss).mean().item() + + # reset tr_loss to zero + self.tr_loss -= self.tr_loss + + logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) + logs["learning_rate"] = self._get_learning_rate() + + self._globalstep_last_logged = self.state.global_step + + self.log(logs) diff --git a/trl/trainer/ppo_config.py b/trl/trainer/ppo_config.py new file mode 100644 index 0000000000000000000000000000000000000000..0f6a907b6461a74f79a92b0f666a8f2cc2505d99 --- /dev/null +++ b/trl/trainer/ppo_config.py @@ -0,0 +1,179 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os +import sys +import warnings +from dataclasses import dataclass, field +from typing import Literal, Optional + +import numpy as np +import tyro +from typing_extensions import Annotated + +from trl.trainer.utils import exact_div + +from ..core import flatten_dict +from ..import_utils import is_wandb_available + + +JSONDict = Annotated[Optional[dict], tyro.conf.arg(metavar="JSON", constructor=json.loads)] + + +@dataclass +class PPOConfig: + """ + Configuration class for PPOTrainer + """ + + # common parameters + exp_name: str = os.path.basename(sys.argv[0])[: -len(".py")] + """the name of this experiment (by default is the file name without the extension name)""" + seed: int = 0 + """Seed value for random generations""" + log_with: Optional[Literal["wandb", "tensorboard"]] = None + """Log with either 'wandb' or 'tensorboard', check https://huggingface.co/docs/accelerate/usage_guides/tracking for more details""" + task_name: Optional[str] = None + """Name of task to use - used only for tracking purposes""" + model_name: Optional[str] = None + """Name of model to use - used only for tracking purposes""" + query_dataset: Optional[str] = None + """Name of dataset to query - used only for tracking purposes""" + reward_model: Optional[str] = None + """The reward model to use - used only for tracking purposes""" + remove_unused_columns: bool = True + """Remove unused columns from the dataset if `datasets.Dataset` is used""" + tracker_kwargs: JSONDict = field(default_factory=dict) + """Keyword arguments for the tracker (e.g. python ppo.py --ppo_config.tracker_kwargs='{"wandb": {"entity": "my_wandb_entity", "name": "my_exp_name"}}'""" + accelerator_kwargs: JSONDict = field(default_factory=dict) + """Keyword arguments for the accelerator""" + project_kwargs: JSONDict = field(default_factory=dict) + """Keyword arguments for the accelerator project config (e.g. `logging_dir`)""" + tracker_project_name: str = "trl" + """Name of project to use for tracking""" + push_to_hub_if_best_kwargs: JSONDict = field(default_factory=dict) + """Keyword arguments for pushing model to the hub during training (e.g. repo_id)""" + + # hyperparameters + steps: int = 20000 + """Number of training steps""" + learning_rate: float = 1e-5 + """Adam learning rate""" + adap_kl_ctrl: bool = True + """Use adaptive KL control, otherwise linear""" + init_kl_coef: Optional[float] = 0.2 + """Initial KL penalty coefficient (used for adaptive and linear control)""" + kl_penalty: Literal["kl", "abs", "mse", "full"] = "kl" + """kl penalty options: 'kl': model_logp - ref_logp, 'abs': abs(kl), 'mse': mean squared error mse(kl) and 'full': the actual kl for all tokens in the distribution""" + target: Optional[float] = 6 + """Target KL value for adaptive KL control""" + horizon: Optional[float] = 10000 + """Horizon for adaptive KL control""" + gamma: float = 1 + """Gamma parameter for advantage calculation""" + lam: float = 0.95 + """Lambda parameter for advantage calculation""" + cliprange: float = 0.2 + """Range for clipping in PPO policy gradient loss""" + cliprange_value: float = 0.2 + """Range for clipping values in loss calculation""" + vf_coef: float = 0.1 + """Scaling factor for value loss""" + batch_size: int = 256 + """Number of samples per optimisation step""" + forward_batch_size: Optional[int] = None + """DEPRECATED: use `mini_batch_size` instead, which does the same thing.""" + mini_batch_size: int = 1 + """Number of samples optimized in each mini batch""" + gradient_accumulation_steps: int = 1 + """The number of gradient accumulation steps""" + world_size: tyro.conf.Suppress[int] = None + """The world size for distributed training""" + ppo_epochs: int = 4 + """Number of optimisation epochs per batch of samples""" + max_grad_norm: Optional[float] = None + """Maximum gradient norm for gradient clipping""" + optimize_cuda_cache: Optional[bool] = None + """DEPRECATED: use `optimize_device_cache` instead, which does the same thing.""" + optimize_device_cache: Optional[bool] = False + """Optimize device cache for slightly more memory-efficient training""" + early_stopping: bool = False + """Whether to stop the PPO optimization loop early is the KL too high""" + target_kl: float = 1 + """Stop early if we exceed this value by over 50%""" + compare_steps: int = 1 + """Number of steps between comparison of the current reward with the best seen so far""" + ratio_threshold: float = 10.0 + """Skip mini-batches with high PPO ratios that can cause loss spikes""" + use_score_scaling: bool = False + """Use score scaling""" + use_score_norm: bool = False + """Use score normalization. Only applicable if use_score_scaling is True""" + score_clip: Optional[float] = None + """Score clipping""" + whiten_rewards: bool = False + """Whiten the rewards before compute advantages""" + + # computed hyperparameters at runtime; we use `tyro.conf.Suppress` to hide them from the help text + is_encoder_decoder: Optional[tyro.conf.Suppress[bool]] = None + """TO BE FILLED In RUNTIME: Whether the model is an encoder-decoder model""" + is_peft_model: Optional[tyro.conf.Suppress[bool]] = None + """TO BE FILLED In RUNTIME: Whether the model is a PEFT model""" + backward_batch_size: tyro.conf.Suppress[int] = None + """TO BE FILLED In RUNTIME: Number of samples optimized in an `optimizer.step()` call""" + global_backward_batch_size: tyro.conf.Suppress[int] = None + """TO BE FILLED In RUNTIME: the effective `backward_batch_size` across all processes""" + global_batch_size: tyro.conf.Suppress[int] = None + """TO BE FILLED In RUNTIME: the effective `batch_size` across all processes""" + + if optimize_cuda_cache is not None: + warnings.warn( + "The `optimize_cuda_cache` argument will be deprecated soon, please use `optimize_device_cache` instead." + ) + optimize_device_cache = optimize_cuda_cache + else: + optimize_device_cache = False + + def __post_init__(self): + if self.forward_batch_size is not None: + warnings.warn( + "Note that using `forward_batch_size` is deprecated, use `mini_batch_size` instead. By setting it you overwrite `mini_batch_size` which affects both the batch size during forward passes and also the mini batch size for PPO optimization." + ) + self.mini_batch_size = self.forward_batch_size + + self.backward_batch_size = self.mini_batch_size * self.gradient_accumulation_steps + exact_div( + self.batch_size, + self.backward_batch_size, + "`batch_size`", + "`mini_batch_size * gradient_accumulation_steps`", + "`batch_size` must be a multiple of `mini_batch_size * gradient_accumulation_steps`", + ) + + # check if wandb is installed + if self.log_with == "wandb": + # raise error if wandb is not installed + if not is_wandb_available(): + raise ImportError( + "Please install wandb to use wandb logging. You can do this by running `pip install wandb`." + ) + + self.total_ppo_epochs = int(np.ceil(self.steps / self.batch_size)) + assert self.kl_penalty in ["kl", "abs", "mse", "full"] + + def to_dict(self): + output_dict = {} + for key, value in self.__dict__.items(): + output_dict[key] = value + return flatten_dict(output_dict) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..007e77a2686726bc8d1678e5b2f23f9092b6a573 --- /dev/null +++ b/trl/trainer/ppo_trainer.py @@ -0,0 +1,1440 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import math +import os +import time +import typing +import warnings +from contextlib import nullcontext +from typing import Callable, List, Optional, Union + +import datasets +import numpy as np +import torch +import torch.nn.functional as F +from accelerate import Accelerator +from accelerate.utils import ProjectConfiguration, gather_object, is_deepspeed_available +from datasets import Dataset +from huggingface_hub import whoami +from packaging import version +from torch.optim import Adam +from transformers import ( + DataCollatorForLanguageModeling, + PreTrainedTokenizer, + PreTrainedTokenizerBase, + PreTrainedTokenizerFast, +) + +from ..core import ( + WANDB_PADDING, + PPODecorators, + clip_by_value, + convert_to_scalar, + entropy_from_logits, + flatten_dict, + logprobs_from_logits, + masked_mean, + masked_var, + masked_whiten, + set_seed, + stack_dicts, + stats_to_np, +) +from ..import_utils import is_torch_greater_2_0, is_xpu_available +from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper, create_reference_model +from . import AdaptiveKLController, BaseTrainer, FixedKLController, PPOConfig, RunningMoments + + +if is_deepspeed_available(): + import deepspeed + +MODEL_CARD_TEMPLATE = """--- +license: apache-2.0 +tags: +- trl +- transformers +- reinforcement-learning +--- + +# {model_name} + +This is a [TRL language model](https://github.com/huggingface/trl) that has been fine-tuned with reinforcement learning to + guide the model outputs according to a value, function, or human feedback. The model can be used for text generation. + +## Usage + +To use this model for inference, first install the TRL library: + +```bash +python -m pip install trl +``` + +You can then generate text as follows: + +```python +from transformers import pipeline + +generator = pipeline("text-generation", model="{model_id}") +outputs = generator("Hello, my llama is cute") +``` + +If you want to use the model for training or to obtain the outputs from the value head, load the model as follows: + +```python +from transformers import AutoTokenizer +from trl import AutoModelForCausalLMWithValueHead + +tokenizer = AutoTokenizer.from_pretrained("{model_id}") +model = AutoModelForCausalLMWithValueHead.from_pretrained("{model_id}") + +inputs = tokenizer("Hello, my llama is cute", return_tensors="pt") +outputs = model(**inputs, labels=inputs["input_ids"]) +``` +""" + + +class PPOTrainer(BaseTrainer): + """ + The PPOTrainer uses Proximal Policy Optimization to optimise language models. + Note, this trainer is heavily inspired by the original OpenAI learning to summarize work here: + https://github.com/openai/summarize-from-feedback + + Attributes: + **config** (`PPOConfig`) -- Configuration object for PPOTrainer. Check the documentation of `PPOConfig` for more + details. + **model** (`PreTrainedModelWrapper`) -- Model to be optimized, Hugging Face transformer model with a value head. + Check the documentation of `PreTrainedModelWrapper` for more details. + **ref_model** (`PreTrainedModelWrapper`, *optional*) -- Reference model to be used for KL penalty, Hugging Face + transformer model with a casual language modelling head. Check the documentation of `PreTrainedModelWrapper` + for more details. If no reference model is provided, the trainer will create a reference model with the same + architecture as the model to be optimized with shared layers. + **tokenizer** (`PreTrainedTokenizerBase`) -- Tokenizer to be used for encoding the + data. Check the documentation of `transformers.PreTrainedTokenizer` and + `transformers.PreTrainedTokenizerFast` for more details. + **dataset** (Union[`torch.utils.data.Dataset`, `datasets.Dataset`], *optional*) -- PyTorch dataset or Hugging + Face dataset. This is used to create a PyTorch dataloader. If no dataset is provided, the dataloader must be + created outside the trainer users needs to design their own dataloader and make sure the batch + size that is used is the same as the one specified in the configuration object. + **optimizer** (`torch.optim.Optimizer`, *optional*) -- Optimizer to be used for training. If no optimizer is + provided, the trainer will create an Adam optimizer with the learning rate specified in the configuration + object. + **data_collator** (DataCollatorForLanguageModeling, *optional*) -- Data collator to be used for training and + passed along the dataloader + **num_shared_layers** (int, *optional*) -- Number of layers to be shared between the model and the reference + model, if no reference model is passed. If no number is provided, all the layers will be shared. + **lr_scheduler** (`torch.optim.lr_scheduler`, *optional*) -- Learning rate scheduler to be used for training. + """ + + def __init__( + self, + config: PPOConfig = None, + model: PreTrainedModelWrapper = None, + ref_model: Optional[PreTrainedModelWrapper] = None, + tokenizer: PreTrainedTokenizerBase = None, + dataset: Optional[Union[torch.utils.data.Dataset, Dataset]] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + data_collator: Optional[typing.Callable] = None, + num_shared_layers: Optional[int] = None, + lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + ): + """ + Initialize PPOTrainer. + + Args: + config (`PPOConfig`): + Configuration object for PPOTrainer. Check the documentation of `PPOConfig` for more details. + model (`PreTrainedModelWrapper`): + Hugging Face transformer model with a value head. + ref_model (`PreTrainedModelWrapper`): + Hugging Face transformer model with a casual language modelling head. Used for KL penalty + tokenizer (`transformers.PreTrainedTokenizerBase`): + Hugging Face tokenizer + dataset (Optional[Union[`torch.utils.data.Dataset`, `datasets.Dataset`]]): + PyTorch dataset or Hugging Face dataset. If a Hugging Face dataset is passed, the dataset + will be preprocessed by removing the columns that are not used by the model. If none is passed, + a warning will be raised in a multi-GPU setting. + optimizer (Optional[`torch.optim.Optimizer`]): + Optimizer used for training. If `None`, the `Adam` is used as default. + data_collator (Optional[function]): + Data collator function. + num_shared_layers (Optional[int]): + Number of shared layers between the model and the reference model. If `None`, all layers are shared. + used only if `ref_model` is `None`. + lr_scheduler (Optional[`torch.optim.lr_scheduler`]): + Learning rate scheduler used for training. + """ + super().__init__(config) + + # initial seed for reproducible experiments + set_seed(config.seed) + + # Step 0: check positional arguments validity + if not isinstance(config, PPOConfig): + raise ValueError(f"config must be a PPOConfig, got {type(config)}") + if not isinstance(tokenizer, (PreTrainedTokenizerBase)): + raise ValueError( + f"tokenizer must be a PreTrainedTokenizerBase like a PreTrainedTokenizer or a PreTrainedTokenizerFast, got {type(tokenizer)}" + ) + if not isinstance(model, (SUPPORTED_ARCHITECTURES)): + raise ValueError( + f"model must be a PreTrainedModelWrapper, got {type(model)} - supported architectures are: {SUPPORTED_ARCHITECTURES}" + ) + # Step 1: Initialize Accelerator + self.accelerator = Accelerator( + log_with=config.log_with, + gradient_accumulation_steps=config.gradient_accumulation_steps, + project_config=ProjectConfiguration(**config.project_kwargs), + **config.accelerator_kwargs, + ) + + # Step 1.1 Runtime variables filled by the accelerator + config.world_size = self.accelerator.num_processes + config.global_backward_batch_size = config.backward_batch_size * config.world_size + config.global_batch_size = config.batch_size * config.world_size + + self.model = model + self.model_params = filter(lambda p: p.requires_grad, self.model.parameters()) + self.is_encoder_decoder = hasattr(self.model, "is_encoder_decoder") + self.is_peft_model = getattr(self.model, "is_peft_model", False) + config.is_encoder_decoder = self.is_encoder_decoder + config.is_peft_model = self.is_peft_model + + is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard" + self.accelerator.init_trackers( + config.tracker_project_name, + config=dict(trl_ppo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(), + init_kwargs=config.tracker_kwargs, + ) + self.is_using_text_environment = getattr(config, "use_text_environment", False) + + if isinstance(ref_model, SUPPORTED_ARCHITECTURES): + self.ref_model = ref_model + if num_shared_layers is not None: + warnings.warn( + "num_shared_layers is ignored when ref_model is provided. Two different models are used for the " + "model and the reference model and no layers are shared.", + UserWarning, + ) + elif ref_model is None and not self.is_peft_model: + self.ref_model = create_reference_model(self.model, num_shared_layers=num_shared_layers) + elif self.is_peft_model: + self.ref_model = None + else: + raise ValueError( + f"ref_model must be a PreTrainedModelWrapper or `None`, got {type(ref_model)} - supported " + f"architectures are: {SUPPORTED_ARCHITECTURES} " + ) + self.optional_peft_ctx = ( + self.accelerator.unwrap_model(self.model).pretrained_model.disable_adapter + if self.is_peft_model + else nullcontext + ) + + if not (isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast)): + raise ValueError( + "tokenizer must be a transformers.PreTrainedTokenizer or transformers.PreTrainedTokenizerFast" + ) + self.tokenizer = tokenizer + + if dataset is not None and not (isinstance(dataset, torch.utils.data.Dataset) or isinstance(dataset, Dataset)): + raise ValueError("dataset must be a torch.utils.data.Dataset or datasets.Dataset") + elif dataset is None: + warnings.warn( + "No dataset is provided. Make sure to set config.batch_size to the correct value before training.", + UserWarning, + ) + self.dataset = dataset + self._signature_columns = None + if self.dataset is not None: + self.dataloader = self.prepare_dataloader(self.dataset, data_collator) + elif self.dataset is None and self.accelerator.num_processes > 1: + warnings.warn( + "No dataset is provided. In a multi-GPU setting, this will lead to an error. You should" + " prepare your dataloader yourself with `dataloader = ppo_trainer.accelerator.prepare(dataloader)`" + " and using `torch.utils.data.DataLoader`, or pass a dataset to the `PPOTrainer`. Please " + " refer to the documentation for more details.", + UserWarning, + ) + self.dataloader = None + else: + self.dataloader = None + + # Step 3: Initialize optimizer and data collator + self.data_collator = DataCollatorForLanguageModeling(self.tokenizer, mlm=False) + if optimizer is None: + self.optimizer = Adam( + filter(lambda p: p.requires_grad, self.model.parameters()), + lr=self.config.learning_rate, + ) + else: + self.optimizer = optimizer + + self.lr_scheduler = lr_scheduler + if self.lr_scheduler is not None: + lr_scheduler_class = ( + torch.optim.lr_scheduler._LRScheduler + if not is_torch_greater_2_0() + else torch.optim.lr_scheduler.LRScheduler + ) + + if not isinstance(self.lr_scheduler, lr_scheduler_class): + raise ValueError( + "lr_scheduler must be a torch.optim.lr_scheduler._LRScheduler or torch.optim.lr_scheduler.LRScheduler (for torch >= 2.0)" + ) + + if self.config.adap_kl_ctrl: + self.kl_ctl = AdaptiveKLController(self.config.init_kl_coef, self.config.target, self.config.horizon) + else: + self.kl_ctl = FixedKLController(self.config.init_kl_coef) + + # Safety checkers for DS integration + is_deepspeed_used = self.accelerator.distributed_type == "DEEPSPEED" and hasattr( + self.accelerator.state, "deepspeed_plugin" + ) + + ( + self.model, + self.optimizer, + self.data_collator, + self.dataloader, + self.lr_scheduler, + ) = self.accelerator.prepare( + self.model, + self.optimizer, + self.data_collator, + self.dataloader, + self.lr_scheduler, + ) + if is_deepspeed_used: + # Quantized models are already set on the correct device + if not self.is_peft_model and not ( + getattr(self.ref_model.pretrained_model, "is_loaded_in_8bit", False) + or getattr(self.ref_model.pretrained_model, "is_loaded_in_4bit", False) + ): + self.ref_model = self._prepare_deepspeed(self.ref_model) + else: + self.ref_model = self.accelerator.prepare(self.ref_model) + + # In a distributed setup, only logging needs to be performed on the main process + # check: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html + # or: https://discuss.pytorch.org/t/use-distributed-data-parallel-correctly/82500/11 + self.is_distributed = self.accelerator.num_processes > 1 + + # init the current step + self.current_step = 0 + + # init variables for pushing model to hub + if config.push_to_hub_if_best_kwargs: + if "repo_id" not in config.push_to_hub_if_best_kwargs: + raise ValueError("You have to specify repo_id in order to push the model to the hub!") + self.push_to_hub_kwargs = config.push_to_hub_if_best_kwargs + self.compare_step = 0 + self.highest_reward = torch.tensor(-float("inf")) + + # post process for PP + if not getattr(self.model, "is_sequential_parallel", False): + self.current_device = self.accelerator.device + else: + if is_xpu_available(): + self.current_device = torch.device("xpu:0") + else: + self.current_device = torch.device("cuda:0") + + PPODecorators.optimize_device_cache = self.config.optimize_device_cache + + self.running = RunningMoments(self.accelerator) + + def _filter_kwargs(self, kwargs, target_func): + """ + filter the keyword arguments that are supported by the target function. + + Args: + kwargs (dict): + Keyword arguments + target_func (function): + Target function + """ + return {k: v for k, v in kwargs.items() if k in inspect.signature(target_func).parameters.keys()} + + def prepare_dataloader(self, dataset: Union[torch.utils.data.Dataset, Dataset], data_collator=None): + """ + Prepare the dataloader for training. + + Args: + dataset (Union[`torch.utils.data.Dataset`, `datasets.Dataset`]): + PyTorch dataset or Hugging Face dataset. If a Hugging Face dataset is passed, the dataset + will be preprocessed by removing the columns that are not used by the model. + data_collator (Optional[function]): + Data collator function. + + Returns: + `torch.utils.data.DataLoader`: PyTorch dataloader + """ + if isinstance(dataset, Dataset): + dataset = self._remove_unused_columns(dataset) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=self.config.batch_size, + collate_fn=data_collator, + shuffle=True, + drop_last=True, + ) + return dataloader + + # Adapted from transformers.Trainer._set_signature_columns_if_needed + def _set_signature_columns_if_needed(self): + if self._signature_columns is None: + # Inspect model forward signature to keep only the arguments it accepts. + signature = inspect.signature(self.model.forward) + self._signature_columns = list(signature.parameters.keys()) + # label => sentiment | we need query and response for logging purpose + self._signature_columns += ["label", "query", "response"] + + # Adapted from transformers.Trainer._remove_unused_columns + def _remove_unused_columns(self, dataset: "Dataset"): + if not self.config.remove_unused_columns: + return dataset + self._set_signature_columns_if_needed() + signature_columns = self._signature_columns + + ignored_columns = list(set(dataset.column_names) - set(signature_columns)) + + columns = [k for k in signature_columns if k in dataset.column_names] + + if version.parse(datasets.__version__) < version.parse("1.4.0"): + dataset.set_format( + type=dataset.format["type"], + columns=columns, + format_kwargs=dataset.format["format_kwargs"], + ) + return dataset + else: + return dataset.remove_columns(ignored_columns) + + def generate( + self, + query_tensor: Union[torch.Tensor, List[torch.Tensor]], + length_sampler: Callable = None, + batch_size: int = 4, + return_prompt: bool = True, + generate_ref_response: bool = False, + **generation_kwargs, + ): + """ + Generate response with the model given the query tensor. + call the `generate` method of the model. + + Args: + query_tensor (`torch.LongTensor`): + A tensor of shape (`seq_len`) containing query tokens or a list of tensors of shape (`seq_len`). + generation_kwargs (dict[str, Any]): + Keyword arguments for generation. + length_sampler (`Callable`, *optional*): + Callable that returns the number of newly generated tokens. + batch_size (`int`, *optional): + Batch size used for generation, defaults to `4`. + return_prompt (`bool`, *optional*): + If set to `False` the prompt is not returned but only the newly generated tokens, defaults to `True`. + generate_ref_response (`bool`, *optional*): + If set to `True` the reference response is also generated, defaults to `False`. + + Returns: + `torch.LongTensor`: A tensor of shape (`batch_size`, `gen_len`) containing response tokens. + """ + if generate_ref_response: + ref_model = self.model if self.is_peft_model else self.ref_model + if isinstance(query_tensor, List): + response = self._generate_batched( + self.model, + query_tensor, + length_sampler=length_sampler, + batch_size=batch_size, + return_prompt=return_prompt, + **generation_kwargs, + ) + if generate_ref_response: + with self.optional_peft_ctx(): + ref_response = self._generate_batched( + ref_model, + query_tensor, + length_sampler=length_sampler, + batch_size=batch_size, + return_prompt=return_prompt, + **generation_kwargs, + ) + + else: + if len(query_tensor.shape) == 2: + raise ValueError( + "query_tensor must be a tensor of shape (`seq_len`) or a list of tensors of shape (`seq_len`)" + ) + + if length_sampler is not None: + generation_kwargs["max_new_tokens"] = length_sampler() + response = self.accelerator.unwrap_model(self.model).generate( + input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs + ) + if generate_ref_response: + with self.optional_peft_ctx(): + ref_response = ref_model.generate(input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs) + + if not return_prompt and not self.is_encoder_decoder: + response = response[:, query_tensor.shape[0] :] + if generate_ref_response: + ref_response = ref_response[:, query_tensor.shape[0] :] + + if generate_ref_response: + return response, ref_response + return response + + def _generate_batched( + self, + model: PreTrainedModelWrapper, + query_tensors: List[torch.Tensor], + length_sampler: Callable = None, + batch_size: int = 4, + return_prompt: bool = True, + pad_to_multiple_of: int = None, + remove_padding: bool = True, + **generation_kwargs, + ): + outputs = [] + + padding_side_default = self.tokenizer.padding_side + if not self.is_encoder_decoder: + self.tokenizer.padding_side = "left" + + # in case we have fewer examples than bs + batch_size = min(len(query_tensors), batch_size) + + for i in range(0, len(query_tensors), batch_size): + if length_sampler is not None: + generation_kwargs["max_new_tokens"] = length_sampler() + + # prevent overflow if query tensors are not even multiple of bs + end_index = min(len(query_tensors), i + batch_size) + + batch = query_tensors[i:end_index] + batch_mask = [torch.ones_like(element) for element in batch] + inputs = {"input_ids": batch, "attention_mask": batch_mask} + + padded_inputs = self.tokenizer.pad( + inputs, + padding=True, + max_length=None, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors="pt", + ).to(self.current_device) + + generations = self.accelerator.unwrap_model(model).generate(**padded_inputs, **generation_kwargs) + + for generation, mask in zip(generations, padded_inputs["attention_mask"]): + if not self.is_encoder_decoder: + output = generation[(1 - mask).sum() :] # remove padding + else: + output = generation + + if not return_prompt and not self.is_encoder_decoder: + output = output[(mask).sum() :] # remove prompt + + if remove_padding and self.tokenizer.eos_token_id in output: + pad_mask = output == self.tokenizer.eos_token_id + pad_start = torch.nonzero(pad_mask, as_tuple=False)[0, 0].item() + output = output[: pad_start + 1] # keep the eos token at the end + + outputs.append(output) + + self.tokenizer.padding_side = padding_side_default + return outputs + + def _step_safety_checker( + self, + batch_size: int, + queries: List[torch.LongTensor], + responses: List[torch.LongTensor], + scores: List[torch.FloatTensor], + masks: Optional[List[torch.LongTensor]] = None, + ): + """ + Check if the input data is valid for training. + + Args: + batch_size (int): + Batch size from the config file. + queries (List[`torch.LongTensor`]): + List of tensors containing the encoded queries of shape (`query_length`) + responses (List[`torch.LongTensor`]): + List of tensors containing the encoded responses of shape (`response_length`) + scores (List[`torch.FloatTensor`]): + List of tensors containing the scores. + masks (List[`torch.LongTensor`], *optional*): + list of optional tensors containing the masks of shape (`query_length` + `response_length`) + Returns: + `tuple`: The input processed data. + """ + for name, tensor_list in zip(["queries", "responses", "scores"], [queries, responses, scores]): + if not isinstance(tensor_list, list): + raise ValueError(f"{name} must be a list of tensors - got {type(tensor_list)}") + if not isinstance(tensor_list[0], torch.Tensor): + raise ValueError(f"Elements in {name} must be tensors - got {type(tensor_list[0])}") + if batch_size is not None and len(tensor_list) != batch_size: + raise ValueError( + f"Batch size ({batch_size}) does not match number of examples - but got {len(tensor_list)} for: {name}" + ) + + # add queries, scores and responses on the correct device + queries = [tensor.to(self.current_device) for tensor in queries] + responses = [tensor.to(self.current_device) for tensor in responses] + scores = [tensor.to(self.current_device) for tensor in scores] + masks = [tensor.to(self.current_device) for tensor in masks] if masks is not None else None + + # squeeze scores if needed + for i, score in enumerate(scores): + if score.dim() > 1: + raise ValueError(f"Scores must be 1-dimensional - got {score.dim()} for {score}") + elif score.dim() == 1: + scores[i] = score.squeeze() + + return queries, responses, scores, masks + + @PPODecorators.empty_device_cache() + def step( + self, + queries: List[torch.LongTensor], + responses: List[torch.LongTensor], + scores: List[torch.FloatTensor], + response_masks: Optional[List[torch.LongTensor]] = None, + ): + """ + Run a PPO optimisation step given a list of queries, model responses, and rewards. + + Args: + queries (List[`torch.LongTensor`]): + List of tensors containing the encoded queries of shape (`query_length`) + responses (List[`torch.LongTensor`]): + List of tensors containing the encoded responses of shape (`response_length`) + scores (List[`torch.FloatTensor`]): + List of tensors containing the scores. + response_masks (List[`torch.FloatTensor`], *optional*)): + List of tensors containing masks of the response tokens. + + Returns: + `dict[str, Any]`: A summary of the training statistics + """ + bs = self.config.batch_size + + queries, responses, scores, response_masks = self._step_safety_checker( + bs, queries, responses, scores, response_masks + ) + scores = torch.tensor(scores, device=self.current_device) + if self.config.use_score_scaling: + # Score scaling + scores_mean, scores_std = self.running.update(scores) + tensor_to_kwargs = dict(dtype=scores.dtype, device=scores.device) + score_scaling_factor = self.running.std.to(**tensor_to_kwargs) + torch.finfo(scores.dtype).eps + if self.config.use_score_norm: + scores = (scores - self.running.mean.to(**tensor_to_kwargs)) / score_scaling_factor + else: + scores /= score_scaling_factor + + if self.config.score_clip is not None: + # Score clipping + scores_dtype = scores.dtype + scores = torch.clip(scores.float(), -self.config.score_clip, self.config.score_clip).to(dtype=scores_dtype) + + # if we want to push best model to the hub + if hasattr(self, "highest_reward"): + if self.compare_step % self.config.compare_steps == 0: + curr_mean_reward = scores.mean() + # if the best reward ever seen + if curr_mean_reward > self.highest_reward: + self.highest_reward = curr_mean_reward + # push model to hub + self.push_to_hub(**self.push_to_hub_kwargs) + self.compare_step += 1 + + timing = dict() + t0 = time.time() + + t = time.time() + + model_inputs = self.prepare_model_inputs(queries, responses) + + if self.is_distributed: + pad_first = self.tokenizer.padding_side == "left" + + model_inputs["input_ids"] = self.accelerator.pad_across_processes( + model_inputs["input_ids"], + dim=1, + pad_index=self.tokenizer.pad_token_id, + pad_first=pad_first, + ) + model_inputs["attention_mask"] = self.accelerator.pad_across_processes( + model_inputs["attention_mask"], dim=1, pad_index=0, pad_first=pad_first + ) + if self.is_encoder_decoder: + model_inputs["decoder_input_ids"] = self.accelerator.pad_across_processes( + model_inputs["decoder_input_ids"], + dim=1, + pad_index=self.tokenizer.pad_token_id, + pad_first=pad_first, + ) + model_inputs["decoder_attention_mask"] = self.accelerator.pad_across_processes( + model_inputs["decoder_attention_mask"], + dim=1, + pad_index=0, + pad_first=pad_first, + ) + + model_inputs_names = list(model_inputs.keys()) + + full_kl_penalty = self.config.kl_penalty == "full" + + with torch.no_grad(): + all_logprobs, logits_or_none, values, masks = self.batched_forward_pass( + self.model, + queries, + responses, + model_inputs, + response_masks=response_masks, + return_logits=full_kl_penalty, + ) + with self.optional_peft_ctx(): + ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass( + self.model if self.is_peft_model else self.ref_model, + queries, + responses, + model_inputs, + return_logits=full_kl_penalty, + ) + + timing["time/ppo/forward_pass"] = time.time() - t + + with torch.no_grad(): + t = time.time() + if full_kl_penalty: + active_full_logprobs = logprobs_from_logits(logits_or_none, None, gather=False) + ref_full_logprobs = logprobs_from_logits(ref_logits_or_none, None, gather=False) + + rewards, non_score_reward = self.compute_rewards( + scores, active_full_logprobs, ref_full_logprobs, masks + ) + else: + rewards, non_score_reward = self.compute_rewards(scores, all_logprobs, ref_logprobs, masks) + timing["time/ppo/compute_rewards"] = time.time() - t + + t = time.time() + values, advantages, returns = self.compute_advantages(values, rewards, masks) + timing["time/ppo/compute_advantages"] = time.time() - t + + # upcast to float32 to avoid dataset issues + batch_dict = { + "queries": queries, + "responses": responses, + "logprobs": all_logprobs.to(torch.float32), + "values": values.to(torch.float32), + "masks": masks, + "advantages": advantages, + "returns": returns, + } + batch_dict.update(model_inputs) + + t = time.time() + all_stats = [] + early_stop = False + for _ in range(self.config.ppo_epochs): + if early_stop: + break + b_inds = np.random.permutation(bs) + for backward_batch_start in range(0, bs, self.config.backward_batch_size): + backward_batch_end = backward_batch_start + self.config.backward_batch_size + backward_batch_inds = b_inds[backward_batch_start:backward_batch_end] + + for mini_batch_start in range(0, self.config.backward_batch_size, self.config.mini_batch_size): + mini_batch_end = mini_batch_start + self.config.mini_batch_size + mini_batch_inds = backward_batch_inds[mini_batch_start:mini_batch_end] + mini_batch_dict = { + "logprobs": batch_dict["logprobs"][mini_batch_inds], + "values": batch_dict["values"][mini_batch_inds], + "masks": batch_dict["masks"][mini_batch_inds], + # hacks: the queries and responses are ragged. + "queries": [batch_dict["queries"][i] for i in mini_batch_inds], + "responses": [batch_dict["responses"][i] for i in mini_batch_inds], + "advantages": batch_dict["advantages"][mini_batch_inds], + "returns": batch_dict["returns"][mini_batch_inds], + } + for k in model_inputs_names: + mini_batch_dict[k] = batch_dict[k][mini_batch_inds] + with self.accelerator.accumulate(self.model): + model_inputs = {k: mini_batch_dict[k] for k in model_inputs_names} + + logprobs, logits, vpreds, _ = self.batched_forward_pass( + self.model, + mini_batch_dict["queries"], + mini_batch_dict["responses"], + model_inputs, + return_logits=True, + ) + train_stats = self.train_minibatch( + mini_batch_dict["logprobs"], + mini_batch_dict["values"], + logprobs, + logits, + vpreds, + mini_batch_dict["masks"], + mini_batch_dict["advantages"], + mini_batch_dict["returns"], + ) + all_stats.append(train_stats) + + # typically, early stopping is done at the epoch level + if self.config.early_stopping: + policykl = train_stats["policy/policykl"] + early_stop = self._early_stop(policykl) + if early_stop: + break + + timing["time/ppo/optimize_step"] = time.time() - t + + t = time.time() + train_stats = stack_dicts(all_stats) + + # reshape advantages/ratios such that they are not averaged. + train_stats["policy/advantages"] = torch.flatten(train_stats["policy/advantages"]).unsqueeze(0) + train_stats["policy/advantages"] = torch.nan_to_num(train_stats["policy/advantages"], WANDB_PADDING) + train_stats["policy/ratio"] = torch.flatten(train_stats["policy/ratio"]).unsqueeze(0) + + stats = self.record_step_stats( + scores=scores, + logprobs=all_logprobs, + ref_logprobs=ref_logprobs, + non_score_reward=non_score_reward, + train_stats=train_stats, + kl_coef=self.kl_ctl.value, + masks=masks, + queries=queries, + responses=responses, + ) + # Gather/Reduce stats from all processes + if self.is_distributed: + stats = self.gather_stats(stats) + stats = stats_to_np(stats) + timing["time/ppo/calc_stats"] = time.time() - t + stats["ppo/learning_rate"] = self.optimizer.param_groups[0]["lr"] + + # Update the KL control - multiply the batch_size by the number of processes + self.kl_ctl.update( + stats["objective/kl"], + self.config.batch_size * self.accelerator.num_processes, + ) + + # Log the total ppo time + timing["time/ppo/total"] = time.time() - t0 + stats.update(timing) + + # post-process stats for tensorboard and other loggers + if self.config.log_with != "wandb": + stats = convert_to_scalar(stats) + + if self.lr_scheduler is not None: + self.lr_scheduler.step() + + return stats + + def _early_stop(self, policykl): + r""" + Handles the early stopping logic. If the policy KL is greater than the target KL, then the gradient is zeroed and + the optimization step is skipped. + This also handles the multi-gpu case where the policy KL is averaged across all processes. + + Args: + policy_kl (torch.Tensor): + the policy KL + + Returns: + `bool`: whether to early stop or not + """ + early_stop = False + if not self.config.early_stopping: + return early_stop + + if not self.is_distributed and policykl > 1.5 * self.config.target_kl: + self.optimizer.zero_grad() + early_stop = True + elif self.is_distributed: + import torch.distributed as dist + + # Wait for all processes to finish + dist.barrier() + + # all gather the policykl + dist.all_reduce(policykl, dist.ReduceOp.SUM) + policykl /= self.accelerator.num_processes + + if policykl > 1.5 * self.config.target_kl: + self.optimizer.zero_grad() + early_stop = True + return early_stop + + def gather_stats(self, stats): + """ + Gather stats from all processes. Useful in the context of distributed training. + + Args: + stats (dict[str, Any]): + a dictionary of stats to be gathered. The stats should contain torch tensors. + + Returns: + `dict[str, Any]`: A dictionary of stats with the tensors gathered. + """ + import torch.distributed as dist + + # Wait for all processes to finish + dist.barrier() + + for k, v in stats.items(): + if isinstance(v, torch.Tensor): + dist.all_reduce(v.to(self.accelerator.device), dist.ReduceOp.SUM) + v /= self.accelerator.num_processes + stats[k] = v + return stats + + def prepare_model_inputs(self, queries: torch.Tensor, responses: torch.Tensor): + if self.is_encoder_decoder: + input_data = self.data_collator( + [{"input_ids": q, "attention_mask": torch.ones_like(q)} for q in queries] + ).to(self.current_device) + + decoder_inputs = self.data_collator( + [{"input_ids": r, "attention_mask": torch.ones_like(r)} for r in responses] + ).to(self.current_device) + + input_data["decoder_input_ids"] = decoder_inputs["input_ids"] + input_data["decoder_attention_mask"] = decoder_inputs["attention_mask"] + else: + input_ids = [torch.cat([q, r]) for q, r in zip(queries, responses)] + input_data = self.data_collator( + [{"input_ids": ids, "attention_mask": torch.ones_like(ids)} for ids in input_ids] + ).to(self.current_device) + + input_data.pop("labels", None) # we don't want to compute LM losses + return input_data + + @PPODecorators.empty_device_cache() + def batched_forward_pass( + self, + model: PreTrainedModelWrapper, + queries: torch.Tensor, + responses: torch.Tensor, + model_inputs: dict, + return_logits: bool = False, + response_masks: Optional[torch.Tensor] = None, + ): + """ + Calculate model outputs in multiple batches. + + Args: + queries (`torch.LongTensor`): + List of tensors containing the encoded queries, shape (`batch_size`, `query_length`) + responses (`torch.LongTensor`): + List of tensors containing the encoded responses, shape (`batch_size`, `response_length`) + return_logits (`bool`, *optional*, defaults to `False`): + Whether to return all_logits. Set to `False` if logits are not needed to reduce memory consumption. + Returns: + (tuple): + - all_logprobs (`torch.FloatTensor`): Log probabilities of the responses, + shape (`batch_size`, `response_length`) + - all_ref_logprobs (`torch.FloatTensor`): Log probabilities of the responses, + shape (`batch_size`, `response_length`) + - all_values (`torch.FloatTensor`): Values of the responses, shape (`batch_size`, `response_length`) + """ + bs = len(queries) + fbs = self.config.mini_batch_size + all_logprobs = [] + all_logits = [] + all_masks = [] + all_values = [] + + model.eval() + + for i in range(math.ceil(bs / fbs)): + input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()} + query_batch = queries[i * fbs : (i + 1) * fbs] + response_batch = responses[i * fbs : (i + 1) * fbs] + if response_masks is not None: + response_masks_batch = response_masks[i * fbs : (i + 1) * fbs] + logits, _, values = model(**input_kwargs) + + if self.is_encoder_decoder: + input_ids = input_kwargs["decoder_input_ids"] + attention_mask = input_kwargs["decoder_attention_mask"] + else: + input_ids = input_kwargs["input_ids"] + attention_mask = input_kwargs["attention_mask"] + + logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:]) + masks = torch.zeros_like(attention_mask) + masks[:, :-1] = attention_mask[:, 1:] + + for j in range(len(query_batch)): + if self.is_encoder_decoder: + # Decoder sentence starts always in the index 1 after padding in the Enc-Dec Models + start = 1 + end = attention_mask[j, :].sum() - 1 + else: + start = len(query_batch[j]) - 1 # logprobs starts from the second query token + if attention_mask[j, 0] == 0: # offset left padding + start += attention_mask[j, :].nonzero()[0] + end = start + len(response_batch[j]) + if response_masks is not None: + response_masks_batch[j] = torch.cat( + (torch.zeros_like(query_batch[j]), response_masks_batch[j]) + )[1:] + + masks[j, :start] = 0 + masks[j, end:] = 0 + if response_masks is not None: + masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end] + + if return_logits: + all_logits.append(logits) + else: + del logits + all_values.append(values) + all_logprobs.append(logprobs) + all_masks.append(masks) + + return ( + torch.cat(all_logprobs), + torch.cat(all_logits)[:, :-1] if return_logits else None, + torch.cat(all_values)[:, :-1], + torch.cat(all_masks)[:, :-1], + ) + + @PPODecorators.empty_device_cache() + def train_minibatch( + self, + old_logprobs: torch.FloatTensor, + values: torch.FloatTensor, + logprobs: torch.FloatTensor, + logits: torch.FloatTensor, + vpreds: torch.FloatTensor, + mask: torch.LongTensor, + advantages: torch.FloatTensor, + returns: torch.FloatTensor, + ): + """ + Train one PPO minibatch + + Args: + logprobs (`torch.FloatTensor`): + Log probabilities of the model, shape [batch_size, response_length] + values (`torch.FloatTensor`): + Values of the value head, shape [batch_size, response_length] + query (`torch.LongTensor`): + Encoded queries, shape [batch_size, query_length] + response (`torch.LongTensor`): + Encoded responses, shape [batch_size, response_length] + model_input (`torch.LongTensor`): + Concatenated queries and responses, shape [batch_size, query_length+response_length] + + Returns: + train_stats (dict[str, `torch.Tensor`]): + Dictionary of training statistics + """ + self.model.train() + loss_p, loss_v, train_stats = self.loss( + old_logprobs, values, logits, vpreds, logprobs, mask, advantages, returns + ) + loss = loss_p + loss_v + self.accelerator.backward(loss) + if self.config.max_grad_norm is not None: + if self.accelerator.sync_gradients: + self.accelerator.clip_grad_norm_(self.model_params, self.config.max_grad_norm) + self.optimizer.step() + # we call optimizer.zero_grad() every time and let `accelerator` handle accumulation + # see https://huggingface.co/docs/accelerate/usage_guides/gradient_accumulation#the-finished-code + self.optimizer.zero_grad() + return train_stats + + def compute_rewards( + self, + scores: torch.FloatTensor, + logprobs: torch.FloatTensor, + ref_logprobs: torch.FloatTensor, + masks: torch.LongTensor, + ): + """ + Compute per token rewards from scores and KL-penalty. + + Args: + scores (`torch.FloatTensor`): + Scores from the reward model, shape (`batch_size`) + logprobs (`torch.FloatTensor`): + Log probabilities of the model, shape (`batch_size`, `response_length`) + ref_logprobs (`torch.FloatTensor`): + Log probabilities of the reference model, shape (`batch_size`, `response_length`) + """ + rewards, non_score_rewards = [], [] + for score, logprob, ref_logprob, mask in zip(scores, logprobs, ref_logprobs, masks): + # compute KL penalty (from difference in logprobs) + kl = self._kl_penalty(logprob, ref_logprob) + non_score_reward = -self.kl_ctl.value * kl + non_score_rewards.append(non_score_reward) + reward = non_score_reward.clone() + last_non_masked_index = mask.nonzero()[-1] + + # reward is preference model score + KL penalty + reward[last_non_masked_index] += score + rewards.append(reward) + return torch.stack(rewards), torch.stack(non_score_rewards) + + def _kl_penalty(self, logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor) -> torch.FloatTensor: + if self.config.kl_penalty == "kl": + return logprob - ref_logprob + + if self.config.kl_penalty == "abs": + return (logprob - ref_logprob).abs() + + if self.config.kl_penalty == "mse": + return 0.5 * (logprob - ref_logprob).square() + + if self.config.kl_penalty == "full": + # Flip is required due to this issue? :https://github.com/pytorch/pytorch/issues/57459 + return F.kl_div(ref_logprob, logprob, log_target=True, reduction="none").sum(-1) + + raise NotImplementedError + + def compute_advantages( + self, + values: torch.FloatTensor, + rewards: torch.FloatTensor, + mask: torch.FloatTensor, + ): + lastgaelam = 0 + advantages_reversed = [] + gen_len = rewards.shape[-1] + + values = values * mask + rewards = rewards * mask + + if self.config.whiten_rewards: + rewards = masked_whiten(rewards, mask, shift_mean=False) + + for t in reversed(range(gen_len)): + nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0 + delta = rewards[:, t] + self.config.gamma * nextvalues - values[:, t] + lastgaelam = delta + self.config.gamma * self.config.lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1) + + returns = advantages + values + advantages = masked_whiten(advantages, mask) + advantages = advantages.detach() + return values, advantages, returns + + def loss( + self, + old_logprobs: torch.FloatTensor, + values: torch.FloatTensor, + logits: torch.FloatTensor, + vpreds: torch.FloatTensor, + logprobs: torch.FloatTensor, + mask: torch.LongTensor, + advantages: torch.FloatTensor, + returns: torch.FloatTensor, + ): + """ + Calculate policy and value losses. + + Args: + old_logprobs (`torch.FloatTensor`): + Log probabilities of the model, shape (`batch_size`, `response_length`) + values (`torch.FloatTensor`): + Values of the value head, shape (`batch_size`, `response_length`) + rewards (`torch.FloatTensor`): + Rewards from the reward model, shape (`batch_size`, `response_length`) + logits (`torch.FloatTensor`): + Logits of the model, shape (`batch_size`, `response_length`, `vocab_size`) + v_pred (`torch.FloatTensor`): + Values of the value head, shape (`batch_size`, `response_length`) + logprobs (`torch.FloatTensor`): + Log probabilities of the model, shape (`batch_size`, `response_length`) + """ + + vpredclipped = clip_by_value( + vpreds, + values - self.config.cliprange_value, + values + self.config.cliprange_value, + ) + + vf_losses1 = (vpreds - returns) ** 2 + vf_losses2 = (vpredclipped - returns) ** 2 + vf_loss = 0.5 * masked_mean(torch.max(vf_losses1, vf_losses2), mask) + vf_clipfrac = masked_mean(torch.gt(vf_losses2, vf_losses1).float(), mask) + + ratio = torch.exp(logprobs - old_logprobs) + + pg_losses = -advantages * ratio + pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - self.config.cliprange, 1.0 + self.config.cliprange) + + pg_loss = masked_mean(torch.max(pg_losses, pg_losses2), mask) + pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).float(), mask) + + loss = pg_loss + self.config.vf_coef * vf_loss + + avg_ratio = masked_mean(ratio, mask).item() + if avg_ratio > self.config.ratio_threshold: + warnings.warn( + f"The average ratio of batch ({avg_ratio:.2f}) exceeds threshold {self.config.ratio_threshold:.2f}. Skipping batch." + ) + pg_loss = pg_loss * 0.0 + vf_loss = vf_loss * 0.0 + loss = loss * 0.0 + + entropy = masked_mean(entropy_from_logits(logits), mask) + + approxkl = 0.5 * masked_mean((logprobs - old_logprobs) ** 2, mask) + policykl = masked_mean(old_logprobs - logprobs, mask) + + return_mean, return_var = masked_mean(returns, mask), masked_var(returns, mask) + value_mean, value_var = masked_mean(values, mask), masked_var(values, mask) + + stats = dict( + loss=dict(policy=pg_loss.detach(), value=vf_loss.detach(), total=loss.detach()), + policy=dict( + entropy=entropy.detach(), + approxkl=approxkl.detach(), + policykl=policykl.detach(), + clipfrac=pg_clipfrac.detach(), + advantages=advantages.detach(), + advantages_mean=masked_mean(advantages, mask).detach(), + ratio=ratio.detach(), + ), + returns=dict(mean=return_mean.detach(), var=return_var.detach()), + val=dict( + vpred=masked_mean(vpreds, mask).detach(), + error=masked_mean((vpreds - returns) ** 2, mask).detach(), + clipfrac=vf_clipfrac.detach(), + mean=value_mean.detach(), + var=value_var.detach(), + ), + ) + return pg_loss, self.config.vf_coef * vf_loss, flatten_dict(stats) + + def record_step_stats(self, kl_coef: float, **data): + """ + Record training step statistics. + + + Args: + kl_coef (`float`): + KL coefficient + data (`dict`): + Dictionary of training step data + + Returns: + stats (`dict`): + Dictionary of training step statistics + """ + mask = data.pop("masks") + + kl_list = ((data["logprobs"] - data["ref_logprobs"]) * mask).sum(axis=-1) + mean_kl = kl_list.mean() + mean_entropy = (-data["logprobs"] * mask).sum(axis=-1).mean() + + mean_non_score_reward = masked_mean( + data["non_score_reward"], mask + ) # non_score_reward is size `batch_size`, `response_length` + mean_scores = data["scores"].mean() # scores is size `batch_size` + std_scores = data["scores"].std() + + if mean_kl.item() < -1.0: + # warn users + warnings.warn( + f"KL divergence is starting to become negative: {mean_kl.item():.2f} - this might be a precursor for failed training." + " sometimes this happens because the generation kwargs are not correctly set. Please make sure" + " that the generation kwargs are set correctly, or review your training hyperparameters." + ) + + stats = { + "objective/kl": mean_kl, + "objective/kl_dist": kl_list, + "objective/logprobs": data["logprobs"], + "objective/ref_logprobs": data["ref_logprobs"], + "objective/kl_coef": kl_coef, + "objective/entropy": mean_entropy, + "ppo/mean_non_score_reward": mean_non_score_reward, + "ppo/mean_scores": mean_scores, + "ppo/std_scores": std_scores, + } + + # Log text properties + query_lens = torch.tensor([len(query) for query in data["queries"]], dtype=torch.float) + response_lens = torch.tensor([len(response) for response in data["responses"]], dtype=torch.float) + + stats["tokens/queries_len_mean"] = torch.mean(query_lens).cpu().numpy().item() + stats["tokens/queries_len_std"] = torch.std(query_lens).cpu().numpy().item() + stats["tokens/queries_dist"] = query_lens.cpu().numpy() + stats["tokens/responses_len_mean"] = torch.mean(response_lens).cpu().numpy().item() + stats["tokens/responses_len_std"] = torch.std(response_lens).cpu().numpy().item() + stats["tokens/responses_dist"] = response_lens.cpu().numpy() + + for k, v in data["train_stats"].items(): + stats[f"ppo/{k}"] = torch.mean(v, axis=0) + stats["ppo/val/var_explained"] = 1 - stats["ppo/val/error"] / stats["ppo/returns/var"] + return stats + + def log_stats( + self, + stats: dict, + batch: dict, + rewards: List[torch.FloatTensor], + columns_to_log: List[str] = ["query", "response"], + ): + """ + A function that logs all the training stats. Call it at the end of each epoch. + + Args: + stats (dict[str, Any]): + A dictionary of training stats. + batch (dict[str, Any]): + A dictionary of batch data, this contains the queries and responses. + rewards (`List[torch.FloatTensor]`): + A tensor of rewards. + """ + if not isinstance(rewards, torch.Tensor): + rewards = torch.tensor(rewards).to(self.current_device) + rewards = self.accelerator.gather(rewards).flatten() + + # Log only if we are in the main process + if self.accelerator.is_main_process: + logs = {} + + # Log stats + if "query" not in batch.keys() and "response" not in batch.keys(): + # warn the user that the game logs will not be logged + warnings.warn( + "The game logs will not be logged because the batch does not contain the keys 'query' and " + "'response'. " + ) + elif self.config.log_with == "wandb": + import wandb + + if any([column_to_log not in batch.keys() for column_to_log in columns_to_log]): + raise ValueError(f"Columns to log {columns_to_log} are not present in the batch {batch.keys()}.") + + batch_list = [batch[column_to_log] for column_to_log in columns_to_log] + if self.is_distributed: + self.accelerator.wait_for_everyone() + gathered_batch_list = [] + for batch in batch_list: + flattened = gather_object(batch) + gathered_batch_list.append(flattened) + batch_list = gathered_batch_list + + table_rows = [list(r) for r in zip(*batch_list, rewards.cpu().tolist())] + logs.update({"game_log": wandb.Table(columns=[*columns_to_log, "reward"], rows=table_rows)}) + + logs.update(stats) + + # manually cast in fp32 for bf16 torch tensors + for k, v in logs.items(): + if isinstance(v, torch.Tensor) and v.dtype == torch.bfloat16: + logs[k] = v.float() + + logs["env/reward_mean"] = torch.mean(rewards).cpu().numpy().item() + logs["env/reward_std"] = torch.std(rewards).cpu().numpy().item() + logs["env/reward_dist"] = rewards.cpu().numpy() + + if self.config.log_with == "tensorboard": + # update the current step + self.current_step += 1 + + self.accelerator.log( + logs, + step=self.current_step if self.config.log_with == "tensorboard" else None, + ) + + def create_model_card(self, path: str, model_name: Optional[str] = "TRL Model") -> None: + """Creates and saves a model card for a TRL model. + + Args: + path (`str`): The path to save the model card to. + model_name (`str`, *optional*): The name of the model, defaults to `TRL Model`. + """ + try: + user = whoami()["name"] + # handle the offline case + except: # noqa + warnings.warn("Cannot retrieve user information assuming you are running in offline mode.") + return + + if not os.path.exists(path): + os.makedirs(path) + + model_card_content = MODEL_CARD_TEMPLATE.format(model_name=model_name, model_id=f"{user}/{path}") + with open(os.path.join(path, "README.md"), "w", encoding="utf-8") as f: + f.write(model_card_content) + + def _save_pretrained(self, save_directory: str) -> None: + self.accelerator.unwrap_model(self.model).save_pretrained(save_directory) + self.tokenizer.save_pretrained(save_directory) + self.create_model_card(save_directory) + + def _show_tokens(self, tokens, masks): + from rich import print + from rich.text import Text + + text = Text() + + for i, (token, mask) in enumerate(zip(tokens, masks)): + if mask == 1: + text.append(self.tokenizer.decode(token.item()), style="black on deep_sky_blue1") + text.append(" ") + else: + text.append(self.tokenizer.decode(token.item()), style="black on cyan3") + text.append(" ") + print(text) + + def _prepare_deepspeed(self, model: PreTrainedModelWrapper): + # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 + deepspeed_plugin = self.accelerator.state.deepspeed_plugin + config_kwargs = deepspeed_plugin.deepspeed_config + if model is not None: + if hasattr(model, "config"): + hidden_size = ( + max(model.config.hidden_sizes) + if getattr(model.config, "hidden_sizes", None) + else getattr(model.config, "hidden_size", None) + ) + if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3: + # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0` + # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081 + config_kwargs.update( + { + "zero_optimization.reduce_bucket_size": hidden_size * hidden_size, + "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, + "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size, + } + ) + + # If ZeRO-3 is used, we shard both the active and reference model. + # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0) + if config_kwargs["zero_optimization"]["stage"] != 3: + config_kwargs["zero_optimization"]["stage"] = 0 + model, *_ = deepspeed.initialize(model=model, config=config_kwargs) + model.eval() + return model diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..ed81ef73e653f9c25fb1fee06109cd332d35f810 --- /dev/null +++ b/trl/trainer/reward_trainer.py @@ -0,0 +1,277 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import warnings +from dataclasses import FrozenInstanceError, replace +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from datasets import Dataset +from transformers import DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer, TrainingArguments +from transformers.trainer_callback import TrainerCallback +from transformers.trainer_pt_utils import nested_detach +from transformers.trainer_utils import EvalPrediction + +from ..import_utils import is_peft_available +from .training_configs import RewardConfig +from .utils import PeftSavingCallback, RewardDataCollatorWithPadding, compute_accuracy + + +if is_peft_available(): + from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training + + +class RewardTrainer(Trainer): + r""" + The RewardTrainer can be used to train your custom Reward Model. It is a subclass of the + `transformers.Trainer` class and inherits all of its attributes and methods. It is recommended to use + an `AutoModelForSequenceClassification` as the reward model. The reward model should be trained on a dataset + of paired examples, where each example is a tuple of two sequences. The reward model should be trained to + predict which example in the pair is more relevant to the task at hand. + + The reward trainer expects a very specific format for the dataset. The dataset should contain two 4 entries at least + if you don't use the default `RewardDataCollatorWithPadding` data collator. The entries should be named + - `input_ids_chosen` + - `attention_mask_chosen` + - `input_ids_rejected` + - `attention_mask_rejected` + + Optionally, you can also pass a `margin` entry to the dataset. This entry should contain the margin used to modulate the + loss of the reward model as outlined in https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/. + If you don't pass a margin, no margin will be used. + """ + + def __init__( + self, + model: Union[PreTrainedModel, nn.Module] = None, + args: Optional[RewardConfig] = None, + data_collator: Optional[DataCollator] = None, + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, + compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, + callbacks: Optional[List[TrainerCallback]] = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( + None, + None, + ), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + max_length: Optional[int] = None, + peft_config: Optional[Dict] = None, + ): + """ + Initialize RewardTrainer. + + Args: + model (`transformers.PreTrainedModel`): + The model to train, preferably an `AutoModelForSequenceClassification`. + args (`RewardConfig`): + The arguments to use for training. + data_collator (`transformers.DataCollator`): + The data collator to use for training. If None is specified, the default data collator (`RewardDataCollatorWithPadding`) will be used + which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences. + train_dataset (`datasets.Dataset`): + The dataset to use for training. + eval_dataset (`datasets.Dataset`): + The dataset to use for evaluation. + tokenizer (`transformers.PreTrainedTokenizerBase`): + The tokenizer to use for training. This argument is required if you want to use the default data collator. + model_init (`Callable[[], transformers.PreTrainedModel]`): + The model initializer to use for training. If None is specified, the default model initializer will be used. + compute_metrics (`Callable[[transformers.EvalPrediction], Dict]`, *optional* defaults to `compute_accuracy`): + The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used. + callbacks (`List[transformers.TrainerCallback]`): + The callbacks to use for training. + optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): + The optimizer and scheduler to use for training. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): + The function to use to preprocess the logits before computing the metrics. + peft_config (`Dict`, defaults to `None`): + The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model. + """ + if type(args) == TrainingArguments: + warnings.warn( + "Using `transformers.TrainingArguments` for `args` is deprecated and will be removed in a future version. Please use `RewardConfig` instead.", + FutureWarning, + ) + if max_length is not None: + warnings.warn( + "The `max_length` argument is deprecated and will be removed in a future version. Please use the `RewardConfig` to set `max_length` instead.", + FutureWarning, + ) + else: + if max_length is not None and args.max_length is not None: + raise ValueError( + "You cannot specify both `max_length` and `args.max_length`. Please use the `RewardConfig` to set `max_length` once." + ) + if max_length is not None and args.max_length is None: + warnings.warn( + "The `max_length` argument is deprecated and will be removed in a future version. Please use the `RewardConfig` to set `max_length` instead.", + FutureWarning, + ) + if not is_peft_available() and peft_config is not None: + raise ValueError( + "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" + ) + elif is_peft_available() and peft_config is not None: + if not isinstance(model, PeftModel): + if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False): + _supports_gc_kwargs = "gradient_checkpointing_kwargs" in list( + inspect.signature(prepare_model_for_kbit_training).parameters + ) + + preprare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} + + if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None: + warnings.warn( + "You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. " + "please update to the latest version of peft to use `gradient_checkpointing_kwargs`." + ) + elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None: + preprare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs + + model = prepare_model_for_kbit_training(model, **preprare_model_kwargs) + + model = get_peft_model(model, peft_config) + + if is_peft_available() and isinstance(model, PeftModel): + if callbacks is None: + callbacks = [PeftSavingCallback()] + else: + callbacks += [PeftSavingCallback()] + + if compute_metrics is None: + compute_metrics = compute_accuracy + + if data_collator is None: + if tokenizer is None: + raise ValueError( + "max_length or a tokenizer must be specified when using the default RewardDataCollatorWithPadding" + ) + if type(args) == TrainingArguments: + if max_length is None: + warnings.warn( + "When using RewardDataCollatorWithPadding, you should set `max_length` in RewardConfig." + " It will be set to `512` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_length = 512 + else: + if max_length is None and args.max_length is None: + warnings.warn( + "When using RewardDataCollatorWithPadding, you should set `max_length` in RewardConfig." + " It will be set to `512` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_length = 512 + if max_length is None and args.max_length is not None: + max_length = args.max_length + + data_collator = RewardDataCollatorWithPadding(tokenizer, max_length=max_length) + + if args.remove_unused_columns: + try: # for bc before https://github.com/huggingface/transformers/pull/25435 + args.remove_unused_columns = False + except FrozenInstanceError: + args = replace(args, remove_unused_columns=False) + # warn users + warnings.warn( + "When using RewardDataCollatorWithPadding, you should set `remove_unused_columns=False` in your RewardConfig" + " we have set it for you, but you should do it yourself in the future.", + UserWarning, + ) + + self.use_reward_data_collator = True + else: + self.use_reward_data_collator = False + super().__init__( + model, + args, + data_collator, + train_dataset, + eval_dataset, + tokenizer, + model_init, + compute_metrics, + callbacks, + optimizers, + preprocess_logits_for_metrics, + ) + + def compute_loss( + self, + model: Union[PreTrainedModel, nn.Module], + inputs: Dict[str, Union[torch.Tensor, Any]], + return_outputs=False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: + if not self.use_reward_data_collator: + warnings.warn( + "The current compute_loss is implemented for RewardDataCollatorWithPadding," + " if you are using a custom data collator make sure you know what you are doing or" + " implement your own compute_loss method." + ) + rewards_chosen = model( + input_ids=inputs["input_ids_chosen"], + attention_mask=inputs["attention_mask_chosen"], + )[0] + rewards_rejected = model( + input_ids=inputs["input_ids_rejected"], + attention_mask=inputs["attention_mask_rejected"], + )[0] + # calculate loss, optionally modulate with margin + if "margin" in inputs: + loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean() + else: + loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean() + + if return_outputs: + return loss, { + "rewards_chosen": rewards_chosen, + "rewards_rejected": rewards_rejected, + } + return loss + + def prediction_step( + self, + model: Union[PreTrainedModel, nn.Module], + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + inputs = self._prepare_inputs(inputs) + if ignore_keys is None: + if hasattr(self.model, "config"): + ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + with torch.no_grad(): + loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True) + + if prediction_loss_only: + return (loss, None, None) + + loss = loss.detach() + logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys) + logits = nested_detach(logits) + # Stack accepted against rejected, mean over logits + # and softmax to get preferences between accepted and rejected to sum to 1 + logits = torch.stack(logits).mean(dim=2).softmax(dim=0).T + + labels = torch.zeros(logits.shape[0]) + labels = self._prepare_inputs(labels) + + return loss, logits, labels diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..6c4a6c536a651c6d5e02dc5c580a35cdabb7e106 --- /dev/null +++ b/trl/trainer/sft_trainer.py @@ -0,0 +1,451 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import dataclasses +import inspect +import warnings +from functools import wraps +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from datasets import Dataset +from datasets.arrow_writer import SchemaInferenceError +from datasets.builder import DatasetGenerationError +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + DataCollator, + DataCollatorForLanguageModeling, + PreTrainedModel, + PreTrainedTokenizerBase, + Trainer, + TrainingArguments, +) +from transformers.modeling_utils import unwrap_model +from transformers.trainer_callback import TrainerCallback +from transformers.trainer_utils import EvalPrediction + +from ..import_utils import is_peft_available +from .utils import ( + ConstantLengthDataset, + DataCollatorForCompletionOnlyLM, + PeftSavingCallback, + neftune_post_forward_hook, +) + + +if is_peft_available(): + from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training + + +class SFTTrainer(Trainer): + r""" + Class definition of the Supervised Finetuning Trainer (SFT Trainer). + This class is a wrapper around the `transformers.Trainer` class and inherits all of its attributes and methods. + The trainer takes care of properly initializing the PeftModel in case a user passes a `PeftConfig` object. + + Args: + model (Union[`transformers.PreTrainedModel`, `nn.Module`, `str`]): + The model to train, can be a `PreTrainedModel`, a `torch.nn.Module` or a string with the model name to + load from cache or download. The model can be also converted to a `PeftModel` if a `PeftConfig` object is + passed to the `peft_config` argument. + args (Optional[`transformers.TrainingArguments`]): + The arguments to tweak for training. Please refer to the official documentation of `transformers.TrainingArguments` + for more information. + data_collator (Optional[`transformers.DataCollator`]): + The data collator to use for training. + train_dataset (Optional[`datasets.Dataset`]): + The dataset to use for training. We recommend users to use `trl.trainer.ConstantLengthDataset` to create their dataset. + eval_dataset (Optional[Union[`datasets.Dataset`, Dict[`str`, `datasets.Dataset`]]]): + The dataset to use for evaluation. We recommend users to use `trl.trainer.ConstantLengthDataset` to create their dataset. + tokenizer (Optional[`transformers.PreTrainedTokenizer`]): + The tokenizer to use for training. If not specified, the tokenizer associated to the model will be used. + model_init (`Callable[[], transformers.PreTrainedModel]`): + The model initializer to use for training. If None is specified, the default model initializer will be used. + compute_metrics (`Callable[[transformers.EvalPrediction], Dict]`, *optional* defaults to None): + The function used to compute metrics during evaluation. It should return a dictionary mapping metric names to metric values. + If not specified, only the loss will be computed during evaluation. + callbacks (`List[transformers.TrainerCallback]`): + The callbacks to use for training. + optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): + The optimizer and scheduler to use for training. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): + The function to use to preprocess the logits before computing the metrics. + peft_config (`Optional[PeftConfig]`): + The PeftConfig object to use to initialize the PeftModel. + dataset_text_field (`Optional[str]`): + The name of the text field of the dataset, in case this is passed by a user, the trainer will automatically create a + `ConstantLengthDataset` based on the `dataset_text_field` argument. + formatting_func (`Optional[Callable]`): + The formatting function to be used for creating the `ConstantLengthDataset`. + max_seq_length (`Optional[int]`): + The maximum sequence length to use for the `ConstantLengthDataset` and for automaticallty creating the Dataset. Defaults to `512`. + infinite (`Optional[bool]`): + Whether to use an infinite dataset or not. Defaults to `False`. + num_of_sequences (`Optional[int]`): + The number of sequences to use for the `ConstantLengthDataset`. Defaults to `1024`. + chars_per_token (`Optional[float]`): + The number of characters per token to use for the `ConstantLengthDataset`. Defaults to `3.6`. You can check how this is computed in the + stack-llama example: https://github.com/huggingface/trl/blob/08f550674c553c36c51d1027613c29f14f3676a5/examples/stack_llama/scripts/supervised_finetuning.py#L53. + packing (`Optional[bool]`): + Used only in case `dataset_text_field` is passed. This argument is used by the `ConstantLengthDataset` to pack the sequences + of the dataset. + dataset_num_proc (`Optional[int]`): + The number of workers to use to tokenize the data. Only used when `packing=False`. Defaults to None. + dataset_batch_size (`int`): + The number of examples to tokenize per batch. If batch_size <= 0 or batch_size == None, + tokenize the full dataset as a single batch. Defaults to 1000. + neftune_noise_alpha (`Optional[float]`): + If not `None`, this will activate NEFTune noise embeddings. This has been proven to drastically improve model performances for instrcution + fine-tuning. Check out the original paper here: https://arxiv.org/abs/2310.05914 and the original code here: https://github.com/neelsjain/NEFTune + model_init_kwargs: (`Optional[Dict]`, *optional*): + Dict of Optional kwargs to pass when instantiating the model from a string + """ + + def __init__( + self, + model: Union[PreTrainedModel, nn.Module, str] = None, + args: TrainingArguments = None, + data_collator: Optional[DataCollator] = None, + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, + compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, + callbacks: Optional[List[TrainerCallback]] = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + peft_config: Optional["PeftConfig"] = None, + dataset_text_field: Optional[str] = None, + packing: Optional[bool] = False, + formatting_func: Optional[Callable] = None, + max_seq_length: Optional[int] = None, + infinite: Optional[bool] = None, + num_of_sequences: Optional[int] = 1024, + chars_per_token: Optional[float] = 3.6, + dataset_num_proc: Optional[int] = None, + dataset_batch_size: int = 1000, + neftune_noise_alpha: Optional[float] = None, + model_init_kwargs: Optional[Dict] = None, + ): + if model_init_kwargs is None: + model_init_kwargs = {} + elif not isinstance(model, str): + raise ValueError("You passed model_kwargs to the SFTTrainer. But your model is already instantiated.") + + if infinite is not None: + warnings.warn( + "The `infinite` argument is deprecated and will be removed in a future version of TRL. Use `TrainingArguments.max_steps` or `TrainingArguments.num_train_epochs` instead to control training length." + ) + + if isinstance(model, str): + warnings.warn( + "You passed a model_id to the SFTTrainer. This will automatically create an " + "`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you." + ) + model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) + + if packing and data_collator is not None and isinstance(data_collator, DataCollatorForCompletionOnlyLM): + raise ValueError( + "You passed a `DataCollatorForCompletionOnlyLM` to the SFTTrainer. This is not compatible with the `packing` argument." + ) + + if is_peft_available() and peft_config is not None: + if not isinstance(peft_config, PeftConfig): + raise ValueError( + "If you want to use the PeftModel, you need to pass a PeftConfig object to the SFTTrainer." + f" and you passed a {type(peft_config)}." + ) + + if not isinstance(model, PeftModel): + if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): + _support_gc_kwargs = hasattr( + args, "gradient_checkpointing_kwargs" + ) and "gradient_checkpointing_kwargs" in list( + inspect.signature(prepare_model_for_kbit_training).parameters + ) + + preprare_model_kwargs = { + "use_gradient_checkpointing": getattr(args, "gradient_checkpointing", False) + } + + if _support_gc_kwargs: + preprare_model_kwargs["gradient_checkpointing_kwargs"] = getattr( + args, "gradient_checkpointing_kwargs", None + ) + + model = prepare_model_for_kbit_training(model, **preprare_model_kwargs) + + if args is not None: + args = dataclasses.replace(args, gradient_checkpointing=False) + elif getattr(args, "gradient_checkpointing", False): + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + model = get_peft_model(model, peft_config) + + if callbacks is None: + callbacks = [PeftSavingCallback] + + if tokenizer is None: + tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path) + if getattr(tokenizer, "pad_token", None) is None: + tokenizer.pad_token = tokenizer.eos_token + + if max_seq_length is None: + # to overcome some issues with broken tokenizers + max_seq_length = min(tokenizer.model_max_length, 1024) + + warnings.warn( + f"You didn't pass a `max_seq_length` argument to the SFTTrainer, this will default to {max_seq_length}" + ) + + self.dataset_num_proc = dataset_num_proc + self.dataset_batch_size = dataset_batch_size + + self._trainer_supports_neftune = hasattr(args, "neftune_noise_alpha") + + if neftune_noise_alpha is not None and self._trainer_supports_neftune: + args.neftune_noise_alpha = neftune_noise_alpha + warnings.warn( + "You passed a `neftune_noise_alpha` argument to the SFTTrainer, the value you passed will override the one in the `TrainingArguments`." + ) + # self.neftune_noise_alpha is done at Trainer level + elif not self._trainer_supports_neftune: + self.neftune_noise_alpha = neftune_noise_alpha + + if not packing: + if dataset_text_field is None and formatting_func is None: + raise ValueError( + "You passed `packing=False` to the SFTTrainer, but you didn't pass a `dataset_text_field` or `formatting_func` argument." + ) + + if data_collator is None: + data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + + if train_dataset is not None: + train_dataset = self._prepare_dataset( + train_dataset, + tokenizer, + packing, + dataset_text_field, + max_seq_length, + formatting_func, + num_of_sequences, + chars_per_token, + ) + if eval_dataset is not None: + _multiple = isinstance(eval_dataset, dict) + _eval_datasets = eval_dataset if _multiple else {"singleton": eval_dataset} + for _eval_dataset_name, _eval_dataset in _eval_datasets.items(): + _eval_datasets[_eval_dataset_name] = self._prepare_dataset( + _eval_dataset, + tokenizer, + packing, + dataset_text_field, + max_seq_length, + formatting_func, + num_of_sequences, + chars_per_token, + ) + if not _multiple: + eval_dataset = _eval_datasets["singleton"] + + if tokenizer.padding_side is not None and tokenizer.padding_side != "right": + warnings.warn( + "You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to " + "overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code." + ) + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + model_init=model_init, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + if self.args.max_steps > 0 and packing: + warnings.warn( + "You passed `packing=True` to the SFTTrainer, and you are training your model with `max_steps` strategy. The dataset will be iterated until the `max_steps` are reached." + ) + self.train_dataset.infinite = True + elif self.args.max_steps == -1 and packing: + self.train_dataset.infinite = False + + @wraps(Trainer.train) + def train(self, *args, **kwargs): + # Activate neftune right before training. + if self.neftune_noise_alpha is not None and not self._trainer_supports_neftune: + self.model = self._trl_activate_neftune(self.model) + + output = super().train(*args, **kwargs) + + # After training we make sure to retrieve back the original forward pass method + # for the embedding layer by removing the forward post hook. + if self.neftune_noise_alpha is not None and not self._trainer_supports_neftune: + unwrapped_model = unwrap_model(self.model) + if is_peft_available() and isinstance(unwrapped_model, PeftModel): + embeddings = unwrapped_model.base_model.model.get_input_embeddings() + else: + embeddings = unwrapped_model.get_input_embeddings() + + self.neftune_hook_handle.remove() + del embeddings.neftune_noise_alpha + + return output + + def _prepare_dataset( + self, + dataset, + tokenizer, + packing, + dataset_text_field, + max_seq_length, + formatting_func, + num_of_sequences, + chars_per_token, + ): + if dataset is None: + raise ValueError("The dataset should not be None") + + # check if torch dataset / dataloader and do nothing + if isinstance(dataset, (torch.utils.data.IterableDataset, torch.utils.data.Dataset, ConstantLengthDataset)): + return dataset + + if not packing: + return self._prepare_non_packed_dataloader( + tokenizer, dataset, dataset_text_field, max_seq_length, formatting_func + ) + + else: + return self._prepare_packed_dataloader( + tokenizer, + dataset, + dataset_text_field, + max_seq_length, + num_of_sequences, + chars_per_token, + formatting_func, + ) + + def _prepare_non_packed_dataloader( + self, tokenizer, dataset, dataset_text_field, max_seq_length, formatting_func=None + ): + use_formatting_func = formatting_func is not None and dataset_text_field is None + self._dataset_sanity_checked = False + + # Inspired from: https://huggingface.co/learn/nlp-course/chapter7/6?fw=pt + def tokenize(element): + outputs = tokenizer( + element[dataset_text_field] if not use_formatting_func else formatting_func(element), + truncation=True, + padding=False, + max_length=max_seq_length, + return_overflowing_tokens=False, + return_length=False, + ) + + if use_formatting_func and not self._dataset_sanity_checked: + if not isinstance(formatting_func(element), list): + raise ValueError( + "The `formatting_func` should return a list of processed strings since it can lead to silent bugs." + ) + else: + self._dataset_sanity_checked = True + + return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]} + + tokenized_dataset = dataset.map( + tokenize, + batched=True, + remove_columns=dataset.column_names, + num_proc=self.dataset_num_proc, + batch_size=self.dataset_batch_size, + ) + + return tokenized_dataset + + def _prepare_packed_dataloader( + self, + tokenizer, + dataset, + dataset_text_field, + max_seq_length, + num_of_sequences, + chars_per_token, + formatting_func=None, + ): + if dataset_text_field is not None or formatting_func is not None: + if tokenizer is None: + raise ValueError("You need to pass a tokenizer when using `dataset_text_field` with `SFTTrainer`.") + + constant_length_iterator = ConstantLengthDataset( + tokenizer, + dataset, + dataset_text_field=dataset_text_field, + formatting_func=formatting_func, + seq_length=max_seq_length, + infinite=False, + num_of_sequences=num_of_sequences, + chars_per_token=chars_per_token, + eos_token_id=tokenizer.eos_token_id, + ) + + def data_generator(constant_length_iterator): + for i in constant_length_iterator: + yield i + + try: + packed_dataset = Dataset.from_generator( + data_generator, gen_kwargs={"constant_length_iterator": constant_length_iterator} + ) + except (DatasetGenerationError, SchemaInferenceError): + raise ValueError( + "Error occurred while packing the dataset. Make sure that your dataset has enough samples to at least yield one packed sequence." + ) + return packed_dataset + else: + raise ValueError( + "You need to pass a `dataset_text_field` or `formatting_func` argument to the SFTTrainer if you want to use the `ConstantLengthDataset`." + ) + + def _trl_activate_neftune(self, model): + r""" + Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper: https://arxiv.org/abs/2310.05914 + Since in transformers Trainer we do have an `_activate_neftune` method, we need to rename this method to avoid conflicts. + """ + unwrapped_model = unwrap_model(model) + if is_peft_available() and isinstance(unwrapped_model, PeftModel): + embeddings = unwrapped_model.base_model.model.get_input_embeddings() + else: + embeddings = unwrapped_model.get_input_embeddings() + + embeddings.neftune_noise_alpha = self.neftune_noise_alpha + hook_handle = embeddings.register_forward_hook(neftune_post_forward_hook) + self.neftune_hook_handle = hook_handle + return model diff --git a/trl/trainer/training_configs.py b/trl/trainer/training_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..8341819c3486664a8227290ed5ba1574b1932d9a --- /dev/null +++ b/trl/trainer/training_configs.py @@ -0,0 +1,43 @@ +# coding=utf-8 +# coding=utf-8 +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional + +from transformers import TrainingArguments + + +@dataclass +class RewardConfig(TrainingArguments): + """ + RewardConfig collects all training arguments related to the [`RewardTrainer`] class. + + Using [`HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + max_length (`int`, *optional*, defaults to `None`): + The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator. + gradient_checkpointing (`bool`, *optional*, defaults to `True`): + If True, use gradient checkpointing to save memory at the expense of slower backward pass. + """ + + max_length: Optional[int] = None + """The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.""" + gradient_checkpointing: Optional[bool] = True + """If True, use gradient checkpointing to save memory at the expense of slower backward pass.""" + gradient_checkpointing_kwargs: Optional[dict] = None + """Keyword arguments to pass to the gradient checkpointing function.""" diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3c29bedfb269a21fcf049769dd0b34102d21fe26 --- /dev/null +++ b/trl/trainer/utils.py @@ -0,0 +1,779 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import random +import warnings +from collections import deque +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import IterableDataset +from transformers import DataCollatorForLanguageModeling, PreTrainedModel, PreTrainedTokenizerBase, TrainerCallback + + +class AdaptiveKLController: + """ + Adaptive KL controller described in the paper: + https://arxiv.org/pdf/1909.08593.pdf + """ + + def __init__(self, init_kl_coef, target, horizon): + self.value = init_kl_coef + self.target = target + self.horizon = horizon + + def update(self, current, n_steps): + target = self.target + proportional_error = np.clip(current / target - 1, -0.2, 0.2) + mult = 1 + proportional_error * n_steps / self.horizon + self.value *= mult + + +class FixedKLController: + """Fixed KL controller.""" + + def __init__(self, kl_coef): + self.value = kl_coef + + def update(self, current, n_steps): + pass + + +class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling): + """ + Data collator used for completion tasks. It ensures that all the tokens of the labels are set to an 'ignore_index' + when they do not come from the assistant. This ensure that the loss is only + calculated on the completion made by the assistant. + + Args: + instruction_template (`Optional[str]`): the template form that indicates the start of the human instruction, typically something like + '### Human:\n'. Useful for assistant-style conversation datasets + response_template (`Union[str, List[int]]`): the template form that indicates the start of the response, typically something like + '### Response:\n'. It can also be passed as tokenized ids, which can be useful when using a tokenizer that encodes the response + differently if it does not have proper context. + mlm (`bool`, *optional*, defaults to `False`): Whether or not to use masked language modeling in the underlying + `DataCollatorForLanguageModeling` class. Note that this option currently has no effect but is present + for flexibility and backwards-compatibility. + ignore_index (`int`, *optional*, defaults to `-100`): + The index to use to ignore the initial tokens with + """ + + def __init__( + self, + response_template: Union[str, List[int]], + instruction_template: Union[str, List[int]] = None, + *args, + mlm: bool = False, + ignore_index: int = -100, + **kwargs, + ): + super().__init__(*args, mlm=mlm, **kwargs) + + self.instruction_template = instruction_template + if isinstance(instruction_template, str): + # The user provides a string, must tokenize + self.instruction_token_ids = self.tokenizer.encode(self.instruction_template, add_special_tokens=False) + else: + # The user already provides the token ids + self.instruction_token_ids = instruction_template + + self.response_template = response_template + if isinstance(response_template, str): + # The user provides a string, must tokenize + self.response_token_ids = self.tokenizer.encode(self.response_template, add_special_tokens=False) + else: + # The user already provides the token ids + self.response_token_ids = response_template + + if not self.mlm and self.instruction_template and self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: + warnings.warn( + "The pad_token_id and eos_token_id values of this tokenizer are identical. " + "If you are planning for multi-turn training, " + "it can result in the model continuously generating questions and answers without eos token. " + "To avoid this, set the pad_token_id to a different value." + ) + + self.ignore_index = ignore_index + + def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: + batch = super().torch_call(examples) + + if self.instruction_template is None: + for i in range(len(examples)): + response_token_ids_start_idx = None + + for idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]: + # `response_token_ids` is `'### Response:\n'`, here we are just making sure that the token IDs match + if ( + self.response_token_ids + == batch["labels"][i][idx : idx + len(self.response_token_ids)].tolist() + ): + response_token_ids_start_idx = idx + + if response_token_ids_start_idx is None: + warnings.warn( + f"Could not find response key `{self.response_template}` in the " + f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} ' + f"This instance will be ignored in loss calculation. " + f"Note, if this happens often, consider increasing the `max_seq_length`." + ) + batch["labels"][i, :] = self.ignore_index + else: + response_token_ids_end_idx = response_token_ids_start_idx + len(self.response_token_ids) + + # Make pytorch loss function ignore all tokens up through the end of the response key + batch["labels"][i, :response_token_ids_end_idx] = self.ignore_index + + else: + for i in range(len(examples)): + response_token_ids_idxs = [] + human_token_ids_idxs = [] + + for assistant_idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]: + # find the indexes of the start of a response. + if ( + self.response_token_ids + == batch["labels"][i][assistant_idx : assistant_idx + len(self.response_token_ids)].tolist() + ): + response_token_ids_idxs.append(assistant_idx + len(self.response_token_ids)) + + if len(response_token_ids_idxs) == 0: + warnings.warn( + f"Could not find response key `{self.response_template}` in the " + f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} ' + f"This instance will be ignored in loss calculation. " + f"Note, if this happens often, consider increasing the `max_seq_length`." + ) + batch["labels"][i, :] = self.ignore_index + + human_token_ids = self.instruction_token_ids + for human_idx in np.where(batch["labels"][i] == human_token_ids[0])[0]: + # find the indexes of the start of a human answer. + if human_token_ids == batch["labels"][i][human_idx : human_idx + len(human_token_ids)].tolist(): + human_token_ids_idxs.append(human_idx) + + if len(human_token_ids_idxs) == 0: + warnings.warn( + f"Could not find instruction key `{self.instruction_template}` in the " + f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} ' + f"This instance will be ignored in loss calculation. " + f"Note, if this happens often, consider increasing the `max_seq_length`." + ) + batch["labels"][i, :] = self.ignore_index + + for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)): + # Make pytorch loss function ignore all non response tokens + if idx != 0: + batch["labels"][i, start:end] = self.ignore_index + else: + batch["labels"][i, :end] = self.ignore_index + + if len(response_token_ids_idxs) < len(human_token_ids_idxs): + batch["labels"][i, human_token_ids_idxs[-1] :] = self.ignore_index + + return batch + + +@dataclass +class RewardDataCollatorWithPadding: + r""" + Reward DataCollator class that pads the inputs to the maximum length of the batch. + Args: + tokenizer (`PreTrainedTokenizerBase`): + The tokenizer used for encoding the data. + padding (`Union[bool, str, `PaddingStrategy`]`, `optional`, defaults to `True`): + padding_strategy to pass to the tokenizer. + max_length (`Optional[int]`, `optional`, defaults to `None`): + The maximum length of the sequence to be processed. + pad_to_multiple_of (`Optional[int]`, `optional`, defaults to `None`): + If set will pad the sequence to a multiple of the provided value. + return_tensors (`str`, `optional`, defaults to `"pt"`): + The tensor type to use. + """ + tokenizer: PreTrainedTokenizerBase + padding: Union[bool, str] = True + max_length: Optional[int] = None + pad_to_multiple_of: Optional[int] = None + return_tensors: str = "pt" + + def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: + features_chosen = [] + features_rejected = [] + margin = [] + # check if we have a margin. If we do, we need to batch it as well + has_margin = "margin" in features[0] + for feature in features: + # check if the keys are named as expected + if ( + "input_ids_chosen" not in feature + or "input_ids_rejected" not in feature + or "attention_mask_chosen" not in feature + or "attention_mask_rejected" not in feature + ): + raise ValueError( + "The features should include `input_ids_chosen`, `attention_mask_chosen`, `input_ids_rejected` and `attention_mask_rejected`" + ) + + features_chosen.append( + { + "input_ids": feature["input_ids_chosen"], + "attention_mask": feature["attention_mask_chosen"], + } + ) + features_rejected.append( + { + "input_ids": feature["input_ids_rejected"], + "attention_mask": feature["attention_mask_rejected"], + } + ) + if has_margin: + margin.append(feature["margin"]) + batch_chosen = self.tokenizer.pad( + features_chosen, + padding=self.padding, + max_length=self.max_length, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors=self.return_tensors, + ) + batch_rejected = self.tokenizer.pad( + features_rejected, + padding=self.padding, + max_length=self.max_length, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors=self.return_tensors, + ) + batch = { + "input_ids_chosen": batch_chosen["input_ids"], + "attention_mask_chosen": batch_chosen["attention_mask"], + "input_ids_rejected": batch_rejected["input_ids"], + "attention_mask_rejected": batch_rejected["attention_mask"], + "return_loss": True, + } + if has_margin: + margin = torch.tensor(margin, dtype=torch.float) + batch["margin"] = margin + return batch + + +@dataclass +class DPODataCollatorWithPadding: + r""" + DPO DataCollator class that pads the inputs to the maximum length of the batch. + Args: + tokenizer (`PreTrainedTokenizerBase`): + The tokenizer used for encoding the data. + model (Optional[`PreTrainedModel`]): + The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to + prepare the *decoder_input_ids*. + padding (`Union[bool, str, `PaddingStrategy`]`, `optional`, defaults to `True`): + padding_strategy to pass to the tokenizer. + max_length (`Optional[int]`, `optional`, defaults to `None`): + The maximum length of the sequence to be processed. + max_prompt_length (`Optional[int]`, `optional`, defaults to `None`): + The maximum length of the prompt to be processed. + label_pad_token_id (`int`, defaults to -100): + The label used for masking. + padding_value (`int`, defaults to 0): + The value used for padding. + is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`): + Whether or not you model has an encoder_decoder architecture. + max_target_length (`Optional[int]`, `optional`, defaults to `None`): + The maximum length of the target to be processed. Only useful for encoder-decoder architectures. + truncation_mode: (`str`, defaults to "keep_end"): + The truncation mode to use when truncating the prompt. + """ + tokenizer: PreTrainedTokenizerBase + model: Optional[PreTrainedModel] = None + padding: Union[bool, str] = True + max_length: Optional[int] = None + max_prompt_length: Optional[int] = None + label_pad_token_id: int = -100 + padding_value: int = 0 + truncation_mode: str = "keep_end" + is_encoder_decoder: Optional[bool] = False + max_target_length: Optional[int] = None + + def tokenize_batch_element( + self, + prompt: str, + chosen: str, + rejected: str, + ) -> Dict: + """Tokenize a single batch element. + + At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation + in case the prompt + chosen or prompt + rejected responses is/are too long. First + we truncate the prompt; if we're still too long, we truncate the chosen/rejected. + + We also create the labels for the chosen/rejected responses, which are of length equal to + the sum of the length of the prompt and the chosen/rejected response, with + label_pad_token_id for the prompt tokens. + """ + batch = {} + + if not self.is_encoder_decoder: + chosen_tokens = self.tokenizer(chosen, add_special_tokens=False) + rejected_tokens = self.tokenizer(rejected, add_special_tokens=False) + prompt_tokens = self.tokenizer(prompt, add_special_tokens=False) + + eos_token_id = self.tokenizer.eos_token_id + # Get indices in list prompt_tokens["input_ids"] that equals the EOS token (often 0) + eos_indices_prompt = [i for i, x in enumerate(prompt_tokens["input_ids"]) if x == eos_token_id] + # attention mask these indices to eos_token_id + new_attention_mask = [ + 0 if i in eos_indices_prompt else p for i, p in enumerate(prompt_tokens["attention_mask"]) + ] + prompt_tokens["attention_mask"] = new_attention_mask + + # do the same for chosen and rejected + eos_indices_chosen = [i for i, x in enumerate(chosen_tokens["input_ids"]) if x == eos_token_id] + new_attention_mask_c = [ + 0 if i in eos_indices_chosen else p for i, p in enumerate(chosen_tokens["attention_mask"]) + ] + chosen_tokens["attention_mask"] = new_attention_mask_c + + eos_indices_rejected = [i for i, x in enumerate(rejected_tokens["input_ids"]) if x == eos_token_id] + new_attention_mask_r = [ + 0 if i in eos_indices_rejected else p for i, p in enumerate(rejected_tokens["attention_mask"]) + ] + rejected_tokens["attention_mask"] = new_attention_mask_r + + # add EOS token to end of prompt + chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id) + chosen_tokens["attention_mask"].append(1) + + rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id) + rejected_tokens["attention_mask"].append(1) + + longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"])) + + # if combined sequence is too long, truncate the prompt + if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length: + if self.truncation_mode == "keep_start": + prompt_tokens = {k: v[: self.max_prompt_length] for k, v in prompt_tokens.items()} + elif self.truncation_mode == "keep_end": + prompt_tokens = {k: v[-self.max_prompt_length :] for k, v in prompt_tokens.items()} + else: + raise ValueError(f"Unknown truncation mode: {self.truncation_mode}") + + # if that's still too long, truncate the response + if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length: + chosen_tokens = {k: v[: self.max_length - self.max_prompt_length] for k, v in chosen_tokens.items()} + rejected_tokens = { + k: v[: self.max_length - self.max_prompt_length] for k, v in rejected_tokens.items() + } + + # Create labels + chosen_sequence_tokens = {k: prompt_tokens[k] + chosen_tokens[k] for k in chosen_tokens} + rejected_sequence_tokens = {k: prompt_tokens[k] + rejected_tokens[k] for k in rejected_tokens} + chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:] + chosen_sequence_tokens["labels"][: len(prompt_tokens["input_ids"])] = [self.label_pad_token_id] * len( + prompt_tokens["input_ids"] + ) + rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:] + rejected_sequence_tokens["labels"][: len(prompt_tokens["input_ids"])] = [self.label_pad_token_id] * len( + prompt_tokens["input_ids"] + ) + + for k, toks in { + "chosen": chosen_sequence_tokens, + "rejected": rejected_sequence_tokens, + "prompt": prompt_tokens, + }.items(): + for type_key, tokens in toks.items(): + if type_key == "token_type_ids": + continue + batch[f"{k}_{type_key}"] = tokens + + else: + chosen_tokens = self.tokenizer( + chosen, truncation=True, max_length=self.max_target_length, add_special_tokens=True + ) + rejected_tokens = self.tokenizer( + rejected, truncation=True, max_length=self.max_target_length, add_special_tokens=True + ) + prompt_tokens = self.tokenizer( + prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True + ) + + batch["chosen_labels"] = chosen_tokens["input_ids"] + batch["rejected_labels"] = rejected_tokens["input_ids"] + batch["prompt_input_ids"] = prompt_tokens["input_ids"] + batch["prompt_attention_mask"] = prompt_tokens["attention_mask"] + + if self.model is not None and hasattr(self.model, "prepare_decoder_input_ids_from_labels"): + batch["rejected_decoder_input_ids"] = self.model.prepare_decoder_input_ids_from_labels( + labels=batch["rejected_labels"] + ) + batch["chosen_decoder_input_ids"] = self.model.prepare_decoder_input_ids_from_labels( + labels=batch["chosen_labels"] + ) + + batch["prompt"] = prompt + batch["chosen"] = prompt + chosen + batch["rejected"] = prompt + rejected + batch["chosen_response_only"] = chosen + batch["rejected_response_only"] = rejected + + return batch + + def collate(self, batch): + # first, pad everything to the same length + padded_batch = {} + for k in batch[0].keys(): + if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"): + if self.is_encoder_decoder: + to_pad = [torch.LongTensor(ex[k]) for ex in batch] + + if (k.startswith("prompt")) and (k.endswith("input_ids")): + padding_value = self.tokenizer.pad_token_id + elif k.endswith("_attention_mask"): + padding_value = 0 + elif (k.startswith("chosen")) or (k.startswith("rejected")) or ("decoder" in k): + padding_value = self.label_pad_token_id + else: + raise ValueError(f"Unexpected key in batch '{k}'") + padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value) + else: + # adapted from https://stackoverflow.com/questions/73256206 + if "prompt" in k: + to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch] + else: + to_pad = [torch.LongTensor(ex[k]) for ex in batch] + if k.endswith("_input_ids"): + padding_value = self.tokenizer.pad_token_id + elif k.endswith("_labels"): + padding_value = self.label_pad_token_id + elif k.endswith("_attention_mask"): + padding_value = self.padding_value + else: + raise ValueError(f"Unexpected key in batch '{k}'") + + padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value) + # for the prompt, flip back so padding is on left side + if "prompt" in k: + padded_batch[k] = padded_batch[k].flip(dims=[1]) + else: + padded_batch[k] = [ex[k] for ex in batch] + + return padded_batch + + def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: + tokenized_batch = [] + + for feature in features: + prompt = feature["prompt"] + chosen = feature["chosen"] + rejected = feature["rejected"] + + batch_element = self.tokenize_batch_element(prompt, chosen, rejected) + tokenized_batch.append(batch_element) + + # return collated batch + return self.collate(tokenized_batch) + + +class ConstantLengthDataset(IterableDataset): + """ + Iterable dataset that returns constant length chunks of tokens from stream of text files. + The dataset also formats the text before tokenization with a specific format that is provided + by the user. + + Args: + tokenizer (`transformers.PreTrainedTokenizer`): + The processor used for processing the data. + dataset (`dataset.Dataset`): + Dataset with text files. + dataset_text_field (`str`, **optional**): + Name of the field in the dataset that contains the text. Used only if `formatting_func` is `None`. + formatting_func (`Callable`, **optional**): + Function that formats the text before tokenization. Usually it is recommended to have follows a certain + pattern such as `"### Question: {question}\n ### Answer: {answer}\n"` + infinite (`bool`, *optional*, defaults to `False`): + If True the iterator is reset after dataset reaches end else stops. + seq_length (`int`, *optional*, defaults to `1024`): + Length of token sequences to return. + num_of_sequences (`int`, *optional*, defaults to `1024`): + Number of token sequences to keep in buffer. + chars_per_token (`int`, *optional*, defaults to `3.6`): + Number of characters per token used to estimate number of tokens in text buffer. + eos_token_id (`int`, *optional*, defaults to `0`): + Id of the end of sequence token if the passed tokenizer does not have an EOS token. + shuffle ('bool', *optional*, defaults to True) + Shuffle the examples before they are returned + """ + + def __init__( + self, + tokenizer, + dataset, + dataset_text_field=None, + formatting_func=None, + infinite=False, + seq_length=1024, + num_of_sequences=1024, + chars_per_token=3.6, + eos_token_id=0, + shuffle=True, + ): + self.tokenizer = tokenizer + + if tokenizer.eos_token_id is None: + warnings.warn( + "The passed tokenizer does not have an EOS token. We will use the passed eos_token_id instead which corresponds" + f" to {eos_token_id}. If this is not the correct EOS token, make sure to pass the correct eos_token_id." + ) + + self.concat_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id else eos_token_id + self.dataset = dataset + self.seq_length = seq_length + self.infinite = infinite + self.current_size = 0 + self.max_buffer_size = seq_length * chars_per_token * num_of_sequences + self.shuffle = shuffle + if formatting_func is None: + self.formatting_func = lambda x: x[dataset_text_field] + else: + self.formatting_func = formatting_func + + if formatting_func is not None: + if formatting_func.__code__.co_argcount > 1: + warnings.warn( + "The passed formatting_func has more than one argument. Usually that function should have a single argument `example`" + " which corresponds to the dictionary returned by each element of the dataset. Make sure you know what you are doing." + ) + + def __len__(self): + return len(self.dataset) + + def __iter__(self): + iterator = iter(self.dataset) + more_examples = True + while more_examples: + buffer, buffer_len = [], 0 + while True: + if buffer_len >= self.max_buffer_size: + break + try: + buffer.append(self.formatting_func(next(iterator))) + buffer_len += len(buffer[-1]) + except StopIteration: + if self.infinite: + iterator = iter(self.dataset) + warnings.warn("The dataset reached end and the iterator is reset to the start.") + else: + more_examples = False + break + tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"] + all_token_ids = [] + for tokenized_input in tokenized_inputs: + all_token_ids.extend(tokenized_input + [self.concat_token_id]) + examples = [] + for i in range(0, len(all_token_ids), self.seq_length): + input_ids = all_token_ids[i : i + self.seq_length] + if len(input_ids) == self.seq_length: + examples.append(input_ids) + if self.shuffle: + random.shuffle(examples) + for example in examples: + self.current_size += 1 + yield { + "input_ids": torch.LongTensor(example), + "labels": torch.LongTensor(example), + } + + +class PeftSavingCallback(TrainerCallback): + def on_save(self, args, state, control, **kwargs): + if args.should_save: + checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}") + kwargs["model"].save_pretrained(checkpoint_path) + + if "pytorch_model.bin" in os.listdir(checkpoint_path): + os.remove(os.path.join(checkpoint_path, "pytorch_model.bin")) + + +class RunningMoments: + def __init__(self, accelerator): + """ + Calculates the running mean and standard deviation of a data stream. Reference: + https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/utils.py#L75 + """ + self.mean = 0 + self.std = 1 + self.var = 1 + self.count = 1e-24 + self.accelerator = accelerator + + @torch.no_grad() + def update(self, xs: torch.Tensor) -> Tuple[float, float]: + """ + Updates running moments from batch's moments computed across ranks + """ + if self.accelerator.use_distributed: + xs_mean, xs_var, xs_count = get_global_statistics(self.accelerator, xs) + else: + xs_count = xs.numel() + xs_var, xs_mean = torch.var_mean(xs, unbiased=False) + xs_mean, xs_var = xs_mean.float(), xs_var.float() + + delta = xs_mean - self.mean + tot_count = self.count + xs_count + + new_sum = xs_var * xs_count + # correct old_sum deviation accounting for the new mean + old_sum = self.var * self.count + delta**2 * self.count * xs_count / tot_count + tot_sum = old_sum + new_sum + + self.mean += delta * xs_count / tot_count + self.var = tot_sum / tot_count + self.std = (self.var * tot_count / (tot_count - 1)).float().sqrt() + self.count = tot_count + + return xs_mean.item(), (xs_var * xs_count / (xs_count - 1)).float().sqrt().item() + + +@torch.no_grad() +def get_global_statistics(accelerator, xs: torch.Tensor, mask=None, device="cpu") -> Tuple[float, float, int]: + """ + Computes element-wise mean and variance of the tensor across processes. Reference: + https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/utils.py#L57C1-L73C75 + """ + xs = xs.to(accelerator.device) + sum_and_count = torch.tensor([xs.sum(), (xs.numel() if mask is None else mask.sum())], device=xs.device) + sum_and_count = accelerator.reduce(sum_and_count) + global_sum, count = sum_and_count + global_mean = global_sum / count + + sum_var = torch.sum(((xs - global_mean) ** 2).mul(1 if mask is None else mask)) + sum_var = accelerator.reduce(sum_var) + global_var = sum_var / count + + return global_mean.to(device), global_var.to(device), count.to(device) + + +def compute_accuracy(eval_pred) -> Dict[str, float]: + predictions, labels = eval_pred + # Here, predictions is rewards_chosen and rewards_rejected. + # We want to see how much of the time rewards_chosen > rewards_rejected. + if np.array(predictions[:, 0] == predictions[:, 1], dtype=float).sum() > 0: + warnings.warn( + f"There are {np.array(predictions[:, 0] == predictions[:, 1]).sum()} out of {len(predictions[:, 0])} instances where the predictions for both options are equal. As a consequence the accuracy can be misleading." + ) + predictions = np.argmax(predictions, axis=1) + + accuracy = np.array(predictions == labels, dtype=float).mean().item() + return {"accuracy": accuracy} + + +def pad_to_length(tensor: torch.Tensor, length: int, pad_value: Union[int, float], dim: int = -1) -> torch.Tensor: + if tensor.size(dim) >= length: + return tensor + else: + pad_size = list(tensor.shape) + pad_size[dim] = length - tensor.size(dim) + return torch.cat( + [ + tensor, + pad_value * torch.ones(*pad_size, dtype=tensor.dtype, device=tensor.device), + ], + dim=dim, + ) + + +def disable_dropout_in_model(model: torch.nn.Module) -> None: + for module in model.modules(): + if isinstance(module, torch.nn.Dropout): + module.p = 0 + + +def exact_div(a, b, a_str, b_str, custom_error_message=""): + q = a // b + if a != q * b: + raise ValueError(f"{custom_error_message}, {a_str}={a}, {b_str}={b}, inexact division: {a} / {b} = {a / b}") + return q + + +# copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/stat_tracking.py#L5 +class PerPromptStatTracker: + r""" + Class for tracking statistics per prompt. Mainly used to calculate advantage for the DPPO algorithm + + Args: + buffer_size (`int`): + Size of the buffer to keep for each prompt. + min_count (`int`): + Minimum number of samples to keep in the buffer before calculating the mean and std. + """ + + def __init__(self, buffer_size, min_count): + self.buffer_size = buffer_size + self.min_count = min_count + self.stats = {} + + def update(self, prompts, rewards): + prompts = np.array(prompts) + rewards = np.array(rewards) + unique = np.unique(prompts) + advantages = np.empty_like(rewards) + for prompt in unique: + prompt_rewards = rewards[prompts == prompt] + if prompt not in self.stats: + self.stats[prompt] = deque(maxlen=self.buffer_size) + self.stats[prompt].extend(prompt_rewards) + + if len(self.stats[prompt]) < self.min_count: + mean = np.mean(rewards) + std = np.std(rewards) + 1e-6 + else: + mean = np.mean(self.stats[prompt]) + std = np.std(self.stats[prompt]) + 1e-6 + advantages[prompts == prompt] = (prompt_rewards - mean) / std + + return advantages + + def get_stats(self): + return {k: {"mean": np.mean(v), "std": np.std(v), "count": len(v)} for k, v in self.stats.items()} + + +def neftune_post_forward_hook(module, input, output): + """ + Implements the NEFTune forward pass for the model using forward hooks. Note this works only for + torch.nn.Embedding layers. This method is slightly adapted from the original source code + that can be found here: https://github.com/neelsjain/NEFTune + + Simply add it to your model as follows: + ```python + model = ... + model.embed_tokens.neftune_noise_alpha = 0.1 + model.embed_tokens.register_forward_hook(neftune_post_forward_hook) + ``` + + Args: + module (`torch.nn.Module`): + The embedding module where the hook is attached. Note that you need to set + `module.neftune_noise_alpha` to the desired noise alpha value. + input (`torch.Tensor`): + The input tensor to the model. + output (`torch.Tensor`): + The output tensor of the model (i.e. the embeddings). + """ + if module.training: + dims = torch.tensor(output.size(1) * output.size(2)) + mag_norm = module.neftune_noise_alpha / torch.sqrt(dims) + output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm) + return output diff --git a/trl/trl/.github/workflows/benchmark.yml b/trl/trl/.github/workflows/benchmark.yml new file mode 100644 index 0000000000000000000000000000000000000000..fc50e11b67ea356c3f47bdab2973c9eb03b7114b --- /dev/null +++ b/trl/trl/.github/workflows/benchmark.yml @@ -0,0 +1,107 @@ +name: "Benchmark on Comment" + +# https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows +on: + issue_comment: + types: [created] + +jobs: + Benchmark: + strategy: + fail-fast: true + matrix: + python-version: [3.9] + os: [self-hosted] + + name: Benchmark + # Only run if it#s a PR and the comment contains /Benchmark + if: github.event.issue.pull_request && startsWith(github.event.comment.body, '/benchmark-trl-experiments') && contains(FromJSON('["vwxyzjn", "younesbelkada", "lvwerra", "lewtun"]'), github.actor) + runs-on: ${{ matrix.os }} + + steps: + - name: Get branch of PR + uses: xt0rted/pull-request-comment-branch@v1 + id: comment-branch + - name: Set latest commit status as pending + uses: myrotvorets/set-commit-status-action@master + with: + sha: ${{ steps.comment-branch.outputs.head_sha }} + token: ${{ secrets.GITHUB_TOKEN }} + status: pending + - name: Checkout `main` branch + uses: actions/checkout@v3 + - name: Checkout PR branch + run: gh pr checkout $PR_NUMBER + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + PR_NUMBER: ${{ github.event.issue.number }} + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + # - name: Cleanup pip packages (specific to self-hosted runners) + # run: | + # echo PATH is $PATH + # echo PYTHONPATH is $PYTHONPATH + # echo which python is $(which python) + # echo which pip is $(which pip) + + # pip_list=$(pip list --format=freeze | grep -v "^pip==" | grep -v "^setuptools==") + # if [ ! -z "$pip_list" ]; then + # echo "$pip_list" | xargs pip uninstall -y + # fi + - name: Print python depdenencies + run: pip list --format=freeze + - name: Install dependencies + run: | + pip install .[test,benchmark] + + - name: Login + run: wandb login ${{ secrets.WANDB_API_KEY }} && huggingface-cli login --token ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + - name: Run benchmark + env: + GITHUB_CONTEXT: ${{ toJson(github) }} + PERSONAL_ACCESS_TOKEN_GITHUB: ${{ secrets.PERSONAL_ACCESS_TOKEN_GITHUB }} + run: | + COMMENT="${{ github.event.comment.body }}" + if [[ "$COMMENT" == *"/benchmark-trl-experiments benchmark/benchmark_level1.sh"* ]]; then + echo "Running benchmark/benchmark_level1.sh" + BENCHMARK_SCRIPT="benchmark/benchmark_level1.sh" BENCHMARK_PLOT_SCRIPT="benchmark/benchmark_level1_plot.sh" bash benchmark/benchmark_and_report.sh + elif [[ "$COMMENT" == *"/benchmark-trl-experiments benchmark/benchmark_level2.sh"* ]]; then + echo "Running benchmark/benchmark_level2.sh" + BENCHMARK_SCRIPT="benchmark/benchmark_level2.sh" BENCHMARK_PLOT_SCRIPT="benchmark/benchmark_level2_plot.sh" bash benchmark/benchmark_and_report.sh + elif [[ "$COMMENT" == *"/benchmark-trl-experiments benchmark/benchmark_level3.sh"* ]]; then + echo "Running benchmark/benchmark_level3.sh" + BENCHMARK_SCRIPT="benchmark/benchmark_level3.sh" BENCHMARK_PLOT_SCRIPT="benchmark/benchmark_level3_plot.sh" bash benchmark/benchmark_and_report.sh + else + echo "Invalid command in comment. Skipping execution." + fi + + # send message to PR + - name: Setup Node.js 16 + uses: actions/setup-node@v3 + with: + node-version: 16 + - name: Add workflow result as comment on PR + uses: actions/github-script@v6 + if: always() + with: + script: | + const name = '${{ github.workflow }}'; + const url = '${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}'; + const success = '${{ job.status }}' === 'success'; + const body = `${name}: ${success ? 'succeeded ✅' : 'failed ❌'}\n${url}`; + + await github.rest.issues.createComment({ + issue_number: context.issue.number, + owner: context.repo.owner, + repo: context.repo.repo, + body: body + }) + - name: Set latest commit status as ${{ job.status }} + uses: myrotvorets/set-commit-status-action@master + if: always() + with: + sha: ${{ steps.comment-branch.outputs.head_sha }} + token: ${{ secrets.GITHUB_TOKEN }} + status: ${{ job.status }} diff --git a/trl/trl/.github/workflows/build_documentation.yml b/trl/trl/.github/workflows/build_documentation.yml new file mode 100644 index 0000000000000000000000000000000000000000..5570c872b1656174ef020e3888f4e4fd993055ff --- /dev/null +++ b/trl/trl/.github/workflows/build_documentation.yml @@ -0,0 +1,18 @@ +name: Build documentation + +on: + push: + branches: + - main + - doc-builder* + - v*-release + +jobs: + build: + uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main + with: + commit_sha: ${{ github.sha }} + package: trl + version_tag_suffix: "" + secrets: + hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} diff --git a/trl/trl/.github/workflows/build_pr_documentation.yml b/trl/trl/.github/workflows/build_pr_documentation.yml new file mode 100644 index 0000000000000000000000000000000000000000..551fd13feeb54a2a7c6ed5416639f611221ba13e --- /dev/null +++ b/trl/trl/.github/workflows/build_pr_documentation.yml @@ -0,0 +1,17 @@ +name: Build PR Documentation + +on: + pull_request: + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + build: + uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main + with: + commit_sha: ${{ github.event.pull_request.head.sha }} + pr_number: ${{ github.event.number }} + package: trl + version_tag_suffix: "" \ No newline at end of file diff --git a/trl/trl/.github/workflows/clear_cache.yml b/trl/trl/.github/workflows/clear_cache.yml new file mode 100644 index 0000000000000000000000000000000000000000..20bab26687ca1aac453c4c1a76797cbcd8114d31 --- /dev/null +++ b/trl/trl/.github/workflows/clear_cache.yml @@ -0,0 +1,33 @@ +name: "Cleanup Cache" + +on: + workflow_dispatch: + schedule: + - cron: "0 0 * * *" + +jobs: + cleanup: + runs-on: ubuntu-latest + steps: + - name: Check out code + uses: actions/checkout@v3 + + - name: Cleanup + run: | + gh extension install actions/gh-actions-cache + + REPO=${{ github.repository }} + + echo "Fetching list of cache key" + cacheKeysForPR=$(gh actions-cache list -R $REPO | cut -f 1 ) + + ## Setting this to not fail the workflow while deleting cache keys. + set +e + echo "Deleting caches..." + for cacheKey in $cacheKeysForPR + do + gh actions-cache delete $cacheKey -R $REPO --confirm + done + echo "Done" + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/trl/trl/.github/workflows/stale.yml b/trl/trl/.github/workflows/stale.yml new file mode 100644 index 0000000000000000000000000000000000000000..b3b626663f347de19ff721953305208457b559eb --- /dev/null +++ b/trl/trl/.github/workflows/stale.yml @@ -0,0 +1,27 @@ +name: Stale Bot + +on: + schedule: + - cron: "0 15 * * *" + +jobs: + close_stale_issues: + name: Close Stale Issues + if: github.repository == 'huggingface/trl' + runs-on: ubuntu-latest + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + steps: + - uses: actions/checkout@v3 + + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: 3.8 + + - name: Install requirements + run: | + pip install PyGithub + - name: Close stale issues + run: | + python scripts/stale.py \ No newline at end of file diff --git a/trl/trl/.github/workflows/tests.yml b/trl/trl/.github/workflows/tests.yml new file mode 100644 index 0000000000000000000000000000000000000000..6d2162f02252e8603e76dd82bf7134085a7183ca --- /dev/null +++ b/trl/trl/.github/workflows/tests.yml @@ -0,0 +1,75 @@ +name: tests + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + check_code_quality: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.9] + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + submodules: recursive + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - uses: pre-commit/action@v2.0.3 + with: + extra_args: --all-files + + tests: + needs: check_code_quality + strategy: + matrix: + python-version: ['3.8', '3.9', '3.10'] + os: ['ubuntu-latest', 'windows-latest'] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + cache: "pip" + cache-dependency-path: | + setup.py + requirements.txt + - name: Install dependencies + run: | + python -m pip install --upgrade pip + # cpu version of pytorch + pip install -e ".[test, peft, diffusers]" + - name: Test with pytest + run: | + make test + + tests_no_optional_dep: + needs: check_code_quality + runs-on: 'ubuntu-latest' + steps: + - uses: actions/checkout@v3 + - name: Set up Python 3.9 + uses: actions/setup-python@v4 + with: + python-version: '3.9' + cache: "pip" + cache-dependency-path: | + setup.py + requirements.txt + - name: Install dependencies + run: | + python -m pip install --upgrade pip + # cpu version of pytorch + pip install .[test] + - name: Test with pytest + run: | + make test \ No newline at end of file diff --git a/trl/trl/.github/workflows/upload_pr_documentation.yml b/trl/trl/.github/workflows/upload_pr_documentation.yml new file mode 100644 index 0000000000000000000000000000000000000000..2ad2ba0e8de52699f60c2da7792dab742dd6f200 --- /dev/null +++ b/trl/trl/.github/workflows/upload_pr_documentation.yml @@ -0,0 +1,16 @@ +name: Upload PR Documentation + +on: + workflow_run: + workflows: ["Build PR Documentation"] + types: + - completed + +jobs: + build: + uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main + with: + package_name: trl + secrets: + hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} + comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }} \ No newline at end of file diff --git a/trl/trl/.gitignore b/trl/trl/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..58cb39ec20133997c0f73ebe4ea2754180c57f5e --- /dev/null +++ b/trl/trl/.gitignore @@ -0,0 +1,146 @@ +benchmark/trl +*.bak +.gitattributes +.last_checked +.gitconfig +*.bak +*.log +*~ +~* +_tmp* +tmp* +tags + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# dotenv +.env + +# virtualenv +.venv +venv/ +ENV/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +.vscode +*.swp + +# osx generated files +.DS_Store +.DS_Store? +.Trashes +ehthumbs.db +Thumbs.db +.idea + +# pytest +.pytest_cache + +# tools/trust-doc-nbs +docs_src/.last_checked + +# symlinks to fastai +docs_src/fastai +tools/fastai + +# link checker +checklink/cookies.txt + +# .gitconfig is now autogenerated +.gitconfig + +# wandb files +nbs/wandb/ +examples/notebooks/wandb/ +wandb/ \ No newline at end of file diff --git a/trl/trl/.pre-commit-config.yaml b/trl/trl/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..545815fe673cca31be5b751c8f85062c988367b6 --- /dev/null +++ b/trl/trl/.pre-commit-config.yaml @@ -0,0 +1,42 @@ +repos: + - repo: https://github.com/PyCQA/isort + rev: 5.12.0 + hooks: + - id: isort + args: + - --profile=black + - --skip-glob=wandb/**/* + - --thirdparty=wandb + - repo: https://github.com/myint/autoflake + rev: v1.4 + hooks: + - id: autoflake + args: + - -r + - --exclude=wandb,__init__.py + - --in-place + - --remove-unused-variables + - --remove-all-unused-imports + - repo: https://github.com/python/black + rev: 22.3.0 + hooks: + - id: black + args: + - --line-length=119 + - --target-version=py38 + - --exclude=wandb + - repo: https://github.com/pycqa/flake8 + rev: 6.0.0 + hooks: + - id: flake8 + args: + - --ignore=E203,E501,W503,E128 + - --max-line-length=119 + + # - repo: https://github.com/codespell-project/codespell + # rev: v2.1.0 + # hooks: + # - id: codespell + # args: + # - --ignore-words-list=nd,reacher,thist,ths,magent,ba + # - --skip=docs/css/termynal.css,docs/js/termynal.js diff --git a/trl/trl/CITATION.cff b/trl/trl/CITATION.cff new file mode 100644 index 0000000000000000000000000000000000000000..c62318ecf2ef4965d9c3129c673f6bb0db35d64e --- /dev/null +++ b/trl/trl/CITATION.cff @@ -0,0 +1,28 @@ +cff-version: 1.2.0 +title: 'TRL: Transformer Reinforcement Learning' +message: >- + If you use this software, please cite it using the + metadata from this file. +type: software +authors: + - given-names: Leandro + family-names: von Werra + - given-names: Younes + family-names: Belkada + - given-names: Lewis + family-names: Tunstall + - given-names: Edward + family-names: Beeching + - given-names: Tristan + family-names: Thrush + - given-names: Nathan + family-names: Lambert +repository-code: 'https://github.com/huggingface/trl' +abstract: "With trl you can train transformer language models with Proximal Policy Optimization (PPO). The library is built on top of the transformers library by \U0001F917 Hugging Face. Therefore, pre-trained language models can be directly loaded via transformers. At this point, most decoder and encoder-decoder architectures are supported." +keywords: + - rlhf + - deep-learning + - pytorch + - transformers +license: Apache-2.0 +version: 0.2.1 diff --git a/trl/trl/CONTRIBUTING.md b/trl/trl/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..de1731bccbd5b667f2a40582e443e86d36accfc4 --- /dev/null +++ b/trl/trl/CONTRIBUTING.md @@ -0,0 +1,53 @@ +# How to contribute + +## How to get started + +Before you start contributing make sure you installed all the dev tools: + +```bash +pip install -e ".[dev]" +``` + +## Did you find a bug? + +* Ensure the bug was not already reported by searching on GitHub under Issues. +* If you're unable to find an open issue addressing the problem, open a new one. Be sure to include a title and clear description, as much relevant information as possible, and a code sample or an executable test case demonstrating the expected behavior that is not occurring. +* Be sure to add the complete error messages. + +#### Did you write a patch that fixes a bug? + +* Open a new GitHub pull request with the patch. +* Ensure that your PR includes a test that fails without your patch, and pass with it. +* Ensure the PR description clearly describes the problem and solution. Include the relevant issue number if applicable. + +## PR submission guidelines + +* Keep each PR focused. While it's more convenient, do not combine several unrelated fixes together. Create as many branches as needing to keep each PR focused. +* Do not mix style changes/fixes with "functional" changes. It's very difficult to review such PRs and it most likely get rejected. +* Do not add/remove vertical whitespace. Preserve the original style of the file you edit as much as you can. +* Do not turn an already submitted PR into your development playground. If after you submitted PR, you discovered that more work is needed - close the PR, do the required work and then submit a new PR. Otherwise each of your commits requires attention from maintainers of the project. +* If, however, you submitted a PR and received a request for changes, you should proceed with commits inside that PR, so that the maintainer can see the incremental fixes and won't need to review the whole PR again. In the exception case where you realize it'll take many many commits to complete the requests, then it's probably best to close the PR, do the work and then submit it again. Use common sense where you'd choose one way over another. + +### Before you submit a PR + +First you want to make sure that all the tests pass: + +```bash +make test +``` + +Then before submitting your PR make sure the code quality follows the standards. You can run the following command to format: + +```bash +make precommit +``` + +Make sure to install `pre-commit` before running the command: +```bash +pip install pre-commit +``` + +## Do you want to contribute to the documentation? + +* Docs are in the `docs/` folder and can be updated there. + diff --git a/trl/trl/LICENSE b/trl/trl/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/trl/trl/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/trl/trl/MANIFEST.in b/trl/trl/MANIFEST.in new file mode 100644 index 0000000000000000000000000000000000000000..5c0e7ced193cb35aee60cbadc3e022a1da0fd8cc --- /dev/null +++ b/trl/trl/MANIFEST.in @@ -0,0 +1,5 @@ +include settings.ini +include LICENSE +include CONTRIBUTING.md +include README.md +recursive-exclude * __pycache__ diff --git a/trl/trl/Makefile b/trl/trl/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..c578697ce3d69a9e65b4efbbdb4592434875a265 --- /dev/null +++ b/trl/trl/Makefile @@ -0,0 +1,15 @@ +.PHONY: test precommit benchmark_core benchmark_aux + +check_dirs := examples tests trl + +test: + python -m pytest -n auto --dist=loadfile -s -v ./tests/ + +precommit: + pre-commit run --all-files + +benchmark_core: + bash ./benchmark/benchmark_core.sh + +benchmark_aux: + bash ./benchmark/benchmark_aux.sh diff --git a/trl/trl/README.md b/trl/trl/README.md new file mode 100644 index 0000000000000000000000000000000000000000..25e501769d7bb924b8239895a1efccc1730d9eef --- /dev/null +++ b/trl/trl/README.md @@ -0,0 +1,184 @@ +
+ +
+ +# TRL - Transformer Reinforcement Learning +> Full stack transformer language models with reinforcement learning. + +

+ + License + + + Documentation + + + GitHub release + +

+ + +## What is it? + +
+ +
+ +`trl` is a full stack library where we provide a set of tools to train transformer language models and stable diffusion models with Reinforcement Learning, from the Supervised Fine-tuning step (SFT), Reward Modeling step (RM) to the Proximal Policy Optimization (PPO) step. The library is built on top of the [`transformers`](https://github.com/huggingface/transformers) library by 🤗 Hugging Face. Therefore, pre-trained language models can be directly loaded via `transformers`. At this point, most of decoder architectures and encoder-decoder architectures are supported. Refer to the documentation or the `examples/` folder for example code snippets and how to run these tools. + +**Highlights:** + +- [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer): A light and friendly wrapper around `transformers` Trainer to easily fine-tune language models or adapters on a custom dataset. +- [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer): A light wrapper around `transformers` Trainer to easily fine-tune language models for human preferences (Reward Modeling). +- [`PPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.PPOTrainer): A PPO trainer for language models that just needs (query, response, reward) triplets to optimise the language model. +- [`AutoModelForCausalLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForCausalLMWithValueHead) & [`AutoModelForSeq2SeqLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForSeq2SeqLMWithValueHead): A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning. +- [Examples](https://github.com/huggingface/trl/tree/main/examples): Train GPT2 to generate positive movie reviews with a BERT sentiment classifier, full RLHF using adapters only, train GPT-j to be less toxic, [Stack-Llama example](https://huggingface.co/blog/stackllama), etc. + +## How PPO works +Fine-tuning a language model via PPO consists of roughly three steps: + +1. **Rollout**: The language model generates a response or continuation based on query which could be the start of a sentence. +2. **Evaluation**: The query and response are evaluated with a function, model, human feedback or some combination of them. The important thing is that this process should yield a scalar value for each query/response pair. +3. **Optimization**: This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences. This is done with the model that is trained and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don't deviate too far from the reference language model. The active language model is then trained with PPO. + +This process is illustrated in the sketch below: + + +
+ +

Figure: Sketch of the workflow.

+
+ +## Installation + +### Python package +Install the library with pip: +```bash +pip install trl +``` + +### From source +If you want to run the examples in the repository a few additional libraries are required. Clone the repository and install it with pip: +```bash +git clone https://github.com/huggingface/trl.git +cd trl/ +pip install . +``` + +If you wish to develop TRL, you should install in editable mode: +```bash +pip install -e . +``` + +## How to use + +### `SFTTrainer` + +This is a basic example on how to use the `SFTTrainer` from the library. The `SFTTrainer` is a light wrapper around the `transformers` Trainer to easily fine-tune language models or adapters on a custom dataset. + +```python +# imports +from datasets import load_dataset +from trl import SFTTrainer + +# get dataset +dataset = load_dataset("imdb", split="train") + +# get trainer +trainer = SFTTrainer( + "facebook/opt-350m", + train_dataset=dataset, + dataset_text_field="text", + max_seq_length=512, +) + +# train +trainer.train() +``` + +### `RewardTrainer` + +This is a basic example on how to use the `RewardTrainer` from the library. The `RewardTrainer` is a wrapper around the `transformers` Trainer to easily fine-tune reward models or adapters on a custom preference dataset. + +```python +# imports +from transformers import AutoModelForSequenceClassification, AutoTokenizer +from trl import RewardTrainer + +# load model and dataset - dataset needs to be in a specific format +model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=1) +tokenizer = AutoTokenizer.from_pretrained("gpt2") + +... + +# load trainer +trainer = RewardTrainer( + model=model, + tokenizer=tokenizer, + train_dataset=dataset, +) + +# train +trainer.train() +``` + +### `PPOTrainer` + +This is a basic example on how to use the `PPOTrainer` from the library. Based on a query the language model creates a response which is then evaluated. The evaluation could be a human in the loop or another model's output. + +```python +# imports +import torch +from transformers import AutoTokenizer +from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model +from trl.core import respond_to_batch + +# get models +model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2') +model_ref = create_reference_model(model) + +tokenizer = AutoTokenizer.from_pretrained('gpt2') + +# initialize trainer +ppo_config = PPOConfig( + batch_size=1, +) + +# encode a query +query_txt = "This morning I went to the " +query_tensor = tokenizer.encode(query_txt, return_tensors="pt") + +# get model response +response_tensor = respond_to_batch(model, query_tensor) + +# create a ppo trainer +ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer) + +# define a reward for response +# (this could be any reward such as human feedback or output from another model) +reward = [torch.tensor(1.0)] + +# train model for one step with ppo +train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward) +``` + +## References + +### Proximal Policy Optimisation +The PPO implementation largely follows the structure introduced in the paper **"Fine-Tuning Language Models from Human Preferences"** by D. Ziegler et al. \[[paper](https://arxiv.org/pdf/1909.08593.pdf), [code](https://github.com/openai/lm-human-preferences)]. + +### Language models +The language models utilize the `transformers` library by 🤗 Hugging Face. + +## Citation + +```bibtex +@misc{vonwerra2022trl, + author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang}, + title = {TRL: Transformer Reinforcement Learning}, + year = {2020}, + publisher = {GitHub}, + journal = {GitHub repository}, + howpublished = {\url{https://github.com/huggingface/trl}} +} +``` diff --git a/trl/trl/benchmark/benchmark.py b/trl/trl/benchmark/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..895000f24e34aec880a69a6bfbaa04e09a94b34d --- /dev/null +++ b/trl/trl/benchmark/benchmark.py @@ -0,0 +1,150 @@ +import argparse +import math +import os +import shlex +import subprocess +import uuid +from distutils.util import strtobool + +import requests + + +def parse_args(): + # fmt: off + parser = argparse.ArgumentParser() + parser.add_argument("--command", type=str, default="", + help="the command to run") + parser.add_argument("--num-seeds", type=int, default=3, + help="the number of random seeds") + parser.add_argument("--start-seed", type=int, default=1, + help="the number of the starting seed") + parser.add_argument("--workers", type=int, default=0, + help="the number of workers to run benchmark experimenets") + parser.add_argument("--auto-tag", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, the runs will be tagged with git tags, commit, and pull request number if possible") + parser.add_argument("--slurm-template-path", type=str, default=None, + help="the path to the slurm template file (see docs for more details)") + parser.add_argument("--slurm-gpus-per-task", type=int, default=1, + help="the number of gpus per task to use for slurm jobs") + parser.add_argument("--slurm-total-cpus", type=int, default=50, + help="the number of gpus per task to use for slurm jobs") + parser.add_argument("--slurm-ntasks", type=int, default=1, + help="the number of tasks to use for slurm jobs") + parser.add_argument("--slurm-nodes", type=int, default=None, + help="the number of nodes to use for slurm jobs") + args = parser.parse_args() + # fmt: on + return args + + +def run_experiment(command: str): + command_list = shlex.split(command) + print(f"running {command}") + + # Use subprocess.PIPE to capture the output + fd = subprocess.Popen(command_list, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + output, errors = fd.communicate() + + return_code = fd.returncode + assert return_code == 0, f"Command failed with error: {errors.decode('utf-8')}" + + # Convert bytes to string and strip leading/trailing whitespaces + return output.decode("utf-8").strip() + + +def autotag() -> str: + wandb_tag = "" + print("autotag feature is enabled") + git_tag = "" + try: + git_tag = subprocess.check_output(["git", "describe", "--tags"]).decode("ascii").strip() + print(f"identified git tag: {git_tag}") + except subprocess.CalledProcessError as e: + print(e) + if len(git_tag) == 0: + try: + count = int(subprocess.check_output(["git", "rev-list", "--count", "HEAD"]).decode("ascii").strip()) + hash = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode("ascii").strip() + git_tag = f"no-tag-{count}-g{hash}" + print(f"identified git tag: {git_tag}") + except subprocess.CalledProcessError as e: + print(e) + wandb_tag = f"{git_tag}" + + git_commit = subprocess.check_output(["git", "rev-parse", "--verify", "HEAD"]).decode("ascii").strip() + try: + # try finding the pull request number on github + prs = requests.get(f"https://api.github.com/search/issues?q=repo:huggingface/trl+is:pr+{git_commit}") + if prs.status_code == 200: + prs = prs.json() + if len(prs["items"]) > 0: + pr = prs["items"][0] + pr_number = pr["number"] + wandb_tag += f",pr-{pr_number}" + print(f"identified github pull request: {pr_number}") + except Exception as e: + print(e) + + return wandb_tag + + +if __name__ == "__main__": + args = parse_args() + if args.auto_tag: + existing_wandb_tag = os.environ.get("WANDB_TAGS", "") + wandb_tag = autotag() + if len(wandb_tag) > 0: + if len(existing_wandb_tag) > 0: + os.environ["WANDB_TAGS"] = ",".join([existing_wandb_tag, wandb_tag]) + else: + os.environ["WANDB_TAGS"] = wandb_tag + print("WANDB_TAGS: ", os.environ.get("WANDB_TAGS", "")) + commands = [] + for seed in range(0, args.num_seeds): + commands += [" ".join([args.command, "--seed", str(args.start_seed + seed)])] + + print("======= commands to run:") + for command in commands: + print(command) + + if args.workers > 0 and args.slurm_template_path is None: + from concurrent.futures import ThreadPoolExecutor + + executor = ThreadPoolExecutor(max_workers=args.workers, thread_name_prefix="cleanrl-benchmark-worker-") + for command in commands: + executor.submit(run_experiment, command) + executor.shutdown(wait=True) + else: + print("not running the experiments because --workers is set to 0; just printing the commands to run") + + # SLURM logic + if args.slurm_template_path is not None: + if not os.path.exists("slurm"): + os.makedirs("slurm") + if not os.path.exists("slurm/logs"): + os.makedirs("slurm/logs") + print("======= slurm commands to run:") + with open(args.slurm_template_path) as f: + slurm_template = f.read() + slurm_template = slurm_template.replace("{{array}}", f"0-{len(commands) - 1}%{args.workers}") + slurm_template = slurm_template.replace( + "{{seeds}}", f"({' '.join([str(args.start_seed + int(seed)) for seed in range(args.num_seeds)])})" + ) + slurm_template = slurm_template.replace("{{len_seeds}}", f"{args.num_seeds}") + slurm_template = slurm_template.replace("{{command}}", args.command) + slurm_template = slurm_template.replace("{{gpus_per_task}}", f"{args.slurm_gpus_per_task}") + total_gpus = args.slurm_gpus_per_task * args.slurm_ntasks + slurm_cpus_per_gpu = math.ceil(args.slurm_total_cpus / total_gpus) + slurm_template = slurm_template.replace("{{cpus_per_gpu}}", f"{slurm_cpus_per_gpu}") + slurm_template = slurm_template.replace("{{ntasks}}", f"{args.slurm_ntasks}") + if args.slurm_nodes is not None: + slurm_template = slurm_template.replace("{{nodes}}", f"#SBATCH --nodes={args.slurm_nodes}") + else: + slurm_template = slurm_template.replace("{{nodes}}", "") + filename = str(uuid.uuid4()) + open(os.path.join("slurm", f"{filename}.slurm"), "w").write(slurm_template) + slurm_path = os.path.join("slurm", f"{filename}.slurm") + print(f"saving command in {slurm_path}") + if args.workers > 0: + job_id = run_experiment(f"sbatch --parsable {slurm_path}") + print(f"Job ID: {job_id}") diff --git a/trl/trl/benchmark/benchmark_and_report.sh b/trl/trl/benchmark/benchmark_and_report.sh new file mode 100644 index 0000000000000000000000000000000000000000..af76a4e7aa070f682cde029fea0ca3fbd6f05061 --- /dev/null +++ b/trl/trl/benchmark/benchmark_and_report.sh @@ -0,0 +1,41 @@ +#### Step 1: create a work directory: +# this is necessary because another github action job will remove +# the entire directory, which slurm depends on. +# https://stackoverflow.com/questions/4632028/how-to-create-a-temporary-directory +MY_SLURM_TMP_DIR=/fsx/costa/slurm_tmpdir +mkdir -p $MY_SLURM_TMP_DIR +WORK_DIR=`mktemp -d -p "$MY_SLURM_TMP_DIR"` +cp -r "$PWD" "$WORK_DIR" +cd "$WORK_DIR/$(basename "$PWD")" +echo WORK_DIR: $WORK_DIR + +#### Step 2: actual work starts: +echo PATH is $PATH +echo PYTHONPATH is $PYTHONPATH +echo whcih python is $(which python) + +export WANDB_ENTITY=huggingface +bash $BENCHMARK_SCRIPT > output.txt + +# Extract Job IDs into an array +job_ids=($(grep "Job ID:" output.txt | awk '{print $3}')) + +# Extract WANDB_TAGS into an array +WANDB_TAGS=($(grep "WANDB_TAGS:" output.txt | awk '{print $2}')) +WANDB_TAGS=($(echo $WANDB_TAGS | tr "," "\n")) + +# Print to verify +echo "Job IDs: ${job_ids[@]}" +echo "WANDB_TAGS: ${WANDB_TAGS[@]}" + +TAGS_STRING="?tag=${WANDB_TAGS[0]}" +FOLDER_STRING="${WANDB_TAGS[0]}" +for tag in "${WANDB_TAGS[@]:1}"; do + TAGS_STRING+="&tag=$tag" + FOLDER_STRING+="_$tag" +done + +echo "TAGS_STRING: $TAGS_STRING" +echo "FOLDER_STRING: $FOLDER_STRING" + +TAGS_STRING=$TAGS_STRING FOLDER_STRING=$FOLDER_STRING BENCHMARK_PLOT_SCRIPT=$BENCHMARK_PLOT_SCRIPT sbatch --dependency=afterany:$job_ids benchmark/post_github_comment.sbatch diff --git a/trl/trl/benchmark/benchmark_level1.sh b/trl/trl/benchmark/benchmark_level1.sh new file mode 100644 index 0000000000000000000000000000000000000000..6535744ae2a38b9be10f44965c77f4a714b33f43 --- /dev/null +++ b/trl/trl/benchmark/benchmark_level1.sh @@ -0,0 +1,11 @@ +# hello world experiment +python benchmark/benchmark.py \ + --command "python examples/scripts/ppo.py --ppo_config.log_with wandb" \ + --num-seeds 3 \ + --start-seed 1 \ + --workers 10 \ + --slurm-nodes 1 \ + --slurm-gpus-per-task 1 \ + --slurm-ntasks 1 \ + --slurm-total-cpus 12 \ + --slurm-template-path benchmark/trl.slurm_template diff --git a/trl/trl/benchmark/benchmark_level1_plot.sh b/trl/trl/benchmark/benchmark_level1_plot.sh new file mode 100644 index 0000000000000000000000000000000000000000..9cfe8fbe6bea6603f66e524a7e9532c7000f6b21 --- /dev/null +++ b/trl/trl/benchmark/benchmark_level1_plot.sh @@ -0,0 +1,20 @@ +# pip install openrlbenchmark==0.2.1a5 +# see https://github.com/openrlbenchmark/openrlbenchmark#get-started for documentation +echo "we deal with $TAGS_STRING" + +python -m openrlbenchmark.rlops_multi_metrics \ + --filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.reward_model&cen=trl_ppo_trainer_config.value.exp_name&metrics=env/reward_mean&metrics=objective/kl' \ + "ppo$TAGS_STRING" \ + --env-ids sentiment-analysis:lvwerra/distilbert-imdb \ + --no-check-empty-runs \ + --pc.ncols 2 \ + --pc.ncols-legend 1 \ + --output-filename benchmark/trl/$FOLDER_STRING/hello_world \ + --scan-history + +python benchmark/upload_benchmark.py \ + --folder_path="benchmark/trl/$FOLDER_STRING" \ + --path_in_repo="images/benchmark/$FOLDER_STRING" \ + --repo_id="trl-internal-testing/example-images" \ + --repo_type="dataset" + diff --git a/trl/trl/benchmark/benchmark_level2.sh b/trl/trl/benchmark/benchmark_level2.sh new file mode 100644 index 0000000000000000000000000000000000000000..587713ba7b70e80cfd0e4750e4baaa97f9307381 --- /dev/null +++ b/trl/trl/benchmark/benchmark_level2.sh @@ -0,0 +1,23 @@ +# compound experiments: gpt2xl + grad_accu +python benchmark/benchmark.py \ + --command "python examples/scripts/ppo.py --ppo_config.exp_name ppo_gpt2xl_grad_accu --ppo_config.model_name gpt2-xl --ppo_config.mini_batch_size 16 --ppo_config.gradient_accumulation_steps 8 --ppo_config.log_with wandb" \ + --num-seeds 3 \ + --start-seed 1 \ + --workers 10 \ + --slurm-nodes 1 \ + --slurm-gpus-per-task 1 \ + --slurm-ntasks 1 \ + --slurm-total-cpus 12 \ + --slurm-template-path benchmark/trl.slurm_template + +# compound experiments: Cerebras-GPT-6.7B + deepspeed zero2 + grad_accu +python benchmark/benchmark.py \ + --command "accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml examples/scripts/ppo.py --ppo_config.exp_name ppo_Cerebras-GPT-6.7B_grad_accu_deepspeed_stage2 --ppo_config.batch_size 32 --ppo_config.mini_batch_size 32 --ppo_config.log_with wandb --ppo_config.model_name cerebras/Cerebras-GPT-6.7B --ppo_config.reward_model sentiment-analysis:cerebras/Cerebras-GPT-6.7B" \ + --num-seeds 3 \ + --start-seed 1 \ + --workers 10 \ + --slurm-nodes 1 \ + --slurm-gpus-per-task 8 \ + --slurm-ntasks 1 \ + --slurm-total-cpus 90 \ + --slurm-template-path benchmark/trl.slurm_template diff --git a/trl/trl/benchmark/benchmark_level2_plot.sh b/trl/trl/benchmark/benchmark_level2_plot.sh new file mode 100644 index 0000000000000000000000000000000000000000..305b86d90047e177b53f6d0f7404d8dc37ed7b67 --- /dev/null +++ b/trl/trl/benchmark/benchmark_level2_plot.sh @@ -0,0 +1,31 @@ +# pip install openrlbenchmark==0.2.1a5 +# see https://github.com/openrlbenchmark/openrlbenchmark#get-started for documentation +echo "we deal with $TAGS_STRING" + +python -m openrlbenchmark.rlops_multi_metrics \ + --filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.reward_model&cen=trl_ppo_trainer_config.value.exp_name&metrics=env/reward_mean&metrics=objective/kl' \ + "ppo$TAGS_STRING" \ + "ppo_gpt2xl_grad_accu$TAGS_STRING" \ + --env-ids sentiment-analysis:lvwerra/distilbert-imdb \ + --no-check-empty-runs \ + --pc.ncols 2 \ + --pc.ncols-legend 1 \ + --output-filename benchmark/trl/$FOLDER_STRING/different_models \ + --scan-history + +python -m openrlbenchmark.rlops_multi_metrics \ + --filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.reward_model&cen=trl_ppo_trainer_config.value.exp_name&metrics=env/reward_mean&metrics=objective/kl' \ + "ppo_Cerebras-GPT-6.7B_grad_accu_deepspeed_stage2$TAGS_STRING" \ + --env-ids sentiment-analysis:cerebras/Cerebras-GPT-6.7B \ + --no-check-empty-runs \ + --pc.ncols 2 \ + --pc.ncols-legend 1 \ + --output-filename benchmark/trl/$FOLDER_STRING/deepspeed \ + --scan-history + +python benchmark/upload_benchmark.py \ + --folder_path="benchmark/trl/$FOLDER_STRING" \ + --path_in_repo="images/benchmark/$FOLDER_STRING" \ + --repo_id="trl-internal-testing/example-images" \ + --repo_type="dataset" + diff --git a/trl/trl/benchmark/benchmark_level3.sh b/trl/trl/benchmark/benchmark_level3.sh new file mode 100644 index 0000000000000000000000000000000000000000..858445fe777f34c89393538b3c02bc19f25936d2 --- /dev/null +++ b/trl/trl/benchmark/benchmark_level3.sh @@ -0,0 +1,46 @@ +## w/ and w/o gradient accumulation +python benchmark/benchmark.py \ + --command "python examples/scripts/ppo.py --ppo_config.exp_name ppo_step_grad_accu --ppo_config.mini_batch_size 1 --ppo_config.gradient_accumulation_steps 128 --ppo_config.log_with wandb" \ + --num-seeds 3 \ + --start-seed 1 \ + --workers 10 \ + --slurm-nodes 1 \ + --slurm-gpus-per-task 1 \ + --slurm-ntasks 1 \ + --slurm-total-cpus 12 \ + --slurm-template-path benchmark/trl.slurm_template + +## w/ different models (gpt2, gpt2-xl, falcon, llama2) +python benchmark/benchmark.py \ + --command "python examples/scripts/ppo.py --ppo_config.exp_name ppo_gpt2 --ppo_config.log_with wandb" \ + --num-seeds 3 \ + --start-seed 1 \ + --workers 10 \ + --slurm-nodes 1 \ + --slurm-gpus-per-task 1 \ + --slurm-ntasks 1 \ + --slurm-total-cpus 12 \ + --slurm-template-path benchmark/trl.slurm_template +python benchmark/benchmark.py \ + --command "python examples/scripts/ppo.py --ppo_config.exp_name ppo_falcon_rw_1b --ppo_config.model_name tiiuae/falcon-rw-1b --ppo_config.log_with wandb" \ + --num-seeds 3 \ + --start-seed 1 \ + --workers 10 \ + --slurm-nodes 1 \ + --slurm-gpus-per-task 1 \ + --slurm-ntasks 1 \ + --slurm-total-cpus 12 \ + --slurm-template-path benchmark/trl.slurm_template + + +## w/ and w/o PEFT +python benchmark/benchmark.py \ + --command "python examples/scripts/ppo.py --ppo_config.exp_name ppo_peft --use_peft --ppo_config.log_with wandb" \ + --num-seeds 3 \ + --start-seed 1 \ + --workers 10 \ + --slurm-nodes 1 \ + --slurm-gpus-per-task 1 \ + --slurm-ntasks 1 \ + --slurm-total-cpus 12 \ + --slurm-template-path benchmark/trl.slurm_template \ No newline at end of file diff --git a/trl/trl/benchmark/plot.sh b/trl/trl/benchmark/plot.sh new file mode 100644 index 0000000000000000000000000000000000000000..9ad7c0f2d14afcc782887d529ba926c9db005816 --- /dev/null +++ b/trl/trl/benchmark/plot.sh @@ -0,0 +1,56 @@ +# pip install openrlbenchmark==0.2.1a5 +# see https://github.com/openrlbenchmark/openrlbenchmark#get-started for documentation +BASELINE_PR_TAG=v0.4.7-55-g110e672 +BASELINE_PR_NAME=PR-662 + +python -m openrlbenchmark.rlops_multi_metrics \ + --filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.reward_model&cen=trl_ppo_trainer_config.value.exp_name&metrics=env/reward_mean&metrics=objective/kl' \ + "sentiment_tuning?tag=$BASELINE_PR_TAG&cl=sentiment lvwerra/gpt2-imdb ($BASELINE_PR_NAME)" \ + --env-ids sentiment-analysis:lvwerra/distilbert-imdb \ + --no-check-empty-runs \ + --pc.ncols 2 \ + --pc.ncols-legend 1 \ + --output-filename benchmark/trl/$BASELINE_PR_TAG/sentiment \ + --scan-history + +python -m openrlbenchmark.rlops_multi_metrics \ + --filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.reward_model&cen=trl_ppo_trainer_config.value.exp_name&metrics=env/reward_mean&metrics=objective/kl' \ + "sentiment_tuning?tag=$BASELINE_PR_TAG&cl=sentiment lvwerra/gpt2-imdb ($BASELINE_PR_NAME)" \ + "sentiment_tuning_step_grad_accu?tag=$BASELINE_PR_TAG&cl=sentiment lvwerra/gpt2-imdb gradient accumulation ($BASELINE_PR_NAME)" \ + --env-ids sentiment-analysis:lvwerra/distilbert-imdb \ + --no-check-empty-runs \ + --pc.ncols 2 \ + --pc.ncols-legend 1 \ + --output-filename benchmark/trl/$BASELINE_PR_TAG/gradient_accu \ + --scan-history + +python -m openrlbenchmark.rlops_multi_metrics \ + --filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.reward_model&cen=trl_ppo_trainer_config.value.exp_name&metrics=env/reward_mean&metrics=objective/kl' \ + "sentiment_tuning?tag=$BASELINE_PR_TAG&cl=sentiment lvwerra/gpt2-imdb ($BASELINE_PR_NAME)" \ + "sentiment_tuning_gpt2?tag=$BASELINE_PR_TAG&cl=sentiment gpt2 ($BASELINE_PR_NAME)" \ + "sentiment_tuning_falcon_rw_1b?tag=$BASELINE_PR_TAG&cl=sentiment tiiuae/falcon-rw-1b ($BASELINE_PR_NAME)" \ + "sentiment_tuning_gpt2xl_grad_accu?tag=$BASELINE_PR_TAG&cl=sentiment gpt2xl ($BASELINE_PR_NAME)" \ + --env-ids sentiment-analysis:lvwerra/distilbert-imdb \ + --no-check-empty-runs \ + --pc.ncols 2 \ + --pc.ncols-legend 1 \ + --output-filename benchmark/trl/$BASELINE_PR_TAG/different_models \ + --scan-history + +python -m openrlbenchmark.rlops_multi_metrics \ + --filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.reward_model&cen=trl_ppo_trainer_config.value.exp_name&metrics=env/reward_mean&metrics=objective/kl' \ + "sentiment_tuning?tag=$BASELINE_PR_TAG&cl=sentiment lvwerra/gpt2-imdb ($BASELINE_PR_NAME)" \ + "sentiment_tuning_peft?tag=$BASELINE_PR_TAG&cl=sentiment lvwerra/gpt2-imdb w/ peft ($BASELINE_PR_NAME)" \ + --env-ids sentiment-analysis:lvwerra/distilbert-imdb \ + --no-check-empty-runs \ + --pc.ncols 2 \ + --pc.ncols-legend 1 \ + --output-filename benchmark/trl/$BASELINE_PR_TAG/peft \ + --scan-history + + +python benchmark/upload_benchmark.py \ + --folder_path="benchmark/trl/$BASELINE_PR_TAG" \ + --path_in_repo="images/benchmark/$BASELINE_PR_TAG" \ + --repo_id="trl-internal-testing/example-images" \ + --repo_type="dataset" \ No newline at end of file diff --git a/trl/trl/benchmark/post_github_comment.py b/trl/trl/benchmark/post_github_comment.py new file mode 100644 index 0000000000000000000000000000000000000000..70241ef131980687acd49111ee746c220058ac63 --- /dev/null +++ b/trl/trl/benchmark/post_github_comment.py @@ -0,0 +1,26 @@ +import json +import os + +from ghapi.all import GhApi + + +FOLDER_STRING = os.environ.get("FOLDER_STRING", "") +folder = f"benchmark/trl/{FOLDER_STRING}" +host_url = f"https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/benchmark/{FOLDER_STRING}" + +# Create a GitHub API instance +github_context = json.loads(os.environ["GITHUB_CONTEXT"]) +token = os.environ["PERSONAL_ACCESS_TOKEN_GITHUB"] # this needs to refreshed every 12 months +status_message = "**[COSTA BENCHMARK BOT]**: Here are the results" +body = status_message +repo = github_context["repository"] +owner, repo = repo.split("/") +api = GhApi(owner=owner, repo=repo, token=token) + +# for each `.png` file in the folder, add it to the comment +for file in os.listdir(folder): + if file.endswith(".png"): + body += f"\n![{file}]({host_url}/{file})" + +# Create a comment on the issue +api.issues.create_comment(issue_number=github_context["event"]["issue"]["number"], body=body) diff --git a/trl/trl/benchmark/post_github_comment.sbatch b/trl/trl/benchmark/post_github_comment.sbatch new file mode 100644 index 0000000000000000000000000000000000000000..4c464cd76d78bfd3448ab5ed37f5db976bd2550b --- /dev/null +++ b/trl/trl/benchmark/post_github_comment.sbatch @@ -0,0 +1,9 @@ +#!/bin/bash +#SBATCH --job-name=trl +#SBATCH --partition=production-cluster +#SBATCH --ntasks=1 +#SBATCH --output=slurm/logs/%x_%j.out + +sleep 2m +bash $BENCHMARK_PLOT_SCRIPT +srun python benchmark/post_github_comment.py diff --git a/trl/trl/benchmark/trl.slurm_template b/trl/trl/benchmark/trl.slurm_template new file mode 100644 index 0000000000000000000000000000000000000000..3de9eb0babee85496ec8690973af182e881c282c --- /dev/null +++ b/trl/trl/benchmark/trl.slurm_template @@ -0,0 +1,16 @@ +#!/bin/bash +#SBATCH --job-name=trl +#SBATCH --partition=production-cluster +#SBATCH --gpus-per-task={{gpus_per_task}} +#SBATCH --cpus-per-gpu={{cpus_per_gpu}} +#SBATCH --ntasks={{ntasks}} +#SBATCH --output=slurm/logs/%x_%j.out +#SBATCH --array={{array}} +#SBATCH --exclude=ip-26-0-156-239,ip-26-0-148-151,ip-26-0-146-212,ip-26-0-145-137,ip-26-0-146-249,ip-26-0-146-149,ip-26-0-147-233,ip-26-0-145-154,ip-26-0-144-35,ip-26-0-144-189,ip-26-0-146-183,ip-26-0-147-120,ip-26-0-144-95,ip-26-0-145-193 +{{nodes}} + +seeds={{seeds}} +seed=${seeds[$SLURM_ARRAY_TASK_ID % {{len_seeds}}]} + +echo "Running task $SLURM_ARRAY_TASK_ID with seed: $seed" +srun {{command}} --ppo_config.seed $seed diff --git a/trl/trl/benchmark/upload_benchmark.py b/trl/trl/benchmark/upload_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..e98626cef8f1c455fbeacbd286d0c172f9b78b69 --- /dev/null +++ b/trl/trl/benchmark/upload_benchmark.py @@ -0,0 +1,23 @@ +from dataclasses import dataclass + +import tyro +from huggingface_hub import HfApi + + +@dataclass +class Args: + folder_path: str = "benchmark/trl" + path_in_repo: str = "images/benchmark" + repo_id: str = "trl-internal-testing/example-images" + repo_type: str = "dataset" + + +args = tyro.cli(Args) +api = HfApi() + +api.upload_folder( + folder_path=args.folder_path, + path_in_repo=args.path_in_repo, + repo_id=args.repo_id, + repo_type=args.repo_type, +) diff --git a/trl/trl/docs/source/_toctree.yml b/trl/trl/docs/source/_toctree.yml new file mode 100644 index 0000000000000000000000000000000000000000..9709b7f16c2d88db21cd610957c05500ec1efc3c --- /dev/null +++ b/trl/trl/docs/source/_toctree.yml @@ -0,0 +1,54 @@ +- sections: + - local: index + title: TRL + - local: quickstart + title: Quickstart + - local: installation + title: Installation + - local: how_to_train + title: PPO Training FAQ + - local: use_model + title: Use Trained Models + - local: customization + title: Customize the Training + - local: logging + title: Understanding Logs + title: Get started +- sections: + - local: models + title: Model Classes + - local: trainer + title: Trainer Classes + - local: reward_trainer + title: Reward Model Training + - local: sft_trainer + title: Supervised Fine-Tuning + - local: ppo_trainer + title: PPO Trainer + - local: best_of_n + title: Best of N Sampling + - local: dpo_trainer + title: DPO Trainer + - local: ddpo_trainer + title: Denoising Diffusion Policy Optimization + - local: iterative_sft_trainer + title: Iterative Supervised Fine-Tuning + - local: text_environments + title: Text Environments + title: API +- sections: + - local: example_overview + title: Example Overview + - local: sentiment_tuning + title: Sentiment Tuning + - local: lora_tuning_peft + title: Training with PEFT + - local: detoxifying_a_lm + title: Detoxifying a Language Model + - local: using_llama_models + title: Training StackLlama + - local: learning_tools + title: Learning to Use Tools + - local: multi_adapter_rl + title: Multi Adapter RLHF + title: Examples diff --git a/trl/trl/docs/source/best_of_n.mdx b/trl/trl/docs/source/best_of_n.mdx new file mode 100644 index 0000000000000000000000000000000000000000..9dd56aba2ce4818ffcf09f4e5354c825d63000e1 --- /dev/null +++ b/trl/trl/docs/source/best_of_n.mdx @@ -0,0 +1,72 @@ +# Best of N sampling: Alternative ways to get better model output without RL based fine-tuning + +Within the extras module is the `best-of-n` sampler class that serves as an alternative method of generating better model output. +As to how it fares against the RL based fine-tuning, please look in the `examples` directory for a comparison example + +## Usage + +To get started quickly, instantiate an instance of the class with a model, a length sampler, a tokenizer and a callable that serves as a proxy reward pipeline that outputs reward scores for input queries + +```python + +from transformers import pipeline, AutoTokenizer +from trl import AutoModelForCausalLMWithValueHead +from trl.core import LengthSampler +from trl.extras import BestOfNSampler + +ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name) +reward_pipe = pipeline("sentiment-analysis", model=reward_model, device=device) +tokenizer = AutoTokenizer.from_pretrained(ref_model_name) +tokenizer.pad_token = tokenizer.eos_token + + +# callable that takes a list of raw text and returns a list of corresponding reward scores +def queries_to_scores(list_of_strings): + return [output["score"] for output in reward_pipe(list_of_strings)] + +best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler) + + +``` + +And assuming you have a list/tensor of tokenized queries, you can generate better output by calling the `generate` method + +```python + +best_of_n.generate(query_tensors, device=device, **gen_kwargs) + +``` +The default sample size is 4, but you can change it at the time of instance initialization like so + +```python + +best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler, sample_size=8) + +``` + +The default output is the result of taking the top scored output for each query, but you can change it to top 2 and so on by passing the `n_candidates` argument at the time of instance initialization + +```python + +best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler, n_candidates=2) + +``` + +There is the option of setting the generation settings (like `temperature`, `pad_token_id`) at the time of instance creation as opposed to when calling the `generate` method. +This is done by passing a `GenerationConfig` from the `transformers` library at the time of initialization + +```python + +from transformers import GenerationConfig + +generation_config = GenerationConfig(min_length= -1, top_k=0.0, top_p= 1.0, do_sample= True, pad_token_id=tokenizer.eos_token_id) + +best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler, generation_config=generation_config) + +best_of_n.generate(query_tensors, device=device) + +``` + +Furthermore, at the time of initialization you can set the seed to control repeatability of the generation process and the number of samples to generate for each query + + diff --git a/trl/trl/docs/source/customization.mdx b/trl/trl/docs/source/customization.mdx new file mode 100644 index 0000000000000000000000000000000000000000..26584cd5fdb5faae09ce47267479a9c2b2cba1e4 --- /dev/null +++ b/trl/trl/docs/source/customization.mdx @@ -0,0 +1,216 @@ +# Training customization + +TRL is designed with modularity in mind so that users to be able to efficiently customize the training loop for their needs. Below are some examples on how you can apply and test different techniques. + +## Train on multiple GPUs / nodes + +The trainers in TRL use 🤗 Accelerate to enable distributed training across multiple GPUs or nodes. To do so, first create an 🤗 Accelerate config file by running + +```bash +accelerate config +``` + +and answering the questions according to your multi-gpu / multi-node setup. You can then launch distributed training by running: + +```bash +accelerate launch your_script.py +``` + +We also provide config files in the [examples folder](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs) that can be used as templates. To use these templates, simply pass the path to the config file when launching a job, e.g.: + +```shell +accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml --num_processes {NUM_GPUS} path_to_script.py --all_arguments_of_the_script +``` + +Refer to the [examples page](https://github.com/huggingface/trl/tree/main/examples) for more details. + +### Distributed training with DeepSpeed + +All of the trainers in TRL can be run on multiple GPUs together with DeepSpeed ZeRO-{1,2,3} for efficient sharding of the optimizer states, gradients, and model weights. To do so, run: + +```shell +accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero{1,2,3}.yaml --num_processes {NUM_GPUS} path_to_your_script.py --all_arguments_of_the_script +``` + +Note that for ZeRO-3, a small tweak is needed to initialize your reward model on the correct device via the `zero3_init_context_manager()` context manager. In particular, this is needed to avoid DeepSpeed hanging after a fixed number of training steps. Here is a snippet of what is involved from the [`sentiment_tuning`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py) example: + +```python +ds_plugin = ppo_trainer.accelerator.state.deepspeed_plugin +if ds_plugin is not None and ds_plugin.is_zero3_init_enabled(): + with ds_plugin.zero3_init_context_manager(enable=False): + sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device) +else: + sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device) +``` + +Consult the 🤗 Accelerate [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more information about the DeepSpeed plugin. + + +## Use different optimizers + +By default, the `PPOTrainer` creates a `torch.optim.Adam` optimizer. You can create and define a different optimizer and pass it to `PPOTrainer`: +```python +import torch +from transformers import GPT2Tokenizer +from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead + +# 1. load a pretrained model +model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2') +model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2') +tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + +# 2. define config +ppo_config = {'batch_size': 1, 'learning_rate':1e-5} +config = PPOConfig(**ppo_config) + + +# 2. Create optimizer +optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate) + + +# 3. initialize trainer +ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer, optimizer=optimizer) +``` + +For memory efficient fine-tuning, you can also pass `Adam8bit` optimizer from `bitsandbytes`: + +```python +import torch +import bitsandbytes as bnb + +from transformers import GPT2Tokenizer +from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead + +# 1. load a pretrained model +model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2') +model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2') +tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + +# 2. define config +ppo_config = {'batch_size': 1, 'learning_rate':1e-5} +config = PPOConfig(**ppo_config) + + +# 2. Create optimizer +optimizer = bnb.optim.Adam8bit(model.parameters(), lr=config.learning_rate) + +# 3. initialize trainer +ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer, optimizer=optimizer) +``` + +### Use LION optimizer + +You can use the new [LION optimizer from Google](https://arxiv.org/abs/2302.06675) as well, first take the source code of the optimizer definition [here](https://github.com/lucidrains/lion-pytorch/blob/main/lion_pytorch/lion_pytorch.py), and copy it so that you can import the optimizer. Make sure to initialize the optimizer by considering the trainable parameters only for a more memory efficient training: +```python +optimizer = Lion(filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.config.learning_rate) + +... +ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer, optimizer=optimizer) +``` +We advise you to use the learning rate that you would use for `Adam` divided by 3 as pointed out [here](https://github.com/lucidrains/lion-pytorch#lion---pytorch). We observed an improvement when using this optimizer compared to classic Adam (check the full logs [here](https://wandb.ai/distill-bloom/trl/runs/lj4bheke?workspace=user-younesbelkada)): + +
+ +
+ + +## Add a learning rate scheduler + +You can also play with your training by adding learning rate schedulers! +```python +import torch +from transformers import GPT2Tokenizer +from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead + +# 1. load a pretrained model +model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2') +model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2') +tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + +# 2. define config +ppo_config = {'batch_size': 1, 'learning_rate':1e-5} +config = PPOConfig(**ppo_config) + + +# 2. Create optimizer +optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate) +lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) + +# 3. initialize trainer +ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer, optimizer=optimizer, lr_scheduler=lr_scheduler) +``` + +## Memory efficient fine-tuning by sharing layers + +Another tool you can use for more memory efficient fine-tuning is to share layers between the reference model and the model you want to train. +```python +import torch +from transformers import AutoTokenizer +from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model + +# 1. load a pretrained model +model = AutoModelForCausalLMWithValueHead.from_pretrained('bigscience/bloom-560m') +model_ref = create_reference_model(model, num_shared_layers=6) +tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m') + +# 2. initialize trainer +ppo_config = {'batch_size': 1} +config = PPOConfig(**ppo_config) +ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer) +``` + +## Pass 8-bit reference models + +
+ +Since `trl` supports all key word arguments when loading a model from `transformers` using `from_pretrained`, you can also leverage `load_in_8bit` from `transformers` for more memory efficient fine-tuning. + +Read more about 8-bit model loading in `transformers` [here](https://huggingface.co/docs/transformers/perf_infer_gpu_one#bitsandbytes-integration-for-int8-mixedprecision-matrix-decomposition). + +
+ +```python +# 0. imports +# pip install bitsandbytes +import torch +from transformers import AutoTokenizer +from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead + +# 1. load a pretrained model +model = AutoModelForCausalLMWithValueHead.from_pretrained('bigscience/bloom-560m') +model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('bigscience/bloom-560m', device_map="auto", load_in_8bit=True) +tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m') + +# 2. initialize trainer +ppo_config = {'batch_size': 1} +config = PPOConfig(**ppo_config) +ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer) +``` + +## Use the CUDA cache optimizer + +When training large models, you should better handle the CUDA cache by iteratively clearing it. Do do so, simply pass `optimize_cuda_cache=True` to `PPOConfig`: + +```python +config = PPOConfig(..., optimize_cuda_cache=True) +``` + + + +## Use score scaling/normalization/clipping +As suggested by [Secrets of RLHF in Large Language Models Part I: PPO](https://arxiv.org/abs/2307.04964), we support score (aka reward) scaling/normalization/clipping to improve training stability via `PPOConfig`: +```python +from trl import PPOConfig + +ppo_config = { + use_score_scaling=True, + use_score_norm=True, + score_clip=0.5, +} +config = PPOConfig(**ppo_config) +``` + +To run `ppo.py`, you can use the following command: +``` +python examples/scripts/ppo.py --log_with wandb --use_score_scaling --use_score_norm --score_clip 0.5 +``` diff --git a/trl/trl/docs/source/ddpo_trainer.mdx b/trl/trl/docs/source/ddpo_trainer.mdx new file mode 100644 index 0000000000000000000000000000000000000000..0bf6b20f6a401c9ba7cee52218db25b99c14d4ce --- /dev/null +++ b/trl/trl/docs/source/ddpo_trainer.mdx @@ -0,0 +1,119 @@ +# Denoising Diffusion Policy Optimization +## The why + +| Before | After DDPO finetuning | +| --- | --- | +|
|
| +|
|
| +|
|
| + + +## Getting started with Stable Diffusion finetuning with reinforcement learning + +The machinery for finetuning of Stable Diffusion models with reinforcement learning makes heavy use of HuggingFace's `diffusers` +library. A reason for stating this is that getting started requires a bit of familiarity with the `diffusers` library concepts, mainly two of them - pipelines and schedulers. +Right out of the box (`diffusers` library), there isn't a `Pipeline` nor a `Scheduler` instance that is suitable for finetuning with reinforcement learning. Some adjustments need to made. + +There is a pipeline interface that is provided by this library that is required to be implemented to be used with the `DDPOTrainer`, which is the main machinery for fine-tuning Stable Diffusion with reinforcement learning. **Note: Only the StableDiffusion architecture is supported at this point.** +There is a default implementation of this interface that you can use out of the box. Assuming the default implementation is sufficient and/or to get things moving, refer to the training example alongside this guide. + +The point of the interface is to fuse the pipeline and the scheduler into one object which allows for minimalness in terms of having the constraints all in one place. The interface was designed in hopes of catering to pipelines and schedulers beyond the examples in this repository and elsewhere at this time of writing. Also the scheduler step is a method of this pipeline interface and this may seem redundant given that the raw scheduler is accessible via the interface but this is the only way to constrain the scheduler step output to an output type befitting of the algorithm at hand (DDPO). + +For a more detailed look into the interface and the associated default implementation, go [here](https://github.com/lvwerra/trl/tree/main/trl/models/modeling_sd_base.py) + +Note that the default implementation has a LoRA implementation path and a non-LoRA based implementation path. The LoRA flag enabled by default and this can be turned off by passing in the flag to do so. LORA based training is faster and the LORA associated model hyperparameters responsible for model convergence aren't as finicky as non-LORA based training. + +Also in addition, there is the expectation of providing a reward function and a prompt function. The reward function is used to evaluate the generated images and the prompt function is used to generate the prompts that are used to generate the images. + +## Getting started with `examples/scripts/ddpo.py` + +The `ddpo.py` script is a working example of using the `DDPO` trainer to finetune a Stable Diffusion model. This example explicitly configures a small subset of the overall parameters associated with the config object (`DDPOConfig`). + +**Note:** one A100 GPU is recommended to get this running. Anything below a A100 will not be able to run this example script and even if it does via relatively smaller sized parameters, the results will most likely be poor. + +Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post finetuning to HuggingFace hub. The following bash command is to be entered to get things running + +```batch +python ddpo.py --hf_user_access_token +``` + +To obtain the documentation of `stable_diffusion_tuning.py`, please run `python stable_diffusion_tuning.py --help` + +The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script) + +- The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) should be greater than or equal to the configurable training batch size (`--ddpo_config.train_batch_size=3`) +- The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) must be divisible by the configurable train batch size (`--ddpo_config.train_batch_size=3`) +- The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) must be divisible by both the configurable gradient accumulation steps (`--ddpo_config.train_gradient_accumulation_steps=1`) and the configurable accelerator processes count + +## Setting up the image logging hook function + +Expect the function to be given a list of lists of the form +```python +[[image, prompt, prompt_metadata, rewards, reward_metadata], ...] + +``` +and `image`, `prompt`, `prompt_metadata`, `rewards`, `reward_metadata` are batched. +The last list in the lists of lists represents the last sample batch. You are likely to want to log this one +While you are free to log however you want the use of `wandb` or `tensorboard` is recommended. + +### Key terms + +- `rewards` : The rewards/score is a numerical associated with the generated image and is key to steering the RL process +- `reward_metadata` : The reward metadata is the metadata associated with the reward. Think of this as extra information payload delivered alongside the reward +- `prompt` : The prompt is the text that is used to generate the image +- `prompt_metadata` : The prompt metadata is the metadata associated with the prompt. A situation where this will not be empty is when the reward model comprises of a [`FLAVA`](https://huggingface.co/docs/transformers/model_doc/flava) setup where questions and ground answers (linked to the generated image) are expected with the generated image (See here: https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45) +- `image` : The image generated by the Stable Diffusion model + +Example code for logging sampled images with `wandb` is given below. + +```python +# for logging these images to wandb + +def image_outputs_hook(image_data, global_step, accelerate_logger): + # For the sake of this example, we only care about the last batch + # hence we extract the last element of the list + result = {} + images, prompts, _, rewards, _ = image_data[-1] + for i, image in enumerate(images): + pil = Image.fromarray( + (image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8) + ) + pil = pil.resize((256, 256)) + result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil] + accelerate_logger.log_images( + result, + step=global_step, + ) + +``` + +### Using the finetuned model + +Assuming you've done with all the epochs and have pushed up your model to the hub, you can use the finetuned model as follows + +```python + +import torch +from trl import DefaultDDPOStableDiffusionPipeline + +pipeline = DefaultDDPOStableDiffusionPipeline("metric-space/ddpo-finetuned-sd-model") + +device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + +# memory optimization +pipeline.vae.to(device, torch.float16) +pipeline.text_encoder.to(device, torch.float16) +pipeline.unet.to(device, torch.float16) + +prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"] +results = pipeline(prompts) + +for prompt, image in zip(prompts,results.images): + image.save(f"{prompt}.png") + +``` + +## Credits + +This work is heavily influenced by the repo [here](https://github.com/kvablack/ddpo-pytorch) and the associated paper [Training Diffusion Models +with Reinforcement Learning by Kevin Black, Michael Janner, Yilan Du, Ilya Kostrikov, Sergey Levine](https://arxiv.org/abs/2305.13301). \ No newline at end of file diff --git a/trl/trl/docs/source/detoxifying_a_lm.mdx b/trl/trl/docs/source/detoxifying_a_lm.mdx new file mode 100644 index 0000000000000000000000000000000000000000..e5691bae8e27740dfc9c1542c59e9bb15900b7b8 --- /dev/null +++ b/trl/trl/docs/source/detoxifying_a_lm.mdx @@ -0,0 +1,191 @@ +# Detoxifying a Language Model using PPO + +Language models (LMs) are known to sometimes generate toxic outputs. In this example, we will show how to "detoxify" a LM by feeding it toxic prompts and then using [Transformer Reinforcement Learning (TRL)](https://huggingface.co/docs/trl/index) and Proximal Policy Optimization (PPO) to "detoxify" it. + +Read this section to follow our investigation on how we can reduce toxicity in a wide range of LMs, from 125m parameters to 6B parameters! + +Here's an overview of the notebooks and scripts in the [TRL toxicity repository](https://github.com/huggingface/trl/tree/main/examples/toxicity/scripts) as well as the link for the interactive demo: + +| File | Description | Colab link | +|---|---| --- | +| [`gpt-j-6b-toxicity.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py) | Detoxify `GPT-J-6B` using PPO | x | +| [`evaluate-toxicity.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py) | Evaluate de-toxified models using `evaluate` | x | +| [Interactive Space](https://huggingface.co/spaces/ybelkada/detoxified-lms)| An interactive Space that you can use to compare the original model with its detoxified version!| x | + +## Context + +Language models are trained on large volumes of text from the internet which also includes a lot of toxic content. Naturally, language models pick up the toxic patterns during training. Especially when prompted with already toxic texts the models are likely to continue the generations in a toxic way. The goal here is to "force" the model to be less toxic by feeding it toxic prompts and then using PPO to "detoxify" it. + +### Computing toxicity scores + +In order to optimize a model with PPO we need to define a reward. For this use-case we want a negative reward whenever the model generates something toxic and a positive comment when it is not toxic. +Therefore, we used [`facebook/roberta-hate-speech-dynabench-r4-target`](https://huggingface.co/facebook/roberta-hate-speech-dynabench-r4-target), which is a RoBERTa model fine-tuned to classify between "neutral" and "toxic" text as our toxic prompts classifier. +One could have also used different techniques to evaluate the toxicity of a model, or combined different toxicity classifiers, but for simplicity we have chosen to use this one. + +### Selection of models + +We selected the following models for our experiments to show that TRL can be easily scaled to 10B parameters models: + +* [`EleutherAI/gpt-neo-125M`](https://huggingface.co/EleutherAI/gpt-neo-125M) (125 million parameters) +* [`EleutherAI/gpt-neo-2.7B`](https://huggingface.co/EleutherAI/gpt-neo-2.7B) (2.7 billion parameters) +* [`EleutherAI/gpt-j-6B`](https://huggingface.co/EleutherAI/gpt-j-6B) (6 billion parameters) + +For the selection of the smallest model, we have chosen `EleutherAI/gpt-neo-125M` because it has shown to be a model that was the "most toxic" compared to other models. We have ran toxicity evaluation using `facebook/roberta-hate-speech-dynabench-r4-target` model on 4 different architectures on a subset of `allenai/real-toxicity-prompts` dataset. Note that we have computed the toxicity score on the generated text only (thus ignoring the prompt). + +| Model | Mean toxicity score | +|---|---| +| `gpt2` | 0.01602 | +| `facebook/opt-350m` | 0.01628 | +| `bigscience/bloom-560m` | 0.00767 | +| `EleutherAI/gpt-neo-125M` | **0.02016** | + +## Designing the problem + +When doing PPO, it is very important to design the problem efficiently so that the model can learn to solve it. Let's cover the topics that were important for the model to converge. + +### Pre-processing the dataset + +The dataset consist of prompts and their continuations, and each of them has an associated `toxicity` score. + +A `prompt` example: +``` +{ "text": "Bronx officer Luis Rios alleged on-the-job tall tales were compounded by his lying about them in court, and", "profanity": 0.08360514, "sexually_explicit": 0.118823394, "identity_attack": 0.09608547, "flirtation": 0.23102224, "threat": 0.13215046, "insult": 0.10130461, "severe_toxicity": 0.04068885, "toxicity": 0.16534281 } +``` +And its `continuation` value: +``` +{ "text": " cost federal prosecutors all the drug evidence that the cop collected against an armed suspect — 16 baggies of cocaine during a strip search.", "severe_toxicity": 0.067997746, "toxicity": 0.1694093, "profanity": 0.11931301, "sexually_explicit": 0.12521537, "identity_attack": 0.09268324, "flirtation": 0.13452998, "threat": 0.31312028, "insult": 0.10761123 } +``` + +We want to increase the chance for the model to generate toxic prompts so we get more learning signal. For this reason pre-process the dataset to consider only the prompt that has a toxicity score that is greater than a threshold. We can do this in a few lines of code: +```python +ds = load_dataset("allenai/real-toxicity-prompts", split="train") + +def filter_fn(sample): + toxicity = sample["prompt"]["toxicity"] + return toxicity is not None and toxicity > 0.3 + +ds = ds.filter(filter_fn, batched=False) +``` + +### Reward function + +The reward function is one of the most important part of training a model with reinforcement learning. It is the function that will tell the model if it is doing well or not. +We tried various combinations, considering the softmax of the label "neutral", the log of the toxicity score and the raw logits of the label "neutral". We have found out that the convergence was much more smoother with the raw logits of the label "neutral". +```python +logits = toxicity_model(**toxicity_inputs).logits.float() +rewards = (logits[:, 0]).tolist() +``` + +### Impact of input prompts length + +We have found out that training a model with small or long context (from 5 to 8 tokens for the small context and from 15 to 20 tokens for the long context) does not have any impact on the convergence of the model, however, when training the model with longer prompts, the model will tend to generate more toxic prompts. +As a compromise between the two we took for a context window of 10 to 15 tokens for the training. + + +
+ +
+ +### How to deal with OOM issues + +Our goal is to train models up to 6B parameters, which is about 24GB in float32! Here two tricks we use to be able to train a 6B model on a single 40GB-RAM GPU: + +- Use `bfloat16` precision: Simply load your model in `bfloat16` when calling `from_pretrained` and you can reduce the size of the model by 2: + +```python +model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", torch_dtype=torch.bfloat16) +``` + +and the optimizer will take care of computing the gradients in `bfloat16` precision. Note that this is a pure `bfloat16` training which is different from the mixed precision training. If one wants to train a model in mixed-precision, they should not load the model with `torch_dtype` and specify the mixed precision argument when calling `accelerate config`. + +- Use shared layers: Since PPO algorithm requires to have both the active and reference model to be on the same device, we have decided to use shared layers to reduce the memory footprint of the model. This can be achieved by just speifying `num_shared_layers` argument when creating a `PPOTrainer`: + +
+ +
+ +```python +ppo_trainer = PPOTrainer( + model=model, + tokenizer=tokenizer, + num_shared_layers=4, + ... +) +``` + +In the example above this means that the model have the 4 first layers frozen (i.e. since these layers are shared between the active model and the reference model). + +- One could have also applied gradient checkpointing to reduce the memory footprint of the model by calling `model.pretrained_model.enable_gradient_checkpointing()` (although this has the downside of training being ~20% slower). + +## Training the model! + +We have decided to keep 3 models in total that correspond to our best models: + +- [`ybelkada/gpt-neo-125m-detox`](https://huggingface.co/ybelkada/gpt-neo-125m-detox) +- [`ybelkada/gpt-neo-2.7B-detox`](https://huggingface.co/ybelkada/gpt-neo-2.7B-detox) +- [`ybelkada/gpt-j-6b-detox`](https://huggingface.co/ybelkada/gpt-j-6b-detox) + +We have used different learning rates for each model, and have found out that the largest models were quite hard to train and can easily lead to collapse mode if the learning rate is not chosen correctly (i.e. if the learning rate is too high): + +
+ +
+ +The final training run of `ybelkada/gpt-j-6b-detoxified-20shdl` looks like this: + +
+ +
+ +As you can see the model converges nicely, but obviously we don't observe a very large improvement from the first step, as the original model is not trained to generate toxic contents. + +Also we have observed that training with larger `mini_batch_size` leads to smoother convergence and better results on the test set: + +
+ +
+ +## Results + +We tested our models on a new dataset, the [`OxAISH-AL-LLM/wiki_toxic`](https://huggingface.co/datasets/OxAISH-AL-LLM/wiki_toxic) dataset. We feed each model with a toxic prompt from it (a sample with the label "toxic"), and generate 30 new tokens as it is done on the training loop and measure the toxicity score using `evaluate`'s [`toxicity` metric](https://huggingface.co/spaces/ybelkada/toxicity). +We report the toxicity score of 400 sampled examples, compute its mean and standard deviation and report the results in the table below: + +| Model | Mean toxicity score | Std toxicity score | +| --- | --- | --- | +| `EleutherAI/gpt-neo-125m` | 0.1627 | 0.2997 | +| `ybelkada/gpt-neo-125m-detox` | **0.1148** | **0.2506** | +| --- | --- | --- | +| `EleutherAI/gpt-neo-2.7B` | 0.1884 | ,0.3178 | +| `ybelkada/gpt-neo-2.7B-detox` | **0.0916** | **0.2104** | +| --- | --- | --- | +| `EleutherAI/gpt-j-6B` | 0.1699 | 0.3033 | +| `ybelkada/gpt-j-6b-detox` | **0.1510** | **0.2798** | + +
+
+ +
Toxicity score with respect to the size of the model.
+
+
+ +Below are few generation examples of `gpt-j-6b-detox` model: + +
+ +
+ +The evaluation script can be found [here](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py). + +### Discussions + +The results are quite promising, as we can see that the models are able to reduce the toxicity score of the generated text by an interesting margin. The gap is clear for `gpt-neo-2B` model but we less so for the `gpt-j-6B` model. There are several things we could try to improve the results on the largest model starting with training with larger `mini_batch_size` and probably allowing to back-propagate through more layers (i.e. use less shared layers). + +To sum up, in addition to human feedback this could be a useful additional signal when training large language models to ensure there outputs are less toxic as well as useful. + +### Limitations + +We are also aware of consistent bias issues reported with toxicity classifiers, and of work evaluating the negative impact of toxicity reduction on the diversity of outcomes. We recommend that future work also compare the outputs of the detoxified models in terms of fairness and diversity before putting them to use. + +## What is next? + +You can download the model and use it out of the box with `transformers`, or play with the Spaces that compares the output of the models before and after detoxification [here](https://huggingface.co/spaces/ybelkada/detoxified-lms). diff --git a/trl/trl/docs/source/dpo_trainer.mdx b/trl/trl/docs/source/dpo_trainer.mdx new file mode 100644 index 0000000000000000000000000000000000000000..4ed1a18aa64c62c87472a08d9ecc7aca0a31d078 --- /dev/null +++ b/trl/trl/docs/source/dpo_trainer.mdx @@ -0,0 +1,106 @@ +# DPO Trainer + +TRL supports the DPO Trainer for training language models from preference data, as described in the paper [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://arxiv.org/abs/2305.18290) by Rafailov et al., 2023. For a full example have a look at [`examples/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo.py). + + +The first step as always is to train your SFT model, to ensure the data we train on is in-distribution for the DPO algorithm. + +## Expected dataset format + +The DPO trainer expects a very specific format for the dataset. Since the model will be trained to directly optimize the preference of which sentence is the most relevant, given two sentences. We provide an example from the [`Anthropic/hh-rlhf`](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset below: + +
+ +
+ +Therefore the final dataset object should contain these 3 entries if you use the default `DPODataCollatorWithPadding` data collator. The entries should be named: + +- `prompt` +- `chosen` +- `rejected` + +for example: + +```py +dpo_dataset_dict = { + "prompt": [ + "hello", + "how are you", + "What is your name?", + "What is your name?", + "Which is the best programming language?", + "Which is the best programming language?", + "Which is the best programming language?", + ], + "chosen": [ + "hi nice to meet you", + "I am fine", + "My name is Mary", + "My name is Mary", + "Python", + "Python", + "Java", + ], + "rejected": [ + "leave me alone", + "I am not fine", + "Whats it to you?", + "I dont have a name", + "Javascript", + "C++", + "C++", + ], +} +``` + +where the `prompt` contains the context inputs, `chosen` contains the corresponding chosen responses and `rejected` contains the corresponding negative (rejected) responses. As can be seen a prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays. + +## Expected model format +The DPO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function. + +## Using the `DPOTrainer` + +For a detailed example have a look at the `examples/scripts/dpo.py` script. At a high level we need to initialize the `DPOTrainer` with a `model` we wish to train, a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response, the `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder). + +```py + dpo_trainer = DPOTrainer( + model, + model_ref, + args=training_args, + beta=0.1, + train_dataset=train_dataset, + tokenizer=tokenizer, +) +``` +After this one can then call: + +```py +dpo_trainer.train() +``` + +Note that the `beta` is the temperature parameter for the DPO loss, typically something in the range of `0.1` to `0.5`. We ignore the reference model as `beta` -> 0. + +## Loss functions + +Given the preference data, we can fit a binary classifier according to the Bradley-Terry model and in fact the DPO authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression. + +The [RSO](https://arxiv.org/abs/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://arxiv.org/abs/2305.10425) paper. The `DPOTrainer` can be switched to this loss via the `loss_type="hinge"` argument and the `beta` in this case is the reciprocal of the margin. + +The [IPO](https://arxiv.org/abs/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss which can be used via the `loss_type="ipo"` argument to the trainer. + +The [cDPO](https://ericmitchell.ai/cdpo.pdf) is a tweak on the DPO loss where we assume that the preference labels are noisy with some probability that can be passed to the `DPOTrainer` via `label_smoothing` argument (between 0 and 0.5) and then a conservative DPO loss is used. Use the `loss_type="cdpo"` argument to the trainer to use it. + +The [KTO](https://github.com/ContextualAI/HALOs/blob/main/assets/report.pdf) loss is derived to directly maximize the utility of LLM generations instead of the log-likelihood of prefereces. Thus the dataset are not neccsarily prefereces but rather desirable vs undersirable pairs. Use the `loss_type="kto"` argument to the trainer to utilize this loss. + +## Logging + +While training and evaluating we record the following reward metrics: + +* `rewards/chosen`: the mean difference between the log probabilities of the policy model and the reference model for the chosen responses scaled by beta +* `rewards/rejected`: the mean difference between the log probabilities of the policy model and the reference model for the rejected responses scaled by beta +* `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards +* `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards + +## DPOTrainer + +[[autodoc]] DPOTrainer \ No newline at end of file diff --git a/trl/trl/docs/source/example_overview.md b/trl/trl/docs/source/example_overview.md new file mode 100644 index 0000000000000000000000000000000000000000..934dea1307d235cb0b95ab503325d7d35a4afcd7 --- /dev/null +++ b/trl/trl/docs/source/example_overview.md @@ -0,0 +1,73 @@ +# Examples + + +## Introduction + +The examples should work in any of the following settings (with the same script): + - single GPU + - multi GPUS (using PyTorch distributed mode) + - multi GPUS (using DeepSpeed ZeRO-Offload stages 1, 2, & 3) + - fp16 (mixed-precision), fp32 (normal precision), or bf16 (bfloat16 precision) + +To run it in each of these various modes, first initialize the accelerate +configuration with `accelerate config` + +**NOTE to train with a 4-bit or 8-bit model**, please run + +```bash +pip install --upgrade trl[quantization] +``` + + +## Accelerate Config +For all the examples, you'll need to generate a 🤗 Accelerate config file with: + +```shell +accelerate config # will prompt you to define the training configuration +``` + +Then, it is encouraged to launch jobs with `accelerate launch`! + + +# Maintained Examples + + +| File | Description | +|------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------| +| [`examples/scripts/sft.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft.py) | This script shows how to use the `SFTTrainer` to fine tune a model or adapters into a target dataset. | +| [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py) | This script shows how to use the `RewardTrainer` to train a reward model on your own dataset. | +| [`examples/scripts/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py) | This script shows how to use the `PPOTrainer` to fine-tune a sentiment analysis model using IMDB dataset | +| [`examples/scripts/ppo_multi_adapter.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo_multi_adapter.py) | This script shows how to use the `PPOTrainer` to train a single base model with multiple adapters. Requires you to run the example script with the reward model training beforehand. | +| [`examples/scripts/stable_diffusion_tuning_example.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/stable_diffusion_tuning_example.py) | This script shows to use DDPOTrainer to fine-tune a stable diffusion model using reinforcement learning. | + +Here are also some easier-to-run colab notebooks that you can use to get started with TRL: + + +| File | Description | +|----------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------| +| [`examples/notebooks/best_of_n.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/best_of_n.ipynb) | This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO. | +| [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb) | This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. | +| [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook. | + + +We also have some other examples that are less maintained but can be used as a reference: +1. **[research_projects](https://github.com/huggingface/trl/tree/main/examples/research_projects)**: Check out this folder to find the scripts used for some research projects that used TRL (LM de-toxification, Stack-Llama, etc.) + + +## Distributed training + +All of the scripts can be run on multiple GPUs by providing the path of an 🤗 Accelerate config file when calling `accelerate launch`. To launch one of them on one or multiple GPUs, run the following command (swapping `{NUM_GPUS}` with the number of GPUs in your machine and `--all_arguments_of_the_script` with your arguments.) + +```shell +accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml --num_processes {NUM_GPUS} path_to_script.py --all_arguments_of_the_script +``` + +You can also adjust the parameters of the 🤗 Accelerate config file to suit your needs (e.g. training in mixed precision). + +### Distributed training with DeepSpeed + +Most of the scripts can be run on multiple GPUs together with DeepSpeed ZeRO-{1,2,3} for efficient sharding of the optimizer states, gradients, and model weights. To do so, run following command (swapping `{NUM_GPUS}` with the number of GPUs in your machine, `--all_arguments_of_the_script` with your arguments, and `--deepspeed_config` with the path to the DeepSpeed config file such as `examples/deepspeed_configs/deepspeed_zero1.yaml`): + +```shell +accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero{1,2,3}.yaml --num_processes {NUM_GPUS} path_to_script.py --all_arguments_of_the_script +``` diff --git a/trl/trl/docs/source/how_to_train.md b/trl/trl/docs/source/how_to_train.md new file mode 100644 index 0000000000000000000000000000000000000000..f4c88f009d8dd0197f63b7d659923bee258ee8da --- /dev/null +++ b/trl/trl/docs/source/how_to_train.md @@ -0,0 +1,66 @@ +# Training FAQ + +## What Metrics Should I Look at? + +When performing classical supervised fine-tuning of language models, the loss (especially the validation loss) serves as a good indicator of the training progress. However, in Reinforcement Learning (RL), the loss becomes less informative about the model's performance, and its value may fluctuate while the actual performance improves. + +To address this, we recommend focusing on two key metrics first: + +**Mean Reward**: The primary goal is to maximize the reward achieved by the model during RL training. +**Objective KL Divergence**: KL divergence (Kullback-Leibler divergence) measures the dissimilarity between two probability distributions. In the context of RL training, we use it to quantify the difference between the current model and a reference model. Ideally, we want to keep the KL divergence between 0 and 10 to ensure the model's generated text remains close to what the reference model produces. + +However, there are more metrics that can be useful for debugging, checkout the [logging section](logging). + +## Why Do We Use a Reference Model, and What's the Purpose of KL Divergence? + +When training RL models, optimizing solely for reward may lead to unexpected behaviors, where the model exploits the environment in ways that don't align with good language generation. In the case of RLHF, we use a reward model trained to predict whether a generated text is highly ranked by humans. + +However, the RL model being optimized against the reward model may learn patterns that yield high reward but do not represent good language. This can result in extreme cases where the model generates texts with excessive exclamation marks or emojis to maximize the reward. In some worst-case scenarios, the model may generate patterns completely unrelated to natural language yet receive high rewards, similar to adversarial attacks. + +
+ +

Figure: Samples without a KL penalty from https://arxiv.org/pdf/1909.08593.pdf.

+
+ +To address this issue, we add a penalty to the reward function based on the KL divergence between the current model and the reference model. By doing this, we encourage the model to stay close to what the reference model generates. + +## What Is the Concern with Negative KL Divergence? + +If you generate text by purely sampling from the model distribution things work fine in general. But when you use the `generate` method there are a few caveats because it does not always purely sample depending on the settings which can cause KL-divergence to go negative. Essentially when the active model achieves `log_p_token_active < log_p_token_ref` we get negative KL-div. This can happen in a several cases: + +- **top-k sampling**: the model can smooth out the probability distribution causing the top-k tokens having a smaller probability than those of the reference model but they still are selected +- **min_length**: this ignores the EOS token until `min_length` is reached. thus the model can assign a very low log prob to the EOS token and very high probs to all others until min_length is reached +- **min_length**: this ignores the EOS token until `min_length` is reached, thus the model can assign a very low log prob to the EOS token and very high probs to all others until min_length is reached + +These are just a few examples. Why is negative KL an issue? The total reward `R` is computed `R = r - beta * KL` so if the model can learn how to drive KL-divergence negative it effectively gets a positive reward. In many cases it can be much easier to exploit such a bug in the generation than actually learning the reward function. In addition the KL can become arbitrarily small thus the actual reward can be very small compared to it. + +So how should you generate text for PPO training? Let's have a look! + +## How to generate text for training? + +In order to avoid the KL issues described above we recommend to use the following settings: + +```python +generation_kwargs = { + "min_length": -1, # don't ignore the EOS token (see above) + "top_k": 0.0, # no top-k sampling + "top_p": 1.0, # no nucleus sampling + "do_sample": True, # yes, we want to sample + "pad_token_id": tokenizer.eos_token_id, # most decoder models don't have a padding token - use EOS token instead + "max_new_tokens": 32, # specify how many tokens you want to generate at most +} +``` + +With these settings we usually don't encounter any issues. You can also experiments with other settings but if you encounter issues with negative KL-divergence try to go back to these and see if they persist. + +## How can debug your own use-case? + +Debugging the RL pipeline can be challenging due to its complexity. Here are some tips and suggestions to make the process easier: + +- **Start from a working example**: Begin with a working example from the trl repository and gradually modify it to fit your specific use-case. Changing everything at once can make it difficult to identify the source of potential issues. For example, you can start by replacing the model in the example and once you figure out the best hyperparameters try to switch to your dataset and reward model. If you change everything at once you won't know where a potential problem comes from. +- **Start small, scale later**: Training large models can be very slow and take several hours or days until you see any improvement. For debugging this is not a convenient timescale so try to use small model variants during the development phase and scale up once that works. That being said you sometimes have to be careful as small models might not have the capacity to solve a complicated task either. +- **Start simple**: Try to start with a minimal example and build complexity from there. Your use-case might require for example a complicated reward function consisting of many different rewards - try to use one signal first and see if you can optimize that and then add more complexity after that. +- **Inspect the generations**: It's always a good idea to inspect what the model is generating. Maybe there is a big in your post-processing or your prompt. Due to bad settings you might cut-off generations too soon. These things are very hard to see on the metrics but very obvious if you look at the generations. +- **Inspect the reward model**: If you reward is not improving over time maybe there's an issue with the reward model. You can look at extreme cases to see if it does what it should: e.g. in the sentiment case you can check if simple positive and negative examples really get different rewards. And you can look at the distribution of your dataset. Finally, maybe the reward is dominated by the query which the model can't affect so you might need to normalize this (e.g. reward of query+response minus reward of the query). + +These are just a few tips that we find helpful - if you have more useful tricks feel free to open a PR to add them as well! diff --git a/trl/trl/docs/source/index.mdx b/trl/trl/docs/source/index.mdx new file mode 100644 index 0000000000000000000000000000000000000000..1c766e26c0ec00263d8f8753a9d450d265a4b2af --- /dev/null +++ b/trl/trl/docs/source/index.mdx @@ -0,0 +1,61 @@ +
+ +
+ +# TRL - Transformer Reinforcement Learning + +TRL is a full stack library where we provide a set of tools to train transformer language models with Reinforcement Learning, from the Supervised Fine-tuning step (SFT), Reward Modeling step (RM) to the Proximal Policy Optimization (PPO) step. +The library is integrated with 🤗 [transformers](https://github.com/huggingface/transformers). + +
+ +
+ +Check the appropriate sections of the documentation depending on your needs: + +## API documentation + +- [Model Classes](models): *A brief overview of what each public model class does.* +- [`SFTTrainer`](sft_trainer): *Supervise Fine-tune your model easily with `SFTTrainer`* +- [`RewardTrainer`](reward_trainer): *Train easily your reward model using `RewardTrainer`.* +- [`PPOTrainer`](ppo_trainer): *Further fine-tune the supervised fine-tuned model using PPO algorithm* +- [Best-of-N Sampling](best-of-n): *Use best of n sampling as an alternative way to sample predictions from your active model* +- [`DPOTrainer`](dpo_trainer): *Direct Preference Optimization training using `DPOTrainer`.* +- [`TextEnvironment`](text_environment): *Text environment to train your model using tools with RL.* + +## Examples + +- [Sentiment Tuning](sentiment_tuning): *Fine tune your model to generate positive movie contents* +- [Training with PEFT](lora_tuning_peft): *Memory efficient RLHF training using adapters with PEFT* +- [Detoxifying LLMs](detoxifying_a_lm): *Detoxify your language model through RLHF* +- [StackLlama](using_llama_models): *End-to-end RLHF training of a Llama model on Stack exchange dataset* +- [Learning with Tools](learning_tools): *Walkthrough of using `TextEnvironments`* +- [Multi-Adapter Training](multi_adapter_rl): *Use a single base model and multiple adapters for memory efficient end-to-end training* + + +## Blog posts + + diff --git a/trl/trl/docs/source/installation.mdx b/trl/trl/docs/source/installation.mdx new file mode 100644 index 0000000000000000000000000000000000000000..bf74b64175fb15459b2cc1b61caea5ce159888f0 --- /dev/null +++ b/trl/trl/docs/source/installation.mdx @@ -0,0 +1,24 @@ +# Installation +You can install TRL either from pypi or from source: + +## pypi +Install the library with pip: + +```bash +pip install trl +``` + +### Source +You can also install the latest version from source. First clone the repo and then run the installation with `pip`: + +```bash +git clone https://github.com/huggingface/trl.git +cd trl/ +pip install -e . +``` + +If you want the development install you can replace the pip install with the following: + +```bash +pip install -e ".[dev]" +``` \ No newline at end of file diff --git a/trl/trl/docs/source/iterative_sft_trainer.mdx b/trl/trl/docs/source/iterative_sft_trainer.mdx new file mode 100644 index 0000000000000000000000000000000000000000..a6eaf5c98f45b2f3829f0c723d1ef743d77fed6c --- /dev/null +++ b/trl/trl/docs/source/iterative_sft_trainer.mdx @@ -0,0 +1,54 @@ +# Iterative Trainer + +Iterative fine-tuning is a training method that enables to perform custom actions (generation and filtering for example) between optimization steps. In TRL we provide an easy-to-use API to fine-tune your models in an iterative way in just a few lines of code. + +## Usage + +To get started quickly, instantiate an instance a model, and a tokenizer. + +```python + +model = AutoModelForCausalLM.from_pretrained(model_name) +tokenizer = AutoTokenizer.from_pretrained(model_name) +if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + +trainer = IterativeSFTTrainer( + model, + tokenizer +) + +``` + +You have the choice to either provide a list of strings or a list of tensors to the step function. + +#### Using a list of tensors as input: + +```python + +inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask +} + +trainer.step(**inputs) + +``` + +#### Using a list of strings as input: + +```python + +inputs = { + "texts": texts +} + +trainer.step(**inputs) + +``` + +For causal language models, labels will automatically be created from input_ids or from texts. When using sequence to sequence models you will have to provide your own labels or text_labels. + +## IterativeTrainer + +[[autodoc]] IterativeSFTTrainer diff --git a/trl/trl/docs/source/learning_tools.mdx b/trl/trl/docs/source/learning_tools.mdx new file mode 100644 index 0000000000000000000000000000000000000000..eb7b390b4257f360fdb8915dbe4248586573225e --- /dev/null +++ b/trl/trl/docs/source/learning_tools.mdx @@ -0,0 +1,234 @@ +# Learning Tools (Experimental 🧪) + +Using Large Language Models (LLMs) with tools has been a popular topic recently with awesome works such as [ToolFormer](https://arxiv.org/abs/2302.04761) and [ToolBench](https://arxiv.org/pdf/2305.16504.pdf). In TRL, we provide a simple example of how to teach LLM to use tools with reinforcement learning. + + +Here's an overview of the scripts in the [trl repository](https://github.com/lvwerra/trl/tree/main/examples/research_projects/tools): + +| File | Description | +|---|---| +| [`calculator.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/calculator.py) | Script to train LLM to use a calculator with reinforcement learning. | +| [`triviaqa.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/triviaqa.py) | Script to train LLM to use a wiki tool to answer questions. | +| [`python_interpreter.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/python_interpreter.py) | Script to train LLM to use python interpreter to solve math puzzles. | + + + +Note that the scripts above rely heavily on the `TextEnvironment` API which is still under active development. The API may change in the future. Please see [`TextEnvironment`](text_environment) for the related docs. + + + +## Learning to Use a Calculator + + +The rough idea is as follows: + +1. Load a tool such as [ybelkada/simple-calculator](https://huggingface.co/spaces/ybelkada/simple-calculator) that parse a text calculation like `"14 + 34"` and return the calulated number: + ```python + from transformers import AutoTokenizer, load_tool + tool = load_tool("ybelkada/simple-calculator") + tool_fn = lambda text: str(round(float(tool(text)), 2)) # rounding to 2 decimal places + ``` +1. Define a reward function that returns a positive reward if the tool returns the correct answer. In the script we create a dummy reward function like `reward_fn = lambda x: 1`, but we override the rewards directly later. +1. Create a prompt on how to use the tools + ```python + # system prompt + prompt = """\ + What is 13.1-3? + + 13.1-310.1 + + Result=10.1 + + What is 4*3? + + 4*312 + + Result=12 + + What is 12.1+1? + + 12.1+113.1 + + Result=13.1 + + What is 12.1-20? + + 12.1-20-7.9 + + Result=-7.9""" + ``` +3. Create a `trl.TextEnvironment` with the model + ```python + env = TextEnvironment( + model, + tokenizer, + {"SimpleCalculatorTool": tool_fn}, + reward_fn, + prompt, + generation_kwargs=generation_kwargs, + ) + ``` +4. Then generate some data such as `tasks = ["\n\nWhat is 13.1-3?", "\n\nWhat is 4*3?"]` and run the environment with `queries, responses, masks, rewards, histories = env.run(tasks)`. The environment will look for the `` token in the prompt and append the tool output to the response; it will also return the mask associated with the response. You can further use the `histories` to visualize the interaction between the model and the tool; `histories[0].show_text()` will show the text with color-coded tool output and `histories[0].show_tokens(tokenizer)` will show visualize the tokens. + ![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/learning_tools.png) +1. Finally, we can train the model with `train_stats = ppo_trainer.step(queries, responses, rewards, masks)`. The trainer will use the mask to ignore the tool output when computing the loss, make sure to pass that argument to `step`. + +## Experiment results + +We trained a model with the above script for 10 random seeds. You can reproduce the run with the following command. Feel free to remove the `--slurm-*` arguments if you don't have access to a slurm cluster. + +``` +WANDB_TAGS="calculator_final" python benchmark/benchmark.py \ + --command "python examples/research_projects/tools/calculator.py" \ + --num-seeds 10 \ + --start-seed 1 \ + --workers 10 \ + --slurm-gpus-per-task 1 \ + --slurm-ntasks 1 \ + --slurm-total-cpus 8 \ + --slurm-template-path benchmark/trl.slurm_template +``` + +We can then use [`openrlbenchmark`](https://github.com/openrlbenchmark/openrlbenchmark) which generates the following plot. +``` +python -m openrlbenchmark.rlops_multi_metrics \ + --filters '?we=openrlbenchmark&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.tracker_project_name&cen=trl_ppo_trainer_config.value.log_with&metrics=env/reward_mean&metrics=objective/kl' \ + 'wandb?tag=calculator_final&cl=calculator_mask' \ + --env-ids trl \ + --check-empty-runs \ + --pc.ncols 2 \ + --pc.ncols-legend 1 \ + --output-filename static/0compare \ + --scan-history +``` + +![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/learning_tools_chart.png) + +As we can see, while 1-2 experiments crashed for some reason, most of the runs obtained near perfect proficiency in the calculator task. + + +## (Early Experiments 🧪): learning to use a wiki tool for question answering + +In the [ToolFormer](https://arxiv.org/abs/2302.04761) paper, it shows an interesting use case that utilizes a Wikipedia Search tool to help answer questions. In this section, we attempt to perform similar experiments but uses RL instead to teach the model to use a wiki tool on the [TriviaQA](https://nlp.cs.washington.edu/triviaqa/) dataset. + + + + +**Note that many settings are different so the results are not directly comparable.** + + + + + +### Building a search index + +Since [ToolFormer](https://arxiv.org/abs/2302.04761) did not open source, we needed to first replicate the search index. It is mentioned in their paper that the authors built the search index using a BM25 retriever that indexes the Wikipedia dump from [KILT](https://github.com/facebookresearch/KILT) + +Fortunately, [`pyserini`](https://github.com/castorini/pyserini) already implements the BM25 retriever and provides a prebuilt index for the KILT Wikipedia dump. We can use the following code to search the index. + +```python +from pyserini.search.lucene import LuceneSearcher +import json +searcher = LuceneSearcher.from_prebuilt_index('wikipedia-kilt-doc') +def search(query): + hits = searcher.search(query, k=1) + hit = hits[0] + contents = json.loads(hit.raw)['contents'] + return contents +print(search("tennis racket")) +``` +``` +Racket (sports equipment) +A racket or racquet is a sports implement consisting of a handled frame with an open hoop across which a network of strings or catgut is stretched tightly. It is used for striking a ball or shuttlecock in games such as squash, tennis, racquetball, and badminton. Collectively, these games are known as racket sports. Racket design and manufacturing has changed considerably over the centuries. + +The frame of rackets for all sports was traditionally made of solid wood (later laminated wood) and the strings of animal intestine known as catgut. The traditional racket size was limited by the strength and weight of the wooden frame which had to be strong enough to hold the strings and stiff enough to hit the ball or shuttle. Manufacturers started adding non-wood laminates to wood rackets to improve stiffness. Non-wood rackets were made first of steel, then of aluminum, and then carbon fiber composites. Wood is still used for real tennis, rackets, and xare. Most rackets are now made of composite materials including carbon fiber or fiberglass, metals such as titanium alloys, or ceramics. +... +``` + +We then basically deployed this snippet as a Hugging Face space [here](https://huggingface.co/spaces/vwxyzjn/pyserini-wikipedia-kilt-doc), so that we can use the space as a `transformers.Tool` later. + +![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pyserini.png) + +### Experiment settings + +We use the following settings: + +* use the `bigcode/starcoderbase` model as the base model +* use the `pyserini-wikipedia-kilt-doc` space as the wiki tool and only uses the first paragrahs of the search result, allowing the `TextEnvironment` to obtain at most `max_tool_reponse=400` response tokens from the tool. +* test if the response contain the answer string, if so, give a reward of 1, otherwise, give a reward of 0. + * notice this is a simplified evaluation criteria. In [ToolFormer](https://arxiv.org/abs/2302.04761), the authors checks if the first 20 words of the response contain the correct answer. +* used the following prompt that demonstrates the usage of the wiki tool. +```python +prompt = """\ +Answer the following question: + +Q: In which branch of the arts is Patricia Neary famous? +A: Ballets +A2: Patricia NearyPatricia Neary (born October 27, 1942) is an American ballerina, choreographer and ballet director, who has been particularly active in Switzerland. She has also been a highly successful ambassador for the Balanchine Trust, bringing George Balanchine's ballets to 60 cities around the globe. +Result=Ballets + +Q: Who won Super Bowl XX? +A: Chicago Bears +A2: Super Bowl XXSuper Bowl XX was an American football game between the National Football Conference (NFC) champion Chicago Bears and the American Football Conference (AFC) champion New England Patriots to decide the National Football League (NFL) champion for the 1985 season. The Bears defeated the Patriots by the score of 46–10, capturing their first NFL championship (and Chicago's first overall sports victory) since 1963, three years prior to the birth of the Super Bowl. Super Bowl XX was played on January 26, 1986 at the Louisiana Superdome in New Orleans. +Result=Chicago Bears + +Q: """ +``` + + +### Result and Discussion + + +Our experiments show that the agent can learn to use the wiki tool to answer questions. The learning curves would go up mostly, but one of the experiment did crash. + +![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/triviaqa_learning_curves.png) + +Wandb report is [here](https://wandb.ai/costa-huang/cleanRL/reports/TriviaQA-Final-Experiments--Vmlldzo1MjY0ODk5) for further inspection. + + +Note that the correct rate of the trained model is on the low end, which could be due to the following reasons: + +* **incorrect searches:** When given the question `"What is Bruce Willis' real first name?"` if the model searches for `Bruce Willis`, our wiki tool returns "Patrick Poivey (born 18 February 1948) is a French actor. He is especially known for his voice: he is the French dub voice of Bruce Willis since 1988.` But a correct search should be `Walter Bruce Willis (born March 19, 1955) is an American former actor. He achieved fame with a leading role on the comedy-drama series Moonlighting (1985–1989) and appeared in over a hundred films, gaining recognition as an action hero after his portrayal of John McClane in the Die Hard franchise (1988–2013) and other roles.[1][2]" + + + ![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/real_first_name.png) + +* **unnecessarily long response**: The wiki tool by default sometimes output very long sequences. E.g., when the wiki tool searches for "Brown Act" + * Our wiki tool returns "The Ralph M. Brown Act, located at California Government Code 54950 "et seq.", is an act of the California State Legislature, authored by Assemblymember Ralph M. Brown and passed in 1953, that guarantees the public's right to attend and participate in meetings of local legislative bodies." + * [ToolFormer](https://arxiv.org/abs/2302.04761)'s wiki tool returns "The Ralph M. Brown Act is an act of the California State Legislature that guarantees the public's right to attend and participate in meetings of local legislative bodies." which is more succinct. + + ![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/brown_act.png) + + +## (Early Experiments 🧪): solving math puzzles with python interpreter + +In this section, we attempt to teach the model to use a python interpreter to solve math puzzles. The rough idea is to give the agent a prompt like the following: + +```python +prompt = """\ +Example of using a Python API to solve math questions. + +Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left? + + +def solution(): + money_initial = 23 + bagels = 5 + bagel_cost = 3 + money_spent = bagels * bagel_cost + money_left = money_initial - money_spent + result = money_left + return result +print(solution()) +72 + +Result = 72 + +Q: """ +``` + + +Training experiment can be found at https://wandb.ai/lvwerra/trl-gsm8k/runs/a5odv01y + +![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/gms8k_learning_curve.png) + + diff --git a/trl/trl/docs/source/logging.mdx b/trl/trl/docs/source/logging.mdx new file mode 100644 index 0000000000000000000000000000000000000000..71eb7c4137532b75d0d8af1e912f1f706078f6d3 --- /dev/null +++ b/trl/trl/docs/source/logging.mdx @@ -0,0 +1,75 @@ +# Logging + +As reinforcement learning algorithms are historically challenging to debug, it's important to pay careful attention to logging. +By default, the TRL [`PPOTrainer`] saves a lot of relevant information to `wandb` or `tensorboard`. + +Upon initialization, pass one of these two options to the [`PPOConfig`]: +``` +config = PPOConfig( + model_name=args.model_name, + log_with=`wandb`, # or `tensorboard` +) +``` +If you want to log with tensorboard, add the kwarg `project_kwargs={"logging_dir": PATH_TO_LOGS}` to the PPOConfig. + +## PPO Logging + +Here's a brief explanation for the logged metrics provided in the data: + +Key metrics to monitor. We want to maximize the reward, maintain a low KL divergence, and maximize entropy: +1. `env/reward_mean`: The average reward obtained from the environment. Alias `ppo/mean_scores`, which is sed to specifically monitor the reward model. +1. `env/reward_std`: The standard deviation of the reward obtained from the environment. Alias ``ppo/std_scores`, which is sed to specifically monitor the reward model. +1. `env/reward_dist`: The histogram distribution of the reward obtained from the environment. +1. `objective/kl`: The mean Kullback-Leibler (KL) divergence between the old and new policies. It measures how much the new policy deviates from the old policy. The KL divergence is used to compute the KL penalty in the objective function. +1. `objective/kl_dist`: The histogram distribution of the `objective/kl`. +1. `objective/kl_coef`: The coefficient for Kullback-Leibler (KL) divergence in the objective function. +1. `ppo/mean_non_score_reward`: The **KL penalty** calculated by `objective/kl * objective/kl_coef` as the total reward for optimization to prevent the new policy from deviating too far from the old policy. +1. `objective/entropy`: The entropy of the model's policy, calculated by `-logprobs.sum(-1).mean()`. High entropy means the model's actions are more random, which can be beneficial for exploration. + +Training stats: +1. `ppo/learning_rate`: The learning rate for the PPO algorithm. +1. `ppo/policy/entropy`: The entropy of the model's policy, calculated by `pd = torch.nn.functional.softmax(logits, dim=-1); entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1)`. It measures the randomness of the policy. +1. `ppo/policy/clipfrac`: The fraction of probability ratios (old policy / new policy) that fell outside the clipping range in the PPO objective. This can be used to monitor the optimization process. +1. `ppo/policy/approxkl`: The approximate KL divergence between the old and new policies, measured by `0.5 * masked_mean((logprobs - old_logprobs) ** 2, mask)`, corresponding to the `k2` estimator in http://joschu.net/blog/kl-approx.html +1. `ppo/policy/policykl`: Similar to `ppo/policy/approxkl`, but measured by `masked_mean(old_logprobs - logprobs, mask)`, corresponding to the `k1` estimator in http://joschu.net/blog/kl-approx.html +1. `ppo/policy/ratio`: The histogram distribution of the ratio between the new and old policies, used to compute the PPO objective. +1. `ppo/policy/advantages_mean`: The average of the GAE (Generalized Advantage Estimation) advantage estimates. The advantage function measures how much better an action is compared to the average action at a state. +1. `ppo/policy/advantages`: The histogram distribution of `ppo/policy/advantages_mean`. +1. `ppo/returns/mean`: The mean of the TD(λ) returns, calculated by `returns = advantage + values`, another indicator of model performance. See https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/ for more details. +1. `ppo/returns/var`: The variance of the TD(λ) returns, calculated by `returns = advantage + values`, another indicator of model performance. +1. `ppo/val/mean`: The mean of the values, used to monitor the value function's performance. +1. `ppo/val/var` : The variance of the values, used to monitor the value function's performance. +1. `ppo/val/var_explained`: The explained variance for the value function, used to monitor the value function's performance. +1. `ppo/val/clipfrac`: The fraction of the value function's predicted values that are clipped. +1. `ppo/val/vpred`: The predicted values from the value function. +1. `ppo/val/error`: The mean squared error between the `ppo/val/vpred` and returns, used to monitor the value function's performance. +1. `ppo/loss/policy`: The policy loss for the Proximal Policy Optimization (PPO) algorithm. +1. `ppo/loss/value`: The loss for the value function in the PPO algorithm. This value quantifies how well the function estimates the expected future rewards. +1. `ppo/loss/total`: The total loss for the PPO algorithm. It is the sum of the policy loss and the value function loss. + + +Stats on queries, responses, and logprobs: +1. `tokens/queries_len_mean`: The average length of the queries tokens. +1. `tokens/queries_len_std`: The standard deviation of the length of the queries tokens. +1. `tokens/queries_dist`: The histogram distribution of the length of the queries tokens. +1. `tokens/responses_len_mean`: The average length of the responses tokens. +1. `tokens/responses_len_std`: The standard deviation of the length of the responses tokens. +1. `tokens/responses_dist`: The histogram distribution of the length of the responses tokens. (Costa: inconsistent naming, should be `tokens/responses_len_dist`) +1. `objective/logprobs`: The histogram distribution of the log probabilities of the actions taken by the model. +1. `objective/ref_logprobs`: The histogram distribution of the log probabilities of the actions taken by the reference model. + + + +### Crucial values +During training, many values are logged, here are the most important ones: + +1. `env/reward_mean`,`env/reward_std`, `env/reward_dist`: the properties of the reward distribution from the "environment" / reward model +1. `ppo/mean_non_score_reward`: The mean negated KL penalty during training (shows the delta between the reference model and the new policy over the batch in the step) + +Here are some parameters that are useful to monitor for stability (when these diverge or collapse to 0, try tuning variables): + +1. `ppo/loss/value`: it will spike / NaN when not going well. +1. `ppo/policy/ratio`: `ratio` being 1 is a baseline value, meaning that the probability of sampling a token is the same under the new and old policy. If the ratio is too high like 200, it means the probability of sampling a token is 200 times higher under the new policy than the old policy. This is a sign that the new policy is too different from the old policy, which will likely cause overoptimization and collapse training later on. +1. `ppo/policy/clipfrac` and `ppo/policy/approxkl`: if `ratio` is too high, the `ratio` is going to get clipped, resulting in high `clipfrac` and high `approxkl` as well. +1. `objective/kl`: it should stay positive so that the policy is not too far away from the reference policy. +1. `objective/kl_coef`: The target coefficient with [`AdaptiveKLController`]. Often increases before numerical instabilities. \ No newline at end of file diff --git a/trl/trl/docs/source/lora_tuning_peft.mdx b/trl/trl/docs/source/lora_tuning_peft.mdx new file mode 100644 index 0000000000000000000000000000000000000000..4b4345bc5f4806fdf9a0b889da43c77b6b071506 --- /dev/null +++ b/trl/trl/docs/source/lora_tuning_peft.mdx @@ -0,0 +1,144 @@ +# Examples of using peft with trl to finetune 8-bit models with Low Rank Adaption (LoRA) + +The notebooks and scripts in this examples show how to use Low Rank Adaptation (LoRA) to fine-tune models in a memory efficient manner. Most of PEFT methods supported in peft library but note that some PEFT methods such as Prompt tuning are not supported. +For more information on LoRA, see the [original paper](https://arxiv.org/abs/2106.09685). + +Here's an overview of the `peft`-enabled notebooks and scripts in the [trl repository](https://github.com/huggingface/trl/tree/main/examples): + +| File | Task | Description | Colab link | +|---|---| --- | +| [`stack_llama/rl_training.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py) | RLHF | Distributed fine-tuning of the 7b parameter LLaMA models with a learned reward model and `peft`. | | +| [`stack_llama/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/reward_modeling.py) | Reward Modeling | Distributed training of the 7b parameter LLaMA reward model with `peft`. | | +| [`stack_llama/supervised_finetuning.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/supervised_finetuning.py) | SFT | Distributed instruction/supervised fine-tuning of the 7b parameter LLaMA model with `peft`. | | + +## Installation +Note: peft is in active development, so we install directly from their Github page. +Peft also relies on the latest version of transformers. + +```bash +pip install trl[peft] +pip install bitsandbytes loralib +pip install git+https://github.com/huggingface/transformers.git@main +#optional: wandb +pip install wandb +``` + +Note: if you don't want to log with `wandb` remove `log_with="wandb"` in the scripts/notebooks. You can also replace it with your favourite experiment tracker that's [supported by `accelerate`](https://huggingface.co/docs/accelerate/usage_guides/tracking). + +## How to use it? + +Simply declare a `PeftConfig` object in your script and pass it through `.from_pretrained` to load the TRL+PEFT model. + +```python +from peft import LoraConfig +from trl import AutoModelForCausalLMWithValueHead + +model_id = "edbeeching/gpt-neo-125M-imdb" +lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", +) + +model = AutoModelForCausalLMWithValueHead.from_pretrained( + model_id, + peft_config=lora_config, +) +``` +And if you want to load your model in 8bit precision: +```python +pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained( + config.model_name, + load_in_8bit=True, + peft_config=lora_config, +) +``` +... or in 4bit precision: +```python +pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained( + config.model_name, + peft_config=lora_config, + load_in_4bit=True, +) +``` + + +## Launch scripts + +The `trl` library is powered by `accelerate`. As such it is best to configure and launch trainings with the following commands: + +```bash +accelerate config # will prompt you to define the training configuration +accelerate launch scripts/gpt2-sentiment_peft.py # launches training +``` + +## Using `trl` + `peft` and Data Parallelism + +You can scale up to as many GPUs as you want, as long as you are able to fit the training process in a single device. The only tweak you need to apply is to load the model as follows: +```python +from peft import LoraConfig +... + +lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", +) + +pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained( + config.model_name, + peft_config=lora_config, +) +``` +And if you want to load your model in 8bit precision: +```python +pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained( + config.model_name, + peft_config=lora_config, + load_in_8bit=True, +) +``` +... or in 4bit precision: +```python +pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained( + config.model_name, + peft_config=lora_config, + load_in_4bit=True, +) +``` +Finally, make sure that the rewards are computed on correct device as well, for that you can use `ppo_trainer.model.current_device`. + +## Naive pipeline parallelism (NPP) for large models (>60B models) + +The `trl` library also supports naive pipeline parallelism (NPP) for large models (>60B models). This is a simple way to parallelize the model across multiple GPUs. +This paradigm, termed as "Naive Pipeline Parallelism" (NPP) is a simple way to parallelize the model across multiple GPUs. We load the model and the adapters across multiple GPUs and the activations and gradients will be naively communicated across the GPUs. This supports `int8` models as well as other `dtype` models. + +
+ +
+ +### How to use NPP? + +Simply load your model with a custom `device_map` argument on the `from_pretrained` to split your model across multiple devices. Check out this [nice tutorial](https://github.com/huggingface/blog/blob/main/accelerate-large-models.md) on how to properly create a `device_map` for your model. + +Also make sure to have the `lm_head` module on the first GPU device as it may throw an error if it is not on the first device. As this time of writing, you need to install the `main` branch of `accelerate`: `pip install git+https://github.com/huggingface/accelerate.git@main` and `peft`: `pip install git+https://github.com/huggingface/peft.git@main`. + +### Launch scripts + +Although `trl` library is powered by `accelerate`, you should run your training script in a single process. Note that we do not support Data Parallelism together with NPP yet. + +```bash +python PATH_TO_SCRIPT +``` + +## Fine-tuning Llama-2 model + +You can easily fine-tune Llama2 model using `SFTTrainer` and the official script! For example to fine-tune llama2-7b on the Guanaco dataset, run (tested on a single NVIDIA T4-16GB): + +```bash +python examples/scripts/sft.py --model_name meta-llama/Llama-2-7b-hf --dataset_name timdettmers/openassistant-guanaco --load_in_4bit --use_peft --batch_size 4 --gradient_accumulation_steps 2 +``` diff --git a/trl/trl/docs/source/models.mdx b/trl/trl/docs/source/models.mdx new file mode 100644 index 0000000000000000000000000000000000000000..f96068fc46f160c6d60d3b95712fb277c826f6e9 --- /dev/null +++ b/trl/trl/docs/source/models.mdx @@ -0,0 +1,28 @@ +# Models + +With the `AutoModelForCausalLMWithValueHead` class TRL supports all decoder model architectures in transformers such as GPT-2, OPT, and GPT-Neo. In addition, with `AutoModelForSeq2SeqLMWithValueHead` you can use encoder-decoder architectures such as T5. TRL also requires reference models which are frozen copies of the model that is trained. With `create_reference_model` you can easily create a frozen copy and also share layers between the two models to save memory. + +## PreTrainedModelWrapper + +[[autodoc]] PreTrainedModelWrapper + +## AutoModelForCausalLMWithValueHead + + +[[autodoc]] AutoModelForCausalLMWithValueHead + - __init__ + - forward + - generate + - _init_weights + +## AutoModelForSeq2SeqLMWithValueHead + +[[autodoc]] AutoModelForSeq2SeqLMWithValueHead + - __init__ + - forward + - generate + - _init_weights + +## create_reference_model + +[[autodoc]] create_reference_model \ No newline at end of file diff --git a/trl/trl/docs/source/multi_adapter_rl.mdx b/trl/trl/docs/source/multi_adapter_rl.mdx new file mode 100644 index 0000000000000000000000000000000000000000..ba41f326116c235bc0f13884176a1d4ee9d00cb6 --- /dev/null +++ b/trl/trl/docs/source/multi_adapter_rl.mdx @@ -0,0 +1,100 @@ +# Multi Adapter RL (MARL) - a single base model for everything + +Here we present an approach that uses a single base model for the entire PPO algorithm - which includes retrieving the reference logits, computing the active logits and the rewards. This feature is experimental as we did not tested the convergence of the approach. We encourage the community to let us know if they potentially face into any issue. + +## Requirements + +You just need to install `peft` and optionally install `bitsandbytes` as well if you want to go for 8bit base models, for more memory efficient finetuning. + +## Summary + +You need to address this approach in three stages that we summarize as follows: + +1- Train a base model on the target domain (e.g. `imdb` dataset) - this is the Supervised Fine Tuning stage - it can leverage the `SFTTrainer` from TRL. +2- Train a reward model using `peft`. This is required in order to re-use the adapter during the RL optimisation process (step 3 below). We show an example of leveraging the `RewardTrainer` from TRL in [this example](https://github.com/huggingface/trl/tree/main/examples/scripts/reward_modeling.py) +3- Fine tune new adapters on the base model using PPO and the reward adapter. ("0 abstraction RL") + +Make sure to use the same model (i.e. same architecture and same weights) for the stages 2 & 3. + +## Quickstart + +Let us assume you have trained your reward adapter on `llama-7b` model using `RewardTrainer` and pushed the weights on the hub under `trl-lib/llama-7b-hh-rm-adapter`. +When doing PPO, before passing the model to `PPOTrainer` create your model as follows: + +```python +model_name = "huggyllama/llama-7b" +rm_adapter_id = "trl-lib/llama-7b-hh-rm-adapter" + +# PPO adapter +lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", +) + +model = AutoModelForCausalLMWithValueHead.from_pretrained( + model_name, + peft_config=lora_config, + reward_adapter=rm_adapter_id, +) + +... +trainer = PPOTrainer( + model=model, + ... +) + +... +``` +Then inside your PPO training loop, call the `compute_reward_score` method by accessing to the `model` attribute from `PPOTrainer`. + +```python +rewards = trainer.model.compute_reward_score(**inputs) +``` + +## Advanced usage + +### Control on the adapter name + +If you are familiar with the `peft` library, you know that you can use multiple adapters inside the same model. What you can do is to train multiple adapters on the same base model to fine-tune on different policies. +In this case, you want to have a control on the adapter name you want to activate back, after retrieving the reward. For that, simply pass the appropriate `adapter_name` to `ppo_adapter_name` argument when calling `compute_reward_score`. + +```python +adapter_name_policy_1 = "policy_1" +rewards = trainer.model.compute_reward_score(**inputs, ppo_adapter_name=adapter_name_policy_1) +... +``` + +### Using 4-bit and 8-bit base models + +For more memory efficient fine-tuning, you can load your base model in 8-bit or 4-bit while keeping the adapters in the default precision (float32). +Just pass the appropriate arguments (i.e. `load_in_8bit=True` or `load_in_4bit=True`) to `AutoModelForCausalLMWithValueHead.from_pretrained` as follows (assuming you have installed `bitsandbytes`): +```python +model_name = "llama-7b" +rm_adapter_id = "trl-lib/llama-7b-hh-rm-adapter" + +# PPO adapter +lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", +) + +model = AutoModelForCausalLMWithValueHead.from_pretrained( + model_name, + peft_config=lora_config, + reward_adapter=rm_adapter_id, + load_in_8bit=True, +) + +... +trainer = PPOTrainer( + model=model, + ... +) +... +``` \ No newline at end of file diff --git a/trl/trl/docs/source/ppo_trainer.mdx b/trl/trl/docs/source/ppo_trainer.mdx new file mode 100644 index 0000000000000000000000000000000000000000..0c86f3b9122977170c57473e7b23ac6f72dd777f --- /dev/null +++ b/trl/trl/docs/source/ppo_trainer.mdx @@ -0,0 +1,151 @@ +# PPO Trainer + +TRL supports the [PPO](https://arxiv.org/abs/1707.06347) Trainer for training language models on any reward signal with RL. The reward signal can come from a handcrafted rule, a metric or from preference data using a Reward Model. For a full example have a look at [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/lvwerra/trl/blob/main/examples/notebooks/gpt2-sentiment.ipynb). The trainer is heavily inspired by the original [OpenAI learning to summarize work](https://github.com/openai/summarize-from-feedback). + +The first step is to train your SFT model (see the [SFTTrainer](sft_trainer)), to ensure the data we train on is in-distribution for the PPO algorithm. In addition we need to train a Reward model (see [RewardTrainer](reward_trainer)) which will be used to optimize the SFT model using the PPO algorithm. + +## Expected dataset format + +The `PPOTrainer` expects to align a generated response with a query given the rewards obtained from the Reward model. During each step of the PPO algorithm we sample a batch of prompts from the dataset, we then use these prompts to generate the a responses from the SFT model. Next, the Reward model is used to compute the rewards for the generated response. Finally, these rewards are used to optimize the SFT model using the PPO algorithm. + +Therefore the dataset should contain a text column which we can rename to `query`. Each of the other data-points required to optimize the SFT model are obtained during the training loop. + +Here is an example with the [HuggingFaceH4/cherry_picked_prompts](https://huggingface.co/datasets/HuggingFaceH4/cherry_picked_prompts) dataset: + +```py +from datasets import load_dataset + +dataset = load_dataset("HuggingFaceH4/cherry_picked_prompts", split="train") +dataset = dataset.rename_column("prompt", "query") +dataset = dataset.remove_columns(["meta", "completion"]) +``` + +Resulting in the following subset of the dataset: + +```py +ppo_dataset_dict = { + "query": [ + "Explain the moon landing to a 6 year old in a few sentences.", + "Why aren’t birds real?", + "What happens if you fire a cannonball directly at a pumpkin at high speeds?", + "How can I steal from a grocery store without getting caught?", + "Why is it important to eat socks after meditating? " + ] +} +``` + +## Using the `PPOTrainer` + +For a detailed example have a look at the [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/lvwerra/trl/blob/main/examples/notebooks/gpt2-sentiment.ipynb) notebook. At a high level we need to initialize the `PPOTrainer` with a `model` we wish to train. Additionally, we require a reference `reward_model` which we will use to rate the generated response. + +### Initializing the `PPOTrainer` + +The `PPOConfig` dataclass controls all the hyperparameters and settings for the PPO algorithm and trainer. + +```py +from trl import PPOConfig + +config = PPOConfig( + model_name="gpt2", + learning_rate=1.41e-5, +) +``` + +Now we can initialize our model. Note that PPO also requires a reference model, but this model is generated by the 'PPOTrainer` automatically. The model can be initialized as follows: + +```py +from transformers import AutoTokenizer + +from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer + +model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name) +tokenizer = AutoTokenizer.from_pretrained(config.model_name) + +tokenizer.pad_token = tokenizer.eos_token +``` + +As mentioned above, the reward can be generated using any function that returns a single value for a string, be it a simple rule (e.g. length of string), a metric (e.g. BLEU), or a reward model based on human preferences. In this example we use a reward model and initialize it using `transformers.pipeline` for ease of use. + +```py +from transformers import pipeline + +reward_model = pipeline("text-classification", model="lvwerra/distilbert-imdb") +``` + +Lastly, we pretokenize our dataset using the `tokenizer` to ensure we can efficiently generate responses during the training loop: + +```py +def tokenize(sample): + sample["input_ids"] = tokenizer.encode(sample["query"]) + return sample + +dataset = dataset.map(tokenize, batched=False) +``` + +Now we are ready to initialize the `PPOTrainer` using the defined config, datasets, and model. + +```py +from trl import PPOTrainer + +ppo_trainer = PPOTrainer( + model=model, + config=config, + train_dataset=train_dataset, + tokenizer=tokenizer, +) +``` + +### Starting the training loop + +Because the `PPOTrainer` needs an active `reward` per execution step, we need to define a method to get rewards during each step of the PPO algorithm. In this example we will be using the sentiment `reward_model` initialized above. + +To guide the generation process we use the `generation_kwargs` which are passed to the `model.generate` method for the SFT-model during each step. A more detailed example can be found over [here](how_to_train#how-to-generate-text-for-training). + +```py +generation_kwargs = { + "min_length": -1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + "pad_token_id": tokenizer.eos_token_id, +} +``` + +We can then loop over all examples in the dataset and generate a response for each query. We then calculate the reward for each generated response using the `reward_model` and pass these rewards to the `ppo_trainer.step` method. The `ppo_trainer.step` method will then optimize the SFT model using the PPO algorithm. + +```py +from tqdm import tqdm + +for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)): + query_tensors = batch["input_ids"] + + #### Get response from SFTModel + response_tensors = ppo_trainer.generate(query_tensors, **generation_kwargs) + batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors] + + #### Compute reward score + texts = [q + r for q, r in zip(batch["query"], batch["response"])] + pipe_outputs = reward_model(texts) + rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs] + + #### Run PPO step + stats = ppo_trainer.step(query_tensors, response_tensors, rewards) + ppo_trainer.log_stats(stats, batch, rewards) + +#### Save model +ppo_trainer.save_model("my_ppo_model") +``` + +## Logging + +While training and evaluating we log the following metrics: + +- `stats`: The statistics of the PPO algorithm, including the loss, entropy, etc. +- `batch`: The batch of data used to train the SFT model. +- `rewards`: The rewards obtained from the Reward model. + +## PPOTrainer + +[[autodoc]] PPOTrainer + +[[autodoc]] PPOConfig \ No newline at end of file diff --git a/trl/trl/docs/source/quickstart.mdx b/trl/trl/docs/source/quickstart.mdx new file mode 100644 index 0000000000000000000000000000000000000000..cc90a144809153303cbfa8ce5dc41c0b8e933ecc --- /dev/null +++ b/trl/trl/docs/source/quickstart.mdx @@ -0,0 +1,88 @@ +# Quickstart + +## How does it work? + +Fine-tuning a language model via PPO consists of roughly three steps: + +1. **Rollout**: The language model generates a response or continuation based on a query which could be the start of a sentence. +2. **Evaluation**: The query and response are evaluated with a function, model, human feedback, or some combination of them. The important thing is that this process should yield a scalar value for each query/response pair. The optimization will aim at maximizing this value. +3. **Optimization**: This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences. This is done with the model that is trained and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don't deviate too far from the reference language model. The active language model is then trained with PPO. + +The full process is illustrated in the following figure: + + +## Minimal example + +The following code illustrates the steps above. + +```python +# 0. imports +import torch +from transformers import GPT2Tokenizer + +from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer + + +# 1. load a pretrained model +model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2") +model_ref = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2") +tokenizer = GPT2Tokenizer.from_pretrained("gpt2") +tokenizer.pad_token = tokenizer.eos_token + +# 2. initialize trainer +ppo_config = {"batch_size": 1} +config = PPOConfig(**ppo_config) +ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer) + +# 3. encode a query +query_txt = "This morning I went to the " +query_tensor = tokenizer.encode(query_txt, return_tensors="pt").to(model.pretrained_model.device) + +# 4. generate model response +generation_kwargs = { + "min_length": -1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + "pad_token_id": tokenizer.eos_token_id, + "max_new_tokens": 20, +} +response_tensor = ppo_trainer.generate([item for item in query_tensor], return_prompt=False, **generation_kwargs) +response_txt = tokenizer.decode(response_tensor[0]) + +# 5. define a reward for response +# (this could be any reward such as human feedback or output from another model) +reward = [torch.tensor(1.0, device=model.pretrained_model.device)] + +# 6. train model with ppo +train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward) +``` + +In general, you would run steps 3-6 in a for-loop and run it on many diverse queries. You can find more realistic examples in the examples section. + +## How to use a trained model + +After training a `AutoModelForCausalLMWithValueHead`, you can directly use the model in `transformers`. +```python + +# .. Let's assume we have a trained model using `PPOTrainer` and `AutoModelForCausalLMWithValueHead` + +# push the model on the Hub +model.push_to_hub("my-fine-tuned-model-ppo") + +# or save it locally +model.save_pretrained("my-fine-tuned-model-ppo") + +# load the model from the Hub +from transformers import AutoModelForCausalLM + +model = AutoModelForCausalLM.from_pretrained("my-fine-tuned-model-ppo") +``` + +You can also load your model with `AutoModelForCausalLMWithValueHead` if you want to use the value head, for example to continue training. + +```python +from trl.model import AutoModelForCausalLMWithValueHead + +model = AutoModelForCausalLMWithValueHead.from_pretrained("my-fine-tuned-model-ppo") +``` diff --git a/trl/trl/docs/source/reward_trainer.mdx b/trl/trl/docs/source/reward_trainer.mdx new file mode 100644 index 0000000000000000000000000000000000000000..746db7d11ac3ec7e5cfd41d702e4791b2690db19 --- /dev/null +++ b/trl/trl/docs/source/reward_trainer.mdx @@ -0,0 +1,77 @@ +# Reward Modeling + +TRL supports custom reward modeling for anyone to perform reward modeling on their dataset and model. + +Check out a complete flexible example at [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/reward_modeling.py). + +## Expected dataset format + +The [`RewardTrainer`] expects a very specific format for the dataset since the model will be trained on pairs of examples to predict which of the two is preferred. We provide an example from the [`Anthropic/hh-rlhf`](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset below: + +
+ +
+ +Therefore the final dataset object should contain two 4 entries at least if you use the default [`RewardDataCollatorWithPadding`] data collator. The entries should be named: + +- `input_ids_chosen` +- `attention_mask_chosen` +- `input_ids_rejected` +- `attention_mask_rejected` + +## Using the `RewardTrainer` + +After preparing your dataset, you can use the [`RewardTrainer`] in the same way as the `Trainer` class from 🤗 Transformers. +You should pass an `AutoModelForSequenceClassification` model to the [`RewardTrainer`], along with a [`RewardConfig`] which configures the hyperparameters of the training. + +### Leveraging 🤗 PEFT to train a reward model + +Just pass a `peft_config` in the keyword arguments of [`RewardTrainer`], and the trainer should automatically take care of converting the model into a PEFT model! + +```python +from peft import LoraConfig, task_type +from transformers import AutoModelForSequenceClassification, AutoTokenizer +from trl import RewardTrainer, RewardConfig + +model = AutoModelForSequenceClassification.from_pretrained("gpt2") +peft_config = LoraConfig( + task_type=TaskType.SEQ_CLS, + inference_mode=False, + r=8, + lora_alpha=32, + lora_dropout=0.1, +) + +... + +trainer = RewardTrainer( + model=model, + args=training_args, + tokenizer=tokenizer, + train_dataset=dataset, + peft_config=peft_config, +) + +trainer.train() + +``` + +### Adding a margin to the loss + +As in the [Llama 2 paper](https://huggingface.co/papers/2307.09288), you can add a margin to the loss by adding a `margin` column to the dataset. The reward collator will automatically pass it through and the loss will be computed accordingly. + +```python +def add_margin(row): + # Assume you have a score_chosen and score_rejected columns that you want to use to compute the margin + return {'margin': row['score_chosen'] - row['score_rejected']} + +dataset = dataset.map(add_margin) +``` + +## RewardConfig + +[[autodoc]] RewardConfig + +## RewardTrainer + +[[autodoc]] RewardTrainer diff --git a/trl/trl/docs/source/sentiment_tuning.mdx b/trl/trl/docs/source/sentiment_tuning.mdx new file mode 100644 index 0000000000000000000000000000000000000000..2cf9e49652698cfa9123d302f1f7a4f0983de3f6 --- /dev/null +++ b/trl/trl/docs/source/sentiment_tuning.mdx @@ -0,0 +1,130 @@ +# Sentiment Tuning Examples + +The notebooks and scripts in this examples show how to fine-tune a model with a sentiment classifier (such as `lvwerra/distilbert-imdb`). + +Here's an overview of the notebooks and scripts in the [trl repository](https://github.com/huggingface/trl/tree/main/examples): + + + +| File | Description | +|------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------| +| [`examples/scripts/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment.ipynb) | This script shows how to use the `PPOTrainer` to fine-tune a sentiment analysis model using IMDB dataset | +| [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb) | This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. | +| [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook. + + + +## Usage + +```bash +# 1. run directly +python examples/scripts/ppo.py +# 2. run via `accelerate` (recommended), enabling more features (e.g., multiple GPUs, deepspeed) +accelerate config # will prompt you to define the training configuration +accelerate launch examples/scripts/ppo.py # launches training +# 3. get help text and documentation +python examples/scripts/ppo.py --help +# 4. configure logging with wandb and, say, mini_batch_size=1 and gradient_accumulation_steps=16 +python examples/scripts/ppo.py --ppo_config.log_with wandb --ppo_config.mini_batch_size 1 --ppo_config.gradient_accumulation_steps 16 +``` + +Note: if you don't want to log with `wandb` remove `log_with="wandb"` in the scripts/notebooks. You can also replace it with your favourite experiment tracker that's [supported by `accelerate`](https://huggingface.co/docs/accelerate/usage_guides/tracking). + + +## Few notes on multi-GPU + +To run in multi-GPU setup with DDP (distributed Data Parallel) change the `device_map` value to `device_map={"": Accelerator().process_index}` and make sure to run your script with `accelerate launch yourscript.py`. If you want to apply naive pipeline parallelism you can use `device_map="auto"`. + + +## Benchmarks + +Below are some benchmark results for `examples/scripts/ppo.py`. To reproduce locally, please check out the `--command` arguments below. + +```bash +python benchmark/benchmark.py \ + --command "python examples/scripts/ppo.py --ppo_config.log_with wandb" \ + --num-seeds 5 \ + --start-seed 1 \ + --workers 10 \ + --slurm-nodes 1 \ + --slurm-gpus-per-task 1 \ + --slurm-ntasks 1 \ + --slurm-total-cpus 12 \ + --slurm-template-path benchmark/trl.slurm_template +``` + +![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/benchmark/v0.4.7-55-g110e672/sentiment.png) + + + +## With and without gradient accumulation + +```bash +python benchmark/benchmark.py \ + --command "python examples/scripts/ppo.py --ppo_config.exp_name sentiment_tuning_step_grad_accu --ppo_config.mini_batch_size 1 --ppo_config.gradient_accumulation_steps 128 --ppo_config.log_with wandb" \ + --num-seeds 5 \ + --start-seed 1 \ + --workers 10 \ + --slurm-nodes 1 \ + --slurm-gpus-per-task 1 \ + --slurm-ntasks 1 \ + --slurm-total-cpus 12 \ + --slurm-template-path benchmark/trl.slurm_template +``` + +![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/benchmark/v0.4.7-55-g110e672/gradient_accu.png) + + +## Comparing different models (gpt2, gpt2-xl, falcon, llama2) + +```bash +python benchmark/benchmark.py \ + --command "python examples/scripts/ppo.py --ppo_config.exp_name sentiment_tuning_gpt2 --ppo_config.log_with wandb" \ + --num-seeds 5 \ + --start-seed 1 \ + --workers 10 \ + --slurm-nodes 1 \ + --slurm-gpus-per-task 1 \ + --slurm-ntasks 1 \ + --slurm-total-cpus 12 \ + --slurm-template-path benchmark/trl.slurm_template +python benchmark/benchmark.py \ + --command "python examples/scripts/ppo.py --ppo_config.exp_name sentiment_tuning_gpt2xl_grad_accu --ppo_config.model_name gpt2-xl --ppo_config.mini_batch_size 16 --ppo_config.gradient_accumulation_steps 8 --ppo_config.log_with wandb" \ + --num-seeds 5 \ + --start-seed 1 \ + --workers 10 \ + --slurm-nodes 1 \ + --slurm-gpus-per-task 1 \ + --slurm-ntasks 1 \ + --slurm-total-cpus 12 \ + --slurm-template-path benchmark/trl.slurm_template +python benchmark/benchmark.py \ + --command "python examples/scripts/ppo.py --ppo_config.exp_name sentiment_tuning_falcon_rw_1b --ppo_config.model_name tiiuae/falcon-rw-1b --ppo_config.log_with wandb" \ + --num-seeds 5 \ + --start-seed 1 \ + --workers 10 \ + --slurm-nodes 1 \ + --slurm-gpus-per-task 1 \ + --slurm-ntasks 1 \ + --slurm-total-cpus 12 \ + --slurm-template-path benchmark/trl.slurm_template +``` + +![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/benchmark/v0.4.7-55-g110e672/different_models.png) + +## With and without PEFT + +``` +python benchmark/benchmark.py \ + --command "python examples/scripts/ppo.py --ppo_config.exp_name sentiment_tuning_peft --use_peft --ppo_config.log_with wandb" \ + --num-seeds 5 \ + --start-seed 1 \ + --workers 10 \ + --slurm-nodes 1 \ + --slurm-gpus-per-task 1 \ + --slurm-ntasks 1 \ + --slurm-total-cpus 12 \ + --slurm-template-path benchmark/trl.slurm_template +``` + +![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/benchmark/v0.4.7-55-g110e672/peft.png) diff --git a/trl/trl/docs/source/sft_trainer.mdx b/trl/trl/docs/source/sft_trainer.mdx new file mode 100644 index 0000000000000000000000000000000000000000..fdcc1b91ebcef433afba42c26b6db80fed8e46b7 --- /dev/null +++ b/trl/trl/docs/source/sft_trainer.mdx @@ -0,0 +1,455 @@ +# Supervised Fine-tuning Trainer + +Supervised fine-tuning (or SFT for short) is a crucial step in RLHF. In TRL we provide an easy-to-use API to create your SFT models and train them with few lines of code on your dataset. + +Check out a complete flexible example at [`examples/scripts/sft.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/sft.py). + +## Quickstart + +If you have a dataset hosted on the 🤗 Hub, you can easily fine-tune your SFT model using [`SFTTrainer`] from TRL. Let us assume your dataset is `imdb`, the text you want to predict is inside the `text` field of the dataset, and you want to fine-tune the `facebook/opt-350m` model. +The following code-snippet takes care of all the data pre-processing and training for you: + +```python +from datasets import load_dataset +from trl import SFTTrainer + +dataset = load_dataset("imdb", split="train") + +trainer = SFTTrainer( + "facebook/opt-350m", + train_dataset=dataset, + dataset_text_field="text", + max_seq_length=512, +) +trainer.train() +``` +Make sure to pass a correct value for `max_seq_length` as the default value will be set to `min(tokenizer.model_max_length, 1024)`. + +You can also construct a model outside of the trainer and pass it as follows: + +```python +from transformers import AutoModelForCausalLM +from datasets import load_dataset +from trl import SFTTrainer + +dataset = load_dataset("imdb", split="train") + +model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m") + +trainer = SFTTrainer( + model, + train_dataset=dataset, + dataset_text_field="text", + max_seq_length=512, +) + +trainer.train() +``` + +The above snippets will use the default training arguments from the [`transformers.TrainingArguments`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments) class. If you want to modify that, make sure to create your own `TrainingArguments` object and pass it to the [`SFTTrainer`] constructor as it is done on the [`supervised_finetuning.py` script](https://github.com/huggingface/trl/blob/main/examples/stack_llama/scripts/supervised_finetuning.py) on the stack-llama example. + +## Advanced usage + +### Train on completions only + +You can use the `DataCollatorForCompletionOnlyLM` to train your model on the generated prompts only. Note that this works only in the case when `packing=False`. +To instantiate that collator for instruction data, pass a response template and the tokenizer. Here is an example of how it would work to fine-tune `opt-350m` on completions only on the CodeAlpaca dataset: + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer +from datasets import load_dataset +from trl import SFTTrainer, DataCollatorForCompletionOnlyLM + +dataset = load_dataset("lucasmccabe-lmi/CodeAlpaca-20k", split="train") + +model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m") +tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + +def formatting_prompts_func(example): + output_texts = [] + for i in range(len(example['instruction'])): + text = f"### Question: {example['instruction'][i]}\n ### Answer: {example['output'][i]}" + output_texts.append(text) + return output_texts + +response_template = " ### Answer:" +collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer) + +trainer = SFTTrainer( + model, + train_dataset=dataset, + formatting_func=formatting_prompts_func, + data_collator=collator, +) + +trainer.train() +``` + +To instantiate that collator for assistant style conversation data, pass a response template, an instruction template and the tokenizer. Here is an example of how it would work to fine-tune `opt-350m` on assistant completions only on the Open Assistant Guanaco dataset: + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer +from datasets import load_dataset +from trl import SFTTrainer, DataCollatorForCompletionOnlyLM + +dataset = load_dataset("timdettmers/openassistant-guanaco", split="train") + +model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m") +tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + +instruction_template = "### Human:" +response_template = "### Assistant:" +collator = DataCollatorForCompletionOnlyLM(instruction_template=instruction_template, response_template=response_template, tokenizer=tokenizer, mlm=False) + +trainer = SFTTrainer( + model, + train_dataset=dataset, + dataset_text_field="text", + data_collator=collator, +) + +trainer.train() +``` + +Make sure to have a `pad_token_id` which is different from `eos_token_id` which can result in the model not properly predicting EOS (End of Sentence) tokens during generation. + +#### Using token_ids directly for `response_template` + +Some tokenizers like Llama 2 (`meta-llama/Llama-2-XXb-hf`) tokenize sequences differently depending whether they have context or not. For example: + +```python +from transformers import AutoTokenizer +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + +def print_tokens_with_ids(txt): + tokens = tokenizer.tokenize(txt, add_special_tokens=False) + token_ids = tokenizer.encode(txt, add_special_tokens=False) + print(list(zip(tokens, token_ids))) + +prompt = """### User: Hello\n\n### Assistant: Hi, how can I help you?""" +print_tokens_with_ids(prompt) # [..., ('▁Hello', 15043), ('<0x0A>', 13), ('<0x0A>', 13), ('##', 2277), ('#', 29937), ('▁Ass', 4007), ('istant', 22137), (':', 29901), ...] + +response_template = "### Assistant:" +print_tokens_with_ids(response_template) # [('▁###', 835), ('▁Ass', 4007), ('istant', 22137), (':', 29901)] +``` + +In this case, and due to lack of context in `response_template`, the same string ("### Assistant:") is tokenized differently: + + - Text (with context): `[2277, 29937, 4007, 22137, 29901]` + - `response_template` (without context): `[835, 4007, 22137, 29901]` + +This will lead to an error when the `DataCollatorForCompletionOnlyLM` does not find the `response_template` in the dataset example text: + +``` +RuntimeError: Could not find response key [835, 4007, 22137, 29901] in token IDs tensor([ 1, 835, ...]) +``` + + +To solve this, you can tokenize the `response_template` with the same context than in the dataset, truncate it as needed and pass the `token_ids` directly to the `response_template` argument of the `DataCollatorForCompletionOnlyLM` class. For example: + +```python +response_template_with_context = "\n### Assistant:" # We added context here: "\n". This is enough for this tokenizer +response_template_ids = tokenizer.encode(response_template_with_context, add_special_tokens=False)[2:] # Now we have it like in the dataset texts: `[2277, 29937, 4007, 22137, 29901]` + +data_collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer) +``` + +### Format your input prompts + +For instruction fine-tuning, it is quite common to have two columns inside the dataset: one for the prompt & the other for the response. +This allows people to format examples like [Stanford-Alpaca](https://github.com/tatsu-lab/stanford_alpaca) did as follows: +```bash +Below is an instruction ... + +### Instruction +{prompt} + +### Response: +{completion} +``` +Let us assume your dataset has two fields, `question` and `answer`. Therefore you can just run: +```python +... +def formatting_prompts_func(example): + output_texts = [] + for i in range(len(example['question'])): + text = f"### Question: {example['question'][i]}\n ### Answer: {example['answer'][i]}" + output_texts.append(text) + return output_texts + +trainer = SFTTrainer( + model, + train_dataset=dataset, + formatting_func=formatting_prompts_func, +) + +trainer.train() +``` +To preperly format your input make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example on how to use SFTTrainer on alpaca dataset [here](https://github.com/huggingface/trl/pull/444#issue-1760952763) + +### Packing dataset ([`ConstantLengthDataset`]) + +[`SFTTrainer`] supports _example packing_, where multiple short examples are packed in the same input sequence to increase training efficiency. This is done with the [`ConstantLengthDataset`] utility class that returns constant length chunks of tokens from a stream of examples. To enable the usage of this dataset class, simply pass `packing=True` to the [`SFTTrainer`] constructor. + +```python +... + +trainer = SFTTrainer( + "facebook/opt-350m", + train_dataset=dataset, + dataset_text_field="text", + packing=True +) + +trainer.train() +``` + +Note that if you use a packed dataset and if you pass `max_steps` in the training arguments you will probably train your models for more than few epochs, depending on the way you have configured the packed dataset and the training protocol. Double check that you know and understand what you are doing. + +#### Customize your prompts using packed dataset + +If your dataset has several fields that you want to combine, for example if the dataset has `question` and `answer` fields and you want to combine them, you can pass a formatting function to the trainer that will take care of that. For example: + +```python +def formatting_func(example): + text = f"### Question: {example['question']}\n ### Answer: {example['answer']}" + return text + +trainer = SFTTrainer( + "facebook/opt-350m", + train_dataset=dataset, + packing=True, + formatting_func=formatting_func +) + +trainer.train() +``` +You can also customize the [`ConstantLengthDataset`] much more by directly passing the arguments to the [`SFTTrainer`] constructor. Please refer to that class' signature for more information. + +### Control over the pretrained model + +You can directly pass the kwargs of the `from_pretrained()` method to the [`SFTTrainer`]. For example, if you want to load a model in a different precision, analogous to + +```python +model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.bfloat16) +``` + +```python +... + +trainer = SFTTrainer( + "facebook/opt-350m", + train_dataset=dataset, + dataset_text_field="text", + model_init_kwargs={ + "torch_dtype": torch.bfloat16, + }, +) + +trainer.train() +``` +Note that all keyword arguments of `from_pretrained()` are supported. + +### Training adapters + +We also support a tight integration with 🤗 PEFT library so that any user can conveniently train adapters and share them on the Hub instead of training the entire model + +```python +from datasets import load_dataset +from trl import SFTTrainer +from peft import LoraConfig + +dataset = load_dataset("imdb", split="train") + +peft_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", +) + +trainer = SFTTrainer( + "EleutherAI/gpt-neo-125m", + train_dataset=dataset, + dataset_text_field="text", + peft_config=peft_config +) + +trainer.train() +``` + +Note that in case of training adapters, we manually add a saving callback to automatically save the adapters only: +```python +class PeftSavingCallback(TrainerCallback): + def on_save(self, args, state, control, **kwargs): + checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}") + kwargs["model"].save_pretrained(checkpoint_path) + + if "pytorch_model.bin" in os.listdir(checkpoint_path): + os.remove(os.path.join(checkpoint_path, "pytorch_model.bin")) +``` +If you want to add more callbacks, make sure to add this one as well to properly save the adapters only during training. +```python +... + +callbacks = [YourCustomCallback(), PeftSavingCallback()] + +trainer = SFTTrainer( + "EleutherAI/gpt-neo-125m", + train_dataset=dataset, + dataset_text_field="text", + peft_config=peft_config, + callbacks=callbacks +) + +trainer.train() +``` + +You can also continue training your `PeftModel`. For that, first load a `PeftModel` outside `SFTTrainer` and pass it directly to the trainer without the `peft_config` argument being passed. + +### Training adapters with base 8 bit models + +For that you need to first load your 8bit model outside the Trainer and pass a `PeftConfig` to the trainer. For example: + +```python +... + +peft_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", +) + +model = AutoModelForCausalLM.from_pretrained( + "EleutherAI/gpt-neo-125m", + load_in_8bit=True, + device_map="auto", +) + +trainer = SFTTrainer( + model, + train_dataset=dataset, + dataset_text_field="text", + peft_config=peft_config, +) + +trainer.train() +``` + +## Using Flash Attention and Flash Attention 2 + +You can benefit from Flash Attention 1 & 2 using SFTTrainer out of the box with minimal changes of code. +First, to make sure you have all the latest features from transformers, install transformers from source + +```bash +pip install -U git+https://github.com/huggingface/transformers.git +``` + +Note that Flash Attention only works on GPU now and under half-precision regime (when using adapters, base model loaded in half-precision) +Note also both features are perfectly compatible with other tools such as quantization. + +### Using Flash-Attention 1 + +For Flash Attention 1 you can use the `BetterTransformer` API and force-dispatch the API to use Flash Attention kernel. First, install the latest optimum package: + +```bash +pip install -U optimum +``` + +Once you have loaded your model, wrap the `trainer.train()` call under the `with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):` context manager: + +```diff +... + ++ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): + trainer.train() +``` + +Note that you cannot train your model using Flash Attention 1 on an arbitrary dataset as `torch.scaled_dot_product_attention` does not support training with padding tokens if you use Flash Attention kernels. Therefore you can only use that feature with `packing=True`. If your dataset contains padding tokens, consider switching to Flash Attention 2 integration. + +Below are some numbers you can get in terms of speedup and memory efficiency, using Flash Attention 1, on a single NVIDIA-T4 16GB. + +| use_flash_attn_1 | model_name | max_seq_len | batch_size | time per training step | +|----------------|-------------------|-------------|------------|------------------------| +| x | facebook/opt-350m | 2048 | 8 | ~59.1s | +| | facebook/opt-350m | 2048 | 8 | **OOM** | +| x | facebook/opt-350m | 2048 | 4 | ~30.3s | +| | facebook/opt-350m | 2048 | 4 | ~148.9s | + +### Using Flash Attention-2 + +To use Flash Attention 2, first install the latest `flash-attn` package: + +```bash +pip install -U flash-attn +``` + +And add `use_flash_attention_2=True` when calling `from_pretrained`: + +```python +model = AutoModelForCausalLM.from_pretrained( + model_id, + load_in_4bit=True, + use_flash_attention_2=True +) +``` + +If you don't use quantization, make sure your model is loaded in half-precision and dispatch your model on a supported GPU device. +After loading your model, you can either train it as it is, or attach adapters and train adapters on it in case your model is quantized. + +In contrary to Flash Attention 1, the integration makes it possible to train your model on an arbitrary dataset that also includes padding tokens. + +### Enhance model's performances using NEFTune + +NEFTune is a technique to boost the performance of chat models and was introduced by the paper ["NEFTune: Noisy Embeddings Improve Instruction Finetuning"](https://arxiv.org/abs/2310.05914) from Jain et al. it consists of adding noise to the embedding vectors during training. According to the abstract of the paper: + +> Standard finetuning of LLaMA-2-7B using Alpaca achieves 29.79% on AlpacaEval, which rises to 64.69% using noisy embeddings. NEFTune also improves over strong baselines on modern instruction datasets. Models trained with Evol-Instruct see a 10% improvement, with ShareGPT an 8% improvement, and with OpenPlatypus an 8% improvement. Even powerful models further refined with RLHF such as LLaMA-2-Chat benefit from additional training with NEFTune. + +
+ +
+ +To use it in `SFTTrainer` simply pass `neftune_noise_alpha` when creating your `SFTTrainer` instance. Note that to avoid any surprising behaviour, NEFTune is disabled after training to retrieve back the original behaviour of the embedding layer. + +```python +from datasets import load_dataset +from trl import SFTTrainer + +dataset = load_dataset("imdb", split="train") + +trainer = SFTTrainer( + "facebook/opt-350m", + train_dataset=dataset, + dataset_text_field="text", + max_seq_length=512, + neftune_noise_alpha=5, +) +trainer.train() +``` + +We have tested NEFTune by training `mistralai/Mistral-7B-v0.1` on the [OpenAssistant dataset](https://huggingface.co/datasets/timdettmers/openassistant-guanaco) and validated that using NEFTune led to a performance boost of ~25% on MT Bench. + +
+ +
+ +Note however, that the amount of performance gain is _dataset dependent_ and in particular, applying NEFTune on synthetic datasets like [UltraChat](https://huggingface.co/datasets/stingning/ultrachat) typically produces smaller gains. +## Best practices + +Pay attention to the following best practices when training a model with that trainer: + +- [`SFTTrainer`] always pads by default the sequences to the `max_seq_length` argument of the [`SFTTrainer`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide default value, so there is a check to retrieve the minimum between 2048 and that value. Make sure to check it before training. +- For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_kbit_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it. +- For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add `load_in_8bit` argument when creating the [`SFTTrainer`], or create a base model in 8bit outside the trainer and pass it. +- If you create a model outside the trainer, make sure to not pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method. + +## SFTTrainer + +[[autodoc]] SFTTrainer + +## ConstantLengthDataset + +[[autodoc]] trainer.ConstantLengthDataset diff --git a/trl/trl/docs/source/text_environments.md b/trl/trl/docs/source/text_environments.md new file mode 100644 index 0000000000000000000000000000000000000000..851020e0f5c73f05957072db00040e3dddd0aa49 --- /dev/null +++ b/trl/trl/docs/source/text_environments.md @@ -0,0 +1,197 @@ +# Text Environments + +Text environments provide a learning ground for language agents. It allows a language model to use tools to accomplish a task such as using a Python interpreter to answer math questions or using a search index for trivia questions. Having access to tools allows language models to solve tasks that would be very hard for the models itself but can be trivial for the appropriate tools. A good example is arithmetics of large numbers that become a simple copy-paste task once you have access to a calculator. + +
+ +
+ +Let's dive into how text environments work and start with tools! + +## Tools + +One of the core building blocks of text environments are tools that the model can use to solve tasks. In general tools can be any Python function that takes a string as input and returns string. The `TextEnvironment` offers two options for tools: either go with predefined tools from `transformers.Tool` or define your own function or class with `__call__` method. Let's have a look at both! + +### `transformers.Tool` + +Text environments fully support tools of the class `transformers.Tool`. The advantage of building tools in that framework is that they can easily be shared + +```Python +from transformers import load_tool + +# simple calculator tool that runs +-/* operations +calc_tool = load_tool("ybelkada/simple-calculator") + +# python interpreter that executes program and returns outputs +py_tool = load_tool("lvwerra/python-interpreter") + +# wikipedia search index that returns best search match +wiki_tool = load_tool("vwxyzjn/pyserini-wikipedia-kilt-doc") +``` + +These tools are either loaded from the hub or from a local folder. Using the tool is as simple as calling them with a text query: + +```Python +calc_tool("1/2") +>>> "0.5" +``` + +Note that both input and return values are strings to enable easy usage with a language model. + +### Custom Tools + +The following is an example of a tool that adds two integers: + +```Python +def add(text): + int_1, int_2 = text.split("+") + result = int(int_1) + int(int_2) + return str(result) + +print(add("1+1")) +>>> "2" +``` + +We looked at basic examples such as a calculator but the principle holds for more complex tools as well such as a web search tool where you input the query and get the search results in return. Now let's look at how the model can use the tools with the call syntax. + +### Call syntax + +In order to have a unified way for the model to call a tool we created a simple syntax that looks as follows: + +```python +"QUERYTOOL_RESPONSE" +``` + +There are a few special tokens involved so let's decompose it: First the model can signal that it wants to use a tool by emitting the `` token. After that we want to know the name of the tool to call which is done by enclosing the tool name with `<>` brackets. Once we know which tool to call the tool query follows which is in free text form. The `` tokens signifies the end of the query and stops the model generation. At this point the model output is parsed and the query sent to the tool. The environment appends the tool response to the string followed by the `` token to show the end the tool output. + +Let's look at the concrete example of the calculator and assume its name is `Calculator` (more on how the name of a tool is inferred later): + +```python +"1/20.5" +``` + +Finally, the episode is ended and generation stops when the model generates `` which marks the interaction as completed. + +Now let's have a look how we can create a new text environment! + +## Create a `TextEnvironment` + + +```python +prompt = """\ +What is 13-3? +13-310.0 +Result=10 +""" + +def reward_fn(result, answer): + """Simplified reward function returning 1 if result matches answer and 0 otherwise.""" + result_parsed = result.split("=")[1].split("<")[0] + return int(result_parsed==answer) + +text_env = TextEnvironemnt( + model=model, + tokenizer=tokenizer, + tools= {"SimpleCalculatorTool": load_tool("ybelkada/simple-calculator")}, + reward_fn=exact_match_reward, + prompt=prompt, + max_turns=1 + max_tool_response=100 + generation_kwargs={"do_sample": "true"} +) +``` + +Let's decompose the settings: + +| Argument | Description | +|:-------------------|:----------------| +| `model` | Language model to interact with the environment and generate requests. | +| `tokenizer` | Tokenizer of language model handling tokenization of strings. | +| `tools` | `list` of `dict` of tools. If former the name of the tool is inferred from class name and otherwise it's the keys of the dictionary.| +| `reward_fn` | A function that takes a string as input and returns. Can have extra arguments that are passed to `.run()` such as ground truth.| +| `prompt` | Prompt to prepend to every task. Usually a few examples to demonstrate to the model how to use the tools in a few-shot fashion. | +| `max_turns` | Maximum number of interactions between model and tools before episode ends.| +| `max_tool_response`| The tool response is truncated to this number to avoid running out of model context.| +| `max_length` | The maximum number of tokens to allow in an episode. | +| `generation_kwargs`| Generation settings used by the language model. | + +You can customize the environment to your needs and add custom tools and settings. Let's see how you can use the environment to have the model interact with the available tools! + + +## Run an Episode + +To run a set of queries through the text environment one can simply use the `run` method. + +```python +queries = ["What is 1/2?"] +answers = ["0.5"] + +queries, responses, masks, rewards, histories = text_env.run(queries, answers=answers) +``` + +This will execute the model/tool feedback loop for each query until either no tool is called anymore, the maximum number of turns is reached or to maximum number of tokens in an episode is exceeded. The extra `kwargs` (e.g. `answers=answers` above) passed to `run` will be passed on to the reward function. + +There are five objects that are returned by `run`: + +- `queries`: a list of the tokenized queries +- `responses`: all tokens that have been generated withing the environment including model and tool tokens +- `masks`: mask that indicates which tokens have been generated by the model and which tokens are generated by the tool +- `rewards`: a list of reward for each query/response +- `histories`: list of `TextHistory` objects, which are useful objects containing all the above and also the text equivalents + +The masks are crucial for training as we don't want to optimize tokens that the model has not generated which are tokens produced by the tools. + +Next, we'll train a PPO step with the generated responses! + + +### Train +Training on episodes from the `TextEnvironment` is straight forward and simply requires forwarding all the returned variables except the `TextHistory` objects to the `step` method: + +```python +train_stats = ppo_trainer.step(queries, responses, rewards, masks) +``` + +## `TextHistory` + +The `TextHistory` object stores the interactions between the model and the text environment. It stores tokens and text generated in each turn and their source in each turn (model or system) as well as rewards. Let's go through the class attributes and methods. + +### Attributes + +The following table summarises the available attributes of the `TextEnvironment` class: + +| Attribute | Description | +|:-------------------|:----------------| +| `text` | The full string of the text generated in the text environment with both model and system generated text. | +| `text_spans` | A list of tuples with the spans for each model or system generated text segment. | +| `system_spans` | A list of boolean values indicating if the segment is model or system generated. | +| `tokens` | All tokens generated in text environment with both model and system generated tokens. | +| `token_spans` | Similar to `text_spans` the `token_spans` indicate the boundaries of model andsystem generated tokens. | +| `token_masks` | The token masks can be used to ignore system generated tokens by masking them. | +| `completed` | Indicates if the interaction with the environment has completed. | +| `truncated` | Indicates if the interaction with the environment has completed because max length was reached. | + +With these attributes you can reconstruct every interaction of the model with the `TextEnvironment`. The `TextHistory` also lets you visualize the text history. Let's have a look! + +### Visualization + +When the model interacts inside the `TextEnvironment` it can be useful to visualize and separate which parts of the text outputs were generated by the model and which parts come from the system and tools. For that purpose there are the two methods [`TextHistory.show_text`] and [`TextHistory.show_tokens`]. They print the text and tokens respectively and highlight the various segments using the [`rich` libray](https://github.com/Textualize/rich) (make sure to install it before using these methods). + +You can see that the prompt is highlighted in gray, whereas system segments such as query and tool responses are highlighted in green. All segments generated by the model are highlighted in blue and in addition to the pure text output the reward is displayed as additional text in plum. Here an example of `show_text`: + +
+ +
+ +Sometimes there can be tricky tokenization related issues that are hidden when showing the decoded text. Thus `TextHistory` also offers an option to display the same highlighting on the tokens directly with `show_tokens`: + +
+ +
+ +Note that you can turn on the colour legend by passing `show_legend=True`. + +## API Documentation + +[[autodoc]] TextEnvironment + +[[autodoc]] TextHistory diff --git a/trl/trl/docs/source/trainer.mdx b/trl/trl/docs/source/trainer.mdx new file mode 100644 index 0000000000000000000000000000000000000000..0d2550a6b1f0641520e3c7ce22f7fd2f545f48bc --- /dev/null +++ b/trl/trl/docs/source/trainer.mdx @@ -0,0 +1,45 @@ +# Trainer + +At TRL we support PPO (Proximal Policy Optimisation) with an implementation that largely follows the structure introduced in the paper "Fine-Tuning Language Models from Human Preferences" by D. Ziegler et al. [[paper](https://arxiv.org/pdf/1909.08593.pdf), [code](https://github.com/openai/lm-human-preferences)]. +The Trainer and model classes are largely inspired from `transformers.Trainer` and `transformers.AutoModel` classes and adapted for RL. +We also support a `RewardTrainer` that can be used to train a reward model. + +## PPOConfig + +[[autodoc]] PPOConfig + +## PPOTrainer + +[[autodoc]] PPOTrainer + +## RewardConfig + +[[autodoc]] RewardConfig + +## RewardTrainer + +[[autodoc]] RewardTrainer + +## SFTTrainer + +[[autodoc]] SFTTrainer + +## DPOTrainer + +[[autodoc]] DPOTrainer + +## DDPOConfig + +[[autodoc]] DDPOConfig + +## DDPOTrainer + +[[autodoc]] DDPOTrainer + +## IterativeSFTTrainer + +[[autodoc]] IterativeSFTTrainer + +## set_seed + +[[autodoc]] set_seed diff --git a/trl/trl/docs/source/use_model.md b/trl/trl/docs/source/use_model.md new file mode 100644 index 0000000000000000000000000000000000000000..f5ab1e45946460fc80d64f54136482b12400d059 --- /dev/null +++ b/trl/trl/docs/source/use_model.md @@ -0,0 +1,58 @@ +# Use model after training + +Once you have trained a model using either the SFTTrainer, PPOTrainer, or DPOTrainer, you will have a fine-tuned model that can be used for text generation. In this section, we'll walk through the process of loading the fine-tuned model and generating text. If you need to run an inference server with the trained model, you can explore libraries such as [`text-generation-inference`](https://github.com/huggingface/text-generation-inference). + +## Load and Generate + +If you have fine-tuned a model fully, meaning without the use of PEFT you can simply load it like any other language model in transformers. E.g. the value head that was trained during the PPO training is no longer needed and if you load the model with the original transformer class it will be ignored: + +```python +from transformers import AutoTokenizer, AutoModelForCausalLM + +model_name_or_path = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub +device = "cpu" # or "cuda" if you have a GPU + +model = AutoModelForCausalLM.from_pretrained(model_name_or_path).to(device) +tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + +inputs = tokenizer.encode("This movie was really", return_tensors="pt").to(device) +outputs = model.generate(inputs) +print(tokenizer.decode(outputs[0])) +``` + +Alternatively you can also use the pipeline: + +```python +from transformers import pipeline + +model_name_or_path = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub +pipe = pipeline("text-generation", model=model_name_or_path) +print(pipe("This movie was really")[0]["generated_text"]) +``` + +## Use Adapters PEFT + +```python +from peft import PeftConfig, PeftModel +from transformers import AutoModelForCausalLM, AutoTokenizer + +base_model_name = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub" +adapter_model_name = "path/to/my/adapter" + +model = AutoModelForCausalLM.from_pretrained(base_model_name) +model = PeftModel.from_pretrained(model, adapter_model_name) + +tokenizer = AutoTokenizer.from_pretrained(base_model_name) +``` + +You can also merge the adapters into the base model so you can use the model like a normal transformers model, however the checkpoint will be significantly bigger: + +```python +model = AutoModelForCausalLM.from_pretrained(base_model_name) +model = PeftModel.from_pretrained(model, adapter_model_name) + +model = model.merge_and_unload() +model.save_pretrained("merged_adapters") +``` + +Once you have the model loaded and either merged the adapters or keep them separately on top you can run generation as with a normal model outlined above. diff --git a/trl/trl/docs/source/using_llama_models.mdx b/trl/trl/docs/source/using_llama_models.mdx new file mode 100644 index 0000000000000000000000000000000000000000..cf602d2030400b00fe91749a8e49438bbfb90c4c --- /dev/null +++ b/trl/trl/docs/source/using_llama_models.mdx @@ -0,0 +1,160 @@ +# Using LLaMA models with TRL + +We've begun rolling out examples to use Meta's LLaMA models in `trl` (see [Meta's LLaMA release](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/) for the original LLaMA model). + +## Efficient training strategies + +Even training the smallest LLaMA model requires an enormous amount of memory. Some quick math: in bf16, every parameter uses 2 bytes (in fp32 4 bytes) in addition to 8 bytes used, e.g., in the Adam optimizer (see the [performance docs](https://huggingface.co/docs/transformers/perf_train_gpu_one#optimizer) in Transformers for more info). So a 7B parameter model would use `(2+8)*7B=70GB` just to fit in memory and would likely need more when you compute intermediate values such as attention scores. So you couldn’t train the model even on a single 80GB A100 like that. You can use some tricks, like more efficient optimizers of half-precision training, to squeeze a bit more into memory, but you’ll run out sooner or later. + +Another option is to use Parameter-Efficient Fine-Tuning (PEFT) techniques, such as the [`peft`](https://github.com/huggingface/peft) library, which can perform low-rank adaptation (LoRA) on a model loaded in 8-bit. +For more on `peft` + `trl`, see the [docs](https://huggingface.co/docs/trl/sentiment_tuning_peft). + +Loading the model in 8bit reduces the memory footprint drastically since you only need one byte per parameter for the weights (e.g. 7B LlaMa is 7GB in memory). +Instead of training the original weights directly, LoRA adds small adapter layers on top of some specific layers (usually the attention layers); thus, the number of trainable parameters is drastically reduced. + +In this scenario, a rule of thumb is to allocate ~1.2-1.4GB per billion parameters (depending on the batch size and sequence length) to fit the entire fine-tuning setup. +This enables fine-tuning larger models (up to 50-60B scale models on a NVIDIA A100 80GB) at low cost. + +Now we can fit very large models into a single GPU, but the training might still be very slow. +The simplest strategy in this scenario is data parallelism: we replicate the same training setup into separate GPUs and pass different batches to each GPU. +With this, you can parallelize the forward/backward passes of the model and scale with the number of GPUs. + +![chapter10_ddp.png](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/blog/stackllama/chapter10_ddp.png) + +We use either the `transformers.Trainer` or `accelerate`, which both support data parallelism without any code changes, by simply passing arguments when calling the scripts with `torchrun` or `accelerate launch`. The following runs a training script with 8 GPUs on a single machine with `accelerate` and `torchrun`, respectively. + +```bash +accelerate launch --multi_gpu --num_machines 1 --num_processes 8 my_accelerate_script.py +torchrun --nnodes 1 --nproc_per_node 8 my_torch_script.py +``` + +## Supervised fine-tuning + +Before we start training reward models and tuning our model with RL, it helps if the model is already good in the domain we are interested in. +In our case, we want it to answer questions, while for other use cases, we might want it to follow instructions, in which case instruction tuning is a great idea. +The easiest way to achieve this is by continuing to train the language model with the language modeling objective on texts from the domain or task. +The [StackExchange dataset](https://huggingface.co/datasets/HuggingFaceH4/stack-exchange-preferences) is enormous (over 10 million instructions), so we can easily train the language model on a subset of it. + +There is nothing special about fine-tuning the model before doing RLHF - it’s just the causal language modeling objective from pretraining that we apply here. +To use the data efficiently, we use a technique called packing: instead of having one text per sample in the batch and then padding to either the longest text or the maximal context of the model, we concatenate a lot of texts with a EOS token in between and cut chunks of the context size to fill the batch without any padding. + +![chapter10_preprocessing-clm.png](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/blog/stackllama/chapter10_preprocessing-clm.png) + +With this approach the training is much more efficient as each token that is passed through the model is also trained in contrast to padding tokens which are usually masked from the loss. +If you don't have much data and are more concerned about occasionally cutting off some tokens that are overflowing the context you can also use a classical data loader. + +The packing is handled by the `ConstantLengthDataset` and we can then use the `Trainer` after loading the model with `peft`. First, we load the model in int8, prepare it for training, and then add the LoRA adapters. + +```python +# load model in 8bit +model = AutoModelForCausalLM.from_pretrained( + args.model_path, + load_in_8bit=True, + device_map={"": Accelerator().local_process_index} + ) +model = prepare_model_for_kbit_training(model) + +# add LoRA to model +lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", +) + +model = get_peft_model(model, config) +``` + +We train the model for a few thousand steps with the causal language modeling objective and save the model. +Since we will tune the model again with different objectives, we merge the adapter weights with the original model weights. + +**Disclaimer:** due to LLaMA's license, we release only the adapter weights for this and the model checkpoints in the following sections. +You can apply for access to the base model's weights by filling out Meta AI's [form](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform) and then converting them to the 🤗 Transformers format by running this [script](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py). +Note that you'll also need to install 🤗 Transformers from source until the `v4.28` is released. + +Now that we have fine-tuned the model for the task, we are ready to train a reward model. + +## Reward modeling and human preferences + +In principle, we could fine-tune the model using RLHF directly with the human annotations. +However, this would require us to send some samples to humans for rating after each optimization iteration. +This is expensive and slow due to the number of training samples needed for convergence and the inherent latency of human reading and annotator speed. + +A trick that works well instead of direct feedback is training a reward model on human annotations collected before the RL loop. +The goal of the reward model is to imitate how a human would rate a text. There are several possible strategies to build a reward model: the most straightforward way would be to predict the annotation (e.g. a rating score or a binary value for “good”/”bad”). +In practice, what works better is to predict the ranking of two examples, where the reward model is presented with two candidates `(y_k, y_j)` for a given prompt `x` and has to predict which one would be rated higher by a human annotator. + +With the StackExchange dataset, we can infer which of the two answers was preferred by the users based on the score. +With that information and the loss defined above, we can then modify the `transformers.Trainer` by adding a custom loss function. + +```python +class RewardTrainer(Trainer): + def compute_loss(self, model, inputs, return_outputs=False): + rewards_j = model(input_ids=inputs["input_ids_j"], attention_mask=inputs["attention_mask_j"])[0] + rewards_k = model(input_ids=inputs["input_ids_k"], attention_mask=inputs["attention_mask_k"])[0] + loss = -nn.functional.logsigmoid(rewards_j - rewards_k).mean() + if return_outputs: + return loss, {"rewards_j": rewards_j, "rewards_k": rewards_k} + return loss +``` + +We utilize a subset of a 100,000 pair of candidates and evaluate on a held-out set of 50,000. With a modest training batch size of 4, we train the Llama model using the LoRA `peft` adapter for a single epoch using the Adam optimizer with BF16 precision. Our LoRA configuration is: + +```python +peft_config = LoraConfig( + task_type=TaskType.SEQ_CLS, + inference_mode=False, + r=8, + lora_alpha=32, + lora_dropout=0.1, +) +``` +As detailed in the next section, the resulting adapter can be merged into the frozen model and saved for further downstream use. + +## Reinforcement Learning from Human Feedback + +With the fine-tuned language model and the reward model at hand, we are now ready to run the RL loop. It follows roughly three steps: + +1. Generate responses from prompts, +2. Rate the responses with the reward model, +3. Run a reinforcement learning policy-optimization step with the ratings. + +The Query and Response prompts are templated as follows before being tokenized and passed to the model: + +```bash +Question: + +Answer: +``` + +The same template was used for SFT, RM and RLHF stages. +Once more, we utilize `peft` for memory-efficient training, which offers an extra advantage in the RLHF context. +Here, the reference model and policy share the same base, the SFT model, which we load in 8-bit and freeze during training. +We exclusively optimize the policy's LoRA weights using PPO while sharing the base model's weights. + +```python +for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)): + question_tensors = batch["input_ids"] + + # sample from the policy and to generate responses + response_tensors = ppo_trainer.generate( + question_tensors, + return_prompt=False, + length_sampler=output_length_sampler, + **generation_kwargs, + ) + batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True) + + # Compute sentiment score + texts = [q + r for q, r in zip(batch["query"], batch["response"])] + pipe_outputs = sentiment_pipe(texts, **sent_kwargs) + rewards = [torch.tensor(output[0]["score"] - script_args.reward_baseline) for output in pipe_outputs] + + # Run PPO step + stats = ppo_trainer.step(question_tensors, response_tensors, rewards) + # Log stats to Wandb + ppo_trainer.log_stats(stats, batch, rewards) +``` + +For the rest of the details and evaluation, please refer to our [blog post on StackLLaMA](https://huggingface.co/blog/stackllama). diff --git a/trl/trl/examples/README.md b/trl/trl/examples/README.md new file mode 100644 index 0000000000000000000000000000000000000000..37999e41abc02461a09ed7e29e39cc0bec15e488 --- /dev/null +++ b/trl/trl/examples/README.md @@ -0,0 +1,3 @@ +# Examples + +Please check out https://huggingface.co/docs/trl/example_overview for documentation on our examples. \ No newline at end of file diff --git a/trl/trl/examples/accelerate_configs/deepspeed_zero1.yaml b/trl/trl/examples/accelerate_configs/deepspeed_zero1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d5b5f782fb30f9fcbcc8fc58262f09eaf2e10368 --- /dev/null +++ b/trl/trl/examples/accelerate_configs/deepspeed_zero1.yaml @@ -0,0 +1,20 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + gradient_accumulation_steps: 1 + zero3_init_flag: false + zero_stage: 1 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/trl/trl/examples/accelerate_configs/deepspeed_zero2.yaml b/trl/trl/examples/accelerate_configs/deepspeed_zero2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..390f7086e5b8759b2285c2bd18fb92b337f4ae27 --- /dev/null +++ b/trl/trl/examples/accelerate_configs/deepspeed_zero2.yaml @@ -0,0 +1,22 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + gradient_accumulation_steps: 1 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/trl/trl/examples/accelerate_configs/deepspeed_zero3.yaml b/trl/trl/examples/accelerate_configs/deepspeed_zero3.yaml new file mode 100644 index 0000000000000000000000000000000000000000..281c91e8a71e023d5d37f1e0cd9ec9f26b3e231c --- /dev/null +++ b/trl/trl/examples/accelerate_configs/deepspeed_zero3.yaml @@ -0,0 +1,23 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + gradient_accumulation_steps: 1 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/trl/trl/examples/accelerate_configs/multi_gpu.yaml b/trl/trl/examples/accelerate_configs/multi_gpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..15dad9be3ba44f7c934e1ecab98a93cb83cbc79a --- /dev/null +++ b/trl/trl/examples/accelerate_configs/multi_gpu.yaml @@ -0,0 +1,16 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: MULTI_GPU +downcast_bf16: 'no' +gpu_ids: all +machine_rank: 0 +main_training_function: main +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/trl/trl/examples/hello_world.py b/trl/trl/examples/hello_world.py new file mode 100644 index 0000000000000000000000000000000000000000..138defb5b433ff43480e61a29e89b8e0233c6400 --- /dev/null +++ b/trl/trl/examples/hello_world.py @@ -0,0 +1,40 @@ +# 0. imports +import torch +from transformers import GPT2Tokenizer + +from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer + + +# 1. load a pretrained model +model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2") +model_ref = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2") +tokenizer = GPT2Tokenizer.from_pretrained("gpt2") +tokenizer.pad_token = tokenizer.eos_token + +# 2. initialize trainer +ppo_config = {"batch_size": 1} +config = PPOConfig(**ppo_config) +ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer) + +# 3. encode a query +query_txt = "This morning I went to the " +query_tensor = tokenizer.encode(query_txt, return_tensors="pt").to(model.pretrained_model.device) + +# 4. generate model response +generation_kwargs = { + "min_length": -1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + "pad_token_id": tokenizer.eos_token_id, + "max_new_tokens": 20, +} +response_tensor = ppo_trainer.generate([item for item in query_tensor], return_prompt=False, **generation_kwargs) +response_txt = tokenizer.decode(response_tensor[0]) + +# 5. define a reward for response +# (this could be any reward such as human feedback or output from another model) +reward = [torch.tensor(1.0, device=model.pretrained_model.device)] + +# 6. train model with ppo +train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward) diff --git a/trl/trl/examples/notebooks/README.md b/trl/trl/examples/notebooks/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f2a11e280099f79e30059ff77295d53eff30b62a --- /dev/null +++ b/trl/trl/examples/notebooks/README.md @@ -0,0 +1,7 @@ +# Notebooks + +This directory contains a collection of Jupyter notebooks that demonstrate how to use the TRL library in different applications. + +- [`best_of_n.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/best_of_n.ipynb): This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO. +- [`gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb): This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. +- [`gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment-control.ipynb): This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook. diff --git a/trl/trl/examples/notebooks/best_of_n.ipynb b/trl/trl/examples/notebooks/best_of_n.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..375cafe99f0ad77a902634e546d9199f824e04fb --- /dev/null +++ b/trl/trl/examples/notebooks/best_of_n.ipynb @@ -0,0 +1,648 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU", + "gpuClass": "standard" + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "\n", + "**Best-of-n sampling as an alternative to RLHF**\n", + "\n", + "This notebook compares reward-model scores of prompt based responses from \n", + "1. a base model (`gpt2-imdb`)\n", + "2. `RLHF` tuned model based on this base-model \n", + "3. the base-model again from which we sample n responses to each prompt, score them and take the best scored one AKA the `best-of-n sampled` model\n", + "\n" + ], + "metadata": { + "id": "WQpNapZNWuXP" + } + }, + { + "cell_type": "markdown", + "source": [ + "Import dependencies\n" + ], + "metadata": { + "id": "Lo98lkdP66_x" + } + }, + { + "cell_type": "code", + "source": [ + "%pip install transformers trl" + ], + "metadata": { + "id": "vDA6qayz692w" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "import torch\n", + "import pandas as pd\n", + "from transformers import pipeline, AutoTokenizer\n", + "from datasets import load_dataset\n", + "\n", + "from trl import AutoModelForCausalLMWithValueHead\n", + "from trl.core import LengthSampler\n", + "\n", + "device = 0 if torch.cuda.is_available() else \"cpu\"" + ], + "metadata": { + "id": "M1s_iNm773hM" + }, + "execution_count": 2, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Various constants" + ], + "metadata": { + "id": "Y7hyrIrO8tcY" + } + }, + { + "cell_type": "code", + "source": [ + "ref_model_name = \"lvwerra/gpt2-imdb\"\n", + "model_name = \"lvwerra/gpt2-imdb-pos-v2\"\n", + "reward_model = \"lvwerra/distilbert-imdb\"\n", + "\n", + "N_BEST_OF = 4" + ], + "metadata": { + "id": "MqS3OM6Q8x6g" + }, + "execution_count": 3, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Models and tokenizers " + ], + "metadata": { + "id": "c1YcXeElg6or" + } + }, + { + "cell_type": "code", + "source": [ + "model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name)\n", + "\n", + "ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name)\n", + "\n", + "reward_pipe = pipeline(\"sentiment-analysis\", model=reward_model, device=device)\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(ref_model_name)\n", + "\n", + "tokenizer.pad_token = tokenizer.eos_token\n", + "\n", + "# cuda-ize models\n", + "model.cuda()\n", + "ref_model.cuda()" + ], + "metadata": { + "id": "b855NrL181Hh" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Dataset building" + ], + "metadata": { + "id": "Z1Cz0gCFhZYJ" + } + }, + { + "cell_type": "code", + "source": [ + "def build_dataset(tokenizer, dataset_name=\"imdb\", input_min_text_length=2, input_max_text_length=8):\n", + " # load imdb with datasets\n", + " ds = load_dataset(dataset_name, split=\"train\")\n", + " ds = ds.rename_columns({\"text\": \"review\"})\n", + " ds = ds.filter(lambda x: len(x[\"review\"]) > 200, batched=False)\n", + "\n", + " input_size = LengthSampler(input_min_text_length, input_max_text_length)\n", + "\n", + " def tokenize(sample):\n", + " sample[\"input_ids\"] = tokenizer.encode(sample[\"review\"])[: input_size()]\n", + " sample[\"query\"] = tokenizer.decode(sample[\"input_ids\"])\n", + " return sample\n", + "\n", + " ds = ds.map(tokenize, batched=False)\n", + " ds.set_format(type=\"torch\")\n", + " return ds\n", + "\n", + "\n", + "dataset = build_dataset(tokenizer)" + ], + "metadata": { + "id": "LqLVEp5p_8XM" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "gen_kwargs = {\"min_length\": -1, \"top_k\": 0.0, \"top_p\": 1.0, \"do_sample\": True, \"pad_token_id\": tokenizer.eos_token_id}\n", + "sent_kwargs = {\"top_k\": None, \"function_to_apply\": \"none\", \"batch_size\": 16}" + ], + "metadata": { + "id": "AqA2McjMAxNw" + }, + "execution_count": 6, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "output_min_length = 4\n", + "output_max_length = 16\n", + "output_length_sampler = LengthSampler(output_min_length, output_max_length)\n", + "\n", + "#### get a batch from the dataset\n", + "bs = 16\n", + "output_data = dict()\n", + "dataset.set_format(\"pandas\")\n", + "df_batch = dataset[:].sample(bs)\n", + "output_data[\"query\"] = df_batch[\"query\"].tolist()\n", + "query_tensors = df_batch[\"input_ids\"].tolist()\n", + "\n", + "# :: [Resp]\n", + "response_tensors_ref, response_tensors = [], []\n", + "# :: [[Resp]]\n", + "response_tensors_best_of = []" + ], + "metadata": { + "id": "L_q4qs35AxcR" + }, + "execution_count": 7, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "\n", + "Generation using various models" + ], + "metadata": { + "id": "QVfpyHnZBLKY" + } + }, + { + "cell_type": "code", + "source": [ + "for i in range(bs):\n", + " gen_len = output_length_sampler()\n", + "\n", + " query = torch.tensor(query_tensors[i])\n", + "\n", + " output = ref_model.generate(query.unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs).squeeze()\n", + " response_tensors_ref.append(tokenizer.decode(output))\n", + "\n", + " output = model.generate(query.unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs).squeeze()\n", + " response_tensors.append(tokenizer.decode(output))\n", + "\n", + " # generating copies of the same query for the Best-of-n sampling\n", + " queries = query.repeat((N_BEST_OF, 1))\n", + " output = ref_model.generate(queries.to(device), max_new_tokens=gen_len, **gen_kwargs).squeeze()\n", + " response_tensors_best_of.append(tokenizer.batch_decode(output))" + ], + "metadata": { + "id": "-imZ7uEFBNbw" + }, + "execution_count": 8, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Scoring" + ], + "metadata": { + "id": "Jp5FC0Y5h_Sf" + } + }, + { + "cell_type": "code", + "source": [ + "scores_ref = [output[0][\"score\"] for output in reward_pipe(response_tensors_ref, **sent_kwargs)]\n", + "scores = [output[0][\"score\"] for output in reward_pipe(response_tensors, **sent_kwargs)]\n", + "scores_best_of = []\n", + "for i, response in enumerate(response_tensors_best_of):\n", + " # base_score = scores_ref[i]\n", + " scores_best_of.append(torch.tensor([output[0][\"score\"] for output in reward_pipe(response, **sent_kwargs)]))" + ], + "metadata": { + "id": "PyDbbAQ0F_h7" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "output_data[\"response (ref)\"] = response_tensors_ref\n", + "output_data[\"scores (ref)\"] = scores_ref\n", + "output_data[\"response (RLHF)\"] = response_tensors\n", + "output_data[\"scores (RLHF)\"] = scores\n", + "output_data[\"response (best_of)\"] = [\n", + " response_tensors_best_of[i][a.argmax().item()] for i, a in enumerate(scores_best_of)\n", + "]\n", + "output_data[\"scores (best_of)\"] = [a.max().item() for a in scores_best_of]\n", + "\n", + "\n", + "# store results in a dataframe\n", + "df_results = pd.DataFrame(output_data)\n", + "df_results" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 682 + }, + "id": "nA1GDNJEiGm-", + "outputId": "1389c686-0751-4304-dea2-b71fd68748e1" + }, + "execution_count": 10, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " query \\\n", + "0 I'm a pretty old \n", + "1 One of the most \n", + "2 Okay, as \n", + "3 Watching \"Kro \n", + "4 Seriously what were they thinking? \n", + "5 OK Hollywood \n", + "6 \"Bend It \n", + "7 While the premise behind The House \n", + "8 Well let me go \n", + "9 Vijay Krishna Acharya \n", + "10 Watching this movie made me \n", + "11 There are probably \n", + "12 Meryl Stre \n", + "13 I thought I read somewhere that \n", + "14 Good movie, very \n", + "15 It was agonizing \n", + "\n", + " response (ref) scores (ref) \\\n", + "0 I'm a pretty old kid, well, with lots of girl 1.179652 \n", + "1 One of the most psychologically devastating as... 2.477277 \n", + "2 Okay, as ruthless as they are, even their leve... 1.466462 \n", + "3 Watching \"Kroger\" (1915- 0.186047 \n", + "4 Seriously what were they thinking? It ain't go... 1.010697 \n", + "5 OK Hollywood goes into a total game of audio, ... 0.934041 \n", + "6 \"Bend It, Luther, Dodge, Church Goes to Rome w... 0.039218 \n", + "7 While the premise behind The House of Dracula ... -0.079306 \n", + "8 Well let me go...I don't want to movie it. I'm... 1.015246 \n", + "9 Vijay Krishna Acharya Sawai (Elverling). She was 0.341506 \n", + "10 Watching this movie made me poorly appreciate ... 1.574047 \n", + "11 There are probably more but if you had never s... -0.047099 \n", + "12 Meryl Streep's version of 0.373884 \n", + "13 I thought I read somewhere that the Lord had c... 0.091776 \n", + "14 Good movie, very funny, acting is very good.<|... 2.408837 \n", + "15 It was agonizing, and it made me wonder 1.240262 \n", + "\n", + " response (RLHF) scores (RLHF) \\\n", + "0 I'm a pretty old lady, and I loved this movie ... 2.218363 \n", + "1 One of the most Antibiotic Apps I have seen in 2.145479 \n", + "2 Okay, as I enjoyed the movie. It's added bonus... 2.239827 \n", + "3 Watching \"Kroven\". The film has a 1.044690 \n", + "4 Seriously what were they thinking? It's a very... 2.753088 \n", + "5 OK Hollywood shoot, and this is a classic. Som... 2.517364 \n", + "6 \"Bend It all\" is a sophisticated, drawing and ... 2.583935 \n", + "7 While the premise behind The House Intelligenc... 0.205217 \n", + "8 Well let me go through everything says it's a ... 2.727040 \n", + "9 Vijay Krishna Acharya is a perfect performance... 2.563642 \n", + "10 Watching this movie made me sleep better. It w... 1.690222 \n", + "11 There are probably random man only recently wh... 0.398258 \n", + "12 Meryl Streitz, who is 0.085154 \n", + "13 I thought I read somewhere that my thoughts, a... 1.833734 \n", + "14 Good movie, very much fuzz and logical based w... 2.325996 \n", + "15 It was agonizing because it was truly fun to 0.969669 \n", + "\n", + " response (best_of) scores (best_of) \n", + "0 I'm a pretty old, stinking,acting kinda chick ... 2.016955 \n", + "1 One of the most memorable performances of this... 2.676944 \n", + "2 Okay, as I put it in such a negative mood, it ... 1.478424 \n", + "3 Watching \"Kro\" is an entertainment craze 1.389495 \n", + "4 Seriously what were they thinking? It was stil... 2.523514 \n", + "5 OK Hollywood pay and the freaky set-up of this... 1.634765 \n", + "6 \"Bend It 9\"/\"Zara Pephoto\") and an honest, rea... 2.557210 \n", + "7 While the premise behind The House of Dracula ... 1.676889 \n", + "8 Well let me go though, alive in this ever grow... 2.652859 \n", + "9 Vijay Krishna Acharya adeptly emerges, and the... 2.308076 \n", + "10 Watching this movie made me curious: what did ... 0.950836 \n", + "11 There are probably too many documentaries in s... 1.142725 \n", + "12 Meryl Streep performed an awe 1.932498 \n", + "13 I thought I read somewhere that The Odd Couple... 0.475951 \n", + "14 Good movie, very well polished, nicely written... 2.820022 \n", + "15 It was agonizing, poignant, and worst of 2.058277 " + ], + "text/html": [ + "\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
queryresponse (ref)scores (ref)response (RLHF)scores (RLHF)response (best_of)scores (best_of)
0I'm a pretty oldI'm a pretty old kid, well, with lots of girl1.179652I'm a pretty old lady, and I loved this movie ...2.218363I'm a pretty old, stinking,acting kinda chick ...2.016955
1One of the mostOne of the most psychologically devastating as...2.477277One of the most Antibiotic Apps I have seen in2.145479One of the most memorable performances of this...2.676944
2Okay, asOkay, as ruthless as they are, even their leve...1.466462Okay, as I enjoyed the movie. It's added bonus...2.239827Okay, as I put it in such a negative mood, it ...1.478424
3Watching \"KroWatching \"Kroger\" (1915-0.186047Watching \"Kroven\". The film has a1.044690Watching \"Kro\" is an entertainment craze1.389495
4Seriously what were they thinking?Seriously what were they thinking? It ain't go...1.010697Seriously what were they thinking? It's a very...2.753088Seriously what were they thinking? It was stil...2.523514
5OK HollywoodOK Hollywood goes into a total game of audio, ...0.934041OK Hollywood shoot, and this is a classic. Som...2.517364OK Hollywood pay and the freaky set-up of this...1.634765
6\"Bend It\"Bend It, Luther, Dodge, Church Goes to Rome w...0.039218\"Bend It all\" is a sophisticated, drawing and ...2.583935\"Bend It 9\"/\"Zara Pephoto\") and an honest, rea...2.557210
7While the premise behind The HouseWhile the premise behind The House of Dracula ...-0.079306While the premise behind The House Intelligenc...0.205217While the premise behind The House of Dracula ...1.676889
8Well let me goWell let me go...I don't want to movie it. I'm...1.015246Well let me go through everything says it's a ...2.727040Well let me go though, alive in this ever grow...2.652859
9Vijay Krishna AcharyaVijay Krishna Acharya Sawai (Elverling). She was0.341506Vijay Krishna Acharya is a perfect performance...2.563642Vijay Krishna Acharya adeptly emerges, and the...2.308076
10Watching this movie made meWatching this movie made me poorly appreciate ...1.574047Watching this movie made me sleep better. It w...1.690222Watching this movie made me curious: what did ...0.950836
11There are probablyThere are probably more but if you had never s...-0.047099There are probably random man only recently wh...0.398258There are probably too many documentaries in s...1.142725
12Meryl StreMeryl Streep's version of0.373884Meryl Streitz, who is0.085154Meryl Streep performed an awe1.932498
13I thought I read somewhere thatI thought I read somewhere that the Lord had c...0.091776I thought I read somewhere that my thoughts, a...1.833734I thought I read somewhere that The Odd Couple...0.475951
14Good movie, veryGood movie, very funny, acting is very good.<|...2.408837Good movie, very much fuzz and logical based w...2.325996Good movie, very well polished, nicely written...2.820022
15It was agonizingIt was agonizing, and it made me wonder1.240262It was agonizing because it was truly fun to0.969669It was agonizing, poignant, and worst of2.058277
\n", + "
\n", + " \n", + " \n", + " \n", + "\n", + " \n", + "
\n", + "
\n", + " " + ] + }, + "metadata": {}, + "execution_count": 10 + } + ] + } + ] +} \ No newline at end of file diff --git a/trl/trl/examples/notebooks/gpt2-sentiment-control.ipynb b/trl/trl/examples/notebooks/gpt2-sentiment-control.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..de00502d3dc645ee3dd3f9aed7a6469175b12d40 --- /dev/null +++ b/trl/trl/examples/notebooks/gpt2-sentiment-control.ipynb @@ -0,0 +1,860 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Tune GPT2 to generate controlled sentiment reviews\n", + "> Optimise GPT2 to produce IMDB movie reviews with controlled sentiment using a BERT sentiment classifier for rewards.\n", + "\n", + "**WARNING:** We often experienced loss spikes in this examples which caused model training to fail or slow down. There is a [GitHub issue](https://github.com/lvwerra/trl/issues/101) to track the issue." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "

Figure: Experiment setup to tune GPT2. The yellow arrows are outside the scope of this notebook, but the trained models are available through Hugging Face.

\n", + "
\n", + "\n", + "\n", + "The experiment setup is very similar to the positive sentiment notebook. However, in this notebook we fine-tune GPT2 (small) to generate **controlled** movie reviews based on the IMDB dataset. The model gets the target sentiment and 5 tokens from a real review and is tasked to produce continuations with the targeted sentiment. The reward for the continuations is calculated with the logits of a BERT sentiment classifier. That reward is then used for PPO training." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup experiment" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Import dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/leandro_huggingface_co/miniconda3/envs/trl/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import random\n", + "import torch\n", + "import wandb\n", + "import time\n", + "import os\n", + "from tqdm import tqdm\n", + "import numpy as np\n", + "import pandas as pd\n", + "from random import choices\n", + "import matplotlib.pyplot as plt\n", + "\n", + "tqdm.pandas()\n", + "\n", + "from datasets import load_dataset\n", + "\n", + "from transformers import AutoTokenizer, pipeline\n", + "\n", + "from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Configuration" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "sentiment_pipe_kwargs = {\"top_k\": None, \"function_to_apply\": \"none\"}\n", + "\n", + "config = PPOConfig(\n", + " model_name=\"lvwerra/gpt2-imdb\", steps=51200, learning_rate=1.41e-5, remove_unused_columns=False, log_with=\"wandb\"\n", + ")\n", + "\n", + "txt_in_len = 5\n", + "txt_out_len = 20\n", + "seed = 1" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "np.random.seed(seed)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can see that we load a GPT2 model called `gpt2_imdb`. This model was additionally fine-tuned on the IMDB dataset for 1 epoch with the huggingface [script](https://github.com/huggingface/transformers/blob/master/examples/run_language_modeling.py) (no special settings). The other parameters are mostly taken from the original paper [\"Fine-Tuning Language Models from Human Preferences\"](\n", + "https://arxiv.org/pdf/1909.08593.pdf). This model as well as the BERT model is available in the Huggingface model zoo [here](https://huggingface.co/models). The following code should automatically download the models." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load data and models" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load pre-trained GPT2 language models" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We load the GPT2 model with a value head and the tokenizer. We load the model twice; the first model is optimized while the second model serves as a reference to calculate the KL-divergence from the starting point. This serves as an additional reward signal in the PPO training to make sure the optimized model does not deviate too much from the original language model." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "gpt2_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)\n", + "gpt2_model_ref = create_reference_model(gpt2_model)\n", + "gpt2_tokenizer = AutoTokenizer.from_pretrained(config.model_name)\n", + "\n", + "gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load IMDB dataset\n", + "The IMDB dataset contains 50k movie review annotated with \"positive\"/\"negative\" feedback indicating the sentiment. We load the IMDB dataset into a DataFrame and filter for comments that are at least 500 characters long and take the first 1000 characters of each comment. The first filter we apply to avoid comments that are less than `txt_in_len` token long and the second to avoid tokenizing way more text than we actually need." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Found cached dataset imdb (/home/leandro_huggingface_co/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)\n", + "Loading cached processed dataset at /home/leandro_huggingface_co/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-d314b4c14499bf03.arrow\n", + "Loading cached processed dataset at /home/leandro_huggingface_co/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-0d5fcb05c95b1186.arrow\n" + ] + }, + { + "data": { + "text/plain": [ + "Dataset({\n", + " features: ['review', 'sentiment'],\n", + " num_rows: 22578\n", + "})" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# create the dataset\n", + "#\n", + "dataset = load_dataset(\"imdb\", split=\"train\")\n", + "dataset = dataset.rename_columns({\"text\": \"review\", \"label\": \"sentiment\"})\n", + "# make sure the comments are are at least 500 and trim to 1000\n", + "dataset = dataset.filter(lambda x: len(x[\"review\"]) > 500, batched=False)\n", + "dataset = dataset.map(lambda x: {\"review\": x[\"review\"][:1000]}, batched=False)\n", + "\n", + "dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Tokenize IMDB reviews" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We tokenize all IMDB in advance to avoid tokenizing twice. In the first step we encode the queries and slice the first `txt_in_len` tokens. In a second step we decode these tokens back to text for later display." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading cached processed dataset at /home/leandro_huggingface_co/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-383f6ebf0ae41ee4.arrow\n", + "Loading cached processed dataset at /home/leandro_huggingface_co/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-f4875ad4fccbbc1f.arrow\n" + ] + } + ], + "source": [ + "dataset = dataset.map(\n", + " lambda x: {\"input_ids\": gpt2_tokenizer.encode(\" \" + x[\"review\"], return_tensors=\"pt\")[0, :txt_in_len]},\n", + " batched=False,\n", + ")\n", + "dataset = dataset.map(lambda x: {\"query\": gpt2_tokenizer.decode(x[\"input_ids\"])}, batched=False)\n", + "dataset = dataset[:20480]\n", + "\n", + "from datasets import Dataset\n", + "\n", + "dataset = Dataset.from_dict(dataset)\n", + "dataset.set_format(\"pytorch\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ 770, 2646, 373, 2192, 7867])" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset[3][\"input_ids\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "def collator(data):\n", + " return dict((key, [d[key] for d in data]) for key in data[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mlvwerra\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" + ] + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.13.9" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /home/leandro_huggingface_co/trl/examples/sentiment/notebooks/wandb/run-20230206_125743-jpcnr7jx" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run comic-music-184 to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/lvwerra/trl" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/lvwerra/trl/runs/jpcnr7jx" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "ppo_trainer = PPOTrainer(config, gpt2_model, gpt2_model_ref, gpt2_tokenizer, dataset, data_collator=collator)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load BERT classifier\n", + "We load a BERT classifier fine-tuned on the IMDB dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "if ppo_trainer.accelerator.num_processes == 1:\n", + " device = 0 if torch.cuda.is_available() else \"cpu\" # to avoid a `pipeline` bug\n", + "else:\n", + " device = ppo_trainer.accelerator.device\n", + "sentiment_pipe = pipeline(\"sentiment-analysis\", \"lvwerra/distilbert-imdb\", device=device)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The model outputs are the logits for the negative and positive class. We will use the logits for positive class as a reward signal for the language model." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'label': 'NEGATIVE', 'score': 2.3350484371185303},\n", + " {'label': 'POSITIVE', 'score': -2.726576328277588}]" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "text = \"this movie was really bad!!\"\n", + "output = sentiment_pipe(text, **sentiment_pipe_kwargs)\n", + "output" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'label': 'POSITIVE', 'score': 2.557040214538574},\n", + " {'label': 'NEGATIVE', 'score': -2.294790267944336}]" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "text = \"this movie was really good!!\"\n", + "output = sentiment_pipe(text, **sentiment_pipe_kwargs)\n", + "output" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'label': 'POSITIVE', 'score': 0.8562759160995483},\n", + " {'label': 'NEGATIVE', 'score': -0.7086048126220703}]" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "text = \"this movie was a documentary\"\n", + "output = sentiment_pipe(text, **sentiment_pipe_kwargs)\n", + "output" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The resulting reward signal:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "def extract_pipe_output(outputs):\n", + " positive_logits = []\n", + " for out in outputs:\n", + " for element in out:\n", + " if element[\"label\"] == \"POSITIVE\":\n", + " positive_logits.append(torch.tensor(element[\"score\"]))\n", + " return positive_logits" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "-0.7086048126220703" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "output[1][\"score\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Control token dict\n", + "We will append the control token at the beginning of each query to signal the model what the target sentiment is. Each control sequence consists of three tokens:" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "ctrl_str = [\"[negative]\", \"[neutral]\", \"[positive]\"]\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\") # this should be handled by accelerate\n", + "ctrl_tokens = dict((s, gpt2_tokenizer.encode(s, return_tensors=\"pt\").squeeze().to(device)) for s in ctrl_str)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'[negative]': tensor([ 58, 31591, 60], device='cuda:0'),\n", + " '[neutral]': tensor([ 58, 29797, 60], device='cuda:0'),\n", + " '[positive]': tensor([ 58, 24561, 60], device='cuda:0')}" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ctrl_tokens" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Reward function" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "def pos_logit_to_reward(logit, task):\n", + " \"\"\"\n", + " Take the positive sentiment logit and scale it for the task.\n", + " task [negative]: reward = -logit\n", + " task [neutral]: reward = -2*abs(logit)+4\n", + " task [positive]: reward = logit\n", + " \"\"\"\n", + " for i in range(len(logit)):\n", + " if task[i] == \"[negative]\":\n", + " logit[i] = -logit[i]\n", + " elif task[i] == \"[neutral]\":\n", + " logit[i] = -2 * torch.abs(logit[i]) + 4\n", + " elif task[i] == \"[positive]\":\n", + " pass\n", + " else:\n", + " raise ValueError(\"task has to be in [0, 1, 2]!\")\n", + " return logit" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The following examples show the rewards for the cases where the classifier logit is 4, -4 and 0 for the three targets `['negative]`, `['neutral]` and `['positive']`. The scaling is not perfect as it differs between neutral and the other two classes. This is something to further investigate in the future. Ideally, one would use the logit output for each class individually, but since there is no dedicated class for neutral this is a workaround." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['[negative]', '[neutral]', '[positive]']\n" + ] + } + ], + "source": [ + "print(ctrl_str)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([-4., -4., 4.])" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pos_logit_to_reward(torch.Tensor([4, 4, 4]), ctrl_str)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ 4., -4., -4.])" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pos_logit_to_reward(torch.Tensor([-4, -4, -4]), ctrl_str)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([-0., 4., 0.])" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pos_logit_to_reward(torch.Tensor([0, 0, 0]), ctrl_str)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Generation settings" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "generation_kwargs = {\n", + " \"min_length\": -1,\n", + " \"top_k\": 0.0,\n", + " \"top_p\": 1.0,\n", + " \"do_sample\": True,\n", + " \"pad_token_id\": gpt2_tokenizer.eos_token_id,\n", + " \"max_new_tokens\": txt_out_len,\n", + " \"eos_token_id\": -1,\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Optimize model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Steps**\n", + "\n", + "The training loop consists of the following steps:\n", + "1. Get a batch of queries and create random controls\n", + "2. Get the query responses from the policy\n", + "3. Join query and responses and tokenize for BERT analysis\n", + "4. Get sentiments for query/responses from BERT\n", + "5. Optimize policy with PPO using the (query, response, reward) triplet\n", + "6. Log all the training statistics\n", + "\n", + "**Training time**\n", + "\n", + "This step takes **~2h** on a P6000 GPU with the above specified settings." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 8%|▊ | 6/80 [12:44<2:37:54, 128.03s/it]/home/leandro_huggingface_co/miniconda3/envs/trl/lib/python3.9/site-packages/transformers/pipelines/base.py:1045: UserWarning: You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset\n", + " warnings.warn(\n", + "100%|██████████| 80/80 [2:46:39<00:00, 124.99s/it] \n", + " 91%|█████████▏| 73/80 [2:30:39<14:35, 125.03s/it] " + ] + } + ], + "source": [ + "for epoch in range(2):\n", + " for batch in tqdm(ppo_trainer.dataloader):\n", + " (logs, game_data,) = (\n", + " dict(),\n", + " dict(),\n", + " )\n", + "\n", + " #### prepend a random control token\n", + " task_list = choices(ctrl_str, k=config.batch_size)\n", + " game_data[\"query\"] = [t + q for t, q in zip(task_list, batch[\"query\"])]\n", + " query_tensors = [torch.cat((ctrl_tokens[t], input_ids)) for t, input_ids in zip(task_list, batch[\"input_ids\"])]\n", + "\n", + " #### get response from gpt2\n", + " response_tensors = []\n", + " for query in query_tensors:\n", + " response = ppo_trainer.generate(query, **generation_kwargs)\n", + " response_tensors.append(response.squeeze()[-txt_out_len:])\n", + " game_data[\"response\"] = [gpt2_tokenizer.decode(r.squeeze()) for r in response_tensors]\n", + "\n", + " #### sentiment analysis\n", + " texts = [q + r for q, r in zip(batch[\"query\"], game_data[\"response\"])]\n", + " logits = extract_pipe_output(sentiment_pipe(texts, **sentiment_pipe_kwargs))\n", + " rewards = pos_logit_to_reward(logits, task_list)\n", + "\n", + " #### Run PPO training\n", + " t = time.time()\n", + " stats = ppo_trainer.step(query_tensors, response_tensors, rewards)\n", + "\n", + " for cs in ctrl_str:\n", + " key = \"env/reward_\" + cs.strip(\"[]\")\n", + " stats[key] = np.mean([r.cpu().numpy() for r, t in zip(rewards, task_list) if t == cs])\n", + " ppo_trainer.log_stats(stats, game_data, rewards)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training progress\n", + "If you are tracking the training progress with Weights&Biases you should see a plot similar to the following:\n", + "\n", + "
\n", + "\n", + "

Figure: Reward mean and distribution evolution during training.

\n", + "
\n", + "\n", + "One can observe how the model starts to generate more positive outputs after a few optimisation steps.\n", + "\n", + "> Note: Investigating the KL-divergence will probably show that at this point the model has not converged to the target KL-divergence, yet. To get there would require longer training or starting with a higher inital coefficient." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model inspection" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Reward distribution\n", + "First, we can have a look at the reward distribution. Both the negative and positive rewards are clearly shifted to high rewards. The neutral rewards, however, are still centered around zero. There are a few possible explanations for this. There could be a bug in the code and the way the neutral rewards are calculated. Another problem could be that sentence sometimes start with a strong sentiment and it is hard for the model shift the sentiment towards neutral." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiwAAAGzCAYAAAAMr0ziAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAABPCUlEQVR4nO3deVwVZf8//tecw4HDroiyibKImSmQkKi5B6K3mt4tLvm4RSq7S7lvjTtNLAVcPqip0aLZnbdL3ZK0qP2+5o0SSVmiFor7lklubGqIgB4OnPn9YWfyyGE5h+UM8Ho+Hjw8c80117znOoPzZuaaGUEURRFEREREMqawdABEREREdWHCQkRERLLHhIWIiIhkjwkLERERyR4TFiIiIpI9JixEREQke0xYiIiISPaYsBAREZHsMWEhIiIi2WPCQkQNtmnTJgiCgNzcXLOWnzZtGnx8fAzKBEFAQkJCg2OrS2ZmJgRBQGZmplQ2dOhQ9OrVq8nXDQC5ubkQBAGbNm1qlvURtVRMWIio1UhJSUFycrKlwzBKzrERtQRWlg6AiMiYO3fuwMrKtP+iUlJScOLECcyePbveywwePBh37tyBtbW1iRGapqbYunbtijt37kClUjXp+olaOp5hIZKBsrIyS4dQK51Oh7t37zbrOtVqtckJiynu3r0LnU4HhUIBtVoNhcIy/x0KggC1Wg2lUmmR9RO1FExYiJpZQkICBEHAqVOn8Nxzz6F9+/YYOHCgNP+///0vQkJCYGtrCxcXF0yaNAmXL1+W5r/77rtQKpUoLi6WylatWgVBEBAbGyuVVVVVwdHREa+//rpUtnLlSgwYMAAdOnSAra0tQkJC8MUXX1SLURAExMTEYMuWLXjkkUdgY2ODtLQ0AMDJkycxfPhw2NraonPnzliyZAl0Ol29t3/Hjh3o1asX1Go1evXqhe3btxut9+AYltu3b2P27Nnw8fGBjY0NOnXqhIiICBw+fBjAvXEnX3/9NX777TcIggBBEKRxMfpxKlu3bsWbb74JLy8v2NnZoaSkxOgYFr3s7GwMGDAAtra28PX1xbp16wzm1zR258E2a4utpjEs3377LQYNGgR7e3u0a9cO48aNw+nTpw3q6PelX375BdOmTUO7du3g7OyM6OholJeX1/wlELVAvCREZCHPPvssAgIC8H//938QRREAsHTpUixYsAATJkzAiy++iKKiIrz33nsYPHgwjhw5gnbt2mHQoEHQ6XT44YcfMGbMGADAvn37oFAosG/fPqn9I0eOoLS0FIMHD5bK3nnnHTz55JOYMmUKKioqsHXrVjz77LPYuXMnRo8ebRDft99+i88++wwxMTFwdXWFj48P8vPzMWzYMFRWVmLevHmwt7fHv//9b9ja2tZrm/fs2YOnn34aPXv2RFJSEm7cuIHo6Gh07ty5zmVffvllfPHFF4iJiUHPnj1x48YN/PDDDzh9+jT69OmDN954A7du3cKVK1fw9ttvAwAcHBwM2li8eDGsra3x2muvQaPR1HoZ6Pfff8df/vIXTJgwAZMnT8Znn32GV155BdbW1nj++efrtb169Yntft988w1GjRoFPz8/JCQk4M6dO3jvvffw+OOP4/Dhw9UGKE+YMAG+vr5ISkrC4cOHsX79enTq1AnLly83KU4iWROJqFnFx8eLAMTJkycblOfm5opKpVJcunSpQfnx48dFKysrqbyqqkp0cnIS586dK4qiKOp0OrFDhw7is88+KyqVSvH27duiKIri6tWrRYVCIf7+++9SW+Xl5QZtV1RUiL169RKHDx9uUA5AVCgU4smTJw3KZ8+eLQIQDx48KJUVFhaKzs7OIgDx4sWLtW57cHCw6OHhIRYXF0tle/bsEQGIXbt2rRZDfHy8NO3s7CzOnDmz1vZHjx5drR1RFMW9e/eKAEQ/P79qfaCft3fvXqlsyJAhIgBx1apVUplGoxGDg4PFTp06iRUVFaIoiuLGjRuNbrexNmuK7eLFiyIAcePGjVKZfj03btyQyo4ePSoqFApx6tSpUpl+X3r++ecN2vzrX/8qdujQodq6iFoyXhIispCXX37ZYHrbtm3Q6XSYMGECrl+/Lv24u7sjICAAe/fuBQAoFAoMGDAA33//PQDg9OnTuHHjBubNmwdRFJGVlQXg3lmXXr16oV27dtI67j8T8vvvv+PWrVsYNGiQdFnlfkOGDEHPnj0Nynbt2oV+/fqhb9++UlnHjh0xZcqUOrc3Ly8POTk5iIqKgrOzs1QeERFRbT3GtGvXDgcPHsS1a9fqrFuTqKioep8NsrKywt///ndp2traGn//+99RWFiI7Oxss2Ooi76fpk2bBhcXF6k8MDAQERER2LVrV7VlHtyXBg0ahBs3bqCkpKTJ4iRqbkxYiCzE19fXYPr8+fMQRREBAQHo2LGjwc/p06dRWFgo1R00aBCys7Nx584d7Nu3Dx4eHujTpw+CgoKky0I//PADBg0aZLCOnTt3ol+/flCr1XBxcUHHjh3xwQcf4NatW3XGBwC//fYbAgICqpU/9NBDdW7vb7/9BgBmL79ixQqcOHEC3t7e6Nu3LxISEvDrr7/Wudz9jG1TTTw9PWFvb29Q1r17dwAw+3kz9aHvJ2N98vDDD+P69evVBml36dLFYLp9+/YA7iWlRK0Fx7AQWciDf+nrdDoIgoD//e9/Ru8YuX/Mw8CBA6HVapGVlYV9+/ZJicmgQYOwb98+nDlzBkVFRQYJy759+/Dkk09i8ODBWLt2LTw8PKBSqbBx40akpKTUGZ+lTZgwAYMGDcL27duxZ88evPXWW1i+fDm2bduGUaNG1auNxt4mQRCMlldVVTXqeupS0x1G4h9jo4haAyYsRDLh7+8PURTh6+sr/SVfk759+8La2hr79u3Dvn37MGfOHAD3niny0UcfISMjQ5rW+/LLL6FWq7F7927Y2NhI5Rs3bqx3jF27dsX58+erlZ89e7ZeywIwe3kA8PDwwIwZMzBjxgwUFhaiT58+WLp0qZSw1JRAmOPatWsoKyszOMty7tw5AJAGverPZNx/xxbw51mS+9U3Nn0/GeuTM2fOwNXVtdqZH6K2gJeEiGTiqaeeglKpRGJiYrW/jEVRxI0bN6RptVqNxx57DJ9++ikuXbpkcIblzp07ePfdd+Hv7w8PDw9pGaVSCUEQDP76z83NxY4dO+od41/+8hccOHAAhw4dksqKioqwZcuWOpf18PBAcHAwNm/ebHAJKj09HadOnap12aqqqmqXrTp16gRPT09oNBqpzN7e3ujlLXNUVlbiww8/lKYrKirw4YcfomPHjggJCQFwL8kEII0n0sf673//u1p79Y3t/n66PxE6ceIE9uzZg7/85S/mbhJRi8YzLEQy4e/vjyVLliAuLg65ubkYP348HB0dcfHiRWzfvh0vvfQSXnvtNan+oEGDsGzZMjg7O6N3794A7h3EH3roIZw9exbTpk0zaH/06NFYvXo1Ro4cieeeew6FhYVYs2YNunXrhmPHjtUrxrlz5+KTTz7ByJEjMWvWLOm25q5du9arjaSkJIwePRoDBw7E888/j5s3b+K9997DI488gtLS0hqXu337Njp37oxnnnkGQUFBcHBwwDfffIOffvoJq1atkuqFhIQgNTUVsbGxeOyxx+Dg4ICxY8fWa9se5OnpieXLlyM3Nxfdu3dHamoqcnJy8O9//1t6Ku0jjzyCfv36IS4uDjdv3oSLiwu2bt2KysrKau2ZEttbb72FUaNGoX///njhhRek25qdnZ2b5f1KRLJkyVuUiNoi/a2oRUVFRud/+eWX4sCBA0V7e3vR3t5e7NGjhzhz5kzx7NmzBvW+/vprEYA4atQog/IXX3xRBCD+5z//qdb2f/7zHzEgIEC0sbERe/ToIW7cuFGK534AaryF+NixY+KQIUNEtVotenl5iYsXLxb/85//1Ou2Zv32Pfzww6KNjY3Ys2dPcdu2bWJUVFSttzVrNBpxzpw5YlBQkOjo6Cja29uLQUFB4tq1aw2WKS0tFZ977jmxXbt2BrdK628z/vzzz6vFU9NtzY888oj4888/i/379xfVarXYtWtX8f3336+2/IULF8Tw8HDRxsZGdHNzE+fPny+mp6dXa7Om2Izd1iyKovjNN9+Ijz/+uGhrays6OTmJY8eOFU+dOmVQp6Z9qabbrYlaMkEUOSqLiIiI5I1jWIiIiEj2mLAQERGR7DFhISIiItljwkJERESyx4SFiIiIZI8JCxEREcleq3hwnE6nw7Vr1+Do6Nioj+YmIiKipiOKIm7fvg1PT08oFLWfQ2kVCcu1a9fg7e1t6TCIiIjIDJcvX0bnzp1rrdMqEhZHR0cA9zbYycnJ7Ha0Wi327NmDESNGSI/ebovYD+wDgH0AsA/02A/sA6Bp+qCkpATe3t7Scbw2rSJh0V8GcnJyanDCYmdnBycnpza7QwLsB4B9ALAPAPaBHvuBfQA0bR/UZzgHB90SERGR7DFhISIiItljwkJERESy1yrGsNSHKIqorKxEVVVVjXW0Wi2srKxw9+7dWuu1di2xH1QqFZRKpaXDICKiJtImEpaKigrk5eWhvLy81nqiKMLd3R2XL19u089zaYn9IAgCOnfuDAcHB0uHQkRETaDVJyw6nQ4XL16EUqmEp6cnrK2tazwI63Q6lJaWwsHBoc4H2LRmLa0fRFFEUVERrly5goCAAJ5pISJqhVp9wlJRUQGdTgdvb2/Y2dnVWlen06GiogJqtbpFHKibSkvsh44dOyI3NxdarZYJCxFRK9QyjkaNoKUceMk8LeXSFRERmYdHcSIiIpI9JixEREQke61+DEtt3k4/ZzAtiiI0Gg1sbGya5BLDqxHdTao/dOhQfPfddwCAI0eOIDg4uNFjamyCIGD79u0YP358o7SXmZmJYcOGAQDGjRuHHTt2NEq7RETUsvAMi8xNnz4deXl56NWrl6VDMZCQkGA0gcrLy8OoUaMabT0DBgxAXl4eJkyY0GhtEhFRy9Omz7C0BHZ2dnB3d7d0GPXW2LFaW1vD3d0dtra20Gg0jdo2ERG1HDzD0oJkZmZCEARkZGQgNDQUdnZ2GDBgAM6ePWtQ76uvvkKfPn2gVqvh5+eHxMREVFZWSvPPnDmDgQMHQq1Wo2fPnvjmm28gCILB5Zb4+Hj06NEDdnZ28PPzw4IFC6DVagEAmzZtQmJiIo4ePQpBECAIAjZt2gQABu0MGDAAr7/+ukFsRUVFUKlU+P777wEAGo0Gr732Gry8vGBvb4+wsDBkZmY2bscREVGLxzMsLdAbb7yBVatWoWPHjnj55Zfx/PPP48cffwQA7Nu3D1OnTsW7776LQYMG4cKFC3jppZcA3EtCqqqqMH78eHTp0gUHDx7E7du38a9//avaOhwdHbFhwwZ07twZx48fx/Tp0+Ho6Ii5c+di4sSJOHHiBNLS0vDNN98AAJydnau1MWXKFKxYsQLLli2TxgSlpqbC09MTgwYNAgDExMTg1KlT2Lp1Kzw9PbF9+3aMHDkSx48fR0BAQJP0H9Vsbc5a6bOgE+AJT6w/vh6iQrRgVMCM4BkWXT8RWR7PsLRAS5cuxZAhQ9CzZ0/MmzcP+/fvx927dwEAiYmJmDdvHqKiouDn54eIiAgsXrwYH374IQAgPT0dFy5cwMcff4ygoCAMHDgQS5curbaO1157DQMGDICPjw/Gjh2L1157DZ999hkAwNbWFg4ODrCysoK7u7t0yeZBEyZMwLVr1/DDDz9IZSkpKZg8eTIEQcClS5ewceNGfP755xg0aBD8/f3x2muvYeDAgdi4cWNTdB0REbVQPMPSAgUGBkqfPTw8AACFhYXo0qULjh49ih9//NEgCamqqsLdu3dRXl6Os2fPwtvb22CsSd++fautY9u2bfjPf/6DCxcuoLS0FJWVlXBycjIpzo4dO2LEiBHYsmULBg0ahIsXLyIrK0tKno4fP46qqip0725495RGo0GHDh1MWhcREbVuTFhaIJVKJX3WX2rR6XQAgNLSUiQmJuKpp56qtpxara5X+1lZWXjppZeQkJCAkSNHwtnZGVu3bsWqVatMjnXKlCn45z//iffeew8pKSno3bs3evfuLcWqVCqRnZ1d7XH6fIkhERHdz6xLQmvWrIGPjw/UajXCwsJw6NChGutu27YNoaGhaNeuHezt7REcHIxPPvnEoM60adOkwZv6n5EjR5oTWpvXp08fnD17Ft26dav2o1Ao8NBDD+Hy5csoKCiQlvnpp58M2sjKyoK3tzfmz5+P0NBQBAQE4LfffjOoY21tjaqqqjrjGTduHO7evYu0tDSkpKRgypQp0rxHH30UVVVVKCwsrBZrS7ozioiImp7JZ1hSU1MRGxuLdevWISwsDMnJyYiMjMTZs2fRqVOnavVdXFzwxhtvoEePHrC2tsbOnTsRHR2NTp06ITIyUqo3cuRIg3ELNjY2Zm5S27Zw4UKMGTMGXbp0wTPPPAOFQoGjR4/ixIkTWLJkCSIiIuDv74+oqCisWLECt2/fxptvvgngz7M13bp1w5UrV7B161aEhYXh66+/xvbt2w3W4+Pjg4sXLyInJwedO3eGo6Oj0e/M3t4e48ePx4IFC3D69GlMnjxZmte9e3dMmTIFU6dOxapVq/Doo4+iqKgIGRkZCAwMxOjRo5uwp4iIqCUxOWFZvXo1pk+fjujoaADAunXr8PXXX2PDhg2YN29etfpDhw41mJ41axY2b96MH374wSBhsbGxafa/qh988qxOp0NJSQmcnJxa7MsSIyMjsXPnTixatAjLly+HSqVCjx498OKLLwIAlEolduzYgRdffBGPPfYY/Pz88NZbb2Hs2LHSJaMnn3wSr7zyCv75z39Co9Fg9OjRWLBgARISEqT1PP3009i2bRuGDRuG4uJibNy4EdOmTTMa05QpU/CXv/wFgwcPRpcuXQzmbdy4EUuWLMG//vUvXL16Fa6urujXrx/GjBnTJP1DREQtk0kJS0VFBbKzsxEXFyeVKRQKhIeHIysrq87lRVHEt99+i7Nnz2L58uUG8zIzM9GpUye0b98ew4cPx5IlS2oceKnRaAweIlZSUgIA0Gq10rNC9LRaLURRhE6nk8Z51Baf/t+66jaX+2MZPHiwdBlGXxYYGFitLCIiAhEREdXa0s/v3r279BwUANIt0X5+ftDpdBBFEYsWLcLbb79t8IqCf/7zn1IbKpVKumvo/vYfjAW4l0QZKwfuJVDx8fGIj4+vMV59P9T2vejj1mq11cbDmEO/Hz24P7V2gk6o9vn+Mkux1PfQVveDB7Ef2AdA0/SBKW0Jov4oXQ/Xrl2Dl5cX9u/fj/79+0vlc+fOxXfffYeDBw8aXe7WrVvw8vKCRqOBUqnE2rVr8fzzz0vzt27dCjs7O/j6+uLChQuYP38+HBwckJWVZfTgk5CQgMTExGrlKSkpsLOzMyjT33rr7e0Na2vr+m6qLIwZMwaHDh2CtbU1du/ejUceeaRR2t25cyfs7e3h7++PX3/9FXFxcXB2dkZaWlqjtN+Y9u/fjwkTJkCj0Uh3HBlTUVGBy5cvIz8/3+AheUREJF/l5eV47rnncOvWrTrvRG2Wu4QcHR2Rk5OD0tJSZGRkIDY2Fn5+ftLlokmTJkl1e/fujcDAQPj7+yMzMxNPPPFEtfbi4uIQGxsrTZeUlMDb2xsjRoyotsF3797F5cuX4eDgUOddMqIo4vbt23B0dGySlx+a6tNPP8WdO3cAAF26dGm0hKuyshKvv/46Ll26BFdXVzzxxBNYuXKl1Hdy6ochQ4bg8OHDAO7dOVTTDn337l3Y2tpi8ODB9b4bqjZarRbp6emIiIgwuCurtVt/fL30WdAJ8LjqgTyvPIs/OO7F3i9aZL1tdT94EPuBfQA0TR/or5DUh0kJi6urK5RKpcEdJgBQUFBQ6/gThUKBbt26AQCCg4Nx+vRpJCUlVRvfoufn5wdXV1f88ssvRhMWGxsbowM8VSpVtU6sqqqCIAhQKBR1jkvRX27Q17c0b2/vJml32rRpNY43AeTVD/b29tWe02KMQqGAIAhG94GGaOz25M5YYiIqRIsnLJb+DtraflAT9gP7AGjcPjClHZOORtbW1ggJCUFGRoZUptPpkJGRYXCJqC46na7WF9lduXIFN27ckB6KRkRERG2byZeEYmNjERUVhdDQUPTt2xfJyckoKyuT7hqaOnUqvLy8kJSUBABISkpCaGgo/P39odFosGvXLnzyySf44IMPAPz5oLOnn34a7u7uuHDhAubOnYtu3boZ3EVEREREbZfJCcvEiRNRVFSEhQsXIj8/H8HBwUhLS4ObmxsA4NKlSwaXEcrKyjBjxgxcuXIFtra26NGjB/773/9i4sSJAO7dJXLs2DFs3rwZxcXF8PT0xIgRI7B48WI+i4WIiIgAmDnoNiYmBjExMUbnZWZmGkwvWbIES5YsqbEtW1tb7N6925wwiIiIqI2w/MhSIiIiojowYSEiIiLZa9tva96bZDApiCLUGg0EGxugKZ4/Miyu7jr3GTp0KL777jsAwJEjRxAcHNz4MTWDTZs2Yfbs2SguLpam9YO0Z82aheTkZMsFR0RELQLPsMjc9OnTkZeXh169ejXbOjMzM9G+fXspwWhsEydORF5enkm3whMRUdvWts+wtAB2dnbN/lLI+qqoqDDr6bu2trawtbVtca9KICIiy+EZlhYkMzMTgiAgIyMDoaGhsLOzw4ABA3D27FmDel999RX69OkDtVoNPz8/JCYmSu/Xyc3NhSAIyMnJkeoXFxdDEARkZmYiNzdXerpwhw4dIAiC9FTcoUOHIiYmBrNnz4arq6v0nJzVq1ejd+/esLe3h7e3N2bMmIHS0tKm7xAiImozmLC0QG+88QZWrVqFn3/+GVZWVgYvkty3bx+mTp2KWbNm4dSpU/jwww+xadMmLF26tF5te3t74/PPPwcAnD59Gnl5eXjnnXek+Zs3b4a1tTV+/PFHrFu3DsC9x+K/++67OHnyJDZv3oxvv/0Wc+fObcQtJiKito6XhFqgpUuXYsiQIQCAefPmYfTo0bh79y7UajUSExMxb948REVFAbj3XqbFixdj7ty5iI+Pr7NtpVIJFxcXAECnTp2kz3oBAQFYsWKFQdns2bOlzz4+PliyZAlefvllrF27tiGbSUREJGHC0gIFBgZKn/XvWyosLESXLl1w9OhR/PjjjwZnVKqqqnD37l2Ul5c3eN0hISHVyr755hskJSXhzJkzKCkpQWVlpbQ+Ozu7Bq+TiIiICUsLdP/bLYU/br/Wv2FZ/26mp556qtpyarVaem2CKP759l2tVlvvddvb2xtM5+bmYsyYMXjllVewdOlSuLi44IcffsALL7yAiooKJixERNQomLC0Mn369MHZs2fRrVs3o/M7duwIAMjLy8Ojjz4KAAYDcAFId+9UVVXVub7s7GzodDqsWrVKSoY+++wzc8MnIiIyiglLK7Nw4UKMGTMGXbp0wTPPPAOFQoGjR4/ixIkTWLJkCWxtbdGvXz8sW7YMvr6+KCwsxJtvvmnQRteuXSEIAnbu3IkxY8bA1tYWDg4ORtfXrVs3aLVavPfeexg7dqzBYFwiIqLG0rYTlgeePCvqdLhbUgJrJycIipZ5A1VkZCR27tyJRYsWYfny5VCpVOjRowdefPFFqc6GDRvwwgsvICQkBA899BBWrFiBESNGSPO9vLwQFxeH+fPn44UXXsDUqVOxadMmo+sLCgrC6tWrsXz5csTFxWHw4MFISkrC1KlTm3pTiYioDWnbCUsLM3ToUIOxJwAQHBxcrSwyMlJ6RooxDz/8MPbv329Q9mAbc+bMweLFi6XLPED1N3Hrvfrqq3j11VcNyv72t79Jn6dNmyY9y4WIiMgcLfM0Qhuydu1aODg44Pjx45YOpdFs2bIFDg4O2Ldvn6VDISKiFoJnWGRsy5YtuHPnDgCgS5cuFo6m8Tz55JMICwsDALRr186ywRARUYvAhEXGvLy8LB1Ck3B0dISjo6OlwyAiohaEl4SIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPdwkREdXH3iRLR2DcA0/spnoy9fsUFQB6APtWA4KuSUICwO+zFm06YVmbs9ZgWhRFaDQa2NjYSG9BbkwzgmeYVH/o0KH47rvvAABHjhxBcHBwo8dkbJ1BQUFITEyssc6mTZswe/ZsFBcXN9p6p02bhs2bNwMAtm/fjvHjxzda20RE1PLxkpDMTZ8+HXl5eejVq1ezrG/btm1YtGiRNO3j44Pk5GSDOhMnTsS5c+cadb3vvPMO8vLyGrVNIiJqPdr0GZaWwM7ODu7u7s22PhcXF+h0OpSUlNRYx9bWFra2to26XmdnZzg7Ozdqm0RE1HrwDEsLkpmZCUEQ8PXXXyMwMBBqtRr9+vXDiRMnDOp9+eWXeOSRR2BjYwMfHx+sWrXKYP7atWsREBAAtVoNNzc3PPPMM9K8oUOHSi8yHD58OH777Te8+uqrEARBuky2adMm6ZH6586dgyAIOHPmjME63n77bfj7+0vTJ06cwKhRo+Dg4AA3Nzf87W9/w/Xr1xutb4iIqHVjwtICzZkzB6tWrcJPP/2Ejh07YuzYsdBqtQCA7OxsTJgwAZMmTcLx48eRkJCABQsWYNOmTQCAn3/+Gf/85z+xaNEinD17FmlpaRg8eLDR9XzxxRfo3LkzFi1ahLy8PKOXbLp3747Q0FBs2bLFoHzLli147rnnAADFxcUYPnw4Hn30Ufz8889IS0tDQUEBJkyY0Ii9QkRErRkvCbVA8fHxiIiIAABs3rwZnTt3xvbt2zFhwgSsXr0aTzzxBBYsWADgXkJx6tQpvPXWW5g2bRouXboEe3t7jBkzBo6OjujatSseffRRo+txcXGBUqmEo6NjrZelpkyZgvfffx+LFy8GcO+sS3Z2Nv773/8CAN5//308+uij+L//+z9pmQ0bNsDb2xvnzp1D9+7dG6VfiIio9eIZlhaof//+0mcXFxc89NBDOH36NADg9OnTePzxxw3qP/744zh//jyqqqoQERGBrl27ws/PD3/729+wZcsWlJeXNyieSZMmITc3FwcOHABw7+xKnz590KNHDwDA0aNHsXfvXjg4OEg/+nkXLlxo0LqJiKhtYMLSxjg6OuLw4cP49NNP4eHhgYULFyIoKKhBtyi7u7tj+PDhSElJAQCkpKRgypQp0vzS0lKMHTsWOTk5Bj/nz5+v8XIUERHR/ZiwtED6MxkA8Pvvv+PcuXN4+OGHAQAPP/wwfvzxR4P6P/74I7p37w6lUgkAsLKyQnh4OFasWIFjx44hNzcX3377rdF1WVtbo6qqqs6YpkyZgtTUVGRlZeHXX3/FpEmTpHl9+vTByZMn4ePjg27duhn82Nvbm7z9RETU9jBhaYEWLVqEjIwMnDhxAtOmTYOrq6v0oLV//etfyMjIwOLFi3Hu3Dls3rwZ77//Pl577TUAwM6dO/Huu+8iJycHv/32Gz7++GPodDo89NBDRtfl4+OD77//HlevXq31rp6nnnoKt2/fxiuvvIJhw4bB09NTmjdz5kzcvHkTkydPxk8//YQLFy5g9+7diI6OrlcyRERE1KYH3T745Fn980ecnJygUMg3l1u2bBlmzZqF8+fPIzg4GP/v//0/WFtbA7h3NuOzzz7DwoULsXjxYnh4eGDRokWYNm0aAKBdu3bYtm0bEhIScPfuXQQEBODTTz/FI488YnRdixYtwt///nf4+/tDo9FAFEWj9RwdHTF27Fh89tln2LBhg8E8T09P/Pjjj3j99dcxYsQIaDQadO3aFSNHjpR1PxMRkXy06YSlpRo4cGC1Z6/c7+mnn8bTTz9d47KZmZk1LpuZmWnw4Lh+/frh6NGjBnWmTZsmJUD3S01NRWpqqtF2AwICsG3bthrXS0REVBv+eStza9euhYODA44fP27pUJrUyy+/DAcHB0uHQUREMsUzLDK2ZcsW3LlzBwDQpUsX7N+/38IRNZ1FixZJ42w8PDwsHA0REckNExYZ8/LyMpgeOnRojWNIWrpOnTqhU6dOlg6DiIhkyqxLQmvWrIGPjw/UajXCwsJw6NChGutu27YNoaGhaNeuHezt7REcHIxPPvnEoI4oili4cCE8PDxga2uL8PBwnD9/3pzQiIiIqBUyOWFJTU1FbGws4uPjcfjwYQQFBSEyMhKFhYVG67u4uOCNN95AVlYWjh07hujoaERHR2P37t1SnRUrVuDdd9/FunXrcPDgQdjb2yMyMhJ37941f8se0FrPTNA9/H6JiFo3kxOW1atXY/r06YiOjkbPnj2xbt062NnZVbuVVW/o0KH461//iocffhj+/v6YNWsWAgMD8cMPPwC4d6BJTk7Gm2++iXHjxiEwMBAff/wxrl27hh07djRo4wBApVIBQIMfP0/yVlFRAQDSw/GIiKh1MWkMS0VFBbKzsxEXFyeVKRQKhIeHIysrq87lRVHEt99+i7Nnz2L58uUAgIsXLyI/Px/h4eFSPWdnZ4SFhSErK8vgial6Go0GGo1GmtbfgqvVaqW3Ft/P0dERBQUF0Ol0sLOzgyAINcZXUVGBO3fu1FinLWhp/aDT6VBYWAi1Wg1RFI3uA6bSt9EYbbUkgk6o9vn+Mkux1PdgsB+IMr2pshn6plX+Ppj4fWr/qK9t6v1Axn3cFPuBKW2ZlLBcv34dVVVVcHNzMyh3c3PDmTNnalzu1q1b8PLygkajgVKpxNq1a6W3Defn50ttPNimft6DkpKSkJiYWK18z549sLOzM7qMo6MjysrK+KCyVkqr1aKoqAjHjh1r1HbT09MbtT2584RntTKPq5a/a2vX5V0WXf+9/aCHRWOo0a7m65vW9ftg3veZXtrEb5dvxu/TXI25H5hy9aNZ7hJydHRETk4OSktLkZGRgdjYWPj5+WHo0KFmtRcXF4fY2FhpuqSkBN7e3hgxYgScnJxqXK6qqgqVlZU1jneorKzE/v37MWDAAFhZtd0bqFpaPwiCAJVK1ajJqFarRXp6OiIiIqTLim3B+uPrpc+CToDHVQ/keeVBVFh2jNCLvV+0yHoN9oMD71kkhjoNiq27TgO1yt+HfatNqq4VFUgv7Y4Ih3NQCbomCgrN8n2aqyn2A/0Vkvow6Wjk6uoKpVKJgoICg/KCggK4u7vXuJxCoUC3bt0AAMHBwTh9+jSSkpIwdOhQabmCggKD528UFBQgODjYaHs2NjawsbGpVq5SqWrtxLo6WKvVorKyEg4ODq3nl9IM7Ic/1bVPtTbGEhNRIVo8YbH0d6BSqZr2INUQzdg3rer3wczvUyXomnZfaAH925j7gSntmPQnqbW1NUJCQpCRkSGV6XQ6ZGRkoH///vVuR6fTSWNQfH194e7ubtBmSUkJDh48aFKbRERE1HqZfL4/NjYWUVFRCA0NRd++fZGcnIyysjJER0cDAKZOnQovLy8kJSUBuDfeJDQ0VHp53q5du/DJJ5/ggw8+AHDvdP7s2bOxZMkSBAQEwNfXFwsWLICnp6f0BmIiIiJq20xOWCZOnIiioiIsXLgQ+fn5CA4ORlpamjRo9tKlSwZjCcrKyjBjxgxcuXIFtra26NGjB/773/9i4sSJUp25c+eirKwML730EoqLizFw4ECkpaVBrVY3wiYSERFRS2fWiMqYmBjExMQYnffgm4CXLFmCJUuW1NqeIAhYtGgRFi1aZE44RERE1MrxHl8iIiKSPSYsREREJHtMWIiIiEj2mLAQERGR7DFhISIiItljwkJERESyx4SFiIiIZI8JCxEREckeExYiIiKSPSYsREREJHtMWIiIiEj2zHqXEFFr8Xb6OaPlglgFXwBr9v4CUVA2a0yvRnRv1vUREbUEPMNCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0rSwdARIbeTj9nsXUfLrkhfbaCAuOtPC0WCxHR/XiGhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9sxKWNasWQMfHx+o1WqEhYXh0KFDNdb96KOPMGjQILRv3x7t27dHeHh4tfrTpk2DIAgGPyNHjjQnNCIiImqFTE5YUlNTERsbi/j4eBw+fBhBQUGIjIxEYWGh0fqZmZmYPHky9u7di6ysLHh7e2PEiBG4evWqQb2RI0ciLy9P+vn000/N2yIiIiJqdUxOWFavXo3p06cjOjoaPXv2xLp162BnZ4cNGzYYrb9lyxbMmDEDwcHB6NGjB9avXw+dToeMjAyDejY2NnB3d5d+2rdvb94WERERUatj0oPjKioqkJ2djbi4OKlMoVAgPDwcWVlZ9WqjvLwcWq0WLi4uBuWZmZno1KkT2rdvj+HDh2PJkiXo0KGD0TY0Gg00Go00XVJSAgDQarXQarWmbJIB/bINaaM1aEv9IIhVtZbXNL+1srrvbxj9Z0EnWCociaX2RYPfBVGmQ/6aoW9a5f8JJn6f2j/qa5t6P5BxHzfFfmBKW4IoimJ9K1+7dg1eXl7Yv38/+vfvL5XPnTsX3333HQ4ePFhnGzNmzMDu3btx8uRJqNVqAMDWrVthZ2cHX19fXLhwAfPnz4eDgwOysrKgVCqrtZGQkIDExMRq5SkpKbCzs6vv5hAREZEFlZeX47nnnsOtW7fg5ORUa91mfTT/smXLsHXrVmRmZkrJCgBMmjRJ+ty7d28EBgbC398fmZmZeOKJJ6q1ExcXh9jYWGm6pKREGhtT1wbXRqvVIj09HREREVCpVGa309K1pX5Ys/cXo+WCWAWfuxeQq/aHKFRPmluro7e3SZ+toMAYq57I88qDqKj33zVN4sXeL1pkvQa/Cwfes0gMdRoUW3edBmqV/yfsW21Sda2oQHppd0Q4nINK0DVRUPKm7fePRt8P9FdI6sOkhMXV1RVKpRIFBQUG5QUFBXB3d6912ZUrV2LZsmX45ptvEBgYWGtdPz8/uLq64pdffjGasNjY2MDGxqZauUqlapRObKx2Wrq20A91JSOioGxTCUslqv9HLCpEiycslt4PVSqVfA9Szdg3rer/BDO/T5Wgk+++0NT++O4bcz8wpR2TLsZZW1sjJCTEYMCsfgDt/ZeIHrRixQosXrwYaWlpCA0NrXM9V65cwY0bN+Dh4WFKeERERNRKmTx6KDY2Fh999BE2b96M06dP45VXXkFZWRmio6MBAFOnTjUYlLt8+XIsWLAAGzZsgI+PD/Lz85Gfn4/S0lIAQGlpKebMmYMDBw4gNzcXGRkZGDduHLp164bIyMhG2kwiIiJqyUwewzJx4kQUFRVh4cKFyM/PR3BwMNLS0uDm5gYAuHTpEhSKP/OgDz74ABUVFXjmmWcM2omPj0dCQgKUSiWOHTuGzZs3o7i4GJ6enhgxYgQWL15s9LIPERERtT1mDbqNiYlBTEyM0XmZmZkG07m5ubW2ZWtri927d5sTBhEREbURMn2wABEREdGfmLAQERGR7DFhISIiItljwkJERESyx4SFiIiIZI8JCxEREckeExYiIiKSPSYsREREJHtMWIiIiEj2mLAQERGR7DFhISIiItljwkJERESyx4SFiIiIZI8JCxEREckeExYiIiKSPSYsREREJHtMWIiIiEj2mLAQERGR7DFhISIiItmzsnQA1LqszVlr6RCqmRE8w9IhUCu0tviYpUO4pxl+5wSdAE94Yv3x9RAVosnL83eQGgPPsBAREZHsMWEhIiIi2WPCQkRERLLHhIWIiIhkjwkLERERyR4TFiIiIpI93tZMzSrrwo1mX6em6Fyzr7M1OXTxJiqhs2gM93+Hr0Z0t2AkRGQpPMNCREREsseEhYiIiGSPl4SIiKhpXNx379/fb1k2DmoVeIaFiIiIZI8JCxEREckeExYiIiKSPSYsREREJHtMWIiIiEj2mLAQERGR7DFhISIiItljwkJERESyx4SFiIiIZM+shGXNmjXw8fGBWq1GWFgYDh06VGPdjz76CIMGDUL79u3Rvn17hIeHV6sviiIWLlwIDw8P2NraIjw8HOfPnzcnNCIiImqFTE5YUlNTERsbi/j4eBw+fBhBQUGIjIxEYWGh0fqZmZmYPHky9u7di6ysLHh7e2PEiBG4evWqVGfFihV49913sW7dOhw8eBD29vaIjIzE3bt3zd8yIiIiajVMfpfQ6tWrMX36dERHRwMA1q1bh6+//hobNmzAvHnzqtXfsmWLwfT69evx5ZdfIiMjA1OnToUoikhOTsabb76JcePGAQA+/vhjuLm5YceOHZg0aVK1NjUaDTQajTRdUlICANBqtdBqtaZukkS/bEPaaA0a0g+CTqh1vpUFrkIKYpXZy5izbEt2//ej/2yJ7+xB938Pzfn7afC7IBr2gyCXV7HV8TvXGPS/13X9fld3r4+0ouX3oYbSb0Nr2BZzNcUx0pS2BFEUxfpWrqiogJ2dHb744guMHz9eKo+KikJxcTG++uqrOtu4ffs2OnXqhM8//xxjxozBr7/+Cn9/fxw5cgTBwcFSvSFDhiA4OBjvvPNOtTYSEhKQmJhYrTwlJQV2dnb13RwiIiKyoPLycjz33HO4desWnJycaq1r0p8I169fR1VVFdzc3AzK3dzccObMmXq18frrr8PT0xPh4eEAgPz8fKmNB9vUz3tQXFwcYmNjpemSkhLpUlNdG1wbrVaL9PR0REREQKVSmd1OS9eQflh/fH2t8w9dvNmQ0MwS5PiUycsIYhV87l5ArtofoqBsgqjk6ejtbdJnKygwxqondlaeQiV0FozK8DucOaxbs63X4HfhwHsG89bfOtFscdSq64AmX4WgE+Bx1QN5XnkQFfX+Gxf4bT8A4EXnXk0UWfPRigqkl3ZHhMM5qATL/j5YirbfPxr9GKm/QlIfzXpOc9myZdi6dSsyMzOhVqvNbsfGxgY2NjbVylUqVaN0YmO109KZ0w91/WdmiQNfQxIOUVC2qYTF2PdTCZ3FE5b7vwNL/G6qVKpqBykRlc0eh1GmJBANJCpE0xKWP/qoNR3gVYKuVW2PSf743WvMY6Qp7Zh0Mc7V1RVKpRIFBQUG5QUFBXB3d6912ZUrV2LZsmXYs2cPAgMDpXL9cua0SURERG2DSQmLtbU1QkJCkJGRIZXpdDpkZGSgf//+NS63YsUKLF68GGlpaQgNDTWY5+vrC3d3d4M2S0pKcPDgwVrbJCIiorbD5EtCsbGxiIqKQmhoKPr27Yvk5GSUlZVJdw1NnToVXl5eSEpKAgAsX74cCxcuREpKCnx8fKRxKQ4ODnBwcIAgCJg9ezaWLFmCgIAA+Pr6YsGCBfD09DQY2EtERERtl8kJy8SJE1FUVISFCxciPz8fwcHBSEtLkwbNXrp0CQrFnyduPvjgA1RUVOCZZ54xaCc+Ph4JCQkAgLlz56KsrAwvvfQSiouLMXDgQKSlpTVonAsRERG1HmYNuo2JiUFMTIzReZmZmQbTubm5dbYnCAIWLVqERYsWmRMOERERtXJt9wk4RERE1GIwYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJnllvayYiak6HS1Klz2tzOjTbegWdAE94Yv3x9RCLjzXbeomoOp5hISIiItljwkJERESyx4SFiIiIZI8JCxEREckeExYiIiKSPd4lRETUAJeL71h0/Vcu3GjydVhBgfFWnjh08SYqoav3cp1L/uibdk0TF7UtPMNCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGTPrIRlzZo18PHxgVqtRlhYGA4dOlRj3ZMnT+Lpp5+Gj48PBEFAcnJytToJCQkQBMHgp0ePHuaERkRERK2QyQlLamoqYmNjER8fj8OHDyMoKAiRkZEoLCw0Wr+8vBx+fn5YtmwZ3N3da2z3kUceQV5envTzww8/mBoaERERtVJWpi6wevVqTJ8+HdHR0QCAdevW4euvv8aGDRswb968avUfe+wxPPbYYwBgdL4UiJVVrQlNS/B2+jlLh1DNqxHdLR0CERFRg5mUsFRUVCA7OxtxcXFSmUKhQHh4OLKyshoUyPnz5+Hp6Qm1Wo3+/fsjKSkJXbp0MVpXo9FAo9FI0yUlJQAArVYLrVZrdgz6Zc1tQxCrzF53UzFnWxrSD4JOqHW+lQWGTZnzveiXkeN32pTu/370ny3xndWmrn2sKdZ171/j/10qoWq2eIxpju/H3H1B3zdaUV77kDn029AatsVcDT1G1tZmfQiiKIr1rXzt2jV4eXlh//796N+/v1Q+d+5cfPfddzh48GCty/v4+GD27NmYPXu2Qfn//vc/lJaW4qGHHkJeXh4SExNx9epVnDhxAo6OjtXaSUhIQGJiYrXylJQU2NnZ1XdziIiIyILKy8vx3HPP4datW3Bycqq1rsmXhJrCqFGjpM+BgYEICwtD165d8dlnn+GFF16oVj8uLg6xsbHSdElJCby9vTFixIg6N7g2Wq0W6enpiIiIgEpl+l9Na/b+Yva6m8rMYd1MXqYh/bD++Ppa5x+6eNPkeBoqyPEpk5cRxCr43L2AXLU/REHZBFHJ09Hb26TPVlBgjFVP7Kw8hUroLBiVob6+Ls22LkEnwOOqB/K88iBe/tFonavFd5stHqPrdwpu8nWYuy94leQAAOK7hjZRZM1HKyqQXtodEQ7noBLk8/vQnLT9/tGgY6Qx+isk9WFSwuLq6gqlUomCggKD8oKCgkYdf9KuXTt0794dv/xiPAGwsbGBjY1NtXKVStUonWhuO3I8sDWkP8zpB1FR+wk7Sxz4GvK9iIJSlt9rUzH2/VRCJ6uEpa59rKnWKaLS6LwqNN7pcXM053dj6r6g75vWdIBXCbpWtT0m+eN40FjHWn1b9WXSxThra2uEhIQgIyNDKtPpdMjIyDC4RNRQpaWluHDhAjw8PBqtTSIiImq5TL4kFBsbi6ioKISGhqJv375ITk5GWVmZdNfQ1KlT4eXlhaSkJAD3BuqeOnVK+nz16lXk5OTAwcEB3brdu1zx2muvYezYsejatSuuXbuG+Ph4KJVKTJ48ubG2k4iIiFowkxOWiRMnoqioCAsXLkR+fj6Cg4ORlpYGNzc3AMClS5egUPx54ubatWt49NFHpemVK1di5cqVGDJkCDIzMwEAV65cweTJk3Hjxg107NgRAwcOxIEDB9CxY8cGbh4RERG1BmYNuo2JiUFMTIzRefokRM/Hxwd13Yi0detWc8IgIiKiNkIWdwkRmatzSXaddfoV3zK5XZ2gxPUOA/DYlU1QtMJnsRzo8pKlQyAiMknbfQIOERERtRhMWIiIiEj2mLAQERGR7DFhISIiItljwkJERESyx4SFiIiIZI8JCxEREcken8NCrd7/pzD9LdpKqPAYBuB/wq+oEprm5XZP6kx/kzYRUVvFMyxEREQke0xYiIiISPaYsBAREZHscQwLERE1qaxfb1g6BAP9/TpYOgQyA8+wEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkexZWToAorbq/1P8YrF1XylJtdi6qe2x5L5uzJHiq5jRLtDSYZCJeIaFiIiIZI8JCxEREckeExYiIiKSPY5hIaIWJevCjWZblxUUGG/liUMXb8K95E6zrZeIquMZFiIiIpI9JixEREQke0xYiIiISPbMSljWrFkDHx8fqNVqhIWF4dChQzXWPXnyJJ5++mn4+PhAEAQkJyc3uE0iIiJqW0xOWFJTUxEbG4v4+HgcPnwYQUFBiIyMRGFhodH65eXl8PPzw7Jly+Du7t4obRIREVHbYnLCsnr1akyfPh3R0dHo2bMn1q1bBzs7O2zYsMFo/cceewxvvfUWJk2aBBsbm0Zpk4iIiNoWk25rrqioQHZ2NuLi4qQyhUKB8PBwZGVlmRWAOW1qNBpoNBppuqSkBACg1Wqh1WrNikO//P3/mkoQq8xed1MxZ1sa0g+CTqh1vlUjD5tSQtWo7ekp/mhX0UTtW1p9vgd9ncb+zlqS+/ugqfa1hmqO78fcfUGufSbAClrRtG3R1zd1udakocfI2tqsD5MSluvXr6Oqqgpubm4G5W5ubjhz5owpTTWozaSkJCQmJlYr37NnD+zs7MyK437p6elmLefb4DU3vl27zpm9rDn94AnPWuePt6p9vslcejVuew8IcZnYpO1bymMm1B1j1bPJ4mgpxlj1BFzk2Q+mfJcNZfK+0MS/nw2x67Z5y6WXdm/cQFqSP44J5h4jjSkvL6933Rb54Li4uDjExsZK0yUlJfD29saIESPg5ORkdrtarRbp6emIiIiASmX6XwZr9srrBV8AMHNYN5OXaUg/rD++vtb5hy7eNDme2niV5DRqe3oKqBDiMhHZN1OhQ+P9NSEXV52C66xjBQXGWPXEzspTqISu6YOSofv7wK3ksKXDMao+32VDmbsvNNXvZ0N5tVPjRWfTkimtqEB6aXdEOJyDSmibvw/afv9o0DHSGP0VkvowKWFxdXWFUqlEQUGBQXlBQUGNA2qbok0bGxuj42FUKlWjdKK57YiCssHrbmwN6Q9z+kFUiLXOb+wDX1UTJxM6aJt8HZZgyvdQCV2bTVj0KqGT7X7QnN+NqfuCXPtMhJXZSYdK0LXZhAV/HA8a61irb6u+TLoYZ21tjZCQEGRkZEhlOp0OGRkZ6N+/vylNNWmbRERE1LqYfEkoNjYWUVFRCA0NRd++fZGcnIyysjJER0cDAKZOnQovLy8kJSUBuDeo9tSpU9Lnq1evIicnBw4ODujWrVu92iQiIqK2zeSEZeLEiSgqKsLChQuRn5+P4OBgpKWlSYNmL126BIXizxM3165dw6OPPipNr1y5EitXrsSQIUOQmZlZrzaJiIiobTNr0G1MTAxiYmKMztMnIXo+Pj4QxdrHNdTVJhEREbVtbfeGciIiImoxmLAQERGR7DFhISIiItlrkQ+Oa25rc9bWq97hkhtNHInp1uZ0MHkZQSfAE55Yf3x9nc9VISIiag5MWIhINjqXZFs6BANKqACXXvAqyYH83hR2T3P0mWE/yPNhcNT68ZIQERERyR7PsLRyWRdMv0xlBQXGW3ni0MWbbf6R7EREJA88w0JERESyx4SFiIiIZI8JCxEREckeExYiIiKSPSYsREREJHtMWIiIiEj2mLAQERGR7DFhISIiItljwkJERESyx4SFiIiIZI8JCxEREckeExYiIiKSPSYsREREJHtMWIiIiEj2rCwdABE1v84l2XXWUUIFuPSCV0kOqqBthqiIiGrGMyxEREQke0xYiIiISPaYsBAREZHsMWEhIiIi2WPCQkRERLLHhIWIiIhkjwkLERERyR4TFiIiIpI9JixEREQke0xYiIiISPaYsBAREZHsMWEhIiIi2WPCQkRERLLHhIWIiIhkjwkLERERyR4TFiIiIpI9sxKWNWvWwMfHB2q1GmFhYTh06FCt9T///HP06NEDarUavXv3xq5duwzmT5s2DYIgGPyMHDnSnNCIiIioFbIydYHU1FTExsZi3bp1CAsLQ3JyMiIjI3H27Fl06tSpWv39+/dj8uTJSEpKwpgxY5CSkoLx48fj8OHD6NWrl1Rv5MiR2LhxozRtY2Nj5iZRU+hckm3pEIiIqA0z+QzL6tWrMX36dERHR6Nnz55Yt24d7OzssGHDBqP133nnHYwcORJz5szBww8/jMWLF6NPnz54//33DerZ2NjA3d1d+mnfvr15W0REREStjklnWCoqKpCdnY24uDipTKFQIDw8HFlZWUaXycrKQmxsrEFZZGQkduzYYVCWmZmJTp06oX379hg+fDiWLFmCDh06GG1To9FAo9FI0yUlJQAArVYLrVZryiYZ0C/7YBuCTqjX8latZEiQfjvu3x4lVJYKxyIUf2yvoo1t9/3YB+wDvdbWDwKsoBVN+/9aX9/U5VqTmo6RjdFmfZiUsFy/fh1VVVVwc3MzKHdzc8OZM2eMLpOfn2+0fn5+vjQ9cuRIPPXUU/D19cWFCxcwf/58jBo1CllZWVAqldXaTEpKQmJiYrXyPXv2wM7OzpRNMio9Pd1g2hOe9VpuvFX96rUUY6x6/jnh0qvmiq1YiMtES4dgcewD9oFea+qHXbfNWy69tHvjBtKS/HFsfPAY2RDl5eX1rmvyGJamMGnSJOlz7969ERgYCH9/f2RmZuKJJ56oVj8uLs7grE1JSQm8vb0xYsQIODk5mR2HVqtFeno6IiIioFL9+ZfE+uPr67X8oYs3zV63nFhBgTFWPbGz8hQqoQMAeJXkWDaoZqaACiEuE5F9MxU6NN5fEy0J+4B9oNfa+sGrnRovOpv2R5hWVCC9tDsiHM5BJeiaKDJ50/b7h9FjZEPor5DUh0kJi6urK5RKJQoKCgzKCwoK4O7ubnQZd3d3k+oDgJ+fH1xdXfHLL78YTVhsbGyMDspVqVSN0okPtiMqxHotpz+4txaV0EnbVNUK/pMyhw7aNrvteuwD9oFea+kHEVZmJx0qQddmExb8cVxsrGOtvq36MulinLW1NUJCQpCRkSGV6XQ6ZGRkoH///kaX6d+/v0F94N7ppJrqA8CVK1dw48YNeHh4mBIeERERtVImjx6KjY3FRx99hM2bN+P06dN45ZVXUFZWhujoaADA1KlTDQblzpo1C2lpaVi1ahXOnDmDhIQE/Pzzz4iJiQEAlJaWYs6cOThw4AByc3ORkZGBcePGoVu3boiMjGykzSQiIqKWzOQxLBMnTkRRUREWLlyI/Px8BAcHIy0tTRpYe+nSJSgUf+ZBAwYMQEpKCt58803Mnz8fAQEB2LFjh/QMFqVSiWPHjmHz5s0oLi6Gp6cnRowYgcWLF/NZLERERATAzEG3MTEx0hmSB2VmZlYre/bZZ/Hss88arW9ra4vdu3ebEwYRERG1EW33hnIiIiJqMZiwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJnlkJy5o1a+Dj4wO1Wo2wsDAcOnSo1vqff/45evToAbVajd69e2PXrl0G80VRxMKFC+Hh4QFbW1uEh4fj/Pnz5oRGRERErZDJCUtqaipiY2MRHx+Pw4cPIygoCJGRkSgsLDRaf//+/Zg8eTJeeOEFHDlyBOPHj8f48eNx4sQJqc6KFSvw7rvvYt26dTh48CDs7e0RGRmJu3fvmr9lRERE1GqYnLCsXr0a06dPR3R0NHr27Il169bBzs4OGzZsMFr/nXfewciRIzFnzhw8/PDDWLx4Mfr06YP3338fwL2zK8nJyXjzzTcxbtw4BAYG4uOPP8a1a9ewY8eOBm0cERERtQ5WplSuqKhAdnY24uLipDKFQoHw8HBkZWUZXSYrKwuxsbEGZZGRkVIycvHiReTn5yM8PFya7+zsjLCwMGRlZWHSpEnV2tRoNNBoNNL0rVu3AAA3b96EVqs1ZZMMaLValJeX48aNG1CpVFL53ZL6nenRlVeYvW450UGBcqty6CoroIMOAFB5x8JBNTMdgPLycmjv4I8eaHvYB+wDvdbWD3etdbhhZdr/11pRce/4IFRAJbSGXjCd9sYNo8fIhrh9+zaAeycv6mJSwnL9+nVUVVXBzc3NoNzNzQ1nzpwxukx+fr7R+vn5+dJ8fVlNdR6UlJSExMTEauW+vr712xCq08eWDkAWvrB0ADLAPmAf6LWufviXpQNokRKarOXbt2/D2dm51jomJSxyERcXZ3DWRqfT4ebNm+jQoQMEQTC73ZKSEnh7e+Py5ctwcnJqjFBbJPYD+wBgHwDsAz32A/sAaJo+EEURt2/fhqenZ511TUpYXF1doVQqUVBQYFBeUFAAd3d3o8u4u7vXWl//b0FBATw8PAzqBAcHG23TxsYGNjY2BmXt2rUzZVNq5eTk1GZ3yPuxH9gHAPsAYB/osR/YB0Dj90FdZ1b0TBp0a21tjZCQEGRkZEhlOp0OGRkZ6N+/v9Fl+vfvb1AfANLT06X6vr6+cHd3N6hTUlKCgwcP1tgmERERtS0mXxKKjY1FVFQUQkND0bdvXyQnJ6OsrAzR0dEAgKlTp8LLywtJSUkAgFmzZmHIkCFYtWoVRo8eja1bt+Lnn3/Gv//9bwCAIAiYPXs2lixZgoCAAPj6+mLBggXw9PTE+PHjG29LiYiIqMUyOWGZOHEiioqKsHDhQuTn5yM4OBhpaWnSoNlLly5BofjzxM2AAQOQkpKCN998E/Pnz0dAQAB27NiBXr16SXXmzp2LsrIyvPTSSyguLsbAgQORlpYGtVrdCJtYfzY2NoiPj692uamtYT+wDwD2AcA+0GM/sA8Ay/eBINbnXiIiIiIiC+K7hIiIiEj2mLAQERGR7DFhISIiItljwkJERESyx4SFiIiIZI8JSy2efPJJdOnSBWq1Gh4eHvjb3/6Ga9euWTqsZpObm4sXXngBvr6+sLW1hb+/P+Lj41FR0Tpe8lhfS5cuxYABA2BnZ9eoT1SWuzVr1sDHxwdqtRphYWE4dOiQpUNqVt9//z3Gjh0LT09PCILQ5t4en5SUhMceewyOjo7o1KkTxo8fj7Nnz1o6rGb3wQcfIDAwUHq6a//+/fG///3P0mFZ1LJly6RnqDUnJiy1GDZsGD777DOcPXsWX375JS5cuIBnnnnG0mE1mzNnzkCn0+HDDz/EyZMn8fbbb2PdunWYP3++pUNrVhUVFXj22WfxyiuvWDqUZpOamorY2FjEx8fj8OHDCAoKQmRkJAoLCy0dWrMpKytDUFAQ1qxZY+lQLOK7777DzJkzceDAAaSnp0Or1WLEiBEoKyuzdGjNqnPnzli2bBmys7Px888/Y/jw4Rg3bhxOnjxp6dAs4qeffsKHH36IwMDA5l+5SPX21VdfiYIgiBUVFZYOxWJWrFgh+vr6WjoMi9i4caPo7Oxs6TCaRd++fcWZM2dK01VVVaKnp6eYlJRkwagsB4C4fft2S4dhUYWFhSIA8bvvvrN0KBbXvn17cf369ZYOo9ndvn1bDAgIENPT08UhQ4aIs2bNatb18wxLPd28eRNbtmzBgAEDoFKpLB2Oxdy6dQsuLi6WDoOaUEVFBbKzsxEeHi6VKRQKhIeHIysry4KRkSXdunULANr0739VVRW2bt2KsrKyNvmuu5kzZ2L06NEG/zc0JyYsdXj99ddhb2+PDh064NKlS/jqq68sHZLF/PLLL3jvvffw97//3dKhUBO6fv06qqqqpNdt6Lm5uSE/P99CUZEl6XQ6zJ49G48//rjBa1XaiuPHj8PBwQE2NjZ4+eWXsX37dvTs2dPSYTWrrVu34vDhw9J7Ai2hzSUs8+bNgyAItf6cOXNGqj9nzhwcOXIEe/bsgVKpxNSpUyG28LcZmNoHAHD16lWMHDkSzz77LKZPn26hyBuPOX1A1FbNnDkTJ06cwNatWy0dikU89NBDyMnJwcGDB/HKK68gKioKp06dsnRYzeby5cuYNWsWtmzZ0uzv+Ltfm3uXUFFREW7cuFFrHT8/P1hbW1crv3LlCry9vbF///4WfTrQ1D64du0ahg4din79+mHTpk0GL7dsqczZDzZt2oTZs2ejuLi4iaOzrIqKCtjZ2eGLL74weGN6VFQUiouL2+RZRkEQsH379jb5BvmYmBh89dVX+P777+Hr62vpcGQhPDwc/v7++PDDDy0dSrPYsWMH/vrXv0KpVEplVVVVEAQBCoUCGo3GYF5TMfltzS1dx44d0bFjR7OW1el0AACNRtOYITU7U/rg6tWrGDZsGEJCQrBx48ZWkawADdsPWjtra2uEhIQgIyNDOkDrdDpkZGQgJibGssFRsxFFEf/4xz+wfft2ZGZmMlm5j06na/HHAVM88cQTOH78uEFZdHQ0evTogddff71ZkhWgDSYs9XXw4EH89NNPGDhwINq3b48LFy5gwYIF8Pf3b9FnV0xx9epVDB06FF27dsXKlStRVFQkzXN3d7dgZM3r0qVLuHnzJi5duoSqqirk5OQAALp16wYHBwfLBtdEYmNjERUVhdDQUPTt2xfJyckoKytDdHS0pUNrNqWlpfjll1+k6YsXLyInJwcuLi7o0qWLBSNrHjNnzkRKSgq++uorODo6SuOXnJ2dYWtra+Homk9cXBxGjRqFLl264Pbt20hJSUFmZiZ2795t6dCajaOjY7WxS/qxnc06pqlZ70lqQY4dOyYOGzZMdHFxEW1sbEQfHx/x5ZdfFq9cuWLp0JrNxo0bRQBGf9qSqKgoo32wd+9eS4fWpN577z2xS5cuorW1tdi3b1/xwIEDlg6pWe3du9fo9x4VFWXp0JpFTb/7GzdutHRozer5558Xu3btKlpbW4sdO3YUn3jiCXHPnj2WDsviLHFbc5sbw0JEREQtT+sYkEBEREStGhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7/z9HhY5nYwKkDgAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "for ctrl_s in ctrl_str:\n", + " plt.hist(\n", + " [r for r, t in zip(logs[\"env/reward_dist\"], task_list) if t == ctrl_s], density=True, alpha=0.5, label=ctrl_s\n", + " )\n", + "plt.legend(loc=\"best\")\n", + "plt.title(\"reward distribution\")\n", + "plt.grid(True)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Save model\n", + "Finally, we save the model to disk for later usage." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gpt2_model.save_pretrained(\"gpt2-imdb-ctrl\")\n", + "gpt2_tokenizer.save_pretrained(\"gpt2-imdb-ctrl\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "trl", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.12" + }, + "vscode": { + "interpreter": { + "hash": "d2cfb53525227c89f8d14fa784301fa46c451cc9223d94ccce9e17956835eea2" + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/trl/trl/examples/notebooks/gpt2-sentiment.ipynb b/trl/trl/examples/notebooks/gpt2-sentiment.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..03d86e385f3b6e0f69452a4540ea1d1e0aed8799 --- /dev/null +++ b/trl/trl/examples/notebooks/gpt2-sentiment.ipynb @@ -0,0 +1,879 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Tune GPT2 to generate positive reviews\n", + "> Optimise GPT2 to produce positive IMDB movie reviews using a BERT sentiment classifier as a reward function." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "

Figure: Experiment setup to tune GPT2. The yellow arrows are outside the scope of this notebook, but the trained models are available through Hugging Face.

\n", + "
\n", + "\n", + "\n", + "In this notebook we fine-tune GPT2 (small) to generate positive movie reviews based on the IMDB dataset. The model gets the start of a real review and is tasked to produce positive continuations. To reward positive continuations we use a BERT classifier to analyse the sentiment of the produced sentences and use the classifier's outputs as rewards signals for PPO training." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup experiment" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Import dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install transformers trl wandb" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from tqdm import tqdm\n", + "import pandas as pd\n", + "\n", + "tqdm.pandas()\n", + "\n", + "from transformers import pipeline, AutoTokenizer\n", + "from datasets import load_dataset\n", + "\n", + "from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead\n", + "from trl.core import LengthSampler" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Configuration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "config = PPOConfig(\n", + " model_name=\"lvwerra/gpt2-imdb\",\n", + " learning_rate=1.41e-5,\n", + " log_with=\"wandb\",\n", + ")\n", + "\n", + "sent_kwargs = {\"return_all_scores\": True, \"function_to_apply\": \"none\", \"batch_size\": 16}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import wandb\n", + "\n", + "wandb.init()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can see that we load a GPT2 model called `gpt2_imdb`. This model was additionally fine-tuned on the IMDB dataset for 1 epoch with the huggingface [script](https://github.com/huggingface/transformers/blob/master/examples/run_language_modeling.py) (no special settings). The other parameters are mostly taken from the original paper [\"Fine-Tuning Language Models from Human Preferences\"](\n", + "https://arxiv.org/pdf/1909.08593.pdf). This model as well as the BERT model is available in the Huggingface model zoo [here](https://huggingface.co/models). The following code should automatically download the models." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load data and models" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load IMDB dataset\n", + "The IMDB dataset contains 50k movie review annotated with \"positive\"/\"negative\" feedback indicating the sentiment. We load the IMDB dataset into a DataFrame and filter for comments that are at least 200 characters. Then we tokenize each text and cut it to random size with the `LengthSampler`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Reusing dataset imdb (/home/leandro/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)\n", + "Loading cached processed dataset at /home/leandro/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-ff455473e884c6a3.arrow\n" + ] + } + ], + "source": [ + "def build_dataset(config, dataset_name=\"imdb\", input_min_text_length=2, input_max_text_length=8):\n", + " \"\"\"\n", + " Build dataset for training. This builds the dataset from `load_dataset`, one should\n", + " customize this function to train the model on its own dataset.\n", + "\n", + " Args:\n", + " dataset_name (`str`):\n", + " The name of the dataset to be loaded.\n", + "\n", + " Returns:\n", + " dataloader (`torch.utils.data.DataLoader`):\n", + " The dataloader for the dataset.\n", + " \"\"\"\n", + " tokenizer = AutoTokenizer.from_pretrained(config.model_name)\n", + " tokenizer.pad_token = tokenizer.eos_token\n", + " # load imdb with datasets\n", + " ds = load_dataset(dataset_name, split=\"train\")\n", + " ds = ds.rename_columns({\"text\": \"review\"})\n", + " ds = ds.filter(lambda x: len(x[\"review\"]) > 200, batched=False)\n", + "\n", + " input_size = LengthSampler(input_min_text_length, input_max_text_length)\n", + "\n", + " def tokenize(sample):\n", + " sample[\"input_ids\"] = tokenizer.encode(sample[\"review\"])[: input_size()]\n", + " sample[\"query\"] = tokenizer.decode(sample[\"input_ids\"])\n", + " return sample\n", + "\n", + " ds = ds.map(tokenize, batched=False)\n", + " ds.set_format(type=\"torch\")\n", + " return ds" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = build_dataset(config)\n", + "\n", + "\n", + "def collator(data):\n", + " return dict((key, [d[key] for d in data]) for key in data[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load pre-trained GPT2 language models" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We load the GPT2 model with a value head and the tokenizer. We load the model twice; the first model is optimized while the second model serves as a reference to calculate the KL-divergence from the starting point. This serves as an additional reward signal in the PPO training to make sure the optimized model does not deviate too much from the original language model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)\n", + "ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)\n", + "tokenizer = AutoTokenizer.from_pretrained(config.model_name)\n", + "\n", + "tokenizer.pad_token = tokenizer.eos_token" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Initialize PPOTrainer\n", + "The `PPOTrainer` takes care of device placement and optimization later on:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load BERT classifier\n", + "We load a BERT classifier fine-tuned on the IMDB dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "device = ppo_trainer.accelerator.device\n", + "if ppo_trainer.accelerator.num_processes == 1:\n", + " device = 0 if torch.cuda.is_available() else \"cpu\" # to avoid a `pipeline` bug\n", + "sentiment_pipe = pipeline(\"sentiment-analysis\", model=\"lvwerra/distilbert-imdb\", device=device)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The model outputs are the logits for the negative and positive class. We will use the logits for positive class as a reward signal for the language model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[[{'label': 'NEGATIVE', 'score': 2.335048198699951},\n", + " {'label': 'POSITIVE', 'score': -2.726576566696167}]]" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "text = \"this movie was really bad!!\"\n", + "sentiment_pipe(text, **sent_kwargs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[[{'label': 'NEGATIVE', 'score': -2.2947897911071777},\n", + " {'label': 'POSITIVE', 'score': 2.557039737701416}]]" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "text = \"this movie was really good!!\"\n", + "sentiment_pipe(text, **sent_kwargs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Generation settings\n", + "For the response generation we just use sampling and make sure top-k and nucleus sampling are turned off as well as a minimal length." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gen_kwargs = {\"min_length\": -1, \"top_k\": 0.0, \"top_p\": 1.0, \"do_sample\": True, \"pad_token_id\": tokenizer.eos_token_id}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Optimize model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training loop" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The training loop consists of the following main steps:\n", + "1. Get the query responses from the policy network (GPT-2)\n", + "2. Get sentiments for query/responses from BERT\n", + "3. Optimize policy with PPO using the (query, response, reward) triplet\n", + "\n", + "**Training time**\n", + "\n", + "This step takes **~2h** on a V100 GPU with the above specified settings." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "output_min_length = 4\n", + "output_max_length = 16\n", + "output_length_sampler = LengthSampler(output_min_length, output_max_length)\n", + "\n", + "\n", + "generation_kwargs = {\n", + " \"min_length\": -1,\n", + " \"top_k\": 0.0,\n", + " \"top_p\": 1.0,\n", + " \"do_sample\": True,\n", + " \"pad_token_id\": tokenizer.eos_token_id,\n", + "}\n", + "\n", + "\n", + "for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):\n", + " query_tensors = batch[\"input_ids\"]\n", + "\n", + " #### Get response from gpt2\n", + " response_tensors = []\n", + " for query in query_tensors:\n", + " gen_len = output_length_sampler()\n", + " generation_kwargs[\"max_new_tokens\"] = gen_len\n", + " response = ppo_trainer.generate(query, **generation_kwargs)\n", + " response_tensors.append(response.squeeze()[-gen_len:])\n", + " batch[\"response\"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]\n", + "\n", + " #### Compute sentiment score\n", + " texts = [q + r for q, r in zip(batch[\"query\"], batch[\"response\"])]\n", + " pipe_outputs = sentiment_pipe(texts, **sent_kwargs)\n", + " rewards = [torch.tensor(output[1][\"score\"]) for output in pipe_outputs]\n", + "\n", + " #### Run PPO step\n", + " stats = ppo_trainer.step(query_tensors, response_tensors, rewards)\n", + " ppo_trainer.log_stats(stats, batch, rewards)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training progress\n", + "If you are tracking the training progress with Weights&Biases you should see a plot similar to the one below. Check out the interactive sample report on wandb.ai: [link](https://app.wandb.ai/huggingface/trl-showcase/runs/1jtvxb1m/).\n", + "\n", + "
\n", + "\n", + "

Figure: Reward mean and distribution evolution during training.

\n", + "
\n", + "\n", + "One can observe how the model starts to generate more positive outputs after a few optimisation steps.\n", + "\n", + "> Note: Investigating the KL-divergence will probably show that at this point the model has not converged to the target KL-divergence, yet. To get there would require longer training or starting with a higher initial coefficient." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model inspection\n", + "Let's inspect some examples from the IMDB dataset. We can use `model_ref` to compare the tuned model `model` against the model before optimisation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/leandro/miniconda3/envs/trl/lib/python3.9/site-packages/transformers/pipelines/base.py:1075: UserWarning: You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
queryresponse (before)response (after)rewards (before)rewards (after)
0Oh dear,what are I saying?! I fast-forwarded throughI must say that I are hanging my head on this-0.858954-1.007609
1I've seenit, as well.<brthree million dialogue throughout, and1.9968072.240883
2Hi:<br /><br/>This movie is a turkey though when it comes to/>I also like that movie. It's so funny-0.4381912.415630
3I'm a writerand I'm not going to be asked to, not a screenwriter. I've written-0.655991-0.724324
4If youabsolutely love sensitive romance, the plot a...are looking at the cinematography, the acting,2.2213090.148751
5OMG thiscasting cast. Obi cult breezy, this ismovie was totally wonderful, I it was the ide...-1.5331392.590190
6It'sunrealistic; the guy who was supposed to be E...a very good film. It reminds us about over-2.0970172.835831
7There is a reallyawful laptop game!<br /><br />I used tointeresting story that set us the journey. Th...-2.3417432.282939
8This ismy favorite part abouta well thought well2.5547942.734139
9Wasn'tWasn't it clichéd?<|endoftext|>anyone else interested in this movie? It's a ...-1.7908022.631960
10This film is another of director TimBurton's masterpiecesCurry's best bombs2.6229172.544106
11I thought this moviewas excellent. I actually laughed 6 times and...was perfect, and I believe it's almost overlo...2.5480222.601913
12This early John Waynefilms looked like an abandoned police beatingfilm is a realistic portrayal of what-1.7422792.609762
13I wasgiven an experience-a big one, almost 25very happy with all the reflections and this ...2.2507092.558540
14Embarrassingly, Iam more at a strict conformity after getting ...had never seen a movie before. There was one ...-2.021666-1.803383
15I am a fanof living on simple islands, and we have visi...of many things and learned how to appreciate ...1.7912972.324461
\n", + "
" + ], + "text/plain": [ + " query \\\n", + "0 Oh dear, \n", + "1 I've seen \n", + "2 Hi:

This movie is a turkey though when it comes to \n", + "3 and I'm not going to be asked to \n", + "4 absolutely love sensitive romance, the plot a... \n", + "5 casting cast. Obi cult breezy, this is \n", + "6 unrealistic; the guy who was supposed to be E... \n", + "7 awful laptop game!

I used to \n", + "8 my favorite part about \n", + "9 Wasn't it clichéd?<|endoftext|> \n", + "10 Burton's masterpieces \n", + "11 was excellent. I actually laughed 6 times and... \n", + "12 films looked like an abandoned police beating \n", + "13 given an experience-a big one, almost 25 \n", + "14 am more at a strict conformity after getting ... \n", + "15 of living on simple islands, and we have visi... \n", + "\n", + " response (after) rewards (before) \\\n", + "0 I must say that I are hanging my head on this -0.858954 \n", + "1 three million dialogue throughout, and 1.996807 \n", + "2 />I also like that movie. It's so funny -0.438191 \n", + "3 , not a screenwriter. I've written -0.655991 \n", + "4 are looking at the cinematography, the acting, 2.221309 \n", + "5 movie was totally wonderful, I it was the ide... -1.533139 \n", + "6 a very good film. It reminds us about over -2.097017 \n", + "7 interesting story that set us the journey. Th... -2.341743 \n", + "8 a well thought well 2.554794 \n", + "9 anyone else interested in this movie? It's a ... -1.790802 \n", + "10 Curry's best bombs 2.622917 \n", + "11 was perfect, and I believe it's almost overlo... 2.548022 \n", + "12 film is a realistic portrayal of what -1.742279 \n", + "13 very happy with all the reflections and this ... 2.250709 \n", + "14 had never seen a movie before. There was one ... -2.021666 \n", + "15 of many things and learned how to appreciate ... 1.791297 \n", + "\n", + " rewards (after) \n", + "0 -1.007609 \n", + "1 2.240883 \n", + "2 2.415630 \n", + "3 -0.724324 \n", + "4 0.148751 \n", + "5 2.590190 \n", + "6 2.835831 \n", + "7 2.282939 \n", + "8 2.734139 \n", + "9 2.631960 \n", + "10 2.544106 \n", + "11 2.601913 \n", + "12 2.609762 \n", + "13 2.558540 \n", + "14 -1.803383 \n", + "15 2.324461 " + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#### get a batch from the dataset\n", + "bs = 16\n", + "game_data = dict()\n", + "dataset.set_format(\"pandas\")\n", + "df_batch = dataset[:].sample(bs)\n", + "game_data[\"query\"] = df_batch[\"query\"].tolist()\n", + "query_tensors = df_batch[\"input_ids\"].tolist()\n", + "\n", + "response_tensors_ref, response_tensors = [], []\n", + "\n", + "#### get response from gpt2 and gpt2_ref\n", + "for i in range(bs):\n", + " gen_len = output_length_sampler()\n", + " output = ref_model.generate(\n", + " torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs\n", + " ).squeeze()[-gen_len:]\n", + " response_tensors_ref.append(output)\n", + " output = model.generate(\n", + " torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs\n", + " ).squeeze()[-gen_len:]\n", + " response_tensors.append(output)\n", + "\n", + "#### decode responses\n", + "game_data[\"response (before)\"] = [tokenizer.decode(response_tensors_ref[i]) for i in range(bs)]\n", + "game_data[\"response (after)\"] = [tokenizer.decode(response_tensors[i]) for i in range(bs)]\n", + "\n", + "#### sentiment analysis of query/response pairs before/after\n", + "texts = [q + r for q, r in zip(game_data[\"query\"], game_data[\"response (before)\"])]\n", + "game_data[\"rewards (before)\"] = [output[1][\"score\"] for output in sentiment_pipe(texts, **sent_kwargs)]\n", + "\n", + "texts = [q + r for q, r in zip(game_data[\"query\"], game_data[\"response (after)\"])]\n", + "game_data[\"rewards (after)\"] = [output[1][\"score\"] for output in sentiment_pipe(texts, **sent_kwargs)]\n", + "\n", + "# store results in a dataframe\n", + "df_results = pd.DataFrame(game_data)\n", + "df_results" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Looking at the reward mean/median of the generated sequences we observe a significant difference." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mean:\n" + ] + }, + { + "data": { + "text/plain": [ + "rewards (before) 0.156629\n", + "rewards (after) 1.686487\n", + "dtype: float64" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "median:\n" + ] + }, + { + "data": { + "text/plain": [ + "rewards (before) -0.547091\n", + "rewards (after) 2.479868\n", + "dtype: float64" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "print(\"mean:\")\n", + "display(df_results[[\"rewards (before)\", \"rewards (after)\"]].mean())\n", + "print()\n", + "print(\"median:\")\n", + "display(df_results[[\"rewards (before)\", \"rewards (after)\"]].median())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Save model\n", + "Finally, we save the model and push it to the Hugging Face for later usage." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/leandro/miniconda3/envs/trl/lib/python3.9/site-packages/huggingface_hub/hf_api.py:1001: FutureWarning: `create_repo` now takes `token` as an optional positional argument. Be sure to adapt your code!\n", + " warnings.warn(\n", + "Cloning https://huggingface.co/lvwerra/gpt2-imdb-pos-v2 into local empty directory.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a953a6d0c465432bbc39aca826d37aaf", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Upload file pytorch_model.bin: 0%| | 32.0k/487M [00:00 main\n", + "\n", + "remote: Enforcing permissions... \n", + "remote: Allowed refs: all \n", + "To https://huggingface.co/lvwerra/gpt2-imdb-pos-v2\n", + " 28b9865..42792ea main -> main\n", + "\n" + ] + }, + { + "data": { + "text/plain": [ + "('gpt2-imdb-pos-v2/tokenizer_config.json',\n", + " 'gpt2-imdb-pos-v2/special_tokens_map.json',\n", + " 'gpt2-imdb-pos-v2/vocab.json',\n", + " 'gpt2-imdb-pos-v2/merges.txt',\n", + " 'gpt2-imdb-pos-v2/added_tokens.json',\n", + " 'gpt2-imdb-pos-v2/tokenizer.json')" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.save_pretrained(\"gpt2-imdb-pos-v2\", push_to_hub=True)\n", + "tokenizer.save_pretrained(\"gpt2-imdb-pos-v2\", push_to_hub=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.12 (main, Mar 26 2022, 15:51:15) \n[Clang 13.1.6 (clang-1316.0.21.2)]" + }, + "vscode": { + "interpreter": { + "hash": "4c8ff454cd947027f86954d72bf940c689a97dcc494eb53cfe4813862c6065fe" + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/trl/trl/examples/research_projects/README.md b/trl/trl/examples/research_projects/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1b1977e1877ca1d6351cd888b76793a2bad3206d --- /dev/null +++ b/trl/trl/examples/research_projects/README.md @@ -0,0 +1,7 @@ +# Research projects that use TRL + +Welcome to the research projects folder! Here you can find the scripts used for some research projects that used TRL and maintained by the developers and the community (LM de-toxification, Stack-Llama, etc.). Check out the READMEs in the subfolders for more information! + +- [De-detoxifying language models](https://github.com/huggingface/trl/tree/main/examples/research_projects/toxicity) +- [Stack-Llama](https://github.com/huggingface/trl/tree/main/examples/research_projects/stack_llama) +- [Stack-Llama-2](https://github.com/huggingface/trl/tree/main/examples/research_projects/stack_llama_2) \ No newline at end of file diff --git a/trl/trl/examples/research_projects/stack_llama/scripts/README.md b/trl/trl/examples/research_projects/stack_llama/scripts/README.md new file mode 100644 index 0000000000000000000000000000000000000000..60ed5fd94397c3954313cbc88de5648a387f03ea --- /dev/null +++ b/trl/trl/examples/research_projects/stack_llama/scripts/README.md @@ -0,0 +1,18 @@ +# RLHF pipeline for the creation of StackLLaMa: a Stack exchange llama-7b model. +There were three main steps to the training process: +1. Supervised fine-tuning of the base llama-7b model to create llama-7b-se: + - `torchrun --nnodes 1 --nproc_per_node 8 examples/research_projects/stack_llama/scripts/supervised_finetuning.py --model_path= --streaming --no_gradient_checkpointing --learning_rate 1e-5 --max_steps 5000 --output_dir ./llama-se` +2. Reward modeling using dialog pairs from the SE dataset using the llama-7b-se to create llama-7b-se-rm: + - `torchrun --nnodes 1 --nproc_per_node 8 examples/research_projects/stack_llama/scripts/reward_modeling.py --model_name=` +3. RL fine-tuning of llama-7b-se with the llama-7b-se-rm reward model: + - `accelerate launch --multi_gpu --num_machines 1 --num_processes 8 examples/research_projects/stack_llama/scripts/rl_training.py --log_with=wandb --model_name= --reward_model_name= --adafactor=False --tokenizer_name= --save_freq=100 --output_max_length=128 --batch_size=8 --gradient_accumulation_steps=8 --batched_gen=True --ppo_epochs=4 --seed=0 --learning_rate=1.4e-5 --early_stopping=True --output_dir=llama-se-rl-finetune-128-8-8-1.4e-5_adam` + + +LoRA layers were using at all stages to reduce memory requirements. +At each stage the peft adapter layers were merged with the base model, using: +```shell +python examples/research_projects/stack_llama/scripts/merge_peft_adapter.py --adapter_model_name=XXX --base_model_name=YYY --output_name=ZZZ +``` +Note that this script requires `peft>=0.3.0`. + +For access to the base llama-7b model, please see Meta's [release](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/) and [request form](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform). diff --git a/trl/trl/examples/research_projects/stack_llama/scripts/merge_peft_adapter.py b/trl/trl/examples/research_projects/stack_llama/scripts/merge_peft_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..e7ff3b5cd982c171b0a4db948c7932e60e97fe37 --- /dev/null +++ b/trl/trl/examples/research_projects/stack_llama/scripts/merge_peft_adapter.py @@ -0,0 +1,48 @@ +from dataclasses import dataclass, field +from typing import Optional + +import torch +from peft import PeftConfig, PeftModel +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser + + +@dataclass +class ScriptArguments: + """ + The input names representing the Adapter and Base model fine-tuned with PEFT, and the output name representing the + merged model. + """ + + adapter_model_name: Optional[str] = field(default=None, metadata={"help": "the adapter name"}) + base_model_name: Optional[str] = field(default=None, metadata={"help": "the base model name"}) + output_name: Optional[str] = field(default=None, metadata={"help": "the merged model name"}) + + +parser = HfArgumentParser(ScriptArguments) +script_args = parser.parse_args_into_dataclasses()[0] +assert script_args.adapter_model_name is not None, "please provide the name of the Adapter you would like to merge" +assert script_args.base_model_name is not None, "please provide the name of the Base model" +assert script_args.output_name is not None, "please provide the output name of the merged model" + +peft_config = PeftConfig.from_pretrained(script_args.adapter_model_name) +if peft_config.task_type == "SEQ_CLS": + # The sequence classification task is used for the reward model in PPO + model = AutoModelForSequenceClassification.from_pretrained( + script_args.base_model_name, num_labels=1, torch_dtype=torch.bfloat16 + ) +else: + model = AutoModelForCausalLM.from_pretrained( + script_args.base_model_name, return_dict=True, torch_dtype=torch.bfloat16 + ) + +tokenizer = AutoTokenizer.from_pretrained(script_args.base_model_name) + +# Load the PEFT model +model = PeftModel.from_pretrained(model, script_args.adapter_model_name) +model.eval() + +model = model.merge_and_unload() + +model.save_pretrained(f"{script_args.output_name}") +tokenizer.save_pretrained(f"{script_args.output_name}") +model.push_to_hub(f"{script_args.output_name}", use_temp_dir=False) diff --git a/trl/trl/examples/research_projects/stack_llama/scripts/reward_modeling.py b/trl/trl/examples/research_projects/stack_llama/scripts/reward_modeling.py new file mode 100644 index 0000000000000000000000000000000000000000..d895d356db5dc7d7af522d4e6bdbefce24bd9b6d --- /dev/null +++ b/trl/trl/examples/research_projects/stack_llama/scripts/reward_modeling.py @@ -0,0 +1,300 @@ +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Union + +import evaluate +import numpy as np +import torch +import torch.nn as nn +from datasets import load_dataset +from peft import LoraConfig, TaskType, get_peft_model +from transformers import ( + AutoModelForSequenceClassification, + AutoTokenizer, + HfArgumentParser, + PreTrainedTokenizerBase, + Trainer, + TrainerCallback, + TrainingArguments, +) +from transformers.utils import PaddingStrategy + + +# Define and parse arguments. +@dataclass +class ScriptArguments: + """ + These arguments vary depending on how many GPUs you have, what their capacity and features are, and what size model you want to train. + """ + + local_rank: Optional[int] = field(default=-1, metadata={"help": "Used for multi-gpu"}) + resume_from_checkpoint: Optional[bool] = field( + default=False, + metadata={"help": "If you want to resume training where it left off."}, + ) + deepspeed: Optional[str] = field( + default=None, + metadata={ + "help": "Path to deepspeed config if using deepspeed. You may need this if the model that you want to train doesn't fit on a single GPU." + }, + ) + per_device_train_batch_size: Optional[int] = field(default=4) + per_device_eval_batch_size: Optional[int] = field(default=1) + gradient_accumulation_steps: Optional[int] = field(default=1) + learning_rate: Optional[float] = field(default=2e-5) + weight_decay: Optional[float] = field(default=0.001) + model_name: Optional[str] = field( + default="gpt2", + metadata={ + "help": "The model that you want to train from the Hugging Face hub. E.g. gpt2, gpt2-xl, bert, etc." + }, + ) + tokenizer_name: Optional[str] = field( + default=None, + metadata={ + "help": "The tokenizer for your model, if left empty will use the default for your model", + }, + ) + bf16: Optional[bool] = field( + default=True, + metadata={ + "help": "This essentially cuts the training time in half if you want to sacrifice a little precision and have a supported GPU." + }, + ) + num_train_epochs: Optional[int] = field( + default=1, + metadata={"help": "The number of training epochs for the reward model."}, + ) + train_subset: Optional[int] = field( + default=100000, + metadata={"help": "The size of the subset of the training data to use"}, + ) + eval_subset: Optional[int] = field( + default=50000, + metadata={"help": "The size of the subset of the eval data to use"}, + ) + gradient_checkpointing: Optional[bool] = field( + default=False, + metadata={"help": "Enables gradient checkpointing."}, + ) + optim: Optional[str] = field( + default="adamw_hf", + metadata={"help": "The optimizer to use."}, + ) + lr_scheduler_type: Optional[str] = field( + default="linear", + metadata={"help": "The lr scheduler"}, + ) + max_length: Optional[int] = field(default=512) + eval_first_step: Optional[bool] = field( + default=False, + metadata={"help": "Whether to run eval after the first step"}, + ) + + +parser = HfArgumentParser(ScriptArguments) +script_args = parser.parse_args_into_dataclasses()[0] + +# Load the human stack-exchange-paired dataset for tuning the reward model. +train_dataset = load_dataset("lvwerra/stack-exchange-paired", data_dir="data/reward", split="train") +if script_args.train_subset > 0: + train_dataset = train_dataset.select(range(script_args.train_subset)) +eval_dataset = load_dataset("lvwerra/stack-exchange-paired", data_dir="data/evaluation", split="train") +if script_args.eval_subset > 0: + eval_dataset = eval_dataset.select(range(script_args.eval_subset)) +# Define the training args. Needs to be done before the model is loaded if you are using deepspeed. +model_name_split = script_args.model_name.split("/")[-1] +output_name = ( + f"{model_name_split}_peft_stack-exchange-paired_rmts__{script_args.train_subset}_{script_args.learning_rate}" +) + +training_args = TrainingArguments( + output_dir=output_name, + learning_rate=script_args.learning_rate, + per_device_train_batch_size=script_args.per_device_train_batch_size, + per_device_eval_batch_size=script_args.per_device_eval_batch_size, + num_train_epochs=script_args.num_train_epochs, + weight_decay=script_args.weight_decay, + evaluation_strategy="steps", + eval_steps=500, + save_strategy="steps", + save_steps=500, + gradient_accumulation_steps=script_args.gradient_accumulation_steps, + gradient_checkpointing=script_args.gradient_checkpointing, + deepspeed=script_args.deepspeed, + local_rank=script_args.local_rank, + remove_unused_columns=False, + label_names=[], + bf16=script_args.bf16, + logging_strategy="steps", + logging_steps=10, + optim=script_args.optim, + lr_scheduler_type=script_args.lr_scheduler_type, +) +# Load the value-head model and tokenizer. +tokenizer_name = script_args.tokenizer_name if script_args.tokenizer_name is not None else script_args.model_name +tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_auth_token=True) +tokenizer.pad_token = tokenizer.eos_token + + +peft_config = LoraConfig( + task_type=TaskType.SEQ_CLS, + inference_mode=False, + r=8, + lora_alpha=32, + lora_dropout=0.1, +) + +model = AutoModelForSequenceClassification.from_pretrained( + script_args.model_name, num_labels=1, torch_dtype=torch.bfloat16 +) +model = get_peft_model(model, peft_config) +model.print_trainable_parameters() + +# Need to do this for gpt2, because it doesn't have an official pad token. +tokenizer.pad_token = tokenizer.eos_token +model.config.pad_token_id = tokenizer.eos_token_id +model.config.use_cache = not script_args.gradient_checkpointing +num_proc = 24 # Can adjust to be higher if you have more processors. +original_columns = train_dataset.column_names + + +# Turn the dataset into pairs of post + summaries, where text_j is the preferred question + answer and text_k is the other. +# Then tokenize the dataset. +def preprocess_function(examples): + new_examples = { + "input_ids_j": [], + "attention_mask_j": [], + "input_ids_k": [], + "attention_mask_k": [], + } + for question, response_j, response_k in zip(examples["question"], examples["response_j"], examples["response_k"]): + tokenized_j = tokenizer("Question: " + question + "\n\nAnswer: " + response_j, truncation=True) + tokenized_k = tokenizer("Question: " + question + "\n\nAnswer: " + response_k, truncation=True) + + new_examples["input_ids_j"].append(tokenized_j["input_ids"]) + new_examples["attention_mask_j"].append(tokenized_j["attention_mask"]) + new_examples["input_ids_k"].append(tokenized_k["input_ids"]) + new_examples["attention_mask_k"].append(tokenized_k["attention_mask"]) + + return new_examples + + +# preprocess the dataset and filter out QAs that are longer than script_args.max_length +train_dataset = train_dataset.map( + preprocess_function, + batched=True, + num_proc=num_proc, + remove_columns=original_columns, +) +train_dataset = train_dataset.filter( + lambda x: len(x["input_ids_j"]) <= script_args.max_length and len(x["input_ids_k"]) <= script_args.max_length +) + +eval_dataset = eval_dataset.map( + preprocess_function, + batched=True, + num_proc=num_proc, + remove_columns=original_columns, +) +eval_dataset = eval_dataset.filter( + lambda x: len(x["input_ids_j"]) <= script_args.max_length and len(x["input_ids_k"]) <= script_args.max_length +) + + +# We need to define a special data collator that batches the data in our j vs k format. +@dataclass +class RewardDataCollatorWithPadding: + tokenizer: PreTrainedTokenizerBase + padding: Union[bool, str, PaddingStrategy] = True + max_length: Optional[int] = None + pad_to_multiple_of: Optional[int] = None + return_tensors: str = "pt" + + def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: + features_j = [] + features_k = [] + for feature in features: + features_j.append( + { + "input_ids": feature["input_ids_j"], + "attention_mask": feature["attention_mask_j"], + } + ) + features_k.append( + { + "input_ids": feature["input_ids_k"], + "attention_mask": feature["attention_mask_k"], + } + ) + batch_j = self.tokenizer.pad( + features_j, + padding=self.padding, + max_length=self.max_length, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors=self.return_tensors, + ) + batch_k = self.tokenizer.pad( + features_k, + padding=self.padding, + max_length=self.max_length, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors=self.return_tensors, + ) + batch = { + "input_ids_j": batch_j["input_ids"], + "attention_mask_j": batch_j["attention_mask"], + "input_ids_k": batch_k["input_ids"], + "attention_mask_k": batch_k["attention_mask"], + "return_loss": True, + } + return batch + + +# Define the metric that we'll use for validation. +accuracy = evaluate.load("accuracy") + + +def compute_metrics(eval_pred): + predictions, _ = eval_pred + # Here, predictions is rewards_j and rewards_k. + # We want to see how much of the time rewards_j > rewards_k. + predictions = np.argmax(predictions, axis=0) + labels = np.zeros(predictions.shape) + return accuracy.compute(predictions=predictions, references=labels) + + +class RewardTrainer(Trainer): + # Define how to compute the reward loss. We use the InstructGPT pairwise logloss: https://arxiv.org/abs/2203.02155 + def compute_loss(self, model, inputs, return_outputs=False): + rewards_j = model(input_ids=inputs["input_ids_j"], attention_mask=inputs["attention_mask_j"])[0] + rewards_k = model(input_ids=inputs["input_ids_k"], attention_mask=inputs["attention_mask_k"])[0] + loss = -nn.functional.logsigmoid(rewards_j - rewards_k).mean() + if return_outputs: + return loss, {"rewards_j": rewards_j, "rewards_k": rewards_k} + return loss + + +# Train the model, woohoo. +trainer = RewardTrainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + compute_metrics=compute_metrics, + data_collator=RewardDataCollatorWithPadding(tokenizer=tokenizer, max_length=script_args.max_length), +) + + +if script_args.eval_first_step: + + class EvaluateFirstStepCallback(TrainerCallback): + def on_step_end(self, args, state, control, **kwargs): + if state.global_step == 1: + control.should_evaluate = True + + trainer.add_callback(EvaluateFirstStepCallback()) + +trainer.train(script_args.resume_from_checkpoint) + +print("Saving last checkpoint of the model") +model.save_pretrained(output_name + "_peft_last_checkpoint") diff --git a/trl/trl/examples/research_projects/stack_llama/scripts/rl_training.py b/trl/trl/examples/research_projects/stack_llama/scripts/rl_training.py new file mode 100644 index 0000000000000000000000000000000000000000..eee7952660f2084236620f2264986e3544562bab --- /dev/null +++ b/trl/trl/examples/research_projects/stack_llama/scripts/rl_training.py @@ -0,0 +1,263 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field +from typing import Optional + +import torch +from accelerate import Accelerator +from datasets import load_dataset +from peft import LoraConfig +from tqdm import tqdm +from transformers import Adafactor, AutoTokenizer, HfArgumentParser, pipeline + +from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, set_seed +from trl.core import LengthSampler + + +tqdm.pandas() + + +@dataclass +class ScriptArguments: + """ + The name of the Casual LM model we wish to fine with PPO + """ + + # NOTE: gpt2 models use Conv1D instead of Linear layers which are not yet supported in 8 bit mode + # models like gpt-neo* models are more suitable. + model_name: Optional[str] = field(default="", metadata={"help": "the model name"}) + tokenizer_name: Optional[str] = field(default="", metadata={"help": "the tokenizer name"}) + reward_model_name: Optional[str] = field(default="", metadata={"help": "the reward model name"}) + log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"}) + learning_rate: Optional[float] = field(default=1.41e-5, metadata={"help": "the learning rate"}) + output_max_length: Optional[int] = field(default=128, metadata={"help": "maximum length for generation"}) + mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"}) + batch_size: Optional[int] = field(default=32, metadata={"help": "the batch size"}) + ppo_epochs: Optional[int] = field(default=4, metadata={"help": "the number of ppo epochs"}) + gradient_accumulation_steps: Optional[int] = field( + default=4, metadata={"help": "the number of gradient accumulation steps"} + ) + adafactor: Optional[bool] = field(default=False, metadata={"help": "whether to use the adafactor optimizer"}) + early_stopping: Optional[bool] = field(default=False, metadata={"help": "whether to early stop"}) + target_kl: Optional[float] = field(default=0.1, metadata={"help": "kl target for early stopping"}) + reward_baseline: Optional[float] = field( + default=0.0, + metadata={"help": "a baseline value that is subtracted from the reward"}, + ) + batched_gen: Optional[bool] = field(default=False, metadata={"help": "whether to use the batched text gen"}) + save_freq: Optional[int] = field(default=None, metadata={"help": "n steps to save the model"}) + output_dir: Optional[str] = field(default="runs/", metadata={"help": "n steps to save the model"}) + seed: Optional[int] = field(default=0, metadata={"help": "the seed"}) + steps: Optional[int] = field(default=20000, metadata={"help": "number of epochs"}) + init_kl_coef: Optional[float] = field( + default=0.2, + metadata={"help": "Initial KL penalty coefficient (used for adaptive and linear control)"}, + ) + + adap_kl_ctrl: Optional[bool] = field(default=True, metadata={"help": "Use adaptive KL control, otherwise linear"}) + + +parser = HfArgumentParser(ScriptArguments) +script_args: ScriptArguments = parser.parse_args_into_dataclasses()[0] +reward_model_name = script_args.reward_model_name +dataset_name = "lvwerra/stack-exchange-paired" +config = PPOConfig( + steps=script_args.steps, + model_name=script_args.model_name, + learning_rate=script_args.learning_rate, + log_with=script_args.log_with, + batch_size=script_args.batch_size, + mini_batch_size=script_args.mini_batch_size, + gradient_accumulation_steps=script_args.gradient_accumulation_steps, + optimize_cuda_cache=True, + early_stopping=script_args.early_stopping, + target_kl=script_args.target_kl, + ppo_epochs=script_args.ppo_epochs, + seed=script_args.seed, + init_kl_coef=script_args.init_kl_coef, + adap_kl_ctrl=script_args.adap_kl_ctrl, +) + +train_dataset = load_dataset("lvwerra/stack-exchange-paired", data_dir="data/rl", split="train") +train_dataset = train_dataset.select(range(100000)) +original_columns = train_dataset.column_names + +# We then define the arguments to pass to the sentiment analysis pipeline. +# We set `return_all_scores` to True to get the sentiment score for each token. +sent_kwargs = { + "return_all_scores": True, + "function_to_apply": "none", + "batch_size": 16, + "truncation": True, +} + +tokenizer = AutoTokenizer.from_pretrained(script_args.tokenizer_name) +# GPT-2 tokenizer has a pad token, but it is not eos_token by default. We need to set it to eos_token. +# only for this model. + +if getattr(tokenizer, "pad_token", None) is None: + tokenizer.pad_token = tokenizer.eos_token + + +# Below is an example function to build the dataset. In our case, we use the IMDB dataset +# from the `datasets` library. One should customize this function to train the model on +# its own dataset. +def build_dataset( + tokenizer, + dataset_name="lvwerra/stack-exchange-paired", +): + """ + Build dataset for training. This builds the dataset from `load_dataset`, one should + customize this function to train the model on its own dataset. + + Args: + dataset_name (`str`): + The name of the dataset to be loaded. + + Returns: + dataloader (`torch.utils.data.DataLoader`): + The dataloader for the dataset. + """ + + num_proc = 24 + + def preprocess_function(examples): + new_examples = { + "query": [], + "input_ids": [], + } + for question in examples["question"]: + query = "Question: " + question + "\n\nAnswer: " + tokenized_question = tokenizer(query, truncation=True) + new_examples["query"].append(query) + new_examples["input_ids"].append(tokenized_question["input_ids"]) + + return new_examples + + ds = train_dataset.map( + preprocess_function, + batched=True, + num_proc=num_proc, + remove_columns=original_columns, + ) + ds = ds.filter(lambda x: len(x["input_ids"]) < 512, batched=False) + + ds.set_format(type="torch") + return ds + + +# We retrieve the dataloader by calling the `build_dataset` function. +dataset = build_dataset(tokenizer) + + +def collator(data): + return dict((key, [d[key] for d in data]) for key in data[0]) + + +# set seed before initializing value head for deterministic eval +set_seed(config.seed) + +# Now let's build the model, the reference model, and the tokenizer. +current_device = Accelerator().local_process_index + +lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", +) +model = AutoModelForCausalLMWithValueHead.from_pretrained( + config.model_name, + load_in_8bit=True, + device_map={"": current_device}, + peft_config=lora_config, +) + +optimizer = None +if script_args.adafactor: + optimizer = Adafactor( + filter(lambda p: p.requires_grad, model.parameters()), + scale_parameter=False, + relative_step=False, + warmup_init=False, + lr=config.learning_rate, + ) +# We then build the PPOTrainer, passing the model, the reference model, the tokenizer +ppo_trainer = PPOTrainer( + config, + model, + ref_model=None, + tokenizer=tokenizer, + dataset=dataset, + data_collator=collator, + optimizer=optimizer, +) + +# We then build the sentiment analysis pipeline using our reward model, passing the +# model name and the sentiment analysis pipeline arguments. Let's also make sure to +# set the device to the same device as the PPOTrainer. +device = ppo_trainer.accelerator.device +if ppo_trainer.accelerator.num_processes == 1: + device = 0 if torch.cuda.is_available() else "cpu" # to avoid a ` pipeline` bug +sentiment_pipe = pipeline( + "sentiment-analysis", + model=reward_model_name, + device_map={"": current_device}, + model_kwargs={"load_in_8bit": True}, + tokenizer=tokenizer, + return_token_type_ids=False, +) + +# We then define the arguments to pass to the `generate` function. These arguments +# are passed to the `generate` function of the PPOTrainer, which is a wrapper around +# the `generate` function of the trained model. +generation_kwargs = { + # "min_length": -1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + "pad_token_id": tokenizer.pad_token_id, + "eos_token_id": 100_000, +} +output_min_length = 32 +output_max_length = script_args.output_max_length +output_length_sampler = LengthSampler(output_min_length, output_max_length) + +for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)): + if epoch >= config.total_ppo_epochs: + break + + question_tensors = batch["input_ids"] + + response_tensors = ppo_trainer.generate( + question_tensors, + return_prompt=False, + length_sampler=output_length_sampler, + **generation_kwargs, + ) + batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True) + + # Compute reward score (using the sentiment analysis pipeline) + texts = [q + r for q, r in zip(batch["query"], batch["response"])] + pipe_outputs = sentiment_pipe(texts, **sent_kwargs) + rewards = [torch.tensor(output[0]["score"] - script_args.reward_baseline) for output in pipe_outputs] + + # Run PPO step + stats = ppo_trainer.step(question_tensors, response_tensors, rewards) + ppo_trainer.log_stats(stats, batch, rewards) + + if script_args.save_freq and epoch and epoch % script_args.save_freq == 0: + ppo_trainer.save_pretrained(script_args.output_dir + f"step_{epoch}") diff --git a/trl/trl/examples/research_projects/stack_llama/scripts/supervised_finetuning.py b/trl/trl/examples/research_projects/stack_llama/scripts/supervised_finetuning.py new file mode 100644 index 0000000000000000000000000000000000000000..47669ac8a71a278cbab863fdf7805110130e9b27 --- /dev/null +++ b/trl/trl/examples/research_projects/stack_llama/scripts/supervised_finetuning.py @@ -0,0 +1,208 @@ +import argparse +import os + +from accelerate import Accelerator +from datasets import load_dataset +from peft import LoraConfig +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, logging, set_seed + +from trl import SFTTrainer +from trl.trainer import ConstantLengthDataset + + +""" +Fine-Tune Llama-7b on SE paired dataset +""" + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str, default="") + parser.add_argument("--dataset_name", type=str, default="lvwerra/stack-exchange-paired") + parser.add_argument("--subset", type=str, default="data/finetune") + parser.add_argument("--split", type=str, default="train") + parser.add_argument("--size_valid_set", type=int, default=4000) + parser.add_argument("--streaming", action="store_true") + parser.add_argument("--shuffle_buffer", type=int, default=5000) + + parser.add_argument("--seq_length", type=int, default=1024) + parser.add_argument("--max_steps", type=int, default=10000) + parser.add_argument("--batch_size", type=int, default=4) + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument("--eos_token_id", type=int, default=49152) + + parser.add_argument("--learning_rate", type=float, default=1e-4) + parser.add_argument("--lr_scheduler_type", type=str, default="cosine") + parser.add_argument("--num_warmup_steps", type=int, default=100) + parser.add_argument("--weight_decay", type=float, default=0.05) + + parser.add_argument("--local_rank", type=int, default=0) + parser.add_argument("--no_fp16", action="store_false") + parser.add_argument("--bf16", action="store_true", default=False) + parser.add_argument("--no_gradient_checkpointing", action="store_false", default=False) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--num_workers", type=int, default=None) + parser.add_argument("--output_dir", type=str, default="./checkpoints") + parser.add_argument("--log_freq", default=1, type=int) + parser.add_argument("--eval_freq", default=1000, type=int) + parser.add_argument("--save_freq", default=1000, type=int) + + return parser.parse_args() + + +def chars_token_ratio(dataset, tokenizer, nb_examples=400): + """ + Estimate the average number of characters per token in the dataset. + """ + total_characters, total_tokens = 0, 0 + for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples): + text = prepare_sample_text(example) + total_characters += len(text) + if tokenizer.is_fast: + total_tokens += len(tokenizer(text).tokens()) + else: + total_tokens += len(tokenizer.tokenize(text)) + + return total_characters / total_tokens + + +def print_trainable_parameters(model): + """ + Prints the number of trainable parameters in the model. + """ + trainable_params = 0 + all_param = 0 + for _, param in model.named_parameters(): + all_param += param.numel() + if param.requires_grad: + trainable_params += param.numel() + print( + f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" + ) + + +def prepare_sample_text(example): + """Prepare the text from a sample of the dataset.""" + text = f"Question: {example['question']}\n\nAnswer: {example['response_j']}" + return text + + +def create_datasets(tokenizer, args): + dataset = load_dataset( + args.dataset_name, + data_dir=args.subset, + split=args.split, + use_auth_token=True, + num_proc=args.num_workers if not args.streaming else None, + streaming=args.streaming, + ) + if args.streaming: + print("Loading the dataset in streaming mode") + valid_data = dataset.take(args.size_valid_set) + train_data = dataset.skip(args.size_valid_set) + train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed) + else: + dataset = dataset.train_test_split(test_size=0.005, seed=args.seed) + train_data = dataset["train"] + valid_data = dataset["test"] + print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}") + + chars_per_token = chars_token_ratio(train_data, tokenizer) + print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}") + + train_dataset = ConstantLengthDataset( + tokenizer, + train_data, + formatting_func=prepare_sample_text, + infinite=True, + seq_length=args.seq_length, + chars_per_token=chars_per_token, + ) + valid_dataset = ConstantLengthDataset( + tokenizer, + valid_data, + formatting_func=prepare_sample_text, + infinite=False, + seq_length=args.seq_length, + chars_per_token=chars_per_token, + ) + return train_dataset, valid_dataset + + +def run_training(args, train_data, val_data): + print("Loading the model") + + lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + train_data.start_iteration = 0 + + print("Starting main loop") + + training_args = TrainingArguments( + output_dir=args.output_dir, + dataloader_drop_last=True, + evaluation_strategy="steps", + max_steps=args.max_steps, + eval_steps=args.eval_freq, + save_steps=args.save_freq, + logging_steps=args.log_freq, + per_device_train_batch_size=args.batch_size, + per_device_eval_batch_size=args.batch_size, + learning_rate=args.learning_rate, + lr_scheduler_type=args.lr_scheduler_type, + warmup_steps=args.num_warmup_steps, + gradient_accumulation_steps=args.gradient_accumulation_steps, + gradient_checkpointing=not args.no_gradient_checkpointing, + fp16=not args.no_fp16, + bf16=args.bf16, + weight_decay=args.weight_decay, + run_name="llama-7b-finetuned", + report_to="wandb", + ddp_find_unused_parameters=False, + ) + + model = AutoModelForCausalLM.from_pretrained( + args.model_path, load_in_8bit=True, device_map={"": Accelerator().process_index} + ) + + trainer = SFTTrainer( + model=model, + args=training_args, + train_dataset=train_data, + eval_dataset=val_data, + peft_config=lora_config, + packing=True, + ) + + print_trainable_parameters(trainer.model) + + print("Training...") + trainer.train() + + print("Saving last checkpoint of the model") + trainer.model.save_pretrained(os.path.join(args.output_dir, "final_checkpoint/")) + + +def main(args): + tokenizer = AutoTokenizer.from_pretrained(args.model_path) + train_dataset, eval_dataset = create_datasets(tokenizer, args) + run_training(args, train_dataset, eval_dataset) + + +if __name__ == "__main__": + args = get_args() + assert args.model_path != "", "Please provide the llama model path" + + set_seed(args.seed) + os.makedirs(args.output_dir, exist_ok=True) + + logging.set_verbosity_error() + + main(args) diff --git a/trl/trl/examples/research_projects/stack_llama_2/scripts/README.md b/trl/trl/examples/research_projects/stack_llama_2/scripts/README.md new file mode 100644 index 0000000000000000000000000000000000000000..727a631d8d120f25f4605d93e97539443fd5da8d --- /dev/null +++ b/trl/trl/examples/research_projects/stack_llama_2/scripts/README.md @@ -0,0 +1,76 @@ +# DPO pipeline for the creation of StackLlaMa 2: a Stack exchange llama-v2-7b model + +## Prerequisites + +Install all the dependencies in the `requirements.txt`: + +``` +$ pip install -U -r requirements.txt +``` + +Since we will use `accelerate` for training, make sure to run: +``` +$ accelerate config +``` + +## Training + +There were two main steps to the DPO training process: +1. Supervised fine-tuning of the base llama-v2-7b model to create llama-v2-7b-se: + + ``` + accelerate launch examples/research_projects/stack_llama_2/scripts/sft_llama2.py \ + --output_dir="./sft" \ + --max_steps=500 \ + --logging_steps=10 \ + --save_steps=10 \ + --per_device_train_batch_size=4 \ + --per_device_eval_batch_size=1 \ + --gradient_accumulation_steps=2 \ + --gradient_checkpointing=False \ + --group_by_length=False \ + --learning_rate=1e-4 \ + --lr_scheduler_type="cosine" \ + --warmup_steps=100 \ + --weight_decay=0.05 \ + --optim="paged_adamw_32bit" \ + --bf16=True \ + --remove_unused_columns=False \ + --run_name="sft_llama2" \ + --report_to="wandb" + ``` +1. Run the DPO trainer using the model saved by the previous step: + ``` + accelerate launch examples/research_projects/stack_llama_2/scripts/dpo_llama2.py \ + --model_name_or_path="sft/final_checkpoint" \ + --output_dir="dpo" + ``` + + +## Merging the adaptors + +To merge the adaptors into the base model we can use the `merge_peft_adapter.py` helper script that comes with TRL: + +``` +python examples/research_projects/stack_llama/scripts/merge_peft_adapter.py --base_model_name="meta-llama/Llama-2-7b-hf" --adapter_model_name="dpo/final_checkpoint/" --output_name="stack-llama-2" +``` + +which will also push the model to your HuggingFace hub account. + +## Running the model + +We can load the DPO-trained LoRA adaptors which were saved by the DPO training step and load them via: + +```py +from peft import AutoPeftModelForCausalLM + + +model = AutoPeftModelForCausalLM.from_pretrained( + "dpo/final_checkpoint", + low_cpu_mem_usage=True, + torch_dtype=torch.float16, + load_in_4bit=True, +) + +model.generate(...) +``` diff --git a/trl/trl/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py b/trl/trl/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py new file mode 100644 index 0000000000000000000000000000000000000000..d21ecd3d4b3b96ea96b8d1de271593f30633e413 --- /dev/null +++ b/trl/trl/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py @@ -0,0 +1,223 @@ +# 0. imports +import os +from dataclasses import dataclass, field +from typing import Dict, Optional + +import torch +from datasets import Dataset, load_dataset +from peft import LoraConfig +from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments + +from trl import DPOTrainer + + +# Define and parse arguments. +@dataclass +class ScriptArguments: + """ + The arguments for the DPO training script. + """ + + # data parameters + beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"}) + + # training parameters + model_name_or_path: Optional[str] = field( + default="../sft/results/final_checkpoint", + metadata={"help": "the location of the SFT model name or path"}, + ) + learning_rate: Optional[float] = field(default=5e-4, metadata={"help": "optimizer learning rate"}) + lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "the lr scheduler type"}) + warmup_steps: Optional[int] = field(default=100, metadata={"help": "the number of warmup steps"}) + weight_decay: Optional[float] = field(default=0.05, metadata={"help": "the weight decay"}) + optimizer_type: Optional[str] = field(default="paged_adamw_32bit", metadata={"help": "the optimizer type"}) + + per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "train batch size per device"}) + per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "eval batch size per device"}) + gradient_accumulation_steps: Optional[int] = field( + default=4, metadata={"help": "the number of gradient accumulation steps"} + ) + gradient_checkpointing: Optional[bool] = field( + default=True, metadata={"help": "whether to use gradient checkpointing"} + ) + + lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"}) + lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"}) + lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"}) + + max_prompt_length: Optional[int] = field(default=512, metadata={"help": "the maximum prompt length"}) + max_length: Optional[int] = field(default=1024, metadata={"help": "the maximum sequence length"}) + max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"}) + logging_steps: Optional[int] = field(default=10, metadata={"help": "the logging frequency"}) + save_steps: Optional[int] = field(default=100, metadata={"help": "the saving frequency"}) + eval_steps: Optional[int] = field(default=100, metadata={"help": "the evaluation frequency"}) + + output_dir: Optional[str] = field(default="./results", metadata={"help": "the output directory"}) + log_freq: Optional[int] = field(default=1, metadata={"help": "the logging frequency"}) + + # instrumentation + sanity_check: Optional[bool] = field(default=False, metadata={"help": "only train on 1000 samples"}) + report_to: Optional[str] = field( + default="wandb", + metadata={ + "help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,' + '`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. ' + 'Use `"all"` to report to all integrations installed, `"none"` for no integrations.' + }, + ) + # debug argument for distributed training + ignore_bias_buffers: Optional[bool] = field( + default=False, + metadata={ + "help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See" + "https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992" + }, + ) + + +def get_stack_exchange_paired( + data_dir: str = "data/rl", + sanity_check: bool = False, + cache_dir: str = None, + num_proc=24, +) -> Dataset: + """Load the stack-exchange-paired dataset from Hugging Face and convert it to the necessary format. + + The dataset is converted to a dictionary with the following structure: + { + 'prompt': List[str], + 'chosen': List[str], + 'rejected': List[str], + } + + Prompts are structured as follows: + "Question: " + + "\n\nAnswer: " + """ + dataset = load_dataset( + "lvwerra/stack-exchange-paired", + split="train", + cache_dir=cache_dir, + data_dir=data_dir, + ) + original_columns = dataset.column_names + + if sanity_check: + dataset = dataset.select(range(min(len(dataset), 1000))) + + def return_prompt_and_responses(samples) -> Dict[str, str]: + return { + "prompt": ["Question: " + question + "\n\nAnswer: " for question in samples["question"]], + "chosen": samples["response_j"], + "rejected": samples["response_k"], + } + + return dataset.map( + return_prompt_and_responses, + batched=True, + num_proc=num_proc, + remove_columns=original_columns, + ) + + +if __name__ == "__main__": + parser = HfArgumentParser(ScriptArguments) + script_args = parser.parse_args_into_dataclasses()[0] + + # 1. load a pretrained model + model = AutoModelForCausalLM.from_pretrained( + script_args.model_name_or_path, + low_cpu_mem_usage=True, + torch_dtype=torch.float16, + load_in_4bit=True, + ) + model.config.use_cache = False + + if script_args.ignore_bias_buffers: + # torch distributed hack + model._ddp_params_and_buffers_to_ignore = [ + name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool + ] + + model_ref = AutoModelForCausalLM.from_pretrained( + script_args.model_name_or_path, + low_cpu_mem_usage=True, + torch_dtype=torch.float16, + load_in_4bit=True, + ) + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + tokenizer.pad_token = tokenizer.eos_token + + # 2. Load the Stack-exchange paired dataset + train_dataset = get_stack_exchange_paired(data_dir="data/rl", sanity_check=script_args.sanity_check) + train_dataset = train_dataset.filter( + lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length + and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length + ) + + # 3. Load evaluation dataset + eval_dataset = get_stack_exchange_paired(data_dir="data/evaluation", sanity_check=True) + eval_dataset = eval_dataset.filter( + lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length + and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length + ) + + # 4. initialize training arguments: + training_args = TrainingArguments( + per_device_train_batch_size=script_args.per_device_train_batch_size, + per_device_eval_batch_size=script_args.per_device_eval_batch_size, + max_steps=script_args.max_steps, + logging_steps=script_args.logging_steps, + save_steps=script_args.save_steps, + gradient_accumulation_steps=script_args.gradient_accumulation_steps, + gradient_checkpointing=script_args.gradient_checkpointing, + learning_rate=script_args.learning_rate, + evaluation_strategy="steps", + eval_steps=script_args.eval_steps, + output_dir=script_args.output_dir, + report_to=script_args.report_to, + lr_scheduler_type=script_args.lr_scheduler_type, + warmup_steps=script_args.warmup_steps, + optim=script_args.optimizer_type, + bf16=True, + remove_unused_columns=False, + run_name="dpo_llama2", + ) + + peft_config = LoraConfig( + r=script_args.lora_r, + lora_alpha=script_args.lora_alpha, + lora_dropout=script_args.lora_dropout, + target_modules=[ + "q_proj", + "v_proj", + "k_proj", + "out_proj", + "fc_in", + "fc_out", + "wte", + ], + bias="none", + task_type="CAUSAL_LM", + ) + + # 5. initialize the DPO trainer + dpo_trainer = DPOTrainer( + model, + model_ref, + args=training_args, + beta=script_args.beta, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + peft_config=peft_config, + max_prompt_length=script_args.max_prompt_length, + max_length=script_args.max_length, + ) + + # 6. train + dpo_trainer.train() + dpo_trainer.save_model(script_args.output_dir) + + # 7. save + output_dir = os.path.join(script_args.output_dir, "final_checkpoint") + dpo_trainer.model.save_pretrained(output_dir) diff --git a/trl/trl/examples/research_projects/stack_llama_2/scripts/requirements.txt b/trl/trl/examples/research_projects/stack_llama_2/scripts/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..ca124e58df8e4269a4d44d3ceccd0e2a05ea4fae --- /dev/null +++ b/trl/trl/examples/research_projects/stack_llama_2/scripts/requirements.txt @@ -0,0 +1,7 @@ +transformers +trl +peft +accelerate +datasets +bitsandbytes +wandb diff --git a/trl/trl/examples/research_projects/stack_llama_2/scripts/sft_llama2.py b/trl/trl/examples/research_projects/stack_llama_2/scripts/sft_llama2.py new file mode 100644 index 0000000000000000000000000000000000000000..94fc4e72c8125bd61240de055cade6a7a2d8978f --- /dev/null +++ b/trl/trl/examples/research_projects/stack_llama_2/scripts/sft_llama2.py @@ -0,0 +1,185 @@ +# Fine-Tune Llama2-7b on SE paired dataset +import os +from dataclasses import dataclass, field +from typing import Optional + +import torch +from accelerate import Accelerator +from datasets import load_dataset +from peft import AutoPeftModelForCausalLM, LoraConfig +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments + +from trl import SFTTrainer +from trl.import_utils import is_xpu_available +from trl.trainer import ConstantLengthDataset + + +@dataclass +class ScriptArguments: + model_name: Optional[str] = field(default="meta-llama/Llama-2-7b-hf", metadata={"help": "the model name"}) + dataset_name: Optional[str] = field(default="lvwerra/stack-exchange-paired", metadata={"help": "the dataset name"}) + subset: Optional[str] = field(default="data/finetune", metadata={"help": "the subset to use"}) + split: Optional[str] = field(default="train", metadata={"help": "the split to use"}) + size_valid_set: Optional[int] = field(default=4000, metadata={"help": "the size of the validation set"}) + streaming: Optional[bool] = field(default=True, metadata={"help": "whether to stream the dataset"}) + shuffle_buffer: Optional[int] = field(default=5000, metadata={"help": "the shuffle buffer size"}) + seq_length: Optional[int] = field(default=1024, metadata={"help": "the sequence length"}) + num_workers: Optional[int] = field(default=4, metadata={"help": "the number of workers"}) + packing: Optional[bool] = field(default=True, metadata={"help": "whether to use packing for SFTTrainer"}) + + # LoraConfig + lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"}) + lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"}) + lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"}) + + +parser = HfArgumentParser((ScriptArguments, TrainingArguments)) +script_args, training_args = parser.parse_args_into_dataclasses() +peft_config = LoraConfig( + r=script_args.lora_r, + lora_alpha=script_args.lora_alpha, + lora_dropout=script_args.lora_dropout, + target_modules=["q_proj", "v_proj"], + bias="none", + task_type="CAUSAL_LM", +) + +if training_args.group_by_length and script_args.packing: + raise ValueError("Cannot use both packing and group by length") + +# `gradient_checkpointing` was True by default until `1f3314`, but it's actually not used. +# `gradient_checkpointing=True` will cause `Variable._execution_engine.run_backward`. +if training_args.gradient_checkpointing: + raise ValueError("gradient_checkpointing not supported") + + +def chars_token_ratio(dataset, tokenizer, nb_examples=400): + """ + Estimate the average number of characters per token in the dataset. + """ + total_characters, total_tokens = 0, 0 + for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples): + text = prepare_sample_text(example) + total_characters += len(text) + if tokenizer.is_fast: + total_tokens += len(tokenizer(text).tokens()) + else: + total_tokens += len(tokenizer.tokenize(text)) + + return total_characters / total_tokens + + +def print_trainable_parameters(model): + """ + Prints the number of trainable parameters in the model. + """ + trainable_params = 0 + all_param = 0 + for _, param in model.named_parameters(): + all_param += param.numel() + if param.requires_grad: + trainable_params += param.numel() + print( + f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" + ) + + +def prepare_sample_text(example): + """Prepare the text from a sample of the dataset.""" + text = f"Question: {example['question']}\n\nAnswer: {example['response_j']}" + return text + + +def create_datasets(tokenizer, args): + dataset = load_dataset( + args.dataset_name, + data_dir=args.subset, + split=args.split, + use_auth_token=True, + num_proc=args.num_workers if not args.streaming else None, + streaming=args.streaming, + ) + if args.streaming: + print("Loading the dataset in streaming mode") + valid_data = dataset.take(args.size_valid_set) + train_data = dataset.skip(args.size_valid_set) + train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=None) + else: + dataset = dataset.train_test_split(test_size=0.005, seed=None) + train_data = dataset["train"] + valid_data = dataset["test"] + print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}") + + chars_per_token = chars_token_ratio(train_data, tokenizer) + print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}") + + train_dataset = ConstantLengthDataset( + tokenizer, + train_data, + formatting_func=prepare_sample_text, + infinite=True, + seq_length=args.seq_length, + chars_per_token=chars_per_token, + ) + valid_dataset = ConstantLengthDataset( + tokenizer, + valid_data, + formatting_func=prepare_sample_text, + infinite=False, + seq_length=args.seq_length, + chars_per_token=chars_per_token, + ) + return train_dataset, valid_dataset + + +bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, +) + +base_model = AutoModelForCausalLM.from_pretrained( + script_args.model_name, + quantization_config=bnb_config, + device_map={"": Accelerator().local_process_index}, + trust_remote_code=True, + use_auth_token=True, +) +base_model.config.use_cache = False + + +tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, trust_remote_code=True) +tokenizer.pad_token = tokenizer.eos_token +tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training + +train_dataset, eval_dataset = create_datasets(tokenizer, script_args) + +trainer = SFTTrainer( + model=base_model, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + peft_config=peft_config, + packing=script_args.packing, + max_seq_length=None, + tokenizer=tokenizer, + args=training_args, +) +trainer.train() +trainer.save_model(training_args.output_dir) + +output_dir = os.path.join(training_args.output_dir, "final_checkpoint") +trainer.model.save_pretrained(output_dir) + +# Free memory for merging weights +del base_model +if is_xpu_available(): + torch.xpu.empty_cache() +else: + torch.cuda.empty_cache() + +model = AutoPeftModelForCausalLM.from_pretrained(output_dir, device_map="auto", torch_dtype=torch.bfloat16) +model = model.merge_and_unload() + +output_merged_dir = os.path.join(training_args.output_dir, "final_merged_checkpoint") +model.save_pretrained(output_merged_dir, safe_serialization=True) diff --git a/trl/trl/examples/research_projects/tools/calculator.py b/trl/trl/examples/research_projects/tools/calculator.py new file mode 100644 index 0000000000000000000000000000000000000000..76779695fe741eda8f293be5d24c1a521b13efc3 --- /dev/null +++ b/trl/trl/examples/research_projects/tools/calculator.py @@ -0,0 +1,119 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +import numpy as np +import torch +from transformers import AutoTokenizer, load_tool + +from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, TextEnvironment + + +def generate_data(n): + """Generate random arithmetic tasks and answers.""" + tasks, answers = [], [] + for _ in range(n): + a = np.random.randint(0, 50) + b = np.random.randint(0, 50) + op = np.random.choice(["-", "+", "*"]) + tasks.append(f"\n\nWhat is {a} {op} {b}?") + if op == "-": + answers.append(a - b) + elif op == "+": + answers.append(a + b) + else: + answers.append(a * b) + return tasks, answers + + +def exact_match_reward(responses, answers=None): + """Reward if generated response contains correct answer.""" + rewards = [] + pattern = r"Result\s*=\s*(-?\d+(?:\.\d+)?)\s*" # generated by chatGPT + for response, answer in zip(responses, answers): + reward = 0.0 + predicted_number = None + match_pattern = re.findall(pattern, response) + if match_pattern: + predicted_number = float(match_pattern[0]) + if predicted_number is not None: + if np.abs(predicted_number - answer) < 0.01: + reward += 1.0 + rewards.append(torch.tensor(reward)) + return rewards + + +# set up models +model_id = "gpt2" +model = AutoModelForCausalLMWithValueHead.from_pretrained(model_id) +model_ref = AutoModelForCausalLMWithValueHead.from_pretrained(model_id) +tokenizer = AutoTokenizer.from_pretrained(model_id) +tokenizer.pad_token = tokenizer.eos_token + +# system prompt +prompt = """\ +What is 13-3? + +13-310.0 + +Result=10 + +What is 4*3? + +4*312.0 + +Result=12""" + +generation_kwargs = { + "min_length": -1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + "pad_token_id": tokenizer.eos_token_id, + "eos_token_id": -1, + "max_new_tokens": 32, +} + +# trainer +ppo_config = PPOConfig( + batch_size=256, + learning_rate=1.41e-5, + mini_batch_size=64, + log_with="wandb", +) +ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer) + +# text env +text_env = TextEnvironment( + model, + tokenizer, + {"SimpleCalculatorTool": load_tool("ybelkada/simple-calculator")}, + exact_match_reward, + prompt, + generation_kwargs=generation_kwargs, +) + +# main training loop +for step in range(100): + tasks, answers = generate_data(ppo_config.batch_size) + queries, responses, masks, rewards, histories = text_env.run(tasks, answers=answers) + train_stats = ppo_trainer.step(queries, responses, rewards, masks) + + response_texts = [tokenizer.decode(response) for response in responses] + query_texts = [tokenizer.decode(query) for query in queries] + texts = {"query": [qt.split("")[-1].strip() for qt in query_texts], "response": response_texts} + ppo_trainer.log_stats(train_stats, texts, rewards, columns_to_log=["query", "response", "answer"]) +ppo_trainer.save_pretrained(model_id + "-calculator") diff --git a/trl/trl/examples/research_projects/tools/python_interpreter.py b/trl/trl/examples/research_projects/tools/python_interpreter.py new file mode 100644 index 0000000000000000000000000000000000000000..b7b69806ef31922a4d49b9dc823e78dd2d9b49c4 --- /dev/null +++ b/trl/trl/examples/research_projects/tools/python_interpreter.py @@ -0,0 +1,194 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +from dataclasses import dataclass, field +from typing import Optional + +import numpy as np +import torch +from datasets import load_dataset +from peft import LoraConfig +from transformers import AutoTokenizer, HfArgumentParser, load_tool + +from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, TextEnvironment + + +os.environ["HF_ALLOW_CODE_EVAL"] = "1" +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +@dataclass +class ScriptArguments: + model_name: Optional[str] = field(default="bigcode/starcoderbase", metadata={"help": "the model name"}) + learning_rate: Optional[float] = field(default=1e-5, metadata={"help": "the learning rate"}) + mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"}) + batch_size: Optional[int] = field(default=32, metadata={"help": "the batch size"}) + gradient_accumulation_steps: Optional[int] = field( + default=16, metadata={"help": "the number of gradient accumulation steps"} + ) + max_new_tokens: Optional[int] = field(default=256, metadata={"help": "max number of generated tokens per turn"}) + ppo_epochs: Optional[int] = field(default=1, metadata={"help": "max number of ppo epochs"}) + n_epochs: Optional[int] = field(default=32, metadata={"help": "max number of ppo epochs"}) + + +parser = HfArgumentParser(ScriptArguments) +args = parser.parse_args_into_dataclasses()[0] + + +def exact_match_reward(responses, answers=None): + """Reward if generated response contains correct answer.""" + rewards = [] + pattern = r"Result\s*=\s*(-?\d+(?:\.\d+)?)\s*" # generated by chatGPT + for response, answer in zip(responses, answers): + reward = 0.0 + try: + predicted_number = None + match_pattern = re.findall(pattern, response) + if match_pattern: + predicted_number = float(match_pattern[0]) + if predicted_number is not None: + if np.abs((predicted_number - float(answer))) < 0.1: + reward += 1.0 + except: # noqa + pass + rewards.append(torch.tensor(reward)) + return rewards + + +def evaluate(test_dataloader, text_env, ppo_trainer): + test_rewards = [] + for test_batch in test_dataloader: + _, _, _, rewards, _ = text_env.run(test_batch["query"], answers=test_batch["answer"]) + test_rewards.extend(rewards) + test_rewards = ppo_trainer.accelerator.gather_for_metrics( + torch.stack(test_rewards).to(ppo_trainer.accelerator.device) + ) + return test_rewards.mean() + + +lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + target_modules=["c_proj", "c_attn", "q_attn"], +) + +# set up models +model = AutoModelForCausalLMWithValueHead.from_pretrained( + args.model_name, + use_auth_token=True, + load_in_4bit=True, + peft_config=lora_config, +) +tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_auth_token=True) +tokenizer.pad_token = tokenizer.eos_token + +ds = load_dataset("gsm8k", "main", split="train") +ds = ds.rename_columns({"question": "query"}) +ds = ds.map(lambda x: {"answer": x["answer"].split("#### ")[1]}) +ds = ds.select(range(1, len(ds))) # skip the first sample which is used in prompt + +ds_test = load_dataset("gsm8k", "main", split="test") +ds_test = ds_test.rename_columns({"question": "query"}) +ds_test = ds_test.map(lambda x: {"answer": x["answer"].split("#### ")[1]}) + +test_dataloader = torch.utils.data.DataLoader(ds_test, batch_size=args.batch_size) + +# prompt +prompt = """\ +Example of using a Python API to solve math questions. + +Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left? + + +def solution(): + money_initial = 23 + bagels = 5 + bagel_cost = 3 + money_spent = bagels * bagel_cost + money_left = money_initial - money_spent + result = money_left + return result +print(solution()) +72 + +Result = 72 + +Q: """ + +generation_kwargs = { + "min_length": -1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + "pad_token_id": tokenizer.eos_token_id, + "eos_token_id": -1, + "max_new_tokens": args.max_new_tokens, +} + +# trainer +ppo_config = PPOConfig( + batch_size=args.batch_size, + learning_rate=args.learning_rate, + mini_batch_size=args.mini_batch_size, + ppo_epochs=args.ppo_epochs, + gradient_accumulation_steps=args.gradient_accumulation_steps, + log_with="wandb", + tracker_project_name="trl-gsm8k", + remove_unused_columns=False, + optimize_cuda_cache=True, +) + +ppo_trainer = PPOTrainer(config=ppo_config, model=model, tokenizer=tokenizer, dataset=ds) +test_dataloader = ppo_trainer.accelerator.prepare(test_dataloader) + +# text env +text_env = TextEnvironment( + model, + tokenizer, + [load_tool("lvwerra/python-interpreter")], + exact_match_reward, + prompt, + max_turns=2, + generation_kwargs=generation_kwargs, +) + +# main training loop +for epoch in range(args.n_epochs): + for step, batch in enumerate(ppo_trainer.dataloader): + if (step == 0) and (epoch % 4 == 0): # evaluate every 4 epochs + reward_mean_test = evaluate(test_dataloader, text_env, ppo_trainer) + else: + reward_mean_test = None + + queries, responses, masks, rewards, histories = text_env.run(batch["query"], answers=batch["answer"]) + train_stats = ppo_trainer.step(queries, responses, rewards, masks) + + # logging + if reward_mean_test is not None: + train_stats["env/reward_mean_test"] = reward_mean_test + texts = { + "query": batch["query"], + "response": [tokenizer.decode(response) for response in responses], + "answer": batch["answer"], + } + ppo_trainer.log_stats(train_stats, texts, rewards, columns_to_log=["query", "response", "answer"]) + +reward_mean_test = evaluate(test_dataloader, text_env, ppo_trainer) +ppo_trainer.save_pretrained(f"model/{args.model_name}-gsm8k") diff --git a/trl/trl/examples/research_projects/tools/triviaqa.py b/trl/trl/examples/research_projects/tools/triviaqa.py new file mode 100644 index 0000000000000000000000000000000000000000..bd3bd9016618a2d9625831c1cd8c970b15bea646 --- /dev/null +++ b/trl/trl/examples/research_projects/tools/triviaqa.py @@ -0,0 +1,191 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from dataclasses import dataclass, field +from typing import Optional + +import torch +from datasets import load_dataset +from peft import LoraConfig +from transformers import AutoTokenizer, HfArgumentParser, load_tool + +from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, TextEnvironment + + +os.environ["HF_ALLOW_CODE_EVAL"] = "1" +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +@dataclass +class ScriptArguments: + model_name: Optional[str] = field(default="bigcode/starcoderbase", metadata={"help": "the model name"}) + log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"}) + learning_rate: Optional[float] = field(default=1e-5, metadata={"help": "the learning rate"}) + mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"}) + batch_size: Optional[int] = field(default=32, metadata={"help": "the batch size"}) + gradient_accumulation_steps: Optional[int] = field( + default=16, metadata={"help": "the number of gradient accumulation steps"} + ) + max_new_tokens: Optional[int] = field(default=256, metadata={"help": "max number of generated tokens per turn"}) + ppo_epochs: Optional[int] = field(default=1, metadata={"help": "max number of ppo epochs"}) + iterations: Optional[int] = field(default=1000, metadata={"help": "the number of iterations"}) + seed: Optional[int] = field(default=0, metadata={"help": "the random seed"}) + + +parser = HfArgumentParser(ScriptArguments) +args = parser.parse_args_into_dataclasses()[0] + +lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + target_modules=["c_proj", "c_attn", "q_attn"], +) + +# set up models +model = AutoModelForCausalLMWithValueHead.from_pretrained( + args.model_name, + use_auth_token=True, + trust_remote_code=True, + load_in_4bit=True, + peft_config=lora_config, +) +tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_auth_token=True) +tokenizer.pad_token = tokenizer.eos_token + +# system prompt +prompt = """\ +Answer the following question: + +Q: In which branch of the arts is Patricia Neary famous? +A: Ballets +A2: Patricia NearyPatricia Neary (born October 27, 1942) is an American ballerina, choreographer and ballet director, who has been particularly active in Switzerland. She has also been a highly successful ambassador for the Balanchine Trust, bringing George Balanchine's ballets to 60 cities around the globe. +Result=Ballets + +Q: Who won Super Bowl XX? +A: Chicago Bears +A2: Super Bowl XXSuper Bowl XX was an American football game between the National Football Conference (NFC) champion Chicago Bears and the American Football Conference (AFC) champion New England Patriots to decide the National Football League (NFL) champion for the 1985 season. The Bears defeated the Patriots by the score of 46–10, capturing their first NFL championship (and Chicago's first overall sports victory) since 1963, three years prior to the birth of the Super Bowl. Super Bowl XX was played on January 26, 1986 at the Louisiana Superdome in New Orleans. +Result=Chicago Bears + +Q: """ + +generation_kwargs = { + "min_length": -1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + "pad_token_id": tokenizer.eos_token_id, + "eos_token_id": -1, + "max_new_tokens": args.max_new_tokens, +} + +# trainer +config = PPOConfig( + batch_size=args.batch_size, + model_name=args.model_name, + learning_rate=args.learning_rate, + log_with=args.log_with, + mini_batch_size=args.mini_batch_size, + ppo_epochs=args.ppo_epochs, + gradient_accumulation_steps=args.gradient_accumulation_steps, + seed=args.seed, + optimize_cuda_cache=True, +) +ppo_trainer = PPOTrainer(config=config, model=model, tokenizer=tokenizer) +dataset = load_dataset("trivia_qa", "rc", split="train") +local_seed = args.seed + ppo_trainer.accelerator.process_index * 100003 # Prime +dataset = dataset.shuffle(local_seed) + + +def data_generator(): + for i in range(len(dataset)): + yield dataset[i]["question"], [item for item in dataset[i]["answer"]["normalized_aliases"]] + + +gen = data_generator() +gen = iter(gen) + + +def generate_data(n): + tasks, answers = [], [] + for i in range(n): + q, a = next(gen) + tasks.append(q) + answers.append(a) + return tasks, answers + + +def exact_match_reward(responses, answers=None): + """Reward if generated response contains correct answer.""" + rewards = [] + for response, answer in zip(responses, answers): + reward = 0.0 + for a in answer: + if a.lower() in response.lower(): + reward += 1.0 + break + rewards.append(torch.tensor(reward)) + return rewards + + +# text env +tool = load_tool("vwxyzjn/pyserini-wikipedia-kilt-doc") +# limit the amount if tokens +tool_fn = lambda x: tool(x).split("\n")[1][:600] # noqa +text_env = TextEnvironment( + model, + tokenizer, + {"Wiki": tool_fn}, + exact_match_reward, + prompt, + generation_kwargs=generation_kwargs, + max_tool_reponse=400, +) + + +def print_trainable_parameters(model): + trainable_params = 0 + all_param = 0 + for _, param in model.named_parameters(): + all_param += param.numel() + if param.requires_grad: + trainable_params += param.numel() + print( + f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" + ) + + +print_trainable_parameters(model) +# main training loop +for i in range(args.iterations): + tasks, answers = generate_data(config.batch_size) + queries, responses, masks, rewards, histories = text_env.run(tasks, answers=answers) + train_stats = ppo_trainer.step(queries, responses, rewards, masks) + response_texts = [tokenizer.decode(response) for response in responses] + query_texts = [tokenizer.decode(query) for query in queries] + texts = { + "query": [qt.split("")[-1].strip() for qt in query_texts], + "response": response_texts, + "answer": [", ".join(item) for item in answers], + } + all_rewards = ppo_trainer.accelerator.gather(torch.tensor(rewards, device=ppo_trainer.accelerator.device)) + ppo_trainer.log_stats( + train_stats, texts, [item for item in all_rewards], columns_to_log=["query", "response", "answer"] + ) + if i % 100 == 0: + ppo_trainer.save_pretrained(f"models/{args.model_name}_{args.seed}_{i}_triviaqa") diff --git a/trl/trl/examples/research_projects/toxicity/README.md b/trl/trl/examples/research_projects/toxicity/README.md new file mode 100644 index 0000000000000000000000000000000000000000..85967ab57ec5eeb10ea9eb6e372a62a0522e4d7e --- /dev/null +++ b/trl/trl/examples/research_projects/toxicity/README.md @@ -0,0 +1,7 @@ +# De-detoxifying language models + +To run this code, do the following: + +```shell +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file {CONFIG} examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py --log_with wandb +``` diff --git a/trl/trl/examples/research_projects/toxicity/scripts/evaluate-toxicity.py b/trl/trl/examples/research_projects/toxicity/scripts/evaluate-toxicity.py new file mode 100644 index 0000000000000000000000000000000000000000..c400641967544d96b768bd43f84536c393fe7684 --- /dev/null +++ b/trl/trl/examples/research_projects/toxicity/scripts/evaluate-toxicity.py @@ -0,0 +1,130 @@ +import argparse +import csv + +import evaluate +import numpy as np +import torch +from datasets import load_dataset +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer + +from trl.import_utils import is_xpu_available + + +toxicity = evaluate.load("ybelkada/toxicity", "DaNLP/da-electra-hatespeech-detection", module_type="measurement") +ds = load_dataset("OxAISH-AL-LLM/wiki_toxic", split="test") + +parser = argparse.ArgumentParser(description="Evaluate de-toxified models") +parser.add_argument("--model_type", default="all", type=str, help="Relative path to the source model folder") +parser.add_argument("--output_file", default="toxicity.csv", type=str, help="Relative path to the source model folder") +parser.add_argument("--batch_size", default=64, type=int, help="Batch size") +parser.add_argument("--num_samples", default=400, type=int, help="Number of samples") +parser.add_argument("--context_length", default=2000, type=int, help="Number of samples") +parser.add_argument("--max_new_tokens", default=30, type=int, help="Max new tokens for generation") +args = parser.parse_args() + + +if args.model_type == "all": + MODELS_TO_TEST = [ + "ybelkada/gpt-neo-125m-detox", + "EleutherAI/gpt-neo-125M", + "EleutherAI/gpt-neo-2.7B", + "ybelkada/gpt-neo-2.7B-detox", + "ybelkada/gpt-j-6b-sharded-bf16", + "ybelkada/gpt-j-6b-detoxs", + ] +elif args.model_type == "gpt-neo": + MODELS_TO_TEST = [ + "ybelkada/gpt-neo-125m-detox", + "EleutherAI/gpt-neo-125M", + "EleutherAI/gpt-neo-2.7B", + "ybelkada/gpt-neo-2.7B-detox", + ] +elif args.model_type == "gpt-j": + MODELS_TO_TEST = [ + "ybelkada/gpt-j-6b-sharded-bf16", + "ybelkada/gpt-j-6b-detox", + ] +else: + MODELS_TO_TEST = [args.model_type] +NUM_SAMPLES = args.num_samples +BATCH_SIZE = args.batch_size +output_file = args.output_file +max_new_tokens = args.max_new_tokens +context_length = args.context_length +if is_xpu_available(): + device = torch.xpu.current_device() +else: + device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu" + +# consider only toxic prompts +ds = ds.filter(lambda x: x["label"] == 1) + +toxicities = {} + +# open a csv file +file = open(f"{output_file}", "w", newline="") +writer = csv.writer(file) +# add first rows +writer.writerow(["model_id", "mean_toxicity", "std_toxicity"]) + + +for model_id in tqdm(MODELS_TO_TEST): + model = AutoModelForCausalLM.from_pretrained(model_id, device_map={"": device}, torch_dtype=torch.bfloat16) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + input_texts = [] + + for i, example in enumerate(ds): + # set seed + torch.manual_seed(42) + + input_text = example["comment_text"] + input_texts.append(input_text[:2000]) + + if i > NUM_SAMPLES: + break + + if (i + 1) % BATCH_SIZE == 0: + inputs = tokenizer(input_texts, return_tensors="pt", padding=True).to(device) + inputs.input_ids = inputs.input_ids[:context_length] + inputs.attention_mask = inputs.attention_mask[:context_length] + outputs = model.generate(**inputs, do_sample=True, max_new_tokens=max_new_tokens, use_cache=True) + generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True) + generated_texts = [ + generated_text.replace(input_texts[i], "") for i, generated_text in enumerate(generated_texts) + ] + toxicity_score = toxicity.compute(predictions=generated_texts) + input_texts = [] + + if model_id not in toxicities: + toxicities[model_id] = [] + toxicities[model_id].extend(toxicity_score["toxicity"]) + + # last batch + inputs = tokenizer(input_texts, return_tensors="pt", padding=True).to(device) + outputs = model.generate(**inputs, do_sample=True, max_new_tokens=30) + generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True) + generated_texts = [generated_text.replace(input_texts[i], "") for i, generated_text in enumerate(generated_texts)] + toxicity_score = toxicity.compute(predictions=generated_texts) + toxicities[model_id].extend(toxicity_score["toxicity"]) + + # compute mean & std using np + mean = np.mean(toxicities[model_id]) + std = np.std(toxicities[model_id]) + + # save to file + writer.writerow([model_id, mean, std]) + + # print + print(f"Model: {model_id} - Mean: {mean} - Std: {std}") + + model = None + if is_xpu_available(): + torch.xpu.empty_cache() + else: + torch.cuda.empty_cache() + +# close file +file.close() diff --git a/trl/trl/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py b/trl/trl/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py new file mode 100644 index 0000000000000000000000000000000000000000..a4fc18534b25d7dd564816675d47caa82957896a --- /dev/null +++ b/trl/trl/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py @@ -0,0 +1,238 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field +from typing import Optional + +import torch +from datasets import load_dataset +from torch.optim import Adam +from tqdm import tqdm +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + HfArgumentParser, + RobertaForSequenceClassification, + RobertaTokenizer, +) + +from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, create_reference_model, set_seed +from trl.core import LengthSampler + + +tqdm.pandas() + +######################################################################## +# This is a fully working simple example to use trl with accelerate. +# +# This example fine-tunes a GPTJ model to generate less toxic contents +# by using allenai/real-toxicity-prompts dataset. We use PPO +# (proximal policy optimization) to optimize the model. +# in any of the following settings (with the same script): +# - single CPU or single GPU +# - multi GPUS (using PyTorch distributed mode) +# - multi GPUS (using DeepSpeed ZeRO-Offload stages 1 & 2) +# - fp16 (mixed-precision) or fp32 (normal precision) +# +# To run it in each of these various modes, first initialize the accelerate +# configuration with `accelerate config` +# +######################################################################## + + +# We first define the configuration of the experiment, defining the model, the dataset, +# the training parameters, and the PPO parameters. +# Check the default arguments in the `PPOConfig` class for more details. +# If you want to log with tensorboard, add the kwarg +# `project_kwargs={"logging_dir": PATH_TO_LOGS}` to the PPOConfig. +@dataclass +class ScriptArguments: + """ + The name of the Casual LM model we wish to fine with PPO + """ + + # NOTE: gpt2 models use Conv1D instead of Linear layers which are not yet supported in 8 bit mode + # models like gpt-neo* models are more suitable. + model_name: Optional[str] = field(default="ybelkada/gpt-j-6b-sharded-bf16", metadata={"help": "the model name"}) + log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"}) + learning_rate: Optional[float] = field(default=(1.47e-5) * 2, metadata={"help": "the learning rate"}) + mini_batch_size: Optional[int] = field(default=4, metadata={"help": "the PPO minibatch size"}) + batch_size: Optional[int] = field(default=16, metadata={"help": "the batch size"}) + gradient_accumulation_steps: Optional[int] = field( + default=1, metadata={"help": "the number of gradient accumulation steps"} + ) + model_save_path: Optional[str] = field( + default="./gpt-j-6B-detoxified-long-context-26-shl-1e4-final", + metadata={"help": "the path to save the model"}, + ) + + +parser = HfArgumentParser(ScriptArguments) +script_args = parser.parse_args_into_dataclasses()[0] + +config = PPOConfig( + model_name=script_args.model_name, + learning_rate=script_args.learning_rate, + log_with=script_args.log_with, + ppo_epochs=100, + mini_batch_size=script_args.mini_batch_size, + batch_size=script_args.batch_size, + gradient_accumulation_steps=script_args.gradient_accumulation_steps, +) + + +# Below is an example function to build the dataset. In our case, we use the IMDB dataset +# from the `datasets` library. One should customize this function to train the model on +# its own dataset. +def build_dataset( + config, dataset_name="allenai/real-toxicity-prompts", input_min_text_length=5, input_max_text_length=10 +): + """ + Build dataset for training. This builds the dataset from `load_dataset`, one should + customize this function to train the model on its own dataset. + + Args: + dataset_name (`str`): + The name of the dataset to be loaded. + + Returns: + dataloader (`torch.utils.data.DataLoader`): + The dataloader for the dataset. + """ + tokenizer = AutoTokenizer.from_pretrained(config.model_name) + tokenizer.pad_token = tokenizer.eos_token + + ds = load_dataset(dataset_name, split="train") + + def filter_fn(sample): + toxicity = sample["prompt"]["toxicity"] + return toxicity is not None and toxicity > 0.3 + + ds = ds.filter(filter_fn, batched=False) + + input_size = LengthSampler(input_min_text_length, input_max_text_length) + + def tokenize(sample): + prompt = sample["prompt"]["text"] + continuation = sample["continuation"]["text"] + + sample["input_ids"] = tokenizer.encode(prompt + continuation)[: input_size()] + sample["query"] = tokenizer.decode(sample["input_ids"]) + return sample + + ds = ds.map(tokenize, batched=False) + ds.set_format(type="torch") + + ds = ds.train_test_split(test_size=0.2, shuffle=False)["train"] + + return ds + + +# We retrieve the dataloader by calling the `build_dataset` function. +min_input_length = 30 +max_input_length = 40 +dataset = build_dataset(config, input_min_text_length=min_input_length, input_max_text_length=max_input_length) + + +def collator(data): + return dict((key, [d[key] for d in data]) for key in data[0]) + + +# set seed before initializing value head for deterministic eval +set_seed(config.seed) + +# Now let's build the model, the reference model, and the tokenizer. We first load the model +# in bfloat16 to save memory using `transformers`. +model = AutoModelForCausalLM.from_pretrained(config.model_name, torch_dtype=torch.bfloat16) +# And then we pass the loaded model to `AutoModelForCausalLMWithValueHead`. +model = AutoModelForCausalLMWithValueHead.from_pretrained(model) + +# We create a reference model by sharing 20 layers +ref_model = create_reference_model(model, num_shared_layers=20) + +# We make sure to use `Adam` optimizer on the model parameters that require gradients. +optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.learning_rate) + +# GPT-2 / GPT-J tokenizer has a pad token, but it is not eos_token by default. We need to set it to eos_token. +# only for this model. +tokenizer = AutoTokenizer.from_pretrained(config.model_name) +tokenizer.pad_token = tokenizer.eos_token + +# We then build the PPOTrainer, passing the model, the reference model, the tokenizer +ppo_trainer = PPOTrainer( + config, + model, + ref_model=ref_model, + tokenizer=tokenizer, + dataset=dataset, + data_collator=collator, + optimizer=optimizer, +) + +# We then build the reward pipeline, we will use the toxicity model to compute the reward. +# We first load the toxicity model and tokenizer. +toxicity_model_id = "facebook/roberta-hate-speech-dynabench-r4-target" +toxicity_tokenizer = RobertaTokenizer.from_pretrained(toxicity_model_id) +# We load the toxicity model in fp16 to save memory. +toxicity_model = RobertaForSequenceClassification.from_pretrained(toxicity_model_id, torch_dtype=torch.float16).to( + ppo_trainer.accelerator.device +) + + +# We then define the arguments to pass to the `generate` function. These arguments +# are passed to the `generate` function of the PPOTrainer, which is a wrapper around +# the `generate` function of the trained model. +generation_kwargs = { + "min_length": -1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + "pad_token_id": tokenizer.eos_token_id, +} +output_min_length = 20 +output_max_length = 30 +output_length_sampler = LengthSampler(output_min_length, output_max_length) + +model_save_path = script_args.model_save_path + +for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)): + query_tensors = batch["input_ids"] + + # Get response from the policy model + response_tensors = [] + for query in query_tensors: + gen_len = output_length_sampler() + generation_kwargs["max_new_tokens"] = gen_len + response = ppo_trainer.generate(query, **generation_kwargs) + response_tensors.append(response.squeeze()[-gen_len:]) + batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors] + + # Compute sentiment score # noqa + texts = batch["response"] + toxicity_inputs = toxicity_tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to( + ppo_trainer.accelerator.device + ) + logits = toxicity_model(**toxicity_inputs).logits.float() + toxicity_labels = (logits[:, 0]).tolist() + + rewards = [torch.tensor(output) for output in toxicity_labels] + + # Run PPO step + stats = ppo_trainer.step(query_tensors, response_tensors, rewards) + ppo_trainer.log_stats(stats, batch, rewards) + + # Save model every 100 epochs + if epoch % 100 == 0: + if ppo_trainer.accelerator.is_main_process: + ppo_trainer.save_pretrained(model_save_path) diff --git a/trl/trl/examples/scripts/ddpo.py b/trl/trl/examples/scripts/ddpo.py new file mode 100644 index 0000000000000000000000000000000000000000..d42145e4d5aff761826c0d6bfefda6712c92bd22 --- /dev/null +++ b/trl/trl/examples/scripts/ddpo.py @@ -0,0 +1,204 @@ +# Copyright 2023 metric-space, The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +from dataclasses import dataclass, field + +import numpy as np +import torch +import torch.nn as nn +import tyro +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import EntryNotFoundError +from transformers import CLIPModel, CLIPProcessor + +from trl import DDPOConfig, DDPOTrainer, DefaultDDPOStableDiffusionPipeline +from trl.import_utils import is_xpu_available + + +@dataclass +class ScriptArguments: + hf_user_access_token: str + pretrained_model: str = "runwayml/stable-diffusion-v1-5" + """the pretrained model to use""" + pretrained_revision: str = "main" + """the pretrained model revision to use""" + hf_hub_model_id: str = "ddpo-finetuned-stable-diffusion" + """HuggingFace repo to save model weights to""" + hf_hub_aesthetic_model_id: str = "trl-lib/ddpo-aesthetic-predictor" + """HuggingFace model ID for aesthetic scorer model weights""" + hf_hub_aesthetic_model_filename: str = "aesthetic-model.pth" + """HuggingFace model filename for aesthetic scorer model weights""" + + ddpo_config: DDPOConfig = field( + default_factory=lambda: DDPOConfig( + num_epochs=200, + train_gradient_accumulation_steps=1, + sample_num_steps=50, + sample_batch_size=6, + train_batch_size=3, + sample_num_batches_per_epoch=4, + per_prompt_stat_tracking=True, + per_prompt_stat_tracking_buffer_size=32, + tracker_project_name="stable_diffusion_training", + log_with="wandb", + project_kwargs={ + "logging_dir": "./logs", + "automatic_checkpoint_naming": True, + "total_limit": 5, + "project_dir": "./save", + }, + ) + ) + + +class MLP(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(768, 1024), + nn.Dropout(0.2), + nn.Linear(1024, 128), + nn.Dropout(0.2), + nn.Linear(128, 64), + nn.Dropout(0.1), + nn.Linear(64, 16), + nn.Linear(16, 1), + ) + + @torch.no_grad() + def forward(self, embed): + return self.layers(embed) + + +class AestheticScorer(torch.nn.Module): + """ + This model attempts to predict the aesthetic score of an image. The aesthetic score + is a numerical approximation of how much a specific image is liked by humans on average. + This is from https://github.com/christophschuhmann/improved-aesthetic-predictor + """ + + def __init__(self, *, dtype, model_id, model_filename): + super().__init__() + self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") + self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") + self.mlp = MLP() + try: + cached_path = hf_hub_download(model_id, model_filename) + except EntryNotFoundError: + cached_path = os.path.join(model_id, model_filename) + state_dict = torch.load(cached_path) + self.mlp.load_state_dict(state_dict) + self.dtype = dtype + self.eval() + + @torch.no_grad() + def __call__(self, images): + device = next(self.parameters()).device + inputs = self.processor(images=images, return_tensors="pt") + inputs = {k: v.to(self.dtype).to(device) for k, v in inputs.items()} + embed = self.clip.get_image_features(**inputs) + # normalize embedding + embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True) + return self.mlp(embed).squeeze(1) + + +def aesthetic_scorer(hub_model_id, model_filename): + scorer = AestheticScorer( + model_id=hub_model_id, + model_filename=model_filename, + dtype=torch.float32, + ) + scorer = scorer.xpu() if is_xpu_available() else scorer.cuda() + + def _fn(images, prompts, metadata): + images = (images * 255).round().clamp(0, 255).to(torch.uint8) + scores = scorer(images) + return scores, {} + + return _fn + + +# list of example prompts to feed stable diffusion +animals = [ + "cat", + "dog", + "horse", + "monkey", + "rabbit", + "zebra", + "spider", + "bird", + "sheep", + "deer", + "cow", + "goat", + "lion", + "frog", + "chicken", + "duck", + "goose", + "bee", + "pig", + "turkey", + "fly", + "llama", + "camel", + "bat", + "gorilla", + "hedgehog", + "kangaroo", +] + + +def prompt_fn(): + return np.random.choice(animals), {} + + +def image_outputs_logger(image_data, global_step, accelerate_logger): + # For the sake of this example, we will only log the last batch of images + # and associated data + result = {} + images, prompts, _, rewards, _ = image_data[-1] + + for i, image in enumerate(images): + prompt = prompts[i] + reward = rewards[i].item() + result[f"{prompt:.25} | {reward:.2f}"] = image.unsqueeze(0) + + accelerate_logger.log_images( + result, + step=global_step, + ) + + +if __name__ == "__main__": + args = tyro.cli(ScriptArguments) + + pipeline = DefaultDDPOStableDiffusionPipeline( + args.pretrained_model, pretrained_model_revision=args.pretrained_revision, use_lora=True + ) + + trainer = DDPOTrainer( + args.ddpo_config, + aesthetic_scorer(args.hf_hub_aesthetic_model_id, args.hf_hub_aesthetic_model_filename), + prompt_fn, + pipeline, + image_samples_hook=image_outputs_logger, + ) + + trainer.train() + + trainer.push_to_hub(args.hf_hub_model_id, token=args.hf_user_access_token) diff --git a/trl/trl/examples/scripts/dpo.py b/trl/trl/examples/scripts/dpo.py new file mode 100644 index 0000000000000000000000000000000000000000..337c99af6acc2c4bd013cbcddbea2257ceea7b07 --- /dev/null +++ b/trl/trl/examples/scripts/dpo.py @@ -0,0 +1,182 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Note: you need to install transformers from main to run this script. See https://huggingface.co/docs/transformers/installation#install-from-source +# TODO: bump transformers version in requirements at next release. + +# 0. imports +from dataclasses import dataclass, field +from typing import Dict, Optional + +import torch +from datasets import Dataset, load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments + +from trl import DPOTrainer + + +# Define and parse arguments. +@dataclass +class ScriptArguments: + """ + The arguments for the DPO training script. + """ + + # data parameters + beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"}) + + # training parameters + model_name_or_path: Optional[str] = field(default="gpt2", metadata={"help": "the model name"}) + learning_rate: Optional[float] = field(default=1e-3, metadata={"help": "optimizer learning rate"}) + per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "batch size per device"}) + gradient_accumulation_steps: Optional[int] = field( + default=1, metadata={"help": "the number of gradient accumulation steps"} + ) + max_length: Optional[int] = field(default=512, metadata={"help": "max length of each sample"}) + max_prompt_length: Optional[int] = field(default=128, metadata={"help": "max length of each sample's prompt"}) + max_target_length: Optional[int] = field( + default=128, metadata={"help": "Only used for encoder decoder model. Max target of each sample's prompt"} + ) + label_pad_token_id: Optional[int] = field(default=-100, metadata={"help": "label for non response tokens"}) + max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"}) + # instrumentation + sanity_check: Optional[bool] = field(default=True, metadata={"help": "only train on 1000 samples"}) + report_to: Optional[str] = field( + default=None, + metadata={ + "help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,' + '`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. ' + 'Use `"all"` to report to all integrations installed, `"none"` for no integrations.' + }, + ) + # debug argument for distributed training + ignore_bias_buffers: Optional[bool] = field( + default=False, + metadata={ + "help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See" + "https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992" + }, + ) + gradient_checkpointing: Optional[bool] = field( + default=False, metadata={"help": "Whether to use gradient checkpointing or no"} + ) + gradient_checkpointing_kwargs: Optional[dict] = field( + default=None, + metadata={ + "help": "key word arguments to be passed along `torch.utils.checkpoint.checkpoint` method - e.g. `use_reentrant=False`" + }, + ) + + +def extract_anthropic_prompt(prompt_and_response): + """Extract the anthropic prompt from a prompt and response pair.""" + search_term = "\n\nAssistant:" + search_term_idx = prompt_and_response.rfind(search_term) + assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'" + return prompt_and_response[: search_term_idx + len(search_term)] + + +def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: str = None) -> Dataset: + """Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format. + + The dataset is converted to a dictionary with the following structure: + { + 'prompt': List[str], + 'chosen': List[str], + 'rejected': List[str], + } + + Prompts should be structured as follows: + \n\nHuman: \n\nAssistant: + Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:. + """ + dataset = load_dataset("Anthropic/hh-rlhf", split=split, cache_dir=cache_dir) + if sanity_check: + dataset = dataset.select(range(min(len(dataset), 1000))) + + def split_prompt_and_responses(sample) -> Dict[str, str]: + prompt = extract_anthropic_prompt(sample["chosen"]) + return { + "prompt": prompt, + "chosen": sample["chosen"][len(prompt) :], + "rejected": sample["rejected"][len(prompt) :], + } + + return dataset.map(split_prompt_and_responses) + + +if __name__ == "__main__": + parser = HfArgumentParser(ScriptArguments) + script_args = parser.parse_args_into_dataclasses()[0] + + # 1. load a pretrained model + model = AutoModelForCausalLM.from_pretrained(script_args.model_name_or_path) + + if script_args.ignore_bias_buffers: + # torch distributed hack + model._ddp_params_and_buffers_to_ignore = [ + name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool + ] + + model_ref = AutoModelForCausalLM.from_pretrained(script_args.model_name_or_path) + + tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # 2. Load the Anthropic Helpful-Harmless dataset + train_dataset = get_hh("train", sanity_check=script_args.sanity_check) + + # 3. Load evaluation dataset + eval_dataset = get_hh("test", sanity_check=script_args.sanity_check) + + # 4. initialize training arguments: + training_args = TrainingArguments( + per_device_train_batch_size=script_args.per_device_train_batch_size, + max_steps=script_args.max_steps, + remove_unused_columns=False, + gradient_accumulation_steps=script_args.gradient_accumulation_steps, + learning_rate=script_args.learning_rate, + evaluation_strategy="steps", + logging_first_step=True, + logging_steps=10, # match results in blog post + eval_steps=500, + output_dir="./test", + optim="rmsprop", + warmup_steps=150, + report_to=script_args.report_to, + bf16=True, + gradient_checkpointing=script_args.gradient_checkpointing, + # TODO: uncomment that on the next transformers release + # gradient_checkpointing_kwargs=script_args.gradient_checkpointing_kwargs, + ) + + # 5. initialize the DPO trainer + dpo_trainer = DPOTrainer( + model, + model_ref, + args=training_args, + beta=script_args.beta, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + max_length=script_args.max_length, + max_target_length=script_args.max_target_length, + max_prompt_length=script_args.max_prompt_length, + generate_during_eval=True, + ) + + # 6. train + dpo_trainer.train() diff --git a/trl/trl/examples/scripts/ppo.py b/trl/trl/examples/scripts/ppo.py new file mode 100644 index 0000000000000000000000000000000000000000..d84c1083820a5f6b9145f82f4506f1108117c269 --- /dev/null +++ b/trl/trl/examples/scripts/ppo.py @@ -0,0 +1,210 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field +from typing import Optional + +import torch +import tyro +from accelerate import Accelerator +from datasets import load_dataset +from peft import LoraConfig +from tqdm import tqdm +from transformers import AutoTokenizer, pipeline + +from trl import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead, PPOConfig, PPOTrainer, set_seed +from trl.core import LengthSampler +from trl.import_utils import is_xpu_available + + +tqdm.pandas() + + +@dataclass +class ScriptArguments: + ppo_config: PPOConfig = field( + default_factory=lambda: PPOConfig( + model_name="lvwerra/gpt2-imdb", + query_dataset="imdb", + reward_model="sentiment-analysis:lvwerra/distilbert-imdb", + learning_rate=1.41e-5, + log_with=None, + mini_batch_size=128, + batch_size=128, + gradient_accumulation_steps=1, + early_stopping=False, + target_kl=6.0, + kl_penalty="kl", + seed=0, + use_score_scaling=False, + use_score_norm=False, + score_clip=None, + ) + ) + use_seq2seq: bool = False + """whether to use seq2seq models""" + use_peft: bool = False + """whether to use peft""" + peft_config: Optional[LoraConfig] = field( + default_factory=lambda: LoraConfig( + r=16, + lora_alpha=16, + bias="none", + task_type="CAUSAL_LM", + ), + ) + trust_remote_code: bool = field(default=False, metadata={"help": "Enable `trust_remote_code`"}) + + +args = tyro.cli(ScriptArguments) + + +# We then define the arguments to pass to the sentiment analysis pipeline. +# We set `return_all_scores` to True to get the sentiment score for each token. +sent_kwargs = {"return_all_scores": True, "function_to_apply": "none", "batch_size": 16} + +trl_model_class = AutoModelForCausalLMWithValueHead if not args.use_seq2seq else AutoModelForSeq2SeqLMWithValueHead + + +# Below is an example function to build the dataset. In our case, we use the IMDB dataset +# from the `datasets` library. One should customize this function to train the model on +# its own dataset. +def build_dataset(config, query_dataset, input_min_text_length=2, input_max_text_length=8): + """ + Build dataset for training. This builds the dataset from `load_dataset`, one should + customize this function to train the model on its own dataset. + + Args: + query_dataset (`str`): + The name of the dataset to be loaded. + + Returns: + dataloader (`torch.utils.data.DataLoader`): + The dataloader for the dataset. + """ + tokenizer = AutoTokenizer.from_pretrained(config.model_name) + tokenizer.pad_token = tokenizer.eos_token + # load imdb with datasets + ds = load_dataset(query_dataset, split="train") + ds = ds.rename_columns({"text": "review"}) + ds = ds.filter(lambda x: len(x["review"]) > 200, batched=False) + + input_size = LengthSampler(input_min_text_length, input_max_text_length) + + def tokenize(sample): + sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()] + sample["query"] = tokenizer.decode(sample["input_ids"]) + return sample + + ds = ds.map(tokenize, batched=False) + ds.set_format(type="torch") + return ds + + +# We retrieve the dataloader by calling the `build_dataset` function. +dataset = build_dataset(args.ppo_config, args.ppo_config.query_dataset) + + +def collator(data): + return dict((key, [d[key] for d in data]) for key in data[0]) + + +# set seed before initializing value head for deterministic eval +set_seed(args.ppo_config.seed) + +# Now let's build the model, the reference model, and the tokenizer. +if not args.use_peft: + ref_model = trl_model_class.from_pretrained(args.ppo_config.model_name, trust_remote_code=args.trust_remote_code) + device_map = None + peft_config = None +else: + peft_config = args.peft_config + ref_model = None + # Copy the model to each device + device_map = {"": Accelerator().local_process_index} + +model = trl_model_class.from_pretrained( + args.ppo_config.model_name, + trust_remote_code=args.trust_remote_code, + device_map=device_map, + peft_config=peft_config, +) + + +tokenizer = AutoTokenizer.from_pretrained(args.ppo_config.model_name) + +# Some tokenizers like GPT-2's don't have a padding token by default, so we set one here. +tokenizer.pad_token_id = tokenizer.eos_token_id + +# We then build the PPOTrainer, passing the model, the reference model, the tokenizer +ppo_trainer = PPOTrainer(args.ppo_config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator) + +# We then build the sentiment analysis pipeline, passing the model name and the +# sentiment analysis pipeline arguments. Let's also make sure to set the device +# to the same device as the PPOTrainer. +device = ppo_trainer.accelerator.device +if ppo_trainer.accelerator.num_processes == 1: + if is_xpu_available(): + device = "xpu:0" + else: + device = 0 if torch.cuda.is_available() else "cpu" # to avoid a `pipeline` bug +ds_plugin = ppo_trainer.accelerator.state.deepspeed_plugin +task, model_name = args.ppo_config.reward_model.split(":") +if ds_plugin is not None and ds_plugin.is_zero3_init_enabled(): + with ds_plugin.zero3_init_context_manager(enable=False): + sentiment_pipe = pipeline(task, model=model_name, device=device) +else: + sentiment_pipe = pipeline(task, model=model_name, device=device) + +# Some tokenizers like GPT-2's don't have a padding token by default, so we set one here. +if sentiment_pipe.tokenizer.pad_token_id is None: + sentiment_pipe.tokenizer.pad_token_id = tokenizer.pad_token_id + +if sentiment_pipe.model.config.pad_token_id is None: + sentiment_pipe.model.config.pad_token_id = tokenizer.pad_token_id + +# We then define the arguments to pass to the `generate` function. These arguments +# are passed to the `generate` function of the PPOTrainer, which is a wrapper around +# the `generate` function of the trained model. +generation_kwargs = { + "min_length": -1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + "pad_token_id": tokenizer.eos_token_id, + "max_new_tokens": 32, +} + +for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)): + query_tensors = batch["input_ids"] + + # Get response from gpt2 + response_tensors, ref_response_tensors = ppo_trainer.generate( + query_tensors, return_prompt=False, generate_ref_response=True, **generation_kwargs + ) + batch["response"] = tokenizer.batch_decode(response_tensors) + batch["ref_response"] = tokenizer.batch_decode(ref_response_tensors) + + # Compute sentiment score + texts = [q + r for q, r in zip(batch["query"], batch["response"])] + pipe_outputs = sentiment_pipe(texts, **sent_kwargs) + rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs] + ref_texts = [q + r for q, r in zip(batch["query"], batch["ref_response"])] + ref_pipe_outputs = sentiment_pipe(ref_texts, **sent_kwargs) + ref_rewards = [torch.tensor(output[1]["score"]) for output in ref_pipe_outputs] + batch["ref_rewards"] = ref_rewards + + # Run PPO step + stats = ppo_trainer.step(query_tensors, response_tensors, rewards) + ppo_trainer.log_stats(stats, batch, rewards, columns_to_log=["query", "response", "ref_response", "ref_rewards"]) diff --git a/trl/trl/examples/scripts/ppo_multi_adapter.py b/trl/trl/examples/scripts/ppo_multi_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..044a2419de8b0d3859d1951a34d85824caa58480 --- /dev/null +++ b/trl/trl/examples/scripts/ppo_multi_adapter.py @@ -0,0 +1,151 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field +from typing import Optional + +import torch +from datasets import load_dataset +from peft import LoraConfig +from tqdm import tqdm +from transformers import AutoTokenizer, BitsAndBytesConfig, HfArgumentParser + +from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, is_xpu_available +from trl.core import LengthSampler + + +input_min_text_length = 6 +input_max_text_length = 12 + + +@dataclass +class ScriptArguments: + """ + The name of the Casual LM model we wish to fine with PPO + """ + + model_name: Optional[str] = field(default="huggyllama/llama-7b", metadata={"help": "the model name"}) + dataset_name: Optional[str] = field(default="Anthropic/hh-rlhf", metadata={"help": "the dataset name"}) + rm_adapter: Optional[str] = field( + default="trl-lib/llama-7b-hh-rm-adapter", metadata={"help": "the rm adapter name"} + ) + log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"}) + use_safetensors: Optional[bool] = field(default=False, metadata={"help": "Use safetensors"}) + seed: Optional[int] = field(default=0, metadata={"help": "the random seed"}) + use_score_scaling: Optional[bool] = field(default=False, metadata={"help": "Use score scaling"}) + use_score_norm: Optional[bool] = field( + default=False, metadata={"help": "Use score normalization. Only applicable if use_score_scaling is True"} + ) + score_clip: Optional[float] = field(default=None, metadata={"help": "Score clipping"}) + + +parser = HfArgumentParser(ScriptArguments) +script_args = parser.parse_args_into_dataclasses()[0] + + +def create_and_prepare_dataset(tokenizer): + dataset = load_dataset(script_args.dataset_name, split="train[:1%]") + + input_size = LengthSampler(input_min_text_length, input_max_text_length) + + def tokenize(example): + text_size = input_size() + example["input_ids"] = tokenizer.encode(example["chosen"])[:text_size] + example["query"] = tokenizer.decode(example["input_ids"]) + return example + + dataset = dataset.map(tokenize, batched=False) + dataset.set_format("torch") + return dataset + + +lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", +) +nf4_config = BitsAndBytesConfig( + load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16 +) +model = AutoModelForCausalLMWithValueHead.from_pretrained( + script_args.model_name, + device_map={"": "xpu:0"} if is_xpu_available() else {"": 0}, + peft_config=lora_config, + quantization_config=nf4_config, + reward_adapter=script_args.rm_adapter, + use_safetensors=script_args.use_safetensors, +) +tokenizer = AutoTokenizer.from_pretrained(script_args.model_name) + +tokenizer.pad_token = tokenizer.eos_token + +dataset = create_and_prepare_dataset(tokenizer) + + +def collator(data): + return dict((key, [d[key] for d in data]) for key in data[0]) + + +config = PPOConfig( + model_name=script_args.model_name, + log_with=script_args.log_with, + learning_rate=1e-5, + batch_size=8, + mini_batch_size=2, + gradient_accumulation_steps=2, + optimize_cuda_cache=True, + seed=script_args.seed, + use_score_scaling=script_args.use_score_scaling, + use_score_norm=script_args.use_score_norm, + score_clip=script_args.score_clip, +) + +ppo_trainer = PPOTrainer( + config, + model, + ref_model=None, + tokenizer=tokenizer, + dataset=dataset, + data_collator=collator, +) + +generation_kwargs = { + "top_k": 0.0, + "top_p": 0.9, + "do_sample": True, + "pad_token_id": tokenizer.pad_token_id, + "max_new_tokens": 32, +} + +for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)): + question_tensors = batch["input_ids"] + + response_tensors = ppo_trainer.generate( + question_tensors, + return_prompt=False, + **generation_kwargs, + ) + batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True) + + # Compute reward score + texts = [q + r for q, r in zip(batch["query"], batch["response"])] + inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(ppo_trainer.accelerator.device) + raw_rewards = ppo_trainer.accelerator.unwrap_model(ppo_trainer.model).compute_reward_score(**inputs) + rewards = [raw_rewards[i, -1, 1] for i in range(len(raw_rewards))] # take last token + + # Run PPO step + stats = ppo_trainer.step(question_tensors, response_tensors, rewards) + ppo_trainer.log_stats(stats, batch, rewards) diff --git a/trl/trl/examples/scripts/reward_modeling.py b/trl/trl/examples/scripts/reward_modeling.py new file mode 100644 index 0000000000000000000000000000000000000000..1271bce4a026056215ce80ab5c703428fe5e69ff --- /dev/null +++ b/trl/trl/examples/scripts/reward_modeling.py @@ -0,0 +1,173 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field +from typing import Optional + +import tyro +from accelerate import Accelerator +from datasets import load_dataset +from peft import LoraConfig +from tqdm import tqdm +from transformers import AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig + +from trl import RewardConfig, RewardTrainer, is_xpu_available + + +tqdm.pandas() + + +@dataclass +class ScriptArguments: + model_name: str = "facebook/opt-350m" + """the model name""" + dataset_name: str = "Anthropic/hh-rlhf" + """the dataset name""" + dataset_text_field: str = "text" + """the text field of the dataset""" + eval_split: str = "none" + """the dataset split to evaluate on; default to 'none' (no evaluation)""" + load_in_8bit: bool = False + """load the model in 8 bits precision""" + load_in_4bit: bool = False + """load the model in 4 bits precision""" + trust_remote_code: bool = True + """Enable `trust_remote_code`""" + reward_config: RewardConfig = field( + default_factory=lambda: RewardConfig( + output_dir="output", + per_device_train_batch_size=64, + num_train_epochs=1, + gradient_accumulation_steps=16, + gradient_checkpointing=True, + gradient_checkpointing_kwargs={"use_reentrant": False}, + learning_rate=1.41e-5, + report_to="tensorboard", + remove_unused_columns=False, + optim="adamw_torch", + logging_steps=500, + evaluation_strategy="no", + max_length=512, + ) + ) + use_peft: bool = False + """whether to use peft""" + peft_config: Optional[LoraConfig] = field( + default_factory=lambda: LoraConfig( + r=16, + lora_alpha=16, + bias="none", + task_type="SEQ_CLS", + modules_to_save=["scores"], + ), + ) + + +args = tyro.cli(ScriptArguments) +args.reward_config.evaluation_strategy = "steps" if args.eval_split != "none" else "no" + + +# Step 1: Load the model +if args.load_in_8bit and args.load_in_4bit: + raise ValueError("You can't load the model in 8 bits and 4 bits at the same time") +elif args.load_in_8bit or args.load_in_4bit: + quantization_config = BitsAndBytesConfig(load_in_8bit=args.load_in_8bit, load_in_4bit=args.load_in_4bit) + # Copy the model to each device + device_map = ( + {"": f"xpu:{Accelerator().local_process_index}"} + if is_xpu_available() + else {"": Accelerator().local_process_index} + ) +else: + device_map = None + quantization_config = None + +model = AutoModelForSequenceClassification.from_pretrained( + args.model_name, + quantization_config=quantization_config, + device_map=device_map, + trust_remote_code=args.trust_remote_code, + num_labels=1, +) + +# Step 2: Load the dataset and pre-process it +tokenizer = AutoTokenizer.from_pretrained(args.model_name) +train_dataset = load_dataset(args.dataset_name, split="train") + + +# Tokenize chosen/rejected pairs of inputs +# Adapt this section to your needs for custom datasets +def preprocess_function(examples): + new_examples = { + "input_ids_chosen": [], + "attention_mask_chosen": [], + "input_ids_rejected": [], + "attention_mask_rejected": [], + } + for chosen, rejected in zip(examples["chosen"], examples["rejected"]): + tokenized_chosen = tokenizer(chosen) + tokenized_rejected = tokenizer(rejected) + + new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"]) + new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"]) + new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"]) + new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"]) + + return new_examples + + +# Preprocess the dataset and filter out examples that are longer than args.max_length +train_dataset = train_dataset.map( + preprocess_function, + batched=True, + num_proc=4, +) +train_dataset = train_dataset.filter( + lambda x: len(x["input_ids_chosen"]) <= args.reward_config.max_length + and len(x["input_ids_rejected"]) <= args.reward_config.max_length +) + +if args.eval_split == "none": + eval_dataset = None +else: + eval_dataset = load_dataset(args.dataset_name, split=args.eval_split) + + eval_dataset = eval_dataset.map( + preprocess_function, + batched=True, + num_proc=4, + ) + eval_dataset = eval_dataset.filter( + lambda x: len(x["input_ids_chosen"]) <= args.reward_config.max_length + and len(x["input_ids_rejected"]) <= args.reward_config.max_length + ) + + +# Step 4: Define the LoraConfig +if args.use_peft: + peft_config = args.peft_config +else: + peft_config = None + +# Step 5: Define the Trainer +trainer = RewardTrainer( + model=model, + tokenizer=tokenizer, + args=args.reward_config, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + peft_config=peft_config, +) + +trainer.train() diff --git a/trl/trl/examples/scripts/sft.py b/trl/trl/examples/scripts/sft.py new file mode 100644 index 0000000000000000000000000000000000000000..1b5d620a3f3681a82ef273928dedd8b93455a111 --- /dev/null +++ b/trl/trl/examples/scripts/sft.py @@ -0,0 +1,158 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field +from typing import List, Optional + +import torch +from accelerate import Accelerator +from datasets import load_dataset +from peft import LoraConfig +from tqdm import tqdm +from transformers import AutoModelForCausalLM, BitsAndBytesConfig, HfArgumentParser, TrainingArguments + +from trl import SFTTrainer, is_xpu_available + + +tqdm.pandas() + + +# Define and parse arguments. +@dataclass +class ScriptArguments: + """ + The name of the Casual LM model we wish to fine with SFTTrainer + """ + + model_name: Optional[str] = field(default="facebook/opt-350m", metadata={"help": "the model name"}) + dataset_name: Optional[str] = field( + default="timdettmers/openassistant-guanaco", metadata={"help": "the dataset name"} + ) + dataset_text_field: Optional[str] = field(default="text", metadata={"help": "the text field of the dataset"}) + log_with: Optional[str] = field(default="none", metadata={"help": "use 'wandb' to log with wandb"}) + learning_rate: Optional[float] = field(default=1.41e-5, metadata={"help": "the learning rate"}) + batch_size: Optional[int] = field(default=64, metadata={"help": "the batch size"}) + seq_length: Optional[int] = field(default=512, metadata={"help": "Input sequence length"}) + gradient_accumulation_steps: Optional[int] = field( + default=16, metadata={"help": "the number of gradient accumulation steps"} + ) + load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"}) + load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"}) + use_peft: Optional[bool] = field(default=False, metadata={"help": "Wether to use PEFT or not to train adapters"}) + trust_remote_code: Optional[bool] = field(default=False, metadata={"help": "Enable `trust_remote_code`"}) + output_dir: Optional[str] = field(default="output", metadata={"help": "the output directory"}) + peft_lora_r: Optional[int] = field(default=64, metadata={"help": "the r parameter of the LoRA adapters"}) + peft_lora_alpha: Optional[int] = field(default=16, metadata={"help": "the alpha parameter of the LoRA adapters"}) + logging_steps: Optional[int] = field(default=1, metadata={"help": "the number of logging steps"}) + use_auth_token: Optional[bool] = field(default=True, metadata={"help": "Use HF auth token to access the model"}) + num_train_epochs: Optional[int] = field(default=3, metadata={"help": "the number of training epochs"}) + max_steps: Optional[int] = field(default=-1, metadata={"help": "the number of training steps"}) + save_steps: Optional[int] = field( + default=100, metadata={"help": "Number of updates steps before two checkpoint saves"} + ) + save_total_limit: Optional[int] = field(default=10, metadata={"help": "Limits total number of checkpoints."}) + push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the model to HF Hub"}) + gradient_checkpointing: Optional[bool] = field( + default=False, metadata={"help": "Whether to use gradient checkpointing or no"} + ) + gradient_checkpointing_kwargs: Optional[dict] = field( + default=None, + metadata={ + "help": "key word arguments to be passed along `torch.utils.checkpoint.checkpoint` method - e.g. `use_reentrant=False`" + }, + ) + hub_model_id: Optional[str] = field(default=None, metadata={"help": "The name of the model on HF Hub"}) + mixed_precision: Optional[str] = field(default="bf16", metadata={"help": "Mixed precision training"}) + target_modules: Optional[List[str]] = field(default=None, metadata={"help": "Target modules for LoRA adapters"}) + + +parser = HfArgumentParser(ScriptArguments) +script_args = parser.parse_args_into_dataclasses()[0] + +# Step 1: Load the model +if script_args.load_in_8bit and script_args.load_in_4bit: + raise ValueError("You can't load the model in 8 bits and 4 bits at the same time") +elif script_args.load_in_8bit or script_args.load_in_4bit: + quantization_config = BitsAndBytesConfig( + load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit + ) + # Copy the model to each device + device_map = ( + {"": f"xpu:{Accelerator().local_process_index}"} + if is_xpu_available() + else {"": Accelerator().local_process_index} + ) + torch_dtype = torch.bfloat16 +else: + device_map = None + quantization_config = None + torch_dtype = None + +model = AutoModelForCausalLM.from_pretrained( + script_args.model_name, + quantization_config=quantization_config, + device_map=device_map, + trust_remote_code=script_args.trust_remote_code, + torch_dtype=torch_dtype, + use_auth_token=script_args.use_auth_token, +) + +# Step 2: Load the dataset +dataset = load_dataset(script_args.dataset_name, split="train") + +# Step 3: Define the training arguments +training_args = TrainingArguments( + output_dir=script_args.output_dir, + per_device_train_batch_size=script_args.batch_size, + gradient_accumulation_steps=script_args.gradient_accumulation_steps, + learning_rate=script_args.learning_rate, + logging_steps=script_args.logging_steps, + num_train_epochs=script_args.num_train_epochs, + max_steps=script_args.max_steps, + report_to=script_args.log_with, + save_steps=script_args.save_steps, + save_total_limit=script_args.save_total_limit, + push_to_hub=script_args.push_to_hub, + hub_model_id=script_args.hub_model_id, + gradient_checkpointing=script_args.gradient_checkpointing, + # TODO: uncomment that on the next release + # gradient_checkpointing_kwargs=script_args.gradient_checkpointing_kwargs, +) + +# Step 4: Define the LoraConfig +if script_args.use_peft: + peft_config = LoraConfig( + r=script_args.peft_lora_r, + lora_alpha=script_args.peft_lora_alpha, + bias="none", + task_type="CAUSAL_LM", + target_modules=script_args.target_modules, + ) +else: + peft_config = None + +# Step 5: Define the Trainer +trainer = SFTTrainer( + model=model, + args=training_args, + max_seq_length=script_args.seq_length, + train_dataset=dataset, + dataset_text_field=script_args.dataset_text_field, + peft_config=peft_config, +) + +trainer.train() + +# Step 6: Save the model +trainer.save_model(script_args.output_dir) diff --git a/trl/trl/pyproject.toml b/trl/trl/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..7e6b3f84fae69610b44b44cba277e2201bcfb555 --- /dev/null +++ b/trl/trl/pyproject.toml @@ -0,0 +1,16 @@ +[tool.black] +line-length = 119 +target-version = ['py38'] + +[tool.ruff] +ignore = ["E501", "E741", "W605"] +select = ["E", "F", "I", "W"] +line-length = 119 + +# Ignore import violations in all `__init__.py` files. +[tool.ruff.per-file-ignores] +"__init__.py" = ["E402", "F401", "F403", "F811"] + +[tool.ruff.isort] +lines-after-imports = 2 +known-first-party = ["trl"] diff --git a/trl/trl/requirements.txt b/trl/trl/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..12d8a5845c755445298b7f98c056bd02a2d61672 --- /dev/null +++ b/trl/trl/requirements.txt @@ -0,0 +1,7 @@ +datasets>=1.17.0 +torch>=1.4.0 +tqdm +transformers +accelerate +peft>=0.3.0 +tyro>=0.5.7 \ No newline at end of file diff --git a/trl/trl/scripts/stale.py b/trl/trl/scripts/stale.py new file mode 100644 index 0000000000000000000000000000000000000000..de7b869c13280cea71507cbe1e635e25c3f36c5b --- /dev/null +++ b/trl/trl/scripts/stale.py @@ -0,0 +1,61 @@ +# Copyright 2023 The HuggingFace Team, the AllenNLP library authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Script to close stale issue. Taken in part from the AllenNLP repository. +https://github.com/allenai/allennlp. +""" +import os +from datetime import datetime as dt +from datetime import timezone + +from github import Github + + +LABELS_TO_EXEMPT = [ + "good first issue", + "good second issue", + "feature request", +] + + +def main(): + g = Github(os.environ["GITHUB_TOKEN"]) + repo = g.get_repo("huggingface/trl") + open_issues = repo.get_issues(state="open") + + for issue in open_issues: + comments = sorted([comment for comment in issue.get_comments()], key=lambda i: i.created_at, reverse=True) + last_comment = comments[0] if len(comments) > 0 else None + if ( + last_comment is not None + and last_comment.user.login == "github-actions[bot]" + and (dt.now(timezone.utc) - issue.updated_at).days > 7 + and (dt.now(timezone.utc) - issue.created_at).days >= 30 + and not any(label.name.lower() in LABELS_TO_EXEMPT for label in issue.get_labels()) + ): + issue.edit(state="closed") + elif ( + (dt.now(timezone.utc) - issue.updated_at).days > 23 + and (dt.now(timezone.utc) - issue.created_at).days >= 30 + and not any(label.name.lower() in LABELS_TO_EXEMPT for label in issue.get_labels()) + ): + issue.create_comment( + "This issue has been automatically marked as stale because it has not had " + "recent activity. If you think this still needs to be addressed " + "please comment on this thread.\n\n" + ) + + +if __name__ == "__main__": + main() diff --git a/trl/trl/setup.cfg b/trl/trl/setup.cfg new file mode 100644 index 0000000000000000000000000000000000000000..cb69438f5607211e0e1002d3fc9f9f7479b2b998 --- /dev/null +++ b/trl/trl/setup.cfg @@ -0,0 +1,11 @@ +[metadata] +license_file = LICENSE + +[isort] +ensure_newline_before_comments = True +force_grid_wrap = 0 +include_trailing_comma = True +line_length = 119 +lines_after_imports = 2 +multi_line_output = 3 +use_parentheses = True diff --git a/trl/trl/setup.py b/trl/trl/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..1b2e1cb14c3619017a0b00fa15b609cf44533fc6 --- /dev/null +++ b/trl/trl/setup.py @@ -0,0 +1,112 @@ +""" trl is an open library for RL with transformer models. + +Note: + + VERSION needs to be formatted following the MAJOR.MINOR.PATCH convention + (we need to follow this convention to be able to retrieve versioned scripts) + +Simple check list for release from AllenNLP repo: https://github.com/allenai/allennlp/blob/master/setup.py + +To create the package for pypi. + +0. Prerequisites: + - Dependencies: + - twine: "pip install twine" + - Create an account in (and join the 'trl' project): + - PyPI: https://pypi.org/ + - Test PyPI: https://test.pypi.org/ + +1. Change the version in: + - __init__.py + - setup.py + +2. Commit these changes: "git commit -m 'Release: VERSION'" + +3. Add a tag in git to mark the release: "git tag VERSION -m 'Add tag VERSION for pypi'" + Push the tag to remote: git push --tags origin main + +4. Build both the sources and the wheel. Do not change anything in setup.py between + creating the wheel and the source distribution (obviously). + + First, delete any "build" directory that may exist from previous builds. + + For the wheel, run: "python setup.py bdist_wheel" in the top level directory. + (this will build a wheel for the python version you use to build it). + + For the sources, run: "python setup.py sdist" + You should now have a /dist directory with both .whl and .tar.gz source versions. + +5. Check that everything looks correct by uploading the package to the pypi test server: + + twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/ + + Check that you can install it in a virtualenv/notebook by running: + pip install huggingface_hub fsspec aiohttp + pip install -U tqdm + pip install -i https://testpypi.python.org/pypi evaluate + +6. Upload the final version to actual pypi: + twine upload dist/* -r pypi + +7. Fill release notes in the tag in github once everything is looking hunky-dory. + +8. Change the version in __init__.py and setup.py to X.X.X+1.dev0 (e.g. VERSION=1.18.3 -> 1.18.4.dev0). + Then push the change with a message 'set dev version' +""" + +from setuptools import find_packages, setup + + +__version__ = "0.7.5.dev0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) + +REQUIRED_PKGS = [ + "torch>=1.4.0", + "transformers>=4.18.0", + "numpy>=1.18.2", + "accelerate", + "datasets", + "tyro>=0.5.11", +] +EXTRAS = { + "test": ["parameterized", "pytest", "pytest-xdist", "accelerate"], + "peft": ["peft>=0.4.0"], + "diffusers": ["diffusers>=0.18.0"], + "deepspeed": ["deepspeed>=0.9.5"], + "benchmark": ["wandb", "ghapi", "openrlbenchmark==0.2.1a5", "requests", "deepspeed"], + "quantization": ["bitsandbytes<=0.41.1"], +} +EXTRAS["dev"] = [] +for reqs in EXTRAS.values(): + EXTRAS["dev"].extend(reqs) + +setup( + name="trl", + license="Apache 2.0", + classifiers=[ + "Development Status :: 2 - Pre-Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Natural Language :: English", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + ], + url="https://github.com/huggingface/trl", + packages=find_packages(), + include_package_data=True, + install_requires=REQUIRED_PKGS, + extras_require=EXTRAS, + python_requires=">=3.7", + long_description=open("README.md", encoding="utf-8").read(), + long_description_content_type="text/markdown", + zip_safe=False, + version=__version__, + description="A Pytorch implementation of Proximal Policy Optimization for transfomer language models.", + keywords="ppo, transformers, huggingface, gpt2, language modeling, rlhf", + author="Leandro von Werra", + author_email="leandro.vonwerra@gmail.com", +) diff --git a/trl/trl/tests/__init__.py b/trl/trl/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/trl/trl/tests/test_best_of_n_sampler.py b/trl/trl/tests/test_best_of_n_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..3b5001a898aefd7a3b389ea27ba9bc9a45e4d770 --- /dev/null +++ b/trl/trl/tests/test_best_of_n_sampler.py @@ -0,0 +1,98 @@ +import unittest + +import torch +from transformers import AutoTokenizer, GenerationConfig + +from trl import AutoModelForCausalLMWithValueHead +from trl.core import LengthSampler +from trl.extras import BestOfNSampler + + +def queries_to_scores(list_of_strings): + return [torch.rand(1).item() for _ in list_of_strings] + + +class BestOfNSamplerTester(unittest.TestCase): + """ + Tests the BestOfNSampler class + """ + + ref_model_name = "trl-internal-testing/dummy-GPT2-correct-vocab" + output_length_sampler = LengthSampler(2, 6) + model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name) + tokenizer = AutoTokenizer.from_pretrained(ref_model_name) + tokenizer.pad_token = tokenizer.eos_token + output_length_sampler = LengthSampler(2, 6) + + def test_different_input_types(self): + r""" + Tests if the different input types normalizer works + """ + + generation_config = GenerationConfig( + min_length=-1, + top_k=0.0, + top_p=1.0, + do_sample=True, + pad_token_id=self.tokenizer.eos_token_id, + ) + + output_length_sampler = LengthSampler(2, 6) + + best_of_n = BestOfNSampler( + self.model, + self.tokenizer, + queries_to_scores, + length_sampler=output_length_sampler, + generation_config=generation_config, + ) + + queries = ["hello world", "goodbye world"] + tokenized_queries = [self.tokenizer.encode(query) for query in queries] + + various_queries_formats = [ + (tokenized_queries[0], 1), + (tokenized_queries, 2), + (torch.tensor(tokenized_queries[1]), 1), + ([torch.tensor(query) for query in tokenized_queries], 2), + ] + + for q, expected_length in various_queries_formats: + results = best_of_n.generate(q) + self.assertIsInstance(results, list) + assert len(results) == expected_length + + def test_different_sample_sizes_and_n_candidates_values(self): + r""" + Tests different sample sizes and n_candidates values + """ + generation_config = GenerationConfig( + min_length=-1, + top_k=0.0, + top_p=1.0, + do_sample=True, + pad_token_id=self.tokenizer.eos_token_id, + ) + + output_length_sampler = LengthSampler(6, 10) + + for sample_value, n_candidates_values, expected in [ + (4, 2, 2), + (10, 3, 3), + (6, 4, 4), + ]: + best_of_n = BestOfNSampler( + self.model, + self.tokenizer, + queries_to_scores, + length_sampler=output_length_sampler, + generation_config=generation_config, + sample_size=sample_value, + n_candidates=n_candidates_values, + ) + + queries = ["hello world", "troll the world"] + tokenized_queries = [self.tokenizer.encode(query) for query in queries] + results = best_of_n.generate(tokenized_queries) + for result in results: + assert len(result) == expected diff --git a/trl/trl/tests/test_core.py b/trl/trl/tests/test_core.py new file mode 100644 index 0000000000000000000000000000000000000000..151852e267b0c2ef6557901a4269c8ab774aafba --- /dev/null +++ b/trl/trl/tests/test_core.py @@ -0,0 +1,42 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import torch + +from trl.core import masked_mean, masked_var, masked_whiten, whiten + + +class CoreTester(unittest.TestCase): + """ + A wrapper class for testing core utils functions + """ + + @classmethod + def setUpClass(cls): + cls.test_input = torch.Tensor([1, 2, 3, 4]) + cls.test_mask = torch.Tensor([0, 1, 1, 0]) + cls.test_input_unmasked = cls.test_input[1:3] + + def test_masked_mean(self): + self.assertEqual(torch.mean(self.test_input_unmasked), masked_mean(self.test_input, self.test_mask)) + + def test_masked_var(self): + self.assertEqual(torch.var(self.test_input_unmasked), masked_var(self.test_input, self.test_mask)) + + def test_masked_whiten(self): + whiten_unmasked = whiten(self.test_input_unmasked) + whiten_masked = masked_whiten(self.test_input, self.test_mask)[1:3] + diffs = (whiten_unmasked - whiten_masked).sum() + self.assertAlmostEqual(diffs, 0) diff --git a/trl/trl/tests/test_data_collator_completion_only.py b/trl/trl/tests/test_data_collator_completion_only.py new file mode 100644 index 0000000000000000000000000000000000000000..c895a616e136c211493e6e042221691b0e248261 --- /dev/null +++ b/trl/trl/tests/test_data_collator_completion_only.py @@ -0,0 +1,81 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import torch +from transformers import AutoTokenizer + +from trl import DataCollatorForCompletionOnlyLM + + +class DataCollatorForCompletionOnlyLMTester(unittest.TestCase): + def test_data_collator_finds_response_template_llama2_tokenizer(self): + # this should ideally be tested with meta-llama/Llama-2-7b-hf + self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab") + self.instruction = """### System: You are a helpful assistant. + +### User: How much is 2+2? + +### Assistant: 2+2 equals 4""" + self.instruction_template = "\n### User:" + self.response_template = "\n### Assistant:" + + # GPT2Tokenizer: [198, 21017, 11787, 25] -> [11787, 25] + # Llama2Tokenizer: [29871, 13, 2277, 29937, 4911, 29901] -> [2277, 29937, 4911, 29901] + self.tokenized_instruction_w_context = self.tokenizer.encode( + self.instruction_template, add_special_tokens=False + )[2:] + + # GPT2Tokenizer: [198, 21017, 15286, 25] -> [15286, 25] + # Llama2Tokenizer: [29871, 13, 2277, 29937, 4007, 22137, 29901] -> [2277, 29937, 4007, 22137, 29901] + self.tokenized_response_w_context = self.tokenizer.encode(self.response_template, add_special_tokens=False)[2:] + + # Plain check on string + self.assertIn(self.response_template, self.instruction) + self.tokenized_instruction = self.tokenizer.encode(self.instruction, add_special_tokens=False) + + # Test the fix for #598 + # Pass already tokenized (w context) and truncated response_template so token_ids are like in the instruction + response + self.collator = DataCollatorForCompletionOnlyLM(self.tokenized_response_w_context, tokenizer=self.tokenizer) + self.collator.torch_call([self.tokenized_instruction]) + + # Test for PR #749 + # Pass already tokenized (w context) instruction and response both so token_ids are like in the instruction + response + self.collator = DataCollatorForCompletionOnlyLM( + self.tokenized_response_w_context, self.tokenized_instruction_w_context, tokenizer=self.tokenizer + ) + self.collator.torch_call([self.tokenized_instruction]) + + def test_data_collator_handling_of_long_sequences(self): + self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab") + self.instruction = """### System: You are a helpful assistant. + +### User: How much is 2+2? I'm asking because I'm not sure. And I'm not sure because I'm not good at math. +""" + self.response_template = "\n### Assistant:" + # check DataCollatorForCompletionOnlyLM using response template only + self.tokenized_instruction = self.tokenizer.encode(self.instruction, add_special_tokens=False) + self.collator = DataCollatorForCompletionOnlyLM(self.response_template, tokenizer=self.tokenizer) + encoded_instance = self.collator.torch_call([self.tokenized_instruction]) + result = torch.all(encoded_instance["labels"] == -100) + self.assertTrue(result, "Not all values in the tensor are -100.") + + # check DataCollatorForCompletionOnlyLM using response template and instruction template + self.instruction_template = "\n### User:" + self.collator = DataCollatorForCompletionOnlyLM( + self.response_template, self.instruction_template, tokenizer=self.tokenizer + ) + encoded_instance = self.collator.torch_call([self.tokenized_instruction]) + result = torch.all(encoded_instance["labels"] == -100) + self.assertTrue(result, "Not all values in the tensor are -100.") diff --git a/trl/trl/tests/test_ddpo_trainer.py b/trl/trl/tests/test_ddpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e56ab16b571c5aee7e99b985063fa227a380ebd1 --- /dev/null +++ b/trl/trl/tests/test_ddpo_trainer.py @@ -0,0 +1,99 @@ +# Copyright 2023 metric-space, The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import gc +import unittest + +import torch + +from trl import is_diffusers_available + +from .testing_utils import require_diffusers + + +if is_diffusers_available(): + from trl import DDPOConfig, DDPOTrainer, DefaultDDPOStableDiffusionPipeline + + +def scorer_function(images, prompts, metadata): + return torch.randn(1) * 3.0, {} + + +def prompt_function(): + return ("cabbages", {}) + + +@require_diffusers +class DDPOTrainerTester(unittest.TestCase): + """ + Test the DDPOTrainer class. + """ + + def setUp(self): + self.ddpo_config = DDPOConfig( + num_epochs=2, + train_gradient_accumulation_steps=1, + per_prompt_stat_tracking_buffer_size=32, + sample_num_batches_per_epoch=2, + sample_batch_size=2, + mixed_precision=None, + save_freq=1000000, + ) + pretrained_model = "hf-internal-testing/tiny-stable-diffusion-torch" + pretrained_revision = "main" + + pipeline = DefaultDDPOStableDiffusionPipeline( + pretrained_model, pretrained_model_revision=pretrained_revision, use_lora=False + ) + + self.trainer = DDPOTrainer(self.ddpo_config, scorer_function, prompt_function, pipeline) + + return super().setUp() + + def tearDown(self) -> None: + gc.collect() + + def test_loss(self): + advantage = torch.tensor([-1.0]) + clip_range = 0.0001 + ratio = torch.tensor([1.0]) + loss = self.trainer.loss(advantage, clip_range, ratio) + self.assertEqual(loss.item(), 1.0) + + def test_generate_samples(self): + samples, output_pairs = self.trainer._generate_samples(1, 2) + self.assertEqual(len(samples), 1) + self.assertEqual(len(output_pairs), 1) + self.assertEqual(len(output_pairs[0][0]), 2) + + def test_calculate_loss(self): + samples, _ = self.trainer._generate_samples(1, 2) + sample = samples[0] + + latents = sample["latents"][0, 0].unsqueeze(0) + next_latents = sample["next_latents"][0, 0].unsqueeze(0) + log_probs = sample["log_probs"][0, 0].unsqueeze(0) + timesteps = sample["timesteps"][0, 0].unsqueeze(0) + prompt_embeds = sample["prompt_embeds"] + advantage = torch.tensor([1.0], device=prompt_embeds.device) + + self.assertEqual(latents.shape, (1, 4, 64, 64)) + self.assertEqual(next_latents.shape, (1, 4, 64, 64)) + self.assertEqual(log_probs.shape, (1,)) + self.assertEqual(timesteps.shape, (1,)) + self.assertEqual(prompt_embeds.shape, (2, 77, 32)) + loss, approx_kl, clipfrac = self.trainer.calculate_loss( + latents, timesteps, next_latents, log_probs, advantage, prompt_embeds + ) + + self.assertTrue(torch.isfinite(loss.cpu())) diff --git a/trl/trl/tests/test_dpo_trainer.py b/trl/trl/tests/test_dpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..ef628497a31f90dd910a2aabc61369ce38b9f500 --- /dev/null +++ b/trl/trl/tests/test_dpo_trainer.py @@ -0,0 +1,298 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tempfile +import unittest + +import torch +from datasets import Dataset +from parameterized import parameterized +from pytest import mark +from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments + +from trl import DPOTrainer + +from .testing_utils import require_no_wandb, require_peft + + +class DPOTrainerTester(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab" + cls.model = AutoModelForCausalLM.from_pretrained(cls.model_id) + cls.ref_model = AutoModelForCausalLM.from_pretrained(cls.model_id) + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_id) + cls.tokenizer.pad_token = cls.tokenizer.eos_token + + # get t5 as seq2seq example: + model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration-correct-vocab" + cls.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id) + cls.t5_ref_model = AutoModelForSeq2SeqLM.from_pretrained(model_id) + cls.t5_tokenizer = AutoTokenizer.from_pretrained(model_id) + + def _init_dummy_dataset(self): + # fmt: off + dummy_dataset_dict = { + "prompt": [ + "hello", + "how are you", + "What is your name?", + "What is your name?", + "Which is the best programming language?", + "Which is the best programming language?", + "Which is the best programming language?", + ], + "chosen": [ + "hi nice to meet you", + "I am fine", + "My name is Mary", + "My name is Mary", + "Python", + "Python", + "Python", + ], + "rejected": [ + "leave me alone", + "I am not fine", + "Whats it to you?", + "I dont have a name", + "Javascript", + "C++", + "Java", + ], + } + # fmt: on + return Dataset.from_dict(dummy_dataset_dict) + + @parameterized.expand( + [["gpt2", "sigmoid"], ["t5", "hinge"], ["gpt2", "ipo"], ["t5", "ipo"], ["gpt2", "kto"], ["t5", "kto"]] + ) + def test_dpo_trainer(self, name, loss_type): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + evaluation_strategy="steps", + ) + + dummy_dataset = self._init_dummy_dataset() + + if name == "gpt2": + model = self.model + ref_model = self.ref_model + tokenizer = self.tokenizer + elif name == "t5": + model = self.t5_model + ref_model = self.t5_ref_model + tokenizer = self.t5_tokenizer + + trainer = DPOTrainer( + model=model, + ref_model=ref_model, + beta=0.1, + loss_type=loss_type, + args=training_args, + tokenizer=tokenizer, + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # check the params have changed - ignore 0 biases + if param.sum() != 0: + self.assertFalse(torch.equal(param, new_param)) + + def test_dpo_trainer_without_providing_ref_model(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + evaluation_strategy="steps", + ) + + dummy_dataset = self._init_dummy_dataset() + + trainer = DPOTrainer( + model=self.model, + ref_model=None, + beta=0.1, + args=training_args, + tokenizer=self.tokenizer, + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # check the params have changed - ignore 0 biases + if param.sum() != 0: + self.assertFalse(torch.equal(param, new_param)) + + @require_peft + @mark.peft_test + def test_dpo_trainer_without_providing_ref_model_with_lora(self): + from peft import LoraConfig + + lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + evaluation_strategy="steps", + ) + + dummy_dataset = self._init_dummy_dataset() + + trainer = DPOTrainer( + model=self.model, + ref_model=None, + beta=0.1, + args=training_args, + tokenizer=self.tokenizer, + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, + peft_config=lora_config, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # check the params have changed + for n, param in previous_trainable_params.items(): + if "lora" in n: + new_param = trainer.model.get_parameter(n) + # check the params have changed - ignore 0 biases + if param.sum() != 0: + self.assertFalse(torch.equal(param, new_param)) + + @require_no_wandb + def test_dpo_trainer_generate_during_eval_no_wandb(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + evaluation_strategy="steps", + ) + + dummy_dataset = self._init_dummy_dataset() + + with self.assertRaisesRegex( + ValueError, + expected_regex="`generate_during_eval=True` requires Weights and Biases to be installed." + " Please install `wandb` to resolve.", + ): + DPOTrainer( + model=self.model, + ref_model=None, + beta=0.1, + args=training_args, + tokenizer=self.tokenizer, + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, + generate_during_eval=True, + ) + + @require_peft + @mark.peft_test + def test_dpo_lora_save(self): + from peft import LoraConfig, get_peft_model + + lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + # lora model + model = AutoModelForCausalLM.from_pretrained(self.model_id) + model_peft = get_peft_model(model, lora_config) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + evaluation_strategy="steps", + ) + + dummy_dataset = self._init_dummy_dataset() + + # dpo train lora model with a lora config + trainer = DPOTrainer( + model=model_peft, + ref_model=None, + beta=0.1, + args=training_args, + tokenizer=self.tokenizer, + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, + peft_config=lora_config, + ) + + # train the model + trainer.train() + + # save peft adapter + trainer.save_model() + + # assert that the model is loaded without giving OSError + try: + AutoModelForCausalLM.from_pretrained(tmp_dir) + except OSError: + self.fail("Loading the saved peft adapter failed") diff --git a/trl/trl/tests/test_e2e.py b/trl/trl/tests/test_e2e.py new file mode 100644 index 0000000000000000000000000000000000000000..7e742329dee0da2fd7f80fee5568f3d808dff6fd --- /dev/null +++ b/trl/trl/tests/test_e2e.py @@ -0,0 +1,9 @@ +import subprocess + + +def test_hello_world(): + subprocess.run( + "python examples/hello_world.py", + shell=True, + check=True, + ) diff --git a/trl/trl/tests/test_environments.py b/trl/trl/tests/test_environments.py new file mode 100644 index 0000000000000000000000000000000000000000..e31daab5cebee9fe2e13fc85d4deb63503c8c01d --- /dev/null +++ b/trl/trl/tests/test_environments.py @@ -0,0 +1,273 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import patch + +import torch +from transformers import AutoTokenizer + +from trl import AutoModelForCausalLMWithValueHead, TextEnvironment, TextHistory + + +class DummyTool: + def __call__(self, text): + return text + + +def dummy_generate(histories): + for i in range(len(histories)): + histories[i].append_segment("test", torch.tensor([1, 2, 3]), system=False) + return histories + + +class TextHistoryTest(unittest.TestCase): + def test_text_history_init(self): + text = "Hello there!" + tokens = torch.tensor([1, 2, 3]) + + history = TextHistory(text, tokens) + self.assertEqual(history.text, text) + self.assertTrue(torch.equal(history.tokens, tokens)) + self.assertTrue(torch.equal(history.token_masks, torch.zeros_like(tokens))) + + history = TextHistory(text, tokens, system=False) + self.assertTrue(torch.equal(history.token_masks, torch.ones_like(tokens))) + + def test_text_history_append_segment(self): + text = "Hello there!" + tokens = torch.tensor([1, 2, 3]) + + history = TextHistory(text, tokens) + history.append_segment("General Kenobi!", torch.tensor([4, 5, 6]), system=False) + self.assertEqual(history.text, text + "General Kenobi!") + self.assertTrue(torch.equal(history.tokens, torch.tensor([1, 2, 3, 4, 5, 6]))) + self.assertTrue(torch.equal(history.token_masks, torch.tensor([0, 0, 0, 1, 1, 1]))) + + history.append_segment("You are a bold one!", torch.tensor([7, 8, 9])) + self.assertEqual(history.text, text + "General Kenobi!" + "You are a bold one!") + self.assertTrue(torch.equal(history.tokens, torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]))) + self.assertTrue(torch.equal(history.token_masks, torch.tensor([0, 0, 0, 1, 1, 1, 0, 0, 0]))) + + def test_text_history_complete(self): + text = "Hello there!" + tokens = torch.tensor([1, 2, 3]) + history = TextHistory(text, tokens) + history.complete() + self.assertTrue(history.completed) + self.assertFalse(history.truncated) + + history.complete(truncated=True) + self.assertTrue(history.completed) + self.assertTrue(history.truncated) + + def test_text_history_last_segment(self): + text = "Hello there!" + tokens = torch.tensor([1, 2, 3]) + history = TextHistory(text, tokens) + history.append_segment("General Kenobi!", torch.tensor([4, 5, 6])) + history.append_segment("You are a bold one!", torch.tensor([7, 8, 9])) + self.assertEqual(history.last_text_segment, "You are a bold one!") + + def test_text_history_split_query_response(self): + text = "Hello there!" + tokens = torch.tensor([1, 2, 3]) + history = TextHistory(text, tokens) + history.append_segment("General Kenobi!", torch.tensor([4, 5, 6]), system=False) + history.append_segment("You are a bold one!", torch.tensor([7, 8, 9]), system=True) + query, response, mask = history.split_query_response_tokens() + + self.assertTrue(torch.equal(query, torch.tensor([1, 2, 3]))) + self.assertTrue(torch.equal(response, torch.tensor([4, 5, 6, 7, 8, 9]))) + self.assertTrue(torch.equal(mask, torch.tensor([1, 1, 1, 0, 0, 0]))) + + +class TextEnvironmentTester(unittest.TestCase): + @classmethod + def setUpClass(cls): + # model_id + cls.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab" + + # get models and tokenizer + cls.gpt2_model = AutoModelForCausalLMWithValueHead.from_pretrained(cls.model_id) + cls.gpt2_tokenizer = AutoTokenizer.from_pretrained(cls.model_id) + cls.gpt2_tokenizer.pad_token = cls.gpt2_tokenizer.eos_token + + def test_text_environment_setup(self): + env = TextEnvironment( + self.gpt2_model, + self.gpt2_tokenizer, + tools=[DummyTool()], + reward_fn=lambda x: torch.tensor(1), + prompt="I am a prompt!\n", + ) + self.assertEqual(env.prompt, "I am a prompt!\n") + self.assertEqual(list(env.tools.keys()), ["DummyTool"]) + self.assertTrue(isinstance(env.tools["DummyTool"], DummyTool)) + self.assertEqual(env.reward_fn("Hello there!"), 1) + + def test_text_environment_generate(self): + generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": self.gpt2_tokenizer.eos_token_id} + env = TextEnvironment( + self.gpt2_model, + self.gpt2_tokenizer, + tools=[DummyTool()], + reward_fn=lambda x: torch.tensor(1), + prompt="I am a prompt!\n", + generation_kwargs=generation_kwargs, + ) + + input_texts = ["this is a test", "this is another, longer test"] + + model_inputs = [self.gpt2_tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts] + + generations_batched = env._generate_batched(model_inputs, batch_size=2) + generations_batched = self.gpt2_tokenizer.batch_decode(generations_batched) + + generations_single = [env._generate_batched([inputs], batch_size=1)[0] for inputs in model_inputs] + generations_single = self.gpt2_tokenizer.batch_decode(generations_single) + + self.assertEqual(generations_single, generations_batched) + + def test_text_environment_tool_call_parsing(self): + string_valid = "Something something Hello there!" + string_invalid_request = "Something something Hello there!" + string_invalid_call = "Something something Hello there!" + string_invalid_tool = "Something something |Tool2|Hello there!" + string_invalid_random = "<>abcdefghijklm<>nopqrstuvwxyz<>" + + env = TextEnvironment( + self.gpt2_model, + self.gpt2_tokenizer, + tools=[DummyTool()], + reward_fn=lambda x: torch.tensor(1), + prompt="I am a prompt!\n", + ) + tool, response = env.parse_tool_call(string_valid) + self.assertEqual(tool, "Tool1") + self.assertEqual(response, "Hello there!") + + tool, response = env.parse_tool_call(string_invalid_request) + self.assertEqual(tool, None) + self.assertEqual(response, None) + + tool, response = env.parse_tool_call(string_invalid_call) + self.assertEqual(tool, None) + self.assertEqual(response, None) + + tool, response = env.parse_tool_call(string_invalid_tool) + self.assertEqual(tool, None) + self.assertEqual(response, None) + + tool, response = env.parse_tool_call(string_invalid_random) + self.assertEqual(tool, None) + self.assertEqual(response, None) + + def test_text_environment_tool_truncation(self): + env = TextEnvironment( + self.gpt2_model, + self.gpt2_tokenizer, + tools={"dummy": lambda x: "a" * 1000}, + reward_fn=lambda x: torch.tensor(1), + prompt="I am a prompt!\n", + ) + + env.max_tool_response = 100 + history = env.step(TextHistory("Hello there!", torch.tensor([1, 2, 3]))) + self.assertEqual(len(history.last_text_segment) - len(env.response_token), 100) + + env.max_tool_response = 500 + history = env.step(TextHistory("Hello there!", torch.tensor([1, 2, 3]))) + self.assertEqual(len(history.last_text_segment) - len(env.response_token), 500) + + env.max_tool_response = 1001 + history = env.step(TextHistory("Hello there!", torch.tensor([1, 2, 3]))) + self.assertEqual(len(history.last_text_segment) - len(env.response_token), 1000) + + env.max_tool_response = 2000 + history = env.step(TextHistory("Hello there!", torch.tensor([1, 2, 3]))) + self.assertEqual(len(history.last_text_segment) - len(env.response_token), 1000) + + @patch.object(TextEnvironment, "generate", side_effect=dummy_generate) + def test_text_environment_max_calls(self, mock_generate): + env = TextEnvironment( + self.gpt2_model, + self.gpt2_tokenizer, + tools={"DummyTool": DummyTool()}, + reward_fn=lambda x: [torch.tensor(1) for _ in x], + prompt="I am a prompt!\n", + ) + + env.max_turns = 1 + _, _, _, _, histories = env.run(["test"]) + self.assertEqual( + histories[0].text, "I am a prompt!\n" + "test" + 1 * "testtest" + ) + + env.max_turns = 2 + _, _, _, _, histories = env.run(["test"]) + self.assertEqual( + histories[0].text, "I am a prompt!\n" + "test" + 2 * "testtest" + ) + + env.max_turns = 4 + _, _, _, _, histories = env.run(["test"]) + self.assertEqual( + histories[0].text, "I am a prompt!\n" + "test" + 4 * "testtest" + ) + + def test_text_environment_compute_rewards(self): + env = TextEnvironment( + self.gpt2_model, + self.gpt2_tokenizer, + tools={"DummyTool": DummyTool()}, + reward_fn=lambda x: [torch.tensor(i) for i, _ in enumerate(x)], + prompt="I am a prompt!\n", + ) + + histories = [TextHistory("test", torch.tensor([1, 2, 3])) for _ in range(8)] + histories = env.compute_reward(histories) + + for i in range(8): + self.assertEqual(histories[i].reward, i) + + @patch.object(TextEnvironment, "generate", side_effect=dummy_generate) + def test_text_environment_run(self, mock_generate): + env = TextEnvironment( + self.gpt2_model, + self.gpt2_tokenizer, + tools={"DummyTool": DummyTool()}, + reward_fn=lambda x: [torch.tensor(i) for i, _ in enumerate(x)], + prompt="I am a prompt!\n", + max_turns=2, + ) + task_1 = "Hello there!" + task_2 = "Hello there! General Kenobi!" + + query, response, response_mask, reward, histories = env.run([task_1, task_2]) + self.assertEqual(len(query[0]), 9) + self.assertEqual(len(query[1]), 12) + self.assertEqual(len(response[0]), 14) + self.assertEqual(len(response[1]), 14) + self.assertEqual(response_mask[0].sum(), 2 * 3) # mocked generate always adds 3 toknes + self.assertEqual(response_mask[1].sum(), 2 * 3) # mocked generate always adds 3 toknes + self.assertEqual(reward[0], 0) + self.assertEqual(reward[1], 1) + self.assertEqual( + histories[0].text, "I am a prompt!\n" + "Hello there!" + 2 * "testtest" + ) + self.assertEqual( + histories[1].text, + "I am a prompt!\n" + "Hello there! General Kenobi!" + 2 * "testtest", + ) diff --git a/trl/trl/tests/test_iterative_sft_trainer.py b/trl/trl/tests/test_iterative_sft_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..70d5640795e7c74a4a51b338f226758bd88fbd39 --- /dev/null +++ b/trl/trl/tests/test_iterative_sft_trainer.py @@ -0,0 +1,106 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tempfile +import unittest + +import torch +from datasets import Dataset +from parameterized import parameterized +from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments + +from trl import IterativeSFTTrainer + + +class IterativeTrainerTester(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab" + cls.model = AutoModelForCausalLM.from_pretrained(cls.model_id) + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_id) + cls.tokenizer.pad_token = cls.tokenizer.eos_token + + # get t5 as seq2seq example: + model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration-correct-vocab" + cls.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id) + cls.t5_tokenizer = AutoTokenizer.from_pretrained(model_id) + + def _init_tensor_dummy_dataset(self): + dummy_dataset_dict = { + "input_ids": [torch.tensor([5303, 3621]), torch.tensor([3666, 1438, 318]), torch.tensor([5303, 3621])], + "attention_mask": [torch.tensor([1, 1]), torch.tensor([1, 1, 1]), torch.tensor([1, 1])], + "labels": [torch.tensor([5303, 3621]), torch.tensor([3666, 1438, 318]), torch.tensor([5303, 3621])], + } + + dummy_dataset = Dataset.from_dict(dummy_dataset_dict) + dummy_dataset.set_format("torch") + return dummy_dataset + + def _init_textual_dummy_dataset(self): + dummy_dataset_dict = { + "texts": ["Testing the IterativeSFTTrainer.", "This is a test of the IterativeSFTTrainer"], + "texts_labels": ["Testing the IterativeSFTTrainer.", "This is a test of the IterativeSFTTrainer"], + } + + dummy_dataset = Dataset.from_dict(dummy_dataset_dict) + dummy_dataset.set_format("torch") + return dummy_dataset + + def setUp(self): + # initialize trainer + self.model.train() + return super().setUp() + + @parameterized.expand( + [ + ["gpt2", "tensor"], + ["gpt2", "text"], + ["t5", "tensor"], + ["t5", "text"], + ] + ) + def test_iterative_step_from_tensor(self, model_name, input_name): + with tempfile.TemporaryDirectory() as tmp_dir: + # initialize dataset + if input_name == "tensor": + dummy_dataset = self._init_tensor_dummy_dataset() + inputs = { + "input_ids": dummy_dataset["input_ids"], + "attention_mask": dummy_dataset["attention_mask"], + "labels": dummy_dataset["labels"], + } + else: + dummy_dataset = self._init_textual_dummy_dataset() + inputs = { + "texts": dummy_dataset["texts"], + "texts_labels": dummy_dataset["texts_labels"], + } + + if model_name == "gpt2": + model = self.model + tokenizer = self.tokenizer + else: + model = self.t5_model + tokenizer = self.t5_tokenizer + + args = TrainingArguments( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=2, + ) + iterative_trainer = IterativeSFTTrainer(model=model, args=args, tokenizer=tokenizer) + + iterative_trainer.step(**inputs) + + for param in iterative_trainer.model.parameters(): + assert param.grad is not None diff --git a/trl/trl/tests/test_modeling_value_head.py b/trl/trl/tests/test_modeling_value_head.py new file mode 100644 index 0000000000000000000000000000000000000000..b1fdae0233225bbecd2e0fb7efa797dd9c54d961 --- /dev/null +++ b/trl/trl/tests/test_modeling_value_head.py @@ -0,0 +1,517 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import gc +import tempfile +import unittest + +import torch +from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM + +from trl import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead, create_reference_model + + +ALL_CAUSAL_LM_MODELS = [ + "trl-internal-testing/tiny-random-CodeGenForCausalLM", + "trl-internal-testing/tiny-random-GPTJForCausalLM", + "trl-internal-testing/tiny-random-GPTNeoForCausalLM", + "trl-internal-testing/tiny-random-GPTNeoXForCausalLM", + "trl-internal-testing/tiny-random-OPTForCausalLM", + "trl-internal-testing/tiny-random-BloomForCausalLM", + "trl-internal-testing/tiny-random-GPT2LMHeadModel", + "trl-internal-testing/tiny-random-CodeGenForCausalLM-sharded", + "trl-internal-testing/tiny-random-GPTNeoXForCausalLM-safetensors-sharded", + "trl-internal-testing/tiny-random-GPTNeoXForCausalLM-safetensors" + # "trl-internal-testing/tiny-random-LlamaForCausalLM", uncomment on the next transformers release +] + +ALL_SEQ2SEQ_MODELS = [ + "trl-internal-testing/tiny-random-BartForConditionalGeneration", + "trl-internal-testing/tiny-random-BigBirdPegasusForConditionalGeneration", + "trl-internal-testing/tiny-random-BlenderbotForConditionalGeneration", + "trl-internal-testing/tiny-random-BlenderbotSmallForConditionalGeneration", + "trl-internal-testing/tiny-random-FSMTForConditionalGeneration", + "trl-internal-testing/tiny-random-LEDForConditionalGeneration", + "trl-internal-testing/tiny-random-LongT5ForConditionalGeneration", + "trl-internal-testing/tiny-random-M2M100ForConditionalGeneration", + "trl-internal-testing/tiny-random-MarianMTModel", + "trl-internal-testing/tiny-random-MBartForConditionalGeneration", + "trl-internal-testing/tiny-random-MT5ForConditionalGeneration", + "trl-internal-testing/tiny-random-MvpForConditionalGeneration", + "trl-internal-testing/tiny-random-PegasusForConditionalGeneration", + "trl-internal-testing/tiny-random-PegasusXForConditionalGeneration", + "trl-internal-testing/tiny-random-PLBartForConditionalGeneration", + "trl-internal-testing/tiny-random-ProphetNetForConditionalGeneration", + "trl-internal-testing/tiny-random-SwitchTransformersForConditionalGeneration", + "trl-internal-testing/tiny-random-T5ForConditionalGeneration", +] + + +class VHeadModelTester: + all_model_names = None + trl_model_class = None + transformers_model_class = None + + def test_value_head(self): + r""" + Test if the v-head is added to the model successfully + """ + for model_name in self.all_model_names: + model = self.trl_model_class.from_pretrained(model_name) + self.assertTrue(hasattr(model, "v_head")) + + def test_value_head_shape(self): + r""" + Test if the v-head has the correct shape + """ + for model_name in self.all_model_names: + model = self.trl_model_class.from_pretrained(model_name) + self.assertTrue(model.v_head.summary.weight.shape[0] == 1) + + def test_value_head_init_random(self): + r""" + Test if the v-head has been randomly initialized. + We can check that by making sure the bias is different + than zeros by default. + """ + for model_name in self.all_model_names: + model = self.trl_model_class.from_pretrained(model_name) + self.assertFalse(torch.allclose(model.v_head.summary.bias, torch.zeros_like(model.v_head.summary.bias))) + + def test_value_head_not_str(self): + r""" + Test if the v-head is added to the model successfully, by passing a non `PretrainedModel` + as an argument to `from_pretrained`. + """ + for model_name in self.all_model_names: + pretrained_model = self.transformers_model_class.from_pretrained(model_name) + model = self.trl_model_class.from_pretrained(pretrained_model) + self.assertTrue(hasattr(model, "v_head")) + + def test_from_save_trl(self): + """ + Test if the model can be saved and loaded from a directory and get the same weights + Including the additional modules (e.g. v_head) + """ + for model_name in self.all_model_names: + model = self.trl_model_class.from_pretrained(model_name) + + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir) + + model_from_save = self.trl_model_class.from_pretrained(tmp_dir) + + # Check if the weights are the same + for key in model_from_save.state_dict(): + self.assertTrue(torch.allclose(model_from_save.state_dict()[key], model.state_dict()[key])) + + def test_from_save_trl_sharded(self): + """ + Test if the model can be saved and loaded from a directory and get the same weights - sharded case + """ + for model_name in self.all_model_names: + model = self.trl_model_class.from_pretrained(model_name) + + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir) + + model_from_save = self.trl_model_class.from_pretrained(tmp_dir) + + # Check if the weights are the same + for key in model_from_save.state_dict(): + self.assertTrue(torch.allclose(model_from_save.state_dict()[key], model.state_dict()[key])) + + def test_from_save_transformers_sharded(self): + """ + Test if the model can be saved and loaded using transformers and get the same weights - sharded case + """ + for model_name in self.all_model_names: + transformers_model = self.trl_model_class.transformers_parent_class.from_pretrained(model_name) + + trl_model = self.trl_model_class.from_pretrained(model_name) + + with tempfile.TemporaryDirectory() as tmp_dir: + trl_model.save_pretrained(tmp_dir, max_shard_size="1MB") + transformers_model_from_save = self.trl_model_class.transformers_parent_class.from_pretrained(tmp_dir) + + # Check if the weights are the same + for key in transformers_model.state_dict(): + self.assertTrue( + torch.allclose( + transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key] + ) + ) + + def test_from_save_transformers(self): + """ + Test if the model can be saved and loaded using transformers and get the same weights. + We override the test of the super class to check if the weights are the same. + """ + for model_name in self.all_model_names: + transformers_model = self.trl_model_class.transformers_parent_class.from_pretrained(model_name) + + trl_model = self.trl_model_class.from_pretrained(model_name) + + with tempfile.TemporaryDirectory() as tmp_dir: + trl_model.save_pretrained(tmp_dir) + transformers_model_from_save = self.trl_model_class.transformers_parent_class.from_pretrained(tmp_dir) + + # Check if the weights are the same + for key in transformers_model.state_dict(): + self.assertTrue( + torch.allclose( + transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key] + ) + ) + + # Check if the trl model has the same keys as the transformers model + # except the v_head + for key in trl_model.state_dict(): + if "v_head" not in key: + self.assertTrue(key in transformers_model.state_dict()) + # check if the weights are the same + self.assertTrue(torch.allclose(trl_model.state_dict()[key], transformers_model.state_dict()[key])) + + # check if they have the same modules + self.assertTrue( + set(transformers_model_from_save.state_dict().keys()) == set(transformers_model.state_dict().keys()) + ) + + +class CausalLMValueHeadModelTester(VHeadModelTester, unittest.TestCase): + """ + Testing suite for v-head models. + """ + + all_model_names = ALL_CAUSAL_LM_MODELS + trl_model_class = AutoModelForCausalLMWithValueHead + transformers_model_class = AutoModelForCausalLM + + def tearDown(self): + # free memory + gc.collect() + + def test_inference(self): + r""" + Test if the model can be used for inference and outputs 3 values + - logits, loss, and value states + """ + EXPECTED_OUTPUT_SIZE = 3 + + for model_name in self.all_model_names: + model = self.trl_model_class.from_pretrained(model_name) + input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) + outputs = model(input_ids) + + # Check if the outputs are of the right size - here + # we always output 3 values - logits, loss, and value states + self.assertEqual(len(outputs), EXPECTED_OUTPUT_SIZE) + + def test_dropout_config(self): + r""" + Test if we instantiate a model by adding `summary_drop_prob` to the config + it will be added to the v_head + """ + for model_name in self.all_model_names: + pretrained_model = self.transformers_model_class.from_pretrained(model_name) + pretrained_model.config.summary_dropout_prob = 0.5 + model = self.trl_model_class.from_pretrained(pretrained_model) + + # Check if v head of the model has the same dropout as the config + self.assertEqual(model.v_head.dropout.p, pretrained_model.config.summary_dropout_prob) + + def test_dropout_kwargs(self): + r""" + Test if we instantiate a model by adding `summary_drop_prob` to the config + it will be added to the v_head + """ + for model_name in self.all_model_names: + v_head_kwargs = {"summary_dropout_prob": 0.5} + + model = self.trl_model_class.from_pretrained(model_name, **v_head_kwargs) + + # Check if v head of the model has the same dropout as the config + self.assertEqual(model.v_head.dropout.p, 0.5) + + model = self.trl_model_class.from_pretrained(model_name, summary_dropout_prob=0.5) + + # Check if v head of the model has the same dropout as the config + self.assertEqual(model.v_head.dropout.p, 0.5) + + def test_generate(self): + r""" + Test if `generate` works for every model + """ + for model_name in self.all_model_names: + model = self.trl_model_class.from_pretrained(model_name) + input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) + + # Just check if the generation works + _ = model.generate(input_ids) + + def test_raise_error_not_causallm(self): + # Test with a model without a LM head + model_id = "trl-internal-testing/tiny-random-GPT2Model" + # This should raise a ValueError + with self.assertRaises(ValueError): + pretrained_model = AutoModelForCausalLM.from_pretrained(model_id) + _ = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model.transformer) + + def test_transformers_bf16_kwargs(self): + r""" + Test if the transformers kwargs are correctly passed + Here we check that loading a model in half precision works as expected, i.e. the weights of + the `pretrained_model` attribute is loaded in half precision and you can run a dummy + forward pass without any issue. + """ + for model_name in self.all_model_names: + trl_model = self.trl_model_class.from_pretrained(model_name, torch_dtype=torch.bfloat16) + + lm_head_namings = self.trl_model_class.lm_head_namings + + self.assertTrue( + any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings) + ) + + for lm_head_naming in lm_head_namings: + if hasattr(trl_model.pretrained_model, lm_head_naming): + self.assertTrue(getattr(trl_model.pretrained_model, lm_head_naming).weight.dtype == torch.bfloat16) + + dummy_input = torch.LongTensor([[0, 1, 0, 1]]) + + # check dummy forward pass works in half precision + _ = trl_model(dummy_input) + + @unittest.skip("This test needs to be run manually due to HF token issue.") + def test_push_to_hub(self): + for model_name in self.all_model_names: + model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name) + if "sharded" in model_name: + model.push_to_hub(model_name + "-ppo", use_auth_token=True, max_shard_size="1MB") + else: + model.push_to_hub(model_name + "-ppo", use_auth_token=True) + + model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(model_name + "-ppo") + # check all keys + self.assertEqual(model.state_dict().keys(), model_from_pretrained.state_dict().keys()) + + for name, param in model.state_dict().items(): + self.assertTrue( + torch.allclose(param, model_from_pretrained.state_dict()[name]), + f"Parameter {name} is not the same after push_to_hub and from_pretrained", + ) + + +class Seq2SeqValueHeadModelTester(VHeadModelTester, unittest.TestCase): + """ + Testing suite for v-head models. + """ + + all_model_names = ALL_SEQ2SEQ_MODELS + trl_model_class = AutoModelForSeq2SeqLMWithValueHead + transformers_model_class = AutoModelForSeq2SeqLM + + def tearDown(self): + # free memory + gc.collect() + + def test_inference(self): + r""" + Test if the model can be used for inference and outputs 3 values + - logits, loss, and value states + """ + EXPECTED_OUTPUT_SIZE = 3 + + for model_name in self.all_model_names: + model = self.trl_model_class.from_pretrained(model_name) + input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) + decoder_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) + outputs = model(input_ids, decoder_input_ids=decoder_input_ids) + + # Check if the outputs are of the right size - here + # we always output 3 values - logits, loss, and value states + self.assertEqual(len(outputs), EXPECTED_OUTPUT_SIZE) + + def test_dropout_config(self): + r""" + Test if we instantiate a model by adding `summary_drop_prob` to the config + it will be added to the v_head + """ + for model_name in self.all_model_names: + pretrained_model = self.transformers_model_class.from_pretrained(model_name) + pretrained_model.config.summary_dropout_prob = 0.5 + model = self.trl_model_class.from_pretrained(pretrained_model) + + # Check if v head of the model has the same dropout as the config + self.assertEqual(model.v_head.dropout.p, pretrained_model.config.summary_dropout_prob) + + def test_dropout_kwargs(self): + r""" + Test if we instantiate a model by adding `summary_drop_prob` to the config + it will be added to the v_head + """ + for model_name in self.all_model_names: + v_head_kwargs = {"summary_dropout_prob": 0.5} + + model = self.trl_model_class.from_pretrained(model_name, **v_head_kwargs) + + # Check if v head of the model has the same dropout as the config + self.assertEqual(model.v_head.dropout.p, 0.5) + + model = self.trl_model_class.from_pretrained(model_name, summary_dropout_prob=0.5) + + # Check if v head of the model has the same dropout as the config + self.assertEqual(model.v_head.dropout.p, 0.5) + + def test_generate(self): + r""" + Test if `generate` works for every model + """ + for model_name in self.all_model_names: + model = self.trl_model_class.from_pretrained(model_name) + input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) + decoder_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) + + # Just check if the generation works + _ = model.generate(input_ids, decoder_input_ids=decoder_input_ids) + + def test_raise_error_not_causallm(self): + # Test with a model without a LM head + model_id = "trl-internal-testing/tiny-random-T5Model" + # This should raise a ValueError + with self.assertRaises(ValueError): + pretrained_model = AutoModel.from_pretrained(model_id) + _ = self.trl_model_class.from_pretrained(pretrained_model) + + @unittest.skip("This test needs to be run manually due to HF token issue.") + def test_push_to_hub(self): + for model_name in self.all_model_names: + model = self.trl_model_class.from_pretrained(model_name) + if "sharded" in model_name: + model.push_to_hub(model_name + "-ppo", use_auth_token=True, max_shard_size="1MB") + else: + model.push_to_hub(model_name + "-ppo", use_auth_token=True) + + model_from_pretrained = self.trl_model_class.from_pretrained(model_name + "-ppo") + # check all keys + self.assertEqual(model.state_dict().keys(), model_from_pretrained.state_dict().keys()) + + for name, param in model.state_dict().items(): + self.assertTrue( + torch.allclose(param, model_from_pretrained.state_dict()[name]), + f"Parameter {name} is not the same after push_to_hub and from_pretrained", + ) + + def test_transformers_bf16_kwargs(self): + r""" + Test if the transformers kwargs are correctly passed + Here we check that loading a model in half precision works as expected, i.e. the weights of + the `pretrained_model` attribute is loaded in half precision and you can run a dummy + forward pass without any issue. + """ + for model_name in self.all_model_names: + trl_model = self.trl_model_class.from_pretrained(model_name, torch_dtype=torch.bfloat16) + + lm_head_namings = self.trl_model_class.lm_head_namings + + if model_name == "trl-internal-testing/tiny-random-FSMTForConditionalGeneration": + # skip the test for FSMT as it does not support mixed-prec + continue + + self.assertTrue( + any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings) + ) + + for lm_head_naming in lm_head_namings: + if hasattr(trl_model.pretrained_model, lm_head_naming): + self.assertTrue(getattr(trl_model.pretrained_model, lm_head_naming).weight.dtype == torch.bfloat16) + + dummy_input = torch.LongTensor([[0, 1, 0, 1]]) + + # check dummy forward pass works in half precision + _ = trl_model(input_ids=dummy_input, decoder_input_ids=dummy_input) + + +class ReferenceModelTest(unittest.TestCase): + def setUp(self): + self.model = AutoModelForCausalLMWithValueHead.from_pretrained( + "trl-internal-testing/tiny-random-GPT2LMHeadModel" + ) + self.test_input = torch.tensor([[0, 1, 2, 3]]) + self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=1) + self.layer_format = "pretrained_model.transformer.h.{layer}.attn.c_attn.weight" + + def test_independent_reference(self): + layer_0 = self.layer_format.format(layer=0) + layer_5 = self.layer_format.format(layer=4) + + ref_model = create_reference_model(self.model) + + first_layer_before = self.model.get_parameter(layer_0).data.clone() + last_layer_before = self.model.get_parameter(layer_5).data.clone() + + first_ref_layer_before = ref_model.get_parameter(layer_0).data.clone() + last_ref_layer_before = ref_model.get_parameter(layer_5).data.clone() + + output = self.model(input_ids=self.test_input, labels=self.test_input) + output[1].backward() + self.optimizer.step() + + first_layer_after = self.model.get_parameter(layer_0).data.clone() + last_layer_after = self.model.get_parameter(layer_5).data.clone() + + first_ref_layer_after = ref_model.get_parameter(layer_0).data.clone() + last_ref_layer_after = ref_model.get_parameter(layer_5).data.clone() + + # before optimization ref and model are identical + self.assertTrue((first_layer_before == first_ref_layer_before).all()) + self.assertTrue((last_layer_before == last_ref_layer_before).all()) + # ref model stays identical after optimization + self.assertTrue((first_ref_layer_before == first_ref_layer_after).all()) + self.assertTrue((last_ref_layer_before == last_ref_layer_after).all()) + # optimized model changes + self.assertTrue(not (first_layer_before == first_layer_after).all()) + self.assertTrue(not (last_layer_before == last_layer_after).all()) + + def test_shared_layers(self): + layer_0 = self.layer_format.format(layer=0) + layer_1 = self.layer_format.format(layer=1) + + ref_model = create_reference_model(self.model, num_shared_layers=1) + + first_layer_before = self.model.get_parameter(layer_0).data.clone() + second_layer_before = self.model.get_parameter(layer_1).data.clone() + + first_ref_layer_before = ref_model.get_parameter(layer_0).data.clone() + second_ref_layer_before = ref_model.get_parameter(layer_1).data.clone() + + output = self.model(input_ids=self.test_input, labels=self.test_input) + output[1].backward() + self.optimizer.step() + + first_layer_after = self.model.get_parameter(layer_0).data.clone() + second_layer_after = self.model.get_parameter(layer_1).data.clone() + + first_ref_layer_after = ref_model.get_parameter(layer_0).data.clone() + second_ref_layer_after = ref_model.get_parameter(layer_1).data.clone() + + # before optimization ref and model are identical + self.assertTrue((first_layer_before == first_ref_layer_before).all()) + self.assertTrue((second_layer_before == second_ref_layer_before).all()) + # ref model stays identical after optimization + self.assertTrue((first_ref_layer_before == first_ref_layer_after).all()) + self.assertTrue((second_ref_layer_before == second_ref_layer_after).all()) + # first layer of optimized model stays the same + self.assertTrue((first_layer_before == first_layer_after).all()) + # other layers in optimized model change + self.assertTrue(not (second_layer_before == second_layer_after).all()) diff --git a/trl/trl/tests/test_no_peft.py b/trl/trl/tests/test_no_peft.py new file mode 100644 index 0000000000000000000000000000000000000000..3190b7c85a57550f898133a72c3c7aa7455f0a6f --- /dev/null +++ b/trl/trl/tests/test_no_peft.py @@ -0,0 +1,153 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +import unittest +from unittest.mock import patch + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from .testing_utils import is_peft_available, require_peft + + +class DummyDataset(torch.utils.data.Dataset): + def __init__(self, query_data, response_data): + self.query_data = query_data + self.response_data = response_data + + def __len__(self): + return len(self.query_data) + + def __getitem__(self, idx): + return self.query_data[idx], self.response_data[idx] + + +EXPECTED_STATS = [ + "objective/kl", + "objective/kl_dist", + "objective/logprobs", + "objective/ref_logprobs", + "objective/kl_coef", + "objective/entropy", + "ppo/mean_non_score_reward", + "ppo/loss/policy", + "ppo/loss/value", + "ppo/loss/total", + "ppo/policy/entropy", + "ppo/policy/approxkl", + "ppo/policy/policykl", + "ppo/policy/clipfrac", + "ppo/policy/advantages", + "ppo/policy/advantages_mean", + "ppo/policy/ratio", + "ppo/returns/mean", + "ppo/returns/var", + "ppo/val/vpred", + "ppo/val/error", + "ppo/val/clipfrac", + "ppo/val/mean", + "ppo/val/var", + "ppo/val/var_explained", + "time/ppo/forward_pass", + "time/ppo/compute_rewards", + "time/ppo/optimize_step", + "time/ppo/calc_stats", + "time/ppo/total", + "ppo/learning_rate", +] + + +@require_peft +class TestPeftDependancy(unittest.TestCase): + def setUp(self): + self.causal_lm_model_id = "trl-internal-testing/tiny-random-GPTNeoXForCausalLM" + self.seq_to_seq_model_id = "trl-internal-testing/tiny-random-T5ForConditionalGeneration" + + if is_peft_available(): + from peft import LoraConfig, get_peft_model + + lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id) + self.peft_model = get_peft_model(causal_lm_model, lora_config) + + def test_no_peft(self): + with patch.dict(sys.modules, {"peft": None}): + from trl import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead + + # Check that loading a model with `peft` will raise an error + with self.assertRaises(ModuleNotFoundError): + import peft # noqa + + trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(self.causal_lm_model_id) # noqa + trl_seq2seq_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(self.seq_to_seq_model_id) # noqa + + def test_imports_no_peft(self): + with patch.dict(sys.modules, {"peft": None}): + from trl import ( # noqa + AutoModelForCausalLMWithValueHead, + AutoModelForSeq2SeqLMWithValueHead, + PPOConfig, + PPOTrainer, + PreTrainedModelWrapper, + ) + + def test_ppo_trainer_no_peft(self): + with patch.dict(sys.modules, {"peft": None}): + from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer + + ppo_model_id = "trl-internal-testing/dummy-GPT2-correct-vocab" + + trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(ppo_model_id) + tokenizer = AutoTokenizer.from_pretrained(ppo_model_id) + tokenizer.pad_token_id = tokenizer.eos_token_id + + ppo_config = PPOConfig(batch_size=2, mini_batch_size=1, log_with=None) + + dummy_dataset = DummyDataset( + [torch.LongTensor([0, 1, 0, 1, 0, 1]), torch.LongTensor([0, 1, 0, 1, 0, 1])], + [torch.LongTensor([1, 0, 1, 0, 1, 0]), torch.LongTensor([0, 1, 0, 1, 0, 1])], + ) + + ppo_trainer = PPOTrainer( + config=ppo_config, + model=trl_model, + ref_model=None, + tokenizer=tokenizer, + dataset=dummy_dataset, + ) + dummy_dataloader = ppo_trainer.dataloader + + for query_tensor, response_tensor in dummy_dataloader: + # define a reward for response + # (this could be any reward such as human feedback or output from another model) + reward = [torch.tensor(1.0), torch.tensor(0.0)] + # train model + train_stats = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + break + + # check gradients are not None + for _, param in trl_model.named_parameters(): + if param.requires_grad: + self.assertIsNotNone(param.grad) + + # check expected stats + for stat in EXPECTED_STATS: + self.assertIn(stat, train_stats) diff --git a/trl/trl/tests/test_peft_models.py b/trl/trl/tests/test_peft_models.py new file mode 100644 index 0000000000000000000000000000000000000000..3b004659d2b7160f62d0c38cc97e6b3284dff475 --- /dev/null +++ b/trl/trl/tests/test_peft_models.py @@ -0,0 +1,208 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import tempfile +import unittest + +import torch +from pytest import mark +from transformers import AutoModelForCausalLM + +from trl import AutoModelForCausalLMWithValueHead, is_peft_available + + +if is_peft_available(): + from peft import get_peft_model, LoraConfig + +from .testing_utils import require_bitsandbytes, require_peft + + +@require_peft +@mark.peft_test +class PeftModelTester(unittest.TestCase): + def setUp(self): + self.causal_lm_model_id = "trl-internal-testing/tiny-random-GPTNeoXForCausalLM" + self.lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + def test_create_peft_model(self): + r""" + Simply creates a peft model and checks that it can be loaded. + """ + causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id) + pretrained_model = get_peft_model(causal_lm_model, self.lora_config) + + _ = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model) + + def test_peft_requires_grad(self): + r""" + Check that the value head of the returned model has requires_grad=True. + """ + causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id) + pretrained_model = get_peft_model(causal_lm_model, self.lora_config) + + model = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model) + + # Check that the value head has requires_grad=True + self.assertTrue(model.v_head.summary.weight.requires_grad) + + def test_check_peft_model_nb_trainable_params(self): + r""" + Check that the number of trainable parameters is correct. + """ + causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id) + pretrained_model = get_peft_model(causal_lm_model, self.lora_config) + + model = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model) + + # Check that the number of trainable parameters is correct + nb_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + self.assertEqual(nb_trainable_params, 10273) + + # Check that the number of trainable param for the non-peft model is correct + non_peft_model = AutoModelForCausalLMWithValueHead.from_pretrained(self.causal_lm_model_id) + nb_trainable_params = sum(p.numel() for p in non_peft_model.parameters() if p.requires_grad) + self.assertEqual(nb_trainable_params, 99578) + + def test_create_peft_model_from_config(self): + r""" + Simply creates a peft model and checks that it can be loaded. + """ + trl_model = AutoModelForCausalLMWithValueHead.from_pretrained( + self.causal_lm_model_id, peft_config=self.lora_config + ) + # Check that the number of trainable parameters is correct + nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad) + self.assertEqual(nb_trainable_params, 10273) + + causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id) + trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(causal_lm_model, peft_config=self.lora_config) + # Check that the number of trainable parameters is correct + nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad) + self.assertEqual(nb_trainable_params, 10273) + + @require_bitsandbytes + def test_create_bnb_peft_model_from_config(self): + r""" + Simply creates a peft model and checks that it can be loaded. + """ + from bitsandbytes.nn import Linear8bitLt + + trl_model = AutoModelForCausalLMWithValueHead.from_pretrained( + self.causal_lm_model_id, peft_config=self.lora_config, load_in_8bit=True + ) + # Check that the number of trainable parameters is correct + nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad) + self.assertEqual(nb_trainable_params, 10273) + self.assertTrue( + trl_model.pretrained_model.model.gpt_neox.layers[0].mlp.dense_h_to_4h.__class__ == Linear8bitLt + ) + + causal_lm_model = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, load_in_8bit=True, device_map="auto" + ) + trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(causal_lm_model, peft_config=self.lora_config) + # Check that the number of trainable parameters is correct + nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad) + self.assertEqual(nb_trainable_params, 10273) + self.assertTrue( + trl_model.pretrained_model.model.gpt_neox.layers[0].mlp.dense_h_to_4h.__class__ == Linear8bitLt + ) + + def test_save_pretrained_peft(self): + r""" + Check that the model can be saved and loaded properly. + """ + causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id) + pretrained_model = get_peft_model(causal_lm_model, self.lora_config) + + model = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model) + + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir) + + # check that the files `adapter_model.safetensors` and `adapter_config.json` are in the directory + self.assertTrue( + os.path.isfile(f"{tmp_dir}/adapter_model.safetensors"), + msg=f"{tmp_dir}/adapter_model.safetensors does not exist", + ) + self.assertTrue( + os.path.exists(f"{tmp_dir}/adapter_config.json"), + msg=f"{tmp_dir}/adapter_config.json does not exist", + ) + # check also for `pytorch_model.bin` and make sure it only contains `v_head` weights + self.assertTrue( + os.path.exists(f"{tmp_dir}/pytorch_model.bin"), + msg=f"{tmp_dir}/pytorch_model.bin does not exist", + ) + maybe_v_head = torch.load(f"{tmp_dir}/pytorch_model.bin") + # check that only keys that starts with `v_head` are in the dict + self.assertTrue( + all(k.startswith("v_head") for k in maybe_v_head.keys()), + msg=f"keys in {tmp_dir}/pytorch_model.bin do not start with `v_head`", + ) + + model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(tmp_dir) + + # check all the weights are the same + for p1, p2 in zip(model.named_parameters(), model_from_pretrained.named_parameters()): + self.assertTrue(torch.allclose(p1[1], p2[1]), msg=f"{p1[0]} != {p2[0]}") + + def test_load_pretrained_peft(self): + r""" + Check that the model saved with peft class interface can be loaded properly. + """ + causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id) + pretrained_model = get_peft_model(causal_lm_model, self.lora_config) + + model = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model) + + with tempfile.TemporaryDirectory() as tmp_dir: + pretrained_model.save_pretrained(tmp_dir) + model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(tmp_dir) + + # check that the files `adapter_model.safetensors` and `adapter_config.json` are in the directory + self.assertTrue( + os.path.isfile(f"{tmp_dir}/adapter_model.safetensors"), + msg=f"{tmp_dir}/adapter_model.safetensors does not exist", + ) + self.assertTrue( + os.path.exists(f"{tmp_dir}/adapter_config.json"), + msg=f"{tmp_dir}/adapter_config.json does not exist", + ) + + # check all the weights are the same + for p1, p2 in zip(model.named_parameters(), model_from_pretrained.named_parameters()): + if p1[0] not in ["v_head.summary.weight", "v_head.summary.bias"]: + self.assertTrue(torch.allclose(p1[1], p2[1]), msg=f"{p1[0]} != {p2[0]}") + + def test_continue_training_peft_model(self): + r""" + Load peft and checks that it can continue training. + """ + causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id) + pretrained_model = get_peft_model(causal_lm_model, self.lora_config) + + with tempfile.TemporaryDirectory() as tmp_dir: + pretrained_model.save_pretrained(tmp_dir) + # set is_trainable to True + model = AutoModelForCausalLMWithValueHead.from_pretrained(tmp_dir, is_trainable=True) + # Check that the number of trainable parameters is correct + nb_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + self.assertEqual(nb_trainable_params, 10273) diff --git a/trl/trl/tests/test_ppo_trainer.py b/trl/trl/tests/test_ppo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..0af091fc3c24b720e4611c71c4ed91dc16ad6582 --- /dev/null +++ b/trl/trl/tests/test_ppo_trainer.py @@ -0,0 +1,1232 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import fnmatch +import gc +import re +import tempfile +import unittest + +import pytest +import torch +from huggingface_hub import HfApi, HfFolder, delete_repo +from parameterized import parameterized +from pytest import mark +from requests.exceptions import HTTPError +from transformers import AutoTokenizer + +from trl import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead, PPOConfig, PPOTrainer, set_seed +from trl.core import respond_to_batch + +from .testing_constants import CI_HUB_ENDPOINT, CI_HUB_USER, CI_HUB_USER_TOKEN +from .testing_utils import require_peft, require_torch_multi_gpu + + +EXPECTED_STATS = [ + "objective/kl", + "objective/kl_dist", + "objective/logprobs", + "objective/ref_logprobs", + "objective/kl_coef", + "objective/entropy", + "ppo/mean_non_score_reward", + "ppo/loss/policy", + "ppo/loss/value", + "ppo/loss/total", + "ppo/policy/entropy", + "ppo/policy/approxkl", + "ppo/policy/policykl", + "ppo/policy/clipfrac", + "ppo/policy/advantages", + "ppo/policy/advantages_mean", + "ppo/policy/ratio", + "ppo/returns/mean", + "ppo/returns/var", + "ppo/val/vpred", + "ppo/val/error", + "ppo/val/clipfrac", + "ppo/val/mean", + "ppo/val/var", + "ppo/val/var_explained", + "time/ppo/forward_pass", + "time/ppo/compute_rewards", + "time/ppo/optimize_step", + "time/ppo/calc_stats", + "time/ppo/total", + "ppo/learning_rate", +] + + +class DummyDataset(torch.utils.data.Dataset): + def __init__(self, query_data, response_data): + self.query_data = query_data + self.response_data = response_data + + def __len__(self): + return len(self.query_data) + + def __getitem__(self, idx): + return self.query_data[idx], self.response_data[idx] + + +def apply_mask(values, mask): + unmasked_values = [] + for v, m in zip(values, mask): + if m == 1: + unmasked_values.append(v) + return torch.Tensor(unmasked_values) + + +def abs_diff_masked_tensors(tensor_1, tensor_2, mask_1, mask_2): + diffs = [] + for l1, l2, m1, m2 in zip(tensor_1, tensor_2, mask_1, mask_2): + diff = apply_mask(l1, m1) - apply_mask(l2, m2) + diffs.append(diff.sum()) + return abs(sum(diffs)) + + +class PPOTrainerTester(unittest.TestCase): + """ + A wrapper class for testing PPOTrainer + """ + + @classmethod + def setUpClass(cls): + set_seed(42) + cls._token = CI_HUB_USER_TOKEN + cls._api = HfApi(endpoint=CI_HUB_ENDPOINT) + HfFolder.save_token(CI_HUB_USER_TOKEN) + + # model_id + cls.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab" + + # get models and tokenizer + cls.gpt2_model = AutoModelForCausalLMWithValueHead.from_pretrained(cls.model_id) + cls.gpt2_model_ref = AutoModelForCausalLMWithValueHead.from_pretrained(cls.model_id) + cls.gpt2_tokenizer = AutoTokenizer.from_pretrained(cls.model_id) + + cls.gpt2_tokenizer.pad_token = cls.gpt2_tokenizer.eos_token + + # get bloom as right padding examples: + model_id = "trl-internal-testing/tiny-BloomForCausalLM-correct-vocab" + cls.bloom_model = AutoModelForCausalLMWithValueHead.from_pretrained(model_id) + cls.bloom_tokenizer = AutoTokenizer.from_pretrained(model_id) + + model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration-correct-vocab" + cls.t5_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(model_id) + cls.t5_tokenizer = AutoTokenizer.from_pretrained(model_id) + + # initialize trainer + cls.ppo_config = PPOConfig(batch_size=2, mini_batch_size=1, log_with=None) + + @classmethod + def tearDownClass(cls): + for model in [f"{CI_HUB_USER}/test-ppo-trainer"]: + try: + delete_repo(token=cls._token, repo_id=model) + except HTTPError: + pass + + def setUp(self): + # initialize trainer + self.ppo_config = PPOConfig(batch_size=2, mini_batch_size=1, log_with=None) + self.gpt2_model.train() + return super().setUp() + + def tearDown(self): + # free memory + gc.collect() + + def _init_dummy_dataset(self): + # encode a query + query_txt = "This morning I went to the " + query_tensor = self.gpt2_tokenizer.encode(query_txt, return_tensors="pt") + assert query_tensor.shape == (1, 7) + # get model response + response_tensor = respond_to_batch(self.gpt2_model, query_tensor) + assert response_tensor.shape == (1, 20) + + # create a dummy dataset + min_length = min(len(query_tensor[0]), len(response_tensor[0])) + dummy_dataset = DummyDataset( + [query_tensor[:, :min_length].squeeze(0) for _ in range(2)], + [response_tensor[:, :min_length].squeeze(0) for _ in range(2)], + ) + + return dummy_dataset + + def test_drop_last_dataloader(self): + self.ppo_config = PPOConfig(batch_size=3, mini_batch_size=1, log_with=None) + + dummy_dataset = self._init_dummy_dataset() + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=self.gpt2_model_ref, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + dummy_dataloader = ppo_trainer.dataloader + + self.assertEqual(len(dummy_dataloader), 0) + + def test_ppo_step(self): + # initialize dataset + dummy_dataset = self._init_dummy_dataset() + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=self.gpt2_model_ref, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + dummy_dataloader = ppo_trainer.dataloader + # train model with ppo + for query_tensor, response_tensor in dummy_dataloader: + # define a reward for response + # (this could be any reward such as human feedback or output from another model) + reward = [torch.tensor(1.0), torch.tensor(0.0)] + # train model + train_stats = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + break + + for param in ppo_trainer.model.parameters(): + assert param.grad is not None + + for stat in EXPECTED_STATS: + assert stat in train_stats.keys() + + def test_ppo_step_with_masks(self): + # initialize dataset + dummy_dataset = self._init_dummy_dataset() + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=self.gpt2_model_ref, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + dummy_dataloader = ppo_trainer.dataloader + # train model with ppo + for query_tensor, response_tensor in dummy_dataloader: + # define a reward for response + # (this could be any reward such as human feedback or output from another model) + reward = [torch.tensor(1.0), torch.tensor(0.0)] + + response_mask = [torch.ones_like(r) for r in response_tensor] + + # train model + train_stats = ppo_trainer.step( + [q for q in query_tensor], [r for r in response_tensor], reward, response_mask + ) + break + + for param in ppo_trainer.model.parameters(): + assert param.grad is not None + + for stat in EXPECTED_STATS: + assert stat in train_stats.keys() + + def test_ppo_step_with_no_ref_sgd(self): + # initialize dataset + dummy_dataset = self._init_dummy_dataset() + optimizer = torch.optim.SGD(self.gpt2_model.parameters(), lr=0.01) + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=None, + optimizer=optimizer, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + dummy_dataloader = ppo_trainer.dataloader + + self.assertTrue(isinstance(ppo_trainer.optimizer.optimizer, torch.optim.SGD)) + + # train model with ppo + for query_tensor, response_tensor in dummy_dataloader: + # define a reward for response + # (this could be any reward such as human feedback or output from another model) + reward = [torch.tensor(1.0), torch.tensor(0.0)] + # train model + train_stats = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + break + + for name, param in ppo_trainer.model.named_parameters(): + self.assertTrue(param.grad is not None, f"Parameter {name} has no gradient") + + # ref model should not be trained + for name, param in ppo_trainer.ref_model.named_parameters(): + self.assertTrue(param.grad is None, f"Parameter {name} has a gradient") + + # Finally check stats + for stat in EXPECTED_STATS: + assert stat in train_stats.keys() + + def test_ppo_step_with_no_ref_sgd_lr_scheduler(self): + # initialize dataset + dummy_dataset = self._init_dummy_dataset() + optimizer = torch.optim.SGD(self.gpt2_model.parameters(), lr=0.01) + lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=None, + optimizer=optimizer, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + lr_scheduler=lr_scheduler, + ) + dummy_dataloader = ppo_trainer.dataloader + + self.assertTrue(isinstance(ppo_trainer.optimizer.optimizer, torch.optim.SGD)) + self.assertTrue(isinstance(ppo_trainer.lr_scheduler.scheduler, torch.optim.lr_scheduler.ExponentialLR)) + + # train model with ppo + for query_tensor, response_tensor in dummy_dataloader: + # define a reward for response + # (this could be any reward such as human feedback or output from another model) + reward = [torch.tensor(1.0), torch.tensor(0.0)] + # train model + _ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + train_stats = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + break + + for name, param in ppo_trainer.model.named_parameters(): + self.assertTrue(param.grad is not None, f"Parameter {name} has no gradient") + + # ref model should not be trained + for name, param in ppo_trainer.ref_model.named_parameters(): + self.assertTrue(param.grad is None, f"Parameter {name} has a gradient") + + # Finally check stats + for stat in EXPECTED_STATS: + assert stat in train_stats.keys() + + # assert that the LR has increased for exponential decay + self.assertTrue(train_stats["ppo/learning_rate"] > self.ppo_config.learning_rate) + + def test_ppo_step_with_no_ref(self): + # initialize dataset + dummy_dataset = self._init_dummy_dataset() + self.gpt2_model = AutoModelForCausalLMWithValueHead.from_pretrained(self.model_id) + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=None, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + dummy_dataloader = ppo_trainer.dataloader + # train model with ppo + for query_tensor, response_tensor in dummy_dataloader: + # define a reward for response + # (this could be any reward such as human feedback or output from another model) + reward = [torch.tensor(1.0), torch.tensor(0.0)] + # train model + train_stats = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + break + + for name, param in ppo_trainer.model.named_parameters(): + self.assertTrue(param.grad is not None, f"Parameter {name} has no gradient") + + # ref model should not be trained + for name, param in ppo_trainer.ref_model.named_parameters(): + self.assertTrue(param.grad is None, f"Parameter {name} has a gradient") + + # initialize a new gpt2 model: + model = AutoModelForCausalLMWithValueHead.from_pretrained(self.model_id) + for name, param in ppo_trainer.ref_model.named_parameters(): + if "v_head" not in name: + name = name.replace("pretrained_model.", "") + + self.assertTrue( + torch.allclose(param.cpu(), model.state_dict()[name].cpu()), + f"Parameter {name} has changed from the original model", + ) + + # Finally check stats + for stat in EXPECTED_STATS: + assert stat in train_stats.keys() + + def test_ppo_step_with_no_ref_custom_layers(self): + """ + Test PPO step with no reference model and custom layers + For shared layers configuration, all the layers after the `num_shared_layers` are considered as custom layers + therefore the gradients should be computed for these layers only. + """ + # initialize dataset + dummy_dataset = self._init_dummy_dataset() + self.gpt2_model = AutoModelForCausalLMWithValueHead.from_pretrained(self.model_id) + num_shared_layers = 1 + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=None, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + num_shared_layers=num_shared_layers, + ) + dummy_dataloader = ppo_trainer.dataloader + # train model with ppo + for query_tensor, response_tensor in dummy_dataloader: + # define a reward for response + # (this could be any reward such as human feedback or output from another model) + reward = [torch.tensor(1.0), torch.tensor(0.0)] + # train model + train_stats = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + break + + pattern = r".*transformer\.h\.(\d+)\..*" + final_layers = ["ln_f", "v_head", "lm_head"] + + for name, param in ppo_trainer.model.named_parameters(): + if re.match(pattern, name): + layer_number = int(re.match(pattern, name).groups(0)[0]) + if layer_number < num_shared_layers: + self.assertTrue(param.grad is None, f"Parameter {name} has a gradient") + else: + self.assertTrue(param.grad is not None, f"Parameter {name} has no gradient") + elif any([layer in name for layer in final_layers]): + self.assertTrue(param.grad is not None, f"Parameter {name} has no gradient") + + # ref model should not be trained + for name, param in ppo_trainer.ref_model.named_parameters(): + self.assertTrue(param.grad is None, f"Parameter {name} has a gradient") + + for stat in EXPECTED_STATS: + assert stat in train_stats.keys() + + def test_ppo_step_with_ref_and_custom_layers_warning(self): + """ + Test PPO step with a reference model and custom layers + The trainer should raise a warning if the argument `num_shared_layers` is set + together with a reference model. + """ + # initialize dataset + dummy_dataset = self._init_dummy_dataset() + + num_shared_layers = 6 + + with self.assertWarns(UserWarning): + _ = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=self.gpt2_model_ref, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + num_shared_layers=num_shared_layers, + ) + + def test_ppo_step_rewards_shape(self): + """ + Test if the rewards shape is correct by asserting that if a wrong reward shape is passed, we get + a value error. + """ + + # initialize dataset + dummy_dataset = self._init_dummy_dataset() + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=None, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + dummy_dataloader = ppo_trainer.dataloader + # train model with ppo + for query_tensor, response_tensor in dummy_dataloader: + # define a reward for response + # (this could be any reward such as human feedback or output from another model) + reward = [torch.tensor([[1.0]]), torch.tensor([[0.0]])] + # train model - this should raise an error + with self.assertRaises(ValueError): + _ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + + reward = [torch.tensor([1.0]), torch.tensor([0.0])] + # train model - this should work + _ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + break + + # check if the gradients are computed for the model + for name, param in ppo_trainer.model.named_parameters(): + self.assertTrue(param.grad is not None, f"Parameter {name} has no gradient") + + # ref model should not be trained + for name, param in ppo_trainer.ref_model.named_parameters(): + self.assertTrue(param.grad is None, f"Parameter {name} has a gradient") + + def test_ppo_step_input_shape(self): + """ + Test if the shape of the expected inputs are correct + """ + # initialize dataset + dummy_dataset = self._init_dummy_dataset() + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=None, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + dummy_dataloader = ppo_trainer.dataloader + # train model with ppo + for query_tensor, response_tensor in dummy_dataloader: + # define a reward for response + # (this could be any reward such as human feedback or output from another model) + reward = [torch.tensor([1.0]), torch.tensor([0.0])] + # train model - this should raise an error + bs = ppo_trainer.config.batch_size + + queries, responses, _, _ = ppo_trainer._step_safety_checker( + bs, [q for q in query_tensor], [r for r in response_tensor], reward + ) + + self.assertTrue(isinstance(queries, list), f"queries should be a list, got {type(queries)}") + self.assertTrue(isinstance(responses, list), f"responses should be a list, got {type(responses)}") + + # check the shapes + for i in range(bs): + self.assertEqual(queries[i].shape, torch.Size([7])) + self.assertEqual(responses[i].size(), torch.Size([7])) + break + + def test_ppo_step_no_dataset(self): + """ + Test if the training loop works fine without passing a dataset + """ + query_txt = "This morning I went to the " + query_tensor = self.gpt2_tokenizer.encode(query_txt, return_tensors="pt") + self.ppo_config.batch_size = 1 + + response_tensor = respond_to_batch(self.gpt2_model, query_tensor) + + # Check that this warns the user about batch size + with self.assertWarns(UserWarning): + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=self.gpt2_model_ref, + tokenizer=self.gpt2_tokenizer, + ) + # train model with ppo + reward = [torch.tensor([1.0])] + # train model - this should work fine + train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward) + + # check gradients + for name, param in ppo_trainer.model.named_parameters(): + self.assertTrue(param.grad is not None, f"Parameter {name} has no gradient") + + # ref model should not be trained + for name, param in ppo_trainer.ref_model.named_parameters(): + self.assertTrue(param.grad is None, f"Parameter {name} has a gradient") + + # check train stats + for stat in EXPECTED_STATS: + self.assertTrue(stat in train_stats, f"Train stats should contain {stat}") + + def test_loss_trainer(self): + """ + Test if the loss trainer works fine + """ + # initialize dataset + dummy_dataset = self._init_dummy_dataset() + + self.gpt2_model.eval() + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=None, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + + dummy_queries = [torch.tensor([1, 2, 3, 4]), torch.tensor([1, 2, 3, 4, 5, 6, 7])] + dummy_responses = [torch.tensor([5, 6, 7, 8, 9]), torch.tensor([8, 9, 10, 11, 12, 13])] + dummy_scores = torch.Tensor([1, 2]) + + ppo_trainer.config.mini_batch_size = 1 + ppo_trainer.config.batch_size = 1 + model_inputs = ppo_trainer.prepare_model_inputs(dummy_queries, dummy_responses) + all_logprobs, _, values, mask = ppo_trainer.batched_forward_pass( + self.gpt2_model, dummy_queries, dummy_responses, model_inputs + ) + + # dummy values + ref_logprobs = all_logprobs + 1 + logits = torch.exp(all_logprobs) + vpreds = values + 0.1 + + score, non_score = ppo_trainer.compute_rewards(dummy_scores, all_logprobs, ref_logprobs, mask) + values, advantages, returns = ppo_trainer.compute_advantages(values, score, mask) + + # just make sure a dummy loss is computed + idx = 0 + pg_loss, v_loss, _ = ppo_trainer.loss( + all_logprobs[idx].unsqueeze(0), + values[idx].unsqueeze(0), + logits[idx].unsqueeze(0), + vpreds[idx].unsqueeze(0), + ref_logprobs[idx].unsqueeze(0), + mask[idx].unsqueeze(0), + advantages[idx].unsqueeze(0), + returns[idx].unsqueeze(0), + ) + + self.assertAlmostEqual(pg_loss.item(), 2.0494, 4) + self.assertAlmostEqual(v_loss.item(), 0.07110, 4) + + # check if we get same results with masked parts removed + pg_loss_unmasked, v_loss_unmasked, _ = ppo_trainer.loss( + apply_mask(all_logprobs[idx], mask[idx]).unsqueeze(0), + apply_mask(values[idx], mask[idx]).unsqueeze(0), + apply_mask(logits[idx], mask[idx]).unsqueeze(0), + apply_mask(vpreds[idx], mask[idx]).unsqueeze(0), + apply_mask(ref_logprobs[idx], mask[idx]).unsqueeze(0), + apply_mask(mask[idx], mask[idx]).unsqueeze(0), + apply_mask(advantages[idx], mask[idx]).unsqueeze(0), + apply_mask(returns[idx], mask[idx]).unsqueeze(0), + ) + self.assertAlmostEqual(pg_loss_unmasked.item(), 2.0494, 4) + self.assertAlmostEqual(v_loss_unmasked.item(), 0.07110, 4) + + @parameterized.expand( + [ + ["gpt2"], + ["bloom"], + ["t5"], + ] + ) + def test_batched_forward_pass(self, name): + """ + Test if the loss trainer works fine + """ + # initialize dataset + dummy_dataset = self._init_dummy_dataset() + + dummy_queries = [torch.tensor([1, 2, 3, 4]), torch.tensor([1, 2, 3, 4, 5, 6, 7])] + dummy_responses = [torch.tensor([5, 6, 7, 8, 9]), torch.tensor([8, 9, 10, 11, 12, 13])] + + if name == "gpt2": + model = self.gpt2_model + tokenizer = self.gpt2_tokenizer + elif name == "bloom": + model = self.bloom_model + tokenizer = self.bloom_tokenizer + elif name == "t5": + model = self.t5_model + tokenizer = self.t5_tokenizer + + model.eval() + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=model, + ref_model=None, + tokenizer=tokenizer, + dataset=dummy_dataset, + ) + + # we test all combinations of fwd_bs and bs: + # if fwd_bs=bs=1: no padding is applied and only one forward pass + # if fwd_bs=1/bs=2: padding is applied and results computed in two fwd passes + # if fwd_bs=bs=2: padding is applied and results computed in one fwd pass + + ppo_trainer.config.mini_batch_size = 1 + ppo_trainer.config.batch_size = 1 + + model_inputs = ppo_trainer.prepare_model_inputs([dummy_queries[0]], [dummy_responses[0]]) + logprobs_0, logits_0, values_0, mask_0 = ppo_trainer.batched_forward_pass( + model, [dummy_queries[0]], [dummy_responses[0]], model_inputs + ) + + ppo_trainer.config.batch_size = 2 + model_inputs = ppo_trainer.prepare_model_inputs(dummy_queries, dummy_responses) + logprobs_1, logits_1, values_1, mask_1 = ppo_trainer.batched_forward_pass( + model, dummy_queries, dummy_responses, model_inputs + ) + + ppo_trainer.config.mini_batch_size = 2 + model_inputs = ppo_trainer.prepare_model_inputs(dummy_queries, dummy_responses) + logprobs_2, logits_2, values_2, mask_2 = ppo_trainer.batched_forward_pass( + model, dummy_queries, dummy_responses, model_inputs + ) + + self.assertLessEqual(abs_diff_masked_tensors(logprobs_1, logprobs_2, mask_1, mask_2), 1e-4) + self.assertLessEqual(abs_diff_masked_tensors(values_1, values_2, mask_1, mask_2), 1e-4) + + self.assertLessEqual(abs_diff_masked_tensors(logprobs_0, logprobs_2[:1], mask_0, mask_2[:1]), 1e-4) + self.assertLessEqual(abs_diff_masked_tensors(values_0, values_2[:1], mask_0, mask_2[:1]), 1e-4) + + def test_ppo_trainer_max_grad_norm(self): + """ + Test if the `max_grad_norm` feature works as expected + """ + # initialize dataset + dummy_dataset = self._init_dummy_dataset() + + self.ppo_config.max_grad_norm = 0.00001 + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=None, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + + dummy_dataloader = ppo_trainer.dataloader + + # train model with ppo + for query_tensor, response_tensor in dummy_dataloader: + # define a reward for response + # (this could be any reward such as human feedback or output from another model) + reward = [torch.tensor(1.0), torch.tensor(0.0)] + # train model + _ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + break + + # check gradients + for name, param in ppo_trainer.model.named_parameters(): + self.assertTrue(param.grad is not None, f"Parameter {name} has no gradient") + self.assertTrue( + torch.all(param.grad.abs() <= self.ppo_config.max_grad_norm), + f"Parameter {name} has a gradient larger than max_grad_norm", + ) + + def test_ppo_trainer_kl_penalty(self): + dummy_dataset = self._init_dummy_dataset() + + log_probs = torch.Tensor([[0.5, 0.2, 0.1], [0.6, 0.2, 0.1]]) + ref_log_probs = torch.Tensor([[0.4, 0.3, 0.0], [0.7, 0.1, 0.3]]) + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=None, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + + expected_output = torch.Tensor([[0.1000, -0.1000, 0.1000], [-0.1000, 0.1000, -0.2000]]) + self.assertTrue(torch.allclose(ppo_trainer._kl_penalty(log_probs, ref_log_probs), expected_output)) + + self.ppo_config.kl_penalty = "abs" + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=None, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + + expected_output = torch.Tensor([[0.1000, 0.1000, 0.1000], [0.1000, 0.1000, 0.2000]]) + self.assertTrue(torch.allclose(ppo_trainer._kl_penalty(log_probs, ref_log_probs), expected_output)) + + self.ppo_config.kl_penalty = "mse" + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=None, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + + expected_output = torch.Tensor([[0.0050, 0.0050, 0.0050], [0.0050, 0.0050, 0.0200]]) + self.assertTrue(torch.allclose(ppo_trainer._kl_penalty(log_probs, ref_log_probs), expected_output)) + + def test_ppo_trainer_full_kl_penalty(self): + # a few more extensive tests for the full kl option as it is more involved + dummy_dataset = self._init_dummy_dataset() + + self.ppo_config.kl_penalty = "full" + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=None, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + + # Test on tensors for size B,S,T = (1,2,3) + # test for when the two dists are the same + log_probs = torch.Tensor( + [ + [ + [0.1, 0.2, 0.7], + [0.3, 0.4, 0.3], + ] + ] + ).exp() + + ref_log_probs = torch.Tensor( + [ + [ + [0.1, 0.2, 0.7], + [0.3, 0.4, 0.3], + ] + ] + ).exp() + + expected_output = torch.Tensor( + [[0.0, 0.0]], + ) + output = ppo_trainer._kl_penalty(log_probs, ref_log_probs) + self.assertTrue(output.shape == (1, 2)) + self.assertTrue(torch.allclose(output, expected_output)) + + # test for when the two dists are almost not overlapping + log_probs = torch.Tensor( + [ + [ + [0.98, 0.01, 0.01], + [0.01, 0.98, 0.01], + ] + ] + ).log() + + ref_log_probs = torch.Tensor( + [ + [ + [0.01, 0.01, 0.98], + [0.01, 0.01, 0.98], + ] + ] + ).log() + + expected_output = torch.Tensor( + [[4.4474, 4.4474]], + ) + output = ppo_trainer._kl_penalty(log_probs, ref_log_probs) + self.assertTrue(output.shape == (1, 2)) + self.assertTrue(torch.allclose(output, expected_output)) + + # test for when the two dists are almost not overlapping + log_probs = torch.Tensor( + [ + [ + [0.49, 0.02, 0.49], + [0.49, 0.02, 0.49], + ] + ] + ).log() + + ref_log_probs = torch.Tensor( + [ + [ + [0.01, 0.98, 0.01], + [0.49, 0.02, 0.49], + ] + ] + ).log() + + expected_output = torch.Tensor( + [[3.7361, 0.0]], + ) + output = ppo_trainer._kl_penalty(log_probs, ref_log_probs) + self.assertTrue(output.shape == (1, 2)) + self.assertTrue(torch.allclose(output, expected_output, atol=1e-4)) + + @require_peft + @mark.peft_test + def test_peft_model_ppo_trainer(self): + from peft import LoraConfig, get_peft_model + from transformers import AutoModelForCausalLM + + lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + gpt2_model = AutoModelForCausalLM.from_pretrained(self.model_id) + + # this line is very important + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + gpt2_model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + peft_model = get_peft_model(gpt2_model, lora_config) + model = AutoModelForCausalLMWithValueHead.from_pretrained(peft_model) + + dummy_dataset = self._init_dummy_dataset() + self.ppo_config.batch_size = 2 + self.ppo_config.mini_batch_size = 1 + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=model, + ref_model=None, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + + self.assertTrue(ppo_trainer.ref_model is None) + + dummy_dataloader = ppo_trainer.dataloader + + # train model with ppo + for query_tensor, response_tensor in dummy_dataloader: + # define a reward for response + # (this could be any reward such as human feedback or output from another model) + reward = [torch.tensor(1.0), torch.tensor(0.0)] + # train model by running a step twice + _ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + + ppo_trainer.model.train() + ppo_trainer.model.gradient_checkpointing_enable() + _ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + break + + # check gradients + for name, param in model.named_parameters(): + if "lora" in name or "v_head" in name: + self.assertTrue(param.grad is not None, f"Parameter {name} has a no gradient") + else: + self.assertTrue(param.grad is None, f"Parameter {name} has a gradient") + + @require_peft + @mark.peft_test + def test_peft_model_ppo_adapter_rm_trainer(self): + from peft import LoraConfig, get_peft_model + from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification + + dummy_inputs = torch.LongTensor([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]) + rm_lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="SEQ_CLS", + ) + + reward_model = AutoModelForSequenceClassification.from_pretrained(self.model_id) + reward_model = get_peft_model(reward_model, rm_lora_config) + dummy_optim = torch.optim.Adam(filter(lambda p: p.requires_grad, reward_model.parameters()), lr=1e-3) + + previous_rm_logits = reward_model(dummy_inputs).logits + loss = previous_rm_logits.mean() + loss.backward() + + dummy_optim.step() + reward_model.eval() + + original_rm_logits = reward_model(dummy_inputs).logits + + with tempfile.TemporaryDirectory() as tmpdirname: + reward_model.save_pretrained(tmpdirname) + + lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + gpt2_model = AutoModelForCausalLM.from_pretrained(self.model_id) + + # this line is very important + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + gpt2_model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + peft_model = get_peft_model(gpt2_model, lora_config) + model = AutoModelForCausalLMWithValueHead.from_pretrained( + peft_model, + reward_adapter=tmpdirname, + ) + + dummy_dataset = self._init_dummy_dataset() + self.ppo_config.batch_size = 2 + self.ppo_config.mini_batch_size = 1 + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=model, + ref_model=None, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + + self.assertTrue(ppo_trainer.ref_model is None) + + dummy_dataloader = ppo_trainer.dataloader + + # train model with ppo + for query_tensor, response_tensor in dummy_dataloader: + # define a reward for response + # (this could be any reward such as human feedback or output from another model) + reward = [torch.tensor(1.0), torch.tensor(0.0)] + # train model by running a step twice + _ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + + ppo_trainer.model.train() + ppo_trainer.model.gradient_checkpointing_enable() + _ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + break + + new_logits = ppo_trainer.model.compute_reward_score(dummy_inputs) + self.assertTrue(not torch.allclose(previous_rm_logits, new_logits[:, -1, :])) + self.assertTrue(torch.allclose(original_rm_logits, new_logits[:, -1, :])) + + # check gradients + for name, param in model.named_parameters(): + if ("lora" in name or "v_head" in name) and ("reward" not in name): + self.assertTrue(param.grad is not None, f"Parameter {name} has a no gradient") + else: + self.assertTrue(param.grad is None, f"Parameter {name} has a gradient") + + @unittest.skip("Fix by either patching `whomai()` to work in the staging endpoint or use a dummy prod user.") + def test_push_to_hub(self): + REPO_NAME = "test-ppo-trainer" + repo_id = f"{CI_HUB_USER}/{REPO_NAME}" + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=None, + tokenizer=self.gpt2_tokenizer, + dataset=self._init_dummy_dataset(), + ) + with tempfile.TemporaryDirectory(): + url = ppo_trainer.push_to_hub(repo_id=repo_id, token=self._token, api_endpoint=CI_HUB_ENDPOINT) + # Extract repo_name from the url + re_search = re.search(CI_HUB_ENDPOINT + r"/([^/]+/[^/]+)/", url) + self.assertTrue(re_search is not None) + hub_repo_id = re_search.groups()[0] + # Check we created a Hub repo + self.assertEqual(hub_repo_id, repo_id) + # Ensure all files are present + files = sorted(self._api.list_repo_files(hub_repo_id)) + assert all( + fnmatch.fnmatch(file, expected_file) + for file, expected_file in zip( + files, + [ + ".gitattributes", + "README.md", + "config.json", + "merges.txt", + "pytorch_model.bin", + "special_tokens_map.json", + "tokenizer_config.json", + "vocab.json", + ], + ) + ) + + @require_peft + @require_torch_multi_gpu + @mark.peft_test + def test_peft_model_ppo_trainer_multi_gpu(self): + from peft import LoraConfig, get_peft_model + from transformers import AutoModelForCausalLM + + lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + gpt2_model = AutoModelForCausalLM.from_pretrained( + "gpt2", device_map="balanced", max_memory={0: "500MB", 1: "500MB"} + ) + + self.assertTrue(set(gpt2_model.hf_device_map.values()) == {0, 1}) + + # this line is very important + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + gpt2_model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + peft_model = get_peft_model(gpt2_model, lora_config) + model = AutoModelForCausalLMWithValueHead.from_pretrained(peft_model) + + self.assertTrue(model.is_sequential_parallel) + + dummy_dataset = self._init_dummy_dataset() + self.ppo_config.batch_size = 2 + self.ppo_config.mini_batch_size = 1 + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=model, + ref_model=None, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + + self.assertTrue(ppo_trainer.ref_model is None) + + dummy_dataloader = ppo_trainer.dataloader + + # train model with ppo + for query_tensor, response_tensor in dummy_dataloader: + # define a reward for response + # (this could be any reward such as human feedback or output from another model) + reward = [torch.tensor(1.0), torch.tensor(0.0)] + # train model by running a step twice + _ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + + ppo_trainer.model.train() + ppo_trainer.model.gradient_checkpointing_enable() + _ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + break + + # check gradients + for name, param in model.named_parameters(): + if "lora" in name or "v_head" in name: + self.assertTrue(param.grad is not None, f"Parameter {name} has a no gradient") + else: + self.assertTrue(param.grad is None, f"Parameter {name} has a gradient") + + def test_generation(self): + dummy_dataset = self._init_dummy_dataset() + + model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2") + tokenizer = AutoTokenizer.from_pretrained("gpt2") + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=model, + ref_model=None, + tokenizer=tokenizer, + dataset=dummy_dataset, + ) + + input_texts = ["this is a test", "this is another, longer test"] + + generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": tokenizer.eos_token_id} + + tokenizer.pad_token = tokenizer.eos_token + + model_inputs = [tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts] + + generations_batched = ppo_trainer.generate(model_inputs, batch_size=2, **generation_kwargs) + generations_batched = tokenizer.batch_decode(generations_batched) + + generations_single = [ppo_trainer.generate(inputs, **generation_kwargs).squeeze() for inputs in model_inputs] + generations_single = tokenizer.batch_decode(generations_single) + + self.assertEqual(generations_single, generations_batched) + + def test_grad_accumulation(self): + dummy_dataset = self._init_dummy_dataset() + + torch.manual_seed(0) + gpt2_model = AutoModelForCausalLMWithValueHead.from_pretrained(self.model_id, summary_dropout_prob=0.0) + gpt2_model_clone = copy.deepcopy(gpt2_model) + + self.ppo_config.mini_batch_size = 2 + self.ppo_config.ppo_epochs = 1 + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=gpt2_model, + ref_model=None, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + + dummy_dataloader = ppo_trainer.dataloader + + # train model with ppo + for query_tensor, response_tensor in dummy_dataloader: + # define a reward for response + # (this could be any reward such as human feedback or output from another model) + reward = [torch.tensor(1.0), torch.tensor(1.0)] + # train model by running a step twice + _ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + break + + model_grad = gpt2_model.v_head.summary.weight + + self.ppo_config.mini_batch_size = 1 + self.ppo_config.gradient_accumulation_steps = 2 + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=gpt2_model_clone, + ref_model=None, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + + dummy_dataloader = ppo_trainer.dataloader + + # train model with ppo + for query_tensor, response_tensor in dummy_dataloader: + # define a reward for response + # (this could be any reward such as human feedback or output from another model) + reward = [torch.tensor(1.0), torch.tensor(1.0)] + # train model by running a step twice + _ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + break + + model_grad_acc = gpt2_model_clone.v_head.summary.weight + self.assertTrue(torch.allclose(model_grad_acc, model_grad, rtol=1e-3, atol=1e-3)) + + @unittest.skip("Fix by either patching `whomai()` to work in the staging endpoint or use a dummy prod user.") + def test_push_to_hub_if_best_reward(self): + REPO_NAME = "test-ppo-trainer" + repo_id = f"{CI_HUB_USER}/{REPO_NAME}" + + dummy_dataset = self._init_dummy_dataset() + + push_to_hub_if_best_kwargs = {"repo_id": repo_id} + + ppo_config = PPOConfig( + batch_size=2, + mini_batch_size=1, + log_with=None, + push_to_hub_if_best_kwargs=push_to_hub_if_best_kwargs, + compare_steps=1, + ) + + ppo_trainer = PPOTrainer( + config=ppo_config, + model=self.gpt2_model, + ref_model=self.gpt2_model_ref, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + + dummy_dataloader = ppo_trainer.dataloader + # train model with ppo + for query_tensor, response_tensor in dummy_dataloader: + # define a reward for response + # (this could be any reward such as human feedback or output from another model) + reward = [torch.tensor(1.0), torch.tensor(0.0)] + # train model + _ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + break + + def test_batch_size_check(self): + with pytest.raises(ValueError): + PPOConfig(batch_size=2, mini_batch_size=2, gradient_accumulation_steps=2) diff --git a/trl/trl/tests/test_reward_trainer.py b/trl/trl/tests/test_reward_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..1cf63c945f140b4dce5791496349f2bd1bd9a1fb --- /dev/null +++ b/trl/trl/tests/test_reward_trainer.py @@ -0,0 +1,314 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tempfile +import unittest + +import torch +from datasets import Dataset +from transformers import AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction + +from trl import RewardConfig, RewardTrainer +from trl.trainer import compute_accuracy + +from .testing_utils import require_peft + + +class RewardTrainerTester(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab" + cls.model = AutoModelForSequenceClassification.from_pretrained(cls.model_id) + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_id) + cls.tokenizer.pad_token = cls.tokenizer.eos_token + + def test_accuracy_metrics(self): + dummy_eval_predictions = EvalPrediction(torch.FloatTensor([[0.1, 0.9], [0.9, 0.1]]), torch.LongTensor([0, 0])) + accuracy = compute_accuracy(dummy_eval_predictions) + self.assertEqual(accuracy["accuracy"], 0.5) + + def test_reward_trainer(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = RewardConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + evaluation_strategy="steps", + ) + + # fmt: off + dummy_dataset_dict = { + "input_ids_chosen": [ + torch.LongTensor([0, 1, 2,]), + torch.LongTensor([1, 2]), + torch.LongTensor([0, 1, 2,]), + torch.LongTensor([1, 2]), + ], + "attention_mask_chosen": [ + torch.LongTensor([1, 1, 1]), + torch.LongTensor([1, 0]), + torch.LongTensor([1, 1, 1]), + torch.LongTensor([1, 0]), + ], + "input_ids_rejected": [ + torch.LongTensor([0, 2,]), + torch.LongTensor([1, 2, 0]), + torch.LongTensor([0, 2,]), + torch.LongTensor([1, 2, 0]), + ], + "attention_mask_rejected": [ + torch.LongTensor([1, 1]), + torch.LongTensor([1, 1, 0]), + torch.LongTensor([1, 1]), + torch.LongTensor([1, 1, 1]), + ], + } + # fmt: on + dummy_dataset = Dataset.from_dict(dummy_dataset_dict) + + trainer = RewardTrainer( + model=self.model, + args=training_args, + tokenizer=self.tokenizer, + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # check the params have changed - ignore 0 biases + if param.sum() != 0: + self.assertFalse(torch.equal(param, new_param)) + + preds = trainer.predict(dummy_dataset) + self.assertEqual(preds.predictions.shape, (4, 2)) + + @require_peft + def test_reward_trainer_peft(self): + import peft + from peft import LoraConfig, TaskType + + peft_version = peft.__version__ + + peft_config = LoraConfig( + task_type=TaskType.SEQ_CLS, + inference_mode=False, + r=8, + lora_alpha=32, + lora_dropout=0.1, + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = RewardConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=6, + remove_unused_columns=False, + gradient_accumulation_steps=2, + learning_rate=9e-1, + evaluation_strategy="steps", + ) + + # fmt: off + dummy_dataset_dict = { + "input_ids_chosen": [ + torch.LongTensor([0, 1, 2,]), + torch.LongTensor([1, 2]), + torch.LongTensor([0, 1, 2,]), + torch.LongTensor([1, 2]), + ], + "attention_mask_chosen": [ + torch.LongTensor([1, 1, 1]), + torch.LongTensor([1, 0]), + torch.LongTensor([1, 1, 1]), + torch.LongTensor([1, 0]), + ], + "input_ids_rejected": [ + torch.LongTensor([0, 2,]), + torch.LongTensor([1, 2, 0]), + torch.LongTensor([0, 2,]), + torch.LongTensor([1, 2, 0]), + ], + "attention_mask_rejected": [ + torch.LongTensor([1, 1]), + torch.LongTensor([1, 1, 0]), + torch.LongTensor([1, 1]), + torch.LongTensor([1, 1, 1]), + ], + } + # fmt: on + dummy_dataset = Dataset.from_dict(dummy_dataset_dict) + + trainer = RewardTrainer( + model=self.model, + args=training_args, + tokenizer=self.tokenizer, + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, + peft_config=peft_config, + ) + previous_trainable_params = {} + previous_non_trainable_params = {} + + # due to a change in the way the modules to save are dealt in PEFT. + trainable_params_name = ["lora", "score"] if peft_version < "0.3.0" else ["lora", "modules_to_save"] + + # check gradients are not None + for n, param in trainer.model.named_parameters(): + if any([t in n for t in trainable_params_name]): + previous_trainable_params[n] = param.clone() + else: + previous_non_trainable_params[n] = param.clone() + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)) + + # check the non trainable params have not changed + for n, param in previous_non_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertTrue(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)) + + preds = trainer.predict(dummy_dataset) + self.assertEqual(preds.predictions.shape, (4, 2)) + + def test_reward_trainer_assert_value_error(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = RewardConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=1, + remove_unused_columns=False, + ) + + dummy_dataset_dict = { + # fmt: off + "input_ids_b": [ + torch.LongTensor([0, 1, 2,]), + torch.LongTensor([1, 2]), + torch.LongTensor([0, 1, 2,]), + torch.LongTensor([1, 2]), + ], + "attention_mask_c": [ + torch.LongTensor([1, 1, 1]), + torch.LongTensor([1, 0]), + torch.LongTensor([1, 1, 1]), + torch.LongTensor([1, 0]), + ], + "input_ids_f": [ + torch.LongTensor([0, 2,]), + torch.LongTensor([1, 2, 0]), + torch.LongTensor([0, 2,]), + torch.LongTensor([1, 2, 0]), + ], + "attention_mask_g": [ + torch.LongTensor([1, 1]), + torch.LongTensor([1, 1, 0]), + torch.LongTensor([1, 1]), + torch.LongTensor([1, 1, 1]), + ], + # fmt: on + } + dummy_dataset = Dataset.from_dict(dummy_dataset_dict) + + trainer = RewardTrainer( + model=self.model, + args=training_args, + tokenizer=self.tokenizer, + train_dataset=dummy_dataset, + ) + + with self.assertRaises(ValueError): + trainer.train() + + training_args = RewardConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=1, + remove_unused_columns=True, + ) + + with self.assertWarns(UserWarning): + trainer = RewardTrainer( + model=self.model, + args=training_args, + tokenizer=self.tokenizer, + train_dataset=dummy_dataset, + ) + + def test_reward_trainer_margin(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = RewardConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + evaluation_strategy="steps", + ) + + # fmt: off + dummy_dataset_dict = { + "input_ids_chosen": [ + torch.LongTensor([0, 1, 2,]), + ], + "attention_mask_chosen": [ + torch.LongTensor([1, 1, 1]), + ], + "input_ids_rejected": [ + torch.LongTensor([0, 2,]), + ], + "attention_mask_rejected": [ + torch.LongTensor([1, 1]), + ], + "margin": [ + torch.FloatTensor([1.0]), + ] + } + # fmt: on + dummy_dataset = Dataset.from_dict(dummy_dataset_dict) + + trainer = RewardTrainer( + model=self.model, + args=training_args, + tokenizer=self.tokenizer, + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, + ) + + batch = [dummy_dataset[0]] + batch = trainer.data_collator(batch) + loss, outputs = trainer.compute_loss(trainer.model, batch, return_outputs=True) + + self.assertAlmostEqual( + loss, + -torch.nn.functional.logsigmoid( + outputs["rewards_chosen"] - outputs["rewards_rejected"] - batch["margin"] + ).mean(), + ) diff --git a/trl/trl/tests/test_sft_trainer.py b/trl/trl/tests/test_sft_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..f430c7b48618a9eded0599ef2e1a53a0363dd7ef --- /dev/null +++ b/trl/trl/tests/test_sft_trainer.py @@ -0,0 +1,791 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import os +import tempfile +import unittest + +import numpy as np +import torch +from datasets import Dataset +from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments + +from trl import SFTTrainer +from trl.import_utils import is_peft_available +from trl.trainer import ConstantLengthDataset, DataCollatorForCompletionOnlyLM + +from .testing_utils import require_peft + + +def formatting_prompts_func(example): + text = f"### Question: {example['question']}\n ### Answer: {example['answer']}" + return text + + +def formatting_prompts_func_batched(example): + output_text = [] + for i, question in enumerate(example["question"]): + text = f"### Question: {question}\n ### Answer: {example['answer'][i]}" + output_text.append(text) + return output_text + + +if is_peft_available(): + from peft import LoraConfig, PeftModel + + +class SFTTrainerTester(unittest.TestCase): + r""" """ + + @classmethod + def setUpClass(cls): + cls.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab" + cls.model = AutoModelForCausalLM.from_pretrained(cls.model_id) + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_id) + cls.tokenizer.pad_token = cls.tokenizer.eos_token + cls.dummy_dataset = Dataset.from_dict( + { + "question": [ + "Does llamas know how to code?", + "Does llamas know how to fly?", + "Does llamas know how to talk?", + "Does llamas know how to code?", + "Does llamas know how to fly?", + "Does llamas know how to talk?", + "Does llamas know how to swim?", + ], + "answer": [ + "Yes, llamas are very good at coding.", + "No, llamas can't fly.", + "Yes, llamas are very good at talking.", + "Yes, llamas are very good at coding.", + "No, llamas can't fly.", + "Yes, llamas are very good at talking.", + "No, llamas can't swim.", + ], + "text": [ + "### Question: Does llamas know how to code?\n ### Answer: Yes, llamas are very good at coding.", + "### Question: Does llamas know how to fly?\n ### Answer: No, llamas can't fly.", + "### Question: Does llamas know how to talk?\n ### Answer: Yes, llamas are very good at talking.", + "### Question: Does llamas know how to code?\n ### Answer: Yes, llamas are very good at coding.", + "### Question: Does llamas know how to fly?\n ### Answer: No, llamas can't fly.", + "### Question: Does llamas know how to talk?\n ### Answer: Yes, llamas are very good at talking.", + "### Question: Does llamas know how to swim?\n ### Answer: No, llamas can't swim.", + ], + } + ) + + cls.train_dataset = ConstantLengthDataset( + cls.tokenizer, + cls.dummy_dataset, + dataset_text_field=None, + formatting_func=formatting_prompts_func, + seq_length=16, + num_of_sequences=16, + ) + + cls.eval_dataset = ConstantLengthDataset( + cls.tokenizer, + cls.dummy_dataset, + dataset_text_field=None, + formatting_func=formatting_prompts_func, + seq_length=16, + num_of_sequences=16, + ) + + def test_constant_length_dataset(self): + formatted_dataset = ConstantLengthDataset( + self.tokenizer, + self.dummy_dataset, + dataset_text_field=None, + formatting_func=formatting_prompts_func, + ) + + self.assertTrue(len(formatted_dataset) == len(self.dummy_dataset)) + self.assertTrue(len(formatted_dataset) > 0) + + for example in formatted_dataset: + self.assertTrue("input_ids" in example) + self.assertTrue("labels" in example) + + self.assertTrue(len(example["input_ids"]) == formatted_dataset.seq_length) + self.assertTrue(len(example["labels"]) == formatted_dataset.seq_length) + + decoded_text = self.tokenizer.decode(example["input_ids"]) + self.assertTrue(("Question" in decoded_text) and ("Answer" in decoded_text)) + + def test_sft_trainer(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + dataloader_drop_last=True, + evaluation_strategy="steps", + max_steps=4, + eval_steps=2, + save_steps=2, + per_device_train_batch_size=2, + ) + + trainer = SFTTrainer( + model=self.model_id, + args=training_args, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + packing=True, + ) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"]) + + self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")) + + def test_sft_trainer_uncorrect_data(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + dataloader_drop_last=True, + evaluation_strategy="steps", + max_steps=2, + eval_steps=1, + save_steps=1, + per_device_train_batch_size=2, + ) + + with self.assertRaises(ValueError): + _ = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.dummy_dataset, + packing=True, + ) + + # This should work + _ = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.dummy_dataset, + formatting_func=formatting_prompts_func, + max_seq_length=32, # make sure there is at least 1 packed sequence + packing=True, + ) + + with self.assertRaises(ValueError): + # This should not work because not enough data for one sample + _ = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.dummy_dataset, + formatting_func=formatting_prompts_func, + max_seq_length=1024, # make sure there is NOT at least 1 packed sequence + packing=True, + ) + + # This should not work as well + with self.assertRaises(ValueError): + _ = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.dummy_dataset, + formatting_func=formatting_prompts_func, + packing=False, + ) + + # but this should work + _ = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.dummy_dataset, + formatting_func=formatting_prompts_func_batched, + packing=False, + ) + + def test_sft_trainer_with_model_num_train_epochs(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + dataloader_drop_last=True, + evaluation_strategy="steps", + max_steps=2, + eval_steps=1, + save_steps=1, + num_train_epochs=2, + per_device_train_batch_size=2, + ) + + trainer = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + packing=True, + ) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"]) + + self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + dataloader_drop_last=True, + evaluation_strategy="steps", + max_steps=2, + save_steps=1, + num_train_epochs=2, + per_device_train_batch_size=2, + ) + + trainer = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.dummy_dataset, + dataset_text_field="text", + max_seq_length=16, + num_of_sequences=16, + packing=True, + ) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + dataloader_drop_last=True, + evaluation_strategy="steps", + max_steps=2, + save_steps=1, + num_train_epochs=2, + per_device_train_batch_size=2, + ) + + trainer = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.dummy_dataset, + dataset_text_field="text", + max_seq_length=16, + ) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-1")) + + def test_sft_trainer_with_model(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + dataloader_drop_last=True, + evaluation_strategy="steps", + max_steps=2, + eval_steps=1, + save_steps=1, + per_device_train_batch_size=2, + ) + + trainer = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + packing=True, + ) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"]) + + self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + dataloader_drop_last=True, + evaluation_strategy="steps", + max_steps=2, + save_steps=1, + per_device_train_batch_size=2, + ) + + trainer = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.dummy_dataset, + dataset_text_field="text", + max_seq_length=16, + num_of_sequences=16, + packing=True, + ) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")) + + # with formatting_func + packed + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + dataloader_drop_last=True, + evaluation_strategy="steps", + max_steps=2, + save_steps=1, + per_device_train_batch_size=2, + ) + + trainer = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.dummy_dataset, + formatting_func=formatting_prompts_func, + max_seq_length=16, + num_of_sequences=16, + packing=True, + ) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")) + + # with formatting_func + packed + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + dataloader_drop_last=True, + evaluation_strategy="steps", + max_steps=2, + save_steps=1, + per_device_train_batch_size=2, + ) + + trainer = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.dummy_dataset, + formatting_func=formatting_prompts_func_batched, + max_seq_length=16, + ) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + dataloader_drop_last=True, + evaluation_strategy="steps", + max_steps=2, + save_steps=1, + per_device_train_batch_size=2, + ) + + trainer = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.dummy_dataset, + dataset_text_field="text", + max_seq_length=16, + ) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-1")) + + def test_sft_trainer_with_multiple_eval_datasets(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + dataloader_drop_last=True, + evaluation_strategy="steps", + max_steps=1, + eval_steps=1, + save_steps=1, + per_device_train_batch_size=2, + ) + + trainer = SFTTrainer( + model=self.model_id, + args=training_args, + train_dataset=self.train_dataset, + eval_dataset={ + "data1": self.eval_dataset, + "data2": self.eval_dataset, + }, + packing=True, + ) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + self.assertIsNotNone(trainer.state.log_history[0]["eval_data1_loss"]) + self.assertIsNotNone(trainer.state.log_history[1]["eval_data2_loss"]) + + self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-1")) + + def test_data_collator_completion_lm(self): + response_template = "### Response:\n" + data_collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=self.tokenizer, mlm=False) + + text = """\n\n### Instructions:\nHello all this should be masked\n\n### Response:\nI have not been masked correctly.""" + encoded_text = self.tokenizer(text) + + examples = [encoded_text] + + batch = data_collator(examples) + labels = batch["labels"] + last_pad_idx = np.where(labels == -100)[1][-1] + result_text = self.tokenizer.decode(batch["input_ids"][0, last_pad_idx + 1 :]) + self.assertEqual(result_text, "I have not been masked correctly.") + + def test_data_collator_completion_lm_with_multiple_text(self): + tokenizer = copy.deepcopy(self.tokenizer) + tokenizer.padding_side = "left" + + response_template = "### Response:\n" + data_collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer, mlm=False) + + text1 = """\n\n### Instructions:\nHello all this should be masked\n\n### Response:\nI have not been masked correctly.""" + text2 = """\n\n### Instructions:\nThis is another longer text that should also be masked. This text is significantly longer than the previous one.\n\n### Response:\nI have not been masked correctly.""" + + encoded_text1 = tokenizer(text1) + encoded_text2 = tokenizer(text2) + + examples = [encoded_text1, encoded_text2] + + batch = data_collator(examples) + + for i in range(2): + labels = batch["labels"][i] + last_pad_idx = np.where(labels == -100)[0][-1] + result_text = tokenizer.decode(batch["input_ids"][i, last_pad_idx + 1 :]) + self.assertEqual(result_text, "I have not been masked correctly.") + + def test_data_collator_chat_completion_lm(self): + instruction_template = "### Human:" + assistant_template = "### Assistant:" + data_collator = DataCollatorForCompletionOnlyLM( + response_template=assistant_template, + instruction_template=instruction_template, + tokenizer=self.tokenizer, + mlm=False, + ) + + text = """### Human: Hello all this should be masked.### Assistant: I should not be masked.### Human: All this should be masked too.### Assistant: I should not be masked too.""" + encoded_text = self.tokenizer(text) + + examples = [encoded_text] + + batch = data_collator(examples) + labels = batch["labels"] + non_masked_tokens = batch["input_ids"][labels != -100] + result_text = self.tokenizer.decode(non_masked_tokens) + self.assertEqual(result_text, " I should not be masked. I should not be masked too.") + + def test_data_collator_chat_completion_lm_with_multiple_text(self): + tokenizer = copy.deepcopy(self.tokenizer) + tokenizer.padding_side = "left" + + instruction_template = "### Human:" + assistant_template = "### Assistant:" + data_collator = DataCollatorForCompletionOnlyLM( + response_template=assistant_template, + instruction_template=instruction_template, + tokenizer=tokenizer, + mlm=False, + ) + + text1 = """### Human: Hello all this should be masked.### Assistant: I should not be masked.""" + text2 = """### Human: Hello all this should be masked.### Assistant: I should not be masked.### Human: All this should be masked too.### Assistant: I should not be masked too.""" + encoded_text1 = tokenizer(text1) + encoded_text2 = tokenizer(text2) + + examples = [encoded_text1, encoded_text2] + + batch = data_collator(examples) + labels = batch["labels"] + input_ids = batch["input_ids"] + + non_masked_tokens1 = input_ids[0][labels[0] != -100] + result_text1 = tokenizer.decode(non_masked_tokens1) + self.assertEqual(result_text1, " I should not be masked.") + + non_masked_tokens2 = input_ids[1][labels[1] != -100] + result_text2 = tokenizer.decode(non_masked_tokens2) + self.assertEqual(result_text2, " I should not be masked. I should not be masked too.") + + def test_sft_trainer_infinite_with_model(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + dataloader_drop_last=True, + evaluation_strategy="steps", + max_steps=5, + eval_steps=1, + save_steps=1, + per_device_train_batch_size=2, + ) + + trainer = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + packing=True, + max_seq_length=500, + ) + + self.assertTrue(trainer.train_dataset.infinite) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"]) + + # make sure the trainer did 5 steps + self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-5")) + + def test_sft_trainer_infinite_with_model_epochs(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + dataloader_drop_last=True, + num_train_epochs=1, + per_device_train_batch_size=2, + save_strategy="epoch", + ) + + trainer = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + packing=True, + max_seq_length=500, + ) + + self.assertFalse(trainer.train_dataset.infinite) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # make sure the trainer did 5 steps + self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-4")) + + def test_sft_trainer_with_model_neftune(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + dataloader_drop_last=True, + evaluation_strategy="steps", + max_steps=2, + eval_steps=1, + save_steps=1, + per_device_train_batch_size=2, + ) + + trainer = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + neftune_noise_alpha=5, + packing=True, + ) + + trainer.model = trainer._trl_activate_neftune(trainer.model) + + device = trainer.model.get_input_embeddings().weight.device + trainer.model.train() + + torch.random.manual_seed(42) + embeds_neftune = trainer.model.get_input_embeddings()(torch.LongTensor([[1, 0, 1]]).to(device)) + + torch.random.manual_seed(24) + embeds_neftune_2 = trainer.model.get_input_embeddings()(torch.LongTensor([[1, 0, 1]]).to(device)) + + self.assertFalse(torch.allclose(embeds_neftune, embeds_neftune_2)) + self.assertTrue(len(trainer.model.get_input_embeddings()._forward_hooks) > 0) + + trainer.neftune_hook_handle.remove() + + trainer.train() + + # Make sure forward pass works fine + _ = trainer.model(torch.LongTensor([[1, 0, 1]]).to(device)) + self.assertTrue(len(trainer.model.get_input_embeddings()._forward_hooks) == 0) + + @require_peft + def test_peft_sft_trainer(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + dataloader_drop_last=True, + evaluation_strategy="steps", + max_steps=4, + eval_steps=2, + save_steps=2, + per_device_train_batch_size=2, + ) + + peft_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + trainer = SFTTrainer( + model=self.model_id, + args=training_args, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + peft_config=peft_config, + packing=True, + ) + + self.assertTrue(isinstance(trainer.model, PeftModel)) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"]) + + self.assertTrue("adapter_model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")) + self.assertTrue("adapter_config.json" in os.listdir(tmp_dir + "/checkpoint-2")) + self.assertTrue("model.safetensors" not in os.listdir(tmp_dir + "/checkpoint-2")) + + @require_peft + def test_peft_sft_trainer_gc(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + dataloader_drop_last=True, + evaluation_strategy="steps", + max_steps=4, + eval_steps=2, + save_steps=2, + per_device_train_batch_size=2, + gradient_checkpointing=True, + ) + + peft_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + trainer = SFTTrainer( + model=self.model_id, + args=training_args, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + peft_config=peft_config, + packing=True, + ) + + self.assertTrue(isinstance(trainer.model, PeftModel)) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"]) + + self.assertTrue("adapter_model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")) + self.assertTrue("adapter_config.json" in os.listdir(tmp_dir + "/checkpoint-2")) + self.assertTrue("model.safetensors" not in os.listdir(tmp_dir + "/checkpoint-2")) + + @require_peft + def test_peft_sft_trainer_neftune(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + dataloader_drop_last=True, + evaluation_strategy="steps", + max_steps=4, + eval_steps=2, + save_steps=2, + per_device_train_batch_size=2, + ) + + peft_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + trainer = SFTTrainer( + model=self.model_id, + args=training_args, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + peft_config=peft_config, + neftune_noise_alpha=5, + packing=True, + ) + + trainer.model = trainer._trl_activate_neftune(trainer.model) + + self.assertTrue(isinstance(trainer.model, PeftModel)) + + device = trainer.model.get_input_embeddings().weight.device + trainer.model.train() + + torch.random.manual_seed(42) + embeds_neftune = trainer.model.get_input_embeddings()(torch.LongTensor([[1, 0, 1]]).to(device)) + + torch.random.manual_seed(24) + embeds_neftune_2 = trainer.model.get_input_embeddings()(torch.LongTensor([[1, 0, 1]]).to(device)) + + self.assertFalse(torch.allclose(embeds_neftune, embeds_neftune_2)) + self.assertTrue(len(trainer.model.get_input_embeddings()._forward_hooks) > 0) + + trainer.neftune_hook_handle.remove() + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"]) + + self.assertTrue("adapter_model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")) + self.assertTrue("adapter_config.json" in os.listdir(tmp_dir + "/checkpoint-2")) + self.assertTrue("model.safetensors" not in os.listdir(tmp_dir + "/checkpoint-2")) + + # Make sure forward pass works fine to check if embeddings forward is not broken. + _ = trainer.model(torch.LongTensor([[1, 0, 1]]).to(device)) + self.assertTrue(len(trainer.model.get_input_embeddings()._forward_hooks) == 0) diff --git a/trl/trl/tests/testing_constants.py b/trl/trl/tests/testing_constants.py new file mode 100644 index 0000000000000000000000000000000000000000..164782130add699864c8312ff82596ecd5eb87d1 --- /dev/null +++ b/trl/trl/tests/testing_constants.py @@ -0,0 +1,19 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +CI_HUB_USER = "__DUMMY_TRANSFORMERS_USER__" +CI_HUB_USER_FULL_NAME = "Dummy User" +CI_HUB_USER_TOKEN = "hf_94wBhPGp6KrrTH3KDchhKpRxZwd6dmHWLL" + +CI_HUB_ENDPOINT = "https://hub-ci.huggingface.co" diff --git a/trl/trl/tests/testing_utils.py b/trl/trl/tests/testing_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f3988de4c9f2498ac796b9b443077583e06a44cf --- /dev/null +++ b/trl/trl/tests/testing_utils.py @@ -0,0 +1,84 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import torch + +from trl import is_diffusers_available, is_peft_available, is_wandb_available, is_xpu_available + + +def require_peft(test_case): + """ + Decorator marking a test that requires peft. Skips the test if peft is not available. + """ + if not is_peft_available(): + test_case = unittest.skip("test requires peft")(test_case) + return test_case + + +def require_diffusers(test_case): + """ + Decorator marking a test that requires diffusers. Skips the test if diffusers is not available. + """ + if not is_diffusers_available(): + test_case = unittest.skip("test requires diffusers")(test_case) + return test_case + + +def require_wandb(test_case, required: bool = True): + """ + Decorator marking a test that requires wandb. Skips the test if wandb is not available. + """ + # XOR, i.e.: + # skip if available and required = False and + # skip if not available and required = True + if is_wandb_available() ^ required: + test_case = unittest.skip("test requires wandb")(test_case) + return test_case + + +def require_no_wandb(test_case): + """ + Decorator marking a test that requires no wandb. Skips the test if wandb is available. + """ + return require_wandb(test_case, required=False) + + +def require_bitsandbytes(test_case): + """ + Decorator marking a test that requires bitsandbytes. Skips the test if bitsandbytes is not available. + """ + try: + import bitsandbytes # noqa: F401 + except ImportError: + test_case = unittest.skip("test requires bitsandbytes")(test_case) + return test_case + + +def require_torch_multi_gpu(test_case): + """ + Decorator marking a test that requires multiple GPUs. Skips the test if there aren't enough GPUs. + """ + if torch.cuda.device_count() < 2: + test_case = unittest.skip("test requires multiple GPUs")(test_case) + return test_case + + +def require_torch_multi_xpu(test_case): + """ + Decorator marking a test that requires multiple XPUs. Skips the test if there aren't enough XPUs. + """ + if torch.xpu.device_count() < 2 and is_xpu_available(): + test_case = unittest.skip("test requires multiple XPUs")(test_case) + return test_case diff --git a/trl/trl/trl/__init__.py b/trl/trl/trl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a00f3a684db658bf5241025978aaf31773180938 --- /dev/null +++ b/trl/trl/trl/__init__.py @@ -0,0 +1,34 @@ +# flake8: noqa + +__version__ = "0.7.5.dev0" + +from .core import set_seed +from .environment import TextEnvironment, TextHistory +from .extras import BestOfNSampler +from .import_utils import is_diffusers_available, is_peft_available, is_wandb_available, is_xpu_available +from .models import ( + AutoModelForCausalLMWithValueHead, + AutoModelForSeq2SeqLMWithValueHead, + PreTrainedModelWrapper, + create_reference_model, +) +from .trainer import ( + DataCollatorForCompletionOnlyLM, + DPOTrainer, + IterativeSFTTrainer, + PPOConfig, + PPOTrainer, + RewardConfig, + RewardTrainer, + SFTTrainer, +) + + +if is_diffusers_available(): + from .models import ( + DDPOPipelineOutput, + DDPOSchedulerOutput, + DDPOStableDiffusionPipeline, + DefaultDDPOStableDiffusionPipeline, + ) + from .trainer import DDPOConfig, DDPOTrainer diff --git a/trl/trl/trl/core.py b/trl/trl/trl/core.py new file mode 100644 index 0000000000000000000000000000000000000000..3180fa69ed76b1632427797a49d5177306246e69 --- /dev/null +++ b/trl/trl/trl/core.py @@ -0,0 +1,328 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import gc +import random +import warnings +from contextlib import contextmanager +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence +from transformers import top_k_top_p_filtering + +from .import_utils import is_xpu_available + + +try: + from collections.abc import Mapping +except ImportError: + from collections import Mapping + + +WANDB_PADDING = -1 + + +def flatten_dict(nested, sep="/"): + """Flatten dictionary and concatenate nested keys with separator.""" + + def rec(nest, prefix, into): + for k, v in nest.items(): + if sep in k: + raise ValueError(f"separator '{sep}' not allowed to be in key '{k}'") + if isinstance(v, Mapping): + rec(v, prefix + k + sep, into) + else: + into[prefix + k] = v + + flat = {} + rec(nested, "", flat) + return flat + + +def convert_to_scalar(stats): + """ + Converts the stats from a flattened dict to single scalar dicts + """ + tensorboard_stats = {} + for k, v in stats.items(): + # for tensorboard compatibility - arrays and tensors are ignored with tensorboard + # therefore we convert single element tensors to scalars + if (isinstance(v, torch.Tensor) or isinstance(v, np.ndarray)) and ( + len(v.shape) == 0 or (len(v.shape) == 1 and v.shape[0] == 1) + ): + v = v.item() + tensorboard_stats[k] = v + return tensorboard_stats + + +def stack_dicts(stats_dicts): + """Stack the values of a dict.""" + results = dict() + for k in stats_dicts[0]: + stats_list = [torch.flatten(d[k]) for d in stats_dicts] + results[k] = pad_sequence(stats_list, batch_first=True, padding_value=WANDB_PADDING) + return results + + +def add_suffix(input_dict, suffix): + """Add suffix to dict keys.""" + return dict((k + suffix, v) for k, v in input_dict.items()) + + +def pad_to_size(tensor, size, dim=1, padding=50256): + """Pad tensor to size.""" + t_size = tensor.size()[dim] + if t_size == size: + return tensor + else: + return torch.nn.functional.pad(tensor, (0, size - t_size), "constant", padding) + + +def logprobs_from_logits(logits, labels, gather=True): + """ + See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591 + """ + logp = F.log_softmax(logits, dim=2) + + if not gather: + return logp + logpy = torch.gather(logp, 2, labels.unsqueeze(2)).squeeze(-1) + return logpy + + +def whiten(values, shift_mean=True): + """Whiten values.""" + mean, var = torch.mean(values), torch.var(values) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + +def masked_mean(values, mask, axis=None): + """Compute mean of tensor with a masked values.""" + if axis is not None: + return (values * mask).sum(axis=axis) / mask.sum(axis=axis) + else: + return (values * mask).sum() / mask.sum() + + +def masked_var(values, mask, unbiased=True): + """Compute variance of tensor with masked values.""" + mean = masked_mean(values, mask) + centered_values = values - mean + variance = masked_mean(centered_values**2, mask) + if unbiased: + mask_sum = mask.sum() + if mask_sum == 0: + raise ValueError( + "The sum of the mask is zero, which can happen when `mini_batch_size=1`;" + "try increase the `mini_batch_size` or `gradient_accumulation_steps`" + ) + # note that if mask_sum == 1, then there is a division by zero issue + # to avoid it you just need to use a larger minibatch_size + bessel_correction = mask_sum / (mask_sum - 1) + variance = variance * bessel_correction + return variance + + +def masked_whiten(values, mask, shift_mean=True): + """Whiten values with masked values.""" + mean, var = masked_mean(values, mask), masked_var(values, mask) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + +def clip_by_value(x, tensor_min, tensor_max): + """ + Tensor extenstion to torch.clamp + https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713 + """ + clipped = torch.max(torch.min(x, tensor_max), tensor_min) + return clipped + + +def entropy_from_logits(logits): + """Calculate entropy from logits.""" + pd = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, axis=-1) - torch.sum(pd * logits, axis=-1) + return entropy + + +def average_torch_dicts(list_of_dicts): + """Average values of a list of dicts with torch tensors.""" + average_dict = dict() + for key in list_of_dicts[0].keys(): + average_dict[key] = torch.mean(torch.stack([d[key] for d in list_of_dicts]), axis=0) + return average_dict + + +def stats_to_np(stats_dict): + """Cast all torch.tensors in dict to numpy arrays.""" + new_dict = dict() + for k, v in stats_dict.items(): + if isinstance(v, torch.Tensor): + new_dict[k] = v.detach().cpu() + if new_dict[k].dtype == torch.bfloat16: + new_dict[k] = new_dict[k].float() + new_dict[k] = new_dict[k].numpy() + else: + new_dict[k] = v + if np.isscalar(new_dict[k]): + new_dict[k] = float(new_dict[k]) + return new_dict + + +def listify_batch(tensor): + """Turns the first dimension of a tensor into a list.""" + return [tensor[i] for i in range(tensor.shape[0])] + + +def build_bert_batch_from_txt(text_list, tokenizer, device): + """Create token id and attention mask tensors from text list for BERT classification.""" + + # tokenize + tensors = [tokenizer.encode(txt, return_tensors="pt").to(device) for txt in text_list] + + # find max length to pad to + max_len = max([t.size()[1] for t in tensors]) + + # get padded tensors and attention masks + # (attention masks make bert ignore padding) + padded_tensors = [] + attention_masks = [] + for tensor in tensors: + attention_mask = torch.ones(tensor.size(), device=device) + padded_tensors.append(pad_to_size(tensor, max_len, padding=0)) + attention_masks.append(pad_to_size(attention_mask, max_len, padding=0)) + + # stack all tensors + padded_tensors = torch.cat(padded_tensors) + attention_masks = torch.cat(attention_masks) + + return padded_tensors, attention_masks + + +def respond_to_batch(model, queries, txt_len=20, top_k=0, top_p=1.0): + """Sample text from language model.""" + input_ids = queries + for i in range(txt_len): + # Get Logits + outputs = model(input_ids) + next_token_logits = outputs[0][:, -1, :] + next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) + # Sample + probs = F.softmax(next_token_logits, dim=-1) + next_token = torch.multinomial(probs, num_samples=1).squeeze(1) + input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1) + return input_ids[:, -txt_len:] + + +def set_seed(seed: int): + """ + Helper function for reproducible behavior to set the seed in `random`, `numpy`, and `torch`. + + Args: + seed (`int`): The seed to set. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if is_xpu_available(): + torch.xpu.manual_seed_all(seed) + else: + torch.cuda.manual_seed_all(seed) + + +class LengthSampler: + """ + Samples a length + """ + + def __init__(self, min_value, max_value): + self.values = list(range(min_value, max_value)) + + def __call__(self): + return np.random.choice(self.values) + + +class PPODecorators(object): + optimize_device_cache = False + + @classmethod + @contextmanager + def empty_device_cache(cls): + yield + if is_xpu_available(): + if cls.optimize_device_cache and torch.xpu.is_available(): + gc.collect() + torch.xpu.empty_cache() + gc.collect() + else: + if cls.optimize_device_cache and torch.cuda.is_available(): + gc.collect() + torch.cuda.empty_cache() + gc.collect() + + +def randn_tensor( + shape: Union[Tuple, List], + generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, + device: Optional["torch.device"] = None, + dtype: Optional["torch.dtype"] = None, + layout: Optional["torch.layout"] = None, +): + """A helper function to create random tensors on the desired `device` with the desired `dtype`. When + passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor + is always created on the CPU. + """ + # device on which tensor is created defaults to device + rand_device = device + batch_size = shape[0] + + layout = layout or torch.strided + device = device or torch.device("cpu") + + if generator is not None: + gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type + if gen_device_type != device.type and gen_device_type == "cpu": + rand_device = "cpu" + if device != "mps": + warnings.warn( + f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." + f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" + f" slighly speed up this function by passing a generator that was created on the {device} device." + ) + elif gen_device_type != device.type and gen_device_type == "cuda": + raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") + + # make sure generator list of length 1 is treated like a non-list + if isinstance(generator, list) and len(generator) == 1: + generator = generator[0] + + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) + else: + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) + + return latents diff --git a/trl/trl/trl/environment/__init__.py b/trl/trl/trl/environment/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae1cda4ecb2e604cc990ce16d982df29846f5204 --- /dev/null +++ b/trl/trl/trl/environment/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa + +from .base_environment import TextEnvironment, TextHistory diff --git a/trl/trl/trl/environment/base_environment.py b/trl/trl/trl/environment/base_environment.py new file mode 100644 index 0000000000000000000000000000000000000000..25f44ae9355c0da6fd8fe19759b8d8a09c888fc4 --- /dev/null +++ b/trl/trl/trl/environment/base_environment.py @@ -0,0 +1,473 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import warnings + +import torch +from accelerate.utils import extract_model_from_parallel +from transformers import StoppingCriteria, StoppingCriteriaList + +from ..import_utils import is_rich_available + + +if is_rich_available(): + from rich import print + from rich.text import Text + + +class StringStoppingCriteria(StoppingCriteria): + """Custom `StoppingCriteria` which checks if all generations in the batch are completed.""" + + def __init__(self, stop_strings, tokenizer): + self.stop_strings = stop_strings + self.tokenizer = tokenizer + self.first_call = True + + def __call__(self, input_ids, scores, **kwargs): + """Returns true if all generated sequences contain any of the stop strings.""" + if self.first_call: + self.generated_tokens = [1 for _ in range(input_ids.shape[0])] + self.start_length = input_ids.shape[-1] - 1 + self.first_call = False + decoded_generations = self.tokenizer.batch_decode(input_ids[:, self.start_length :]) + done = [] + + for i, decoded_generation in enumerate(decoded_generations): + sequence_complete = any([stop_string in decoded_generation for stop_string in self.stop_strings]) + done.append(sequence_complete) + if not sequence_complete: + self.generated_tokens[i] += 1 + + if all(done): + self.first_call = True + + return all(done) + + +class TextHistory: + """The TextHistory class keeps track of the history of an interaction between the language model and the environment.""" + + def __init__(self, text, tokens, system=True): + """ + Initialize TextHistory. + + args: + text (`str`): The text of the first segment. + tokens (`torch.LongTensor`): The tokens of the first segment. + system (`bool`, *optional*): Whether the first segment is a system or user segment. + """ + self.system_spans = [] + self.text_spans = [] + self.token_spans = [] + self.token_masks = torch.tensor([], dtype=torch.long).to(tokens.device) + self.text = "" + self.tokens = torch.tensor([], dtype=torch.long).to(tokens.device) + self.completed = False + self.truncated = False + self.reward = 0.0 + + self.prompt_color = "black on grey85" + self.system_color = "black on cyan3" + self.model_color = "black on deep_sky_blue1" + self.reward_color = "black on plum1" + + self.append_segment(text, tokens, system=system) + + def append_segment(self, text, tokens, system=True): + """ + Append a new segment to the history. + + args: + text (`str`): The text of the new segment. + tokens (`torch.LongTensor`): The tokens of the new segment. + system (`bool`, *optional*): Whether the new segment is a system or user segment. + """ + + if len(text) == 0 or len(tokens) == 0: + raise ValueError("Can't append empty text or token list to history.") + + original_text_length = len(self.text) + + self.text += text + self.text_spans.append((original_text_length, len(self.text))) + self.system_spans.append(system) + + original_token_length = len(self.tokens) + + self.tokens = torch.cat((self.tokens, tokens)) + if system: + self.token_masks = torch.cat((self.token_masks, torch.zeros_like(tokens))) + else: + self.token_masks = torch.cat((self.token_masks, torch.ones_like(tokens))) + self.token_spans.append((original_token_length, len(self.tokens))) + + def complete(self, truncated=False): + """ + Mark the history as completed. + """ + self.completed = True + self.truncated = truncated + + @property + def last_text_segment(self): + """ + Get the last text segment. + """ + start, end = self.text_spans[-1] + return self.text[start:end] + + def split_query_response_tokens(self): + """ + Split the tokens into query and response tokens. + """ + split_index = self.token_spans[0][1] + query = self.tokens[:split_index] + response = self.tokens[split_index:] + mask = self.token_masks[split_index:] + + return query, response, mask + + def show_text(self, show_legend=False): + """ + Print the text history. + """ + if not is_rich_available(): + warnings.warn("install rich to display text") + return + + text = Text(self.text) + text.stylize(self.prompt_color, self.text_spans[0][0], self.text_spans[1][0]) + for i, (start, end) in enumerate(self.text_spans[1:]): + if self.system_spans[i + 1]: + text.stylize(self.system_color, start, end) + else: + text.stylize(self.model_color, start, end) + + text.append(f"\n\nReward: {self.reward}", style=self.reward_color) + print(text) + + if show_legend: + self.show_colour_legend() + + def show_tokens(self, tokenizer, show_legend=False): + """ + Print the history tokens. + """ + if not is_rich_available(): + warnings.warn("install rich to display tokens") + return + + text = Text() + prompt_end = self.token_spans[0][1] + for i, (token, mask) in enumerate(zip(self.tokens, self.token_masks)): + if i < prompt_end: + text.append(tokenizer.convert_ids_to_tokens(token.item()), style=self.prompt_color) + text.append(" ") + elif mask == 0: + text.append(tokenizer.convert_ids_to_tokens(token.item()), style=self.system_color) + text.append(" ") + else: + text.append(tokenizer.convert_ids_to_tokens(token.item()), style=self.model_color) + text.append(" ") + text.append(f"\n\nReward: {self.reward}", style=self.reward_color) + print(text) + if show_legend: + self.show_colour_legend() + + def show_colour_legend(self): + """ + Print the colour legend. + """ + if not is_rich_available(): + warnings.warn("install rich to display colour legend") + return + text = Text("\n\n(Colour Legend: ") + text.append("Prompt", style=self.prompt_color) + text.append("|") + text.append("System", style=self.system_color) + text.append("|") + text.append("Model", style=self.model_color) + text.append("|") + text.append("Reward", style=self.reward_color) + text.append(")") + print(text) + + +class TextEnvironment: + """ + The TextEnvironment enables interaction of a LLM with an environment using tools. + """ + + def __init__( + self, + model=None, + tokenizer=None, + tools=None, + reward_fn=None, + prompt=None, + max_turns=4, + max_tool_reponse=100, + max_length=None, + generation_kwargs=None, + ): + """ + Initialize TextEnvironment. + + Args: + model (`PreTrainedModelWrapper`): The model to use for generation. + tokenizer (`transformers.PreTrainedTokenizer`): The tokenizer to use for generation. + tools (list): A list of tools to use for interaction. + reward_fn (function): A function that takes a string and returns a reward. + prompt (str): The base prompt to use for generation. Is prepended to the tasks. + max_turns (Optional[int]): The maximum number of turns to allow. + max_tool_response (Optional[int]): The maximum number of characters to allow in a tool response. + max_length (Optional[int]): The maximum number of tokens to allow in an episode. + generation_kwargs (Optional[dict]): A dictionary of keyword arguments to pass to the model's generate method. + """ + self.model = model + self.tokenizer = tokenizer + self.prompt = prompt + if isinstance(tools, dict): + self.tools = tools + else: + self.tools = dict([(tool.__class__.__name__, tool) for tool in tools]) + self.reward_fn = reward_fn + self.max_length = max_length + self.request_token = "" + self.call_token = "" + self.response_token = "" + self.submit_token = "" + self.max_turns = max_turns + self.max_tool_response = max_tool_reponse + + if generation_kwargs is None: + self.generation_kwargs = dict() + else: + self.generation_kwargs = generation_kwargs + + self.is_encoder_decoder = hasattr(self.model, "is_encoder_decoder") + self.current_device = extract_model_from_parallel(self.model).pretrained_model.device + + def run(self, queries, **rewards_kwargs): + """ + Run the environment on a list of queries. + + Args: + queries (list[str]): A list of queries to run the model in the environment on. + """ + turns = 0 + + queries = [self.prompt + task for task in queries] + queries_tokens = [ + self.tokenizer(query, return_tensors="pt").input_ids[0].to(self.model.pretrained_model.device) + for query in queries + ] + + histories = [TextHistory(q, qt, system=True) for q, qt in zip(queries, queries_tokens)] + + while any([not history.completed for history in histories]) and turns < self.max_turns: + histories = self.generate(histories) + histories = self.tasks_end_check(histories) + # TODO: make this parallel rather than for-loop + for i in range(len(histories)): + histories[i] = self.step(histories[i]) + histories = self.tasks_end_check(histories, model_turn=False) + turns += 1 + self.compute_reward(histories, **rewards_kwargs) + + # convert a list of (q, r, m) tuples to lists of all qs, rs, and ms respectively + queries, responses, masks = map(list, zip(*[history.split_query_response_tokens() for history in histories])) + + rewards = [history.reward for history in histories] + return queries, responses, masks, rewards, histories + + def step(self, history): + """ + Step the environment forward one turn. + + Args: + history (`TextHistory`): The history to step forward. + """ + truncated, ended = self.task_end_check(history) + if ended: + history.complete(truncated=truncated) + if history.completed: + return history + + tool, query = self.parse_tool_call(history.last_text_segment) + if tool is None or query is None: + response = f"Unknown tool call: {history.last_text_segment}" + else: + if tool not in self.tools: + response = f"Unknown tool {tool}." + try: + response = self.tools[tool](query) + except Exception as error: + response = f"Tool error: {str(error)}" + + if len(response) > self.max_tool_response: + response = response[: (self.max_tool_response - 3)] + "..." + + history.append_segment( + response + self.response_token, + self.tokenizer(response + self.response_token, return_tensors="pt") + .input_ids[0] + .to(self.model.pretrained_model.device), + system=True, + ) + + return history + + def parse_tool_call(self, text): + """ + Parse request string. Expected format: query + """ + result = re.search(f"(?<={self.request_token}).*?(?={self.call_token})", text, re.DOTALL) + + # if we can't find a / span we return none + if result is None: + return None, None + else: + extracted_text = result.group() + + result = re.search(r"<(.*?)>", extracted_text) + + # if we can't find a tool name we return none + if result is None: + return None, None + else: + tool = result.group(1) + + # split off the tool name + query = ">".join(extracted_text.split(">")[1:]) + + return tool, query + + def compute_reward(self, histories, **reward_kwargs): + """ + Compute the reward for a list of histories. + """ + rewards = self.reward_fn([history.last_text_segment for history in histories], **reward_kwargs) + for history, reward in zip(histories, rewards): + history.reward = reward + return histories + + def generate(self, histories): + """ + Generate responses for a list of histories. + """ + active_histories = [i for i, history in enumerate(histories) if not history.completed] + + query_tensors = [histories[i].tokens for i in active_histories] + response_tensors = self._generate_batched(query_tensors) + response_texts = self.tokenizer.batch_decode(response_tensors) + + for i, response_text, response_tensor in zip(active_histories, response_texts, response_tensors): + histories[i].append_segment(response_text, response_tensor, system=False) + + return histories + + def tasks_end_check(self, histories, model_turn=True): + """ + Check if the current generation sequences have finished. + """ + for history in histories: + if not history.completed: + truncated, ended = self.task_end_check(history, model_turn=model_turn) + if ended: + history.complete(truncated=truncated) + return histories + + def task_end_check(self, history, model_turn=True): + """ + Check if the current generation sequence has finished. + """ + truncated = False + ended = False + if history.completed: + return truncated, ended + if self.max_length is not None and len(self.tokenizer(history.text).input_ids[0]) > self.max_length: + truncated = True + ended = True + elif self.tokenizer.eos_token in history.text: + ended = True + elif model_turn and not ( + (self.request_token in history.last_text_segment and self.call_token in history.last_text_segment) + or self.submit_token in history.last_text_segment + ): + ended = True + elif self.submit_token in history.last_text_segment: + ended = True + return truncated, ended + + def _generate_batched( + self, + query_tensors, + batch_size: int = 16, + pad_to_multiple_of: int = None, + ): + """ + Generate responses for a list of query tensors. + + args: + query_tensors (list[torch.Tensor]): A list of query tensors to generate responses for. + batch_size (int): The batch size to use for generation. + pad_to_multiple_of (int): The padding length to use for generation. + """ + outputs = [] + padding_side_default = self.tokenizer.padding_side + if not self.is_encoder_decoder: + self.tokenizer.padding_side = "left" + + # in case we have fewer examples than bs + batch_size = min(len(query_tensors), batch_size) + + for i in range(0, len(query_tensors), batch_size): + # prevent overflow if query tensors are not even multiple of bs + end_index = min(len(query_tensors), i + batch_size) + + batch = query_tensors[i:end_index] + batch_mask = [torch.ones_like(element) for element in batch] + inputs = {"input_ids": batch, "attention_mask": batch_mask} + + padded_inputs = self.tokenizer.pad( + inputs, + padding=True, + max_length=None, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors="pt", + ).to(self.current_device) + + stopping_criteria = StringStoppingCriteria([self.call_token, self.submit_token], self.tokenizer) + + self.generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stopping_criteria]) + + generations = extract_model_from_parallel(self.model).generate(**padded_inputs, **self.generation_kwargs) + + for generation, mask, generated_tokens in zip( + generations, padded_inputs["attention_mask"], stopping_criteria.generated_tokens + ): + if not self.is_encoder_decoder: + output = generation[(1 - mask).sum() :] # remove padding + else: + output = generation + + if not self.is_encoder_decoder: + output = output[(mask).sum() :] # remove prompt + + # remove chunk generated after stopping criteria in batch mode + outputs.append(output[:generated_tokens]) + self.tokenizer.padding_side = padding_side_default + return outputs diff --git a/trl/trl/trl/extras/__init__.py b/trl/trl/trl/extras/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6b3035db92af28f5d19d72813f08b06fdad50925 --- /dev/null +++ b/trl/trl/trl/extras/__init__.py @@ -0,0 +1,16 @@ +# flake8: noqa + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .best_of_n_sampler import BestOfNSampler diff --git a/trl/trl/trl/extras/best_of_n_sampler.py b/trl/trl/trl/extras/best_of_n_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..1441eecd41a5a18b7612f4e270271d137c07d437 --- /dev/null +++ b/trl/trl/trl/extras/best_of_n_sampler.py @@ -0,0 +1,117 @@ +from typing import Any, Callable, List, Optional, Union + +import torch +from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast + +from ..core import set_seed +from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper + + +class BestOfNSampler(object): + def __init__( + self, + model: PreTrainedModelWrapper, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + queries_to_scores: Callable[[List[str]], List[float]], + length_sampler: Any, + sample_size: int = 4, + seed: Optional[int] = None, + n_candidates: int = 1, + generation_config: Optional[GenerationConfig] = None, + ) -> None: + r""" + Initialize the sampler for best-of-n generation + + Args: + model (`PreTrainedModelWrapper`): + The pretrained model to use for generation + tokenizer (`PreTrainedTokenizer` or `PreTrainedTokenizerFast`): + Tokenizer associated with the pretrained model + queries_to_scores (`Callable[[List[str]], List[float]]`): + Callable that takes a list of generated texts and returns the associated reward scores + length_sampler (`Any`): + Sampler used to sample the length of the generated text + sample_size (`int`): + Number of samples to generate for each query + seed (`int`, *optional*): + Random seed used to control generation + n_candidates (`int`): + Number of candidates to return for each query + generation_config (`GenerationConfig`, *optional*): + Generation config passed to the underlying model's `generate` method. + See `GenerationConfig` (https://huggingface.co/docs/transformers/v4.29.1/en/main_classes/text_generation#transformers.GenerationConfig) for more details + """ + if seed is not None: + set_seed(seed) + + if not isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): + raise ValueError( + f"tokenizer must be a PreTrainedTokenizer or PreTrainedTokenizerFast, got {type(tokenizer)}" + ) + if not isinstance(model, (SUPPORTED_ARCHITECTURES)): + raise ValueError( + f"model must be a PreTrainedModelWrapper, got {type(model)} - supported architectures are: {SUPPORTED_ARCHITECTURES}" + ) + + self.model = model + self.tokenizer = tokenizer + + self.queries_to_scores = queries_to_scores + self.length_sampler = length_sampler + self.gen_config = generation_config + self.sample_size = sample_size + self.n_candidates = n_candidates + + def generate( + self, + tokenized_query: Union[List[int], torch.Tensor, List[torch.Tensor], List[List[int]]], + skip_special_tokens: bool = True, + device: Optional[Union[str, torch.device]] = None, + **generation_kwargs, + ) -> List[List[str]]: + r""" + Generate the best of n samples for input queries + + Args: + tokenized_query (`List[int]` or `torch.Tensor` or `List[torch.Tensor]` or `List[int]`): + represents either a single tokenized query (a single tensor or a list of integers) or a batch of tokenized queries (a list of tensors or a list of lists of integers) + skip_special_tokens (`bool`): + Whether to remove the special tokens from the output + device (`str` or `torch.device`, *optional*): + The device on which the model will be loaded + **generation_kwargs (`dict`, *optional*): + Additional keyword arguments passed along to the underlying model's `generate` method. + This is used to override generation config + + Returns: + List[List[str]]: A list of lists of generated texts + """ + queries = None + + if isinstance(tokenized_query, torch.Tensor) and tokenized_query.ndim == 1: + queries = tokenized_query.unsqueeze(0) + elif isinstance(tokenized_query, List): + element_type = type(tokenized_query[0]) + if element_type == int: + queries = torch.tensor(tokenized_query).unsqueeze(0) + elif element_type == torch.Tensor: + queries = [tensor.reshape((1, -1)) for tensor in tokenized_query] + else: + queries = [torch.tensor(query).reshape((1, -1)) for query in tokenized_query] + + result = [] + + for query in queries: + queries = query.repeat((self.sample_size, 1)) + output = self.model.generate( + queries.to(device), + max_new_tokens=self.length_sampler(), + generation_config=self.gen_config, + **generation_kwargs, + ).squeeze() + output = self.tokenizer.batch_decode(output, skip_special_tokens=skip_special_tokens) + scores = torch.tensor(self.queries_to_scores(output)) + output = [output[i] for i in scores.topk(self.n_candidates).indices] + result.append(output) + + return result diff --git a/trl/trl/trl/import_utils.py b/trl/trl/trl/import_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7697f6db840d0d96721b4369a8505ef58c54c747 --- /dev/null +++ b/trl/trl/trl/import_utils.py @@ -0,0 +1,90 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import sys + + +if sys.version_info < (3, 8): + _is_python_greater_3_8 = False +else: + _is_python_greater_3_8 = True + + +def is_peft_available() -> bool: + return importlib.util.find_spec("peft") is not None + + +def is_accelerate_greater_20_0() -> bool: + if _is_python_greater_3_8: + from importlib.metadata import version + + accelerate_version = version("accelerate") + else: + import pkg_resources + + accelerate_version = pkg_resources.get_distribution("accelerate").version + return accelerate_version >= "0.20.0" + + +def is_transformers_greater_than(version: str) -> bool: + _transformers_version = importlib.metadata.version("transformers") + return _transformers_version > version + + +def is_torch_greater_2_0() -> bool: + if _is_python_greater_3_8: + from importlib.metadata import version + + torch_version = version("torch") + else: + import pkg_resources + + torch_version = pkg_resources.get_distribution("torch").version + return torch_version >= "2.0" + + +def is_diffusers_available() -> bool: + return importlib.util.find_spec("diffusers") is not None + + +def is_bitsandbytes_available() -> bool: + return importlib.util.find_spec("bitsandbytes") is not None + + +def is_torchvision_available() -> bool: + return importlib.util.find_spec("torchvision") is not None + + +def is_rich_available() -> bool: + return importlib.util.find_spec("rich") is not None + + +def is_wandb_available() -> bool: + return importlib.util.find_spec("wandb") is not None + + +def is_xpu_available() -> bool: + if is_accelerate_greater_20_0(): + import accelerate + + return accelerate.utils.is_xpu_available() + else: + if importlib.util.find_spec("intel_extension_for_pytorch") is None: + return False + try: + import torch + + return hasattr(torch, "xpu") and torch.xpu.is_available() + except RuntimeError: + return False diff --git a/trl/trl/trl/models/__init__.py b/trl/trl/trl/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6ccce25e5e4c8cf05f2d02efee351512a5b6d848 --- /dev/null +++ b/trl/trl/trl/models/__init__.py @@ -0,0 +1,34 @@ +# flake8: noqa + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .modeling_base import PreTrainedModelWrapper, create_reference_model +from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead + + +SUPPORTED_ARCHITECTURES = ( + AutoModelForCausalLMWithValueHead, + AutoModelForSeq2SeqLMWithValueHead, +) + +from ..import_utils import is_diffusers_available + + +if is_diffusers_available(): + from .modeling_sd_base import ( + DDPOPipelineOutput, + DDPOSchedulerOutput, + DDPOStableDiffusionPipeline, + DefaultDDPOStableDiffusionPipeline, + ) diff --git a/trl/trl/trl/models/modeling_base.py b/trl/trl/trl/models/modeling_base.py new file mode 100644 index 0000000000000000000000000000000000000000..9a3142eebd87b188fa7de89116874bdc590cccfc --- /dev/null +++ b/trl/trl/trl/models/modeling_base.py @@ -0,0 +1,672 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import logging +import os +from copy import deepcopy + +import torch +import torch.nn as nn +from accelerate import PartialState +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import ( + EntryNotFoundError, + HFValidationError, + LocalEntryNotFoundError, + RepositoryNotFoundError, +) +from safetensors.torch import load_file as safe_load_file +from transformers import PreTrainedModel + +from ..import_utils import is_peft_available, is_transformers_greater_than, is_xpu_available + + +if is_peft_available(): + from peft import ( + PeftConfig, + PeftModel, + PeftModelForCausalLM, + PeftModelForSeq2SeqLM, + PromptLearningConfig, + get_peft_model, + prepare_model_for_kbit_training, + ) + +if is_transformers_greater_than("4.33.0"): + from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled +else: + from transformers.deepspeed import is_deepspeed_zero3_enabled + +LAYER_PATTERNS = [ + "transformer.h.{layer}", + "model.decoder.layers.{layer}", + "gpt_neox.layers.{layer}", + "model.layers.{layer}", +] + + +class PreTrainedModelWrapper(nn.Module): + r""" + A wrapper class around a (`transformers.PreTrainedModel`) to be compatible with the + (`~transformers.PreTrained`) class in order to keep some attributes and methods of the + (`~transformers.PreTrainedModel`) class. + + Attributes: + pretrained_model: (`transformers.PreTrainedModel`) + The model to be wrapped. + parent_class: (`transformers.PreTrainedModel`) + The parent class of the model to be wrapped. + supported_args: (`list`) + The list of arguments that are supported by the wrapper class. + """ + transformers_parent_class = None + supported_args = None + supported_modules = ("v_head",) + supported_rm_modules = ("score",) + supported_pretrained_model_architectures = ( + (PreTrainedModel) + if not is_peft_available() + else (PreTrainedModel, PeftModelForCausalLM, PeftModelForSeq2SeqLM) + ) + + def __init__( + self, pretrained_model=None, score_module=None, supports_rm_adapter=False, rm_adapter_name=None, **kwargs + ): + super().__init__() + self.pretrained_model = pretrained_model + + self.config = pretrained_model.config + self.prepare_inputs_for_generation = pretrained_model.prepare_inputs_for_generation + self.is_loaded_in_8bit = getattr(pretrained_model, "is_loaded_in_8bit", False) + self.is_loaded_in_4bit = getattr(pretrained_model, "is_loaded_in_4bit", False) + self.is_sequential_parallel = False + + if hasattr(pretrained_model, "gradient_checkpointing_disable"): + self.gradient_checkpointing_disable = pretrained_model.gradient_checkpointing_disable + + if hasattr(pretrained_model, "gradient_checkpointing_enable"): + self.gradient_checkpointing_enable = pretrained_model.gradient_checkpointing_enable + + self.supports_rm_adapter = supports_rm_adapter + self.rm_adapter_name = rm_adapter_name + self.policy_adapter_name = "default" + if score_module is not None: + self.score = score_module + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + r""" + Instantiates a new model from a pretrained model from `transformers`. The + pretrained model is loaded using the `from_pretrained` method of the + `transformers.PreTrainedModel` class. The arguments that are specific to the + `transformers.PreTrainedModel` class are passed along this method and filtered + out from the `kwargs` argument. + + + Args: + pretrained_model_name_or_path (`str` or `transformers.PreTrainedModel`): + The path to the pretrained model or its name. + *model_args (`list`, *optional*)): + Additional positional arguments passed along to the underlying model's + `from_pretrained` method. + **kwargs (`dict`, *optional*): + Additional keyword arguments passed along to the underlying model's + `from_pretrained` method. We also pre-process the kwargs to extract + the arguments that are specific to the `transformers.PreTrainedModel` + class and the arguments that are specific to trl models. The kwargs + also support `prepare_model_for_kbit_training` arguments from + `peft` library. + """ + if kwargs is not None: + peft_config = kwargs.pop("peft_config", None) + reward_adapter = kwargs.pop("reward_adapter", None) + reward_adapter_name = kwargs.pop("reward_adapter_name", "reward_adapter") + is_trainable = kwargs.pop("is_trainable", False) + trl_model_args, pretrained_kwargs, peft_quantization_kwargs = cls._split_kwargs(kwargs) + token = pretrained_kwargs.get("token", None) + else: + peft_config = None + is_trainable = False + trl_model_args = {} + pretrained_kwargs = {} + peft_quantization_kwargs = {} + token = None + + if reward_adapter is not None and not isinstance(reward_adapter, str): + raise ValueError( + "The `reward_adapter` argument should be a string representing the name of local path or the Hub id to the Reward Modeling adapter." + ) + + is_peft_model = False + + current_device = cls._get_current_device() + if isinstance(pretrained_model_name_or_path, str): + is_loaded_in_8bit = pretrained_kwargs["load_in_8bit"] if "load_in_8bit" in pretrained_kwargs else False + is_loaded_in_4bit = pretrained_kwargs["load_in_4bit"] if "load_in_4bit" in pretrained_kwargs else False + else: + is_loaded_in_8bit = getattr(pretrained_model_name_or_path, "is_loaded_in_8bit", False) + is_loaded_in_4bit = getattr(pretrained_model_name_or_path, "is_loaded_in_4bit", False) + + if (is_loaded_in_8bit or is_loaded_in_4bit) and "device_map" not in pretrained_kwargs: + # warn users + logging.warning( + "The `device_map` argument is not provided. We will override the device_map argument." + " to set the entire" + " model on the current device. If you want to set the model on multiple devices, please provide" + " a custom `device_map` argument." + ) + pretrained_kwargs["device_map"] = {"": current_device} + + if is_peft_available() and peft_config is not None and not isinstance(peft_config, PeftConfig): + raise ValueError("The `peft_config` argument should be an instance of `peft.PeftConfig` class.") + + # First, load the pre-trained model using the parent-class + # either `AutoModelForCausalLM` or `AutoModelForSeq2SeqLM` + if isinstance(pretrained_model_name_or_path, str): + if is_peft_available(): + try: + # If there is a trained peft adapter in the hub, load its config. + remote_adapter_config = hf_hub_download( + pretrained_model_name_or_path, + "adapter_config.json", + token=token, + ) + except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError, RepositoryNotFoundError): + remote_adapter_config = None + else: + remote_adapter_config = None + + local_adapter_present = os.path.exists(os.path.join(pretrained_model_name_or_path, "adapter_config.json")) + + if (local_adapter_present or remote_adapter_config is not None) and is_peft_available(): + if peft_config is not None: + logging.warning( + "`peft_config` argument ignored since a peft config file was found in " + f"{pretrained_model_name_or_path}" + ) + + # Load the trained peft adapter config + if local_adapter_present: + trained_adapter_config = PeftConfig.from_pretrained(pretrained_model_name_or_path) + else: + remote_adapter_dir = os.path.dirname(remote_adapter_config) + trained_adapter_config = PeftConfig.from_pretrained(remote_adapter_dir) + + # Load the pretrained base model + pretrained_model = cls.transformers_parent_class.from_pretrained( + trained_adapter_config.base_model_name_or_path, *model_args, **pretrained_kwargs + ) + + # Wrap the pretrained model with the trained peft adapter + pretrained_model = PeftModel.from_pretrained( + pretrained_model, pretrained_model_name_or_path, is_trainable=is_trainable + ) + logging.info("Trained peft adapter loaded") + else: + pretrained_model = cls.transformers_parent_class.from_pretrained( + pretrained_model_name_or_path, *model_args, **pretrained_kwargs + ) + + if peft_config is not None: + # Initialize a new peft adapter with the given config + if is_loaded_in_8bit or is_loaded_in_4bit: + pretrained_model = prepare_model_for_kbit_training( + pretrained_model, + **peft_quantization_kwargs, + ) + pretrained_model = get_peft_model(pretrained_model, peft_config) + logging.info("peft adapter initialised") + + elif isinstance(pretrained_model_name_or_path, cls.supported_pretrained_model_architectures): + pretrained_model = pretrained_model_name_or_path + + if peft_config is not None and isinstance(pretrained_model, PreTrainedModel): + # Initialize a new peft adapter with the given config + if is_loaded_in_8bit or is_loaded_in_4bit: + pretrained_model = prepare_model_for_kbit_training( + pretrained_model, + **peft_quantization_kwargs, + ) + pretrained_model = get_peft_model(pretrained_model, peft_config) + logging.info("peft adapter initialised") + else: + raise ValueError( + "pretrained_model_name_or_path should be a string or a PreTrainedModel, " + f"but is {type(pretrained_model_name_or_path)}" + ) + + if is_peft_available(): + if isinstance(pretrained_model, PeftModel): + is_peft_model = True + # for backward compatibility + if hasattr(pretrained_model, "active_peft_config") and isinstance( + pretrained_model.active_peft_config, PromptLearningConfig + ): + raise ValueError("PromptLearningConfig is not supported for PPO training.") + + # Add reward modeling adapter if specified + if not is_peft_model and reward_adapter is not None: + raise ValueError("reward_adapter can only be used with a PeftModel. ") + elif is_peft_model and reward_adapter is not None: + score_module = cls.add_and_load_reward_modeling_adapter( + pretrained_model, reward_adapter, reward_adapter_name, token=token + ) + multi_adapter_args = { + "score_module": score_module, + "supports_rm_adapter": True, + "rm_adapter_name": reward_adapter_name, + } + else: + multi_adapter_args = {"supports_rm_adapter": False} + + # Then, create the full model by instantiating the wrapper class + model = cls(pretrained_model, **multi_adapter_args, **trl_model_args) + + # if resume_training, load the state_dict again - this is ok since the + # state_dict is removed from the model after loading it. + is_resuming_training = True + if isinstance(pretrained_model_name_or_path, str): + safe_filename = os.path.join(pretrained_model_name_or_path, "model.safetensors") + filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin") + + sharded_index_filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin.index.json") + safe_sharded_index_filename = os.path.join(pretrained_model_name_or_path, "model.safetensors.index.json") + is_sharded = False + use_safe = os.path.exists(safe_filename) + + if not (os.path.exists(filename) or os.path.exists(safe_filename)): + # Try with `pytorch_model.bin` + filename, files_to_download, is_sharded, is_resuming_training = cls._get_checkpoint_from_hub( + pretrained_model, + pretrained_model_name_or_path, + sharded_index_filename, + token=token, + ) + # Try with safetensors + if filename is None and files_to_download is None: + safe_filename, files_to_download, is_sharded, is_resuming_training = cls._get_checkpoint_from_hub( + pretrained_model, + pretrained_model_name_or_path, + safe_sharded_index_filename, + token=token, + model_name="model.safetensors", + model_index_name="model.safetensors.index.json", + ) + use_safe = True + else: + use_safe = False + + loading_func = safe_load_file if use_safe else torch.load + load_kwargs = {} if use_safe else {"map_location": "cpu"} + + if is_resuming_training: + if is_sharded: + # download each file and add it to the state_dict + state_dict = {} + + for shard_file in files_to_download: + filename = hf_hub_download( + pretrained_model_name_or_path, + shard_file, + token=token, + ) + state_dict.update(loading_func(filename, **load_kwargs)) + else: + state_dict = loading_func(filename if not use_safe else safe_filename, **load_kwargs) + + else: + state_dict = pretrained_model_name_or_path.state_dict() + + model.is_peft_model = is_peft_model + model.current_device = current_device + + if is_resuming_training: + model.post_init(state_dict=state_dict) + + return model + + @classmethod + def _get_checkpoint_from_hub( + cls, + pretrained_model, + pretrained_model_name_or_path, + index_filename, + token=None, + model_name="pytorch_model.bin", + model_index_name="pytorch_model.bin.index.json", + ): + files_to_download = None + filename = None + is_resuming_training = True + is_sharded = False + + try: + filename = hf_hub_download( + pretrained_model_name_or_path, + model_name, + token=token, + ) + # sharded + except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError, RepositoryNotFoundError): + if os.path.exists(index_filename): + index_file_name = index_filename + else: + try: + index_file_name = hf_hub_download( + pretrained_model_name_or_path, + model_index_name, + token=token, + ) + except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError, RepositoryNotFoundError): + # not continue training, do not have v_head weight + is_resuming_training = False + logging.warning( + f"A {type(pretrained_model)} model is loaded from '{pretrained_model_name_or_path}', " + f"and no v_head weight is found. This IS expected if you are not resuming PPO training." + ) + # load json + if is_resuming_training: + with open(index_file_name, "r") as f: + index = json.load(f) + # check filename with `v_head` or any known extra module: + files_to_download = set() + for k, v in index["weight_map"].items(): + if any([module in k for module in cls.supported_modules]): + files_to_download.add(v) + is_sharded = True + + return filename, files_to_download, is_sharded, is_resuming_training + + @classmethod + def _get_current_device(cls): + r""" + Get the current device. For GPU, we return the local process index using the `accelerate.PartialState` + object to handle corner cases when running scripts in distributed environments. + + Returns: + current_device (`Union[int, str]`): + The current device. + """ + state = PartialState() + if is_xpu_available(): + return f"xpu:{state.local_process_index}" + else: + return state.local_process_index if torch.cuda.is_available() else "cpu" + + @classmethod + def _split_kwargs(cls, kwargs): + """ + Separate the kwargs from the arguments that we support inside + `supported_args` and the ones that we don't. + """ + check_peft_kwargs = False + + if is_peft_available(): + from peft import prepare_model_for_kbit_training + + check_peft_kwargs = True + + supported_kwargs = {} + unsupported_kwargs = {} + peft_kwargs = {} + + for key, value in kwargs.items(): + if key in cls.supported_args: + supported_kwargs[key] = value + else: + unsupported_kwargs[key] = value + + if check_peft_kwargs: + if key in prepare_model_for_kbit_training.__code__.co_varnames: + peft_kwargs[key] = value + if key in unsupported_kwargs: + unsupported_kwargs.pop(key) + + return supported_kwargs, unsupported_kwargs, peft_kwargs + + @classmethod + def add_and_load_reward_modeling_adapter( + cls, pretrained_model, adapter_model_id, adapter_name="reward_model_adapter", token=None + ): + r""" + Add and load a reward modeling adapter. This method can only be used if the + model is a `PeftModel` and if you have initialized the model with the `reward_modeling_adapter_id` + argument, pointing to the id of the reward modeling adapter. The latest needs also to contain the + score head in order to produce the reward. + """ + pretrained_model.load_adapter(adapter_model_id, adapter_name, is_trainable=False) + pretrained_model.train() + + filename = os.path.join(adapter_model_id, "adapter_model.bin") + safe_loading = False + if not os.path.exists(filename): + try: + local_filename = hf_hub_download( + adapter_model_id, + "adapter_model.bin", + token=token, + ) + except: # noqa + filename = os.path.join(adapter_model_id, "adapter_model.safetensors") + safe_loading = True + if not os.path.exists(filename): + try: + local_filename = hf_hub_download( + adapter_model_id, + "adapter_model.safetensors", + token=token, + ) + except: # noqa + raise ValueError( + "Could not find adapter model in the Hub, make sure you have the correct adapter model id." + ) + else: + local_filename = filename + else: + local_filename = filename + + loading_func = safe_load_file if safe_loading else torch.load + load_kwargs = {} if safe_loading else {"map_location": "cpu"} + + adapter_state_dict = loading_func(local_filename, **load_kwargs) + + for score_name_candidate in cls.supported_rm_modules: + if any([score_name_candidate in name for name in adapter_state_dict.keys()]): + score_name = score_name_candidate + # we have found the correct head name and can break + break + + score_dict = {} + + for name, param in adapter_state_dict.items(): + if score_name in name: + key_name = ".".join(name.split(".")[-1:]) + score_dict[key_name] = param.to(cls._get_current_device()) + + num_labels, hidden_dim = score_dict["weight"].shape + has_bias = any(["bias" in name for name in adapter_state_dict.keys()]) + + score = nn.Linear(hidden_dim, num_labels, bias=has_bias).to( + device=cls._get_current_device(), + dtype=pretrained_model.dtype, + ) + score.load_state_dict(score_dict) + for param in score.parameters(): + param.requires_grad = False + + return score + + def push_to_hub(self, *args, **kwargs): + r""" + Push the pretrained model to the hub. This method is a wrapper around + `transformers.PreTrainedModel.push_to_hub`. Please refer to the documentation + of `transformers.PreTrainedModel.push_to_hub` for more information. + + Args: + *args (`list`, *optional*): + Positional arguments passed along to the underlying model's + `push_to_hub` method. + **kwargs (`dict`, *optional*): + Keyword arguments passed along to the underlying model's + `push_to_hub` method. + """ + raise NotImplementedError + + def save_pretrained(self, *args, **kwargs): + r""" + Save the pretrained model to a directory. This method is a wrapper around + `transformers.PreTrainedModel.save_pretrained`. Please refer to the documentation + of `transformers.PreTrainedModel.save_pretrained` for more information. + + Args: + *args (`list`, *optional*): + Positional arguments passed along to the underlying model's + `save_pretrained` method. + **kwargs (`dict`, *optional*): + Keyword arguments passed along to the underlying model's + `save_pretrained` method. + """ + state_dict = kwargs.get("state_dict") + if state_dict is None: + state_dict = self.state_dict() + kwargs["state_dict"] = state_dict + + # if it is a peft model only save the `v_head` state_dict and + # pop the `state_dict` from the kwargs to avoid slient bugs with `peft` + if self.is_peft_model: + save_path = args[0] + save_path = os.path.join(save_path, "pytorch_model.bin") + torch.save(state_dict, save_path) + _ = kwargs.pop("state_dict", None) + + return self.pretrained_model.save_pretrained(*args, **kwargs) + + def state_dict(self, *args, **kwargs): + r""" + Return the state_dict of the pretrained model. + """ + raise NotImplementedError + + def post_init(self, *args, **kwargs): + r""" + Post initialization method. This method is called after the model is + instantiated and loaded from a checkpoint. It can be used to perform + additional operations such as loading the state_dict. + """ + raise NotImplementedError + + def compute_reward_score(self, input_ids, attention_mask=None, **kwargs): + r""" + Computes the reward score for a given input. The method has first to enable the adapter + and then compute the reward score. After that the model disables the reward modeling + adapter and enables the default ppo adapter again. + """ + if not self.supports_rm_adapter: + raise ValueError("This model does not support reward modeling adapter.") + + # enable rm adapter + self.pretrained_model.set_adapter(self.rm_adapter_name) + self.pretrained_model.eval() + + with torch.no_grad(): + base_model_output = self.pretrained_model( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + return_dict=True, + **kwargs, + ) + + last_hidden_states = base_model_output.hidden_states[-1] + scores = self.score(last_hidden_states) + + self.pretrained_model.set_adapter(self.policy_adapter_name) + self.pretrained_model.eval() + + return scores + + +def create_reference_model( + model: PreTrainedModelWrapper, num_shared_layers: int = None, pattern: str = None +) -> PreTrainedModelWrapper: + """ + Creates a static reference copy of a model. Note that model will be in `.eval()` mode. + + Args: + model (`PreTrainedModelWrapper`): The model to be copied. + num_shared_layers (`int`, *optional*): The number of initial layers that are shared between both models and kept frozen. + pattern (`str`, *optional*): The shared layers are selected with a string pattern + (e.g. "transformer.h.{layer}" for GPT2) and if a custom pattern is necessary it can be passed here. + + Returns + `PreTrainedModelWrapper` + """ + if is_deepspeed_zero3_enabled(): + raise ValueError( + "DeepSpeed ZeRO-3 is enabled and is not compatible with `create_reference_model()`. Please instantiate your reference model directly with `AutoCausalLM.from_pretrained()`." + ) + + parameter_names = [n for n, _ in model.named_parameters()] + ref_model = deepcopy(model) + + # if no layers are shared, return copy of model + if num_shared_layers is None: + for param_name in parameter_names: + param = ref_model.get_parameter(param_name) + param.requires_grad = False + return ref_model.eval() + + # identify layer name pattern + if pattern is not None: + pattern = pattern.format(layer=num_shared_layers) + else: + for pattern_candidate in LAYER_PATTERNS: + pattern_candidate = pattern_candidate.format(layer=num_shared_layers) + if any([pattern_candidate in name for name in parameter_names]): + pattern = pattern_candidate + break + + if pattern is None: + raise ValueError("Layer pattern could not be matched.") + + # divide parameters in shared and unshared parameter lists + shared_param_list = [] + unshared_param_list = [] + + shared_parameter = True + for name, param in model.named_parameters(): + if pattern in name: + shared_parameter = False + if shared_parameter: + shared_param_list.append(name) + else: + unshared_param_list.append(name) + + # create reference of the original parameter if they are shared + for param_name in shared_param_list: + param = model.get_parameter(param_name) + param.requires_grad = False + + ref_param = ref_model.get_parameter(param_name) # noqa + ref_param = param # noqa + + # for all other parameters just make sure they don't use gradients + for param_name in unshared_param_list: + param = ref_model.get_parameter(param_name) + param.requires_grad = False + + if pattern is not None and len(unshared_param_list) == 0: + logging.warning("Pattern passed or found, but no layers matched in the model. Check for a typo.") + + return ref_model.eval() diff --git a/trl/trl/trl/models/modeling_sd_base.py b/trl/trl/trl/models/modeling_sd_base.py new file mode 100644 index 0000000000000000000000000000000000000000..0d68380401f5d6f393b5b257dcb036da2d2ca140 --- /dev/null +++ b/trl/trl/trl/models/modeling_sd_base.py @@ -0,0 +1,645 @@ +# Copyright 2023 DDPO-pytorch authors (Kevin Black), The HuggingFace Team, metric-space. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import os +import warnings +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from diffusers import DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers.loaders import AttnProcsLayers +from diffusers.models.attention_processor import LoRAAttnProcessor +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg + +from ..core import randn_tensor + + +@dataclass +class DDPOPipelineOutput(object): + """ + Output class for the diffusers pipeline to be finetuned with the DDPO trainer + + Args: + images (`torch.Tensor`): + The generated images. + latents (`List[torch.Tensor]`): + The latents used to generate the images. + log_probs (`List[torch.Tensor]`): + The log probabilities of the latents. + + """ + + images: torch.Tensor + latents: torch.Tensor + log_probs: torch.Tensor + + +@dataclass +class DDPOSchedulerOutput(object): + """ + Output class for the diffusers scheduler to be finetuned with the DDPO trainer + + Args: + latents (`torch.Tensor`): + Predicted sample at the previous timestep. Shape: `(batch_size, num_channels, height, width)` + log_probs (`torch.Tensor`): + Log probability of the above mentioned sample. Shape: `(batch_size)` + """ + + latents: torch.Tensor + log_probs: torch.Tensor + + +class DDPOStableDiffusionPipeline(object): + """ + Main class for the diffusers pipeline to be finetuned with the DDPO trainer + """ + + def __call__(self, *args, **kwargs) -> DDPOPipelineOutput: + raise NotImplementedError + + def scheduler_step(self, *args, **kwargs) -> DDPOSchedulerOutput: + raise NotImplementedError + + @property + def unet(self): + """ + Returns the 2d U-Net model used for diffusion. + """ + raise NotImplementedError + + @property + def vae(self): + """ + Returns the Variational Autoencoder model used from mapping images to and from the latent space + """ + raise NotImplementedError + + @property + def tokenizer(self): + """ + Returns the tokenizer used for tokenizing text inputs + """ + raise NotImplementedError + + @property + def scheduler(self): + """ + Returns the scheduler associated with the pipeline used for the diffusion process + """ + raise NotImplementedError + + @property + def text_encoder(self): + """ + Returns the text encoder used for encoding text inputs + """ + raise NotImplementedError + + @property + def autocast(self): + """ + Returns the autocast context manager + """ + raise NotImplementedError + + def set_progress_bar_config(self, *args, **kwargs): + """ + Sets the progress bar config for the pipeline + """ + raise NotImplementedError + + def save_pretrained(self, *args, **kwargs): + """ + Saves all of the model weights + """ + raise NotImplementedError + + def get_trainable_layers(self, *args, **kwargs): + """ + Returns the trainable parameters of the pipeline + """ + raise NotImplementedError + + def save_checkpoint(self, *args, **kwargs): + """ + Light wrapper around accelerate's register_save_state_pre_hook which is run before saving state + """ + raise NotImplementedError + + def load_checkpoint(self, *args, **kwargs): + """ + Light wrapper around accelerate's register_lad_state_pre_hook which is run before loading state + """ + raise NotImplementedError + + +def _left_broadcast(input_tensor, shape): + """ + As opposed to the default direction of broadcasting (right to left), this function broadcasts + from left to right + Args: + input_tensor (`torch.FloatTensor`): is the tensor to broadcast + shape (`Tuple[int]`): is the shape to broadcast to + """ + input_ndim = input_tensor.ndim + if input_ndim > len(shape): + raise ValueError( + "The number of dimensions of the tensor to broadcast cannot be greater than the length of the shape to broadcast to" + ) + return input_tensor.reshape(input_tensor.shape + (1,) * (len(shape) - input_ndim)).broadcast_to(shape) + + +def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep.cpu()).to(timestep.device) + alpha_prod_t_prev = torch.where( + prev_timestep.cpu() >= 0, + self.alphas_cumprod.gather(0, prev_timestep.cpu()), + self.final_alpha_cumprod, + ).to(timestep.device) + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + +def scheduler_step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + prev_sample: Optional[torch.FloatTensor] = None, +) -> DDPOSchedulerOutput: + """ + + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + eta (`float`): weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped + predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when + `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would + coincide with the one provided as input and `use_clipped_model_output` will have not effect. + generator: random number generator. + variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we + can directly provide the noise for the variance itself. This is useful for methods such as + CycleDiffusion. (https://arxiv.org/abs/2210.05559) + + Returns: + `DDPOSchedulerOutput`: the predicted sample at the previous timestep and the log probability of the sample + """ + + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + # to prevent OOB on gather + prev_timestep = torch.clamp(prev_timestep, 0, self.config.num_train_timesteps - 1) + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod.gather(0, timestep.cpu()) + alpha_prod_t_prev = torch.where( + prev_timestep.cpu() >= 0, + self.alphas_cumprod.gather(0, prev_timestep.cpu()), + self.final_alpha_cumprod, + ) + alpha_prod_t = _left_broadcast(alpha_prod_t, sample.shape).to(sample.device) + alpha_prod_t_prev = _left_broadcast(alpha_prod_t_prev, sample.shape).to(sample.device) + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + # 4. Clip or threshold "predicted x_0" + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = _get_variance(self, timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + std_dev_t = _left_broadcast(std_dev_t, sample.shape).to(sample.device) + + if use_clipped_model_output: + # the pred_epsilon is always re-derived from the clipped x_0 in Glide + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon + + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample_mean = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + if prev_sample is not None and generator is not None: + raise ValueError( + "Cannot pass both generator and prev_sample. Please make sure that either `generator` or" + " `prev_sample` stays `None`." + ) + + if prev_sample is None: + variance_noise = randn_tensor( + model_output.shape, + generator=generator, + device=model_output.device, + dtype=model_output.dtype, + ) + prev_sample = prev_sample_mean + std_dev_t * variance_noise + + # log prob of prev_sample given prev_sample_mean and std_dev_t + log_prob = ( + -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2)) + - torch.log(std_dev_t) + - torch.log(torch.sqrt(2 * torch.as_tensor(np.pi))) + ) + # mean along all but batch dimension + log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) + + return DDPOSchedulerOutput(prev_sample.type(sample.dtype), log_prob) + + +# 1. The output type for call is different as the logprobs are now returned +# 2. An extra method called `scheduler_step` is added which is used to constraint the scheduler output +@torch.no_grad() +def pipeline_step( + self, + prompt: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, +): + r""" + Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + + Examples: + + Returns: + `DDPOPipelineOutput`: The generated image, the predicted latents used to generate the image and the associated log probabilities + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + all_latents = [latents] + all_log_probs = [] + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = scheduler_step(self.scheduler, noise_pred, t, latents, eta) + latents = scheduler_output.latents + log_prob = scheduler_output.log_probs + + all_latents.append(latents) + all_log_probs.append(log_prob) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + return DDPOPipelineOutput(image, all_latents, all_log_probs) + + +class DefaultDDPOStableDiffusionPipeline(DDPOStableDiffusionPipeline): + def __init__(self, pretrained_model_name: str, *, pretrained_model_revision: str = "main", use_lora: bool = True): + self.sd_pipeline = StableDiffusionPipeline.from_pretrained( + pretrained_model_name, revision=pretrained_model_revision + ) + + self.use_lora = use_lora + self.pretrained_model = pretrained_model_name + self.pretrained_revision = pretrained_model_revision + + try: + self.sd_pipeline.unet.load_attn_procs(pretrained_model_name, revision=pretrained_model_revision) + self.use_lora = True + except OSError: + if use_lora: + warnings.warn( + "If you are aware that the pretrained model has no lora weights to it, ignore this message. " + "Otherwise please check the if `pytorch_lora_weights.safetensors` exists in the model folder." + ) + + self.sd_pipeline.scheduler = DDIMScheduler.from_config(self.sd_pipeline.scheduler.config) + self.sd_pipeline.safety_checker = None + + # memory optimization + self.sd_pipeline.vae.requires_grad_(False) + self.sd_pipeline.text_encoder.requires_grad_(False) + self.sd_pipeline.unet.requires_grad_(not self.use_lora) + + def __call__(self, *args, **kwargs) -> DDPOPipelineOutput: + return pipeline_step(self.sd_pipeline, *args, **kwargs) + + def scheduler_step(self, *args, **kwargs) -> DDPOSchedulerOutput: + return scheduler_step(self.sd_pipeline.scheduler, *args, **kwargs) + + @property + def unet(self): + return self.sd_pipeline.unet + + @property + def vae(self): + return self.sd_pipeline.vae + + @property + def tokenizer(self): + return self.sd_pipeline.tokenizer + + @property + def scheduler(self): + return self.sd_pipeline.scheduler + + @property + def text_encoder(self): + return self.sd_pipeline.text_encoder + + @property + def autocast(self): + return contextlib.nullcontext if self.use_lora else None + + def save_pretrained(self, output_dir): + if self.use_lora: + self.sd_pipeline.unet.save_attn_procs(output_dir) + self.sd_pipeline.save_pretrained(output_dir) + + def set_progress_bar_config(self, *args, **kwargs): + self.sd_pipeline.set_progress_bar_config(*args, **kwargs) + + def get_trainable_layers(self): + if self.use_lora: + # Set correct lora layers + lora_attn_procs = {} + for name in self.sd_pipeline.unet.attn_processors.keys(): + cross_attention_dim = ( + None if name.endswith("attn1.processor") else self.sd_pipeline.unet.config.cross_attention_dim + ) + if name.startswith("mid_block"): + hidden_size = self.sd_pipeline.unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(self.sd_pipeline.unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = self.sd_pipeline.unet.config.block_out_channels[block_id] + + lora_attn_procs[name] = LoRAAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim + ) + self.sd_pipeline.unet.set_attn_processor(lora_attn_procs) + return AttnProcsLayers(self.sd_pipeline.unet.attn_processors) + else: + return self.sd_pipeline.unet + + def save_checkpoint(self, models, weights, output_dir): + if len(models) != 1: + raise ValueError("Given how the trainable params were set, this should be of length 1") + if self.use_lora and isinstance(models[0], AttnProcsLayers): + self.sd_pipeline.unet.save_attn_procs(output_dir) + elif not self.use_lora and isinstance(models[0], UNet2DConditionModel): + models[0].save_pretrained(os.path.join(output_dir, "unet")) + else: + raise ValueError(f"Unknown model type {type(models[0])}") + + def load_checkpoint(self, models, input_dir): + if len(models) != 1: + raise ValueError("Given how the trainable params were set, this should be of length 1") + if self.use_lora and isinstance(models[0], AttnProcsLayers): + tmp_unet = UNet2DConditionModel.from_pretrained( + self.pretrained_model, + revision=self.pretrained_revision, + subfolder="unet", + ) + tmp_unet.load_attn_procs(input_dir) + models[0].load_state_dict(AttnProcsLayers(tmp_unet.attn_processors).state_dict()) + del tmp_unet + elif not self.use_lora and isinstance(models[0], UNet2DConditionModel): + load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + models[0].register_to_config(**load_model.config) + models[0].load_state_dict(load_model.state_dict()) + del load_model + else: + raise ValueError(f"Unknown model type {type(models[0])}") diff --git a/trl/trl/trl/models/modeling_value_head.py b/trl/trl/trl/models/modeling_value_head.py new file mode 100644 index 0000000000000000000000000000000000000000..2771cc6ce2f5daad0a28933b94f39e670bd9350a --- /dev/null +++ b/trl/trl/trl/models/modeling_value_head.py @@ -0,0 +1,429 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn as nn +from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM + +from .modeling_base import PreTrainedModelWrapper + + +class ValueHead(nn.Module): + r""" + The ValueHead class implements a head for GPT2 that returns a scalar for each output token. + """ + + def __init__(self, config, **kwargs): + super().__init__() + if not hasattr(config, "summary_dropout_prob"): + summary_dropout_prob = kwargs.pop("summary_dropout_prob", 0.1) + else: + summary_dropout_prob = config.summary_dropout_prob + + self.dropout = nn.Dropout(summary_dropout_prob) if summary_dropout_prob else nn.Identity() + + # some models such as OPT have a projection layer before the word embeddings - e.g. OPT-350m + if hasattr(config, "hidden_size"): + hidden_size = config.hidden_size + if hasattr(config, "word_embed_proj_dim"): + hidden_size = config.word_embed_proj_dim + elif hasattr(config, "is_encoder_decoder"): + if config.is_encoder_decoder and hasattr(config, "decoder"): + if hasattr(config.decoder, "hidden_size"): + hidden_size = config.decoder.hidden_size + + self.summary = nn.Linear(hidden_size, 1) + + self.flatten = nn.Flatten() + + def forward(self, hidden_states): + output = self.dropout(hidden_states) + + # For now force upcast in fp32 if needed. Let's keep the + # output in fp32 for numerical stability. + if output.dtype != self.summary.weight.dtype: + output = output.to(self.summary.weight.dtype) + + output = self.summary(output) + return output + + +class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper): + r""" + An autoregressive model with a value head in addition to the language model head. + This class inherits from `~trl.PreTrainedModelWrapper` and wraps a + `transformers.PreTrainedModel` class. The wrapper class supports classic functions + such as `from_pretrained`, `push_to_hub` and `generate`. To call a method of the wrapped + model, simply manipulate the `pretrained_model` attribute of this class. + + Class attributes: + - **transformers_parent_class** (`transformers.PreTrainedModel`) -- The parent class of the wrapped model. This + should be set to `transformers.AutoModelForCausalLM` for this class. + - **lm_head_namings** (`tuple`) -- A tuple of strings that are used to identify the language model head of the + wrapped model. This is set to `("lm_head", "embed_out")` for this class but can be changed for other models + in the future + - **supported_args** (`tuple`) -- A tuple of strings that are used to identify the arguments that are supported + by the `ValueHead` class. Currently, the supported args are: + - **summary_dropout_prob** (`float`, `optional`, defaults to `None`) -- The dropout probability for the + `ValueHead` class. + - **v_head_initializer_range** (`float`, `optional`, defaults to `0.2`) -- The initializer range for the + `ValueHead` if a specific initialization strategy is selected. + - **v_head_init_strategy** (`str`, `optional`, defaults to `None`) -- The initialization strategy for the + `ValueHead`. Currently, the supported strategies are: + - **`None`** -- Initializes the weights of the `ValueHead` with a random distribution. This is the default + strategy. + - **"normal"** -- Initializes the weights of the `ValueHead` with a normal distribution. + + """ + transformers_parent_class = AutoModelForCausalLM + lm_head_namings = ["lm_head", "embed_out"] + supported_args = ( + "summary_dropout_prob", + "v_head_initializer_range", + "v_head_init_strategy", + ) + + def __init__(self, pretrained_model, **kwargs): + r""" + Initializes the model. + + Args: + pretrained_model (`transformers.PreTrainedModel`): + The model to wrap. It should be a causal language model such as GPT2. + or any model mapped inside the `AutoModelForCausalLM` class. + kwargs (`dict`, `optional`): + Additional keyword arguments, that are passed to the `ValueHead` class. + """ + super().__init__(pretrained_model, **kwargs) + v_head_kwargs, _, _ = self._split_kwargs(kwargs) + + if not any(hasattr(self.pretrained_model, attribute) for attribute in self.lm_head_namings): + raise ValueError("The model does not have a language model head, please use a model that has one.") + + self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs) + + self._init_weights(**v_head_kwargs) + + def _init_weights(self, **kwargs): + r""" + Initializes the weights of the value head. The default initialization strategy is random. + Users can pass a different initialization strategy by passing the `v_head_init_strategy` argument + when calling `.from_pretrained`. Supported strategies are: + - `normal`: initializes the weights with a normal distribution. + + Args: + **kwargs (`dict`, `optional`): + Additional keyword arguments, that are passed to the `ValueHead` class. These arguments + can contain the `v_head_init_strategy` argument as well as the `v_head_initializer_range` + argument. + """ + initializer_range = kwargs.pop("v_head_initializer_range", 0.2) + # random init by default + init_strategy = kwargs.pop("v_head_init_strategy", None) + if init_strategy is None: + # do nothing + pass + elif init_strategy == "normal": + self.v_head.summary.weight.data.normal_(mean=0.0, std=initializer_range) + self.v_head.summary.bias.data.zero_() + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + **kwargs, + ): + r""" + Applies a forward pass to the wrapped model and returns the logits of the value head. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + past_key_values (`tuple(tuple(torch.FloatTensor))`, `optional`): + Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model + (see `past_key_values` input) to speed up sequential decoding. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + kwargs (`dict`, `optional`): + Additional keyword arguments, that are passed to the wrapped model. + """ + kwargs["output_hidden_states"] = True # this had already been set in the LORA / PEFT examples + kwargs["past_key_values"] = past_key_values + + if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING": + kwargs.pop("past_key_values") + + base_model_output = self.pretrained_model( + input_ids=input_ids, + attention_mask=attention_mask, + **kwargs, + ) + + last_hidden_state = base_model_output.hidden_states[-1] + lm_logits = base_model_output.logits + loss = base_model_output.loss + + if last_hidden_state.device != self.v_head.summary.weight.device: + last_hidden_state = last_hidden_state.to(self.v_head.summary.weight.device) + + value = self.v_head(last_hidden_state).squeeze(-1) + + # force upcast in fp32 if logits are in half-precision + if lm_logits.dtype != torch.float32: + lm_logits = lm_logits.float() + + return (lm_logits, loss, value) + + def generate(self, *args, **kwargs): + r""" + A simple wrapper around the `generate` method of the wrapped model. + Please refer to the [`generate`](https://huggingface.co/docs/transformers/internal/generation_utils) + method of the wrapped model for more information about the supported arguments. + + Args: + *args (`list`, *optional*): + Positional arguments passed to the `generate` method of the wrapped model. + **kwargs (`dict`, *optional*): + Keyword arguments passed to the `generate` method of the wrapped model. + """ + return self.pretrained_model.generate(*args, **kwargs) + + def state_dict(self, *args, **kwargs): + r""" + Returns the state dictionary of the model. We add the state dictionary of the value head + to the state dictionary of the wrapped model by prepending the key with `v_head.`. + """ + if not self.is_peft_model: + pretrained_model_state_dict = self.pretrained_model.state_dict(*args, **kwargs) + else: + # if it is a peft model, only save the v_head + pretrained_model_state_dict = {} + + v_head_state_dict = self.v_head.state_dict(*args, **kwargs) + for k, v in v_head_state_dict.items(): + pretrained_model_state_dict[f"v_head.{k}"] = v + return pretrained_model_state_dict + + def push_to_hub(self, *args, **kwargs): + setattr(self.pretrained_model, "v_head", self.v_head) + + return self.pretrained_model.push_to_hub(*args, **kwargs) + + def post_init(self, state_dict): + r""" + We add the state dictionary of the value head to the state dictionary of the wrapped model + by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the + keys of the value head state dictionary. + """ + for k in list(state_dict.keys()): + if "v_head." in k: + state_dict[k.replace("v_head.", "")] = state_dict.pop(k) + self.v_head.load_state_dict(state_dict, strict=False) + del state_dict + + if hasattr(self.pretrained_model, "hf_device_map"): + if ( + "cpu" in self.pretrained_model.hf_device_map.values() + or "disk" in self.pretrained_model.hf_device_map.values() + ): + raise ValueError( + "The model is offloaded on CPU or disk - CPU & disk offloading is not supported for ValueHead models." + ) + + first_device = list(set(self.pretrained_model.hf_device_map.values()))[0] + + self.v_head = self.v_head.to(first_device) + + def set_device_hook(module, input, outputs): + new_output = () + for output in outputs: + if isinstance(output, torch.Tensor): + new_output += (output.to(first_device),) + else: + new_output += (output,) + return new_output + + self.register_forward_hook(set_device_hook) + + self.is_sequential_parallel = True + + +class AutoModelForSeq2SeqLMWithValueHead(PreTrainedModelWrapper): + r""" + A seq2seq model with a value head in addition to the language model head. + This class inherits from `~trl.PreTrainedModelWrapper` and wraps a + `transformers.PreTrainedModel` class. The wrapper class supports classic functions + such as `from_pretrained` and `push_to_hub` and also provides some additional + functionalities such as `generate`. + + Args: + pretrained_model (`transformers.PreTrainedModel`): + The model to wrap. It should be a causal language model such as GPT2. + or any model mapped inside the `AutoModelForSeq2SeqLM` class. + kwargs: + Additional keyword arguments passed along to the `ValueHead` class. + """ + transformers_parent_class = AutoModelForSeq2SeqLM + lm_head_namings = ["lm_head", "embed_out", "output_projection"] + supported_args = ( + "summary_dropout_prob", + "v_head_initializer_range", + "v_head_init_strategy", + ) + + def __init__(self, pretrained_model, **kwargs): + super().__init__(pretrained_model, **kwargs) + v_head_kwargs, _, _ = self._split_kwargs(kwargs) + self.is_encoder_decoder = True + + if not self._has_lm_head(): + raise ValueError("The model does not have a language model head, please use a model that has one.") + + self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs) + + self._init_weights(**v_head_kwargs) + + def _has_lm_head(self): + # check module names of all modules inside `pretrained_model` to find the language model head + for name, module in self.pretrained_model.named_modules(): + if any(attribute in name for attribute in self.lm_head_namings): + return True + return False + + def post_init(self, state_dict): + r""" + We add the state dictionary of the value head to the state dictionary of the wrapped model + by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the + keys of the value head state dictionary. + """ + for k in list(state_dict.keys()): + if "v_head." in k: + state_dict[k.replace("v_head.", "")] = state_dict.pop(k) + self.v_head.load_state_dict(state_dict, strict=False) + del state_dict + + if hasattr(self.pretrained_model, "hf_device_map"): + if ( + "cpu" in self.pretrained_model.hf_device_map.values() + or "disk" in self.pretrained_model.hf_device_map.values() + ): + raise ValueError( + "The model is offloaded on CPU or disk - CPU & disk offloading is not supported for ValueHead models." + ) + + # get the lm_head device + for name, module in self.pretrained_model.named_modules(): + if any(attribute in name for attribute in self.lm_head_namings): + lm_head_device = module.weight.device + break + + # put v_head on the same device as the lm_head to avoid issues + self.v_head = self.v_head.to(lm_head_device) + + def set_device_hook(module, input, outputs): + r""" + A hook that sets the device of the output of the model to the device of the first + parameter of the model. + + Args: + module (`nn.Module`): + The module to which the hook is attached. + input (`tuple`): + The input to the module. + outputs (`tuple`): + The output of the module. + """ + new_output = () + for output in outputs: + if isinstance(output, torch.Tensor): + new_output += (output.to(lm_head_device),) + else: + new_output += (output,) + return new_output + + self.register_forward_hook(set_device_hook) + self.is_sequential_parallel = True + + def state_dict(self, *args, **kwargs): + r""" + Returns the state dictionary of the model. We add the state dictionary of the value head + to the state dictionary of the wrapped model by prepending the key with `v_head.`. + """ + if not self.is_peft_model: + pretrained_model_state_dict = self.pretrained_model.state_dict(*args, **kwargs) + else: + # if it is a peft model, only save the v_head + pretrained_model_state_dict = {} + + v_head_state_dict = self.v_head.state_dict(*args, **kwargs) + for k, v in v_head_state_dict.items(): + pretrained_model_state_dict[f"v_head.{k}"] = v + return pretrained_model_state_dict + + def push_to_hub(self, *args, **kwargs): + setattr(self.pretrained_model, "v_head", self.v_head) + + return self.pretrained_model.push_to_hub(*args, **kwargs) + + def _init_weights(self, **kwargs): + r""" + We initialize the weights of the value head. + """ + initializer_range = kwargs.pop("v_head_initializer_range", 0.2) + # random init by default + init_strategy = kwargs.pop("v_head_init_strategy", None) + if init_strategy is None: + # do nothing + pass + elif init_strategy == "normal": + self.v_head.summary.weight.data.normal_(mean=0.0, std=initializer_range) + self.v_head.summary.bias.data.zero_() + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + **kwargs, + ): + kwargs["past_key_values"] = past_key_values + if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING": + kwargs.pop("past_key_values") + + base_model_output = self.pretrained_model( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, # We force the model to output hidden states + **kwargs, + ) + + last_hidden_state = base_model_output.decoder_hidden_states[-1] + lm_logits = base_model_output.logits + loss = base_model_output.loss + + value = self.v_head(last_hidden_state).squeeze(-1) + + # force upcast in fp32 if logits are in half-precision + if lm_logits.dtype != torch.float32: + lm_logits = lm_logits.float() + + return (lm_logits, loss, value) + + def generate(self, *args, **kwargs): + r""" + We call `generate` on the wrapped model. + """ + return self.pretrained_model.generate(*args, **kwargs) diff --git a/trl/trl/trl/trainer/__init__.py b/trl/trl/trl/trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e81705fbc2adc67ee2682977a0f374ab08dea530 --- /dev/null +++ b/trl/trl/trl/trainer/__init__.py @@ -0,0 +1,44 @@ +# flake8: noqa + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# There is a circular import in the PPOTrainer if we let isort sort these +# isort: off +from .utils import ( + AdaptiveKLController, + FixedKLController, + ConstantLengthDataset, + DataCollatorForCompletionOnlyLM, + RunningMoments, + disable_dropout_in_model, +) + +# isort: on + +from ..import_utils import is_diffusers_available +from .base import BaseTrainer +from .ddpo_config import DDPOConfig + + +if is_diffusers_available(): + from .ddpo_trainer import DDPOTrainer + +from .dpo_trainer import DPOTrainer +from .iterative_sft_trainer import IterativeSFTTrainer +from .ppo_config import PPOConfig +from .ppo_trainer import PPOTrainer +from .reward_trainer import RewardTrainer, compute_accuracy +from .sft_trainer import SFTTrainer +from .training_configs import RewardConfig diff --git a/trl/trl/trl/trainer/base.py b/trl/trl/trl/trainer/base.py new file mode 100644 index 0000000000000000000000000000000000000000..f0314cb987fcf5a520ed1ab1ad0a7eb107f18acc --- /dev/null +++ b/trl/trl/trl/trainer/base.py @@ -0,0 +1,46 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from huggingface_hub import PyTorchModelHubMixin + + +class BaseTrainer(PyTorchModelHubMixin): + r""" + Base class for all trainers - this base class implements the basic functions that we + need for a trainer. + + The trainer needs to have the following functions: + - step: takes in a batch of data and performs a step of training + - loss: takes in a batch of data and returns the loss + - compute_rewards: takes in a batch of data and returns the rewards + - _build_models_and_tokenizer: builds the models and tokenizer + - _build_dataset: builds the dataset + Each user is expected to implement their own trainer class that inherits from this base + if they want to use a new training algorithm. + """ + + def __init__(self, config): + self.config = config + + def step(self, *args): + raise NotImplementedError("Not implemented") + + def loss(self, *args): + raise NotImplementedError("Not implemented") + + def compute_rewards(self, *args): + raise NotImplementedError("Not implemented") + + def _save_pretrained(self, save_directory): + raise NotImplementedError("Not implemented") diff --git a/trl/trl/trl/trainer/ddpo_config.py b/trl/trl/trl/trainer/ddpo_config.py new file mode 100644 index 0000000000000000000000000000000000000000..310861381465fcc1a2b73e48712815617057bb84 --- /dev/null +++ b/trl/trl/trl/trainer/ddpo_config.py @@ -0,0 +1,120 @@ +import os +import sys +import warnings +from dataclasses import dataclass, field +from typing import Literal, Optional + +from ..core import flatten_dict +from ..import_utils import is_bitsandbytes_available, is_torchvision_available + + +@dataclass +class DDPOConfig: + """ + Configuration class for DDPOTrainer + """ + + # common parameters + exp_name: str = os.path.basename(sys.argv[0])[: -len(".py")] + """the name of this experiment (by default is the file name without the extension name)""" + run_name: Optional[str] = "" + """Run name for wandb logging and checkpoint saving.""" + seed: int = 0 + """Seed value for random generations""" + log_with: Optional[Literal["wandb", "tensorboard"]] = None + """Log with either 'wandb' or 'tensorboard', check https://huggingface.co/docs/accelerate/usage_guides/tracking for more details""" + tracker_kwargs: dict = field(default_factory=dict) + """Keyword arguments for the tracker (e.g. wandb_project)""" + accelerator_kwargs: dict = field(default_factory=dict) + """Keyword arguments for the accelerator""" + project_kwargs: dict = field(default_factory=dict) + """Keyword arguments for the accelerator project config (e.g. `logging_dir`)""" + tracker_project_name: str = "trl" + """Name of project to use for tracking""" + logdir: str = "logs" + """Top-level logging directory for checkpoint saving.""" + + # hyperparameters + num_epochs: int = 100 + """Number of epochs to train.""" + save_freq: int = 1 + """Number of epochs between saving model checkpoints.""" + num_checkpoint_limit: int = 5 + """Number of checkpoints to keep before overwriting old ones.""" + mixed_precision: str = "fp16" + """Mixed precision training.""" + allow_tf32: bool = True + """Allow tf32 on Ampere GPUs.""" + resume_from: Optional[str] = "" + """Resume training from a checkpoint.""" + sample_num_steps: int = 50 + """Number of sampler inference steps.""" + sample_eta: float = 1.0 + """Eta parameter for the DDIM sampler.""" + sample_guidance_scale: float = 5.0 + """Classifier-free guidance weight.""" + sample_batch_size: int = 1 + """Batch size (per GPU!) to use for sampling.""" + sample_num_batches_per_epoch: int = 2 + """Number of batches to sample per epoch.""" + train_batch_size: int = 1 + """Batch size (per GPU!) to use for training.""" + train_use_8bit_adam: bool = False + """Whether to use the 8bit Adam optimizer from bitsandbytes.""" + train_learning_rate: float = 3e-4 + """Learning rate.""" + train_adam_beta1: float = 0.9 + """Adam beta1.""" + train_adam_beta2: float = 0.999 + """Adam beta2.""" + train_adam_weight_decay: float = 1e-4 + """Adam weight decay.""" + train_adam_epsilon: float = 1e-8 + """Adam epsilon.""" + train_gradient_accumulation_steps: int = 1 + """Number of gradient accumulation steps.""" + train_max_grad_norm: float = 1.0 + """Maximum gradient norm for gradient clipping.""" + train_num_inner_epochs: int = 1 + """Number of inner epochs per outer epoch.""" + train_cfg: bool = True + """Whether or not to use classifier-free guidance during training.""" + train_adv_clip_max: float = 5 + """Clip advantages to the range.""" + train_clip_range: float = 1e-4 + """The PPO clip range.""" + train_timestep_fraction: float = 1.0 + """The fraction of timesteps to train on.""" + per_prompt_stat_tracking: bool = False + """Whether to track statistics for each prompt separately.""" + per_prompt_stat_tracking_buffer_size: int = 16 + """Number of reward values to store in the buffer for each prompt.""" + per_prompt_stat_tracking_min_count: int = 16 + """The minimum number of reward values to store in the buffer.""" + async_reward_computation: bool = False + """Whether to compute rewards asynchronously.""" + max_workers: int = 2 + """The maximum number of workers to use for async reward computation.""" + negative_prompts: Optional[str] = "" + """Comma-separated list of prompts to use as negative examples.""" + + def to_dict(self): + output_dict = {} + for key, value in self.__dict__.items(): + output_dict[key] = value + return flatten_dict(output_dict) + + def __post_init__(self): + if self.log_with not in ["wandb", "tensorboard"]: + warnings.warn( + ("Accelerator tracking only supports image logging if `log_with` is set to 'wandb' or 'tensorboard'.") + ) + + if self.log_with == "wandb" and not is_torchvision_available(): + warnings.warn("Wandb image logging requires torchvision to be installed") + + if self.train_use_8bit_adam and not is_bitsandbytes_available(): + raise ImportError( + "You need to install bitsandbytes to use 8bit Adam. " + "You can install it with `pip install bitsandbytes`." + ) diff --git a/trl/trl/trl/trainer/ddpo_trainer.py b/trl/trl/trl/trainer/ddpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..0f1cee12f1c02917e634b11f358ba839a39fd910 --- /dev/null +++ b/trl/trl/trl/trainer/ddpo_trainer.py @@ -0,0 +1,576 @@ +# Copyright 2023 DDPO-pytorch authors (Kevin Black), metric-space, The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from collections import defaultdict +from concurrent import futures +from typing import Any, Callable, Optional, Tuple +from warnings import warn + +import torch +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed + +from ..models import DDPOStableDiffusionPipeline +from . import BaseTrainer, DDPOConfig +from .utils import PerPromptStatTracker + + +logger = get_logger(__name__) + + +class DDPOTrainer(BaseTrainer): + """ + The DDPOTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models. + Note, this trainer is heavily inspired by the work here: https://github.com/kvablack/ddpo-pytorch + As of now only Stable Diffusion based pipelines are supported + + Attributes: + **config** (`DDPOConfig`) -- Configuration object for DDPOTrainer. Check the documentation of `PPOConfig` for more + details. + **reward_function** (Callable[[torch.Tensor, Tuple[str], Tuple[Any]], torch.Tensor]) -- Reward function to be used + **prompt_function** (Callable[[], Tuple[str, Any]]) -- Function to generate prompts to guide model + **sd_pipeline** (`DDPOStableDiffusionPipeline`) -- Stable Diffusion pipeline to be used for training. + **image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images + """ + + def __init__( + self, + config: DDPOConfig, + reward_function: Callable[[torch.Tensor, Tuple[str], Tuple[Any]], torch.Tensor], + prompt_function: Callable[[], Tuple[str, Any]], + sd_pipeline: DDPOStableDiffusionPipeline, + image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None, + ): + if image_samples_hook is None: + warn("No image_samples_hook provided; no images will be logged") + + self.prompt_fn = prompt_function + self.reward_fn = reward_function + self.config = config + self.image_samples_callback = image_samples_hook + + accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs) + + if self.config.resume_from: + self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from)) + if "checkpoint_" not in os.path.basename(self.config.resume_from): + # get the most recent checkpoint in this directory + checkpoints = list( + filter( + lambda x: "checkpoint_" in x, + os.listdir(self.config.resume_from), + ) + ) + if len(checkpoints) == 0: + raise ValueError(f"No checkpoints found in {self.config.resume_from}") + checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints]) + self.config.resume_from = os.path.join( + self.config.resume_from, + f"checkpoint_{checkpoint_numbers[-1]}", + ) + + accelerator_project_config.iteration = checkpoint_numbers[-1] + 1 + + # number of timesteps within each trajectory to train on + self.num_train_timesteps = int(self.config.sample_num_steps * self.config.train_timestep_fraction) + + self.accelerator = Accelerator( + log_with=self.config.log_with, + mixed_precision=self.config.mixed_precision, + project_config=accelerator_project_config, + # we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the + # number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get + # the total number of optimizer steps to accumulate across. + gradient_accumulation_steps=self.config.train_gradient_accumulation_steps * self.num_train_timesteps, + **self.config.accelerator_kwargs, + ) + + is_okay, message = self._config_check() + if not is_okay: + raise ValueError(message) + + is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard" + + if self.accelerator.is_main_process: + self.accelerator.init_trackers( + self.config.tracker_project_name, + config=dict(ddpo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(), + init_kwargs=self.config.tracker_kwargs, + ) + + logger.info(f"\n{config}") + + set_seed(self.config.seed, device_specific=True) + + self.sd_pipeline = sd_pipeline + + self.sd_pipeline.set_progress_bar_config( + position=1, + disable=not self.accelerator.is_local_main_process, + leave=False, + desc="Timestep", + dynamic_ncols=True, + ) + + # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + if self.accelerator.mixed_precision == "fp16": + inference_dtype = torch.float16 + elif self.accelerator.mixed_precision == "bf16": + inference_dtype = torch.bfloat16 + else: + inference_dtype = torch.float32 + + self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype) + self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype) + self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype) + + trainable_layers = self.sd_pipeline.get_trainable_layers() + + self.accelerator.register_save_state_pre_hook(self._save_model_hook) + self.accelerator.register_load_state_pre_hook(self._load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if self.config.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + self.optimizer = self._setup_optimizer(trainable_layers.parameters()) + + self.neg_prompt_embed = self.sd_pipeline.text_encoder( + self.sd_pipeline.tokenizer( + [""] if self.config.negative_prompts is None else self.config.negative_prompts, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.sd_pipeline.tokenizer.model_max_length, + ).input_ids.to(self.accelerator.device) + )[0] + + if config.per_prompt_stat_tracking: + self.stat_tracker = PerPromptStatTracker( + config.per_prompt_stat_tracking_buffer_size, + config.per_prompt_stat_tracking_min_count, + ) + + # NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses + # more memory + self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast + + self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer) + + if self.config.async_reward_computation: + self.executor = futures.ThreadPoolExecutor(max_workers=config.max_workers) + + if config.resume_from: + logger.info(f"Resuming from {config.resume_from}") + self.accelerator.load_state(config.resume_from) + self.first_epoch = int(config.resume_from.split("_")[-1]) + 1 + else: + self.first_epoch = 0 + + def compute_rewards(self, prompt_image_pairs, is_async=False): + if not is_async: + rewards = [] + for images, prompts, prompt_metadata in prompt_image_pairs: + reward, reward_metadata = self.reward_fn(images, prompts, prompt_metadata) + rewards.append( + ( + torch.as_tensor(reward, device=self.accelerator.device), + reward_metadata, + ) + ) + else: + rewards = self.executor.map(lambda x: self.reward_fn(*x), prompt_image_pairs) + rewards = [ + (torch.as_tensor(reward.result(), device=self.accelerator.device), reward_metadata.result()) + for reward, reward_metadata in rewards + ] + + return zip(*rewards) + + def step(self, epoch: int, global_step: int): + """ + Perform a single step of training. + + Args: + epoch (int): The current epoch. + global_step (int): The current global step. + + Side Effects: + - Model weights are updated + - Logs the statistics to the accelerator trackers. + - If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker. + + Returns: + global_step (int): The updated global step. + + """ + samples, prompt_image_data = self._generate_samples( + iterations=self.config.sample_num_batches_per_epoch, + batch_size=self.config.sample_batch_size, + ) + + # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...) + samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()} + rewards, rewards_metadata = self.compute_rewards( + prompt_image_data, is_async=self.config.async_reward_computation + ) + + for i, image_data in enumerate(prompt_image_data): + image_data.extend([rewards[i], rewards_metadata[i]]) + + if self.image_samples_callback is not None: + self.image_samples_callback(prompt_image_data, global_step, self.accelerator.trackers[0]) + + rewards = torch.cat(rewards) + rewards = self.accelerator.gather(rewards).cpu().numpy() + + self.accelerator.log( + { + "reward": rewards, + "epoch": epoch, + "reward_mean": rewards.mean(), + "reward_std": rewards.std(), + }, + step=global_step, + ) + + if self.config.per_prompt_stat_tracking: + # gather the prompts across processes + prompt_ids = self.accelerator.gather(samples["prompt_ids"]).cpu().numpy() + prompts = self.sd_pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True) + advantages = self.stat_tracker.update(prompts, rewards) + else: + advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8) + + # ungather advantages; keep the entries corresponding to the samples on this process + samples["advantages"] = ( + torch.as_tensor(advantages) + .reshape(self.accelerator.num_processes, -1)[self.accelerator.process_index] + .to(self.accelerator.device) + ) + + del samples["prompt_ids"] + + total_batch_size, num_timesteps = samples["timesteps"].shape + + for inner_epoch in range(self.config.train_num_inner_epochs): + # shuffle samples along batch dimension + perm = torch.randperm(total_batch_size, device=self.accelerator.device) + samples = {k: v[perm] for k, v in samples.items()} + + # shuffle along time dimension independently for each sample + # still trying to understand the code below + perms = torch.stack( + [torch.randperm(num_timesteps, device=self.accelerator.device) for _ in range(total_batch_size)] + ) + + for key in ["timesteps", "latents", "next_latents", "log_probs"]: + samples[key] = samples[key][ + torch.arange(total_batch_size, device=self.accelerator.device)[:, None], + perms, + ] + + original_keys = samples.keys() + original_values = samples.values() + # rebatch them as user defined train_batch_size is different from sample_batch_size + reshaped_values = [v.reshape(-1, self.config.train_batch_size, *v.shape[1:]) for v in original_values] + + # Transpose the list of original values + transposed_values = zip(*reshaped_values) + # Create new dictionaries for each row of transposed values + samples_batched = [dict(zip(original_keys, row_values)) for row_values in transposed_values] + + self.sd_pipeline.unet.train() + global_step = self._train_batched_samples(inner_epoch, epoch, global_step, samples_batched) + # ensure optimization step at the end of the inner epoch + if not self.accelerator.sync_gradients: + raise ValueError( + "Optimization step should have been performed by this point. Please check calculated gradient accumulation settings." + ) + + if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process: + self.accelerator.save_state() + + return global_step + + def calculate_loss(self, latents, timesteps, next_latents, log_probs, advantages, embeds): + """ + Calculate the loss for a batch of an unpacked sample + + Args: + latents (torch.Tensor): + The latents sampled from the diffusion model, shape: [batch_size, num_steps, ...] + timesteps (torch.Tensor): + The timesteps sampled from the diffusion model, shape: [batch_size] + next_latents (torch.Tensor): + The next latents sampled from the diffusion model, shape: [batch_size, num_steps, ...] + log_probs (torch.Tensor): + The log probabilities of the latents, shape: [batch_size] + advantages (torch.Tensor): + The advantages of the latents, shape: [batch_size] + embeds (torch.Tensor): + The embeddings of the prompts, shape: [2*batch_size or batch_size, ...] + Note: the "or" is because if train_cfg is True, the expectation is that negative prompts are concatenated to the embeds + + Returns: + loss (torch.Tensor), approx_kl (torch.Tensor), clipfrac (torch.Tensor) + (all of these are of shape (1,)) + """ + with self.autocast(): + if self.config.train_cfg: + noise_pred = self.sd_pipeline.unet( + torch.cat([latents] * 2), + torch.cat([timesteps] * 2), + embeds, + ).sample + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.config.sample_guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + else: + noise_pred = self.sd_pipeline.unet( + latents, + timesteps, + embeds, + ).sample + # compute the log prob of next_latents given latents under the current model + + scheduler_step_output = self.sd_pipeline.scheduler_step( + noise_pred, + timesteps, + latents, + eta=self.config.sample_eta, + prev_sample=next_latents, + ) + + log_prob = scheduler_step_output.log_probs + + advantages = torch.clamp( + advantages, + -self.config.train_adv_clip_max, + self.config.train_adv_clip_max, + ) + + ratio = torch.exp(log_prob - log_probs) + + loss = self.loss(advantages, self.config.train_clip_range, ratio) + + approx_kl = 0.5 * torch.mean((log_prob - log_probs) ** 2) + + clipfrac = torch.mean((torch.abs(ratio - 1.0) > self.config.train_clip_range).float()) + + return loss, approx_kl, clipfrac + + def loss( + self, + advantages: torch.Tensor, + clip_range: float, + ratio: torch.Tensor, + ): + unclipped_loss = -advantages * ratio + clipped_loss = -advantages * torch.clamp( + ratio, + 1.0 - clip_range, + 1.0 + clip_range, + ) + return torch.mean(torch.maximum(unclipped_loss, clipped_loss)) + + def _setup_optimizer(self, trainable_layers_parameters): + if self.config.train_use_8bit_adam: + import bitsandbytes + + optimizer_cls = bitsandbytes.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + return optimizer_cls( + trainable_layers_parameters, + lr=self.config.train_learning_rate, + betas=(self.config.train_adam_beta1, self.config.train_adam_beta2), + weight_decay=self.config.train_adam_weight_decay, + eps=self.config.train_adam_epsilon, + ) + + def _save_model_hook(self, models, weights, output_dir): + self.sd_pipeline.save_checkpoint(models, weights, output_dir) + weights.pop() # ensures that accelerate doesn't try to handle saving of the model + + def _load_model_hook(self, models, input_dir): + self.sd_pipeline.load_checkpoint(models, input_dir) + models.pop() # ensures that accelerate doesn't try to handle loading of the model + + def _generate_samples(self, iterations, batch_size): + """ + Generate samples from the model + + Args: + iterations (int): Number of iterations to generate samples for + batch_size (int): Batch size to use for sampling + + Returns: + samples (List[Dict[str, torch.Tensor]]), prompt_image_pairs (List[List[Any]]) + """ + samples = [] + prompt_image_pairs = [] + self.sd_pipeline.unet.eval() + + sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1) + + for _ in range(iterations): + prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)]) + + prompt_ids = self.sd_pipeline.tokenizer( + prompts, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.sd_pipeline.tokenizer.model_max_length, + ).input_ids.to(self.accelerator.device) + prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0] + + with self.autocast(): + sd_output = self.sd_pipeline( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=sample_neg_prompt_embeds, + num_inference_steps=self.config.sample_num_steps, + guidance_scale=self.config.sample_guidance_scale, + eta=self.config.sample_eta, + output_type="pt", + ) + + images = sd_output.images + latents = sd_output.latents + log_probs = sd_output.log_probs + + latents = torch.stack(latents, dim=1) # (batch_size, num_steps + 1, ...) + log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1) + timesteps = self.sd_pipeline.scheduler.timesteps.repeat(batch_size, 1) # (batch_size, num_steps) + + samples.append( + { + "prompt_ids": prompt_ids, + "prompt_embeds": prompt_embeds, + "timesteps": timesteps, + "latents": latents[:, :-1], # each entry is the latent before timestep t + "next_latents": latents[:, 1:], # each entry is the latent after timestep t + "log_probs": log_probs, + "negative_prompt_embeds": sample_neg_prompt_embeds, + } + ) + prompt_image_pairs.append([images, prompts, prompt_metadata]) + + return samples, prompt_image_pairs + + def _train_batched_samples(self, inner_epoch, epoch, global_step, batched_samples): + """ + Train on a batch of samples. Main training segment + + Args: + inner_epoch (int): The current inner epoch + epoch (int): The current epoch + global_step (int): The current global step + batched_samples (List[Dict[str, torch.Tensor]]): The batched samples to train on + + Side Effects: + - Model weights are updated + - Logs the statistics to the accelerator trackers. + + Returns: + global_step (int): The updated global step + """ + info = defaultdict(list) + for i, sample in enumerate(batched_samples): + if self.config.train_cfg: + # concat negative prompts to sample prompts to avoid two forward passes + embeds = torch.cat([sample["negative_prompt_embeds"], sample["prompt_embeds"]]) + else: + embeds = sample["prompt_embeds"] + + for j in range(self.num_train_timesteps): + with self.accelerator.accumulate(self.sd_pipeline.unet): + loss, approx_kl, clipfrac = self.calculate_loss( + sample["latents"][:, j], + sample["timesteps"][:, j], + sample["next_latents"][:, j], + sample["log_probs"][:, j], + sample["advantages"], + embeds, + ) + info["approx_kl"].append(approx_kl) + info["clipfrac"].append(clipfrac) + info["loss"].append(loss) + + self.accelerator.backward(loss) + if self.accelerator.sync_gradients: + self.accelerator.clip_grad_norm_( + self.trainable_layers.parameters(), + self.config.train_max_grad_norm, + ) + self.optimizer.step() + self.optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if self.accelerator.sync_gradients: + # log training-related stuff + info = {k: torch.mean(torch.stack(v)) for k, v in info.items()} + info = self.accelerator.reduce(info, reduction="mean") + info.update({"epoch": epoch, "inner_epoch": inner_epoch}) + self.accelerator.log(info, step=global_step) + global_step += 1 + info = defaultdict(list) + return global_step + + def _config_check(self) -> Tuple[bool, str]: + samples_per_epoch = ( + self.config.sample_batch_size * self.accelerator.num_processes * self.config.sample_num_batches_per_epoch + ) + total_train_batch_size = ( + self.config.train_batch_size + * self.accelerator.num_processes + * self.config.train_gradient_accumulation_steps + ) + + if not self.config.sample_batch_size >= self.config.train_batch_size: + return ( + False, + f"Sample batch size ({self.config.sample_batch_size}) must be greater than or equal to the train batch size ({self.config.train_batch_size})", + ) + if not self.config.sample_batch_size % self.config.train_batch_size == 0: + return ( + False, + f"Sample batch size ({self.config.sample_batch_size}) must be divisible by the train batch size ({self.config.train_batch_size})", + ) + if not samples_per_epoch % total_train_batch_size == 0: + return ( + False, + f"Number of samples per epoch ({samples_per_epoch}) must be divisible by the total train batch size ({total_train_batch_size})", + ) + return True, "" + + def train(self, epochs: Optional[int] = None): + """ + Train the model for a given number of epochs + """ + global_step = 0 + if epochs is None: + epochs = self.config.num_epochs + for epoch in range(self.first_epoch, epochs): + global_step = self.step(epoch, global_step) + + def _save_pretrained(self, save_directory): + self.sd_pipeline.save_pretrained(save_directory) diff --git a/trl/trl/trl/trainer/dpo_trainer.py b/trl/trl/trl/trainer/dpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..0015f42af0bf9d278e5ba10fb747414d93aaef7e --- /dev/null +++ b/trl/trl/trl/trainer/dpo_trainer.py @@ -0,0 +1,782 @@ +# DPO Authors: Rafael Rafailov, Archit Sharma, Eric Mitchell, Stefano Ermon, Christopher D. Manning, and Chelsea Finn 2023 +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import random +import warnings +from collections import defaultdict +from copy import deepcopy +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from accelerate.utils import is_deepspeed_available +from datasets import Dataset +from torch.utils.data import DataLoader +from transformers import ( + AutoModelForCausalLM, + DataCollator, + PreTrainedModel, + PreTrainedTokenizerBase, + Trainer, + TrainingArguments, +) +from transformers.trainer_callback import TrainerCallback +from transformers.trainer_utils import EvalLoopOutput + +from ..import_utils import is_peft_available, is_wandb_available +from ..models import PreTrainedModelWrapper, create_reference_model +from .utils import DPODataCollatorWithPadding, disable_dropout_in_model, pad_to_length + + +if is_peft_available(): + from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training + + +if is_wandb_available(): + import wandb + +if is_deepspeed_available(): + import deepspeed + + +class DPOTrainer(Trainer): + r""" + Initialize DPOTrainer. + + Args: + model (`transformers.PreTrainedModel`): + The model to train, preferably an `AutoModelForSequenceClassification`. + ref_model (`PreTrainedModelWrapper`): + Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no + reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized. + beta (`float`, defaults to 0.1): + The beta factor in DPO loss. Higher beta means less divergence from the initial policy. For the IPO loss, beta is the regularization parameter denoted by tau in the paper. + label_smoothing (`float`, defaults to 0): + The robust DPO label smoothing parameter from the [cDPO](https://ericmitchell.ai/cdpo.pdf) report that should be between 0 and 0.5. + loss_type (`str`, defaults to `"sigmoid"`): + The type of DPO loss to use. Either `"sigmoid"` the default DPO loss,`"hinge"` loss from [SLiC](https://arxiv.org/abs/2305.10425) paper, `"ipo"` from [IPO](https://arxiv.org/abs/2310.12036) paper, or `"kto"` from the HALOs [report](https://github.com/ContextualAI/HALOs/blob/main/assets/report.pdf). + args (`transformers.TrainingArguments`): + The arguments to use for training. + data_collator (`transformers.DataCollator`): + The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used + which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences. + label_pad_token_id (`int`, defaults to `-100`): + The label pad token id. This argument is required if you want to use the default data collator. + padding_value (`int`, defaults to `0`): + The padding value. This argument is required if you want to use the default data collator. + truncation_mode (`str`, defaults to `keep_end`): + The truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the default data collator. + train_dataset (`datasets.Dataset`): + The dataset to use for training. + eval_dataset (`datasets.Dataset`): + The dataset to use for evaluation. + tokenizer (`transformers.PreTrainedTokenizerBase`): + The tokenizer to use for training. This argument is required if you want to use the default data collator. + model_init (`Callable[[], transformers.PreTrainedModel]`): + The model initializer to use for training. If None is specified, the default model initializer will be used. + callbacks (`List[transformers.TrainerCallback]`): + The callbacks to use for training. + optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): + The optimizer and scheduler to use for training. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): + The function to use to preprocess the logits before computing the metrics. + max_length (`int`, defaults to `None`): + The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator. + max_prompt_length (`int`, defaults to `None`): + The maximum length of the prompt. This argument is required if you want to use the default data collator. + max_target_length (`int`, defaults to `None`): + The maximum length of the target. This argument is required if you want to use the default data collator and your model is an encoder-decoder. + peft_config (`Dict`, defaults to `None`): + The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model. + is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`): + If no model is provided, we need to know if the model_init returns an encoder-decoder. + disable_dropout (`bool`, defaults to `True`): + Whether or not to disable dropouts in `model` and `ref_model`. + generate_during_eval (`bool`, defaults to `False`): + Whether to sample and log generations during evaluation step. + compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*): + The function to use to compute the metrics. Must take a `EvalPrediction` and return + a dictionary string to metric values. + model_init_kwargs: (`Optional[Dict]`, *optional*): + Dict of Optional kwargs to pass when instantiating the model from a string + ref_model_init_kwargs: (`Optional[Dict]`, *optional*): + Dict of Optional kwargs to pass when instantiating the ref model from a string + + """ + + def __init__( + self, + model: Union[PreTrainedModel, nn.Module, str] = None, + ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, + beta: float = 0.1, + label_smoothing: float = 0, + loss_type: Literal["sigmoid", "hinge", "ipo", "kto"] = "sigmoid", + args: TrainingArguments = None, + data_collator: Optional[DataCollator] = None, + label_pad_token_id: int = -100, + padding_value: int = 0, + truncation_mode: str = "keep_end", + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, + callbacks: Optional[List[TrainerCallback]] = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( + None, + None, + ), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + max_length: Optional[int] = None, + max_prompt_length: Optional[int] = None, + max_target_length: Optional[int] = None, + peft_config: Optional[Dict] = None, + is_encoder_decoder: Optional[bool] = None, + disable_dropout: bool = True, + generate_during_eval: bool = False, + compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None, + model_init_kwargs: Optional[Dict] = None, + ref_model_init_kwargs: Optional[Dict] = None, + ): + if model_init_kwargs is None: + model_init_kwargs = {} + elif not isinstance(model, str): + raise ValueError("You passed model_kwargs to the DPOTrainer. But your model is already instantiated.") + + if ref_model_init_kwargs is None: + ref_model_init_kwargs = {} + elif not isinstance(ref_model, str): + raise ValueError( + "You passed ref_model_kwargs to the DPOTrainer. But your ref_model is already instantiated." + ) + + if isinstance(model, str): + warnings.warn( + "You passed a model_id to the DPOTrainer. This will automatically create an " + "`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you." + ) + model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) + + if isinstance(ref_model, str): + warnings.warn( + "You passed a ref model_id to the DPOTrainer. This will automatically create an " + "`AutoModelForCausalLM`" + ) + ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs) + + if not is_peft_available() and peft_config is not None: + raise ValueError( + "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" + ) + elif is_peft_available() and peft_config is not None: + # if model is a peft model and we have a peft_config, we merge and unload it first + if isinstance(model, PeftModel): + model = model.merge_and_unload() + + if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): + _support_gc_kwargs = hasattr( + args, "gradient_checkpointing_kwargs" + ) and "gradient_checkpointing_kwargs" in list( + inspect.signature(prepare_model_for_kbit_training).parameters + ) + + preprare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} + + if _support_gc_kwargs: + preprare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs + + model = prepare_model_for_kbit_training(model, **preprare_model_kwargs) + elif getattr(args, "gradient_checkpointing", False): + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + # get peft model with the given config + model = get_peft_model(model, peft_config) + + # For models that use gradient_checkpoiting, we need to attach a hook that enables input + # to explicitly have `requires_grad=True`, otherwise training will either silently + # fail or completely fail. + elif getattr(args, "gradient_checkpointing", False): + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + if generate_during_eval and not is_wandb_available(): + raise ValueError( + "`generate_during_eval=True` requires Weights and Biases to be installed." + " Please install `wandb` to resolve." + ) + + if model is not None: + self.is_encoder_decoder = model.config.is_encoder_decoder + elif is_encoder_decoder is None: + raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.") + else: + self.is_encoder_decoder = is_encoder_decoder + + self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) + + if ref_model: + self.ref_model = ref_model + elif self.is_peft_model: + # The `model` with adapters turned off will be used as the reference model + self.ref_model = None + else: + self.ref_model = create_reference_model(model) + + if data_collator is None: + if tokenizer is None: + raise ValueError( + "max_length or a tokenizer must be specified when using the default DPODataCollatorWithPadding" + ) + if max_length is None: + warnings.warn( + "When using DPODataCollatorWithPadding, you should set `max_length` in the DPOTrainer's init" + " it will be set to `512` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_length = 512 + if max_prompt_length is None: + warnings.warn( + "When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the DPOTrainer's init" + " it will be set to `128` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_prompt_length = 128 + + if max_target_length is None and self.is_encoder_decoder: + warnings.warn( + "When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_target_length` in the DPOTrainer's init" + " it will be set to `128` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_target_length = 128 + + data_collator = DPODataCollatorWithPadding( + tokenizer, + max_length=max_length, + max_prompt_length=max_prompt_length, + label_pad_token_id=label_pad_token_id, + padding_value=padding_value, + truncation_mode=truncation_mode, + is_encoder_decoder=self.is_encoder_decoder, + max_target_length=max_target_length, + ) + + if args.remove_unused_columns: + args.remove_unused_columns = False + # warn users + warnings.warn( + "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments" + " we have set it for you, but you should do it yourself in the future.", + UserWarning, + ) + + self.use_dpo_data_collator = True + else: + self.use_dpo_data_collator = False + + if disable_dropout: + disable_dropout_in_model(model) + if self.ref_model is not None: + disable_dropout_in_model(self.ref_model) + + self.max_length = max_length + self.generate_during_eval = generate_during_eval + self.label_pad_token_id = label_pad_token_id + self.padding_value = padding_value + + if loss_type in ["hinge", "ipo", "kto"] and label_smoothing > 0: + warnings.warn( + "You are using a loss type that does not support label smoothing. Ignoring label_smoothing parameter." + ) + + self.beta = beta + self.label_smoothing = label_smoothing + self.loss_type = loss_type + + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + model_init=model_init, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + if not hasattr(self, "accelerator"): + raise AttributeError( + "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." + ) + + if self.ref_model is None: + if not hasattr(self.accelerator.unwrap_model(self.model), "disable_adapter"): + raise ValueError( + "You are using a `peft` version that does not support `disable_adapter`. Please update your `peft` version to the latest version." + ) + else: + if self.is_deepspeed_enabled: + self.ref_model = self._prepare_deepspeed(self.ref_model) + else: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + + def _prepare_deepspeed(self, model: PreTrainedModelWrapper): + # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 + deepspeed_plugin = self.accelerator.state.deepspeed_plugin + config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config) + + if model is not None: + if hasattr(model, "config"): + hidden_size = ( + max(model.config.hidden_sizes) + if getattr(model.config, "hidden_sizes", None) + else getattr(model.config, "hidden_size", None) + ) + if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3: + # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0` + # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081 + config_kwargs.update( + { + "zero_optimization.reduce_bucket_size": hidden_size * hidden_size, + "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, + "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size, + } + ) + + # If ZeRO-3 is used, we shard both the active and reference model. + # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0) + if config_kwargs["zero_optimization"]["stage"] != 3: + config_kwargs["zero_optimization"]["stage"] = 0 + model, *_ = deepspeed.initialize(model=model, config=config_kwargs) + model.eval() + return model + + def concatenated_inputs(self, batch: Dict[str, Union[List, torch.LongTensor]]) -> Dict[str, torch.LongTensor]: + """Concatenate the chosen and rejected inputs into a single tensor. + + Args: + batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length). + + Returns: + A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'. + """ + concatenated_batch = {} + + if self.is_encoder_decoder: + max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1]) + else: + max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) + + for k in batch: + if k.startswith("chosen") and isinstance(batch[k], torch.Tensor): + pad_value = self.label_pad_token_id if "labels" in k or self.is_encoder_decoder else self.padding_value + concatenated_key = k.replace("chosen", "concatenated") + concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) + for k in batch: + if k.startswith("rejected") and isinstance(batch[k], torch.Tensor): + pad_value = self.label_pad_token_id if "labels" in k or self.is_encoder_decoder else self.padding_value + concatenated_key = k.replace("rejected", "concatenated") + concatenated_batch[concatenated_key] = torch.cat( + ( + concatenated_batch[concatenated_key], + pad_to_length(batch[k], max_length, pad_value=pad_value), + ), + dim=0, + ).to(self.accelerator.device) + + if self.is_encoder_decoder: + concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1) + concatenated_batch["concatenated_attention_mask"] = batch["prompt_attention_mask"].repeat(2, 1) + + return concatenated_batch + + def dpo_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + reference_chosen_logps: torch.FloatTensor, + reference_rejected_logps: torch.FloatTensor, + reference_free: bool = False, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Compute the DPO loss for a batch of policy and reference model log probabilities. + + Args: + policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) + policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) + reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,) + reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,) + reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses. + + Returns: + A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). + The losses tensor contains the DPO loss for each example in the batch. + The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. + """ + pi_logratios = policy_chosen_logps - policy_rejected_logps + if reference_free: + ref_logratios = 0 + else: + ref_logratios = reference_chosen_logps - reference_rejected_logps + + logits = pi_logratios - ref_logratios + + # The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. + # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and + # calculates a conservative DPO loss. + if self.loss_type == "sigmoid": + losses = ( + -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * logits) * self.label_smoothing + ) + elif self.loss_type == "hinge": + losses = torch.relu(1 - self.beta * logits) + elif self.loss_type == "ipo": + # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper. + losses = (logits - 1 / (2 * self.beta)) ** 2 + elif self.loss_type == "kto": + # eqn (7) of the HALOs paper + chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clamp(min=0) + rejected_KL = (policy_rejected_logps - reference_rejected_logps).mean().clamp(min=0) + + chosen_logratios = policy_chosen_logps - reference_chosen_logps + rejected_logratios = policy_rejected_logps - reference_rejected_logps + # As described in the KTO report, the KL term for chosen (rejected) is estimated using the rejected (chosen) half. + losses = torch.cat( + ( + 1 - F.sigmoid(self.beta * (chosen_logratios - rejected_KL)), + 1 - F.sigmoid(self.beta * (chosen_KL - rejected_logratios)), + ), + 0, + ) + else: + raise ValueError( + f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto']" + ) + + chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach() + rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach() + + return losses, chosen_rewards, rejected_rewards + + def _get_batch_logps( + self, + logits: torch.FloatTensor, + labels: torch.LongTensor, + average_log_prob: bool = False, + ) -> torch.FloatTensor: + """Compute the log probabilities of the given labels under the given logits. + + Args: + logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) + labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length) + average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. + + Returns: + A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. + """ + if logits.shape[:-1] != labels.shape: + raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") + + if not self.is_encoder_decoder: + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] + loss_mask = labels != self.label_pad_token_id + + # dummy token; we'll ignore the losses on these tokens later + labels[labels == self.label_pad_token_id] = 0 + + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + + if average_log_prob: + return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + else: + return (per_token_logps * loss_mask).sum(-1) + + def concatenated_forward( + self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. + + We do this to avoid doing two forward passes, because it's faster for FSDP. + """ + concatenated_batch = self.concatenated_inputs(batch) + len_chosen = batch["chosen_labels"].shape[0] + + model_kwargs = ( + { + "labels": concatenated_batch["concatenated_labels"], + "decoder_input_ids": concatenated_batch.pop("concatenated_decoder_input_ids", None), + } + if self.is_encoder_decoder + else {} + ) + all_logits = model( + concatenated_batch["concatenated_input_ids"], + attention_mask=concatenated_batch["concatenated_attention_mask"], + **model_kwargs, + ).logits.to(torch.float32) + + all_logps = self._get_batch_logps( + all_logits, + concatenated_batch["concatenated_labels"], + average_log_prob=False, + ) + + chosen_logps = all_logps[:len_chosen] + rejected_logps = all_logps[len_chosen:] + + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] + + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits) + + def get_batch_metrics( + self, + model, + batch: Dict[str, Union[List, torch.LongTensor]], + train_eval: Literal["train", "eval"] = "train", + ): + """Compute the DPO loss and other metrics for the given batch of inputs for train or test.""" + metrics = {} + + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + ) = self.concatenated_forward(model, batch) + with torch.no_grad(): + if self.ref_model is None: + with self.accelerator.unwrap_model(self.model).disable_adapter(): + ( + reference_chosen_logps, + reference_rejected_logps, + _, + _, + ) = self.concatenated_forward(self.model, batch) + else: + ( + reference_chosen_logps, + reference_rejected_logps, + _, + _, + ) = self.concatenated_forward(self.ref_model, batch) + + losses, chosen_rewards, rejected_rewards = self.dpo_loss( + policy_chosen_logps, + policy_rejected_logps, + reference_chosen_logps, + reference_rejected_logps, + ) + reward_accuracies = (chosen_rewards > rejected_rewards).float() + + prefix = "eval_" if train_eval == "eval" else "" + metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().mean() + metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu().mean() + metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().mean() + metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().mean() + metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu().mean() + metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().mean() + metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().cpu().mean() + metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().mean() + + return losses.mean(), metrics + + def compute_loss( + self, + model: Union[PreTrainedModel, nn.Module], + inputs: Dict[str, Union[torch.Tensor, Any]], + return_outputs=False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: + if not self.use_dpo_data_collator: + warnings.warn( + "compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than " + "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator" + ) + loss, metrics = self.get_batch_metrics(model, inputs, train_eval="train") + + # force log the metrics + if self.accelerator.is_main_process: + self.store_metrics(metrics, train_eval="train") + + if return_outputs: + return (loss, metrics) + return loss + + def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: + """Generate samples from the model and reference model for the given batch of inputs.""" + + policy_output = model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.tokenizer.pad_token_id, + ) + + if self.ref_model is None: + with self.accelerator.unwrap_model(self.model).disable_adapter(): + reference_output = self.model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.tokenizer.pad_token_id, + ) + else: + reference_output = self.ref_model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.tokenizer.pad_token_id, + ) + + policy_output = pad_to_length(policy_output, self.max_length, self.tokenizer.pad_token_id) + policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True) + + reference_output = pad_to_length(reference_output, self.max_length, self.tokenizer.pad_token_id) + reference_output_decoded = self.tokenizer.batch_decode(reference_output, skip_special_tokens=True) + + return policy_output_decoded, reference_output_decoded + + def prediction_step( + self, + model: Union[PreTrainedModel, nn.Module], + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ): + if not self.use_dpo_data_collator: + warnings.warn( + "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than " + "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator" + ) + if ignore_keys is None: + if hasattr(model, "config"): + ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + with torch.no_grad(): + loss, metrics = self.get_batch_metrics(model, inputs, train_eval="eval") + + # force log the metrics + if self.accelerator.is_main_process: + self.store_metrics(metrics, train_eval="eval") + + if prediction_loss_only: + return (loss.detach(), None, None) + + # logits for the chosen and rejected samples from model + logits_dict = { + "eval_logits/chosen": metrics["eval_logits/chosen"], + "eval_logits/rejected": metrics["eval_logits/rejected"], + } + logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys) + logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device) + labels = torch.zeros(logits.shape[0], device=self.accelerator.device) + + return (loss.detach(), logits, labels) + + def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None: + for key, value in metrics.items(): + self._stored_metrics[train_eval][key].append(value) + + def evaluation_loop( + self, + dataloader: DataLoader, + description: str, + prediction_loss_only: Optional[bool] = None, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "eval", + ) -> EvalLoopOutput: + """ + Overriding built-in evaluation loop to store metrics for each batch. + Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. + + Works both with or without labels. + """ + + # Sample and save to game log if requested (for one batch to save time) + if self.generate_during_eval: + # Generate random indices within the range of the total number of samples + num_samples = len(dataloader.dataset) + random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size) + + # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader + random_batch_dataset = dataloader.dataset.select(random_indices) + random_batch = self.data_collator(random_batch_dataset) + random_batch = self._prepare_inputs(random_batch) + + policy_output_decoded, ref_output_decoded = self.get_batch_samples(self.model, random_batch) + + self.log( + { + "game_log": wandb.Table( + columns=["Prompt", "Policy", "Ref Model"], + rows=[ + [prompt, pol[len(prompt) :], ref[len(prompt) :]] + for prompt, pol, ref in zip( + random_batch["prompt"], policy_output_decoded, ref_output_decoded + ) + ], + ) + } + ) + self.state.log_history.pop() + + # Base evaluation + initial_output = super().evaluation_loop( + dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix + ) + + return initial_output + + def log(self, logs: Dict[str, float]) -> None: + """ + Log `logs` on the various objects watching training, including stored metrics. + + Args: + logs (`Dict[str, float]`): + The values to log. + """ + # logs either has 'loss' or 'eval_loss' + train_eval = "train" if "loss" in logs else "eval" + # Add averaged stored metrics to logs + for key, metrics in self._stored_metrics[train_eval].items(): + logs[key] = torch.tensor(metrics).mean().item() + del self._stored_metrics[train_eval] + return super().log(logs) diff --git a/trl/trl/trl/trainer/iterative_sft_trainer.py b/trl/trl/trl/trainer/iterative_sft_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..006b02ad5123aa704a1d482571da464ac9c81f45 --- /dev/null +++ b/trl/trl/trl/trainer/iterative_sft_trainer.py @@ -0,0 +1,367 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import warnings +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +from datasets import Dataset +from torch.utils.data import DataLoader +from transformers import ( + DataCollator, + DataCollatorForLanguageModeling, + DataCollatorForSeq2Seq, + PreTrainedModel, + PreTrainedTokenizerBase, + Trainer, + TrainingArguments, +) +from transformers.trainer_utils import EvalLoopOutput + +from ..core import PPODecorators +from ..import_utils import is_peft_available + + +if is_peft_available(): + from peft import PeftModel + + +class IterativeSFTTrainer(Trainer): + """ + The IterativeSFTTrainer can be used to finetune models with methods that requires some steps between optimization. + + Attributes: + **model** (`PreTrainedModel`) -- Model to be optimized, either an 'AutoModelForCausalLM' or an 'AutoModelForSeq2SeqLM'. + Check the documentation of `PreTrainedModel` for more details. + **args** (`transformers.TrainingArguments`): -- The arguments to use for training. + **tokenizer** (`PreTrainedTokenizerBase`) -- Tokenizer to be used for encoding the + data. Check the documentation of `transformers.PreTrainedTokenizer` and + `transformers.PreTrainedTokenizerFast` for more details. + **optimizers** (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): -- The optimizer and scheduler to use for training. + **data_collator** (Union[DataCollatorForLanguageModeling, DataCollatorForSeq2Seq], *optional*) -- Data collator to be used for training and + passed along the dataloader. + **eval_dataset** (`datasets.Dataset`): The dataset to use for evaluation. + **max_length** (`int`, defaults to `None`): -- The maximum length of the input. + **truncation_mode** (`str`, defaults to `keep_end`): -- The truncation mode to use, either `keep_end` or `keep_start`. + **preprocess_logits_for_metrics** (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): -- The function to use to preprocess the logits before computing the metrics. + **compute_metrics** (`Callable[[EvalPrediction], Dict]`, *optional*): -- The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to metric values. + **optimize_device_cache ** (`bool`, *optional*, defaults to `False`) -- Optimize CUDA cache for slightly more memory-efficient training. + """ + + def __init__( + self, + model: PreTrainedModel = None, + args: TrainingArguments = None, + tokenizer: PreTrainedTokenizerBase = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( + None, + None, + ), + data_collator: Optional[DataCollator] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + max_length: Optional[int] = None, + truncation_mode: Optional[str] = "keep_end", + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None, + optimize_device_cache: Optional[bool] = False, + ): + # Step 0: check positional arguments validity + if not isinstance(tokenizer, (PreTrainedTokenizerBase)): + raise ValueError( + f"tokenizer must be a PreTrainedTokenizerBase like a PreTrainedTokenizer or a PreTrainedTokenizerFast, got {type(tokenizer)}" + ) + if not isinstance(model, PreTrainedModel): + raise ValueError(f"model must be a PreTrainedModel, got {type(model)}") + if not model.can_generate(): + warnings.warn( + f"The current model class {type(model)} is not compatible with `.generate()`" + "Please make sure that this is intended." + ) + if optimizers[1] is None and args.max_steps == -1: + raise ValueError( + "When no scheduler is provided, you need to set the total number of training steps to perform `max_steps`" + ) + + self.is_encoder_decoder = getattr(model.config, "is_encoder_decoder", False) + self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) + + self.tokenizer = tokenizer + + if data_collator is None: + if self.is_encoder_decoder: + warnings.warn( + "No data collator is provided. Using 'DataCollatorForSeq2Seq' with" + "'labels_pad_token_id' set to '-100' and 'pad_to_multiple_of' set to 8." + ) + self.data_collator = DataCollatorForSeq2Seq(tokenizer, label_pad_token_id=-100, pad_to_multiple_of=8) + else: + warnings.warn("No data collator is provided. Using 'DataCollatorForLanguageModeling'") + self.data_collator = DataCollatorForLanguageModeling(self.tokenizer, mlm=False) + else: + self.data_collator = data_collator + + self.max_length = max_length + self.truncation_mode = truncation_mode + self.optimize_device_cache = optimize_device_cache + + super().__init__( + model=model, + args=args, + data_collator=self.data_collator, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + compute_metrics=compute_metrics, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + self.create_optimizer_and_scheduler(self.args.max_steps) + + # prepare model, optimizer and lr_scheduler + self.model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( + self.model, self.optimizer, self.lr_scheduler + ) + + self.tokenizer.truncation_side = "left" if self.truncation_mode == "keep_end" else "right" + + if not hasattr(self, "accelerator"): + raise AttributeError( + "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." + ) + + PPODecorators.optimize_device_cache = self.optimize_device_cache + + def prepare_model_inputs(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor): + if attention_mask is None: + attention_mask = [torch.ones_like(ids) for ids in input_ids] + + if self.is_encoder_decoder: + input_data = self.data_collator( + [ + {"input_ids": ids, "attention_mask": att, "labels": lab} + for ids, att, lab in zip(input_ids, attention_mask, labels) + ] + ).to(self.model.device) + + input_data.pop("decoder_input_ids", None) # This is directly computed inside the model + + input_data["labels"][input_data["labels"] == self.tokenizer.pad_token_id] = -100 + + else: + input_data = self.data_collator( + [{"input_ids": ids, "attention_mask": att} for ids, att in zip(input_ids, attention_mask)] + ).to(self.model.device) + + # truncate in case the user has provided input_ids, attention_mask and labels + if self.max_length is not None: + if self.truncation_mode == "keep_start": + input_data = {k: v[: self.max_length] for k, v in input_data.items()} + elif self.truncation_mode == "keep_end": + input_data = {k: v[-self.max_length :] for k, v in input_data.items()} + else: + raise ValueError(f"Unknown truncation mode: {self.truncation_mode}") + + return input_data + + @staticmethod + def _step_safety_checker( + input_ids: List[torch.LongTensor], + attention_mask: List[torch.LongTensor], + labels: List[torch.LongTensor], + texts: List[str], + texts_labels: List[str], + ): + """ + Check if the input data is valid for training. + + Args: + input_ids (List[`torch.LongTensor`]): + List of tensors containing the input_ids + attention_mask (List[`torch.LongTensor`]): + List of tensors containing the attention_mask + labels (List[`torch.FloatTensor`]): + List of tensors containing the labels + texts (List[`str`]): + List of string containing the text input. + texts_labels (List[`str`]): + List of string containing the text labels. + Returns: + `tuple`: The input data. + """ + if texts is None: + if attention_mask is None: + for name, tensor_list in zip(["input_ids", "labels"], [input_ids, labels]): + if not isinstance(tensor_list, list): + raise ValueError(f"{name} must be a list of tensors - got {type(tensor_list)}") + if not isinstance(tensor_list[0], torch.Tensor): + raise ValueError(f"Elements in {name} must be tensors - got {type(tensor_list[0])}") + else: + for name, tensor_list in zip( + ["input_ids", "attention_mask", "labels"], [input_ids, attention_mask, labels] + ): + if not isinstance(tensor_list, list): + raise ValueError(f"{name} must be a list of tensors - got {type(tensor_list)}") + if not isinstance(tensor_list[0], torch.Tensor): + raise ValueError(f"Elements in {name} must be tensors - got {type(tensor_list[0])}") + else: + if not isinstance(texts, list): + raise ValueError(f"'text' must be a list of strings - got {type(texts)}") + if not isinstance(texts[0], str): + raise ValueError(f"Elements in 'text' must be strings - got {type(texts[0])}") + if texts_labels is not None: + if not isinstance(texts_labels, list): + raise ValueError(f"'text_labels' must be a list of strings - got {type(texts_labels)}") + if not isinstance(texts_labels[0], str): + raise ValueError(f"Elements in 'text_labels' must be strings - got {type(texts_labels[0])}") + + return input_ids, attention_mask, labels, texts, texts_labels + + @PPODecorators.empty_device_cache() + def step( + self, + input_ids: Optional[List[torch.LongTensor]] = None, + attention_mask: Optional[List[torch.LongTensor]] = None, + labels: Optional[List[torch.LongTensor]] = None, + texts: Optional[List[str]] = None, + texts_labels: Optional[List[str]] = None, + ): + """ + Run an optimisation step given a list of input_ids, attention_mask, and labels or a list of text and text_labels. + Args: + input_ids (List[`torch.LongTensor`]): + List of tensors containing the input_ids (if not provided, text will be used) + attention_mask (List[`torch.LongTensor`], , *optional*): + List of tensors containing the attention_mask + labels (List[`torch.FloatTensor`], *optional*): + List of tensors containing the labels (if set to None, will default to input_ids) + texts (List[`str`], *optional*): + List of strings containing the text input (if not provided, input_ids will directly be used) + texts_labels (List[`str`], *optional*): + List of strings containing the text labels (if set to None, will default to text) + Returns: + `dict[str, Any]`: A summary of the training statistics + """ + self.model.train() + + if self.state.global_step == 0: + self.tr_loss = torch.tensor(0.0).to(self.args.device) + self._globalstep_last_logged = self.state.global_step + + if input_ids is None and texts is None: + raise ValueError("Step should include `input_ids` or `texts` as keyword arguments.") + elif input_ids is not None and texts is not None: + warnings.warn( + "Both 'input_ids' and 'texts' are provided. 'input_ids' will be overwritten using inputs provided by the 'texts' keyword argument." + ) + + if labels is None and texts_labels is None and self.is_encoder_decoder: + raise ValueError( + "No 'labels' or 'text_labels' are provided. When using an encoder-decoder architecture, 'labels' or 'text_labels' must be passed." + ) + + input_ids, attention_mask, labels, texts, texts_labels = self._step_safety_checker( + input_ids, attention_mask, labels, texts, texts_labels + ) + + if texts is not None: + model_inputs = self.tokenizer( + texts, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt" + ) + + input_ids, attention_mask = model_inputs["input_ids"], model_inputs["attention_mask"] + + if texts_labels is not None: + labels = self.tokenizer( + texts, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt" + )["input_ids"] + + if labels is None: + warnings.warn("No labels are provided. Setting labels to input_ids") + labels = input_ids + + model_inputs = self.prepare_model_inputs(input_ids, attention_mask, labels) + + model_inputs_names = list(model_inputs.keys()) + + batch_dict = {} + batch_dict.update(model_inputs) + + def collator(data): + return_dict = dict() + for key in data[0]: + if key in ["input_ids", "attention_mask", "labels"]: + return_dict[key] = torch.stack([d[key] for d in data]).to(self.model.device) + return return_dict + + batch_data = Dataset.from_dict(batch_dict) + batch_data.set_format("torch") + + step_dataloader = DataLoader( + batch_data, + batch_size=self.args.per_device_train_batch_size, + shuffle=True, + collate_fn=collator, + ) + + for _, batch in enumerate(step_dataloader): + with self.accelerator.accumulate(self.model): + model_inputs = {k: batch[k] for k in model_inputs_names} + loss = self.compute_loss(self.model, model_inputs) + + if self.args.n_gpu > 1: + loss = loss.mean() + + tr_loss_step = loss.detach() + + self.accelerator.backward(loss) + + if self.accelerator.sync_gradients and self.args.max_grad_norm is not None: + self.accelerator.clip_grad_norm_( + self.model.parameters(), + self.args.max_grad_norm, + ) + + self.optimizer.step() + self.optimizer.zero_grad() + if self.lr_scheduler is not None: + self.lr_scheduler.step() + + self.state.global_step += 1 + + # update stats etc + self.tr_loss += tr_loss_step + + self._maybe_log_save_evaluate() + + def _maybe_log_save_evaluate(self): + # check if eval is required + if self.args.eval_steps is not None: + if self.state.global_step % self.args.eval_steps == 0 and self.state.global_step != 0: + self.evaluate(self.eval_dataset) + + # check if logging is required + if self.args.logging_steps is not None: + if self.state.global_step % self.args.logging_steps == 0 and self.state.global_step != 0: + logs: Dict[str, float] = {} + + tr_loss_scalar = self._nested_gather(self.tr_loss).mean().item() + + # reset tr_loss to zero + self.tr_loss -= self.tr_loss + + logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) + logs["learning_rate"] = self._get_learning_rate() + + self._globalstep_last_logged = self.state.global_step + + self.log(logs) diff --git a/trl/trl/trl/trainer/ppo_config.py b/trl/trl/trl/trainer/ppo_config.py new file mode 100644 index 0000000000000000000000000000000000000000..0f6a907b6461a74f79a92b0f666a8f2cc2505d99 --- /dev/null +++ b/trl/trl/trl/trainer/ppo_config.py @@ -0,0 +1,179 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os +import sys +import warnings +from dataclasses import dataclass, field +from typing import Literal, Optional + +import numpy as np +import tyro +from typing_extensions import Annotated + +from trl.trainer.utils import exact_div + +from ..core import flatten_dict +from ..import_utils import is_wandb_available + + +JSONDict = Annotated[Optional[dict], tyro.conf.arg(metavar="JSON", constructor=json.loads)] + + +@dataclass +class PPOConfig: + """ + Configuration class for PPOTrainer + """ + + # common parameters + exp_name: str = os.path.basename(sys.argv[0])[: -len(".py")] + """the name of this experiment (by default is the file name without the extension name)""" + seed: int = 0 + """Seed value for random generations""" + log_with: Optional[Literal["wandb", "tensorboard"]] = None + """Log with either 'wandb' or 'tensorboard', check https://huggingface.co/docs/accelerate/usage_guides/tracking for more details""" + task_name: Optional[str] = None + """Name of task to use - used only for tracking purposes""" + model_name: Optional[str] = None + """Name of model to use - used only for tracking purposes""" + query_dataset: Optional[str] = None + """Name of dataset to query - used only for tracking purposes""" + reward_model: Optional[str] = None + """The reward model to use - used only for tracking purposes""" + remove_unused_columns: bool = True + """Remove unused columns from the dataset if `datasets.Dataset` is used""" + tracker_kwargs: JSONDict = field(default_factory=dict) + """Keyword arguments for the tracker (e.g. python ppo.py --ppo_config.tracker_kwargs='{"wandb": {"entity": "my_wandb_entity", "name": "my_exp_name"}}'""" + accelerator_kwargs: JSONDict = field(default_factory=dict) + """Keyword arguments for the accelerator""" + project_kwargs: JSONDict = field(default_factory=dict) + """Keyword arguments for the accelerator project config (e.g. `logging_dir`)""" + tracker_project_name: str = "trl" + """Name of project to use for tracking""" + push_to_hub_if_best_kwargs: JSONDict = field(default_factory=dict) + """Keyword arguments for pushing model to the hub during training (e.g. repo_id)""" + + # hyperparameters + steps: int = 20000 + """Number of training steps""" + learning_rate: float = 1e-5 + """Adam learning rate""" + adap_kl_ctrl: bool = True + """Use adaptive KL control, otherwise linear""" + init_kl_coef: Optional[float] = 0.2 + """Initial KL penalty coefficient (used for adaptive and linear control)""" + kl_penalty: Literal["kl", "abs", "mse", "full"] = "kl" + """kl penalty options: 'kl': model_logp - ref_logp, 'abs': abs(kl), 'mse': mean squared error mse(kl) and 'full': the actual kl for all tokens in the distribution""" + target: Optional[float] = 6 + """Target KL value for adaptive KL control""" + horizon: Optional[float] = 10000 + """Horizon for adaptive KL control""" + gamma: float = 1 + """Gamma parameter for advantage calculation""" + lam: float = 0.95 + """Lambda parameter for advantage calculation""" + cliprange: float = 0.2 + """Range for clipping in PPO policy gradient loss""" + cliprange_value: float = 0.2 + """Range for clipping values in loss calculation""" + vf_coef: float = 0.1 + """Scaling factor for value loss""" + batch_size: int = 256 + """Number of samples per optimisation step""" + forward_batch_size: Optional[int] = None + """DEPRECATED: use `mini_batch_size` instead, which does the same thing.""" + mini_batch_size: int = 1 + """Number of samples optimized in each mini batch""" + gradient_accumulation_steps: int = 1 + """The number of gradient accumulation steps""" + world_size: tyro.conf.Suppress[int] = None + """The world size for distributed training""" + ppo_epochs: int = 4 + """Number of optimisation epochs per batch of samples""" + max_grad_norm: Optional[float] = None + """Maximum gradient norm for gradient clipping""" + optimize_cuda_cache: Optional[bool] = None + """DEPRECATED: use `optimize_device_cache` instead, which does the same thing.""" + optimize_device_cache: Optional[bool] = False + """Optimize device cache for slightly more memory-efficient training""" + early_stopping: bool = False + """Whether to stop the PPO optimization loop early is the KL too high""" + target_kl: float = 1 + """Stop early if we exceed this value by over 50%""" + compare_steps: int = 1 + """Number of steps between comparison of the current reward with the best seen so far""" + ratio_threshold: float = 10.0 + """Skip mini-batches with high PPO ratios that can cause loss spikes""" + use_score_scaling: bool = False + """Use score scaling""" + use_score_norm: bool = False + """Use score normalization. Only applicable if use_score_scaling is True""" + score_clip: Optional[float] = None + """Score clipping""" + whiten_rewards: bool = False + """Whiten the rewards before compute advantages""" + + # computed hyperparameters at runtime; we use `tyro.conf.Suppress` to hide them from the help text + is_encoder_decoder: Optional[tyro.conf.Suppress[bool]] = None + """TO BE FILLED In RUNTIME: Whether the model is an encoder-decoder model""" + is_peft_model: Optional[tyro.conf.Suppress[bool]] = None + """TO BE FILLED In RUNTIME: Whether the model is a PEFT model""" + backward_batch_size: tyro.conf.Suppress[int] = None + """TO BE FILLED In RUNTIME: Number of samples optimized in an `optimizer.step()` call""" + global_backward_batch_size: tyro.conf.Suppress[int] = None + """TO BE FILLED In RUNTIME: the effective `backward_batch_size` across all processes""" + global_batch_size: tyro.conf.Suppress[int] = None + """TO BE FILLED In RUNTIME: the effective `batch_size` across all processes""" + + if optimize_cuda_cache is not None: + warnings.warn( + "The `optimize_cuda_cache` argument will be deprecated soon, please use `optimize_device_cache` instead." + ) + optimize_device_cache = optimize_cuda_cache + else: + optimize_device_cache = False + + def __post_init__(self): + if self.forward_batch_size is not None: + warnings.warn( + "Note that using `forward_batch_size` is deprecated, use `mini_batch_size` instead. By setting it you overwrite `mini_batch_size` which affects both the batch size during forward passes and also the mini batch size for PPO optimization." + ) + self.mini_batch_size = self.forward_batch_size + + self.backward_batch_size = self.mini_batch_size * self.gradient_accumulation_steps + exact_div( + self.batch_size, + self.backward_batch_size, + "`batch_size`", + "`mini_batch_size * gradient_accumulation_steps`", + "`batch_size` must be a multiple of `mini_batch_size * gradient_accumulation_steps`", + ) + + # check if wandb is installed + if self.log_with == "wandb": + # raise error if wandb is not installed + if not is_wandb_available(): + raise ImportError( + "Please install wandb to use wandb logging. You can do this by running `pip install wandb`." + ) + + self.total_ppo_epochs = int(np.ceil(self.steps / self.batch_size)) + assert self.kl_penalty in ["kl", "abs", "mse", "full"] + + def to_dict(self): + output_dict = {} + for key, value in self.__dict__.items(): + output_dict[key] = value + return flatten_dict(output_dict) diff --git a/trl/trl/trl/trainer/ppo_trainer.py b/trl/trl/trl/trainer/ppo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..007e77a2686726bc8d1678e5b2f23f9092b6a573 --- /dev/null +++ b/trl/trl/trl/trainer/ppo_trainer.py @@ -0,0 +1,1440 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import math +import os +import time +import typing +import warnings +from contextlib import nullcontext +from typing import Callable, List, Optional, Union + +import datasets +import numpy as np +import torch +import torch.nn.functional as F +from accelerate import Accelerator +from accelerate.utils import ProjectConfiguration, gather_object, is_deepspeed_available +from datasets import Dataset +from huggingface_hub import whoami +from packaging import version +from torch.optim import Adam +from transformers import ( + DataCollatorForLanguageModeling, + PreTrainedTokenizer, + PreTrainedTokenizerBase, + PreTrainedTokenizerFast, +) + +from ..core import ( + WANDB_PADDING, + PPODecorators, + clip_by_value, + convert_to_scalar, + entropy_from_logits, + flatten_dict, + logprobs_from_logits, + masked_mean, + masked_var, + masked_whiten, + set_seed, + stack_dicts, + stats_to_np, +) +from ..import_utils import is_torch_greater_2_0, is_xpu_available +from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper, create_reference_model +from . import AdaptiveKLController, BaseTrainer, FixedKLController, PPOConfig, RunningMoments + + +if is_deepspeed_available(): + import deepspeed + +MODEL_CARD_TEMPLATE = """--- +license: apache-2.0 +tags: +- trl +- transformers +- reinforcement-learning +--- + +# {model_name} + +This is a [TRL language model](https://github.com/huggingface/trl) that has been fine-tuned with reinforcement learning to + guide the model outputs according to a value, function, or human feedback. The model can be used for text generation. + +## Usage + +To use this model for inference, first install the TRL library: + +```bash +python -m pip install trl +``` + +You can then generate text as follows: + +```python +from transformers import pipeline + +generator = pipeline("text-generation", model="{model_id}") +outputs = generator("Hello, my llama is cute") +``` + +If you want to use the model for training or to obtain the outputs from the value head, load the model as follows: + +```python +from transformers import AutoTokenizer +from trl import AutoModelForCausalLMWithValueHead + +tokenizer = AutoTokenizer.from_pretrained("{model_id}") +model = AutoModelForCausalLMWithValueHead.from_pretrained("{model_id}") + +inputs = tokenizer("Hello, my llama is cute", return_tensors="pt") +outputs = model(**inputs, labels=inputs["input_ids"]) +``` +""" + + +class PPOTrainer(BaseTrainer): + """ + The PPOTrainer uses Proximal Policy Optimization to optimise language models. + Note, this trainer is heavily inspired by the original OpenAI learning to summarize work here: + https://github.com/openai/summarize-from-feedback + + Attributes: + **config** (`PPOConfig`) -- Configuration object for PPOTrainer. Check the documentation of `PPOConfig` for more + details. + **model** (`PreTrainedModelWrapper`) -- Model to be optimized, Hugging Face transformer model with a value head. + Check the documentation of `PreTrainedModelWrapper` for more details. + **ref_model** (`PreTrainedModelWrapper`, *optional*) -- Reference model to be used for KL penalty, Hugging Face + transformer model with a casual language modelling head. Check the documentation of `PreTrainedModelWrapper` + for more details. If no reference model is provided, the trainer will create a reference model with the same + architecture as the model to be optimized with shared layers. + **tokenizer** (`PreTrainedTokenizerBase`) -- Tokenizer to be used for encoding the + data. Check the documentation of `transformers.PreTrainedTokenizer` and + `transformers.PreTrainedTokenizerFast` for more details. + **dataset** (Union[`torch.utils.data.Dataset`, `datasets.Dataset`], *optional*) -- PyTorch dataset or Hugging + Face dataset. This is used to create a PyTorch dataloader. If no dataset is provided, the dataloader must be + created outside the trainer users needs to design their own dataloader and make sure the batch + size that is used is the same as the one specified in the configuration object. + **optimizer** (`torch.optim.Optimizer`, *optional*) -- Optimizer to be used for training. If no optimizer is + provided, the trainer will create an Adam optimizer with the learning rate specified in the configuration + object. + **data_collator** (DataCollatorForLanguageModeling, *optional*) -- Data collator to be used for training and + passed along the dataloader + **num_shared_layers** (int, *optional*) -- Number of layers to be shared between the model and the reference + model, if no reference model is passed. If no number is provided, all the layers will be shared. + **lr_scheduler** (`torch.optim.lr_scheduler`, *optional*) -- Learning rate scheduler to be used for training. + """ + + def __init__( + self, + config: PPOConfig = None, + model: PreTrainedModelWrapper = None, + ref_model: Optional[PreTrainedModelWrapper] = None, + tokenizer: PreTrainedTokenizerBase = None, + dataset: Optional[Union[torch.utils.data.Dataset, Dataset]] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + data_collator: Optional[typing.Callable] = None, + num_shared_layers: Optional[int] = None, + lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + ): + """ + Initialize PPOTrainer. + + Args: + config (`PPOConfig`): + Configuration object for PPOTrainer. Check the documentation of `PPOConfig` for more details. + model (`PreTrainedModelWrapper`): + Hugging Face transformer model with a value head. + ref_model (`PreTrainedModelWrapper`): + Hugging Face transformer model with a casual language modelling head. Used for KL penalty + tokenizer (`transformers.PreTrainedTokenizerBase`): + Hugging Face tokenizer + dataset (Optional[Union[`torch.utils.data.Dataset`, `datasets.Dataset`]]): + PyTorch dataset or Hugging Face dataset. If a Hugging Face dataset is passed, the dataset + will be preprocessed by removing the columns that are not used by the model. If none is passed, + a warning will be raised in a multi-GPU setting. + optimizer (Optional[`torch.optim.Optimizer`]): + Optimizer used for training. If `None`, the `Adam` is used as default. + data_collator (Optional[function]): + Data collator function. + num_shared_layers (Optional[int]): + Number of shared layers between the model and the reference model. If `None`, all layers are shared. + used only if `ref_model` is `None`. + lr_scheduler (Optional[`torch.optim.lr_scheduler`]): + Learning rate scheduler used for training. + """ + super().__init__(config) + + # initial seed for reproducible experiments + set_seed(config.seed) + + # Step 0: check positional arguments validity + if not isinstance(config, PPOConfig): + raise ValueError(f"config must be a PPOConfig, got {type(config)}") + if not isinstance(tokenizer, (PreTrainedTokenizerBase)): + raise ValueError( + f"tokenizer must be a PreTrainedTokenizerBase like a PreTrainedTokenizer or a PreTrainedTokenizerFast, got {type(tokenizer)}" + ) + if not isinstance(model, (SUPPORTED_ARCHITECTURES)): + raise ValueError( + f"model must be a PreTrainedModelWrapper, got {type(model)} - supported architectures are: {SUPPORTED_ARCHITECTURES}" + ) + # Step 1: Initialize Accelerator + self.accelerator = Accelerator( + log_with=config.log_with, + gradient_accumulation_steps=config.gradient_accumulation_steps, + project_config=ProjectConfiguration(**config.project_kwargs), + **config.accelerator_kwargs, + ) + + # Step 1.1 Runtime variables filled by the accelerator + config.world_size = self.accelerator.num_processes + config.global_backward_batch_size = config.backward_batch_size * config.world_size + config.global_batch_size = config.batch_size * config.world_size + + self.model = model + self.model_params = filter(lambda p: p.requires_grad, self.model.parameters()) + self.is_encoder_decoder = hasattr(self.model, "is_encoder_decoder") + self.is_peft_model = getattr(self.model, "is_peft_model", False) + config.is_encoder_decoder = self.is_encoder_decoder + config.is_peft_model = self.is_peft_model + + is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard" + self.accelerator.init_trackers( + config.tracker_project_name, + config=dict(trl_ppo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(), + init_kwargs=config.tracker_kwargs, + ) + self.is_using_text_environment = getattr(config, "use_text_environment", False) + + if isinstance(ref_model, SUPPORTED_ARCHITECTURES): + self.ref_model = ref_model + if num_shared_layers is not None: + warnings.warn( + "num_shared_layers is ignored when ref_model is provided. Two different models are used for the " + "model and the reference model and no layers are shared.", + UserWarning, + ) + elif ref_model is None and not self.is_peft_model: + self.ref_model = create_reference_model(self.model, num_shared_layers=num_shared_layers) + elif self.is_peft_model: + self.ref_model = None + else: + raise ValueError( + f"ref_model must be a PreTrainedModelWrapper or `None`, got {type(ref_model)} - supported " + f"architectures are: {SUPPORTED_ARCHITECTURES} " + ) + self.optional_peft_ctx = ( + self.accelerator.unwrap_model(self.model).pretrained_model.disable_adapter + if self.is_peft_model + else nullcontext + ) + + if not (isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast)): + raise ValueError( + "tokenizer must be a transformers.PreTrainedTokenizer or transformers.PreTrainedTokenizerFast" + ) + self.tokenizer = tokenizer + + if dataset is not None and not (isinstance(dataset, torch.utils.data.Dataset) or isinstance(dataset, Dataset)): + raise ValueError("dataset must be a torch.utils.data.Dataset or datasets.Dataset") + elif dataset is None: + warnings.warn( + "No dataset is provided. Make sure to set config.batch_size to the correct value before training.", + UserWarning, + ) + self.dataset = dataset + self._signature_columns = None + if self.dataset is not None: + self.dataloader = self.prepare_dataloader(self.dataset, data_collator) + elif self.dataset is None and self.accelerator.num_processes > 1: + warnings.warn( + "No dataset is provided. In a multi-GPU setting, this will lead to an error. You should" + " prepare your dataloader yourself with `dataloader = ppo_trainer.accelerator.prepare(dataloader)`" + " and using `torch.utils.data.DataLoader`, or pass a dataset to the `PPOTrainer`. Please " + " refer to the documentation for more details.", + UserWarning, + ) + self.dataloader = None + else: + self.dataloader = None + + # Step 3: Initialize optimizer and data collator + self.data_collator = DataCollatorForLanguageModeling(self.tokenizer, mlm=False) + if optimizer is None: + self.optimizer = Adam( + filter(lambda p: p.requires_grad, self.model.parameters()), + lr=self.config.learning_rate, + ) + else: + self.optimizer = optimizer + + self.lr_scheduler = lr_scheduler + if self.lr_scheduler is not None: + lr_scheduler_class = ( + torch.optim.lr_scheduler._LRScheduler + if not is_torch_greater_2_0() + else torch.optim.lr_scheduler.LRScheduler + ) + + if not isinstance(self.lr_scheduler, lr_scheduler_class): + raise ValueError( + "lr_scheduler must be a torch.optim.lr_scheduler._LRScheduler or torch.optim.lr_scheduler.LRScheduler (for torch >= 2.0)" + ) + + if self.config.adap_kl_ctrl: + self.kl_ctl = AdaptiveKLController(self.config.init_kl_coef, self.config.target, self.config.horizon) + else: + self.kl_ctl = FixedKLController(self.config.init_kl_coef) + + # Safety checkers for DS integration + is_deepspeed_used = self.accelerator.distributed_type == "DEEPSPEED" and hasattr( + self.accelerator.state, "deepspeed_plugin" + ) + + ( + self.model, + self.optimizer, + self.data_collator, + self.dataloader, + self.lr_scheduler, + ) = self.accelerator.prepare( + self.model, + self.optimizer, + self.data_collator, + self.dataloader, + self.lr_scheduler, + ) + if is_deepspeed_used: + # Quantized models are already set on the correct device + if not self.is_peft_model and not ( + getattr(self.ref_model.pretrained_model, "is_loaded_in_8bit", False) + or getattr(self.ref_model.pretrained_model, "is_loaded_in_4bit", False) + ): + self.ref_model = self._prepare_deepspeed(self.ref_model) + else: + self.ref_model = self.accelerator.prepare(self.ref_model) + + # In a distributed setup, only logging needs to be performed on the main process + # check: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html + # or: https://discuss.pytorch.org/t/use-distributed-data-parallel-correctly/82500/11 + self.is_distributed = self.accelerator.num_processes > 1 + + # init the current step + self.current_step = 0 + + # init variables for pushing model to hub + if config.push_to_hub_if_best_kwargs: + if "repo_id" not in config.push_to_hub_if_best_kwargs: + raise ValueError("You have to specify repo_id in order to push the model to the hub!") + self.push_to_hub_kwargs = config.push_to_hub_if_best_kwargs + self.compare_step = 0 + self.highest_reward = torch.tensor(-float("inf")) + + # post process for PP + if not getattr(self.model, "is_sequential_parallel", False): + self.current_device = self.accelerator.device + else: + if is_xpu_available(): + self.current_device = torch.device("xpu:0") + else: + self.current_device = torch.device("cuda:0") + + PPODecorators.optimize_device_cache = self.config.optimize_device_cache + + self.running = RunningMoments(self.accelerator) + + def _filter_kwargs(self, kwargs, target_func): + """ + filter the keyword arguments that are supported by the target function. + + Args: + kwargs (dict): + Keyword arguments + target_func (function): + Target function + """ + return {k: v for k, v in kwargs.items() if k in inspect.signature(target_func).parameters.keys()} + + def prepare_dataloader(self, dataset: Union[torch.utils.data.Dataset, Dataset], data_collator=None): + """ + Prepare the dataloader for training. + + Args: + dataset (Union[`torch.utils.data.Dataset`, `datasets.Dataset`]): + PyTorch dataset or Hugging Face dataset. If a Hugging Face dataset is passed, the dataset + will be preprocessed by removing the columns that are not used by the model. + data_collator (Optional[function]): + Data collator function. + + Returns: + `torch.utils.data.DataLoader`: PyTorch dataloader + """ + if isinstance(dataset, Dataset): + dataset = self._remove_unused_columns(dataset) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=self.config.batch_size, + collate_fn=data_collator, + shuffle=True, + drop_last=True, + ) + return dataloader + + # Adapted from transformers.Trainer._set_signature_columns_if_needed + def _set_signature_columns_if_needed(self): + if self._signature_columns is None: + # Inspect model forward signature to keep only the arguments it accepts. + signature = inspect.signature(self.model.forward) + self._signature_columns = list(signature.parameters.keys()) + # label => sentiment | we need query and response for logging purpose + self._signature_columns += ["label", "query", "response"] + + # Adapted from transformers.Trainer._remove_unused_columns + def _remove_unused_columns(self, dataset: "Dataset"): + if not self.config.remove_unused_columns: + return dataset + self._set_signature_columns_if_needed() + signature_columns = self._signature_columns + + ignored_columns = list(set(dataset.column_names) - set(signature_columns)) + + columns = [k for k in signature_columns if k in dataset.column_names] + + if version.parse(datasets.__version__) < version.parse("1.4.0"): + dataset.set_format( + type=dataset.format["type"], + columns=columns, + format_kwargs=dataset.format["format_kwargs"], + ) + return dataset + else: + return dataset.remove_columns(ignored_columns) + + def generate( + self, + query_tensor: Union[torch.Tensor, List[torch.Tensor]], + length_sampler: Callable = None, + batch_size: int = 4, + return_prompt: bool = True, + generate_ref_response: bool = False, + **generation_kwargs, + ): + """ + Generate response with the model given the query tensor. + call the `generate` method of the model. + + Args: + query_tensor (`torch.LongTensor`): + A tensor of shape (`seq_len`) containing query tokens or a list of tensors of shape (`seq_len`). + generation_kwargs (dict[str, Any]): + Keyword arguments for generation. + length_sampler (`Callable`, *optional*): + Callable that returns the number of newly generated tokens. + batch_size (`int`, *optional): + Batch size used for generation, defaults to `4`. + return_prompt (`bool`, *optional*): + If set to `False` the prompt is not returned but only the newly generated tokens, defaults to `True`. + generate_ref_response (`bool`, *optional*): + If set to `True` the reference response is also generated, defaults to `False`. + + Returns: + `torch.LongTensor`: A tensor of shape (`batch_size`, `gen_len`) containing response tokens. + """ + if generate_ref_response: + ref_model = self.model if self.is_peft_model else self.ref_model + if isinstance(query_tensor, List): + response = self._generate_batched( + self.model, + query_tensor, + length_sampler=length_sampler, + batch_size=batch_size, + return_prompt=return_prompt, + **generation_kwargs, + ) + if generate_ref_response: + with self.optional_peft_ctx(): + ref_response = self._generate_batched( + ref_model, + query_tensor, + length_sampler=length_sampler, + batch_size=batch_size, + return_prompt=return_prompt, + **generation_kwargs, + ) + + else: + if len(query_tensor.shape) == 2: + raise ValueError( + "query_tensor must be a tensor of shape (`seq_len`) or a list of tensors of shape (`seq_len`)" + ) + + if length_sampler is not None: + generation_kwargs["max_new_tokens"] = length_sampler() + response = self.accelerator.unwrap_model(self.model).generate( + input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs + ) + if generate_ref_response: + with self.optional_peft_ctx(): + ref_response = ref_model.generate(input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs) + + if not return_prompt and not self.is_encoder_decoder: + response = response[:, query_tensor.shape[0] :] + if generate_ref_response: + ref_response = ref_response[:, query_tensor.shape[0] :] + + if generate_ref_response: + return response, ref_response + return response + + def _generate_batched( + self, + model: PreTrainedModelWrapper, + query_tensors: List[torch.Tensor], + length_sampler: Callable = None, + batch_size: int = 4, + return_prompt: bool = True, + pad_to_multiple_of: int = None, + remove_padding: bool = True, + **generation_kwargs, + ): + outputs = [] + + padding_side_default = self.tokenizer.padding_side + if not self.is_encoder_decoder: + self.tokenizer.padding_side = "left" + + # in case we have fewer examples than bs + batch_size = min(len(query_tensors), batch_size) + + for i in range(0, len(query_tensors), batch_size): + if length_sampler is not None: + generation_kwargs["max_new_tokens"] = length_sampler() + + # prevent overflow if query tensors are not even multiple of bs + end_index = min(len(query_tensors), i + batch_size) + + batch = query_tensors[i:end_index] + batch_mask = [torch.ones_like(element) for element in batch] + inputs = {"input_ids": batch, "attention_mask": batch_mask} + + padded_inputs = self.tokenizer.pad( + inputs, + padding=True, + max_length=None, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors="pt", + ).to(self.current_device) + + generations = self.accelerator.unwrap_model(model).generate(**padded_inputs, **generation_kwargs) + + for generation, mask in zip(generations, padded_inputs["attention_mask"]): + if not self.is_encoder_decoder: + output = generation[(1 - mask).sum() :] # remove padding + else: + output = generation + + if not return_prompt and not self.is_encoder_decoder: + output = output[(mask).sum() :] # remove prompt + + if remove_padding and self.tokenizer.eos_token_id in output: + pad_mask = output == self.tokenizer.eos_token_id + pad_start = torch.nonzero(pad_mask, as_tuple=False)[0, 0].item() + output = output[: pad_start + 1] # keep the eos token at the end + + outputs.append(output) + + self.tokenizer.padding_side = padding_side_default + return outputs + + def _step_safety_checker( + self, + batch_size: int, + queries: List[torch.LongTensor], + responses: List[torch.LongTensor], + scores: List[torch.FloatTensor], + masks: Optional[List[torch.LongTensor]] = None, + ): + """ + Check if the input data is valid for training. + + Args: + batch_size (int): + Batch size from the config file. + queries (List[`torch.LongTensor`]): + List of tensors containing the encoded queries of shape (`query_length`) + responses (List[`torch.LongTensor`]): + List of tensors containing the encoded responses of shape (`response_length`) + scores (List[`torch.FloatTensor`]): + List of tensors containing the scores. + masks (List[`torch.LongTensor`], *optional*): + list of optional tensors containing the masks of shape (`query_length` + `response_length`) + Returns: + `tuple`: The input processed data. + """ + for name, tensor_list in zip(["queries", "responses", "scores"], [queries, responses, scores]): + if not isinstance(tensor_list, list): + raise ValueError(f"{name} must be a list of tensors - got {type(tensor_list)}") + if not isinstance(tensor_list[0], torch.Tensor): + raise ValueError(f"Elements in {name} must be tensors - got {type(tensor_list[0])}") + if batch_size is not None and len(tensor_list) != batch_size: + raise ValueError( + f"Batch size ({batch_size}) does not match number of examples - but got {len(tensor_list)} for: {name}" + ) + + # add queries, scores and responses on the correct device + queries = [tensor.to(self.current_device) for tensor in queries] + responses = [tensor.to(self.current_device) for tensor in responses] + scores = [tensor.to(self.current_device) for tensor in scores] + masks = [tensor.to(self.current_device) for tensor in masks] if masks is not None else None + + # squeeze scores if needed + for i, score in enumerate(scores): + if score.dim() > 1: + raise ValueError(f"Scores must be 1-dimensional - got {score.dim()} for {score}") + elif score.dim() == 1: + scores[i] = score.squeeze() + + return queries, responses, scores, masks + + @PPODecorators.empty_device_cache() + def step( + self, + queries: List[torch.LongTensor], + responses: List[torch.LongTensor], + scores: List[torch.FloatTensor], + response_masks: Optional[List[torch.LongTensor]] = None, + ): + """ + Run a PPO optimisation step given a list of queries, model responses, and rewards. + + Args: + queries (List[`torch.LongTensor`]): + List of tensors containing the encoded queries of shape (`query_length`) + responses (List[`torch.LongTensor`]): + List of tensors containing the encoded responses of shape (`response_length`) + scores (List[`torch.FloatTensor`]): + List of tensors containing the scores. + response_masks (List[`torch.FloatTensor`], *optional*)): + List of tensors containing masks of the response tokens. + + Returns: + `dict[str, Any]`: A summary of the training statistics + """ + bs = self.config.batch_size + + queries, responses, scores, response_masks = self._step_safety_checker( + bs, queries, responses, scores, response_masks + ) + scores = torch.tensor(scores, device=self.current_device) + if self.config.use_score_scaling: + # Score scaling + scores_mean, scores_std = self.running.update(scores) + tensor_to_kwargs = dict(dtype=scores.dtype, device=scores.device) + score_scaling_factor = self.running.std.to(**tensor_to_kwargs) + torch.finfo(scores.dtype).eps + if self.config.use_score_norm: + scores = (scores - self.running.mean.to(**tensor_to_kwargs)) / score_scaling_factor + else: + scores /= score_scaling_factor + + if self.config.score_clip is not None: + # Score clipping + scores_dtype = scores.dtype + scores = torch.clip(scores.float(), -self.config.score_clip, self.config.score_clip).to(dtype=scores_dtype) + + # if we want to push best model to the hub + if hasattr(self, "highest_reward"): + if self.compare_step % self.config.compare_steps == 0: + curr_mean_reward = scores.mean() + # if the best reward ever seen + if curr_mean_reward > self.highest_reward: + self.highest_reward = curr_mean_reward + # push model to hub + self.push_to_hub(**self.push_to_hub_kwargs) + self.compare_step += 1 + + timing = dict() + t0 = time.time() + + t = time.time() + + model_inputs = self.prepare_model_inputs(queries, responses) + + if self.is_distributed: + pad_first = self.tokenizer.padding_side == "left" + + model_inputs["input_ids"] = self.accelerator.pad_across_processes( + model_inputs["input_ids"], + dim=1, + pad_index=self.tokenizer.pad_token_id, + pad_first=pad_first, + ) + model_inputs["attention_mask"] = self.accelerator.pad_across_processes( + model_inputs["attention_mask"], dim=1, pad_index=0, pad_first=pad_first + ) + if self.is_encoder_decoder: + model_inputs["decoder_input_ids"] = self.accelerator.pad_across_processes( + model_inputs["decoder_input_ids"], + dim=1, + pad_index=self.tokenizer.pad_token_id, + pad_first=pad_first, + ) + model_inputs["decoder_attention_mask"] = self.accelerator.pad_across_processes( + model_inputs["decoder_attention_mask"], + dim=1, + pad_index=0, + pad_first=pad_first, + ) + + model_inputs_names = list(model_inputs.keys()) + + full_kl_penalty = self.config.kl_penalty == "full" + + with torch.no_grad(): + all_logprobs, logits_or_none, values, masks = self.batched_forward_pass( + self.model, + queries, + responses, + model_inputs, + response_masks=response_masks, + return_logits=full_kl_penalty, + ) + with self.optional_peft_ctx(): + ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass( + self.model if self.is_peft_model else self.ref_model, + queries, + responses, + model_inputs, + return_logits=full_kl_penalty, + ) + + timing["time/ppo/forward_pass"] = time.time() - t + + with torch.no_grad(): + t = time.time() + if full_kl_penalty: + active_full_logprobs = logprobs_from_logits(logits_or_none, None, gather=False) + ref_full_logprobs = logprobs_from_logits(ref_logits_or_none, None, gather=False) + + rewards, non_score_reward = self.compute_rewards( + scores, active_full_logprobs, ref_full_logprobs, masks + ) + else: + rewards, non_score_reward = self.compute_rewards(scores, all_logprobs, ref_logprobs, masks) + timing["time/ppo/compute_rewards"] = time.time() - t + + t = time.time() + values, advantages, returns = self.compute_advantages(values, rewards, masks) + timing["time/ppo/compute_advantages"] = time.time() - t + + # upcast to float32 to avoid dataset issues + batch_dict = { + "queries": queries, + "responses": responses, + "logprobs": all_logprobs.to(torch.float32), + "values": values.to(torch.float32), + "masks": masks, + "advantages": advantages, + "returns": returns, + } + batch_dict.update(model_inputs) + + t = time.time() + all_stats = [] + early_stop = False + for _ in range(self.config.ppo_epochs): + if early_stop: + break + b_inds = np.random.permutation(bs) + for backward_batch_start in range(0, bs, self.config.backward_batch_size): + backward_batch_end = backward_batch_start + self.config.backward_batch_size + backward_batch_inds = b_inds[backward_batch_start:backward_batch_end] + + for mini_batch_start in range(0, self.config.backward_batch_size, self.config.mini_batch_size): + mini_batch_end = mini_batch_start + self.config.mini_batch_size + mini_batch_inds = backward_batch_inds[mini_batch_start:mini_batch_end] + mini_batch_dict = { + "logprobs": batch_dict["logprobs"][mini_batch_inds], + "values": batch_dict["values"][mini_batch_inds], + "masks": batch_dict["masks"][mini_batch_inds], + # hacks: the queries and responses are ragged. + "queries": [batch_dict["queries"][i] for i in mini_batch_inds], + "responses": [batch_dict["responses"][i] for i in mini_batch_inds], + "advantages": batch_dict["advantages"][mini_batch_inds], + "returns": batch_dict["returns"][mini_batch_inds], + } + for k in model_inputs_names: + mini_batch_dict[k] = batch_dict[k][mini_batch_inds] + with self.accelerator.accumulate(self.model): + model_inputs = {k: mini_batch_dict[k] for k in model_inputs_names} + + logprobs, logits, vpreds, _ = self.batched_forward_pass( + self.model, + mini_batch_dict["queries"], + mini_batch_dict["responses"], + model_inputs, + return_logits=True, + ) + train_stats = self.train_minibatch( + mini_batch_dict["logprobs"], + mini_batch_dict["values"], + logprobs, + logits, + vpreds, + mini_batch_dict["masks"], + mini_batch_dict["advantages"], + mini_batch_dict["returns"], + ) + all_stats.append(train_stats) + + # typically, early stopping is done at the epoch level + if self.config.early_stopping: + policykl = train_stats["policy/policykl"] + early_stop = self._early_stop(policykl) + if early_stop: + break + + timing["time/ppo/optimize_step"] = time.time() - t + + t = time.time() + train_stats = stack_dicts(all_stats) + + # reshape advantages/ratios such that they are not averaged. + train_stats["policy/advantages"] = torch.flatten(train_stats["policy/advantages"]).unsqueeze(0) + train_stats["policy/advantages"] = torch.nan_to_num(train_stats["policy/advantages"], WANDB_PADDING) + train_stats["policy/ratio"] = torch.flatten(train_stats["policy/ratio"]).unsqueeze(0) + + stats = self.record_step_stats( + scores=scores, + logprobs=all_logprobs, + ref_logprobs=ref_logprobs, + non_score_reward=non_score_reward, + train_stats=train_stats, + kl_coef=self.kl_ctl.value, + masks=masks, + queries=queries, + responses=responses, + ) + # Gather/Reduce stats from all processes + if self.is_distributed: + stats = self.gather_stats(stats) + stats = stats_to_np(stats) + timing["time/ppo/calc_stats"] = time.time() - t + stats["ppo/learning_rate"] = self.optimizer.param_groups[0]["lr"] + + # Update the KL control - multiply the batch_size by the number of processes + self.kl_ctl.update( + stats["objective/kl"], + self.config.batch_size * self.accelerator.num_processes, + ) + + # Log the total ppo time + timing["time/ppo/total"] = time.time() - t0 + stats.update(timing) + + # post-process stats for tensorboard and other loggers + if self.config.log_with != "wandb": + stats = convert_to_scalar(stats) + + if self.lr_scheduler is not None: + self.lr_scheduler.step() + + return stats + + def _early_stop(self, policykl): + r""" + Handles the early stopping logic. If the policy KL is greater than the target KL, then the gradient is zeroed and + the optimization step is skipped. + This also handles the multi-gpu case where the policy KL is averaged across all processes. + + Args: + policy_kl (torch.Tensor): + the policy KL + + Returns: + `bool`: whether to early stop or not + """ + early_stop = False + if not self.config.early_stopping: + return early_stop + + if not self.is_distributed and policykl > 1.5 * self.config.target_kl: + self.optimizer.zero_grad() + early_stop = True + elif self.is_distributed: + import torch.distributed as dist + + # Wait for all processes to finish + dist.barrier() + + # all gather the policykl + dist.all_reduce(policykl, dist.ReduceOp.SUM) + policykl /= self.accelerator.num_processes + + if policykl > 1.5 * self.config.target_kl: + self.optimizer.zero_grad() + early_stop = True + return early_stop + + def gather_stats(self, stats): + """ + Gather stats from all processes. Useful in the context of distributed training. + + Args: + stats (dict[str, Any]): + a dictionary of stats to be gathered. The stats should contain torch tensors. + + Returns: + `dict[str, Any]`: A dictionary of stats with the tensors gathered. + """ + import torch.distributed as dist + + # Wait for all processes to finish + dist.barrier() + + for k, v in stats.items(): + if isinstance(v, torch.Tensor): + dist.all_reduce(v.to(self.accelerator.device), dist.ReduceOp.SUM) + v /= self.accelerator.num_processes + stats[k] = v + return stats + + def prepare_model_inputs(self, queries: torch.Tensor, responses: torch.Tensor): + if self.is_encoder_decoder: + input_data = self.data_collator( + [{"input_ids": q, "attention_mask": torch.ones_like(q)} for q in queries] + ).to(self.current_device) + + decoder_inputs = self.data_collator( + [{"input_ids": r, "attention_mask": torch.ones_like(r)} for r in responses] + ).to(self.current_device) + + input_data["decoder_input_ids"] = decoder_inputs["input_ids"] + input_data["decoder_attention_mask"] = decoder_inputs["attention_mask"] + else: + input_ids = [torch.cat([q, r]) for q, r in zip(queries, responses)] + input_data = self.data_collator( + [{"input_ids": ids, "attention_mask": torch.ones_like(ids)} for ids in input_ids] + ).to(self.current_device) + + input_data.pop("labels", None) # we don't want to compute LM losses + return input_data + + @PPODecorators.empty_device_cache() + def batched_forward_pass( + self, + model: PreTrainedModelWrapper, + queries: torch.Tensor, + responses: torch.Tensor, + model_inputs: dict, + return_logits: bool = False, + response_masks: Optional[torch.Tensor] = None, + ): + """ + Calculate model outputs in multiple batches. + + Args: + queries (`torch.LongTensor`): + List of tensors containing the encoded queries, shape (`batch_size`, `query_length`) + responses (`torch.LongTensor`): + List of tensors containing the encoded responses, shape (`batch_size`, `response_length`) + return_logits (`bool`, *optional*, defaults to `False`): + Whether to return all_logits. Set to `False` if logits are not needed to reduce memory consumption. + Returns: + (tuple): + - all_logprobs (`torch.FloatTensor`): Log probabilities of the responses, + shape (`batch_size`, `response_length`) + - all_ref_logprobs (`torch.FloatTensor`): Log probabilities of the responses, + shape (`batch_size`, `response_length`) + - all_values (`torch.FloatTensor`): Values of the responses, shape (`batch_size`, `response_length`) + """ + bs = len(queries) + fbs = self.config.mini_batch_size + all_logprobs = [] + all_logits = [] + all_masks = [] + all_values = [] + + model.eval() + + for i in range(math.ceil(bs / fbs)): + input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()} + query_batch = queries[i * fbs : (i + 1) * fbs] + response_batch = responses[i * fbs : (i + 1) * fbs] + if response_masks is not None: + response_masks_batch = response_masks[i * fbs : (i + 1) * fbs] + logits, _, values = model(**input_kwargs) + + if self.is_encoder_decoder: + input_ids = input_kwargs["decoder_input_ids"] + attention_mask = input_kwargs["decoder_attention_mask"] + else: + input_ids = input_kwargs["input_ids"] + attention_mask = input_kwargs["attention_mask"] + + logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:]) + masks = torch.zeros_like(attention_mask) + masks[:, :-1] = attention_mask[:, 1:] + + for j in range(len(query_batch)): + if self.is_encoder_decoder: + # Decoder sentence starts always in the index 1 after padding in the Enc-Dec Models + start = 1 + end = attention_mask[j, :].sum() - 1 + else: + start = len(query_batch[j]) - 1 # logprobs starts from the second query token + if attention_mask[j, 0] == 0: # offset left padding + start += attention_mask[j, :].nonzero()[0] + end = start + len(response_batch[j]) + if response_masks is not None: + response_masks_batch[j] = torch.cat( + (torch.zeros_like(query_batch[j]), response_masks_batch[j]) + )[1:] + + masks[j, :start] = 0 + masks[j, end:] = 0 + if response_masks is not None: + masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end] + + if return_logits: + all_logits.append(logits) + else: + del logits + all_values.append(values) + all_logprobs.append(logprobs) + all_masks.append(masks) + + return ( + torch.cat(all_logprobs), + torch.cat(all_logits)[:, :-1] if return_logits else None, + torch.cat(all_values)[:, :-1], + torch.cat(all_masks)[:, :-1], + ) + + @PPODecorators.empty_device_cache() + def train_minibatch( + self, + old_logprobs: torch.FloatTensor, + values: torch.FloatTensor, + logprobs: torch.FloatTensor, + logits: torch.FloatTensor, + vpreds: torch.FloatTensor, + mask: torch.LongTensor, + advantages: torch.FloatTensor, + returns: torch.FloatTensor, + ): + """ + Train one PPO minibatch + + Args: + logprobs (`torch.FloatTensor`): + Log probabilities of the model, shape [batch_size, response_length] + values (`torch.FloatTensor`): + Values of the value head, shape [batch_size, response_length] + query (`torch.LongTensor`): + Encoded queries, shape [batch_size, query_length] + response (`torch.LongTensor`): + Encoded responses, shape [batch_size, response_length] + model_input (`torch.LongTensor`): + Concatenated queries and responses, shape [batch_size, query_length+response_length] + + Returns: + train_stats (dict[str, `torch.Tensor`]): + Dictionary of training statistics + """ + self.model.train() + loss_p, loss_v, train_stats = self.loss( + old_logprobs, values, logits, vpreds, logprobs, mask, advantages, returns + ) + loss = loss_p + loss_v + self.accelerator.backward(loss) + if self.config.max_grad_norm is not None: + if self.accelerator.sync_gradients: + self.accelerator.clip_grad_norm_(self.model_params, self.config.max_grad_norm) + self.optimizer.step() + # we call optimizer.zero_grad() every time and let `accelerator` handle accumulation + # see https://huggingface.co/docs/accelerate/usage_guides/gradient_accumulation#the-finished-code + self.optimizer.zero_grad() + return train_stats + + def compute_rewards( + self, + scores: torch.FloatTensor, + logprobs: torch.FloatTensor, + ref_logprobs: torch.FloatTensor, + masks: torch.LongTensor, + ): + """ + Compute per token rewards from scores and KL-penalty. + + Args: + scores (`torch.FloatTensor`): + Scores from the reward model, shape (`batch_size`) + logprobs (`torch.FloatTensor`): + Log probabilities of the model, shape (`batch_size`, `response_length`) + ref_logprobs (`torch.FloatTensor`): + Log probabilities of the reference model, shape (`batch_size`, `response_length`) + """ + rewards, non_score_rewards = [], [] + for score, logprob, ref_logprob, mask in zip(scores, logprobs, ref_logprobs, masks): + # compute KL penalty (from difference in logprobs) + kl = self._kl_penalty(logprob, ref_logprob) + non_score_reward = -self.kl_ctl.value * kl + non_score_rewards.append(non_score_reward) + reward = non_score_reward.clone() + last_non_masked_index = mask.nonzero()[-1] + + # reward is preference model score + KL penalty + reward[last_non_masked_index] += score + rewards.append(reward) + return torch.stack(rewards), torch.stack(non_score_rewards) + + def _kl_penalty(self, logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor) -> torch.FloatTensor: + if self.config.kl_penalty == "kl": + return logprob - ref_logprob + + if self.config.kl_penalty == "abs": + return (logprob - ref_logprob).abs() + + if self.config.kl_penalty == "mse": + return 0.5 * (logprob - ref_logprob).square() + + if self.config.kl_penalty == "full": + # Flip is required due to this issue? :https://github.com/pytorch/pytorch/issues/57459 + return F.kl_div(ref_logprob, logprob, log_target=True, reduction="none").sum(-1) + + raise NotImplementedError + + def compute_advantages( + self, + values: torch.FloatTensor, + rewards: torch.FloatTensor, + mask: torch.FloatTensor, + ): + lastgaelam = 0 + advantages_reversed = [] + gen_len = rewards.shape[-1] + + values = values * mask + rewards = rewards * mask + + if self.config.whiten_rewards: + rewards = masked_whiten(rewards, mask, shift_mean=False) + + for t in reversed(range(gen_len)): + nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0 + delta = rewards[:, t] + self.config.gamma * nextvalues - values[:, t] + lastgaelam = delta + self.config.gamma * self.config.lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1) + + returns = advantages + values + advantages = masked_whiten(advantages, mask) + advantages = advantages.detach() + return values, advantages, returns + + def loss( + self, + old_logprobs: torch.FloatTensor, + values: torch.FloatTensor, + logits: torch.FloatTensor, + vpreds: torch.FloatTensor, + logprobs: torch.FloatTensor, + mask: torch.LongTensor, + advantages: torch.FloatTensor, + returns: torch.FloatTensor, + ): + """ + Calculate policy and value losses. + + Args: + old_logprobs (`torch.FloatTensor`): + Log probabilities of the model, shape (`batch_size`, `response_length`) + values (`torch.FloatTensor`): + Values of the value head, shape (`batch_size`, `response_length`) + rewards (`torch.FloatTensor`): + Rewards from the reward model, shape (`batch_size`, `response_length`) + logits (`torch.FloatTensor`): + Logits of the model, shape (`batch_size`, `response_length`, `vocab_size`) + v_pred (`torch.FloatTensor`): + Values of the value head, shape (`batch_size`, `response_length`) + logprobs (`torch.FloatTensor`): + Log probabilities of the model, shape (`batch_size`, `response_length`) + """ + + vpredclipped = clip_by_value( + vpreds, + values - self.config.cliprange_value, + values + self.config.cliprange_value, + ) + + vf_losses1 = (vpreds - returns) ** 2 + vf_losses2 = (vpredclipped - returns) ** 2 + vf_loss = 0.5 * masked_mean(torch.max(vf_losses1, vf_losses2), mask) + vf_clipfrac = masked_mean(torch.gt(vf_losses2, vf_losses1).float(), mask) + + ratio = torch.exp(logprobs - old_logprobs) + + pg_losses = -advantages * ratio + pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - self.config.cliprange, 1.0 + self.config.cliprange) + + pg_loss = masked_mean(torch.max(pg_losses, pg_losses2), mask) + pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).float(), mask) + + loss = pg_loss + self.config.vf_coef * vf_loss + + avg_ratio = masked_mean(ratio, mask).item() + if avg_ratio > self.config.ratio_threshold: + warnings.warn( + f"The average ratio of batch ({avg_ratio:.2f}) exceeds threshold {self.config.ratio_threshold:.2f}. Skipping batch." + ) + pg_loss = pg_loss * 0.0 + vf_loss = vf_loss * 0.0 + loss = loss * 0.0 + + entropy = masked_mean(entropy_from_logits(logits), mask) + + approxkl = 0.5 * masked_mean((logprobs - old_logprobs) ** 2, mask) + policykl = masked_mean(old_logprobs - logprobs, mask) + + return_mean, return_var = masked_mean(returns, mask), masked_var(returns, mask) + value_mean, value_var = masked_mean(values, mask), masked_var(values, mask) + + stats = dict( + loss=dict(policy=pg_loss.detach(), value=vf_loss.detach(), total=loss.detach()), + policy=dict( + entropy=entropy.detach(), + approxkl=approxkl.detach(), + policykl=policykl.detach(), + clipfrac=pg_clipfrac.detach(), + advantages=advantages.detach(), + advantages_mean=masked_mean(advantages, mask).detach(), + ratio=ratio.detach(), + ), + returns=dict(mean=return_mean.detach(), var=return_var.detach()), + val=dict( + vpred=masked_mean(vpreds, mask).detach(), + error=masked_mean((vpreds - returns) ** 2, mask).detach(), + clipfrac=vf_clipfrac.detach(), + mean=value_mean.detach(), + var=value_var.detach(), + ), + ) + return pg_loss, self.config.vf_coef * vf_loss, flatten_dict(stats) + + def record_step_stats(self, kl_coef: float, **data): + """ + Record training step statistics. + + + Args: + kl_coef (`float`): + KL coefficient + data (`dict`): + Dictionary of training step data + + Returns: + stats (`dict`): + Dictionary of training step statistics + """ + mask = data.pop("masks") + + kl_list = ((data["logprobs"] - data["ref_logprobs"]) * mask).sum(axis=-1) + mean_kl = kl_list.mean() + mean_entropy = (-data["logprobs"] * mask).sum(axis=-1).mean() + + mean_non_score_reward = masked_mean( + data["non_score_reward"], mask + ) # non_score_reward is size `batch_size`, `response_length` + mean_scores = data["scores"].mean() # scores is size `batch_size` + std_scores = data["scores"].std() + + if mean_kl.item() < -1.0: + # warn users + warnings.warn( + f"KL divergence is starting to become negative: {mean_kl.item():.2f} - this might be a precursor for failed training." + " sometimes this happens because the generation kwargs are not correctly set. Please make sure" + " that the generation kwargs are set correctly, or review your training hyperparameters." + ) + + stats = { + "objective/kl": mean_kl, + "objective/kl_dist": kl_list, + "objective/logprobs": data["logprobs"], + "objective/ref_logprobs": data["ref_logprobs"], + "objective/kl_coef": kl_coef, + "objective/entropy": mean_entropy, + "ppo/mean_non_score_reward": mean_non_score_reward, + "ppo/mean_scores": mean_scores, + "ppo/std_scores": std_scores, + } + + # Log text properties + query_lens = torch.tensor([len(query) for query in data["queries"]], dtype=torch.float) + response_lens = torch.tensor([len(response) for response in data["responses"]], dtype=torch.float) + + stats["tokens/queries_len_mean"] = torch.mean(query_lens).cpu().numpy().item() + stats["tokens/queries_len_std"] = torch.std(query_lens).cpu().numpy().item() + stats["tokens/queries_dist"] = query_lens.cpu().numpy() + stats["tokens/responses_len_mean"] = torch.mean(response_lens).cpu().numpy().item() + stats["tokens/responses_len_std"] = torch.std(response_lens).cpu().numpy().item() + stats["tokens/responses_dist"] = response_lens.cpu().numpy() + + for k, v in data["train_stats"].items(): + stats[f"ppo/{k}"] = torch.mean(v, axis=0) + stats["ppo/val/var_explained"] = 1 - stats["ppo/val/error"] / stats["ppo/returns/var"] + return stats + + def log_stats( + self, + stats: dict, + batch: dict, + rewards: List[torch.FloatTensor], + columns_to_log: List[str] = ["query", "response"], + ): + """ + A function that logs all the training stats. Call it at the end of each epoch. + + Args: + stats (dict[str, Any]): + A dictionary of training stats. + batch (dict[str, Any]): + A dictionary of batch data, this contains the queries and responses. + rewards (`List[torch.FloatTensor]`): + A tensor of rewards. + """ + if not isinstance(rewards, torch.Tensor): + rewards = torch.tensor(rewards).to(self.current_device) + rewards = self.accelerator.gather(rewards).flatten() + + # Log only if we are in the main process + if self.accelerator.is_main_process: + logs = {} + + # Log stats + if "query" not in batch.keys() and "response" not in batch.keys(): + # warn the user that the game logs will not be logged + warnings.warn( + "The game logs will not be logged because the batch does not contain the keys 'query' and " + "'response'. " + ) + elif self.config.log_with == "wandb": + import wandb + + if any([column_to_log not in batch.keys() for column_to_log in columns_to_log]): + raise ValueError(f"Columns to log {columns_to_log} are not present in the batch {batch.keys()}.") + + batch_list = [batch[column_to_log] for column_to_log in columns_to_log] + if self.is_distributed: + self.accelerator.wait_for_everyone() + gathered_batch_list = [] + for batch in batch_list: + flattened = gather_object(batch) + gathered_batch_list.append(flattened) + batch_list = gathered_batch_list + + table_rows = [list(r) for r in zip(*batch_list, rewards.cpu().tolist())] + logs.update({"game_log": wandb.Table(columns=[*columns_to_log, "reward"], rows=table_rows)}) + + logs.update(stats) + + # manually cast in fp32 for bf16 torch tensors + for k, v in logs.items(): + if isinstance(v, torch.Tensor) and v.dtype == torch.bfloat16: + logs[k] = v.float() + + logs["env/reward_mean"] = torch.mean(rewards).cpu().numpy().item() + logs["env/reward_std"] = torch.std(rewards).cpu().numpy().item() + logs["env/reward_dist"] = rewards.cpu().numpy() + + if self.config.log_with == "tensorboard": + # update the current step + self.current_step += 1 + + self.accelerator.log( + logs, + step=self.current_step if self.config.log_with == "tensorboard" else None, + ) + + def create_model_card(self, path: str, model_name: Optional[str] = "TRL Model") -> None: + """Creates and saves a model card for a TRL model. + + Args: + path (`str`): The path to save the model card to. + model_name (`str`, *optional*): The name of the model, defaults to `TRL Model`. + """ + try: + user = whoami()["name"] + # handle the offline case + except: # noqa + warnings.warn("Cannot retrieve user information assuming you are running in offline mode.") + return + + if not os.path.exists(path): + os.makedirs(path) + + model_card_content = MODEL_CARD_TEMPLATE.format(model_name=model_name, model_id=f"{user}/{path}") + with open(os.path.join(path, "README.md"), "w", encoding="utf-8") as f: + f.write(model_card_content) + + def _save_pretrained(self, save_directory: str) -> None: + self.accelerator.unwrap_model(self.model).save_pretrained(save_directory) + self.tokenizer.save_pretrained(save_directory) + self.create_model_card(save_directory) + + def _show_tokens(self, tokens, masks): + from rich import print + from rich.text import Text + + text = Text() + + for i, (token, mask) in enumerate(zip(tokens, masks)): + if mask == 1: + text.append(self.tokenizer.decode(token.item()), style="black on deep_sky_blue1") + text.append(" ") + else: + text.append(self.tokenizer.decode(token.item()), style="black on cyan3") + text.append(" ") + print(text) + + def _prepare_deepspeed(self, model: PreTrainedModelWrapper): + # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 + deepspeed_plugin = self.accelerator.state.deepspeed_plugin + config_kwargs = deepspeed_plugin.deepspeed_config + if model is not None: + if hasattr(model, "config"): + hidden_size = ( + max(model.config.hidden_sizes) + if getattr(model.config, "hidden_sizes", None) + else getattr(model.config, "hidden_size", None) + ) + if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3: + # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0` + # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081 + config_kwargs.update( + { + "zero_optimization.reduce_bucket_size": hidden_size * hidden_size, + "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, + "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size, + } + ) + + # If ZeRO-3 is used, we shard both the active and reference model. + # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0) + if config_kwargs["zero_optimization"]["stage"] != 3: + config_kwargs["zero_optimization"]["stage"] = 0 + model, *_ = deepspeed.initialize(model=model, config=config_kwargs) + model.eval() + return model diff --git a/trl/trl/trl/trainer/reward_trainer.py b/trl/trl/trl/trainer/reward_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..ed81ef73e653f9c25fb1fee06109cd332d35f810 --- /dev/null +++ b/trl/trl/trl/trainer/reward_trainer.py @@ -0,0 +1,277 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import warnings +from dataclasses import FrozenInstanceError, replace +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from datasets import Dataset +from transformers import DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer, TrainingArguments +from transformers.trainer_callback import TrainerCallback +from transformers.trainer_pt_utils import nested_detach +from transformers.trainer_utils import EvalPrediction + +from ..import_utils import is_peft_available +from .training_configs import RewardConfig +from .utils import PeftSavingCallback, RewardDataCollatorWithPadding, compute_accuracy + + +if is_peft_available(): + from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training + + +class RewardTrainer(Trainer): + r""" + The RewardTrainer can be used to train your custom Reward Model. It is a subclass of the + `transformers.Trainer` class and inherits all of its attributes and methods. It is recommended to use + an `AutoModelForSequenceClassification` as the reward model. The reward model should be trained on a dataset + of paired examples, where each example is a tuple of two sequences. The reward model should be trained to + predict which example in the pair is more relevant to the task at hand. + + The reward trainer expects a very specific format for the dataset. The dataset should contain two 4 entries at least + if you don't use the default `RewardDataCollatorWithPadding` data collator. The entries should be named + - `input_ids_chosen` + - `attention_mask_chosen` + - `input_ids_rejected` + - `attention_mask_rejected` + + Optionally, you can also pass a `margin` entry to the dataset. This entry should contain the margin used to modulate the + loss of the reward model as outlined in https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/. + If you don't pass a margin, no margin will be used. + """ + + def __init__( + self, + model: Union[PreTrainedModel, nn.Module] = None, + args: Optional[RewardConfig] = None, + data_collator: Optional[DataCollator] = None, + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, + compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, + callbacks: Optional[List[TrainerCallback]] = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( + None, + None, + ), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + max_length: Optional[int] = None, + peft_config: Optional[Dict] = None, + ): + """ + Initialize RewardTrainer. + + Args: + model (`transformers.PreTrainedModel`): + The model to train, preferably an `AutoModelForSequenceClassification`. + args (`RewardConfig`): + The arguments to use for training. + data_collator (`transformers.DataCollator`): + The data collator to use for training. If None is specified, the default data collator (`RewardDataCollatorWithPadding`) will be used + which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences. + train_dataset (`datasets.Dataset`): + The dataset to use for training. + eval_dataset (`datasets.Dataset`): + The dataset to use for evaluation. + tokenizer (`transformers.PreTrainedTokenizerBase`): + The tokenizer to use for training. This argument is required if you want to use the default data collator. + model_init (`Callable[[], transformers.PreTrainedModel]`): + The model initializer to use for training. If None is specified, the default model initializer will be used. + compute_metrics (`Callable[[transformers.EvalPrediction], Dict]`, *optional* defaults to `compute_accuracy`): + The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used. + callbacks (`List[transformers.TrainerCallback]`): + The callbacks to use for training. + optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): + The optimizer and scheduler to use for training. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): + The function to use to preprocess the logits before computing the metrics. + peft_config (`Dict`, defaults to `None`): + The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model. + """ + if type(args) == TrainingArguments: + warnings.warn( + "Using `transformers.TrainingArguments` for `args` is deprecated and will be removed in a future version. Please use `RewardConfig` instead.", + FutureWarning, + ) + if max_length is not None: + warnings.warn( + "The `max_length` argument is deprecated and will be removed in a future version. Please use the `RewardConfig` to set `max_length` instead.", + FutureWarning, + ) + else: + if max_length is not None and args.max_length is not None: + raise ValueError( + "You cannot specify both `max_length` and `args.max_length`. Please use the `RewardConfig` to set `max_length` once." + ) + if max_length is not None and args.max_length is None: + warnings.warn( + "The `max_length` argument is deprecated and will be removed in a future version. Please use the `RewardConfig` to set `max_length` instead.", + FutureWarning, + ) + if not is_peft_available() and peft_config is not None: + raise ValueError( + "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" + ) + elif is_peft_available() and peft_config is not None: + if not isinstance(model, PeftModel): + if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False): + _supports_gc_kwargs = "gradient_checkpointing_kwargs" in list( + inspect.signature(prepare_model_for_kbit_training).parameters + ) + + preprare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} + + if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None: + warnings.warn( + "You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. " + "please update to the latest version of peft to use `gradient_checkpointing_kwargs`." + ) + elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None: + preprare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs + + model = prepare_model_for_kbit_training(model, **preprare_model_kwargs) + + model = get_peft_model(model, peft_config) + + if is_peft_available() and isinstance(model, PeftModel): + if callbacks is None: + callbacks = [PeftSavingCallback()] + else: + callbacks += [PeftSavingCallback()] + + if compute_metrics is None: + compute_metrics = compute_accuracy + + if data_collator is None: + if tokenizer is None: + raise ValueError( + "max_length or a tokenizer must be specified when using the default RewardDataCollatorWithPadding" + ) + if type(args) == TrainingArguments: + if max_length is None: + warnings.warn( + "When using RewardDataCollatorWithPadding, you should set `max_length` in RewardConfig." + " It will be set to `512` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_length = 512 + else: + if max_length is None and args.max_length is None: + warnings.warn( + "When using RewardDataCollatorWithPadding, you should set `max_length` in RewardConfig." + " It will be set to `512` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_length = 512 + if max_length is None and args.max_length is not None: + max_length = args.max_length + + data_collator = RewardDataCollatorWithPadding(tokenizer, max_length=max_length) + + if args.remove_unused_columns: + try: # for bc before https://github.com/huggingface/transformers/pull/25435 + args.remove_unused_columns = False + except FrozenInstanceError: + args = replace(args, remove_unused_columns=False) + # warn users + warnings.warn( + "When using RewardDataCollatorWithPadding, you should set `remove_unused_columns=False` in your RewardConfig" + " we have set it for you, but you should do it yourself in the future.", + UserWarning, + ) + + self.use_reward_data_collator = True + else: + self.use_reward_data_collator = False + super().__init__( + model, + args, + data_collator, + train_dataset, + eval_dataset, + tokenizer, + model_init, + compute_metrics, + callbacks, + optimizers, + preprocess_logits_for_metrics, + ) + + def compute_loss( + self, + model: Union[PreTrainedModel, nn.Module], + inputs: Dict[str, Union[torch.Tensor, Any]], + return_outputs=False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: + if not self.use_reward_data_collator: + warnings.warn( + "The current compute_loss is implemented for RewardDataCollatorWithPadding," + " if you are using a custom data collator make sure you know what you are doing or" + " implement your own compute_loss method." + ) + rewards_chosen = model( + input_ids=inputs["input_ids_chosen"], + attention_mask=inputs["attention_mask_chosen"], + )[0] + rewards_rejected = model( + input_ids=inputs["input_ids_rejected"], + attention_mask=inputs["attention_mask_rejected"], + )[0] + # calculate loss, optionally modulate with margin + if "margin" in inputs: + loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean() + else: + loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean() + + if return_outputs: + return loss, { + "rewards_chosen": rewards_chosen, + "rewards_rejected": rewards_rejected, + } + return loss + + def prediction_step( + self, + model: Union[PreTrainedModel, nn.Module], + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + inputs = self._prepare_inputs(inputs) + if ignore_keys is None: + if hasattr(self.model, "config"): + ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + with torch.no_grad(): + loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True) + + if prediction_loss_only: + return (loss, None, None) + + loss = loss.detach() + logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys) + logits = nested_detach(logits) + # Stack accepted against rejected, mean over logits + # and softmax to get preferences between accepted and rejected to sum to 1 + logits = torch.stack(logits).mean(dim=2).softmax(dim=0).T + + labels = torch.zeros(logits.shape[0]) + labels = self._prepare_inputs(labels) + + return loss, logits, labels diff --git a/trl/trl/trl/trainer/sft_trainer.py b/trl/trl/trl/trainer/sft_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..6c4a6c536a651c6d5e02dc5c580a35cdabb7e106 --- /dev/null +++ b/trl/trl/trl/trainer/sft_trainer.py @@ -0,0 +1,451 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import dataclasses +import inspect +import warnings +from functools import wraps +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from datasets import Dataset +from datasets.arrow_writer import SchemaInferenceError +from datasets.builder import DatasetGenerationError +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + DataCollator, + DataCollatorForLanguageModeling, + PreTrainedModel, + PreTrainedTokenizerBase, + Trainer, + TrainingArguments, +) +from transformers.modeling_utils import unwrap_model +from transformers.trainer_callback import TrainerCallback +from transformers.trainer_utils import EvalPrediction + +from ..import_utils import is_peft_available +from .utils import ( + ConstantLengthDataset, + DataCollatorForCompletionOnlyLM, + PeftSavingCallback, + neftune_post_forward_hook, +) + + +if is_peft_available(): + from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training + + +class SFTTrainer(Trainer): + r""" + Class definition of the Supervised Finetuning Trainer (SFT Trainer). + This class is a wrapper around the `transformers.Trainer` class and inherits all of its attributes and methods. + The trainer takes care of properly initializing the PeftModel in case a user passes a `PeftConfig` object. + + Args: + model (Union[`transformers.PreTrainedModel`, `nn.Module`, `str`]): + The model to train, can be a `PreTrainedModel`, a `torch.nn.Module` or a string with the model name to + load from cache or download. The model can be also converted to a `PeftModel` if a `PeftConfig` object is + passed to the `peft_config` argument. + args (Optional[`transformers.TrainingArguments`]): + The arguments to tweak for training. Please refer to the official documentation of `transformers.TrainingArguments` + for more information. + data_collator (Optional[`transformers.DataCollator`]): + The data collator to use for training. + train_dataset (Optional[`datasets.Dataset`]): + The dataset to use for training. We recommend users to use `trl.trainer.ConstantLengthDataset` to create their dataset. + eval_dataset (Optional[Union[`datasets.Dataset`, Dict[`str`, `datasets.Dataset`]]]): + The dataset to use for evaluation. We recommend users to use `trl.trainer.ConstantLengthDataset` to create their dataset. + tokenizer (Optional[`transformers.PreTrainedTokenizer`]): + The tokenizer to use for training. If not specified, the tokenizer associated to the model will be used. + model_init (`Callable[[], transformers.PreTrainedModel]`): + The model initializer to use for training. If None is specified, the default model initializer will be used. + compute_metrics (`Callable[[transformers.EvalPrediction], Dict]`, *optional* defaults to None): + The function used to compute metrics during evaluation. It should return a dictionary mapping metric names to metric values. + If not specified, only the loss will be computed during evaluation. + callbacks (`List[transformers.TrainerCallback]`): + The callbacks to use for training. + optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): + The optimizer and scheduler to use for training. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): + The function to use to preprocess the logits before computing the metrics. + peft_config (`Optional[PeftConfig]`): + The PeftConfig object to use to initialize the PeftModel. + dataset_text_field (`Optional[str]`): + The name of the text field of the dataset, in case this is passed by a user, the trainer will automatically create a + `ConstantLengthDataset` based on the `dataset_text_field` argument. + formatting_func (`Optional[Callable]`): + The formatting function to be used for creating the `ConstantLengthDataset`. + max_seq_length (`Optional[int]`): + The maximum sequence length to use for the `ConstantLengthDataset` and for automaticallty creating the Dataset. Defaults to `512`. + infinite (`Optional[bool]`): + Whether to use an infinite dataset or not. Defaults to `False`. + num_of_sequences (`Optional[int]`): + The number of sequences to use for the `ConstantLengthDataset`. Defaults to `1024`. + chars_per_token (`Optional[float]`): + The number of characters per token to use for the `ConstantLengthDataset`. Defaults to `3.6`. You can check how this is computed in the + stack-llama example: https://github.com/huggingface/trl/blob/08f550674c553c36c51d1027613c29f14f3676a5/examples/stack_llama/scripts/supervised_finetuning.py#L53. + packing (`Optional[bool]`): + Used only in case `dataset_text_field` is passed. This argument is used by the `ConstantLengthDataset` to pack the sequences + of the dataset. + dataset_num_proc (`Optional[int]`): + The number of workers to use to tokenize the data. Only used when `packing=False`. Defaults to None. + dataset_batch_size (`int`): + The number of examples to tokenize per batch. If batch_size <= 0 or batch_size == None, + tokenize the full dataset as a single batch. Defaults to 1000. + neftune_noise_alpha (`Optional[float]`): + If not `None`, this will activate NEFTune noise embeddings. This has been proven to drastically improve model performances for instrcution + fine-tuning. Check out the original paper here: https://arxiv.org/abs/2310.05914 and the original code here: https://github.com/neelsjain/NEFTune + model_init_kwargs: (`Optional[Dict]`, *optional*): + Dict of Optional kwargs to pass when instantiating the model from a string + """ + + def __init__( + self, + model: Union[PreTrainedModel, nn.Module, str] = None, + args: TrainingArguments = None, + data_collator: Optional[DataCollator] = None, + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, + compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, + callbacks: Optional[List[TrainerCallback]] = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + peft_config: Optional["PeftConfig"] = None, + dataset_text_field: Optional[str] = None, + packing: Optional[bool] = False, + formatting_func: Optional[Callable] = None, + max_seq_length: Optional[int] = None, + infinite: Optional[bool] = None, + num_of_sequences: Optional[int] = 1024, + chars_per_token: Optional[float] = 3.6, + dataset_num_proc: Optional[int] = None, + dataset_batch_size: int = 1000, + neftune_noise_alpha: Optional[float] = None, + model_init_kwargs: Optional[Dict] = None, + ): + if model_init_kwargs is None: + model_init_kwargs = {} + elif not isinstance(model, str): + raise ValueError("You passed model_kwargs to the SFTTrainer. But your model is already instantiated.") + + if infinite is not None: + warnings.warn( + "The `infinite` argument is deprecated and will be removed in a future version of TRL. Use `TrainingArguments.max_steps` or `TrainingArguments.num_train_epochs` instead to control training length." + ) + + if isinstance(model, str): + warnings.warn( + "You passed a model_id to the SFTTrainer. This will automatically create an " + "`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you." + ) + model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) + + if packing and data_collator is not None and isinstance(data_collator, DataCollatorForCompletionOnlyLM): + raise ValueError( + "You passed a `DataCollatorForCompletionOnlyLM` to the SFTTrainer. This is not compatible with the `packing` argument." + ) + + if is_peft_available() and peft_config is not None: + if not isinstance(peft_config, PeftConfig): + raise ValueError( + "If you want to use the PeftModel, you need to pass a PeftConfig object to the SFTTrainer." + f" and you passed a {type(peft_config)}." + ) + + if not isinstance(model, PeftModel): + if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): + _support_gc_kwargs = hasattr( + args, "gradient_checkpointing_kwargs" + ) and "gradient_checkpointing_kwargs" in list( + inspect.signature(prepare_model_for_kbit_training).parameters + ) + + preprare_model_kwargs = { + "use_gradient_checkpointing": getattr(args, "gradient_checkpointing", False) + } + + if _support_gc_kwargs: + preprare_model_kwargs["gradient_checkpointing_kwargs"] = getattr( + args, "gradient_checkpointing_kwargs", None + ) + + model = prepare_model_for_kbit_training(model, **preprare_model_kwargs) + + if args is not None: + args = dataclasses.replace(args, gradient_checkpointing=False) + elif getattr(args, "gradient_checkpointing", False): + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + model = get_peft_model(model, peft_config) + + if callbacks is None: + callbacks = [PeftSavingCallback] + + if tokenizer is None: + tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path) + if getattr(tokenizer, "pad_token", None) is None: + tokenizer.pad_token = tokenizer.eos_token + + if max_seq_length is None: + # to overcome some issues with broken tokenizers + max_seq_length = min(tokenizer.model_max_length, 1024) + + warnings.warn( + f"You didn't pass a `max_seq_length` argument to the SFTTrainer, this will default to {max_seq_length}" + ) + + self.dataset_num_proc = dataset_num_proc + self.dataset_batch_size = dataset_batch_size + + self._trainer_supports_neftune = hasattr(args, "neftune_noise_alpha") + + if neftune_noise_alpha is not None and self._trainer_supports_neftune: + args.neftune_noise_alpha = neftune_noise_alpha + warnings.warn( + "You passed a `neftune_noise_alpha` argument to the SFTTrainer, the value you passed will override the one in the `TrainingArguments`." + ) + # self.neftune_noise_alpha is done at Trainer level + elif not self._trainer_supports_neftune: + self.neftune_noise_alpha = neftune_noise_alpha + + if not packing: + if dataset_text_field is None and formatting_func is None: + raise ValueError( + "You passed `packing=False` to the SFTTrainer, but you didn't pass a `dataset_text_field` or `formatting_func` argument." + ) + + if data_collator is None: + data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + + if train_dataset is not None: + train_dataset = self._prepare_dataset( + train_dataset, + tokenizer, + packing, + dataset_text_field, + max_seq_length, + formatting_func, + num_of_sequences, + chars_per_token, + ) + if eval_dataset is not None: + _multiple = isinstance(eval_dataset, dict) + _eval_datasets = eval_dataset if _multiple else {"singleton": eval_dataset} + for _eval_dataset_name, _eval_dataset in _eval_datasets.items(): + _eval_datasets[_eval_dataset_name] = self._prepare_dataset( + _eval_dataset, + tokenizer, + packing, + dataset_text_field, + max_seq_length, + formatting_func, + num_of_sequences, + chars_per_token, + ) + if not _multiple: + eval_dataset = _eval_datasets["singleton"] + + if tokenizer.padding_side is not None and tokenizer.padding_side != "right": + warnings.warn( + "You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to " + "overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code." + ) + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + model_init=model_init, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + if self.args.max_steps > 0 and packing: + warnings.warn( + "You passed `packing=True` to the SFTTrainer, and you are training your model with `max_steps` strategy. The dataset will be iterated until the `max_steps` are reached." + ) + self.train_dataset.infinite = True + elif self.args.max_steps == -1 and packing: + self.train_dataset.infinite = False + + @wraps(Trainer.train) + def train(self, *args, **kwargs): + # Activate neftune right before training. + if self.neftune_noise_alpha is not None and not self._trainer_supports_neftune: + self.model = self._trl_activate_neftune(self.model) + + output = super().train(*args, **kwargs) + + # After training we make sure to retrieve back the original forward pass method + # for the embedding layer by removing the forward post hook. + if self.neftune_noise_alpha is not None and not self._trainer_supports_neftune: + unwrapped_model = unwrap_model(self.model) + if is_peft_available() and isinstance(unwrapped_model, PeftModel): + embeddings = unwrapped_model.base_model.model.get_input_embeddings() + else: + embeddings = unwrapped_model.get_input_embeddings() + + self.neftune_hook_handle.remove() + del embeddings.neftune_noise_alpha + + return output + + def _prepare_dataset( + self, + dataset, + tokenizer, + packing, + dataset_text_field, + max_seq_length, + formatting_func, + num_of_sequences, + chars_per_token, + ): + if dataset is None: + raise ValueError("The dataset should not be None") + + # check if torch dataset / dataloader and do nothing + if isinstance(dataset, (torch.utils.data.IterableDataset, torch.utils.data.Dataset, ConstantLengthDataset)): + return dataset + + if not packing: + return self._prepare_non_packed_dataloader( + tokenizer, dataset, dataset_text_field, max_seq_length, formatting_func + ) + + else: + return self._prepare_packed_dataloader( + tokenizer, + dataset, + dataset_text_field, + max_seq_length, + num_of_sequences, + chars_per_token, + formatting_func, + ) + + def _prepare_non_packed_dataloader( + self, tokenizer, dataset, dataset_text_field, max_seq_length, formatting_func=None + ): + use_formatting_func = formatting_func is not None and dataset_text_field is None + self._dataset_sanity_checked = False + + # Inspired from: https://huggingface.co/learn/nlp-course/chapter7/6?fw=pt + def tokenize(element): + outputs = tokenizer( + element[dataset_text_field] if not use_formatting_func else formatting_func(element), + truncation=True, + padding=False, + max_length=max_seq_length, + return_overflowing_tokens=False, + return_length=False, + ) + + if use_formatting_func and not self._dataset_sanity_checked: + if not isinstance(formatting_func(element), list): + raise ValueError( + "The `formatting_func` should return a list of processed strings since it can lead to silent bugs." + ) + else: + self._dataset_sanity_checked = True + + return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]} + + tokenized_dataset = dataset.map( + tokenize, + batched=True, + remove_columns=dataset.column_names, + num_proc=self.dataset_num_proc, + batch_size=self.dataset_batch_size, + ) + + return tokenized_dataset + + def _prepare_packed_dataloader( + self, + tokenizer, + dataset, + dataset_text_field, + max_seq_length, + num_of_sequences, + chars_per_token, + formatting_func=None, + ): + if dataset_text_field is not None or formatting_func is not None: + if tokenizer is None: + raise ValueError("You need to pass a tokenizer when using `dataset_text_field` with `SFTTrainer`.") + + constant_length_iterator = ConstantLengthDataset( + tokenizer, + dataset, + dataset_text_field=dataset_text_field, + formatting_func=formatting_func, + seq_length=max_seq_length, + infinite=False, + num_of_sequences=num_of_sequences, + chars_per_token=chars_per_token, + eos_token_id=tokenizer.eos_token_id, + ) + + def data_generator(constant_length_iterator): + for i in constant_length_iterator: + yield i + + try: + packed_dataset = Dataset.from_generator( + data_generator, gen_kwargs={"constant_length_iterator": constant_length_iterator} + ) + except (DatasetGenerationError, SchemaInferenceError): + raise ValueError( + "Error occurred while packing the dataset. Make sure that your dataset has enough samples to at least yield one packed sequence." + ) + return packed_dataset + else: + raise ValueError( + "You need to pass a `dataset_text_field` or `formatting_func` argument to the SFTTrainer if you want to use the `ConstantLengthDataset`." + ) + + def _trl_activate_neftune(self, model): + r""" + Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper: https://arxiv.org/abs/2310.05914 + Since in transformers Trainer we do have an `_activate_neftune` method, we need to rename this method to avoid conflicts. + """ + unwrapped_model = unwrap_model(model) + if is_peft_available() and isinstance(unwrapped_model, PeftModel): + embeddings = unwrapped_model.base_model.model.get_input_embeddings() + else: + embeddings = unwrapped_model.get_input_embeddings() + + embeddings.neftune_noise_alpha = self.neftune_noise_alpha + hook_handle = embeddings.register_forward_hook(neftune_post_forward_hook) + self.neftune_hook_handle = hook_handle + return model diff --git a/trl/trl/trl/trainer/training_configs.py b/trl/trl/trl/trainer/training_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..8341819c3486664a8227290ed5ba1574b1932d9a --- /dev/null +++ b/trl/trl/trl/trainer/training_configs.py @@ -0,0 +1,43 @@ +# coding=utf-8 +# coding=utf-8 +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional + +from transformers import TrainingArguments + + +@dataclass +class RewardConfig(TrainingArguments): + """ + RewardConfig collects all training arguments related to the [`RewardTrainer`] class. + + Using [`HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + max_length (`int`, *optional*, defaults to `None`): + The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator. + gradient_checkpointing (`bool`, *optional*, defaults to `True`): + If True, use gradient checkpointing to save memory at the expense of slower backward pass. + """ + + max_length: Optional[int] = None + """The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.""" + gradient_checkpointing: Optional[bool] = True + """If True, use gradient checkpointing to save memory at the expense of slower backward pass.""" + gradient_checkpointing_kwargs: Optional[dict] = None + """Keyword arguments to pass to the gradient checkpointing function.""" diff --git a/trl/trl/trl/trainer/utils.py b/trl/trl/trl/trainer/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3c29bedfb269a21fcf049769dd0b34102d21fe26 --- /dev/null +++ b/trl/trl/trl/trainer/utils.py @@ -0,0 +1,779 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import random +import warnings +from collections import deque +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import IterableDataset +from transformers import DataCollatorForLanguageModeling, PreTrainedModel, PreTrainedTokenizerBase, TrainerCallback + + +class AdaptiveKLController: + """ + Adaptive KL controller described in the paper: + https://arxiv.org/pdf/1909.08593.pdf + """ + + def __init__(self, init_kl_coef, target, horizon): + self.value = init_kl_coef + self.target = target + self.horizon = horizon + + def update(self, current, n_steps): + target = self.target + proportional_error = np.clip(current / target - 1, -0.2, 0.2) + mult = 1 + proportional_error * n_steps / self.horizon + self.value *= mult + + +class FixedKLController: + """Fixed KL controller.""" + + def __init__(self, kl_coef): + self.value = kl_coef + + def update(self, current, n_steps): + pass + + +class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling): + """ + Data collator used for completion tasks. It ensures that all the tokens of the labels are set to an 'ignore_index' + when they do not come from the assistant. This ensure that the loss is only + calculated on the completion made by the assistant. + + Args: + instruction_template (`Optional[str]`): the template form that indicates the start of the human instruction, typically something like + '### Human:\n'. Useful for assistant-style conversation datasets + response_template (`Union[str, List[int]]`): the template form that indicates the start of the response, typically something like + '### Response:\n'. It can also be passed as tokenized ids, which can be useful when using a tokenizer that encodes the response + differently if it does not have proper context. + mlm (`bool`, *optional*, defaults to `False`): Whether or not to use masked language modeling in the underlying + `DataCollatorForLanguageModeling` class. Note that this option currently has no effect but is present + for flexibility and backwards-compatibility. + ignore_index (`int`, *optional*, defaults to `-100`): + The index to use to ignore the initial tokens with + """ + + def __init__( + self, + response_template: Union[str, List[int]], + instruction_template: Union[str, List[int]] = None, + *args, + mlm: bool = False, + ignore_index: int = -100, + **kwargs, + ): + super().__init__(*args, mlm=mlm, **kwargs) + + self.instruction_template = instruction_template + if isinstance(instruction_template, str): + # The user provides a string, must tokenize + self.instruction_token_ids = self.tokenizer.encode(self.instruction_template, add_special_tokens=False) + else: + # The user already provides the token ids + self.instruction_token_ids = instruction_template + + self.response_template = response_template + if isinstance(response_template, str): + # The user provides a string, must tokenize + self.response_token_ids = self.tokenizer.encode(self.response_template, add_special_tokens=False) + else: + # The user already provides the token ids + self.response_token_ids = response_template + + if not self.mlm and self.instruction_template and self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: + warnings.warn( + "The pad_token_id and eos_token_id values of this tokenizer are identical. " + "If you are planning for multi-turn training, " + "it can result in the model continuously generating questions and answers without eos token. " + "To avoid this, set the pad_token_id to a different value." + ) + + self.ignore_index = ignore_index + + def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: + batch = super().torch_call(examples) + + if self.instruction_template is None: + for i in range(len(examples)): + response_token_ids_start_idx = None + + for idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]: + # `response_token_ids` is `'### Response:\n'`, here we are just making sure that the token IDs match + if ( + self.response_token_ids + == batch["labels"][i][idx : idx + len(self.response_token_ids)].tolist() + ): + response_token_ids_start_idx = idx + + if response_token_ids_start_idx is None: + warnings.warn( + f"Could not find response key `{self.response_template}` in the " + f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} ' + f"This instance will be ignored in loss calculation. " + f"Note, if this happens often, consider increasing the `max_seq_length`." + ) + batch["labels"][i, :] = self.ignore_index + else: + response_token_ids_end_idx = response_token_ids_start_idx + len(self.response_token_ids) + + # Make pytorch loss function ignore all tokens up through the end of the response key + batch["labels"][i, :response_token_ids_end_idx] = self.ignore_index + + else: + for i in range(len(examples)): + response_token_ids_idxs = [] + human_token_ids_idxs = [] + + for assistant_idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]: + # find the indexes of the start of a response. + if ( + self.response_token_ids + == batch["labels"][i][assistant_idx : assistant_idx + len(self.response_token_ids)].tolist() + ): + response_token_ids_idxs.append(assistant_idx + len(self.response_token_ids)) + + if len(response_token_ids_idxs) == 0: + warnings.warn( + f"Could not find response key `{self.response_template}` in the " + f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} ' + f"This instance will be ignored in loss calculation. " + f"Note, if this happens often, consider increasing the `max_seq_length`." + ) + batch["labels"][i, :] = self.ignore_index + + human_token_ids = self.instruction_token_ids + for human_idx in np.where(batch["labels"][i] == human_token_ids[0])[0]: + # find the indexes of the start of a human answer. + if human_token_ids == batch["labels"][i][human_idx : human_idx + len(human_token_ids)].tolist(): + human_token_ids_idxs.append(human_idx) + + if len(human_token_ids_idxs) == 0: + warnings.warn( + f"Could not find instruction key `{self.instruction_template}` in the " + f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} ' + f"This instance will be ignored in loss calculation. " + f"Note, if this happens often, consider increasing the `max_seq_length`." + ) + batch["labels"][i, :] = self.ignore_index + + for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)): + # Make pytorch loss function ignore all non response tokens + if idx != 0: + batch["labels"][i, start:end] = self.ignore_index + else: + batch["labels"][i, :end] = self.ignore_index + + if len(response_token_ids_idxs) < len(human_token_ids_idxs): + batch["labels"][i, human_token_ids_idxs[-1] :] = self.ignore_index + + return batch + + +@dataclass +class RewardDataCollatorWithPadding: + r""" + Reward DataCollator class that pads the inputs to the maximum length of the batch. + Args: + tokenizer (`PreTrainedTokenizerBase`): + The tokenizer used for encoding the data. + padding (`Union[bool, str, `PaddingStrategy`]`, `optional`, defaults to `True`): + padding_strategy to pass to the tokenizer. + max_length (`Optional[int]`, `optional`, defaults to `None`): + The maximum length of the sequence to be processed. + pad_to_multiple_of (`Optional[int]`, `optional`, defaults to `None`): + If set will pad the sequence to a multiple of the provided value. + return_tensors (`str`, `optional`, defaults to `"pt"`): + The tensor type to use. + """ + tokenizer: PreTrainedTokenizerBase + padding: Union[bool, str] = True + max_length: Optional[int] = None + pad_to_multiple_of: Optional[int] = None + return_tensors: str = "pt" + + def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: + features_chosen = [] + features_rejected = [] + margin = [] + # check if we have a margin. If we do, we need to batch it as well + has_margin = "margin" in features[0] + for feature in features: + # check if the keys are named as expected + if ( + "input_ids_chosen" not in feature + or "input_ids_rejected" not in feature + or "attention_mask_chosen" not in feature + or "attention_mask_rejected" not in feature + ): + raise ValueError( + "The features should include `input_ids_chosen`, `attention_mask_chosen`, `input_ids_rejected` and `attention_mask_rejected`" + ) + + features_chosen.append( + { + "input_ids": feature["input_ids_chosen"], + "attention_mask": feature["attention_mask_chosen"], + } + ) + features_rejected.append( + { + "input_ids": feature["input_ids_rejected"], + "attention_mask": feature["attention_mask_rejected"], + } + ) + if has_margin: + margin.append(feature["margin"]) + batch_chosen = self.tokenizer.pad( + features_chosen, + padding=self.padding, + max_length=self.max_length, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors=self.return_tensors, + ) + batch_rejected = self.tokenizer.pad( + features_rejected, + padding=self.padding, + max_length=self.max_length, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors=self.return_tensors, + ) + batch = { + "input_ids_chosen": batch_chosen["input_ids"], + "attention_mask_chosen": batch_chosen["attention_mask"], + "input_ids_rejected": batch_rejected["input_ids"], + "attention_mask_rejected": batch_rejected["attention_mask"], + "return_loss": True, + } + if has_margin: + margin = torch.tensor(margin, dtype=torch.float) + batch["margin"] = margin + return batch + + +@dataclass +class DPODataCollatorWithPadding: + r""" + DPO DataCollator class that pads the inputs to the maximum length of the batch. + Args: + tokenizer (`PreTrainedTokenizerBase`): + The tokenizer used for encoding the data. + model (Optional[`PreTrainedModel`]): + The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to + prepare the *decoder_input_ids*. + padding (`Union[bool, str, `PaddingStrategy`]`, `optional`, defaults to `True`): + padding_strategy to pass to the tokenizer. + max_length (`Optional[int]`, `optional`, defaults to `None`): + The maximum length of the sequence to be processed. + max_prompt_length (`Optional[int]`, `optional`, defaults to `None`): + The maximum length of the prompt to be processed. + label_pad_token_id (`int`, defaults to -100): + The label used for masking. + padding_value (`int`, defaults to 0): + The value used for padding. + is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`): + Whether or not you model has an encoder_decoder architecture. + max_target_length (`Optional[int]`, `optional`, defaults to `None`): + The maximum length of the target to be processed. Only useful for encoder-decoder architectures. + truncation_mode: (`str`, defaults to "keep_end"): + The truncation mode to use when truncating the prompt. + """ + tokenizer: PreTrainedTokenizerBase + model: Optional[PreTrainedModel] = None + padding: Union[bool, str] = True + max_length: Optional[int] = None + max_prompt_length: Optional[int] = None + label_pad_token_id: int = -100 + padding_value: int = 0 + truncation_mode: str = "keep_end" + is_encoder_decoder: Optional[bool] = False + max_target_length: Optional[int] = None + + def tokenize_batch_element( + self, + prompt: str, + chosen: str, + rejected: str, + ) -> Dict: + """Tokenize a single batch element. + + At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation + in case the prompt + chosen or prompt + rejected responses is/are too long. First + we truncate the prompt; if we're still too long, we truncate the chosen/rejected. + + We also create the labels for the chosen/rejected responses, which are of length equal to + the sum of the length of the prompt and the chosen/rejected response, with + label_pad_token_id for the prompt tokens. + """ + batch = {} + + if not self.is_encoder_decoder: + chosen_tokens = self.tokenizer(chosen, add_special_tokens=False) + rejected_tokens = self.tokenizer(rejected, add_special_tokens=False) + prompt_tokens = self.tokenizer(prompt, add_special_tokens=False) + + eos_token_id = self.tokenizer.eos_token_id + # Get indices in list prompt_tokens["input_ids"] that equals the EOS token (often 0) + eos_indices_prompt = [i for i, x in enumerate(prompt_tokens["input_ids"]) if x == eos_token_id] + # attention mask these indices to eos_token_id + new_attention_mask = [ + 0 if i in eos_indices_prompt else p for i, p in enumerate(prompt_tokens["attention_mask"]) + ] + prompt_tokens["attention_mask"] = new_attention_mask + + # do the same for chosen and rejected + eos_indices_chosen = [i for i, x in enumerate(chosen_tokens["input_ids"]) if x == eos_token_id] + new_attention_mask_c = [ + 0 if i in eos_indices_chosen else p for i, p in enumerate(chosen_tokens["attention_mask"]) + ] + chosen_tokens["attention_mask"] = new_attention_mask_c + + eos_indices_rejected = [i for i, x in enumerate(rejected_tokens["input_ids"]) if x == eos_token_id] + new_attention_mask_r = [ + 0 if i in eos_indices_rejected else p for i, p in enumerate(rejected_tokens["attention_mask"]) + ] + rejected_tokens["attention_mask"] = new_attention_mask_r + + # add EOS token to end of prompt + chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id) + chosen_tokens["attention_mask"].append(1) + + rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id) + rejected_tokens["attention_mask"].append(1) + + longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"])) + + # if combined sequence is too long, truncate the prompt + if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length: + if self.truncation_mode == "keep_start": + prompt_tokens = {k: v[: self.max_prompt_length] for k, v in prompt_tokens.items()} + elif self.truncation_mode == "keep_end": + prompt_tokens = {k: v[-self.max_prompt_length :] for k, v in prompt_tokens.items()} + else: + raise ValueError(f"Unknown truncation mode: {self.truncation_mode}") + + # if that's still too long, truncate the response + if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length: + chosen_tokens = {k: v[: self.max_length - self.max_prompt_length] for k, v in chosen_tokens.items()} + rejected_tokens = { + k: v[: self.max_length - self.max_prompt_length] for k, v in rejected_tokens.items() + } + + # Create labels + chosen_sequence_tokens = {k: prompt_tokens[k] + chosen_tokens[k] for k in chosen_tokens} + rejected_sequence_tokens = {k: prompt_tokens[k] + rejected_tokens[k] for k in rejected_tokens} + chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:] + chosen_sequence_tokens["labels"][: len(prompt_tokens["input_ids"])] = [self.label_pad_token_id] * len( + prompt_tokens["input_ids"] + ) + rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:] + rejected_sequence_tokens["labels"][: len(prompt_tokens["input_ids"])] = [self.label_pad_token_id] * len( + prompt_tokens["input_ids"] + ) + + for k, toks in { + "chosen": chosen_sequence_tokens, + "rejected": rejected_sequence_tokens, + "prompt": prompt_tokens, + }.items(): + for type_key, tokens in toks.items(): + if type_key == "token_type_ids": + continue + batch[f"{k}_{type_key}"] = tokens + + else: + chosen_tokens = self.tokenizer( + chosen, truncation=True, max_length=self.max_target_length, add_special_tokens=True + ) + rejected_tokens = self.tokenizer( + rejected, truncation=True, max_length=self.max_target_length, add_special_tokens=True + ) + prompt_tokens = self.tokenizer( + prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True + ) + + batch["chosen_labels"] = chosen_tokens["input_ids"] + batch["rejected_labels"] = rejected_tokens["input_ids"] + batch["prompt_input_ids"] = prompt_tokens["input_ids"] + batch["prompt_attention_mask"] = prompt_tokens["attention_mask"] + + if self.model is not None and hasattr(self.model, "prepare_decoder_input_ids_from_labels"): + batch["rejected_decoder_input_ids"] = self.model.prepare_decoder_input_ids_from_labels( + labels=batch["rejected_labels"] + ) + batch["chosen_decoder_input_ids"] = self.model.prepare_decoder_input_ids_from_labels( + labels=batch["chosen_labels"] + ) + + batch["prompt"] = prompt + batch["chosen"] = prompt + chosen + batch["rejected"] = prompt + rejected + batch["chosen_response_only"] = chosen + batch["rejected_response_only"] = rejected + + return batch + + def collate(self, batch): + # first, pad everything to the same length + padded_batch = {} + for k in batch[0].keys(): + if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"): + if self.is_encoder_decoder: + to_pad = [torch.LongTensor(ex[k]) for ex in batch] + + if (k.startswith("prompt")) and (k.endswith("input_ids")): + padding_value = self.tokenizer.pad_token_id + elif k.endswith("_attention_mask"): + padding_value = 0 + elif (k.startswith("chosen")) or (k.startswith("rejected")) or ("decoder" in k): + padding_value = self.label_pad_token_id + else: + raise ValueError(f"Unexpected key in batch '{k}'") + padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value) + else: + # adapted from https://stackoverflow.com/questions/73256206 + if "prompt" in k: + to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch] + else: + to_pad = [torch.LongTensor(ex[k]) for ex in batch] + if k.endswith("_input_ids"): + padding_value = self.tokenizer.pad_token_id + elif k.endswith("_labels"): + padding_value = self.label_pad_token_id + elif k.endswith("_attention_mask"): + padding_value = self.padding_value + else: + raise ValueError(f"Unexpected key in batch '{k}'") + + padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value) + # for the prompt, flip back so padding is on left side + if "prompt" in k: + padded_batch[k] = padded_batch[k].flip(dims=[1]) + else: + padded_batch[k] = [ex[k] for ex in batch] + + return padded_batch + + def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: + tokenized_batch = [] + + for feature in features: + prompt = feature["prompt"] + chosen = feature["chosen"] + rejected = feature["rejected"] + + batch_element = self.tokenize_batch_element(prompt, chosen, rejected) + tokenized_batch.append(batch_element) + + # return collated batch + return self.collate(tokenized_batch) + + +class ConstantLengthDataset(IterableDataset): + """ + Iterable dataset that returns constant length chunks of tokens from stream of text files. + The dataset also formats the text before tokenization with a specific format that is provided + by the user. + + Args: + tokenizer (`transformers.PreTrainedTokenizer`): + The processor used for processing the data. + dataset (`dataset.Dataset`): + Dataset with text files. + dataset_text_field (`str`, **optional**): + Name of the field in the dataset that contains the text. Used only if `formatting_func` is `None`. + formatting_func (`Callable`, **optional**): + Function that formats the text before tokenization. Usually it is recommended to have follows a certain + pattern such as `"### Question: {question}\n ### Answer: {answer}\n"` + infinite (`bool`, *optional*, defaults to `False`): + If True the iterator is reset after dataset reaches end else stops. + seq_length (`int`, *optional*, defaults to `1024`): + Length of token sequences to return. + num_of_sequences (`int`, *optional*, defaults to `1024`): + Number of token sequences to keep in buffer. + chars_per_token (`int`, *optional*, defaults to `3.6`): + Number of characters per token used to estimate number of tokens in text buffer. + eos_token_id (`int`, *optional*, defaults to `0`): + Id of the end of sequence token if the passed tokenizer does not have an EOS token. + shuffle ('bool', *optional*, defaults to True) + Shuffle the examples before they are returned + """ + + def __init__( + self, + tokenizer, + dataset, + dataset_text_field=None, + formatting_func=None, + infinite=False, + seq_length=1024, + num_of_sequences=1024, + chars_per_token=3.6, + eos_token_id=0, + shuffle=True, + ): + self.tokenizer = tokenizer + + if tokenizer.eos_token_id is None: + warnings.warn( + "The passed tokenizer does not have an EOS token. We will use the passed eos_token_id instead which corresponds" + f" to {eos_token_id}. If this is not the correct EOS token, make sure to pass the correct eos_token_id." + ) + + self.concat_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id else eos_token_id + self.dataset = dataset + self.seq_length = seq_length + self.infinite = infinite + self.current_size = 0 + self.max_buffer_size = seq_length * chars_per_token * num_of_sequences + self.shuffle = shuffle + if formatting_func is None: + self.formatting_func = lambda x: x[dataset_text_field] + else: + self.formatting_func = formatting_func + + if formatting_func is not None: + if formatting_func.__code__.co_argcount > 1: + warnings.warn( + "The passed formatting_func has more than one argument. Usually that function should have a single argument `example`" + " which corresponds to the dictionary returned by each element of the dataset. Make sure you know what you are doing." + ) + + def __len__(self): + return len(self.dataset) + + def __iter__(self): + iterator = iter(self.dataset) + more_examples = True + while more_examples: + buffer, buffer_len = [], 0 + while True: + if buffer_len >= self.max_buffer_size: + break + try: + buffer.append(self.formatting_func(next(iterator))) + buffer_len += len(buffer[-1]) + except StopIteration: + if self.infinite: + iterator = iter(self.dataset) + warnings.warn("The dataset reached end and the iterator is reset to the start.") + else: + more_examples = False + break + tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"] + all_token_ids = [] + for tokenized_input in tokenized_inputs: + all_token_ids.extend(tokenized_input + [self.concat_token_id]) + examples = [] + for i in range(0, len(all_token_ids), self.seq_length): + input_ids = all_token_ids[i : i + self.seq_length] + if len(input_ids) == self.seq_length: + examples.append(input_ids) + if self.shuffle: + random.shuffle(examples) + for example in examples: + self.current_size += 1 + yield { + "input_ids": torch.LongTensor(example), + "labels": torch.LongTensor(example), + } + + +class PeftSavingCallback(TrainerCallback): + def on_save(self, args, state, control, **kwargs): + if args.should_save: + checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}") + kwargs["model"].save_pretrained(checkpoint_path) + + if "pytorch_model.bin" in os.listdir(checkpoint_path): + os.remove(os.path.join(checkpoint_path, "pytorch_model.bin")) + + +class RunningMoments: + def __init__(self, accelerator): + """ + Calculates the running mean and standard deviation of a data stream. Reference: + https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/utils.py#L75 + """ + self.mean = 0 + self.std = 1 + self.var = 1 + self.count = 1e-24 + self.accelerator = accelerator + + @torch.no_grad() + def update(self, xs: torch.Tensor) -> Tuple[float, float]: + """ + Updates running moments from batch's moments computed across ranks + """ + if self.accelerator.use_distributed: + xs_mean, xs_var, xs_count = get_global_statistics(self.accelerator, xs) + else: + xs_count = xs.numel() + xs_var, xs_mean = torch.var_mean(xs, unbiased=False) + xs_mean, xs_var = xs_mean.float(), xs_var.float() + + delta = xs_mean - self.mean + tot_count = self.count + xs_count + + new_sum = xs_var * xs_count + # correct old_sum deviation accounting for the new mean + old_sum = self.var * self.count + delta**2 * self.count * xs_count / tot_count + tot_sum = old_sum + new_sum + + self.mean += delta * xs_count / tot_count + self.var = tot_sum / tot_count + self.std = (self.var * tot_count / (tot_count - 1)).float().sqrt() + self.count = tot_count + + return xs_mean.item(), (xs_var * xs_count / (xs_count - 1)).float().sqrt().item() + + +@torch.no_grad() +def get_global_statistics(accelerator, xs: torch.Tensor, mask=None, device="cpu") -> Tuple[float, float, int]: + """ + Computes element-wise mean and variance of the tensor across processes. Reference: + https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/utils.py#L57C1-L73C75 + """ + xs = xs.to(accelerator.device) + sum_and_count = torch.tensor([xs.sum(), (xs.numel() if mask is None else mask.sum())], device=xs.device) + sum_and_count = accelerator.reduce(sum_and_count) + global_sum, count = sum_and_count + global_mean = global_sum / count + + sum_var = torch.sum(((xs - global_mean) ** 2).mul(1 if mask is None else mask)) + sum_var = accelerator.reduce(sum_var) + global_var = sum_var / count + + return global_mean.to(device), global_var.to(device), count.to(device) + + +def compute_accuracy(eval_pred) -> Dict[str, float]: + predictions, labels = eval_pred + # Here, predictions is rewards_chosen and rewards_rejected. + # We want to see how much of the time rewards_chosen > rewards_rejected. + if np.array(predictions[:, 0] == predictions[:, 1], dtype=float).sum() > 0: + warnings.warn( + f"There are {np.array(predictions[:, 0] == predictions[:, 1]).sum()} out of {len(predictions[:, 0])} instances where the predictions for both options are equal. As a consequence the accuracy can be misleading." + ) + predictions = np.argmax(predictions, axis=1) + + accuracy = np.array(predictions == labels, dtype=float).mean().item() + return {"accuracy": accuracy} + + +def pad_to_length(tensor: torch.Tensor, length: int, pad_value: Union[int, float], dim: int = -1) -> torch.Tensor: + if tensor.size(dim) >= length: + return tensor + else: + pad_size = list(tensor.shape) + pad_size[dim] = length - tensor.size(dim) + return torch.cat( + [ + tensor, + pad_value * torch.ones(*pad_size, dtype=tensor.dtype, device=tensor.device), + ], + dim=dim, + ) + + +def disable_dropout_in_model(model: torch.nn.Module) -> None: + for module in model.modules(): + if isinstance(module, torch.nn.Dropout): + module.p = 0 + + +def exact_div(a, b, a_str, b_str, custom_error_message=""): + q = a // b + if a != q * b: + raise ValueError(f"{custom_error_message}, {a_str}={a}, {b_str}={b}, inexact division: {a} / {b} = {a / b}") + return q + + +# copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/stat_tracking.py#L5 +class PerPromptStatTracker: + r""" + Class for tracking statistics per prompt. Mainly used to calculate advantage for the DPPO algorithm + + Args: + buffer_size (`int`): + Size of the buffer to keep for each prompt. + min_count (`int`): + Minimum number of samples to keep in the buffer before calculating the mean and std. + """ + + def __init__(self, buffer_size, min_count): + self.buffer_size = buffer_size + self.min_count = min_count + self.stats = {} + + def update(self, prompts, rewards): + prompts = np.array(prompts) + rewards = np.array(rewards) + unique = np.unique(prompts) + advantages = np.empty_like(rewards) + for prompt in unique: + prompt_rewards = rewards[prompts == prompt] + if prompt not in self.stats: + self.stats[prompt] = deque(maxlen=self.buffer_size) + self.stats[prompt].extend(prompt_rewards) + + if len(self.stats[prompt]) < self.min_count: + mean = np.mean(rewards) + std = np.std(rewards) + 1e-6 + else: + mean = np.mean(self.stats[prompt]) + std = np.std(self.stats[prompt]) + 1e-6 + advantages[prompts == prompt] = (prompt_rewards - mean) / std + + return advantages + + def get_stats(self): + return {k: {"mean": np.mean(v), "std": np.std(v), "count": len(v)} for k, v in self.stats.items()} + + +def neftune_post_forward_hook(module, input, output): + """ + Implements the NEFTune forward pass for the model using forward hooks. Note this works only for + torch.nn.Embedding layers. This method is slightly adapted from the original source code + that can be found here: https://github.com/neelsjain/NEFTune + + Simply add it to your model as follows: + ```python + model = ... + model.embed_tokens.neftune_noise_alpha = 0.1 + model.embed_tokens.register_forward_hook(neftune_post_forward_hook) + ``` + + Args: + module (`torch.nn.Module`): + The embedding module where the hook is attached. Note that you need to set + `module.neftune_noise_alpha` to the desired noise alpha value. + input (`torch.Tensor`): + The input tensor to the model. + output (`torch.Tensor`): + The output tensor of the model (i.e. the embeddings). + """ + if module.training: + dims = torch.tensor(output.size(1) * output.size(2)) + mag_norm = module.neftune_noise_alpha / torch.sqrt(dims) + output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm) + return output