tuandunghcmut commited on
Commit
e9cd0c7
·
verified ·
1 Parent(s): 0d2c90e

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Ovis/docs/license/GEMMA_LICENSE.txt +77 -0
  2. Ovis/docs/license/LLAMA3_LICENSE +84 -0
  3. Ovis/ovis/__pycache__/__init__.cpython-310.pyc +0 -0
  4. Ovis/ovis/__pycache__/__init__.cpython-311.pyc +0 -0
  5. Ovis/ovis/model/__pycache__/__init__.cpython-310.pyc +0 -0
  6. Ovis/ovis/model/__pycache__/__init__.cpython-311.pyc +0 -0
  7. Ovis/ovis/model/__pycache__/configuration_ovis.cpython-311.pyc +0 -0
  8. Ovis/ovis/model/__pycache__/modeling_ovis.cpython-311.pyc +0 -0
  9. Ovis/ovis/model/configuration_ovis.py +41 -0
  10. Ovis/ovis/model/conversation_formatter.py +233 -0
  11. Ovis/ovis/model/visual_tokenizer/__pycache__/base_visual_tokenizer.cpython-310.pyc +0 -0
  12. Ovis/ovis/model/visual_tokenizer/__pycache__/base_visual_tokenizer.cpython-311.pyc +0 -0
  13. Ovis/ovis/model/visual_tokenizer/__pycache__/clip_visual_tokenizer.cpython-310.pyc +0 -0
  14. Ovis/ovis/model/visual_tokenizer/__pycache__/clip_visual_tokenizer.cpython-311.pyc +0 -0
  15. Ovis/ovis/model/visual_tokenizer/__pycache__/siglip_visual_tokenizer.cpython-310.pyc +0 -0
  16. Ovis/ovis/model/visual_tokenizer/__pycache__/siglip_visual_tokenizer.cpython-311.pyc +0 -0
  17. Ovis/ovis/serve/runner.py +105 -0
  18. Ovis/ovis/serve/server.py +41 -0
  19. Ovis/ovis/train/__init__.py +0 -0
  20. Ovis/ovis/train/arguments.py +48 -0
  21. Ovis/ovis/train/callback.py +37 -0
  22. Ovis/ovis/train/train.py +206 -0
  23. Ovis/ovis/util/constants.py +11 -0
  24. Ovis/ovis/util/utils.py +26 -0
  25. llm2vec/docs/.gitignore +9 -0
  26. llm2vec/docs/Gemfile +18 -0
  27. llm2vec/docs/README.md +104 -0
  28. llm2vec/docs/_config.yml +110 -0
  29. llm2vec/docs/_data/navigation.yml +17 -0
  30. llm2vec/docs/_includes/head/custom.html +48 -0
  31. llm2vec/docs/_sass/custom/header-footer.scss +19 -0
  32. llm2vec/docs/_sass/custom/no-sidebar.scss +9 -0
  33. llm2vec/docs/_sass/custom/splash.scss +5 -0
  34. llm2vec/docs/_sass/skins/dark.scss +30 -0
  35. llm2vec/docs/_sass/skins/light.scss +12 -0
  36. llm2vec/docs/assets/images/logo/favicon.png +0 -0
  37. llm2vec/docs/assets/images/logo/logo.png +0 -0
  38. llm2vec/docs/assets/images/logo/logo.svg +0 -0
  39. llm2vec/examples/classification.py +62 -0
  40. llm2vec/examples/clustering.py +58 -0
  41. llm2vec/examples/retrieval.py +177 -0
  42. llm2vec/examples/sts.py +57 -0
  43. llm2vec/experiments/mteb_eval.py +31 -0
  44. llm2vec/experiments/mteb_eval_custom.py +98 -0
  45. llm2vec/experiments/run_mntp.py +997 -0
  46. llm2vec/experiments/run_simcse.py +388 -0
  47. llm2vec/experiments/run_supervised.py +482 -0
  48. llm2vec/experiments/run_word_task.py +905 -0
  49. llm2vec/experiments/test_word_task.py +393 -0
  50. llm2vec/images/sample_efficient.png +0 -0
Ovis/docs/license/GEMMA_LICENSE.txt ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Gemma Terms of Use
2
+
3
+ Last modified: April 1, 2024
4
+
5
+ By using, reproducing, modifying, distributing, performing or displaying any portion or element of Gemma, Model Derivatives including via any Hosted Service, (each as defined below) (collectively, the "Gemma Services") or otherwise accepting the terms of this Agreement, you agree to be bound by this Agreement.
6
+
7
+ Section 1: DEFINITIONS
8
+ 1.1 Definitions
9
+ (a) "Agreement" or "Gemma Terms of Use" means these terms and conditions that govern the use, reproduction, Distribution or modification of the Gemma Services and any terms and conditions incorporated by reference.
10
+
11
+ (b) "Distribution" or "Distribute" means any transmission, publication, or other sharing of Gemma or Model Derivatives to a third party, including by providing or making Gemma or its functionality available as a hosted service via API, web access, or any other electronic or remote means ("Hosted Service").
12
+
13
+ (c) "Gemma" means the set of machine learning language models, trained model weights and parameters identified at ai.google.dev/gemma, regardless of the source that you obtained it from.
14
+
15
+ (d) "Google" means Google LLC.
16
+
17
+ (e) "Model Derivatives" means all (i) modifications to Gemma, (ii) works based on Gemma, or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Gemma, to that model in order to cause that model to perform similarly to Gemma, including distillation methods that use intermediate data representations or methods based on the generation of synthetic data Outputs by Gemma for training that model. For clarity, Outputs are not deemed Model Derivatives.
18
+
19
+ (f) "Output" means the information content output of Gemma or a Model Derivative that results from operating or otherwise using Gemma or the Model Derivative, including via a Hosted Service.
20
+
21
+ 1.2
22
+ As used in this Agreement, "including" means "including without limitation".
23
+
24
+ Section 2: ELIGIBILITY AND USAGE
25
+ 2.1 Eligibility
26
+ You represent and warrant that you have the legal capacity to enter into this Agreement (including being of sufficient age of consent). If you are accessing or using any of the Gemma Services for or on behalf of a legal entity, (a) you are entering into this Agreement on behalf of yourself and that legal entity, (b) you represent and warrant that you have the authority to act on behalf of and bind that entity to this Agreement and (c) references to "you" or "your" in the remainder of this Agreement refers to both you (as an individual) and that entity.
27
+
28
+ 2.2 Use
29
+ You may use, reproduce, modify, Distribute, perform or display any of the Gemma Services only in accordance with the terms of this Agreement, and must not violate (or encourage or permit anyone else to violate) any term of this Agreement.
30
+
31
+ Section 3: DISTRIBUTION AND RESTRICTIONS
32
+ 3.1 Distribution and Redistribution
33
+ You may reproduce or Distribute copies of Gemma or Model Derivatives if you meet all of the following conditions:
34
+
35
+ You must include the use restrictions referenced in Section 3.2 as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Gemma or Model Derivatives and you must provide notice to subsequent users you Distribute to that Gemma or Model Derivatives are subject to the use restrictions in Section 3.2.
36
+ You must provide all third party recipients of Gemma or Model Derivatives a copy of this Agreement.
37
+ You must cause any modified files to carry prominent notices stating that you modified the files.
38
+ All Distributions (other than through a Hosted Service) must be accompanied by a "Notice" text file that contains the following notice: "Gemma is provided under and subject to the Gemma Terms of Use found at ai.google.dev/gemma/terms".
39
+ You may add your own intellectual property statement to your modifications and, except as set forth in this Section, may provide additional or different terms and conditions for use, reproduction, or Distribution of your modifications, or for any such Model Derivatives as a whole, provided your use, reproduction, modification, Distribution, performance, and display of Gemma otherwise complies with the terms and conditions of this Agreement. Any additional or different terms and conditions you impose must not conflict with the terms of this Agreement.
40
+
41
+ 3.2 Use Restrictions
42
+ You must not use any of the Gemma Services:
43
+
44
+ for the restricted uses set forth in the Gemma Prohibited Use Policy at ai.google.dev/gemma/prohibited_use_policy ("Prohibited Use Policy"), which is hereby incorporated by reference into this Agreement; or
45
+ in violation of applicable laws and regulations.
46
+ To the maximum extent permitted by law, Google reserves the right to restrict (remotely or otherwise) usage of any of the Gemma Services that Google reasonably believes are in violation of this Agreement.
47
+
48
+ 3.3 Generated Output
49
+ Google claims no rights in Outputs you generate using Gemma. You and your users are solely responsible for Outputs and their subsequent uses.
50
+
51
+ Section 4: ADDITIONAL PROVISIONS
52
+ 4.1 Updates
53
+ Google may update Gemma from time to time.
54
+
55
+ 4.2 Trademarks
56
+ Nothing in this Agreement grants you any rights to use Google's trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between you and Google. Google reserves any rights not expressly granted herein.
57
+
58
+ 4.3 DISCLAIMER OF WARRANTY
59
+ UNLESS REQUIRED BY APPLICABLE LAW, THE GEMMA SERVICES, AND OUTPUTS, ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING ANY WARRANTIES OR CONDITIONS OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE GEMMA SERVICES OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR USE OR DISTRIBUTION OF ANY OF THE GEMMA SERVICES OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
60
+
61
+ 4.4 LIMITATION OF LIABILITY
62
+ TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), PRODUCT LIABILITY, CONTRACT, OR OTHERWISE, UNLESS REQUIRED BY APPLICABLE LAW, SHALL GOOGLE OR ITS AFFILIATES BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL, OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO, ANY OF THE GEMMA SERVICES OR OUTPUTS EVEN IF GOOGLE OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
63
+
64
+ 4.5 Term, Termination, and Survival
65
+ The term of this Agreement will commence upon your acceptance of this Agreement (including acceptance by your use, modification, or Distribution, reproduction, performance or display of any portion or element of the Gemma Services) and will continue in full force and effect until terminated in accordance with the terms of this Agreement. Google may terminate this Agreement if you are in breach of any term of this Agreement. Upon termination of this Agreement, you must delete and cease use and Distribution of all copies of Gemma and Model Derivatives in your possession or control. Sections 1, 2.1, 3.3, 4.2 to 4.9 shall survive the termination of this Agreement.
66
+
67
+ 4.6 Governing Law and Jurisdiction
68
+ This Agreement will be governed by the laws of the State of California without regard to choice of law principles. The UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The state and federal courts of Santa Clara County, California shall have exclusive jurisdiction of any dispute arising out of this Agreement.
69
+
70
+ 4.7 Severability
71
+ If any provision of this Agreement is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.
72
+
73
+ 4.8 Entire Agreement
74
+ This Agreement states all the terms agreed between the parties and supersedes all other agreements between the parties as of the date of acceptance relating to its subject matter.
75
+
76
+ 4.9 No Waiver
77
+ Google will not be treated as having waived any rights by not exercising (or delaying the exercise of) any rights under this Agreement.
Ovis/docs/license/LLAMA3_LICENSE ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ META LLAMA 3 COMMUNITY LICENSE AGREEMENT
2
+
3
+ Meta Llama 3 Version Release Date: April 18, 2024
4
+ “Agreement” means the terms and conditions for use, reproduction, distribution and modification of the Llama Materials set forth herein.
5
+
6
+ “Documentation” means the specifications, manuals and documentation accompanying Meta Llama 3 distributed by Meta at https://llama.meta.com/get-started/.
7
+
8
+ “Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
9
+
10
+ “Meta Llama 3” means the foundational large language models and software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing distributed by Meta at https://llama.meta.com/llama-downloads.
11
+
12
+ “Llama Materials” means, collectively, Meta’s proprietary Meta Llama 3 and Documentation (and any portion thereof) made available under this Agreement.
13
+
14
+ “Meta” or “we” means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) and Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland).
15
+
16
+ By clicking “I Accept” below or by using or distributing any portion or element of the Llama Materials, you agree to be bound by this Agreement.
17
+
18
+ 1. License Rights and Redistribution.
19
+
20
+ a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the Llama Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Llama Materials.
21
+ b. Redistribution and Use.
22
+ i. If you distribute or make available the Llama Materials (or any derivative works thereof), or a product or service that uses any of them, including another AI model, you shall (A) provide a copy of this Agreement with any such Llama Materials; and (B) prominently display “Built with Meta Llama 3” on a related website, user interface, blogpost, about page, or product documentation. If you use the Llama Materials to create, train, fine tune, or otherwise improve an AI model, which is distributed or made available, you shall also include “Llama 3” at the beginning of any such AI model name.
23
+ ii. If you receive Llama Materials, or any derivative works thereof, from a Licensee as part of an integrated end user product, then Section 2 of this Agreement will not apply to you.
24
+ iii. You must retain in all copies of the Llama Materials that you distribute the following attribution notice within a “Notice” text file distributed as a part of such copies: “Meta Llama 3 is licensed under the Meta Llama 3 Community License, Copyright © Meta Platforms, Inc. All Rights Reserved.”
25
+ iv. Your use of the Llama Materials must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Llama Materials (available at https://llama.meta.com/llama3/use-policy), which is hereby incorporated by reference into this Agreement.
26
+ v. You will not use the Llama Materials or any output or results of the Llama Materials to improve any other large language model (excluding Meta Llama 3 or derivative works thereof).
27
+
28
+ 2. Additional Commercial Terms. If, on the Meta Llama 3 version release date, the monthly active users of the products or services made available by or for Licensee, or Licensee’s affiliates, is greater than 700 million monthly active users in the preceding calendar month, you must request a license from Meta, which Meta may grant to you in its sole discretion, and you are not authorized to exercise any of the rights under this Agreement unless or until Meta otherwise expressly grants you such rights.
29
+
30
+ 3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE LLAMA MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE LLAMA MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE LLAMA MATERIALS AND ANY OUTPUT AND RESULTS.
31
+
32
+ 4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
33
+
34
+ 5. Intellectual Property.
35
+ a. No trademark licenses are granted under this Agreement, and in connection with the Llama Materials, neither Meta nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Llama Materials or as set forth in this Section 5(a). Meta hereby grants you a license to use “Llama 3” (the “Mark”) solely as required to comply with the last sentence of Section 1.b.i. You will comply with Meta’s brand guidelines (currently accessible at https://about.meta.com/brand/resources/meta/company-brand/ ). All goodwill arising out of your use of the Mark will inure to the benefit of Meta.
36
+ b. Subject to Meta’s ownership of Llama Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the Llama Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications.
37
+ c. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Llama Materials or Meta Llama 3 outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the Llama Materials.
38
+
39
+ 6. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Llama Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Llama Materials. Sections 3, 4 and 7 shall survive the termination of this Agreement.
40
+
41
+ 7. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement.
42
+
43
+
44
+ Meta Llama 3 Acceptable Use Policy
45
+ Meta is committed to promoting safe and fair use of its tools and features, including Meta Llama 3. If you access or use Meta Llama 3, you agree to this Acceptable Use Policy (“Policy”). The most recent copy of this policy can be found at https://llama.meta.com/llama3/use-policy
46
+ Prohibited Uses
47
+ We want everyone to use Meta Llama 3 safely and responsibly. You agree you will not use, or allow others to use, Meta Llama 3 to:
48
+ 1. Violate the law or others’ rights, including to:
49
+ a. Engage in, promote, generate, contribute to, encourage, plan, incite, or further illegal or unlawful activity or content, such as:
50
+ i. Violence or terrorism
51
+ ii. Exploitation or harm to children, including the solicitation, creation, acquisition, or dissemination of child exploitative content or failure to report Child Sexual Abuse Material
52
+ iii. Human trafficking, exploitation, and sexual violence
53
+ iv. The illegal distribution of information or materials to minors, including obscene materials, or failure to employ legally required age-gating in connection with such information or materials.
54
+ v. Sexual solicitation
55
+ vi. Any other criminal activity
56
+ b. Engage in, promote, incite, or facilitate the harassment, abuse, threatening, or bullying of individuals or groups of individuals
57
+ c. Engage in, promote, incite, or facilitate discrimination or other unlawful or harmful conduct in the provision of employment, employment benefits, credit, housing, other economic benefits, or other essential goods and services
58
+ d. Engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or related professional practices
59
+ e. Collect, process, disclose, generate, or infer health, demographic, or other sensitive personal or private information about individuals without rights and consents required by applicable laws
60
+ f. Engage in or facilitate any action or generate any content that infringes, misappropriates, or otherwise violates any third-party rights, including the outputs or results of any products or services using the Llama Materials
61
+ g. Create, generate, or facilitate the creation of malicious code, malware, computer viruses or do anything else that could disable, overburden, interfere with or impair the proper working, integrity, operation or appearance of a website or computer system
62
+
63
+ 2. Engage in, promote, incite, facilitate, or assist in the planning or development of activities that present a risk of death or bodily harm to individuals, including use of Meta Llama 3 related to the following:
64
+ a. Military, warfare, nuclear industries or applications, espionage, use for materials or activities that are subject to the International Traffic Arms Regulations (ITAR) maintained by the United States Department of State
65
+ b. Guns and illegal weapons (including weapon development)
66
+ c. Illegal drugs and regulated/controlled substances
67
+ d. Operation of critical infrastructure, transportation technologies, or heavy machinery
68
+ e. Self-harm or harm to others, including suicide, cutting, and eating disorders
69
+ f. Any content intended to incite or promote violence, abuse, or any infliction of bodily harm to an individual
70
+
71
+ 3. Intentionally deceive or mislead others, including use of Meta Llama 3 related to the following:
72
+ a. Generating, promoting, or furthering fraud or the creation or promotion of disinformation
73
+ b. Generating, promoting, or furthering defamatory content, including the creation of defamatory statements, images, or other content
74
+ c. Generating, promoting, or further distributing spam
75
+ d. Impersonating another individual without consent, authorization, or legal right
76
+ e. Representing that the use of Meta Llama 3 or outputs are human-generated
77
+ f. Generating or facilitating false online engagement, including fake reviews and other means of fake online engagement
78
+ g. Fail to appropriately disclose to end users any known dangers of your AI system
79
+
80
+ Please report any violation of this Policy, software “bug,” or other problems that could lead to a violation of this Policy through one of the following means:
81
+ * Reporting issues with the model: https://github.com/meta-llama/llama3
82
+ * Reporting risky content generated by the model: developers.facebook.com/llama_output_feedback
83
+ * Reporting bugs and security concerns: facebook.com/whitehat/info
84
+ * Reporting violations of the Acceptable Use Policy or unlicensed uses of Meta Llama 3: LlamaUseReport@meta.com
Ovis/ovis/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (206 Bytes). View file
 
Ovis/ovis/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (248 Bytes). View file
 
Ovis/ovis/model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (388 Bytes). View file
 
Ovis/ovis/model/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (446 Bytes). View file
 
Ovis/ovis/model/__pycache__/configuration_ovis.cpython-311.pyc ADDED
Binary file (2.53 kB). View file
 
Ovis/ovis/model/__pycache__/modeling_ovis.cpython-311.pyc ADDED
Binary file (29.4 kB). View file
 
Ovis/ovis/model/configuration_ovis.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Optional
2
+
3
+ from transformers import PretrainedConfig, AutoConfig
4
+
5
+
6
+ class OvisConfig(PretrainedConfig):
7
+ model_type = "ovis"
8
+
9
+ def __init__(
10
+ self,
11
+ llm_config: Optional[Union[PretrainedConfig, dict]] = None,
12
+ visual_tokenizer_config: Optional[Union[PretrainedConfig, dict]] = None,
13
+ multimodal_max_length=8192,
14
+ hidden_size=None,
15
+ conversation_formatter_class=None,
16
+ llm_attn_implementation=None,
17
+ disable_tie_weight=False,
18
+ **kwargs
19
+ ):
20
+ super().__init__(**kwargs)
21
+ if llm_config is not None:
22
+ assert isinstance(llm_config, (PretrainedConfig, dict)), \
23
+ f"expect `llm_config` to be instance of PretrainedConfig or dict, but got {type(llm_config)} type"
24
+ if not isinstance(llm_config, PretrainedConfig):
25
+ model_type = llm_config['model_type']
26
+ llm_config.pop('model_type')
27
+ llm_config = AutoConfig.for_model(model_type, **llm_config)
28
+ self.llm_config = llm_config
29
+ if visual_tokenizer_config is not None:
30
+ assert isinstance(visual_tokenizer_config, (PretrainedConfig, dict)), \
31
+ f"expect `visual_tokenizer_config` to be instance of PretrainedConfig or dict, but got {type(visual_tokenizer_config)} type"
32
+ if not isinstance(visual_tokenizer_config, PretrainedConfig):
33
+ model_type = visual_tokenizer_config['model_type']
34
+ visual_tokenizer_config.pop('model_type')
35
+ visual_tokenizer_config = AutoConfig.for_model(model_type, **visual_tokenizer_config)
36
+ self.visual_tokenizer_config = visual_tokenizer_config
37
+ self.multimodal_max_length = multimodal_max_length
38
+ self.hidden_size = hidden_size
39
+ self.conversation_formatter_class = conversation_formatter_class
40
+ self.llm_attn_implementation = llm_attn_implementation
41
+ self.disable_tie_weight = disable_tie_weight
Ovis/ovis/model/conversation_formatter.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import List, Dict
3
+
4
+ from ovis.util.constants import IMAGE_TOKEN_ID, IGNORE_ID, IMAGE_TOKEN
5
+
6
+
7
+ class ConversationFormatter(ABC):
8
+ support_tokenizer_types = None
9
+
10
+ def __init__(self, tokenizer):
11
+ tokenizer_type = type(tokenizer).__name__
12
+ assert tokenizer_type in self.support_tokenizer_types, \
13
+ f'Invalid tokenizer type, expected one from `{self.support_tokenizer_types}`, but got `{tokenizer_type}`'
14
+ self.tokenizer = tokenizer
15
+ self.image_token = IMAGE_TOKEN
16
+ self.image_token_id = IMAGE_TOKEN_ID
17
+ self.ignore_id = IGNORE_ID
18
+
19
+ def _tokenize_with_image_symbol(self, text):
20
+ text_chunks = [self.tokenizer(chunk, add_special_tokens=False).input_ids for chunk in
21
+ text.split(self.image_token)]
22
+ token_ids = []
23
+ num_chuck = len(text_chunks)
24
+ for i, chunk in enumerate(text_chunks):
25
+ token_ids.extend(chunk)
26
+ if i < num_chuck - 1:
27
+ token_ids.append(self.image_token_id)
28
+ return token_ids
29
+
30
+ @abstractmethod
31
+ def format(self, conversations: List[Dict], generation_preface=None):
32
+ pass
33
+
34
+ @abstractmethod
35
+ def format_query(self, query, generation_preface=""):
36
+ pass
37
+
38
+
39
+ class QwenConversationFormatter(ConversationFormatter):
40
+ support_tokenizer_types = ['QWenTokenizer', 'Qwen2TokenizerFast']
41
+
42
+ def __init__(self, tokenizer):
43
+ super().__init__(tokenizer)
44
+ self.from2role = {
45
+ "system": "<|im_start|>system\n",
46
+ "human": "<|im_start|>user\n",
47
+ "gpt": "<|im_start|>assistant\n",
48
+ }
49
+ self.gpt_token_num = None
50
+ self.im_end = "<|im_end|>\n"
51
+ self.default_system_prompt = "You are a helpful assistant."
52
+
53
+ def format(self, conversations: List[Dict], generation_preface=None):
54
+ if self.gpt_token_num is None:
55
+ self.gpt_token_num = len(self.tokenizer(self.from2role["gpt"], add_special_tokens=False).input_ids)
56
+
57
+ if conversations[0]["from"] != "system":
58
+ conversations.insert(0, {
59
+ "from": "system",
60
+ "value": self.default_system_prompt
61
+ })
62
+
63
+ if generation_preface is not None:
64
+ conversations.append({
65
+ "from": "gpt",
66
+ "value": generation_preface
67
+ })
68
+
69
+ prompt = ""
70
+ input_ids = []
71
+ labels = []
72
+ num_conversation = len(conversations)
73
+ for i, conversation in enumerate(conversations):
74
+ frm = conversation["from"]
75
+ role = self.from2role[frm]
76
+ message = conversation["value"]
77
+ text = role + message
78
+ if i < num_conversation - 1 or generation_preface is None:
79
+ text += self.im_end
80
+ prompt += text
81
+ token_ids = self._tokenize_with_image_symbol(text)
82
+ input_ids.extend(token_ids)
83
+ label_ids = [self.ignore_id] * len(token_ids)
84
+ if frm == "gpt" and generation_preface is None:
85
+ # learning `\n` following `im_end` is meaningless, so the last `\n` token is ignored in label
86
+ label_ids[self.gpt_token_num:-1] = token_ids[self.gpt_token_num:-1]
87
+ labels.extend(label_ids)
88
+
89
+ assert self._tokenize_with_image_symbol(prompt) == input_ids
90
+ assert len(input_ids) == len(labels)
91
+
92
+ return prompt, input_ids, labels
93
+
94
+ def format_query(self, query, generation_preface=""):
95
+ prompt, input_ids, _ = self.format([{
96
+ "from": "human",
97
+ "value": query
98
+ }], generation_preface=generation_preface)
99
+
100
+ return prompt, input_ids
101
+
102
+
103
+ class Llama3ConversationFormatter(ConversationFormatter):
104
+ support_tokenizer_types = ['PreTrainedTokenizerFast']
105
+
106
+ def __init__(self, tokenizer):
107
+ super().__init__(tokenizer)
108
+ self.from2role = {
109
+ "system": "<|start_header_id|>system<|end_header_id|>\n\n",
110
+ "human": "<|start_header_id|>user<|end_header_id|>\n\n",
111
+ "gpt": "<|start_header_id|>assistant<|end_header_id|>\n\n",
112
+ }
113
+ self.gpt_token_num = None
114
+ self.im_end = "<|eot_id|>"
115
+ self.default_system_prompt = "You are a helpful and honest multimodal assistant."
116
+ self.bos_token = "<|begin_of_text|>"
117
+ self.bos_token_ids = None
118
+
119
+ def format(self, conversations: List[Dict], generation_preface=None):
120
+ if self.gpt_token_num is None:
121
+ self.gpt_token_num = len(self.tokenizer(self.from2role["gpt"], add_special_tokens=False).input_ids)
122
+
123
+ if self.bos_token_ids is None:
124
+ self.bos_token_ids = self.tokenizer(self.bos_token, add_special_tokens=False).input_ids
125
+
126
+ if conversations[0]["from"] != "system":
127
+ conversations.insert(0, {
128
+ "from": "system",
129
+ "value": self.default_system_prompt
130
+ })
131
+
132
+ if generation_preface is not None:
133
+ conversations.append({
134
+ "from": "gpt",
135
+ "value": generation_preface
136
+ })
137
+
138
+ prompt = "" + self.bos_token
139
+ input_ids = [] + self.bos_token_ids
140
+ labels = [] + [IGNORE_ID] * len(input_ids)
141
+ num_conversation = len(conversations)
142
+ for i, conversation in enumerate(conversations):
143
+ frm = conversation["from"]
144
+ role = self.from2role[frm]
145
+ message = conversation["value"].strip()
146
+ text = role + message
147
+ if i < num_conversation - 1 or generation_preface is None:
148
+ text += self.im_end
149
+ prompt += text
150
+ token_ids = self._tokenize_with_image_symbol(text)
151
+ input_ids.extend(token_ids)
152
+ label_ids = [self.ignore_id] * len(token_ids)
153
+ if frm == "gpt":
154
+ label_ids[self.gpt_token_num:] = token_ids[self.gpt_token_num:]
155
+ labels.extend(label_ids)
156
+
157
+ assert self._tokenize_with_image_symbol(prompt) == input_ids
158
+ assert len(input_ids) == len(labels)
159
+
160
+ return prompt, input_ids, labels
161
+
162
+ def format_query(self, query, generation_preface=""):
163
+ prompt, input_ids, _ = self.format([{
164
+ "from": "human",
165
+ "value": query
166
+ }], generation_preface=generation_preface)
167
+
168
+ return prompt, input_ids
169
+
170
+
171
+ class GemmaConversationFormatter(ConversationFormatter):
172
+ support_tokenizer_types = ['GemmaTokenizer', 'GemmaTokenizerFast']
173
+
174
+ def __init__(self, tokenizer):
175
+ super().__init__(tokenizer)
176
+ # Gemma does not support system prompt
177
+ self.from2role = {
178
+ "human": "<start_of_turn>user\n",
179
+ "gpt": "<start_of_turn>model\n",
180
+ }
181
+ self.gpt_token_num = None
182
+ self.im_end = "<end_of_turn>\n"
183
+ self.bos_token = "<bos>"
184
+ self.bos_token_ids = None
185
+
186
+ def format(self, conversations: List[Dict], generation_preface=None):
187
+ if self.gpt_token_num is None:
188
+ self.gpt_token_num = len(self.tokenizer(self.from2role["gpt"], add_special_tokens=False).input_ids)
189
+
190
+ if self.bos_token_ids is None:
191
+ self.bos_token_ids = self.tokenizer(self.bos_token, add_special_tokens=False).input_ids
192
+
193
+ if conversations[0]["from"] == "system":
194
+ raise ValueError("Gemma does not support system prompt")
195
+
196
+ if generation_preface is not None:
197
+ conversations.append({
198
+ "from": "gpt",
199
+ "value": generation_preface
200
+ })
201
+
202
+ prompt = "" + self.bos_token
203
+ input_ids = [] + self.bos_token_ids
204
+ labels = [] + [IGNORE_ID] * len(input_ids)
205
+ num_conversation = len(conversations)
206
+ for i, conversation in enumerate(conversations):
207
+ frm = conversation["from"]
208
+ role = self.from2role[frm]
209
+ message = conversation["value"].strip()
210
+ text = role + message
211
+ if i < num_conversation - 1 or generation_preface is None:
212
+ text += self.im_end
213
+ prompt += text
214
+ token_ids = self._tokenize_with_image_symbol(text)
215
+ input_ids.extend(token_ids)
216
+ label_ids = [self.ignore_id] * len(token_ids)
217
+ if frm == "gpt":
218
+ # learning `\n` following `im_end` is meaningless, so the last `\n` token is ignored in label
219
+ label_ids[self.gpt_token_num:-1] = token_ids[self.gpt_token_num:-1]
220
+ labels.extend(label_ids)
221
+
222
+ assert self._tokenize_with_image_symbol(prompt) == input_ids
223
+ assert len(input_ids) == len(labels)
224
+
225
+ return prompt, input_ids, labels
226
+
227
+ def format_query(self, query, generation_preface=""):
228
+ prompt, input_ids, _ = self.format([{
229
+ "from": "human",
230
+ "value": query
231
+ }], generation_preface=generation_preface)
232
+
233
+ return prompt, input_ids
Ovis/ovis/model/visual_tokenizer/__pycache__/base_visual_tokenizer.cpython-310.pyc ADDED
Binary file (10.1 kB). View file
 
Ovis/ovis/model/visual_tokenizer/__pycache__/base_visual_tokenizer.cpython-311.pyc ADDED
Binary file (18.3 kB). View file
 
Ovis/ovis/model/visual_tokenizer/__pycache__/clip_visual_tokenizer.cpython-310.pyc ADDED
Binary file (2.03 kB). View file
 
Ovis/ovis/model/visual_tokenizer/__pycache__/clip_visual_tokenizer.cpython-311.pyc ADDED
Binary file (3.03 kB). View file
 
Ovis/ovis/model/visual_tokenizer/__pycache__/siglip_visual_tokenizer.cpython-310.pyc ADDED
Binary file (2.05 kB). View file
 
Ovis/ovis/model/visual_tokenizer/__pycache__/siglip_visual_tokenizer.cpython-311.pyc ADDED
Binary file (3.06 kB). View file
 
Ovis/ovis/serve/runner.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import field, dataclass
2
+ from typing import Optional, Union, List
3
+
4
+ import torch
5
+ from PIL import Image
6
+
7
+ from ovis.model.modeling_ovis import Ovis
8
+ from ovis.util.constants import IMAGE_TOKEN
9
+
10
+
11
+ @dataclass
12
+ class RunnerArguments:
13
+ model_path: str
14
+ max_new_tokens: int = field(default=512)
15
+ do_sample: bool = field(default=False)
16
+ top_p: Optional[float] = field(default=None)
17
+ top_k: Optional[int] = field(default=None)
18
+ temperature: Optional[float] = field(default=None)
19
+ max_partition: int = field(default=9)
20
+
21
+
22
+ class OvisRunner:
23
+ def __init__(self, args: RunnerArguments):
24
+ self.model_path = args.model_path
25
+ self.dtype = torch.bfloat16
26
+ self.device = torch.cuda.current_device()
27
+ self.dtype = torch.bfloat16
28
+ self.model = Ovis.from_pretrained(self.model_path, torch_dtype=self.dtype, multimodal_max_length=8192)
29
+ self.model = self.model.eval().to(device=self.device)
30
+ self.eos_token_id = self.model.generation_config.eos_token_id
31
+ self.text_tokenizer = self.model.get_text_tokenizer()
32
+ self.pad_token_id = self.text_tokenizer.pad_token_id
33
+ self.visual_tokenizer = self.model.get_visual_tokenizer()
34
+ self.conversation_formatter = self.model.get_conversation_formatter()
35
+ self.image_placeholder = IMAGE_TOKEN
36
+ self.max_partition = args.max_partition
37
+ self.gen_kwargs = dict(
38
+ max_new_tokens=args.max_new_tokens,
39
+ do_sample=args.do_sample,
40
+ top_p=args.top_p,
41
+ top_k=args.top_k,
42
+ temperature=args.temperature,
43
+ repetition_penalty=None,
44
+ eos_token_id=self.eos_token_id,
45
+ pad_token_id=self.pad_token_id,
46
+ use_cache=True
47
+ )
48
+
49
+ def preprocess(self, inputs: List[Union[Image.Image, str]]):
50
+ # for single image and single text inputs, ensure image ahead
51
+ if len(inputs) == 2 and isinstance(inputs[0], str) and isinstance(inputs[1], Image.Image):
52
+ inputs = reversed(inputs)
53
+
54
+ # build query
55
+ query = ''
56
+ images = []
57
+ for data in inputs:
58
+ if isinstance(data, Image.Image):
59
+ query += self.image_placeholder + '\n'
60
+ images.append(data)
61
+ elif isinstance(data, str):
62
+ query += data.replace(self.image_placeholder, '')
63
+ elif data is not None:
64
+ raise RuntimeError(f'Invalid input type, expected `PIL.Image.Image` or `str`, but got {type(data)}')
65
+
66
+ # format conversation
67
+ prompt, input_ids, pixel_values = self.model.preprocess_inputs(
68
+ query, images, max_partition=self.max_partition)
69
+ attention_mask = torch.ne(input_ids, self.text_tokenizer.pad_token_id)
70
+ input_ids = input_ids.unsqueeze(0).to(device=self.device)
71
+ attention_mask = attention_mask.unsqueeze(0).to(device=self.device)
72
+ if pixel_values is not None:
73
+ pixel_values = [pixel_values.to(device=self.device, dtype=self.dtype)]
74
+ else:
75
+ pixel_values = [None]
76
+
77
+ return prompt, input_ids, attention_mask, pixel_values
78
+
79
+ def run(self, inputs: List[Union[Image.Image, str]]):
80
+ prompt, input_ids, attention_mask, pixel_values = self.preprocess(inputs)
81
+ output_ids = self.model.generate(
82
+ input_ids,
83
+ pixel_values=pixel_values,
84
+ attention_mask=attention_mask,
85
+ **self.gen_kwargs
86
+ )
87
+ output = self.text_tokenizer.decode(output_ids[0], skip_special_tokens=True)
88
+ input_token_len = input_ids.shape[1]
89
+ output_token_len = output_ids.shape[1]
90
+ response = dict(
91
+ prompt=prompt,
92
+ output=output,
93
+ prompt_tokens=input_token_len,
94
+ total_tokens=input_token_len + output_token_len
95
+ )
96
+ return response
97
+
98
+
99
+ if __name__ == '__main__':
100
+ runner_args = RunnerArguments(model_path='<model_path>')
101
+ runner = OvisRunner(runner_args)
102
+ image = Image.open('<image_path>')
103
+ text = '<prompt>'
104
+ response = runner.run([image, text])
105
+ print(response['output'])
Ovis/ovis/serve/server.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os.path
3
+
4
+ import gradio as gr
5
+ from gradio.components import Textbox, Image
6
+
7
+ from ovis.serve.runner import RunnerArguments, OvisRunner
8
+
9
+
10
+ class Server:
11
+ def __init__(self, runner: OvisRunner):
12
+ self.runner = runner
13
+
14
+ def __call__(self, image, text):
15
+ response = self.runner.run([image, text])
16
+ output = response["output"]
17
+ return output
18
+
19
+
20
+ if __name__ == '__main__':
21
+ parser = argparse.ArgumentParser(description='Ovis Server')
22
+ parser.add_argument('--model_path', type=str, required=True)
23
+ parser.add_argument('--flagging_dir', type=str, default=os.path.expanduser('~/ovis-flagged'))
24
+ parser.add_argument('--max_partition', type=int, default=9)
25
+ parser.add_argument('--port', type=int, required=True)
26
+ args = parser.parse_args()
27
+
28
+ os.makedirs(args.flagging_dir, exist_ok=True)
29
+ runner_args = RunnerArguments(
30
+ model_path=args.model_path,
31
+ max_partition=args.max_partition
32
+ )
33
+ demo = gr.Interface(
34
+ fn=Server(OvisRunner(runner_args)),
35
+ inputs=[Image(type='pil', label='image'),
36
+ Textbox(placeholder='Enter your text here...', label='prompt')],
37
+ outputs=gr.Markdown(),
38
+ title=args.model_path.split('/')[-1],
39
+ flagging_dir=args.flagging_dir
40
+ )
41
+ demo.launch(server_port=args.port)
Ovis/ovis/train/__init__.py ADDED
File without changes
Ovis/ovis/train/arguments.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Optional
3
+
4
+ import transformers
5
+
6
+
7
+ @dataclass
8
+ class ModelArguments:
9
+ llm_name_or_path: Optional[str] = field(default=None)
10
+ visual_tokenizer_type: str = field(default=None)
11
+ visual_vocab_size: int = field(default=8192)
12
+ visual_drop_cls_token: bool = field(default=False)
13
+ visual_tokenize_function: str = field(default='softmax')
14
+ visual_tau: float = field(default=1.0)
15
+ visual_depths: Optional[str] = field(default=None)
16
+ visual_hidden_stride: int = field(default=1)
17
+ multimodal_max_length: int = field(default=2048)
18
+ conversation_formatter_class: str = field(default=None)
19
+ pad_token_id: Optional[int] = field(default=None)
20
+ llm_attn_implementation: Optional[str] = field(default=None)
21
+ disable_tie_weight: bool = field(default=False)
22
+
23
+
24
+ @dataclass
25
+ class TrainingArguments(transformers.TrainingArguments):
26
+ dataset_names: Optional[str] = field(default=None) # a|b|c
27
+ dataset_info: Optional[str] = field(default='dataset_info_v1_6')
28
+ ovis_pretrained_path: Optional[str] = field(default=None)
29
+ visual_tokenizer_pretrained_path: Optional[str] = field(default=None)
30
+ caption_template: Optional[str] = field(default=None)
31
+ stage: Optional[int] = field(default=None)
32
+ train_modules: Optional[str] = field(default=None)
33
+ cache_dir: Optional[str] = field(default=None)
34
+ optim: str = field(default="adamw_torch")
35
+ visual_max_tau: float = field(default=5.0)
36
+ visual_min_tau: float = field(default=0.05)
37
+ save_safetensors: bool = field(default=True)
38
+ monitor_step: int = field(default=100)
39
+ vte_re_init: bool = field(default=False)
40
+ text_max_length: int = field(default=1024)
41
+ max_partitions: str = field(default="9|1|1")
42
+
43
+ def __post_init__(self):
44
+ if self.gradient_checkpointing:
45
+ self.gradient_checkpointing_kwargs = {"use_reentrant": False}
46
+ if self.stage < 3:
47
+ self.save_safetensors = False
48
+ super().__post_init__()
Ovis/ovis/train/callback.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import deepspeed
2
+ import torch
3
+ from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
4
+
5
+ from ovis.util.constants import END_LINE, BEGIN_LINE
6
+ from ovis.util.utils import rank0_print
7
+
8
+
9
+ class TuneTauCallback(TrainerCallback):
10
+ def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
11
+ visual_tokenizer = kwargs['model'].get_visual_tokenizer()
12
+ current_step = state.global_step
13
+ max_step = state.max_steps
14
+ ratio = current_step / max_step
15
+ visual_tokenizer.config.tau = args.visual_max_tau - (args.visual_max_tau - args.visual_min_tau) * ratio
16
+
17
+
18
+ class MonitorCallback(TrainerCallback):
19
+ def _monitoring(self, model, step):
20
+ with torch.no_grad():
21
+ with deepspeed.zero.GatheredParameters(model.get_monitor_tensors().values()):
22
+ for k, v in model.get_monitor_tensors().items():
23
+ rank0_print(BEGIN_LINE)
24
+ rank0_print(f'{k} @ step {step} with sum: {v.sum().item()} and content: ')
25
+ rank0_print(v)
26
+ rank0_print(END_LINE)
27
+
28
+ def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
29
+ model = kwargs['model']
30
+ step = state.global_step
31
+ if step % args.monitor_step == 0 or step == 10: # monitor at step 10 for fast check
32
+ self._monitoring(model, step)
33
+
34
+ def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
35
+ model = kwargs['model']
36
+ step = state.global_step
37
+ self._monitoring(model, step)
Ovis/ovis/train/train.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import pathlib
4
+
5
+ import deepspeed
6
+ import torch
7
+ import transformers
8
+ from deepspeed import get_accelerator
9
+ from torch.utils.data import ConcatDataset
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, AutoConfig
11
+ from transformers import Trainer
12
+ from transformers.integrations.deepspeed import unset_hf_deepspeed_config, set_hf_deepspeed_config
13
+
14
+ from callback import TuneTauCallback, MonitorCallback
15
+ from ovis.model.configuration_ovis import OvisConfig
16
+ from ovis.model.modeling_ovis import Ovis
17
+ from ovis.train.arguments import ModelArguments, TrainingArguments
18
+ from ovis.train.dataset.caption_dataset import CaptionDataset
19
+ from ovis.train.dataset.conversation_dataset import ConversationDataset
20
+ from ovis.train.dataset.multimodal_dataset import DataCollatorForMultimodalDataset
21
+ from ovis.util.constants import BEGIN_LINE, END_LINE
22
+ from ovis.util.utils import smart_unit, rank0_print
23
+
24
+
25
+ def train():
26
+ # parse args
27
+ parser = transformers.HfArgumentParser(
28
+ (ModelArguments, TrainingArguments))
29
+ model_args, training_args = parser.parse_args_into_dataclasses()
30
+
31
+ # save args to checkpoint dir
32
+ with training_args.main_process_first(local=False):
33
+ if training_args.process_index == 0:
34
+ def args2dict(args):
35
+ return {k: str(v) for k, v in args.__dict__.items()}
36
+
37
+ args_log = json.dumps(dict(
38
+ model_args=args2dict(model_args),
39
+ training_args=args2dict(training_args)
40
+ ), ensure_ascii=False, indent=2)
41
+ print(args_log)
42
+ os.makedirs(training_args.output_dir, exist_ok=True)
43
+ with open(os.path.join(training_args.output_dir, 'model_training_args.json'), 'w',
44
+ encoding='utf-8') as f:
45
+ f.write(args_log + '\n')
46
+
47
+ # construct or load ovis model
48
+ if not training_args.ovis_pretrained_path: # construct model (S1)
49
+ # 1. construct ovis config
50
+ ovis_config = OvisConfig(
51
+ multimodal_max_length=model_args.multimodal_max_length,
52
+ conversation_formatter_class=model_args.conversation_formatter_class,
53
+ llm_attn_implementation=model_args.llm_attn_implementation
54
+ )
55
+ # 2. load pretrained llm and text tokenizer
56
+ attn_kwargs = dict()
57
+ if model_args.llm_attn_implementation:
58
+ attn_kwargs['attn_implementation'] = model_args.llm_attn_implementation
59
+ llm = AutoModelForCausalLM.from_pretrained(model_args.llm_name_or_path, **attn_kwargs)
60
+ text_tokenizer = AutoTokenizer.from_pretrained(model_args.llm_name_or_path)
61
+ if text_tokenizer.pad_token_id is None and model_args.pad_token_id is not None:
62
+ text_tokenizer.pad_token_id = model_args.pad_token_id
63
+ # 3. construct visual tokenizer
64
+ # deepspeed zero.Init with bfloat16 fail for visual_tokenizer, so temporarily disable zero.Init here
65
+ unset_hf_deepspeed_config()
66
+ if training_args.visual_tokenizer_pretrained_path is not None:
67
+ visual_tokenizer = AutoModel.from_pretrained(
68
+ training_args.visual_tokenizer_pretrained_path,
69
+ image_processor_name_or_path=training_args.visual_tokenizer_pretrained_path
70
+ )
71
+ else:
72
+ visual_tokenizer_config = AutoConfig.for_model(
73
+ model_type=model_args.visual_tokenizer_type + "_visual_tokenizer",
74
+ vocab_size=model_args.visual_vocab_size,
75
+ tokenize_function=model_args.visual_tokenize_function,
76
+ tau=model_args.visual_tau,
77
+ depths=model_args.visual_depths,
78
+ drop_cls_token=model_args.visual_drop_cls_token,
79
+ hidden_stride=model_args.visual_hidden_stride,
80
+ )
81
+ visual_tokenizer = AutoModel.from_config(visual_tokenizer_config, train_from_scratch=True)
82
+ visual_tokenizer = visual_tokenizer.to(
83
+ device=torch.device(get_accelerator().device_name(os.getenv("LOCAL_RANK"))))
84
+ if getattr(training_args, 'hf_deepspeed_config', None) is not None:
85
+ set_hf_deepspeed_config(training_args.hf_deepspeed_config)
86
+ # 4. construct ovis model
87
+ model = Ovis(ovis_config, llm=llm, text_tokenizer=text_tokenizer, visual_tokenizer=visual_tokenizer,
88
+ train_from_scratch=True)
89
+ else: # load pretrained ovis model
90
+ model, loading_info = Ovis.from_pretrained(training_args.ovis_pretrained_path,
91
+ multimodal_max_length=model_args.multimodal_max_length,
92
+ output_loading_info=True)
93
+ rank0_print(BEGIN_LINE)
94
+ rank0_print(f'Loading info of Ovis:\n{loading_info}')
95
+ rank0_print(END_LINE)
96
+ training_args.vte_re_init = False
97
+
98
+ model.get_llm().config.use_cache = False
99
+ model.config.use_cache = False
100
+ text_tokenizer = model.get_text_tokenizer()
101
+
102
+ rank0_print(BEGIN_LINE)
103
+ rank0_print(f'model.config:\n{model.config}')
104
+ rank0_print(END_LINE)
105
+
106
+ # maybe re-init vte
107
+ if training_args.vte_re_init:
108
+ with deepspeed.zero.GatheredParameters([model.get_wte().weight]):
109
+ mean = model.get_wte().weight.mean().item()
110
+ std = model.get_wte().weight.std().item()
111
+ rank0_print(f'Statistics of embedding table of LLM: {mean=}, {std=}')
112
+ model.re_init_vte(mean, std)
113
+
114
+ # select train modules
115
+ model.requires_grad_(False)
116
+ for module in training_args.train_modules.split('|'):
117
+ if module == 'all':
118
+ model.requires_grad_(True)
119
+ elif module == 'llm':
120
+ model.get_llm().requires_grad_(True)
121
+ elif module == 'visual_tokenizer':
122
+ model.get_visual_tokenizer().requires_grad_(True)
123
+ elif module == 'visual_tokenizer.backbone':
124
+ model.get_visual_tokenizer().get_backbone().requires_grad_(True)
125
+ elif module.startswith('visual_tokenizer.backbone.layer.'):
126
+ layer_index = int(module[len('visual_tokenizer.backbone.layer.'):])
127
+ layer = model.get_visual_tokenizer().get_backbone_layer(layer_index)
128
+ layer.requires_grad_(True)
129
+ elif module == 'visual_tokenizer.head':
130
+ model.get_visual_tokenizer().get_head().requires_grad_(True)
131
+ elif module == 'vte':
132
+ model.get_vte().requires_grad_(True)
133
+ else:
134
+ raise ValueError(f'Invalid train module name: {module}')
135
+
136
+ rank0_print(BEGIN_LINE)
137
+ rank0_print('Parameters to train:')
138
+ for name, param in model.named_parameters():
139
+ if param.requires_grad:
140
+ rank0_print(name)
141
+ rank0_print(f'LLM\'s attn implementation: {model.get_llm().config._attn_implementation}')
142
+ rank0_print(END_LINE)
143
+
144
+ # construct data module
145
+ datasets = []
146
+ dataset_info_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
147
+ f'dataset/{training_args.dataset_info}.json')
148
+ with open(dataset_info_path, 'r', encoding='utf-8') as f:
149
+ dataset_info = json.load(f)
150
+ for name in training_args.dataset_names.split('|'):
151
+ info = dataset_info[name]
152
+ data_format = info['data_format']
153
+ if data_format == 'caption':
154
+ dataset = CaptionDataset(name, info, model, training_args)
155
+ elif data_format == 'conversation':
156
+ dataset = ConversationDataset(name, info, model, training_args)
157
+ else:
158
+ raise ValueError(f'Invalid data format `{data_format}` for dataset `{name}`')
159
+ datasets.append(dataset)
160
+ data_module = dict(
161
+ train_dataset=ConcatDataset(datasets),
162
+ data_collator=DataCollatorForMultimodalDataset(text_tokenizer)
163
+ )
164
+
165
+ # train
166
+ train_callbacks = [MonitorCallback]
167
+ if model_args.visual_tokenize_function == 'gumbel_argmax':
168
+ train_callbacks.append(TuneTauCallback)
169
+ trainer = Trainer(
170
+ model=model,
171
+ args=training_args,
172
+ callbacks=train_callbacks,
173
+ **data_module
174
+ )
175
+ rank0_print(BEGIN_LINE)
176
+ rank0_print('Dataset sample tensor:')
177
+ rank0_print(data_module['train_dataset'][0])
178
+ rank0_print(END_LINE)
179
+ rank0_print(BEGIN_LINE)
180
+ rank0_print('Dataset sample input_ids decoding:')
181
+ rank0_print(text_tokenizer.decode([x for x in data_module['train_dataset'][0]['input_ids'] if x >= 0]))
182
+ rank0_print(END_LINE)
183
+ rank0_print(BEGIN_LINE)
184
+ rank0_print('Dataset sample labels decoding:')
185
+ rank0_print(text_tokenizer.decode([x for x in data_module['train_dataset'][0]['labels'] if x >= 0]))
186
+ rank0_print(END_LINE)
187
+ rank0_print(BEGIN_LINE)
188
+ rank0_print(f'#param of model: {smart_unit(model.num_parameters())}')
189
+ rank0_print(f'#param of llm: {smart_unit(model.get_llm().num_parameters())}')
190
+ rank0_print(f'#param of visual_tokenizer: {smart_unit(model.get_visual_tokenizer().num_parameters())}')
191
+ rank0_print(f'#param of vte: {smart_unit(model.get_vte().weight.numel())}')
192
+ rank0_print(END_LINE)
193
+ if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
194
+ trainer.train(resume_from_checkpoint=True)
195
+ else:
196
+ trainer.train()
197
+ trainer.save_state()
198
+
199
+ # save model
200
+ model.get_llm().config.use_cache = True
201
+ model.config.use_cache = True
202
+ trainer.save_model()
203
+
204
+
205
+ if __name__ == '__main__':
206
+ train()
Ovis/ovis/util/constants.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Constants
2
+ IGNORE_ID = -100
3
+ IMAGE_TOKEN_ID = -200
4
+ IMAGE_TOKEN = "<image>"
5
+
6
+ IMAGE_ATOM_ID = -300
7
+ IMAGE_INDICATOR_IDS = [-301, -302, -303, -304, -305]
8
+
9
+ # Log & Print
10
+ BEGIN_LINE = '========================************========================'
11
+ END_LINE = '------------------------------------------------------------'
Ovis/ovis/util/utils.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from importlib import import_module
3
+
4
+
5
+ def rank0_print(*args):
6
+ if int(os.getenv("LOCAL_PROCESS_RANK", os.getenv("LOCAL_RANK", 0))) == 0:
7
+ print(*args)
8
+
9
+
10
+ def smart_unit(num):
11
+ if num / 1.0e9 >= 1:
12
+ return f'{num / 1.0e9:.2f}B'
13
+ else:
14
+ return f'{num / 1.0e6:.2f}M'
15
+
16
+
17
+ def import_class_from_string(full_class_string):
18
+ # Split the path to get separate module and class names
19
+ module_path, _, class_name = full_class_string.rpartition('.')
20
+
21
+ # Import the module using the module path
22
+ module = import_module(module_path)
23
+
24
+ # Get the class from the imported module
25
+ cls = getattr(module, class_name)
26
+ return cls
llm2vec/docs/.gitignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Jekyll
2
+ _site
3
+ .sass-cache
4
+ .jekyll-metadata
5
+ Gemfile.lock
6
+ vendor/*
7
+ .bundle
8
+
9
+ *.vscode
llm2vec/docs/Gemfile ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ source "https://rubygems.org"
2
+
3
+ gem "webrick"
4
+ gem "github-pages", group: :jekyll_plugins
5
+
6
+ gem "tzinfo-data"
7
+ gem "wdm", "~> 0.1.0" if Gem.win_platform?
8
+
9
+ # If you have any plugins, put them here!
10
+ group :jekyll_plugins do
11
+ gem "jekyll-paginate"
12
+ gem "jekyll-sitemap"
13
+ gem "jekyll-gist"
14
+ gem "jekyll-feed"
15
+ gem "jemoji"
16
+ gem "jekyll-include-cache"
17
+ gem "jekyll-algolia"
18
+ end
llm2vec/docs/README.md ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project Page Template
2
+
3
+ This is a project page template for McGill-NLP projects.
4
+
5
+ ![Demo of project page](images/demo.jpg)
6
+
7
+ ## Getting started
8
+
9
+ You can follow one of the following ways to get started.
10
+
11
+ ### Copy from template
12
+
13
+ If you have not yet created a repo for your project, you can copy the template from the following link. Simply click on the "Use this template" button, or [click here](https://github.com/McGill-NLP/project-page-template/generate).
14
+
15
+ ### Cloning with git
16
+
17
+ If you have already created a repo for your project, you can clone the template and copy the files to your project. Let's assume you are in the root of your project directory, e.g. `my-project`.
18
+
19
+ ```bash
20
+ cd .. # Go to parent directory
21
+ git clone https://github.com/McGill-NLP/project-page-template
22
+ cp -r project-page-template/docs my-project/
23
+ cp project-page-template/README.md my-project/docs/
24
+ cd my-project/
25
+ ```
26
+
27
+ ## Activate GitHub Pages
28
+
29
+ Once the template is copied to your project, you need to activate GitHub Pages for your project.
30
+
31
+ 1. Click on the "Settings" button in the top right corner of the page.
32
+ 2. Click on the "Pages" tab.
33
+ 3. In "Source", select branch to be "main" and folder to "/docs".
34
+ 4. Click on "Save"
35
+ 5. Go to the "Actions" tab (on the right of "Pull requests" tab) and wait for the action to finish.
36
+ 6. Visit your project page at mcgill-nlp.github.io/<your-project-name>
37
+
38
+
39
+ ### Why `docs/`?
40
+
41
+ You might be wondering why all the files for the webpage is in `/docs`. Well the page is not really "docs" per se, it's because `github-pages` only allows us to either use the root folder or `/docs`. So we are forced to use the latter in order to clearly separate this page from the rest of the project. Maybe in the future GitHub will allow other names like `/page`, but before then there's nothing we can do...
42
+
43
+ > Well technically, you *can* push everything to a separate branch, but the original author of this repo is in the school of thought that branches are meant for alternate or historical versions of the `main` branch, or serve as a way to create pull requests. The original author will also vehemently defend this philosophy if one tries to argue otherwise :)
44
+
45
+ But what if you actually want to write some docs for your library? You can just edit `docs/_pages/docs.md`, which is the real page for documentations.
46
+
47
+
48
+ ## Navigation bar
49
+
50
+ You can already find links to different pages in the navigation bar. To add, remove, or modify links, you can edit [`docs/_data/navigation.yml`](docs/_data/navigation.yml) file. The `title` corresponds to the text that appear on the navbar, and the `url` corresponds to the relative URL of the page. It is not recommended to include an external URL, as that should be in `/home` page.
51
+
52
+ ## Modifying a page
53
+
54
+ The files are located in [`docs/_pages/`](docs/_pages/). For example, if you want to modify the `/home` page, you would edit `docs/_pages/home.md`.
55
+
56
+ All of the pages are markdown files with something called a [front matter](https://jekyllrb.com/docs/front-matter/) at the top, which uses the YAML syntax. Generally, all you need to worry about is the title and the permalink (the latter is the relative URL of the page). In the case of `/home`, you also need to specify external links (`header.actions`) and author names (`excerpt`). However, a template is already provided for you, you only need to modify the content.
57
+
58
+ ## Adding and removing pages
59
+
60
+ To add a page:
61
+ 1. Create a new file in `docs/_pages/`, with the desired `permalink` to be your relative URL. So for example, `/contact/` links to `mcgill-nlp.github.io/my-project/contact`.
62
+ 2. In `docs/_data/navigation.yml`, add a new entry with the `title` and `url` from the previous step.
63
+
64
+ To remove a page:
65
+ 1. Delete the file in `docs/_pages/`.
66
+ 2. In `docs/_data/navigation.yml`, remove the entry with the same `url` as the deleted file.
67
+
68
+
69
+ ## Documentations and API for your project
70
+
71
+ Note that there's a tab that says `docs`, and you can see that it links to other pages. So this is a standalone doc page inside your webpage. Note also that, due to Github pages' caveat, we were forced to put the webpage in `/docs`, but the actual docs are in `/docs/_docs`. Now that's cleared up, you can head to [`/docs/_docs/README.md`](/docs/_docs/README.md) to read the instructions.
72
+
73
+ > Do you feel writing documentation is too complicated or time-consuming, and you'd like something more straightforward? Check out the [template for using MkDocs](https://github.com/McGill-NLP/mkdocs-template) instead. However, the simplicity comes at the cost of more repositories and different frameworks to maintain.
74
+
75
+ ## Advanced
76
+
77
+ For any advanced modification, it is recommended to look in the advanced section of the readme of the [group website](https://github.com/McGill-NLP/mcgill-nlp.github.io). Below are a extra tips included for convenience.
78
+
79
+ ### Setup
80
+
81
+ Please refer to setup instructions in the readme of the [group website](https://github.com/McGill-NLP/mcgill-nlp.github.io).
82
+
83
+ ### Running locally
84
+
85
+ ```bash
86
+ cd docs/
87
+ bundle exec jekyll serve
88
+ ```
89
+
90
+ ### Removing dark mode
91
+
92
+ To remove dark mode, inside [`docs/_config.yml`](docs/_config.yml) file, remove `dark_theme_css`. The dark mode should automatically turn off.
93
+
94
+ ### Updating footer
95
+
96
+ Inside [`docs/_config.yml`](docs/_config.yml) file, you can modify the footer.
97
+
98
+ ### Modify `excerpt` in a splash page (`/home`)
99
+
100
+ If you want to modify the excerpt in the `/home` page, you can do so in [`docs/_sass_/splash.scss`](docs/_sass_/splash.scss). Note that `splash.scss` was added specifically for this template, not for the group website.
101
+
102
+ ### Modify or remove icons in splash page buttons
103
+
104
+ This is handled in [`docs/_includes/page__hero.html`](docs/_includes/page__hero.html). That file was added specifically for this template, not for the group website. You can modify that file to add, modify or remove icons.
llm2vec/docs/_config.yml ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Welcome to Jekyll!
2
+ #
3
+ # This config file is meant for settings that affect your whole blog, values
4
+ # which you are expected to set up once and rarely edit after that. If you find
5
+ # yourself editing this file very often, consider using Jekyll's data files
6
+ # feature for the data you need to update frequently.
7
+ #
8
+ # For technical reasons, this file is *NOT* reloaded automatically when you use
9
+ # 'bundle exec jekyll serve'. If you change this file, please restart the server process.
10
+
11
+ # Site settings
12
+ # These are used to personalize your new site. If you look in the HTML files,
13
+ # you will see them accessed via {{ site.title }}, {{ site.email }}, and so on.
14
+ # You can create any custom variable you would like, and they will be accessible
15
+ # in the templates via {{ site.myvariable }}.
16
+ title: McGill NLP
17
+ email:
18
+ description: >- # this means to ignore newlines until "baseurl:"
19
+ McGill NLP is a research group within McGill University and Mila focusing on various topics of natural language processing.
20
+ twitter_username: McGill_NLP
21
+ github_username: McGill-NLP
22
+ logo: "/assets/images/logo/logo.png"
23
+ dark_theme_css: "/assets/css/main-dark.css"
24
+ future: true
25
+
26
+ # Build settings
27
+ markdown: kramdown
28
+ remote_theme: mmistakes/minimal-mistakes@4.24.0
29
+ # Outputting
30
+ permalink: /:categories/:title/
31
+ timezone: America/Montreal
32
+
33
+ include:
34
+ - _pages
35
+ - _docs
36
+
37
+ # Exclude from processing.
38
+ # The following items will not be processed, by default. Create a custom list
39
+ # to override the default setting.
40
+ # exclude:
41
+ # - Gemfile
42
+ # - Gemfile.lock
43
+ # - node_modules
44
+ # - vendor/bundle/
45
+ # - vendor/cache/
46
+ # - vendor/gems/
47
+ # - vendor/ruby/
48
+
49
+ # Plugins (previously gems:)
50
+ plugins:
51
+ - jekyll-sitemap
52
+ - jekyll-gist
53
+ - jemoji
54
+ - jekyll-include-cache
55
+
56
+ author:
57
+ name : "McGill NLP Member(s)"
58
+ avatar : "/assets/images/bio/default.jpg"
59
+ bio : "Current or former lab member(s) worked on this."
60
+ links:
61
+ - label: "Website"
62
+ icon: "fas fa-fw fa-link"
63
+ url: "https://mcgill-nlp.github.io"
64
+ - label: "GitHub"
65
+ icon: "fab fa-fw fa-github"
66
+ url: "https://github.com/McGill-NLP"
67
+ - label: "Twitter"
68
+ icon: "fab fa-fw fa-twitter-square"
69
+ url: "https://twitter.com/McGill_NLP"
70
+
71
+ analytics:
72
+ provider: "google-gtag"
73
+ google:
74
+ tracking_id: "G-MEDG9XN4VP"
75
+ anonymize_ip: false # default
76
+
77
+ atom_feed:
78
+ hide: true
79
+
80
+ footer:
81
+ links:
82
+ - label: "GitHub"
83
+ icon: "fab fa-fw fa-github"
84
+ url: "https://github.com/McGill-NLP"
85
+ - label: "Twitter"
86
+ icon: "fab fa-fw fa-twitter-square"
87
+ url: "https://twitter.com/McGill_NLP"
88
+
89
+ defaults:
90
+ # /docs/_pages
91
+ - scope:
92
+ path: "_pages"
93
+ type: pages
94
+ values:
95
+ layout: single
96
+ classes:
97
+ - no-sidebar
98
+ - wide
99
+ author_profile: false
100
+ # /docs/_docs
101
+ - scope:
102
+ path: "_docs"
103
+ type: pages
104
+ values:
105
+ layout: single
106
+ sidebar:
107
+ title: "Doc Pages"
108
+ nav: sidebar-docs # See /docs/_data/navigation.yml
109
+ toc: true
110
+ toc_label: "Table of Contents"
llm2vec/docs/_data/navigation.yml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ main:
2
+ - title: "Home"
3
+ url: /
4
+ - title: "Leaderboard"
5
+ url: /leaderboard/
6
+ - title: "Docs"
7
+ url: /docs/
8
+ - title: "Contact"
9
+ url: /contact/
10
+
11
+ sidebar-docs: # See "include" in /_config.yml and /docs/_docs
12
+ - title: "Home"
13
+ url: /docs/
14
+ - title: "API"
15
+ url: /docs/api
16
+ - title: "Training"
17
+ url: /docs/training
llm2vec/docs/_includes/head/custom.html ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!-- Add favicon -->
2
+ <link rel="icon" type="image/png" href="{{ site.baseurl }}/assets/images/logo/favicon.png">
3
+
4
+ {% if site.dark_theme_css %}
5
+ <!-- Dark Mode -->
6
+ <link rel="stylesheet" href="{{ '/assets/css/main.css' | relative_url }}" id="theme-css">
7
+ <link rel="stylesheet alternate" href="{{ site.dark_theme_css | relative_url }}" id="theme-css-dark">
8
+
9
+ <script type="text/javascript">
10
+ const updateNodesRel = theme => {
11
+ const node_light = document.getElementById('theme-css');
12
+ const node_dark = document.getElementById('theme-css-dark');
13
+
14
+ if (theme === "dark") {
15
+ node_light.setAttribute('rel', 'stylesheet alternate');
16
+ node_dark.setAttribute('rel', 'stylesheet');
17
+ }
18
+ else if (theme === "light") {
19
+ node_light.setAttribute('rel', 'stylesheet');
20
+ node_dark.setAttribute('rel', 'stylesheet alternate');
21
+ }
22
+ }
23
+
24
+ const changeTheme = () => {
25
+ let theme = sessionStorage.getItem('theme');
26
+
27
+ // Change the theme to the other option
28
+ if (theme === "light") {
29
+ theme = "dark";
30
+ } else {
31
+ theme = "light";
32
+ }
33
+
34
+ // Update the stored session and the nodes' rel attribute
35
+ sessionStorage.setItem('theme', theme);
36
+ updateNodesRel(theme);
37
+
38
+ return false;
39
+ }
40
+
41
+ if (sessionStorage.getItem('theme') === null) {
42
+ sessionStorage.setItem('theme', "light");
43
+ }
44
+
45
+ const theme = sessionStorage.getItem('theme');
46
+ updateNodesRel(theme);
47
+ </script>
48
+ {% endif %}
llm2vec/docs/_sass/custom/header-footer.scss ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ a.site-title {
2
+ @media (min-width: 601px) {
3
+ font-size: xx-large;
4
+ }
5
+ @media (max-width: 600px) {
6
+ font-size: large;
7
+ }
8
+ color: $primary-color;
9
+
10
+ &:hover {
11
+ color: mix($background-color, $primary-color, 25%);
12
+ }
13
+ }
14
+
15
+ .theme-toggle {
16
+ @media (min-width: 601px) {
17
+ margin: 0px;
18
+ }
19
+ }
llm2vec/docs/_sass/custom/no-sidebar.scss ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ .no-sidebar article.page {
2
+ float: left;
3
+ width: 100%;
4
+ }
5
+
6
+ .no-sidebar .archive {
7
+ float: left;
8
+ width: 100%;
9
+ }
llm2vec/docs/_sass/custom/splash.scss ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This contains the styles for the "excerpt" on a splash page
2
+ div.wrapper > p.page__lead {
3
+ font-size: x-large;
4
+ max-width: 100%;
5
+ }
llm2vec/docs/_sass/skins/dark.scss ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* ==========================================================================
2
+ Dark skin
3
+ Imported in /assets/css/main-light.scss
4
+ ========================================================================== */
5
+
6
+ /* Colors */
7
+ $background-color: #000000 !default;
8
+ $text-color: #eaeaea !default;
9
+ $primary-color: #ED1B2F !default;
10
+ $border-color: mix(#fff, $background-color, 20%) !default;
11
+ $code-background-color: mix(#000, $background-color, 15%) !default;
12
+ $code-background-color-dark: mix(#000, $background-color, 20%) !default;
13
+ $form-background-color: mix(#000, $background-color, 15%) !default;
14
+ $footer-background-color: mix($text-color, $background-color, 5%) !default;
15
+ $link-color: mix($primary-color, $text-color, 100%) !default;
16
+ $link-color-hover: mix($background-color, $link-color, 15%) !default;
17
+ $link-color-visited: mix(#000, $link-color, 0%) !default;
18
+ $masthead-link-color: $text-color !default;
19
+ $masthead-link-color-hover: mix(#000, $text-color, 20%) !default;
20
+
21
+ .author__urls.social-icons i,
22
+ .author__urls.social-icons .svg-inline--fa,
23
+ .page__footer-follow .social-icons i,
24
+ .page__footer-follow .social-icons .svg-inline--fa {
25
+ color: inherit;
26
+ }
27
+
28
+ .ais-search-box .ais-search-box--input {
29
+ background-color: $form-background-color;
30
+ }
llm2vec/docs/_sass/skins/light.scss ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ Imported in /assets/css/main-light.scss
3
+ */
4
+ $background-color: #fff !default;
5
+ $text-color: #000 !default;
6
+ $primary-color: #ED1B2F !default;
7
+ // $footer-background-color: mix($primary-color, $background-color, 100%) !default;
8
+ $link-color: #ED1B2F !default;
9
+ $link-color-hover: mix(#fff, $link-color, 25%) !default;
10
+ $link-color-visited: mix(#000, $link-color, 10%) !default;
11
+ $masthead-link-color: $text-color !default;
12
+ $masthead-link-color-hover: mix($background-color, $text-color, 25%) !default;
llm2vec/docs/assets/images/logo/favicon.png ADDED
llm2vec/docs/assets/images/logo/logo.png ADDED
llm2vec/docs/assets/images/logo/logo.svg ADDED
llm2vec/examples/classification.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.metrics import accuracy_score, f1_score
2
+ from sklearn.linear_model import LogisticRegression
3
+ import datasets
4
+ import numpy as np
5
+
6
+ import torch
7
+ from llm2vec import LLM2Vec
8
+
9
+ dataset = "mteb/amazon_counterfactual"
10
+ instruction = "Classify a given Amazon customer review text as either counterfactual or notcounterfactual: "
11
+
12
+ dataset = datasets.load_dataset(dataset, "en")
13
+
14
+ sentences_train, y_train = dataset["train"]["text"], dataset["train"]["label"]
15
+ sentences_test, y_test = dataset["test"]["text"], dataset["test"]["label"]
16
+ max_iter = 100
17
+ batch_size = 8
18
+
19
+ scores = {}
20
+ clf = LogisticRegression(
21
+ random_state=42,
22
+ n_jobs=1,
23
+ max_iter=max_iter,
24
+ verbose=0,
25
+ )
26
+
27
+ print("Loading model...")
28
+ model = LLM2Vec.from_pretrained(
29
+ "McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp",
30
+ peft_model_name_or_path="McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp-supervised",
31
+ device_map="cuda" if torch.cuda.is_available() else "cpu",
32
+ torch_dtype=torch.bfloat16,
33
+ )
34
+
35
+
36
+ def append_instruction(instruction, sentences):
37
+ new_sentences = []
38
+ for s in sentences:
39
+ new_sentences.append([instruction, s, 0])
40
+ return new_sentences
41
+
42
+
43
+ print(f"Encoding {len(sentences_train)} training sentences...")
44
+ sentences_train = append_instruction(instruction, sentences_train)
45
+ X_train = np.asarray(model.encode(sentences_train, batch_size=batch_size))
46
+
47
+ print(f"Encoding {len(sentences_test)} test sentences...")
48
+ sentences_test = append_instruction(instruction, sentences_test)
49
+ X_test = np.asarray(model.encode(sentences_test, batch_size=batch_size))
50
+
51
+ print("Fitting logistic regression classifier...")
52
+ clf.fit(X_train, y_train)
53
+ print("Evaluating...")
54
+ y_pred = clf.predict(X_test)
55
+
56
+ accuracy = accuracy_score(y_test, y_pred)
57
+ scores["accuracy"] = accuracy
58
+ f1 = f1_score(y_test, y_pred, average="macro")
59
+ scores["f1"] = f1
60
+
61
+ print(scores)
62
+ # {'accuracy': 0.891044776119403, 'f1': 0.8283106625713033}
llm2vec/examples/clustering.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sklearn
2
+ import sklearn.cluster
3
+
4
+ import datasets
5
+ import tqdm
6
+ import numpy as np
7
+
8
+ import torch
9
+ from llm2vec import LLM2Vec
10
+
11
+ dataset = "mteb/twentynewsgroups-clustering"
12
+ instruction = "Identify the topic or theme of the given news articles: "
13
+
14
+ dataset = datasets.load_dataset(dataset)
15
+ batch_size = 32
16
+
17
+ print("Loading model...")
18
+ model = LLM2Vec.from_pretrained(
19
+ "McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp",
20
+ peft_model_name_or_path="McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp-supervised",
21
+ device_map="cuda" if torch.cuda.is_available() else "cpu",
22
+ torch_dtype=torch.bfloat16,
23
+ )
24
+
25
+
26
+ def append_instruction(instruction, sentences):
27
+ new_sentences = []
28
+ for s in sentences:
29
+ new_sentences.append([instruction, s, 0])
30
+ return new_sentences
31
+
32
+
33
+ v_measures = []
34
+ for cluster_set in tqdm.tqdm(dataset["test"], desc="Clustering"):
35
+ sentences = cluster_set["sentences"]
36
+ labels = cluster_set["labels"]
37
+ clustering_batch_size = 500
38
+
39
+ print(f"Encoding {len(sentences)} sentences...")
40
+ new_sentences = append_instruction(instruction, sentences)
41
+ corpus_embeddings = np.asarray(model.encode(new_sentences, batch_size=batch_size))
42
+
43
+ print("Fitting Mini-Batch K-Means model...")
44
+ clustering_model = sklearn.cluster.MiniBatchKMeans(
45
+ n_clusters=len(set(labels)), batch_size=clustering_batch_size
46
+ )
47
+ clustering_model.fit(corpus_embeddings)
48
+ cluster_assignment = clustering_model.labels_
49
+
50
+ print("Evaluating...")
51
+ v_measure = sklearn.metrics.cluster.v_measure_score(labels, cluster_assignment)
52
+ v_measures.append(v_measure)
53
+
54
+ v_mean = np.mean(v_measures)
55
+ v_std = np.std(v_measures)
56
+
57
+ print(v_mean)
58
+ # 0.5137461051538426
llm2vec/examples/retrieval.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datasets
2
+ import torch
3
+ from llm2vec import LLM2Vec
4
+ from beir import util
5
+ from beir.datasets.data_loader import GenericDataLoader as BeirDataLoader
6
+ import os
7
+ from typing import Dict, List
8
+
9
+ from beir.retrieval.evaluation import EvaluateRetrieval
10
+
11
+ dataset = "arguana"
12
+ instruction = "Given a claim, find documents that refute the claim: "
13
+
14
+ print("Loading dataset...")
15
+ url = (
16
+ f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset}.zip"
17
+ )
18
+ download_path = os.path.join(datasets.config.HF_DATASETS_CACHE, "BeIR")
19
+ data_path = util.download_and_unzip(url, download_path)
20
+ corpus, queries, relevant_docs = BeirDataLoader(data_folder=data_path).load(
21
+ split="test"
22
+ )
23
+ batch_size = 8
24
+
25
+ print("Loading model...")
26
+ model = LLM2Vec.from_pretrained(
27
+ "McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp",
28
+ peft_model_name_or_path="McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp-supervised",
29
+ device_map="cuda" if torch.cuda.is_available() else "cpu",
30
+ torch_dtype=torch.bfloat16,
31
+ )
32
+
33
+
34
+ def append_instruction(instruction, sentences):
35
+ new_sentences = []
36
+ for s in sentences:
37
+ new_sentences.append([instruction, s, 0])
38
+ return new_sentences
39
+
40
+
41
+ def cos_sim(a: torch.Tensor, b: torch.Tensor):
42
+ if not isinstance(a, torch.Tensor):
43
+ a = torch.tensor(a)
44
+
45
+ if not isinstance(b, torch.Tensor):
46
+ b = torch.tensor(b)
47
+
48
+ if len(a.shape) == 1:
49
+ a = a.unsqueeze(0)
50
+
51
+ if len(b.shape) == 1:
52
+ b = b.unsqueeze(0)
53
+
54
+ a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
55
+ b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
56
+ return torch.mm(a_norm, b_norm.transpose(0, 1))
57
+
58
+
59
+ def encode_queries(queries: List[str], batch_size: int, **kwargs):
60
+ new_sentences = append_instruction(instruction, queries)
61
+
62
+ kwargs["show_progress_bar"] = False
63
+ return model.encode(new_sentences, batch_size=batch_size, **kwargs)
64
+
65
+
66
+ def encode_corpus(corpus: List[Dict[str, str]], batch_size: int, **kwargs):
67
+ if type(corpus) is dict:
68
+ sentences = [
69
+ (
70
+ (corpus["title"][i] + " " + corpus["text"][i]).strip()
71
+ if "title" in corpus
72
+ else corpus["text"][i].strip()
73
+ )
74
+ for i in range(len(corpus["text"]))
75
+ ]
76
+ else:
77
+ sentences = [
78
+ (
79
+ (doc["title"] + " " + doc["text"]).strip()
80
+ if "title" in doc
81
+ else doc["text"].strip()
82
+ )
83
+ for doc in corpus
84
+ ]
85
+ new_sentences = append_instruction("", sentences)
86
+ return model.encode(new_sentences, batch_size=batch_size, **kwargs)
87
+
88
+
89
+ print("Encoding Queries...")
90
+ query_ids = list(queries.keys())
91
+ results = {qid: {} for qid in query_ids}
92
+ queries = [queries[qid] for qid in queries]
93
+ query_embeddings = encode_queries(
94
+ queries, batch_size=batch_size, show_progress_bar=True, convert_to_tensor=True
95
+ )
96
+
97
+ print("Sorting Corpus by document length (Longest first)...")
98
+ corpus_ids = sorted(
99
+ corpus,
100
+ key=lambda k: len(corpus[k].get("title", "") + corpus[k].get("text", "")),
101
+ reverse=True,
102
+ )
103
+ corpus = [corpus[cid] for cid in corpus_ids]
104
+
105
+ print("Encoding Corpus ... Warning: This might take a while!")
106
+ corpus_embeddings = encode_corpus(
107
+ corpus, batch_size=batch_size, show_progress_bar=True, convert_to_tensor=True
108
+ )
109
+
110
+ print("Scoring Function: {} ({})".format("Cosine Similarity", "cos_sim"))
111
+ cos_scores = cos_sim(query_embeddings, corpus_embeddings)
112
+ cos_scores[torch.isnan(cos_scores)] = -1
113
+
114
+ # Get top-k values
115
+ top_k = 1000
116
+ cos_scores_top_k_values, cos_scores_top_k_idx = torch.topk(
117
+ cos_scores, min(top_k + 1, len(cos_scores[0])), dim=1, largest=True, sorted=False
118
+ )
119
+ cos_scores_top_k_values = cos_scores_top_k_values.cpu().tolist()
120
+ cos_scores_top_k_idx = cos_scores_top_k_idx.cpu().tolist()
121
+
122
+ for query_itr in range(len(query_embeddings)):
123
+ query_id = query_ids[query_itr]
124
+ for sub_corpus_id, score in zip(
125
+ cos_scores_top_k_idx[query_itr], cos_scores_top_k_values[query_itr]
126
+ ):
127
+ corpus_id = corpus_ids[sub_corpus_id]
128
+ if corpus_id != query_id:
129
+ results[query_id][corpus_id] = score
130
+
131
+ retriever = EvaluateRetrieval(model, score_function="cos_sim")
132
+ ndcg, _map, recall, precision = retriever.evaluate(
133
+ relevant_docs, results, retriever.k_values
134
+ )
135
+ mrr = retriever.evaluate_custom(relevant_docs, results, retriever.k_values, "mrr")
136
+
137
+ scores = {
138
+ **{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()},
139
+ **{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()},
140
+ **{f"recall_at_{k.split('@')[1]}": v for (k, v) in recall.items()},
141
+ **{f"precision_at_{k.split('@')[1]}": v for (k, v) in precision.items()},
142
+ **{f"mrr_at_{k.split('@')[1]}": v for (k, v) in mrr.items()},
143
+ }
144
+ print(scores)
145
+ """
146
+ {
147
+ 'ndcg_at_1': 0.32788,
148
+ 'ndcg_at_3': 0.47534,
149
+ 'ndcg_at_5': 0.52296,
150
+ 'ndcg_at_10': 0.57505,
151
+ 'ndcg_at_100': 0.6076,
152
+ 'ndcg_at_1000': 0.60801,
153
+ 'map_at_1': 0.32788,
154
+ 'map_at_3': 0.43883,
155
+ 'map_at_5': 0.46518,
156
+ 'map_at_10': 0.48675,
157
+ 'map_at_100': 0.49506,
158
+ 'map_at_1000': 0.49509,
159
+ 'recall_at_1': 0.32788,
160
+ 'recall_at_3': 0.58108,
161
+ 'recall_at_5': 0.69701,
162
+ 'recall_at_10': 0.85775,
163
+ 'recall_at_100': 0.9936,
164
+ 'recall_at_1000': 0.99644,
165
+ 'precision_at_1': 0.32788,
166
+ 'precision_at_3': 0.19369,
167
+ 'precision_at_5': 0.1394,
168
+ 'precision_at_10': 0.08578,
169
+ 'precision_at_100': 0.00994,
170
+ 'precision_at_1000': 0.001,
171
+ 'mrr_at_1': 0.33357,
172
+ 'mrr_at_3': 0.44085,
173
+ 'mrr_at_5': 0.46745,
174
+ 'mrr_at_10': 0.4888,
175
+ 'mrr_at_100': 0.49718,
176
+ 'mrr_at_1000': 0.49721}
177
+ """
llm2vec/examples/sts.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datasets
2
+ import numpy as np
3
+ from sklearn.metrics.pairwise import paired_cosine_distances
4
+ from scipy.stats import spearmanr
5
+
6
+ import torch
7
+ from llm2vec import LLM2Vec
8
+
9
+
10
+ dataset = "mteb/sts17-crosslingual-sts"
11
+ instruction = "Retrieve semantically similar text: "
12
+
13
+ dataset = datasets.load_dataset(dataset, "en-en")
14
+
15
+ min_score, max_score = 0, 5
16
+ normalize = lambda x: (x - min_score) / (max_score - min_score)
17
+ normalized_scores = list(map(normalize, dataset["test"]["score"]))
18
+ batch_size = 8
19
+
20
+ sentences1, sentences2 = dataset["test"]["sentence1"], dataset["test"]["sentence2"]
21
+
22
+ print("Loading model...")
23
+ model = LLM2Vec.from_pretrained(
24
+ "McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp",
25
+ peft_model_name_or_path="McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp-supervised",
26
+ device_map="cuda" if torch.cuda.is_available() else "cpu",
27
+ torch_dtype=torch.bfloat16,
28
+ )
29
+
30
+
31
+ def append_instruction(instruction, sentences):
32
+ new_sentences = []
33
+ for s in sentences:
34
+ new_sentences.append([instruction, s, 0])
35
+ return new_sentences
36
+
37
+
38
+ print(f"Encoding {len(sentences1)} sentences1...")
39
+ sentences1 = append_instruction(instruction, sentences1)
40
+ embeddings1 = np.asarray(model.encode(sentences1, batch_size=batch_size))
41
+
42
+ print(f"Encoding {len(sentences2)} sentences2...")
43
+ sentences2 = append_instruction(instruction, sentences2)
44
+ embeddings2 = np.asarray(model.encode(sentences2, batch_size=batch_size))
45
+
46
+ print("Evaluating...")
47
+ cosine_scores = 1 - (paired_cosine_distances(embeddings1, embeddings2))
48
+ cosine_spearman, _ = spearmanr(normalized_scores, cosine_scores)
49
+
50
+ results = {
51
+ "cos_sim": {
52
+ "spearman": cosine_spearman,
53
+ }
54
+ }
55
+
56
+ print(results)
57
+ # {'cos_sim': {'spearman': 0.9021906216635642}}
llm2vec/experiments/mteb_eval.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import mteb
3
+ import json
4
+
5
+ if __name__ == "__main__":
6
+ parser = argparse.ArgumentParser()
7
+ parser.add_argument(
8
+ "--model_name",
9
+ type=str,
10
+ default="McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp-supervised",
11
+ )
12
+ parser.add_argument("--task_name", type=str, default="STS16")
13
+ parser.add_argument(
14
+ "--task_to_instructions_fp",
15
+ type=str,
16
+ default="test_configs/mteb/task_to_instructions.json",
17
+ )
18
+ parser.add_argument("--output_dir", type=str, default="results")
19
+
20
+ args = parser.parse_args()
21
+
22
+ model_kwargs = {}
23
+ if args.task_to_instructions_fp is not None:
24
+ with open(args.task_to_instructions_fp, "r") as f:
25
+ task_to_instructions = json.load(f)
26
+ model_kwargs["task_to_instructions"] = task_to_instructions
27
+
28
+ model = mteb.get_model(args.model_name, **model_kwargs)
29
+ tasks = mteb.get_tasks(tasks=[args.task_name])
30
+ evaluation = mteb.MTEB(tasks=tasks)
31
+ results = evaluation.run(model, output_folder=args.output_dir)
llm2vec/experiments/mteb_eval_custom.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from typing import Any
3
+ import mteb
4
+ import json
5
+ import torch
6
+
7
+ import numpy as np
8
+ from mteb.models.instructions import task_to_instruction
9
+ from mteb.models.text_formatting_utils import corpus_to_texts
10
+
11
+ from llm2vec import LLM2Vec
12
+
13
+ def llm2vec_instruction(instruction):
14
+ if len(instruction) > 0 and instruction[-1] != ":":
15
+ instruction = instruction.strip(".") + ":"
16
+ return instruction
17
+
18
+
19
+ class LLM2VecWrapper:
20
+ def __init__(self, model=None, task_to_instructions=None):
21
+
22
+ self.task_to_instructions = task_to_instructions
23
+ self.model = model
24
+
25
+ def encode(
26
+ self,
27
+ sentences: list[str],
28
+ *,
29
+ prompt_name: str = None,
30
+ **kwargs: Any, # noqa
31
+ ) -> np.ndarray:
32
+ if prompt_name is not None:
33
+ instruction = (
34
+ self.task_to_instructions[prompt_name]
35
+ if self.task_to_instructions
36
+ and prompt_name in self.task_to_instructions
37
+ else llm2vec_instruction(task_to_instruction(prompt_name))
38
+ )
39
+ else:
40
+ instruction = ""
41
+
42
+ sentences = [[instruction, sentence] for sentence in sentences]
43
+ return self.model.encode(sentences, **kwargs)
44
+
45
+ def encode_corpus(
46
+ self,
47
+ corpus: list[dict[str, str]] | dict[str, list[str]] | list[str],
48
+ prompt_name: str = None,
49
+ **kwargs: Any,
50
+ ) -> np.ndarray:
51
+ sentences = corpus_to_texts(corpus, sep=" ")
52
+ sentences = [["", sentence] for sentence in sentences]
53
+ if "request_qid" in kwargs:
54
+ kwargs.pop("request_qid")
55
+ return self.model.encode(sentences, **kwargs)
56
+
57
+ def encode_queries(self, queries: list[str], **kwargs: Any) -> np.ndarray:
58
+ return self.encode(queries, **kwargs)
59
+
60
+
61
+ if __name__ == "__main__":
62
+ parser = argparse.ArgumentParser()
63
+ parser.add_argument(
64
+ "--base_model_name_or_path",
65
+ type=str,
66
+ default="McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp",
67
+ )
68
+ parser.add_argument(
69
+ "--peft_model_name_or_path",
70
+ type=str,
71
+ default="McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp-supervised",
72
+ )
73
+ parser.add_argument("--task_name", type=str, default="STS16")
74
+ parser.add_argument(
75
+ "--task_to_instructions_fp",
76
+ type=str,
77
+ default="test_configs/mteb/task_to_instructions.json",
78
+ )
79
+ parser.add_argument("--output_dir", type=str, default="results")
80
+
81
+ args = parser.parse_args()
82
+
83
+ task_to_instructions = None
84
+ if args.task_to_instructions_fp is not None:
85
+ with open(args.task_to_instructions_fp, "r") as f:
86
+ task_to_instructions = json.load(f)
87
+
88
+ l2v_model = LLM2Vec.from_pretrained(
89
+ args.base_model_name_or_path,
90
+ peft_model_name_or_path=args.peft_model_name_or_path,
91
+ device_map="cuda" if torch.cuda.is_available() else "cpu",
92
+ torch_dtype=torch.bfloat16,
93
+ )
94
+
95
+ model = LLM2VecWrapper(model=l2v_model, task_to_instructions=task_to_instructions)
96
+ tasks = mteb.get_tasks(tasks=[args.task_name])
97
+ evaluation = mteb.MTEB(tasks=tasks)
98
+ results = evaluation.run(model, output_folder=args.output_dir)
llm2vec/experiments/run_mntp.py ADDED
@@ -0,0 +1,997 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2020 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ The script is adapted from https://github.com/huggingface/transformers/blob/51bcadc10a569847b93a30dbe3a077037ae63bad/examples/pytorch/language-modeling/run_mlm.py
18
+ """
19
+
20
+ import logging
21
+ import math
22
+ import os
23
+ import sys
24
+ import warnings
25
+ from dataclasses import dataclass, field
26
+ from itertools import chain
27
+ from typing import Optional, Any, Tuple, List
28
+ import numpy as np
29
+
30
+ import datasets
31
+ import evaluate
32
+ from datasets import load_dataset
33
+
34
+ import torch
35
+ import transformers
36
+ from transformers import (
37
+ CONFIG_MAPPING,
38
+ MODEL_FOR_MASKED_LM_MAPPING,
39
+ AutoConfig,
40
+ AutoTokenizer,
41
+ DataCollatorForLanguageModeling,
42
+ HfArgumentParser,
43
+ Trainer,
44
+ TrainingArguments,
45
+ TrainerCallback,
46
+ is_torch_tpu_available,
47
+ set_seed,
48
+ )
49
+ from transformers.trainer_utils import get_last_checkpoint
50
+ from transformers.utils import send_example_telemetry
51
+ from transformers.utils.versions import require_version
52
+
53
+ from peft import LoraConfig, get_peft_model
54
+
55
+ from llm2vec.models import (
56
+ MistralBiForMNTP,
57
+ LlamaBiForMNTP,
58
+ GemmaBiForMNTP,
59
+ Qwen2BiForMNTP,
60
+ )
61
+
62
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
63
+ # check_min_version("4.38.0.dev0")
64
+
65
+ require_version(
66
+ "datasets>=1.8.0",
67
+ "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt",
68
+ )
69
+
70
+ logger = logging.getLogger(__name__)
71
+ MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
72
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
73
+
74
+
75
+ def get_model_class(config):
76
+ config_class_name = config.__class__.__name__
77
+ if config_class_name == "MistralConfig":
78
+ return MistralBiForMNTP
79
+ elif config_class_name == "LlamaConfig":
80
+ return LlamaBiForMNTP
81
+ elif config_class_name == "GemmaConfig":
82
+ return GemmaBiForMNTP
83
+ elif config_class_name == "Qwen2Config":
84
+ return Qwen2BiForMNTP
85
+ else:
86
+ raise ValueError(f"Model class {config_class_name} not supported.")
87
+
88
+
89
+ def initialize_peft(
90
+ model,
91
+ lora_r: int = 8,
92
+ lora_alpha: int = 16,
93
+ lora_dropout: float = 0.05,
94
+ lora_modules: Optional[List[str]] = None,
95
+ ):
96
+ if lora_modules is None and model.config.__class__.__name__ in [
97
+ "LlamaConfig",
98
+ "MistralConfig",
99
+ "GemmaConfig",
100
+ "Qwen2Config",
101
+ ]:
102
+ lora_modules = [
103
+ "q_proj",
104
+ "v_proj",
105
+ "k_proj",
106
+ "o_proj",
107
+ "gate_proj",
108
+ "up_proj",
109
+ "down_proj",
110
+ ]
111
+ elif lora_modules is None:
112
+ raise ValueError("lora_modules must be specified for this model.")
113
+
114
+ config = LoraConfig(
115
+ r=lora_r,
116
+ lora_alpha=lora_alpha,
117
+ target_modules=lora_modules,
118
+ lora_dropout=lora_dropout,
119
+ bias="none",
120
+ task_type=None,
121
+ )
122
+
123
+ model = get_peft_model(model, config)
124
+ print(f"Model's Lora trainable parameters:")
125
+ model.print_trainable_parameters()
126
+ return model
127
+
128
+
129
+ @dataclass
130
+ class ModelArguments:
131
+ """
132
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
133
+ """
134
+
135
+ model_name_or_path: Optional[str] = field(
136
+ default=None,
137
+ metadata={
138
+ "help": (
139
+ "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
140
+ )
141
+ },
142
+ )
143
+ model_type: Optional[str] = field(
144
+ default=None,
145
+ metadata={
146
+ "help": "If training from scratch, pass a model type from the list: "
147
+ + ", ".join(MODEL_TYPES)
148
+ },
149
+ )
150
+ config_overrides: Optional[str] = field(
151
+ default=None,
152
+ metadata={
153
+ "help": (
154
+ "Override some existing default config settings when a model is trained from scratch. Example: "
155
+ "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
156
+ )
157
+ },
158
+ )
159
+ config_name: Optional[str] = field(
160
+ default=None,
161
+ metadata={
162
+ "help": "Pretrained config name or path if not the same as model_name"
163
+ },
164
+ )
165
+ tokenizer_name: Optional[str] = field(
166
+ default=None,
167
+ metadata={
168
+ "help": "Pretrained tokenizer name or path if not the same as model_name"
169
+ },
170
+ )
171
+ cache_dir: Optional[str] = field(
172
+ default=None,
173
+ metadata={
174
+ "help": "Where do you want to store the pretrained models downloaded from huggingface.co"
175
+ },
176
+ )
177
+ use_fast_tokenizer: bool = field(
178
+ default=True,
179
+ metadata={
180
+ "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
181
+ },
182
+ )
183
+ model_revision: str = field(
184
+ default="main",
185
+ metadata={
186
+ "help": "The specific model version to use (can be a branch name, tag name or commit id)."
187
+ },
188
+ )
189
+ token: str = field(
190
+ default=None,
191
+ metadata={
192
+ "help": (
193
+ "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
194
+ "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
195
+ )
196
+ },
197
+ )
198
+ use_auth_token: bool = field(
199
+ default=None,
200
+ metadata={
201
+ "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."
202
+ },
203
+ )
204
+ trust_remote_code: bool = field(
205
+ default=False,
206
+ metadata={
207
+ "help": (
208
+ "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
209
+ "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
210
+ "execute code present on the Hub on your local machine."
211
+ )
212
+ },
213
+ )
214
+ torch_dtype: Optional[str] = field(
215
+ default=None,
216
+ metadata={
217
+ "help": (
218
+ "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
219
+ "dtype will be automatically derived from the model's weights."
220
+ ),
221
+ "choices": ["auto", "bfloat16", "float16", "float32"],
222
+ },
223
+ )
224
+ attn_implementation: Optional[str] = field(
225
+ default="sdpa",
226
+ metadata={
227
+ "help": ("The attention implementation to use in the model."),
228
+ "choices": ["eager", "sdpa", "flash_attention_2"],
229
+ },
230
+ )
231
+ low_cpu_mem_usage: bool = field(
232
+ default=False,
233
+ metadata={
234
+ "help": (
235
+ "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded. "
236
+ "set True will benefit LLM loading time and RAM consumption."
237
+ )
238
+ },
239
+ )
240
+
241
+ def __post_init__(self):
242
+ if self.config_overrides is not None and (
243
+ self.config_name is not None or self.model_name_or_path is not None
244
+ ):
245
+ raise ValueError(
246
+ "--config_overrides can't be used in combination with --config_name or --model_name_or_path"
247
+ )
248
+
249
+
250
+ @dataclass
251
+ class DataTrainingArguments:
252
+ """
253
+ Arguments pertaining to what data we are going to input our model for training and eval.
254
+ """
255
+
256
+ dataset_name: Optional[str] = field(
257
+ default=None,
258
+ metadata={"help": "The name of the dataset to use (via the datasets library)."},
259
+ )
260
+ dataset_config_name: Optional[str] = field(
261
+ default=None,
262
+ metadata={
263
+ "help": "The configuration name of the dataset to use (via the datasets library)."
264
+ },
265
+ )
266
+ train_file: Optional[str] = field(
267
+ default=None, metadata={"help": "The input training data file (a text file)."}
268
+ )
269
+ validation_file: Optional[str] = field(
270
+ default=None,
271
+ metadata={
272
+ "help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."
273
+ },
274
+ )
275
+ overwrite_cache: bool = field(
276
+ default=True,
277
+ metadata={"help": "Overwrite the cached training and evaluation sets"},
278
+ )
279
+ validation_split_percentage: Optional[int] = field(
280
+ default=5,
281
+ metadata={
282
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
283
+ },
284
+ )
285
+ max_seq_length: Optional[int] = field(
286
+ default=None,
287
+ metadata={
288
+ "help": (
289
+ "The maximum total input sequence length after tokenization. Sequences longer "
290
+ "than this will be truncated."
291
+ )
292
+ },
293
+ )
294
+ preprocessing_num_workers: Optional[int] = field(
295
+ default=None,
296
+ metadata={"help": "The number of processes to use for the preprocessing."},
297
+ )
298
+ mlm_probability: float = field(
299
+ default=0.15,
300
+ metadata={"help": "Ratio of tokens to mask for masked language modeling loss"},
301
+ )
302
+ line_by_line: bool = field(
303
+ default=False,
304
+ metadata={
305
+ "help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."
306
+ },
307
+ )
308
+ pad_to_max_length: bool = field(
309
+ default=False,
310
+ metadata={
311
+ "help": (
312
+ "Whether to pad all samples to `max_seq_length`. "
313
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
314
+ )
315
+ },
316
+ )
317
+ max_train_samples: Optional[int] = field(
318
+ default=None,
319
+ metadata={
320
+ "help": (
321
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
322
+ "value if set."
323
+ )
324
+ },
325
+ )
326
+ max_eval_samples: Optional[int] = field(
327
+ default=None,
328
+ metadata={
329
+ "help": (
330
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
331
+ "value if set."
332
+ )
333
+ },
334
+ )
335
+ streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"})
336
+
337
+ def __post_init__(self):
338
+ if self.streaming:
339
+ require_version(
340
+ "datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`"
341
+ )
342
+
343
+ if (
344
+ self.dataset_name is None
345
+ and self.train_file is None
346
+ and self.validation_file is None
347
+ ):
348
+ raise ValueError(
349
+ "Need either a dataset name or a training/validation file."
350
+ )
351
+ else:
352
+ if self.train_file is not None:
353
+ extension = self.train_file.split(".")[-1]
354
+ if extension not in ["csv", "json", "txt"]:
355
+ raise ValueError(
356
+ "`train_file` should be a csv, a json or a txt file."
357
+ )
358
+ if self.validation_file is not None:
359
+ extension = self.validation_file.split(".")[-1]
360
+ if extension not in ["csv", "json", "txt"]:
361
+ raise ValueError(
362
+ "`validation_file` should be a csv, a json or a txt file."
363
+ )
364
+
365
+
366
+ # add more arguments
367
+ @dataclass
368
+ class CustomArguments:
369
+ """
370
+ Custom arguments for the script
371
+ """
372
+
373
+ lora_dropout: float = field(
374
+ default=0.05, metadata={"help": "The dropout rate for lora"}
375
+ )
376
+
377
+ lora_r: int = field(default=8, metadata={"help": "The r value for lora"})
378
+
379
+ mask_token_type: str = field(
380
+ default="blank",
381
+ metadata={"help": "The type of mask token. Options: blank, eos, mask"},
382
+ )
383
+
384
+ stop_after_n_steps: int = field(
385
+ default=10000, metadata={"help": "Stop training after n steps"}
386
+ )
387
+
388
+ data_collator_type: str = field(
389
+ default="default",
390
+ metadata={"help": "The type of data collator. Options: default, all_mask"},
391
+ )
392
+
393
+
394
+ class DataCollatorForLanguageModelingWithFullMasking(DataCollatorForLanguageModeling):
395
+ def torch_mask_tokens(
396
+ self,
397
+ inputs: Any,
398
+ special_tokens_mask: Optional[Any] = None,
399
+ ) -> Tuple[Any, Any]:
400
+ """
401
+ Prepare masked tokens inputs/labels for masked language modeling: 100% MASK, 0% random, 0% original.
402
+ """
403
+ import torch
404
+
405
+ labels = inputs.clone()
406
+ # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
407
+ probability_matrix = torch.full(labels.shape, self.mlm_probability)
408
+ if special_tokens_mask is None:
409
+ special_tokens_mask = [
410
+ self.tokenizer.get_special_tokens_mask(
411
+ val, already_has_special_tokens=True
412
+ )
413
+ for val in labels.tolist()
414
+ ]
415
+ special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
416
+ else:
417
+ special_tokens_mask = special_tokens_mask.bool()
418
+
419
+ probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
420
+ masked_indices = torch.bernoulli(probability_matrix).bool()
421
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
422
+
423
+ # 100% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
424
+ inputs[masked_indices] = self.tokenizer.convert_tokens_to_ids(
425
+ self.tokenizer.mask_token
426
+ )
427
+
428
+ return inputs, labels
429
+
430
+
431
+ class StopTrainingCallback(TrainerCallback):
432
+ def __init__(self, stop_after_n_steps: int):
433
+ self.stop_after_n_steps = stop_after_n_steps
434
+
435
+ def on_step_end(self, args, state, control, **kwargs):
436
+ if state.global_step >= self.stop_after_n_steps:
437
+ control.should_training_stop = True
438
+
439
+
440
+ class MNTPTrainer(Trainer):
441
+ def __init__(self, *args, **kwargs):
442
+ super().__init__(*args, **kwargs)
443
+ self.label_names = ["labels"]
444
+
445
+ def _remove_unused_columns(
446
+ self, dataset: "datasets.Dataset", description: Optional[str] = None
447
+ ):
448
+ return dataset
449
+
450
+ # We need a custom save function as we have to save the inner model
451
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
452
+ # If we are executing this function, we are the process zero, so we don't check for that.
453
+ output_dir = output_dir if output_dir is not None else self.args.output_dir
454
+ os.makedirs(output_dir, exist_ok=True)
455
+ logger.info(f"Saving model checkpoint to {output_dir}")
456
+
457
+ # model organization is MODEL_TYPEBiForMNTP.model -> MODEL_TYPELBiModel, we have to save the inner model, handled by save_peft_model function of the outer model
458
+ self.model.save_peft_model(output_dir)
459
+ self.tokenizer.save_pretrained(output_dir)
460
+
461
+ # Good practice: save your training arguments together with the trained model
462
+ torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
463
+
464
+
465
+ def main():
466
+ # See all possible arguments in src/transformers/training_args.py
467
+ # or by passing the --help flag to this script.
468
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
469
+
470
+ parser = HfArgumentParser(
471
+ (ModelArguments, DataTrainingArguments, TrainingArguments, CustomArguments)
472
+ )
473
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
474
+ # If we pass only one argument to the script and it's the path to a json file,
475
+ # let's parse it to get our arguments.
476
+ model_args, data_args, training_args, custom_args = parser.parse_json_file(
477
+ json_file=os.path.abspath(sys.argv[1])
478
+ )
479
+ else:
480
+ (
481
+ model_args,
482
+ data_args,
483
+ training_args,
484
+ custom_args,
485
+ ) = parser.parse_args_into_dataclasses()
486
+
487
+ if training_args.gradient_checkpointing:
488
+ training_args.gradient_checkpointing_kwargs = {"use_reentrant": False}
489
+
490
+ if model_args.use_auth_token is not None:
491
+ warnings.warn(
492
+ "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead.",
493
+ FutureWarning,
494
+ )
495
+ if model_args.token is not None:
496
+ raise ValueError(
497
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
498
+ )
499
+ model_args.token = model_args.use_auth_token
500
+
501
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
502
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
503
+ send_example_telemetry("run_mlm", model_args, data_args)
504
+
505
+ # Setup logging
506
+ logging.basicConfig(
507
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
508
+ datefmt="%m/%d/%Y %H:%M:%S",
509
+ handlers=[logging.StreamHandler(sys.stdout)],
510
+ )
511
+
512
+ if training_args.should_log:
513
+ # The default of training_args.log_level is passive, so we set log level at info here to have that default.
514
+ transformers.utils.logging.set_verbosity_info()
515
+
516
+ log_level = training_args.get_process_log_level()
517
+ logger.setLevel(log_level)
518
+ datasets.utils.logging.set_verbosity(log_level)
519
+ transformers.utils.logging.set_verbosity(log_level)
520
+ transformers.utils.logging.enable_default_handler()
521
+ transformers.utils.logging.enable_explicit_format()
522
+
523
+ # Log on each process the small summary:
524
+ logger.warning(
525
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
526
+ + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
527
+ )
528
+ # Set the verbosity to info of the Transformers logger (on main process only):
529
+ logger.info(f"Training/evaluation parameters {training_args}")
530
+
531
+ # Detecting last checkpoint.
532
+ last_checkpoint = None
533
+ if (
534
+ os.path.isdir(training_args.output_dir)
535
+ and training_args.do_train
536
+ and not training_args.overwrite_output_dir
537
+ ):
538
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
539
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
540
+ raise ValueError(
541
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
542
+ "Use --overwrite_output_dir to overcome."
543
+ )
544
+ elif (
545
+ last_checkpoint is not None and training_args.resume_from_checkpoint is None
546
+ ):
547
+ logger.info(
548
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
549
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
550
+ )
551
+
552
+ # Set seed before initializing model.
553
+ set_seed(training_args.seed)
554
+
555
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
556
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
557
+ # (the dataset will be downloaded automatically from the datasets Hub
558
+ #
559
+ # For CSV/JSON files, this script will use the column called 'text' or the first column. You can easily tweak this
560
+ # behavior (see below)
561
+ #
562
+ # In distributed training, the load_dataset function guarantee that only one local process can concurrently
563
+ # download the dataset.
564
+ if data_args.dataset_name is not None:
565
+ # Downloading and loading a dataset from the hub.
566
+ raw_datasets = load_dataset(
567
+ data_args.dataset_name,
568
+ data_args.dataset_config_name,
569
+ cache_dir=model_args.cache_dir,
570
+ token=model_args.token,
571
+ streaming=data_args.streaming,
572
+ )
573
+ if "validation" not in raw_datasets.keys():
574
+ raw_datasets["validation"] = load_dataset(
575
+ data_args.dataset_name,
576
+ data_args.dataset_config_name,
577
+ split=f"train[:{data_args.validation_split_percentage}%]",
578
+ cache_dir=model_args.cache_dir,
579
+ token=model_args.token,
580
+ streaming=data_args.streaming,
581
+ )
582
+ raw_datasets["train"] = load_dataset(
583
+ data_args.dataset_name,
584
+ data_args.dataset_config_name,
585
+ split=f"train[{data_args.validation_split_percentage}%:]",
586
+ cache_dir=model_args.cache_dir,
587
+ token=model_args.token,
588
+ streaming=data_args.streaming,
589
+ )
590
+ else:
591
+ data_files = {}
592
+ if data_args.train_file is not None:
593
+ data_files["train"] = data_args.train_file
594
+ extension = data_args.train_file.split(".")[-1]
595
+ if data_args.validation_file is not None:
596
+ data_files["validation"] = data_args.validation_file
597
+ extension = data_args.validation_file.split(".")[-1]
598
+ if extension == "txt":
599
+ extension = "text"
600
+ raw_datasets = load_dataset(
601
+ extension,
602
+ data_files=data_files,
603
+ cache_dir=model_args.cache_dir,
604
+ token=model_args.token,
605
+ )
606
+
607
+ # If no validation data is there, validation_split_percentage will be used to divide the dataset.
608
+ if "validation" not in raw_datasets.keys():
609
+ raw_datasets["validation"] = load_dataset(
610
+ extension,
611
+ data_files=data_files,
612
+ split=f"train[:{data_args.validation_split_percentage}%]",
613
+ cache_dir=model_args.cache_dir,
614
+ token=model_args.token,
615
+ )
616
+ raw_datasets["train"] = load_dataset(
617
+ extension,
618
+ data_files=data_files,
619
+ split=f"train[{data_args.validation_split_percentage}%:]",
620
+ cache_dir=model_args.cache_dir,
621
+ token=model_args.token,
622
+ )
623
+
624
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
625
+ # https://huggingface.co/docs/datasets/loading_datasets.
626
+
627
+ # Load pretrained model and tokenizer
628
+ #
629
+ # Distributed training:
630
+ # The .from_pretrained methods guarantee that only one local process can concurrently
631
+ # download model & vocab.
632
+ config_kwargs = {
633
+ "cache_dir": model_args.cache_dir,
634
+ "revision": model_args.model_revision,
635
+ "token": model_args.token,
636
+ "trust_remote_code": model_args.trust_remote_code,
637
+ }
638
+ if model_args.config_name:
639
+ config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
640
+ elif model_args.model_name_or_path:
641
+ config = AutoConfig.from_pretrained(
642
+ model_args.model_name_or_path, **config_kwargs
643
+ )
644
+ else:
645
+ config = CONFIG_MAPPING[model_args.model_type]()
646
+ logger.warning("You are instantiating a new config instance from scratch.")
647
+ if model_args.config_overrides is not None:
648
+ logger.info(f"Overriding config: {model_args.config_overrides}")
649
+ config.update_from_string(model_args.config_overrides)
650
+ logger.info(f"New config: {config}")
651
+
652
+ tokenizer_kwargs = {
653
+ "cache_dir": model_args.cache_dir,
654
+ "use_fast": model_args.use_fast_tokenizer,
655
+ "revision": model_args.model_revision,
656
+ "token": model_args.token,
657
+ "trust_remote_code": model_args.trust_remote_code,
658
+ }
659
+ if model_args.tokenizer_name:
660
+ tokenizer = AutoTokenizer.from_pretrained(
661
+ model_args.tokenizer_name, **tokenizer_kwargs
662
+ )
663
+ elif model_args.model_name_or_path:
664
+ tokenizer = AutoTokenizer.from_pretrained(
665
+ model_args.model_name_or_path, **tokenizer_kwargs
666
+ )
667
+ else:
668
+ raise ValueError(
669
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script. "
670
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
671
+ )
672
+
673
+ # blank, eos, mask
674
+ if tokenizer.mask_token is None:
675
+ if custom_args.mask_token_type == "blank":
676
+ tokenizer.mask_token = "_"
677
+ elif custom_args.mask_token_type == "eos":
678
+ tokenizer.mask_token = tokenizer.eos_token
679
+ elif custom_args.mask_token_type == "mask":
680
+ tokenizer.add_tokens(["<mask>"])
681
+ tokenizer.mask_token = "<mask>"
682
+ else:
683
+ raise ValueError(
684
+ f"mask_token_type {custom_args.mask_token_type} is not supported."
685
+ )
686
+
687
+ if tokenizer.pad_token is None:
688
+ tokenizer.pad_token = tokenizer.eos_token
689
+
690
+ # Loading bidirectional model using LLM2Vec package
691
+ model_class = get_model_class(config)
692
+ torch_dtype = (
693
+ model_args.torch_dtype
694
+ if model_args.torch_dtype in ["auto", None]
695
+ else getattr(torch, model_args.torch_dtype)
696
+ )
697
+ model = model_class.from_pretrained(
698
+ model_args.model_name_or_path,
699
+ from_tf=bool(".ckpt" in model_args.model_name_or_path),
700
+ config=config,
701
+ cache_dir=model_args.cache_dir,
702
+ revision=model_args.model_revision,
703
+ token=model_args.token,
704
+ trust_remote_code=model_args.trust_remote_code,
705
+ torch_dtype=torch_dtype,
706
+ low_cpu_mem_usage=model_args.low_cpu_mem_usage,
707
+ attn_implementation=model_args.attn_implementation,
708
+ )
709
+
710
+ # model organization is MODEL_TYPEBiForMNTP.model -> MODEL_TYPELBiModel, we have to apply PEFT to the inner model
711
+ model.model = initialize_peft(
712
+ model.model,
713
+ lora_r=custom_args.lora_r,
714
+ lora_alpha=2 * custom_args.lora_r,
715
+ lora_dropout=custom_args.lora_dropout,
716
+ )
717
+
718
+ # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
719
+ # on a small vocab and want a smaller embedding size, remove this test.
720
+ embedding_size = model.get_input_embeddings().weight.shape[0]
721
+ if len(tokenizer) > embedding_size:
722
+ model.resize_token_embeddings(len(tokenizer))
723
+
724
+ # Preprocessing the datasets.
725
+ # First we tokenize all the texts.
726
+ if training_args.do_train:
727
+ column_names = list(raw_datasets["train"].features)
728
+ else:
729
+ column_names = list(raw_datasets["validation"].features)
730
+ text_column_name = "text" if "text" in column_names else column_names[0]
731
+
732
+ if data_args.max_seq_length is None:
733
+ max_seq_length = tokenizer.model_max_length
734
+ if max_seq_length > 1024:
735
+ logger.warning(
736
+ "The chosen tokenizer supports a `model_max_length` that is longer than the default `block_size` value"
737
+ " of 1024. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can"
738
+ " override this default with `--block_size xxx`."
739
+ )
740
+ max_seq_length = 1024
741
+ else:
742
+ if data_args.max_seq_length > tokenizer.model_max_length:
743
+ logger.warning(
744
+ f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the "
745
+ f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
746
+ )
747
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
748
+
749
+ if data_args.line_by_line:
750
+ # When using line_by_line, we just tokenize each nonempty line.
751
+ padding = "max_length" if data_args.pad_to_max_length else False
752
+
753
+ def tokenize_function(examples):
754
+ # Remove empty lines
755
+ examples[text_column_name] = [
756
+ line
757
+ for line in examples[text_column_name]
758
+ if len(line) > 0 and not line.isspace()
759
+ ]
760
+ return tokenizer(
761
+ examples[text_column_name],
762
+ padding=padding,
763
+ truncation=True,
764
+ max_length=max_seq_length,
765
+ # We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it
766
+ # receives the `special_tokens_mask`.
767
+ return_special_tokens_mask=True,
768
+ )
769
+
770
+ with training_args.main_process_first(desc="dataset map tokenization"):
771
+ if not data_args.streaming:
772
+ tokenized_datasets = raw_datasets.map(
773
+ tokenize_function,
774
+ batched=True,
775
+ num_proc=data_args.preprocessing_num_workers,
776
+ remove_columns=[text_column_name],
777
+ load_from_cache_file=not data_args.overwrite_cache,
778
+ desc="Running tokenizer on dataset line_by_line",
779
+ )
780
+ else:
781
+ tokenized_datasets = raw_datasets.map(
782
+ tokenize_function,
783
+ batched=True,
784
+ remove_columns=[text_column_name],
785
+ )
786
+ else:
787
+ # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
788
+ # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
789
+ # efficient when it receives the `special_tokens_mask`.
790
+ def tokenize_function(examples):
791
+ return tokenizer(
792
+ examples[text_column_name], return_special_tokens_mask=True
793
+ )
794
+
795
+ with training_args.main_process_first(desc="dataset map tokenization"):
796
+ if not data_args.streaming:
797
+ tokenized_datasets = raw_datasets.map(
798
+ tokenize_function,
799
+ batched=True,
800
+ num_proc=data_args.preprocessing_num_workers,
801
+ remove_columns=column_names,
802
+ load_from_cache_file=not data_args.overwrite_cache,
803
+ desc="Running tokenizer on every text in dataset",
804
+ )
805
+ else:
806
+ tokenized_datasets = raw_datasets.map(
807
+ tokenize_function,
808
+ batched=True,
809
+ remove_columns=column_names,
810
+ )
811
+
812
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of
813
+ # max_seq_length.
814
+ def group_texts(examples):
815
+ # Concatenate all texts.
816
+ concatenated_examples = {
817
+ k: list(chain(*examples[k])) for k in examples.keys()
818
+ }
819
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
820
+ # We drop the small remainder, and if the total_length < max_seq_length we exclude this batch and return an empty dict.
821
+ # We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
822
+ total_length = (total_length // max_seq_length) * max_seq_length
823
+ # Split by chunks of max_len.
824
+ result = {
825
+ k: [
826
+ t[i : i + max_seq_length]
827
+ for i in range(0, total_length, max_seq_length)
828
+ ]
829
+ for k, t in concatenated_examples.items()
830
+ }
831
+ return result
832
+
833
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
834
+ # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
835
+ # might be slower to preprocess.
836
+ #
837
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
838
+ # https://huggingface.co/docs/datasets/process#map
839
+
840
+ with training_args.main_process_first(desc="grouping texts together"):
841
+ if not data_args.streaming:
842
+ tokenized_datasets = tokenized_datasets.map(
843
+ group_texts,
844
+ batched=True,
845
+ num_proc=data_args.preprocessing_num_workers,
846
+ load_from_cache_file=not data_args.overwrite_cache,
847
+ desc=f"Grouping texts in chunks of {max_seq_length}",
848
+ )
849
+ else:
850
+ tokenized_datasets = tokenized_datasets.map(
851
+ group_texts,
852
+ batched=True,
853
+ )
854
+
855
+ if training_args.do_train:
856
+ if "train" not in tokenized_datasets:
857
+ raise ValueError("--do_train requires a train dataset")
858
+ train_dataset = tokenized_datasets["train"]
859
+ if data_args.max_train_samples is not None:
860
+ max_train_samples = min(len(train_dataset), data_args.max_train_samples)
861
+ train_dataset = train_dataset.select(range(max_train_samples))
862
+
863
+ if training_args.do_eval:
864
+ if "validation" not in tokenized_datasets:
865
+ raise ValueError("--do_eval requires a validation dataset")
866
+ eval_dataset = tokenized_datasets["validation"]
867
+ if data_args.max_eval_samples is not None:
868
+ max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
869
+ eval_dataset = eval_dataset.select(range(max_eval_samples))
870
+
871
+ def preprocess_logits_for_metrics(logits, labels):
872
+ if isinstance(logits, tuple):
873
+ # Depending on the model and config, logits may contain extra tensors,
874
+ # like past_key_values, but logits always come first
875
+ logits = logits[0]
876
+ return logits.argmax(dim=-1)
877
+
878
+ metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir)
879
+
880
+ def compute_metrics(eval_preds):
881
+ preds, labels = eval_preds
882
+ preds = preds[:, :-1]
883
+ labels = labels[:, 1:]
884
+ # preds have the same shape as the labels, after the argmax(-1) has been calculated
885
+ # by preprocess_logits_for_metrics
886
+ labels = labels.reshape(-1)
887
+ preds = preds.reshape(-1)
888
+ mask = labels != -100
889
+ labels = labels[mask]
890
+ preds = preds[mask]
891
+ return metric.compute(predictions=preds, references=labels)
892
+
893
+ # Data collator
894
+ # This one will take care of randomly masking the tokens.
895
+ pad_to_multiple_of_8 = (
896
+ data_args.line_by_line
897
+ and training_args.fp16
898
+ and not data_args.pad_to_max_length
899
+ )
900
+ data_collator_cls = None
901
+ if custom_args.data_collator_type == "all_mask":
902
+ data_collator_cls = DataCollatorForLanguageModelingWithFullMasking
903
+ elif custom_args.data_collator_type == "default":
904
+ data_collator_cls = DataCollatorForLanguageModeling
905
+ else:
906
+ raise ValueError(
907
+ f"data_collator_type {custom_args.data_collator_type} is not supported."
908
+ )
909
+
910
+ data_collator = data_collator_cls(
911
+ tokenizer=tokenizer,
912
+ mlm_probability=data_args.mlm_probability,
913
+ pad_to_multiple_of=8 if pad_to_multiple_of_8 else None,
914
+ )
915
+
916
+ # Initialize our Trainer
917
+ trainer = MNTPTrainer(
918
+ model=model,
919
+ args=training_args,
920
+ train_dataset=train_dataset if training_args.do_train else None,
921
+ eval_dataset=eval_dataset if training_args.do_eval else None,
922
+ tokenizer=tokenizer,
923
+ data_collator=data_collator,
924
+ compute_metrics=(
925
+ compute_metrics
926
+ if training_args.do_eval and not is_torch_tpu_available()
927
+ else None
928
+ ),
929
+ preprocess_logits_for_metrics=(
930
+ preprocess_logits_for_metrics
931
+ if training_args.do_eval and not is_torch_tpu_available()
932
+ else None
933
+ ),
934
+ )
935
+
936
+ trainer.add_callback(StopTrainingCallback(custom_args.stop_after_n_steps))
937
+
938
+ # Training
939
+ if training_args.do_train:
940
+ checkpoint = None
941
+ if training_args.resume_from_checkpoint is not None:
942
+ checkpoint = training_args.resume_from_checkpoint
943
+ elif last_checkpoint is not None:
944
+ checkpoint = last_checkpoint
945
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
946
+ trainer.save_model() # Saves the tokenizer too for easy upload
947
+ metrics = train_result.metrics
948
+
949
+ max_train_samples = (
950
+ data_args.max_train_samples
951
+ if data_args.max_train_samples is not None
952
+ else len(train_dataset)
953
+ )
954
+ metrics["train_samples"] = min(max_train_samples, len(train_dataset))
955
+
956
+ trainer.log_metrics("train", metrics)
957
+ trainer.save_metrics("train", metrics)
958
+ trainer.save_state()
959
+
960
+ # Evaluation
961
+ if training_args.do_eval:
962
+ logger.info("*** Evaluate ***")
963
+
964
+ metrics = trainer.evaluate()
965
+
966
+ max_eval_samples = (
967
+ data_args.max_eval_samples
968
+ if data_args.max_eval_samples is not None
969
+ else len(eval_dataset)
970
+ )
971
+ metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
972
+ try:
973
+ perplexity = math.exp(metrics["eval_loss"])
974
+ except OverflowError:
975
+ perplexity = float("inf")
976
+ metrics["perplexity"] = perplexity
977
+
978
+ trainer.log_metrics("eval", metrics)
979
+ trainer.save_metrics("eval", metrics)
980
+
981
+ kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "fill-mask"}
982
+ if data_args.dataset_name is not None:
983
+ kwargs["dataset_tags"] = data_args.dataset_name
984
+ if data_args.dataset_config_name is not None:
985
+ kwargs["dataset_args"] = data_args.dataset_config_name
986
+ kwargs["dataset"] = (
987
+ f"{data_args.dataset_name} {data_args.dataset_config_name}"
988
+ )
989
+ else:
990
+ kwargs["dataset"] = data_args.dataset_name
991
+
992
+ if training_args.push_to_hub:
993
+ trainer.push_to_hub(**kwargs)
994
+
995
+
996
+ if __name__ == "__main__":
997
+ main()
llm2vec/experiments/run_simcse.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from dataclasses import dataclass, field
3
+ import os
4
+ import sys
5
+ from typing import Any, Dict, List, Optional, Tuple, Union
6
+
7
+ import torch
8
+ from torch import nn
9
+
10
+ from accelerate import Accelerator, DistributedDataParallelKwargs
11
+ from accelerate.logging import get_logger
12
+
13
+ import transformers
14
+ from transformers import (
15
+ MODEL_FOR_MASKED_LM_MAPPING,
16
+ HfArgumentParser,
17
+ TrainingArguments,
18
+ Trainer,
19
+ TrainerCallback,
20
+ set_seed,
21
+ )
22
+ from transformers.trainer_utils import seed_worker
23
+
24
+ from peft import LoraConfig, get_peft_model
25
+
26
+ from llm2vec import LLM2Vec
27
+ from llm2vec.dataset.utils import load_dataset
28
+ from llm2vec.loss.utils import load_loss
29
+
30
+ from tqdm import tqdm
31
+
32
+ transformers.logging.set_verbosity_error()
33
+
34
+ logging.basicConfig(
35
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
36
+ datefmt="%Y-%m-%d %H:%M:%S",
37
+ level=logging.INFO,
38
+ )
39
+ logger = get_logger(__name__, log_level="INFO")
40
+ MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
41
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
42
+
43
+
44
+ def initialize_peft(
45
+ model,
46
+ lora_r: int = 8,
47
+ lora_alpha: int = 16,
48
+ lora_dropout: float = 0.05,
49
+ lora_modules: Optional[List[str]] = None,
50
+ ):
51
+ if lora_modules is None and model.config.__class__.__name__ in [
52
+ "LlamaConfig",
53
+ "MistralConfig",
54
+ "GemmaConfig",
55
+ "Qwen2Config",
56
+ ]:
57
+ lora_modules = [
58
+ "q_proj",
59
+ "v_proj",
60
+ "k_proj",
61
+ "o_proj",
62
+ "gate_proj",
63
+ "up_proj",
64
+ "down_proj",
65
+ ]
66
+ elif lora_modules is None:
67
+ raise ValueError("lora_modules must be specified for this model.")
68
+
69
+ config = LoraConfig(
70
+ r=lora_r,
71
+ lora_alpha=lora_alpha,
72
+ target_modules=lora_modules,
73
+ lora_dropout=lora_dropout,
74
+ bias="none",
75
+ task_type=None,
76
+ )
77
+
78
+ model = get_peft_model(model, config)
79
+ print(f"Model's Lora trainable parameters:")
80
+ model.print_trainable_parameters()
81
+ return model
82
+
83
+
84
+ @dataclass
85
+ class ModelArguments:
86
+ """
87
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
88
+ """
89
+
90
+ model_name_or_path: Optional[str] = field(
91
+ default=None,
92
+ metadata={
93
+ "help": (
94
+ "The base model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
95
+ )
96
+ },
97
+ )
98
+ peft_model_name_or_path: Optional[str] = field(
99
+ default=None,
100
+ metadata={"help": ("The PEFT model checkpoint to add on top of base model.")},
101
+ )
102
+ bidirectional: Optional[bool] = field(
103
+ default=False,
104
+ metadata={
105
+ "help": (
106
+ "Whether to enable bidirectional attention in the model. If set to False, the model will use unidirectional attention."
107
+ )
108
+ },
109
+ )
110
+ max_seq_length: Optional[int] = field(
111
+ default=None,
112
+ metadata={
113
+ "help": (
114
+ "The maximum total input sequence length after tokenization. Sequences longer "
115
+ "than this will be truncated."
116
+ )
117
+ },
118
+ )
119
+ torch_dtype: Optional[str] = field(
120
+ default=None,
121
+ metadata={
122
+ "help": (
123
+ "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
124
+ "dtype will be automatically derived from the model's weights."
125
+ ),
126
+ "choices": ["auto", "bfloat16", "float16", "float32"],
127
+ },
128
+ )
129
+ attn_implementation: Optional[str] = field(
130
+ default="sdpa",
131
+ metadata={
132
+ "help": ("The attention implementation to use in the model."),
133
+ "choices": ["eager", "sdpa", "flash_attention_2"],
134
+ },
135
+ )
136
+ pooling_mode: Optional[str] = field(
137
+ default="mean",
138
+ metadata={
139
+ "help": ("The pooling mode to use in the model."),
140
+ "choices": ["mean", "weighted_mean", "eos_token"],
141
+ },
142
+ )
143
+
144
+
145
+ @dataclass
146
+ class DataTrainingArguments:
147
+ """
148
+ Arguments pertaining to what data we are going to input our model for training and eval.
149
+ """
150
+
151
+ dataset_name: Optional[str] = field(
152
+ default=None,
153
+ metadata={"help": "The name of the dataset to use. Options: E5"},
154
+ )
155
+ dataset_file_path: Optional[str] = field(
156
+ default=None, metadata={"help": "The input training data file or folder."}
157
+ )
158
+ # TODO: implement this
159
+ max_train_samples: Optional[int] = field(
160
+ default=None,
161
+ metadata={
162
+ "help": (
163
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
164
+ "value if set."
165
+ )
166
+ },
167
+ )
168
+
169
+
170
+ @dataclass
171
+ class CustomArguments:
172
+ """
173
+ Custom arguments for the script
174
+ """
175
+
176
+ simcse_dropout: float = field(
177
+ default=0.1, metadata={"help": "The SimCSE dropout rate for the model"}
178
+ )
179
+
180
+ lora_dropout: float = field(
181
+ default=0.05, metadata={"help": "The dropout rate for lora"}
182
+ )
183
+
184
+ lora_r: int = field(default=8, metadata={"help": "The r value for lora"})
185
+
186
+ stop_after_n_steps: int = field(
187
+ default=10000, metadata={"help": "Stop training after n steps"}
188
+ )
189
+
190
+ experiment_id: Optional[str] = field(
191
+ default=None, metadata={"help": "The experiment id"}
192
+ )
193
+
194
+ loss_class: Optional[str] = field(
195
+ default="HardNegativeNLLLoss",
196
+ metadata={
197
+ "help": "The loss class to use for training. Options: HardNegativeNLLLoss"
198
+ },
199
+ )
200
+
201
+ loss_scale: float = field(
202
+ default=50.0, metadata={"help": "The loss scale for the loss function"}
203
+ )
204
+
205
+
206
+ @dataclass
207
+ class DefaultCollator:
208
+ model: LLM2Vec
209
+
210
+ def __init__(self, model: LLM2Vec) -> None:
211
+ self.model = model
212
+
213
+ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
214
+ batch = features
215
+ num_texts = len(batch[0].texts)
216
+ texts = [[] for _ in range(num_texts)]
217
+ labels = []
218
+
219
+ for example in batch:
220
+ for idx, text in enumerate(example.texts):
221
+ # TODO: Add prepare_for_tokenization here similar to supervised training and see if it impacts performance
222
+ texts[idx].append(text)
223
+ labels.append(example.label)
224
+ labels = torch.tensor(labels)
225
+
226
+ sentence_features = []
227
+ for idx in range(num_texts):
228
+ tokenized = self.model.tokenize(texts[idx])
229
+ sentence_features.append(tokenized)
230
+
231
+ return sentence_features, labels
232
+
233
+
234
+ class StopTrainingCallback(TrainerCallback):
235
+ def __init__(self, stop_after_n_steps: int):
236
+ self.stop_after_n_steps = stop_after_n_steps
237
+
238
+ def on_step_end(self, args, state, control, **kwargs):
239
+ if state.global_step >= self.stop_after_n_steps:
240
+ control.should_training_stop = True
241
+
242
+
243
+ class SimCSETrainer(Trainer):
244
+ def __init__(
245
+ self,
246
+ *args,
247
+ loss_function=None,
248
+ **kwargs,
249
+ ) -> None:
250
+ super().__init__(*args, **kwargs)
251
+ self.loss_function = loss_function
252
+
253
+ def compute_loss(
254
+ self,
255
+ model: nn.Module,
256
+ inputs: Dict[str, Union[torch.Tensor, Any]],
257
+ return_outputs: bool = False,
258
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
259
+ features, labels = inputs
260
+ q_reps = self.model(features[0])
261
+ d_reps = self.model(features[1])
262
+
263
+ d_reps_neg = None
264
+ if len(features) > 2:
265
+ d_reps_neg = self.model(features[2])
266
+
267
+ loss = self.loss_function(q_reps, d_reps, d_reps_neg)
268
+
269
+ if return_outputs:
270
+ output = torch.cat(
271
+ [model(row)["sentence_embedding"][:, None] for row in features], dim=1
272
+ )
273
+ return loss, output
274
+
275
+ return loss
276
+
277
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
278
+ # If we are executing this function, we are the process zero, so we don't check for that.
279
+ output_dir = output_dir if output_dir is not None else self.args.output_dir
280
+ os.makedirs(output_dir, exist_ok=True)
281
+ logger.info(f"Saving model checkpoint to {output_dir}")
282
+
283
+ self.model.save(output_dir)
284
+
285
+ # Good practice: save your training arguments together with the trained model
286
+ torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
287
+
288
+
289
+ def main():
290
+ parser = HfArgumentParser(
291
+ (ModelArguments, DataTrainingArguments, TrainingArguments, CustomArguments)
292
+ )
293
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
294
+ # If we pass only one argument to the script and it's the path to a json file,
295
+ # let's parse it to get our arguments.
296
+ model_args, data_args, training_args, custom_args = parser.parse_json_file(
297
+ json_file=os.path.abspath(sys.argv[1])
298
+ )
299
+ else:
300
+ (
301
+ model_args,
302
+ data_args,
303
+ training_args,
304
+ custom_args,
305
+ ) = parser.parse_args_into_dataclasses()
306
+ if training_args.ddp_find_unused_parameters:
307
+ kwargs = [
308
+ DistributedDataParallelKwargs(
309
+ dim=0,
310
+ broadcast_buffers=True,
311
+ bucket_cap_mb=25,
312
+ find_unused_parameters=True,
313
+ check_reduction=False,
314
+ gradient_as_bucket_view=False,
315
+ )
316
+ ]
317
+ else:
318
+ kwargs = []
319
+ accelerator = Accelerator(kwargs_handlers=kwargs)
320
+
321
+ set_seed(training_args.seed)
322
+
323
+ if training_args.gradient_checkpointing:
324
+ training_args.gradient_checkpointing_kwargs = {"use_reentrant": False}
325
+
326
+ train_dataset = load_dataset(
327
+ data_args.dataset_name,
328
+ split="train",
329
+ file_path=data_args.dataset_file_path,
330
+ )
331
+
332
+ train_examples = [
333
+ train_dataset[i]
334
+ for i in tqdm(
335
+ range(len(train_dataset)),
336
+ desc="Loading train examples...",
337
+ disable=not accelerator.is_main_process,
338
+ )
339
+ ]
340
+
341
+ torch_dtype = (
342
+ model_args.torch_dtype
343
+ if model_args.torch_dtype in ["auto", None]
344
+ else getattr(torch, model_args.torch_dtype)
345
+ )
346
+ model = LLM2Vec.from_pretrained(
347
+ base_model_name_or_path=model_args.model_name_or_path,
348
+ enable_bidirectional=model_args.bidirectional,
349
+ peft_model_name_or_path=model_args.peft_model_name_or_path,
350
+ merge_peft=True,
351
+ pooling_mode=model_args.pooling_mode,
352
+ max_length=model_args.max_seq_length,
353
+ torch_dtype=torch_dtype,
354
+ attn_implementation=model_args.attn_implementation,
355
+ attention_dropout=custom_args.simcse_dropout,
356
+ )
357
+
358
+ # model organization is LLM2VecModel.model -> HF Model, we have to apply PEFT to the inner model
359
+ model.model = initialize_peft(
360
+ model.model,
361
+ lora_r=custom_args.lora_r,
362
+ lora_alpha=2 * custom_args.lora_r,
363
+ lora_dropout=custom_args.lora_dropout,
364
+ )
365
+
366
+ tokenizer = model.tokenizer
367
+
368
+ train_loss = load_loss(custom_args.loss_class, scale=custom_args.loss_scale)
369
+
370
+ data_collator = DefaultCollator(model)
371
+
372
+ trainer = SimCSETrainer(
373
+ model=model,
374
+ args=training_args,
375
+ train_dataset=train_examples,
376
+ data_collator=data_collator,
377
+ tokenizer=tokenizer,
378
+ loss_function=train_loss,
379
+ )
380
+
381
+ if custom_args.stop_after_n_steps is not None:
382
+ trainer.add_callback(StopTrainingCallback(custom_args.stop_after_n_steps))
383
+
384
+ trainer.train()
385
+
386
+
387
+ if __name__ == "__main__":
388
+ main()
llm2vec/experiments/run_supervised.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from dataclasses import dataclass, field
3
+ import os
4
+ import sys
5
+ from typing import Any, Dict, List, Optional, Tuple, Union
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.utils.data import DataLoader, SequentialSampler
10
+
11
+ from accelerate import Accelerator, DistributedDataParallelKwargs
12
+ from accelerate.logging import get_logger
13
+
14
+ import transformers
15
+ from transformers import (
16
+ MODEL_FOR_MASKED_LM_MAPPING,
17
+ HfArgumentParser,
18
+ TrainingArguments,
19
+ Trainer,
20
+ TrainerCallback,
21
+ LlamaConfig,
22
+ MistralConfig,
23
+ GemmaConfig,
24
+ Qwen2Config,
25
+ set_seed,
26
+ )
27
+ from transformers.trainer_utils import seed_worker
28
+
29
+ from peft import LoraConfig, get_peft_model
30
+
31
+ from llm2vec import LLM2Vec
32
+ from llm2vec.dataset.utils import load_dataset
33
+ from llm2vec.loss.utils import load_loss
34
+ from llm2vec.experiment_utils import generate_experiment_id
35
+
36
+ from tqdm import tqdm
37
+
38
+ transformers.logging.set_verbosity_error()
39
+
40
+ logging.basicConfig(
41
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
42
+ datefmt="%Y-%m-%d %H:%M:%S",
43
+ level=logging.INFO,
44
+ )
45
+ logger = get_logger(__name__, log_level="INFO")
46
+ MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
47
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
48
+
49
+
50
+ def prepare_for_tokenization(model, text, pooling_mode="mean"):
51
+ if model.config._name_or_path == "meta-llama/Meta-Llama-3-8B-Instruct":
52
+ text = (
53
+ "<|start_header_id|>user<|end_header_id|>\n\n" + text.strip() + "<|eot_id|>"
54
+ )
55
+ return text
56
+ if model.config._name_or_path in [
57
+ "mistralai/Mistral-7B-Instruct-v0.2",
58
+ "meta-llama/Llama-2-7b-chat-hf",
59
+ ]:
60
+ text = "[INST] " + text.strip() + " [/INST]"
61
+ if model.config._name_or_path in [
62
+ "google/gemma-2-9b-it",
63
+ ]:
64
+ text = "<bos><start_of_turn>user\n" + text.strip() + "<end_of_turn>"
65
+ if model.config._name_or_path in [
66
+ "Qwen/Qwen2-1.5B-Instruct",
67
+ "Qwen/Qwen2-7B-Instruct",
68
+ ]:
69
+ text = "<|im_start|>user\n" + text.strip() + "<|im_end|>"
70
+ if pooling_mode == "eos_token":
71
+ if model.config._name_or_path == "meta-llama/Meta-Llama-3-8B":
72
+ text = text.strip() + "<|end_of_text|>"
73
+ elif isinstance(model.config, LlamaConfig) or isinstance(
74
+ model.config, MistralConfig
75
+ ):
76
+ text = text.strip() + " </s>"
77
+ elif isinstance(model.config, GemmaConfig):
78
+ text = text.strip() + "<eos>"
79
+ elif isinstance(model.config, Qwen2Config):
80
+ text = text.strip() + "<|endoftext|>"
81
+ return text
82
+
83
+
84
+ def initialize_peft(
85
+ model,
86
+ lora_r: int = 8,
87
+ lora_alpha: int = 16,
88
+ lora_dropout: float = 0.05,
89
+ lora_modules: Optional[List[str]] = None,
90
+ ):
91
+ if lora_modules is None and model.config.__class__.__name__ in [
92
+ "LlamaConfig",
93
+ "MistralConfig",
94
+ "GemmaConfig",
95
+ "Qwen2Config",
96
+ ]:
97
+ lora_modules = [
98
+ "q_proj",
99
+ "v_proj",
100
+ "k_proj",
101
+ "o_proj",
102
+ "gate_proj",
103
+ "up_proj",
104
+ "down_proj",
105
+ ]
106
+ elif lora_modules is None:
107
+ raise ValueError("lora_modules must be specified for this model.")
108
+
109
+ config = LoraConfig(
110
+ r=lora_r,
111
+ lora_alpha=lora_alpha,
112
+ target_modules=lora_modules,
113
+ lora_dropout=lora_dropout,
114
+ bias="none",
115
+ task_type=None,
116
+ )
117
+
118
+ model = get_peft_model(model, config)
119
+ print(f"Model's Lora trainable parameters:")
120
+ model.print_trainable_parameters()
121
+ return model
122
+
123
+
124
+ @dataclass
125
+ class ModelArguments:
126
+ """
127
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
128
+ """
129
+
130
+ model_name_or_path: Optional[str] = field(
131
+ default=None,
132
+ metadata={
133
+ "help": (
134
+ "The base model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
135
+ )
136
+ },
137
+ )
138
+ peft_model_name_or_path: Optional[str] = field(
139
+ default=None,
140
+ metadata={"help": ("The PEFT model checkpoint to add on top of base model.")},
141
+ )
142
+ bidirectional: Optional[bool] = field(
143
+ default=False,
144
+ metadata={
145
+ "help": (
146
+ "Whether to enable bidirectional attention in the model. If set to False, the model will use unidirectional attention."
147
+ )
148
+ },
149
+ )
150
+ max_seq_length: Optional[int] = field(
151
+ default=None,
152
+ metadata={
153
+ "help": (
154
+ "The maximum total input sequence length after tokenization. Sequences longer "
155
+ "than this will be truncated."
156
+ )
157
+ },
158
+ )
159
+ torch_dtype: Optional[str] = field(
160
+ default=None,
161
+ metadata={
162
+ "help": (
163
+ "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
164
+ "dtype will be automatically derived from the model's weights."
165
+ ),
166
+ "choices": ["auto", "bfloat16", "float16", "float32"],
167
+ },
168
+ )
169
+ attn_implementation: Optional[str] = field(
170
+ default="sdpa",
171
+ metadata={
172
+ "help": ("The attention implementation to use in the model."),
173
+ "choices": ["eager", "sdpa", "flash_attention_2"],
174
+ },
175
+ )
176
+ pooling_mode: Optional[str] = field(
177
+ default="mean",
178
+ metadata={
179
+ "help": ("The pooling mode to use in the model."),
180
+ "choices": ["mean", "weighted_mean", "eos_token"],
181
+ },
182
+ )
183
+
184
+
185
+ @dataclass
186
+ class DataTrainingArguments:
187
+ """
188
+ Arguments pertaining to what data we are going to input our model for training and eval.
189
+ """
190
+
191
+ dataset_name: Optional[str] = field(
192
+ default=None,
193
+ metadata={"help": "The name of the dataset to use. Options: E5"},
194
+ )
195
+ dataset_file_path: Optional[str] = field(
196
+ default=None, metadata={"help": "The input training data file or folder."}
197
+ )
198
+ # TODO: implement this
199
+ max_train_samples: Optional[int] = field(
200
+ default=None,
201
+ metadata={
202
+ "help": (
203
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
204
+ "value if set."
205
+ )
206
+ },
207
+ )
208
+
209
+
210
+ @dataclass
211
+ class CustomArguments:
212
+ """
213
+ Custom arguments for the script
214
+ """
215
+
216
+ lora_dropout: float = field(
217
+ default=0.05, metadata={"help": "The dropout rate for lora"}
218
+ )
219
+
220
+ lora_r: int = field(default=8, metadata={"help": "The r value for lora"})
221
+
222
+ stop_after_n_steps: int = field(
223
+ default=10000, metadata={"help": "Stop training after n steps"}
224
+ )
225
+
226
+ experiment_id: Optional[str] = field(
227
+ default=None, metadata={"help": "The experiment id"}
228
+ )
229
+
230
+ loss_class: Optional[str] = field(
231
+ default="HardNegativeNLLLoss",
232
+ metadata={
233
+ "help": "The loss class to use for training. Options: HardNegativeNLLLoss"
234
+ },
235
+ )
236
+
237
+ loss_scale: float = field(
238
+ default=50.0, metadata={"help": "The loss scale for the loss function"}
239
+ )
240
+
241
+
242
+ @dataclass
243
+ class DefaultCollator:
244
+ model: LLM2Vec
245
+
246
+ def __init__(self, model: LLM2Vec) -> None:
247
+ self.model = model
248
+
249
+ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
250
+ batch = features
251
+ num_texts = len(batch[0].texts)
252
+ texts = [[] for _ in range(num_texts)]
253
+ labels = []
254
+
255
+ for example in batch:
256
+ for idx, text in enumerate(example.texts):
257
+ text = prepare_for_tokenization(
258
+ self.model, text, pooling_mode=self.model.pooling_mode
259
+ )
260
+ texts[idx].append(text)
261
+ labels.append(example.label)
262
+ labels = torch.tensor(labels)
263
+
264
+ sentence_features = []
265
+ for idx in range(num_texts):
266
+ tokenized = self.model.tokenize(texts[idx])
267
+ sentence_features.append(tokenized)
268
+
269
+ return sentence_features, labels
270
+
271
+
272
+ class StopTrainingCallback(TrainerCallback):
273
+ def __init__(self, stop_after_n_steps: int):
274
+ self.stop_after_n_steps = stop_after_n_steps
275
+
276
+ def on_step_end(self, args, state, control, **kwargs):
277
+ if state.global_step >= self.stop_after_n_steps:
278
+ control.should_training_stop = True
279
+
280
+
281
+ class LLM2VecSupervisedTrainer(Trainer):
282
+ def __init__(
283
+ self,
284
+ *args,
285
+ loss_function=None,
286
+ **kwargs,
287
+ ) -> None:
288
+ super().__init__(*args, **kwargs)
289
+ self.loss_function = loss_function
290
+
291
+ def compute_loss(
292
+ self,
293
+ model: nn.Module,
294
+ inputs: Dict[str, Union[torch.Tensor, Any]],
295
+ return_outputs: bool = False,
296
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
297
+ features, labels = inputs
298
+ q_reps = self.model(features[0])
299
+ d_reps = self.model(features[1])
300
+
301
+ d_reps_neg = None
302
+ if len(features) > 2:
303
+ d_reps_neg = self.model(features[2])
304
+
305
+ loss = self.loss_function(q_reps, d_reps, d_reps_neg)
306
+
307
+ if return_outputs:
308
+ output = torch.cat(
309
+ [model(row)["sentence_embedding"][:, None] for row in features], dim=1
310
+ )
311
+ return loss, output
312
+
313
+ return loss
314
+
315
+ def get_train_dataloader(self) -> DataLoader:
316
+ # Copying most of the code from the parent class, changing the sampler to SequentialSampler
317
+ if self.train_dataset is None:
318
+ raise ValueError("Trainer: training requires a train_dataset.")
319
+
320
+ train_dataset = self.train_dataset
321
+ data_collator = self.data_collator
322
+
323
+ data_collator = self._get_collator_with_removed_columns(
324
+ data_collator, description="training"
325
+ )
326
+
327
+ dataloader_params = {
328
+ "batch_size": self._train_batch_size,
329
+ "collate_fn": data_collator,
330
+ "num_workers": self.args.dataloader_num_workers,
331
+ "pin_memory": self.args.dataloader_pin_memory,
332
+ "persistent_workers": self.args.dataloader_persistent_workers,
333
+ }
334
+
335
+ if not isinstance(train_dataset, torch.utils.data.IterableDataset):
336
+ # Changing from random sampler to sequential sampler
337
+ dataloader_params["sampler"] = SequentialSampler(train_dataset)
338
+ dataloader_params["drop_last"] = self.args.dataloader_drop_last
339
+ dataloader_params["worker_init_fn"] = seed_worker
340
+
341
+ return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
342
+
343
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
344
+ # If we are executing this function, we are the process zero, so we don't check for that.
345
+ output_dir = output_dir if output_dir is not None else self.args.output_dir
346
+ os.makedirs(output_dir, exist_ok=True)
347
+ logger.info(f"Saving model checkpoint to {output_dir}")
348
+
349
+ self.model.save(output_dir)
350
+
351
+ # Good practice: save your training arguments together with the trained model
352
+ torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
353
+
354
+
355
+ def main():
356
+ parser = HfArgumentParser(
357
+ (ModelArguments, DataTrainingArguments, TrainingArguments, CustomArguments)
358
+ )
359
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
360
+ # If we pass only one argument to the script and it's the path to a json file,
361
+ # let's parse it to get our arguments.
362
+ model_args, data_args, training_args, custom_args = parser.parse_json_file(
363
+ json_file=os.path.abspath(sys.argv[1])
364
+ )
365
+ else:
366
+ (
367
+ model_args,
368
+ data_args,
369
+ training_args,
370
+ custom_args,
371
+ ) = parser.parse_args_into_dataclasses()
372
+ if training_args.ddp_find_unused_parameters:
373
+ kwargs = [
374
+ DistributedDataParallelKwargs(
375
+ dim=0,
376
+ broadcast_buffers=True,
377
+ bucket_cap_mb=25,
378
+ find_unused_parameters=True,
379
+ check_reduction=False,
380
+ gradient_as_bucket_view=False,
381
+ )
382
+ ]
383
+ else:
384
+ kwargs = []
385
+ accelerator = Accelerator(kwargs_handlers=kwargs)
386
+
387
+ set_seed(training_args.seed)
388
+
389
+ if training_args.gradient_checkpointing:
390
+ training_args.gradient_checkpointing_kwargs = {"use_reentrant": False}
391
+
392
+ if custom_args.experiment_id is not None:
393
+ experiment_id = custom_args.experiment_id
394
+ else:
395
+ experiment_id = generate_experiment_id(
396
+ name=data_args.dataset_name,
397
+ split="train",
398
+ model_name=(
399
+ model_args.model_name_or_path
400
+ if "/" not in model_args.model_name_or_path
401
+ else model_args.model_name_or_path.split("/")[-1]
402
+ ),
403
+ pooling_mode=model_args.pooling_mode,
404
+ train_batch_size=training_args.per_device_train_batch_size
405
+ * accelerator.num_processes
406
+ * training_args.gradient_accumulation_steps,
407
+ max_seq_length=model_args.max_seq_length,
408
+ bidirectional=model_args.bidirectional,
409
+ epochs=training_args.num_train_epochs,
410
+ seed=training_args.seed,
411
+ warmup_steps=training_args.warmup_steps,
412
+ lr=training_args.learning_rate,
413
+ lora_r=custom_args.lora_r,
414
+ )
415
+
416
+ training_args.output_dir = f"{training_args.output_dir}/{experiment_id}"
417
+
418
+ # TODO: can also pass separator arg here
419
+ train_dataset = load_dataset(
420
+ data_args.dataset_name,
421
+ split="train",
422
+ file_path=data_args.dataset_file_path,
423
+ effective_batch_size=training_args.per_device_train_batch_size
424
+ * accelerator.num_processes,
425
+ )
426
+
427
+ train_examples = [
428
+ train_dataset[i]
429
+ for i in tqdm(
430
+ range(len(train_dataset)),
431
+ desc="Loading train examples...",
432
+ disable=not accelerator.is_main_process,
433
+ )
434
+ ]
435
+
436
+ torch_dtype = (
437
+ model_args.torch_dtype
438
+ if model_args.torch_dtype in ["auto", None]
439
+ else getattr(torch, model_args.torch_dtype)
440
+ )
441
+ model = LLM2Vec.from_pretrained(
442
+ base_model_name_or_path=model_args.model_name_or_path,
443
+ enable_bidirectional=model_args.bidirectional,
444
+ peft_model_name_or_path=model_args.peft_model_name_or_path,
445
+ merge_peft=True,
446
+ pooling_mode=model_args.pooling_mode,
447
+ max_length=model_args.max_seq_length,
448
+ torch_dtype=torch_dtype,
449
+ attn_implementation=model_args.attn_implementation,
450
+ )
451
+
452
+ # model organization is LLM2VecModel.model -> HF Model, we have to apply PEFT to the inner model
453
+ model.model = initialize_peft(
454
+ model.model,
455
+ lora_r=custom_args.lora_r,
456
+ lora_alpha=2 * custom_args.lora_r,
457
+ lora_dropout=custom_args.lora_dropout,
458
+ )
459
+
460
+ tokenizer = model.tokenizer
461
+
462
+ train_loss = load_loss(custom_args.loss_class, scale=custom_args.loss_scale)
463
+
464
+ data_collator = DefaultCollator(model)
465
+
466
+ trainer = LLM2VecSupervisedTrainer(
467
+ model=model,
468
+ args=training_args,
469
+ train_dataset=train_examples,
470
+ data_collator=data_collator,
471
+ tokenizer=tokenizer,
472
+ loss_function=train_loss,
473
+ )
474
+
475
+ if custom_args.stop_after_n_steps is not None:
476
+ trainer.add_callback(StopTrainingCallback(custom_args.stop_after_n_steps))
477
+
478
+ trainer.train()
479
+
480
+
481
+ if __name__ == "__main__":
482
+ main()
llm2vec/experiments/run_word_task.py ADDED
@@ -0,0 +1,905 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The script is adapted from https://huggingface.co/docs/transformers/en/tasks/token_classification
3
+ """
4
+
5
+ import logging
6
+ import os
7
+ import sys
8
+ import warnings
9
+ from dataclasses import dataclass, field
10
+ import numpy as np
11
+ from typing import List, Optional, Tuple, Union
12
+
13
+ import datasets
14
+ import evaluate
15
+ from datasets import load_dataset
16
+
17
+ import torch
18
+ from torch import nn
19
+ from torch.nn import CrossEntropyLoss
20
+ import transformers
21
+ from transformers import (
22
+ PreTrainedModel,
23
+ MODEL_FOR_MASKED_LM_MAPPING,
24
+ AutoConfig,
25
+ AutoTokenizer,
26
+ HfArgumentParser,
27
+ Trainer,
28
+ TrainingArguments,
29
+ TrainerCallback,
30
+ set_seed,
31
+ AutoModelForTokenClassification,
32
+ DataCollatorForTokenClassification,
33
+ )
34
+
35
+ from transformers.modeling_outputs import TokenClassifierOutput
36
+ from transformers.utils import send_example_telemetry
37
+ from transformers.utils.versions import require_version
38
+
39
+ from llm2vec import LLM2Vec
40
+
41
+ require_version(
42
+ "datasets>=1.8.0",
43
+ "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt",
44
+ )
45
+
46
+
47
+ class ModelForWordTask(PreTrainedModel):
48
+ def __init__(self, config, model, merge_subwords=False, **model_args):
49
+ PreTrainedModel.__init__(self, config)
50
+ self.model = model
51
+ self.merge_subwords = merge_subwords
52
+
53
+ if (
54
+ hasattr(config, "classifier_dropout")
55
+ and config.classifier_dropout is not None
56
+ ):
57
+ classifier_dropout = config.classifier_dropout
58
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
59
+ classifier_dropout = config.hidden_dropout
60
+ else:
61
+ classifier_dropout = 0.1
62
+
63
+ self.dropout = nn.Dropout(classifier_dropout)
64
+ self.num_labels = config.num_labels
65
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels).to(
66
+ model_args.get("torch_dtype")
67
+ )
68
+
69
+ # Initialize weights and apply final processing
70
+ self.post_init()
71
+
72
+ def _merge_subwords(self, hidden_states, token_type_ids, attention_mask):
73
+ new_hidden_states = hidden_states.clone()
74
+ for b in range(hidden_states.shape[0]):
75
+ for w in torch.arange(0, token_type_ids[b].max() + 1):
76
+ words_w = (token_type_ids[b] == w) * (attention_mask[b] > 0)
77
+ new_hidden_states[b][words_w] = torch.mean(
78
+ hidden_states[b][words_w], dim=0
79
+ ).repeat(sum(words_w), 1)
80
+ return new_hidden_states
81
+
82
+ def forward(
83
+ self,
84
+ input_ids: torch.LongTensor = None,
85
+ attention_mask: Optional[torch.Tensor] = None,
86
+ position_ids: Optional[torch.LongTensor] = None,
87
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
88
+ inputs_embeds: Optional[torch.FloatTensor] = None,
89
+ use_cache: Optional[bool] = None,
90
+ output_attentions: Optional[bool] = None,
91
+ output_hidden_states: Optional[bool] = None,
92
+ return_dict: Optional[bool] = None,
93
+ token_type_ids: Optional[torch.LongTensor] = None,
94
+ head_mask: Optional[torch.FloatTensor] = None,
95
+ labels: Optional[torch.LongTensor] = None,
96
+ ) -> Union[Tuple, TokenClassifierOutput]:
97
+ output_attentions = (
98
+ output_attentions
99
+ if output_attentions is not None
100
+ else self.config.output_attentions
101
+ )
102
+ output_hidden_states = (
103
+ output_hidden_states
104
+ if output_hidden_states is not None
105
+ else self.config.output_hidden_states
106
+ )
107
+
108
+ return_dict = (
109
+ return_dict if return_dict is not None else self.config.use_return_dict
110
+ )
111
+
112
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
113
+ outputs = self.model(
114
+ input_ids=input_ids,
115
+ attention_mask=attention_mask,
116
+ position_ids=position_ids,
117
+ past_key_values=past_key_values,
118
+ inputs_embeds=inputs_embeds,
119
+ use_cache=use_cache,
120
+ output_attentions=output_attentions,
121
+ output_hidden_states=output_hidden_states,
122
+ return_dict=return_dict,
123
+ )
124
+
125
+ hidden_states = outputs[0]
126
+
127
+ if self.merge_subwords:
128
+ hidden_states = self._merge_subwords(
129
+ hidden_states, token_type_ids, attention_mask
130
+ )
131
+
132
+ hidden_states = self.dropout(hidden_states)
133
+ logits = self.classifier(hidden_states)
134
+
135
+ loss = None
136
+ if labels is not None:
137
+ labels = labels.to(logits.device)
138
+ loss_fct = CrossEntropyLoss()
139
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
140
+
141
+ if not return_dict:
142
+ output = (logits,) + outputs.hidden_states
143
+ return ((loss,) + output) if loss is not None else output
144
+
145
+ return TokenClassifierOutput(
146
+ loss=loss,
147
+ logits=logits,
148
+ hidden_states=hidden_states,
149
+ attentions=outputs.attentions,
150
+ )
151
+
152
+
153
+ logger = logging.getLogger(__name__)
154
+ MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
155
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
156
+ LABELS = {
157
+ "conll2003": {
158
+ "pos_tags": {
159
+ '"': 0,
160
+ "''": 1,
161
+ "#": 2,
162
+ "$": 3,
163
+ "(": 4,
164
+ ")": 5,
165
+ ",": 6,
166
+ ".": 7,
167
+ ":": 8,
168
+ "``": 9,
169
+ "CC": 10,
170
+ "CD": 11,
171
+ "DT": 12,
172
+ "EX": 13,
173
+ "FW": 14,
174
+ "IN": 15,
175
+ "JJ": 16,
176
+ "JJR": 17,
177
+ "JJS": 18,
178
+ "LS": 19,
179
+ "MD": 20,
180
+ "NN": 21,
181
+ "NNP": 22,
182
+ "NNPS": 23,
183
+ "NNS": 24,
184
+ "NN|SYM": 25,
185
+ "PDT": 26,
186
+ "POS": 27,
187
+ "PRP": 28,
188
+ "PRP$": 29,
189
+ "RB": 30,
190
+ "RBR": 31,
191
+ "RBS": 32,
192
+ "RP": 33,
193
+ "SYM": 34,
194
+ "TO": 35,
195
+ "UH": 36,
196
+ "VB": 37,
197
+ "VBD": 38,
198
+ "VBG": 39,
199
+ "VBN": 40,
200
+ "VBP": 41,
201
+ "VBZ": 42,
202
+ "WDT": 43,
203
+ "WP": 44,
204
+ "WP$": 45,
205
+ "WRB": 46,
206
+ },
207
+ "chunk_tags": {
208
+ "O": 0,
209
+ "B-ADJP": 1,
210
+ "I-ADJP": 2,
211
+ "B-ADVP": 3,
212
+ "I-ADVP": 4,
213
+ "B-CONJP": 5,
214
+ "I-CONJP": 6,
215
+ "B-INTJ": 7,
216
+ "I-INTJ": 8,
217
+ "B-LST": 9,
218
+ "I-LST": 10,
219
+ "B-NP": 11,
220
+ "I-NP": 12,
221
+ "B-PP": 13,
222
+ "I-PP": 14,
223
+ "B-PRT": 15,
224
+ "I-PRT": 16,
225
+ "B-SBAR": 17,
226
+ "I-SBAR": 18,
227
+ "B-UCP": 19,
228
+ "I-UCP": 20,
229
+ "B-VP": 21,
230
+ "I-VP": 22,
231
+ },
232
+ "ner_tags": {
233
+ "O": 0,
234
+ "B-PER": 1,
235
+ "I-PER": 2,
236
+ "B-ORG": 3,
237
+ "I-ORG": 4,
238
+ "B-LOC": 5,
239
+ "I-LOC": 6,
240
+ "B-MISC": 7,
241
+ "I-MISC": 8,
242
+ },
243
+ }
244
+ }
245
+
246
+
247
+ @dataclass
248
+ class ModelArguments:
249
+ """
250
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
251
+ """
252
+
253
+ model_name_or_path: Optional[str] = field(
254
+ default=None,
255
+ metadata={},
256
+ )
257
+ config_overrides: Optional[str] = field(
258
+ default=None,
259
+ metadata={
260
+ "help": (
261
+ "Override some existing default config settings when a model is trained from scratch. Example: "
262
+ "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
263
+ )
264
+ },
265
+ )
266
+ config_name: Optional[str] = field(
267
+ default=None,
268
+ metadata={
269
+ "help": "Pretrained config name or path if not the same as model_name"
270
+ },
271
+ )
272
+ tokenizer_name: Optional[str] = field(
273
+ default=None,
274
+ metadata={
275
+ "help": "Pretrained tokenizer name or path if not the same as model_name"
276
+ },
277
+ )
278
+ cache_dir: Optional[str] = field(
279
+ default=None,
280
+ metadata={
281
+ "help": "Where do you want to store the pretrained models downloaded from huggingface.co"
282
+ },
283
+ )
284
+ use_fast_tokenizer: bool = field(
285
+ default=True,
286
+ metadata={
287
+ "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
288
+ },
289
+ )
290
+ model_revision: str = field(
291
+ default="main",
292
+ metadata={
293
+ "help": "The specific model version to use (can be a branch name, tag name or commit id)."
294
+ },
295
+ )
296
+ token: str = field(
297
+ default=None,
298
+ metadata={
299
+ "help": (
300
+ "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
301
+ "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
302
+ )
303
+ },
304
+ )
305
+ use_auth_token: bool = field(
306
+ default=None,
307
+ metadata={
308
+ "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."
309
+ },
310
+ )
311
+ trust_remote_code: bool = field(
312
+ default=False,
313
+ metadata={
314
+ "help": (
315
+ "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
316
+ "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
317
+ "execute code present on the Hub on your local machine."
318
+ )
319
+ },
320
+ )
321
+ low_cpu_mem_usage: bool = field(
322
+ default=False,
323
+ metadata={
324
+ "help": (
325
+ "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded. "
326
+ "set True will benefit LLM loading time and RAM consumption."
327
+ )
328
+ },
329
+ )
330
+ torch_dtype: Optional[str] = field(
331
+ default=None,
332
+ metadata={
333
+ "help": (
334
+ "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
335
+ "dtype will be automatically derived from the model's weights."
336
+ ),
337
+ "choices": ["auto", "bfloat16", "float16", "float32"],
338
+ },
339
+ )
340
+ attn_implementation: Optional[str] = field(
341
+ default="sdpa",
342
+ metadata={
343
+ "help": ("The attention implementation to use in the model."),
344
+ "choices": ["eager", "sdpa", "flash_attention_2"],
345
+ },
346
+ )
347
+ classifier_dropout: Optional[float] = field(
348
+ default=0.1, metadata={"help": "The dropout rate for models"}
349
+ )
350
+ peft_addr: Optional[str] = field(
351
+ default=None, metadata={"help": "addr of lora adapter weights"}
352
+ )
353
+ model_class: str = field(
354
+ default="custom",
355
+ metadata={
356
+ "help": "One of the items 'custom' or 'auto'. 'custom' for LLM2Vec models and 'auto' for pretrained encoders such as BERT.",
357
+ "choices": ["custom", "auto"],
358
+ },
359
+ )
360
+ merge_subwords: bool = field(
361
+ default=True,
362
+ metadata={"help": "Whether the representations of the subtokens get averaged."},
363
+ )
364
+ bidirectional: bool = field(
365
+ default=True, metadata={"help": "Whether to use bidirectional attention."}
366
+ )
367
+
368
+ def __post_init__(self):
369
+ if self.config_overrides is not None and (
370
+ self.config_name is not None or self.model_name_or_path is not None
371
+ ):
372
+ raise ValueError(
373
+ "--config_overrides can't be used in combination with --config_name or --model_name_or_path"
374
+ )
375
+
376
+
377
+ @dataclass
378
+ class DataTrainingArguments:
379
+ """
380
+ Arguments pertaining to what data we are going to input our model for training and eval.
381
+ """
382
+
383
+ dataset_name: Optional[str] = field(
384
+ default=None,
385
+ metadata={"help": "The name of the dataset to use (via the datasets library)."},
386
+ )
387
+ dataset_config_name: Optional[str] = field(
388
+ default=None,
389
+ metadata={
390
+ "help": "The configuration name of the dataset to use (via the datasets library)."
391
+ },
392
+ )
393
+ train_file: Optional[str] = field(
394
+ default=None, metadata={"help": "The input training data file (a text file)."}
395
+ )
396
+ validation_file: Optional[str] = field(
397
+ default=None,
398
+ metadata={
399
+ "help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."
400
+ },
401
+ )
402
+ overwrite_cache: bool = field(
403
+ default=True,
404
+ metadata={"help": "Overwrite the cached training and evaluation sets"},
405
+ )
406
+ validation_split_percentage: Optional[int] = field(
407
+ default=5,
408
+ metadata={
409
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
410
+ },
411
+ )
412
+ max_seq_length: Optional[int] = field(
413
+ default=None,
414
+ metadata={
415
+ "help": (
416
+ "The maximum total input sequence length after tokenization. Sequences longer "
417
+ "than this will be truncated."
418
+ )
419
+ },
420
+ )
421
+ preprocessing_num_workers: Optional[int] = field(
422
+ default=None,
423
+ metadata={"help": "The number of processes to use for the preprocessing."},
424
+ )
425
+ mlm_probability: float = field(
426
+ default=0.15,
427
+ metadata={"help": "Ratio of tokens to mask for masked language modeling loss"},
428
+ )
429
+ line_by_line: bool = field(
430
+ default=False,
431
+ metadata={
432
+ "help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."
433
+ },
434
+ )
435
+ pad_to_max_length: bool = field(
436
+ default=False,
437
+ metadata={
438
+ "help": (
439
+ "Whether to pad all samples to `max_seq_length`. "
440
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
441
+ )
442
+ },
443
+ )
444
+ max_train_samples: Optional[int] = field(
445
+ default=None,
446
+ metadata={
447
+ "help": (
448
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
449
+ "value if set."
450
+ )
451
+ },
452
+ )
453
+ max_eval_samples: Optional[int] = field(
454
+ default=None,
455
+ metadata={
456
+ "help": (
457
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
458
+ "value if set."
459
+ )
460
+ },
461
+ )
462
+ streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"})
463
+
464
+ def __post_init__(self):
465
+ if self.streaming:
466
+ require_version(
467
+ "datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`"
468
+ )
469
+
470
+ if (
471
+ self.dataset_name is None
472
+ and self.train_file is None
473
+ and self.validation_file is None
474
+ ):
475
+ raise ValueError(
476
+ "Need either a dataset name or a training/validation file."
477
+ )
478
+ else:
479
+ if self.train_file is not None:
480
+ extension = self.train_file.split(".")[-1]
481
+ if extension not in ["csv", "json", "txt"]:
482
+ raise ValueError(
483
+ "`train_file` should be a csv, a json or a txt file."
484
+ )
485
+ if self.validation_file is not None:
486
+ extension = self.validation_file.split(".")[-1]
487
+ if extension not in ["csv", "json", "txt"]:
488
+ raise ValueError(
489
+ "`validation_file` should be a csv, a json or a txt file."
490
+ )
491
+
492
+
493
+ # add more arguments
494
+ @dataclass
495
+ class CustomArguments:
496
+ """
497
+ Custom arguments for the script
498
+ """
499
+
500
+ stop_after_n_steps: int = field(
501
+ default=10000, metadata={"help": "Stop training after n steps"}
502
+ )
503
+ data_collator_type: str = field(
504
+ default="custom",
505
+ metadata={
506
+ "help": "The type of data collator. Options: custom, default, custom_no_random"
507
+ },
508
+ )
509
+ task: Optional[str] = field(
510
+ default="pos_tags",
511
+ metadata={
512
+ "help": "One of the 'pos_tags', 'chunk_tags', and 'ner_tags' choices",
513
+ "choices": ["pos_tags", "ner_tags", "chunk_tags"],
514
+ },
515
+ )
516
+ retroactive_labels: str = field(
517
+ default="next_token",
518
+ metadata={
519
+ "help": "Whether the tokens representations are used to predict the next token's labels. Options: same_token, next_word, next_token.",
520
+ "choices": ["next_token", "same_token"],
521
+ },
522
+ )
523
+
524
+
525
+ class StopTrainingCallback(TrainerCallback):
526
+ def __init__(self, stop_after_n_steps: int):
527
+ self.stop_after_n_steps = stop_after_n_steps
528
+
529
+ def on_step_end(self, args, state, control, **kwargs):
530
+ if state.global_step >= self.stop_after_n_steps:
531
+ control.should_training_stop = True
532
+
533
+
534
+ class WordTaskTrainer(Trainer):
535
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
536
+ # If we are executing this function, we are the process zero, so we don't check for that.
537
+ output_dir = output_dir if output_dir is not None else self.args.output_dir
538
+ os.makedirs(output_dir, exist_ok=True)
539
+ logger.info(f"Saving model checkpoint to {output_dir}")
540
+
541
+ torch.save(self.model.classifier, os.path.join(output_dir, "classifier.pt"))
542
+ self.tokenizer.save_pretrained(output_dir)
543
+
544
+ # Good practice: save your training arguments together with the trained model
545
+ torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
546
+
547
+
548
+ def main():
549
+ parser = HfArgumentParser(
550
+ (ModelArguments, DataTrainingArguments, TrainingArguments, CustomArguments)
551
+ )
552
+ # model_args, data_args, training_args, custom_args = parser.parse_args_into_dataclasses()
553
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
554
+ # If we pass only one argument to the script and it's the path to a json file,
555
+ # let's parse it to get our arguments.
556
+ model_args, data_args, training_args, custom_args = parser.parse_json_file(
557
+ json_file=os.path.abspath(sys.argv[1])
558
+ )
559
+ else:
560
+ (
561
+ model_args,
562
+ data_args,
563
+ training_args,
564
+ custom_args,
565
+ ) = parser.parse_args_into_dataclasses()
566
+
567
+ if training_args.gradient_checkpointing:
568
+ training_args.gradient_checkpointing_kwargs = {"use_reentrant": False}
569
+
570
+ if model_args.use_auth_token is not None:
571
+ warnings.warn(
572
+ "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead.",
573
+ FutureWarning,
574
+ )
575
+ if model_args.token is not None:
576
+ raise ValueError(
577
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
578
+ )
579
+ model_args.token = model_args.use_auth_token
580
+
581
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
582
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
583
+ send_example_telemetry("run_word_task", model_args, data_args)
584
+
585
+ # Setup logging
586
+ logging.basicConfig(
587
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
588
+ datefmt="%m/%d/%Y %H:%M:%S",
589
+ handlers=[logging.StreamHandler(sys.stdout)],
590
+ )
591
+
592
+ if training_args.should_log:
593
+ # The default of training_args.log_level is passive, so we set log level at info here to have that default.
594
+ transformers.utils.logging.set_verbosity_info()
595
+
596
+ log_level = training_args.get_process_log_level()
597
+ logger.setLevel(log_level)
598
+ datasets.utils.logging.set_verbosity(log_level)
599
+ transformers.utils.logging.set_verbosity(log_level)
600
+ transformers.utils.logging.enable_default_handler()
601
+ transformers.utils.logging.enable_explicit_format()
602
+
603
+ # Log on each process the small summary:
604
+ logger.warning(
605
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
606
+ + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
607
+ )
608
+ # Set the verbosity to info of the Transformers logger (on main process only):
609
+ logger.info(f"Training/evaluation parameters {training_args}")
610
+
611
+ # Set seed before initializing model.
612
+ set_seed(training_args.seed)
613
+
614
+ if data_args.dataset_name is not None:
615
+ # Downloading and loading a dataset from the hub.
616
+ raw_datasets = load_dataset(
617
+ data_args.dataset_name,
618
+ data_args.dataset_config_name,
619
+ cache_dir=model_args.cache_dir,
620
+ token=model_args.token,
621
+ streaming=data_args.streaming,
622
+ )
623
+ if "validation" not in raw_datasets.keys():
624
+ raw_datasets["validation"] = load_dataset(
625
+ data_args.dataset_name,
626
+ data_args.dataset_config_name,
627
+ split=f"train[:{data_args.validation_split_percentage}%]",
628
+ cache_dir=model_args.cache_dir,
629
+ token=model_args.token,
630
+ streaming=data_args.streaming,
631
+ )
632
+ raw_datasets["train"] = load_dataset(
633
+ data_args.dataset_name,
634
+ data_args.dataset_config_name,
635
+ split=f"train[{data_args.validation_split_percentage}%:]",
636
+ cache_dir=model_args.cache_dir,
637
+ token=model_args.token,
638
+ streaming=data_args.streaming,
639
+ )
640
+ else:
641
+ data_files = {}
642
+ if data_args.train_file is not None:
643
+ data_files["train"] = data_args.train_file
644
+ extension = data_args.train_file.split(".")[-1]
645
+ if data_args.validation_file is not None:
646
+ data_files["validation"] = data_args.validation_file
647
+ extension = data_args.validation_file.split(".")[-1]
648
+ if extension == "txt":
649
+ extension = "text"
650
+ raw_datasets = load_dataset(
651
+ extension,
652
+ data_files=data_files,
653
+ cache_dir=model_args.cache_dir,
654
+ token=model_args.token,
655
+ )
656
+
657
+ # If no validation data is there, validation_split_percentage will be used to divide the dataset.
658
+ if "validation" not in raw_datasets.keys():
659
+ raw_datasets["validation"] = load_dataset(
660
+ extension,
661
+ data_files=data_files,
662
+ split=f"train[:{data_args.validation_split_percentage}%]",
663
+ cache_dir=model_args.cache_dir,
664
+ token=model_args.token,
665
+ )
666
+ raw_datasets["train"] = load_dataset(
667
+ extension,
668
+ data_files=data_files,
669
+ split=f"train[{data_args.validation_split_percentage}%:]",
670
+ cache_dir=model_args.cache_dir,
671
+ token=model_args.token,
672
+ )
673
+
674
+ assert (
675
+ data_args.dataset_name in LABELS
676
+ and custom_args.task in LABELS[data_args.dataset_name]
677
+ ), f"LABELS[{data_args.dataset_name}][{custom_args.task}] is not defined."
678
+
679
+ config_kwargs = {
680
+ "num_labels": len(LABELS[data_args.dataset_name][custom_args.task]),
681
+ "id2label": {
682
+ i: lab
683
+ for (lab, i) in LABELS[data_args.dataset_name][custom_args.task].items()
684
+ },
685
+ "label2id": LABELS[data_args.dataset_name][custom_args.task],
686
+ "classifier_dropout": model_args.classifier_dropout,
687
+ }
688
+
689
+ tokenizer_kwargs = {
690
+ "cache_dir": model_args.cache_dir,
691
+ "use_fast": model_args.use_fast_tokenizer,
692
+ "revision": model_args.model_revision,
693
+ "token": model_args.token,
694
+ "trust_remote_code": model_args.trust_remote_code,
695
+ }
696
+ if model_args.tokenizer_name:
697
+ if "gpt" in model_args.tokenizer_name:
698
+ tokenizer_kwargs["add_prefix_space"] = True
699
+ tokenizer = AutoTokenizer.from_pretrained(
700
+ model_args.tokenizer_name, **tokenizer_kwargs
701
+ )
702
+ elif model_args.model_name_or_path:
703
+ if "gpt" in model_args.model_name_or_path:
704
+ tokenizer_kwargs["add_prefix_space"] = True
705
+ tokenizer = AutoTokenizer.from_pretrained(
706
+ model_args.model_name_or_path, **tokenizer_kwargs
707
+ )
708
+ else:
709
+ raise ValueError(
710
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script. "
711
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
712
+ )
713
+
714
+ if tokenizer.pad_token is None:
715
+ tokenizer.pad_token = tokenizer.eos_token
716
+ if model_args.model_class == "custom":
717
+ tokenizer.model_input_names.append("token_type_ids")
718
+ if model_args.model_class == "auto":
719
+ assert not model_args.merge_subwords
720
+
721
+ if model_args.model_class == "custom":
722
+ if model_args.config_name:
723
+ config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
724
+ elif model_args.model_name_or_path:
725
+ config = AutoConfig.from_pretrained(
726
+ model_args.model_name_or_path, **config_kwargs
727
+ )
728
+ else:
729
+ raise ValueError("Invalid config loading")
730
+
731
+ for k, v in config_kwargs.items():
732
+ config.__setattr__(k, v)
733
+
734
+ torch_dtype = (
735
+ model_args.torch_dtype
736
+ if model_args.torch_dtype in ["auto", None]
737
+ else getattr(torch, model_args.torch_dtype)
738
+ )
739
+ l2v = LLM2Vec.from_pretrained(
740
+ base_model_name_or_path=model_args.model_name_or_path,
741
+ enable_bidirectional=model_args.bidirectional,
742
+ peft_model_name_or_path=model_args.peft_addr,
743
+ merge_peft=False,
744
+ torch_dtype=torch_dtype,
745
+ attn_implementation=model_args.attn_implementation,
746
+ )
747
+
748
+ model = ModelForWordTask(
749
+ model=l2v.model,
750
+ merge_subwords=model_args.merge_subwords,
751
+ config=config,
752
+ torch_dtype=torch_dtype,
753
+ )
754
+
755
+ MyTrainer = WordTaskTrainer
756
+
757
+ elif model_args.model_class == "auto":
758
+ model = AutoModelForTokenClassification.from_pretrained(
759
+ model_args.model_name_or_path,
760
+ num_labels=config_kwargs["num_labels"],
761
+ id2label=config_kwargs["id2label"],
762
+ label2id=config_kwargs["label2id"],
763
+ )
764
+ MyTrainer = Trainer
765
+
766
+ else:
767
+ raise ValueError(
768
+ f"{model_args.model_class} is not implemented. Only 'auto' and 'custom' model_class options are valid."
769
+ )
770
+
771
+ # only train classifier
772
+ for n, p in list(model.named_parameters()):
773
+ if "classifier" in n:
774
+ p.requires_grad = True
775
+ else:
776
+ p.requires_grad = False
777
+
778
+ if data_args.max_seq_length is None:
779
+ max_seq_length = tokenizer.model_max_length
780
+ if max_seq_length > 1024:
781
+ logger.warning(
782
+ "The chosen tokenizer supports a `model_max_length` that is longer than the default `block_size` value"
783
+ " of 1024. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can"
784
+ " override this default with `--block_size xxx`."
785
+ )
786
+ max_seq_length = 1024
787
+ else:
788
+ if data_args.max_seq_length > tokenizer.model_max_length:
789
+ logger.warning(
790
+ f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the "
791
+ f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
792
+ )
793
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
794
+
795
+ def tokenize_and_align_labels(examples):
796
+ task = custom_args.task
797
+ padding = "max_length" if data_args.pad_to_max_length else False
798
+ tokenized_inputs = tokenizer(
799
+ examples["tokens"],
800
+ truncation=True,
801
+ is_split_into_words=True,
802
+ padding=padding,
803
+ max_length=max_seq_length,
804
+ )
805
+
806
+ labels = []
807
+ words = []
808
+ for i, label in enumerate(examples[task]):
809
+ if custom_args.retroactive_labels in ["same_token"]:
810
+ word_ids = tokenized_inputs.word_ids(batch_index=i)
811
+ previous_word_idx = None
812
+ label_ids = []
813
+ for word_idx in word_ids:
814
+ if word_idx is None:
815
+ label_ids.append(-100)
816
+ elif word_idx != previous_word_idx:
817
+ label_ids.append(label[word_idx])
818
+ else:
819
+ label_ids.append(-100)
820
+ previous_word_idx = word_idx
821
+ labels.append(label_ids)
822
+ word_ids = [-1 if w is None else w for w in word_ids]
823
+ words.append(word_ids)
824
+ elif custom_args.retroactive_labels == "next_token":
825
+ word_ids = tokenized_inputs.word_ids(batch_index=i)
826
+ previous_word_idx = None
827
+ label_ids = []
828
+ for word_idx in word_ids:
829
+ if word_idx is None:
830
+ label_ids.append(-100)
831
+ elif word_idx != previous_word_idx:
832
+ label_ids.append(label[word_idx])
833
+ else:
834
+ label_ids.append(-100)
835
+ previous_word_idx = word_idx
836
+ label_ids.append(-100)
837
+ labels.append(label_ids[1:])
838
+ word_ids = word_ids[1:] + [None]
839
+ word_ids = [-1 if w is None else w for w in word_ids]
840
+ words.append(word_ids)
841
+ else:
842
+ raise ValueError(
843
+ f"retroactive_labels {custom_args.retroactive_labels} is not implemented."
844
+ )
845
+
846
+ tokenized_inputs["labels"] = labels
847
+ if model_args.model_class == "custom":
848
+ tokenized_inputs["token_type_ids"] = words
849
+ return tokenized_inputs
850
+
851
+ tokenized_dataset = raw_datasets.map(
852
+ tokenize_and_align_labels,
853
+ batched=True,
854
+ remove_columns=list(LABELS[data_args.dataset_name].keys()) + ["tokens", "id"],
855
+ load_from_cache_file=not data_args.overwrite_cache,
856
+ )
857
+ data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
858
+ seqeval = evaluate.load("seqeval")
859
+
860
+ def compute_metrics(p):
861
+ predictions, labels = p
862
+ predictions = predictions[0]
863
+ predictions = np.argmax(predictions, axis=2)
864
+
865
+ true_predictions = [
866
+ [
867
+ config_kwargs["id2label"][p]
868
+ for (p, l) in zip(prediction, label)
869
+ if l != -100
870
+ ]
871
+ for prediction, label in zip(predictions, labels)
872
+ ]
873
+ true_labels = [
874
+ [
875
+ config_kwargs["id2label"][l]
876
+ for (p, l) in zip(prediction, label)
877
+ if l != -100
878
+ ]
879
+ for prediction, label in zip(predictions, labels)
880
+ ]
881
+
882
+ results = seqeval.compute(predictions=true_predictions, references=true_labels)
883
+ return {
884
+ "precision": results["overall_precision"],
885
+ "recall": results["overall_recall"],
886
+ "f1": results["overall_f1"],
887
+ "accuracy": results["overall_accuracy"],
888
+ }
889
+
890
+ trainer = MyTrainer(
891
+ model=model,
892
+ args=training_args,
893
+ train_dataset=tokenized_dataset["train"],
894
+ eval_dataset=tokenized_dataset["validation"],
895
+ tokenizer=tokenizer,
896
+ data_collator=data_collator,
897
+ compute_metrics=compute_metrics,
898
+ )
899
+ trainer.add_callback(StopTrainingCallback(custom_args.stop_after_n_steps))
900
+
901
+ trainer.train()
902
+
903
+
904
+ if __name__ == "__main__":
905
+ main()
llm2vec/experiments/test_word_task.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import logging
4
+ import argparse
5
+ from transformers import (
6
+ AutoTokenizer,
7
+ AutoConfig,
8
+ AutoModelForTokenClassification,
9
+ set_seed,
10
+ HfArgumentParser,
11
+ )
12
+ import torch
13
+ from datasets import load_dataset
14
+ import evaluate
15
+ import json
16
+ from tqdm import tqdm
17
+ from run_word_task import ModelForWordTask
18
+ from llm2vec import LLM2Vec
19
+
20
+
21
+ LABELS = {
22
+ "conll2003": {
23
+ "pos_tags": {
24
+ '"': 0,
25
+ "''": 1,
26
+ "#": 2,
27
+ "$": 3,
28
+ "(": 4,
29
+ ")": 5,
30
+ ",": 6,
31
+ ".": 7,
32
+ ":": 8,
33
+ "``": 9,
34
+ "CC": 10,
35
+ "CD": 11,
36
+ "DT": 12,
37
+ "EX": 13,
38
+ "FW": 14,
39
+ "IN": 15,
40
+ "JJ": 16,
41
+ "JJR": 17,
42
+ "JJS": 18,
43
+ "LS": 19,
44
+ "MD": 20,
45
+ "NN": 21,
46
+ "NNP": 22,
47
+ "NNPS": 23,
48
+ "NNS": 24,
49
+ "NN|SYM": 25,
50
+ "PDT": 26,
51
+ "POS": 27,
52
+ "PRP": 28,
53
+ "PRP$": 29,
54
+ "RB": 30,
55
+ "RBR": 31,
56
+ "RBS": 32,
57
+ "RP": 33,
58
+ "SYM": 34,
59
+ "TO": 35,
60
+ "UH": 36,
61
+ "VB": 37,
62
+ "VBD": 38,
63
+ "VBG": 39,
64
+ "VBN": 40,
65
+ "VBP": 41,
66
+ "VBZ": 42,
67
+ "WDT": 43,
68
+ "WP": 44,
69
+ "WP$": 45,
70
+ "WRB": 46,
71
+ },
72
+ "chunk_tags": {
73
+ "O": 0,
74
+ "B-ADJP": 1,
75
+ "I-ADJP": 2,
76
+ "B-ADVP": 3,
77
+ "I-ADVP": 4,
78
+ "B-CONJP": 5,
79
+ "I-CONJP": 6,
80
+ "B-INTJ": 7,
81
+ "I-INTJ": 8,
82
+ "B-LST": 9,
83
+ "I-LST": 10,
84
+ "B-NP": 11,
85
+ "I-NP": 12,
86
+ "B-PP": 13,
87
+ "I-PP": 14,
88
+ "B-PRT": 15,
89
+ "I-PRT": 16,
90
+ "B-SBAR": 17,
91
+ "I-SBAR": 18,
92
+ "B-UCP": 19,
93
+ "I-UCP": 20,
94
+ "B-VP": 21,
95
+ "I-VP": 22,
96
+ },
97
+ "ner_tags": {
98
+ "O": 0,
99
+ "B-PER": 1,
100
+ "I-PER": 2,
101
+ "B-ORG": 3,
102
+ "I-ORG": 4,
103
+ "B-LOC": 5,
104
+ "I-LOC": 6,
105
+ "B-MISC": 7,
106
+ "I-MISC": 8,
107
+ },
108
+ }
109
+ }
110
+
111
+
112
+ def str2bool(v):
113
+ if isinstance(v, bool):
114
+ return v
115
+ if v.lower() in ("yes", "true", "t", "y", "1"):
116
+ return True
117
+ elif v.lower() in ("no", "false", "f", "n", "0"):
118
+ return False
119
+ else:
120
+ raise argparse.ArgumentTypeError("Boolean value expected.")
121
+
122
+
123
+ if __name__ == "__main__":
124
+ logging.basicConfig(level=logging.INFO)
125
+ parser = argparse.ArgumentParser()
126
+ parser.add_argument("--model_class", default="custom", type=str)
127
+ parser.add_argument("--model_name_or_path", default=None, type=str)
128
+ parser.add_argument(
129
+ "--peft_addr",
130
+ default=None,
131
+ type=str,
132
+ help="The dir address where adapter_model.bin is saved.",
133
+ )
134
+ parser.add_argument(
135
+ "--cls_addr",
136
+ default=None,
137
+ type=str,
138
+ help="The dir address where classifier is saved.",
139
+ )
140
+ parser.add_argument("--bidirectional", default=True, type=str2bool)
141
+ parser.add_argument("--merge_subwords", default=True, type=str2bool)
142
+ parser.add_argument("--output_dir", default=None, type=str)
143
+ parser.add_argument("--classifier_dropout", default=0.1, type=float)
144
+ parser.add_argument(
145
+ "--attn_implementation",
146
+ default="sdpa",
147
+ type=str,
148
+ choices=["sdpa", "eager", "flash_attention_2"],
149
+ )
150
+ parser.add_argument(
151
+ "--torch_dtype",
152
+ default=None,
153
+ type=str,
154
+ choices=["auto", "bfloat16", "float16", "float32"],
155
+ )
156
+
157
+ parser.add_argument(
158
+ "--retroactive_labels",
159
+ default="next_token",
160
+ type=str,
161
+ choices=["next_token", "same_token"],
162
+ )
163
+ parser.add_argument("--dataset_name", default=None, type=str)
164
+ parser.add_argument(
165
+ "--task", default=None, type=str, choices=["pos_tags", "chunk_tags", "ner_tags"]
166
+ )
167
+ parser.add_argument("--max_seq_length", default=1024, type=int)
168
+ parser.add_argument("--batch_size", default=32, type=int)
169
+ parser.add_argument("--seed", default=32, type=int)
170
+
171
+ parser.add_argument("--config_file", default=None, type=str)
172
+
173
+ args = parser.parse_args()
174
+
175
+ if args.config_file is not None:
176
+ # If we pass only one argument to the script and it's the path to a json file,
177
+ # let's parse it to get our arguments.
178
+ from pathlib import Path
179
+ import json
180
+
181
+ json_text = json.load(open(os.path.abspath(args.config_file)))
182
+ argparse_dict = vars(args)
183
+ argparse_dict.update(json_text)
184
+ # args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
185
+ else:
186
+ args = parser.parse_args()
187
+
188
+ path_to_check = args.peft_addr if args.peft_addr else args.model_name_or_path
189
+ assert (
190
+ args.output_dir is not None
191
+ ), "If you want to evaluate a model, you have to provide the output_dir"
192
+ os.makedirs(args.output_dir, exist_ok=True)
193
+
194
+ set_seed(args.seed)
195
+
196
+ tokenizer_kwargs = {}
197
+ if "gpt" in args.model_name_or_path:
198
+ tokenizer_kwargs["add_prefix_space"] = True
199
+ tokenizer = AutoTokenizer.from_pretrained(
200
+ args.model_name_or_path, **tokenizer_kwargs
201
+ )
202
+ if tokenizer.pad_token_id is None:
203
+ tokenizer.pad_token = tokenizer.eos_token
204
+
205
+ if args.model_class == "custom":
206
+ tokenizer.model_input_names.append("token_type_ids")
207
+
208
+ if args.model_class == "auto":
209
+ assert not args.merge_subwords
210
+
211
+ assert (
212
+ args.dataset_name in LABELS and args.task in LABELS[args.dataset_name]
213
+ ), f"LABELS[{args.dataset_name}][{args.task}] is not defined."
214
+
215
+ config_kwargs = {
216
+ "num_labels": len(LABELS[args.dataset_name][args.task]),
217
+ "id2label": {
218
+ i: lab for (lab, i) in LABELS[args.dataset_name][args.task].items()
219
+ },
220
+ "label2id": LABELS[args.dataset_name][args.task],
221
+ "classifier_dropout": args.classifier_dropout,
222
+ }
223
+
224
+ if args.model_class == "custom":
225
+ if args.model_name_or_path:
226
+ config = AutoConfig.from_pretrained(
227
+ args.model_name_or_path, **config_kwargs
228
+ )
229
+ else:
230
+ raise ValueError("Invalid config loading")
231
+
232
+ for k, v in config_kwargs.items():
233
+ config.__setattr__(k, v)
234
+
235
+ torch_dtype = (
236
+ args.torch_dtype
237
+ if args.torch_dtype in ["auto", None]
238
+ else getattr(torch, args.torch_dtype)
239
+ )
240
+ l2v = LLM2Vec.from_pretrained(
241
+ base_model_name_or_path=args.model_name_or_path,
242
+ enable_bidirectional=args.bidirectional,
243
+ peft_model_name_or_path=args.peft_addr,
244
+ merge_peft=False,
245
+ torch_dtype=torch_dtype,
246
+ attn_implementation=args.attn_implementation,
247
+ )
248
+ model = ModelForWordTask(
249
+ model=l2v.model,
250
+ merge_subwords=args.merge_subwords,
251
+ config=config,
252
+ torch_dtype=torch_dtype,
253
+ )
254
+
255
+ classifier_path = os.path.join(args.cls_addr, "classifier.pt")
256
+ if os.path.exists(classifier_path):
257
+ print(f"Loading classifier from {classifier_path}")
258
+ model.classifier = torch.load(classifier_path)
259
+ else:
260
+ raise ValueError("classifier does not exist in", classifier_path)
261
+
262
+ elif args.model_class == "auto":
263
+ model = AutoModelForTokenClassification.from_pretrained(
264
+ args.model_name_or_path,
265
+ num_labels=len(LABELS[args.dataset_name][args.task]),
266
+ id2label={
267
+ i: lab for (lab, i) in LABELS[args.dataset_name][args.task].items()
268
+ },
269
+ label2id=LABELS[args.dataset_name][args.task],
270
+ )
271
+ else:
272
+ raise ValueError(
273
+ f"{args.model_class} is not implemented. Only auto and custom model_class options are valid."
274
+ )
275
+
276
+ model = model.cuda()
277
+
278
+ raw_datasets = load_dataset(args.dataset_name, split="test")
279
+
280
+ def tokenize_and_align_labels(examples):
281
+ task = args.task
282
+ tokenized_inputs = tokenizer(
283
+ examples["tokens"],
284
+ truncation=True,
285
+ is_split_into_words=True,
286
+ padding="max_length",
287
+ max_length=args.max_seq_length,
288
+ return_tensors="pt",
289
+ )
290
+
291
+ labels = []
292
+ words = []
293
+ for i, label in enumerate(examples[task]):
294
+ if args.retroactive_labels in ["same_token"]:
295
+ # if args.retroactive_labels == "next_word":
296
+ # label = label[1:] + [-100]
297
+ word_ids = tokenized_inputs.word_ids(batch_index=i)
298
+ previous_word_idx = None
299
+ label_ids = []
300
+ for word_idx in word_ids:
301
+ if word_idx is None:
302
+ label_ids.append(-100)
303
+ elif word_idx != previous_word_idx:
304
+ label_ids.append(label[word_idx])
305
+ else:
306
+ label_ids.append(-100)
307
+ previous_word_idx = word_idx
308
+ labels.append(label_ids)
309
+ word_ids = [-1 if w is None else w for w in word_ids]
310
+ words.append(word_ids)
311
+ elif args.retroactive_labels == "next_token":
312
+ word_ids = tokenized_inputs.word_ids(batch_index=i)
313
+ previous_word_idx = None
314
+ label_ids = []
315
+ for word_idx in word_ids:
316
+ if word_idx is None:
317
+ label_ids.append(-100)
318
+ elif word_idx != previous_word_idx:
319
+ label_ids.append(label[word_idx])
320
+ else:
321
+ label_ids.append(-100)
322
+ previous_word_idx = word_idx
323
+ label_ids.append(-100)
324
+ labels.append(label_ids[1:])
325
+ word_ids = word_ids[1:] + [None]
326
+ word_ids = [-1 if w is None else w for w in word_ids]
327
+ words.append(word_ids)
328
+ else:
329
+ raise ValueError(
330
+ f"retroactive_labels {args.retroactive_labels} is not implemented."
331
+ )
332
+
333
+ tokenized_inputs["labels"] = torch.tensor(labels)
334
+ if args.model_class == "custom":
335
+ tokenized_inputs["token_type_ids"] = words
336
+ return tokenized_inputs
337
+
338
+ tokenized_dataset = raw_datasets.map(
339
+ tokenize_and_align_labels,
340
+ batched=True,
341
+ remove_columns=list(LABELS[args.dataset_name].keys()) + ["tokens", "id"],
342
+ )
343
+ with torch.no_grad():
344
+ predictions = None
345
+ labels = None
346
+ for batch_begin in tqdm(
347
+ torch.arange(0, len(tokenized_dataset), args.batch_size)
348
+ ):
349
+ features = {
350
+ "input_ids": torch.tensor(
351
+ tokenized_dataset[batch_begin : batch_begin + args.batch_size][
352
+ "input_ids"
353
+ ]
354
+ ).to(model.device),
355
+ "attention_mask": torch.tensor(
356
+ tokenized_dataset[batch_begin : batch_begin + args.batch_size][
357
+ "attention_mask"
358
+ ]
359
+ ).to(model.device),
360
+ }
361
+ if (
362
+ "token_type_ids"
363
+ in tokenized_dataset[batch_begin : batch_begin + args.batch_size]
364
+ ):
365
+ features["token_type_ids"] = torch.tensor(
366
+ tokenized_dataset[batch_begin : batch_begin + args.batch_size][
367
+ "token_type_ids"
368
+ ]
369
+ ).to(model.device)
370
+
371
+ labs = torch.tensor(
372
+ tokenized_dataset[batch_begin : batch_begin + args.batch_size]["labels"]
373
+ )
374
+
375
+ logits = model(**features).logits
376
+ preds = torch.argmax(logits, dim=-1)
377
+ if predictions is None:
378
+ predictions = preds
379
+ labels = labs
380
+ else:
381
+ predictions = torch.concatenate((predictions, preds))
382
+ labels = torch.concatenate((labels, labs))
383
+
384
+ precision_metric = evaluate.load("precision")
385
+ metrics = precision_metric.compute(
386
+ references=labels[labels != -100],
387
+ predictions=predictions[labels != -100],
388
+ average="micro",
389
+ )
390
+
391
+ with open(os.path.join(args.output_dir, "result_summary.json"), "w") as f:
392
+ json.dump(metrics, f)
393
+ print(metrics)
llm2vec/images/sample_efficient.png ADDED