Upload 13 files
Browse files- CODE_OF_CONDUCT.md +80 -0
- CONTRIBUTING.md +31 -0
- LICENSE +126 -0
- download.sh +58 -0
- example_chat_completion.py +73 -0
- example_text_completion.py +55 -0
- llama-2-7b/checklist.chk +2 -0
- llama-2-7b/consolidated.00.pth +3 -0
- llama-2-7b/params.json +1 -0
- llama/__init__.py +6 -0
- llama/generation.py +303 -0
- llama/model.py +288 -0
- llama/tokenizer.py +41 -0
CODE_OF_CONDUCT.md
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code of Conduct
|
| 2 |
+
|
| 3 |
+
## Our Pledge
|
| 4 |
+
|
| 5 |
+
In the interest of fostering an open and welcoming environment, we as
|
| 6 |
+
contributors and maintainers pledge to make participation in our project and
|
| 7 |
+
our community a harassment-free experience for everyone, regardless of age, body
|
| 8 |
+
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
| 9 |
+
level of experience, education, socio-economic status, nationality, personal
|
| 10 |
+
appearance, race, religion, or sexual identity and orientation.
|
| 11 |
+
|
| 12 |
+
## Our Standards
|
| 13 |
+
|
| 14 |
+
Examples of behavior that contributes to creating a positive environment
|
| 15 |
+
include:
|
| 16 |
+
|
| 17 |
+
* Using welcoming and inclusive language
|
| 18 |
+
* Being respectful of differing viewpoints and experiences
|
| 19 |
+
* Gracefully accepting constructive criticism
|
| 20 |
+
* Focusing on what is best for the community
|
| 21 |
+
* Showing empathy towards other community members
|
| 22 |
+
|
| 23 |
+
Examples of unacceptable behavior by participants include:
|
| 24 |
+
|
| 25 |
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
| 26 |
+
advances
|
| 27 |
+
* Trolling, insulting/derogatory comments, and personal or political attacks
|
| 28 |
+
* Public or private harassment
|
| 29 |
+
* Publishing others' private information, such as a physical or electronic
|
| 30 |
+
address, without explicit permission
|
| 31 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
| 32 |
+
professional setting
|
| 33 |
+
|
| 34 |
+
## Our Responsibilities
|
| 35 |
+
|
| 36 |
+
Project maintainers are responsible for clarifying the standards of acceptable
|
| 37 |
+
behavior and are expected to take appropriate and fair corrective action in
|
| 38 |
+
response to any instances of unacceptable behavior.
|
| 39 |
+
|
| 40 |
+
Project maintainers have the right and responsibility to remove, edit, or
|
| 41 |
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
| 42 |
+
that are not aligned to this Code of Conduct, or to ban temporarily or
|
| 43 |
+
permanently any contributor for other behaviors that they deem inappropriate,
|
| 44 |
+
threatening, offensive, or harmful.
|
| 45 |
+
|
| 46 |
+
## Scope
|
| 47 |
+
|
| 48 |
+
This Code of Conduct applies within all project spaces, and it also applies when
|
| 49 |
+
an individual is representing the project or its community in public spaces.
|
| 50 |
+
Examples of representing a project or community include using an official
|
| 51 |
+
project e-mail address, posting via an official social media account, or acting
|
| 52 |
+
as an appointed representative at an online or offline event. Representation of
|
| 53 |
+
a project may be further defined and clarified by project maintainers.
|
| 54 |
+
|
| 55 |
+
This Code of Conduct also applies outside the project spaces when there is a
|
| 56 |
+
reasonable belief that an individual's behavior may have a negative impact on
|
| 57 |
+
the project or its community.
|
| 58 |
+
|
| 59 |
+
## Enforcement
|
| 60 |
+
|
| 61 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
| 62 |
+
reported by contacting the project team at <opensource-conduct@meta.com>. All
|
| 63 |
+
complaints will be reviewed and investigated and will result in a response that
|
| 64 |
+
is deemed necessary and appropriate to the circumstances. The project team is
|
| 65 |
+
obligated to maintain confidentiality with regard to the reporter of an incident.
|
| 66 |
+
Further details of specific enforcement policies may be posted separately.
|
| 67 |
+
|
| 68 |
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
| 69 |
+
faith may face temporary or permanent repercussions as determined by other
|
| 70 |
+
members of the project's leadership.
|
| 71 |
+
|
| 72 |
+
## Attribution
|
| 73 |
+
|
| 74 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
| 75 |
+
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
| 76 |
+
|
| 77 |
+
[homepage]: https://www.contributor-covenant.org
|
| 78 |
+
|
| 79 |
+
For answers to common questions about this code of conduct, see
|
| 80 |
+
https://www.contributor-covenant.org/faq
|
CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributing to Llama
|
| 2 |
+
We want to make contributing to this project as easy and transparent as
|
| 3 |
+
possible.
|
| 4 |
+
|
| 5 |
+
## Pull Requests
|
| 6 |
+
We actively welcome your pull requests.
|
| 7 |
+
|
| 8 |
+
1. Fork the repo and create your branch from `main`.
|
| 9 |
+
2. If you've added code that should be tested, add tests.
|
| 10 |
+
3. If you've changed APIs, update the documentation.
|
| 11 |
+
4. Ensure the test suite passes.
|
| 12 |
+
5. Make sure your code lints.
|
| 13 |
+
6. If you haven't already, complete the Contributor License Agreement ("CLA").
|
| 14 |
+
|
| 15 |
+
## Contributor License Agreement ("CLA")
|
| 16 |
+
In order to accept your pull request, we need you to submit a CLA. You only need
|
| 17 |
+
to do this once to work on any of Meta's open source projects.
|
| 18 |
+
|
| 19 |
+
Complete your CLA here: <https://code.facebook.com/cla>
|
| 20 |
+
|
| 21 |
+
## Issues
|
| 22 |
+
We use GitHub issues to track public bugs. Please ensure your description is
|
| 23 |
+
clear and has sufficient instructions to be able to reproduce the issue.
|
| 24 |
+
|
| 25 |
+
Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
|
| 26 |
+
disclosure of security bugs. In those cases, please go through the process
|
| 27 |
+
outlined on that page and do not file a public issue.
|
| 28 |
+
|
| 29 |
+
## License
|
| 30 |
+
By contributing to Llama, you agree that your contributions will be licensed
|
| 31 |
+
under the LICENSE file in the root directory of this source tree.
|
LICENSE
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
LLAMA 2 COMMUNITY LICENSE AGREEMENT
|
| 2 |
+
Llama 2 Version Release Date: July 18, 2023
|
| 3 |
+
|
| 4 |
+
"Agreement" means the terms and conditions for use, reproduction, distribution and
|
| 5 |
+
modification of the Llama Materials set forth herein.
|
| 6 |
+
|
| 7 |
+
"Documentation" means the specifications, manuals and documentation
|
| 8 |
+
accompanying Llama 2 distributed by Meta at ai.meta.com/resources/models-and-
|
| 9 |
+
libraries/llama-downloads/.
|
| 10 |
+
|
| 11 |
+
"Licensee" or "you" means you, or your employer or any other person or entity (if
|
| 12 |
+
you are entering into this Agreement on such person or entity's behalf), of the age
|
| 13 |
+
required under applicable laws, rules or regulations to provide legal consent and that
|
| 14 |
+
has legal authority to bind your employer or such other person or entity if you are
|
| 15 |
+
entering in this Agreement on their behalf.
|
| 16 |
+
|
| 17 |
+
"Llama 2" means the foundational large language models and software and
|
| 18 |
+
algorithms, including machine-learning model code, trained model weights,
|
| 19 |
+
inference-enabling code, training-enabling code, fine-tuning enabling code and other
|
| 20 |
+
elements of the foregoing distributed by Meta at ai.meta.com/resources/models-and-
|
| 21 |
+
libraries/llama-downloads/.
|
| 22 |
+
|
| 23 |
+
"Llama Materials" means, collectively, Meta's proprietary Llama 2 and
|
| 24 |
+
Documentation (and any portion thereof) made available under this Agreement.
|
| 25 |
+
|
| 26 |
+
"Meta" or "we" means Meta Platforms Ireland Limited (if you are located in or, if you
|
| 27 |
+
are an entity, your principal place of business is in the EEA or Switzerland) and Meta
|
| 28 |
+
Platforms, Inc. (if you are located outside of the EEA or Switzerland).
|
| 29 |
+
|
| 30 |
+
By clicking "I Accept" below or by using or distributing any portion or element of the
|
| 31 |
+
Llama Materials, you agree to be bound by this Agreement.
|
| 32 |
+
|
| 33 |
+
1. License Rights and Redistribution.
|
| 34 |
+
|
| 35 |
+
a. Grant of Rights. You are granted a non-exclusive, worldwide, non-
|
| 36 |
+
transferable and royalty-free limited license under Meta's intellectual property or
|
| 37 |
+
other rights owned by Meta embodied in the Llama Materials to use, reproduce,
|
| 38 |
+
distribute, copy, create derivative works of, and make modifications to the Llama
|
| 39 |
+
Materials.
|
| 40 |
+
|
| 41 |
+
b. Redistribution and Use.
|
| 42 |
+
|
| 43 |
+
i. If you distribute or make the Llama Materials, or any derivative works
|
| 44 |
+
thereof, available to a third party, you shall provide a copy of this Agreement to such
|
| 45 |
+
third party.
|
| 46 |
+
ii. If you receive Llama Materials, or any derivative works thereof, from
|
| 47 |
+
a Licensee as part of an integrated end user product, then Section 2 of this
|
| 48 |
+
Agreement will not apply to you.
|
| 49 |
+
|
| 50 |
+
iii. You must retain in all copies of the Llama Materials that you
|
| 51 |
+
distribute the following attribution notice within a "Notice" text file distributed as a
|
| 52 |
+
part of such copies: "Llama 2 is licensed under the LLAMA 2 Community License,
|
| 53 |
+
Copyright (c) Meta Platforms, Inc. All Rights Reserved."
|
| 54 |
+
|
| 55 |
+
iv. Your use of the Llama Materials must comply with applicable laws
|
| 56 |
+
and regulations (including trade compliance laws and regulations) and adhere to the
|
| 57 |
+
Acceptable Use Policy for the Llama Materials (available at
|
| 58 |
+
https://ai.meta.com/llama/use-policy), which is hereby incorporated by reference into
|
| 59 |
+
this Agreement.
|
| 60 |
+
|
| 61 |
+
v. You will not use the Llama Materials or any output or results of the
|
| 62 |
+
Llama Materials to improve any other large language model (excluding Llama 2 or
|
| 63 |
+
derivative works thereof).
|
| 64 |
+
|
| 65 |
+
2. Additional Commercial Terms. If, on the Llama 2 version release date, the
|
| 66 |
+
monthly active users of the products or services made available by or for Licensee,
|
| 67 |
+
or Licensee's affiliates, is greater than 700 million monthly active users in the
|
| 68 |
+
preceding calendar month, you must request a license from Meta, which Meta may
|
| 69 |
+
grant to you in its sole discretion, and you are not authorized to exercise any of the
|
| 70 |
+
rights under this Agreement unless or until Meta otherwise expressly grants you
|
| 71 |
+
such rights.
|
| 72 |
+
|
| 73 |
+
3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE
|
| 74 |
+
LLAMA MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE
|
| 75 |
+
PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
| 76 |
+
EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY
|
| 77 |
+
WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR
|
| 78 |
+
FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE
|
| 79 |
+
FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING
|
| 80 |
+
THE LLAMA MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR
|
| 81 |
+
USE OF THE LLAMA MATERIALS AND ANY OUTPUT AND RESULTS.
|
| 82 |
+
|
| 83 |
+
4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE
|
| 84 |
+
LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT,
|
| 85 |
+
NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS
|
| 86 |
+
AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL,
|
| 87 |
+
CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN
|
| 88 |
+
IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF
|
| 89 |
+
ANY OF THE FOREGOING.
|
| 90 |
+
|
| 91 |
+
5. Intellectual Property.
|
| 92 |
+
|
| 93 |
+
a. No trademark licenses are granted under this Agreement, and in
|
| 94 |
+
connection with the Llama Materials, neither Meta nor Licensee may use any name
|
| 95 |
+
or mark owned by or associated with the other or any of its affiliates, except as
|
| 96 |
+
required for reasonable and customary use in describing and redistributing the
|
| 97 |
+
Llama Materials.
|
| 98 |
+
|
| 99 |
+
b. Subject to Meta's ownership of Llama Materials and derivatives made by or
|
| 100 |
+
for Meta, with respect to any derivative works and modifications of the Llama
|
| 101 |
+
Materials that are made by you, as between you and Meta, you are and will be the
|
| 102 |
+
owner of such derivative works and modifications.
|
| 103 |
+
|
| 104 |
+
c. If you institute litigation or other proceedings against Meta or any entity
|
| 105 |
+
(including a cross-claim or counterclaim in a lawsuit) alleging that the Llama
|
| 106 |
+
Materials or Llama 2 outputs or results, or any portion of any of the foregoing,
|
| 107 |
+
constitutes infringement of intellectual property or other rights owned or licensable
|
| 108 |
+
by you, then any licenses granted to you under this Agreement shall terminate as of
|
| 109 |
+
the date such litigation or claim is filed or instituted. You will indemnify and hold
|
| 110 |
+
harmless Meta from and against any claim by any third party arising out of or related
|
| 111 |
+
to your use or distribution of the Llama Materials.
|
| 112 |
+
|
| 113 |
+
6. Term and Termination. The term of this Agreement will commence upon your
|
| 114 |
+
acceptance of this Agreement or access to the Llama Materials and will continue in
|
| 115 |
+
full force and effect until terminated in accordance with the terms and conditions
|
| 116 |
+
herein. Meta may terminate this Agreement if you are in breach of any term or
|
| 117 |
+
condition of this Agreement. Upon termination of this Agreement, you shall delete
|
| 118 |
+
and cease use of the Llama Materials. Sections 3, 4 and 7 shall survive the
|
| 119 |
+
termination of this Agreement.
|
| 120 |
+
|
| 121 |
+
7. Governing Law and Jurisdiction. This Agreement will be governed and
|
| 122 |
+
construed under the laws of the State of California without regard to choice of law
|
| 123 |
+
principles, and the UN Convention on Contracts for the International Sale of Goods
|
| 124 |
+
does not apply to this Agreement. The courts of California shall have exclusive
|
| 125 |
+
jurisdiction of any dispute arising out of this Agreement.
|
| 126 |
+
|
download.sh
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
| 3 |
+
|
| 4 |
+
read -p "Enter the URL from email: " PRESIGNED_URL
|
| 5 |
+
echo ""
|
| 6 |
+
read -p "Enter the list of models to download without spaces (7B,13B,70B,7B-chat,13B-chat,70B-chat), or press Enter for all: " MODEL_SIZE
|
| 7 |
+
TARGET_FOLDER="." # where all files should end up
|
| 8 |
+
mkdir -p ${TARGET_FOLDER}
|
| 9 |
+
|
| 10 |
+
if [[ $MODEL_SIZE == "" ]]; then
|
| 11 |
+
MODEL_SIZE="7B,13B,70B,7B-chat,13B-chat,70B-chat"
|
| 12 |
+
fi
|
| 13 |
+
|
| 14 |
+
echo "Downloading LICENSE and Acceptable Usage Policy"
|
| 15 |
+
wget ${PRESIGNED_URL/'*'/"LICENSE"} -O ${TARGET_FOLDER}"/LICENSE"
|
| 16 |
+
wget ${PRESIGNED_URL/'*'/"USE_POLICY.md"} -O ${TARGET_FOLDER}"/USE_POLICY.md"
|
| 17 |
+
|
| 18 |
+
echo "Downloading tokenizer"
|
| 19 |
+
wget ${PRESIGNED_URL/'*'/"tokenizer.model"} -O ${TARGET_FOLDER}"/tokenizer.model"
|
| 20 |
+
wget ${PRESIGNED_URL/'*'/"tokenizer_checklist.chk"} -O ${TARGET_FOLDER}"/tokenizer_checklist.chk"
|
| 21 |
+
(cd ${TARGET_FOLDER} && md5sum -c tokenizer_checklist.chk)
|
| 22 |
+
|
| 23 |
+
for m in ${MODEL_SIZE//,/ }
|
| 24 |
+
do
|
| 25 |
+
if [[ $m == "7B" ]]; then
|
| 26 |
+
SHARD=0
|
| 27 |
+
MODEL_PATH="llama-2-7b"
|
| 28 |
+
elif [[ $m == "7B-chat" ]]; then
|
| 29 |
+
SHARD=0
|
| 30 |
+
MODEL_PATH="llama-2-7b-chat"
|
| 31 |
+
elif [[ $m == "13B" ]]; then
|
| 32 |
+
SHARD=1
|
| 33 |
+
MODEL_PATH="llama-2-13b"
|
| 34 |
+
elif [[ $m == "13B-chat" ]]; then
|
| 35 |
+
SHARD=1
|
| 36 |
+
MODEL_PATH="llama-2-13b-chat"
|
| 37 |
+
elif [[ $m == "70B" ]]; then
|
| 38 |
+
SHARD=7
|
| 39 |
+
MODEL_PATH="llama-2-70b"
|
| 40 |
+
elif [[ $m == "70B-chat" ]]; then
|
| 41 |
+
SHARD=7
|
| 42 |
+
MODEL_PATH="llama-2-70b-chat"
|
| 43 |
+
fi
|
| 44 |
+
|
| 45 |
+
echo "Downloading ${MODEL_PATH}"
|
| 46 |
+
mkdir -p ${TARGET_FOLDER}"/${MODEL_PATH}"
|
| 47 |
+
|
| 48 |
+
for s in $(seq -f "0%g" 0 ${SHARD})
|
| 49 |
+
do
|
| 50 |
+
wget ${PRESIGNED_URL/'*'/"${MODEL_PATH}/consolidated.${s}.pth"} -O ${TARGET_FOLDER}"/${MODEL_PATH}/consolidated.${s}.pth"
|
| 51 |
+
done
|
| 52 |
+
|
| 53 |
+
wget ${PRESIGNED_URL/'*'/"${MODEL_PATH}/params.json"} -O ${TARGET_FOLDER}"/${MODEL_PATH}/params.json"
|
| 54 |
+
wget ${PRESIGNED_URL/'*'/"${MODEL_PATH}/checklist.chk"} -O ${TARGET_FOLDER}"/${MODEL_PATH}/checklist.chk"
|
| 55 |
+
echo "Checking checksums"
|
| 56 |
+
(cd ${TARGET_FOLDER}"/${MODEL_PATH}" && md5sum -c checklist.chk)
|
| 57 |
+
done
|
| 58 |
+
|
example_chat_completion.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
| 3 |
+
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import fire
|
| 7 |
+
|
| 8 |
+
from llama import Llama
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def main(
|
| 12 |
+
ckpt_dir: str,
|
| 13 |
+
tokenizer_path: str,
|
| 14 |
+
temperature: float = 0.6,
|
| 15 |
+
top_p: float = 0.9,
|
| 16 |
+
max_seq_len: int = 512,
|
| 17 |
+
max_batch_size: int = 4,
|
| 18 |
+
max_gen_len: Optional[int] = None,
|
| 19 |
+
):
|
| 20 |
+
generator = Llama.build(
|
| 21 |
+
ckpt_dir=ckpt_dir,
|
| 22 |
+
tokenizer_path=tokenizer_path,
|
| 23 |
+
max_seq_len=max_seq_len,
|
| 24 |
+
max_batch_size=max_batch_size,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
dialogs = [
|
| 28 |
+
[{"role": "user", "content": "what is the recipe of mayonnaise?"}],
|
| 29 |
+
[
|
| 30 |
+
{"role": "user", "content": "I am going to Paris, what should I see?"},
|
| 31 |
+
{
|
| 32 |
+
"role": "assistant",
|
| 33 |
+
"content": """\
|
| 34 |
+
Paris, the capital of France, is known for its stunning architecture, art museums, historical landmarks, and romantic atmosphere. Here are some of the top attractions to see in Paris:
|
| 35 |
+
|
| 36 |
+
1. The Eiffel Tower: The iconic Eiffel Tower is one of the most recognizable landmarks in the world and offers breathtaking views of the city.
|
| 37 |
+
2. The Louvre Museum: The Louvre is one of the world's largest and most famous museums, housing an impressive collection of art and artifacts, including the Mona Lisa.
|
| 38 |
+
3. Notre-Dame Cathedral: This beautiful cathedral is one of the most famous landmarks in Paris and is known for its Gothic architecture and stunning stained glass windows.
|
| 39 |
+
|
| 40 |
+
These are just a few of the many attractions that Paris has to offer. With so much to see and do, it's no wonder that Paris is one of the most popular tourist destinations in the world.""",
|
| 41 |
+
},
|
| 42 |
+
{"role": "user", "content": "What is so great about #1?"},
|
| 43 |
+
],
|
| 44 |
+
[
|
| 45 |
+
{"role": "system", "content": "Always answer with Haiku"},
|
| 46 |
+
{"role": "user", "content": "I am going to Paris, what should I see?"},
|
| 47 |
+
],
|
| 48 |
+
[
|
| 49 |
+
{
|
| 50 |
+
"role": "system",
|
| 51 |
+
"content": "Always answer with emojis",
|
| 52 |
+
},
|
| 53 |
+
{"role": "user", "content": "How to go from Beijing to NY?"},
|
| 54 |
+
],
|
| 55 |
+
]
|
| 56 |
+
results = generator.chat_completion(
|
| 57 |
+
dialogs, # type: ignore
|
| 58 |
+
max_gen_len=max_gen_len,
|
| 59 |
+
temperature=temperature,
|
| 60 |
+
top_p=top_p,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
for dialog, result in zip(dialogs, results):
|
| 64 |
+
for msg in dialog:
|
| 65 |
+
print(f"{msg['role'].capitalize()}: {msg['content']}\n")
|
| 66 |
+
print(
|
| 67 |
+
f"> {result['generation']['role'].capitalize()}: {result['generation']['content']}"
|
| 68 |
+
)
|
| 69 |
+
print("\n==================================\n")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
if __name__ == "__main__":
|
| 73 |
+
fire.Fire(main)
|
example_text_completion.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
| 3 |
+
|
| 4 |
+
import fire
|
| 5 |
+
|
| 6 |
+
from llama import Llama
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def main(
|
| 10 |
+
ckpt_dir: str,
|
| 11 |
+
tokenizer_path: str,
|
| 12 |
+
temperature: float = 0.6,
|
| 13 |
+
top_p: float = 0.9,
|
| 14 |
+
max_seq_len: int = 128,
|
| 15 |
+
max_gen_len: int = 64,
|
| 16 |
+
max_batch_size: int = 4,
|
| 17 |
+
):
|
| 18 |
+
generator = Llama.build(
|
| 19 |
+
ckpt_dir=ckpt_dir,
|
| 20 |
+
tokenizer_path=tokenizer_path,
|
| 21 |
+
max_seq_len=max_seq_len,
|
| 22 |
+
max_batch_size=max_batch_size,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
prompts = [
|
| 26 |
+
# For these prompts, the expected answer is the natural continuation of the prompt
|
| 27 |
+
"I believe the meaning of life is",
|
| 28 |
+
"Simply put, the theory of relativity states that ",
|
| 29 |
+
"""A brief message congratulating the team on the launch:
|
| 30 |
+
|
| 31 |
+
Hi everyone,
|
| 32 |
+
|
| 33 |
+
I just """,
|
| 34 |
+
# Few shot prompt (providing a few examples before asking model to complete more);
|
| 35 |
+
"""Translate English to French:
|
| 36 |
+
|
| 37 |
+
sea otter => loutre de mer
|
| 38 |
+
peppermint => menthe poivrée
|
| 39 |
+
plush girafe => girafe peluche
|
| 40 |
+
cheese =>""",
|
| 41 |
+
]
|
| 42 |
+
results = generator.text_completion(
|
| 43 |
+
prompts,
|
| 44 |
+
max_gen_len=max_gen_len,
|
| 45 |
+
temperature=temperature,
|
| 46 |
+
top_p=top_p,
|
| 47 |
+
)
|
| 48 |
+
for prompt, result in zip(prompts, results):
|
| 49 |
+
print(prompt)
|
| 50 |
+
print(f"> {result['generation']}")
|
| 51 |
+
print("\n==================================\n")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
if __name__ == "__main__":
|
| 55 |
+
fire.Fire(main)
|
llama-2-7b/checklist.chk
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
daa8e3109935070df7fe8fc42d34525e consolidated.00.pth
|
| 2 |
+
9a3757de7196d1840b551a85b82efbc8 params.json
|
llama-2-7b/consolidated.00.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d67a91807d5879d193a694da57f28ff85092e92dc9fbef4888bd05e22b15ab75
|
| 3 |
+
size 13476925163
|
llama-2-7b/params.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-05, "vocab_size": -1}
|
llama/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
| 3 |
+
|
| 4 |
+
from .generation import Llama
|
| 5 |
+
from .model import ModelArgs, Transformer
|
| 6 |
+
from .tokenizer import Tokenizer
|
llama/generation.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
| 3 |
+
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
import time
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import List, Literal, Optional, Tuple, TypedDict
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from fairscale.nn.model_parallel.initialize import (
|
| 14 |
+
get_model_parallel_rank,
|
| 15 |
+
initialize_model_parallel,
|
| 16 |
+
model_parallel_is_initialized,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
from llama.model import ModelArgs, Transformer
|
| 20 |
+
from llama.tokenizer import Tokenizer
|
| 21 |
+
|
| 22 |
+
Role = Literal["system", "user", "assistant"]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Message(TypedDict):
|
| 26 |
+
role: Role
|
| 27 |
+
content: str
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class CompletionPrediction(TypedDict, total=False):
|
| 31 |
+
generation: str
|
| 32 |
+
tokens: List[str] # not required
|
| 33 |
+
logprobs: List[float] # not required
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class ChatPrediction(TypedDict, total=False):
|
| 37 |
+
generation: Message
|
| 38 |
+
tokens: List[str] # not required
|
| 39 |
+
logprobs: List[float] # not required
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
Dialog = List[Message]
|
| 43 |
+
|
| 44 |
+
B_INST, E_INST = "[INST]", "[/INST]"
|
| 45 |
+
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
| 46 |
+
DEFAULT_SYSTEM_PROMPT = """\
|
| 47 |
+
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
| 48 |
+
|
| 49 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class Llama:
|
| 53 |
+
@staticmethod
|
| 54 |
+
def build(
|
| 55 |
+
ckpt_dir: str,
|
| 56 |
+
tokenizer_path: str,
|
| 57 |
+
max_seq_len: int,
|
| 58 |
+
max_batch_size: int,
|
| 59 |
+
model_parallel_size: Optional[int] = None,
|
| 60 |
+
) -> "Llama":
|
| 61 |
+
if not torch.distributed.is_initialized():
|
| 62 |
+
torch.distributed.init_process_group("nccl")
|
| 63 |
+
if not model_parallel_is_initialized():
|
| 64 |
+
if model_parallel_size is None:
|
| 65 |
+
model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
|
| 66 |
+
initialize_model_parallel(model_parallel_size)
|
| 67 |
+
|
| 68 |
+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
| 69 |
+
torch.cuda.set_device(local_rank)
|
| 70 |
+
|
| 71 |
+
# seed must be the same in all processes
|
| 72 |
+
torch.manual_seed(1)
|
| 73 |
+
|
| 74 |
+
if local_rank > 0:
|
| 75 |
+
sys.stdout = open(os.devnull, "w")
|
| 76 |
+
|
| 77 |
+
start_time = time.time()
|
| 78 |
+
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
| 79 |
+
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
| 80 |
+
assert model_parallel_size == len(
|
| 81 |
+
checkpoints
|
| 82 |
+
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
|
| 83 |
+
ckpt_path = checkpoints[get_model_parallel_rank()]
|
| 84 |
+
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
| 85 |
+
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
| 86 |
+
params = json.loads(f.read())
|
| 87 |
+
|
| 88 |
+
model_args: ModelArgs = ModelArgs(
|
| 89 |
+
max_seq_len=max_seq_len,
|
| 90 |
+
max_batch_size=max_batch_size,
|
| 91 |
+
**params,
|
| 92 |
+
)
|
| 93 |
+
tokenizer = Tokenizer(model_path=tokenizer_path)
|
| 94 |
+
model_args.vocab_size = tokenizer.n_words
|
| 95 |
+
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
| 96 |
+
model = Transformer(model_args)
|
| 97 |
+
model.load_state_dict(checkpoint, strict=False)
|
| 98 |
+
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
| 99 |
+
|
| 100 |
+
return Llama(model, tokenizer)
|
| 101 |
+
|
| 102 |
+
def __init__(self, model: Transformer, tokenizer: Tokenizer):
|
| 103 |
+
self.model = model
|
| 104 |
+
self.tokenizer = tokenizer
|
| 105 |
+
|
| 106 |
+
@torch.inference_mode()
|
| 107 |
+
def generate(
|
| 108 |
+
self,
|
| 109 |
+
prompt_tokens: List[List[int]],
|
| 110 |
+
max_gen_len: int,
|
| 111 |
+
temperature: float = 0.6,
|
| 112 |
+
top_p: float = 0.9,
|
| 113 |
+
logprobs: bool = False,
|
| 114 |
+
echo: bool = False,
|
| 115 |
+
) -> Tuple[List[List[int]], Optional[List[List[float]]]]:
|
| 116 |
+
params = self.model.params
|
| 117 |
+
bsz = len(prompt_tokens)
|
| 118 |
+
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
|
| 119 |
+
|
| 120 |
+
min_prompt_len = min(len(t) for t in prompt_tokens)
|
| 121 |
+
max_prompt_len = max(len(t) for t in prompt_tokens)
|
| 122 |
+
assert max_prompt_len <= params.max_seq_len
|
| 123 |
+
total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)
|
| 124 |
+
|
| 125 |
+
pad_id = self.tokenizer.pad_id
|
| 126 |
+
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
|
| 127 |
+
for k, t in enumerate(prompt_tokens):
|
| 128 |
+
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
|
| 129 |
+
if logprobs:
|
| 130 |
+
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
|
| 131 |
+
|
| 132 |
+
prev_pos = 0
|
| 133 |
+
eos_reached = torch.tensor([False] * bsz, device="cuda")
|
| 134 |
+
input_text_mask = tokens != pad_id
|
| 135 |
+
for cur_pos in range(min_prompt_len, total_len):
|
| 136 |
+
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
| 137 |
+
if logprobs:
|
| 138 |
+
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
|
| 139 |
+
input=logits.transpose(1, 2),
|
| 140 |
+
target=tokens[:, prev_pos + 1 : cur_pos + 1],
|
| 141 |
+
reduction="none",
|
| 142 |
+
ignore_index=pad_id,
|
| 143 |
+
)
|
| 144 |
+
if temperature > 0:
|
| 145 |
+
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
|
| 146 |
+
next_token = sample_top_p(probs, top_p)
|
| 147 |
+
else:
|
| 148 |
+
next_token = torch.argmax(logits[:, -1], dim=-1)
|
| 149 |
+
|
| 150 |
+
next_token = next_token.reshape(-1)
|
| 151 |
+
# only replace token if prompt has already been generated
|
| 152 |
+
next_token = torch.where(
|
| 153 |
+
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
|
| 154 |
+
)
|
| 155 |
+
tokens[:, cur_pos] = next_token
|
| 156 |
+
eos_reached |= (~input_text_mask[:, cur_pos]) & (
|
| 157 |
+
next_token == self.tokenizer.eos_id
|
| 158 |
+
)
|
| 159 |
+
prev_pos = cur_pos
|
| 160 |
+
if all(eos_reached):
|
| 161 |
+
break
|
| 162 |
+
|
| 163 |
+
if logprobs:
|
| 164 |
+
token_logprobs = token_logprobs.tolist()
|
| 165 |
+
out_tokens, out_logprobs = [], []
|
| 166 |
+
for i, toks in enumerate(tokens.tolist()):
|
| 167 |
+
# cut to max gen len
|
| 168 |
+
start = 0 if echo else len(prompt_tokens[i])
|
| 169 |
+
toks = toks[start : len(prompt_tokens[i]) + max_gen_len]
|
| 170 |
+
if logprobs:
|
| 171 |
+
probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len]
|
| 172 |
+
# cut to eos tok if any
|
| 173 |
+
if self.tokenizer.eos_id in toks:
|
| 174 |
+
eos_idx = toks.index(self.tokenizer.eos_id)
|
| 175 |
+
toks = toks[:eos_idx]
|
| 176 |
+
probs = probs[:eos_idx] if logprobs else None
|
| 177 |
+
out_tokens.append(toks)
|
| 178 |
+
out_logprobs.append(probs)
|
| 179 |
+
return (out_tokens, out_logprobs if logprobs else None)
|
| 180 |
+
|
| 181 |
+
def text_completion(
|
| 182 |
+
self,
|
| 183 |
+
prompts: List[str],
|
| 184 |
+
temperature: float = 0.6,
|
| 185 |
+
top_p: float = 0.9,
|
| 186 |
+
max_gen_len: Optional[int] = None,
|
| 187 |
+
logprobs: bool = False,
|
| 188 |
+
echo: bool = False,
|
| 189 |
+
) -> List[CompletionPrediction]:
|
| 190 |
+
if max_gen_len is None:
|
| 191 |
+
max_gen_len = self.model.params.max_seq_len - 1
|
| 192 |
+
prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
|
| 193 |
+
generation_tokens, generation_logprobs = self.generate(
|
| 194 |
+
prompt_tokens=prompt_tokens,
|
| 195 |
+
max_gen_len=max_gen_len,
|
| 196 |
+
temperature=temperature,
|
| 197 |
+
top_p=top_p,
|
| 198 |
+
logprobs=logprobs,
|
| 199 |
+
echo=echo,
|
| 200 |
+
)
|
| 201 |
+
if logprobs:
|
| 202 |
+
return [
|
| 203 |
+
{
|
| 204 |
+
"generation": self.tokenizer.decode(t),
|
| 205 |
+
"tokens": [self.tokenizer.decode(x) for x in t],
|
| 206 |
+
"logprobs": logprobs_i,
|
| 207 |
+
}
|
| 208 |
+
for t, logprobs_i in zip(generation_tokens, generation_logprobs)
|
| 209 |
+
]
|
| 210 |
+
return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens]
|
| 211 |
+
|
| 212 |
+
def chat_completion(
|
| 213 |
+
self,
|
| 214 |
+
dialogs: List[Dialog],
|
| 215 |
+
temperature: float = 0.6,
|
| 216 |
+
top_p: float = 0.9,
|
| 217 |
+
max_gen_len: Optional[int] = None,
|
| 218 |
+
logprobs: bool = False,
|
| 219 |
+
) -> List[ChatPrediction]:
|
| 220 |
+
if max_gen_len is None:
|
| 221 |
+
max_gen_len = self.model.params.max_seq_len - 1
|
| 222 |
+
prompt_tokens = []
|
| 223 |
+
for dialog in dialogs:
|
| 224 |
+
if dialog[0]["role"] != "system":
|
| 225 |
+
dialog = [
|
| 226 |
+
{
|
| 227 |
+
"role": "system",
|
| 228 |
+
"content": DEFAULT_SYSTEM_PROMPT,
|
| 229 |
+
}
|
| 230 |
+
] + dialog
|
| 231 |
+
dialog = [
|
| 232 |
+
{
|
| 233 |
+
"role": dialog[1]["role"],
|
| 234 |
+
"content": B_SYS
|
| 235 |
+
+ dialog[0]["content"]
|
| 236 |
+
+ E_SYS
|
| 237 |
+
+ dialog[1]["content"],
|
| 238 |
+
}
|
| 239 |
+
] + dialog[2:]
|
| 240 |
+
assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
|
| 241 |
+
[msg["role"] == "assistant" for msg in dialog[1::2]]
|
| 242 |
+
), (
|
| 243 |
+
"model only supports 'system', 'user' and 'assistant' roles, "
|
| 244 |
+
"starting with 'system', then 'user' and alternating (u/a/u/a/u...)"
|
| 245 |
+
)
|
| 246 |
+
dialog_tokens: List[int] = sum(
|
| 247 |
+
[
|
| 248 |
+
self.tokenizer.encode(
|
| 249 |
+
f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
|
| 250 |
+
bos=True,
|
| 251 |
+
eos=True,
|
| 252 |
+
)
|
| 253 |
+
for prompt, answer in zip(
|
| 254 |
+
dialog[::2],
|
| 255 |
+
dialog[1::2],
|
| 256 |
+
)
|
| 257 |
+
],
|
| 258 |
+
[],
|
| 259 |
+
)
|
| 260 |
+
assert (
|
| 261 |
+
dialog[-1]["role"] == "user"
|
| 262 |
+
), f"Last message must be from user, got {dialog[-1]['role']}"
|
| 263 |
+
dialog_tokens += self.tokenizer.encode(
|
| 264 |
+
f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
|
| 265 |
+
bos=True,
|
| 266 |
+
eos=False,
|
| 267 |
+
)
|
| 268 |
+
prompt_tokens.append(dialog_tokens)
|
| 269 |
+
|
| 270 |
+
generation_tokens, generation_logprobs = self.generate(
|
| 271 |
+
prompt_tokens=prompt_tokens,
|
| 272 |
+
max_gen_len=max_gen_len,
|
| 273 |
+
temperature=temperature,
|
| 274 |
+
top_p=top_p,
|
| 275 |
+
logprobs=logprobs,
|
| 276 |
+
)
|
| 277 |
+
if logprobs:
|
| 278 |
+
return [
|
| 279 |
+
{
|
| 280 |
+
"generation": {
|
| 281 |
+
"role": "assistant",
|
| 282 |
+
"content": self.tokenizer.decode(t),
|
| 283 |
+
},
|
| 284 |
+
"tokens": [self.tokenizer.decode(x) for x in t],
|
| 285 |
+
"logprobs": logprobs_i,
|
| 286 |
+
}
|
| 287 |
+
for t, logprobs_i in zip(generation_tokens, generation_logprobs)
|
| 288 |
+
]
|
| 289 |
+
return [
|
| 290 |
+
{"generation": {"role": "assistant", "content": self.tokenizer.decode(t)}}
|
| 291 |
+
for t in generation_tokens
|
| 292 |
+
]
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def sample_top_p(probs, p):
|
| 296 |
+
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
| 297 |
+
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
| 298 |
+
mask = probs_sum - probs_sort > p
|
| 299 |
+
probs_sort[mask] = 0.0
|
| 300 |
+
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
| 301 |
+
next_token = torch.multinomial(probs_sort, num_samples=1)
|
| 302 |
+
next_token = torch.gather(probs_idx, -1, next_token)
|
| 303 |
+
return next_token
|
llama/model.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Any, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
import fairscale.nn.model_parallel.initialize as fs_init
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from fairscale.nn.model_parallel.layers import (
|
| 12 |
+
ColumnParallelLinear,
|
| 13 |
+
ParallelEmbedding,
|
| 14 |
+
RowParallelLinear,
|
| 15 |
+
)
|
| 16 |
+
from torch import nn
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class ModelArgs:
|
| 21 |
+
dim: int = 4096
|
| 22 |
+
n_layers: int = 32
|
| 23 |
+
n_heads: int = 32
|
| 24 |
+
n_kv_heads: Optional[int] = None
|
| 25 |
+
vocab_size: int = -1 # defined later by tokenizer
|
| 26 |
+
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
| 27 |
+
ffn_dim_multiplier: Optional[float] = None
|
| 28 |
+
norm_eps: float = 1e-5
|
| 29 |
+
|
| 30 |
+
max_batch_size: int = 32
|
| 31 |
+
max_seq_len: int = 2048
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class RMSNorm(torch.nn.Module):
|
| 35 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.eps = eps
|
| 38 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 39 |
+
|
| 40 |
+
def _norm(self, x):
|
| 41 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 42 |
+
|
| 43 |
+
def forward(self, x):
|
| 44 |
+
output = self._norm(x.float()).type_as(x)
|
| 45 |
+
return output * self.weight
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
| 49 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 50 |
+
t = torch.arange(end, device=freqs.device) # type: ignore
|
| 51 |
+
freqs = torch.outer(t, freqs).float() # type: ignore
|
| 52 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
| 53 |
+
return freqs_cis
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
| 57 |
+
ndim = x.ndim
|
| 58 |
+
assert 0 <= 1 < ndim
|
| 59 |
+
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
| 60 |
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
| 61 |
+
return freqs_cis.view(*shape)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def apply_rotary_emb(
|
| 65 |
+
xq: torch.Tensor,
|
| 66 |
+
xk: torch.Tensor,
|
| 67 |
+
freqs_cis: torch.Tensor,
|
| 68 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 69 |
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
| 70 |
+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
| 71 |
+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
| 72 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
| 73 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
| 74 |
+
return xq_out.type_as(xq), xk_out.type_as(xk)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 78 |
+
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
|
| 79 |
+
bs, slen, n_kv_heads, head_dim = x.shape
|
| 80 |
+
if n_rep == 1:
|
| 81 |
+
return x
|
| 82 |
+
return (
|
| 83 |
+
x[:, :, :, None, :]
|
| 84 |
+
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
|
| 85 |
+
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class Attention(nn.Module):
|
| 90 |
+
def __init__(self, args: ModelArgs):
|
| 91 |
+
super().__init__()
|
| 92 |
+
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
| 93 |
+
model_parallel_size = fs_init.get_model_parallel_world_size()
|
| 94 |
+
self.n_local_heads = args.n_heads // model_parallel_size
|
| 95 |
+
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
|
| 96 |
+
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
| 97 |
+
self.head_dim = args.dim // args.n_heads
|
| 98 |
+
|
| 99 |
+
self.wq = ColumnParallelLinear(
|
| 100 |
+
args.dim,
|
| 101 |
+
args.n_heads * self.head_dim,
|
| 102 |
+
bias=False,
|
| 103 |
+
gather_output=False,
|
| 104 |
+
init_method=lambda x: x,
|
| 105 |
+
)
|
| 106 |
+
self.wk = ColumnParallelLinear(
|
| 107 |
+
args.dim,
|
| 108 |
+
self.n_kv_heads * self.head_dim,
|
| 109 |
+
bias=False,
|
| 110 |
+
gather_output=False,
|
| 111 |
+
init_method=lambda x: x,
|
| 112 |
+
)
|
| 113 |
+
self.wv = ColumnParallelLinear(
|
| 114 |
+
args.dim,
|
| 115 |
+
self.n_kv_heads * self.head_dim,
|
| 116 |
+
bias=False,
|
| 117 |
+
gather_output=False,
|
| 118 |
+
init_method=lambda x: x,
|
| 119 |
+
)
|
| 120 |
+
self.wo = RowParallelLinear(
|
| 121 |
+
args.n_heads * self.head_dim,
|
| 122 |
+
args.dim,
|
| 123 |
+
bias=False,
|
| 124 |
+
input_is_parallel=True,
|
| 125 |
+
init_method=lambda x: x,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
self.cache_k = torch.zeros(
|
| 129 |
+
(
|
| 130 |
+
args.max_batch_size,
|
| 131 |
+
args.max_seq_len,
|
| 132 |
+
self.n_local_kv_heads,
|
| 133 |
+
self.head_dim,
|
| 134 |
+
)
|
| 135 |
+
).cuda()
|
| 136 |
+
self.cache_v = torch.zeros(
|
| 137 |
+
(
|
| 138 |
+
args.max_batch_size,
|
| 139 |
+
args.max_seq_len,
|
| 140 |
+
self.n_local_kv_heads,
|
| 141 |
+
self.head_dim,
|
| 142 |
+
)
|
| 143 |
+
).cuda()
|
| 144 |
+
|
| 145 |
+
def forward(
|
| 146 |
+
self,
|
| 147 |
+
x: torch.Tensor,
|
| 148 |
+
start_pos: int,
|
| 149 |
+
freqs_cis: torch.Tensor,
|
| 150 |
+
mask: Optional[torch.Tensor],
|
| 151 |
+
):
|
| 152 |
+
bsz, seqlen, _ = x.shape
|
| 153 |
+
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
| 154 |
+
|
| 155 |
+
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
| 156 |
+
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
| 157 |
+
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
| 158 |
+
|
| 159 |
+
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
| 160 |
+
|
| 161 |
+
self.cache_k = self.cache_k.to(xq)
|
| 162 |
+
self.cache_v = self.cache_v.to(xq)
|
| 163 |
+
|
| 164 |
+
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
|
| 165 |
+
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
|
| 166 |
+
|
| 167 |
+
keys = self.cache_k[:bsz, : start_pos + seqlen]
|
| 168 |
+
values = self.cache_v[:bsz, : start_pos + seqlen]
|
| 169 |
+
|
| 170 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
| 171 |
+
keys = repeat_kv(keys, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
|
| 172 |
+
values = repeat_kv(values, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
|
| 173 |
+
|
| 174 |
+
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
|
| 175 |
+
keys = keys.transpose(1, 2)
|
| 176 |
+
values = values.transpose(1, 2)
|
| 177 |
+
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
|
| 178 |
+
if mask is not None:
|
| 179 |
+
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
|
| 180 |
+
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
| 181 |
+
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
|
| 182 |
+
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
| 183 |
+
return self.wo(output)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class FeedForward(nn.Module):
|
| 187 |
+
def __init__(
|
| 188 |
+
self,
|
| 189 |
+
dim: int,
|
| 190 |
+
hidden_dim: int,
|
| 191 |
+
multiple_of: int,
|
| 192 |
+
ffn_dim_multiplier: Optional[float],
|
| 193 |
+
):
|
| 194 |
+
super().__init__()
|
| 195 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
| 196 |
+
# custom dim factor multiplier
|
| 197 |
+
if ffn_dim_multiplier is not None:
|
| 198 |
+
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
| 199 |
+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
| 200 |
+
|
| 201 |
+
self.w1 = ColumnParallelLinear(
|
| 202 |
+
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
|
| 203 |
+
)
|
| 204 |
+
self.w2 = RowParallelLinear(
|
| 205 |
+
hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
|
| 206 |
+
)
|
| 207 |
+
self.w3 = ColumnParallelLinear(
|
| 208 |
+
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
def forward(self, x):
|
| 212 |
+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class TransformerBlock(nn.Module):
|
| 216 |
+
def __init__(self, layer_id: int, args: ModelArgs):
|
| 217 |
+
super().__init__()
|
| 218 |
+
self.n_heads = args.n_heads
|
| 219 |
+
self.dim = args.dim
|
| 220 |
+
self.head_dim = args.dim // args.n_heads
|
| 221 |
+
self.attention = Attention(args)
|
| 222 |
+
self.feed_forward = FeedForward(
|
| 223 |
+
dim=args.dim,
|
| 224 |
+
hidden_dim=4 * args.dim,
|
| 225 |
+
multiple_of=args.multiple_of,
|
| 226 |
+
ffn_dim_multiplier=args.ffn_dim_multiplier,
|
| 227 |
+
)
|
| 228 |
+
self.layer_id = layer_id
|
| 229 |
+
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
| 230 |
+
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
| 231 |
+
|
| 232 |
+
def forward(
|
| 233 |
+
self,
|
| 234 |
+
x: torch.Tensor,
|
| 235 |
+
start_pos: int,
|
| 236 |
+
freqs_cis: torch.Tensor,
|
| 237 |
+
mask: Optional[torch.Tensor],
|
| 238 |
+
):
|
| 239 |
+
h = x + self.attention.forward(
|
| 240 |
+
self.attention_norm(x), start_pos, freqs_cis, mask
|
| 241 |
+
)
|
| 242 |
+
out = h + self.feed_forward.forward(self.ffn_norm(h))
|
| 243 |
+
return out
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
class Transformer(nn.Module):
|
| 247 |
+
def __init__(self, params: ModelArgs):
|
| 248 |
+
super().__init__()
|
| 249 |
+
self.params = params
|
| 250 |
+
self.vocab_size = params.vocab_size
|
| 251 |
+
self.n_layers = params.n_layers
|
| 252 |
+
|
| 253 |
+
self.tok_embeddings = ParallelEmbedding(
|
| 254 |
+
params.vocab_size, params.dim, init_method=lambda x: x
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
self.layers = torch.nn.ModuleList()
|
| 258 |
+
for layer_id in range(params.n_layers):
|
| 259 |
+
self.layers.append(TransformerBlock(layer_id, params))
|
| 260 |
+
|
| 261 |
+
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
| 262 |
+
self.output = ColumnParallelLinear(
|
| 263 |
+
params.dim, params.vocab_size, bias=False, init_method=lambda x: x
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
self.freqs_cis = precompute_freqs_cis(
|
| 267 |
+
self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
@torch.inference_mode()
|
| 271 |
+
def forward(self, tokens: torch.Tensor, start_pos: int):
|
| 272 |
+
_bsz, seqlen = tokens.shape
|
| 273 |
+
h = self.tok_embeddings(tokens)
|
| 274 |
+
self.freqs_cis = self.freqs_cis.to(h.device)
|
| 275 |
+
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
|
| 276 |
+
|
| 277 |
+
mask = None
|
| 278 |
+
if seqlen > 1:
|
| 279 |
+
mask = torch.full(
|
| 280 |
+
(1, 1, seqlen, seqlen), float("-inf"), device=tokens.device
|
| 281 |
+
)
|
| 282 |
+
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
|
| 283 |
+
|
| 284 |
+
for layer in self.layers:
|
| 285 |
+
h = layer(h, start_pos, freqs_cis, mask)
|
| 286 |
+
h = self.norm(h)
|
| 287 |
+
output = self.output(h).float()
|
| 288 |
+
return output
|
llama/tokenizer.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
from logging import getLogger
|
| 6 |
+
from typing import List
|
| 7 |
+
|
| 8 |
+
from sentencepiece import SentencePieceProcessor
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
logger = getLogger()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Tokenizer:
|
| 15 |
+
def __init__(self, model_path: str):
|
| 16 |
+
# reload tokenizer
|
| 17 |
+
assert os.path.isfile(model_path), model_path
|
| 18 |
+
self.sp_model = SentencePieceProcessor(model_file=model_path)
|
| 19 |
+
logger.info(f"Reloaded SentencePiece model from {model_path}")
|
| 20 |
+
|
| 21 |
+
# BOS / EOS token IDs
|
| 22 |
+
self.n_words: int = self.sp_model.vocab_size()
|
| 23 |
+
self.bos_id: int = self.sp_model.bos_id()
|
| 24 |
+
self.eos_id: int = self.sp_model.eos_id()
|
| 25 |
+
self.pad_id: int = self.sp_model.pad_id()
|
| 26 |
+
logger.info(
|
| 27 |
+
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
|
| 28 |
+
)
|
| 29 |
+
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
|
| 30 |
+
|
| 31 |
+
def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
|
| 32 |
+
assert type(s) is str
|
| 33 |
+
t = self.sp_model.encode(s)
|
| 34 |
+
if bos:
|
| 35 |
+
t = [self.bos_id] + t
|
| 36 |
+
if eos:
|
| 37 |
+
t = t + [self.eos_id]
|
| 38 |
+
return t
|
| 39 |
+
|
| 40 |
+
def decode(self, t: List[int]) -> str:
|
| 41 |
+
return self.sp_model.decode(t)
|