Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- Ovis/docs/license/GEMMA_LICENSE.txt +77 -0
- Ovis/docs/license/LLAMA3_LICENSE +84 -0
- Ovis/ovis/__pycache__/__init__.cpython-310.pyc +0 -0
- Ovis/ovis/__pycache__/__init__.cpython-311.pyc +0 -0
- Ovis/ovis/model/__pycache__/__init__.cpython-310.pyc +0 -0
- Ovis/ovis/model/__pycache__/__init__.cpython-311.pyc +0 -0
- Ovis/ovis/model/__pycache__/configuration_ovis.cpython-311.pyc +0 -0
- Ovis/ovis/model/__pycache__/modeling_ovis.cpython-311.pyc +0 -0
- Ovis/ovis/model/configuration_ovis.py +41 -0
- Ovis/ovis/model/conversation_formatter.py +233 -0
- Ovis/ovis/model/visual_tokenizer/__pycache__/base_visual_tokenizer.cpython-310.pyc +0 -0
- Ovis/ovis/model/visual_tokenizer/__pycache__/base_visual_tokenizer.cpython-311.pyc +0 -0
- Ovis/ovis/model/visual_tokenizer/__pycache__/clip_visual_tokenizer.cpython-310.pyc +0 -0
- Ovis/ovis/model/visual_tokenizer/__pycache__/clip_visual_tokenizer.cpython-311.pyc +0 -0
- Ovis/ovis/model/visual_tokenizer/__pycache__/siglip_visual_tokenizer.cpython-310.pyc +0 -0
- Ovis/ovis/model/visual_tokenizer/__pycache__/siglip_visual_tokenizer.cpython-311.pyc +0 -0
- Ovis/ovis/serve/runner.py +105 -0
- Ovis/ovis/serve/server.py +41 -0
- Ovis/ovis/train/__init__.py +0 -0
- Ovis/ovis/train/arguments.py +48 -0
- Ovis/ovis/train/callback.py +37 -0
- Ovis/ovis/train/train.py +206 -0
- Ovis/ovis/util/constants.py +11 -0
- Ovis/ovis/util/utils.py +26 -0
- llm2vec/docs/.gitignore +9 -0
- llm2vec/docs/Gemfile +18 -0
- llm2vec/docs/README.md +104 -0
- llm2vec/docs/_config.yml +110 -0
- llm2vec/docs/_data/navigation.yml +17 -0
- llm2vec/docs/_includes/head/custom.html +48 -0
- llm2vec/docs/_sass/custom/header-footer.scss +19 -0
- llm2vec/docs/_sass/custom/no-sidebar.scss +9 -0
- llm2vec/docs/_sass/custom/splash.scss +5 -0
- llm2vec/docs/_sass/skins/dark.scss +30 -0
- llm2vec/docs/_sass/skins/light.scss +12 -0
- llm2vec/docs/assets/images/logo/favicon.png +0 -0
- llm2vec/docs/assets/images/logo/logo.png +0 -0
- llm2vec/docs/assets/images/logo/logo.svg +0 -0
- llm2vec/examples/classification.py +62 -0
- llm2vec/examples/clustering.py +58 -0
- llm2vec/examples/retrieval.py +177 -0
- llm2vec/examples/sts.py +57 -0
- llm2vec/experiments/mteb_eval.py +31 -0
- llm2vec/experiments/mteb_eval_custom.py +98 -0
- llm2vec/experiments/run_mntp.py +997 -0
- llm2vec/experiments/run_simcse.py +388 -0
- llm2vec/experiments/run_supervised.py +482 -0
- llm2vec/experiments/run_word_task.py +905 -0
- llm2vec/experiments/test_word_task.py +393 -0
- llm2vec/images/sample_efficient.png +0 -0
Ovis/docs/license/GEMMA_LICENSE.txt
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Gemma Terms of Use
|
| 2 |
+
|
| 3 |
+
Last modified: April 1, 2024
|
| 4 |
+
|
| 5 |
+
By using, reproducing, modifying, distributing, performing or displaying any portion or element of Gemma, Model Derivatives including via any Hosted Service, (each as defined below) (collectively, the "Gemma Services") or otherwise accepting the terms of this Agreement, you agree to be bound by this Agreement.
|
| 6 |
+
|
| 7 |
+
Section 1: DEFINITIONS
|
| 8 |
+
1.1 Definitions
|
| 9 |
+
(a) "Agreement" or "Gemma Terms of Use" means these terms and conditions that govern the use, reproduction, Distribution or modification of the Gemma Services and any terms and conditions incorporated by reference.
|
| 10 |
+
|
| 11 |
+
(b) "Distribution" or "Distribute" means any transmission, publication, or other sharing of Gemma or Model Derivatives to a third party, including by providing or making Gemma or its functionality available as a hosted service via API, web access, or any other electronic or remote means ("Hosted Service").
|
| 12 |
+
|
| 13 |
+
(c) "Gemma" means the set of machine learning language models, trained model weights and parameters identified at ai.google.dev/gemma, regardless of the source that you obtained it from.
|
| 14 |
+
|
| 15 |
+
(d) "Google" means Google LLC.
|
| 16 |
+
|
| 17 |
+
(e) "Model Derivatives" means all (i) modifications to Gemma, (ii) works based on Gemma, or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Gemma, to that model in order to cause that model to perform similarly to Gemma, including distillation methods that use intermediate data representations or methods based on the generation of synthetic data Outputs by Gemma for training that model. For clarity, Outputs are not deemed Model Derivatives.
|
| 18 |
+
|
| 19 |
+
(f) "Output" means the information content output of Gemma or a Model Derivative that results from operating or otherwise using Gemma or the Model Derivative, including via a Hosted Service.
|
| 20 |
+
|
| 21 |
+
1.2
|
| 22 |
+
As used in this Agreement, "including" means "including without limitation".
|
| 23 |
+
|
| 24 |
+
Section 2: ELIGIBILITY AND USAGE
|
| 25 |
+
2.1 Eligibility
|
| 26 |
+
You represent and warrant that you have the legal capacity to enter into this Agreement (including being of sufficient age of consent). If you are accessing or using any of the Gemma Services for or on behalf of a legal entity, (a) you are entering into this Agreement on behalf of yourself and that legal entity, (b) you represent and warrant that you have the authority to act on behalf of and bind that entity to this Agreement and (c) references to "you" or "your" in the remainder of this Agreement refers to both you (as an individual) and that entity.
|
| 27 |
+
|
| 28 |
+
2.2 Use
|
| 29 |
+
You may use, reproduce, modify, Distribute, perform or display any of the Gemma Services only in accordance with the terms of this Agreement, and must not violate (or encourage or permit anyone else to violate) any term of this Agreement.
|
| 30 |
+
|
| 31 |
+
Section 3: DISTRIBUTION AND RESTRICTIONS
|
| 32 |
+
3.1 Distribution and Redistribution
|
| 33 |
+
You may reproduce or Distribute copies of Gemma or Model Derivatives if you meet all of the following conditions:
|
| 34 |
+
|
| 35 |
+
You must include the use restrictions referenced in Section 3.2 as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Gemma or Model Derivatives and you must provide notice to subsequent users you Distribute to that Gemma or Model Derivatives are subject to the use restrictions in Section 3.2.
|
| 36 |
+
You must provide all third party recipients of Gemma or Model Derivatives a copy of this Agreement.
|
| 37 |
+
You must cause any modified files to carry prominent notices stating that you modified the files.
|
| 38 |
+
All Distributions (other than through a Hosted Service) must be accompanied by a "Notice" text file that contains the following notice: "Gemma is provided under and subject to the Gemma Terms of Use found at ai.google.dev/gemma/terms".
|
| 39 |
+
You may add your own intellectual property statement to your modifications and, except as set forth in this Section, may provide additional or different terms and conditions for use, reproduction, or Distribution of your modifications, or for any such Model Derivatives as a whole, provided your use, reproduction, modification, Distribution, performance, and display of Gemma otherwise complies with the terms and conditions of this Agreement. Any additional or different terms and conditions you impose must not conflict with the terms of this Agreement.
|
| 40 |
+
|
| 41 |
+
3.2 Use Restrictions
|
| 42 |
+
You must not use any of the Gemma Services:
|
| 43 |
+
|
| 44 |
+
for the restricted uses set forth in the Gemma Prohibited Use Policy at ai.google.dev/gemma/prohibited_use_policy ("Prohibited Use Policy"), which is hereby incorporated by reference into this Agreement; or
|
| 45 |
+
in violation of applicable laws and regulations.
|
| 46 |
+
To the maximum extent permitted by law, Google reserves the right to restrict (remotely or otherwise) usage of any of the Gemma Services that Google reasonably believes are in violation of this Agreement.
|
| 47 |
+
|
| 48 |
+
3.3 Generated Output
|
| 49 |
+
Google claims no rights in Outputs you generate using Gemma. You and your users are solely responsible for Outputs and their subsequent uses.
|
| 50 |
+
|
| 51 |
+
Section 4: ADDITIONAL PROVISIONS
|
| 52 |
+
4.1 Updates
|
| 53 |
+
Google may update Gemma from time to time.
|
| 54 |
+
|
| 55 |
+
4.2 Trademarks
|
| 56 |
+
Nothing in this Agreement grants you any rights to use Google's trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between you and Google. Google reserves any rights not expressly granted herein.
|
| 57 |
+
|
| 58 |
+
4.3 DISCLAIMER OF WARRANTY
|
| 59 |
+
UNLESS REQUIRED BY APPLICABLE LAW, THE GEMMA SERVICES, AND OUTPUTS, ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING ANY WARRANTIES OR CONDITIONS OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE GEMMA SERVICES OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR USE OR DISTRIBUTION OF ANY OF THE GEMMA SERVICES OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
|
| 60 |
+
|
| 61 |
+
4.4 LIMITATION OF LIABILITY
|
| 62 |
+
TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), PRODUCT LIABILITY, CONTRACT, OR OTHERWISE, UNLESS REQUIRED BY APPLICABLE LAW, SHALL GOOGLE OR ITS AFFILIATES BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL, OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO, ANY OF THE GEMMA SERVICES OR OUTPUTS EVEN IF GOOGLE OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
|
| 63 |
+
|
| 64 |
+
4.5 Term, Termination, and Survival
|
| 65 |
+
The term of this Agreement will commence upon your acceptance of this Agreement (including acceptance by your use, modification, or Distribution, reproduction, performance or display of any portion or element of the Gemma Services) and will continue in full force and effect until terminated in accordance with the terms of this Agreement. Google may terminate this Agreement if you are in breach of any term of this Agreement. Upon termination of this Agreement, you must delete and cease use and Distribution of all copies of Gemma and Model Derivatives in your possession or control. Sections 1, 2.1, 3.3, 4.2 to 4.9 shall survive the termination of this Agreement.
|
| 66 |
+
|
| 67 |
+
4.6 Governing Law and Jurisdiction
|
| 68 |
+
This Agreement will be governed by the laws of the State of California without regard to choice of law principles. The UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The state and federal courts of Santa Clara County, California shall have exclusive jurisdiction of any dispute arising out of this Agreement.
|
| 69 |
+
|
| 70 |
+
4.7 Severability
|
| 71 |
+
If any provision of this Agreement is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.
|
| 72 |
+
|
| 73 |
+
4.8 Entire Agreement
|
| 74 |
+
This Agreement states all the terms agreed between the parties and supersedes all other agreements between the parties as of the date of acceptance relating to its subject matter.
|
| 75 |
+
|
| 76 |
+
4.9 No Waiver
|
| 77 |
+
Google will not be treated as having waived any rights by not exercising (or delaying the exercise of) any rights under this Agreement.
|
Ovis/docs/license/LLAMA3_LICENSE
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
META LLAMA 3 COMMUNITY LICENSE AGREEMENT
|
| 2 |
+
|
| 3 |
+
Meta Llama 3 Version Release Date: April 18, 2024
|
| 4 |
+
“Agreement” means the terms and conditions for use, reproduction, distribution and modification of the Llama Materials set forth herein.
|
| 5 |
+
|
| 6 |
+
“Documentation” means the specifications, manuals and documentation accompanying Meta Llama 3 distributed by Meta at https://llama.meta.com/get-started/.
|
| 7 |
+
|
| 8 |
+
“Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
|
| 9 |
+
|
| 10 |
+
“Meta Llama 3” means the foundational large language models and software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing distributed by Meta at https://llama.meta.com/llama-downloads.
|
| 11 |
+
|
| 12 |
+
“Llama Materials” means, collectively, Meta’s proprietary Meta Llama 3 and Documentation (and any portion thereof) made available under this Agreement.
|
| 13 |
+
|
| 14 |
+
“Meta” or “we” means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) and Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland).
|
| 15 |
+
|
| 16 |
+
By clicking “I Accept” below or by using or distributing any portion or element of the Llama Materials, you agree to be bound by this Agreement.
|
| 17 |
+
|
| 18 |
+
1. License Rights and Redistribution.
|
| 19 |
+
|
| 20 |
+
a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the Llama Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Llama Materials.
|
| 21 |
+
b. Redistribution and Use.
|
| 22 |
+
i. If you distribute or make available the Llama Materials (or any derivative works thereof), or a product or service that uses any of them, including another AI model, you shall (A) provide a copy of this Agreement with any such Llama Materials; and (B) prominently display “Built with Meta Llama 3” on a related website, user interface, blogpost, about page, or product documentation. If you use the Llama Materials to create, train, fine tune, or otherwise improve an AI model, which is distributed or made available, you shall also include “Llama 3” at the beginning of any such AI model name.
|
| 23 |
+
ii. If you receive Llama Materials, or any derivative works thereof, from a Licensee as part of an integrated end user product, then Section 2 of this Agreement will not apply to you.
|
| 24 |
+
iii. You must retain in all copies of the Llama Materials that you distribute the following attribution notice within a “Notice” text file distributed as a part of such copies: “Meta Llama 3 is licensed under the Meta Llama 3 Community License, Copyright © Meta Platforms, Inc. All Rights Reserved.”
|
| 25 |
+
iv. Your use of the Llama Materials must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Llama Materials (available at https://llama.meta.com/llama3/use-policy), which is hereby incorporated by reference into this Agreement.
|
| 26 |
+
v. You will not use the Llama Materials or any output or results of the Llama Materials to improve any other large language model (excluding Meta Llama 3 or derivative works thereof).
|
| 27 |
+
|
| 28 |
+
2. Additional Commercial Terms. If, on the Meta Llama 3 version release date, the monthly active users of the products or services made available by or for Licensee, or Licensee’s affiliates, is greater than 700 million monthly active users in the preceding calendar month, you must request a license from Meta, which Meta may grant to you in its sole discretion, and you are not authorized to exercise any of the rights under this Agreement unless or until Meta otherwise expressly grants you such rights.
|
| 29 |
+
|
| 30 |
+
3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE LLAMA MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE LLAMA MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE LLAMA MATERIALS AND ANY OUTPUT AND RESULTS.
|
| 31 |
+
|
| 32 |
+
4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
| 33 |
+
|
| 34 |
+
5. Intellectual Property.
|
| 35 |
+
a. No trademark licenses are granted under this Agreement, and in connection with the Llama Materials, neither Meta nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Llama Materials or as set forth in this Section 5(a). Meta hereby grants you a license to use “Llama 3” (the “Mark”) solely as required to comply with the last sentence of Section 1.b.i. You will comply with Meta’s brand guidelines (currently accessible at https://about.meta.com/brand/resources/meta/company-brand/ ). All goodwill arising out of your use of the Mark will inure to the benefit of Meta.
|
| 36 |
+
b. Subject to Meta’s ownership of Llama Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the Llama Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications.
|
| 37 |
+
c. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Llama Materials or Meta Llama 3 outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the Llama Materials.
|
| 38 |
+
|
| 39 |
+
6. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Llama Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Llama Materials. Sections 3, 4 and 7 shall survive the termination of this Agreement.
|
| 40 |
+
|
| 41 |
+
7. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement.
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
Meta Llama 3 Acceptable Use Policy
|
| 45 |
+
Meta is committed to promoting safe and fair use of its tools and features, including Meta Llama 3. If you access or use Meta Llama 3, you agree to this Acceptable Use Policy (“Policy”). The most recent copy of this policy can be found at https://llama.meta.com/llama3/use-policy
|
| 46 |
+
Prohibited Uses
|
| 47 |
+
We want everyone to use Meta Llama 3 safely and responsibly. You agree you will not use, or allow others to use, Meta Llama 3 to:
|
| 48 |
+
1. Violate the law or others’ rights, including to:
|
| 49 |
+
a. Engage in, promote, generate, contribute to, encourage, plan, incite, or further illegal or unlawful activity or content, such as:
|
| 50 |
+
i. Violence or terrorism
|
| 51 |
+
ii. Exploitation or harm to children, including the solicitation, creation, acquisition, or dissemination of child exploitative content or failure to report Child Sexual Abuse Material
|
| 52 |
+
iii. Human trafficking, exploitation, and sexual violence
|
| 53 |
+
iv. The illegal distribution of information or materials to minors, including obscene materials, or failure to employ legally required age-gating in connection with such information or materials.
|
| 54 |
+
v. Sexual solicitation
|
| 55 |
+
vi. Any other criminal activity
|
| 56 |
+
b. Engage in, promote, incite, or facilitate the harassment, abuse, threatening, or bullying of individuals or groups of individuals
|
| 57 |
+
c. Engage in, promote, incite, or facilitate discrimination or other unlawful or harmful conduct in the provision of employment, employment benefits, credit, housing, other economic benefits, or other essential goods and services
|
| 58 |
+
d. Engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or related professional practices
|
| 59 |
+
e. Collect, process, disclose, generate, or infer health, demographic, or other sensitive personal or private information about individuals without rights and consents required by applicable laws
|
| 60 |
+
f. Engage in or facilitate any action or generate any content that infringes, misappropriates, or otherwise violates any third-party rights, including the outputs or results of any products or services using the Llama Materials
|
| 61 |
+
g. Create, generate, or facilitate the creation of malicious code, malware, computer viruses or do anything else that could disable, overburden, interfere with or impair the proper working, integrity, operation or appearance of a website or computer system
|
| 62 |
+
|
| 63 |
+
2. Engage in, promote, incite, facilitate, or assist in the planning or development of activities that present a risk of death or bodily harm to individuals, including use of Meta Llama 3 related to the following:
|
| 64 |
+
a. Military, warfare, nuclear industries or applications, espionage, use for materials or activities that are subject to the International Traffic Arms Regulations (ITAR) maintained by the United States Department of State
|
| 65 |
+
b. Guns and illegal weapons (including weapon development)
|
| 66 |
+
c. Illegal drugs and regulated/controlled substances
|
| 67 |
+
d. Operation of critical infrastructure, transportation technologies, or heavy machinery
|
| 68 |
+
e. Self-harm or harm to others, including suicide, cutting, and eating disorders
|
| 69 |
+
f. Any content intended to incite or promote violence, abuse, or any infliction of bodily harm to an individual
|
| 70 |
+
|
| 71 |
+
3. Intentionally deceive or mislead others, including use of Meta Llama 3 related to the following:
|
| 72 |
+
a. Generating, promoting, or furthering fraud or the creation or promotion of disinformation
|
| 73 |
+
b. Generating, promoting, or furthering defamatory content, including the creation of defamatory statements, images, or other content
|
| 74 |
+
c. Generating, promoting, or further distributing spam
|
| 75 |
+
d. Impersonating another individual without consent, authorization, or legal right
|
| 76 |
+
e. Representing that the use of Meta Llama 3 or outputs are human-generated
|
| 77 |
+
f. Generating or facilitating false online engagement, including fake reviews and other means of fake online engagement
|
| 78 |
+
g. Fail to appropriately disclose to end users any known dangers of your AI system
|
| 79 |
+
|
| 80 |
+
Please report any violation of this Policy, software “bug,” or other problems that could lead to a violation of this Policy through one of the following means:
|
| 81 |
+
* Reporting issues with the model: https://github.com/meta-llama/llama3
|
| 82 |
+
* Reporting risky content generated by the model: developers.facebook.com/llama_output_feedback
|
| 83 |
+
* Reporting bugs and security concerns: facebook.com/whitehat/info
|
| 84 |
+
* Reporting violations of the Acceptable Use Policy or unlicensed uses of Meta Llama 3: LlamaUseReport@meta.com
|
Ovis/ovis/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (206 Bytes). View file
|
|
|
Ovis/ovis/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (248 Bytes). View file
|
|
|
Ovis/ovis/model/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (388 Bytes). View file
|
|
|
Ovis/ovis/model/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (446 Bytes). View file
|
|
|
Ovis/ovis/model/__pycache__/configuration_ovis.cpython-311.pyc
ADDED
|
Binary file (2.53 kB). View file
|
|
|
Ovis/ovis/model/__pycache__/modeling_ovis.cpython-311.pyc
ADDED
|
Binary file (29.4 kB). View file
|
|
|
Ovis/ovis/model/configuration_ovis.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Union, Optional
|
| 2 |
+
|
| 3 |
+
from transformers import PretrainedConfig, AutoConfig
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class OvisConfig(PretrainedConfig):
|
| 7 |
+
model_type = "ovis"
|
| 8 |
+
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
llm_config: Optional[Union[PretrainedConfig, dict]] = None,
|
| 12 |
+
visual_tokenizer_config: Optional[Union[PretrainedConfig, dict]] = None,
|
| 13 |
+
multimodal_max_length=8192,
|
| 14 |
+
hidden_size=None,
|
| 15 |
+
conversation_formatter_class=None,
|
| 16 |
+
llm_attn_implementation=None,
|
| 17 |
+
disable_tie_weight=False,
|
| 18 |
+
**kwargs
|
| 19 |
+
):
|
| 20 |
+
super().__init__(**kwargs)
|
| 21 |
+
if llm_config is not None:
|
| 22 |
+
assert isinstance(llm_config, (PretrainedConfig, dict)), \
|
| 23 |
+
f"expect `llm_config` to be instance of PretrainedConfig or dict, but got {type(llm_config)} type"
|
| 24 |
+
if not isinstance(llm_config, PretrainedConfig):
|
| 25 |
+
model_type = llm_config['model_type']
|
| 26 |
+
llm_config.pop('model_type')
|
| 27 |
+
llm_config = AutoConfig.for_model(model_type, **llm_config)
|
| 28 |
+
self.llm_config = llm_config
|
| 29 |
+
if visual_tokenizer_config is not None:
|
| 30 |
+
assert isinstance(visual_tokenizer_config, (PretrainedConfig, dict)), \
|
| 31 |
+
f"expect `visual_tokenizer_config` to be instance of PretrainedConfig or dict, but got {type(visual_tokenizer_config)} type"
|
| 32 |
+
if not isinstance(visual_tokenizer_config, PretrainedConfig):
|
| 33 |
+
model_type = visual_tokenizer_config['model_type']
|
| 34 |
+
visual_tokenizer_config.pop('model_type')
|
| 35 |
+
visual_tokenizer_config = AutoConfig.for_model(model_type, **visual_tokenizer_config)
|
| 36 |
+
self.visual_tokenizer_config = visual_tokenizer_config
|
| 37 |
+
self.multimodal_max_length = multimodal_max_length
|
| 38 |
+
self.hidden_size = hidden_size
|
| 39 |
+
self.conversation_formatter_class = conversation_formatter_class
|
| 40 |
+
self.llm_attn_implementation = llm_attn_implementation
|
| 41 |
+
self.disable_tie_weight = disable_tie_weight
|
Ovis/ovis/model/conversation_formatter.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import List, Dict
|
| 3 |
+
|
| 4 |
+
from ovis.util.constants import IMAGE_TOKEN_ID, IGNORE_ID, IMAGE_TOKEN
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ConversationFormatter(ABC):
|
| 8 |
+
support_tokenizer_types = None
|
| 9 |
+
|
| 10 |
+
def __init__(self, tokenizer):
|
| 11 |
+
tokenizer_type = type(tokenizer).__name__
|
| 12 |
+
assert tokenizer_type in self.support_tokenizer_types, \
|
| 13 |
+
f'Invalid tokenizer type, expected one from `{self.support_tokenizer_types}`, but got `{tokenizer_type}`'
|
| 14 |
+
self.tokenizer = tokenizer
|
| 15 |
+
self.image_token = IMAGE_TOKEN
|
| 16 |
+
self.image_token_id = IMAGE_TOKEN_ID
|
| 17 |
+
self.ignore_id = IGNORE_ID
|
| 18 |
+
|
| 19 |
+
def _tokenize_with_image_symbol(self, text):
|
| 20 |
+
text_chunks = [self.tokenizer(chunk, add_special_tokens=False).input_ids for chunk in
|
| 21 |
+
text.split(self.image_token)]
|
| 22 |
+
token_ids = []
|
| 23 |
+
num_chuck = len(text_chunks)
|
| 24 |
+
for i, chunk in enumerate(text_chunks):
|
| 25 |
+
token_ids.extend(chunk)
|
| 26 |
+
if i < num_chuck - 1:
|
| 27 |
+
token_ids.append(self.image_token_id)
|
| 28 |
+
return token_ids
|
| 29 |
+
|
| 30 |
+
@abstractmethod
|
| 31 |
+
def format(self, conversations: List[Dict], generation_preface=None):
|
| 32 |
+
pass
|
| 33 |
+
|
| 34 |
+
@abstractmethod
|
| 35 |
+
def format_query(self, query, generation_preface=""):
|
| 36 |
+
pass
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class QwenConversationFormatter(ConversationFormatter):
|
| 40 |
+
support_tokenizer_types = ['QWenTokenizer', 'Qwen2TokenizerFast']
|
| 41 |
+
|
| 42 |
+
def __init__(self, tokenizer):
|
| 43 |
+
super().__init__(tokenizer)
|
| 44 |
+
self.from2role = {
|
| 45 |
+
"system": "<|im_start|>system\n",
|
| 46 |
+
"human": "<|im_start|>user\n",
|
| 47 |
+
"gpt": "<|im_start|>assistant\n",
|
| 48 |
+
}
|
| 49 |
+
self.gpt_token_num = None
|
| 50 |
+
self.im_end = "<|im_end|>\n"
|
| 51 |
+
self.default_system_prompt = "You are a helpful assistant."
|
| 52 |
+
|
| 53 |
+
def format(self, conversations: List[Dict], generation_preface=None):
|
| 54 |
+
if self.gpt_token_num is None:
|
| 55 |
+
self.gpt_token_num = len(self.tokenizer(self.from2role["gpt"], add_special_tokens=False).input_ids)
|
| 56 |
+
|
| 57 |
+
if conversations[0]["from"] != "system":
|
| 58 |
+
conversations.insert(0, {
|
| 59 |
+
"from": "system",
|
| 60 |
+
"value": self.default_system_prompt
|
| 61 |
+
})
|
| 62 |
+
|
| 63 |
+
if generation_preface is not None:
|
| 64 |
+
conversations.append({
|
| 65 |
+
"from": "gpt",
|
| 66 |
+
"value": generation_preface
|
| 67 |
+
})
|
| 68 |
+
|
| 69 |
+
prompt = ""
|
| 70 |
+
input_ids = []
|
| 71 |
+
labels = []
|
| 72 |
+
num_conversation = len(conversations)
|
| 73 |
+
for i, conversation in enumerate(conversations):
|
| 74 |
+
frm = conversation["from"]
|
| 75 |
+
role = self.from2role[frm]
|
| 76 |
+
message = conversation["value"]
|
| 77 |
+
text = role + message
|
| 78 |
+
if i < num_conversation - 1 or generation_preface is None:
|
| 79 |
+
text += self.im_end
|
| 80 |
+
prompt += text
|
| 81 |
+
token_ids = self._tokenize_with_image_symbol(text)
|
| 82 |
+
input_ids.extend(token_ids)
|
| 83 |
+
label_ids = [self.ignore_id] * len(token_ids)
|
| 84 |
+
if frm == "gpt" and generation_preface is None:
|
| 85 |
+
# learning `\n` following `im_end` is meaningless, so the last `\n` token is ignored in label
|
| 86 |
+
label_ids[self.gpt_token_num:-1] = token_ids[self.gpt_token_num:-1]
|
| 87 |
+
labels.extend(label_ids)
|
| 88 |
+
|
| 89 |
+
assert self._tokenize_with_image_symbol(prompt) == input_ids
|
| 90 |
+
assert len(input_ids) == len(labels)
|
| 91 |
+
|
| 92 |
+
return prompt, input_ids, labels
|
| 93 |
+
|
| 94 |
+
def format_query(self, query, generation_preface=""):
|
| 95 |
+
prompt, input_ids, _ = self.format([{
|
| 96 |
+
"from": "human",
|
| 97 |
+
"value": query
|
| 98 |
+
}], generation_preface=generation_preface)
|
| 99 |
+
|
| 100 |
+
return prompt, input_ids
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class Llama3ConversationFormatter(ConversationFormatter):
|
| 104 |
+
support_tokenizer_types = ['PreTrainedTokenizerFast']
|
| 105 |
+
|
| 106 |
+
def __init__(self, tokenizer):
|
| 107 |
+
super().__init__(tokenizer)
|
| 108 |
+
self.from2role = {
|
| 109 |
+
"system": "<|start_header_id|>system<|end_header_id|>\n\n",
|
| 110 |
+
"human": "<|start_header_id|>user<|end_header_id|>\n\n",
|
| 111 |
+
"gpt": "<|start_header_id|>assistant<|end_header_id|>\n\n",
|
| 112 |
+
}
|
| 113 |
+
self.gpt_token_num = None
|
| 114 |
+
self.im_end = "<|eot_id|>"
|
| 115 |
+
self.default_system_prompt = "You are a helpful and honest multimodal assistant."
|
| 116 |
+
self.bos_token = "<|begin_of_text|>"
|
| 117 |
+
self.bos_token_ids = None
|
| 118 |
+
|
| 119 |
+
def format(self, conversations: List[Dict], generation_preface=None):
|
| 120 |
+
if self.gpt_token_num is None:
|
| 121 |
+
self.gpt_token_num = len(self.tokenizer(self.from2role["gpt"], add_special_tokens=False).input_ids)
|
| 122 |
+
|
| 123 |
+
if self.bos_token_ids is None:
|
| 124 |
+
self.bos_token_ids = self.tokenizer(self.bos_token, add_special_tokens=False).input_ids
|
| 125 |
+
|
| 126 |
+
if conversations[0]["from"] != "system":
|
| 127 |
+
conversations.insert(0, {
|
| 128 |
+
"from": "system",
|
| 129 |
+
"value": self.default_system_prompt
|
| 130 |
+
})
|
| 131 |
+
|
| 132 |
+
if generation_preface is not None:
|
| 133 |
+
conversations.append({
|
| 134 |
+
"from": "gpt",
|
| 135 |
+
"value": generation_preface
|
| 136 |
+
})
|
| 137 |
+
|
| 138 |
+
prompt = "" + self.bos_token
|
| 139 |
+
input_ids = [] + self.bos_token_ids
|
| 140 |
+
labels = [] + [IGNORE_ID] * len(input_ids)
|
| 141 |
+
num_conversation = len(conversations)
|
| 142 |
+
for i, conversation in enumerate(conversations):
|
| 143 |
+
frm = conversation["from"]
|
| 144 |
+
role = self.from2role[frm]
|
| 145 |
+
message = conversation["value"].strip()
|
| 146 |
+
text = role + message
|
| 147 |
+
if i < num_conversation - 1 or generation_preface is None:
|
| 148 |
+
text += self.im_end
|
| 149 |
+
prompt += text
|
| 150 |
+
token_ids = self._tokenize_with_image_symbol(text)
|
| 151 |
+
input_ids.extend(token_ids)
|
| 152 |
+
label_ids = [self.ignore_id] * len(token_ids)
|
| 153 |
+
if frm == "gpt":
|
| 154 |
+
label_ids[self.gpt_token_num:] = token_ids[self.gpt_token_num:]
|
| 155 |
+
labels.extend(label_ids)
|
| 156 |
+
|
| 157 |
+
assert self._tokenize_with_image_symbol(prompt) == input_ids
|
| 158 |
+
assert len(input_ids) == len(labels)
|
| 159 |
+
|
| 160 |
+
return prompt, input_ids, labels
|
| 161 |
+
|
| 162 |
+
def format_query(self, query, generation_preface=""):
|
| 163 |
+
prompt, input_ids, _ = self.format([{
|
| 164 |
+
"from": "human",
|
| 165 |
+
"value": query
|
| 166 |
+
}], generation_preface=generation_preface)
|
| 167 |
+
|
| 168 |
+
return prompt, input_ids
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class GemmaConversationFormatter(ConversationFormatter):
|
| 172 |
+
support_tokenizer_types = ['GemmaTokenizer', 'GemmaTokenizerFast']
|
| 173 |
+
|
| 174 |
+
def __init__(self, tokenizer):
|
| 175 |
+
super().__init__(tokenizer)
|
| 176 |
+
# Gemma does not support system prompt
|
| 177 |
+
self.from2role = {
|
| 178 |
+
"human": "<start_of_turn>user\n",
|
| 179 |
+
"gpt": "<start_of_turn>model\n",
|
| 180 |
+
}
|
| 181 |
+
self.gpt_token_num = None
|
| 182 |
+
self.im_end = "<end_of_turn>\n"
|
| 183 |
+
self.bos_token = "<bos>"
|
| 184 |
+
self.bos_token_ids = None
|
| 185 |
+
|
| 186 |
+
def format(self, conversations: List[Dict], generation_preface=None):
|
| 187 |
+
if self.gpt_token_num is None:
|
| 188 |
+
self.gpt_token_num = len(self.tokenizer(self.from2role["gpt"], add_special_tokens=False).input_ids)
|
| 189 |
+
|
| 190 |
+
if self.bos_token_ids is None:
|
| 191 |
+
self.bos_token_ids = self.tokenizer(self.bos_token, add_special_tokens=False).input_ids
|
| 192 |
+
|
| 193 |
+
if conversations[0]["from"] == "system":
|
| 194 |
+
raise ValueError("Gemma does not support system prompt")
|
| 195 |
+
|
| 196 |
+
if generation_preface is not None:
|
| 197 |
+
conversations.append({
|
| 198 |
+
"from": "gpt",
|
| 199 |
+
"value": generation_preface
|
| 200 |
+
})
|
| 201 |
+
|
| 202 |
+
prompt = "" + self.bos_token
|
| 203 |
+
input_ids = [] + self.bos_token_ids
|
| 204 |
+
labels = [] + [IGNORE_ID] * len(input_ids)
|
| 205 |
+
num_conversation = len(conversations)
|
| 206 |
+
for i, conversation in enumerate(conversations):
|
| 207 |
+
frm = conversation["from"]
|
| 208 |
+
role = self.from2role[frm]
|
| 209 |
+
message = conversation["value"].strip()
|
| 210 |
+
text = role + message
|
| 211 |
+
if i < num_conversation - 1 or generation_preface is None:
|
| 212 |
+
text += self.im_end
|
| 213 |
+
prompt += text
|
| 214 |
+
token_ids = self._tokenize_with_image_symbol(text)
|
| 215 |
+
input_ids.extend(token_ids)
|
| 216 |
+
label_ids = [self.ignore_id] * len(token_ids)
|
| 217 |
+
if frm == "gpt":
|
| 218 |
+
# learning `\n` following `im_end` is meaningless, so the last `\n` token is ignored in label
|
| 219 |
+
label_ids[self.gpt_token_num:-1] = token_ids[self.gpt_token_num:-1]
|
| 220 |
+
labels.extend(label_ids)
|
| 221 |
+
|
| 222 |
+
assert self._tokenize_with_image_symbol(prompt) == input_ids
|
| 223 |
+
assert len(input_ids) == len(labels)
|
| 224 |
+
|
| 225 |
+
return prompt, input_ids, labels
|
| 226 |
+
|
| 227 |
+
def format_query(self, query, generation_preface=""):
|
| 228 |
+
prompt, input_ids, _ = self.format([{
|
| 229 |
+
"from": "human",
|
| 230 |
+
"value": query
|
| 231 |
+
}], generation_preface=generation_preface)
|
| 232 |
+
|
| 233 |
+
return prompt, input_ids
|
Ovis/ovis/model/visual_tokenizer/__pycache__/base_visual_tokenizer.cpython-310.pyc
ADDED
|
Binary file (10.1 kB). View file
|
|
|
Ovis/ovis/model/visual_tokenizer/__pycache__/base_visual_tokenizer.cpython-311.pyc
ADDED
|
Binary file (18.3 kB). View file
|
|
|
Ovis/ovis/model/visual_tokenizer/__pycache__/clip_visual_tokenizer.cpython-310.pyc
ADDED
|
Binary file (2.03 kB). View file
|
|
|
Ovis/ovis/model/visual_tokenizer/__pycache__/clip_visual_tokenizer.cpython-311.pyc
ADDED
|
Binary file (3.03 kB). View file
|
|
|
Ovis/ovis/model/visual_tokenizer/__pycache__/siglip_visual_tokenizer.cpython-310.pyc
ADDED
|
Binary file (2.05 kB). View file
|
|
|
Ovis/ovis/model/visual_tokenizer/__pycache__/siglip_visual_tokenizer.cpython-311.pyc
ADDED
|
Binary file (3.06 kB). View file
|
|
|
Ovis/ovis/serve/runner.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import field, dataclass
|
| 2 |
+
from typing import Optional, Union, List
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
from ovis.model.modeling_ovis import Ovis
|
| 8 |
+
from ovis.util.constants import IMAGE_TOKEN
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class RunnerArguments:
|
| 13 |
+
model_path: str
|
| 14 |
+
max_new_tokens: int = field(default=512)
|
| 15 |
+
do_sample: bool = field(default=False)
|
| 16 |
+
top_p: Optional[float] = field(default=None)
|
| 17 |
+
top_k: Optional[int] = field(default=None)
|
| 18 |
+
temperature: Optional[float] = field(default=None)
|
| 19 |
+
max_partition: int = field(default=9)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class OvisRunner:
|
| 23 |
+
def __init__(self, args: RunnerArguments):
|
| 24 |
+
self.model_path = args.model_path
|
| 25 |
+
self.dtype = torch.bfloat16
|
| 26 |
+
self.device = torch.cuda.current_device()
|
| 27 |
+
self.dtype = torch.bfloat16
|
| 28 |
+
self.model = Ovis.from_pretrained(self.model_path, torch_dtype=self.dtype, multimodal_max_length=8192)
|
| 29 |
+
self.model = self.model.eval().to(device=self.device)
|
| 30 |
+
self.eos_token_id = self.model.generation_config.eos_token_id
|
| 31 |
+
self.text_tokenizer = self.model.get_text_tokenizer()
|
| 32 |
+
self.pad_token_id = self.text_tokenizer.pad_token_id
|
| 33 |
+
self.visual_tokenizer = self.model.get_visual_tokenizer()
|
| 34 |
+
self.conversation_formatter = self.model.get_conversation_formatter()
|
| 35 |
+
self.image_placeholder = IMAGE_TOKEN
|
| 36 |
+
self.max_partition = args.max_partition
|
| 37 |
+
self.gen_kwargs = dict(
|
| 38 |
+
max_new_tokens=args.max_new_tokens,
|
| 39 |
+
do_sample=args.do_sample,
|
| 40 |
+
top_p=args.top_p,
|
| 41 |
+
top_k=args.top_k,
|
| 42 |
+
temperature=args.temperature,
|
| 43 |
+
repetition_penalty=None,
|
| 44 |
+
eos_token_id=self.eos_token_id,
|
| 45 |
+
pad_token_id=self.pad_token_id,
|
| 46 |
+
use_cache=True
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
def preprocess(self, inputs: List[Union[Image.Image, str]]):
|
| 50 |
+
# for single image and single text inputs, ensure image ahead
|
| 51 |
+
if len(inputs) == 2 and isinstance(inputs[0], str) and isinstance(inputs[1], Image.Image):
|
| 52 |
+
inputs = reversed(inputs)
|
| 53 |
+
|
| 54 |
+
# build query
|
| 55 |
+
query = ''
|
| 56 |
+
images = []
|
| 57 |
+
for data in inputs:
|
| 58 |
+
if isinstance(data, Image.Image):
|
| 59 |
+
query += self.image_placeholder + '\n'
|
| 60 |
+
images.append(data)
|
| 61 |
+
elif isinstance(data, str):
|
| 62 |
+
query += data.replace(self.image_placeholder, '')
|
| 63 |
+
elif data is not None:
|
| 64 |
+
raise RuntimeError(f'Invalid input type, expected `PIL.Image.Image` or `str`, but got {type(data)}')
|
| 65 |
+
|
| 66 |
+
# format conversation
|
| 67 |
+
prompt, input_ids, pixel_values = self.model.preprocess_inputs(
|
| 68 |
+
query, images, max_partition=self.max_partition)
|
| 69 |
+
attention_mask = torch.ne(input_ids, self.text_tokenizer.pad_token_id)
|
| 70 |
+
input_ids = input_ids.unsqueeze(0).to(device=self.device)
|
| 71 |
+
attention_mask = attention_mask.unsqueeze(0).to(device=self.device)
|
| 72 |
+
if pixel_values is not None:
|
| 73 |
+
pixel_values = [pixel_values.to(device=self.device, dtype=self.dtype)]
|
| 74 |
+
else:
|
| 75 |
+
pixel_values = [None]
|
| 76 |
+
|
| 77 |
+
return prompt, input_ids, attention_mask, pixel_values
|
| 78 |
+
|
| 79 |
+
def run(self, inputs: List[Union[Image.Image, str]]):
|
| 80 |
+
prompt, input_ids, attention_mask, pixel_values = self.preprocess(inputs)
|
| 81 |
+
output_ids = self.model.generate(
|
| 82 |
+
input_ids,
|
| 83 |
+
pixel_values=pixel_values,
|
| 84 |
+
attention_mask=attention_mask,
|
| 85 |
+
**self.gen_kwargs
|
| 86 |
+
)
|
| 87 |
+
output = self.text_tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 88 |
+
input_token_len = input_ids.shape[1]
|
| 89 |
+
output_token_len = output_ids.shape[1]
|
| 90 |
+
response = dict(
|
| 91 |
+
prompt=prompt,
|
| 92 |
+
output=output,
|
| 93 |
+
prompt_tokens=input_token_len,
|
| 94 |
+
total_tokens=input_token_len + output_token_len
|
| 95 |
+
)
|
| 96 |
+
return response
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
if __name__ == '__main__':
|
| 100 |
+
runner_args = RunnerArguments(model_path='<model_path>')
|
| 101 |
+
runner = OvisRunner(runner_args)
|
| 102 |
+
image = Image.open('<image_path>')
|
| 103 |
+
text = '<prompt>'
|
| 104 |
+
response = runner.run([image, text])
|
| 105 |
+
print(response['output'])
|
Ovis/ovis/serve/server.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os.path
|
| 3 |
+
|
| 4 |
+
import gradio as gr
|
| 5 |
+
from gradio.components import Textbox, Image
|
| 6 |
+
|
| 7 |
+
from ovis.serve.runner import RunnerArguments, OvisRunner
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Server:
|
| 11 |
+
def __init__(self, runner: OvisRunner):
|
| 12 |
+
self.runner = runner
|
| 13 |
+
|
| 14 |
+
def __call__(self, image, text):
|
| 15 |
+
response = self.runner.run([image, text])
|
| 16 |
+
output = response["output"]
|
| 17 |
+
return output
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if __name__ == '__main__':
|
| 21 |
+
parser = argparse.ArgumentParser(description='Ovis Server')
|
| 22 |
+
parser.add_argument('--model_path', type=str, required=True)
|
| 23 |
+
parser.add_argument('--flagging_dir', type=str, default=os.path.expanduser('~/ovis-flagged'))
|
| 24 |
+
parser.add_argument('--max_partition', type=int, default=9)
|
| 25 |
+
parser.add_argument('--port', type=int, required=True)
|
| 26 |
+
args = parser.parse_args()
|
| 27 |
+
|
| 28 |
+
os.makedirs(args.flagging_dir, exist_ok=True)
|
| 29 |
+
runner_args = RunnerArguments(
|
| 30 |
+
model_path=args.model_path,
|
| 31 |
+
max_partition=args.max_partition
|
| 32 |
+
)
|
| 33 |
+
demo = gr.Interface(
|
| 34 |
+
fn=Server(OvisRunner(runner_args)),
|
| 35 |
+
inputs=[Image(type='pil', label='image'),
|
| 36 |
+
Textbox(placeholder='Enter your text here...', label='prompt')],
|
| 37 |
+
outputs=gr.Markdown(),
|
| 38 |
+
title=args.model_path.split('/')[-1],
|
| 39 |
+
flagging_dir=args.flagging_dir
|
| 40 |
+
)
|
| 41 |
+
demo.launch(server_port=args.port)
|
Ovis/ovis/train/__init__.py
ADDED
|
File without changes
|
Ovis/ovis/train/arguments.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import transformers
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class ModelArguments:
|
| 9 |
+
llm_name_or_path: Optional[str] = field(default=None)
|
| 10 |
+
visual_tokenizer_type: str = field(default=None)
|
| 11 |
+
visual_vocab_size: int = field(default=8192)
|
| 12 |
+
visual_drop_cls_token: bool = field(default=False)
|
| 13 |
+
visual_tokenize_function: str = field(default='softmax')
|
| 14 |
+
visual_tau: float = field(default=1.0)
|
| 15 |
+
visual_depths: Optional[str] = field(default=None)
|
| 16 |
+
visual_hidden_stride: int = field(default=1)
|
| 17 |
+
multimodal_max_length: int = field(default=2048)
|
| 18 |
+
conversation_formatter_class: str = field(default=None)
|
| 19 |
+
pad_token_id: Optional[int] = field(default=None)
|
| 20 |
+
llm_attn_implementation: Optional[str] = field(default=None)
|
| 21 |
+
disable_tie_weight: bool = field(default=False)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class TrainingArguments(transformers.TrainingArguments):
|
| 26 |
+
dataset_names: Optional[str] = field(default=None) # a|b|c
|
| 27 |
+
dataset_info: Optional[str] = field(default='dataset_info_v1_6')
|
| 28 |
+
ovis_pretrained_path: Optional[str] = field(default=None)
|
| 29 |
+
visual_tokenizer_pretrained_path: Optional[str] = field(default=None)
|
| 30 |
+
caption_template: Optional[str] = field(default=None)
|
| 31 |
+
stage: Optional[int] = field(default=None)
|
| 32 |
+
train_modules: Optional[str] = field(default=None)
|
| 33 |
+
cache_dir: Optional[str] = field(default=None)
|
| 34 |
+
optim: str = field(default="adamw_torch")
|
| 35 |
+
visual_max_tau: float = field(default=5.0)
|
| 36 |
+
visual_min_tau: float = field(default=0.05)
|
| 37 |
+
save_safetensors: bool = field(default=True)
|
| 38 |
+
monitor_step: int = field(default=100)
|
| 39 |
+
vte_re_init: bool = field(default=False)
|
| 40 |
+
text_max_length: int = field(default=1024)
|
| 41 |
+
max_partitions: str = field(default="9|1|1")
|
| 42 |
+
|
| 43 |
+
def __post_init__(self):
|
| 44 |
+
if self.gradient_checkpointing:
|
| 45 |
+
self.gradient_checkpointing_kwargs = {"use_reentrant": False}
|
| 46 |
+
if self.stage < 3:
|
| 47 |
+
self.save_safetensors = False
|
| 48 |
+
super().__post_init__()
|
Ovis/ovis/train/callback.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import deepspeed
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
|
| 4 |
+
|
| 5 |
+
from ovis.util.constants import END_LINE, BEGIN_LINE
|
| 6 |
+
from ovis.util.utils import rank0_print
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class TuneTauCallback(TrainerCallback):
|
| 10 |
+
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
| 11 |
+
visual_tokenizer = kwargs['model'].get_visual_tokenizer()
|
| 12 |
+
current_step = state.global_step
|
| 13 |
+
max_step = state.max_steps
|
| 14 |
+
ratio = current_step / max_step
|
| 15 |
+
visual_tokenizer.config.tau = args.visual_max_tau - (args.visual_max_tau - args.visual_min_tau) * ratio
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class MonitorCallback(TrainerCallback):
|
| 19 |
+
def _monitoring(self, model, step):
|
| 20 |
+
with torch.no_grad():
|
| 21 |
+
with deepspeed.zero.GatheredParameters(model.get_monitor_tensors().values()):
|
| 22 |
+
for k, v in model.get_monitor_tensors().items():
|
| 23 |
+
rank0_print(BEGIN_LINE)
|
| 24 |
+
rank0_print(f'{k} @ step {step} with sum: {v.sum().item()} and content: ')
|
| 25 |
+
rank0_print(v)
|
| 26 |
+
rank0_print(END_LINE)
|
| 27 |
+
|
| 28 |
+
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
| 29 |
+
model = kwargs['model']
|
| 30 |
+
step = state.global_step
|
| 31 |
+
if step % args.monitor_step == 0 or step == 10: # monitor at step 10 for fast check
|
| 32 |
+
self._monitoring(model, step)
|
| 33 |
+
|
| 34 |
+
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
| 35 |
+
model = kwargs['model']
|
| 36 |
+
step = state.global_step
|
| 37 |
+
self._monitoring(model, step)
|
Ovis/ovis/train/train.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import pathlib
|
| 4 |
+
|
| 5 |
+
import deepspeed
|
| 6 |
+
import torch
|
| 7 |
+
import transformers
|
| 8 |
+
from deepspeed import get_accelerator
|
| 9 |
+
from torch.utils.data import ConcatDataset
|
| 10 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, AutoConfig
|
| 11 |
+
from transformers import Trainer
|
| 12 |
+
from transformers.integrations.deepspeed import unset_hf_deepspeed_config, set_hf_deepspeed_config
|
| 13 |
+
|
| 14 |
+
from callback import TuneTauCallback, MonitorCallback
|
| 15 |
+
from ovis.model.configuration_ovis import OvisConfig
|
| 16 |
+
from ovis.model.modeling_ovis import Ovis
|
| 17 |
+
from ovis.train.arguments import ModelArguments, TrainingArguments
|
| 18 |
+
from ovis.train.dataset.caption_dataset import CaptionDataset
|
| 19 |
+
from ovis.train.dataset.conversation_dataset import ConversationDataset
|
| 20 |
+
from ovis.train.dataset.multimodal_dataset import DataCollatorForMultimodalDataset
|
| 21 |
+
from ovis.util.constants import BEGIN_LINE, END_LINE
|
| 22 |
+
from ovis.util.utils import smart_unit, rank0_print
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def train():
|
| 26 |
+
# parse args
|
| 27 |
+
parser = transformers.HfArgumentParser(
|
| 28 |
+
(ModelArguments, TrainingArguments))
|
| 29 |
+
model_args, training_args = parser.parse_args_into_dataclasses()
|
| 30 |
+
|
| 31 |
+
# save args to checkpoint dir
|
| 32 |
+
with training_args.main_process_first(local=False):
|
| 33 |
+
if training_args.process_index == 0:
|
| 34 |
+
def args2dict(args):
|
| 35 |
+
return {k: str(v) for k, v in args.__dict__.items()}
|
| 36 |
+
|
| 37 |
+
args_log = json.dumps(dict(
|
| 38 |
+
model_args=args2dict(model_args),
|
| 39 |
+
training_args=args2dict(training_args)
|
| 40 |
+
), ensure_ascii=False, indent=2)
|
| 41 |
+
print(args_log)
|
| 42 |
+
os.makedirs(training_args.output_dir, exist_ok=True)
|
| 43 |
+
with open(os.path.join(training_args.output_dir, 'model_training_args.json'), 'w',
|
| 44 |
+
encoding='utf-8') as f:
|
| 45 |
+
f.write(args_log + '\n')
|
| 46 |
+
|
| 47 |
+
# construct or load ovis model
|
| 48 |
+
if not training_args.ovis_pretrained_path: # construct model (S1)
|
| 49 |
+
# 1. construct ovis config
|
| 50 |
+
ovis_config = OvisConfig(
|
| 51 |
+
multimodal_max_length=model_args.multimodal_max_length,
|
| 52 |
+
conversation_formatter_class=model_args.conversation_formatter_class,
|
| 53 |
+
llm_attn_implementation=model_args.llm_attn_implementation
|
| 54 |
+
)
|
| 55 |
+
# 2. load pretrained llm and text tokenizer
|
| 56 |
+
attn_kwargs = dict()
|
| 57 |
+
if model_args.llm_attn_implementation:
|
| 58 |
+
attn_kwargs['attn_implementation'] = model_args.llm_attn_implementation
|
| 59 |
+
llm = AutoModelForCausalLM.from_pretrained(model_args.llm_name_or_path, **attn_kwargs)
|
| 60 |
+
text_tokenizer = AutoTokenizer.from_pretrained(model_args.llm_name_or_path)
|
| 61 |
+
if text_tokenizer.pad_token_id is None and model_args.pad_token_id is not None:
|
| 62 |
+
text_tokenizer.pad_token_id = model_args.pad_token_id
|
| 63 |
+
# 3. construct visual tokenizer
|
| 64 |
+
# deepspeed zero.Init with bfloat16 fail for visual_tokenizer, so temporarily disable zero.Init here
|
| 65 |
+
unset_hf_deepspeed_config()
|
| 66 |
+
if training_args.visual_tokenizer_pretrained_path is not None:
|
| 67 |
+
visual_tokenizer = AutoModel.from_pretrained(
|
| 68 |
+
training_args.visual_tokenizer_pretrained_path,
|
| 69 |
+
image_processor_name_or_path=training_args.visual_tokenizer_pretrained_path
|
| 70 |
+
)
|
| 71 |
+
else:
|
| 72 |
+
visual_tokenizer_config = AutoConfig.for_model(
|
| 73 |
+
model_type=model_args.visual_tokenizer_type + "_visual_tokenizer",
|
| 74 |
+
vocab_size=model_args.visual_vocab_size,
|
| 75 |
+
tokenize_function=model_args.visual_tokenize_function,
|
| 76 |
+
tau=model_args.visual_tau,
|
| 77 |
+
depths=model_args.visual_depths,
|
| 78 |
+
drop_cls_token=model_args.visual_drop_cls_token,
|
| 79 |
+
hidden_stride=model_args.visual_hidden_stride,
|
| 80 |
+
)
|
| 81 |
+
visual_tokenizer = AutoModel.from_config(visual_tokenizer_config, train_from_scratch=True)
|
| 82 |
+
visual_tokenizer = visual_tokenizer.to(
|
| 83 |
+
device=torch.device(get_accelerator().device_name(os.getenv("LOCAL_RANK"))))
|
| 84 |
+
if getattr(training_args, 'hf_deepspeed_config', None) is not None:
|
| 85 |
+
set_hf_deepspeed_config(training_args.hf_deepspeed_config)
|
| 86 |
+
# 4. construct ovis model
|
| 87 |
+
model = Ovis(ovis_config, llm=llm, text_tokenizer=text_tokenizer, visual_tokenizer=visual_tokenizer,
|
| 88 |
+
train_from_scratch=True)
|
| 89 |
+
else: # load pretrained ovis model
|
| 90 |
+
model, loading_info = Ovis.from_pretrained(training_args.ovis_pretrained_path,
|
| 91 |
+
multimodal_max_length=model_args.multimodal_max_length,
|
| 92 |
+
output_loading_info=True)
|
| 93 |
+
rank0_print(BEGIN_LINE)
|
| 94 |
+
rank0_print(f'Loading info of Ovis:\n{loading_info}')
|
| 95 |
+
rank0_print(END_LINE)
|
| 96 |
+
training_args.vte_re_init = False
|
| 97 |
+
|
| 98 |
+
model.get_llm().config.use_cache = False
|
| 99 |
+
model.config.use_cache = False
|
| 100 |
+
text_tokenizer = model.get_text_tokenizer()
|
| 101 |
+
|
| 102 |
+
rank0_print(BEGIN_LINE)
|
| 103 |
+
rank0_print(f'model.config:\n{model.config}')
|
| 104 |
+
rank0_print(END_LINE)
|
| 105 |
+
|
| 106 |
+
# maybe re-init vte
|
| 107 |
+
if training_args.vte_re_init:
|
| 108 |
+
with deepspeed.zero.GatheredParameters([model.get_wte().weight]):
|
| 109 |
+
mean = model.get_wte().weight.mean().item()
|
| 110 |
+
std = model.get_wte().weight.std().item()
|
| 111 |
+
rank0_print(f'Statistics of embedding table of LLM: {mean=}, {std=}')
|
| 112 |
+
model.re_init_vte(mean, std)
|
| 113 |
+
|
| 114 |
+
# select train modules
|
| 115 |
+
model.requires_grad_(False)
|
| 116 |
+
for module in training_args.train_modules.split('|'):
|
| 117 |
+
if module == 'all':
|
| 118 |
+
model.requires_grad_(True)
|
| 119 |
+
elif module == 'llm':
|
| 120 |
+
model.get_llm().requires_grad_(True)
|
| 121 |
+
elif module == 'visual_tokenizer':
|
| 122 |
+
model.get_visual_tokenizer().requires_grad_(True)
|
| 123 |
+
elif module == 'visual_tokenizer.backbone':
|
| 124 |
+
model.get_visual_tokenizer().get_backbone().requires_grad_(True)
|
| 125 |
+
elif module.startswith('visual_tokenizer.backbone.layer.'):
|
| 126 |
+
layer_index = int(module[len('visual_tokenizer.backbone.layer.'):])
|
| 127 |
+
layer = model.get_visual_tokenizer().get_backbone_layer(layer_index)
|
| 128 |
+
layer.requires_grad_(True)
|
| 129 |
+
elif module == 'visual_tokenizer.head':
|
| 130 |
+
model.get_visual_tokenizer().get_head().requires_grad_(True)
|
| 131 |
+
elif module == 'vte':
|
| 132 |
+
model.get_vte().requires_grad_(True)
|
| 133 |
+
else:
|
| 134 |
+
raise ValueError(f'Invalid train module name: {module}')
|
| 135 |
+
|
| 136 |
+
rank0_print(BEGIN_LINE)
|
| 137 |
+
rank0_print('Parameters to train:')
|
| 138 |
+
for name, param in model.named_parameters():
|
| 139 |
+
if param.requires_grad:
|
| 140 |
+
rank0_print(name)
|
| 141 |
+
rank0_print(f'LLM\'s attn implementation: {model.get_llm().config._attn_implementation}')
|
| 142 |
+
rank0_print(END_LINE)
|
| 143 |
+
|
| 144 |
+
# construct data module
|
| 145 |
+
datasets = []
|
| 146 |
+
dataset_info_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
| 147 |
+
f'dataset/{training_args.dataset_info}.json')
|
| 148 |
+
with open(dataset_info_path, 'r', encoding='utf-8') as f:
|
| 149 |
+
dataset_info = json.load(f)
|
| 150 |
+
for name in training_args.dataset_names.split('|'):
|
| 151 |
+
info = dataset_info[name]
|
| 152 |
+
data_format = info['data_format']
|
| 153 |
+
if data_format == 'caption':
|
| 154 |
+
dataset = CaptionDataset(name, info, model, training_args)
|
| 155 |
+
elif data_format == 'conversation':
|
| 156 |
+
dataset = ConversationDataset(name, info, model, training_args)
|
| 157 |
+
else:
|
| 158 |
+
raise ValueError(f'Invalid data format `{data_format}` for dataset `{name}`')
|
| 159 |
+
datasets.append(dataset)
|
| 160 |
+
data_module = dict(
|
| 161 |
+
train_dataset=ConcatDataset(datasets),
|
| 162 |
+
data_collator=DataCollatorForMultimodalDataset(text_tokenizer)
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# train
|
| 166 |
+
train_callbacks = [MonitorCallback]
|
| 167 |
+
if model_args.visual_tokenize_function == 'gumbel_argmax':
|
| 168 |
+
train_callbacks.append(TuneTauCallback)
|
| 169 |
+
trainer = Trainer(
|
| 170 |
+
model=model,
|
| 171 |
+
args=training_args,
|
| 172 |
+
callbacks=train_callbacks,
|
| 173 |
+
**data_module
|
| 174 |
+
)
|
| 175 |
+
rank0_print(BEGIN_LINE)
|
| 176 |
+
rank0_print('Dataset sample tensor:')
|
| 177 |
+
rank0_print(data_module['train_dataset'][0])
|
| 178 |
+
rank0_print(END_LINE)
|
| 179 |
+
rank0_print(BEGIN_LINE)
|
| 180 |
+
rank0_print('Dataset sample input_ids decoding:')
|
| 181 |
+
rank0_print(text_tokenizer.decode([x for x in data_module['train_dataset'][0]['input_ids'] if x >= 0]))
|
| 182 |
+
rank0_print(END_LINE)
|
| 183 |
+
rank0_print(BEGIN_LINE)
|
| 184 |
+
rank0_print('Dataset sample labels decoding:')
|
| 185 |
+
rank0_print(text_tokenizer.decode([x for x in data_module['train_dataset'][0]['labels'] if x >= 0]))
|
| 186 |
+
rank0_print(END_LINE)
|
| 187 |
+
rank0_print(BEGIN_LINE)
|
| 188 |
+
rank0_print(f'#param of model: {smart_unit(model.num_parameters())}')
|
| 189 |
+
rank0_print(f'#param of llm: {smart_unit(model.get_llm().num_parameters())}')
|
| 190 |
+
rank0_print(f'#param of visual_tokenizer: {smart_unit(model.get_visual_tokenizer().num_parameters())}')
|
| 191 |
+
rank0_print(f'#param of vte: {smart_unit(model.get_vte().weight.numel())}')
|
| 192 |
+
rank0_print(END_LINE)
|
| 193 |
+
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
|
| 194 |
+
trainer.train(resume_from_checkpoint=True)
|
| 195 |
+
else:
|
| 196 |
+
trainer.train()
|
| 197 |
+
trainer.save_state()
|
| 198 |
+
|
| 199 |
+
# save model
|
| 200 |
+
model.get_llm().config.use_cache = True
|
| 201 |
+
model.config.use_cache = True
|
| 202 |
+
trainer.save_model()
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
if __name__ == '__main__':
|
| 206 |
+
train()
|
Ovis/ovis/util/constants.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Model Constants
|
| 2 |
+
IGNORE_ID = -100
|
| 3 |
+
IMAGE_TOKEN_ID = -200
|
| 4 |
+
IMAGE_TOKEN = "<image>"
|
| 5 |
+
|
| 6 |
+
IMAGE_ATOM_ID = -300
|
| 7 |
+
IMAGE_INDICATOR_IDS = [-301, -302, -303, -304, -305]
|
| 8 |
+
|
| 9 |
+
# Log & Print
|
| 10 |
+
BEGIN_LINE = '========================************========================'
|
| 11 |
+
END_LINE = '------------------------------------------------------------'
|
Ovis/ovis/util/utils.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from importlib import import_module
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def rank0_print(*args):
|
| 6 |
+
if int(os.getenv("LOCAL_PROCESS_RANK", os.getenv("LOCAL_RANK", 0))) == 0:
|
| 7 |
+
print(*args)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def smart_unit(num):
|
| 11 |
+
if num / 1.0e9 >= 1:
|
| 12 |
+
return f'{num / 1.0e9:.2f}B'
|
| 13 |
+
else:
|
| 14 |
+
return f'{num / 1.0e6:.2f}M'
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def import_class_from_string(full_class_string):
|
| 18 |
+
# Split the path to get separate module and class names
|
| 19 |
+
module_path, _, class_name = full_class_string.rpartition('.')
|
| 20 |
+
|
| 21 |
+
# Import the module using the module path
|
| 22 |
+
module = import_module(module_path)
|
| 23 |
+
|
| 24 |
+
# Get the class from the imported module
|
| 25 |
+
cls = getattr(module, class_name)
|
| 26 |
+
return cls
|
llm2vec/docs/.gitignore
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Jekyll
|
| 2 |
+
_site
|
| 3 |
+
.sass-cache
|
| 4 |
+
.jekyll-metadata
|
| 5 |
+
Gemfile.lock
|
| 6 |
+
vendor/*
|
| 7 |
+
.bundle
|
| 8 |
+
|
| 9 |
+
*.vscode
|
llm2vec/docs/Gemfile
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
source "https://rubygems.org"
|
| 2 |
+
|
| 3 |
+
gem "webrick"
|
| 4 |
+
gem "github-pages", group: :jekyll_plugins
|
| 5 |
+
|
| 6 |
+
gem "tzinfo-data"
|
| 7 |
+
gem "wdm", "~> 0.1.0" if Gem.win_platform?
|
| 8 |
+
|
| 9 |
+
# If you have any plugins, put them here!
|
| 10 |
+
group :jekyll_plugins do
|
| 11 |
+
gem "jekyll-paginate"
|
| 12 |
+
gem "jekyll-sitemap"
|
| 13 |
+
gem "jekyll-gist"
|
| 14 |
+
gem "jekyll-feed"
|
| 15 |
+
gem "jemoji"
|
| 16 |
+
gem "jekyll-include-cache"
|
| 17 |
+
gem "jekyll-algolia"
|
| 18 |
+
end
|
llm2vec/docs/README.md
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Project Page Template
|
| 2 |
+
|
| 3 |
+
This is a project page template for McGill-NLP projects.
|
| 4 |
+
|
| 5 |
+

|
| 6 |
+
|
| 7 |
+
## Getting started
|
| 8 |
+
|
| 9 |
+
You can follow one of the following ways to get started.
|
| 10 |
+
|
| 11 |
+
### Copy from template
|
| 12 |
+
|
| 13 |
+
If you have not yet created a repo for your project, you can copy the template from the following link. Simply click on the "Use this template" button, or [click here](https://github.com/McGill-NLP/project-page-template/generate).
|
| 14 |
+
|
| 15 |
+
### Cloning with git
|
| 16 |
+
|
| 17 |
+
If you have already created a repo for your project, you can clone the template and copy the files to your project. Let's assume you are in the root of your project directory, e.g. `my-project`.
|
| 18 |
+
|
| 19 |
+
```bash
|
| 20 |
+
cd .. # Go to parent directory
|
| 21 |
+
git clone https://github.com/McGill-NLP/project-page-template
|
| 22 |
+
cp -r project-page-template/docs my-project/
|
| 23 |
+
cp project-page-template/README.md my-project/docs/
|
| 24 |
+
cd my-project/
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
## Activate GitHub Pages
|
| 28 |
+
|
| 29 |
+
Once the template is copied to your project, you need to activate GitHub Pages for your project.
|
| 30 |
+
|
| 31 |
+
1. Click on the "Settings" button in the top right corner of the page.
|
| 32 |
+
2. Click on the "Pages" tab.
|
| 33 |
+
3. In "Source", select branch to be "main" and folder to "/docs".
|
| 34 |
+
4. Click on "Save"
|
| 35 |
+
5. Go to the "Actions" tab (on the right of "Pull requests" tab) and wait for the action to finish.
|
| 36 |
+
6. Visit your project page at mcgill-nlp.github.io/<your-project-name>
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
### Why `docs/`?
|
| 40 |
+
|
| 41 |
+
You might be wondering why all the files for the webpage is in `/docs`. Well the page is not really "docs" per se, it's because `github-pages` only allows us to either use the root folder or `/docs`. So we are forced to use the latter in order to clearly separate this page from the rest of the project. Maybe in the future GitHub will allow other names like `/page`, but before then there's nothing we can do...
|
| 42 |
+
|
| 43 |
+
> Well technically, you *can* push everything to a separate branch, but the original author of this repo is in the school of thought that branches are meant for alternate or historical versions of the `main` branch, or serve as a way to create pull requests. The original author will also vehemently defend this philosophy if one tries to argue otherwise :)
|
| 44 |
+
|
| 45 |
+
But what if you actually want to write some docs for your library? You can just edit `docs/_pages/docs.md`, which is the real page for documentations.
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
## Navigation bar
|
| 49 |
+
|
| 50 |
+
You can already find links to different pages in the navigation bar. To add, remove, or modify links, you can edit [`docs/_data/navigation.yml`](docs/_data/navigation.yml) file. The `title` corresponds to the text that appear on the navbar, and the `url` corresponds to the relative URL of the page. It is not recommended to include an external URL, as that should be in `/home` page.
|
| 51 |
+
|
| 52 |
+
## Modifying a page
|
| 53 |
+
|
| 54 |
+
The files are located in [`docs/_pages/`](docs/_pages/). For example, if you want to modify the `/home` page, you would edit `docs/_pages/home.md`.
|
| 55 |
+
|
| 56 |
+
All of the pages are markdown files with something called a [front matter](https://jekyllrb.com/docs/front-matter/) at the top, which uses the YAML syntax. Generally, all you need to worry about is the title and the permalink (the latter is the relative URL of the page). In the case of `/home`, you also need to specify external links (`header.actions`) and author names (`excerpt`). However, a template is already provided for you, you only need to modify the content.
|
| 57 |
+
|
| 58 |
+
## Adding and removing pages
|
| 59 |
+
|
| 60 |
+
To add a page:
|
| 61 |
+
1. Create a new file in `docs/_pages/`, with the desired `permalink` to be your relative URL. So for example, `/contact/` links to `mcgill-nlp.github.io/my-project/contact`.
|
| 62 |
+
2. In `docs/_data/navigation.yml`, add a new entry with the `title` and `url` from the previous step.
|
| 63 |
+
|
| 64 |
+
To remove a page:
|
| 65 |
+
1. Delete the file in `docs/_pages/`.
|
| 66 |
+
2. In `docs/_data/navigation.yml`, remove the entry with the same `url` as the deleted file.
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
## Documentations and API for your project
|
| 70 |
+
|
| 71 |
+
Note that there's a tab that says `docs`, and you can see that it links to other pages. So this is a standalone doc page inside your webpage. Note also that, due to Github pages' caveat, we were forced to put the webpage in `/docs`, but the actual docs are in `/docs/_docs`. Now that's cleared up, you can head to [`/docs/_docs/README.md`](/docs/_docs/README.md) to read the instructions.
|
| 72 |
+
|
| 73 |
+
> Do you feel writing documentation is too complicated or time-consuming, and you'd like something more straightforward? Check out the [template for using MkDocs](https://github.com/McGill-NLP/mkdocs-template) instead. However, the simplicity comes at the cost of more repositories and different frameworks to maintain.
|
| 74 |
+
|
| 75 |
+
## Advanced
|
| 76 |
+
|
| 77 |
+
For any advanced modification, it is recommended to look in the advanced section of the readme of the [group website](https://github.com/McGill-NLP/mcgill-nlp.github.io). Below are a extra tips included for convenience.
|
| 78 |
+
|
| 79 |
+
### Setup
|
| 80 |
+
|
| 81 |
+
Please refer to setup instructions in the readme of the [group website](https://github.com/McGill-NLP/mcgill-nlp.github.io).
|
| 82 |
+
|
| 83 |
+
### Running locally
|
| 84 |
+
|
| 85 |
+
```bash
|
| 86 |
+
cd docs/
|
| 87 |
+
bundle exec jekyll serve
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
### Removing dark mode
|
| 91 |
+
|
| 92 |
+
To remove dark mode, inside [`docs/_config.yml`](docs/_config.yml) file, remove `dark_theme_css`. The dark mode should automatically turn off.
|
| 93 |
+
|
| 94 |
+
### Updating footer
|
| 95 |
+
|
| 96 |
+
Inside [`docs/_config.yml`](docs/_config.yml) file, you can modify the footer.
|
| 97 |
+
|
| 98 |
+
### Modify `excerpt` in a splash page (`/home`)
|
| 99 |
+
|
| 100 |
+
If you want to modify the excerpt in the `/home` page, you can do so in [`docs/_sass_/splash.scss`](docs/_sass_/splash.scss). Note that `splash.scss` was added specifically for this template, not for the group website.
|
| 101 |
+
|
| 102 |
+
### Modify or remove icons in splash page buttons
|
| 103 |
+
|
| 104 |
+
This is handled in [`docs/_includes/page__hero.html`](docs/_includes/page__hero.html). That file was added specifically for this template, not for the group website. You can modify that file to add, modify or remove icons.
|
llm2vec/docs/_config.yml
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Welcome to Jekyll!
|
| 2 |
+
#
|
| 3 |
+
# This config file is meant for settings that affect your whole blog, values
|
| 4 |
+
# which you are expected to set up once and rarely edit after that. If you find
|
| 5 |
+
# yourself editing this file very often, consider using Jekyll's data files
|
| 6 |
+
# feature for the data you need to update frequently.
|
| 7 |
+
#
|
| 8 |
+
# For technical reasons, this file is *NOT* reloaded automatically when you use
|
| 9 |
+
# 'bundle exec jekyll serve'. If you change this file, please restart the server process.
|
| 10 |
+
|
| 11 |
+
# Site settings
|
| 12 |
+
# These are used to personalize your new site. If you look in the HTML files,
|
| 13 |
+
# you will see them accessed via {{ site.title }}, {{ site.email }}, and so on.
|
| 14 |
+
# You can create any custom variable you would like, and they will be accessible
|
| 15 |
+
# in the templates via {{ site.myvariable }}.
|
| 16 |
+
title: McGill NLP
|
| 17 |
+
email:
|
| 18 |
+
description: >- # this means to ignore newlines until "baseurl:"
|
| 19 |
+
McGill NLP is a research group within McGill University and Mila focusing on various topics of natural language processing.
|
| 20 |
+
twitter_username: McGill_NLP
|
| 21 |
+
github_username: McGill-NLP
|
| 22 |
+
logo: "/assets/images/logo/logo.png"
|
| 23 |
+
dark_theme_css: "/assets/css/main-dark.css"
|
| 24 |
+
future: true
|
| 25 |
+
|
| 26 |
+
# Build settings
|
| 27 |
+
markdown: kramdown
|
| 28 |
+
remote_theme: mmistakes/minimal-mistakes@4.24.0
|
| 29 |
+
# Outputting
|
| 30 |
+
permalink: /:categories/:title/
|
| 31 |
+
timezone: America/Montreal
|
| 32 |
+
|
| 33 |
+
include:
|
| 34 |
+
- _pages
|
| 35 |
+
- _docs
|
| 36 |
+
|
| 37 |
+
# Exclude from processing.
|
| 38 |
+
# The following items will not be processed, by default. Create a custom list
|
| 39 |
+
# to override the default setting.
|
| 40 |
+
# exclude:
|
| 41 |
+
# - Gemfile
|
| 42 |
+
# - Gemfile.lock
|
| 43 |
+
# - node_modules
|
| 44 |
+
# - vendor/bundle/
|
| 45 |
+
# - vendor/cache/
|
| 46 |
+
# - vendor/gems/
|
| 47 |
+
# - vendor/ruby/
|
| 48 |
+
|
| 49 |
+
# Plugins (previously gems:)
|
| 50 |
+
plugins:
|
| 51 |
+
- jekyll-sitemap
|
| 52 |
+
- jekyll-gist
|
| 53 |
+
- jemoji
|
| 54 |
+
- jekyll-include-cache
|
| 55 |
+
|
| 56 |
+
author:
|
| 57 |
+
name : "McGill NLP Member(s)"
|
| 58 |
+
avatar : "/assets/images/bio/default.jpg"
|
| 59 |
+
bio : "Current or former lab member(s) worked on this."
|
| 60 |
+
links:
|
| 61 |
+
- label: "Website"
|
| 62 |
+
icon: "fas fa-fw fa-link"
|
| 63 |
+
url: "https://mcgill-nlp.github.io"
|
| 64 |
+
- label: "GitHub"
|
| 65 |
+
icon: "fab fa-fw fa-github"
|
| 66 |
+
url: "https://github.com/McGill-NLP"
|
| 67 |
+
- label: "Twitter"
|
| 68 |
+
icon: "fab fa-fw fa-twitter-square"
|
| 69 |
+
url: "https://twitter.com/McGill_NLP"
|
| 70 |
+
|
| 71 |
+
analytics:
|
| 72 |
+
provider: "google-gtag"
|
| 73 |
+
google:
|
| 74 |
+
tracking_id: "G-MEDG9XN4VP"
|
| 75 |
+
anonymize_ip: false # default
|
| 76 |
+
|
| 77 |
+
atom_feed:
|
| 78 |
+
hide: true
|
| 79 |
+
|
| 80 |
+
footer:
|
| 81 |
+
links:
|
| 82 |
+
- label: "GitHub"
|
| 83 |
+
icon: "fab fa-fw fa-github"
|
| 84 |
+
url: "https://github.com/McGill-NLP"
|
| 85 |
+
- label: "Twitter"
|
| 86 |
+
icon: "fab fa-fw fa-twitter-square"
|
| 87 |
+
url: "https://twitter.com/McGill_NLP"
|
| 88 |
+
|
| 89 |
+
defaults:
|
| 90 |
+
# /docs/_pages
|
| 91 |
+
- scope:
|
| 92 |
+
path: "_pages"
|
| 93 |
+
type: pages
|
| 94 |
+
values:
|
| 95 |
+
layout: single
|
| 96 |
+
classes:
|
| 97 |
+
- no-sidebar
|
| 98 |
+
- wide
|
| 99 |
+
author_profile: false
|
| 100 |
+
# /docs/_docs
|
| 101 |
+
- scope:
|
| 102 |
+
path: "_docs"
|
| 103 |
+
type: pages
|
| 104 |
+
values:
|
| 105 |
+
layout: single
|
| 106 |
+
sidebar:
|
| 107 |
+
title: "Doc Pages"
|
| 108 |
+
nav: sidebar-docs # See /docs/_data/navigation.yml
|
| 109 |
+
toc: true
|
| 110 |
+
toc_label: "Table of Contents"
|
llm2vec/docs/_data/navigation.yml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
main:
|
| 2 |
+
- title: "Home"
|
| 3 |
+
url: /
|
| 4 |
+
- title: "Leaderboard"
|
| 5 |
+
url: /leaderboard/
|
| 6 |
+
- title: "Docs"
|
| 7 |
+
url: /docs/
|
| 8 |
+
- title: "Contact"
|
| 9 |
+
url: /contact/
|
| 10 |
+
|
| 11 |
+
sidebar-docs: # See "include" in /_config.yml and /docs/_docs
|
| 12 |
+
- title: "Home"
|
| 13 |
+
url: /docs/
|
| 14 |
+
- title: "API"
|
| 15 |
+
url: /docs/api
|
| 16 |
+
- title: "Training"
|
| 17 |
+
url: /docs/training
|
llm2vec/docs/_includes/head/custom.html
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!-- Add favicon -->
|
| 2 |
+
<link rel="icon" type="image/png" href="{{ site.baseurl }}/assets/images/logo/favicon.png">
|
| 3 |
+
|
| 4 |
+
{% if site.dark_theme_css %}
|
| 5 |
+
<!-- Dark Mode -->
|
| 6 |
+
<link rel="stylesheet" href="{{ '/assets/css/main.css' | relative_url }}" id="theme-css">
|
| 7 |
+
<link rel="stylesheet alternate" href="{{ site.dark_theme_css | relative_url }}" id="theme-css-dark">
|
| 8 |
+
|
| 9 |
+
<script type="text/javascript">
|
| 10 |
+
const updateNodesRel = theme => {
|
| 11 |
+
const node_light = document.getElementById('theme-css');
|
| 12 |
+
const node_dark = document.getElementById('theme-css-dark');
|
| 13 |
+
|
| 14 |
+
if (theme === "dark") {
|
| 15 |
+
node_light.setAttribute('rel', 'stylesheet alternate');
|
| 16 |
+
node_dark.setAttribute('rel', 'stylesheet');
|
| 17 |
+
}
|
| 18 |
+
else if (theme === "light") {
|
| 19 |
+
node_light.setAttribute('rel', 'stylesheet');
|
| 20 |
+
node_dark.setAttribute('rel', 'stylesheet alternate');
|
| 21 |
+
}
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
const changeTheme = () => {
|
| 25 |
+
let theme = sessionStorage.getItem('theme');
|
| 26 |
+
|
| 27 |
+
// Change the theme to the other option
|
| 28 |
+
if (theme === "light") {
|
| 29 |
+
theme = "dark";
|
| 30 |
+
} else {
|
| 31 |
+
theme = "light";
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
// Update the stored session and the nodes' rel attribute
|
| 35 |
+
sessionStorage.setItem('theme', theme);
|
| 36 |
+
updateNodesRel(theme);
|
| 37 |
+
|
| 38 |
+
return false;
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
if (sessionStorage.getItem('theme') === null) {
|
| 42 |
+
sessionStorage.setItem('theme', "light");
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
const theme = sessionStorage.getItem('theme');
|
| 46 |
+
updateNodesRel(theme);
|
| 47 |
+
</script>
|
| 48 |
+
{% endif %}
|
llm2vec/docs/_sass/custom/header-footer.scss
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
a.site-title {
|
| 2 |
+
@media (min-width: 601px) {
|
| 3 |
+
font-size: xx-large;
|
| 4 |
+
}
|
| 5 |
+
@media (max-width: 600px) {
|
| 6 |
+
font-size: large;
|
| 7 |
+
}
|
| 8 |
+
color: $primary-color;
|
| 9 |
+
|
| 10 |
+
&:hover {
|
| 11 |
+
color: mix($background-color, $primary-color, 25%);
|
| 12 |
+
}
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
.theme-toggle {
|
| 16 |
+
@media (min-width: 601px) {
|
| 17 |
+
margin: 0px;
|
| 18 |
+
}
|
| 19 |
+
}
|
llm2vec/docs/_sass/custom/no-sidebar.scss
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.no-sidebar article.page {
|
| 2 |
+
float: left;
|
| 3 |
+
width: 100%;
|
| 4 |
+
}
|
| 5 |
+
|
| 6 |
+
.no-sidebar .archive {
|
| 7 |
+
float: left;
|
| 8 |
+
width: 100%;
|
| 9 |
+
}
|
llm2vec/docs/_sass/custom/splash.scss
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// This contains the styles for the "excerpt" on a splash page
|
| 2 |
+
div.wrapper > p.page__lead {
|
| 3 |
+
font-size: x-large;
|
| 4 |
+
max-width: 100%;
|
| 5 |
+
}
|
llm2vec/docs/_sass/skins/dark.scss
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* ==========================================================================
|
| 2 |
+
Dark skin
|
| 3 |
+
Imported in /assets/css/main-light.scss
|
| 4 |
+
========================================================================== */
|
| 5 |
+
|
| 6 |
+
/* Colors */
|
| 7 |
+
$background-color: #000000 !default;
|
| 8 |
+
$text-color: #eaeaea !default;
|
| 9 |
+
$primary-color: #ED1B2F !default;
|
| 10 |
+
$border-color: mix(#fff, $background-color, 20%) !default;
|
| 11 |
+
$code-background-color: mix(#000, $background-color, 15%) !default;
|
| 12 |
+
$code-background-color-dark: mix(#000, $background-color, 20%) !default;
|
| 13 |
+
$form-background-color: mix(#000, $background-color, 15%) !default;
|
| 14 |
+
$footer-background-color: mix($text-color, $background-color, 5%) !default;
|
| 15 |
+
$link-color: mix($primary-color, $text-color, 100%) !default;
|
| 16 |
+
$link-color-hover: mix($background-color, $link-color, 15%) !default;
|
| 17 |
+
$link-color-visited: mix(#000, $link-color, 0%) !default;
|
| 18 |
+
$masthead-link-color: $text-color !default;
|
| 19 |
+
$masthead-link-color-hover: mix(#000, $text-color, 20%) !default;
|
| 20 |
+
|
| 21 |
+
.author__urls.social-icons i,
|
| 22 |
+
.author__urls.social-icons .svg-inline--fa,
|
| 23 |
+
.page__footer-follow .social-icons i,
|
| 24 |
+
.page__footer-follow .social-icons .svg-inline--fa {
|
| 25 |
+
color: inherit;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
.ais-search-box .ais-search-box--input {
|
| 29 |
+
background-color: $form-background-color;
|
| 30 |
+
}
|
llm2vec/docs/_sass/skins/light.scss
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
Imported in /assets/css/main-light.scss
|
| 3 |
+
*/
|
| 4 |
+
$background-color: #fff !default;
|
| 5 |
+
$text-color: #000 !default;
|
| 6 |
+
$primary-color: #ED1B2F !default;
|
| 7 |
+
// $footer-background-color: mix($primary-color, $background-color, 100%) !default;
|
| 8 |
+
$link-color: #ED1B2F !default;
|
| 9 |
+
$link-color-hover: mix(#fff, $link-color, 25%) !default;
|
| 10 |
+
$link-color-visited: mix(#000, $link-color, 10%) !default;
|
| 11 |
+
$masthead-link-color: $text-color !default;
|
| 12 |
+
$masthead-link-color-hover: mix($background-color, $text-color, 25%) !default;
|
llm2vec/docs/assets/images/logo/favicon.png
ADDED
|
|
llm2vec/docs/assets/images/logo/logo.png
ADDED
|
llm2vec/docs/assets/images/logo/logo.svg
ADDED
|
|
llm2vec/examples/classification.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sklearn.metrics import accuracy_score, f1_score
|
| 2 |
+
from sklearn.linear_model import LogisticRegression
|
| 3 |
+
import datasets
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from llm2vec import LLM2Vec
|
| 8 |
+
|
| 9 |
+
dataset = "mteb/amazon_counterfactual"
|
| 10 |
+
instruction = "Classify a given Amazon customer review text as either counterfactual or notcounterfactual: "
|
| 11 |
+
|
| 12 |
+
dataset = datasets.load_dataset(dataset, "en")
|
| 13 |
+
|
| 14 |
+
sentences_train, y_train = dataset["train"]["text"], dataset["train"]["label"]
|
| 15 |
+
sentences_test, y_test = dataset["test"]["text"], dataset["test"]["label"]
|
| 16 |
+
max_iter = 100
|
| 17 |
+
batch_size = 8
|
| 18 |
+
|
| 19 |
+
scores = {}
|
| 20 |
+
clf = LogisticRegression(
|
| 21 |
+
random_state=42,
|
| 22 |
+
n_jobs=1,
|
| 23 |
+
max_iter=max_iter,
|
| 24 |
+
verbose=0,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
print("Loading model...")
|
| 28 |
+
model = LLM2Vec.from_pretrained(
|
| 29 |
+
"McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp",
|
| 30 |
+
peft_model_name_or_path="McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp-supervised",
|
| 31 |
+
device_map="cuda" if torch.cuda.is_available() else "cpu",
|
| 32 |
+
torch_dtype=torch.bfloat16,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def append_instruction(instruction, sentences):
|
| 37 |
+
new_sentences = []
|
| 38 |
+
for s in sentences:
|
| 39 |
+
new_sentences.append([instruction, s, 0])
|
| 40 |
+
return new_sentences
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
print(f"Encoding {len(sentences_train)} training sentences...")
|
| 44 |
+
sentences_train = append_instruction(instruction, sentences_train)
|
| 45 |
+
X_train = np.asarray(model.encode(sentences_train, batch_size=batch_size))
|
| 46 |
+
|
| 47 |
+
print(f"Encoding {len(sentences_test)} test sentences...")
|
| 48 |
+
sentences_test = append_instruction(instruction, sentences_test)
|
| 49 |
+
X_test = np.asarray(model.encode(sentences_test, batch_size=batch_size))
|
| 50 |
+
|
| 51 |
+
print("Fitting logistic regression classifier...")
|
| 52 |
+
clf.fit(X_train, y_train)
|
| 53 |
+
print("Evaluating...")
|
| 54 |
+
y_pred = clf.predict(X_test)
|
| 55 |
+
|
| 56 |
+
accuracy = accuracy_score(y_test, y_pred)
|
| 57 |
+
scores["accuracy"] = accuracy
|
| 58 |
+
f1 = f1_score(y_test, y_pred, average="macro")
|
| 59 |
+
scores["f1"] = f1
|
| 60 |
+
|
| 61 |
+
print(scores)
|
| 62 |
+
# {'accuracy': 0.891044776119403, 'f1': 0.8283106625713033}
|
llm2vec/examples/clustering.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sklearn
|
| 2 |
+
import sklearn.cluster
|
| 3 |
+
|
| 4 |
+
import datasets
|
| 5 |
+
import tqdm
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from llm2vec import LLM2Vec
|
| 10 |
+
|
| 11 |
+
dataset = "mteb/twentynewsgroups-clustering"
|
| 12 |
+
instruction = "Identify the topic or theme of the given news articles: "
|
| 13 |
+
|
| 14 |
+
dataset = datasets.load_dataset(dataset)
|
| 15 |
+
batch_size = 32
|
| 16 |
+
|
| 17 |
+
print("Loading model...")
|
| 18 |
+
model = LLM2Vec.from_pretrained(
|
| 19 |
+
"McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp",
|
| 20 |
+
peft_model_name_or_path="McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp-supervised",
|
| 21 |
+
device_map="cuda" if torch.cuda.is_available() else "cpu",
|
| 22 |
+
torch_dtype=torch.bfloat16,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def append_instruction(instruction, sentences):
|
| 27 |
+
new_sentences = []
|
| 28 |
+
for s in sentences:
|
| 29 |
+
new_sentences.append([instruction, s, 0])
|
| 30 |
+
return new_sentences
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
v_measures = []
|
| 34 |
+
for cluster_set in tqdm.tqdm(dataset["test"], desc="Clustering"):
|
| 35 |
+
sentences = cluster_set["sentences"]
|
| 36 |
+
labels = cluster_set["labels"]
|
| 37 |
+
clustering_batch_size = 500
|
| 38 |
+
|
| 39 |
+
print(f"Encoding {len(sentences)} sentences...")
|
| 40 |
+
new_sentences = append_instruction(instruction, sentences)
|
| 41 |
+
corpus_embeddings = np.asarray(model.encode(new_sentences, batch_size=batch_size))
|
| 42 |
+
|
| 43 |
+
print("Fitting Mini-Batch K-Means model...")
|
| 44 |
+
clustering_model = sklearn.cluster.MiniBatchKMeans(
|
| 45 |
+
n_clusters=len(set(labels)), batch_size=clustering_batch_size
|
| 46 |
+
)
|
| 47 |
+
clustering_model.fit(corpus_embeddings)
|
| 48 |
+
cluster_assignment = clustering_model.labels_
|
| 49 |
+
|
| 50 |
+
print("Evaluating...")
|
| 51 |
+
v_measure = sklearn.metrics.cluster.v_measure_score(labels, cluster_assignment)
|
| 52 |
+
v_measures.append(v_measure)
|
| 53 |
+
|
| 54 |
+
v_mean = np.mean(v_measures)
|
| 55 |
+
v_std = np.std(v_measures)
|
| 56 |
+
|
| 57 |
+
print(v_mean)
|
| 58 |
+
# 0.5137461051538426
|
llm2vec/examples/retrieval.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datasets
|
| 2 |
+
import torch
|
| 3 |
+
from llm2vec import LLM2Vec
|
| 4 |
+
from beir import util
|
| 5 |
+
from beir.datasets.data_loader import GenericDataLoader as BeirDataLoader
|
| 6 |
+
import os
|
| 7 |
+
from typing import Dict, List
|
| 8 |
+
|
| 9 |
+
from beir.retrieval.evaluation import EvaluateRetrieval
|
| 10 |
+
|
| 11 |
+
dataset = "arguana"
|
| 12 |
+
instruction = "Given a claim, find documents that refute the claim: "
|
| 13 |
+
|
| 14 |
+
print("Loading dataset...")
|
| 15 |
+
url = (
|
| 16 |
+
f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset}.zip"
|
| 17 |
+
)
|
| 18 |
+
download_path = os.path.join(datasets.config.HF_DATASETS_CACHE, "BeIR")
|
| 19 |
+
data_path = util.download_and_unzip(url, download_path)
|
| 20 |
+
corpus, queries, relevant_docs = BeirDataLoader(data_folder=data_path).load(
|
| 21 |
+
split="test"
|
| 22 |
+
)
|
| 23 |
+
batch_size = 8
|
| 24 |
+
|
| 25 |
+
print("Loading model...")
|
| 26 |
+
model = LLM2Vec.from_pretrained(
|
| 27 |
+
"McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp",
|
| 28 |
+
peft_model_name_or_path="McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp-supervised",
|
| 29 |
+
device_map="cuda" if torch.cuda.is_available() else "cpu",
|
| 30 |
+
torch_dtype=torch.bfloat16,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def append_instruction(instruction, sentences):
|
| 35 |
+
new_sentences = []
|
| 36 |
+
for s in sentences:
|
| 37 |
+
new_sentences.append([instruction, s, 0])
|
| 38 |
+
return new_sentences
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def cos_sim(a: torch.Tensor, b: torch.Tensor):
|
| 42 |
+
if not isinstance(a, torch.Tensor):
|
| 43 |
+
a = torch.tensor(a)
|
| 44 |
+
|
| 45 |
+
if not isinstance(b, torch.Tensor):
|
| 46 |
+
b = torch.tensor(b)
|
| 47 |
+
|
| 48 |
+
if len(a.shape) == 1:
|
| 49 |
+
a = a.unsqueeze(0)
|
| 50 |
+
|
| 51 |
+
if len(b.shape) == 1:
|
| 52 |
+
b = b.unsqueeze(0)
|
| 53 |
+
|
| 54 |
+
a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
|
| 55 |
+
b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
|
| 56 |
+
return torch.mm(a_norm, b_norm.transpose(0, 1))
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def encode_queries(queries: List[str], batch_size: int, **kwargs):
|
| 60 |
+
new_sentences = append_instruction(instruction, queries)
|
| 61 |
+
|
| 62 |
+
kwargs["show_progress_bar"] = False
|
| 63 |
+
return model.encode(new_sentences, batch_size=batch_size, **kwargs)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def encode_corpus(corpus: List[Dict[str, str]], batch_size: int, **kwargs):
|
| 67 |
+
if type(corpus) is dict:
|
| 68 |
+
sentences = [
|
| 69 |
+
(
|
| 70 |
+
(corpus["title"][i] + " " + corpus["text"][i]).strip()
|
| 71 |
+
if "title" in corpus
|
| 72 |
+
else corpus["text"][i].strip()
|
| 73 |
+
)
|
| 74 |
+
for i in range(len(corpus["text"]))
|
| 75 |
+
]
|
| 76 |
+
else:
|
| 77 |
+
sentences = [
|
| 78 |
+
(
|
| 79 |
+
(doc["title"] + " " + doc["text"]).strip()
|
| 80 |
+
if "title" in doc
|
| 81 |
+
else doc["text"].strip()
|
| 82 |
+
)
|
| 83 |
+
for doc in corpus
|
| 84 |
+
]
|
| 85 |
+
new_sentences = append_instruction("", sentences)
|
| 86 |
+
return model.encode(new_sentences, batch_size=batch_size, **kwargs)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
print("Encoding Queries...")
|
| 90 |
+
query_ids = list(queries.keys())
|
| 91 |
+
results = {qid: {} for qid in query_ids}
|
| 92 |
+
queries = [queries[qid] for qid in queries]
|
| 93 |
+
query_embeddings = encode_queries(
|
| 94 |
+
queries, batch_size=batch_size, show_progress_bar=True, convert_to_tensor=True
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
print("Sorting Corpus by document length (Longest first)...")
|
| 98 |
+
corpus_ids = sorted(
|
| 99 |
+
corpus,
|
| 100 |
+
key=lambda k: len(corpus[k].get("title", "") + corpus[k].get("text", "")),
|
| 101 |
+
reverse=True,
|
| 102 |
+
)
|
| 103 |
+
corpus = [corpus[cid] for cid in corpus_ids]
|
| 104 |
+
|
| 105 |
+
print("Encoding Corpus ... Warning: This might take a while!")
|
| 106 |
+
corpus_embeddings = encode_corpus(
|
| 107 |
+
corpus, batch_size=batch_size, show_progress_bar=True, convert_to_tensor=True
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
print("Scoring Function: {} ({})".format("Cosine Similarity", "cos_sim"))
|
| 111 |
+
cos_scores = cos_sim(query_embeddings, corpus_embeddings)
|
| 112 |
+
cos_scores[torch.isnan(cos_scores)] = -1
|
| 113 |
+
|
| 114 |
+
# Get top-k values
|
| 115 |
+
top_k = 1000
|
| 116 |
+
cos_scores_top_k_values, cos_scores_top_k_idx = torch.topk(
|
| 117 |
+
cos_scores, min(top_k + 1, len(cos_scores[0])), dim=1, largest=True, sorted=False
|
| 118 |
+
)
|
| 119 |
+
cos_scores_top_k_values = cos_scores_top_k_values.cpu().tolist()
|
| 120 |
+
cos_scores_top_k_idx = cos_scores_top_k_idx.cpu().tolist()
|
| 121 |
+
|
| 122 |
+
for query_itr in range(len(query_embeddings)):
|
| 123 |
+
query_id = query_ids[query_itr]
|
| 124 |
+
for sub_corpus_id, score in zip(
|
| 125 |
+
cos_scores_top_k_idx[query_itr], cos_scores_top_k_values[query_itr]
|
| 126 |
+
):
|
| 127 |
+
corpus_id = corpus_ids[sub_corpus_id]
|
| 128 |
+
if corpus_id != query_id:
|
| 129 |
+
results[query_id][corpus_id] = score
|
| 130 |
+
|
| 131 |
+
retriever = EvaluateRetrieval(model, score_function="cos_sim")
|
| 132 |
+
ndcg, _map, recall, precision = retriever.evaluate(
|
| 133 |
+
relevant_docs, results, retriever.k_values
|
| 134 |
+
)
|
| 135 |
+
mrr = retriever.evaluate_custom(relevant_docs, results, retriever.k_values, "mrr")
|
| 136 |
+
|
| 137 |
+
scores = {
|
| 138 |
+
**{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()},
|
| 139 |
+
**{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()},
|
| 140 |
+
**{f"recall_at_{k.split('@')[1]}": v for (k, v) in recall.items()},
|
| 141 |
+
**{f"precision_at_{k.split('@')[1]}": v for (k, v) in precision.items()},
|
| 142 |
+
**{f"mrr_at_{k.split('@')[1]}": v for (k, v) in mrr.items()},
|
| 143 |
+
}
|
| 144 |
+
print(scores)
|
| 145 |
+
"""
|
| 146 |
+
{
|
| 147 |
+
'ndcg_at_1': 0.32788,
|
| 148 |
+
'ndcg_at_3': 0.47534,
|
| 149 |
+
'ndcg_at_5': 0.52296,
|
| 150 |
+
'ndcg_at_10': 0.57505,
|
| 151 |
+
'ndcg_at_100': 0.6076,
|
| 152 |
+
'ndcg_at_1000': 0.60801,
|
| 153 |
+
'map_at_1': 0.32788,
|
| 154 |
+
'map_at_3': 0.43883,
|
| 155 |
+
'map_at_5': 0.46518,
|
| 156 |
+
'map_at_10': 0.48675,
|
| 157 |
+
'map_at_100': 0.49506,
|
| 158 |
+
'map_at_1000': 0.49509,
|
| 159 |
+
'recall_at_1': 0.32788,
|
| 160 |
+
'recall_at_3': 0.58108,
|
| 161 |
+
'recall_at_5': 0.69701,
|
| 162 |
+
'recall_at_10': 0.85775,
|
| 163 |
+
'recall_at_100': 0.9936,
|
| 164 |
+
'recall_at_1000': 0.99644,
|
| 165 |
+
'precision_at_1': 0.32788,
|
| 166 |
+
'precision_at_3': 0.19369,
|
| 167 |
+
'precision_at_5': 0.1394,
|
| 168 |
+
'precision_at_10': 0.08578,
|
| 169 |
+
'precision_at_100': 0.00994,
|
| 170 |
+
'precision_at_1000': 0.001,
|
| 171 |
+
'mrr_at_1': 0.33357,
|
| 172 |
+
'mrr_at_3': 0.44085,
|
| 173 |
+
'mrr_at_5': 0.46745,
|
| 174 |
+
'mrr_at_10': 0.4888,
|
| 175 |
+
'mrr_at_100': 0.49718,
|
| 176 |
+
'mrr_at_1000': 0.49721}
|
| 177 |
+
"""
|
llm2vec/examples/sts.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datasets
|
| 2 |
+
import numpy as np
|
| 3 |
+
from sklearn.metrics.pairwise import paired_cosine_distances
|
| 4 |
+
from scipy.stats import spearmanr
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from llm2vec import LLM2Vec
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
dataset = "mteb/sts17-crosslingual-sts"
|
| 11 |
+
instruction = "Retrieve semantically similar text: "
|
| 12 |
+
|
| 13 |
+
dataset = datasets.load_dataset(dataset, "en-en")
|
| 14 |
+
|
| 15 |
+
min_score, max_score = 0, 5
|
| 16 |
+
normalize = lambda x: (x - min_score) / (max_score - min_score)
|
| 17 |
+
normalized_scores = list(map(normalize, dataset["test"]["score"]))
|
| 18 |
+
batch_size = 8
|
| 19 |
+
|
| 20 |
+
sentences1, sentences2 = dataset["test"]["sentence1"], dataset["test"]["sentence2"]
|
| 21 |
+
|
| 22 |
+
print("Loading model...")
|
| 23 |
+
model = LLM2Vec.from_pretrained(
|
| 24 |
+
"McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp",
|
| 25 |
+
peft_model_name_or_path="McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp-supervised",
|
| 26 |
+
device_map="cuda" if torch.cuda.is_available() else "cpu",
|
| 27 |
+
torch_dtype=torch.bfloat16,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def append_instruction(instruction, sentences):
|
| 32 |
+
new_sentences = []
|
| 33 |
+
for s in sentences:
|
| 34 |
+
new_sentences.append([instruction, s, 0])
|
| 35 |
+
return new_sentences
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
print(f"Encoding {len(sentences1)} sentences1...")
|
| 39 |
+
sentences1 = append_instruction(instruction, sentences1)
|
| 40 |
+
embeddings1 = np.asarray(model.encode(sentences1, batch_size=batch_size))
|
| 41 |
+
|
| 42 |
+
print(f"Encoding {len(sentences2)} sentences2...")
|
| 43 |
+
sentences2 = append_instruction(instruction, sentences2)
|
| 44 |
+
embeddings2 = np.asarray(model.encode(sentences2, batch_size=batch_size))
|
| 45 |
+
|
| 46 |
+
print("Evaluating...")
|
| 47 |
+
cosine_scores = 1 - (paired_cosine_distances(embeddings1, embeddings2))
|
| 48 |
+
cosine_spearman, _ = spearmanr(normalized_scores, cosine_scores)
|
| 49 |
+
|
| 50 |
+
results = {
|
| 51 |
+
"cos_sim": {
|
| 52 |
+
"spearman": cosine_spearman,
|
| 53 |
+
}
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
print(results)
|
| 57 |
+
# {'cos_sim': {'spearman': 0.9021906216635642}}
|
llm2vec/experiments/mteb_eval.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import mteb
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
if __name__ == "__main__":
|
| 6 |
+
parser = argparse.ArgumentParser()
|
| 7 |
+
parser.add_argument(
|
| 8 |
+
"--model_name",
|
| 9 |
+
type=str,
|
| 10 |
+
default="McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp-supervised",
|
| 11 |
+
)
|
| 12 |
+
parser.add_argument("--task_name", type=str, default="STS16")
|
| 13 |
+
parser.add_argument(
|
| 14 |
+
"--task_to_instructions_fp",
|
| 15 |
+
type=str,
|
| 16 |
+
default="test_configs/mteb/task_to_instructions.json",
|
| 17 |
+
)
|
| 18 |
+
parser.add_argument("--output_dir", type=str, default="results")
|
| 19 |
+
|
| 20 |
+
args = parser.parse_args()
|
| 21 |
+
|
| 22 |
+
model_kwargs = {}
|
| 23 |
+
if args.task_to_instructions_fp is not None:
|
| 24 |
+
with open(args.task_to_instructions_fp, "r") as f:
|
| 25 |
+
task_to_instructions = json.load(f)
|
| 26 |
+
model_kwargs["task_to_instructions"] = task_to_instructions
|
| 27 |
+
|
| 28 |
+
model = mteb.get_model(args.model_name, **model_kwargs)
|
| 29 |
+
tasks = mteb.get_tasks(tasks=[args.task_name])
|
| 30 |
+
evaluation = mteb.MTEB(tasks=tasks)
|
| 31 |
+
results = evaluation.run(model, output_folder=args.output_dir)
|
llm2vec/experiments/mteb_eval_custom.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from typing import Any
|
| 3 |
+
import mteb
|
| 4 |
+
import json
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from mteb.models.instructions import task_to_instruction
|
| 9 |
+
from mteb.models.text_formatting_utils import corpus_to_texts
|
| 10 |
+
|
| 11 |
+
from llm2vec import LLM2Vec
|
| 12 |
+
|
| 13 |
+
def llm2vec_instruction(instruction):
|
| 14 |
+
if len(instruction) > 0 and instruction[-1] != ":":
|
| 15 |
+
instruction = instruction.strip(".") + ":"
|
| 16 |
+
return instruction
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class LLM2VecWrapper:
|
| 20 |
+
def __init__(self, model=None, task_to_instructions=None):
|
| 21 |
+
|
| 22 |
+
self.task_to_instructions = task_to_instructions
|
| 23 |
+
self.model = model
|
| 24 |
+
|
| 25 |
+
def encode(
|
| 26 |
+
self,
|
| 27 |
+
sentences: list[str],
|
| 28 |
+
*,
|
| 29 |
+
prompt_name: str = None,
|
| 30 |
+
**kwargs: Any, # noqa
|
| 31 |
+
) -> np.ndarray:
|
| 32 |
+
if prompt_name is not None:
|
| 33 |
+
instruction = (
|
| 34 |
+
self.task_to_instructions[prompt_name]
|
| 35 |
+
if self.task_to_instructions
|
| 36 |
+
and prompt_name in self.task_to_instructions
|
| 37 |
+
else llm2vec_instruction(task_to_instruction(prompt_name))
|
| 38 |
+
)
|
| 39 |
+
else:
|
| 40 |
+
instruction = ""
|
| 41 |
+
|
| 42 |
+
sentences = [[instruction, sentence] for sentence in sentences]
|
| 43 |
+
return self.model.encode(sentences, **kwargs)
|
| 44 |
+
|
| 45 |
+
def encode_corpus(
|
| 46 |
+
self,
|
| 47 |
+
corpus: list[dict[str, str]] | dict[str, list[str]] | list[str],
|
| 48 |
+
prompt_name: str = None,
|
| 49 |
+
**kwargs: Any,
|
| 50 |
+
) -> np.ndarray:
|
| 51 |
+
sentences = corpus_to_texts(corpus, sep=" ")
|
| 52 |
+
sentences = [["", sentence] for sentence in sentences]
|
| 53 |
+
if "request_qid" in kwargs:
|
| 54 |
+
kwargs.pop("request_qid")
|
| 55 |
+
return self.model.encode(sentences, **kwargs)
|
| 56 |
+
|
| 57 |
+
def encode_queries(self, queries: list[str], **kwargs: Any) -> np.ndarray:
|
| 58 |
+
return self.encode(queries, **kwargs)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
if __name__ == "__main__":
|
| 62 |
+
parser = argparse.ArgumentParser()
|
| 63 |
+
parser.add_argument(
|
| 64 |
+
"--base_model_name_or_path",
|
| 65 |
+
type=str,
|
| 66 |
+
default="McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp",
|
| 67 |
+
)
|
| 68 |
+
parser.add_argument(
|
| 69 |
+
"--peft_model_name_or_path",
|
| 70 |
+
type=str,
|
| 71 |
+
default="McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp-supervised",
|
| 72 |
+
)
|
| 73 |
+
parser.add_argument("--task_name", type=str, default="STS16")
|
| 74 |
+
parser.add_argument(
|
| 75 |
+
"--task_to_instructions_fp",
|
| 76 |
+
type=str,
|
| 77 |
+
default="test_configs/mteb/task_to_instructions.json",
|
| 78 |
+
)
|
| 79 |
+
parser.add_argument("--output_dir", type=str, default="results")
|
| 80 |
+
|
| 81 |
+
args = parser.parse_args()
|
| 82 |
+
|
| 83 |
+
task_to_instructions = None
|
| 84 |
+
if args.task_to_instructions_fp is not None:
|
| 85 |
+
with open(args.task_to_instructions_fp, "r") as f:
|
| 86 |
+
task_to_instructions = json.load(f)
|
| 87 |
+
|
| 88 |
+
l2v_model = LLM2Vec.from_pretrained(
|
| 89 |
+
args.base_model_name_or_path,
|
| 90 |
+
peft_model_name_or_path=args.peft_model_name_or_path,
|
| 91 |
+
device_map="cuda" if torch.cuda.is_available() else "cpu",
|
| 92 |
+
torch_dtype=torch.bfloat16,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
model = LLM2VecWrapper(model=l2v_model, task_to_instructions=task_to_instructions)
|
| 96 |
+
tasks = mteb.get_tasks(tasks=[args.task_name])
|
| 97 |
+
evaluation = mteb.MTEB(tasks=tasks)
|
| 98 |
+
results = evaluation.run(model, output_folder=args.output_dir)
|
llm2vec/experiments/run_mntp.py
ADDED
|
@@ -0,0 +1,997 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
# Copyright 2020 The HuggingFace Team All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""
|
| 17 |
+
The script is adapted from https://github.com/huggingface/transformers/blob/51bcadc10a569847b93a30dbe3a077037ae63bad/examples/pytorch/language-modeling/run_mlm.py
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import logging
|
| 21 |
+
import math
|
| 22 |
+
import os
|
| 23 |
+
import sys
|
| 24 |
+
import warnings
|
| 25 |
+
from dataclasses import dataclass, field
|
| 26 |
+
from itertools import chain
|
| 27 |
+
from typing import Optional, Any, Tuple, List
|
| 28 |
+
import numpy as np
|
| 29 |
+
|
| 30 |
+
import datasets
|
| 31 |
+
import evaluate
|
| 32 |
+
from datasets import load_dataset
|
| 33 |
+
|
| 34 |
+
import torch
|
| 35 |
+
import transformers
|
| 36 |
+
from transformers import (
|
| 37 |
+
CONFIG_MAPPING,
|
| 38 |
+
MODEL_FOR_MASKED_LM_MAPPING,
|
| 39 |
+
AutoConfig,
|
| 40 |
+
AutoTokenizer,
|
| 41 |
+
DataCollatorForLanguageModeling,
|
| 42 |
+
HfArgumentParser,
|
| 43 |
+
Trainer,
|
| 44 |
+
TrainingArguments,
|
| 45 |
+
TrainerCallback,
|
| 46 |
+
is_torch_tpu_available,
|
| 47 |
+
set_seed,
|
| 48 |
+
)
|
| 49 |
+
from transformers.trainer_utils import get_last_checkpoint
|
| 50 |
+
from transformers.utils import send_example_telemetry
|
| 51 |
+
from transformers.utils.versions import require_version
|
| 52 |
+
|
| 53 |
+
from peft import LoraConfig, get_peft_model
|
| 54 |
+
|
| 55 |
+
from llm2vec.models import (
|
| 56 |
+
MistralBiForMNTP,
|
| 57 |
+
LlamaBiForMNTP,
|
| 58 |
+
GemmaBiForMNTP,
|
| 59 |
+
Qwen2BiForMNTP,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
| 63 |
+
# check_min_version("4.38.0.dev0")
|
| 64 |
+
|
| 65 |
+
require_version(
|
| 66 |
+
"datasets>=1.8.0",
|
| 67 |
+
"To fix: pip install -r examples/pytorch/language-modeling/requirements.txt",
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
logger = logging.getLogger(__name__)
|
| 71 |
+
MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
|
| 72 |
+
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def get_model_class(config):
|
| 76 |
+
config_class_name = config.__class__.__name__
|
| 77 |
+
if config_class_name == "MistralConfig":
|
| 78 |
+
return MistralBiForMNTP
|
| 79 |
+
elif config_class_name == "LlamaConfig":
|
| 80 |
+
return LlamaBiForMNTP
|
| 81 |
+
elif config_class_name == "GemmaConfig":
|
| 82 |
+
return GemmaBiForMNTP
|
| 83 |
+
elif config_class_name == "Qwen2Config":
|
| 84 |
+
return Qwen2BiForMNTP
|
| 85 |
+
else:
|
| 86 |
+
raise ValueError(f"Model class {config_class_name} not supported.")
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def initialize_peft(
|
| 90 |
+
model,
|
| 91 |
+
lora_r: int = 8,
|
| 92 |
+
lora_alpha: int = 16,
|
| 93 |
+
lora_dropout: float = 0.05,
|
| 94 |
+
lora_modules: Optional[List[str]] = None,
|
| 95 |
+
):
|
| 96 |
+
if lora_modules is None and model.config.__class__.__name__ in [
|
| 97 |
+
"LlamaConfig",
|
| 98 |
+
"MistralConfig",
|
| 99 |
+
"GemmaConfig",
|
| 100 |
+
"Qwen2Config",
|
| 101 |
+
]:
|
| 102 |
+
lora_modules = [
|
| 103 |
+
"q_proj",
|
| 104 |
+
"v_proj",
|
| 105 |
+
"k_proj",
|
| 106 |
+
"o_proj",
|
| 107 |
+
"gate_proj",
|
| 108 |
+
"up_proj",
|
| 109 |
+
"down_proj",
|
| 110 |
+
]
|
| 111 |
+
elif lora_modules is None:
|
| 112 |
+
raise ValueError("lora_modules must be specified for this model.")
|
| 113 |
+
|
| 114 |
+
config = LoraConfig(
|
| 115 |
+
r=lora_r,
|
| 116 |
+
lora_alpha=lora_alpha,
|
| 117 |
+
target_modules=lora_modules,
|
| 118 |
+
lora_dropout=lora_dropout,
|
| 119 |
+
bias="none",
|
| 120 |
+
task_type=None,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
model = get_peft_model(model, config)
|
| 124 |
+
print(f"Model's Lora trainable parameters:")
|
| 125 |
+
model.print_trainable_parameters()
|
| 126 |
+
return model
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
@dataclass
|
| 130 |
+
class ModelArguments:
|
| 131 |
+
"""
|
| 132 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
model_name_or_path: Optional[str] = field(
|
| 136 |
+
default=None,
|
| 137 |
+
metadata={
|
| 138 |
+
"help": (
|
| 139 |
+
"The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
|
| 140 |
+
)
|
| 141 |
+
},
|
| 142 |
+
)
|
| 143 |
+
model_type: Optional[str] = field(
|
| 144 |
+
default=None,
|
| 145 |
+
metadata={
|
| 146 |
+
"help": "If training from scratch, pass a model type from the list: "
|
| 147 |
+
+ ", ".join(MODEL_TYPES)
|
| 148 |
+
},
|
| 149 |
+
)
|
| 150 |
+
config_overrides: Optional[str] = field(
|
| 151 |
+
default=None,
|
| 152 |
+
metadata={
|
| 153 |
+
"help": (
|
| 154 |
+
"Override some existing default config settings when a model is trained from scratch. Example: "
|
| 155 |
+
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
|
| 156 |
+
)
|
| 157 |
+
},
|
| 158 |
+
)
|
| 159 |
+
config_name: Optional[str] = field(
|
| 160 |
+
default=None,
|
| 161 |
+
metadata={
|
| 162 |
+
"help": "Pretrained config name or path if not the same as model_name"
|
| 163 |
+
},
|
| 164 |
+
)
|
| 165 |
+
tokenizer_name: Optional[str] = field(
|
| 166 |
+
default=None,
|
| 167 |
+
metadata={
|
| 168 |
+
"help": "Pretrained tokenizer name or path if not the same as model_name"
|
| 169 |
+
},
|
| 170 |
+
)
|
| 171 |
+
cache_dir: Optional[str] = field(
|
| 172 |
+
default=None,
|
| 173 |
+
metadata={
|
| 174 |
+
"help": "Where do you want to store the pretrained models downloaded from huggingface.co"
|
| 175 |
+
},
|
| 176 |
+
)
|
| 177 |
+
use_fast_tokenizer: bool = field(
|
| 178 |
+
default=True,
|
| 179 |
+
metadata={
|
| 180 |
+
"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
|
| 181 |
+
},
|
| 182 |
+
)
|
| 183 |
+
model_revision: str = field(
|
| 184 |
+
default="main",
|
| 185 |
+
metadata={
|
| 186 |
+
"help": "The specific model version to use (can be a branch name, tag name or commit id)."
|
| 187 |
+
},
|
| 188 |
+
)
|
| 189 |
+
token: str = field(
|
| 190 |
+
default=None,
|
| 191 |
+
metadata={
|
| 192 |
+
"help": (
|
| 193 |
+
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
|
| 194 |
+
"generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
|
| 195 |
+
)
|
| 196 |
+
},
|
| 197 |
+
)
|
| 198 |
+
use_auth_token: bool = field(
|
| 199 |
+
default=None,
|
| 200 |
+
metadata={
|
| 201 |
+
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."
|
| 202 |
+
},
|
| 203 |
+
)
|
| 204 |
+
trust_remote_code: bool = field(
|
| 205 |
+
default=False,
|
| 206 |
+
metadata={
|
| 207 |
+
"help": (
|
| 208 |
+
"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
|
| 209 |
+
"should only be set to `True` for repositories you trust and in which you have read the code, as it will "
|
| 210 |
+
"execute code present on the Hub on your local machine."
|
| 211 |
+
)
|
| 212 |
+
},
|
| 213 |
+
)
|
| 214 |
+
torch_dtype: Optional[str] = field(
|
| 215 |
+
default=None,
|
| 216 |
+
metadata={
|
| 217 |
+
"help": (
|
| 218 |
+
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
|
| 219 |
+
"dtype will be automatically derived from the model's weights."
|
| 220 |
+
),
|
| 221 |
+
"choices": ["auto", "bfloat16", "float16", "float32"],
|
| 222 |
+
},
|
| 223 |
+
)
|
| 224 |
+
attn_implementation: Optional[str] = field(
|
| 225 |
+
default="sdpa",
|
| 226 |
+
metadata={
|
| 227 |
+
"help": ("The attention implementation to use in the model."),
|
| 228 |
+
"choices": ["eager", "sdpa", "flash_attention_2"],
|
| 229 |
+
},
|
| 230 |
+
)
|
| 231 |
+
low_cpu_mem_usage: bool = field(
|
| 232 |
+
default=False,
|
| 233 |
+
metadata={
|
| 234 |
+
"help": (
|
| 235 |
+
"It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded. "
|
| 236 |
+
"set True will benefit LLM loading time and RAM consumption."
|
| 237 |
+
)
|
| 238 |
+
},
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
def __post_init__(self):
|
| 242 |
+
if self.config_overrides is not None and (
|
| 243 |
+
self.config_name is not None or self.model_name_or_path is not None
|
| 244 |
+
):
|
| 245 |
+
raise ValueError(
|
| 246 |
+
"--config_overrides can't be used in combination with --config_name or --model_name_or_path"
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
@dataclass
|
| 251 |
+
class DataTrainingArguments:
|
| 252 |
+
"""
|
| 253 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
| 254 |
+
"""
|
| 255 |
+
|
| 256 |
+
dataset_name: Optional[str] = field(
|
| 257 |
+
default=None,
|
| 258 |
+
metadata={"help": "The name of the dataset to use (via the datasets library)."},
|
| 259 |
+
)
|
| 260 |
+
dataset_config_name: Optional[str] = field(
|
| 261 |
+
default=None,
|
| 262 |
+
metadata={
|
| 263 |
+
"help": "The configuration name of the dataset to use (via the datasets library)."
|
| 264 |
+
},
|
| 265 |
+
)
|
| 266 |
+
train_file: Optional[str] = field(
|
| 267 |
+
default=None, metadata={"help": "The input training data file (a text file)."}
|
| 268 |
+
)
|
| 269 |
+
validation_file: Optional[str] = field(
|
| 270 |
+
default=None,
|
| 271 |
+
metadata={
|
| 272 |
+
"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."
|
| 273 |
+
},
|
| 274 |
+
)
|
| 275 |
+
overwrite_cache: bool = field(
|
| 276 |
+
default=True,
|
| 277 |
+
metadata={"help": "Overwrite the cached training and evaluation sets"},
|
| 278 |
+
)
|
| 279 |
+
validation_split_percentage: Optional[int] = field(
|
| 280 |
+
default=5,
|
| 281 |
+
metadata={
|
| 282 |
+
"help": "The percentage of the train set used as validation set in case there's no validation split"
|
| 283 |
+
},
|
| 284 |
+
)
|
| 285 |
+
max_seq_length: Optional[int] = field(
|
| 286 |
+
default=None,
|
| 287 |
+
metadata={
|
| 288 |
+
"help": (
|
| 289 |
+
"The maximum total input sequence length after tokenization. Sequences longer "
|
| 290 |
+
"than this will be truncated."
|
| 291 |
+
)
|
| 292 |
+
},
|
| 293 |
+
)
|
| 294 |
+
preprocessing_num_workers: Optional[int] = field(
|
| 295 |
+
default=None,
|
| 296 |
+
metadata={"help": "The number of processes to use for the preprocessing."},
|
| 297 |
+
)
|
| 298 |
+
mlm_probability: float = field(
|
| 299 |
+
default=0.15,
|
| 300 |
+
metadata={"help": "Ratio of tokens to mask for masked language modeling loss"},
|
| 301 |
+
)
|
| 302 |
+
line_by_line: bool = field(
|
| 303 |
+
default=False,
|
| 304 |
+
metadata={
|
| 305 |
+
"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."
|
| 306 |
+
},
|
| 307 |
+
)
|
| 308 |
+
pad_to_max_length: bool = field(
|
| 309 |
+
default=False,
|
| 310 |
+
metadata={
|
| 311 |
+
"help": (
|
| 312 |
+
"Whether to pad all samples to `max_seq_length`. "
|
| 313 |
+
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
| 314 |
+
)
|
| 315 |
+
},
|
| 316 |
+
)
|
| 317 |
+
max_train_samples: Optional[int] = field(
|
| 318 |
+
default=None,
|
| 319 |
+
metadata={
|
| 320 |
+
"help": (
|
| 321 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
| 322 |
+
"value if set."
|
| 323 |
+
)
|
| 324 |
+
},
|
| 325 |
+
)
|
| 326 |
+
max_eval_samples: Optional[int] = field(
|
| 327 |
+
default=None,
|
| 328 |
+
metadata={
|
| 329 |
+
"help": (
|
| 330 |
+
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
| 331 |
+
"value if set."
|
| 332 |
+
)
|
| 333 |
+
},
|
| 334 |
+
)
|
| 335 |
+
streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"})
|
| 336 |
+
|
| 337 |
+
def __post_init__(self):
|
| 338 |
+
if self.streaming:
|
| 339 |
+
require_version(
|
| 340 |
+
"datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`"
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
if (
|
| 344 |
+
self.dataset_name is None
|
| 345 |
+
and self.train_file is None
|
| 346 |
+
and self.validation_file is None
|
| 347 |
+
):
|
| 348 |
+
raise ValueError(
|
| 349 |
+
"Need either a dataset name or a training/validation file."
|
| 350 |
+
)
|
| 351 |
+
else:
|
| 352 |
+
if self.train_file is not None:
|
| 353 |
+
extension = self.train_file.split(".")[-1]
|
| 354 |
+
if extension not in ["csv", "json", "txt"]:
|
| 355 |
+
raise ValueError(
|
| 356 |
+
"`train_file` should be a csv, a json or a txt file."
|
| 357 |
+
)
|
| 358 |
+
if self.validation_file is not None:
|
| 359 |
+
extension = self.validation_file.split(".")[-1]
|
| 360 |
+
if extension not in ["csv", "json", "txt"]:
|
| 361 |
+
raise ValueError(
|
| 362 |
+
"`validation_file` should be a csv, a json or a txt file."
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
# add more arguments
|
| 367 |
+
@dataclass
|
| 368 |
+
class CustomArguments:
|
| 369 |
+
"""
|
| 370 |
+
Custom arguments for the script
|
| 371 |
+
"""
|
| 372 |
+
|
| 373 |
+
lora_dropout: float = field(
|
| 374 |
+
default=0.05, metadata={"help": "The dropout rate for lora"}
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
lora_r: int = field(default=8, metadata={"help": "The r value for lora"})
|
| 378 |
+
|
| 379 |
+
mask_token_type: str = field(
|
| 380 |
+
default="blank",
|
| 381 |
+
metadata={"help": "The type of mask token. Options: blank, eos, mask"},
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
stop_after_n_steps: int = field(
|
| 385 |
+
default=10000, metadata={"help": "Stop training after n steps"}
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
data_collator_type: str = field(
|
| 389 |
+
default="default",
|
| 390 |
+
metadata={"help": "The type of data collator. Options: default, all_mask"},
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
class DataCollatorForLanguageModelingWithFullMasking(DataCollatorForLanguageModeling):
|
| 395 |
+
def torch_mask_tokens(
|
| 396 |
+
self,
|
| 397 |
+
inputs: Any,
|
| 398 |
+
special_tokens_mask: Optional[Any] = None,
|
| 399 |
+
) -> Tuple[Any, Any]:
|
| 400 |
+
"""
|
| 401 |
+
Prepare masked tokens inputs/labels for masked language modeling: 100% MASK, 0% random, 0% original.
|
| 402 |
+
"""
|
| 403 |
+
import torch
|
| 404 |
+
|
| 405 |
+
labels = inputs.clone()
|
| 406 |
+
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
|
| 407 |
+
probability_matrix = torch.full(labels.shape, self.mlm_probability)
|
| 408 |
+
if special_tokens_mask is None:
|
| 409 |
+
special_tokens_mask = [
|
| 410 |
+
self.tokenizer.get_special_tokens_mask(
|
| 411 |
+
val, already_has_special_tokens=True
|
| 412 |
+
)
|
| 413 |
+
for val in labels.tolist()
|
| 414 |
+
]
|
| 415 |
+
special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
|
| 416 |
+
else:
|
| 417 |
+
special_tokens_mask = special_tokens_mask.bool()
|
| 418 |
+
|
| 419 |
+
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
|
| 420 |
+
masked_indices = torch.bernoulli(probability_matrix).bool()
|
| 421 |
+
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
| 422 |
+
|
| 423 |
+
# 100% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
| 424 |
+
inputs[masked_indices] = self.tokenizer.convert_tokens_to_ids(
|
| 425 |
+
self.tokenizer.mask_token
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
return inputs, labels
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
class StopTrainingCallback(TrainerCallback):
|
| 432 |
+
def __init__(self, stop_after_n_steps: int):
|
| 433 |
+
self.stop_after_n_steps = stop_after_n_steps
|
| 434 |
+
|
| 435 |
+
def on_step_end(self, args, state, control, **kwargs):
|
| 436 |
+
if state.global_step >= self.stop_after_n_steps:
|
| 437 |
+
control.should_training_stop = True
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
class MNTPTrainer(Trainer):
|
| 441 |
+
def __init__(self, *args, **kwargs):
|
| 442 |
+
super().__init__(*args, **kwargs)
|
| 443 |
+
self.label_names = ["labels"]
|
| 444 |
+
|
| 445 |
+
def _remove_unused_columns(
|
| 446 |
+
self, dataset: "datasets.Dataset", description: Optional[str] = None
|
| 447 |
+
):
|
| 448 |
+
return dataset
|
| 449 |
+
|
| 450 |
+
# We need a custom save function as we have to save the inner model
|
| 451 |
+
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
| 452 |
+
# If we are executing this function, we are the process zero, so we don't check for that.
|
| 453 |
+
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
| 454 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 455 |
+
logger.info(f"Saving model checkpoint to {output_dir}")
|
| 456 |
+
|
| 457 |
+
# model organization is MODEL_TYPEBiForMNTP.model -> MODEL_TYPELBiModel, we have to save the inner model, handled by save_peft_model function of the outer model
|
| 458 |
+
self.model.save_peft_model(output_dir)
|
| 459 |
+
self.tokenizer.save_pretrained(output_dir)
|
| 460 |
+
|
| 461 |
+
# Good practice: save your training arguments together with the trained model
|
| 462 |
+
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
def main():
|
| 466 |
+
# See all possible arguments in src/transformers/training_args.py
|
| 467 |
+
# or by passing the --help flag to this script.
|
| 468 |
+
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
| 469 |
+
|
| 470 |
+
parser = HfArgumentParser(
|
| 471 |
+
(ModelArguments, DataTrainingArguments, TrainingArguments, CustomArguments)
|
| 472 |
+
)
|
| 473 |
+
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
| 474 |
+
# If we pass only one argument to the script and it's the path to a json file,
|
| 475 |
+
# let's parse it to get our arguments.
|
| 476 |
+
model_args, data_args, training_args, custom_args = parser.parse_json_file(
|
| 477 |
+
json_file=os.path.abspath(sys.argv[1])
|
| 478 |
+
)
|
| 479 |
+
else:
|
| 480 |
+
(
|
| 481 |
+
model_args,
|
| 482 |
+
data_args,
|
| 483 |
+
training_args,
|
| 484 |
+
custom_args,
|
| 485 |
+
) = parser.parse_args_into_dataclasses()
|
| 486 |
+
|
| 487 |
+
if training_args.gradient_checkpointing:
|
| 488 |
+
training_args.gradient_checkpointing_kwargs = {"use_reentrant": False}
|
| 489 |
+
|
| 490 |
+
if model_args.use_auth_token is not None:
|
| 491 |
+
warnings.warn(
|
| 492 |
+
"The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead.",
|
| 493 |
+
FutureWarning,
|
| 494 |
+
)
|
| 495 |
+
if model_args.token is not None:
|
| 496 |
+
raise ValueError(
|
| 497 |
+
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
|
| 498 |
+
)
|
| 499 |
+
model_args.token = model_args.use_auth_token
|
| 500 |
+
|
| 501 |
+
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
|
| 502 |
+
# information sent is the one passed as arguments along with your Python/PyTorch versions.
|
| 503 |
+
send_example_telemetry("run_mlm", model_args, data_args)
|
| 504 |
+
|
| 505 |
+
# Setup logging
|
| 506 |
+
logging.basicConfig(
|
| 507 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 508 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 509 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
if training_args.should_log:
|
| 513 |
+
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
|
| 514 |
+
transformers.utils.logging.set_verbosity_info()
|
| 515 |
+
|
| 516 |
+
log_level = training_args.get_process_log_level()
|
| 517 |
+
logger.setLevel(log_level)
|
| 518 |
+
datasets.utils.logging.set_verbosity(log_level)
|
| 519 |
+
transformers.utils.logging.set_verbosity(log_level)
|
| 520 |
+
transformers.utils.logging.enable_default_handler()
|
| 521 |
+
transformers.utils.logging.enable_explicit_format()
|
| 522 |
+
|
| 523 |
+
# Log on each process the small summary:
|
| 524 |
+
logger.warning(
|
| 525 |
+
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
|
| 526 |
+
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
|
| 527 |
+
)
|
| 528 |
+
# Set the verbosity to info of the Transformers logger (on main process only):
|
| 529 |
+
logger.info(f"Training/evaluation parameters {training_args}")
|
| 530 |
+
|
| 531 |
+
# Detecting last checkpoint.
|
| 532 |
+
last_checkpoint = None
|
| 533 |
+
if (
|
| 534 |
+
os.path.isdir(training_args.output_dir)
|
| 535 |
+
and training_args.do_train
|
| 536 |
+
and not training_args.overwrite_output_dir
|
| 537 |
+
):
|
| 538 |
+
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
| 539 |
+
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
| 540 |
+
raise ValueError(
|
| 541 |
+
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
| 542 |
+
"Use --overwrite_output_dir to overcome."
|
| 543 |
+
)
|
| 544 |
+
elif (
|
| 545 |
+
last_checkpoint is not None and training_args.resume_from_checkpoint is None
|
| 546 |
+
):
|
| 547 |
+
logger.info(
|
| 548 |
+
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
| 549 |
+
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
# Set seed before initializing model.
|
| 553 |
+
set_seed(training_args.seed)
|
| 554 |
+
|
| 555 |
+
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
| 556 |
+
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
| 557 |
+
# (the dataset will be downloaded automatically from the datasets Hub
|
| 558 |
+
#
|
| 559 |
+
# For CSV/JSON files, this script will use the column called 'text' or the first column. You can easily tweak this
|
| 560 |
+
# behavior (see below)
|
| 561 |
+
#
|
| 562 |
+
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
|
| 563 |
+
# download the dataset.
|
| 564 |
+
if data_args.dataset_name is not None:
|
| 565 |
+
# Downloading and loading a dataset from the hub.
|
| 566 |
+
raw_datasets = load_dataset(
|
| 567 |
+
data_args.dataset_name,
|
| 568 |
+
data_args.dataset_config_name,
|
| 569 |
+
cache_dir=model_args.cache_dir,
|
| 570 |
+
token=model_args.token,
|
| 571 |
+
streaming=data_args.streaming,
|
| 572 |
+
)
|
| 573 |
+
if "validation" not in raw_datasets.keys():
|
| 574 |
+
raw_datasets["validation"] = load_dataset(
|
| 575 |
+
data_args.dataset_name,
|
| 576 |
+
data_args.dataset_config_name,
|
| 577 |
+
split=f"train[:{data_args.validation_split_percentage}%]",
|
| 578 |
+
cache_dir=model_args.cache_dir,
|
| 579 |
+
token=model_args.token,
|
| 580 |
+
streaming=data_args.streaming,
|
| 581 |
+
)
|
| 582 |
+
raw_datasets["train"] = load_dataset(
|
| 583 |
+
data_args.dataset_name,
|
| 584 |
+
data_args.dataset_config_name,
|
| 585 |
+
split=f"train[{data_args.validation_split_percentage}%:]",
|
| 586 |
+
cache_dir=model_args.cache_dir,
|
| 587 |
+
token=model_args.token,
|
| 588 |
+
streaming=data_args.streaming,
|
| 589 |
+
)
|
| 590 |
+
else:
|
| 591 |
+
data_files = {}
|
| 592 |
+
if data_args.train_file is not None:
|
| 593 |
+
data_files["train"] = data_args.train_file
|
| 594 |
+
extension = data_args.train_file.split(".")[-1]
|
| 595 |
+
if data_args.validation_file is not None:
|
| 596 |
+
data_files["validation"] = data_args.validation_file
|
| 597 |
+
extension = data_args.validation_file.split(".")[-1]
|
| 598 |
+
if extension == "txt":
|
| 599 |
+
extension = "text"
|
| 600 |
+
raw_datasets = load_dataset(
|
| 601 |
+
extension,
|
| 602 |
+
data_files=data_files,
|
| 603 |
+
cache_dir=model_args.cache_dir,
|
| 604 |
+
token=model_args.token,
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
|
| 608 |
+
if "validation" not in raw_datasets.keys():
|
| 609 |
+
raw_datasets["validation"] = load_dataset(
|
| 610 |
+
extension,
|
| 611 |
+
data_files=data_files,
|
| 612 |
+
split=f"train[:{data_args.validation_split_percentage}%]",
|
| 613 |
+
cache_dir=model_args.cache_dir,
|
| 614 |
+
token=model_args.token,
|
| 615 |
+
)
|
| 616 |
+
raw_datasets["train"] = load_dataset(
|
| 617 |
+
extension,
|
| 618 |
+
data_files=data_files,
|
| 619 |
+
split=f"train[{data_args.validation_split_percentage}%:]",
|
| 620 |
+
cache_dir=model_args.cache_dir,
|
| 621 |
+
token=model_args.token,
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
+
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
| 625 |
+
# https://huggingface.co/docs/datasets/loading_datasets.
|
| 626 |
+
|
| 627 |
+
# Load pretrained model and tokenizer
|
| 628 |
+
#
|
| 629 |
+
# Distributed training:
|
| 630 |
+
# The .from_pretrained methods guarantee that only one local process can concurrently
|
| 631 |
+
# download model & vocab.
|
| 632 |
+
config_kwargs = {
|
| 633 |
+
"cache_dir": model_args.cache_dir,
|
| 634 |
+
"revision": model_args.model_revision,
|
| 635 |
+
"token": model_args.token,
|
| 636 |
+
"trust_remote_code": model_args.trust_remote_code,
|
| 637 |
+
}
|
| 638 |
+
if model_args.config_name:
|
| 639 |
+
config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
|
| 640 |
+
elif model_args.model_name_or_path:
|
| 641 |
+
config = AutoConfig.from_pretrained(
|
| 642 |
+
model_args.model_name_or_path, **config_kwargs
|
| 643 |
+
)
|
| 644 |
+
else:
|
| 645 |
+
config = CONFIG_MAPPING[model_args.model_type]()
|
| 646 |
+
logger.warning("You are instantiating a new config instance from scratch.")
|
| 647 |
+
if model_args.config_overrides is not None:
|
| 648 |
+
logger.info(f"Overriding config: {model_args.config_overrides}")
|
| 649 |
+
config.update_from_string(model_args.config_overrides)
|
| 650 |
+
logger.info(f"New config: {config}")
|
| 651 |
+
|
| 652 |
+
tokenizer_kwargs = {
|
| 653 |
+
"cache_dir": model_args.cache_dir,
|
| 654 |
+
"use_fast": model_args.use_fast_tokenizer,
|
| 655 |
+
"revision": model_args.model_revision,
|
| 656 |
+
"token": model_args.token,
|
| 657 |
+
"trust_remote_code": model_args.trust_remote_code,
|
| 658 |
+
}
|
| 659 |
+
if model_args.tokenizer_name:
|
| 660 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 661 |
+
model_args.tokenizer_name, **tokenizer_kwargs
|
| 662 |
+
)
|
| 663 |
+
elif model_args.model_name_or_path:
|
| 664 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 665 |
+
model_args.model_name_or_path, **tokenizer_kwargs
|
| 666 |
+
)
|
| 667 |
+
else:
|
| 668 |
+
raise ValueError(
|
| 669 |
+
"You are instantiating a new tokenizer from scratch. This is not supported by this script. "
|
| 670 |
+
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
|
| 671 |
+
)
|
| 672 |
+
|
| 673 |
+
# blank, eos, mask
|
| 674 |
+
if tokenizer.mask_token is None:
|
| 675 |
+
if custom_args.mask_token_type == "blank":
|
| 676 |
+
tokenizer.mask_token = "_"
|
| 677 |
+
elif custom_args.mask_token_type == "eos":
|
| 678 |
+
tokenizer.mask_token = tokenizer.eos_token
|
| 679 |
+
elif custom_args.mask_token_type == "mask":
|
| 680 |
+
tokenizer.add_tokens(["<mask>"])
|
| 681 |
+
tokenizer.mask_token = "<mask>"
|
| 682 |
+
else:
|
| 683 |
+
raise ValueError(
|
| 684 |
+
f"mask_token_type {custom_args.mask_token_type} is not supported."
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
if tokenizer.pad_token is None:
|
| 688 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 689 |
+
|
| 690 |
+
# Loading bidirectional model using LLM2Vec package
|
| 691 |
+
model_class = get_model_class(config)
|
| 692 |
+
torch_dtype = (
|
| 693 |
+
model_args.torch_dtype
|
| 694 |
+
if model_args.torch_dtype in ["auto", None]
|
| 695 |
+
else getattr(torch, model_args.torch_dtype)
|
| 696 |
+
)
|
| 697 |
+
model = model_class.from_pretrained(
|
| 698 |
+
model_args.model_name_or_path,
|
| 699 |
+
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
| 700 |
+
config=config,
|
| 701 |
+
cache_dir=model_args.cache_dir,
|
| 702 |
+
revision=model_args.model_revision,
|
| 703 |
+
token=model_args.token,
|
| 704 |
+
trust_remote_code=model_args.trust_remote_code,
|
| 705 |
+
torch_dtype=torch_dtype,
|
| 706 |
+
low_cpu_mem_usage=model_args.low_cpu_mem_usage,
|
| 707 |
+
attn_implementation=model_args.attn_implementation,
|
| 708 |
+
)
|
| 709 |
+
|
| 710 |
+
# model organization is MODEL_TYPEBiForMNTP.model -> MODEL_TYPELBiModel, we have to apply PEFT to the inner model
|
| 711 |
+
model.model = initialize_peft(
|
| 712 |
+
model.model,
|
| 713 |
+
lora_r=custom_args.lora_r,
|
| 714 |
+
lora_alpha=2 * custom_args.lora_r,
|
| 715 |
+
lora_dropout=custom_args.lora_dropout,
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
|
| 719 |
+
# on a small vocab and want a smaller embedding size, remove this test.
|
| 720 |
+
embedding_size = model.get_input_embeddings().weight.shape[0]
|
| 721 |
+
if len(tokenizer) > embedding_size:
|
| 722 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 723 |
+
|
| 724 |
+
# Preprocessing the datasets.
|
| 725 |
+
# First we tokenize all the texts.
|
| 726 |
+
if training_args.do_train:
|
| 727 |
+
column_names = list(raw_datasets["train"].features)
|
| 728 |
+
else:
|
| 729 |
+
column_names = list(raw_datasets["validation"].features)
|
| 730 |
+
text_column_name = "text" if "text" in column_names else column_names[0]
|
| 731 |
+
|
| 732 |
+
if data_args.max_seq_length is None:
|
| 733 |
+
max_seq_length = tokenizer.model_max_length
|
| 734 |
+
if max_seq_length > 1024:
|
| 735 |
+
logger.warning(
|
| 736 |
+
"The chosen tokenizer supports a `model_max_length` that is longer than the default `block_size` value"
|
| 737 |
+
" of 1024. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can"
|
| 738 |
+
" override this default with `--block_size xxx`."
|
| 739 |
+
)
|
| 740 |
+
max_seq_length = 1024
|
| 741 |
+
else:
|
| 742 |
+
if data_args.max_seq_length > tokenizer.model_max_length:
|
| 743 |
+
logger.warning(
|
| 744 |
+
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the "
|
| 745 |
+
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
|
| 746 |
+
)
|
| 747 |
+
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
| 748 |
+
|
| 749 |
+
if data_args.line_by_line:
|
| 750 |
+
# When using line_by_line, we just tokenize each nonempty line.
|
| 751 |
+
padding = "max_length" if data_args.pad_to_max_length else False
|
| 752 |
+
|
| 753 |
+
def tokenize_function(examples):
|
| 754 |
+
# Remove empty lines
|
| 755 |
+
examples[text_column_name] = [
|
| 756 |
+
line
|
| 757 |
+
for line in examples[text_column_name]
|
| 758 |
+
if len(line) > 0 and not line.isspace()
|
| 759 |
+
]
|
| 760 |
+
return tokenizer(
|
| 761 |
+
examples[text_column_name],
|
| 762 |
+
padding=padding,
|
| 763 |
+
truncation=True,
|
| 764 |
+
max_length=max_seq_length,
|
| 765 |
+
# We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it
|
| 766 |
+
# receives the `special_tokens_mask`.
|
| 767 |
+
return_special_tokens_mask=True,
|
| 768 |
+
)
|
| 769 |
+
|
| 770 |
+
with training_args.main_process_first(desc="dataset map tokenization"):
|
| 771 |
+
if not data_args.streaming:
|
| 772 |
+
tokenized_datasets = raw_datasets.map(
|
| 773 |
+
tokenize_function,
|
| 774 |
+
batched=True,
|
| 775 |
+
num_proc=data_args.preprocessing_num_workers,
|
| 776 |
+
remove_columns=[text_column_name],
|
| 777 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
| 778 |
+
desc="Running tokenizer on dataset line_by_line",
|
| 779 |
+
)
|
| 780 |
+
else:
|
| 781 |
+
tokenized_datasets = raw_datasets.map(
|
| 782 |
+
tokenize_function,
|
| 783 |
+
batched=True,
|
| 784 |
+
remove_columns=[text_column_name],
|
| 785 |
+
)
|
| 786 |
+
else:
|
| 787 |
+
# Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
|
| 788 |
+
# We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
|
| 789 |
+
# efficient when it receives the `special_tokens_mask`.
|
| 790 |
+
def tokenize_function(examples):
|
| 791 |
+
return tokenizer(
|
| 792 |
+
examples[text_column_name], return_special_tokens_mask=True
|
| 793 |
+
)
|
| 794 |
+
|
| 795 |
+
with training_args.main_process_first(desc="dataset map tokenization"):
|
| 796 |
+
if not data_args.streaming:
|
| 797 |
+
tokenized_datasets = raw_datasets.map(
|
| 798 |
+
tokenize_function,
|
| 799 |
+
batched=True,
|
| 800 |
+
num_proc=data_args.preprocessing_num_workers,
|
| 801 |
+
remove_columns=column_names,
|
| 802 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
| 803 |
+
desc="Running tokenizer on every text in dataset",
|
| 804 |
+
)
|
| 805 |
+
else:
|
| 806 |
+
tokenized_datasets = raw_datasets.map(
|
| 807 |
+
tokenize_function,
|
| 808 |
+
batched=True,
|
| 809 |
+
remove_columns=column_names,
|
| 810 |
+
)
|
| 811 |
+
|
| 812 |
+
# Main data processing function that will concatenate all texts from our dataset and generate chunks of
|
| 813 |
+
# max_seq_length.
|
| 814 |
+
def group_texts(examples):
|
| 815 |
+
# Concatenate all texts.
|
| 816 |
+
concatenated_examples = {
|
| 817 |
+
k: list(chain(*examples[k])) for k in examples.keys()
|
| 818 |
+
}
|
| 819 |
+
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
| 820 |
+
# We drop the small remainder, and if the total_length < max_seq_length we exclude this batch and return an empty dict.
|
| 821 |
+
# We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
|
| 822 |
+
total_length = (total_length // max_seq_length) * max_seq_length
|
| 823 |
+
# Split by chunks of max_len.
|
| 824 |
+
result = {
|
| 825 |
+
k: [
|
| 826 |
+
t[i : i + max_seq_length]
|
| 827 |
+
for i in range(0, total_length, max_seq_length)
|
| 828 |
+
]
|
| 829 |
+
for k, t in concatenated_examples.items()
|
| 830 |
+
}
|
| 831 |
+
return result
|
| 832 |
+
|
| 833 |
+
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
|
| 834 |
+
# remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
|
| 835 |
+
# might be slower to preprocess.
|
| 836 |
+
#
|
| 837 |
+
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
|
| 838 |
+
# https://huggingface.co/docs/datasets/process#map
|
| 839 |
+
|
| 840 |
+
with training_args.main_process_first(desc="grouping texts together"):
|
| 841 |
+
if not data_args.streaming:
|
| 842 |
+
tokenized_datasets = tokenized_datasets.map(
|
| 843 |
+
group_texts,
|
| 844 |
+
batched=True,
|
| 845 |
+
num_proc=data_args.preprocessing_num_workers,
|
| 846 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
| 847 |
+
desc=f"Grouping texts in chunks of {max_seq_length}",
|
| 848 |
+
)
|
| 849 |
+
else:
|
| 850 |
+
tokenized_datasets = tokenized_datasets.map(
|
| 851 |
+
group_texts,
|
| 852 |
+
batched=True,
|
| 853 |
+
)
|
| 854 |
+
|
| 855 |
+
if training_args.do_train:
|
| 856 |
+
if "train" not in tokenized_datasets:
|
| 857 |
+
raise ValueError("--do_train requires a train dataset")
|
| 858 |
+
train_dataset = tokenized_datasets["train"]
|
| 859 |
+
if data_args.max_train_samples is not None:
|
| 860 |
+
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
|
| 861 |
+
train_dataset = train_dataset.select(range(max_train_samples))
|
| 862 |
+
|
| 863 |
+
if training_args.do_eval:
|
| 864 |
+
if "validation" not in tokenized_datasets:
|
| 865 |
+
raise ValueError("--do_eval requires a validation dataset")
|
| 866 |
+
eval_dataset = tokenized_datasets["validation"]
|
| 867 |
+
if data_args.max_eval_samples is not None:
|
| 868 |
+
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
|
| 869 |
+
eval_dataset = eval_dataset.select(range(max_eval_samples))
|
| 870 |
+
|
| 871 |
+
def preprocess_logits_for_metrics(logits, labels):
|
| 872 |
+
if isinstance(logits, tuple):
|
| 873 |
+
# Depending on the model and config, logits may contain extra tensors,
|
| 874 |
+
# like past_key_values, but logits always come first
|
| 875 |
+
logits = logits[0]
|
| 876 |
+
return logits.argmax(dim=-1)
|
| 877 |
+
|
| 878 |
+
metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir)
|
| 879 |
+
|
| 880 |
+
def compute_metrics(eval_preds):
|
| 881 |
+
preds, labels = eval_preds
|
| 882 |
+
preds = preds[:, :-1]
|
| 883 |
+
labels = labels[:, 1:]
|
| 884 |
+
# preds have the same shape as the labels, after the argmax(-1) has been calculated
|
| 885 |
+
# by preprocess_logits_for_metrics
|
| 886 |
+
labels = labels.reshape(-1)
|
| 887 |
+
preds = preds.reshape(-1)
|
| 888 |
+
mask = labels != -100
|
| 889 |
+
labels = labels[mask]
|
| 890 |
+
preds = preds[mask]
|
| 891 |
+
return metric.compute(predictions=preds, references=labels)
|
| 892 |
+
|
| 893 |
+
# Data collator
|
| 894 |
+
# This one will take care of randomly masking the tokens.
|
| 895 |
+
pad_to_multiple_of_8 = (
|
| 896 |
+
data_args.line_by_line
|
| 897 |
+
and training_args.fp16
|
| 898 |
+
and not data_args.pad_to_max_length
|
| 899 |
+
)
|
| 900 |
+
data_collator_cls = None
|
| 901 |
+
if custom_args.data_collator_type == "all_mask":
|
| 902 |
+
data_collator_cls = DataCollatorForLanguageModelingWithFullMasking
|
| 903 |
+
elif custom_args.data_collator_type == "default":
|
| 904 |
+
data_collator_cls = DataCollatorForLanguageModeling
|
| 905 |
+
else:
|
| 906 |
+
raise ValueError(
|
| 907 |
+
f"data_collator_type {custom_args.data_collator_type} is not supported."
|
| 908 |
+
)
|
| 909 |
+
|
| 910 |
+
data_collator = data_collator_cls(
|
| 911 |
+
tokenizer=tokenizer,
|
| 912 |
+
mlm_probability=data_args.mlm_probability,
|
| 913 |
+
pad_to_multiple_of=8 if pad_to_multiple_of_8 else None,
|
| 914 |
+
)
|
| 915 |
+
|
| 916 |
+
# Initialize our Trainer
|
| 917 |
+
trainer = MNTPTrainer(
|
| 918 |
+
model=model,
|
| 919 |
+
args=training_args,
|
| 920 |
+
train_dataset=train_dataset if training_args.do_train else None,
|
| 921 |
+
eval_dataset=eval_dataset if training_args.do_eval else None,
|
| 922 |
+
tokenizer=tokenizer,
|
| 923 |
+
data_collator=data_collator,
|
| 924 |
+
compute_metrics=(
|
| 925 |
+
compute_metrics
|
| 926 |
+
if training_args.do_eval and not is_torch_tpu_available()
|
| 927 |
+
else None
|
| 928 |
+
),
|
| 929 |
+
preprocess_logits_for_metrics=(
|
| 930 |
+
preprocess_logits_for_metrics
|
| 931 |
+
if training_args.do_eval and not is_torch_tpu_available()
|
| 932 |
+
else None
|
| 933 |
+
),
|
| 934 |
+
)
|
| 935 |
+
|
| 936 |
+
trainer.add_callback(StopTrainingCallback(custom_args.stop_after_n_steps))
|
| 937 |
+
|
| 938 |
+
# Training
|
| 939 |
+
if training_args.do_train:
|
| 940 |
+
checkpoint = None
|
| 941 |
+
if training_args.resume_from_checkpoint is not None:
|
| 942 |
+
checkpoint = training_args.resume_from_checkpoint
|
| 943 |
+
elif last_checkpoint is not None:
|
| 944 |
+
checkpoint = last_checkpoint
|
| 945 |
+
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
| 946 |
+
trainer.save_model() # Saves the tokenizer too for easy upload
|
| 947 |
+
metrics = train_result.metrics
|
| 948 |
+
|
| 949 |
+
max_train_samples = (
|
| 950 |
+
data_args.max_train_samples
|
| 951 |
+
if data_args.max_train_samples is not None
|
| 952 |
+
else len(train_dataset)
|
| 953 |
+
)
|
| 954 |
+
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
|
| 955 |
+
|
| 956 |
+
trainer.log_metrics("train", metrics)
|
| 957 |
+
trainer.save_metrics("train", metrics)
|
| 958 |
+
trainer.save_state()
|
| 959 |
+
|
| 960 |
+
# Evaluation
|
| 961 |
+
if training_args.do_eval:
|
| 962 |
+
logger.info("*** Evaluate ***")
|
| 963 |
+
|
| 964 |
+
metrics = trainer.evaluate()
|
| 965 |
+
|
| 966 |
+
max_eval_samples = (
|
| 967 |
+
data_args.max_eval_samples
|
| 968 |
+
if data_args.max_eval_samples is not None
|
| 969 |
+
else len(eval_dataset)
|
| 970 |
+
)
|
| 971 |
+
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
| 972 |
+
try:
|
| 973 |
+
perplexity = math.exp(metrics["eval_loss"])
|
| 974 |
+
except OverflowError:
|
| 975 |
+
perplexity = float("inf")
|
| 976 |
+
metrics["perplexity"] = perplexity
|
| 977 |
+
|
| 978 |
+
trainer.log_metrics("eval", metrics)
|
| 979 |
+
trainer.save_metrics("eval", metrics)
|
| 980 |
+
|
| 981 |
+
kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "fill-mask"}
|
| 982 |
+
if data_args.dataset_name is not None:
|
| 983 |
+
kwargs["dataset_tags"] = data_args.dataset_name
|
| 984 |
+
if data_args.dataset_config_name is not None:
|
| 985 |
+
kwargs["dataset_args"] = data_args.dataset_config_name
|
| 986 |
+
kwargs["dataset"] = (
|
| 987 |
+
f"{data_args.dataset_name} {data_args.dataset_config_name}"
|
| 988 |
+
)
|
| 989 |
+
else:
|
| 990 |
+
kwargs["dataset"] = data_args.dataset_name
|
| 991 |
+
|
| 992 |
+
if training_args.push_to_hub:
|
| 993 |
+
trainer.push_to_hub(**kwargs)
|
| 994 |
+
|
| 995 |
+
|
| 996 |
+
if __name__ == "__main__":
|
| 997 |
+
main()
|
llm2vec/experiments/run_simcse.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
|
| 10 |
+
from accelerate import Accelerator, DistributedDataParallelKwargs
|
| 11 |
+
from accelerate.logging import get_logger
|
| 12 |
+
|
| 13 |
+
import transformers
|
| 14 |
+
from transformers import (
|
| 15 |
+
MODEL_FOR_MASKED_LM_MAPPING,
|
| 16 |
+
HfArgumentParser,
|
| 17 |
+
TrainingArguments,
|
| 18 |
+
Trainer,
|
| 19 |
+
TrainerCallback,
|
| 20 |
+
set_seed,
|
| 21 |
+
)
|
| 22 |
+
from transformers.trainer_utils import seed_worker
|
| 23 |
+
|
| 24 |
+
from peft import LoraConfig, get_peft_model
|
| 25 |
+
|
| 26 |
+
from llm2vec import LLM2Vec
|
| 27 |
+
from llm2vec.dataset.utils import load_dataset
|
| 28 |
+
from llm2vec.loss.utils import load_loss
|
| 29 |
+
|
| 30 |
+
from tqdm import tqdm
|
| 31 |
+
|
| 32 |
+
transformers.logging.set_verbosity_error()
|
| 33 |
+
|
| 34 |
+
logging.basicConfig(
|
| 35 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
| 36 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 37 |
+
level=logging.INFO,
|
| 38 |
+
)
|
| 39 |
+
logger = get_logger(__name__, log_level="INFO")
|
| 40 |
+
MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
|
| 41 |
+
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def initialize_peft(
|
| 45 |
+
model,
|
| 46 |
+
lora_r: int = 8,
|
| 47 |
+
lora_alpha: int = 16,
|
| 48 |
+
lora_dropout: float = 0.05,
|
| 49 |
+
lora_modules: Optional[List[str]] = None,
|
| 50 |
+
):
|
| 51 |
+
if lora_modules is None and model.config.__class__.__name__ in [
|
| 52 |
+
"LlamaConfig",
|
| 53 |
+
"MistralConfig",
|
| 54 |
+
"GemmaConfig",
|
| 55 |
+
"Qwen2Config",
|
| 56 |
+
]:
|
| 57 |
+
lora_modules = [
|
| 58 |
+
"q_proj",
|
| 59 |
+
"v_proj",
|
| 60 |
+
"k_proj",
|
| 61 |
+
"o_proj",
|
| 62 |
+
"gate_proj",
|
| 63 |
+
"up_proj",
|
| 64 |
+
"down_proj",
|
| 65 |
+
]
|
| 66 |
+
elif lora_modules is None:
|
| 67 |
+
raise ValueError("lora_modules must be specified for this model.")
|
| 68 |
+
|
| 69 |
+
config = LoraConfig(
|
| 70 |
+
r=lora_r,
|
| 71 |
+
lora_alpha=lora_alpha,
|
| 72 |
+
target_modules=lora_modules,
|
| 73 |
+
lora_dropout=lora_dropout,
|
| 74 |
+
bias="none",
|
| 75 |
+
task_type=None,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
model = get_peft_model(model, config)
|
| 79 |
+
print(f"Model's Lora trainable parameters:")
|
| 80 |
+
model.print_trainable_parameters()
|
| 81 |
+
return model
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@dataclass
|
| 85 |
+
class ModelArguments:
|
| 86 |
+
"""
|
| 87 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
model_name_or_path: Optional[str] = field(
|
| 91 |
+
default=None,
|
| 92 |
+
metadata={
|
| 93 |
+
"help": (
|
| 94 |
+
"The base model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
|
| 95 |
+
)
|
| 96 |
+
},
|
| 97 |
+
)
|
| 98 |
+
peft_model_name_or_path: Optional[str] = field(
|
| 99 |
+
default=None,
|
| 100 |
+
metadata={"help": ("The PEFT model checkpoint to add on top of base model.")},
|
| 101 |
+
)
|
| 102 |
+
bidirectional: Optional[bool] = field(
|
| 103 |
+
default=False,
|
| 104 |
+
metadata={
|
| 105 |
+
"help": (
|
| 106 |
+
"Whether to enable bidirectional attention in the model. If set to False, the model will use unidirectional attention."
|
| 107 |
+
)
|
| 108 |
+
},
|
| 109 |
+
)
|
| 110 |
+
max_seq_length: Optional[int] = field(
|
| 111 |
+
default=None,
|
| 112 |
+
metadata={
|
| 113 |
+
"help": (
|
| 114 |
+
"The maximum total input sequence length after tokenization. Sequences longer "
|
| 115 |
+
"than this will be truncated."
|
| 116 |
+
)
|
| 117 |
+
},
|
| 118 |
+
)
|
| 119 |
+
torch_dtype: Optional[str] = field(
|
| 120 |
+
default=None,
|
| 121 |
+
metadata={
|
| 122 |
+
"help": (
|
| 123 |
+
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
|
| 124 |
+
"dtype will be automatically derived from the model's weights."
|
| 125 |
+
),
|
| 126 |
+
"choices": ["auto", "bfloat16", "float16", "float32"],
|
| 127 |
+
},
|
| 128 |
+
)
|
| 129 |
+
attn_implementation: Optional[str] = field(
|
| 130 |
+
default="sdpa",
|
| 131 |
+
metadata={
|
| 132 |
+
"help": ("The attention implementation to use in the model."),
|
| 133 |
+
"choices": ["eager", "sdpa", "flash_attention_2"],
|
| 134 |
+
},
|
| 135 |
+
)
|
| 136 |
+
pooling_mode: Optional[str] = field(
|
| 137 |
+
default="mean",
|
| 138 |
+
metadata={
|
| 139 |
+
"help": ("The pooling mode to use in the model."),
|
| 140 |
+
"choices": ["mean", "weighted_mean", "eos_token"],
|
| 141 |
+
},
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
@dataclass
|
| 146 |
+
class DataTrainingArguments:
|
| 147 |
+
"""
|
| 148 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
| 149 |
+
"""
|
| 150 |
+
|
| 151 |
+
dataset_name: Optional[str] = field(
|
| 152 |
+
default=None,
|
| 153 |
+
metadata={"help": "The name of the dataset to use. Options: E5"},
|
| 154 |
+
)
|
| 155 |
+
dataset_file_path: Optional[str] = field(
|
| 156 |
+
default=None, metadata={"help": "The input training data file or folder."}
|
| 157 |
+
)
|
| 158 |
+
# TODO: implement this
|
| 159 |
+
max_train_samples: Optional[int] = field(
|
| 160 |
+
default=None,
|
| 161 |
+
metadata={
|
| 162 |
+
"help": (
|
| 163 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
| 164 |
+
"value if set."
|
| 165 |
+
)
|
| 166 |
+
},
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
@dataclass
|
| 171 |
+
class CustomArguments:
|
| 172 |
+
"""
|
| 173 |
+
Custom arguments for the script
|
| 174 |
+
"""
|
| 175 |
+
|
| 176 |
+
simcse_dropout: float = field(
|
| 177 |
+
default=0.1, metadata={"help": "The SimCSE dropout rate for the model"}
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
lora_dropout: float = field(
|
| 181 |
+
default=0.05, metadata={"help": "The dropout rate for lora"}
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
lora_r: int = field(default=8, metadata={"help": "The r value for lora"})
|
| 185 |
+
|
| 186 |
+
stop_after_n_steps: int = field(
|
| 187 |
+
default=10000, metadata={"help": "Stop training after n steps"}
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
experiment_id: Optional[str] = field(
|
| 191 |
+
default=None, metadata={"help": "The experiment id"}
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
loss_class: Optional[str] = field(
|
| 195 |
+
default="HardNegativeNLLLoss",
|
| 196 |
+
metadata={
|
| 197 |
+
"help": "The loss class to use for training. Options: HardNegativeNLLLoss"
|
| 198 |
+
},
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
loss_scale: float = field(
|
| 202 |
+
default=50.0, metadata={"help": "The loss scale for the loss function"}
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
@dataclass
|
| 207 |
+
class DefaultCollator:
|
| 208 |
+
model: LLM2Vec
|
| 209 |
+
|
| 210 |
+
def __init__(self, model: LLM2Vec) -> None:
|
| 211 |
+
self.model = model
|
| 212 |
+
|
| 213 |
+
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
| 214 |
+
batch = features
|
| 215 |
+
num_texts = len(batch[0].texts)
|
| 216 |
+
texts = [[] for _ in range(num_texts)]
|
| 217 |
+
labels = []
|
| 218 |
+
|
| 219 |
+
for example in batch:
|
| 220 |
+
for idx, text in enumerate(example.texts):
|
| 221 |
+
# TODO: Add prepare_for_tokenization here similar to supervised training and see if it impacts performance
|
| 222 |
+
texts[idx].append(text)
|
| 223 |
+
labels.append(example.label)
|
| 224 |
+
labels = torch.tensor(labels)
|
| 225 |
+
|
| 226 |
+
sentence_features = []
|
| 227 |
+
for idx in range(num_texts):
|
| 228 |
+
tokenized = self.model.tokenize(texts[idx])
|
| 229 |
+
sentence_features.append(tokenized)
|
| 230 |
+
|
| 231 |
+
return sentence_features, labels
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class StopTrainingCallback(TrainerCallback):
|
| 235 |
+
def __init__(self, stop_after_n_steps: int):
|
| 236 |
+
self.stop_after_n_steps = stop_after_n_steps
|
| 237 |
+
|
| 238 |
+
def on_step_end(self, args, state, control, **kwargs):
|
| 239 |
+
if state.global_step >= self.stop_after_n_steps:
|
| 240 |
+
control.should_training_stop = True
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
class SimCSETrainer(Trainer):
|
| 244 |
+
def __init__(
|
| 245 |
+
self,
|
| 246 |
+
*args,
|
| 247 |
+
loss_function=None,
|
| 248 |
+
**kwargs,
|
| 249 |
+
) -> None:
|
| 250 |
+
super().__init__(*args, **kwargs)
|
| 251 |
+
self.loss_function = loss_function
|
| 252 |
+
|
| 253 |
+
def compute_loss(
|
| 254 |
+
self,
|
| 255 |
+
model: nn.Module,
|
| 256 |
+
inputs: Dict[str, Union[torch.Tensor, Any]],
|
| 257 |
+
return_outputs: bool = False,
|
| 258 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
|
| 259 |
+
features, labels = inputs
|
| 260 |
+
q_reps = self.model(features[0])
|
| 261 |
+
d_reps = self.model(features[1])
|
| 262 |
+
|
| 263 |
+
d_reps_neg = None
|
| 264 |
+
if len(features) > 2:
|
| 265 |
+
d_reps_neg = self.model(features[2])
|
| 266 |
+
|
| 267 |
+
loss = self.loss_function(q_reps, d_reps, d_reps_neg)
|
| 268 |
+
|
| 269 |
+
if return_outputs:
|
| 270 |
+
output = torch.cat(
|
| 271 |
+
[model(row)["sentence_embedding"][:, None] for row in features], dim=1
|
| 272 |
+
)
|
| 273 |
+
return loss, output
|
| 274 |
+
|
| 275 |
+
return loss
|
| 276 |
+
|
| 277 |
+
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
| 278 |
+
# If we are executing this function, we are the process zero, so we don't check for that.
|
| 279 |
+
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
| 280 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 281 |
+
logger.info(f"Saving model checkpoint to {output_dir}")
|
| 282 |
+
|
| 283 |
+
self.model.save(output_dir)
|
| 284 |
+
|
| 285 |
+
# Good practice: save your training arguments together with the trained model
|
| 286 |
+
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def main():
|
| 290 |
+
parser = HfArgumentParser(
|
| 291 |
+
(ModelArguments, DataTrainingArguments, TrainingArguments, CustomArguments)
|
| 292 |
+
)
|
| 293 |
+
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
| 294 |
+
# If we pass only one argument to the script and it's the path to a json file,
|
| 295 |
+
# let's parse it to get our arguments.
|
| 296 |
+
model_args, data_args, training_args, custom_args = parser.parse_json_file(
|
| 297 |
+
json_file=os.path.abspath(sys.argv[1])
|
| 298 |
+
)
|
| 299 |
+
else:
|
| 300 |
+
(
|
| 301 |
+
model_args,
|
| 302 |
+
data_args,
|
| 303 |
+
training_args,
|
| 304 |
+
custom_args,
|
| 305 |
+
) = parser.parse_args_into_dataclasses()
|
| 306 |
+
if training_args.ddp_find_unused_parameters:
|
| 307 |
+
kwargs = [
|
| 308 |
+
DistributedDataParallelKwargs(
|
| 309 |
+
dim=0,
|
| 310 |
+
broadcast_buffers=True,
|
| 311 |
+
bucket_cap_mb=25,
|
| 312 |
+
find_unused_parameters=True,
|
| 313 |
+
check_reduction=False,
|
| 314 |
+
gradient_as_bucket_view=False,
|
| 315 |
+
)
|
| 316 |
+
]
|
| 317 |
+
else:
|
| 318 |
+
kwargs = []
|
| 319 |
+
accelerator = Accelerator(kwargs_handlers=kwargs)
|
| 320 |
+
|
| 321 |
+
set_seed(training_args.seed)
|
| 322 |
+
|
| 323 |
+
if training_args.gradient_checkpointing:
|
| 324 |
+
training_args.gradient_checkpointing_kwargs = {"use_reentrant": False}
|
| 325 |
+
|
| 326 |
+
train_dataset = load_dataset(
|
| 327 |
+
data_args.dataset_name,
|
| 328 |
+
split="train",
|
| 329 |
+
file_path=data_args.dataset_file_path,
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
train_examples = [
|
| 333 |
+
train_dataset[i]
|
| 334 |
+
for i in tqdm(
|
| 335 |
+
range(len(train_dataset)),
|
| 336 |
+
desc="Loading train examples...",
|
| 337 |
+
disable=not accelerator.is_main_process,
|
| 338 |
+
)
|
| 339 |
+
]
|
| 340 |
+
|
| 341 |
+
torch_dtype = (
|
| 342 |
+
model_args.torch_dtype
|
| 343 |
+
if model_args.torch_dtype in ["auto", None]
|
| 344 |
+
else getattr(torch, model_args.torch_dtype)
|
| 345 |
+
)
|
| 346 |
+
model = LLM2Vec.from_pretrained(
|
| 347 |
+
base_model_name_or_path=model_args.model_name_or_path,
|
| 348 |
+
enable_bidirectional=model_args.bidirectional,
|
| 349 |
+
peft_model_name_or_path=model_args.peft_model_name_or_path,
|
| 350 |
+
merge_peft=True,
|
| 351 |
+
pooling_mode=model_args.pooling_mode,
|
| 352 |
+
max_length=model_args.max_seq_length,
|
| 353 |
+
torch_dtype=torch_dtype,
|
| 354 |
+
attn_implementation=model_args.attn_implementation,
|
| 355 |
+
attention_dropout=custom_args.simcse_dropout,
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
# model organization is LLM2VecModel.model -> HF Model, we have to apply PEFT to the inner model
|
| 359 |
+
model.model = initialize_peft(
|
| 360 |
+
model.model,
|
| 361 |
+
lora_r=custom_args.lora_r,
|
| 362 |
+
lora_alpha=2 * custom_args.lora_r,
|
| 363 |
+
lora_dropout=custom_args.lora_dropout,
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
tokenizer = model.tokenizer
|
| 367 |
+
|
| 368 |
+
train_loss = load_loss(custom_args.loss_class, scale=custom_args.loss_scale)
|
| 369 |
+
|
| 370 |
+
data_collator = DefaultCollator(model)
|
| 371 |
+
|
| 372 |
+
trainer = SimCSETrainer(
|
| 373 |
+
model=model,
|
| 374 |
+
args=training_args,
|
| 375 |
+
train_dataset=train_examples,
|
| 376 |
+
data_collator=data_collator,
|
| 377 |
+
tokenizer=tokenizer,
|
| 378 |
+
loss_function=train_loss,
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
if custom_args.stop_after_n_steps is not None:
|
| 382 |
+
trainer.add_callback(StopTrainingCallback(custom_args.stop_after_n_steps))
|
| 383 |
+
|
| 384 |
+
trainer.train()
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
if __name__ == "__main__":
|
| 388 |
+
main()
|
llm2vec/experiments/run_supervised.py
ADDED
|
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
from torch.utils.data import DataLoader, SequentialSampler
|
| 10 |
+
|
| 11 |
+
from accelerate import Accelerator, DistributedDataParallelKwargs
|
| 12 |
+
from accelerate.logging import get_logger
|
| 13 |
+
|
| 14 |
+
import transformers
|
| 15 |
+
from transformers import (
|
| 16 |
+
MODEL_FOR_MASKED_LM_MAPPING,
|
| 17 |
+
HfArgumentParser,
|
| 18 |
+
TrainingArguments,
|
| 19 |
+
Trainer,
|
| 20 |
+
TrainerCallback,
|
| 21 |
+
LlamaConfig,
|
| 22 |
+
MistralConfig,
|
| 23 |
+
GemmaConfig,
|
| 24 |
+
Qwen2Config,
|
| 25 |
+
set_seed,
|
| 26 |
+
)
|
| 27 |
+
from transformers.trainer_utils import seed_worker
|
| 28 |
+
|
| 29 |
+
from peft import LoraConfig, get_peft_model
|
| 30 |
+
|
| 31 |
+
from llm2vec import LLM2Vec
|
| 32 |
+
from llm2vec.dataset.utils import load_dataset
|
| 33 |
+
from llm2vec.loss.utils import load_loss
|
| 34 |
+
from llm2vec.experiment_utils import generate_experiment_id
|
| 35 |
+
|
| 36 |
+
from tqdm import tqdm
|
| 37 |
+
|
| 38 |
+
transformers.logging.set_verbosity_error()
|
| 39 |
+
|
| 40 |
+
logging.basicConfig(
|
| 41 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
| 42 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 43 |
+
level=logging.INFO,
|
| 44 |
+
)
|
| 45 |
+
logger = get_logger(__name__, log_level="INFO")
|
| 46 |
+
MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
|
| 47 |
+
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def prepare_for_tokenization(model, text, pooling_mode="mean"):
|
| 51 |
+
if model.config._name_or_path == "meta-llama/Meta-Llama-3-8B-Instruct":
|
| 52 |
+
text = (
|
| 53 |
+
"<|start_header_id|>user<|end_header_id|>\n\n" + text.strip() + "<|eot_id|>"
|
| 54 |
+
)
|
| 55 |
+
return text
|
| 56 |
+
if model.config._name_or_path in [
|
| 57 |
+
"mistralai/Mistral-7B-Instruct-v0.2",
|
| 58 |
+
"meta-llama/Llama-2-7b-chat-hf",
|
| 59 |
+
]:
|
| 60 |
+
text = "[INST] " + text.strip() + " [/INST]"
|
| 61 |
+
if model.config._name_or_path in [
|
| 62 |
+
"google/gemma-2-9b-it",
|
| 63 |
+
]:
|
| 64 |
+
text = "<bos><start_of_turn>user\n" + text.strip() + "<end_of_turn>"
|
| 65 |
+
if model.config._name_or_path in [
|
| 66 |
+
"Qwen/Qwen2-1.5B-Instruct",
|
| 67 |
+
"Qwen/Qwen2-7B-Instruct",
|
| 68 |
+
]:
|
| 69 |
+
text = "<|im_start|>user\n" + text.strip() + "<|im_end|>"
|
| 70 |
+
if pooling_mode == "eos_token":
|
| 71 |
+
if model.config._name_or_path == "meta-llama/Meta-Llama-3-8B":
|
| 72 |
+
text = text.strip() + "<|end_of_text|>"
|
| 73 |
+
elif isinstance(model.config, LlamaConfig) or isinstance(
|
| 74 |
+
model.config, MistralConfig
|
| 75 |
+
):
|
| 76 |
+
text = text.strip() + " </s>"
|
| 77 |
+
elif isinstance(model.config, GemmaConfig):
|
| 78 |
+
text = text.strip() + "<eos>"
|
| 79 |
+
elif isinstance(model.config, Qwen2Config):
|
| 80 |
+
text = text.strip() + "<|endoftext|>"
|
| 81 |
+
return text
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def initialize_peft(
|
| 85 |
+
model,
|
| 86 |
+
lora_r: int = 8,
|
| 87 |
+
lora_alpha: int = 16,
|
| 88 |
+
lora_dropout: float = 0.05,
|
| 89 |
+
lora_modules: Optional[List[str]] = None,
|
| 90 |
+
):
|
| 91 |
+
if lora_modules is None and model.config.__class__.__name__ in [
|
| 92 |
+
"LlamaConfig",
|
| 93 |
+
"MistralConfig",
|
| 94 |
+
"GemmaConfig",
|
| 95 |
+
"Qwen2Config",
|
| 96 |
+
]:
|
| 97 |
+
lora_modules = [
|
| 98 |
+
"q_proj",
|
| 99 |
+
"v_proj",
|
| 100 |
+
"k_proj",
|
| 101 |
+
"o_proj",
|
| 102 |
+
"gate_proj",
|
| 103 |
+
"up_proj",
|
| 104 |
+
"down_proj",
|
| 105 |
+
]
|
| 106 |
+
elif lora_modules is None:
|
| 107 |
+
raise ValueError("lora_modules must be specified for this model.")
|
| 108 |
+
|
| 109 |
+
config = LoraConfig(
|
| 110 |
+
r=lora_r,
|
| 111 |
+
lora_alpha=lora_alpha,
|
| 112 |
+
target_modules=lora_modules,
|
| 113 |
+
lora_dropout=lora_dropout,
|
| 114 |
+
bias="none",
|
| 115 |
+
task_type=None,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
model = get_peft_model(model, config)
|
| 119 |
+
print(f"Model's Lora trainable parameters:")
|
| 120 |
+
model.print_trainable_parameters()
|
| 121 |
+
return model
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
@dataclass
|
| 125 |
+
class ModelArguments:
|
| 126 |
+
"""
|
| 127 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
model_name_or_path: Optional[str] = field(
|
| 131 |
+
default=None,
|
| 132 |
+
metadata={
|
| 133 |
+
"help": (
|
| 134 |
+
"The base model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
|
| 135 |
+
)
|
| 136 |
+
},
|
| 137 |
+
)
|
| 138 |
+
peft_model_name_or_path: Optional[str] = field(
|
| 139 |
+
default=None,
|
| 140 |
+
metadata={"help": ("The PEFT model checkpoint to add on top of base model.")},
|
| 141 |
+
)
|
| 142 |
+
bidirectional: Optional[bool] = field(
|
| 143 |
+
default=False,
|
| 144 |
+
metadata={
|
| 145 |
+
"help": (
|
| 146 |
+
"Whether to enable bidirectional attention in the model. If set to False, the model will use unidirectional attention."
|
| 147 |
+
)
|
| 148 |
+
},
|
| 149 |
+
)
|
| 150 |
+
max_seq_length: Optional[int] = field(
|
| 151 |
+
default=None,
|
| 152 |
+
metadata={
|
| 153 |
+
"help": (
|
| 154 |
+
"The maximum total input sequence length after tokenization. Sequences longer "
|
| 155 |
+
"than this will be truncated."
|
| 156 |
+
)
|
| 157 |
+
},
|
| 158 |
+
)
|
| 159 |
+
torch_dtype: Optional[str] = field(
|
| 160 |
+
default=None,
|
| 161 |
+
metadata={
|
| 162 |
+
"help": (
|
| 163 |
+
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
|
| 164 |
+
"dtype will be automatically derived from the model's weights."
|
| 165 |
+
),
|
| 166 |
+
"choices": ["auto", "bfloat16", "float16", "float32"],
|
| 167 |
+
},
|
| 168 |
+
)
|
| 169 |
+
attn_implementation: Optional[str] = field(
|
| 170 |
+
default="sdpa",
|
| 171 |
+
metadata={
|
| 172 |
+
"help": ("The attention implementation to use in the model."),
|
| 173 |
+
"choices": ["eager", "sdpa", "flash_attention_2"],
|
| 174 |
+
},
|
| 175 |
+
)
|
| 176 |
+
pooling_mode: Optional[str] = field(
|
| 177 |
+
default="mean",
|
| 178 |
+
metadata={
|
| 179 |
+
"help": ("The pooling mode to use in the model."),
|
| 180 |
+
"choices": ["mean", "weighted_mean", "eos_token"],
|
| 181 |
+
},
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
@dataclass
|
| 186 |
+
class DataTrainingArguments:
|
| 187 |
+
"""
|
| 188 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
| 189 |
+
"""
|
| 190 |
+
|
| 191 |
+
dataset_name: Optional[str] = field(
|
| 192 |
+
default=None,
|
| 193 |
+
metadata={"help": "The name of the dataset to use. Options: E5"},
|
| 194 |
+
)
|
| 195 |
+
dataset_file_path: Optional[str] = field(
|
| 196 |
+
default=None, metadata={"help": "The input training data file or folder."}
|
| 197 |
+
)
|
| 198 |
+
# TODO: implement this
|
| 199 |
+
max_train_samples: Optional[int] = field(
|
| 200 |
+
default=None,
|
| 201 |
+
metadata={
|
| 202 |
+
"help": (
|
| 203 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
| 204 |
+
"value if set."
|
| 205 |
+
)
|
| 206 |
+
},
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
@dataclass
|
| 211 |
+
class CustomArguments:
|
| 212 |
+
"""
|
| 213 |
+
Custom arguments for the script
|
| 214 |
+
"""
|
| 215 |
+
|
| 216 |
+
lora_dropout: float = field(
|
| 217 |
+
default=0.05, metadata={"help": "The dropout rate for lora"}
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
lora_r: int = field(default=8, metadata={"help": "The r value for lora"})
|
| 221 |
+
|
| 222 |
+
stop_after_n_steps: int = field(
|
| 223 |
+
default=10000, metadata={"help": "Stop training after n steps"}
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
experiment_id: Optional[str] = field(
|
| 227 |
+
default=None, metadata={"help": "The experiment id"}
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
loss_class: Optional[str] = field(
|
| 231 |
+
default="HardNegativeNLLLoss",
|
| 232 |
+
metadata={
|
| 233 |
+
"help": "The loss class to use for training. Options: HardNegativeNLLLoss"
|
| 234 |
+
},
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
loss_scale: float = field(
|
| 238 |
+
default=50.0, metadata={"help": "The loss scale for the loss function"}
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
@dataclass
|
| 243 |
+
class DefaultCollator:
|
| 244 |
+
model: LLM2Vec
|
| 245 |
+
|
| 246 |
+
def __init__(self, model: LLM2Vec) -> None:
|
| 247 |
+
self.model = model
|
| 248 |
+
|
| 249 |
+
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
| 250 |
+
batch = features
|
| 251 |
+
num_texts = len(batch[0].texts)
|
| 252 |
+
texts = [[] for _ in range(num_texts)]
|
| 253 |
+
labels = []
|
| 254 |
+
|
| 255 |
+
for example in batch:
|
| 256 |
+
for idx, text in enumerate(example.texts):
|
| 257 |
+
text = prepare_for_tokenization(
|
| 258 |
+
self.model, text, pooling_mode=self.model.pooling_mode
|
| 259 |
+
)
|
| 260 |
+
texts[idx].append(text)
|
| 261 |
+
labels.append(example.label)
|
| 262 |
+
labels = torch.tensor(labels)
|
| 263 |
+
|
| 264 |
+
sentence_features = []
|
| 265 |
+
for idx in range(num_texts):
|
| 266 |
+
tokenized = self.model.tokenize(texts[idx])
|
| 267 |
+
sentence_features.append(tokenized)
|
| 268 |
+
|
| 269 |
+
return sentence_features, labels
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
class StopTrainingCallback(TrainerCallback):
|
| 273 |
+
def __init__(self, stop_after_n_steps: int):
|
| 274 |
+
self.stop_after_n_steps = stop_after_n_steps
|
| 275 |
+
|
| 276 |
+
def on_step_end(self, args, state, control, **kwargs):
|
| 277 |
+
if state.global_step >= self.stop_after_n_steps:
|
| 278 |
+
control.should_training_stop = True
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class LLM2VecSupervisedTrainer(Trainer):
|
| 282 |
+
def __init__(
|
| 283 |
+
self,
|
| 284 |
+
*args,
|
| 285 |
+
loss_function=None,
|
| 286 |
+
**kwargs,
|
| 287 |
+
) -> None:
|
| 288 |
+
super().__init__(*args, **kwargs)
|
| 289 |
+
self.loss_function = loss_function
|
| 290 |
+
|
| 291 |
+
def compute_loss(
|
| 292 |
+
self,
|
| 293 |
+
model: nn.Module,
|
| 294 |
+
inputs: Dict[str, Union[torch.Tensor, Any]],
|
| 295 |
+
return_outputs: bool = False,
|
| 296 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
|
| 297 |
+
features, labels = inputs
|
| 298 |
+
q_reps = self.model(features[0])
|
| 299 |
+
d_reps = self.model(features[1])
|
| 300 |
+
|
| 301 |
+
d_reps_neg = None
|
| 302 |
+
if len(features) > 2:
|
| 303 |
+
d_reps_neg = self.model(features[2])
|
| 304 |
+
|
| 305 |
+
loss = self.loss_function(q_reps, d_reps, d_reps_neg)
|
| 306 |
+
|
| 307 |
+
if return_outputs:
|
| 308 |
+
output = torch.cat(
|
| 309 |
+
[model(row)["sentence_embedding"][:, None] for row in features], dim=1
|
| 310 |
+
)
|
| 311 |
+
return loss, output
|
| 312 |
+
|
| 313 |
+
return loss
|
| 314 |
+
|
| 315 |
+
def get_train_dataloader(self) -> DataLoader:
|
| 316 |
+
# Copying most of the code from the parent class, changing the sampler to SequentialSampler
|
| 317 |
+
if self.train_dataset is None:
|
| 318 |
+
raise ValueError("Trainer: training requires a train_dataset.")
|
| 319 |
+
|
| 320 |
+
train_dataset = self.train_dataset
|
| 321 |
+
data_collator = self.data_collator
|
| 322 |
+
|
| 323 |
+
data_collator = self._get_collator_with_removed_columns(
|
| 324 |
+
data_collator, description="training"
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
dataloader_params = {
|
| 328 |
+
"batch_size": self._train_batch_size,
|
| 329 |
+
"collate_fn": data_collator,
|
| 330 |
+
"num_workers": self.args.dataloader_num_workers,
|
| 331 |
+
"pin_memory": self.args.dataloader_pin_memory,
|
| 332 |
+
"persistent_workers": self.args.dataloader_persistent_workers,
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
|
| 336 |
+
# Changing from random sampler to sequential sampler
|
| 337 |
+
dataloader_params["sampler"] = SequentialSampler(train_dataset)
|
| 338 |
+
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
| 339 |
+
dataloader_params["worker_init_fn"] = seed_worker
|
| 340 |
+
|
| 341 |
+
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
|
| 342 |
+
|
| 343 |
+
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
| 344 |
+
# If we are executing this function, we are the process zero, so we don't check for that.
|
| 345 |
+
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
| 346 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 347 |
+
logger.info(f"Saving model checkpoint to {output_dir}")
|
| 348 |
+
|
| 349 |
+
self.model.save(output_dir)
|
| 350 |
+
|
| 351 |
+
# Good practice: save your training arguments together with the trained model
|
| 352 |
+
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def main():
|
| 356 |
+
parser = HfArgumentParser(
|
| 357 |
+
(ModelArguments, DataTrainingArguments, TrainingArguments, CustomArguments)
|
| 358 |
+
)
|
| 359 |
+
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
| 360 |
+
# If we pass only one argument to the script and it's the path to a json file,
|
| 361 |
+
# let's parse it to get our arguments.
|
| 362 |
+
model_args, data_args, training_args, custom_args = parser.parse_json_file(
|
| 363 |
+
json_file=os.path.abspath(sys.argv[1])
|
| 364 |
+
)
|
| 365 |
+
else:
|
| 366 |
+
(
|
| 367 |
+
model_args,
|
| 368 |
+
data_args,
|
| 369 |
+
training_args,
|
| 370 |
+
custom_args,
|
| 371 |
+
) = parser.parse_args_into_dataclasses()
|
| 372 |
+
if training_args.ddp_find_unused_parameters:
|
| 373 |
+
kwargs = [
|
| 374 |
+
DistributedDataParallelKwargs(
|
| 375 |
+
dim=0,
|
| 376 |
+
broadcast_buffers=True,
|
| 377 |
+
bucket_cap_mb=25,
|
| 378 |
+
find_unused_parameters=True,
|
| 379 |
+
check_reduction=False,
|
| 380 |
+
gradient_as_bucket_view=False,
|
| 381 |
+
)
|
| 382 |
+
]
|
| 383 |
+
else:
|
| 384 |
+
kwargs = []
|
| 385 |
+
accelerator = Accelerator(kwargs_handlers=kwargs)
|
| 386 |
+
|
| 387 |
+
set_seed(training_args.seed)
|
| 388 |
+
|
| 389 |
+
if training_args.gradient_checkpointing:
|
| 390 |
+
training_args.gradient_checkpointing_kwargs = {"use_reentrant": False}
|
| 391 |
+
|
| 392 |
+
if custom_args.experiment_id is not None:
|
| 393 |
+
experiment_id = custom_args.experiment_id
|
| 394 |
+
else:
|
| 395 |
+
experiment_id = generate_experiment_id(
|
| 396 |
+
name=data_args.dataset_name,
|
| 397 |
+
split="train",
|
| 398 |
+
model_name=(
|
| 399 |
+
model_args.model_name_or_path
|
| 400 |
+
if "/" not in model_args.model_name_or_path
|
| 401 |
+
else model_args.model_name_or_path.split("/")[-1]
|
| 402 |
+
),
|
| 403 |
+
pooling_mode=model_args.pooling_mode,
|
| 404 |
+
train_batch_size=training_args.per_device_train_batch_size
|
| 405 |
+
* accelerator.num_processes
|
| 406 |
+
* training_args.gradient_accumulation_steps,
|
| 407 |
+
max_seq_length=model_args.max_seq_length,
|
| 408 |
+
bidirectional=model_args.bidirectional,
|
| 409 |
+
epochs=training_args.num_train_epochs,
|
| 410 |
+
seed=training_args.seed,
|
| 411 |
+
warmup_steps=training_args.warmup_steps,
|
| 412 |
+
lr=training_args.learning_rate,
|
| 413 |
+
lora_r=custom_args.lora_r,
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
training_args.output_dir = f"{training_args.output_dir}/{experiment_id}"
|
| 417 |
+
|
| 418 |
+
# TODO: can also pass separator arg here
|
| 419 |
+
train_dataset = load_dataset(
|
| 420 |
+
data_args.dataset_name,
|
| 421 |
+
split="train",
|
| 422 |
+
file_path=data_args.dataset_file_path,
|
| 423 |
+
effective_batch_size=training_args.per_device_train_batch_size
|
| 424 |
+
* accelerator.num_processes,
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
train_examples = [
|
| 428 |
+
train_dataset[i]
|
| 429 |
+
for i in tqdm(
|
| 430 |
+
range(len(train_dataset)),
|
| 431 |
+
desc="Loading train examples...",
|
| 432 |
+
disable=not accelerator.is_main_process,
|
| 433 |
+
)
|
| 434 |
+
]
|
| 435 |
+
|
| 436 |
+
torch_dtype = (
|
| 437 |
+
model_args.torch_dtype
|
| 438 |
+
if model_args.torch_dtype in ["auto", None]
|
| 439 |
+
else getattr(torch, model_args.torch_dtype)
|
| 440 |
+
)
|
| 441 |
+
model = LLM2Vec.from_pretrained(
|
| 442 |
+
base_model_name_or_path=model_args.model_name_or_path,
|
| 443 |
+
enable_bidirectional=model_args.bidirectional,
|
| 444 |
+
peft_model_name_or_path=model_args.peft_model_name_or_path,
|
| 445 |
+
merge_peft=True,
|
| 446 |
+
pooling_mode=model_args.pooling_mode,
|
| 447 |
+
max_length=model_args.max_seq_length,
|
| 448 |
+
torch_dtype=torch_dtype,
|
| 449 |
+
attn_implementation=model_args.attn_implementation,
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
# model organization is LLM2VecModel.model -> HF Model, we have to apply PEFT to the inner model
|
| 453 |
+
model.model = initialize_peft(
|
| 454 |
+
model.model,
|
| 455 |
+
lora_r=custom_args.lora_r,
|
| 456 |
+
lora_alpha=2 * custom_args.lora_r,
|
| 457 |
+
lora_dropout=custom_args.lora_dropout,
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
tokenizer = model.tokenizer
|
| 461 |
+
|
| 462 |
+
train_loss = load_loss(custom_args.loss_class, scale=custom_args.loss_scale)
|
| 463 |
+
|
| 464 |
+
data_collator = DefaultCollator(model)
|
| 465 |
+
|
| 466 |
+
trainer = LLM2VecSupervisedTrainer(
|
| 467 |
+
model=model,
|
| 468 |
+
args=training_args,
|
| 469 |
+
train_dataset=train_examples,
|
| 470 |
+
data_collator=data_collator,
|
| 471 |
+
tokenizer=tokenizer,
|
| 472 |
+
loss_function=train_loss,
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
if custom_args.stop_after_n_steps is not None:
|
| 476 |
+
trainer.add_callback(StopTrainingCallback(custom_args.stop_after_n_steps))
|
| 477 |
+
|
| 478 |
+
trainer.train()
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
if __name__ == "__main__":
|
| 482 |
+
main()
|
llm2vec/experiments/run_word_task.py
ADDED
|
@@ -0,0 +1,905 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The script is adapted from https://huggingface.co/docs/transformers/en/tasks/token_classification
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import warnings
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
import numpy as np
|
| 11 |
+
from typing import List, Optional, Tuple, Union
|
| 12 |
+
|
| 13 |
+
import datasets
|
| 14 |
+
import evaluate
|
| 15 |
+
from datasets import load_dataset
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
from torch import nn
|
| 19 |
+
from torch.nn import CrossEntropyLoss
|
| 20 |
+
import transformers
|
| 21 |
+
from transformers import (
|
| 22 |
+
PreTrainedModel,
|
| 23 |
+
MODEL_FOR_MASKED_LM_MAPPING,
|
| 24 |
+
AutoConfig,
|
| 25 |
+
AutoTokenizer,
|
| 26 |
+
HfArgumentParser,
|
| 27 |
+
Trainer,
|
| 28 |
+
TrainingArguments,
|
| 29 |
+
TrainerCallback,
|
| 30 |
+
set_seed,
|
| 31 |
+
AutoModelForTokenClassification,
|
| 32 |
+
DataCollatorForTokenClassification,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
from transformers.modeling_outputs import TokenClassifierOutput
|
| 36 |
+
from transformers.utils import send_example_telemetry
|
| 37 |
+
from transformers.utils.versions import require_version
|
| 38 |
+
|
| 39 |
+
from llm2vec import LLM2Vec
|
| 40 |
+
|
| 41 |
+
require_version(
|
| 42 |
+
"datasets>=1.8.0",
|
| 43 |
+
"To fix: pip install -r examples/pytorch/language-modeling/requirements.txt",
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class ModelForWordTask(PreTrainedModel):
|
| 48 |
+
def __init__(self, config, model, merge_subwords=False, **model_args):
|
| 49 |
+
PreTrainedModel.__init__(self, config)
|
| 50 |
+
self.model = model
|
| 51 |
+
self.merge_subwords = merge_subwords
|
| 52 |
+
|
| 53 |
+
if (
|
| 54 |
+
hasattr(config, "classifier_dropout")
|
| 55 |
+
and config.classifier_dropout is not None
|
| 56 |
+
):
|
| 57 |
+
classifier_dropout = config.classifier_dropout
|
| 58 |
+
elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
|
| 59 |
+
classifier_dropout = config.hidden_dropout
|
| 60 |
+
else:
|
| 61 |
+
classifier_dropout = 0.1
|
| 62 |
+
|
| 63 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
| 64 |
+
self.num_labels = config.num_labels
|
| 65 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels).to(
|
| 66 |
+
model_args.get("torch_dtype")
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# Initialize weights and apply final processing
|
| 70 |
+
self.post_init()
|
| 71 |
+
|
| 72 |
+
def _merge_subwords(self, hidden_states, token_type_ids, attention_mask):
|
| 73 |
+
new_hidden_states = hidden_states.clone()
|
| 74 |
+
for b in range(hidden_states.shape[0]):
|
| 75 |
+
for w in torch.arange(0, token_type_ids[b].max() + 1):
|
| 76 |
+
words_w = (token_type_ids[b] == w) * (attention_mask[b] > 0)
|
| 77 |
+
new_hidden_states[b][words_w] = torch.mean(
|
| 78 |
+
hidden_states[b][words_w], dim=0
|
| 79 |
+
).repeat(sum(words_w), 1)
|
| 80 |
+
return new_hidden_states
|
| 81 |
+
|
| 82 |
+
def forward(
|
| 83 |
+
self,
|
| 84 |
+
input_ids: torch.LongTensor = None,
|
| 85 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 86 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 87 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 88 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 89 |
+
use_cache: Optional[bool] = None,
|
| 90 |
+
output_attentions: Optional[bool] = None,
|
| 91 |
+
output_hidden_states: Optional[bool] = None,
|
| 92 |
+
return_dict: Optional[bool] = None,
|
| 93 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 94 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 95 |
+
labels: Optional[torch.LongTensor] = None,
|
| 96 |
+
) -> Union[Tuple, TokenClassifierOutput]:
|
| 97 |
+
output_attentions = (
|
| 98 |
+
output_attentions
|
| 99 |
+
if output_attentions is not None
|
| 100 |
+
else self.config.output_attentions
|
| 101 |
+
)
|
| 102 |
+
output_hidden_states = (
|
| 103 |
+
output_hidden_states
|
| 104 |
+
if output_hidden_states is not None
|
| 105 |
+
else self.config.output_hidden_states
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
return_dict = (
|
| 109 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 113 |
+
outputs = self.model(
|
| 114 |
+
input_ids=input_ids,
|
| 115 |
+
attention_mask=attention_mask,
|
| 116 |
+
position_ids=position_ids,
|
| 117 |
+
past_key_values=past_key_values,
|
| 118 |
+
inputs_embeds=inputs_embeds,
|
| 119 |
+
use_cache=use_cache,
|
| 120 |
+
output_attentions=output_attentions,
|
| 121 |
+
output_hidden_states=output_hidden_states,
|
| 122 |
+
return_dict=return_dict,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
hidden_states = outputs[0]
|
| 126 |
+
|
| 127 |
+
if self.merge_subwords:
|
| 128 |
+
hidden_states = self._merge_subwords(
|
| 129 |
+
hidden_states, token_type_ids, attention_mask
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
hidden_states = self.dropout(hidden_states)
|
| 133 |
+
logits = self.classifier(hidden_states)
|
| 134 |
+
|
| 135 |
+
loss = None
|
| 136 |
+
if labels is not None:
|
| 137 |
+
labels = labels.to(logits.device)
|
| 138 |
+
loss_fct = CrossEntropyLoss()
|
| 139 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 140 |
+
|
| 141 |
+
if not return_dict:
|
| 142 |
+
output = (logits,) + outputs.hidden_states
|
| 143 |
+
return ((loss,) + output) if loss is not None else output
|
| 144 |
+
|
| 145 |
+
return TokenClassifierOutput(
|
| 146 |
+
loss=loss,
|
| 147 |
+
logits=logits,
|
| 148 |
+
hidden_states=hidden_states,
|
| 149 |
+
attentions=outputs.attentions,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
logger = logging.getLogger(__name__)
|
| 154 |
+
MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
|
| 155 |
+
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
| 156 |
+
LABELS = {
|
| 157 |
+
"conll2003": {
|
| 158 |
+
"pos_tags": {
|
| 159 |
+
'"': 0,
|
| 160 |
+
"''": 1,
|
| 161 |
+
"#": 2,
|
| 162 |
+
"$": 3,
|
| 163 |
+
"(": 4,
|
| 164 |
+
")": 5,
|
| 165 |
+
",": 6,
|
| 166 |
+
".": 7,
|
| 167 |
+
":": 8,
|
| 168 |
+
"``": 9,
|
| 169 |
+
"CC": 10,
|
| 170 |
+
"CD": 11,
|
| 171 |
+
"DT": 12,
|
| 172 |
+
"EX": 13,
|
| 173 |
+
"FW": 14,
|
| 174 |
+
"IN": 15,
|
| 175 |
+
"JJ": 16,
|
| 176 |
+
"JJR": 17,
|
| 177 |
+
"JJS": 18,
|
| 178 |
+
"LS": 19,
|
| 179 |
+
"MD": 20,
|
| 180 |
+
"NN": 21,
|
| 181 |
+
"NNP": 22,
|
| 182 |
+
"NNPS": 23,
|
| 183 |
+
"NNS": 24,
|
| 184 |
+
"NN|SYM": 25,
|
| 185 |
+
"PDT": 26,
|
| 186 |
+
"POS": 27,
|
| 187 |
+
"PRP": 28,
|
| 188 |
+
"PRP$": 29,
|
| 189 |
+
"RB": 30,
|
| 190 |
+
"RBR": 31,
|
| 191 |
+
"RBS": 32,
|
| 192 |
+
"RP": 33,
|
| 193 |
+
"SYM": 34,
|
| 194 |
+
"TO": 35,
|
| 195 |
+
"UH": 36,
|
| 196 |
+
"VB": 37,
|
| 197 |
+
"VBD": 38,
|
| 198 |
+
"VBG": 39,
|
| 199 |
+
"VBN": 40,
|
| 200 |
+
"VBP": 41,
|
| 201 |
+
"VBZ": 42,
|
| 202 |
+
"WDT": 43,
|
| 203 |
+
"WP": 44,
|
| 204 |
+
"WP$": 45,
|
| 205 |
+
"WRB": 46,
|
| 206 |
+
},
|
| 207 |
+
"chunk_tags": {
|
| 208 |
+
"O": 0,
|
| 209 |
+
"B-ADJP": 1,
|
| 210 |
+
"I-ADJP": 2,
|
| 211 |
+
"B-ADVP": 3,
|
| 212 |
+
"I-ADVP": 4,
|
| 213 |
+
"B-CONJP": 5,
|
| 214 |
+
"I-CONJP": 6,
|
| 215 |
+
"B-INTJ": 7,
|
| 216 |
+
"I-INTJ": 8,
|
| 217 |
+
"B-LST": 9,
|
| 218 |
+
"I-LST": 10,
|
| 219 |
+
"B-NP": 11,
|
| 220 |
+
"I-NP": 12,
|
| 221 |
+
"B-PP": 13,
|
| 222 |
+
"I-PP": 14,
|
| 223 |
+
"B-PRT": 15,
|
| 224 |
+
"I-PRT": 16,
|
| 225 |
+
"B-SBAR": 17,
|
| 226 |
+
"I-SBAR": 18,
|
| 227 |
+
"B-UCP": 19,
|
| 228 |
+
"I-UCP": 20,
|
| 229 |
+
"B-VP": 21,
|
| 230 |
+
"I-VP": 22,
|
| 231 |
+
},
|
| 232 |
+
"ner_tags": {
|
| 233 |
+
"O": 0,
|
| 234 |
+
"B-PER": 1,
|
| 235 |
+
"I-PER": 2,
|
| 236 |
+
"B-ORG": 3,
|
| 237 |
+
"I-ORG": 4,
|
| 238 |
+
"B-LOC": 5,
|
| 239 |
+
"I-LOC": 6,
|
| 240 |
+
"B-MISC": 7,
|
| 241 |
+
"I-MISC": 8,
|
| 242 |
+
},
|
| 243 |
+
}
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
@dataclass
|
| 248 |
+
class ModelArguments:
|
| 249 |
+
"""
|
| 250 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
| 251 |
+
"""
|
| 252 |
+
|
| 253 |
+
model_name_or_path: Optional[str] = field(
|
| 254 |
+
default=None,
|
| 255 |
+
metadata={},
|
| 256 |
+
)
|
| 257 |
+
config_overrides: Optional[str] = field(
|
| 258 |
+
default=None,
|
| 259 |
+
metadata={
|
| 260 |
+
"help": (
|
| 261 |
+
"Override some existing default config settings when a model is trained from scratch. Example: "
|
| 262 |
+
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
|
| 263 |
+
)
|
| 264 |
+
},
|
| 265 |
+
)
|
| 266 |
+
config_name: Optional[str] = field(
|
| 267 |
+
default=None,
|
| 268 |
+
metadata={
|
| 269 |
+
"help": "Pretrained config name or path if not the same as model_name"
|
| 270 |
+
},
|
| 271 |
+
)
|
| 272 |
+
tokenizer_name: Optional[str] = field(
|
| 273 |
+
default=None,
|
| 274 |
+
metadata={
|
| 275 |
+
"help": "Pretrained tokenizer name or path if not the same as model_name"
|
| 276 |
+
},
|
| 277 |
+
)
|
| 278 |
+
cache_dir: Optional[str] = field(
|
| 279 |
+
default=None,
|
| 280 |
+
metadata={
|
| 281 |
+
"help": "Where do you want to store the pretrained models downloaded from huggingface.co"
|
| 282 |
+
},
|
| 283 |
+
)
|
| 284 |
+
use_fast_tokenizer: bool = field(
|
| 285 |
+
default=True,
|
| 286 |
+
metadata={
|
| 287 |
+
"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
|
| 288 |
+
},
|
| 289 |
+
)
|
| 290 |
+
model_revision: str = field(
|
| 291 |
+
default="main",
|
| 292 |
+
metadata={
|
| 293 |
+
"help": "The specific model version to use (can be a branch name, tag name or commit id)."
|
| 294 |
+
},
|
| 295 |
+
)
|
| 296 |
+
token: str = field(
|
| 297 |
+
default=None,
|
| 298 |
+
metadata={
|
| 299 |
+
"help": (
|
| 300 |
+
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
|
| 301 |
+
"generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
|
| 302 |
+
)
|
| 303 |
+
},
|
| 304 |
+
)
|
| 305 |
+
use_auth_token: bool = field(
|
| 306 |
+
default=None,
|
| 307 |
+
metadata={
|
| 308 |
+
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."
|
| 309 |
+
},
|
| 310 |
+
)
|
| 311 |
+
trust_remote_code: bool = field(
|
| 312 |
+
default=False,
|
| 313 |
+
metadata={
|
| 314 |
+
"help": (
|
| 315 |
+
"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
|
| 316 |
+
"should only be set to `True` for repositories you trust and in which you have read the code, as it will "
|
| 317 |
+
"execute code present on the Hub on your local machine."
|
| 318 |
+
)
|
| 319 |
+
},
|
| 320 |
+
)
|
| 321 |
+
low_cpu_mem_usage: bool = field(
|
| 322 |
+
default=False,
|
| 323 |
+
metadata={
|
| 324 |
+
"help": (
|
| 325 |
+
"It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded. "
|
| 326 |
+
"set True will benefit LLM loading time and RAM consumption."
|
| 327 |
+
)
|
| 328 |
+
},
|
| 329 |
+
)
|
| 330 |
+
torch_dtype: Optional[str] = field(
|
| 331 |
+
default=None,
|
| 332 |
+
metadata={
|
| 333 |
+
"help": (
|
| 334 |
+
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
|
| 335 |
+
"dtype will be automatically derived from the model's weights."
|
| 336 |
+
),
|
| 337 |
+
"choices": ["auto", "bfloat16", "float16", "float32"],
|
| 338 |
+
},
|
| 339 |
+
)
|
| 340 |
+
attn_implementation: Optional[str] = field(
|
| 341 |
+
default="sdpa",
|
| 342 |
+
metadata={
|
| 343 |
+
"help": ("The attention implementation to use in the model."),
|
| 344 |
+
"choices": ["eager", "sdpa", "flash_attention_2"],
|
| 345 |
+
},
|
| 346 |
+
)
|
| 347 |
+
classifier_dropout: Optional[float] = field(
|
| 348 |
+
default=0.1, metadata={"help": "The dropout rate for models"}
|
| 349 |
+
)
|
| 350 |
+
peft_addr: Optional[str] = field(
|
| 351 |
+
default=None, metadata={"help": "addr of lora adapter weights"}
|
| 352 |
+
)
|
| 353 |
+
model_class: str = field(
|
| 354 |
+
default="custom",
|
| 355 |
+
metadata={
|
| 356 |
+
"help": "One of the items 'custom' or 'auto'. 'custom' for LLM2Vec models and 'auto' for pretrained encoders such as BERT.",
|
| 357 |
+
"choices": ["custom", "auto"],
|
| 358 |
+
},
|
| 359 |
+
)
|
| 360 |
+
merge_subwords: bool = field(
|
| 361 |
+
default=True,
|
| 362 |
+
metadata={"help": "Whether the representations of the subtokens get averaged."},
|
| 363 |
+
)
|
| 364 |
+
bidirectional: bool = field(
|
| 365 |
+
default=True, metadata={"help": "Whether to use bidirectional attention."}
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
def __post_init__(self):
|
| 369 |
+
if self.config_overrides is not None and (
|
| 370 |
+
self.config_name is not None or self.model_name_or_path is not None
|
| 371 |
+
):
|
| 372 |
+
raise ValueError(
|
| 373 |
+
"--config_overrides can't be used in combination with --config_name or --model_name_or_path"
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
@dataclass
|
| 378 |
+
class DataTrainingArguments:
|
| 379 |
+
"""
|
| 380 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
| 381 |
+
"""
|
| 382 |
+
|
| 383 |
+
dataset_name: Optional[str] = field(
|
| 384 |
+
default=None,
|
| 385 |
+
metadata={"help": "The name of the dataset to use (via the datasets library)."},
|
| 386 |
+
)
|
| 387 |
+
dataset_config_name: Optional[str] = field(
|
| 388 |
+
default=None,
|
| 389 |
+
metadata={
|
| 390 |
+
"help": "The configuration name of the dataset to use (via the datasets library)."
|
| 391 |
+
},
|
| 392 |
+
)
|
| 393 |
+
train_file: Optional[str] = field(
|
| 394 |
+
default=None, metadata={"help": "The input training data file (a text file)."}
|
| 395 |
+
)
|
| 396 |
+
validation_file: Optional[str] = field(
|
| 397 |
+
default=None,
|
| 398 |
+
metadata={
|
| 399 |
+
"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."
|
| 400 |
+
},
|
| 401 |
+
)
|
| 402 |
+
overwrite_cache: bool = field(
|
| 403 |
+
default=True,
|
| 404 |
+
metadata={"help": "Overwrite the cached training and evaluation sets"},
|
| 405 |
+
)
|
| 406 |
+
validation_split_percentage: Optional[int] = field(
|
| 407 |
+
default=5,
|
| 408 |
+
metadata={
|
| 409 |
+
"help": "The percentage of the train set used as validation set in case there's no validation split"
|
| 410 |
+
},
|
| 411 |
+
)
|
| 412 |
+
max_seq_length: Optional[int] = field(
|
| 413 |
+
default=None,
|
| 414 |
+
metadata={
|
| 415 |
+
"help": (
|
| 416 |
+
"The maximum total input sequence length after tokenization. Sequences longer "
|
| 417 |
+
"than this will be truncated."
|
| 418 |
+
)
|
| 419 |
+
},
|
| 420 |
+
)
|
| 421 |
+
preprocessing_num_workers: Optional[int] = field(
|
| 422 |
+
default=None,
|
| 423 |
+
metadata={"help": "The number of processes to use for the preprocessing."},
|
| 424 |
+
)
|
| 425 |
+
mlm_probability: float = field(
|
| 426 |
+
default=0.15,
|
| 427 |
+
metadata={"help": "Ratio of tokens to mask for masked language modeling loss"},
|
| 428 |
+
)
|
| 429 |
+
line_by_line: bool = field(
|
| 430 |
+
default=False,
|
| 431 |
+
metadata={
|
| 432 |
+
"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."
|
| 433 |
+
},
|
| 434 |
+
)
|
| 435 |
+
pad_to_max_length: bool = field(
|
| 436 |
+
default=False,
|
| 437 |
+
metadata={
|
| 438 |
+
"help": (
|
| 439 |
+
"Whether to pad all samples to `max_seq_length`. "
|
| 440 |
+
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
| 441 |
+
)
|
| 442 |
+
},
|
| 443 |
+
)
|
| 444 |
+
max_train_samples: Optional[int] = field(
|
| 445 |
+
default=None,
|
| 446 |
+
metadata={
|
| 447 |
+
"help": (
|
| 448 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
| 449 |
+
"value if set."
|
| 450 |
+
)
|
| 451 |
+
},
|
| 452 |
+
)
|
| 453 |
+
max_eval_samples: Optional[int] = field(
|
| 454 |
+
default=None,
|
| 455 |
+
metadata={
|
| 456 |
+
"help": (
|
| 457 |
+
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
| 458 |
+
"value if set."
|
| 459 |
+
)
|
| 460 |
+
},
|
| 461 |
+
)
|
| 462 |
+
streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"})
|
| 463 |
+
|
| 464 |
+
def __post_init__(self):
|
| 465 |
+
if self.streaming:
|
| 466 |
+
require_version(
|
| 467 |
+
"datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`"
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
if (
|
| 471 |
+
self.dataset_name is None
|
| 472 |
+
and self.train_file is None
|
| 473 |
+
and self.validation_file is None
|
| 474 |
+
):
|
| 475 |
+
raise ValueError(
|
| 476 |
+
"Need either a dataset name or a training/validation file."
|
| 477 |
+
)
|
| 478 |
+
else:
|
| 479 |
+
if self.train_file is not None:
|
| 480 |
+
extension = self.train_file.split(".")[-1]
|
| 481 |
+
if extension not in ["csv", "json", "txt"]:
|
| 482 |
+
raise ValueError(
|
| 483 |
+
"`train_file` should be a csv, a json or a txt file."
|
| 484 |
+
)
|
| 485 |
+
if self.validation_file is not None:
|
| 486 |
+
extension = self.validation_file.split(".")[-1]
|
| 487 |
+
if extension not in ["csv", "json", "txt"]:
|
| 488 |
+
raise ValueError(
|
| 489 |
+
"`validation_file` should be a csv, a json or a txt file."
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
# add more arguments
|
| 494 |
+
@dataclass
|
| 495 |
+
class CustomArguments:
|
| 496 |
+
"""
|
| 497 |
+
Custom arguments for the script
|
| 498 |
+
"""
|
| 499 |
+
|
| 500 |
+
stop_after_n_steps: int = field(
|
| 501 |
+
default=10000, metadata={"help": "Stop training after n steps"}
|
| 502 |
+
)
|
| 503 |
+
data_collator_type: str = field(
|
| 504 |
+
default="custom",
|
| 505 |
+
metadata={
|
| 506 |
+
"help": "The type of data collator. Options: custom, default, custom_no_random"
|
| 507 |
+
},
|
| 508 |
+
)
|
| 509 |
+
task: Optional[str] = field(
|
| 510 |
+
default="pos_tags",
|
| 511 |
+
metadata={
|
| 512 |
+
"help": "One of the 'pos_tags', 'chunk_tags', and 'ner_tags' choices",
|
| 513 |
+
"choices": ["pos_tags", "ner_tags", "chunk_tags"],
|
| 514 |
+
},
|
| 515 |
+
)
|
| 516 |
+
retroactive_labels: str = field(
|
| 517 |
+
default="next_token",
|
| 518 |
+
metadata={
|
| 519 |
+
"help": "Whether the tokens representations are used to predict the next token's labels. Options: same_token, next_word, next_token.",
|
| 520 |
+
"choices": ["next_token", "same_token"],
|
| 521 |
+
},
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
class StopTrainingCallback(TrainerCallback):
|
| 526 |
+
def __init__(self, stop_after_n_steps: int):
|
| 527 |
+
self.stop_after_n_steps = stop_after_n_steps
|
| 528 |
+
|
| 529 |
+
def on_step_end(self, args, state, control, **kwargs):
|
| 530 |
+
if state.global_step >= self.stop_after_n_steps:
|
| 531 |
+
control.should_training_stop = True
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
class WordTaskTrainer(Trainer):
|
| 535 |
+
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
| 536 |
+
# If we are executing this function, we are the process zero, so we don't check for that.
|
| 537 |
+
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
| 538 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 539 |
+
logger.info(f"Saving model checkpoint to {output_dir}")
|
| 540 |
+
|
| 541 |
+
torch.save(self.model.classifier, os.path.join(output_dir, "classifier.pt"))
|
| 542 |
+
self.tokenizer.save_pretrained(output_dir)
|
| 543 |
+
|
| 544 |
+
# Good practice: save your training arguments together with the trained model
|
| 545 |
+
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
def main():
|
| 549 |
+
parser = HfArgumentParser(
|
| 550 |
+
(ModelArguments, DataTrainingArguments, TrainingArguments, CustomArguments)
|
| 551 |
+
)
|
| 552 |
+
# model_args, data_args, training_args, custom_args = parser.parse_args_into_dataclasses()
|
| 553 |
+
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
| 554 |
+
# If we pass only one argument to the script and it's the path to a json file,
|
| 555 |
+
# let's parse it to get our arguments.
|
| 556 |
+
model_args, data_args, training_args, custom_args = parser.parse_json_file(
|
| 557 |
+
json_file=os.path.abspath(sys.argv[1])
|
| 558 |
+
)
|
| 559 |
+
else:
|
| 560 |
+
(
|
| 561 |
+
model_args,
|
| 562 |
+
data_args,
|
| 563 |
+
training_args,
|
| 564 |
+
custom_args,
|
| 565 |
+
) = parser.parse_args_into_dataclasses()
|
| 566 |
+
|
| 567 |
+
if training_args.gradient_checkpointing:
|
| 568 |
+
training_args.gradient_checkpointing_kwargs = {"use_reentrant": False}
|
| 569 |
+
|
| 570 |
+
if model_args.use_auth_token is not None:
|
| 571 |
+
warnings.warn(
|
| 572 |
+
"The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead.",
|
| 573 |
+
FutureWarning,
|
| 574 |
+
)
|
| 575 |
+
if model_args.token is not None:
|
| 576 |
+
raise ValueError(
|
| 577 |
+
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
|
| 578 |
+
)
|
| 579 |
+
model_args.token = model_args.use_auth_token
|
| 580 |
+
|
| 581 |
+
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
|
| 582 |
+
# information sent is the one passed as arguments along with your Python/PyTorch versions.
|
| 583 |
+
send_example_telemetry("run_word_task", model_args, data_args)
|
| 584 |
+
|
| 585 |
+
# Setup logging
|
| 586 |
+
logging.basicConfig(
|
| 587 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 588 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 589 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
if training_args.should_log:
|
| 593 |
+
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
|
| 594 |
+
transformers.utils.logging.set_verbosity_info()
|
| 595 |
+
|
| 596 |
+
log_level = training_args.get_process_log_level()
|
| 597 |
+
logger.setLevel(log_level)
|
| 598 |
+
datasets.utils.logging.set_verbosity(log_level)
|
| 599 |
+
transformers.utils.logging.set_verbosity(log_level)
|
| 600 |
+
transformers.utils.logging.enable_default_handler()
|
| 601 |
+
transformers.utils.logging.enable_explicit_format()
|
| 602 |
+
|
| 603 |
+
# Log on each process the small summary:
|
| 604 |
+
logger.warning(
|
| 605 |
+
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
|
| 606 |
+
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
|
| 607 |
+
)
|
| 608 |
+
# Set the verbosity to info of the Transformers logger (on main process only):
|
| 609 |
+
logger.info(f"Training/evaluation parameters {training_args}")
|
| 610 |
+
|
| 611 |
+
# Set seed before initializing model.
|
| 612 |
+
set_seed(training_args.seed)
|
| 613 |
+
|
| 614 |
+
if data_args.dataset_name is not None:
|
| 615 |
+
# Downloading and loading a dataset from the hub.
|
| 616 |
+
raw_datasets = load_dataset(
|
| 617 |
+
data_args.dataset_name,
|
| 618 |
+
data_args.dataset_config_name,
|
| 619 |
+
cache_dir=model_args.cache_dir,
|
| 620 |
+
token=model_args.token,
|
| 621 |
+
streaming=data_args.streaming,
|
| 622 |
+
)
|
| 623 |
+
if "validation" not in raw_datasets.keys():
|
| 624 |
+
raw_datasets["validation"] = load_dataset(
|
| 625 |
+
data_args.dataset_name,
|
| 626 |
+
data_args.dataset_config_name,
|
| 627 |
+
split=f"train[:{data_args.validation_split_percentage}%]",
|
| 628 |
+
cache_dir=model_args.cache_dir,
|
| 629 |
+
token=model_args.token,
|
| 630 |
+
streaming=data_args.streaming,
|
| 631 |
+
)
|
| 632 |
+
raw_datasets["train"] = load_dataset(
|
| 633 |
+
data_args.dataset_name,
|
| 634 |
+
data_args.dataset_config_name,
|
| 635 |
+
split=f"train[{data_args.validation_split_percentage}%:]",
|
| 636 |
+
cache_dir=model_args.cache_dir,
|
| 637 |
+
token=model_args.token,
|
| 638 |
+
streaming=data_args.streaming,
|
| 639 |
+
)
|
| 640 |
+
else:
|
| 641 |
+
data_files = {}
|
| 642 |
+
if data_args.train_file is not None:
|
| 643 |
+
data_files["train"] = data_args.train_file
|
| 644 |
+
extension = data_args.train_file.split(".")[-1]
|
| 645 |
+
if data_args.validation_file is not None:
|
| 646 |
+
data_files["validation"] = data_args.validation_file
|
| 647 |
+
extension = data_args.validation_file.split(".")[-1]
|
| 648 |
+
if extension == "txt":
|
| 649 |
+
extension = "text"
|
| 650 |
+
raw_datasets = load_dataset(
|
| 651 |
+
extension,
|
| 652 |
+
data_files=data_files,
|
| 653 |
+
cache_dir=model_args.cache_dir,
|
| 654 |
+
token=model_args.token,
|
| 655 |
+
)
|
| 656 |
+
|
| 657 |
+
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
|
| 658 |
+
if "validation" not in raw_datasets.keys():
|
| 659 |
+
raw_datasets["validation"] = load_dataset(
|
| 660 |
+
extension,
|
| 661 |
+
data_files=data_files,
|
| 662 |
+
split=f"train[:{data_args.validation_split_percentage}%]",
|
| 663 |
+
cache_dir=model_args.cache_dir,
|
| 664 |
+
token=model_args.token,
|
| 665 |
+
)
|
| 666 |
+
raw_datasets["train"] = load_dataset(
|
| 667 |
+
extension,
|
| 668 |
+
data_files=data_files,
|
| 669 |
+
split=f"train[{data_args.validation_split_percentage}%:]",
|
| 670 |
+
cache_dir=model_args.cache_dir,
|
| 671 |
+
token=model_args.token,
|
| 672 |
+
)
|
| 673 |
+
|
| 674 |
+
assert (
|
| 675 |
+
data_args.dataset_name in LABELS
|
| 676 |
+
and custom_args.task in LABELS[data_args.dataset_name]
|
| 677 |
+
), f"LABELS[{data_args.dataset_name}][{custom_args.task}] is not defined."
|
| 678 |
+
|
| 679 |
+
config_kwargs = {
|
| 680 |
+
"num_labels": len(LABELS[data_args.dataset_name][custom_args.task]),
|
| 681 |
+
"id2label": {
|
| 682 |
+
i: lab
|
| 683 |
+
for (lab, i) in LABELS[data_args.dataset_name][custom_args.task].items()
|
| 684 |
+
},
|
| 685 |
+
"label2id": LABELS[data_args.dataset_name][custom_args.task],
|
| 686 |
+
"classifier_dropout": model_args.classifier_dropout,
|
| 687 |
+
}
|
| 688 |
+
|
| 689 |
+
tokenizer_kwargs = {
|
| 690 |
+
"cache_dir": model_args.cache_dir,
|
| 691 |
+
"use_fast": model_args.use_fast_tokenizer,
|
| 692 |
+
"revision": model_args.model_revision,
|
| 693 |
+
"token": model_args.token,
|
| 694 |
+
"trust_remote_code": model_args.trust_remote_code,
|
| 695 |
+
}
|
| 696 |
+
if model_args.tokenizer_name:
|
| 697 |
+
if "gpt" in model_args.tokenizer_name:
|
| 698 |
+
tokenizer_kwargs["add_prefix_space"] = True
|
| 699 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 700 |
+
model_args.tokenizer_name, **tokenizer_kwargs
|
| 701 |
+
)
|
| 702 |
+
elif model_args.model_name_or_path:
|
| 703 |
+
if "gpt" in model_args.model_name_or_path:
|
| 704 |
+
tokenizer_kwargs["add_prefix_space"] = True
|
| 705 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 706 |
+
model_args.model_name_or_path, **tokenizer_kwargs
|
| 707 |
+
)
|
| 708 |
+
else:
|
| 709 |
+
raise ValueError(
|
| 710 |
+
"You are instantiating a new tokenizer from scratch. This is not supported by this script. "
|
| 711 |
+
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
if tokenizer.pad_token is None:
|
| 715 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 716 |
+
if model_args.model_class == "custom":
|
| 717 |
+
tokenizer.model_input_names.append("token_type_ids")
|
| 718 |
+
if model_args.model_class == "auto":
|
| 719 |
+
assert not model_args.merge_subwords
|
| 720 |
+
|
| 721 |
+
if model_args.model_class == "custom":
|
| 722 |
+
if model_args.config_name:
|
| 723 |
+
config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
|
| 724 |
+
elif model_args.model_name_or_path:
|
| 725 |
+
config = AutoConfig.from_pretrained(
|
| 726 |
+
model_args.model_name_or_path, **config_kwargs
|
| 727 |
+
)
|
| 728 |
+
else:
|
| 729 |
+
raise ValueError("Invalid config loading")
|
| 730 |
+
|
| 731 |
+
for k, v in config_kwargs.items():
|
| 732 |
+
config.__setattr__(k, v)
|
| 733 |
+
|
| 734 |
+
torch_dtype = (
|
| 735 |
+
model_args.torch_dtype
|
| 736 |
+
if model_args.torch_dtype in ["auto", None]
|
| 737 |
+
else getattr(torch, model_args.torch_dtype)
|
| 738 |
+
)
|
| 739 |
+
l2v = LLM2Vec.from_pretrained(
|
| 740 |
+
base_model_name_or_path=model_args.model_name_or_path,
|
| 741 |
+
enable_bidirectional=model_args.bidirectional,
|
| 742 |
+
peft_model_name_or_path=model_args.peft_addr,
|
| 743 |
+
merge_peft=False,
|
| 744 |
+
torch_dtype=torch_dtype,
|
| 745 |
+
attn_implementation=model_args.attn_implementation,
|
| 746 |
+
)
|
| 747 |
+
|
| 748 |
+
model = ModelForWordTask(
|
| 749 |
+
model=l2v.model,
|
| 750 |
+
merge_subwords=model_args.merge_subwords,
|
| 751 |
+
config=config,
|
| 752 |
+
torch_dtype=torch_dtype,
|
| 753 |
+
)
|
| 754 |
+
|
| 755 |
+
MyTrainer = WordTaskTrainer
|
| 756 |
+
|
| 757 |
+
elif model_args.model_class == "auto":
|
| 758 |
+
model = AutoModelForTokenClassification.from_pretrained(
|
| 759 |
+
model_args.model_name_or_path,
|
| 760 |
+
num_labels=config_kwargs["num_labels"],
|
| 761 |
+
id2label=config_kwargs["id2label"],
|
| 762 |
+
label2id=config_kwargs["label2id"],
|
| 763 |
+
)
|
| 764 |
+
MyTrainer = Trainer
|
| 765 |
+
|
| 766 |
+
else:
|
| 767 |
+
raise ValueError(
|
| 768 |
+
f"{model_args.model_class} is not implemented. Only 'auto' and 'custom' model_class options are valid."
|
| 769 |
+
)
|
| 770 |
+
|
| 771 |
+
# only train classifier
|
| 772 |
+
for n, p in list(model.named_parameters()):
|
| 773 |
+
if "classifier" in n:
|
| 774 |
+
p.requires_grad = True
|
| 775 |
+
else:
|
| 776 |
+
p.requires_grad = False
|
| 777 |
+
|
| 778 |
+
if data_args.max_seq_length is None:
|
| 779 |
+
max_seq_length = tokenizer.model_max_length
|
| 780 |
+
if max_seq_length > 1024:
|
| 781 |
+
logger.warning(
|
| 782 |
+
"The chosen tokenizer supports a `model_max_length` that is longer than the default `block_size` value"
|
| 783 |
+
" of 1024. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can"
|
| 784 |
+
" override this default with `--block_size xxx`."
|
| 785 |
+
)
|
| 786 |
+
max_seq_length = 1024
|
| 787 |
+
else:
|
| 788 |
+
if data_args.max_seq_length > tokenizer.model_max_length:
|
| 789 |
+
logger.warning(
|
| 790 |
+
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the "
|
| 791 |
+
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
|
| 792 |
+
)
|
| 793 |
+
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
| 794 |
+
|
| 795 |
+
def tokenize_and_align_labels(examples):
|
| 796 |
+
task = custom_args.task
|
| 797 |
+
padding = "max_length" if data_args.pad_to_max_length else False
|
| 798 |
+
tokenized_inputs = tokenizer(
|
| 799 |
+
examples["tokens"],
|
| 800 |
+
truncation=True,
|
| 801 |
+
is_split_into_words=True,
|
| 802 |
+
padding=padding,
|
| 803 |
+
max_length=max_seq_length,
|
| 804 |
+
)
|
| 805 |
+
|
| 806 |
+
labels = []
|
| 807 |
+
words = []
|
| 808 |
+
for i, label in enumerate(examples[task]):
|
| 809 |
+
if custom_args.retroactive_labels in ["same_token"]:
|
| 810 |
+
word_ids = tokenized_inputs.word_ids(batch_index=i)
|
| 811 |
+
previous_word_idx = None
|
| 812 |
+
label_ids = []
|
| 813 |
+
for word_idx in word_ids:
|
| 814 |
+
if word_idx is None:
|
| 815 |
+
label_ids.append(-100)
|
| 816 |
+
elif word_idx != previous_word_idx:
|
| 817 |
+
label_ids.append(label[word_idx])
|
| 818 |
+
else:
|
| 819 |
+
label_ids.append(-100)
|
| 820 |
+
previous_word_idx = word_idx
|
| 821 |
+
labels.append(label_ids)
|
| 822 |
+
word_ids = [-1 if w is None else w for w in word_ids]
|
| 823 |
+
words.append(word_ids)
|
| 824 |
+
elif custom_args.retroactive_labels == "next_token":
|
| 825 |
+
word_ids = tokenized_inputs.word_ids(batch_index=i)
|
| 826 |
+
previous_word_idx = None
|
| 827 |
+
label_ids = []
|
| 828 |
+
for word_idx in word_ids:
|
| 829 |
+
if word_idx is None:
|
| 830 |
+
label_ids.append(-100)
|
| 831 |
+
elif word_idx != previous_word_idx:
|
| 832 |
+
label_ids.append(label[word_idx])
|
| 833 |
+
else:
|
| 834 |
+
label_ids.append(-100)
|
| 835 |
+
previous_word_idx = word_idx
|
| 836 |
+
label_ids.append(-100)
|
| 837 |
+
labels.append(label_ids[1:])
|
| 838 |
+
word_ids = word_ids[1:] + [None]
|
| 839 |
+
word_ids = [-1 if w is None else w for w in word_ids]
|
| 840 |
+
words.append(word_ids)
|
| 841 |
+
else:
|
| 842 |
+
raise ValueError(
|
| 843 |
+
f"retroactive_labels {custom_args.retroactive_labels} is not implemented."
|
| 844 |
+
)
|
| 845 |
+
|
| 846 |
+
tokenized_inputs["labels"] = labels
|
| 847 |
+
if model_args.model_class == "custom":
|
| 848 |
+
tokenized_inputs["token_type_ids"] = words
|
| 849 |
+
return tokenized_inputs
|
| 850 |
+
|
| 851 |
+
tokenized_dataset = raw_datasets.map(
|
| 852 |
+
tokenize_and_align_labels,
|
| 853 |
+
batched=True,
|
| 854 |
+
remove_columns=list(LABELS[data_args.dataset_name].keys()) + ["tokens", "id"],
|
| 855 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
| 856 |
+
)
|
| 857 |
+
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
|
| 858 |
+
seqeval = evaluate.load("seqeval")
|
| 859 |
+
|
| 860 |
+
def compute_metrics(p):
|
| 861 |
+
predictions, labels = p
|
| 862 |
+
predictions = predictions[0]
|
| 863 |
+
predictions = np.argmax(predictions, axis=2)
|
| 864 |
+
|
| 865 |
+
true_predictions = [
|
| 866 |
+
[
|
| 867 |
+
config_kwargs["id2label"][p]
|
| 868 |
+
for (p, l) in zip(prediction, label)
|
| 869 |
+
if l != -100
|
| 870 |
+
]
|
| 871 |
+
for prediction, label in zip(predictions, labels)
|
| 872 |
+
]
|
| 873 |
+
true_labels = [
|
| 874 |
+
[
|
| 875 |
+
config_kwargs["id2label"][l]
|
| 876 |
+
for (p, l) in zip(prediction, label)
|
| 877 |
+
if l != -100
|
| 878 |
+
]
|
| 879 |
+
for prediction, label in zip(predictions, labels)
|
| 880 |
+
]
|
| 881 |
+
|
| 882 |
+
results = seqeval.compute(predictions=true_predictions, references=true_labels)
|
| 883 |
+
return {
|
| 884 |
+
"precision": results["overall_precision"],
|
| 885 |
+
"recall": results["overall_recall"],
|
| 886 |
+
"f1": results["overall_f1"],
|
| 887 |
+
"accuracy": results["overall_accuracy"],
|
| 888 |
+
}
|
| 889 |
+
|
| 890 |
+
trainer = MyTrainer(
|
| 891 |
+
model=model,
|
| 892 |
+
args=training_args,
|
| 893 |
+
train_dataset=tokenized_dataset["train"],
|
| 894 |
+
eval_dataset=tokenized_dataset["validation"],
|
| 895 |
+
tokenizer=tokenizer,
|
| 896 |
+
data_collator=data_collator,
|
| 897 |
+
compute_metrics=compute_metrics,
|
| 898 |
+
)
|
| 899 |
+
trainer.add_callback(StopTrainingCallback(custom_args.stop_after_n_steps))
|
| 900 |
+
|
| 901 |
+
trainer.train()
|
| 902 |
+
|
| 903 |
+
|
| 904 |
+
if __name__ == "__main__":
|
| 905 |
+
main()
|
llm2vec/experiments/test_word_task.py
ADDED
|
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import logging
|
| 4 |
+
import argparse
|
| 5 |
+
from transformers import (
|
| 6 |
+
AutoTokenizer,
|
| 7 |
+
AutoConfig,
|
| 8 |
+
AutoModelForTokenClassification,
|
| 9 |
+
set_seed,
|
| 10 |
+
HfArgumentParser,
|
| 11 |
+
)
|
| 12 |
+
import torch
|
| 13 |
+
from datasets import load_dataset
|
| 14 |
+
import evaluate
|
| 15 |
+
import json
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
from run_word_task import ModelForWordTask
|
| 18 |
+
from llm2vec import LLM2Vec
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
LABELS = {
|
| 22 |
+
"conll2003": {
|
| 23 |
+
"pos_tags": {
|
| 24 |
+
'"': 0,
|
| 25 |
+
"''": 1,
|
| 26 |
+
"#": 2,
|
| 27 |
+
"$": 3,
|
| 28 |
+
"(": 4,
|
| 29 |
+
")": 5,
|
| 30 |
+
",": 6,
|
| 31 |
+
".": 7,
|
| 32 |
+
":": 8,
|
| 33 |
+
"``": 9,
|
| 34 |
+
"CC": 10,
|
| 35 |
+
"CD": 11,
|
| 36 |
+
"DT": 12,
|
| 37 |
+
"EX": 13,
|
| 38 |
+
"FW": 14,
|
| 39 |
+
"IN": 15,
|
| 40 |
+
"JJ": 16,
|
| 41 |
+
"JJR": 17,
|
| 42 |
+
"JJS": 18,
|
| 43 |
+
"LS": 19,
|
| 44 |
+
"MD": 20,
|
| 45 |
+
"NN": 21,
|
| 46 |
+
"NNP": 22,
|
| 47 |
+
"NNPS": 23,
|
| 48 |
+
"NNS": 24,
|
| 49 |
+
"NN|SYM": 25,
|
| 50 |
+
"PDT": 26,
|
| 51 |
+
"POS": 27,
|
| 52 |
+
"PRP": 28,
|
| 53 |
+
"PRP$": 29,
|
| 54 |
+
"RB": 30,
|
| 55 |
+
"RBR": 31,
|
| 56 |
+
"RBS": 32,
|
| 57 |
+
"RP": 33,
|
| 58 |
+
"SYM": 34,
|
| 59 |
+
"TO": 35,
|
| 60 |
+
"UH": 36,
|
| 61 |
+
"VB": 37,
|
| 62 |
+
"VBD": 38,
|
| 63 |
+
"VBG": 39,
|
| 64 |
+
"VBN": 40,
|
| 65 |
+
"VBP": 41,
|
| 66 |
+
"VBZ": 42,
|
| 67 |
+
"WDT": 43,
|
| 68 |
+
"WP": 44,
|
| 69 |
+
"WP$": 45,
|
| 70 |
+
"WRB": 46,
|
| 71 |
+
},
|
| 72 |
+
"chunk_tags": {
|
| 73 |
+
"O": 0,
|
| 74 |
+
"B-ADJP": 1,
|
| 75 |
+
"I-ADJP": 2,
|
| 76 |
+
"B-ADVP": 3,
|
| 77 |
+
"I-ADVP": 4,
|
| 78 |
+
"B-CONJP": 5,
|
| 79 |
+
"I-CONJP": 6,
|
| 80 |
+
"B-INTJ": 7,
|
| 81 |
+
"I-INTJ": 8,
|
| 82 |
+
"B-LST": 9,
|
| 83 |
+
"I-LST": 10,
|
| 84 |
+
"B-NP": 11,
|
| 85 |
+
"I-NP": 12,
|
| 86 |
+
"B-PP": 13,
|
| 87 |
+
"I-PP": 14,
|
| 88 |
+
"B-PRT": 15,
|
| 89 |
+
"I-PRT": 16,
|
| 90 |
+
"B-SBAR": 17,
|
| 91 |
+
"I-SBAR": 18,
|
| 92 |
+
"B-UCP": 19,
|
| 93 |
+
"I-UCP": 20,
|
| 94 |
+
"B-VP": 21,
|
| 95 |
+
"I-VP": 22,
|
| 96 |
+
},
|
| 97 |
+
"ner_tags": {
|
| 98 |
+
"O": 0,
|
| 99 |
+
"B-PER": 1,
|
| 100 |
+
"I-PER": 2,
|
| 101 |
+
"B-ORG": 3,
|
| 102 |
+
"I-ORG": 4,
|
| 103 |
+
"B-LOC": 5,
|
| 104 |
+
"I-LOC": 6,
|
| 105 |
+
"B-MISC": 7,
|
| 106 |
+
"I-MISC": 8,
|
| 107 |
+
},
|
| 108 |
+
}
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def str2bool(v):
|
| 113 |
+
if isinstance(v, bool):
|
| 114 |
+
return v
|
| 115 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
| 116 |
+
return True
|
| 117 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
| 118 |
+
return False
|
| 119 |
+
else:
|
| 120 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
if __name__ == "__main__":
|
| 124 |
+
logging.basicConfig(level=logging.INFO)
|
| 125 |
+
parser = argparse.ArgumentParser()
|
| 126 |
+
parser.add_argument("--model_class", default="custom", type=str)
|
| 127 |
+
parser.add_argument("--model_name_or_path", default=None, type=str)
|
| 128 |
+
parser.add_argument(
|
| 129 |
+
"--peft_addr",
|
| 130 |
+
default=None,
|
| 131 |
+
type=str,
|
| 132 |
+
help="The dir address where adapter_model.bin is saved.",
|
| 133 |
+
)
|
| 134 |
+
parser.add_argument(
|
| 135 |
+
"--cls_addr",
|
| 136 |
+
default=None,
|
| 137 |
+
type=str,
|
| 138 |
+
help="The dir address where classifier is saved.",
|
| 139 |
+
)
|
| 140 |
+
parser.add_argument("--bidirectional", default=True, type=str2bool)
|
| 141 |
+
parser.add_argument("--merge_subwords", default=True, type=str2bool)
|
| 142 |
+
parser.add_argument("--output_dir", default=None, type=str)
|
| 143 |
+
parser.add_argument("--classifier_dropout", default=0.1, type=float)
|
| 144 |
+
parser.add_argument(
|
| 145 |
+
"--attn_implementation",
|
| 146 |
+
default="sdpa",
|
| 147 |
+
type=str,
|
| 148 |
+
choices=["sdpa", "eager", "flash_attention_2"],
|
| 149 |
+
)
|
| 150 |
+
parser.add_argument(
|
| 151 |
+
"--torch_dtype",
|
| 152 |
+
default=None,
|
| 153 |
+
type=str,
|
| 154 |
+
choices=["auto", "bfloat16", "float16", "float32"],
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
parser.add_argument(
|
| 158 |
+
"--retroactive_labels",
|
| 159 |
+
default="next_token",
|
| 160 |
+
type=str,
|
| 161 |
+
choices=["next_token", "same_token"],
|
| 162 |
+
)
|
| 163 |
+
parser.add_argument("--dataset_name", default=None, type=str)
|
| 164 |
+
parser.add_argument(
|
| 165 |
+
"--task", default=None, type=str, choices=["pos_tags", "chunk_tags", "ner_tags"]
|
| 166 |
+
)
|
| 167 |
+
parser.add_argument("--max_seq_length", default=1024, type=int)
|
| 168 |
+
parser.add_argument("--batch_size", default=32, type=int)
|
| 169 |
+
parser.add_argument("--seed", default=32, type=int)
|
| 170 |
+
|
| 171 |
+
parser.add_argument("--config_file", default=None, type=str)
|
| 172 |
+
|
| 173 |
+
args = parser.parse_args()
|
| 174 |
+
|
| 175 |
+
if args.config_file is not None:
|
| 176 |
+
# If we pass only one argument to the script and it's the path to a json file,
|
| 177 |
+
# let's parse it to get our arguments.
|
| 178 |
+
from pathlib import Path
|
| 179 |
+
import json
|
| 180 |
+
|
| 181 |
+
json_text = json.load(open(os.path.abspath(args.config_file)))
|
| 182 |
+
argparse_dict = vars(args)
|
| 183 |
+
argparse_dict.update(json_text)
|
| 184 |
+
# args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
| 185 |
+
else:
|
| 186 |
+
args = parser.parse_args()
|
| 187 |
+
|
| 188 |
+
path_to_check = args.peft_addr if args.peft_addr else args.model_name_or_path
|
| 189 |
+
assert (
|
| 190 |
+
args.output_dir is not None
|
| 191 |
+
), "If you want to evaluate a model, you have to provide the output_dir"
|
| 192 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 193 |
+
|
| 194 |
+
set_seed(args.seed)
|
| 195 |
+
|
| 196 |
+
tokenizer_kwargs = {}
|
| 197 |
+
if "gpt" in args.model_name_or_path:
|
| 198 |
+
tokenizer_kwargs["add_prefix_space"] = True
|
| 199 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 200 |
+
args.model_name_or_path, **tokenizer_kwargs
|
| 201 |
+
)
|
| 202 |
+
if tokenizer.pad_token_id is None:
|
| 203 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 204 |
+
|
| 205 |
+
if args.model_class == "custom":
|
| 206 |
+
tokenizer.model_input_names.append("token_type_ids")
|
| 207 |
+
|
| 208 |
+
if args.model_class == "auto":
|
| 209 |
+
assert not args.merge_subwords
|
| 210 |
+
|
| 211 |
+
assert (
|
| 212 |
+
args.dataset_name in LABELS and args.task in LABELS[args.dataset_name]
|
| 213 |
+
), f"LABELS[{args.dataset_name}][{args.task}] is not defined."
|
| 214 |
+
|
| 215 |
+
config_kwargs = {
|
| 216 |
+
"num_labels": len(LABELS[args.dataset_name][args.task]),
|
| 217 |
+
"id2label": {
|
| 218 |
+
i: lab for (lab, i) in LABELS[args.dataset_name][args.task].items()
|
| 219 |
+
},
|
| 220 |
+
"label2id": LABELS[args.dataset_name][args.task],
|
| 221 |
+
"classifier_dropout": args.classifier_dropout,
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
if args.model_class == "custom":
|
| 225 |
+
if args.model_name_or_path:
|
| 226 |
+
config = AutoConfig.from_pretrained(
|
| 227 |
+
args.model_name_or_path, **config_kwargs
|
| 228 |
+
)
|
| 229 |
+
else:
|
| 230 |
+
raise ValueError("Invalid config loading")
|
| 231 |
+
|
| 232 |
+
for k, v in config_kwargs.items():
|
| 233 |
+
config.__setattr__(k, v)
|
| 234 |
+
|
| 235 |
+
torch_dtype = (
|
| 236 |
+
args.torch_dtype
|
| 237 |
+
if args.torch_dtype in ["auto", None]
|
| 238 |
+
else getattr(torch, args.torch_dtype)
|
| 239 |
+
)
|
| 240 |
+
l2v = LLM2Vec.from_pretrained(
|
| 241 |
+
base_model_name_or_path=args.model_name_or_path,
|
| 242 |
+
enable_bidirectional=args.bidirectional,
|
| 243 |
+
peft_model_name_or_path=args.peft_addr,
|
| 244 |
+
merge_peft=False,
|
| 245 |
+
torch_dtype=torch_dtype,
|
| 246 |
+
attn_implementation=args.attn_implementation,
|
| 247 |
+
)
|
| 248 |
+
model = ModelForWordTask(
|
| 249 |
+
model=l2v.model,
|
| 250 |
+
merge_subwords=args.merge_subwords,
|
| 251 |
+
config=config,
|
| 252 |
+
torch_dtype=torch_dtype,
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
classifier_path = os.path.join(args.cls_addr, "classifier.pt")
|
| 256 |
+
if os.path.exists(classifier_path):
|
| 257 |
+
print(f"Loading classifier from {classifier_path}")
|
| 258 |
+
model.classifier = torch.load(classifier_path)
|
| 259 |
+
else:
|
| 260 |
+
raise ValueError("classifier does not exist in", classifier_path)
|
| 261 |
+
|
| 262 |
+
elif args.model_class == "auto":
|
| 263 |
+
model = AutoModelForTokenClassification.from_pretrained(
|
| 264 |
+
args.model_name_or_path,
|
| 265 |
+
num_labels=len(LABELS[args.dataset_name][args.task]),
|
| 266 |
+
id2label={
|
| 267 |
+
i: lab for (lab, i) in LABELS[args.dataset_name][args.task].items()
|
| 268 |
+
},
|
| 269 |
+
label2id=LABELS[args.dataset_name][args.task],
|
| 270 |
+
)
|
| 271 |
+
else:
|
| 272 |
+
raise ValueError(
|
| 273 |
+
f"{args.model_class} is not implemented. Only auto and custom model_class options are valid."
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
model = model.cuda()
|
| 277 |
+
|
| 278 |
+
raw_datasets = load_dataset(args.dataset_name, split="test")
|
| 279 |
+
|
| 280 |
+
def tokenize_and_align_labels(examples):
|
| 281 |
+
task = args.task
|
| 282 |
+
tokenized_inputs = tokenizer(
|
| 283 |
+
examples["tokens"],
|
| 284 |
+
truncation=True,
|
| 285 |
+
is_split_into_words=True,
|
| 286 |
+
padding="max_length",
|
| 287 |
+
max_length=args.max_seq_length,
|
| 288 |
+
return_tensors="pt",
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
labels = []
|
| 292 |
+
words = []
|
| 293 |
+
for i, label in enumerate(examples[task]):
|
| 294 |
+
if args.retroactive_labels in ["same_token"]:
|
| 295 |
+
# if args.retroactive_labels == "next_word":
|
| 296 |
+
# label = label[1:] + [-100]
|
| 297 |
+
word_ids = tokenized_inputs.word_ids(batch_index=i)
|
| 298 |
+
previous_word_idx = None
|
| 299 |
+
label_ids = []
|
| 300 |
+
for word_idx in word_ids:
|
| 301 |
+
if word_idx is None:
|
| 302 |
+
label_ids.append(-100)
|
| 303 |
+
elif word_idx != previous_word_idx:
|
| 304 |
+
label_ids.append(label[word_idx])
|
| 305 |
+
else:
|
| 306 |
+
label_ids.append(-100)
|
| 307 |
+
previous_word_idx = word_idx
|
| 308 |
+
labels.append(label_ids)
|
| 309 |
+
word_ids = [-1 if w is None else w for w in word_ids]
|
| 310 |
+
words.append(word_ids)
|
| 311 |
+
elif args.retroactive_labels == "next_token":
|
| 312 |
+
word_ids = tokenized_inputs.word_ids(batch_index=i)
|
| 313 |
+
previous_word_idx = None
|
| 314 |
+
label_ids = []
|
| 315 |
+
for word_idx in word_ids:
|
| 316 |
+
if word_idx is None:
|
| 317 |
+
label_ids.append(-100)
|
| 318 |
+
elif word_idx != previous_word_idx:
|
| 319 |
+
label_ids.append(label[word_idx])
|
| 320 |
+
else:
|
| 321 |
+
label_ids.append(-100)
|
| 322 |
+
previous_word_idx = word_idx
|
| 323 |
+
label_ids.append(-100)
|
| 324 |
+
labels.append(label_ids[1:])
|
| 325 |
+
word_ids = word_ids[1:] + [None]
|
| 326 |
+
word_ids = [-1 if w is None else w for w in word_ids]
|
| 327 |
+
words.append(word_ids)
|
| 328 |
+
else:
|
| 329 |
+
raise ValueError(
|
| 330 |
+
f"retroactive_labels {args.retroactive_labels} is not implemented."
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
tokenized_inputs["labels"] = torch.tensor(labels)
|
| 334 |
+
if args.model_class == "custom":
|
| 335 |
+
tokenized_inputs["token_type_ids"] = words
|
| 336 |
+
return tokenized_inputs
|
| 337 |
+
|
| 338 |
+
tokenized_dataset = raw_datasets.map(
|
| 339 |
+
tokenize_and_align_labels,
|
| 340 |
+
batched=True,
|
| 341 |
+
remove_columns=list(LABELS[args.dataset_name].keys()) + ["tokens", "id"],
|
| 342 |
+
)
|
| 343 |
+
with torch.no_grad():
|
| 344 |
+
predictions = None
|
| 345 |
+
labels = None
|
| 346 |
+
for batch_begin in tqdm(
|
| 347 |
+
torch.arange(0, len(tokenized_dataset), args.batch_size)
|
| 348 |
+
):
|
| 349 |
+
features = {
|
| 350 |
+
"input_ids": torch.tensor(
|
| 351 |
+
tokenized_dataset[batch_begin : batch_begin + args.batch_size][
|
| 352 |
+
"input_ids"
|
| 353 |
+
]
|
| 354 |
+
).to(model.device),
|
| 355 |
+
"attention_mask": torch.tensor(
|
| 356 |
+
tokenized_dataset[batch_begin : batch_begin + args.batch_size][
|
| 357 |
+
"attention_mask"
|
| 358 |
+
]
|
| 359 |
+
).to(model.device),
|
| 360 |
+
}
|
| 361 |
+
if (
|
| 362 |
+
"token_type_ids"
|
| 363 |
+
in tokenized_dataset[batch_begin : batch_begin + args.batch_size]
|
| 364 |
+
):
|
| 365 |
+
features["token_type_ids"] = torch.tensor(
|
| 366 |
+
tokenized_dataset[batch_begin : batch_begin + args.batch_size][
|
| 367 |
+
"token_type_ids"
|
| 368 |
+
]
|
| 369 |
+
).to(model.device)
|
| 370 |
+
|
| 371 |
+
labs = torch.tensor(
|
| 372 |
+
tokenized_dataset[batch_begin : batch_begin + args.batch_size]["labels"]
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
logits = model(**features).logits
|
| 376 |
+
preds = torch.argmax(logits, dim=-1)
|
| 377 |
+
if predictions is None:
|
| 378 |
+
predictions = preds
|
| 379 |
+
labels = labs
|
| 380 |
+
else:
|
| 381 |
+
predictions = torch.concatenate((predictions, preds))
|
| 382 |
+
labels = torch.concatenate((labels, labs))
|
| 383 |
+
|
| 384 |
+
precision_metric = evaluate.load("precision")
|
| 385 |
+
metrics = precision_metric.compute(
|
| 386 |
+
references=labels[labels != -100],
|
| 387 |
+
predictions=predictions[labels != -100],
|
| 388 |
+
average="micro",
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
with open(os.path.join(args.output_dir, "result_summary.json"), "w") as f:
|
| 392 |
+
json.dump(metrics, f)
|
| 393 |
+
print(metrics)
|
llm2vec/images/sample_efficient.png
ADDED
|