diff --git a/code/RL_model/verl/Search-R1/dataset/data_prep.py b/code/RL_model/verl/Search-R1/dataset/data_prep.py new file mode 100644 index 0000000000000000000000000000000000000000..2be1042640a754eab0eeb916616e656bffd67db8 --- /dev/null +++ b/code/RL_model/verl/Search-R1/dataset/data_prep.py @@ -0,0 +1,88 @@ +import os +import json +import datasets +import argparse +from verl.utils.hdfs_io import copy, makedirs + +# 1. Define the exact Prompt Template from your requirements +# /home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/prompt +with open("/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/prompt", 'r') as f: + PROMPT_TEMPLATE = f.read() + +def make_map_fn(split, data_source): + def process_fn(example, idx): + # Extract fields from your specific JSON keys: ['id', 'fulltext', 'summary'] + full_text = example.pop('fulltext') + gold_summary = example.pop('summary') + + # Format the prompt using your template + # Note: Added 'English' as default source lang based on filename + prompt_content = PROMPT_TEMPLATE.format( + source_lang="English", + gold_summary=gold_summary, + full_text=full_text + ) + + return { + "data_source": data_source, + "prompt": [{ + "role": "user", + "content": prompt_content + }], + "ability": "summarization", + "reward_model": { + "style": "rule", + "ground_truth": gold_summary + }, + "extra_info": { + "split": split, + "index": idx, + "original_id": example.get('id', idx) + } + } + return process_fn + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + # Path to your input JSON + parser.add_argument('--input_path', default='/home/mshahidul/readctrl/data/processed_test_raw_data/multiclinsum_test_en.json') + # Updated destination as requested + parser.add_argument('--local_dir', default='/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset') + args = parser.parse_args() + + data_source = 'multiclinsum' + + # Load your local JSON file + with open(args.input_path, 'r') as f: + raw_data = json.load(f) + + # Convert to HuggingFace Dataset + dataset = datasets.Dataset.from_list(raw_data) + + # Split into train/test (95% train, 5% test) + split_dataset = dataset.train_test_split(test_size=0.05, seed=42) + + # Apply the mapping transformation for each split + processed_train = split_dataset["train"].map( + function=make_map_fn('train', data_source), + with_indices=True + ) + processed_test = split_dataset["test"].map( + function=make_map_fn('test', data_source), + with_indices=True + ) + + # Create the directory if it doesn't exist + os.makedirs(args.local_dir, exist_ok=True) + + # Save to Parquet in the specified location + train_output_path = os.path.join(args.local_dir, 'train.parquet') + test_output_path = os.path.join(args.local_dir, 'test.parquet') + processed_train.to_parquet(train_output_path) + processed_test.to_parquet(test_output_path) + + print(f"--- Dataset Preparation Complete ---") + print(f"Train file saved to: {train_output_path}") + print(f"Test file saved to: {test_output_path}") + print(f"Total train records: {len(processed_train)}") + print(f"Total test records: {len(processed_test)}") \ No newline at end of file diff --git a/code/RL_model/verl/Search-R1/dataset/prompt b/code/RL_model/verl/Search-R1/dataset/prompt new file mode 100644 index 0000000000000000000000000000000000000000..084bb706dafafee7913a406ccb6fbffa524be840 --- /dev/null +++ b/code/RL_model/verl/Search-R1/dataset/prompt @@ -0,0 +1,58 @@ +**System Role:** + +You are an expert medical editor and Health Literacy specialist. Your task is to transform complex medical text into three distinct versions based on the reader's health literacy level. You must maintain the source language of the input while adjusting the linguistic complexity. Use the provided Gold Summary as the factual anchor to ensure the simplified versions remain accurate and focused on the most important information. + +**User Prompt:** + +Please process the following medical Source Text and its corresponding Gold Summary to generate three versions tailored to different health literacy levels. +### Instructions for Each Level: + +1. Level: Low Health Literacy (High Readability) + +Target: Individuals needing the simplest terms for immediate action. + +Linguistic Goal: Use "living room" language. Replace all medical jargon with functional descriptions (e.g., "renal" becomes "kidney"). + +Information Density: Focus strictly on the "need-to-know" info found in the Gold Summary. + +Strategy: High paraphrasing using analogies. One idea per sentence. + +Faithfulness: Must align perfectly with the Gold Summary. + +2. Level: Intermediate Health Literacy (Medium Readability) + +Target: The general public (news-reading level). + +Linguistic Goal: Standard vocabulary. Common medical terms are okay, but technical "doctor-speak" must be simplified. + +Information Density: Balanced. Use the Gold Summary as the lead, supplemented by necessary context from the Source Text. + +Strategy: Moderate paraphrasing. Remove minor technical details to avoid information overload. + +Faithfulness: Maintains the main narrative of the Gold Summary. + +3. Level: Proficient Health Literacy (Low Readability) + +Target: Researchers, clinicians, or highly informed patients. + +Linguistic Goal: Technical and academic language. Prioritize clinical nuance and medical accuracy. + +Information Density: High. Use the Full Source Text to include data, physiological mechanisms, and statistics. + +Strategy: Minimal paraphrasing. Retain all original technical terminology. + +Faithfulness: Adhere to the Source Text; you may add related subclaims that provide deeper scientific context. + + +I will provide the following information: + +- Input Language: <<>> +- Gold Summary (the anchor reference summary): <<>> +- Source Text (detailed content): <<>> + +**Output Format (JSON only):** + {{ + "low_health_literacy": "...", + "intermediate_health_literacy": "...", + "proficient_health_literacy": "..." + }} \ No newline at end of file diff --git a/code/RL_model/verl/Search-R1/outputs/2026-02-01/20-26-44/main_ppo.log b/code/RL_model/verl/Search-R1/outputs/2026-02-01/20-26-44/main_ppo.log new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/code/RL_model/verl/Search-R1/search_r1/__init__.py b/code/RL_model/verl/Search-R1/search_r1/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/code/RL_model/verl/Search-R1/verl.egg-info/PKG-INFO b/code/RL_model/verl/Search-R1/verl.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..515534aa3a3ae1a8c5031bac6340cb3f5a6d08f4 --- /dev/null +++ b/code/RL_model/verl/Search-R1/verl.egg-info/PKG-INFO @@ -0,0 +1,507 @@ +Metadata-Version: 2.4 +Name: verl +Version: 0.1 +Summary: veRL: Volcano Engine Reinforcement Learning for LLM +Home-page: https://github.com/volcengine/verl +Author: Bytedance - Seed - MLSys +Author-email: Bytedance - Seed - MLSys , Bytedance - Seed - MLSys +License: + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +Project-URL: Homepage, https://github.com/volcengine/verl +Requires-Python: >=3.8 +Description-Content-Type: text/markdown +License-File: LICENSE +Requires-Dist: accelerate +Requires-Dist: codetiming +Requires-Dist: datasets +Requires-Dist: dill +Requires-Dist: hydra-core +Requires-Dist: numpy +Requires-Dist: pybind11 +Requires-Dist: ray +Requires-Dist: tensordict +Requires-Dist: transformers<4.48 +Requires-Dist: vllm<=0.6.3 +Provides-Extra: test +Requires-Dist: pytest; extra == "test" +Requires-Dist: yapf; extra == "test" +Dynamic: author +Dynamic: home-page +Dynamic: license-file + +# Search-R1: Train your LLMs to reason and call a search engine with reinforcement learning + +
+ logo +
+ +

+ + Button1 + + + Button2 + + + Button3 + + + Button4 + + + Button5 + +

+ + + + +**Search-R1** is a reinforcement learning framework designed for training **reasoning-and-searching interleaved LLMs**—language models that learn to reason and make tool calls (e.g., to search engines) in a coordinated manner. + + +Built upon [veRL](https://github.com/volcengine/verl), Search-R1 extends the ideas of **DeepSeek-R1(-Zero)** by incorporating interleaved search engine access and provides a fully open-source RL training pipeline. It serves as an alternative and open solution to **OpenAI DeepResearch**, enabling research and development in tool-augmented LLM reasoning. + + + +We support different RL methods (e.g., PPO, GRPO, reinforce), different LLMs (e.g., llama3, Qwen2.5, etc) and different search engines (e.g., local sparse/dense retrievers and online search engines). + +Paper: [link1](https://arxiv.org/pdf/2503.09516), [link2](https://arxiv.org/abs/2505.15117); Model and data: [link](https://huggingface.co/collections/PeterJinGo/search-r1-67d1a021202731cb065740f5); Twitter thread: [link](https://x.com/BowenJin13/status/1895544294473109889); Full experiment log: [prelim](https://wandb.ai/peterjin/Search-R1-open); [v0.1](https://wandb.ai/peterjin/Search-R1-nq_hotpotqa_train); [v0.2](https://wandb.ai/peterjin/Search-R1-v0.2); [v0.3](https://wandb.ai/peterjin/Search-R1-v0.3). Details about these logs and methods can be find [here](https://github.com/PeterGriffinJin/Search-R1/blob/main/docs/experiment_log.md). + + +![single-turn](public/main.png) + +## News + +- [2025.10] Search-R1 is featured by Thinking Machines Lab's first product [Tinker](https://github.com/thinking-machines-lab/tinker-cookbook)! Details: [Document](https://github.com/thinking-machines-lab/tinker-cookbook/tree/main/tinker_cookbook/recipes/tool_use/search). +- [2025.7] Search-R1 is supported by [SkyRL](https://github.com/NovaSky-AI/SkyRL)! Detailed instructions: [code](https://github.com/NovaSky-AI/SkyRL/tree/main/skyrl-train/examples/search), [Document](https://novasky-ai.notion.site/skyrl-searchr1). +- [2025.6] Search-R1 is now integrated into the latest version of veRL and can take advantage of its most up-to-date features! Detailed instructions: [veRL](https://verl.readthedocs.io/en/latest/sglang_multiturn/search_tool_example.html), [English Document](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/multi-turn/tool_examples/verl-multiturn-searchR1-like.md), [Chinese Document](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/multi-turn/tool_examples/verl-multiturn-searchR1-like_ZH.md). +- [2025.5] The second [paper](https://arxiv.org/abs/2505.15117) conducting detailed empirical studies is published with logs: [v0.3](https://wandb.ai/peterjin/Search-R1-v0.3). +- [2025.4] We support [multinode](https://github.com/PeterGriffinJin/Search-R1/blob/main/docs/multinode.md) training for 30B+ LLMs! +- [2025.4] We support [different search engines](https://github.com/PeterGriffinJin/Search-R1/blob/main/docs/retriever.md) including sparse local retriever, dense local retriever with ANN indexing and online search engines! +- [2025.3] The first Search-R1 [paper](https://arxiv.org/pdf/2503.09516) is published with the logs: [v0.1](https://wandb.ai/peterjin/Search-R1-nq_hotpotqa_train); [v0.2](https://wandb.ai/peterjin/Search-R1-v0.2). +- [2025.2] We opensource Search-R1 codebase with [preliminary results](https://wandb.ai/peterjin/Search-R1-open). + +## Links + +- [Installation](#installation) +- [Quick start](#quick-start) +- [Preliminary results](#preliminary-results) +- [Inference](#inference) +- [Use your own dataset](#use-your-own-dataset) +- [Use your own search engine](#use-your-own-search-engine) +- [Features](#features) +- [Ackowledge](#acknowledge) +- [Citations](#citations) + +## Installation + +### Search-r1 environment +```bash +conda create -n searchr1 python=3.9 +conda activate searchr1 +# install torch [or you can skip this step and let vllm to install the correct version for you] +pip install torch==2.4.0 --index-url https://download.pytorch.org/whl/cu121 +# install vllm +pip3 install vllm==0.6.3 # or you can install 0.5.4, 0.4.2 and 0.3.1 + +# verl +pip install -e . + +# flash attention 2 +pip3 install flash-attn --no-build-isolation +pip install wandb +``` + +### Retriever environment (optional) +If you would like to call a local retriever as the search engine, you can install the environment as follows. (We recommend using a seperate environment.) +```bash +conda create -n retriever python=3.10 +conda activate retriever + +# we recommend installing torch with conda for faiss-gpu +conda install pytorch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 pytorch-cuda=12.1 -c pytorch -c nvidia +pip install transformers datasets pyserini + +## install the gpu version faiss to guarantee efficient RL rollout +conda install -c pytorch -c nvidia faiss-gpu=1.8.0 + +## API function +pip install uvicorn fastapi +``` + + +## Quick start + +Train a reasoning + search LLM on NQ dataset with e5 as the retriever and wikipedia as the corpus. + +(1) Download the indexing and corpus. +```bash +save_path=/the/path/to/save +python scripts/download.py --save_path $save_path +cat $save_path/part_* > $save_path/e5_Flat.index +gzip -d $save_path/wiki-18.jsonl.gz +``` + +(2) Process the NQ dataset. +```bash +python scripts/data_process/nq_search.py +``` + +(3) Launch a local retrieval server. +```bash +conda activate retriever +bash retrieval_launch.sh +``` + +(4) Run RL training (PPO) with Llama-3.2-3b-base. +```bash +conda activate searchr1 +bash train_ppo.sh +``` + +## Preliminary results + +(1) The base model (llama3.2-3b-base) learns to call the search engine and obtain improved performance. + +![llama-3b](public/llama32-3b.png) + + +(2) The base model (Qwen2.5-7b-base) can learn to conduct multi-turn search engine calling and reasoning with RL. + +![multi-turn](public/multi-turn.png) + +## Inference +#### You can play with the trained Search-R1 model with your own question. +(1) Launch a local retrieval server. +```bash +conda activate retriever +bash retrieval_launch.sh +``` + +(2) Run inference. +```bash +conda activate searchr1 +python infer.py +``` +You can modify the ```question``` on line 7 to something you're interested in. + +## Use your own dataset + +### QA data +For each question-answer sample, it should be a dictionary containing the desired content as below: + +``` +data = { + "data_source": data_source, + "prompt": [{ + "role": "user", + "content": question, + }], + "ability": "fact-reasoning", + "reward_model": { + "style": "rule", + "ground_truth": solution + }, + "extra_info": { + 'split': split, + 'index': idx, + } + } +``` + +You can refer to ```scripts/data_process/nq_search.py``` for a concrete data processing example. + +### Corpora + +It is recommended to make your corpus a jsonl file, where each line (a dictionary with "id" key and "contents" key) corresponds to one passage. You can refer to ```example/corpus.jsonl``` for an example. + +The "id" key corresponds to the passage id, while the "contents" key corresponds to the passage content ('"' + title + '"\n' + text). +For example: +``` +{"id": "0", "contents": "Evan Morris Evan L. Morris (January 26, 1977 \u2013 July 9, 2015) was a lobbyist for Genentech and its parent corporation Roche in Washington."} +... +{"id": "100", "contents": "Three years later, when the United States Exploring Expedition to little-known portions of the globe was organised under Charles Wilkes, Hale was recommended, while yet an undergraduate."} +... +``` + +**Index your corpora (optional).** +If you would like to use a local retriever as the search engine, you can index your own corpus by: +``` +bash search_r1/search/build_index.sh +``` +You can change ```retriever_name``` and ```retriever_model``` to your interested off-the-shelf retriever. + +## Use your own search engine + +Our codebase supports local sparse retriever (e.g., BM25), local dense retriever (both flat indexing with GPUs and ANN indexing with CPUs) and online search engine (e.g., Google, Bing, etc). More details can be found [here](https://github.com/PeterGriffinJin/Search-R1/tree/main/docs/retriever.md). + +The main philosophy is to launch a local or remote search engine server separately from the main RL training pipeline. + +The LLM can call the search engine by calling the search API (e.g., "http://127.0.0.1:8000/retrieve"). + +You can refer to ```search_r1/search/retriever_server.py``` for an example of launching a local retriever server. + +## Features +- Support local sparse retrievers (e.g., BM25). ✔️ +- Support local dense retrievers (both flat indexing and ANN indexing) ✔️ +- Support google search / bing search / brave search API and others. ✔️ +- Support off-the-shelf neural rerankers. ✔️ +- Support different RL methods (e.g., PPO, GRPO, reinforce). ✔️ +- Support different LLMs (e.g., llama3, Qwen2.5, etc). ✔️ + +## Acknowledge + +The concept of Search-R1 is inspired by [Deepseek-R1](https://github.com/deepseek-ai/DeepSeek-R1) and [TinyZero](https://github.com/Jiayi-Pan/TinyZero/tree/main). +Its implementation is built upon [veRL](https://github.com/volcengine/verl) and [RAGEN](https://github.com/ZihanWang314/RAGEN/tree/main). +We sincerely appreciate the efforts of these teams for their contributions to open-source research and development. + +## Awesome work powered or inspired by Search-R1 + +- [DeepResearcher](https://github.com/GAIR-NLP/DeepResearcher): Scaling Deep Research via Reinforcement Learning in Real-world Environments. [![[code]](https://img.shields.io/github/stars/GAIR-NLP/DeepResearcher)](https://github.com/GAIR-NLP/DeepResearcher) +- [Multimodal-Search-R1](https://github.com/EvolvingLMMs-Lab/multimodal-search-r1): Incentivizing LMMs to Search. [![[code]](https://img.shields.io/github/stars/EvolvingLMMs-Lab/multimodal-search-r1)](https://github.com/EvolvingLMMs-Lab/multimodal-search-r1) +- [OTC](https://arxiv.org/pdf/2504.14870): Optimal Tool Calls via Reinforcement Learning. +- [ZeroSearch](https://github.com/Alibaba-NLP/ZeroSearch): Incentivize the Search Capability of LLMs without Searching. [![[code]](https://img.shields.io/github/stars/Alibaba-NLP/ZeroSearch)](https://github.com/Alibaba-NLP/ZeroSearch) +- [IKEA](https://github.com/hzy312/knowledge-r1): Reinforced Internal-External Knowledge Synergistic Reasoning for Efficient Adaptive Search Agent. [![[code]](https://img.shields.io/github/stars/hzy312/knowledge-r1)](https://github.com/hzy312/knowledge-r1) +- [Scent of Knowledge](https://arxiv.org/abs/2505.09316): Optimizing Search-Enhanced Reasoning with Information Foraging. +- [AutoRefine](https://www.arxiv.org/pdf/2505.11277): Search and Refine During Think. [![[code]](https://img.shields.io/github/stars/syr-cn/AutoRefine)](https://github.com/syr-cn/AutoRefine) +- [O^2-Searcher](https://arxiv.org/pdf/2505.16582): A Searching-based Agent Model for Open-Domain Open-Ended Question Answering. [![[code]](https://img.shields.io/github/stars/Acade-Mate/O2-Searcher)](https://github.com/Acade-Mate/O2-Searcher) +- [MaskSearch](https://arxiv.org/pdf/2505.20285): A Universal Pre-Training Framework to Enhance Agentic Search Capability. [![[code]](https://img.shields.io/github/stars/Alibaba-NLP/MaskSearch)](https://github.com/Alibaba-NLP/MaskSearch) +- [VRAG-RL](https://arxiv.org/abs/2505.22019): Vision-Perception-Based RAG for Visually Rich Information Understanding. [![[code]](https://img.shields.io/github/stars/Alibaba-NLP/VRAG)](https://github.com/Alibaba-NLP/VRAG) +- [R1-Code-Interpreter](https://arxiv.org/abs/2505.21668): Training LLMs to Reason with Code via SFT and RL. [![[code]](https://img.shields.io/github/stars/yongchao98/R1-Code-Interpreter)](https://github.com/yongchao98/R1-Code-Interpreter) +- [R-Search](https://arxiv.org/abs/2506.04185): Empowering LLM Reasoning with Search via Multi-Reward Reinforcement Learning. [![[code]](https://img.shields.io/github/stars/QingFei1/R-Search)](https://github.com/QingFei1/R-Search) +- [StepSearch](https://arxiv.org/pdf/2505.15107): Igniting LLMs Search Ability via Step-Wise Proximal Policy Optimization. [![[code]](https://img.shields.io/github/stars/Zillwang/StepSearch)](https://github.com/Zillwang/StepSearch) +- [SimpleTIR](https://simpletir.notion.site/report): Stable End-to-End Reinforcement Learning for Multi-Turn Tool-Integrated Reasoning. [![[code]](https://img.shields.io/github/stars/ltzheng/SimpleTIR)](https://github.com/ltzheng/SimpleTIR) +- [Router-R1](https://arxiv.org/pdf/2506.09033): Teaching LLMs Multi-Round Routing and Aggregation via Reinforcement Learning. [![[code]](https://img.shields.io/github/stars/ulab-uiuc/Router-R1)](https://github.com/ulab-uiuc/Router-R1) +- [SkyRL](https://skyrl.readthedocs.io/en/latest/): A Modular Full-stack RL Library for LLMs. [![[code]](https://img.shields.io/github/stars/NovaSky-AI/SkyRL)](https://github.com/NovaSky-AI/SkyRL) +- [ASearcher](https://arxiv.org/abs/2508.07976): Large-Scale RL for Search Agents. [![[code]](https://img.shields.io/github/stars/inclusionAI/ASearcher)](https://github.com/inclusionAI/ASearcher) +- [ParallelSearch](https://www.arxiv.org/abs/2508.09303): Decompose Query and Search Sub-queries in Parallel with RL. [![[code]](https://img.shields.io/github/stars/Tree-Shu-Zhao/ParallelSearch)](https://github.com/Tree-Shu-Zhao/ParallelSearch) +- [AutoTIR](https://arxiv.org/pdf/2507.21836): Autonomous Tools Integrated Reasoning via Reinforcement Learning. [![[code]](https://img.shields.io/github/stars/weiyifan1023/AutoTIR)](https://github.com/weiyifan1023/AutoTIR) +- [verl-tool](https://arxiv.org/pdf/2509.01055): A version of verl to support diverse tool use. [![[code]](https://img.shields.io/github/stars/TIGER-AI-Lab/verl-tool)](https://github.com/TIGER-AI-Lab/verl-tool) +- [Tree-GRPO](https://arxiv.org/abs/2509.21240): Tree Search for LLM Agent Reinforcement Learning. [![[code]](https://img.shields.io/github/stars/AMAP-ML/Tree-GRPO)](https://github.com/AMAP-ML/Tree-GRPO) +- [EviNote-RAG](https://arxiv.org/abs/2509.00877): Enhancing RAG Models via Answer-Supportive Evidence Notes. [![[code]](https://img.shields.io/github/stars/Da1yuqin/EviNoteRAG)](https://github.com/Da1yuqin/EviNoteRAG) +- [GlobalRAG](https://arxiv.org/pdf/2510.20548v1): GlobalRAG: Enhancing Global Reasoning in Multi-hop Question Answering via Reinforcement Learning. [![[code]](https://img.shields.io/github/stars/CarnegieBin/GlobalRAG)](https://github.com/CarnegieBin/GlobalRAG) + + + + + +## Citations + +```bibtex +@article{jin2025search, + title={Search-r1: Training llms to reason and leverage search engines with reinforcement learning}, + author={Jin, Bowen and Zeng, Hansi and Yue, Zhenrui and Yoon, Jinsung and Arik, Sercan and Wang, Dong and Zamani, Hamed and Han, Jiawei}, + journal={arXiv preprint arXiv:2503.09516}, + year={2025} +} +``` + +```bibtex +@article{jin2025empirical, + title={An Empirical Study on Reinforcement Learning for Reasoning-Search Interleaved LLM Agents}, + author={Jin, Bowen and Yoon, Jinsung and Kargupta, Priyanka and Arik, Sercan O and Han, Jiawei}, + journal={arXiv preprint arXiv:2505.15117}, + year={2025} +} +``` diff --git a/code/RL_model/verl/Search-R1/verl.egg-info/dependency_links.txt b/code/RL_model/verl/Search-R1/verl.egg-info/dependency_links.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/code/RL_model/verl/Search-R1/verl.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/code/RL_model/verl/Search-R1/verl.egg-info/requires.txt b/code/RL_model/verl/Search-R1/verl.egg-info/requires.txt new file mode 100644 index 0000000000000000000000000000000000000000..e32a32857dde8a4998ffbea539487754c7822f2c --- /dev/null +++ b/code/RL_model/verl/Search-R1/verl.egg-info/requires.txt @@ -0,0 +1,15 @@ +accelerate +codetiming +datasets +dill +hydra-core +numpy +pybind11 +ray +tensordict +transformers<4.48 +vllm<=0.6.3 + +[test] +pytest +yapf diff --git a/code/RL_model/verl/Search-R1/verl.egg-info/top_level.txt b/code/RL_model/verl/Search-R1/verl.egg-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..1f2557330403c1b03fc8119ed28560c788040326 --- /dev/null +++ b/code/RL_model/verl/Search-R1/verl.egg-info/top_level.txt @@ -0,0 +1,2 @@ +search_r1 +verl diff --git a/code/RL_model/verl/Search-R1/verl/__init__.py b/code/RL_model/verl/Search-R1/verl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f068717761543cde8dd59ad08b42465160893bb3 --- /dev/null +++ b/code/RL_model/verl/Search-R1/verl/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__))) + +with open(os.path.join(version_folder, 'version/version')) as f: + __version__ = f.read().strip() + +from .protocol import DataProto + +from .utils.logging_utils import set_basic_config +import logging + +set_basic_config(level=logging.WARNING) diff --git a/code/RL_model/verl/Search-R1/verl/protocol.py b/code/RL_model/verl/Search-R1/verl/protocol.py new file mode 100644 index 0000000000000000000000000000000000000000..803da36643a70a69f08541d74e2782ad72db32a9 --- /dev/null +++ b/code/RL_model/verl/Search-R1/verl/protocol.py @@ -0,0 +1,639 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Implement base data transfer protocol between any two functions, modules. +We can subclass Protocol to define more detailed batch info with specific keys +""" + +import pickle +import numpy as np +import copy +from dataclasses import dataclass, field +from typing import Callable, Dict, List, Union + +import torch +import tensordict +from tensordict import TensorDict +from torch.utils.data import DataLoader, Dataset + +from verl.utils.py_functional import union_two_dict + +__all__ = ['DataProto', 'union_tensor_dict'] + +try: + tensordict.set_lazy_legacy(False).set() +except: + pass + + +def pad_dataproto_to_divisor(data: 'DataProto', size_divisor: int): + """Pad a DataProto to size divisible by size_divisor + + Args: + size_divisor (int): size divisor + + Returns: + data: (DataProto): the padded DataProto + pad_size (int) + """ + assert isinstance(data, DataProto), 'data must be a DataProto' + if len(data) % size_divisor != 0: + pad_size = size_divisor - len(data) % size_divisor + data_padded = DataProto.concat([data, data[:pad_size]]) + else: + pad_size = 0 + data_padded = data + return data_padded, pad_size + + +def unpad_dataproto(data: 'DataProto', pad_size): + if pad_size != 0: + data = data[:-pad_size] + return data + + +def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> TensorDict: + """Union two tensordicts.""" + assert tensor_dict1.batch_size == tensor_dict2.batch_size, \ + f'Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}' + for key in tensor_dict2.keys(): + if key not in tensor_dict1.keys(): + tensor_dict1[key] = tensor_dict2[key] + else: + assert tensor_dict1[key].equal(tensor_dict2[key]), \ + f'{key} in tensor_dict1 and tensor_dict2 are not the same object' + + return tensor_dict1 + + +def union_numpy_dict(tensor_dict1: dict[np.ndarray], tensor_dict2: dict[np.ndarray]) -> dict[np.ndarray]: + for key, val in tensor_dict2.items(): + if key in tensor_dict1: + assert isinstance(tensor_dict2[key], np.ndarray) + assert isinstance(tensor_dict1[key], np.ndarray) + assert np.all(tensor_dict2[key] == tensor_dict1[key]), \ + f'{key} in tensor_dict1 and tensor_dict2 are not the same object' + tensor_dict1[key] = val + + return tensor_dict1 + + +def list_of_dict_to_dict_of_list(list_of_dict: list[dict]): + if len(list_of_dict) == 0: + return {} + keys = list_of_dict[0].keys() + output = {key: [] for key in keys} + for data in list_of_dict: + for key, item in data.items(): + assert key in output + output[key].append(item) + return output + + +def fold_batch_dim(data: 'DataProto', new_batch_size): + """ + Fold a batch dim from [bsz, xxx] into [new_bsz, bsz // new_bsz, xxx] + """ + batch_size = data.batch.batch_size[0] + + assert batch_size % new_batch_size == 0 + + tensor: TensorDict = data.batch + non_tensor = data.non_tensor_batch + + tensor = tensor.view(new_batch_size, -1) + tensor.auto_batch_size_(batch_dims=1) + + for key, val in non_tensor.items(): + non_tensor[key] = np.reshape(val, newshape=(new_batch_size, -1, *val.shape[1:])) + + return DataProto(batch=tensor, non_tensor_batch=non_tensor, meta_info=data.meta_info) + + +def unfold_batch_dim(data: 'DataProto', batch_dims=2): + """ + Unfold the first n dims as new batch dim + """ + tensor: TensorDict = data.batch + non_tensor = data.non_tensor_batch + tensor.auto_batch_size_(batch_dims=batch_dims) + tensor = tensor.view(-1) + + batch_size = tensor.batch_size[0] + + non_tensor_new = {} + + for key, val in non_tensor.items(): + non_tensor_new[key] = np.reshape(val, newshape=(batch_size, *val.shape[batch_dims:])) + + return DataProto(batch=tensor, non_tensor_batch=non_tensor_new, meta_info=data.meta_info) + + +def collate_fn(x: list['DataProtoItem']): + batch = [] + non_tensor_batch = [] + for data in x: + batch.append(data.batch) + non_tensor_batch.append(data.non_tensor_batch) + batch = torch.stack(batch).contiguous() + non_tensor_batch = list_of_dict_to_dict_of_list(non_tensor_batch) + for key, val in non_tensor_batch.items(): + non_tensor_batch[key] = np.array(val, dtype=object) + return DataProto(batch=batch, non_tensor_batch=non_tensor_batch) + + +@dataclass +class DataProtoItem: + # TODO(zhangchi.usc1992) add consistency check + batch: TensorDict = None + non_tensor_batch: Dict = field(default_factory=dict) + meta_info: Dict = field(default_factory=dict) + + +@dataclass +class DataProto: + """ + A DataProto is a data structure that aims to provide a standard protocol for data exchange between functions. + It contains a batch (TensorDict) and a meta_info (Dict). The batch is a TensorDict https://pytorch.org/tensordict/. + TensorDict allows you to manipulate a dictionary of Tensors like a single Tensor. Ideally, the tensors with the + same batch size should be put inside batch. + """ + batch: TensorDict = None + non_tensor_batch: Dict = field(default_factory=dict) + meta_info: Dict = field(default_factory=dict) + + def __post_init__(self): + # perform necessary checking + self.check_consistency() + + def __len__(self): + if self.batch is not None: + return self.batch.batch_size[0] + elif self.non_tensor_batch is not None and len(self.non_tensor_batch) > 0: + random_key = list(self.non_tensor_batch.keys())[0] + return self.non_tensor_batch[random_key].shape[0] + else: + return 0 + + def __getitem__(self, item): + tensor_data = self.batch[item] + non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()} + return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info) + + def __getstate__(self): + import io + buffer = io.BytesIO() + if tensordict.__version__ >= '0.5.0' and self.batch is not None: + self.batch = self.batch.contiguous() + self.batch = self.batch.consolidate() + torch.save(self.batch, buffer) + buffer_bytes = buffer.getvalue() + return buffer_bytes, self.non_tensor_batch, self.meta_info + + def __setstate__(self, data): + import io + batch_deserialized_bytes, non_tensor_batch, meta_info = data + batch_deserialized = io.BytesIO(initial_bytes=batch_deserialized_bytes) + batch = torch.load(batch_deserialized, + weights_only=False, + map_location='cpu' if not torch.cuda.is_available() else None) + self.batch = batch + self.non_tensor_batch = non_tensor_batch + self.meta_info = meta_info + + def save_to_disk(self, filepath): + with open(filepath, 'wb') as f: + pickle.dump(self, f) + + @staticmethod + def load_from_disk(filepath) -> 'DataProto': + with open(filepath, 'rb') as f: + data = pickle.load(f) + return data + + def print_size(self, prefix=""): + size_of_tensordict = 0 + for key, tensor in self.batch.items(): + size_of_tensordict += tensor.element_size() * tensor.numel() + size_of_numpy_array = 0 + for key, numpy_array in self.non_tensor_batch.items(): + size_of_numpy_array += numpy_array.nbytes + + size_of_numpy_array /= 1024**3 + size_of_tensordict /= 1024**3 + + message = f'Size of tensordict: {size_of_tensordict} GB, size of non_tensor_batch: {size_of_numpy_array} GB' + + if prefix: + message = f'{prefix}, ' + message + print(message) + + def check_consistency(self): + """Check the consistency of the DataProto. Mainly for batch and non_tensor_batch + We expose this function as a public one so that user can call themselves directly + """ + if self.batch is not None: + assert len(self.batch.batch_size) == 1, 'only support num_batch_dims=1' + + if self.non_tensor_batch is not None: + for key, val in self.non_tensor_batch.items(): + assert isinstance(val, np.ndarray) + + if self.batch is not None and len(self.non_tensor_batch) != 0: + # TODO: we can actually lift this restriction if needed + assert len(self.batch.batch_size) == 1, 'only support num_batch_dims=1 when non_tensor_batch is not empty.' + + batch_size = self.batch.batch_size[0] + for key, val in self.non_tensor_batch.items(): + assert isinstance( + val, np.ndarray + ) and val.dtype == object, 'data in the non_tensor_batch must be a numpy.array with dtype=object' + assert val.shape[ + 0] == batch_size, f'key {key} length {len(val)} is not equal to batch size {batch_size}' + + @classmethod + def from_single_dict(cls, data: Dict[str, Union[torch.Tensor, np.ndarray]], meta_info=None): + tensors = {} + non_tensors = {} + + for key, val in data.items(): + if isinstance(val, torch.Tensor): + tensors[key] = val + elif isinstance(val, np.ndarray): + non_tensors[key] = val + else: + raise ValueError(f'Unsupported type in data {type(val)}') + + return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info) + + @classmethod + def from_dict(cls, tensors: Dict[str, torch.Tensor], non_tensors=None, meta_info=None, num_batch_dims=1): + """Create a DataProto from a dict of tensors. This assumes that + 1. All the tensor in tensors have the same dim0 + 2. Only dim0 is the batch dim + """ + assert len(tensors) > 0, 'tensors must not be empty' + assert num_batch_dims > 0, 'num_batch_dims must be greater than zero' + if non_tensors is not None: + assert num_batch_dims == 1, 'only support num_batch_dims=1 when non_tensors is not None.' + + if meta_info is None: + meta_info = {} + if non_tensors is None: + non_tensors = {} + + assert isinstance(non_tensors, dict) + + # get and check batch size + batch_size = None + pivot_key = None + for key, tensor in tensors.items(): + if batch_size is None: + batch_size = tensor.shape[:num_batch_dims] + pivot_key = key + else: + current_batch = tensor.shape[:num_batch_dims] + assert batch_size == current_batch, \ + f'Not all the tensor in tensors have the same batch size with batch_dims={num_batch_dims}. Got {pivot_key} has {batch_size}, {key} has {current_batch}' + + for key, val in non_tensors.items(): + non_tensors[key] = np.array(val, dtype=object) + + tensor_dict = TensorDict(source=tensors, batch_size=batch_size) + return cls(batch=tensor_dict, non_tensor_batch=non_tensors, meta_info=meta_info) + + def to(self, device) -> 'DataProto': + """move the batch to device + + Args: + device (torch.device, str): torch device + + Returns: + DataProto: the current DataProto + + """ + if self.batch is not None: + self.batch = self.batch.to(device) + return self + + def select(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None, deepcopy=False) -> 'DataProto': + """Select a subset of the DataProto via batch_keys and meta_info_keys + + Args: + batch_keys (list, optional): a list of strings indicating the keys in batch to select + meta_info_keys (list, optional): a list of keys indicating the meta info to select + + Returns: + DataProto: the DataProto with the selected batch_keys and meta_info_keys + """ + # TODO (zhangchi.usc1992) whether to copy + if batch_keys is not None: + batch_keys = tuple(batch_keys) + sub_batch = self.batch.select(*batch_keys) + else: + sub_batch = self.batch + + if non_tensor_batch_keys is not None: + non_tensor_batch = {key: val for key, val in self.non_tensor_batch.items() if key in non_tensor_batch_keys} + else: + non_tensor_batch = self.non_tensor_batch + + if deepcopy: + non_tensor_batch = copy.deepcopy(non_tensor_batch) + + if meta_info_keys is not None: + sub_meta_info = {key: val for key, val in self.meta_info.items() if key in meta_info_keys} + else: + sub_meta_info = self.meta_info + + if deepcopy: + sub_meta_info = copy.deepcopy(sub_meta_info) + + return DataProto(batch=sub_batch, non_tensor_batch=non_tensor_batch, meta_info=sub_meta_info) + + def pop(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None) -> 'DataProto': + """Pop a subset of the DataProto via `batch_keys` and `meta_info_keys` + + Args: + batch_keys (list, optional): a list of strings indicating the keys in batch to pop + meta_info_keys (list, optional): a list of keys indicating the meta info to pop + + Returns: + DataProto: the DataProto with the poped batch_keys and meta_info_keys + """ + assert batch_keys is not None + if meta_info_keys is None: + meta_info_keys = [] + if non_tensor_batch_keys is None: + non_tensor_batch_keys = [] + + tensors = {} + # tensor batch + for key in batch_keys: + assert key in self.batch.keys() + tensors[key] = self.batch.pop(key) + non_tensors = {} + # non tensor batch + for key in non_tensor_batch_keys: + assert key in self.non_tensor_batch.keys() + non_tensors[key] = self.non_tensor_batch.pop(key) + meta_info = {} + for key in meta_info_keys: + assert key in self.meta_info.keys() + meta_info[key] = self.meta_info.pop(key) + return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info) + + def rename(self, old_keys=None, new_keys=None) -> 'DataProto': + """ + Note that this function only rename the key in the batch + """ + + def validate_input(keys): + if keys is not None: + if isinstance(keys, str): + keys = [keys] + elif isinstance(keys, list): + pass + else: + raise TypeError(f'keys must be a list or a string, but got {type(keys)}') + return keys + + old_keys = validate_input(old_keys) + new_keys = validate_input(new_keys) + + if len(new_keys) != len(old_keys): + raise ValueError( + f'new_keys and old_keys must have the same length, but got {len(new_keys)} and {len(old_keys)}') + + self.batch.rename_key_(tuple(old_keys), tuple(new_keys)) + + return self + + def union(self, other: 'DataProto') -> 'DataProto': + """Union with another DataProto. Union batch and meta_info separately. + Throw an error if + - there are conflict keys in batch and they are not equal + - the batch size of two data batch is not the same + - there are conflict keys in meta_info and they are not the same. + + Args: + other (DataProto): another DataProto to union + + Returns: + DataProto: the DataProto after union + """ + self.batch = union_tensor_dict(self.batch, other.batch) + self.non_tensor_batch = union_numpy_dict(self.non_tensor_batch, other.non_tensor_batch) + self.meta_info = union_two_dict(self.meta_info, other.meta_info) + return self + + def make_iterator(self, mini_batch_size, epochs, seed=None, dataloader_kwargs=None): + """Make an iterator from the DataProto. This is built upon that TensorDict can be used as a normal Pytorch + dataset. See https://pytorch.org/tensordict/tutorials/data_fashion for more details. + + Args: + mini_batch_size (int): mini-batch size when iterating the dataset. We require that + ``batch.batch_size[0] % mini_batch_size == 0`` + epochs (int): number of epochs when iterating the dataset. + dataloader_kwargs: internally, it returns a DataLoader over the batch. + The dataloader_kwargs is the kwargs passed to the DataLoader + + Returns: + Iterator: an iterator that yields a mini-batch data at a time. The total number of iteration steps is + ``self.batch.batch_size * epochs // mini_batch_size`` + """ + assert self.batch.batch_size[0] % mini_batch_size == 0, f"{self.batch.batch_size[0]} % {mini_batch_size} != 0" + # we can directly create a dataloader from TensorDict + if dataloader_kwargs is None: + dataloader_kwargs = {} + + if seed is not None: + generator = torch.Generator() + generator.manual_seed(seed) + else: + generator = None + + assert isinstance(dataloader_kwargs, Dict) + train_dataloader = DataLoader(dataset=self, + batch_size=mini_batch_size, + collate_fn=collate_fn, + generator=generator, + **dataloader_kwargs) + + def get_data(): + for _ in range(epochs): + for d in train_dataloader: + d.meta_info = self.meta_info + yield d + + return iter(get_data()) + + def chunk(self, chunks: int) -> List['DataProto']: + """Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split. + + Args: + chunks (int): the number of chunks to split on dim=0 + + Returns: + List[DataProto]: a list of DataProto after splitting + """ + assert len( + self) % chunks == 0, f'only support equal chunk. Got size of DataProto {len(self)} and chunk {chunks}.' + + if self.batch is not None: + batch_lst = self.batch.chunk(chunks=chunks, dim=0) + else: + batch_lst = [None for _ in range(chunks)] + + non_tensor_batch_lst = [{} for _ in range(chunks)] + for key, val in self.non_tensor_batch.items(): + assert isinstance(val, np.ndarray) + non_tensor_lst = np.array_split(val, chunks) + assert len(non_tensor_lst) == chunks + for i in range(chunks): + non_tensor_batch_lst[i][key] = non_tensor_lst[i] + + output = [] + for i in range(chunks): + output.append( + DataProto(batch=batch_lst[i], non_tensor_batch=non_tensor_batch_lst[i], meta_info=self.meta_info)) + + return output + + @staticmethod + def concat(data: List['DataProto']) -> 'DataProto': + """Concat a list of DataProto. The batch is concatenated among dim=0. + The meta_info is assumed to be identical and will use the first one. + + Args: + data (List[DataProto]): list of DataProto + + Returns: + DataProto: concatenated DataProto + """ + batch_lst = [] + for batch in data: + batch_lst.append(batch.batch) + if batch_lst[0] is not None: + new_batch = torch.cat(batch_lst, dim=0) + else: + new_batch = None + + non_tensor_batch = list_of_dict_to_dict_of_list(list_of_dict=[d.non_tensor_batch for d in data]) + for key, val in non_tensor_batch.items(): + non_tensor_batch[key] = np.concatenate(val, axis=0) + + return DataProto(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=data[0].meta_info) + + def reorder(self, indices): + """ + Note that this operation is in-place + """ + indices_np = indices.detach().numpy() + self.batch = self.batch[indices] + self.non_tensor_batch = {key: val[indices_np] for key, val in self.non_tensor_batch.items()} + + def repeat(self, repeat_times=2, interleave=True): + """ + Repeat the batch data a specified number of times. + + Args: + repeat_times (int): Number of times to repeat the data. + interleave (bool): Whether to interleave the repeated data. + + Returns: + DataProto: A new DataProto with repeated data. + """ + if self.batch is not None: + if interleave: + # Interleave the data + repeated_tensors = { + key: tensor.repeat_interleave(repeat_times, dim=0) for key, tensor in self.batch.items() + } + else: + # Stack the data + repeated_tensors = { + key: tensor.unsqueeze(0).expand(repeat_times, *tensor.shape).reshape(-1, *tensor.shape[1:]) + for key, tensor in self.batch.items() + } + + repeated_batch = TensorDict( + source=repeated_tensors, + batch_size=(self.batch.batch_size[0] * repeat_times,), + ) + else: + repeated_batch = None + + repeated_non_tensor_batch = {} + for key, val in self.non_tensor_batch.items(): + if interleave: + repeated_non_tensor_batch[key] = np.repeat(val, repeat_times, axis=0) + else: + repeated_non_tensor_batch[key] = np.tile(val, (repeat_times,) + (1,) * (val.ndim - 1)) + + return DataProto( + batch=repeated_batch, + non_tensor_batch=repeated_non_tensor_batch, + meta_info=self.meta_info, + ) + + +import ray + + +@dataclass +class DataProtoFuture: + """ + DataProtoFuture aims to eliminate actual data fetching on driver. By doing so, the driver doesn't have to wait + for data so that asynchronous execution becomes possible. + DataProtoFuture contains a list of futures from another WorkerGroup of size world_size. + - collect_fn is a Callable that reduces the list of futures to a DataProto + - dispatch_fn is a Callable that partitions the DataProto into a list of DataProto of size world_size and then select + + Potential issue: we can optimize dispatch_fn(collect_fn) such that only needed data is fetched on destination + - DataProtoFuture only supports directly passing from the output of a method to another input. You can't perform any + operation on the DataProtoFuture in driver. + """ + collect_fn: Callable + futures: List[ray.ObjectRef] + dispatch_fn: Callable = None + + @staticmethod + def concat(data: List[ray.ObjectRef]) -> 'DataProtoFuture': + output = DataProtoFuture(collect_fn=DataProto.concat, futures=data) + return output + + def chunk(self, chunks: int) -> List['DataProtoFuture']: + from functools import partial + + arg_future_lst = [] + for i in range(chunks): + # note that we can't directly pass i and chunks + def dispatch_fn(x, i, chunks): + return x.chunk(chunks=chunks)[i] + + arg_future = DataProtoFuture(collect_fn=self.collect_fn, + dispatch_fn=partial(dispatch_fn, i=i, chunks=chunks), + futures=self.futures) + arg_future_lst.append(arg_future) + return arg_future_lst + + def get(self): + output = ray.get(self.futures) # dp_size. + for o in output: + assert isinstance(o, DataProto) + output = self.collect_fn(output) # select dp, concat + if self.dispatch_fn is not None: + output = self.dispatch_fn(output) # split in batch dim, select using dp + return output diff --git a/code/RL_model/verl/Search-R1/wandb/debug-internal.log b/code/RL_model/verl/Search-R1/wandb/debug-internal.log new file mode 100644 index 0000000000000000000000000000000000000000..303163b4a3ed9a27addcc89f564458d66d92cea4 --- /dev/null +++ b/code/RL_model/verl/Search-R1/wandb/debug-internal.log @@ -0,0 +1,6 @@ +{"time":"2026-02-01T20:27:26.269116545-05:00","level":"INFO","msg":"stream: starting","core version":"0.23.1"} +{"time":"2026-02-01T20:27:27.692526697-05:00","level":"INFO","msg":"stream: created new stream","id":"lly0j9zs"} +{"time":"2026-02-01T20:27:27.692680073-05:00","level":"INFO","msg":"handler: started","stream_id":"lly0j9zs"} +{"time":"2026-02-01T20:27:27.695494454-05:00","level":"INFO","msg":"stream: started","id":"lly0j9zs"} +{"time":"2026-02-01T20:27:27.69557747-05:00","level":"INFO","msg":"writer: started","stream_id":"lly0j9zs"} +{"time":"2026-02-01T20:27:27.695701035-05:00","level":"INFO","msg":"sender: started","stream_id":"lly0j9zs"} diff --git a/code/RL_model/verl/Search-R1/wandb/debug.log b/code/RL_model/verl/Search-R1/wandb/debug.log new file mode 100644 index 0000000000000000000000000000000000000000..8df0f98e5da5d448c64dbafc1ef3703811880cd5 --- /dev/null +++ b/code/RL_model/verl/Search-R1/wandb/debug.log @@ -0,0 +1,21 @@ +2026-02-01 20:27:25,874 INFO MainThread:1578907 [wandb_setup.py:_flush():80] Current SDK version is 0.23.1 +2026-02-01 20:27:25,874 INFO MainThread:1578907 [wandb_setup.py:_flush():80] Configure stats pid to 1578907 +2026-02-01 20:27:25,875 INFO MainThread:1578907 [wandb_setup.py:_flush():80] Loading settings from /home/mshahidul/.config/wandb/settings +2026-02-01 20:27:25,875 INFO MainThread:1578907 [wandb_setup.py:_flush():80] Loading settings from /data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1/wandb/settings +2026-02-01 20:27:25,875 INFO MainThread:1578907 [wandb_setup.py:_flush():80] Loading settings from environment variables +2026-02-01 20:27:25,875 INFO MainThread:1578907 [wandb_init.py:setup_run_log_directory():714] Logging user logs to /data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1/wandb/run-20260201_202725-lly0j9zs/logs/debug.log +2026-02-01 20:27:25,875 INFO MainThread:1578907 [wandb_init.py:setup_run_log_directory():715] Logging internal logs to /data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1/wandb/run-20260201_202725-lly0j9zs/logs/debug-internal.log +2026-02-01 20:27:25,876 INFO MainThread:1578907 [wandb_init.py:init():841] calling init triggers +2026-02-01 20:27:25,876 INFO MainThread:1578907 [wandb_init.py:init():846] wandb.init called with sweep_config: {} +config: {'data': {'tokenizer': None, 'train_files': '/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet', 'val_files': '/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet', 'train_data_num': None, 'val_data_num': None, 'prompt_key': 'prompt', 'max_prompt_length': 4096, 'max_response_length': 1024, 'max_start_length': 256, 'max_obs_length': 512, 'train_batch_size': 128, 'val_batch_size': 64, 'return_raw_input_ids': False, 'return_raw_chat': False, 'shuffle_train_dataloader': True}, 'actor_rollout_ref': {'hybrid_engine': True, 'model': {'path': 'Qwen/Qwen3-4B-Instruct-2507', 'external_lib': None, 'override_config': {}, 'enable_gradient_checkpointing': True, 'use_remove_padding': False}, 'actor': {'strategy': 'fsdp', 'ppo_mini_batch_size': 64, 'ppo_micro_batch_size': 64, 'use_dynamic_bsz': False, 'ppo_max_token_len_per_gpu': 16384, 'grad_clip': 1.0, 'state_masking': False, 'clip_ratio': 0.2, 'entropy_coeff': 0.001, 'use_kl_loss': False, 'kl_loss_coef': 0.001, 'kl_loss_type': 'low_var_kl', 'ppo_epochs': 1, 'shuffle': False, 'ulysses_sequence_parallel_size': 1, 'optim': {'lr': 1e-06, 'lr_warmup_steps_ratio': 0.0, 'min_lr_ratio': None, 'warmup_style': 'constant', 'total_training_steps': 1005}, 'fsdp_config': {'wrap_policy': {'min_num_params': 0}, 'param_offload': True, 'grad_offload': False, 'optimizer_offload': True, 'fsdp_size': -1}, 'ppo_micro_batch_size_per_gpu': 16}, 'ref': {'fsdp_config': {'param_offload': True, 'wrap_policy': {'min_num_params': 0}, 'fsdp_size': -1}, 'log_prob_micro_batch_size': 64, 'log_prob_use_dynamic_bsz': False, 'log_prob_max_token_len_per_gpu': 16384, 'ulysses_sequence_parallel_size': 1}, 'rollout': {'name': 'vllm', 'temperature': 1.0, 'top_k': -1, 'top_p': 0.95, 'prompt_length': 4096, 'response_length': 1024, 'dtype': 'bfloat16', 'gpu_memory_utilization': 0.4, 'ignore_eos': False, 'enforce_eager': True, 'free_cache_engine': True, 'load_format': 'dummy_dtensor', 'tensor_model_parallel_size': 1, 'max_num_batched_tokens': 8192, 'max_num_seqs': 1024, 'log_prob_micro_batch_size': 64, 'log_prob_use_dynamic_bsz': False, 'log_prob_max_token_len_per_gpu': 16384, 'do_sample': True, 'n': 1, 'n_agent': 1}}, 'critic': {'strategy': 'fsdp', 'optim': {'lr': 1e-05, 'lr_warmup_steps_ratio': 0.0, 'min_lr_ratio': None, 'warmup_style': 'constant', 'total_training_steps': 1005}, 'model': {'path': '~/models/deepseek-llm-7b-chat', 'tokenizer_path': 'Qwen/Qwen3-4B-Instruct-2507', 'override_config': {}, 'external_lib': None, 'enable_gradient_checkpointing': False, 'use_remove_padding': False, 'fsdp_config': {'param_offload': False, 'grad_offload': False, 'optimizer_offload': False, 'wrap_policy': {'min_num_params': 0}, 'fsdp_size': -1}}, 'ppo_mini_batch_size': 64, 'ppo_micro_batch_size': 64, 'forward_micro_batch_size': 64, 'use_dynamic_bsz': False, 'ppo_max_token_len_per_gpu': 32768, 'forward_max_token_len_per_gpu': 32768, 'ulysses_sequence_parallel_size': 1, 'ppo_epochs': 1, 'shuffle': False, 'grad_clip': 1.0, 'cliprange_value': 0.5}, 'reward_model': {'enable': False, 'strategy': 'fsdp', 'model': {'input_tokenizer': 'Qwen/Qwen3-4B-Instruct-2507', 'path': '~/models/FsfairX-LLaMA3-RM-v0.1', 'external_lib': None, 'use_remove_padding': False, 'fsdp_config': {'min_num_params': 0, 'param_offload': False}}, 'micro_batch_size': 64, 'max_length': None, 'ulysses_sequence_parallel_size': 1, 'use_dynamic_bsz': False, 'forward_max_token_len_per_gpu': 32768, 'structure_format_score': 0, 'final_format_score': 0, 'retrieval_score': 0}, 'retriever': {'url': 'http://127.0.0.1:8000/retrieve', 'topk': 3}, 'algorithm': {'gamma': 1.0, 'lam': 1.0, 'adv_estimator': 'grpo', 'no_think_rl': False, 'kl_penalty': 'kl', 'kl_ctrl': {'type': 'fixed', 'kl_coef': 0.001}, 'state_masking': {'start_state_marker': '', 'end_state_marker': ''}}, 'trainer': {'total_epochs': 15, 'total_training_steps': 1005, 'project_name': '', 'experiment_name': 'llm_guard_3B_10k_v2', 'logger': ['wandb'], 'nnodes': 1, 'n_gpus_per_node': 2, 'save_freq': 100, 'test_freq': 50, 'critic_warmup': 0, 'default_hdfs_dir': '~/experiments/gsm8k/ppo/llm_guard_3B_10k_v2', 'default_local_dir': 'verl_checkpoints/llm_guard_3B_10k_v2'}, 'max_turns': 1, 'do_search': False, '_wandb': {}} +2026-02-01 20:27:25,876 INFO MainThread:1578907 [wandb_init.py:init():889] starting backend +2026-02-01 20:27:26,251 INFO MainThread:1578907 [wandb_init.py:init():892] sending inform_init request +2026-02-01 20:27:26,261 INFO MainThread:1578907 [wandb_init.py:init():900] backend started and connected +2026-02-01 20:27:26,270 INFO MainThread:1578907 [wandb_init.py:init():970] updated telemetry +2026-02-01 20:27:26,293 INFO MainThread:1578907 [wandb_init.py:init():994] communicating run to backend with 90.0 second timeout +2026-02-01 20:27:27,908 INFO MainThread:1578907 [wandb_init.py:init():1041] starting run threads in backend +2026-02-01 20:27:28,715 INFO MainThread:1578907 [wandb_run.py:_console_start():2521] atexit reg +2026-02-01 20:27:28,716 INFO MainThread:1578907 [wandb_run.py:_redirect():2369] redirect: wrap_raw +2026-02-01 20:27:28,716 INFO MainThread:1578907 [wandb_run.py:_redirect():2438] Wrapping output streams. +2026-02-01 20:27:28,716 INFO MainThread:1578907 [wandb_run.py:_redirect():2461] Redirects installed. +2026-02-01 20:27:28,726 INFO MainThread:1578907 [wandb_init.py:init():1081] run started, returning control to user process diff --git a/code/RL_model/verl/verl_train/tests/experimental/agent_loop/agent_utils.py b/code/RL_model/verl/verl_train/tests/experimental/agent_loop/agent_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6222a29738b8b30de58a5cef6780493bd08c38ec --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/experimental/agent_loop/agent_utils.py @@ -0,0 +1,92 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ray +from omegaconf import DictConfig + +from verl.experimental.agent_loop import AgentLoopManager +from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup +from verl.single_controller.ray.base import create_colocated_worker_cls +from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role +from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, RewardModelWorker + + +def init_agent_loop_manager(config: DictConfig) -> AgentLoopManager | RayWorkerGroup: + # =========================== 1. Create hybrid ActorRollout workers =========================== + actor_rollout_cls = ( + AsyncActorRolloutRefWorker if config.actor_rollout_ref.rollout.mode == "async" else ActorRolloutRefWorker + ) + role_worker_mapping = { + Role.ActorRollout: ray.remote(actor_rollout_cls), + } + if config.reward_model.enable: + role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) + + global_pool_id = "global_pool" + resource_pool_spec = { + global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, + } + mapping = { + Role.ActorRollout: global_pool_id, + } + if config.reward_model.enable_resource_pool: + mapping[Role.RewardModel] = "reward_pool" + if config.reward_model.n_gpus_per_node <= 0: + raise ValueError("config.reward_model.n_gpus_per_node must be greater than 0") + if config.reward_model.nnodes <= 0: + raise ValueError("config.reward_model.nnodes must be greater than 0") + + reward_pool = [config.reward_model.n_gpus_per_node] * config.reward_model.nnodes + resource_pool_spec["reward_pool"] = reward_pool + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + resource_pool_manager.create_resource_pool() + resource_pool_to_cls = {pool: {} for pool in resource_pool_manager.resource_pool_dict.values()} + + # create actor and rollout + resource_pool = resource_pool_manager.get_resource_pool(Role.ActorRollout) + actor_rollout_cls = RayClassWithInitArgs( + cls=role_worker_mapping[Role.ActorRollout], config=config.actor_rollout_ref, role="actor_rollout" + ) + resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls + + if config.reward_model.enable: + # we create a RM here + resource_pool = resource_pool_manager.get_resource_pool(Role.RewardModel) + rm_cls = RayClassWithInitArgs(role_worker_mapping[Role.RewardModel], config=config.reward_model) + resource_pool_to_cls[resource_pool]["rm"] = rm_cls + + all_wg = {} + for resource_pool, class_dict in resource_pool_to_cls.items(): + worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) + wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls) + spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) + all_wg.update(spawn_wg) + actor_rollout_wg = all_wg["actor_rollout"] + actor_rollout_wg.init_model() + + if config.actor_rollout_ref.rollout.mode == "sync": + raise ValueError("Agent loop tests require async rollout mode. Please set rollout.mode=async.") + + if config.reward_model.enable_resource_pool and config.reward_model.enable: + rm_resource_pool = resource_pool_manager.get_resource_pool(Role.RewardModel) + else: + rm_resource_pool = None + # =========================== 2. Create AgentLoopManager =========================== + agent_loop_manager = AgentLoopManager( + config=config, + worker_group=actor_rollout_wg, + rm_resource_pool=rm_resource_pool, + ) + + return agent_loop_manager diff --git a/code/RL_model/verl/verl_train/tests/experimental/agent_loop/qwen_vl_tool_chat_template.jinja2 b/code/RL_model/verl/verl_train/tests/experimental/agent_loop/qwen_vl_tool_chat_template.jinja2 new file mode 100644 index 0000000000000000000000000000000000000000..9fea57ff86b54917ff806a28b3617bb79517c494 --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/experimental/agent_loop/qwen_vl_tool_chat_template.jinja2 @@ -0,0 +1,150 @@ +{% set image_count = namespace(value=0) %} +{% set video_count = namespace(value=0) %} +{%- if tools %} +{{- '<|im_start|>system\n' }} +{%- if messages[0]['role'] == 'system' %} +{%- if messages[0]['content'] is string %} +{{- messages[0]['content'] }} +{%- else %} +{{- messages[0]['content'][0]['text'] }} +{%- endif %} +{%- else %} +{{- 'You are a helpful assistant.' }} +{%- endif %} +{{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} +{%- for tool in tools %} +{{- "\n" }} +{{- tool | tojson }} +{%- endfor %} +{{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{% for message in messages %} +{% if message['role'] != 'system' or loop.first == false %} +{%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %} +<|im_start|>{{ message['role'] }} +{% if message['content'] is string %} +{{ message['content'] }}<|im_end|> +{% else %} +{% for content in message['content'] %} +{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %} +{% set image_count.value = image_count.value + 1 %} +{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|> +{% elif content['type'] == 'video' or 'video' in content %} +{% set video_count.value = video_count.value + 1 %} +{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|> +{% elif 'text' in content %} +{{ content['text'] }} +{% endif %} +{% endfor %}<|im_end|> +{% endif %} +{%- elif message.role == "assistant" %} +{{- '<|im_start|>' + message.role }} +{%- if message.content %} +{{- '\n' + message.content }} +{%- endif %} +{%- for tool_call in message.tool_calls %} +{%- if tool_call.function is defined %} +{%- set tool_call = tool_call.function %} +{%- endif %} +{{- '\n\n{"name": "' }} +{{- tool_call.name }} +{{- '", "arguments": ' }} +{{- tool_call.arguments | tojson }} +{{- '}\n' }} +{%- endfor %} +{{- '<|im_end|>\n' }} +{%- elif message.role == "tool" %} +{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %} +{{- '<|im_start|>user' }} +{%- endif %} +{{- '\n\n' }} +{% if message['content'] is string %} +{{ message.content }} +{% else %} +{% for content in message['content'] %} +{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %} +{% set image_count.value = image_count.value + 1 %} +{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|> +{% elif content['type'] == 'video' or 'video' in content %} +{% set video_count.value = video_count.value + 1 %} +{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|> +{% elif content['type'] == 'text' or 'text' in content %} +{{ content['text'] }} +{% endif %} +{% endfor %} +{% endif %} +{{- '\n' }} +{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} +{{- '<|im_end|>\n' }} +{%- endif %} +{%- endif %} +{% endif %} +{% endfor %} +{%- else %} +{% for message in messages %} +{% if loop.first and message['role'] != 'system' %} +<|im_start|>system +You are a helpful assistant.<|im_end|> +{% endif %} +{%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %} +<|im_start|>{{ message['role'] }} +{% if message['content'] is string %} +{{ message['content'] }}<|im_end|> +{% else %} +{% for content in message['content'] %} +{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %} +{% set image_count.value = image_count.value + 1 %} +{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|> +{% elif content['type'] == 'video' or 'video' in content %} +{% set video_count.value = video_count.value + 1 %} +{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|> +{% elif 'text' in content %} +{{ content['text'] }} +{% endif %} +{% endfor %}<|im_end|> +{% endif %} +{%- elif message.role == "assistant" %} +{{- '<|im_start|>' + message.role }} +{%- if message.content %} +{{- '\n' + message.content }} +{%- endif %} +{%- for tool_call in message.tool_calls %} +{%- if tool_call.function is defined %} +{%- set tool_call = tool_call.function %} +{%- endif %} +{{- '\n\n{"name": "' }} +{{- tool_call.name }} +{{- '", "arguments": ' }} +{{- tool_call.arguments | tojson }} +{{- '}\n' }} +{%- endfor %} +{{- '<|im_end|>\n' }} +{%- elif message.role == "tool" %} +{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %} +{{- '<|im_start|>user' }} +{%- endif %} +{{- '\n\n' }} +{% if message['content'] is string %} +{{ message.content }} +{% else %} +{% for content in message['content'] %} +{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %} +{% set image_count.value = image_count.value + 1 %} +{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|> +{% elif content['type'] == 'video' or 'video' in content %} +{% set video_count.value = video_count.value + 1 %} +{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|> +{% elif content['type'] == 'text' or 'text' in content %} +{{ content['text'] }} +{% endif %} +{% endfor %} +{% endif %} +{{- '\n' }} +{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} +{{- '<|im_end|>\n' }} +{%- endif %} +{%- endif %} +{% endfor %} +{%- endif %} +{% if add_generation_prompt %} +<|im_start|>assistant +{% endif %} \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/tests/experimental/agent_loop/test_basic_agent_loop.py b/code/RL_model/verl/verl_train/tests/experimental/agent_loop/test_basic_agent_loop.py new file mode 100644 index 0000000000000000000000000000000000000000..7cb55bde48f1d58451b9f29a0999150f74922ca7 --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/experimental/agent_loop/test_basic_agent_loop.py @@ -0,0 +1,454 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os +from typing import Any + +import numpy as np +import pytest +import ray +from omegaconf import DictConfig +from transformers.utils import get_json_schema + +from tests.experimental.agent_loop.agent_utils import init_agent_loop_manager +from verl.checkpoint_engine import CheckpointEngineManager +from verl.experimental.agent_loop import AgentLoopManager +from verl.experimental.agent_loop.agent_loop import get_trajectory_info +from verl.protocol import DataProto +from verl.tools.base_tool import BaseTool, OpenAIFunctionToolSchema +from verl.tools.schemas import ToolResponse +from verl.trainer.ppo.reward import compute_reward, load_reward_manager +from verl.utils import hf_tokenizer + + +@pytest.fixture +def init_config() -> DictConfig: + from hydra import compose, initialize_config_dir + + with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): + config = compose( + config_name="ppo_trainer", + overrides=[ + "actor_rollout_ref.actor.use_dynamic_bsz=true", + # test sleep/wake_up with fsdp offload + "actor_rollout_ref.actor.fsdp_config.param_offload=True", + "actor_rollout_ref.actor.fsdp_config.optimizer_offload=True", + "reward_model.reward_manager=dapo", + "+reward_model.reward_kwargs.overlong_buffer_cfg.enable=False", + "+reward_model.reward_kwargs.overlong_buffer_cfg.len=3072", + "+reward_model.reward_kwargs.max_resp_len=4096", + ], + ) + + model_path = os.path.expanduser("~/models/Qwen/Qwen2.5-1.5B-Instruct") + config.actor_rollout_ref.model.path = model_path + config.actor_rollout_ref.rollout.name = os.environ["ROLLOUT_NAME"] + config.actor_rollout_ref.rollout.mode = "async" + config.actor_rollout_ref.rollout.enforce_eager = True + config.actor_rollout_ref.rollout.prompt_length = 4096 + config.actor_rollout_ref.rollout.response_length = 4096 + config.actor_rollout_ref.rollout.n = 4 + config.actor_rollout_ref.rollout.agent.num_workers = 2 + config.actor_rollout_ref.rollout.skip_tokenizer_init = True + + return config + + +def test_single_turn(init_config): + ray.init( + runtime_env={ + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "INFO", + "VLLM_USE_V1": "1", + } + } + ) + + agent_loop_manager = AgentLoopManager(init_config) + tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path) + reward_fn = load_reward_manager( + init_config, tokenizer, num_examine=0, **init_config.reward_model.get("reward_kwargs", {}) + ) + + raw_prompts = [ + [ + { + "role": "user", + "content": "Let's play a role playing game. Your name is Alice, your favorite color is blue.", + } + ], + [{"role": "user", "content": "Let's play a role playing game. Your name is Bob, your favorite color is red."}], + ] + batch = DataProto( + non_tensor_batch={ + "raw_prompt": np.array(raw_prompts), + "agent_name": np.array(["single_turn_agent"] * len(raw_prompts)), + "data_source": np.array(["openai/gsm8k"] * len(raw_prompts)), + "reward_model": np.array([{"style": "rule", "ground_truth": "1.0"}] * len(raw_prompts)), + }, + ) + n = init_config.actor_rollout_ref.rollout.n + batch = batch.repeat(n) + result = agent_loop_manager.generate_sequences(prompts=batch) + assert len(result) == len(raw_prompts) * n + + # check result + seq_len = result.batch["prompts"].size(1) + result.batch["responses"].size(1) + assert result.batch["input_ids"].size(1) == seq_len + assert result.batch["attention_mask"].size(1) == seq_len + assert result.batch["position_ids"].size(1) == seq_len + + if init_config.actor_rollout_ref.rollout.calculate_log_probs: + assert result.batch["rollout_log_probs"].size(1) == result.batch["responses"].size(1) + + # check compute score + assert result.batch["rm_scores"].shape == result.batch["responses"].shape + reward_tensor, reward_extra_info = compute_reward(result, reward_fn) + assert reward_tensor.shape == result.batch["responses"].shape + assert "acc" in reward_extra_info, f"reward_extra_info {reward_extra_info} should contain 'acc'" + assert reward_extra_info["acc"].shape == (len(result),), f"invalid acc: {reward_extra_info['acc']}" + + # check turns + num_turns = result.non_tensor_batch["__num_turns__"] + assert np.all(num_turns == 2) + + print("Test passed!") + ray.shutdown() + + +class WeatherTool(BaseTool): + def get_current_temperature(self, location: str, unit: str = "celsius"): + """Get current temperature at a location. + + Args: + location: The location to get the temperature for, in the format "City, State, Country". + unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"]) + + Returns: + the temperature, the location, and the unit in a dict + """ + print(f"[DEBUG] get_current_temperature: {location}, {unit}") + return { + "temperature": 26.1, + "location": location, + "unit": unit, + } + + def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: + schema = get_json_schema(self.get_current_temperature) + return OpenAIFunctionToolSchema(**schema) + + async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[ToolResponse, float, dict]: + try: + result = self.get_current_temperature(**parameters) + return ToolResponse(text=json.dumps(result)), 0, {} + except Exception as e: + return ToolResponse(text=str(e)), 0, {} + + +class WeatherToolWithData(BaseTool): + def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: + schema = get_json_schema(self.get_temperature_date) + return OpenAIFunctionToolSchema(**schema) + + def get_temperature_date(self, location: str, date: str, unit: str = "celsius"): + """Get temperature at a location and date. + + Args: + location: The location to get the temperature for, in the format "City, State, Country". + date: The date to get the temperature for, in the format "Year-Month-Day". + unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"]) + + Returns: + the temperature, the location, the date and the unit in a dict + """ + print(f"[DEBUG] get_temperature_date: {location}, {date}, {unit}") + return { + "temperature": 25.9, + "location": location, + "date": date, + "unit": unit, + } + + async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[ToolResponse, float, dict]: + try: + result = self.get_temperature_date(**parameters) + return ToolResponse(text=json.dumps(result)), 0, {} + except Exception as e: + return ToolResponse(text=str(e)), 0, {} + + +def test_tool_agent(init_config): + ray.init( + runtime_env={ + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "INFO", + "VLLM_USE_V1": "1", + } + }, + ignore_reinit_error=True, + ) + + # =========================== 1. Init rollout manager =========================== + tool_config = { + "tools": [ + { + "class_name": "tests.experimental.agent_loop.test_basic_agent_loop.WeatherTool", + "config": {"type": "native"}, + }, + { + "class_name": "tests.experimental.agent_loop.test_basic_agent_loop.WeatherToolWithData", + "config": {"type": "native"}, + }, + ] + } + tool_config_path = "/tmp/tool_config.json" + with open(tool_config_path, "w") as f: + json.dump(tool_config, f) + + n = 2 + init_config.actor_rollout_ref.rollout.n = n + init_config.actor_rollout_ref.rollout.multi_turn.tool_config_path = tool_config_path + init_config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls = 2 + init_config.actor_rollout_ref.rollout.calculate_log_probs = True + agent_loop_manager = AgentLoopManager(init_config) + + # =========================== 2. Generate sequences =========================== + raw_prompts = [ + [ + {"role": "user", "content": "How are you?"}, + ], + [ + {"role": "user", "content": "What's the temperature in Los Angeles now?"}, + ], + [ + {"role": "user", "content": "What's the temperature in New York now?"}, + ], + [ + { + "role": "system", + "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\n\n" + "Current Date: 2024-09-30", + }, + {"role": "user", "content": "What's the temperature in San Francisco now? How about tomorrow?"}, + ], + ] + batch = DataProto( + non_tensor_batch={ + "raw_prompt": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object), + "agent_name": np.array(["tool_agent"] * len(raw_prompts)), + "data_source": np.array(["openai/gsm8k"] * len(raw_prompts)), + "reward_model": np.array([{"style": "rule", "ground_truth": "1.0"}] * len(raw_prompts)), + }, + ) + batch = batch.repeat(n) + result = agent_loop_manager.generate_sequences(prompts=batch) + assert len(result) == len(raw_prompts) * n + + # Check turns + num_turns = result.non_tensor_batch["__num_turns__"] + print(f"num_turns: {num_turns}") + for i in range(len(num_turns)): + if i // n == 0: + # [user, assistant] + assert num_turns[i] == 2 + else: + # [user, assistant, tool, assistant] + assert num_turns[i] == 4 + + # Check response_mask + tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path) + responses = result.batch["responses"] + response_mask = result.batch["response_mask"] + attention_mask = result.batch["attention_mask"] + assert result.batch["rm_scores"].size(1) == responses.size(1) + assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}" + assert result.batch["rollout_log_probs"].size(1) == result.batch["responses"].size(1) + + response_length = response_mask.size(1) + for i in range(len(responses)): + # response with tool response + valid_tokens = responses[i][attention_mask[i][-response_length:].bool()] + response_with_obs = tokenizer.decode(valid_tokens) + + # response without tool response + valid_tokens = responses[i][response_mask[i].bool()] + response_without_obs = tokenizer.decode(valid_tokens) + + assert "" not in response_without_obs, ( + f"found in response: {response_without_obs}" + ) + assert "" not in response_without_obs, ( + f"found in response: {response_without_obs}" + ) + print("=========================") + print(response_with_obs) + print("---") + print(response_without_obs) + + print("Test passed!") + ray.shutdown() + + +def test_tool_agent_with_interaction(init_config): + ray.init( + runtime_env={ + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "INFO", + "VLLM_USE_V1": "1", + } + } + ) + + # =========================== 1. Init rollout manager =========================== + tool_config = { + "tools": [ + { + "class_name": "tests.experimental.agent_loop.test_basic_agent_loop.WeatherTool", + "config": {"type": "native"}, + }, + { + "class_name": "tests.experimental.agent_loop.test_basic_agent_loop.WeatherToolWithData", + "config": {"type": "native"}, + }, + ] + } + tool_config_path = "/tmp/tool_config.json" + with open(tool_config_path, "w") as f: + json.dump(tool_config, f) + + interaction_config = { + "interaction": [ + {"name": "weather", "class_name": "verl.interactions.weather_interaction.WeatherInteraction", "config": {}} + ] + } + interaction_config_path = "/tmp/interaction_config.json" + with open(interaction_config_path, "w") as f: + json.dump(interaction_config, f) + + n = 2 + init_config.actor_rollout_ref.rollout.n = n + init_config.actor_rollout_ref.rollout.multi_turn.tool_config_path = tool_config_path + init_config.actor_rollout_ref.rollout.multi_turn.interaction_config_path = interaction_config_path + init_config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls = 2 + agent_loop_manager = init_agent_loop_manager(init_config) + checkpoint_manager = CheckpointEngineManager( + backend=init_config.actor_rollout_ref.rollout.checkpoint_engine.backend, + trainer=agent_loop_manager.worker_group, + replicas=agent_loop_manager.rollout_replicas, + ) + checkpoint_manager.sleep_replicas() + checkpoint_manager.update_weights() + + # =========================== 2. Generate sequences =========================== + raw_prompts = [ + [ + {"role": "user", "content": "How are you?"}, + ], + [ + {"role": "user", "content": "What's the temperature in Los Angeles now?"}, + ], + [ + {"role": "user", "content": "What's the temperature in New York now?"}, + ], + [ + { + "role": "system", + "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\n\n" + "Current Date: 2024-09-30", + }, + {"role": "user", "content": "What's the temperature in San Francisco now? How about tomorrow?"}, + ], + ] + batch = DataProto( + non_tensor_batch={ + "raw_prompt": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object), + "agent_name": np.array(["tool_agent"] * len(raw_prompts)), + "data_source": np.array(["openai/gsm8k"] * len(raw_prompts)), + "reward_model": np.array([{"style": "rule", "ground_truth": "1.0"}] * len(raw_prompts)), + "extra_info": np.array( + [ + {"interaction_kwargs": {"name": "weather"}}, + {"interaction_kwargs": {"name": "weather"}}, + {"interaction_kwargs": {"name": "weather"}}, + {"interaction_kwargs": {"name": "weather"}}, + ] + ), + }, + ) + batch = batch.repeat(n) + result = agent_loop_manager.generate_sequences(prompts=batch) + assert len(result) == len(raw_prompts) * n + + # Check turns + num_turns = result.non_tensor_batch["__num_turns__"] + print(f"num_turns: {num_turns}") + for i in range(len(num_turns)): + if i // n == 0: + # [user, assistant, user] + assert num_turns[i] == 3 + else: + # [user, assistant, tool, assistant, user] + assert num_turns[i] == 5 + + # Check response_mask + tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path) + responses = result.batch["responses"] + response_mask = result.batch["response_mask"] + attention_mask = result.batch["attention_mask"] + assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}" + response_length = response_mask.size(1) + + for i in range(len(responses)): + # response with tool response + valid_tokens = responses[i][attention_mask[i][-response_length:].bool()] + response_with_obs = tokenizer.decode(valid_tokens) + + # response without tool response + valid_tokens = responses[i][response_mask[i].bool()] + response_without_obs = tokenizer.decode(valid_tokens) + + assert "\udb82\udc89" not in response_without_obs, f"found \udb82\udc89 in response: {response_without_obs}" + assert "\udb82\udc8a" not in response_without_obs, f"found \udb82\udc8a in response: {response_without_obs}" + print("=========================") + print(response_with_obs) + print("---") + print(response_without_obs) + + print("Test passed!") + ray.shutdown() + + +@pytest.mark.asyncio +async def test_get_trajectory_info(): + """Tests the get_trajectory_info method.""" + # Initialize the class to set up class-level attributes + step = 10 + index = [1, 1, 3, 3] + expected_info = [ + {"step": step, "sample_index": 1, "rollout_n": 0, "validate": False}, + {"step": step, "sample_index": 1, "rollout_n": 1, "validate": False}, + {"step": step, "sample_index": 3, "rollout_n": 0, "validate": False}, + {"step": step, "sample_index": 3, "rollout_n": 1, "validate": False}, + ] + + trajectory_info = await get_trajectory_info(step, index, validate=False) + + assert trajectory_info == expected_info diff --git a/code/RL_model/verl/verl_train/tests/experimental/agent_loop/test_gpt_oss_tool_parser.py b/code/RL_model/verl/verl_train/tests/experimental/agent_loop/test_gpt_oss_tool_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..a58c977a1b0d4eac2cbd542aab1fd0b8b691f1df --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/experimental/agent_loop/test_gpt_oss_tool_parser.py @@ -0,0 +1,34 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from transformers import AutoTokenizer + +from verl.experimental.agent_loop.tool_parser import GptOssToolParser + + +@pytest.mark.asyncio +@pytest.mark.skip(reason="local test only") +async def test_gpt_oss_tool_parser(): + example_text = """ +<|start|>assistant<|channel|>commentary to=functions.get_current_weather \ +<|constrain|>json<|message|>{"location": "Tokyo"}<|call|> +<|start|>functions.get_current_weather to=assistant<|channel|>commentary<|message|>\ +{ "temperature": 20, "sunny": true }<|end|>""" + tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b") + response_ids = tokenizer.encode(example_text) + tool_parser = GptOssToolParser(tokenizer) + _, function_calls = await tool_parser.extract_tool_calls(response_ids) + assert len(function_calls) == 1 + assert function_calls[0].name == "get_current_weather" + assert function_calls[0].arguments == '{"location": "Tokyo"}' diff --git a/code/RL_model/verl/verl_train/tests/experimental/agent_loop/test_multi_modal.py b/code/RL_model/verl/verl_train/tests/experimental/agent_loop/test_multi_modal.py new file mode 100644 index 0000000000000000000000000000000000000000..7810c7a4599c3016581a210158b59acc11a86748 --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/experimental/agent_loop/test_multi_modal.py @@ -0,0 +1,570 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os +from typing import Any + +import numpy as np +import pytest +import ray +from omegaconf import DictConfig +from PIL import Image +from transformers.utils import get_json_schema + +from verl.experimental.agent_loop import AgentLoopManager +from verl.protocol import DataProto +from verl.tools.base_tool import BaseTool, OpenAIFunctionToolSchema +from verl.tools.schemas import ToolResponse +from verl.utils import hf_tokenizer + + +def parse_multi_modal_type(messages: list[dict]) -> str: + message = messages[-1] + if isinstance(message["content"], str): + return "text" + + for content in message["content"]: + if content["type"] == "image": + return "image" + elif content["type"] == "video": + return "video" + + return "text" + + +@pytest.fixture +def init_config() -> DictConfig: + from hydra import compose, initialize_config_dir + + with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): + config = compose( + config_name="ppo_trainer", + overrides=[ + "actor_rollout_ref.actor.use_dynamic_bsz=true", + # test sleep/wake_up with fsdp offload + "actor_rollout_ref.actor.fsdp_config.param_offload=True", + "actor_rollout_ref.actor.fsdp_config.optimizer_offload=True", + ], + ) + + model_path = os.path.expanduser("~/models/Qwen/Qwen2.5-VL-3B-Instruct") + config.actor_rollout_ref.model.path = model_path + config.actor_rollout_ref.rollout.name = os.environ["ROLLOUT_NAME"] + config.actor_rollout_ref.rollout.mode = "async" + config.actor_rollout_ref.rollout.enforce_eager = True + config.actor_rollout_ref.rollout.prompt_length = 10240 + config.actor_rollout_ref.rollout.response_length = 4096 + config.actor_rollout_ref.rollout.n = 4 + config.actor_rollout_ref.rollout.agent.num_workers = 2 + config.actor_rollout_ref.rollout.skip_tokenizer_init = True + + return config + + +class ImageGeneratorTool(BaseTool): + def generate_image(self, description: str, size: str = "256x256"): + """Generate a simple image based on description. + + Args: + description: The description of the image to generate. + size: The size of the image. Defaults to "256x256". (choices: ["256x256", "512x512"]) + + Returns: + A generated image + """ + print(f"[DEBUG] generate_image: {description}, {size}") + # Create a simple colored image for testing + width, height = map(int, size.split("x")) + + # Create different colors based on description + if "red" in description.lower(): + color = (255, 0, 0) + elif "blue" in description.lower(): + color = (0, 0, 255) + elif "green" in description.lower(): + color = (0, 255, 0) + else: + color = (128, 128, 128) # gray + + # Create image + image = Image.new("RGB", (width, height), color) + + # Add some pattern to make it more interesting + for i in range(0, width, 50): + for j in range(0, height, 50): + # Add white squares in a grid pattern + for x in range(i, min(i + 20, width)): + for y in range(j, min(j + 20, height)): + image.putpixel((x, y), (255, 255, 255)) + + return image + + def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: + schema = get_json_schema(self.generate_image) + return OpenAIFunctionToolSchema(**schema) + + async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[ToolResponse, float, dict]: + try: + image = self.generate_image(**parameters) + # Return the PIL Image directly - the framework should handle the conversion + return ToolResponse(image=[image]), 0, {} + except Exception as e: + return ToolResponse(text=str(e)), 0, {} + + +@pytest.mark.flaky(reruns=3) +def test_multimodal_tool_agent(init_config): + """Test agent loop with multimodal tool that returns images using Qwen VL model.""" + ray.shutdown() + ray.init( + runtime_env={ + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "INFO", + "VLLM_USE_V1": "1", + } + }, + ignore_reinit_error=True, + ) + + # Add custom chat template to enable tool calling support (same as recipe/deepeyes) + template_path = os.path.join(os.path.dirname(__file__), "qwen_vl_tool_chat_template.jinja2") + with open(template_path, encoding="utf-8") as f: + custom_chat_template = f.read() + + init_config.actor_rollout_ref.model.custom_chat_template = custom_chat_template + + # =========================== 1. Init rollout manager with image tool =========================== + tool_config = { + "tools": [ + { + "class_name": "tests.experimental.agent_loop.test_multi_modal.ImageGeneratorTool", + "config": {"type": "native"}, + }, + ] + } + tool_config_path = "/tmp/multimodal_tool_config.json" + with open(tool_config_path, "w") as f: + json.dump(tool_config, f) + + n = 2 + init_config.actor_rollout_ref.rollout.n = n + init_config.actor_rollout_ref.rollout.multi_turn.tool_config_path = tool_config_path + init_config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls = 1 + init_config.actor_rollout_ref.rollout.multi_turn.max_user_turns = 1 + agent_loop_manager = AgentLoopManager(init_config) + + # =========================== 2. Generate sequences with multimodal prompts =========================== + raw_prompts = [ + [ + {"role": "user", "content": "How are you?"}, + ], + [ + { + "role": "user", + "content": [ + { + "type": "video", + "video": os.path.expanduser("~/models/hf_data/test-videos/space_woaudio.mp4"), + "min_pixels": 4 * 32 * 32, + "max_pixels": 256 * 32 * 32, + "total_pixels": 4096 * 32 * 32, + }, + { + "type": "text", + "text": "Describe this video. Then you must call the " + "image generator tool to generate a green image for me.", + }, + ], + }, + ], + [ + {"role": "user", "content": "Please generate a red image for me."}, + ], + [ + {"role": "user", "content": "Can you create a blue picture with size 512x512?"}, + ], + [ + { + "role": "system", + "content": ( + "You are Qwen VL, created by Alibaba Cloud. You are a helpful " + "assistant that can generate and analyze images." + ), + }, + {"role": "user", "content": "Generate a green landscape image and describe what you see in it."}, + ], + ] + + batch = DataProto( + non_tensor_batch={ + "raw_prompt": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object), + "agent_name": np.array(["tool_agent"] * len(raw_prompts)), + "data_source": np.array(["openai/gsm8k"] * len(raw_prompts)), + "reward_model": np.array([{"style": "rule", "ground_truth": "1.0"}] * len(raw_prompts)), + }, + ) + batch = batch.repeat(n) + result = agent_loop_manager.generate_sequences(prompts=batch) + assert len(result) == len(raw_prompts) * n + + # Check turns + num_turns = result.non_tensor_batch["__num_turns__"] + multi_modal_inputs = result.non_tensor_batch["multi_modal_inputs"] + print(f"num_turns: {num_turns}") + for i in range(len(num_turns)): + multi_modal_type = parse_multi_modal_type(raw_prompts[i // n]) + if multi_modal_type == "video": + assert "pixel_values_videos" in multi_modal_inputs[i], f"Sample {i} should have pixel_values_videos" + assert "video_grid_thw" in multi_modal_inputs[i], f"Sample {i} should have video_grid_thw" + + if i // n <= 1: + # TODO: prompt with video not generate tool call as expected + # First prompt: "How are you?" - should have 2 turns [user, assistant] + assert num_turns[i] == 2, f"Expected 2 turns but got {num_turns[i]} for sample {i}" + else: + # Tool-calling prompts should have 4 turns [user, assistant, tool, assistant] + assert num_turns[i] == 4, f"Expected 4 turns but got {num_turns[i]} for sample {i}" + assert "pixel_values" in multi_modal_inputs[i], f"Sample {i} should have pixel_values" + assert "image_grid_thw" in multi_modal_inputs[i], f"Sample {i} should have image_grid_thw" + + # Check that images were properly returned in the tool responses + tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path) + responses = result.batch["responses"] + response_mask = result.batch["response_mask"] + attention_mask = result.batch["attention_mask"] + assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}" + response_length = response_mask.size(1) + + image_found_count = 0 + for i in range(len(responses)): + # response with tool response (including images) + valid_tokens = responses[i][attention_mask[i][-response_length:].bool()] + response_with_obs = tokenizer.decode(valid_tokens) + + # response without tool response + valid_tokens = responses[i][response_mask[i].bool()] + response_without_obs = tokenizer.decode(valid_tokens) + + # Check that tool responses were properly masked out from training + assert "" not in response_without_obs, ( + f"found in response: {response_without_obs}" + ) + assert "" not in response_without_obs, ( + f"found in response: {response_without_obs}" + ) + + # Check that images were included in the full response + if "" in response_with_obs or "image" in response_with_obs.lower(): + image_found_count += 1 + + print("=========================") + print("Response with tool observations:") + print(response_with_obs) + print("---") + print("Response without tool observations:") + print(response_without_obs) + + # Verify that tool-calling responses contained image-related content + print(f"Found {image_found_count} responses with image content out of {len(responses)}") + # We should have at least some image content from the tool-calling prompts + # Note: First prompt might not use tools, so we don't expect 100% image content + expected_tool_calls = sum(1 for i in range(len(num_turns)) if num_turns[i] == 4) + assert image_found_count >= 0, ( + f"No image-related content found, but expected at least some from {expected_tool_calls} tool calls" + ) + + print("Multimodal tool test passed!") + ray.shutdown() + + +def test_multimodal_single_turn_agent(init_config): + """Test single turn agent loop with multimodal inputs using Qwen VL model.""" + ray.init( + runtime_env={ + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "INFO", + "VLLM_USE_V1": "1", + } + }, + ignore_reinit_error=True, + ) + + # =========================== 1. Init rollout manager =========================== + n = 2 + init_config.actor_rollout_ref.rollout.n = n + init_config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls = 1 + init_config.actor_rollout_ref.rollout.multi_turn.max_user_turns = 1 + agent_loop_manager = AgentLoopManager(init_config) + + # =========================== 2. Generate sequences with multimodal prompts =========================== + # Create a simple test image + test_image = Image.new("RGB", (256, 256), (100, 150, 200)) + test_image2 = Image.new("RGB", (512, 512), (100, 150, 200)) + + raw_prompts = [ + # text + [ + {"role": "user", "content": "Hello, how are you?"}, + ], + # image + [ + { + "role": "user", + "content": [ + {"type": "image", "image": test_image}, + {"type": "text", "text": "What color is this image?"}, + ], + }, + ], + # system + image + [ + { + "role": "system", + "content": "You are Qwen VL, created by Alibaba Cloud. You are a helpful assistant.", + }, + { + "role": "user", + "content": [ + {"type": "image", "image": test_image2}, + {"type": "text", "text": "Describe this image in detail."}, + ], + }, + ], + # video + [ + { + "role": "user", + "content": [ + { + "type": "video", + "video": os.path.expanduser("~/models/hf_data/test-videos/space_woaudio.mp4"), + "min_pixels": 4 * 32 * 32, + "max_pixels": 256 * 32 * 32, + "total_pixels": 4096 * 32 * 32, + }, + {"type": "text", "text": "Describe this video."}, + ], + }, + ], + ] + + batch = DataProto( + non_tensor_batch={ + "raw_prompt": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object), + "agent_name": np.array(["single_turn_agent"] * len(raw_prompts)), + "data_source": np.array(["openai/gsm8k"] * len(raw_prompts)), + "reward_model": np.array([{"style": "rule", "ground_truth": "1.0"}] * len(raw_prompts)), + }, + ) + + batch = batch.repeat(n) + result = agent_loop_manager.generate_sequences(prompts=batch) + assert len(result) == len(raw_prompts) * n + + # Check turns - all should be single turn (2: user + assistant) + num_turns = result.non_tensor_batch["__num_turns__"] + print(f"num_turns: {num_turns}") + for i in range(len(num_turns)): + assert num_turns[i] == 2, f"Expected 2 turns but got {num_turns[i]} for sample {i}" + + # Verify responses + tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path) + prompts = result.batch["prompts"] + responses = result.batch["responses"] + response_mask = result.batch["response_mask"] + input_ids = result.batch["input_ids"] + position_ids = result.batch["position_ids"] + multi_modal_inputs = result.non_tensor_batch["multi_modal_inputs"] + assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}" + assert position_ids.size() == (input_ids.size(0), 4, input_ids.size(1)) # (batch_size, 4, seq_len) + + # Check for image pads in prompts + image_pad_count = 0 + for i in range(len(prompts)): + prompt_ids = prompts[i][prompts[i] != tokenizer.pad_token_id].tolist() + prompt_text = tokenizer.decode(prompt_ids) + + # Check if this sample should have image pads (samples with index 1 and 2 in each repeat have images) + sample_idx = i // n + has_image_pad = "<|image_pad|>" in prompt_text or "<|vision_start|>" in prompt_text + + print("=========================") + print(f"Sample {i} (original prompt index: {sample_idx}):") + print(f"Prompt length: {len(prompt_ids)} tokens") + print(f"Has image_pad: {has_image_pad}") + + # Check multi-modal type + multi_modal_type = parse_multi_modal_type(raw_prompts[sample_idx]) + + if multi_modal_type == "text": + assert len(multi_modal_inputs[i]) == 0, f"Sample {i} should not have multi-modal inputs" + elif multi_modal_type == "image": + assert "pixel_values" in multi_modal_inputs[i], f"Sample {i} should have pixel_values" + assert "image_grid_thw" in multi_modal_inputs[i], f"Sample {i} should have image_grid_thw" + else: + assert "pixel_values_videos" in multi_modal_inputs[i], f"Sample {i} should have pixel_values_videos" + assert "video_grid_thw" in multi_modal_inputs[i], f"Sample {i} should have video_grid_thw" + + # Show first 200 chars of prompt + print(f"Prompt text (first 200 chars): {prompt_text[:200]}...") + + for i in range(len(responses)): + valid_tokens = responses[i][response_mask[i].bool()] + response_text = tokenizer.decode(valid_tokens) + print(f"Sample {i} response: {response_text[:100]}...") + + # Verify that we found image pads in multimodal samples + expected_multimodal_samples = 2 * n # 2 prompts with images, repeated n times + print(f"\nFound {image_pad_count} samples with image_pad out of {expected_multimodal_samples} expected") + + print("Single turn multimodal test passed!") + ray.shutdown() + + +def test_multimodal_partial_single_turn_agent(init_config): + """Test partial single turn agent loop with multimodal inputs using Qwen VL model.""" + + # TODO(baiyan): + # see verl/recipe/fully_async_policy/agent_loop/partial_single_turn_agent_loop.py for more details. + # if use_correct_processor=True, the test will pass but the async training will hang, so I disable this test + # for now + + return + + ray.init( + runtime_env={ + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "INFO", + "VLLM_USE_V1": "1", + } + }, + ignore_reinit_error=True, + ) + from verl.experimental.fully_async_policy.agent_loop import FullyAsyncAgentLoopManager + + # =========================== 1. Init rollout manager =========================== + n = 2 + init_config.actor_rollout_ref.rollout.n = n + init_config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls = 1 + init_config.actor_rollout_ref.rollout.multi_turn.max_user_turns = 1 + import asyncio + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + agent_loop_manager = loop.run_until_complete(FullyAsyncAgentLoopManager.create(init_config)) + + # =========================== 2. Generate sequences with multimodal prompts =========================== + # Create a simple test image + test_image = Image.new("RGB", (256, 256), (200, 100, 50)) + test_image2 = Image.new("RGB", (512, 512), (100, 150, 200)) + + raw_prompts = [ + [ + {"role": "user", "content": "What is the capital of France?"}, + ], + [ + { + "role": "user", + "content": [ + {"type": "image", "image": test_image}, + {"type": "text", "text": "What do you see in this image?"}, + ], + }, + ], + [ + { + "role": "system", + "content": "You are Qwen VL, a helpful multimodal assistant.", + }, + { + "role": "user", + "content": [ + {"type": "image", "image": test_image2}, + {"type": "text", "text": "Analyze the colors in this image."}, + ], + }, + ], + ] + + batch = DataProto( + non_tensor_batch={ + "raw_prompt": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object), + "agent_name": np.array(["partial_single_turn_agent"] * len(raw_prompts)), + "data_source": np.array(["openai/gsm8k"] * len(raw_prompts)), + "reward_model": np.array([{"style": "rule", "ground_truth": "1.0"}] * len(raw_prompts)), + }, + ) + + batch = batch.repeat(n) + result = agent_loop_manager.generate_sequences(prompts=batch) + assert len(result) == len(raw_prompts) * n + + # Check turns - all should be single turn (2: user + assistant) + num_turns = result.non_tensor_batch["__num_turns__"] + print(f"num_turns: {num_turns}") + for i in range(len(num_turns)): + assert num_turns[i] == 2, f"Expected 2 turns but got {num_turns[i]} for sample {i}" + + # Verify responses + tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path) + prompts = result.batch["prompts"] + responses = result.batch["responses"] + response_mask = result.batch["response_mask"] + assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}" + + # Check for image pads in prompts + image_pad_count = 0 + for i in range(len(prompts)): + prompt_ids = prompts[i][prompts[i] != tokenizer.pad_token_id].tolist() + prompt_text = tokenizer.decode(prompt_ids) + + # Check if this sample should have image pads (samples with index 1 and 2 in each repeat have images) + sample_idx = i // n + has_image_pad = "<|image_pad|>" in prompt_text or "<|vision_start|>" in prompt_text + + print("=========================") + print(f"Sample {i} (original prompt index: {sample_idx}):") + print(f"Prompt length: {len(prompt_ids)} tokens") + print(f"Has image_pad: {has_image_pad}") + + if sample_idx != 0: # Samples 1 and 2 should have images + if has_image_pad: + image_pad_count += 1 + # Count the number of image_pad tokens + num_image_pads = prompt_text.count("<|image_pad|>") + print(f"Number of <|image_pad|> tokens: {num_image_pads}") + else: + print("WARNING: Expected image_pad but not found!") + + # Show first 200 chars of prompt + print(f"Prompt text (first 200 chars): {prompt_text[:200]}...") + + for i in range(len(responses)): + valid_tokens = responses[i][response_mask[i].bool()] + response_text = tokenizer.decode(valid_tokens) + print(f"Sample {i} response: {response_text[:100]}...") + + # Verify that we found image pads in multimodal samples + expected_multimodal_samples = 2 * n # 2 prompts with images, repeated n times + print(f"\nFound {image_pad_count} samples with image_pad out of {expected_multimodal_samples} expected") + assert image_pad_count > 0, "No image_pad tokens found in multimodal samples!" + + print("Partial single turn multimodal test passed!") + ray.shutdown() diff --git a/code/RL_model/verl/verl_train/tests/experimental/agent_loop/test_standalone_rollout.py b/code/RL_model/verl/verl_train/tests/experimental/agent_loop/test_standalone_rollout.py new file mode 100644 index 0000000000000000000000000000000000000000..96b7912045ba37bbd18b554841fe899e05c807e1 --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/experimental/agent_loop/test_standalone_rollout.py @@ -0,0 +1,157 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import os + +import pytest +import ray +from omegaconf import DictConfig +from openai import AsyncOpenAI, OpenAI + +from tests.experimental.agent_loop.agent_utils import init_agent_loop_manager +from verl.checkpoint_engine import CheckpointEngineManager +from verl.workers.rollout.replica import get_rollout_replica_class + + +@pytest.fixture +def init_config() -> DictConfig: + from hydra import compose, initialize_config_dir + + with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): + config = compose(config_name="ppo_trainer") + + config.trainer.n_gpus_per_node = 4 + config.trainer.nnodes = 2 + config.actor_rollout_ref.actor.use_dynamic_bsz = True + config.actor_rollout_ref.model.path = os.path.expanduser("~/models/Qwen/Qwen2.5-1.5B-Instruct") + config.actor_rollout_ref.rollout.name = os.environ["ROLLOUT_NAME"] + config.actor_rollout_ref.rollout.mode = "async" + config.actor_rollout_ref.rollout.skip_tokenizer_init = False + + return config + + +@pytest.mark.asyncio +@pytest.mark.parametrize("tp_size", [2, 4]) +async def test_standalone_rollout(init_config, tp_size): + """Test standalone rollout single node and multi nodes.""" + ray.init( + runtime_env={ + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "INFO", + "VLLM_USE_V1": "1", + "NCCL_P2P_DISABLE": "1", # disable p2p in L20 + } + } + ) + + init_config.actor_rollout_ref.rollout.tensor_model_parallel_size = tp_size + num_replicas = (init_config.trainer.n_gpus_per_node * init_config.trainer.nnodes) // tp_size + rollout_config = init_config.actor_rollout_ref.rollout + model_config = init_config.actor_rollout_ref.model + + # create standalone rollout server + rollout_server_class = get_rollout_replica_class(init_config.actor_rollout_ref.rollout.name) + rollout_servers = [ + rollout_server_class( + replica_rank=replica_rank, config=rollout_config, model_config=model_config, gpus_per_node=2 + ) + for replica_rank in range(num_replicas) + ] + await asyncio.gather(*[server.init_standalone() for server in rollout_servers]) + + server_handles = [server._server_handle for server in rollout_servers] + server_addresses = [server._server_address for server in rollout_servers] + assert len(server_handles) == num_replicas + assert len(server_addresses) == num_replicas + + os.environ.pop("HTTPS_PROXY", None) + os.environ.pop("HTTP_PROXY", None) + os.environ.pop("NO_PROXY", None) + + client = AsyncOpenAI( + api_key="123-abc", + base_url=f"http://{server_addresses[0]}/v1", + ) + + completion = await client.chat.completions.create( + model=init_config.actor_rollout_ref.model.path, + messages=[{"role": "user", "content": "What can you do?"}], + ) + print(completion.choices[0].message.content) + + ray.shutdown() + + +@pytest.mark.skip(reason="local test only") +def test_hybrid_rollout_with_ep(init_config): + """Test hybrid rollout with expert parallelism, DP=2, TP=4, EP=8.""" + ray.init( + runtime_env={ + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "INFO", + "VLLM_USE_V1": "1", + } + } + ) + + model_path = os.path.expanduser("~/models/Qwen/Qwen3-30B-A3B-Instruct-2507") + init_config.actor_rollout_ref.model.path = model_path + + # parallelism config + init_config.actor_rollout_ref.rollout.tensor_model_parallel_size = 2 + init_config.actor_rollout_ref.rollout.data_parallel_size = 4 + init_config.actor_rollout_ref.rollout.expert_parallel_size = 8 + + # 1. init hybrid worker: FSDP+rollout + # - build FSDP model and optimizer + # - offload FSDP model and optimizer, build rollout + # - sleep rollout and load FSDP model and optimizer + agent_loop_manager = init_agent_loop_manager(init_config) + checkpoint_manager = CheckpointEngineManager( + backend=init_config.actor_rollout_ref.rollout.checkpoint_engine.backend, + trainer=agent_loop_manager.worker_group, + replicas=agent_loop_manager.rollout_replicas, + ) + checkpoint_manager.sleep_replicas() + checkpoint_manager.update_weights() + + # 3. test async openai call + server_address = agent_loop_manager.server_addresses[0] + client = OpenAI( + api_key="123-abc", + base_url=f"http://{server_address}/v1", + ) + + smapling_params = { + "temperature": 1.0, + "top_p": 1.0, + "max_tokens": 512, + } + + response = client.chat.completions.create( + model=model_path, + messages=[{"role": "user", "content": "What can you do?"}], + **smapling_params, + ) + + completion = response.choices[0].message.content + print(f"response: {completion}") + + print("Test passed!") + ray.shutdown() diff --git a/code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_agent_loop_reward_manager.py b/code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_agent_loop_reward_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..ef8e6a3da7ca102ff8f64852809cb8a92dc47e45 --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_agent_loop_reward_manager.py @@ -0,0 +1,111 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import ray +from hydra import compose, initialize_config_dir +from torchdata.stateful_dataloader import StatefulDataLoader +from transformers import AutoTokenizer + +from verl.experimental.agent_loop import AgentLoopManager +from verl.protocol import DataProto +from verl.trainer.main_ppo import create_rl_sampler +from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn + + +def test_agent_loop_reward_manager(): + ray.init( + runtime_env={ + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "INFO", + "VLLM_USE_V1": "1", + } + } + ) + with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): + config = compose(config_name="ppo_trainer") + + rollout_model_path = os.path.expanduser("~/models/Qwen/Qwen2.5-0.5B-Instruct") + reward_model_path = os.path.expanduser("~/models/Qwen/Qwen2.5-1.5B-Instruct") + + # actor_rollout_ref config + config.data.return_raw_chat = True + config.data.max_prompt_length = 1024 + config.data.max_response_length = 4096 + config.actor_rollout_ref.model.path = rollout_model_path + config.actor_rollout_ref.actor.use_dynamic_bsz = True + config.actor_rollout_ref.rollout.name = os.getenv("ROLLOUT_NAME", "vllm") + config.actor_rollout_ref.rollout.mode = "async" + config.actor_rollout_ref.rollout.tensor_model_parallel_size = 2 + config.actor_rollout_ref.rollout.gpu_memory_utilization = 0.9 + config.actor_rollout_ref.rollout.enforce_eager = True + config.actor_rollout_ref.rollout.prompt_length = 1024 + config.actor_rollout_ref.rollout.response_length = 4096 + config.actor_rollout_ref.rollout.skip_tokenizer_init = True + config.trainer.n_gpus_per_node = 4 + config.trainer.nnodes = 1 + + config.reward_model.reward_manager = "dapo" + config.reward_model.enable = True + config.reward_model.enable_resource_pool = True + config.reward_model.n_gpus_per_node = 4 + config.reward_model.nnodes = 1 + config.reward_model.model.path = reward_model_path + config.reward_model.rollout.name = os.getenv("ROLLOUT_NAME", "vllm") + config.reward_model.rollout.gpu_memory_utilization = 0.9 + config.reward_model.rollout.tensor_model_parallel_size = 2 + config.reward_model.rollout.skip_tokenizer_init = False + config.reward_model.rollout.prompt_length = 5120 + config.reward_model.rollout.response_length = 4096 + config.custom_reward_function.path = "tests/experimental/reward_loop/reward_fn.py" + config.custom_reward_function.name = "compute_score_gsm8k" + + # 1. init reward model manager + agent_loop_manager = AgentLoopManager(config) + + # 2. init test data + local_folder = os.path.expanduser("~/data/gsm8k/") + data_files = [os.path.join(local_folder, "train.parquet")] + tokenizer = AutoTokenizer.from_pretrained(rollout_model_path) + + dataset = RLHFDataset( + data_files=data_files, + tokenizer=tokenizer, + config=config.data, + processor=None, + ) + + batch_size = 64 + sampler = create_rl_sampler(config.data, dataset) + dataloader = StatefulDataLoader( + dataset=dataset, + batch_size=batch_size, + num_workers=config.data.dataloader_num_workers, + drop_last=True, + collate_fn=collate_fn, + sampler=sampler, + ) + + # 3. generate responses + batch_dict = next(iter(dataloader)) + batch = DataProto.from_single_dict(batch_dict) + gen_batch = agent_loop_manager.generate_sequences(prompts=batch) + + rm_scores = gen_batch.batch["rm_scores"] + sample_scores = rm_scores.sum(dim=1) + print(sample_scores) + + ray.shutdown() diff --git a/code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_agent_reward_loop_colocate.py b/code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_agent_reward_loop_colocate.py new file mode 100644 index 0000000000000000000000000000000000000000..638c224da707c817907eb2b0fd05f5823e5b58a9 --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_agent_reward_loop_colocate.py @@ -0,0 +1,168 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import ray +from hydra import compose, initialize_config_dir +from torchdata.stateful_dataloader import StatefulDataLoader +from transformers import AutoTokenizer + +from verl.checkpoint_engine import CheckpointEngineManager +from verl.experimental.agent_loop import AgentLoopManager +from verl.experimental.reward_loop import RewardLoopManager +from verl.protocol import DataProto +from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup +from verl.trainer.main_ppo import create_rl_sampler +from verl.trainer.ppo.ray_trainer import ResourcePoolManager +from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn +from verl.utils.device import get_device_name +from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker + + +def test_agent_loop_reward_manager(): + ray.init( + runtime_env={ + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "INFO", + "VLLM_USE_V1": "1", + } + } + ) + with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): + config = compose(config_name="ppo_trainer") + + rollout_model_path = os.path.expanduser("~/models/Qwen/Qwen2.5-0.5B-Instruct") + reward_model_path = os.path.expanduser("~/models/Qwen/Qwen2.5-1.5B-Instruct") + + # actor_rollout_ref config + config.data.return_raw_chat = True + config.data.max_prompt_length = 1024 + config.data.max_response_length = 4096 + config.actor_rollout_ref.model.path = rollout_model_path + config.actor_rollout_ref.actor.use_dynamic_bsz = True + config.actor_rollout_ref.rollout.name = os.getenv("ROLLOUT_NAME", "vllm") + config.actor_rollout_ref.rollout.mode = "async" + config.actor_rollout_ref.rollout.tensor_model_parallel_size = 2 + config.actor_rollout_ref.rollout.gpu_memory_utilization = 0.8 + config.actor_rollout_ref.rollout.enforce_eager = True + config.actor_rollout_ref.rollout.prompt_length = 1024 + config.actor_rollout_ref.rollout.response_length = 4096 + config.actor_rollout_ref.rollout.skip_tokenizer_init = True + config.trainer.n_gpus_per_node = 8 + config.trainer.nnodes = 1 + + config.reward_model.reward_manager = "dapo" + config.reward_model.enable = True + config.reward_model.enable_resource_pool = False + config.reward_model.n_gpus_per_node = 8 + config.reward_model.model.path = reward_model_path + config.reward_model.rollout.name = os.getenv("ROLLOUT_NAME", "vllm") + config.reward_model.rollout.gpu_memory_utilization = 0.8 + config.reward_model.rollout.tensor_model_parallel_size = 2 + config.reward_model.rollout.skip_tokenizer_init = False + config.reward_model.rollout.prompt_length = 5120 + config.reward_model.rollout.response_length = 4096 + config.custom_reward_function.path = "tests/experimental/reward_loop/reward_fn.py" + config.custom_reward_function.name = "compute_score_gsm8k" + + # 1. init reward model manager + actor_rollout_cls = ( + AsyncActorRolloutRefWorker if config.actor_rollout_ref.rollout.mode == "async" else ActorRolloutRefWorker + ) + global_pool_id = "global_pool" + resource_pool_spec = { + global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, + } + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=None) + resource_pool_manager.create_resource_pool() + resource_pool = resource_pool_manager.resource_pool_dict[global_pool_id] + actor_rollout_cls = RayClassWithInitArgs( + cls=ray.remote(actor_rollout_cls), config=config.actor_rollout_ref, role="actor_rollout" + ) + actor_rollout_wg = RayWorkerGroup( + resource_pool=resource_pool, ray_cls_with_init=actor_rollout_cls, device_name=get_device_name() + ) + actor_rollout_wg.init_model() + + agent_loop_manager = AgentLoopManager(config, worker_group=actor_rollout_wg) + # sleep rollout replicas + checkpoint_manager = CheckpointEngineManager( + backend=config.actor_rollout_ref.rollout.checkpoint_engine.backend, + trainer=actor_rollout_wg, + replicas=agent_loop_manager.rollout_replicas, + ) + checkpoint_manager.sleep_replicas() + reward_loop_manager = RewardLoopManager(config, rm_resource_pool=resource_pool) + + # 2. init test data + local_folder = os.path.expanduser("~/data/gsm8k/") + + data_files = [os.path.join(local_folder, "train.parquet")] + tokenizer = AutoTokenizer.from_pretrained(rollout_model_path) + + dataset = RLHFDataset( + data_files=data_files, + tokenizer=tokenizer, + config=config.data, + processor=None, + ) + + batch_size = 64 + sampler = create_rl_sampler(config.data, dataset) + dataloader = StatefulDataLoader( + dataset=dataset, + batch_size=batch_size, + num_workers=config.data.dataloader_num_workers, + drop_last=True, + collate_fn=collate_fn, + sampler=sampler, + ) + + # 3. generate responses + batch_dict = next(iter(dataloader)) + batch = DataProto.from_single_dict(batch_dict) + + def _get_gen_batch(batch: DataProto) -> DataProto: + reward_model_keys = set({"data_source", "reward_model", "extra_info", "uid"}) & batch.non_tensor_batch.keys() + + # pop those keys for generation + batch_keys_to_pop = [] + non_tensor_batch_keys_to_pop = set(batch.non_tensor_batch.keys()) - reward_model_keys + gen_batch = batch.pop( + batch_keys=batch_keys_to_pop, + non_tensor_batch_keys=list(non_tensor_batch_keys_to_pop), + ) + + # For agent loop, we need reward model keys to compute score. + gen_batch.non_tensor_batch.update(batch.non_tensor_batch) + + return gen_batch + + # wake up rollout replicas via update_weight + checkpoint_manager.update_weights() + gen_batch = _get_gen_batch(batch) + gen_batch = agent_loop_manager.generate_sequences(gen_batch) + checkpoint_manager.sleep_replicas() + + batch = batch.union(gen_batch) + rm_outputs = reward_loop_manager.compute_rm_score(batch) + + for output in rm_outputs[:5]: + print(output.non_tensor_batch) + + print("done") + + ray.shutdown() diff --git a/code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_async_token_bucket_on_cpu.py b/code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_async_token_bucket_on_cpu.py new file mode 100644 index 0000000000000000000000000000000000000000..70906fb51bd3848aa9e925261f2f5c4f71718e17 --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_async_token_bucket_on_cpu.py @@ -0,0 +1,267 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import time + +import pytest + +from verl.experimental.reward_loop.reward_manager.limited import AsyncTokenBucket + + +class TestAsyncTokenBucket: + """Unit tests for AsyncTokenBucket rate limiter.""" + + @pytest.mark.asyncio + async def test_basic_acquire(self): + """Test basic token acquisition.""" + bucket = AsyncTokenBucket(rate_limit=10.0, max_tokens=10.0) + + # Should be able to acquire tokens immediately when bucket is full + start = time.time() + await bucket.acquire(5.0) + elapsed = time.time() - start + + assert elapsed < 0.1, "Initial acquire should be immediate" + assert bucket.tokens == pytest.approx(5.0, abs=0.1) + + @pytest.mark.asyncio + async def test_refill_mechanism(self): + """Test that tokens refill over time.""" + bucket = AsyncTokenBucket(rate_limit=10.0, max_tokens=10.0) + + # Consume all tokens + await bucket.acquire(10.0) + assert bucket.tokens == pytest.approx(0.0, abs=0.1) + + # Wait for refill (should get ~5 tokens in 0.5 seconds at 10 tokens/sec) + await asyncio.sleep(0.5) + + # Try to acquire 4 tokens (should succeed without waiting) + start = time.time() + await bucket.acquire(4.0) + elapsed = time.time() - start + + assert elapsed < 0.1, "Acquire should be quick after refill" + + @pytest.mark.asyncio + async def test_waiting_for_tokens(self): + """Test that acquire waits when insufficient tokens available.""" + bucket = AsyncTokenBucket(rate_limit=10.0, max_tokens=10.0) + + # Consume all tokens + await bucket.acquire(10.0) + + # Try to acquire more tokens (should wait ~0.5 seconds for 5 tokens) + start = time.time() + await bucket.acquire(5.0) + elapsed = time.time() - start + + # Should wait approximately 0.5 seconds (5 tokens / 10 tokens per second) + assert 0.4 < elapsed < 0.7, f"Expected ~0.5s wait, got {elapsed:.3f}s" + + @pytest.mark.asyncio + async def test_max_tokens_cap(self): + """Test that tokens don't exceed max_tokens capacity.""" + bucket = AsyncTokenBucket(rate_limit=10.0, max_tokens=5.0) + + # Wait for potential overflow + await asyncio.sleep(1.0) + + # Tokens should be capped at max_tokens + await bucket.acquire(1.0) + + # After 1 second at 10 tokens/sec, should have max_tokens (5.0) + # After acquiring 1, should have 4.0 remaining + assert bucket.tokens <= 5.0, "Tokens should not exceed max_tokens" + + @pytest.mark.asyncio + async def test_fractional_tokens(self): + """Test acquiring fractional tokens.""" + bucket = AsyncTokenBucket(rate_limit=100.0, max_tokens=100.0) + + # Acquire fractional amounts + await bucket.acquire(0.5) + await bucket.acquire(1.5) + await bucket.acquire(2.3) + + assert bucket.tokens == pytest.approx(100.0 - 0.5 - 1.5 - 2.3, abs=0.1) + + @pytest.mark.asyncio + async def test_concurrent_acquires(self): + """Test multiple concurrent acquire operations.""" + bucket = AsyncTokenBucket(rate_limit=10.0, max_tokens=10.0) + + async def acquire_task(num_tokens: float, task_id: int): + await bucket.acquire(num_tokens) + return task_id + + # Launch 5 concurrent tasks, each acquiring 3 tokens (15 total) + # Bucket only has 10, so some will need to wait + start = time.time() + tasks = [acquire_task(3.0, i) for i in range(5)] + results = await asyncio.gather(*tasks) + elapsed = time.time() - start + + # Should take at least 0.5 seconds to refill 5 tokens + # (15 needed - 10 available) / 10 tokens per second = 0.5 seconds + assert elapsed >= 0.4, f"Expected >=0.4s for concurrent acquires, got {elapsed:.3f}s" + assert len(results) == 5, "All tasks should complete" + + @pytest.mark.asyncio + async def test_high_rate_limit(self): + """Test with high rate limit (simulating high-throughput scenarios).""" + bucket = AsyncTokenBucket(rate_limit=1000.0, max_tokens=1000.0) + + # Rapidly acquire tokens + start = time.time() + for _ in range(100): + await bucket.acquire(10.0) # 1000 tokens total + elapsed = time.time() - start + + # Should complete in approximately 1 second + assert elapsed < 1.5, f"High rate limit test took too long: {elapsed:.3f}s" + + @pytest.mark.asyncio + async def test_zero_initial_state(self): + """Test that bucket starts with full tokens.""" + bucket = AsyncTokenBucket(rate_limit=10.0, max_tokens=10.0) + + assert bucket.tokens == 10.0, "Bucket should start full" + assert bucket.last_update is None, "last_update should be None initially" + + # After first acquire, last_update should be set + await bucket.acquire(1.0) + assert bucket.last_update is not None, "last_update should be set after acquire" + + @pytest.mark.asyncio + async def test_rate_limit_accuracy(self): + """Test rate limit accuracy over time.""" + rate = 50.0 # 50 tokens per second + bucket = AsyncTokenBucket(rate_limit=rate, max_tokens=rate) + + # Consume all tokens and measure refill time for 25 tokens + await bucket.acquire(50.0) + + start = time.time() + await bucket.acquire(25.0) + elapsed = time.time() - start + + expected_time = 25.0 / rate # 0.5 seconds + # Allow 20% margin for timing inaccuracy + assert abs(elapsed - expected_time) < expected_time * 0.2, f"Expected ~{expected_time:.3f}s, got {elapsed:.3f}s" + + @pytest.mark.asyncio + async def test_sequential_acquires(self): + """Test sequential acquire operations.""" + bucket = AsyncTokenBucket(rate_limit=20.0, max_tokens=20.0) + + # Sequential acquires without waiting + await bucket.acquire(5.0) + await bucket.acquire(5.0) + await bucket.acquire(5.0) + await bucket.acquire(5.0) + + # Bucket should be empty + assert bucket.tokens == pytest.approx(0.0, abs=0.1) + + # Next acquire should wait + start = time.time() + await bucket.acquire(10.0) + elapsed = time.time() - start + + assert elapsed >= 0.4, "Should wait for token refill" + + @pytest.mark.asyncio + async def test_default_max_tokens(self): + """Test that max_tokens defaults to rate_limit.""" + bucket = AsyncTokenBucket(rate_limit=15.0) + + assert bucket.max_tokens == 15.0, "max_tokens should default to rate_limit" + assert bucket.tokens == 15.0, "Initial tokens should equal max_tokens" + + @pytest.mark.asyncio + async def test_single_token_acquire(self): + """Test default acquire of 1 token.""" + bucket = AsyncTokenBucket(rate_limit=10.0, max_tokens=10.0) + + await bucket.acquire() # Default num_tokens=1.0 + + assert bucket.tokens == pytest.approx(9.0, abs=0.1) + + @pytest.mark.asyncio + async def test_large_token_acquire(self): + """Test acquiring more tokens than bucket capacity.""" + bucket = AsyncTokenBucket(rate_limit=10.0, max_tokens=10.0) + + # Try to acquire 50 tokens (5x capacity) + start = time.time() + await bucket.acquire(50.0) + elapsed = time.time() - start + + # Should wait for: (50 - 10) / 10 = 4 seconds + assert 3.5 < elapsed < 5.0, f"Expected ~4s wait for large acquire, got {elapsed:.3f}s" + + @pytest.mark.asyncio + async def test_thread_safety_with_lock(self): + """Test that lock prevents race conditions.""" + bucket = AsyncTokenBucket(rate_limit=100.0, max_tokens=100.0) + results = [] + + async def acquire_and_record(): + await bucket.acquire(10.0) + results.append(1) + + # Launch many concurrent tasks + tasks = [acquire_and_record() for _ in range(10)] + await asyncio.gather(*tasks) + + # All tasks should complete + assert len(results) == 10, "All tasks should complete successfully" + + # Bucket should have consumed exactly 100 tokens + assert bucket.tokens == pytest.approx(0.0, abs=0.5) + + @pytest.mark.asyncio + async def test_multiple_wait_cycles(self): + """Test multiple wait cycles in the acquire loop.""" + bucket = AsyncTokenBucket(rate_limit=10.0, max_tokens=10.0) + + # Consume all tokens + await bucket.acquire(10.0) + + # Acquire tokens that require multiple refill cycles + start = time.time() + await bucket.acquire(15.0) + elapsed = time.time() - start + + # Should wait for 15 tokens / 10 tokens per second = 1.5 seconds + assert 1.3 < elapsed < 1.8, f"Expected ~1.5s for multiple refill cycles, got {elapsed:.3f}s" + + @pytest.mark.asyncio + async def test_rapid_small_acquires(self): + """Test many rapid small acquisitions.""" + bucket = AsyncTokenBucket(rate_limit=100.0, max_tokens=100.0) + + start = time.time() + for _ in range(50): + await bucket.acquire(2.0) # 100 tokens total + elapsed = time.time() - start + + # Should complete quickly since we're within capacity + assert elapsed < 0.5, f"Rapid small acquires took too long: {elapsed:.3f}s" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_math_verify.py b/code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_math_verify.py new file mode 100644 index 0000000000000000000000000000000000000000..c40a0296340521f57ac87917aa0fc6aebeef7b46 --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_math_verify.py @@ -0,0 +1,100 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import ray +from hydra import compose, initialize_config_dir +from torchdata.stateful_dataloader import StatefulDataLoader +from transformers import AutoTokenizer + +from verl.experimental.agent_loop import AgentLoopManager +from verl.protocol import DataProto +from verl.trainer.main_ppo import create_rl_sampler +from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn + + +def test_agent_loop_reward_manager(): + ray.init( + runtime_env={ + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "INFO", + "VLLM_USE_V1": "1", + } + } + ) + with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): + config = compose(config_name="ppo_trainer") + + rollout_model_path = os.path.expanduser("~/models/Qwen/Qwen2.5-3B-Instruct") + + # actor_rollout_ref config + config.data.return_raw_chat = True + config.data.max_prompt_length = 1024 + config.data.max_response_length = 4096 + config.actor_rollout_ref.model.path = rollout_model_path + config.actor_rollout_ref.actor.use_dynamic_bsz = True + config.actor_rollout_ref.rollout.name = os.getenv("ROLLOUT_NAME", "vllm") + config.actor_rollout_ref.rollout.mode = "async" + config.actor_rollout_ref.rollout.tensor_model_parallel_size = 2 + config.actor_rollout_ref.rollout.gpu_memory_utilization = 0.9 + config.actor_rollout_ref.rollout.enforce_eager = True + config.actor_rollout_ref.rollout.prompt_length = 2048 + config.actor_rollout_ref.rollout.response_length = 4096 + config.actor_rollout_ref.rollout.skip_tokenizer_init = True + config.trainer.n_gpus_per_node = 8 + config.trainer.nnodes = 1 + + config.reward_model.reward_manager = "remote" + config.reward_model.num_workers = 2 + config.custom_reward_function.path = "tests/experimental/reward_loop/reward_fn.py" + config.custom_reward_function.name = "compute_score_math_verify" + + # 1. init reward model manager + agent_loop_manager = AgentLoopManager(config) + + # 2. init test data + local_folder = os.path.expanduser("~/data/math/") + data_files = [os.path.join(local_folder, "train.parquet")] + tokenizer = AutoTokenizer.from_pretrained(rollout_model_path) + + dataset = RLHFDataset( + data_files=data_files, + tokenizer=tokenizer, + config=config.data, + processor=None, + ) + + batch_size = 64 + sampler = create_rl_sampler(config.data, dataset) + dataloader = StatefulDataLoader( + dataset=dataset, + batch_size=batch_size, + num_workers=config.data.dataloader_num_workers, + drop_last=True, + collate_fn=collate_fn, + sampler=sampler, + ) + + # 3. generate responses + batch_dict = next(iter(dataloader)) + batch = DataProto.from_single_dict(batch_dict) + gen_batch = agent_loop_manager.generate_sequences(prompts=batch) + + rm_scores = gen_batch.batch["rm_scores"] + accuracy = rm_scores.sum(dim=-1).mean() + print(accuracy) + + ray.shutdown() diff --git a/code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_rate_limited_reward_manager_on_cpu.py b/code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_rate_limited_reward_manager_on_cpu.py new file mode 100644 index 0000000000000000000000000000000000000000..dfeca215327c8dd4aadab4ee2b4f10a7ce6e5f53 --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_rate_limited_reward_manager_on_cpu.py @@ -0,0 +1,528 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +import os.path +import time + +import pytest +import torch +from omegaconf import DictConfig +from transformers import AutoTokenizer + +from verl import DataProto +from verl.experimental.reward_loop.reward_manager.limited import RateLimitedRewardManager + + +# Mock API reward functions for testing +class MockAPICounter: + """Shared counter to track API calls across tests.""" + + def __init__(self): + self.call_count = 0 + self.call_times = [] + self.lock = asyncio.Lock() + + async def record_call(self): + async with self.lock: + self.call_count += 1 + self.call_times.append(time.time()) + + def reset(self): + self.call_count = 0 + self.call_times.clear() + + def get_rate_per_second(self, window_start: float = None): + """Calculate API call rate over a time window.""" + if window_start is None: + if not self.call_times: + return 0.0 + window_start = self.call_times[0] + + if not self.call_times: + return 0.0 + + window_end = self.call_times[-1] + duration = window_end - window_start + + if duration <= 0: + return 0.0 + + calls_in_window = sum(1 for t in self.call_times if t >= window_start) + return calls_in_window / duration + + +# Global counter instance +api_counter = MockAPICounter() + + +def mock_sync_reward_function( + data_source: str, solution_str: str, ground_truth: str, extra_info: dict, **kwargs +) -> float: + """Synchronous mock reward function that simulates API call.""" + # Simulate API processing time + time.sleep(0.01) + + # Simple scoring logic + score = 1.0 if solution_str.strip() == ground_truth.strip() else 0.0 + return score + + +async def mock_async_reward_function( + data_source: str, solution_str: str, ground_truth: str, extra_info: dict, **kwargs +) -> float: + """Asynchronous mock reward function that simulates API call.""" + # Record API call for rate tracking + await api_counter.record_call() + + # Simulate async API call (e.g., HTTP request) + await asyncio.sleep(0.01) + + # Simple scoring logic + score = 1.0 if solution_str.strip() == ground_truth.strip() else 0.0 + return score + + +async def mock_slow_api_function( + data_source: str, solution_str: str, ground_truth: str, extra_info: dict, **kwargs +) -> float: + """Slow mock API function for timeout testing.""" + await asyncio.sleep(2.0) # Simulate slow API + return 0.5 + + +async def mock_failing_api_function( + data_source: str, solution_str: str, ground_truth: str, extra_info: dict, **kwargs +) -> float: + """Mock API function that raises an exception.""" + await api_counter.record_call() + raise ValueError("Simulated API error") + + +async def mock_dict_result_function( + data_source: str, solution_str: str, ground_truth: str, extra_info: dict, **kwargs +) -> dict: + """Mock API function that returns dict result.""" + await api_counter.record_call() + await asyncio.sleep(0.01) + + correct = solution_str.strip() == ground_truth.strip() + return {"score": 1.0 if correct else 0.0, "correct": correct, "reasoning": "Mock reasoning"} + + +def create_test_data_proto(tokenizer, response_text: str, ground_truth: str, data_source: str = "test"): + """Helper to create DataProto for testing.""" + response_ids = tokenizer.encode(response_text, add_special_tokens=False) + response_tensor = torch.tensor([response_ids], dtype=torch.long) + attention_mask = torch.ones_like(response_tensor) + + data = DataProto.from_dict( + { + "responses": response_tensor, + "attention_mask": attention_mask, + } + ) + + # Wrap non-tensor values in lists to match batch dimension + data.non_tensor_batch = {"data_source": [data_source], "reward_model": [{"ground_truth": ground_truth}]} + + return data + + +class TestRateLimitedRewardManager: + """Integration tests for RateLimitedRewardManager with mock API functions.""" + + @pytest.fixture(autouse=True) + def setup_and_teardown(self): + """Reset global state before each test.""" + api_counter.reset() + # Reset class state + RateLimitedRewardManager._class_initialized = False + RateLimitedRewardManager._semaphore = None + RateLimitedRewardManager._rpm_limiter = None + RateLimitedRewardManager._tpm_limiter = None + yield + # Cleanup + api_counter.reset() + + @pytest.fixture + def tokenizer(self): + """Load a simple tokenizer for testing.""" + return AutoTokenizer.from_pretrained(os.path.expanduser("~/models/Qwen/Qwen2.5-0.5B-Instruct")) + + @pytest.mark.asyncio + async def test_basic_reward_computation(self, tokenizer): + """Test basic reward computation without rate limiting.""" + config = DictConfig({"reward_model": {"max_concurrent": 10, "timeout": 10.0}}) + + RateLimitedRewardManager.init_class(config, tokenizer) + manager = RateLimitedRewardManager(config=config, tokenizer=tokenizer, compute_score=mock_async_reward_function) + + # Create test data + data = create_test_data_proto(tokenizer, "correct answer", "correct answer") + + # Compute reward + result = await manager.run_single(data) + + assert "reward_score" in result + assert result["reward_score"] == 1.0 + assert api_counter.call_count == 1 + + @pytest.mark.asyncio + async def test_rpm_rate_limiting(self, tokenizer): + """Test request per minute (RPM) rate limiting.""" + # Set RPM limit to 60 (1 request per second) + config = DictConfig( + { + "reward_model": { + "max_concurrent": 10, + "max_rpm": 60, # 1 request per second + "timeout": 10.0, + } + } + ) + + RateLimitedRewardManager.init_class(config, tokenizer) + manager = RateLimitedRewardManager(config=config, tokenizer=tokenizer, compute_score=mock_async_reward_function) + + # Create test data + data = create_test_data_proto(tokenizer, "answer", "answer") + + # Make 3 requests - should be rate limited + start_time = time.time() + + results = [] + for _ in range(3): + result = await manager.run_single(data) + results.append(result) + + elapsed = time.time() - start_time + + # Should take at least ~2 seconds for 3 requests at 1 req/sec + assert elapsed >= 1.8, f"RPM limiting failed: {elapsed:.3f}s for 3 requests" + assert all(r["reward_score"] == 1.0 for r in results) + assert api_counter.call_count == 3 + + @pytest.mark.asyncio + async def test_tpm_rate_limiting(self, tokenizer): + """Test tokens per minute (TPM) rate limiting.""" + # Set TPM limit to 6000 (100 tokens per second) + # With 2000 tokens per request, that's 0.05 req/sec or 20 seconds per request + config = DictConfig( + { + "reward_model": { + "max_concurrent": 10, + "max_tpm": 6000, # 100 tokens per second + "estimated_tokens_per_request": 2000, # Each request = 2000 tokens + "timeout": 30.0, + } + } + ) + + RateLimitedRewardManager.init_class(config, tokenizer) + manager = RateLimitedRewardManager(config=config, tokenizer=tokenizer, compute_score=mock_async_reward_function) + + data = create_test_data_proto(tokenizer, "answer", "answer") + + # Make 2 requests + start_time = time.time() + + result1 = await manager.run_single(data) + result2 = await manager.run_single(data) + + elapsed = time.time() - start_time + + # First request: consumes 2000 tokens (immediate) + # Second request: needs 2000 tokens, waits for refill + # Wait time: 2000 tokens / 100 tokens per second = 20 seconds + assert elapsed >= 18.0, f"TPM limiting failed: {elapsed:.3f}s for 2 requests" + assert result1["reward_score"] == 1.0 + assert result2["reward_score"] == 1.0 + + @pytest.mark.asyncio + async def test_concurrency_limiting(self, tokenizer): + """Test concurrent request limiting.""" + config = DictConfig( + { + "reward_model": { + "max_concurrent": 2, # Only 2 concurrent requests + "timeout": 10.0, + } + } + ) + + RateLimitedRewardManager.init_class(config, tokenizer) + manager = RateLimitedRewardManager(config=config, tokenizer=tokenizer, compute_score=mock_async_reward_function) + + data = create_test_data_proto(tokenizer, "answer", "answer") + + # Launch 5 concurrent requests + start_time = time.time() + + tasks = [manager.run_single(data) for _ in range(5)] + results = await asyncio.gather(*tasks) + + elapsed = time.time() - start_time + + # All should succeed + assert len(results) == 5 + assert all(r["reward_score"] == 1.0 for r in results) + + # With concurrency=2 and 0.01s per request, should take at least 0.03s + # (3 batches: 2+2+1) + assert elapsed >= 0.02, f"Concurrency limiting may not be working: {elapsed:.3f}s" + + @pytest.mark.asyncio + async def test_timeout_handling(self, tokenizer): + """Test timeout handling for slow API.""" + config = DictConfig( + { + "reward_model": { + "max_concurrent": 10, + "timeout": 0.5, # 500ms timeout + } + } + ) + + RateLimitedRewardManager.init_class(config, tokenizer) + manager = RateLimitedRewardManager(config=config, tokenizer=tokenizer, compute_score=mock_slow_api_function) + + data = create_test_data_proto(tokenizer, "answer", "answer") + + # Should timeout and return 0.0 + result = await manager.run_single(data) + + assert result["reward_score"] == 0.0 + assert result["reward_extra_info"].get("timeout") is True + assert result["reward_extra_info"].get("acc") == 0.0 + + @pytest.mark.asyncio + async def test_error_handling(self, tokenizer): + """Test error handling for failing API.""" + config = DictConfig({"reward_model": {"max_concurrent": 10, "timeout": 10.0}}) + + RateLimitedRewardManager.init_class(config, tokenizer) + manager = RateLimitedRewardManager(config=config, tokenizer=tokenizer, compute_score=mock_failing_api_function) + + data = create_test_data_proto(tokenizer, "answer", "answer") + + # Should catch exception and return 0.0 + result = await manager.run_single(data) + + assert result["reward_score"] == 0.0 + assert "error" in result["reward_extra_info"] + assert "Simulated API error" in result["reward_extra_info"]["error"] + assert result["reward_extra_info"].get("acc") == 0.0 + assert api_counter.call_count == 1 + + @pytest.mark.asyncio + async def test_dict_result_format(self, tokenizer): + """Test handling of dict return format from reward function.""" + config = DictConfig({"reward_model": {"max_concurrent": 10, "timeout": 10.0}}) + + RateLimitedRewardManager.init_class(config, tokenizer) + manager = RateLimitedRewardManager(config=config, tokenizer=tokenizer, compute_score=mock_dict_result_function) + + data = create_test_data_proto(tokenizer, "correct", "correct") + + result = await manager.run_single(data) + + assert result["reward_score"] == 1.0 + assert result["reward_extra_info"]["score"] == 1.0 + assert result["reward_extra_info"]["correct"] is True + assert result["reward_extra_info"]["reasoning"] == "Mock reasoning" + + @pytest.mark.asyncio + async def test_sync_reward_function(self, tokenizer): + """Test that synchronous reward functions work correctly.""" + config = DictConfig({"reward_model": {"max_concurrent": 10, "timeout": 10.0}}) + + RateLimitedRewardManager.init_class(config, tokenizer) + manager = RateLimitedRewardManager(config=config, tokenizer=tokenizer, compute_score=mock_sync_reward_function) + + data = create_test_data_proto(tokenizer, "answer", "answer") + + result = await manager.run_single(data) + + assert result["reward_score"] == 1.0 + assert manager.is_async_reward_score is False + + @pytest.mark.asyncio + async def test_combined_rate_limits(self, tokenizer): + """Test all three rate limiting layers together.""" + config = DictConfig( + { + "reward_model": { + "max_concurrent": 2, + "max_rpm": 120, # 2 requests per second + "max_tpm": 12000, # 200 tokens per second + "estimated_tokens_per_request": 100, # 0.5 seconds per request + "timeout": 10.0, + } + } + ) + + RateLimitedRewardManager.init_class(config, tokenizer) + manager = RateLimitedRewardManager(config=config, tokenizer=tokenizer, compute_score=mock_async_reward_function) + + data = create_test_data_proto(tokenizer, "answer", "answer") + + # Make 6 requests to exceed burst capacity (RPM bucket starts with 2 tokens) + start_time = time.time() + + tasks = [manager.run_single(data) for _ in range(6)] + results = await asyncio.gather(*tasks) + + elapsed = time.time() - start_time + + # Bucket starts with 2 RPM tokens and 200 TPM tokens + # First 2 requests: use burst capacity (2 RPM tokens, 200 TPM tokens) + # Next 4 requests: need 4 RPM tokens (wait 2 seconds) and 400 TPM tokens (wait 2 seconds) + # Limiting factor: RPM at 2 seconds + assert elapsed >= 1.8, f"Combined rate limiting: {elapsed:.3f}s" + assert all(r["reward_score"] == 1.0 for r in results) + assert api_counter.call_count == 6 + + @pytest.mark.asyncio + async def test_correct_vs_incorrect_answers(self, tokenizer): + """Test scoring of correct vs incorrect answers.""" + config = DictConfig({"reward_model": {"max_concurrent": 10, "timeout": 10.0}}) + + RateLimitedRewardManager.init_class(config, tokenizer) + manager = RateLimitedRewardManager(config=config, tokenizer=tokenizer, compute_score=mock_async_reward_function) + + # Test correct answer + data_correct = create_test_data_proto(tokenizer, "right answer", "right answer") + result_correct = await manager.run_single(data_correct) + + # Test incorrect answer + data_incorrect = create_test_data_proto(tokenizer, "wrong answer", "right answer") + result_incorrect = await manager.run_single(data_incorrect) + + assert result_correct["reward_score"] == 1.0 + assert result_incorrect["reward_score"] == 0.0 + + @pytest.mark.asyncio + async def test_high_throughput(self, tokenizer): + """Test high throughput with many concurrent requests.""" + config = DictConfig( + { + "reward_model": { + "max_concurrent": 20, + "max_rpm": 6000, # 100 requests per second + "timeout": 10.0, + } + } + ) + + RateLimitedRewardManager.init_class(config, tokenizer) + manager = RateLimitedRewardManager(config=config, tokenizer=tokenizer, compute_score=mock_async_reward_function) + + data = create_test_data_proto(tokenizer, "answer", "answer") + + # Launch 200 concurrent requests (more than burst capacity of 100) + start_time = time.time() + + tasks = [manager.run_single(data) for _ in range(200)] + results = await asyncio.gather(*tasks) + + elapsed = time.time() - start_time + + assert len(results) == 200 + assert all(r["reward_score"] == 1.0 for r in results) + + # Bucket starts with 100 tokens (burst capacity) + # First 100 requests: use burst capacity instantly + # Next 100 requests: need to wait for refill at 100 tokens/sec = 1 second minimum + # Total time should be at least 1 second + assert elapsed >= 0.9, f"Should take at least 0.9s for rate limiting, took {elapsed:.3f}s" + + # Calculate actual rate over the time window + actual_rate = api_counter.call_count / elapsed + + # Average rate should not significantly exceed 100 req/sec + # Allow some burst overhead due to initial capacity + assert actual_rate <= 200, f"Rate limiting failed: {actual_rate:.1f} req/sec (max 200)" + + @pytest.mark.asyncio + async def test_class_initialization_once(self, tokenizer): + """Test that class initialization only happens once.""" + config = DictConfig({"reward_model": {"max_concurrent": 5, "timeout": 10.0}}) + + # Initialize multiple times + RateLimitedRewardManager.init_class(config, tokenizer) + first_semaphore = RateLimitedRewardManager._semaphore + + RateLimitedRewardManager.init_class(config, tokenizer) + second_semaphore = RateLimitedRewardManager._semaphore + + # Should be the same object + assert first_semaphore is second_semaphore + + def test_warn_when_rate_limits_are_ignored_due_to_prior_init(self, tokenizer, caplog): + """Warn when a new config attempts to change global RPM/TPM after the class has been initialized.""" + caplog.set_level(logging.WARNING) + + # First instantiation without a config (legacy signature) initializes global limiters with defaults. + _ = RateLimitedRewardManager( + tokenizer=tokenizer, + compute_score=mock_async_reward_function, + num_examine=0, + reward_fn_key="data_source", + ) + + # Second instantiation attempts to set RPM limits, but will be ignored due to global initialization. + config = DictConfig({"reward_model": {"max_concurrent": 10, "max_rpm": 60, "timeout": 10.0}}) + _ = RateLimitedRewardManager( + config=config, + tokenizer=tokenizer, + compute_score=mock_async_reward_function, + ) + + assert any( + "RateLimitedRewardManager has already been initialized" in record.getMessage() + and "ignored" in record.getMessage() + for record in caplog.records + ), "Expected a warning when attempting to change global rate limits after initialization." + + @pytest.mark.asyncio + async def test_extra_info_handling(self, tokenizer): + """Test that extra_info is properly passed to reward function.""" + received_extra_info = {} + + async def mock_reward_with_extra_info( + data_source: str, solution_str: str, ground_truth: str, extra_info: dict, **kwargs + ): + received_extra_info.update(extra_info) + return 1.0 + + config = DictConfig({"reward_model": {"max_concurrent": 10, "timeout": 10.0}}) + + RateLimitedRewardManager.init_class(config, tokenizer) + manager = RateLimitedRewardManager( + config=config, tokenizer=tokenizer, compute_score=mock_reward_with_extra_info + ) + + data = create_test_data_proto(tokenizer, "answer", "answer") + data.non_tensor_batch["extra_info"] = [{"custom_field": "test_value"}] + + await manager.run_single(data) + + assert "custom_field" in received_extra_info + assert received_extra_info["custom_field"] == "test_value" + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_reward_model_disrm.py b/code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_reward_model_disrm.py new file mode 100644 index 0000000000000000000000000000000000000000..194d499e567b5894051e2473798b96c83b4716ec --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_reward_model_disrm.py @@ -0,0 +1,153 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import ray +import torch +from hydra import compose, initialize_config_dir + +from verl.experimental.reward_loop import RewardLoopManager +from verl.protocol import DataProto +from verl.utils import hf_tokenizer +from verl.utils.model import compute_position_id_with_mask + + +def create_data_samples(tokenizer) -> DataProto: + convs = [ + [ + { + "role": "user", + "content": "What is the range of the numeric output of a sigmoid node in a neural network?", + }, + {"role": "assistant", "content": "Between -1 and 1."}, + ], + [ + { + "role": "user", + "content": "What is the range of the numeric output of a sigmoid node in a neural network?", + }, + {"role": "assistant", "content": "Between 0 and 1."}, + ], + [ + {"role": "user", "content": "What is the capital of Australia?"}, + { + "role": "assistant", + "content": "Canberra is the capital city of Australia.", + }, + ], + [ + {"role": "user", "content": "What is the capital of Australia?"}, + { + "role": "assistant", + "content": "Sydney is the capital of Australia.", + }, + ], + ] + raw_prompt = [conv[:1] for conv in convs] + data_source = ["gsm8k"] * len(convs) + reward_info = [{"ground_truth": "Not Used"}] * len(convs) + extra_info = [{"question": conv[0]["content"]} for conv in convs] + + prompt_length, response_length = 1024, 4096 + pad_token_id = tokenizer.pad_token_id + prompts, responses, input_ids, attention_masks = [], [], [], [] + for conv in convs: + prompt_tokens = tokenizer.apply_chat_template(conv[:1], tokenize=True) + response_tokens = tokenizer.apply_chat_template(conv, tokenize=True)[len(prompt_tokens) :] + + padded_prompt = [pad_token_id] * (prompt_length - len(prompt_tokens)) + prompt_tokens + padded_response = response_tokens + [pad_token_id] * (response_length - len(response_tokens)) + attention_mask = ( + [0] * (prompt_length - len(prompt_tokens)) + + [1] * len(prompt_tokens) + + [1] * len(response_tokens) + + [0] * (response_length - len(response_tokens)) + ) + prompts.append(torch.tensor(padded_prompt)) + responses.append(torch.tensor(padded_response)) + input_ids.append(torch.tensor(padded_prompt + padded_response)) + attention_masks.append(torch.tensor(attention_mask)) + + prompts = torch.stack(prompts) + responses = torch.stack(responses) + input_ids = torch.stack(input_ids) + attention_masks = torch.stack(attention_masks) + position_ids = compute_position_id_with_mask(attention_masks) + + data = DataProto.from_dict( + tensors={ + "prompts": prompts, + "responses": responses, + "input_ids": input_ids, + "attention_mask": attention_masks, + "position_ids": position_ids, + }, + non_tensors={ + "data_source": data_source, + "reward_model": reward_info, + "raw_prompt": raw_prompt, + "extra_info": extra_info, + }, + ) + return data, convs + + +def test_reward_model_manager(): + ray.init( + runtime_env={ + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "INFO", + "VLLM_USE_V1": "1", + } + } + ) + with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): + config = compose(config_name="ppo_trainer") + + rollout_model_name = os.path.expanduser("~/models/Qwen/Qwen2.5-1.5B-Instruct") + reward_model_name = os.path.expanduser("~/models/Skywork/Skywork-Reward-V2-Llama-3.2-1B") + + config.actor_rollout_ref.model.path = rollout_model_name + config.reward_model.reward_manager = "dapo" + config.reward_model.enable = True + config.reward_model.enable_resource_pool = True + config.reward_model.n_gpus_per_node = 8 + config.reward_model.nnodes = 1 + config.reward_model.model.path = reward_model_name + config.reward_model.rollout.name = os.getenv("ROLLOUT_NAME", "vllm") + config.reward_model.rollout.gpu_memory_utilization = 0.9 + config.reward_model.rollout.tensor_model_parallel_size = 2 + config.reward_model.rollout.skip_tokenizer_init = False + config.reward_model.rollout.prompt_length = 2048 + config.reward_model.rollout.response_length = 4096 + + # 1. init reward model manager + reward_loop_manager = RewardLoopManager(config) + + # 2. init test data + rollout_tokenizer = hf_tokenizer(rollout_model_name) + data, convs = create_data_samples(rollout_tokenizer) + + # 3. generate responses + outputs = reward_loop_manager.compute_rm_score(data) + + for idx, (conv, output) in enumerate(zip(convs, outputs, strict=True)): + print(f"Problem {idx}:\n{conv[0]['content']}\n") + print(f"AI Solution {idx}:\n{conv[1]['content']}\n") + print(f"DisRM Score {idx}:\n{output.batch['rm_scores'].sum(dim=-1).item()}\n") + print("=" * 50 + "\n") + + ray.shutdown() diff --git a/code/RL_model/verl/verl_train/tests/experimental/vla/test_sim_envs.py b/code/RL_model/verl/verl_train/tests/experimental/vla/test_sim_envs.py new file mode 100644 index 0000000000000000000000000000000000000000..adb2723498ed854b33c7b81610cb47b17e471477 --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/experimental/vla/test_sim_envs.py @@ -0,0 +1,101 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +import numpy as np +import pytest +from omegaconf import OmegaConf + + +# @pytest.mark.parametrize("simulator_type", ["libero", "isaac"]) +@pytest.mark.parametrize("simulator_type", ["isaac"]) +def test_sim_env_creation_and_step(simulator_type): + num_envs = 8 + actions = np.array( + [ + [5.59112417e-01, 8.06460073e-02, 1.36817226e-02, -4.64279854e-04, -1.72158767e-02, -6.57548380e-04, -1], + [2.12711899e-03, -3.13366604e-01, 3.41386353e-04, -4.64279854e-04, -8.76528812e-03, -6.57548380e-04, -1], + [7.38182960e-02, -4.64548351e-02, -6.63602950e-02, -4.64279854e-04, -2.32520114e-02, -6.57548380e-04, -1], + [7.38182960e-02, -1.60845593e-01, 3.41386353e-04, -4.64279854e-04, 1.05503430e-02, -6.57548380e-04, -1], + [7.38182960e-02, -3.95982152e-01, -7.97006313e-02, -5.10713711e-03, 3.22804279e-02, -6.57548380e-04, -1], + [2.41859427e-02, -3.64206941e-01, -6.63602950e-02, -4.64279854e-04, 1.05503430e-02, -6.57548380e-04, -1], + [4.62447664e-02, -5.16727952e-01, -7.97006313e-02, -4.64279854e-04, 1.05503430e-02, 8.73740975e-03, -1], + [4.62447664e-02, -5.73923331e-01, 3.41386353e-04, -4.64279854e-04, 6.92866212e-03, -6.57548380e-04, -1], + ] + ) + cfg = OmegaConf.create( + { + "max_episode_steps": 512, + "only_eval": False, + "reward_coef": 1.0, + "init_params": { + "camera_names": ["agentview"], + }, + "video_cfg": { + "save_video": True, + "video_base_dir": "/tmp/test_sim_env_creation_and_step", + }, + "task_suite_name": "libero_10", + "num_envs": num_envs, + "num_group": 1, + "group_size": num_envs, + "seed": 0, + }, + ) + + sim_env = None + if simulator_type == "isaac": + from verl.experimental.vla.envs.isaac_env.isaac_env import IsaacEnv + + sim_env = IsaacEnv(cfg, rank=0, world_size=1) + elif simulator_type == "libero": + from verl.experimental.vla.envs.libero_env.libero_env import LiberoEnv + + sim_env = LiberoEnv(cfg, rank=0, world_size=1) + else: + raise ValueError(f"simulator_type {simulator_type} is not supported") + + video_count = 0 + for i in [0]: + # The first call to step with actions=None will reset the environment + step = 0 + sim_env.reset_envs_to_state_ids([0] * num_envs, [i] * num_envs) + for action in actions: + obs_venv, reward_venv, terminated_venv, truncated_venv, info_venv = sim_env.step( + np.array([action] * num_envs) + ) + + assert isinstance(obs_venv, dict) + assert reward_venv.shape == (num_envs,) + assert terminated_venv.shape == (num_envs,) + assert truncated_venv.shape == (num_envs,) + assert isinstance(info_venv, dict) + + if terminated_venv.any() or truncated_venv.any(): + break + step += 1 + + sim_env.flush_video(video_sub_dir=f"task_{i}") + assert os.path.exists(os.path.join(cfg.video_cfg.video_base_dir, f"rank_0/task_{i}/{video_count}.mp4")) + os.remove(os.path.join(cfg.video_cfg.video_base_dir, f"rank_0/task_{i}/{video_count}.mp4")) + video_count += 1 + + print("test passed") + sim_env.close() + + +if __name__ == "__main__": + unittest.main() diff --git a/code/RL_model/verl/verl_train/tests/single_controller/base/test_decorator.py b/code/RL_model/verl/verl_train/tests/single_controller/base/test_decorator.py new file mode 100644 index 0000000000000000000000000000000000000000..5447d65ce0ecfad235d63c3c8ca02d88c4c7a9e7 --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/single_controller/base/test_decorator.py @@ -0,0 +1,76 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import verl.single_controller.base.decorator as decorator_module +from verl.single_controller.base.decorator import ( + DISPATCH_MODE_FN_REGISTRY, + Dispatch, + _check_dispatch_mode, + get_predefined_dispatch_fn, + register_dispatch_mode, + update_dispatch_mode, +) + + +@pytest.fixture +def reset_dispatch_registry(): + # Store original state + original_registry = DISPATCH_MODE_FN_REGISTRY.copy() + yield + # Reset registry after test + decorator_module.DISPATCH_MODE_FN_REGISTRY.clear() + decorator_module.DISPATCH_MODE_FN_REGISTRY.update(original_registry) + + +def test_register_new_dispatch_mode(reset_dispatch_registry): + # Test registration + def dummy_dispatch(worker_group, *args, **kwargs): + return args, kwargs + + def dummy_collect(worker_group, output): + return output + + register_dispatch_mode("TEST_MODE", dummy_dispatch, dummy_collect) + + # Verify enum extension + _check_dispatch_mode(Dispatch.TEST_MODE) + + # Verify registry update + assert get_predefined_dispatch_fn(Dispatch.TEST_MODE) == { + "dispatch_fn": dummy_dispatch, + "collect_fn": dummy_collect, + } + # Clean up + Dispatch.remove("TEST_MODE") + + +def test_update_existing_dispatch_mode(reset_dispatch_registry): + # Store original implementation + original_mode = Dispatch.ONE_TO_ALL + + # New implementations + def new_dispatch(worker_group, *args, **kwargs): + return args, kwargs + + def new_collect(worker_group, output): + return output + + # Test update= + update_dispatch_mode(original_mode, new_dispatch, new_collect) + + # Verify update + assert get_predefined_dispatch_fn(original_mode)["dispatch_fn"] == new_dispatch + assert get_predefined_dispatch_fn(original_mode)["collect_fn"] == new_collect diff --git a/code/RL_model/verl/verl_train/tests/single_controller/check_worker_alive/main.py b/code/RL_model/verl/verl_train/tests/single_controller/check_worker_alive/main.py new file mode 100644 index 0000000000000000000000000000000000000000..cbdee9a8d6cf98544efc8abeb9555a66a2fd70ee --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/single_controller/check_worker_alive/main.py @@ -0,0 +1,64 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import time + +import ray + +from verl.single_controller.base.decorator import Dispatch, register +from verl.single_controller.base.worker import Worker +from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup + + +@ray.remote +class TestActor(Worker): + def __init__(self) -> None: + super().__init__() + + @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) + def foo(self, wait_time): + time.sleep(wait_time) + sys.exit(1) + + +if __name__ == "__main__": + wait_time = int(os.getenv("WAIT_TIME", "10")) + + ray.init() + + # test single-node-no-partition + print("test single-node-no-partition") + resource_pool = RayResourcePool([2], use_gpu=False) + class_with_args = RayClassWithInitArgs(cls=TestActor) + + print("create worker group") + wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="test") + + wg.start_worker_aliveness_check(1) + time.sleep(1) + + print(time.time(), "start foo") + + _ = wg.foo(wait_time) + print("foo started") + + print( + time.time(), + f"wait 6x wait time {wait_time * 6} to let signal returned to process but still not exceed process wait time", + ) + time.sleep(wait_time * 6) + + ray.shutdown() diff --git a/code/RL_model/verl/verl_train/tests/single_controller/detached_worker/README.md b/code/RL_model/verl/verl_train/tests/single_controller/detached_worker/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b06c4c6143e01d071458f7416033872d41d71031 --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/single_controller/detached_worker/README.md @@ -0,0 +1,14 @@ +# Detached Worker +## How to run (Only on a single node) +- Start a local ray cluster: +```bash +ray start --head --port=6379 +``` +- Run the server +```bash +python3 server.py +``` +- On another terminal, Run the client +```bash +python3 client.py +``` diff --git a/code/RL_model/verl/verl_train/tests/single_controller/detached_worker/client.py b/code/RL_model/verl/verl_train/tests/single_controller/detached_worker/client.py new file mode 100644 index 0000000000000000000000000000000000000000..8c78aaf5d37f6ca5aced3ba5a42b64218cb950e1 --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/single_controller/detached_worker/client.py @@ -0,0 +1,56 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +In client, we can get the server handler and send RPC request +""" + +import ray +import torch +from server import Trainer +from tensordict import TensorDict + +from verl import DataProto +from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup + + +def compute_position_id_with_mask(mask): + return torch.clip(torch.cumsum(mask, dim=-1) - 1, min=0, max=None) + + +if __name__ == "__main__": + ray.init(address="auto", namespace="verl") + # get the worker group using names + worker_names = ["trainerTrainer_0:0", "trainerTrainer_0:1"] + cls_with_init_args = RayClassWithInitArgs(cls=Trainer) + worker_group = RayWorkerGroup.from_detached(worker_names=worker_names, ray_cls_with_init=cls_with_init_args) + + batch_size = 16 + sequence_length = 1024 + + # give Trainer some data to train + input_ids = torch.randint(low=0, high=256, size=(batch_size, sequence_length), dtype=torch.int64, device="cuda") + attention_mask = torch.ones_like(input_ids) + position_ids = compute_position_id_with_mask(attention_mask) + + data = DataProto( + batch=TensorDict( + {"input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids}, + batch_size=batch_size, + ), + meta_info={}, + ) + + output = worker_group.train_model(data) + + print(output) diff --git a/code/RL_model/verl/verl_train/tests/single_controller/detached_worker/run.sh b/code/RL_model/verl/verl_train/tests/single_controller/detached_worker/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..a3c6387933262694bf3534066b4310fda0a9fea3 --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/single_controller/detached_worker/run.sh @@ -0,0 +1,5 @@ +#!/bin/bash +ray start --head --port=6379 +python3 server.py +python3 client.py +ray stop --force \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/tests/single_controller/detached_worker/server.py b/code/RL_model/verl/verl_train/tests/single_controller/detached_worker/server.py new file mode 100644 index 0000000000000000000000000000000000000000..f8a7f014d2317b2d1918b2fe9fd5d6b177e09317 --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/single_controller/detached_worker/server.py @@ -0,0 +1,152 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Server starts a Trainer. Client sends data to the server to train. +""" + +import os + +os.environ["MEGATRON_USE_CUDA_TIMER"] = "0" +os.environ["MEGATRON_START_PROCESS_TIMER"] = "False" +os.environ["NCCL_DEBUG"] = "WARN" + +import ray +import torch +from megatron.core import parallel_state as mpu +from megatron.core import tensor_parallel +from megatron.core.models.gpt.gpt_model import ModelType +from omegaconf import OmegaConf +from tensordict import TensorDict +from torch import nn +from transformers import LlamaConfig + +from verl import DataProto +from verl.models.llama.megatron import ParallelLlamaForCausalLMRmPadPP +from verl.single_controller.base import Worker +from verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register +from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup +from verl.utils.megatron.optimizer import get_megatron_optimizer, init_megatron_optim_config +from verl.utils.megatron_utils import get_model, mcore_model_parallel_config + + +@ray.remote +class Trainer(Worker): + def __init__(self): + super().__init__() + + if not torch.distributed.is_initialized(): + rank = int(os.environ["LOCAL_RANK"]) + torch.distributed.init_process_group(backend="nccl") + torch.cuda.set_device(rank) + + mpu.initialize_model_parallel( + tensor_model_parallel_size=2, + pipeline_model_parallel_size=1, + virtual_pipeline_model_parallel_size=None, + use_sharp=False, + context_parallel_size=1, + expert_model_parallel_size=1, + nccl_communicator_config_path=None, + ) + tensor_parallel.model_parallel_cuda_manual_seed(10) + + is_collect = ( + mpu.get_tensor_model_parallel_rank() == 0 + and mpu.get_pipeline_model_parallel_rank() == mpu.get_pipeline_model_parallel_world_size() - 1 + and mpu.get_context_parallel_rank() == 0 + ) + self._register_dispatch_collect_info( + mesh_name="train", dp_rank=mpu.get_data_parallel_rank(), is_collect=is_collect + ) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def init_model(self): + actor_model_config = LlamaConfig( + vocab_size=256, + hidden_size=2048, + intermediate_size=5504, + num_hidden_layers=24, + num_attention_heads=16, + num_key_value_heads=16, + ) + + megatron_config = mcore_model_parallel_config(sequence_parallel=True, params_dtype=torch.bfloat16) + self.megatron_config = megatron_config + + def megatron_actor_model_provider(pre_process, post_process): + # vpp is not supported yet because it will hang for some reason. Need debugging + # this_megatron_config = copy.deepcopy(megatron_config) + # this_megatron_config.virtual_pipeline_model_parallel_rank = vpp_rank + parallel_model = ParallelLlamaForCausalLMRmPadPP( + config=actor_model_config, + megatron_config=megatron_config, + pre_process=pre_process, + post_process=post_process, + ) + parallel_model.cuda() + return parallel_model + + actor_module = get_model( + model_provider_func=megatron_actor_model_provider, + model_type=ModelType.encoder_or_decoder, + wrap_with_ddp=True, + ) + actor_module = nn.ModuleList(actor_module) + + optim_config = OmegaConf.create({"lr": 1e-6, "clip_grad": 1.0}) + + optim_config = init_megatron_optim_config(optim_config) + self.optimizer_config = optim_config + actor_optimizer = get_megatron_optimizer(model=actor_module, config=optim_config) + + self.model = actor_module[0] + self.optimizer = actor_optimizer + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="train")) + def train_model(self, data: DataProto) -> DataProto: + input_ids = data.batch["input_ids"] + attention_mask = data.batch["attention_mask"] + position_ids = data.batch["position_ids"] + + self.optimizer.zero_grad() + self.model.zero_grad_buffer( + zero_buffer=(not self.optimizer_config.use_distributed_optimizer) + ) # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm + # update for 1 iteration + output = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids).logits + output.mean().backward() + + update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step( + self.megatron_config, self.megatron_config.timers + ) + + return DataProto(batch=TensorDict({"loss": output.detach()}, batch_size=output.shape[0])) + + +if __name__ == "__main__": + ray.init(address="auto", namespace="verl") + + resource_pool = RayResourcePool(process_on_nodes=[2], detached=True) + cls_with_init_args = RayClassWithInitArgs(cls=Trainer) + worker_group = RayWorkerGroup( + resource_pool=resource_pool, + ray_cls_with_init=cls_with_init_args, + name_prefix="trainer", + detached=True, + ) + + worker_group.init_model() + + worker_names = worker_group.worker_names + print(worker_names) diff --git a/code/RL_model/verl/verl_train/tests/special_e2e/envs/__init__.py b/code/RL_model/verl/verl_train/tests/special_e2e/envs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eb85e22f361e4af4635bda991ff12a1ed4911eec --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/special_e2e/envs/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .digit_completion import DigitCompletion + +__all__ = ["DigitCompletion"] diff --git a/code/RL_model/verl/verl_train/tests/special_e2e/envs/digit_completion/__init__.py b/code/RL_model/verl/verl_train/tests/special_e2e/envs/digit_completion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..80893ae41d6669f4f7265ce76d7ac28579b30b6f --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/special_e2e/envs/digit_completion/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers import AutoTokenizer, LlamaConfig + +from .task import DigitCompletion, generate_ground_truth_response +from .tokenizer import CharTokenizer + +AutoTokenizer.register(LlamaConfig, CharTokenizer, exist_ok=True) + +__all__ = ["DigitCompletion", "generate_ground_truth_response", "CharTokenizer"] diff --git a/code/RL_model/verl/verl_train/tests/special_e2e/envs/digit_completion/task.py b/code/RL_model/verl/verl_train/tests/special_e2e/envs/digit_completion/task.py new file mode 100644 index 0000000000000000000000000000000000000000..c3643a86b867b440352ed55dc0f978135ac79bcf --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/special_e2e/envs/digit_completion/task.py @@ -0,0 +1,179 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Task and environment definition for digit completion.""" + +import numpy as np + + +class DigitCompletion: + """ + The implementation of a simple digit completion task. + The prompt is a sequence of numbers with fixed difference. The task is to complete the next N numbers. + If the max number is reached, the next number should be modulo with max number. + + For example, + - prompt = [1, 2, 3] + - N = 5 + - max_number = 6 + + the response should be [4, 5, 6, 7%6, 8%6] = [4, 5, 6, 0, 1] + + Note that the tokenizer is char-level to increase the difficulty. + """ + + def __init__(self, max_number: int, max_diff: int, max_num_in_response: int, seed=0): + """ + + Args: + max_number: the maximum number allowed in the arithmetic sequence + max_diff: the maximum diff. The actual common diff will be sampled from [0, max_diff] + max_num_in_response: the maximum number in the response + """ + super().__init__() + self.max_number = max_number + self.max_diff = max_diff + self.max_num_in_response = max_num_in_response + assert self.max_num_in_response < 10 + assert self.max_number > 0 + assert self.max_diff > 0 + self.max_number_length = len(str(max_number)) + # {num1},{num2}:{max_num_in_response},{max_number} + self._prompt_length = self.max_number_length * 2 + 4 + self.max_number_length # no negative is allowed + + self.np_rng = np.random.default_rng(seed=seed) + + def __str__(self): + return ( + f"Prompt length: {self.prompt_length}. Response length: {self.response_length}, " + f"Max number: {self.max_number}. Max diff: {self.max_diff}, " + f"Max number in response: {self.max_num_in_response}" + ) + + def get_state(self): + return {"rng": self.np_rng} + + def set_state(self, state): + assert "rng" in state, "rng must be inside state" + self.np_rng = state["rng"] + + @property + def prompt_length(self): + return self._prompt_length + + @property + def response_length(self): + # number length + comma length + [EOS] + # The actual number times 1.5 to allow 'U' + return (self.max_num_in_response * self.max_number_length + (self.max_num_in_response - 1) + 1) * 2 + + def add(self, a, b): + return (a + b) % self.max_number + + def get_all_prompts(self): + all_prompts = [] + for first_num in range(self.max_number + 1): + for diff in range(0, self.max_diff + 1): + second_num = self.add(first_num, diff) + for num_to_complete in range(self.max_num_in_response + 1): + prompt = str(first_num) + "," + str(second_num) + f":{self.max_number},{num_to_complete}" + all_prompts.append(prompt) + return all_prompts + + def sample_str_prompts(self): + # step 1: sample initial numbers + first_num = self.np_rng.integers(self.max_number + 1) + diff = self.np_rng.integers(self.max_diff + 1) + second_num = self.add(first_num, diff) + num_to_complete = self.np_rng.integers(self.max_num_in_response + 1) + prompt = str(first_num) + "," + str(second_num) + f":{self.max_number},{num_to_complete}" + return prompt + + def sample_batch_str_prompts(self, batch_size): + str_prompts = [] + for _ in range(batch_size): + str_prompts.append(self.sample_str_prompts()) + return str_prompts + + +def compute_attention_mask(prompts, pad_token_id): + mask = np.ones_like(prompts) + mask[prompts == pad_token_id] = 0 + return mask + + +def compute_position_id_with_mask(mask): + return np.clip(np.cumsum(mask, axis=-1) - 1, a_min=0, a_max=None) + + +def generate_ground_truth_response(prompt: str): + """Generate ground truth response given a prompt.""" + num, info = prompt.split(":") + num1, num2 = num.split(",") + max_number, num_to_gen = info.split(",") + num1 = int(num1) + num2 = int(num2) + max_number = int(max_number) + num_to_gen = int(num_to_gen) + diff = (num2 - num1) % max_number + results = [] + last_num = num2 + for _ in range(num_to_gen): + curr = (last_num + diff) % max_number + results.append(str(curr)) + last_num = curr + response = ",".join(results) + return response + + +def compute_reward(prompt: str, response: str, sequence_reward=1.0): + """We compute dense reward here so that we can directly train RL without SFT""" + response_length = len(response) + ground_truth_response = generate_ground_truth_response(prompt) + per_token_reward = sequence_reward / (len(ground_truth_response) + 1) # including [EOS] + + # pad + reward = np.zeros(response_length, dtype=np.float32) # this assumes that each char is a token + # assign reward until mismatches + ground_truth_idx = 0 + for i in range(response_length): + if ground_truth_idx == len(ground_truth_response): + break + + ground_truth_response_token = ground_truth_response[ground_truth_idx] + response_token = response[i] + if ground_truth_response_token == response_token: + reward[i] = per_token_reward + ground_truth_idx += 1 + else: + # no matches + break + + return reward, {"ground_truth_response": ground_truth_response} + + +if __name__ == "__main__": + task = DigitCompletion(max_number=20, max_diff=3, max_num_in_response=5) + print(task.sample_str_prompts()) + + prompt = "7,8:20,0" + response = "" + print(compute_reward(prompt, response)) + + prompt = "7,8:20,0" + response = "E000" + print(compute_reward(prompt, response)) + + prompt = "9,10:20,2" + response = "11,12,13" + print(compute_reward(prompt, response)) diff --git a/code/RL_model/verl/verl_train/tests/special_e2e/envs/digit_completion/tokenizer.py b/code/RL_model/verl/verl_train/tests/special_e2e/envs/digit_completion/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..6ff471938937dc55ab528cb883e4ba2e03b35416 --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/special_e2e/envs/digit_completion/tokenizer.py @@ -0,0 +1,155 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Copied from https://github.com/dariush-bahrami/character-tokenizer/blob/master/charactertokenizer/core.py + +CharacterTokenzier for Hugging Face Transformers. + +This is heavily inspired from CanineTokenizer in transformers package. +""" + +import json +import os +from pathlib import Path +from typing import Optional, Sequence + +from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer + + +class CharTokenizer(PreTrainedTokenizer): + def __init__(self, characters: Sequence[str], model_max_length: int, chat_template, **kwargs): + """Character tokenizer for Hugging Face transformers. + + Args: + characters (Sequence[str]): List of desired characters. Any character which + is not included in this list will be replaced by a special token called + [UNK] with id=6. Following are list of all of the special tokens with + their corresponding ids: + "[CLS]": 0 + "[SEP]": 1 + "[BOS]": 2 + "[MASK]": 3 + "[PAD]": 4 + "[RESERVED]": 5 + "[UNK]": 6 + an id (starting at 7) will be assigned to each character. + + model_max_length (int): Model maximum sequence length. + """ + eos_token_str = "E" + sep_token_str = "S" + pad_token_str = "P" + unk_token_str = "U" + + self.characters = characters + self.model_max_length = model_max_length + eos_token = AddedToken(eos_token_str, lstrip=False, rstrip=False) + sep_token = AddedToken(sep_token_str, lstrip=False, rstrip=False) + pad_token = AddedToken(pad_token_str, lstrip=False, rstrip=False) + unk_token = AddedToken(unk_token_str, lstrip=False, rstrip=False) + + self._vocab_str_to_int = { + sep_token_str: 0, + eos_token_str: 1, + pad_token_str: 2, + unk_token_str: 3, + **{ch: i + 4 for i, ch in enumerate(characters)}, + } + self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()} + + super().__init__( + eos_token=eos_token, + sep_token=sep_token, + pad_token=pad_token, + unk_token=unk_token, + add_prefix_space=False, + model_max_length=model_max_length, + **kwargs, + ) + + self.chat_template = chat_template + + @property + def vocab_size(self) -> int: + return len(self._vocab_str_to_int) + + def get_vocab(self): + return self._vocab_str_to_int + + def _tokenize(self, text: str) -> list[str]: + return list(text) + + def _convert_token_to_id(self, token: str) -> int: + return self._vocab_str_to_int.get(token, self._vocab_str_to_int["U"]) + + def _convert_id_to_token(self, index: int) -> str: + return self._vocab_int_to_str[index] + + def convert_tokens_to_string(self, tokens): + return "".join(tokens) + + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + sep = [self.sep_token_id] + cls = [self.cls_token_id] + result = cls + token_ids_0 + sep + if token_ids_1 is not None: + result += token_ids_1 + sep + return result + + def get_special_tokens_mask( + self, + token_ids_0: list[int], + token_ids_1: Optional[list[int]] = None, + already_has_special_tokens: bool = False, + ) -> list[int]: + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, + token_ids_1=token_ids_1, + already_has_special_tokens=True, + ) + + result = [1] + ([0] * len(token_ids_0)) + [1] + if token_ids_1 is not None: + result += ([0] * len(token_ids_1)) + [1] + return result + + def get_config(self) -> dict: + return { + "char_ords": [ord(ch) for ch in self.characters], + "model_max_length": self.model_max_length, + "chat_template": self.chat_template, + } + + @classmethod + def from_config(cls, config: dict): + cfg = {} + cfg["characters"] = [chr(i) for i in config["char_ords"]] + cfg["model_max_length"] = config["model_max_length"] + cfg["chat_template"] = config["chat_template"] + return cls(**cfg) + + def save_pretrained(self, save_directory: str | os.PathLike, **kwargs): + cfg_file = Path(save_directory) / "tokenizer_config.json" + cfg = self.get_config() + with open(cfg_file, "w") as f: + json.dump(cfg, f, indent=4) + + @classmethod + def from_pretrained(cls, save_directory: str | os.PathLike, **kwargs): + cfg_file = Path(save_directory) / "tokenizer_config.json" + with open(cfg_file) as f: + cfg = json.load(f) + return cls.from_config(cfg) diff --git a/code/RL_model/verl/verl_train/tests/special_e2e/generation/run_gen_qwen05.sh b/code/RL_model/verl/verl_train/tests/special_e2e/generation/run_gen_qwen05.sh new file mode 100644 index 0000000000000000000000000000000000000000..61c55b157cdaa06b9fa0b977c733397f37c1ec61 --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/special_e2e/generation/run_gen_qwen05.sh @@ -0,0 +1,26 @@ +#!/usr/bin/env bash +# Tested with 1 & 4 GPUs +set -xeuo pipefail + +MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct} + +NGPUS_PER_NODE=${NGPUS_PER_NODE:-4} +OUTPUT_PATH=${OUTPUT_PATH:-$HOME/data/gen/qwen_05_gen_test.parquet} +GEN_TP=${GEN_TP:-2} # Default tensor parallel size to 2 + +python3 -m verl.trainer.main_generation \ + trainer.nnodes=1 \ + trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \ + data.path="${HOME}/data/gsm8k/test.parquet" \ + data.prompt_key=prompt \ + data.n_samples=1 \ + data.output_path="${OUTPUT_PATH}" \ + model.path="${MODEL_ID}" \ + +model.trust_remote_code=True \ + rollout.temperature=1.0 \ + rollout.top_k=50 \ + rollout.top_p=0.7 \ + rollout.prompt_length=2048 \ + rollout.response_length=1024 \ + rollout.tensor_model_parallel_size="${GEN_TP}" \ + rollout.gpu_memory_utilization=0.8 diff --git a/code/RL_model/verl/verl_train/tests/special_e2e/generation/run_gen_qwen05_server.sh b/code/RL_model/verl/verl_train/tests/special_e2e/generation/run_gen_qwen05_server.sh new file mode 100644 index 0000000000000000000000000000000000000000..0d55b167de6a7153ac29978aee3e52b35680b974 --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/special_e2e/generation/run_gen_qwen05_server.sh @@ -0,0 +1,26 @@ +#!/usr/bin/env bash +# Tested with 1 & 4 GPUs +set -xeuo pipefail + +MODEL_ID=${MODEL_ID:-$HOME/models/Qwen/Qwen2.5-0.5B-Instruct} +NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} +OUTPUT_PATH=${OUTPUT_PATH:-$HOME/data/gen/qwen_05_gen_test.parquet} +GEN_TP=${GEN_TP:-2} # Default tensor parallel size to 2 + +python3 -m verl.trainer.main_generation_server \ + trainer.nnodes=1 \ + trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \ + actor_rollout_ref.model.path="${MODEL_ID}" \ + actor_rollout_ref.model.trust_remote_code=True \ + actor_rollout_ref.rollout.temperature=1.0 \ + actor_rollout_ref.rollout.top_k=50 \ + actor_rollout_ref.rollout.top_p=0.7 \ + actor_rollout_ref.rollout.prompt_length=2048 \ + actor_rollout_ref.rollout.response_length=1024 \ + actor_rollout_ref.rollout.tensor_model_parallel_size="${GEN_TP}" \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=4 \ + data.train_files="${HOME}/data/gsm8k/test.parquet" \ + data.prompt_key=prompt \ + +data.output_path="${OUTPUT_PATH}" \ diff --git a/code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/expert_parallel/qwen2moe_minimal.json b/code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/expert_parallel/qwen2moe_minimal.json new file mode 100644 index 0000000000000000000000000000000000000000..c215fa4f7ccb777035e4be513045fb6ddb204b8f --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/expert_parallel/qwen2moe_minimal.json @@ -0,0 +1,4 @@ +{ + "num_hidden_layers": 2, + "max_window_layers": 2 +} \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/expert_parallel/qwen3moe_minimal.json b/code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/expert_parallel/qwen3moe_minimal.json new file mode 100644 index 0000000000000000000000000000000000000000..c215fa4f7ccb777035e4be513045fb6ddb204b8f --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/expert_parallel/qwen3moe_minimal.json @@ -0,0 +1,4 @@ +{ + "num_hidden_layers": 2, + "max_window_layers": 2 +} \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/run_function_reward.sh b/code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/run_function_reward.sh new file mode 100644 index 0000000000000000000000000000000000000000..3607af94df22d361519ab9ca0df4ba548c30993c --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/run_function_reward.sh @@ -0,0 +1,165 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +NUM_GPUS=${NUM_GPUS:-8} + +MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B} +MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}} +#hf download "${MODEL_ID}" --local-dir "${MODEL_PATH}" + +TRAIN_FILES=${TRAIN_FILES:-$HOME/data/gsm8k/train.parquet} +VAL_FILES=${VAL_FILES:-$HOME/data/gsm8k/test.parquet} +MAX_PROMPT_LEN=${MAX_PROMPT_LEN:-512} +MAX_RESPONSE_LEN=${MAX_RESPONSE_LEN:-512} + +ENGINE=${ENGINE:-vllm} +if [ "$ENGINE" = "vllm" ]; then + export VLLM_USE_V1=1 +fi +ROLLOUT_MODE="async" + +RETURN_RAW_CHAT="True" +SKIP_TOKENIZER_INIT="True" + +GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.7} +ACTOR_FSDP_PARAM_OFFLOAD=${ACTOR_FSDP_PARAM_OFFLOAD:-False} +ACTOR_FSDP_OPTIMIZER_OFFLOAD=${ACTOR_FSDP_OPTIMIZER_OFFLOAD:-False} +REF_FSDP_PARAM_OFFLOAD=${REF_FSDP_PARAM_OFFLOAD:-True} +RM_PAD=${RM_PAD:-True} +FUSED_KERNELS=${FUSED_KERNELS:-False} +FUSED_KERNEL_BACKEND=${FUSED_KERNEL_BACKEND:-torch} # or 'triton' for triton backend +ADV_ESTIMATOR=${ADV_ESTIMATOR:-gae} +LOSS_MODE=${LOSS_MODE:-vanilla} +USE_KL=${USE_KL:-False} +CUSTOM_REWARD_FN=${CUSTOM_REWARD_FN:-False} +ENABLE_CHUNKED_PREFILL=${ENABLE_CHUNKED_PREFILL:-True} # For vLLM VLM placeholder issue: https://github.com/vllm-project/vllm/issues/15185 +STRATEGY=${STRATEGY:-fsdp} +# LoRA config +LORA_RANK=${LORA_RANK:-0} +LORA_ALPHA=${LORA_ALPHA:-${LORA_RANK}} +LORA_TARGET=${LORA_TARGET:-"all-linear"} +LORA_EXCLUDE=${LORA_EXCLUDE:-"DONT_EXCLUDE"} +USE_SHM=${USE_SHM:-False} +LOAD_FORMAT=${LOAD_FORMAT:-dummy} +LAYERED_SUMMON=${LAYERED_SUMMON:-False} +# Validation +VAL_BEFORE_TRAIN=${VAL_BEFORE_TRAIN:-False} +TEST_FREQ=${TEST_FREQ:--1} +# Save & Resume +RESUME_MODE=${RESUME_MODE:-disable} +SAVE_FREQ=${SAVE_FREQ:--1} +TOTAL_TRAIN_STEPS=${TOTAL_TRAIN_STEPS:-1} + +# whether to save hf_model +SAVE_HF_MODEL=${SAVE_HF_MODEL:-False} +FSDP_SIZE=${FSDP_SIZE:--1} +SP_SIZE=${SP_SIZE:-1} + +if [ "${SAVE_HF_MODEL}" = "True" ]; then + CHECKPOINT_CONTENTS="['model','hf_model','optimizer','extra']" +else + CHECKPOINT_CONTENTS="['model','optimizer','extra']" +fi + +train_traj_micro_bsz_per_gpu=2 # b +n_resp_per_prompt=4 # g + +train_traj_micro_bsz=$((train_traj_micro_bsz_per_gpu * NUM_GPUS)) # b * n +train_traj_mini_bsz=$((train_traj_micro_bsz * 2)) # 2 * b * n +train_prompt_mini_bsz=$((train_traj_mini_bsz * n_resp_per_prompt)) # 2 * b * n / g +train_prompt_bsz=$((train_prompt_mini_bsz * 2)) # 4 * b * n / g + +reward_fn_name=null +reward_fn_file_path=null +output_file="$(pwd)/output.txt" +if [ "${CUSTOM_REWARD_FN}" = "True" ]; then + reward_fn_name="my_reward_function" + reward_fn_file_path="$(pwd)/my_reward_function.py" + rm -rf "${reward_fn_file_path}" + cat < "$reward_fn_file_path" +def ${reward_fn_name}(data_source, solution_str, ground_truth, extra_info=None): + print(f"Congratulations!!! You have called ${reward_fn_name} successfully!!!") + return 0.1 +EOF + + rm -rf "${output_file}" +fi + +exp_name="${VERL_EXP_NAME:-$(basename "${MODEL_ID,,}")-function-reward-minimal}" + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator="${ADV_ESTIMATOR}" \ + data.train_files="${TRAIN_FILES}" \ + data.val_files="${VAL_FILES}" \ + data.train_batch_size="${train_prompt_bsz}" \ + data.max_prompt_length="${MAX_PROMPT_LEN}" \ + data.max_response_length="${MAX_RESPONSE_LEN}" \ + data.return_raw_chat=${RETURN_RAW_CHAT} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.use_shm=${USE_SHM} \ + actor_rollout_ref.model.lora_rank=${LORA_RANK} \ + actor_rollout_ref.model.lora_alpha=${LORA_ALPHA} \ + actor_rollout_ref.model.target_modules=${LORA_TARGET} \ + actor_rollout_ref.model.exclude_modules=${LORA_EXCLUDE} \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding="${RM_PAD}" \ + actor_rollout_ref.model.use_fused_kernels=${FUSED_KERNELS} \ + actor_rollout_ref.model.fused_kernel_options.impl_backend=${FUSED_KERNEL_BACKEND} \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ + actor_rollout_ref.actor.strategy=${STRATEGY} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${ACTOR_FSDP_PARAM_OFFLOAD} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${ACTOR_FSDP_OPTIMIZER_OFFLOAD} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=${FSDP_SIZE} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size="${SP_SIZE}" \ + actor_rollout_ref.actor.checkpoint.save_contents=${CHECKPOINT_CONTENTS} \ + actor_rollout_ref.actor.use_kl_loss="${USE_KL}" \ + actor_rollout_ref.actor.policy_loss.loss_mode="${LOSS_MODE}" \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name="${ENGINE}" \ + actor_rollout_ref.rollout.mode="${ROLLOUT_MODE}" \ + actor_rollout_ref.rollout.load_format=${LOAD_FORMAT} \ + actor_rollout_ref.rollout.layered_summon=${LAYERED_SUMMON} \ + actor_rollout_ref.rollout.skip_tokenizer_init="${SKIP_TOKENIZER_INIT}" \ + actor_rollout_ref.rollout.gpu_memory_utilization="${GPU_MEMORY_UTILIZATION}" \ + actor_rollout_ref.rollout.enable_chunked_prefill="${ENABLE_CHUNKED_PREFILL}" \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ + actor_rollout_ref.ref.fsdp_config.param_offload="${REF_FSDP_PARAM_OFFLOAD}" \ + critic.optim.lr=1e-5 \ + critic.model.use_remove_padding="${RM_PAD}" \ + critic.model.path="${MODEL_PATH}" \ + critic.model.enable_gradient_checkpointing=False \ + critic.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ + critic.model.fsdp_config.param_offload=False \ + critic.model.fsdp_config.optimizer_offload=False \ + custom_reward_function.path="${reward_fn_file_path}"\ + custom_reward_function.name="${reward_fn_name}"\ + algorithm.use_kl_in_reward="${USE_KL}" \ + algorithm.kl_penalty=kl \ + algorithm.kl_ctrl.kl_coef=0.001 \ + trainer.critic_warmup=0 \ + trainer.logger=console \ + trainer.project_name='verl-test' \ + trainer.experiment_name="${exp_name}" \ + trainer.nnodes=1 \ + trainer.n_gpus_per_node="${NUM_GPUS}" \ + trainer.val_before_train="${VAL_BEFORE_TRAIN}" \ + trainer.test_freq="${TEST_FREQ}" \ + trainer.save_freq="${SAVE_FREQ}" \ + trainer.resume_mode="${RESUME_MODE}" \ + trainer.total_epochs=2 \ + trainer.device=cuda \ + trainer.total_training_steps="${TOTAL_TRAIN_STEPS}" $@ \ + | tee "${output_file}" + +if [ "${CUSTOM_REWARD_FN}" = "True" ]; then + python3 tests/special_e2e/check_custom_rwd_fn.py --output_file="${output_file}" + check_exit_code=$? + rm -rf "${reward_fn_file_path}" + rm -rf "${output_file}" + # Return the exit code of check_custom_rwd_fn.py if it fails + if [ $check_exit_code -ne 0 ]; then + exit $check_exit_code + fi +fi diff --git a/code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/run_model_reward.sh b/code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/run_model_reward.sh new file mode 100644 index 0000000000000000000000000000000000000000..68eb4171f8e5f1d5b5933ead68b50a67de93da34 --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/run_model_reward.sh @@ -0,0 +1,101 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +NUM_GPUS=${NUM_GPUS:-8} + +MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B} +MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}} +#hf download "${MODEL_ID}" --local-dir "${MODEL_PATH}" + +TRAIN_FILES=${TRAIN_FILES:-$HOME/data/gsm8k/train.parquet} +VAL_FILES=${VAL_FILES:-$HOME/data/gsm8k/test.parquet} + +RM_PAD=${RM_PAD:-True} +FUSED_KERNELS=${FUSED_KERNELS:-False} +FUSED_KERNEL_BACKEND=${FUSED_KERNEL_BACKEND:-torch} # or 'triton' for triton backend +SP_SIZE=${SP_SIZE:-1} +SEQ_BALANCE=${SEQ_BALANCE:-False} +LIGER=${LIGER:-False} +# Validation +VAL_BEFORE_TRAIN=${VAL_BEFORE_TRAIN:-False} +TEST_FREQ=${TEST_FREQ:--1} +# Save & Resume +RESUME_MODE=${RESUME_MODE:-disable} +SAVE_FREQ=${SAVE_FREQ:--1} +TOTAL_TRAIN_STEPS=${TOTAL_TRAIN_STEPS:-1} + +train_traj_micro_bsz_per_gpu=2 # b +n_resp_per_prompt=4 # g + +train_traj_micro_bsz=$((train_traj_micro_bsz_per_gpu * NUM_GPUS)) # b * n +train_traj_mini_bsz=$((train_traj_micro_bsz * 2)) # 2 * b * n +train_prompt_mini_bsz=$((train_traj_mini_bsz * n_resp_per_prompt)) # 2 * b * n / g +train_prompt_bsz=$((train_prompt_mini_bsz * 2)) # 4 * b * n / g + +train_max_token_num_per_gpu=32768 +infer_max_token_num_per_gpu=32768 + +exp_name="$(basename "${MODEL_ID,,}")-model-reward-minimal" + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=gae \ + data.train_files="${TRAIN_FILES}" \ + data.val_files="${VAL_FILES}" \ + data.train_batch_size=${train_prompt_bsz} \ + data.max_prompt_length=512 \ + data.max_response_length=512 \ + data.return_raw_chat=True \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.use_liger="${LIGER}" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding="${RM_PAD}" \ + actor_rollout_ref.model.use_fused_kernels=${FUSED_KERNELS} \ + actor_rollout_ref.model.fused_kernel_options.impl_backend=${FUSED_KERNEL_BACKEND} \ + actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.use_dynamic_bsz="${SEQ_BALANCE}" \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${train_max_token_num_per_gpu} \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size="${SP_SIZE}" \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_max_token_num_per_gpu} \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_max_token_num_per_gpu} \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ + critic.optim.lr=1e-5 \ + critic.ulysses_sequence_parallel_size="${SP_SIZE}" \ + critic.model.use_remove_padding="${RM_PAD}" \ + critic.optim.lr_warmup_steps_ratio=0.05 \ + critic.model.path="${MODEL_PATH}" \ + critic.model.enable_gradient_checkpointing=False \ + critic.use_dynamic_bsz="${SEQ_BALANCE}" \ + critic.ppo_max_token_len_per_gpu=${train_max_token_num_per_gpu} \ + critic.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ + critic.model.fsdp_config.param_offload=False \ + critic.model.fsdp_config.optimizer_offload=False \ + reward_model.enable=True \ + reward_model.model.path="${MODEL_PATH}" \ + reward_model.use_reward_loop=True \ + reward_model.rollout.gpu_memory_utilization=0.8 \ + reward_model.rollout.tensor_model_parallel_size=1 \ + reward_model.rollout.prompt_length=1024 \ + reward_model.rollout.response_length=512 \ + reward_model.num_workers=8 \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger=console \ + trainer.project_name='verl-test' \ + trainer.experiment_name="${exp_name}" \ + trainer.nnodes=1 \ + trainer.n_gpus_per_node="${NUM_GPUS}" \ + trainer.val_before_train="${VAL_BEFORE_TRAIN}" \ + trainer.test_freq="${VAL_BEFORE_TRAIN}" \ + trainer.save_freq="${SAVE_FREQ}" \ + trainer.resume_mode="${RESUME_MODE}" \ + trainer.total_epochs=2 \ + trainer.total_training_steps="${TOTAL_TRAIN_STEPS}" $@ diff --git a/code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/run_single_gpu.sh b/code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/run_single_gpu.sh new file mode 100644 index 0000000000000000000000000000000000000000..7e8615a24fbaad4b01993ddaa755e2ddb79bfde1 --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/run_single_gpu.sh @@ -0,0 +1,24 @@ +PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.train_batch_size=256 \ + data.max_prompt_length=512 \ + data.max_response_length=256 \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=64 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + critic.optim.lr=1e-5 \ + critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \ + critic.ppo_micro_batch_size_per_gpu=4 \ + algorithm.kl_ctrl.kl_coef=0.001 \ + trainer.logger=console \ + trainer.val_before_train=False \ + trainer.n_gpus_per_node=1 \ + trainer.nnodes=1 \ + actor_rollout_ref.rollout.name=hf \ + trainer.total_training_steps=2 \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/run_single_gpu_with_engine.sh b/code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/run_single_gpu_with_engine.sh new file mode 100644 index 0000000000000000000000000000000000000000..9f36a9dc8605e37bf70ab3acdf22acd84cdcb0d5 --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/run_single_gpu_with_engine.sh @@ -0,0 +1,25 @@ +PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.train_batch_size=256 \ + data.max_prompt_length=512 \ + data.max_response_length=256 \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=64 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + critic.optim.lr=1e-5 \ + critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \ + critic.ppo_micro_batch_size_per_gpu=4 \ + algorithm.kl_ctrl.kl_coef=0.001 \ + trainer.logger=['console'] \ + trainer.val_before_train=False \ + trainer.n_gpus_per_node=1 \ + trainer.nnodes=1 \ + actor_rollout_ref.rollout.name=hf \ + trainer.use_legacy_worker_impl=disable \ + trainer.total_training_steps=2 \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/tests/special_e2e/sft/compare_sft_engine_results.py b/code/RL_model/verl/verl_train/tests/special_e2e/sft/compare_sft_engine_results.py new file mode 100644 index 0000000000000000000000000000000000000000..322f5353c06e7fd8463b9236a69b8fe078f9adb9 --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/special_e2e/sft/compare_sft_engine_results.py @@ -0,0 +1,58 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os + +import torch + + +def get_result(file): + file = os.path.expanduser(file) + result = [] + with open(file) as f: + lines = f.readlines() + for line in lines: + result.append(json.loads(line)) + return result + + +def compare_results(golden_results, other_result): + golden_loss = golden_results[0]["data"]["train/loss"] + golden_grad_norm = golden_results[0]["data"]["train/grad_norm"] + + loss = other_result[0]["data"]["train/loss"] + grad_norm = other_result[0]["data"]["train/grad_norm"] + + torch.testing.assert_close(golden_loss, loss, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(golden_grad_norm, grad_norm, atol=1e-4, rtol=3e-2) + + +if __name__ == "__main__": + golden_results = get_result("~/verl/test/log/golden.jsonl") + + # get all other results + other_results = {} + # walk through all files in ~/verl/test/log + for file in os.listdir(os.path.expanduser("~/verl/test/log/verl_sft_test")): + if file.endswith(".jsonl"): + other_results[file] = get_result(os.path.join(os.path.expanduser("~/verl/test/log/verl_sft_test"), file)) + + # # compare results + for file, other_result in other_results.items(): + print(f"compare results {file}") + compare_results(golden_results, other_result) + print(f"compare results {file} done") + + print("All results are close to golden results") diff --git a/code/RL_model/verl/verl_train/tests/special_e2e/sft/run_sft.sh b/code/RL_model/verl/verl_train/tests/special_e2e/sft/run_sft.sh new file mode 100644 index 0000000000000000000000000000000000000000..4cef7c680824a1df920fe5276863f432f24ebb0c --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/special_e2e/sft/run_sft.sh @@ -0,0 +1,63 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +ENTRYPOINT=${ENTRYPOINT:-"-m verl.trainer.fsdp_sft_trainer"} + +NUM_GPUS=${NUM_GPUS:-8} + +MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct} +MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}} +#hf download "${MODEL_ID}" --local-dir "${MODEL_PATH}" + +TRAIN_FILES=${TRAIN_FILES:-$HOME/data/gsm8k/train.parquet} +VAL_FILES=${VAL_FILES:-$HOME/data/gsm8k/test.parquet} + +SP_SIZE=${SP_SIZE:-1} +LIGER=${LIGER:-False} +MULTITURN=${MULTITURN:-False} +LORA_RANK=${LORA_RANK:-0} +RM_PAD=${RM_PAD:-True} + +TOTAL_TRAIN_STEP=${TOTAL_TRAIN_STEP:-1} +RESUME_MODE=${RESUME_MODE:-disable} +SAVE_FREQ=${SAVE_FREQ:-1} + +micro_bsz=2 +NUM_GPUS=8 + +project_name="verl-test" +exp_name="$(basename "${MODEL_ID,,}")-sft-minimal" +ckpts_home=${ckpts_home:-$HOME/${project_name}/${exp_name}} + +mkdir -p "${ckpts_home}" + +torchrun --standalone --nnodes=1 --nproc_per_node=${NUM_GPUS} ${ENTRYPOINT} \ + data.train_files="${TRAIN_FILES}" \ + data.val_files="${VAL_FILES}" \ + data.prompt_key=extra_info \ + data.response_key=extra_info \ + data.prompt_dict_keys=['question'] \ + data.response_dict_keys=['answer'] \ + data.multiturn.enable="${MULTITURN}" \ + data.multiturn.messages_key=messages \ + optim.lr=1e-4 \ + data.micro_batch_size_per_gpu=${micro_bsz} \ + model.strategy=fsdp \ + model.partial_pretrain="${MODEL_PATH}" \ + model.lora_rank="${LORA_RANK}" \ + model.lora_alpha=16 \ + model.target_modules=all-linear \ + model.use_liger="${LIGER}" \ + ulysses_sequence_parallel_size="${SP_SIZE}" \ + use_remove_padding="${RM_PAD}" \ + trainer.default_local_dir="${ckpts_home}" \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.total_training_steps=${TOTAL_TRAIN_STEP} \ + trainer.save_freq=${SAVE_FREQ} \ + trainer.checkpoint.save_contents=[model,optimizer,extra,hf_model] \ + trainer.max_ckpt_to_keep=1 \ + trainer.resume_mode=${RESUME_MODE} \ + trainer.logger=['console'] $@ + +rm -rf "${ckpts_home:?}/*" \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/tests/special_e2e/sft/run_sft_engine.sh b/code/RL_model/verl/verl_train/tests/special_e2e/sft/run_sft_engine.sh new file mode 100644 index 0000000000000000000000000000000000000000..f3657ae6d9469d2f21b519d9c10580b24be9dab8 --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/special_e2e/sft/run_sft_engine.sh @@ -0,0 +1,134 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +NUM_GPUS=${NUM_GPUS:-1} + +mode=${mode:-spmd} + +if [ "$mode" = "spmd" ]; then + ENTRYPOINT=${ENTRYPOINT:-"-m verl.trainer.sft_trainer"} + COMMAND="torchrun --standalone --nnodes=${NNODES:-1} --nproc-per-node=${NUM_GPUS:-1} ${ENTRYPOINT}" +else + ENTRYPOINT=${ENTRYPOINT:-"-m verl.trainer.sft_trainer_ray"} + COMMAND="python ${ENTRYPOINT} trainer.nnodes=${NNODES:-1} trainer.n_gpus_per_node=${NUM_GPUS:-1}" +fi + +DATASET_DIR=${DATASET_DIR:-~/data/gsm8k_sft} +TRAIN_FILES=${DATASET_DIR}/train.parquet +VAL_FILES=${DATASET_DIR}/test.parquet + +backend=${BACKEND:-fsdp} + +project_name=verl_sft_test + +RESUME_MODE=disable + +ckpts_home=${ckpts_home:-~/verl/test/gsm8k-sft-${backend}} + +MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B} +MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}} +#hf download "${MODEL_ID}" --local-dir "${MODEL_PATH}" + +SP_SIZE=${SP_SIZE:-1} +FSDP_SIZE=${FSDP_SIZE:-${NUM_GPUS}} +FSDP_STRATEGY=${FSDP_STRATEGY:-"fsdp"} + +TP_SIZE=${TP_SIZE:-1} +PP_SIZE=${PP_SIZE:-1} +VPP_SIZE=${VPP_SIZE:-null} +CP_SIZE=${CP_SIZE:-1} + +PAD_MODE=${PAD_MODE:-no_padding} + +USE_REMOVE_PADDING=${USE_REMOVE_PADDING:-True} + +FSDP_ENGINE_CONFIG="\ + engine=${backend} \ + optim=${backend} \ + optim.lr=1e-5 \ + optim.lr_warmup_steps_ratio=0.2 \ + optim.weight_decay=0.1 \ + optim.betas="[0.9,0.95]" \ + optim.clip_grad=1.0 \ + optim.min_lr_ratio=0.1 \ + optim.lr_scheduler_type=cosine \ + engine.ulysses_sequence_parallel_size=${SP_SIZE} \ + engine.strategy=${FSDP_STRATEGY} \ + engine.fsdp_size=${FSDP_SIZE}" + +VEOMNI_ENGINE_CONFIG="\ + engine=${backend} \ + optim=${backend} \ + optim.lr=1e-5 \ + optim.lr_warmup_steps_ratio=0.2 \ + optim.weight_decay=0.1 \ + optim.betas="[0.9,0.95]" \ + optim.clip_grad=1.0 \ + optim.lr_min=1e-6 \ + optim.lr_scheduler_type=cosine \ + engine.ulysses_parallel_size=${SP_SIZE} \ + engine.data_parallel_mode=${FSDP_STRATEGY} \ + engine.data_parallel_size=${FSDP_SIZE}" + + +MEGATRON_ENGINE_CONFIG="\ + engine=${backend} \ + optim=${backend} \ + optim.lr=1e-5 \ + optim.lr_warmup_steps_ratio=0.2 \ + optim.weight_decay=0.1 \ + optim.betas="[0.9,0.95]" \ + optim.clip_grad=1.0 \ + optim.lr_warmup_init=0 \ + optim.lr_decay_style=cosine \ + optim.min_lr=1e-6 \ + engine.tensor_model_parallel_size=${TP_SIZE} \ + engine.pipeline_model_parallel_size=${PP_SIZE} \ + engine.virtual_pipeline_model_parallel_size=${VPP_SIZE} \ + engine.context_parallel_size=${CP_SIZE} \ + +engine.override_transformer_config.context_parallel_size=${CP_SIZE} \ + engine.use_mbridge=True" + +if [ "$backend" = "fsdp" ]; then + ENGINE_CONFIG="$FSDP_ENGINE_CONFIG" + echo "Using fsdp engine" + exp_name=gsm8k-${backend}-${FSDP_STRATEGY}-sp${SP_SIZE}-fsdp${FSDP_SIZE}-pad-${PAD_MODE}-use_remove_padding-${USE_REMOVE_PADDING}-mode-${mode} +elif [ "$backend" = "veomni" ]; then + ENGINE_CONFIG="$VEOMNI_ENGINE_CONFIG" + echo "Using veomni engine" + exp_name=gsm8k-${backend}-sp${SP_SIZE}-fsdp${FSDP_SIZE}-pad-${PAD_MODE}-use_remove_padding-${USE_REMOVE_PADDING}-mode-${mode} +else + ENGINE_CONFIG="$MEGATRON_ENGINE_CONFIG" + echo "Using megatron engine" + exp_name=gsm8k-${backend}-tp${TP_SIZE}-pp${PP_SIZE}-vpp${VPP_SIZE}-cp${CP_SIZE}-pad-${PAD_MODE}-use_remove_padding-${USE_REMOVE_PADDING}-mode-${mode} +fi + +mkdir -p "${ckpts_home}" + +$COMMAND \ + data.train_files="${TRAIN_FILES}" \ + data.val_files="${VAL_FILES}" \ + data.train_batch_size=128 \ + data.pad_mode=${PAD_MODE} \ + data.truncation=error \ + data.use_dynamic_bsz=True \ + data.max_token_len_per_gpu=2048 \ + data.messages_key=messages \ + model.path=$MODEL_PATH \ + model.use_remove_padding=${USE_REMOVE_PADDING} \ + ${ENGINE_CONFIG} \ + trainer.test_freq=after_each_epoch \ + trainer.save_freq=-1 \ + trainer.logger=['console','file'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.total_epochs=2 \ + trainer.total_training_steps=2 \ + trainer.default_local_dir="${ckpts_home}" \ + trainer.resume_mode=${RESUME_MODE} \ + + # trainer.total_training_steps=${TOTAL_TRAIN_STEP} \ + # trainer.checkpoint.save_contents=[model,optimizer,extra,hf_model] \ + # trainer.max_ckpt_to_keep=1 \ + +rm -rf "${ckpts_home:?}/*" \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/tests/special_e2e/sft/test_sft_engine_all.sh b/code/RL_model/verl/verl_train/tests/special_e2e/sft/test_sft_engine_all.sh new file mode 100644 index 0000000000000000000000000000000000000000..96f5f1956920a9d97f531f61373620a5c07d3df8 --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/special_e2e/sft/test_sft_engine_all.sh @@ -0,0 +1,42 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +rm -rf ~/verl/test/log +mkdir -p ~/verl/test/log + +export VERL_FILE_LOGGER_ROOT=~/verl/test/log +VPP_SIZE=${VPP_SIZE:-2} + +# test with single gpu as golden +echo "run with single gpu as golden" +BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=1 NUM_GPUS=1 FSDP_STRATEGY=fsdp VERL_FILE_LOGGER_PATH=~/verl/test/log/golden.jsonl bash tests/special_e2e/sft/run_sft_engine.sh + +# test with fsdp 1 +echo "run with sp2 fsdp_size2 num_gpus8 fsdp_strategy fsdp pad_mode no_padding" +BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine.sh + +# test with fsdp 1 use_remove_padding and pad_mode no_padding +echo "run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp pad_mode no_padding use_remove_padding False" +BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding USE_REMOVE_PADDING=False bash tests/special_e2e/sft/run_sft_engine.sh + + +# test with fsdp 2 +echo "run with sp2 fsdp_size2 num_gpus8 fsdp_strategy fsdp2" +BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine.sh + +# test with veomni +echo "run with sp2 fsdp_size4 num_gpus8 fsdp_strategy fsdp2" +BACKEND=veomni SP_SIZE=2 FSDP_SIZE=4 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine.sh + + +# test with megatron +echo "run with tp2 pp2 vpp2 cp2 num_gpus8" +BACKEND=megatron TP_SIZE=2 PP_SIZE=2 VPP_SIZE=${VPP_SIZE} CP_SIZE=2 NUM_GPUS=8 bash tests/special_e2e/sft/run_sft_engine.sh + +# test with cp in ray +echo "run with tp2 pp2 vpp2 cp2 num_gpus8 mode=ray" +BACKEND=megatron TP_SIZE=2 PP_SIZE=2 VPP_SIZE=${VPP_SIZE} CP_SIZE=2 NUM_GPUS=8 mode=ray bash tests/special_e2e/sft/run_sft_engine.sh + +python3 tests/special_e2e/sft/compare_sft_engine_results.py + +rm -rf ~/verl/test/log diff --git a/code/RL_model/verl/verl_train/tests/special_e2e/sft/test_sp_loss_match.py b/code/RL_model/verl/verl_train/tests/special_e2e/sft/test_sp_loss_match.py new file mode 100644 index 0000000000000000000000000000000000000000..5d8e59e721d9359b6030b8fe1a80d09c1e0540e3 --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/special_e2e/sft/test_sp_loss_match.py @@ -0,0 +1,150 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed +from tensordict import TensorDict +from torch.distributed.device_mesh import init_device_mesh + +from verl.trainer.fsdp_sft_trainer import FSDPSFTTrainer +from verl.utils.distributed import initialize_global_process_group + + +def test_trainer_forward_consistency(trainer: FSDPSFTTrainer, total_steps: int = 4): + """Test consistency between original forward pass and SP+rmpad forward passes. + + Args: + trainer: The FSDPSFTTrainer instance to test + total_steps: Number of steps to test (default: 4) + """ + if trainer.device_mesh.get_rank() == 0: + print("\nStarting debug comparison between original and SP+rmpad forward passes...") + print(f"Sequence parallel size: {trainer.config.ulysses_sequence_parallel_size}") + print(f"Remove padding: {trainer.use_remove_padding}\n") + + steps_remaining = total_steps + + for epoch in range(1): # Just one epoch for testing + trainer.train_sampler.set_epoch(epoch=epoch) + for data in trainer.train_dataloader: + data = TensorDict(data, batch_size=trainer.config.data.train_batch_size).cuda() + trainer.fsdp_model.train() + micro_batches = data.split(trainer.config.data.micro_batch_size_per_gpu) + + for idx, micro_batch in enumerate(micro_batches): + if trainer.device_mesh.get_rank() == 0: + print(f"\nProcessing micro batch {idx + 1}/{len(micro_batches)}") + + # Compute losses using both methods + # Disable SP and rmpad + trainer.use_remove_padding = False + old_sp = trainer.config.ulysses_sequence_parallel_size + trainer.config.ulysses_sequence_parallel_size = 1 + loss_ref = trainer._compute_loss_and_backward(micro_batch.copy(), do_backward=False) + + # Do SP and rmpad + trainer.config.ulysses_sequence_parallel_size = old_sp + trainer.use_remove_padding = True + loss_sp = trainer._compute_loss_and_backward(micro_batch.copy(), do_backward=False) + + # Collect losses across all ranks + loss_ref_all = loss_ref.clone() + loss_sp_all = loss_sp.clone() + torch.distributed.all_reduce(loss_ref_all, op=torch.distributed.ReduceOp.AVG) + torch.distributed.all_reduce(loss_sp_all, op=torch.distributed.ReduceOp.AVG) + + # Calculate relative difference of averaged losses + rel_diff = torch.abs(loss_ref_all - loss_sp_all) / (torch.abs(loss_ref_all) + 1e-8) + + if trainer.device_mesh.get_rank() == 0: + print("\nComparison Results (Averaged across ranks):") + print(f"Reference Loss: {loss_ref_all.item():.6f}") + print(f"SP+rmpad Loss: {loss_sp_all.item():.6f}") + print(f"Relative Difference: {rel_diff.item():.6f}") + + assert rel_diff.item() < 1e-2, "Significant difference detected between averaged losses!" + print("Loss difference is within the acceptable range.") + + steps_remaining -= 1 + if steps_remaining == 0: + break + if steps_remaining == 0: + break + break + + if trainer.device_mesh.get_rank() == 0: + print("\nDebug comparison completed successfully.") + + +def create_trainer(config): + """Create and initialize a trainer instance with the given config. + + Args: + config: Configuration object with training parameters + + Returns: + FSDPSFTTrainer: Initialized trainer instance + """ + local_rank, rank, world_size = initialize_global_process_group() + + device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,), mesh_dim_names=("fsdp",)) + + dp_size = world_size // config.ulysses_sequence_parallel_size + ulysses_device_mesh = init_device_mesh( + device_type="cuda", mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), mesh_dim_names=("dp", "sp") + ) + + # build tokenizer and datasets first + from verl.trainer.fsdp_sft_trainer import create_sft_dataset + from verl.utils import hf_tokenizer + from verl.utils.fs import copy_to_local + + local_model_path = copy_to_local(src=config.model.partial_pretrain, verbose=True) + tokenizer = hf_tokenizer(local_model_path, trust_remote_code=config.model.trust_remote_code) + train_dataset = create_sft_dataset( + config.data.train_files, config.data, tokenizer, max_samples=config.data.get("train_max_samples", -1) + ) + val_dataset = create_sft_dataset( + config.data.val_files, config.data, tokenizer, max_samples=config.data.get("val_max_samples", -1) + ) + + return FSDPSFTTrainer( + config=config, + device_mesh=device_mesh, + ulysses_device_mesh=ulysses_device_mesh, + tokenizer=tokenizer, + train_dataset=train_dataset, + val_dataset=val_dataset, + ) + + +def main(config): + """Main function to run trainer tests. + + Args: + config: Configuration object with training parameters + """ + trainer = create_trainer(config) + test_trainer_forward_consistency(trainer) + + +if __name__ == "__main__": + import hydra + from omegaconf import DictConfig + + @hydra.main(config_path="../../../verl/trainer/config", config_name="sft_trainer") + def hydra_entry(cfg: DictConfig) -> None: + main(cfg) + + hydra_entry() diff --git a/code/RL_model/verl/verl_train/tests/trainer/config/__init__.py b/code/RL_model/verl/verl_train/tests/trainer/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1ce90c5eb352d85c59105c0dc85b5f1dd576f095 --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/trainer/config/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/code/RL_model/verl/verl_train/tests/utils/ckpt/test_checkpoint_cleanup_on_cpu.py b/code/RL_model/verl/verl_train/tests/utils/ckpt/test_checkpoint_cleanup_on_cpu.py new file mode 100644 index 0000000000000000000000000000000000000000..166208a4fc3493342818d6b530686b7e84015816 --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/utils/ckpt/test_checkpoint_cleanup_on_cpu.py @@ -0,0 +1,139 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import tempfile + +import pytest + + +class TestCheckpointCleanupLogic: + """Tests for checkpoint cleanup methods in BaseCheckpointManager.""" + + @pytest.fixture(autouse=True) + def setup(self): + """Set up test fixtures.""" + self.test_dir = tempfile.mkdtemp() + yield + shutil.rmtree(self.test_dir, ignore_errors=True) + + @pytest.fixture + def manager(self, monkeypatch): + """Create a minimal BaseCheckpointManager for testing.""" + import torch.distributed + + monkeypatch.setattr(torch.distributed, "get_rank", lambda: 0) + monkeypatch.setattr(torch.distributed, "get_world_size", lambda: 1) + + from verl.utils.checkpoint.checkpoint_manager import BaseCheckpointManager + + class MockModel: + pass + + class MockOptimizer: + pass + + return BaseCheckpointManager( + model=MockModel(), + optimizer=MockOptimizer(), + lr_scheduler=None, + processing_class=None, + checkpoint_config=None, + ) + + def _create_checkpoint_dir(self, step: int) -> str: + """Create a mock checkpoint directory.""" + path = os.path.join(self.test_dir, f"global_step_{step}") + os.makedirs(path, exist_ok=True) + with open(os.path.join(path, "checkpoint.txt"), "w") as f: + f.write(f"step={step}") + return path + + def test_max_ckpt_1_preserves_existing_before_save(self, manager): + """ + Regression test: max_ckpt_to_keep=1 must NOT delete existing checkpoint before save. + """ + ckpt_100 = self._create_checkpoint_dir(100) + manager.previous_saved_paths = [ckpt_100] + + manager.ensure_checkpoint_capacity(max_ckpt_to_keep=1) + + assert os.path.exists(ckpt_100), "Bug: checkpoint deleted before save!" + assert manager.previous_saved_paths == [ckpt_100] + + def test_max_ckpt_1_deletes_old_after_save(self, manager): + """After save succeeds, old checkpoint should be deleted.""" + ckpt_100 = self._create_checkpoint_dir(100) + manager.previous_saved_paths = [ckpt_100] + + ckpt_200 = self._create_checkpoint_dir(200) + manager.register_checkpoint(ckpt_200, max_ckpt_to_keep=1) + + assert not os.path.exists(ckpt_100) + assert os.path.exists(ckpt_200) + assert manager.previous_saved_paths == [ckpt_200] + + def test_max_ckpt_2_keeps_one_before_save(self, manager): + """With max_ckpt_to_keep=2, pre-save cleanup keeps 1 checkpoint.""" + ckpt_100 = self._create_checkpoint_dir(100) + ckpt_200 = self._create_checkpoint_dir(200) + manager.previous_saved_paths = [ckpt_100, ckpt_200] + + manager.ensure_checkpoint_capacity(max_ckpt_to_keep=2) + + assert not os.path.exists(ckpt_100) + assert os.path.exists(ckpt_200) + assert len(manager.previous_saved_paths) == 1 + + def test_max_ckpt_0_keeps_all(self, manager): + """max_ckpt_to_keep=0 means unlimited - no deletions.""" + ckpt_100 = self._create_checkpoint_dir(100) + ckpt_200 = self._create_checkpoint_dir(200) + manager.previous_saved_paths = [ckpt_100, ckpt_200] + + manager.ensure_checkpoint_capacity(max_ckpt_to_keep=0) + ckpt_300 = self._create_checkpoint_dir(300) + manager.register_checkpoint(ckpt_300, max_ckpt_to_keep=0) + + assert os.path.exists(ckpt_100) + assert os.path.exists(ckpt_200) + assert os.path.exists(ckpt_300) + assert len(manager.previous_saved_paths) == 3 + + def test_full_save_cycle_max_ckpt_1(self, manager): + """Simulate multiple save cycles with max_ckpt_to_keep=1.""" + # First save + manager.ensure_checkpoint_capacity(1) + ckpt_100 = self._create_checkpoint_dir(100) + manager.register_checkpoint(ckpt_100, 1) + assert manager.previous_saved_paths == [ckpt_100] + + # Second save - existing checkpoint must survive pre-save + manager.ensure_checkpoint_capacity(1) + assert os.path.exists(ckpt_100), "Bug: checkpoint deleted before save!" + + ckpt_200 = self._create_checkpoint_dir(200) + manager.register_checkpoint(ckpt_200, 1) + assert not os.path.exists(ckpt_100) + assert manager.previous_saved_paths == [ckpt_200] + + # Third save + manager.ensure_checkpoint_capacity(1) + assert os.path.exists(ckpt_200), "Bug: checkpoint deleted before save!" + + ckpt_300 = self._create_checkpoint_dir(300) + manager.register_checkpoint(ckpt_300, 1) + assert not os.path.exists(ckpt_200) + assert manager.previous_saved_paths == [ckpt_300] diff --git a/code/RL_model/verl/verl_train/tests/utils/ckpt/test_esi_save_ckpt_on_cpu.py b/code/RL_model/verl/verl_train/tests/utils/ckpt/test_esi_save_ckpt_on_cpu.py new file mode 100644 index 0000000000000000000000000000000000000000..203494bd90bd9676fd615f5db5576e94c0219ee9 --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/utils/ckpt/test_esi_save_ckpt_on_cpu.py @@ -0,0 +1,70 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import time +from datetime import datetime, timedelta +from unittest import TestCase + +from verl.utils.checkpoint.checkpoint_manager import should_save_ckpt_esi + + +class TestShouldSaveCkptEsi(TestCase): + def test_no_expiration_timestamp(self): + """Test case when no expiration timestamp is set""" + os.environ.pop("MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP", None) + os.environ.pop("SAGEMAKER_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP", None) + self.assertFalse(should_save_ckpt_esi(100)) + + def test_mlp_expiration_valid(self): + """Test valid MLP expiration timestamp requiring save""" + current_time = time.time() + os.environ["MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP"] = str(current_time + 90) + self.assertTrue(should_save_ckpt_esi(30)) # max_steps_duration=30 seconds + + def test_mlp_expiration_passed(self): + """Test expired MLP timestamp""" + current_time = time.time() + os.environ["MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP"] = str(current_time - 10) + self.assertFalse(should_save_ckpt_esi(30)) + + def test_mlp_invalid_timestamp(self): + """Test invalid MLP timestamp format""" + os.environ["MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP"] = "invalid" + self.assertFalse(should_save_ckpt_esi(30)) + + def test_mlp_expiration_not_reached(self): + """Test MLP expiration timestamp with insufficient remaining time""" + current_time = time.time() + os.environ["MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP"] = str(current_time + 200) + self.assertFalse(should_save_ckpt_esi(30)) # max_steps_duration=30 + + def test_aws_expiration_not_reached(self): + """Test AWS expiration timestamp with sufficient remaining time""" + now = datetime.now() + expiration = now + timedelta(minutes=100) # Exceeds 90-minute threshold + os.environ["SAGEMAKER_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP"] = str(int(expiration.timestamp())) + self.assertFalse(should_save_ckpt_esi(30 * 60)) + + def test_redundant_time(self): + """Test redundant_time parameter effect""" + current_time = time.time() + # Total required: 60+30+30=120 seconds + os.environ["MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP"] = str(current_time + 120) + self.assertTrue(should_save_ckpt_esi(30, redundant_time=30)) + + def test_zero_max_steps_duration(self): + """Test zero max_steps_duration""" + current_time = time.time() + os.environ["MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP"] = str(current_time + 60) + self.assertFalse(should_save_ckpt_esi(0)) diff --git a/code/RL_model/verl/verl_train/tests/utils/dataset/test_create_rl_sampler_on_cpu.py b/code/RL_model/verl/verl_train/tests/utils/dataset/test_create_rl_sampler_on_cpu.py new file mode 100644 index 0000000000000000000000000000000000000000..35bf5a3ab5bd32544b2eec487e96ef61312766b9 --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/utils/dataset/test_create_rl_sampler_on_cpu.py @@ -0,0 +1,108 @@ +# Copyright 2025 Amazon.com Inc and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +test create_rl_sampler +""" + +from collections.abc import Sized + +import pytest +import torch +from omegaconf import DictConfig, OmegaConf +from torch.utils.data import Dataset, RandomSampler + +from verl.experimental.dataset.sampler import AbstractCurriculumSampler +from verl.trainer.main_ppo import create_rl_sampler + + +class RandomCurriculumSampler(AbstractCurriculumSampler): + def __init__( + self, + data_source: Sized, + data_config: DictConfig, + ): + train_dataloader_generator = torch.Generator() + train_dataloader_generator.manual_seed(1) + sampler = RandomSampler(data_source=data_source) + self.sampler = sampler + + def __iter__(self): + return self.sampler.__iter__() + + def __len__(self) -> int: + return len(self.sampler) + + def update(self, batch) -> None: + return + + +class MockIncorrectSampler: + """A fake sampler class that does not adhere to the AbstractCurriculumSampler interface.""" + + def __init__(self, data_source, data_config): + pass + + +class MockChatDataset(Dataset): + def __init__(self): + self.data = [ + {"prompt": "What's your name?", "response": "My name is Assistant."}, + {"prompt": "How are you?", "response": "I'm doing well, thank you."}, + {"prompt": "What is the capital of France?", "response": "Paris."}, + { + "prompt": "Tell me a joke.", + "response": "Why did the chicken cross the road? To get to the other side!", + }, + {"prompt": "What is 2+2?", "response": "4"}, + ] + + def __getitem__(self, index): + return self.data[index] + + def __len__(self): + return len(self.data) + + +def test_create_custom_curriculum_samper(): + data_config = OmegaConf.create( + { + "dataloader_num_workers": 0, + "sampler": { + "class_path": "pkg://tests.utils.dataset.test_create_rl_sampler_on_cpu", + "class_name": "RandomCurriculumSampler", + }, + } + ) + + dataset = MockChatDataset() + + # doesn't raise + create_rl_sampler(data_config, dataset) + + +def test_create_custom_curriculum_samper_wrong_class(): + data_config = OmegaConf.create( + { + "sampler": { + "class_path": "pkg://tests.utils.dataset.test_create_rl_sampler_on_cpu", + "class_name": "MockIncorrectSampler", + } + } + ) + + dataset = MockChatDataset() + + # MockIncorrectSampler is not an instance of AbstractCurriculumSampler, so raises + with pytest.raises(AssertionError): + create_rl_sampler(data_config, dataset) diff --git a/code/RL_model/verl/verl_train/tests/utils/dataset/test_multiturn_sft_dataset_on_cpu.py b/code/RL_model/verl/verl_train/tests/utils/dataset/test_multiturn_sft_dataset_on_cpu.py new file mode 100644 index 0000000000000000000000000000000000000000..a55417ce839446ffd52291e500cb6f182ba01c5f --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/utils/dataset/test_multiturn_sft_dataset_on_cpu.py @@ -0,0 +1,445 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Test the MultiTurnSFTDataset implementation +""" + +import os +from io import BytesIO +from pathlib import Path + +import pandas as pd +import pytest +import torch +from PIL import Image +from tensordict import TensorDict +from torch.utils.data import DistributedSampler +from torchdata.stateful_dataloader import StatefulDataLoader +from transformers import AutoProcessor, AutoTokenizer +from transformers.utils import get_json_schema + +from verl.utils.dataset.dataset_utils import DatasetPadMode, SFTTensorCollator +from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset +from verl.utils.model import extract_multi_modal_inputs + +custom_model_prefix = Path("~/models").expanduser().resolve() + + +@pytest.mark.parametrize( + "model_path", + [ + f"{custom_model_prefix}/Qwen/Qwen2.5-0.5B", + f"{custom_model_prefix}/Qwen/Qwen2.5-Coder-7B-Instruct", + f"{custom_model_prefix}/Qwen/Qwen3-30B-A3B-Instruct-2507", + # "Qwen/Qwen3-30B-A3B-Thinking-2507" # Thinking series models add tags to last turn. + ], +) +@pytest.mark.parametrize("enable_thinking", [False, True]) +def test_multiturn_sft_dataset(model_path: str, enable_thinking: bool): + print(f"Starting test... model_path={model_path}, enable_thinking={enable_thinking}") + # Create a temporary parquet file with test data + test_data = { + "messages": [ + [ + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "2+2 equals 4."}, + {"role": "user", "content": "And what is 4+4?"}, + {"role": "assistant", "content": "4+4 equals 8."}, + ], + [ + {"role": "system", "content": "You are a powerful assistant."}, + {"role": "user", "content": "Tell me a joke."}, + {"role": "assistant", "content": "Why did the chicken cross the road?"}, + {"role": "user", "content": "Why?"}, + {"role": "assistant", "content": "To get to the other side!"}, + ], + ] + } + + # Create test directory if it doesn't exist + os.makedirs("test_data", exist_ok=True) + test_file = "test_data/test.parquet" + + # Save test data to parquet + df = pd.DataFrame(test_data) + df.to_parquet(test_file) + + # Initialize tokenizer and dataset + tokenizer = AutoTokenizer.from_pretrained(model_path) + config = { + "max_length": 512, + "truncation": "error", + "multiturn": {"messages_key": "messages"}, + "apply_chat_template_kwargs": {"enable_thinking": enable_thinking}, + } + dataset = MultiTurnSFTDataset(parquet_files=test_file, tokenizer=tokenizer, config=config) + + # Test 1: Dataset Length + assert len(dataset) == 2, f"Expected dataset length 2, got {len(dataset)}" + + # Get items for testing + item0 = dataset[0] # Math conversation + item1 = dataset[1] # Joke conversation + + # Test 2: Required Keys and Types + required_keys = ["input_ids", "attention_mask", "position_ids", "loss_mask"] + for key in required_keys: + assert key in item0, f"Missing key {key} in dataset item" + assert isinstance(item0[key], torch.Tensor), f"Expected torch.Tensor for {key}" + assert item0[key].dtype == torch.long, f"Expected torch.long for {key}, got {item0[key].dtype}" + + # Test 3: Shape Consistency + assert item0["loss_mask"].shape == item0["input_ids"].shape, "Loss mask shape doesn't match input_ids shape" + assert item0["attention_mask"].shape == item0["input_ids"].shape, ( + "Attention mask shape doesn't match input_ids shape" + ) + assert item0["position_ids"].shape == item0["input_ids"].shape, "Position IDs shape doesn't match input_ids shape" + + # Test 4: Loss Mask Pattern - Math Conversation + loss_mask0 = item0["loss_mask"] + input_ids0 = item0["input_ids"] + + # Find assistant response positions + assistant_positions0 = torch.where(loss_mask0 == 1)[0] + assert len(assistant_positions0) > 0, "No assistant positions found in loss mask" + + # Decode and verify assistant responses + assistant_text0 = tokenizer.decode(input_ids0[loss_mask0 == 1]) + print(f"Math conversation assistant text: {assistant_text0}") + assert "2+2 equals 4" in assistant_text0, "First assistant response not found" + assert "4+4 equals 8" in assistant_text0, "Second assistant response not found" + + # Test 5: Loss Mask Pattern - Joke Conversation + loss_mask1 = item1["loss_mask"] + input_ids1 = item1["input_ids"] + + # Find assistant response positions + assistant_positions1 = torch.where(loss_mask1 == 1)[0] + assert len(assistant_positions1) > 0, "No assistant positions found in loss mask" + + # Decode and verify assistant responses + assistant_text1 = tokenizer.decode(input_ids1[loss_mask1 == 1]) + print(f"Joke conversation assistant text: {assistant_text1}") + assert "chicken cross the road" in assistant_text1, "First assistant response not found" + assert "other side" in assistant_text1, "Second assistant response not found" + + # Test 6: Attention Mask Pattern + attention_mask0 = item0["attention_mask"] + sequence_length = torch.sum(attention_mask0) + assert sequence_length > 0, "No tokens marked as attended in attention mask" + assert torch.all(attention_mask0[:sequence_length] == 1), "Incorrect attention mask pattern" + if sequence_length < len(attention_mask0): + assert torch.all(attention_mask0[sequence_length:] == 0), "Padding not properly masked" + + # Test 7: Position IDs Pattern + position_ids0 = item0["position_ids"] + assert torch.equal(position_ids0[:sequence_length], torch.arange(sequence_length)), ( + "Position IDs not sequential for non-padded tokens" + ) + if sequence_length < len(position_ids0): + assert torch.all(position_ids0[sequence_length:] == 0), "Padding position IDs not zero" + + # Test 8: Verify loss mask for assistant responses + # Get the full conversation text + full_text = tokenizer.decode(input_ids0) + print(f"\nFull conversation text:\n{full_text}") + + # Get the assistant responses + assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 1]) + print(f"\nAssistant responses (from loss mask):\n{assistant_text}") + + # Verify that loss mask is set for all assistant responses + for msg in test_data["messages"][0]: # First conversation + if msg["role"] == "assistant": + # The content should appear in the masked text + assert msg["content"] in assistant_text, f"Assistant message '{msg['content']}' not found in masked text" + + # The content should NOT appear in the non-masked text + non_assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 0]) + assert msg["content"] not in non_assistant_text, ( + f"Assistant message '{msg['content']}' found in non-assistant text" + ) + + # Test 9: Verify non-assistant parts have loss_mask=0 + # Get non-assistant text + non_assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 0]) + print(f"\nNon-assistant text (from loss mask):\n{non_assistant_text}") + + # Verify that system and user messages are in the non-assistant text + for msg in test_data["messages"][0]: # First conversation + if msg["role"] in ["system", "user"]: + assert msg["content"] in non_assistant_text, ( + f"{msg['role'].title()} message '{msg['content']}' not found in non-assistant text" + ) + + # And verify they're NOT in the assistant text + assert msg["content"] not in assistant_text, ( + f"{msg['role'].title()} message '{msg['content']}' found in assistant text" + ) + + # Test 10: Verify padding behavior + padding_config = {"max_length": 1024, "truncation": "error", "multiturn": {"messages_key": "messages"}} + small_dataset = MultiTurnSFTDataset(parquet_files=test_file, tokenizer=tokenizer, config=padding_config) + padded_item = small_dataset[0] + + # Get actual sequence length (before padding) + actual_length = torch.sum(padded_item["attention_mask"]) + + # Verify padding tokens + assert torch.all(padded_item["input_ids"][actual_length:] == tokenizer.pad_token_id), ( + "Padding tokens not set correctly" + ) + assert torch.all(padded_item["attention_mask"][actual_length:] == 0), "Attention mask not set correctly for padding" + assert torch.all(padded_item["loss_mask"][actual_length:] == 0), "Loss mask not set correctly for padding" + + # test no-padding + config = { + "max_length": 512, + "truncation": "error", + "multiturn": {"messages_key": "messages"}, + "pad_mode": "no_padding", + } + dataset = MultiTurnSFTDataset(parquet_files=test_file, tokenizer=tokenizer, config=config) + + item0 = dataset[0] + + # Verify that the output contains expected keys for no-padding mode + required_keys = ["input_ids", "position_ids", "loss_mask"] + for key in required_keys: + assert key in item0, f"Missing key {key} in no-padding mode dataset item" + assert isinstance(item0[key], torch.Tensor), f"Expected torch.Tensor for {key} in no-padding mode" + + # make sure assistant_text matches with expected + assistant_text = tokenizer.decode(item0["input_ids"][item0["loss_mask"] == 1]) + assert assistant_text == "2+2 equals 4.<|im_end|>\n4+4 equals 8.<|im_end|>\n" + + print("All tests passed!") + print("Starting test...") + + +def generate_image(description: str, size: str = "256x256"): + """Generate a simple image based on description. + + Args: + description: The description of the image to generate. + size: The size of the image. Defaults to "256x256". (choices: ["256x256", "512x512"]) + + Returns: + A generated image + """ + ... + + +@pytest.fixture +def vlm_data_file(): + test_data = [ + # sample 0: single turn with image input + { + "messages": [ + { + "role": "user", + "content": "Describe this image.", + }, + { + "role": "assistant", + "content": "The image is a red square.", + }, + ], + "images": [Image.new("RGB", (300, 300), color="red")], + "tools": [], + }, + # sample 1: single turn with multiple images input + { + "messages": [ + { + "role": "user", + "content": "Compare these images.", + }, + { + "role": "assistant", + "content": "The first image is a red square and the second image is a green square.", + }, + ], + "images": [Image.new("RGB", (100, 100), color="red"), Image.new("RGB", (100, 300), color="green")], + "tools": [], + }, + # sample 2: multi turn with image input and tool generated image + { + "messages": [ + { + "role": "user", + "content": "Describe this image.", + }, + { + "role": "assistant", + "content": "Let's generate a zoom-in image.", + "tool_calls": [ + { + "function": {"arguments": '{"bbox_2d": "[0, 1, 2, 4]"}', "name": "image_zoom_in_tool"}, + "type": "function", + } + ], + }, + { + "role": "tool", + "content": "Generated image.", + }, + {"role": "assistant", "content": "The zoom-in image is a red square."}, + ], + "images": [Image.new("RGB", (300, 500), color="red"), Image.new("RGB", (100, 100), color="red")], + "tools": [get_json_schema(generate_image)], + }, + # sample 3: single turn without image input + { + "messages": [ + {"role": "user", "content": "How is the weather today?"}, + {"role": "assistant", "content": "The weather is sunny."}, + ], + "images": [], + "tools": [], + }, + ] + + # Create test directory if it doesn't exist + os.makedirs("test_data", exist_ok=True) + test_file = "test_data/test_vlm.parquet" + + # Save test data to parquet + df = pd.DataFrame(test_data) + + def serialize_image(img): + if isinstance(img, Image.Image): + img_byte_arr = BytesIO() + img.save(img_byte_arr, format="PNG") + return {"bytes": img_byte_arr.getvalue()} + return img + + df["images"] = df["images"].apply(lambda x: [serialize_image(img) for img in x]) + + df.to_parquet(test_file) + return test_file + + +def test_multiturn_sft_vlm_dataset_on_cpu(vlm_data_file): + df = pd.read_parquet(vlm_data_file) + model_path = f"{custom_model_prefix}/Qwen/Qwen3-VL-2B-Instruct" + tokenizer = AutoTokenizer.from_pretrained(model_path) + processor = AutoProcessor.from_pretrained(model_path) + config = {"max_length": 512, "pad_mode": "no_padding", "truncation": "error", "messages_key": "messages"} + dataset = MultiTurnSFTDataset(parquet_files=vlm_data_file, tokenizer=tokenizer, config=config, processor=processor) + assert dataset.pad_mode == DatasetPadMode.NO_PADDING + + for i in range(len(dataset)): + item = dataset[i] + input_ids = item["input_ids"] + loss_mask = item["loss_mask"] + position_ids = item["position_ids"] + pixel_values = item.get("multi_modal_inputs", {}).get("pixel_values") + image_grid_thw = item.get("multi_modal_inputs", {}).get("image_grid_thw") + + assert input_ids.shape == loss_mask.shape, "Shapes of input_ids and loss_mask must be equal" + assert position_ids.dim() == 2, "position_ids must be 2-dimensional" + assert position_ids.shape[0] == 4, f"position_ids[0] should be 4: {position_ids[0]}" + assert position_ids.shape[1] == input_ids.shape[0] + + # 1. verify input_ids without assistant text + text = tokenizer.decode(input_ids[loss_mask == 0], skip_special_tokens=True) + print(f"Text without assistant: {repr(text)}") + for message in df["messages"][i]: + if message["role"] != "assistant": + content = message["content"].replace("", "") + assert content in text, f"user/tool text should be in the input_ids: {text}" + + # 2. verify input_ids with assistant text + text = tokenizer.decode(input_ids[loss_mask == 1], skip_special_tokens=True) + print(f"Text with assistant: {repr(text)}") + for message in df["messages"][i]: + if message["role"] == "assistant": + assert message["content"] in text, f"Assistant text should be in the input_ids: {text}" + assert "assistant" not in text, f"Assistant token should not be in the input_ids: {text}" + + # 3. verify image token match with image_grid_thw + if len(df["images"][i]) > 0: + patch_size = processor.image_processor.patch_size + temporal_patch_size = processor.image_processor.temporal_patch_size + merge_size = processor.image_processor.merge_size + num_patches = image_grid_thw.prod(dim=1).sum() + assert image_grid_thw.shape == (len(df["images"][i]), 3), ( + f"image_grid_thw: {image_grid_thw.shape} should have shape ({len(df['images'][i])}, 3)" + ) + assert pixel_values.shape == (num_patches, 3 * temporal_patch_size * patch_size * patch_size), ( + f"pixel_values: {pixel_values.shape} should have shape ({num_patches}, {3 * patch_size * patch_size})" + ) + assert (input_ids == processor.image_token_id).sum() == num_patches // (merge_size**2) + else: + assert pixel_values is None, "pixel_values should be None when no image is provided" + assert image_grid_thw is None, "image_grid_thw should be None when no image is provided" + + +def test_multiturn_sft_vlm_dataloader_on_cpu(vlm_data_file): + df = pd.read_parquet(vlm_data_file) + model_path = f"{custom_model_prefix}/Qwen/Qwen3-VL-2B-Instruct" + tokenizer = AutoTokenizer.from_pretrained(model_path) + processor = AutoProcessor.from_pretrained(model_path) + config = {"max_length": 512, "pad_mode": "no_padding", "truncation": "error", "messages_key": "messages"} + dataset = MultiTurnSFTDataset(parquet_files=vlm_data_file, tokenizer=tokenizer, config=config, processor=processor) + assert dataset.pad_mode == DatasetPadMode.NO_PADDING + + collate_fn = SFTTensorCollator(DatasetPadMode.NO_PADDING) + sampler = DistributedSampler(dataset, shuffle=False, num_replicas=1, rank=0, drop_last=True) + batch_size = 2 + dataloader = StatefulDataLoader( + dataset=dataset, + batch_size=batch_size, + sampler=sampler, + collate_fn=collate_fn, + num_workers=0, + pin_memory=False, + drop_last=True, + ) + + for i, batch in enumerate(dataloader): + # 1. verify input_ids, loss_mask + input_ids = batch["input_ids"] + loss_mask = batch["loss_mask"] + assert input_ids.is_nested, "input_ids should be a nested tensor" + assert loss_mask.is_nested, "loss_mask should be a nested tensor" + assert input_ids.shape[0] == loss_mask.shape[0] == batch_size, "Shapes of input_ids, loss_mask must be equal" + + # 2. verify position_ids: (bs, 4, seq_len) + position_ids = batch["position_ids"] + assert position_ids.is_nested, "position_ids should be a nested tensor" + assert position_ids.dim() == 3, "position_ids must be 3-dimensional" + assert position_ids.shape[0] == batch_size + assert position_ids.shape[1] == 4 + values = position_ids.values() + assert values.shape == (4, len(input_ids.values())) + + # 3. verify multi-modal data + td = TensorDict(**batch, batch_size=batch_size) + multi_modal_inputs = extract_multi_modal_inputs(td["multi_modal_inputs"]) + pixel_values = multi_modal_inputs["pixel_values"] + image_grid_thw = multi_modal_inputs["image_grid_thw"] + + num_images = sum([len(images) for images in df["images"][i * batch_size : (i + 1) * batch_size]]) + assert image_grid_thw.shape == (num_images, 3), ( + f"image_grid_thw: {image_grid_thw.shape} should have shape ({num_images}, 3)" + ) + patch_size = processor.image_processor.patch_size + temporal_patch_size = processor.image_processor.temporal_patch_size + num_patches = image_grid_thw.prod(dim=1).sum() + assert pixel_values.shape[0] == num_patches, ( + f"pixel_values: {pixel_values.shape} should have shape " + f"({num_patches}, 3 * {temporal_patch_size} * {patch_size} * {patch_size})" + ) diff --git a/code/RL_model/verl/verl_train/tests/utils/dataset/test_rl_collate_fn_on_cpu.py b/code/RL_model/verl/verl_train/tests/utils/dataset/test_rl_collate_fn_on_cpu.py new file mode 100644 index 0000000000000000000000000000000000000000..415595295e7fde5d4de648284091bc87c53b4a10 --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/utils/dataset/test_rl_collate_fn_on_cpu.py @@ -0,0 +1,72 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + + +def test_rl_collate_fn(): + from verl.utils.dataset.rl_dataset import collate_fn + + max_prompt_length = 5 + + test_data = [ + { + # test tensor + "input_ids": torch.randint(0, 10, (max_prompt_length,)), + # test fixed length (1) list within a batch + "messages": [{"role": "user", "content": "Hi."}], + # test variable length list within a batch + "raw_prompt_ids": [1, 2, 3, 4], + # test string + "ability": "math", + # test dict + "reward_model": {"ground_truth": 5, "style": "rule"}, + # test empty dict + "tools_kwargs": {}, + }, + { + "input_ids": torch.randint(0, 10, (max_prompt_length,)), + "messages": [{"role": "user", "content": "Hello."}], + "raw_prompt_ids": [1, 2, 3], + "ability": "toolcall", + "reward_model": { + "ground_truth": '[{"name": "rgb_to_cmyk", "arguments": {"r": 0, "g": 0, "b": 255}}]', + "style": "rule", + }, + "tools_kwargs": {}, + }, + ] + + batch_size = len(test_data) + batch = collate_fn(test_data) + + # Tensor part + assert batch["input_ids"].shape == (batch_size, max_prompt_length) + assert isinstance(batch["input_ids"], torch.Tensor) + + # Non-tensor parts + expected_types = { + "messages": list, + "raw_prompt_ids": list, + "ability": str, + "reward_model": dict, + "tools_kwargs": dict, + } + + for key, dtype in expected_types.items(): + assert batch[key].shape == (batch_size,), ( + f"Expected shape {(batch_size,)} for '{key}', but got {batch[key].shape}" + ) + assert isinstance(batch[key][0], dtype), ( + f"'{key}' should contain elements of type {dtype}, but got {type(batch[key][0])}" + ) diff --git a/code/RL_model/verl/verl_train/tests/utils/dataset/test_rl_dataset_on_cpu.py b/code/RL_model/verl/verl_train/tests/utils/dataset/test_rl_dataset_on_cpu.py new file mode 100644 index 0000000000000000000000000000000000000000..05ebdbab98ce217da3c84aaf450c5dd2fe1b5abf --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/utils/dataset/test_rl_dataset_on_cpu.py @@ -0,0 +1,197 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os + +import pytest +import torch +from omegaconf import OmegaConf +from PIL import Image +from torch.utils.data import DataLoader + +from verl import DataProto +from verl.utils import hf_processor, hf_tokenizer +from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn + + +def get_gsm8k_data(): + # prepare test dataset + local_folder = os.path.expanduser("~/data/gsm8k/") + local_path = os.path.join(local_folder, "train.parquet") + os.makedirs(local_folder, exist_ok=True) + return local_path + + +def test_rl_dataset(): + tokenizer = hf_tokenizer(os.path.expanduser("~/models/deepseek-ai/deepseek-coder-1.3b-instruct")) + local_path = get_gsm8k_data() + config = OmegaConf.create( + { + "prompt_key": "prompt", + "max_prompt_length": 256, + "filter_overlong_prompts": True, + "filter_overlong_prompts_workers": 2, + } + ) + dataset = RLHFDataset(data_files=local_path, tokenizer=tokenizer, config=config) + + dataloader = DataLoader(dataset=dataset, batch_size=16, shuffle=True, drop_last=True, collate_fn=collate_fn) + + a = next(iter(dataloader)) + + tensors = {} + non_tensors = {} + + for key, val in a.items(): + if isinstance(val, torch.Tensor): + tensors[key] = val + else: + non_tensors[key] = val + + data_proto = DataProto.from_dict(tensors=tensors, non_tensors=non_tensors) + assert len(data_proto) == 16 + assert "raw_prompt" in data_proto.non_tensor_batch + + +def test_rl_dataset_with_max_samples(): + tokenizer = hf_tokenizer(os.path.expanduser("~/models/deepseek-ai/deepseek-coder-1.3b-instruct")) + local_path = get_gsm8k_data() + config = OmegaConf.create( + { + "prompt_key": "prompt", + "max_prompt_length": 256, + "filter_overlong_prompts": True, + "filter_overlong_prompts_workers": 2, + "max_samples": 5, + } + ) + dataset = RLHFDataset(data_files=local_path, tokenizer=tokenizer, config=config, max_samples=5) + assert len(dataset) == 5 + + +def test_image_rl_data(): + tokenizer = hf_tokenizer(os.path.expanduser("~/models/Qwen/Qwen2-VL-2B-Instruct")) + processor = hf_processor(os.path.expanduser("~/models/Qwen/Qwen2-VL-2B-Instruct")) + config = OmegaConf.create( + { + "prompt_key": "prompt", + "max_prompt_length": 1024, + "filter_overlong_prompts": True, + "filter_overlong_prompts_workers": None, # num_workers=1 hang in ci + } + ) + dataset = RLHFDataset( + data_files=os.path.expanduser("~/data/geo3k/train.parquet"), + tokenizer=tokenizer, + config=config, + processor=processor, + ) + + dataloader = DataLoader(dataset=dataset, batch_size=16, shuffle=True, drop_last=True, collate_fn=collate_fn) + + a = next(iter(dataloader)) + + tensors = {} + non_tensors = {} + + for key, val in a.items(): + if isinstance(val, torch.Tensor): + tensors[key] = val + else: + non_tensors[key] = val + + data_proto = DataProto.from_dict(tensors=tensors, non_tensors=non_tensors) + assert len(data_proto) == 16 + assert "images" not in data_proto.non_tensor_batch + + for prompt in data_proto.non_tensor_batch["raw_prompt"]: + assert len(prompt) == 1 + prompt = prompt[0] + role, content = prompt["role"], prompt["content"] + assert role == "user" + assert len(content) == 2 + assert content[0]["type"] == "image" and isinstance(content[0]["image"], Image.Image) + assert content[1]["type"] == "text" and isinstance(content[1]["text"], str) + + print("raw_prompt", data_proto.non_tensor_batch["raw_prompt"][0]) + + +@pytest.fixture +def video_data_file(): + data = [ + { + "problem_id": 17, + "problem": "How does the crowd's excitement change as the match progresses?", + "data_type": "video", + "prompt": [ + { + "role": "user", + "content": [ + {"type": "video", "video": "LLaVA-Video-178K/academic_source/activitynet/v_2g9GrshWQrU.mp4"}, + { + "type": "text", + "text": "How does the crowd's excitement change as the match progresses? " + "A. It fluctuates; B. It decreases; C. It builds up; D. It remains the same. " + "Put your answer in ", + }, + ], + } + ], + "problem_type": "multiple choice", + "solution": "C", + "data_source": "LLaVA-Video-178K/2_3_m_academic_v0_1", + } + ] * 30 + + # Create test directory if it doesn't exist + os.makedirs("test_data", exist_ok=True) + test_file = "test_data/test_video.json" + with open(test_file, "w") as f: + json.dump(data, f, indent=2) + + return test_file + + +def test_video_rl_data(video_data_file): + tokenizer = hf_tokenizer(os.path.expanduser("~/models/Qwen/Qwen2-VL-2B-Instruct")) + processor = hf_processor(os.path.expanduser("~/models/Qwen/Qwen2-VL-2B-Instruct")) + config = OmegaConf.create( + { + "prompt_key": "prompt", + "max_prompt_length": 1024, + "filter_overlong_prompts": False, + } + ) + dataset = RLHFDataset( + data_files=video_data_file, + tokenizer=tokenizer, + config=config, + processor=processor, + ) + + dataloader = DataLoader(dataset=dataset, batch_size=16, shuffle=True, drop_last=True, collate_fn=collate_fn) + batch = next(iter(dataloader)) + tensors = {} + non_tensors = {} + for key, val in batch.items(): + if isinstance(val, torch.Tensor): + tensors[key] = val + else: + non_tensors[key] = val + + data_proto = DataProto.from_dict(tensors=tensors, non_tensors=non_tensors) + assert len(data_proto) == 16 + assert "images" not in data_proto.non_tensor_batch + + print("raw_prompt", data_proto.non_tensor_batch["raw_prompt"][0]) diff --git a/code/RL_model/verl/verl_train/tests/utils/dataset/test_sft_dataset_on_cpu.py b/code/RL_model/verl/verl_train/tests/utils/dataset/test_sft_dataset_on_cpu.py new file mode 100644 index 0000000000000000000000000000000000000000..be91b598091727b18462dae7cc46d580bdf9660e --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/utils/dataset/test_sft_dataset_on_cpu.py @@ -0,0 +1,97 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from verl.utils import hf_tokenizer +from verl.utils.dataset.sft_dataset import SFTDataset + + +def get_gsm8k_data(): + # prepare test dataset + local_folder = os.path.expanduser("~/data/gsm8k/") + local_path = os.path.join(local_folder, "train.parquet") + return local_path + + +def test_sft_cot_dataset(): + tokenizer = hf_tokenizer(os.path.expanduser("~/models/deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct")) + local_path = get_gsm8k_data() + from omegaconf import OmegaConf + + dataset = SFTDataset( + parquet_files=local_path, + tokenizer=tokenizer, + config=OmegaConf.create( + { + "prompt_key": "prompt", + "prompt_dict_keys": ["content"], + "response_key": "extra_info", + "response_dict_keys": ["answer"], + "max_length": 512, + } + ), + ) + + data = dataset[0]["input_ids"] + output = tokenizer.batch_decode([data])[0] + assert len(output) > 1 + assert isinstance(output, str) + + +def test_sft_dataset(): + tokenizer = hf_tokenizer(os.path.expanduser("~/models/deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct")) + local_path = get_gsm8k_data() + from omegaconf import OmegaConf + + dataset = SFTDataset( + parquet_files=local_path, + tokenizer=tokenizer, + config=OmegaConf.create( + { + "prompt_key": "extra_info", + "prompt_dict_keys": ["question"], + "response_key": "extra_info", + "response_dict_keys": ["answer"], + "max_length": 512, + } + ), + ) + + data = dataset[0]["input_ids"] + output = tokenizer.batch_decode([data])[0] + assert len(output) > 1 + assert isinstance(output, str) + + +def test_sft_dataset_with_max_samples(): + tokenizer = hf_tokenizer(os.path.expanduser("~/models/deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct")) + local_path = get_gsm8k_data() + from omegaconf import OmegaConf + + dataset = SFTDataset( + parquet_files=local_path, + tokenizer=tokenizer, + config=OmegaConf.create( + { + "prompt_key": "extra_info", + "prompt_dict_keys": ["question"], + "response_key": "extra_info", + "response_dict_keys": ["answer"], + "max_length": 512, + } + ), + max_samples=5, + ) + + assert len(dataset) == 5 diff --git a/code/RL_model/verl/verl_train/tests/utils/debug/test_metrics.py b/code/RL_model/verl/verl_train/tests/utils/debug/test_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..1b2f7f8faa17dd024df4b92c3a3b1b81d48923e0 --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/utils/debug/test_metrics.py @@ -0,0 +1,48 @@ +# Copyright 2025 Individual Contributor: TomQunChaoA +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from verl.protocol import DataProto +from verl.utils.debug.metrics import calculate_debug_metrics + + +class TestMetrics(unittest.TestCase): + def test_calculate_debug_metrics(self): + data = DataProto.from_dict( + { + "rollout_log_probs": torch.tensor( + [ + [-1.5085, -0.1200, -0.6650, -0.4823, -0.1426, -1.5557, -2.8532, -0.3919, -0.4294, -0.4700], + [-0.0585, -0.0573, -0.4681, -0.5187, -0.7451, -1.2737, -0.0682, -0.4284, -0.5754, -0.0611], + ] + ), + "old_log_probs": torch.tensor( + [ + [-1.8636, -0.7863, -0.2136, -0.4376, -2.0257, -0.2579, -1.1547, -0.5203, -0.3802, -0.9872], + [-0.3507, -0.5426, -0.2725, -0.4637, -0.3577, -0.3733, -1.7560, -1.9542, -0.4229, -1.3098], + ] + ), + "loss_mask": torch.tensor([[1, 0, 0, 0, 1, 1, 0, 1, 1, 0], [1, 0, 1, 0, 1, 1, 1, 0, 1, 1]]), + "responses": torch.zeros((2, 10)), + } + ) + metrics = calculate_debug_metrics(data) + print(metrics) + assert metrics["training/rollout_probs_diff_valid"] == 1 + + +if __name__ == "__main__": + unittest.main() diff --git a/code/RL_model/verl/verl_train/tests/utils/megatron/test_pipeline_parallel.py b/code/RL_model/verl/verl_train/tests/utils/megatron/test_pipeline_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..24a416987dae68089a3d26d18f34d5defbd14245 --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/utils/megatron/test_pipeline_parallel.py @@ -0,0 +1,70 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from verl.model_merger.megatron_model_merger import get_dynamic_pipeline_shards +from verl.utils.megatron.pipeline_parallel import make_batch_generator + + +def test_make_batch_generator_no_vpp(): + batches = [1, 2, 3] + vpp_size = 1 + generator = make_batch_generator(batches, vpp_size) + assert list(generator) == batches + + +def test_make_batch_generator_with_vpp(): + batches = [{"data": 1}, {"data": 2}] + vpp_size = 2 + generators = make_batch_generator(batches, vpp_size) + assert isinstance(generators, list) + assert len(generators) == vpp_size + + # Check each generator yields the original batches + for gen in generators: + assert list(gen) == batches + + +def test_make_batch_generator_empty(): + batches = [] + vpp_size = 1 + generator = make_batch_generator(batches, vpp_size) + assert list(generator) == [] + + vpp_size = 3 + generators = make_batch_generator(batches, vpp_size) + assert len(generators) == vpp_size + for gen in generators: + assert list(gen) == [] + + +@pytest.mark.parametrize( + "layer_num,pp_size,gt", + [ + (61, 8, [6, 8, 8, 8, 8, 8, 8, 7]), + (61, 7, [8, 9, 9, 9, 9, 9, 8]), + (61, 1, [61]), + (61, 0, ValueError), + (10, 16, ValueError), + ], +) +def test_get_dynamic_pipeline_shards(layer_num, pp_size, gt): + if isinstance(gt, list): + shards = get_dynamic_pipeline_shards(layer_num, pp_size) + assert len(shards) == len(gt) == pp_size, f"Expected {pp_size} shards, got {len(shards)}" + assert all([shard == gt[i] for i, shard in enumerate(shards)]), f"Expected shards {gt}, got {shards}" + elif issubclass(gt, Exception): + with pytest.raises(gt): + shards = get_dynamic_pipeline_shards(layer_num, pp_size) diff --git a/code/RL_model/verl/verl_train/tests/utils/reward_score/reward_score/test_sandbox_fusion_on_cpu.py b/code/RL_model/verl/verl_train/tests/utils/reward_score/reward_score/test_sandbox_fusion_on_cpu.py new file mode 100644 index 0000000000000000000000000000000000000000..83aed24d054ddce33bc8fd311de2705fcca24776 --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/utils/reward_score/reward_score/test_sandbox_fusion_on_cpu.py @@ -0,0 +1,747 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import multiprocessing +import os +import time +from concurrent.futures import ProcessPoolExecutor +from unittest.mock import patch + +import pytest + +# Import the function to be tested +from verl.utils.reward_score.sandbox_fusion.utils import check_correctness + +# Get SANDBOX_URL from environment variable +SANDBOX_URL = os.environ.get("SANDBOX_FUSION_URL") +# Define skip condition and reason +skip_reason = "SANDBOX_FUSION_URL environment variable not set" +skip_condition = not SANDBOX_URL + +# --- Test code (for real API calls) --- +CODE_SUCCESS = """ +import sys +data = sys.stdin.read() +if data == 'input1': + print('output1\\n', end='') +elif data == 'input2': + print('output2\\n', end='') +else: + print('unexpected input', end='') +""" + +CODE_WRONG_OUTPUT = """ +print('wrong_output\\n', end='') +""" + +CODE_COMPILE_ERROR = """ +a=b +""" + +CODE_RUNTIME_ERROR = """ +import sys +print("About to raise error", file=sys.stderr) +raise ValueError("This is a runtime error") +""" + +CODE_TIMEOUT = """ +import time +import sys +print("Sleeping...", file=sys.stderr) +time.sleep(10) # Sleep time should be longer than the timeout set in the test +print("Finished sleeping", file=sys.stderr) +""" + +# --- Test input/output data --- +INPUT_OUTPUT_VALID = {"inputs": ["input1", "input2"], "outputs": ["output1\n", "output2\n"]} + +INPUT_OUTPUT_SINGLE = {"inputs": ["input1"], "outputs": ["output1\n"]} + +INPUT_OUTPUT_MISMATCH = {"inputs": ["input1"], "outputs": ["output1\n", "output2\n"]} + +INPUT_OUTPUT_INVALID_MISSING_KEY = {"inputs": ["input1"]} + +# --- Integration test cases (calling real API) --- + + +@pytest.mark.skipif(skip_condition, reason=skip_reason) +def test_integration_success_correct(): + """Integration test: Code is correct, output is correct""" + results, metadata_list = check_correctness(SANDBOX_URL, INPUT_OUTPUT_VALID, CODE_SUCCESS) + assert results == [True, True] + assert metadata_list[0]["status"] == "success" + assert metadata_list[0]["stdout"] == "output1\n" + assert metadata_list[1]["status"] == "success" + assert metadata_list[1]["stdout"] == "output2\n" + + +@pytest.mark.skipif(skip_condition, reason=skip_reason) +def test_integration_success_wrong_output(): + """Integration test: Code runs successfully, but output is wrong""" + results, metadata_list = check_correctness(SANDBOX_URL, INPUT_OUTPUT_VALID, CODE_WRONG_OUTPUT) + assert results == [False, False] + assert metadata_list[0]["status"] == "wrong_answer" + assert metadata_list[0]["stdout"] == "wrong_output\n" + assert metadata_list[1]["status"] == "wrong_answer" + + +@pytest.mark.skipif(skip_condition, reason=skip_reason) +def test_integration_compile_error(): + """Integration test: Code causes compile error""" + results, metadata_list = check_correctness(SANDBOX_URL, INPUT_OUTPUT_VALID, CODE_COMPILE_ERROR, language="cpp") + assert results == [-4, -4] + assert metadata_list[0]["status"] == "compile_error" + assert metadata_list[1]["status"] == "compile_error" + + +@pytest.mark.skipif(skip_condition, reason=skip_reason) +def test_integration_runtime_error(): + """Integration test: Code causes runtime error""" + results, metadata_list = check_correctness(SANDBOX_URL, INPUT_OUTPUT_SINGLE, CODE_RUNTIME_ERROR) + assert results == [-2] + assert metadata_list[0]["status"] == "runtime_error" + # More assertions can be added based on the actual API response, e.g., exit_code, stderr + + +@pytest.mark.skipif(skip_condition, reason=skip_reason) +def test_integration_runtime_timeout(): + """Integration test: Code causes runtime timeout""" + test_timeout = 5 # Set a timeout shorter than the sleep time in CODE_TIMEOUT + results, metadata_list = check_correctness(SANDBOX_URL, INPUT_OUTPUT_SINGLE, CODE_TIMEOUT, timeout=test_timeout) + assert results == [-3] + assert metadata_list[0]["status"] == "timeout" + # More assertions can be added based on the actual API response, e.g., run_status + + +@pytest.mark.skipif(skip_condition, reason=skip_reason) +def test_integration_concurrency_high_load(): + """Integration test: High concurrency (100 cases) against real API with mixed results (success, wrong + answer, timeout)""" + concurrency_level = 100 + # Indices for different expected outcomes + wrong_answer_indices = {10, 25, 50} + timeout_indices = {5, 30, 60, 90} # Indices where we expect a timeout + + # Generate 100 input/output pairs and code + high_load_inputs = [] + high_load_outputs = [] + expected_results_map = {} # Store expected result for each index + + for i in range(concurrency_level): + if i in timeout_indices: + # Use a special input to trigger timeout in the code + high_load_inputs.append(f"input_timeout_{i}") + # Output doesn't matter for timeout, but keep it consistent + high_load_outputs.append(f"output_{i}\n") + expected_results_map[i] = -3 # Expect timeout + elif i in wrong_answer_indices: + high_load_inputs.append(f"input_{i}") + # Intentionally set wrong expected output + high_load_outputs.append(f"wrong_output_{i}\n") + expected_results_map[i] = False # Expect wrong answer + else: + high_load_inputs.append(f"input_{i}") + # Correct expected output + high_load_outputs.append(f"output_{i}\n") + expected_results_map[i] = True # Expect success + + high_load_in_outs = {"inputs": high_load_inputs, "outputs": high_load_outputs} + + # Code that handles normal inputs, and sleeps on specific "timeout" inputs + code_mixed_concurrent = """ +import sys +import time +data = sys.stdin.read() +if data.startswith('input_timeout_'): + time.sleep(20) # Sleep longer than the test timeout + print(f"output_{data.split('_')[-1]}\\n", end='') # Still print something in case it finishes early +elif data.startswith('input_'): + print(f"output_{data.split('_')[-1]}\\n", end='') +else: + print("unknown_input\\n", end='') +""" + # Set a reasonable timeout per case (must be less than the sleep time in the code) + test_timeout = 15 # Allow slightly more time due to potential API load, but less than 20s sleep + + start_time = time.time() + results, metadata_list = check_correctness( + SANDBOX_URL, + high_load_in_outs, + code_mixed_concurrent, # Use the new code + timeout=test_timeout, + ) + end_time = time.time() + duration = end_time - start_time + print( + f"\nHigh concurrency test ({concurrency_level} cases with {len(wrong_answer_indices)} wrong answers, " + f"{len(timeout_indices)} timeouts) duration: {duration:.2f} seconds" + ) + + # Verify results against the expected map + assert len(results) == concurrency_level, f"Expected {concurrency_level} results, got {len(results)}" + + correct_count = 0 + wrong_count = 0 + timeout_count = 0 + unexpected_results = [] + for i, r in enumerate(results): + expected = expected_results_map[i] + if r == expected: + if expected is True: + correct_count += 1 + elif expected is False: + wrong_count += 1 + elif expected == -3: + timeout_count += 1 + else: + unexpected_results.append((i, r, f"Expected {expected}")) + + print( + f"Correct results (True): {correct_count}/" + f"{concurrency_level - len(wrong_answer_indices) - len(timeout_indices)}" + ) + print(f"Expected wrong answers (False, correctly identified): {wrong_count}/{len(wrong_answer_indices)}") + print(f"Expected timeouts (-3, correctly identified): {timeout_count}/{len(timeout_indices)}") + + if unexpected_results: + print("Unexpected results found:") + for idx, res, expected_str in unexpected_results[:10]: # Print first 10 unexpected + print(f" Index {idx}: Got {res}, {expected_str}. Metadata: {metadata_list[idx]}") + raise AssertionError(f"Found {len(unexpected_results)} unexpected results.") + + assert correct_count == concurrency_level - len(wrong_answer_indices) - len(timeout_indices), ( + "Incorrect number of successful results" + ) + assert wrong_count == len(wrong_answer_indices), "Incorrect number of identified wrong answers" + assert timeout_count == len(timeout_indices), "Incorrect number of identified timeouts" + + # Verify metadata count and basic status of one of each type + assert len(metadata_list) == concurrency_level + # Find the first correct index + first_correct_index = next( + i for i in range(concurrency_level) if i not in wrong_answer_indices and i not in timeout_indices + ) + assert metadata_list[first_correct_index]["status"] == "success" + assert metadata_list[first_correct_index]["stdout"] == f"output_{first_correct_index}\n" + + # Check the status of the first intentionally wrong case + first_wrong_index = min(wrong_answer_indices) + assert metadata_list[first_wrong_index]["status"] == "wrong_answer" + assert metadata_list[first_wrong_index]["stdout"] == f"output_{first_wrong_index}\n" + assert metadata_list[first_wrong_index]["expected_output"] == f"wrong_output_{first_wrong_index}\n" + + # Check the status of the first intentionally timeout case + first_timeout_index = min(timeout_indices) + assert metadata_list[first_timeout_index]["status"] == "timeout" + # For timeout, stdout might be None or empty depending on when the timeout occurred + # assert metadata_list[first_timeout_index]["stdout"] is None or metadata_list[first_timeout_index]["stdout"] == "" + + +# --- Unit test cases (using mock) --- + + +@patch("verl.utils.reward_score.sandbox_fusion.utils.call_sandbox_api") +def test_unit_concurrency_order(mock_call_sandbox_api): + sandbox_url = "mock_url" + generation = "print(input())" + language = "python" + timeout = 5 + in_outs = {"inputs": ["input1", "input2", "input3"], "outputs": ["output1", "output2", "output3"]} + + def side_effect(*args, **kwargs): + stdin = kwargs.get("stdin") + if stdin == "input1": + return ( + {"status": "Success", "run_result": {"status": "Finished", "stdout": "output1", "return_code": 0}}, + None, + ) + elif stdin == "input2": + time.sleep(0.1) + return ( + {"status": "Success", "run_result": {"status": "Finished", "stdout": "output2", "return_code": 0}}, + None, + ) + elif stdin == "input3": + return ( + {"status": "Success", "run_result": {"status": "Finished", "stdout": "output3", "return_code": 0}}, + None, + ) + else: + return (None, "Unknown input in mock") + + mock_call_sandbox_api.side_effect = side_effect + + results, metadata_list = check_correctness(sandbox_url, in_outs, generation, timeout, language) + + assert results == [True, True, True] + assert len(metadata_list) == 3 + assert metadata_list[0]["case_index"] == 0 + assert metadata_list[0]["status"] == "success" + assert metadata_list[1]["case_index"] == 1 + assert metadata_list[1]["status"] == "success" + assert metadata_list[2]["case_index"] == 2 + assert metadata_list[2]["status"] == "success" + assert mock_call_sandbox_api.call_count == 3 + + +@patch("verl.utils.reward_score.sandbox_fusion.utils.call_sandbox_api") +def test_unit_api_timeout_error_concurrent(mock_call_sandbox_api): + sandbox_url = "mock_url" + generation = "print(input())" + language = "python" + timeout = 5 + in_outs = {"inputs": ["input1", "input2_timeout", "input3"], "outputs": ["output1", "output2", "output3"]} + + api_error_message = "API Call Failed: Gateway Timeout (504) on attempt 3/3" + + def side_effect(*args, **kwargs): + stdin = kwargs.get("stdin") + if stdin == "input1": + return ( + {"status": "Success", "run_result": {"status": "Finished", "stdout": "output1", "return_code": 0}}, + None, + ) + elif stdin == "input2_timeout": + return (None, api_error_message) + elif stdin == "input3": + return ( + {"status": "Success", "run_result": {"status": "Finished", "stdout": "output3", "return_code": 0}}, + None, + ) + else: + return (None, "Unknown input in mock") + + mock_call_sandbox_api.side_effect = side_effect + + results, metadata_list = check_correctness(sandbox_url, in_outs, generation, timeout, language) + + assert results == [True, -1, True] + assert len(metadata_list) == 3 + assert metadata_list[0]["status"] == "success" + assert metadata_list[1]["status"] == "api_error" + assert metadata_list[1]["api_request_error"] == api_error_message + assert metadata_list[2]["status"] == "success" + assert mock_call_sandbox_api.call_count == 3 + + +# --- Constants for the new concurrency test --- +# Define a low global concurrency limit to test the semaphore's effect +MAX_GLOBAL_CONCURRENCY_LIMIT_TEST = 5 +# Define the number of processes used in the test +NUM_PROCESSES_TEST = 4 +# Define the number of tasks processed by check_correctness in each process (i.e., internal +# ThreadPoolExecutor's concurrency potential) +NUM_TASKS_PER_PROCESS_TEST = 3 +# Simulate API call duration to ensure calls can overlap +SIMULATED_API_CALL_DURATION_TEST = 0.2 # seconds + + +# --- Mock API call function for concurrency tracking --- +# This function will replace the real call_sandbox_api and use shared variables to track concurrency +def _mock_api_call_for_concurrency_tracking( + active_calls_counter, # multiprocessing.Value + max_calls_tracker, # multiprocessing.Value + call_lock, # multiprocessing.Lock + # Standard call_sandbox_api parameters + sandbox_fusion_url, + code, + stdin, + compile_timeout, + run_timeout, + memory_limit_mb, + language, +): + # entry_time = time.time() # For detailed logging + with call_lock: + active_calls_counter.value += 1 + if active_calls_counter.value > max_calls_tracker.value: + max_calls_tracker.value = active_calls_counter.value + # Optional debug log: + # print(f"[PID:{os.getpid()}-TID:{threading.get_ident()}] API Call Start. Active: " + # f"{active_calls_counter.value}, Max Observed: {max_calls_tracker.value}, Input: {stdin}") + + time.sleep(SIMULATED_API_CALL_DURATION_TEST) # Simulate actual work duration + + # exit_time = time.time() # For detailed logging + with call_lock: + active_calls_counter.value -= 1 + # Optional debug log: + # print(f"[PID:{os.getpid()}-TID:{threading.get_ident()}] API Call End. Active: " + # f"{active_calls_counter.value}, Input: {stdin}, Duration: {exit_time - entry_time:.2f}s") + + # Return a simulated successful API response + return { + "status": "Success", + "run_result": {"status": "Finished", "stdout": f"mock_output_for_{stdin}", "return_code": 0}, + }, None + + +# --- Worker function for ProcessPoolExecutor --- +# This function runs in each child process of ProcessPoolExecutor +def _process_pool_worker_for_concurrency_test( + sandbox_url, + in_outs, + generation, + memory_limit_mb, + language, + timeout, + mp_semaphore_for_check_correctness, + active_calls_counter, + max_calls_tracker, + call_lock, +): + # Corrected lambda to accept keyword arguments matching call_sandbox_api's usage + curried_mock_api_call = ( + lambda sandbox_fusion_url, code, stdin, compile_timeout, run_timeout, memory_limit_mb, language: ( + _mock_api_call_for_concurrency_tracking( + active_calls_counter, + max_calls_tracker, + call_lock, + sandbox_fusion_url, + code, + stdin, + compile_timeout, + run_timeout, + memory_limit_mb, + language, + ) + ) + ) + + # ---- START DEBUG PRINTS ---- + import os + + import verl.utils.reward_score.sandbox_fusion.utils + + print( + f"[Worker PID:{os.getpid()}] Original call_sandbox_api: " + f"{verl.utils.reward_score.sandbox_fusion.utils.call_sandbox_api}", + flush=True, + ) + # ---- END DEBUG PRINTS ---- + + with patch( + "verl.utils.reward_score.sandbox_fusion.utils.call_sandbox_api", side_effect=curried_mock_api_call + ) as mock_obj: + # ---- START DEBUG PRINTS ---- + print( + f"[Worker PID:{os.getpid()}] Patched call_sandbox_api: " + f"{verl.utils.reward_score.sandbox_fusion.utils.call_sandbox_api}", + flush=True, + ) + print(f"[Worker PID:{os.getpid()}] Mock object: {mock_obj}", flush=True) + # ---- END DEBUG PRINTS ---- + results, metadata_list = check_correctness( + sandbox_fusion_url=sandbox_url, + in_outs=in_outs, + generation=generation, + timeout=timeout, + memory_limit_mb=memory_limit_mb, + language=language, + concurrent_semaphore=mp_semaphore_for_check_correctness, # Pass multiprocessing.Semaphore + ) + # print(f"Process {os.getpid()} finished check_correctness. Processed {len(results)} tasks.") + return len(results) # Return the number of processed tasks for basic validation + + +# --- The actual test case for multiprocess concurrency control --- +def test_multiprocess_global_concurrency_limit_with_semaphore(): + """ + Tests that the global concurrent_semaphore (multiprocessing.Semaphore) + correctly limits the number of concurrent calls to call_sandbox_api + across multiple processes, each potentially running multiple threads + via check_correctness's internal ThreadPoolExecutor. + """ + manager = multiprocessing.Manager() + active_calls_counter = manager.Value("i", 0) # Current active mock API calls + max_calls_tracker = manager.Value("i", 0) # Observed maximum concurrent mock API calls + call_lock = manager.Lock() # Lock to protect counters + + # Create a multiprocessing.Semaphore instance, this is the global semaphore we are testing. + # It will be passed to check_correctness and used by _process_single_case to limit calls to call_sandbox_api. + global_mp_semaphore = manager.Semaphore(MAX_GLOBAL_CONCURRENCY_LIMIT_TEST) + + mock_sandbox_url = "mock_url_for_concurrency_test" + mock_generation = "pass" # Specific code content is not important as API call is mocked + mock_memory_limit_mb = 1024 + mock_language = "python" + mock_timeout = 5 # Timeout setting, not critical for mock calls + + # Input/output data for each process + # NUM_TASKS_PER_PROCESS_TEST tasks will be handled by check_correctness's internal ThreadPoolExecutor + process_in_outs = { + "inputs": [f"task_input_{i}" for i in range(NUM_TASKS_PER_PROCESS_TEST)], + "outputs": [f"task_output_{i}" for i in range(NUM_TASKS_PER_PROCESS_TEST)], + } + + futures = [] + total_tasks_expected_to_run = NUM_PROCESSES_TEST * NUM_TASKS_PER_PROCESS_TEST + + test_start_time = time.time() + + with ProcessPoolExecutor(max_workers=NUM_PROCESSES_TEST) as executor: + for i in range(NUM_PROCESSES_TEST): + future = executor.submit( + _process_pool_worker_for_concurrency_test, # Worker function + mock_sandbox_url, + process_in_outs, + mock_generation, + mock_memory_limit_mb, + mock_language, + mock_timeout, + global_mp_semaphore, # Global semaphore to test + active_calls_counter, # Shared variables for tracking + max_calls_tracker, + call_lock, + ) + futures.append(future) + + # Wait for all processes to complete and collect results + num_tasks_processed_per_worker = [f.result() for f in futures] + test_end_time = time.time() + total_execution_time = test_end_time - test_start_time + + # Print some test statistics for debugging and validation + print("\n--- Global Concurrency Test Stats ---") + print(f"Semaphore Limit (MAX_GLOBAL_CONCURRENCY_LIMIT_TEST): {MAX_GLOBAL_CONCURRENCY_LIMIT_TEST}") + print(f"Number of Processes (NUM_PROCESSES_TEST): {NUM_PROCESSES_TEST}") + print(f"Tasks per Process (NUM_TASKS_PER_PROCESS_TEST): {NUM_TASKS_PER_PROCESS_TEST}") + print(f"Total Tasks Submitted: {total_tasks_expected_to_run}") + print(f"Simulated API Call Duration: {SIMULATED_API_CALL_DURATION_TEST}s") + print(f"Total Test Execution Time: {total_execution_time:.2f}s") + print(f"Max Concurrent Mock API Calls Observed: {max_calls_tracker.value}") + # print(f"Tasks processed per worker: {num_tasks_processed_per_worker}") + + # Verify that all submitted tasks have been processed + assert sum(num_tasks_processed_per_worker) == total_tasks_expected_to_run, ( + "Mismatch in the number of tasks processed." + ) + + # Verify that the mock API was called at least once + assert max_calls_tracker.value > 0, "The mocked API call_sandbox_api was not called." + + # Core assertion: Observed maximum concurrent calls should not exceed the semaphore's limit + assert max_calls_tracker.value <= MAX_GLOBAL_CONCURRENCY_LIMIT_TEST, ( + f"Observed concurrency ({max_calls_tracker.value}) exceeded semaphore limit " + f"({MAX_GLOBAL_CONCURRENCY_LIMIT_TEST})." + ) + + # Optional: Rough check on execution time to verify semaphore is working to limit concurrency + # Theoretical minimum execution time = (Total tasks / Concurrency limit) * Single task duration + # Actual time will be longer due to various overheads + min_expected_duration = ( + total_tasks_expected_to_run * SIMULATED_API_CALL_DURATION_TEST + ) / MAX_GLOBAL_CONCURRENCY_LIMIT_TEST + # print(f"Minimum Expected Execution Time (approx): {min_expected_duration:.2f}s") + # Allow some margin, e.g., 80% of theoretical minimum time + assert total_execution_time >= min_expected_duration * 0.8, ( + f"Total execution time ({total_execution_time:.2f}s) was unexpectedly short, suggesting the " + f"semaphore might not be effectively limiting concurrency as expected " + f"(min expected: {min_expected_duration * 0.8:.2f}s)." + ) + + +# Ensure there is no more code after this point if these were the last functions. +# If there was other code, it would follow here. +def test_unit_invalid_input_format(): + """Unit test: Invalid in_outs format passed""" + results, metadata_list = check_correctness(SANDBOX_URL, None, CODE_SUCCESS) + assert results == [-1] + assert metadata_list[0]["error"] == "Invalid input/output data" + + results, metadata_list = check_correctness(SANDBOX_URL, {}, CODE_SUCCESS) + assert results == [-1] + assert metadata_list[0]["error"] == "Invalid input/output data" + + results, metadata_list = check_correctness(SANDBOX_URL, INPUT_OUTPUT_INVALID_MISSING_KEY, CODE_SUCCESS) + assert results == [-1] + assert metadata_list[0]["error"] == "Invalid input/output data" + + +@pytest.mark.skipif(skip_condition, reason=skip_reason) +def test_unit_input_output_mismatch(): + """Unit test: Mismatch between the number of inputs and outputs""" + results, metadata_list = check_correctness(SANDBOX_URL, INPUT_OUTPUT_MISMATCH, CODE_SUCCESS) + assert results == [-1] + assert len(metadata_list) == 1 + assert metadata_list[0]["error"] == "Input/output count mismatch" + + +@pytest.mark.skipif(skip_condition, reason=skip_reason) +def test_integration_concurrency_all_timeout(): + """Integration test: High concurrency (100 cases) against real API, all causing timeout""" + concurrency_level = 100 + code_infinite_loop = """ +def knight_moves(X, Y): + MOD = 10**9 + 7 + dp = [[0] * (Y + 1) for _ in range(X + 1)] + dp[0][0] = 1 + for i in range(1, X + 1): + for j in range(1, Y + 1): + dp[i][j] = (dp[i - 1][j] + dp[i][j - 1]) % MOD + return dp[X][Y] + +def solve(): + X, Y = map(int, input().split()) + print(knight_moves(X, Y)) + +if __name__ == "__main__": + solve() + """ + + # Generate 100 simple input/output pairs (content doesn't matter) + timeout_inputs = ["324 384429" for i in range(concurrency_level)] + timeout_outputs = [f"output_{i}\n" for i in range(concurrency_level)] + timeout_in_outs = {"inputs": timeout_inputs, "outputs": timeout_outputs} + + # Set a timeout for the test cases + test_timeout = 10 # Set a timeout value + + start_time = time.time() + results, metadata_list = check_correctness(SANDBOX_URL, timeout_in_outs, code_infinite_loop, timeout=test_timeout) + end_time = time.time() + duration = end_time - start_time + print(f"\nHigh concurrency all timeout test ({concurrency_level} cases) duration: {duration:.2f} seconds") + + # Verify all results are -3 (timeout) + assert len(results) == concurrency_level, f"Expected {concurrency_level} results, got {len(results)}" + all_timed_out = all(r == -3 for r in results) + if not all_timed_out: + non_timeout_indices = [i for i, r in enumerate(results) if r != -3] + print(f"Indices that did not time out: {non_timeout_indices}") + # Print metadata for the first few non-timeout cases for debugging + for i in non_timeout_indices[:5]: + print(f"Metadata for non-timeout case {i}: {metadata_list[i]}") + assert all_timed_out, f"Not all {concurrency_level} concurrent tests resulted in timeout (-3). Results: {results}" + + # Verify metadata count and status of the first case + assert len(metadata_list) == concurrency_level + assert metadata_list[0]["status"] == "timeout" + + +@pytest.mark.skipif(skip_condition, reason=skip_reason) +def test_fn_name_success_single_case(): + """Tests successful execution for a single test case with fn_name. + from livecodebench/code_generation_lite test 510 + """ + generation_code = """ +class Solution: + def occurrencesOfElement(self, nums: List[int], queries: List[int], x: int) -> List[int]: + positions = defaultdict(list) + for idx, num in enumerate(nums): + positions[num].append(idx) + + x_positions = positions[x] + answer = [] + for k in queries: + if k > len(x_positions): + answer.append(-1) + else: + answer.append(x_positions[k-1]) + return answer +""" + in_outs = { + "fn_name": "occurrencesOfElement", + "inputs": ["[1, 3, 1, 7]\n[1, 3, 2, 4]\n1", "[1, 2, 3]\n[10]\n5"], + "outputs": ["[0, -1, 2, -1]", "[-1]"], + } + + # Use a short timeout for fast tests + results, metadata_list = check_correctness(SANDBOX_URL, in_outs, generation_code, timeout=5) + # from verl.utils.reward_score.prime_code import apps_check_correctness + # results, metadata_list = apps_check_correctness(in_outs=in_outs, generation=generation_code, + # timeout=50000, debug=True) + + assert results == [True, True] + assert "error" not in metadata_list[0] + assert metadata_list[0].get("status") != "compile_error" + assert metadata_list[0].get("status") != "runtime_error" + + +@pytest.mark.skipif(skip_condition, reason=skip_reason) +def test_none_and_empty_stdin_passed_correctly(): + """ + Tests that when stdin data is set to an empty string or None, it is still + is passed correctly to Sandbox Fusion as an empty string. + """ + echo_code = """ +import sys +print(f"You said '{sys.stdin.readline().strip()}'") +""" + in_outs = { + "inputs": [None, "", "hello"], + "outputs": ["You said ''", "You said ''", "You said 'hello'"], + } + + # Use a short timeout for fast tests + results, metadata_list = check_correctness(SANDBOX_URL, in_outs, echo_code, timeout=5) + + assert results == [True, True, True] + assert "error" not in metadata_list[0] + assert metadata_list[0].get("status") != "compile_error" + assert metadata_list[0].get("status") != "runtime_error" + + +@pytest.mark.skipif(skip_condition, reason=skip_reason) +def test_assert_case_success(): + """Tests successful execution for assert case. + from KodCode + """ + generation_code = """ +from typing import List, Tuple + +def merge_intervals(intervals: List[Tuple[int, int]]) -> List[Tuple[int, int]]: + if not intervals: + return [] + + # Sort intervals by the start time + intervals.sort(key=lambda x: x[0]) + + merged = [intervals[0]] + + for current in intervals[1:]: + last = merged[-1] + # If intervals overlap, merge them + if current[0] <= last[1]: + merged[-1] = (last[0], max(last[1], current[1])) + else: + merged.append(current) + + return merged +""" + test_cases = { + "fn_name": "merge_intervals", + "assert_case": [ + "assert merge_intervals([(0, 1), (3, 5), (4, 7), (6, 8), (10, 12)," + " (12, 14)]) == [(0, 1), (3, 8), (10, 14)]", + "assert merge_intervals([(1, 2), (2, 3), (3, 4)]) == [(1, 4)]", + "assert merge_intervals([(1, 2), (3, 4), (5, 6)]) == [(1, 2), (3, 4), (5, 5)]", + ], + } + + assert_cases = test_cases.get("assert_case") + test_cases.setdefault("inputs", ["" for _ in assert_cases]) + test_cases.setdefault("outputs", [None for _ in assert_cases]) + + # Use a short timeout for fast tests + results, metadata_list = check_correctness(SANDBOX_URL, test_cases, generation_code, timeout=5) + assert results == [True, True, -2] + for i in range(2): + assert "error" not in metadata_list[i] + assert metadata_list[i].get("status") == "success" + assert metadata_list[i].get("expected_output") is None + assert metadata_list[i].get("status") != "runtime_error" + assert "error" not in metadata_list[2] + assert metadata_list[2].get("status") != "success" + assert metadata_list[2].get("expected_output") is None + assert metadata_list[2].get("status") == "runtime_error" diff --git a/code/RL_model/verl/verl_train/tests/utils/reward_score/test_sandbox_on_cpu.py b/code/RL_model/verl/verl_train/tests/utils/reward_score/test_sandbox_on_cpu.py new file mode 100644 index 0000000000000000000000000000000000000000..ff8508de255bc24d9c9be6e13ede0eecb81e1459 --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/utils/reward_score/test_sandbox_on_cpu.py @@ -0,0 +1,190 @@ +# Copyright 2024 PRIME team and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import json +import os + +import pytest + +from verl.utils.reward_score import default_compute_score, sandbox_fusion +from verl.workers.reward_manager.prime import parallel_compute_score_async + +prime_math_answers = [ + """\\begin{bmatrix}\n -7 & 6 & -8 \\\\\n 11 & -9 & 12 \\\\\n 15 & -16 & 19 \n \\end{bmatrix}""", + """\\frac{\\sqrt{505}}{7}""", + """x^2 + y^2 + 4x - 6y + 13""", +] +prime_math_gts = [ + """\\begin{pmatrix}\n -7 & 6 & -8 \\\\\n 11 & -9 & 12 \\\\\n 15 & -16 & 19\n \\end{pmatrix}""", # mat test + """\\frac{\\sqrt{505}}{7}""", # frac test + """(x + 2)^2 + (y - 3)^2 """, # symbolic test +] + +prime_code_answers = [ + """import sys +from collections import deque + +def main(): + data = sys.stdin.read().split() + it = iter(data) + + # Read start and target positions + x0, y0, x1, y1 = int(next(it)), int(next(it)), int(next(it)), int(next(it)) + + n = int(next(it)) + allowed = set() + # The total number of allowed cells is at most 10^5. + for _ in range(n): + r = int(next(it)) + a = int(next(it)) + b = int(next(it)) + for c in range(a, b + 1): + allowed.add((r, c)) + + # Directions for the king (8 neighboring cells) + directions = [(-1, -1), (-1, 0), (-1, 1), + (0, -1), (0, 1), + (1, -1), (1, 0), (1, 1)] + + start = (x0, y0) + target = (x1, y1) + + # BFS initialization + queue = deque() + queue.append((x0, y0, 0)) + # Mark the starting cell as visited by removing it from allowed set. + allowed.discard(start) + + while queue: + x, y, moves = queue.popleft() + if (x, y) == target: + print(moves) + return + for dx, dy in directions: + nx, ny = x + dx, y + dy + if (nx, ny) in allowed: + allowed.remove((nx, ny)) + queue.append((nx, ny, moves + 1)) + + print(-1) + +if __name__ == '__main__': + main() +""" +] * 2 +prime_code_gts = [ + """{\n \"inputs\": [\n \"5 7 6 11\\n3\\n5 3 8\\n6 7 11\\n5 2 5\\n\",\n \"3 4 3 10\\n3\\n3 1 4\\n4 5 9\\n3 10 10\\n\",\n \"1 1 2 10\\n2\\n1 1 3\\n2 6 10\\n\",\n \"9 8 7 8\\n9\\n10 6 6\\n10 6 6\\n7 7 8\\n9 5 6\\n8 9 9\\n9 5 5\\n9 8 8\\n8 5 6\\n9 10 10\\n\",\n \"6 15 7 15\\n9\\n6 15 15\\n7 14 14\\n6 15 15\\n9 14 14\\n7 14 16\\n6 15 15\\n6 15 15\\n7 14 14\\n8 15 15\\n\",\n \"13 16 20 10\\n18\\n13 16 16\\n20 10 10\\n19 10 10\\n12 15 15\\n20 10 10\\n18 11 11\\n19 10 10\\n19 10 10\\n20 10 10\\n19 10 10\\n20 10 10\\n20 10 10\\n19 10 10\\n18 11 11\\n13 16 16\\n12 15 15\\n19 10 10\\n19 10 10\\n\",\n \"89 29 88 30\\n16\\n87 31 31\\n14 95 95\\n98 88 89\\n96 88 88\\n14 97 97\\n13 97 98\\n100 88 88\\n88 32 32\\n99 88 89\\n90 29 29\\n87 31 31\\n15 94 96\\n89 29 29\\n88 32 32\\n97 89 89\\n88 29 30\\n\",\n \"30 14 39 19\\n31\\n35 7 11\\n37 11 12\\n32 13 13\\n37 5 6\\n46 13 13\\n37 14 14\\n31 13 13\\n43 13 19\\n45 15 19\\n46 13 13\\n32 17 17\\n41 14 19\\n30 14 14\\n43 13 17\\n34 16 18\\n44 11 19\\n38 13 13\\n40 12 20\\n37 16 18\\n46 16 18\\n34 10 14\\n36 9 10\\n36 15 19\\n38 15 19\\n42 13 19\\n33 14 15\\n35 15 19\\n33 17 18\\n39 12 20\\n36 5 7\\n45 12 12\\n\",\n \"2 1 1 1\\n2\\n1 1 2\\n2 1 2\\n\",\n \"1 1 1 2\\n5\\n1000000000 1 10000\\n19920401 1188 5566\\n1000000000 1 10000\\n1 1 10000\\n5 100 200\\n\",\n \"1 1 1000000000 2\\n5\\n1000000000 1 10000\\n19920401 1188 5566\\n1000000000 1 10000\\n1 1 10000\\n5 100 200\\n\"\n ],\n \"outputs\": [\n \"4\\n\",\n \"6\\n\",\n \"-1\\n\",\n \"2\\n\",\n \"1\\n\",\n \"-1\\n\",\n \"1\\n\",\n \"9\\n\",\n \"1\\n\",\n \"1\\n\",\n \"-1\\n\"\n ]\n}""", # A correct sample # noqa: E501 + """{\n \"inputs\": [\n \"5 7 6 11\\n3\\n5 3 8\\n6 7 11\\n5 2 5\\n\",\n \"3 4 3 10\\n3\\n3 1 4\\n4 5 9\\n3 10 10\\n\",\n \"1 1 2 10\\n2\\n1 1 3\\n2 6 10\\n\",\n \"9 8 7 8\\n9\\n10 6 6\\n10 6 6\\n7 7 8\\n9 5 6\\n8 9 9\\n9 5 5\\n9 8 8\\n8 5 6\\n9 10 10\\n\",\n \"6 15 7 15\\n9\\n6 15 15\\n7 14 14\\n6 15 15\\n9 14 14\\n7 14 16\\n6 15 15\\n6 15 15\\n7 14 14\\n8 15 15\\n\",\n \"13 16 20 10\\n18\\n13 16 16\\n20 10 10\\n19 10 10\\n12 15 15\\n20 10 10\\n18 11 11\\n19 10 10\\n19 10 10\\n20 10 10\\n19 10 10\\n20 10 10\\n20 10 10\\n19 10 10\\n18 11 11\\n13 16 16\\n12 15 15\\n19 10 10\\n19 10 10\\n\",\n \"89 29 88 30\\n16\\n87 31 31\\n14 95 95\\n98 88 89\\n96 88 88\\n14 97 97\\n13 97 98\\n100 88 88\\n88 32 32\\n99 88 89\\n90 29 29\\n87 31 31\\n15 94 96\\n89 29 29\\n88 32 32\\n97 89 89\\n88 29 30\\n\",\n \"30 14 39 19\\n31\\n35 7 11\\n37 11 12\\n32 13 13\\n37 5 6\\n46 13 13\\n37 14 14\\n31 13 13\\n43 13 19\\n45 15 19\\n46 13 13\\n32 17 17\\n41 14 19\\n30 14 14\\n43 13 17\\n34 16 18\\n44 11 19\\n38 13 13\\n40 12 20\\n37 16 18\\n46 16 18\\n34 10 14\\n36 9 10\\n36 15 19\\n38 15 19\\n42 13 19\\n33 14 15\\n35 15 19\\n33 17 18\\n39 12 20\\n36 5 7\\n45 12 12\\n\",\n \"2 1 1 1\\n2\\n1 1 2\\n2 1 2\\n\",\n \"1 1 1 2\\n5\\n1000000000 1 10000\\n19920401 1188 5566\\n1000000000 1 10000\\n1 1 10000\\n5 100 200\\n\",\n \"1 1 1000000000 2\\n5\\n1000000000 1 10000\\n19920401 1188 5566\\n1000000000 1 10000\\n1 1 10000\\n5 100 200\\n\"\n ],\n \"outputs\": [\n \"4\\n\",\n \"6\\n\",\n \"-1\\n\",\n \"-1\\n\",\n \"1\\n\",\n \"-1\\n\",\n \"1\\n\",\n \"9\\n\",\n \"1\\n\",\n \"1\\n\",\n \"-1\\n\"\n ]\n}""", # noqa: E501 +] # A failed sample with first several in-out passed + +prime_code_scores = [1.0, 0.9] + + +def test_parallelism(): + """ + Test if process pool works properly + """ + sequences_str = [] + ground_truth = [] + data_sources = [] + while len(sequences_str) < 32: + sequences_str.extend(prime_code_answers) + ground_truth.extend(prime_code_gts) + data_sources.extend(["codecontests"] * len(prime_code_answers)) + + sequences_str.extend(prime_math_answers) + ground_truth.extend(prime_math_gts) + data_sources.extend(["numina_aops_forum"] * len(prime_math_answers)) + + scores = asyncio.run( + parallel_compute_score_async(default_compute_score, sequences_str, ground_truth, data_sources, num_processes=16) + ) + print(scores) + + +@pytest.mark.skip("pyext not compatible with python 3.12") +def test_prime_code(): + """ + Test PRIME code sandbox. + """ + data_source = "codecontests" + for completion, ground_truth, score_ in zip(prime_code_answers, prime_code_gts, prime_code_scores, strict=True): + score = default_compute_score(data_source, completion, ground_truth) + assert float(score) == score_ + + +# Use the pytest.mark.skipif decorator to skip the test +@pytest.mark.skipif(not os.environ.get("SANDBOX_FUSION_URL"), reason="SANDBOX_FUSION_URL environment variable not set") +def test_prime_code_sandbox_fusion(): + """ + Test PRIME code on sandbox fusion. Skips if SANDBOX_FUSION_URL is not set. + """ + data_source = "codecontests" + # Get the URL from the environment variable, as skipif ensures it is set at this point + sandbox_fusion_url = os.environ.get("SANDBOX_FUSION_URL") + # Removed the previous 'if not sandbox_url' check block + + for completion, ground_truth, score_ in zip(prime_code_answers, prime_code_gts, prime_code_scores, strict=True): + score = default_compute_score( + data_source, completion, ground_truth, extra_info={"sandbox_fusion_url": sandbox_fusion_url} + ) # <-- Use the URL obtained from the environment variable + assert float(score) == score_ + + +@pytest.mark.skipif(not os.environ.get("SANDBOX_FUSION_URL"), reason="SANDBOX_FUSION_URL environment variable not set") +def test_continuous_score_consistency(): + """ + Verify that continuous score calculation is consistent between prime_code and sandbox_fusion. + Uses a test case where the first 9 out of 11 sub-cases pass (expected score 0.9). + """ + from verl.utils.reward_score import prime_code + + completion = prime_code_answers[1] # Use the second sample + ground_truth = prime_code_gts[1] # Use the second sample (9/11 pass, first 9 pass) + expected_continuous_score = 0.9 + + # 1. Calculate score using prime_code (default) with continuous=True + prime_score, _ = sandbox_fusion.compute_score( + os.environ.get("SANDBOX_FUSION_URL"), None, completion, ground_truth, continuous=True + ) + + # 2. Calculate score using sandbox_fusion with continuous=True + # Ensure the extra_info key triggers the sandbox_fusion path in default_compute_score + fusion_score, _ = prime_code.compute_score(completion, ground_truth, continuous=True) + + # 3. Assert scores are equal (using pytest.approx for float comparison) + assert float(prime_score) == pytest.approx(expected_continuous_score) + assert float(fusion_score) == pytest.approx(expected_continuous_score) + assert float(prime_score) == pytest.approx(float(fusion_score)) + print(f"Continuous Score (Prime Code): {prime_score}") + print(f"Continuous Score (Sandbox Fusion): {fusion_score}") + + +@pytest.mark.skip("pyext not compatible with python 3.12") +def test_check_correctness(): + from verl.utils.reward_score.prime_code import apps_check_correctness + + completion = prime_code_answers[0] + ground_truth = json.loads(prime_code_gts[0]) + ground_truth_single = {"inputs": ground_truth["inputs"][:1], "outputs": ground_truth["outputs"][:1]} + res, meta = apps_check_correctness(in_outs=ground_truth_single, generation=completion, timeout=5, debug=False) + print(res, meta) + + +def test_prime_math(): + data_source = "numina_aops_forum" + for completion, ground_truth in zip(prime_math_answers, prime_math_gts, strict=True): + score = default_compute_score(data_source, completion, ground_truth) + assert float(score) == 1.0 diff --git a/code/RL_model/verl/verl_train/tests/workers/rollout/resource/tool_configs/mcp_server.json b/code/RL_model/verl/verl_train/tests/workers/rollout/resource/tool_configs/mcp_server.json new file mode 100644 index 0000000000000000000000000000000000000000..9ed41f10bc00784d6c4935bf882900aee748723f --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/workers/rollout/resource/tool_configs/mcp_server.json @@ -0,0 +1,8 @@ +{ + "mcpServers": { + "Tavily Expert": { + "url": "https://tavily.api.tadata.com/mcp/tavily/your_expert", + "auth_token": "your_tavily_token" + } + } +} \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/tests/workers/rollout/resource/tool_configs/mcp_tool_config b/code/RL_model/verl/verl_train/tests/workers/rollout/resource/tool_configs/mcp_tool_config new file mode 100644 index 0000000000000000000000000000000000000000..a9a45bd0bc2fdc7b0805f7af2fa56521a1544a47 --- /dev/null +++ b/code/RL_model/verl/verl_train/tests/workers/rollout/resource/tool_configs/mcp_tool_config @@ -0,0 +1,11 @@ +tools: + - class_name: verl.tools.mcp_search_tool.MCPSearchTool + config: + rate_limit: 120 + timeout: 120 + type: mcp + mcp: + mcp_servers_config_path: ./resource/tool_configs/mcp_server.json + # optional + tool_selected_list: + - tavily_search_tool \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/verl/experimental/reward_loop/reward_manager/base.py b/code/RL_model/verl/verl_train/verl/experimental/reward_loop/reward_manager/base.py new file mode 100644 index 0000000000000000000000000000000000000000..1c26e77ad7fd873c43f7861bb0904757d0d9acc0 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/experimental/reward_loop/reward_manager/base.py @@ -0,0 +1,53 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from abc import ABC, abstractmethod + +from omegaconf import DictConfig +from transformers import AutoTokenizer + +from verl import DataProto +from verl.utils.ray_utils import get_event_loop + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class RewardManagerBase(ABC): + _class_initialized = False + + def __init__(self, config: DictConfig, tokenizer: AutoTokenizer): + """Initialize reward manager. + + Args: + config (DictConfig): YAML config. + tokenizer (AutoTokenizer): Tokenizer for tokenize messages. + """ + self.config = config + self.tokenizer = tokenizer + self.loop = get_event_loop() + self.init_class(config, tokenizer) + + @classmethod + def init_class(cls, config: DictConfig, tokenizer: AutoTokenizer): + """Initialize class state shared across all instances.""" + if cls._class_initialized: + return + cls._class_initialized = True + + @abstractmethod + async def run_single(self, data: DataProto): + raise NotImplementedError diff --git a/code/RL_model/verl/verl_train/verl/experimental/reward_loop/reward_manager/dapo.py b/code/RL_model/verl/verl_train/verl/experimental/reward_loop/reward_manager/dapo.py new file mode 100644 index 0000000000000000000000000000000000000000..ad06494c85b008ae095a648d62d529dd36e8a230 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/experimental/reward_loop/reward_manager/dapo.py @@ -0,0 +1,114 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +from verl import DataProto +from verl.experimental.reward_loop.reward_manager import register +from verl.experimental.reward_loop.reward_manager.base import RewardManagerBase +from verl.utils.reward_score import default_compute_score + + +@register("dapo") +class DAPORewardManager(RewardManagerBase): + """DAPO Reward Manager.""" + + def __init__(self, config, tokenizer, compute_score=None, reward_router_address=None, reward_model_tokenizer=None): + super().__init__(config, tokenizer) + self.compute_score = compute_score or default_compute_score + self.is_async_reward_score = inspect.iscoroutinefunction(self.compute_score) + + # DAPO Reward Config + overlong_buffer_cfg = config.reward_model.get("reward_kwargs", {}).get("overlong_buffer_cfg", None) + self.overlong_buffer_cfg = overlong_buffer_cfg + self.max_resp_len = config.reward_model.get("reward_kwargs", {}).get("max_resp_len", None) + self.reward_router_address = reward_router_address + self.reward_model_tokenizer = reward_model_tokenizer + + if self.overlong_buffer_cfg is not None: + assert self.max_resp_len is not None, ( + f"max_resp_len must be provided if {overlong_buffer_cfg=}, but got None" + ) + assert self.max_resp_len >= self.overlong_buffer_cfg.len, ( + "max_resp_len must be larger than overlong_buffer.len" + ) + + async def run_single(self, data: DataProto) -> dict: + assert len(data) == 1, "Only support single data item" + data_item = data[0] + response_ids = data_item.batch["responses"] + response_length = response_ids.shape[-1] + valid_response_length = data_item.batch["attention_mask"][-response_length:].sum() + valid_response_ids = response_ids[:valid_response_length] + + data_source = data_item.non_tensor_batch["data_source"] + ground_truth = data_item.non_tensor_batch["reward_model"]["ground_truth"] + extra_info = data_item.non_tensor_batch.get("extra_info", {}) + + response_str = await self.loop.run_in_executor( + None, lambda: self.tokenizer.decode(valid_response_ids, skip_special_tokens=True) + ) + extra_reward_kwargs = ( + { + "reward_router_address": self.reward_router_address, + "reward_model_tokenizer": self.reward_model_tokenizer, + } + if self.reward_router_address is not None + else {} + ) + if self.is_async_reward_score: + result = await self.compute_score( + data_source=data_source, + solution_str=response_str, + ground_truth=ground_truth, + extra_info=extra_info, + **extra_reward_kwargs, + ) + else: + result = await self.loop.run_in_executor( + None, + lambda: self.compute_score( + data_source=data_source, + solution_str=response_str, + ground_truth=ground_truth, + extra_info=extra_info, + **extra_reward_kwargs, + ), + ) + + reward_extra_info = {} + + score: float + if isinstance(result, dict): + score = result["score"] + for key, value in result.items(): + reward_extra_info[key] = value + else: + score = result + reward_extra_info["acc"] = score + + reward = score + + if self.overlong_buffer_cfg is not None and self.overlong_buffer_cfg.enable: + overlong_buffer_len = self.overlong_buffer_cfg.len + expected_len = self.max_resp_len - overlong_buffer_len + exceed_len = valid_response_length - expected_len + overlong_penalty_factor = self.overlong_buffer_cfg.penalty_factor + overlong_reward = min(-exceed_len / overlong_buffer_len * overlong_penalty_factor, 0) + reward += overlong_reward + if self.overlong_buffer_cfg.log: + reward_extra_info["overlong_reward"] = overlong_reward + reward_extra_info["overlong"] = overlong_reward < 0 + + return {"reward_score": reward, "reward_extra_info": reward_extra_info} diff --git a/code/RL_model/verl/verl_train/verl/experimental/reward_loop/reward_manager/limited.py b/code/RL_model/verl/verl_train/verl/experimental/reward_loop/reward_manager/limited.py new file mode 100644 index 0000000000000000000000000000000000000000..e4cb047a81016f847bccf7eae1b2825cc64a02b7 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/experimental/reward_loop/reward_manager/limited.py @@ -0,0 +1,540 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import inspect +import logging + +from omegaconf import DictConfig +from transformers import AutoTokenizer + +from verl import DataProto +from verl.experimental.reward_loop.reward_manager import register as register_manager +from verl.experimental.reward_loop.reward_manager.base import RewardManagerBase +from verl.utils.ray_utils import get_event_loop +from verl.utils.reward_score import default_compute_score +from verl.workers.reward_manager import register as register_manager_legacy + +logger = logging.getLogger(__file__) + + +class AsyncTokenBucket: + """Async token bucket for rate limiting with variable token consumption. + + The token bucket algorithm is a classic rate limiting technique that allows + for burst traffic while maintaining an average rate limit. This implementation + is async-first and thread-safe, designed for use in concurrent environments. + + The bucket starts full and refills at a constant rate (rate_limit tokens/second). + When tokens are acquired, they are consumed from the bucket. If insufficient + tokens are available, the acquire() method will sleep until enough tokens + have been refilled. + + This implementation supports variable token consumption, making it suitable + for rate limiting based on request size (e.g., API token usage). + + Args: + rate_limit (float): The rate at which tokens are added to the bucket, + in tokens per second. For example, rate_limit=10.0 means 10 tokens + are added per second (or 600 per minute). + max_tokens (float, optional): The maximum capacity of the token bucket. + Defaults to rate_limit if not specified. This value determines the + maximum burst size allowed. + + Attributes: + rate_limit (float): Tokens added per second. + max_tokens (float): Maximum bucket capacity. + tokens (float): Current number of available tokens. + last_update (float | None): Timestamp of last token update (from event loop). + lock (asyncio.Lock): Async lock for thread-safe token operations. + + Example: + >>> # Limit to 60 requests per minute (1 request per second) + >>> rpm_limiter = AsyncTokenBucket(rate_limit=1.0, max_tokens=1.0) + >>> await rpm_limiter.acquire(1.0) # Consumes 1 token + >>> + >>> # Limit to 10000 tokens per minute (~166.67 tokens per second) + >>> tpm_limiter = AsyncTokenBucket(rate_limit=166.67, max_tokens=166.67) + >>> await tpm_limiter.acquire(100.0) # Consumes 100 tokens + + Thread Safety: + All operations are protected by an asyncio.Lock, making this class safe + for concurrent use across multiple coroutines. + + Algorithm Details: + 1. On each acquire(), calculate elapsed time since last update + 2. Refill tokens: tokens += elapsed * rate_limit (capped at max_tokens) + 3. If tokens >= num_tokens: consume tokens and return + 4. Otherwise: calculate wait_time = tokens_needed / rate_limit, then sleep + 5. Retry after sleep (loop back to step 1) + """ + + def __init__(self, rate_limit: float, max_tokens: float = None): + self.rate_limit = rate_limit + self.max_tokens = max_tokens or rate_limit + self.tokens = self.max_tokens + self.last_update = None + self.lock = asyncio.Lock() + + async def acquire(self, num_tokens: float = 1.0) -> None: + """Acquire tokens from the bucket, waiting if necessary. + + This method will block (using asyncio.sleep) until sufficient tokens + are available. It automatically refills tokens based on elapsed time + and the configured rate_limit. + + For requests exceeding max_tokens, the method will wait for enough time + to accumulate the required tokens at the configured rate_limit, allowing + tokens to temporarily go negative. + + Args: + num_tokens (float): Number of tokens to consume. Defaults to 1.0. + Can be fractional for fine-grained rate limiting. + + Returns: + None: Returns when tokens have been successfully acquired. + + Raises: + No exceptions are raised. This method will wait indefinitely until + tokens become available. + + Example: + >>> bucket = AsyncTokenBucket(rate_limit=10.0) + >>> await bucket.acquire(5.0) # Acquire 5 tokens + >>> await bucket.acquire(1.0) # Acquire 1 more token + + Implementation Notes: + - Uses event loop's time() for high-precision timestamps + - Lock is released during sleep to allow other coroutines to proceed + - Tokens are refilled continuously based on elapsed time + - For requests > max_tokens, allows temporary negative balance + """ + # Handle requests larger than max_tokens separately + if num_tokens > self.max_tokens: + wait_time = 0.0 + async with self.lock: + loop = get_event_loop() + now = loop.time() + if self.last_update is None: + self.last_update = now + + elapsed = now - self.last_update + new_tokens = elapsed * self.rate_limit + self.tokens = min(self.max_tokens, self.tokens + new_tokens) + + tokens_needed = num_tokens - self.tokens + if tokens_needed > 0: + wait_time = tokens_needed / self.rate_limit + + self.tokens -= num_tokens + self.last_update = now + + if wait_time > 0: + await asyncio.sleep(wait_time) + return + + # Standard case: request <= max_tokens + while True: + wait_time = 0.0 + async with self.lock: + loop = get_event_loop() + now = loop.time() + if self.last_update is None: + self.last_update = now + + elapsed = now - self.last_update + new_tokens = elapsed * self.rate_limit + self.tokens = min(self.max_tokens, self.tokens + new_tokens) + self.last_update = now + + if self.tokens >= num_tokens: + self.tokens -= num_tokens + return + + tokens_needed = num_tokens - self.tokens + wait_time = tokens_needed / self.rate_limit + + if wait_time > 0: + await asyncio.sleep(wait_time) + + +@register_manager("rate_limited") +@register_manager_legacy("rate_limited") +class RateLimitedRewardManager(RewardManagerBase): + """Reward manager with rate limiting for API-based reward functions. + + This manager implements a sophisticated three-layer rate limiting system + designed for LLM-as-judge scenarios where reward computation involves + external API calls (e.g., OpenAI, Anthropic, Claude) that have rate limits. + + The three layers of rate limiting are: + 1. **Concurrency limiting** (max_concurrent): Limits the number of + simultaneous API requests using asyncio.Semaphore. This prevents + overwhelming the API with too many parallel connections. + + 2. **Request rate limiting** (max_rpm): Limits requests per minute + using AsyncTokenBucket. Each request consumes 1 token. Useful for + APIs with per-minute request quotas. + + 3. **Token rate limiting** (max_tpm): Limits tokens per minute using + AsyncTokenBucket. Each request consumes estimated_tokens_per_request + tokens. Essential for APIs that bill or limit based on token usage + (e.g., GPT-4 API). + + All rate limiters are **global class-level resources**, meaning they are + shared across all instances of this manager. This ensures that rate limits + are enforced consistently across multiple workers in distributed training. + + Rate Limiting Flow: + When processing a reward request, the manager: + 1. Acquires RPM token (if rpm_limiter enabled) + 2. Acquires TPM tokens (if tpm_limiter enabled) + 3. Acquires concurrency semaphore + 4. Executes reward computation with timeout + 5. Releases concurrency semaphore + 6. Tokens are automatically refilled by the token buckets + + Args: + config (DictConfig): Configuration object containing reward_model settings: + - max_concurrent (int): Max parallel requests. Default: 1 + - max_rpm (int | None): Max requests per minute. Default: None (unlimited) + - max_tpm (int | None): Max tokens per minute. Default: None (unlimited) + - estimated_tokens_per_request (int): Estimated tokens per request for + TPM limiting. Default: 2000 + - timeout (float): Timeout for reward computation in seconds. Default: 300 + tokenizer (AutoTokenizer): HuggingFace tokenizer for decoding responses. + compute_score (callable, optional): Custom reward scoring function. Can be + sync or async. Defaults to default_compute_score. + reward_router_address (str | None): Address for reward router service. + reward_model_tokenizer (AutoTokenizer | None): Optional tokenizer for reward model. + + Class Attributes (Global State): + _semaphore (asyncio.Semaphore): Global concurrency limiter + _max_concurrent (int): Max concurrent requests + _rpm_limiter (AsyncTokenBucket | None): Request rate limiter + _max_rpm (int | None): Max requests per minute + _tpm_limiter (AsyncTokenBucket | None): Token rate limiter + _max_tpm (int | None): Max tokens per minute + _estimated_tokens_per_request (int): Estimated tokens per request + _class_initialized (bool): Whether class has been initialized + + Example Configuration: + >>> config = DictConfig({ + ... "reward_model": { + ... "max_concurrent": 10, # 10 parallel requests + ... "max_rpm": 500, # 500 requests/minute + ... "max_tpm": 100000, # 100k tokens/minute + ... "estimated_tokens_per_request": 2000, + ... "timeout": 60.0, + ... } + ... }) + >>> manager = RateLimitedRewardManager(config, tokenizer) + + Thread Safety: + This class is designed for concurrent use. All rate limiting resources + are protected by asyncio primitives (Lock, Semaphore). + + See Also: + - AsyncTokenBucket: Token bucket implementation for rate limiting + - RewardManagerBase: Base class for reward managers + - verl.utils.reward_score.default_compute_score: Default scoring function + """ + + # Class-level state for global rate limiting + _semaphore = None + _max_concurrent = None + _rpm_limiter = None + _max_rpm = None + _tpm_limiter = None + _max_tpm = None + _estimated_tokens_per_request = None + _class_initialized = False + + @classmethod + def init_class(cls, config: DictConfig, tokenizer: AutoTokenizer): + """Initialize class state shared across all instances.""" + # Check if already initialized before calling parent. + # + # NOTE: This class owns a *global*, class-level set of rate limiters. Once the class has been + # initialized, subsequent instantiations cannot change the shared limiters. This is by design, + # but it can be surprising (and dangerous) when the first initialization happens with default + # values (often "unlimited") and later code tries to apply limits. + if cls._class_initialized: + rm_cfg = config.get("reward_model") or {} + incoming_max_rpm = rm_cfg.get("max_rpm", None) + incoming_max_tpm = rm_cfg.get("max_tpm", None) + + # Warn when a caller is trying to change the global RPM/TPM limits after initialization. + # This commonly happens if the first instance was created without a config (legacy signature), + # which initializes the global limiters to their defaults and locks them in. + if (incoming_max_rpm != cls._max_rpm) or (incoming_max_tpm != cls._max_tpm): + if ( + incoming_max_rpm is not None + or incoming_max_tpm is not None + or cls._max_rpm is not None + or cls._max_tpm is not None + ): + logger.warning( + "RateLimitedRewardManager has already been initialized and its rate limiters are shared " + "globally across instances. The incoming (max_rpm/max_tpm) settings will be ignored. " + "This can lead to unexpected behavior (e.g., exceeding API rate limits) if the first " + "initialization used defaults (often unlimited). " + f"Existing: max_rpm={cls._max_rpm}, max_tpm={cls._max_tpm}. " + f"Incoming: max_rpm={incoming_max_rpm}, max_tpm={incoming_max_tpm}. " + "To apply different limits, ensure the first RateLimitedRewardManager created in this " + "process uses the desired configuration (or restart/reset the process)." + ) + return + + super().init_class(config, tokenizer) + + rm_cfg = config.get("reward_model") or {} + + # Concurrency limiter + cls._max_concurrent = rm_cfg.get("max_concurrent", 1) + cls._semaphore = asyncio.Semaphore(cls._max_concurrent) + + # Request rate limiter (RPM) + cls._max_rpm = rm_cfg.get("max_rpm", None) + if cls._max_rpm is not None: + requests_per_second = cls._max_rpm / 60.0 + cls._rpm_limiter = AsyncTokenBucket(rate_limit=requests_per_second, max_tokens=requests_per_second) + else: + cls._rpm_limiter = None + + # Token rate limiter (TPM) + cls._max_tpm = rm_cfg.get("max_tpm", None) + cls._estimated_tokens_per_request = rm_cfg.get("estimated_tokens_per_request", 2000) + if cls._max_tpm is not None: + tokens_per_second = cls._max_tpm / 60.0 + cls._tpm_limiter = AsyncTokenBucket(rate_limit=tokens_per_second, max_tokens=tokens_per_second) + else: + cls._tpm_limiter = None + + log_msg = "Rate limiting configuration:\n" + log_msg += f" - Concurrency limit: {cls._max_concurrent}\n" + if cls._max_rpm is not None: + log_msg += f" - Request rate limit: {cls._max_rpm} RPM ({cls._max_rpm / 60.0:.2f} RPS)\n" + else: + log_msg += " - Request rate limit: unlimited\n" + if cls._max_tpm is not None: + log_msg += f" - Token rate limit: {cls._max_tpm} TPM ({cls._max_tpm / 60.0:.2f} TPS)\n" + log_msg += f" - Estimated tokens per request: {cls._estimated_tokens_per_request}\n" + else: + log_msg += " - Token rate limit: unlimited\n" + log_msg += "All limiters are shared globally across all workers." + logger.info(log_msg) + + cls._class_initialized = True + + def __init__( + self, + config: DictConfig | None = None, + tokenizer: AutoTokenizer | None = None, + compute_score=None, + reward_router_address=None, + reward_model_tokenizer=None, + # Legacy (AbstractRewardManager) kwargs for compatibility. Not used. + num_examine: int | None = None, + reward_fn_key: str | None = None, + **kwargs, + ): + # When called via the legacy AbstractRewardManager signature, `config` may be absent. + # In that case we fall back to an empty config so training can proceed. + if config is None: + config = DictConfig({"reward_model": {}}) + if tokenizer is None: + raise TypeError("RateLimitedRewardManager requires `tokenizer`.") + + super().__init__(config, tokenizer) + self.compute_score = compute_score or default_compute_score + self.is_async_reward_score = inspect.iscoroutinefunction(self.compute_score) + self.reward_router_address = reward_router_address + self.reward_model_tokenizer = reward_model_tokenizer + self.timeout = config.reward_model.get("timeout", 300.0) + + async def _compute_reward( + self, data_source: str, solution_str: str, ground_truth: str, extra_info: dict + ) -> dict | float: + extra_reward_kwargs = ( + { + "reward_router_address": self.reward_router_address, + "reward_model_tokenizer": self.reward_model_tokenizer, + } + if self.reward_router_address is not None + else {} + ) + if self.is_async_reward_score: + return await self.compute_score( + data_source=data_source, + solution_str=solution_str, + ground_truth=ground_truth, + extra_info=extra_info, + **extra_reward_kwargs, + ) + else: + return await self.loop.run_in_executor( + None, + lambda: self.compute_score( + data_source=data_source, + solution_str=solution_str, + ground_truth=ground_truth, + extra_info=extra_info, + **extra_reward_kwargs, + ), + ) + + async def run_single(self, data: DataProto) -> dict: + assert len(data) == 1, "Only support single data item" + data_item = data[0] + + response_ids = data_item.batch["responses"] + response_length = response_ids.shape[-1] + valid_response_length = data_item.batch["attention_mask"][-response_length:].sum() + valid_response_ids = response_ids[:valid_response_length] + + data_source = data_item.non_tensor_batch["data_source"] + ground_truth = data_item.non_tensor_batch["reward_model"]["ground_truth"] + extra_info = data_item.non_tensor_batch.get("extra_info", {}) + tool_extra_fields = data_item.non_tensor_batch.get("tool_extra_fields", None) + if tool_extra_fields is not None: + extra_info.update(tool_extra_fields.items()) + + response_str = await self.loop.run_in_executor( + None, lambda: self.tokenizer.decode(valid_response_ids, skip_special_tokens=True) + ) + + reward_extra_info = {} + + # Apply rate limiting layers + if self._rpm_limiter is not None: + await self._rpm_limiter.acquire(1.0) + + if self._tpm_limiter is not None: + estimated_tokens = self._estimated_tokens_per_request + await self._tpm_limiter.acquire(estimated_tokens) + + async with self._semaphore: + try: + result = await asyncio.wait_for( + self._compute_reward( + data_source=data_source, + solution_str=response_str, + ground_truth=ground_truth, + extra_info=extra_info, + ), + timeout=self.timeout, + ) + + score: float + if isinstance(result, dict): + score = result["score"] + for key, value in result.items(): + reward_extra_info[key] = value + else: + score = result + reward_extra_info["acc"] = score + + reward = score + + except asyncio.TimeoutError: + logger.warning( + f"Reward computation timed out after {self.timeout}s for data_source={data_source}. " + f"Response preview: {response_str[:100]}..." + ) + reward = 0.0 + reward_extra_info["timeout"] = True + reward_extra_info["acc"] = 0.0 + + except Exception as e: + logger.error( + f"Reward computation failed for data_source={data_source}: {e}. " + f"Response preview: {response_str[:100]}..." + ) + reward = 0.0 + reward_extra_info["error"] = str(e) + reward_extra_info["acc"] = 0.0 + + return {"reward_score": reward, "reward_extra_info": reward_extra_info} + + def __call__(self, data: DataProto, return_dict: bool = False): + """Make the manager callable like traditional reward managers. + + This method provides compatibility with the existing reward manager interface + by wrapping the async run_single method in a synchronous call. + + Args: + data (DataProto): Input data containing prompts and responses. + return_dict (bool): If True, return a dict with reward_tensor and reward_extra_info. + If False, return only the reward_tensor. Defaults to False. + + Returns: + torch.Tensor | dict: If return_dict is False, returns a tensor of shape [batch_size, response_length] + with rewards. If return_dict is True, returns a dict with: + - reward_tensor: The reward tensor + - reward_extra_info: Dict containing extra information about rewards + """ + from collections import defaultdict + + import torch + + # If there are pre-computed rm_scores, return them directly + if "rm_scores" in data.batch.keys(): + if return_dict: + reward_extra_keys = data.meta_info.get("reward_extra_keys", []) + reward_extra_info = {key: data.non_tensor_batch[key] for key in reward_extra_keys} + return {"reward_tensor": data.batch["rm_scores"], "reward_extra_info": reward_extra_info} + else: + return data.batch["rm_scores"] + + # Initialize reward tensor + reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) + reward_extra_info = defaultdict(list) + + # Process each data item through the async event loop + async def process_batch(): + tasks = [] + for i in range(len(data)): + data_item = data[i : i + 1] # Get single item as DataProto slice + tasks.append(self.run_single(data_item)) + + results = await asyncio.gather(*tasks) + return results + + # Run the async processing using self.loop property which lazily gets/creates event loop + # This ensures rate limiters and semaphores work correctly by using the same loop + results = self.loop.run_until_complete(process_batch()) + + # Aggregate results into reward tensor and extra info + for i, result in enumerate(results): + data_item = data[i] + response_ids = data_item.batch["responses"] + response_length = response_ids.shape[-1] + valid_response_length = data_item.batch["attention_mask"][-response_length:].sum() + + reward = result["reward_score"] + reward_tensor[i, valid_response_length - 1] = reward + + # Collect extra info + if "reward_extra_info" in result: + for key, value in result["reward_extra_info"].items(): + reward_extra_info[key].append(value) + + if return_dict: + return { + "reward_tensor": reward_tensor, + "reward_extra_info": reward_extra_info, + } + else: + return reward_tensor diff --git a/code/RL_model/verl/verl_train/verl/experimental/reward_loop/reward_manager/naive.py b/code/RL_model/verl/verl_train/verl/experimental/reward_loop/reward_manager/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..a7255603da7e3d97531d02f59e94a9c41475e713 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/experimental/reward_loop/reward_manager/naive.py @@ -0,0 +1,99 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +from verl import DataProto +from verl.experimental.reward_loop.reward_manager import register +from verl.experimental.reward_loop.reward_manager.base import RewardManagerBase +from verl.utils.reward_score import default_compute_score + + +@register("naive") +class NaiveRewardManager(RewardManagerBase): + """The reward manager.""" + + def __init__(self, config, tokenizer, compute_score=None, reward_router_address=None, reward_model_tokenizer=None): + super().__init__(config, tokenizer) + self.compute_score = compute_score or default_compute_score + self.is_async_reward_score = inspect.iscoroutinefunction(self.compute_score) + self.reward_router_address = reward_router_address + self.reward_model_tokenizer = reward_model_tokenizer + + async def run_single(self, data: DataProto) -> dict: + assert len(data) == 1, "Only support single data item" + data_item = data[0] + response_ids = data_item.batch["responses"] + response_length = response_ids.shape[-1] + valid_response_length = data_item.batch["attention_mask"][-response_length:].sum() + valid_response_ids = response_ids[:valid_response_length] + + data_source = data_item.non_tensor_batch["data_source"] + ground_truth = data_item.non_tensor_batch["reward_model"]["ground_truth"] + extra_info = data_item.non_tensor_batch.get("extra_info", {}) + tool_extra_fields = data_item.non_tensor_batch.get("tool_extra_fields", None) + if tool_extra_fields is not None: + extra_info.update(tool_extra_fields.items()) + + num_turns = data_item.non_tensor_batch.get("__num_turns__", None) + rollout_reward_scores = data_item.non_tensor_batch.get("reward_scores", {}) + extra_info["num_turns"] = num_turns + extra_info["rollout_reward_scores"] = rollout_reward_scores + + response_str = await self.loop.run_in_executor( + None, lambda: self.tokenizer.decode(valid_response_ids, skip_special_tokens=True) + ) + + extra_reward_kwargs = ( + { + "reward_router_address": self.reward_router_address, + "reward_model_tokenizer": self.reward_model_tokenizer, + } + if self.reward_router_address is not None + else {} + ) + if self.is_async_reward_score: + result = await self.compute_score( + data_source=data_source, + solution_str=response_str, + ground_truth=ground_truth, + extra_info=extra_info, + **extra_reward_kwargs, + ) + else: + result = await self.loop.run_in_executor( + None, + lambda: self.compute_score( + data_source=data_source, + solution_str=response_str, + ground_truth=ground_truth, + extra_info=extra_info, + **extra_reward_kwargs, + ), + ) + + reward_extra_info = {} + + score: float + if isinstance(result, dict): + score = result["score"] + for key, value in result.items(): + reward_extra_info[key] = value + else: + score = result + reward_extra_info["acc"] = score + + reward = score + + return {"reward_score": reward, "reward_extra_info": reward_extra_info} diff --git a/code/RL_model/verl/verl_train/verl/experimental/reward_loop/reward_manager/registry.py b/code/RL_model/verl/verl_train/verl/experimental/reward_loop/reward_manager/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..c2da59c419f289643b50476774c6c68fc04b1826 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/experimental/reward_loop/reward_manager/registry.py @@ -0,0 +1,55 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable + +from verl.experimental.reward_loop.reward_manager.base import RewardManagerBase + +__all__ = ["register", "get_reward_manager_cls"] + +REWARD_LOOP_MANAGER_REGISTRY: dict[str, type[RewardManagerBase]] = {} + + +def register(name: str) -> Callable[[type[RewardManagerBase]], type[RewardManagerBase]]: + """Decorator to register a reward manager class with a given name. + + Args: + name: `(str)` + The name of the reward manager. + """ + + def decorator(cls: type[RewardManagerBase]) -> type[RewardManagerBase]: + if name in REWARD_LOOP_MANAGER_REGISTRY and REWARD_LOOP_MANAGER_REGISTRY[name] != cls: + raise ValueError( + f"reward manager {name} has already been registered: {REWARD_LOOP_MANAGER_REGISTRY[name]} vs {cls}" + ) + REWARD_LOOP_MANAGER_REGISTRY[name] = cls + return cls + + return decorator + + +def get_reward_manager_cls(name: str) -> type[RewardManagerBase]: + """Get the reward manager class with a given name. + + Args: + name: `(str)` + The name of the reward manager. + + Returns: + `(type)`: The reward manager class. + """ + if name not in REWARD_LOOP_MANAGER_REGISTRY: + raise ValueError(f"Unknown reward manager: {name}") + return REWARD_LOOP_MANAGER_REGISTRY[name] diff --git a/code/RL_model/verl/verl_train/verl/experimental/reward_loop/reward_manager/remote.py b/code/RL_model/verl/verl_train/verl/experimental/reward_loop/reward_manager/remote.py new file mode 100644 index 0000000000000000000000000000000000000000..be841e78c734fa9b5732c084b1ea2db11e1ea733 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/experimental/reward_loop/reward_manager/remote.py @@ -0,0 +1,130 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import itertools + +import ray + +from verl import DataProto +from verl.experimental.reward_loop.reward_manager import register +from verl.experimental.reward_loop.reward_manager.base import RewardManagerBase +from verl.utils.reward_score import default_compute_score + + +@ray.remote(num_cpus=1) +class RewardComputeWorker: + """ + WARNING: This class cannot have async methods. + """ + + def __init__(self, compute_score_fn): + # since the reward function may not be pickleable, we need to init it in the worker + self.compute_score_fn = compute_score_fn + + def compute_score(self, **kwargs) -> dict: + return self.compute_score_fn(**kwargs) + + +@register("remote") +class RemoteRewardManager(RewardManagerBase): + """ + The reward manager. + Some errors exist when using default thread pool to compute reward score, e.g., math-verify. + https://github.com/volcengine/verl/issues/3407 + To avoid the above issues, we use a separate process to compute reward score. + Moreover, process may be more suitable for cpu-intensive requests. + """ + + def __init__(self, config, tokenizer, compute_score=None, reward_router_address=None, reward_model_tokenizer=None): + super().__init__(config, tokenizer) + self.compute_score = compute_score or default_compute_score + self.is_async_reward_score = inspect.iscoroutinefunction(self.compute_score) + assert not self.is_async_reward_score, "Async reward score is not supported in remote reward manager. " + self.reward_router_address = reward_router_address + self.reward_model_tokenizer = reward_model_tokenizer + num_reward_workers = config.reward_model.num_workers + # in the rollout & reward parallel mode + # the sum of final reward workers will be agent_loop_workers * num_reward_workers + self.reward_worker = [ + # register the reward worker in the same node + RewardComputeWorker.options( + scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy( + node_id=ray.get_runtime_context().get_node_id(), + soft=True, + ), + ).remote(self.compute_score) + for _ in range(num_reward_workers) + ] + self.reward_worker_pool = itertools.cycle(self.reward_worker) + + def choose_reward_worker(self): + return next(self.reward_worker_pool) + + async def run_single(self, data: DataProto) -> dict: + assert len(data) == 1, "Only support single data item" + data_item = data[0] + response_ids = data_item.batch["responses"] + response_length = response_ids.shape[-1] + valid_response_length = data_item.batch["attention_mask"][-response_length:].sum() + valid_response_ids = response_ids[:valid_response_length] + + data_source = data_item.non_tensor_batch["data_source"] + ground_truth = data_item.non_tensor_batch["reward_model"]["ground_truth"] + extra_info = data_item.non_tensor_batch.get("extra_info", {}) + tool_extra_fields = data_item.non_tensor_batch.get("tool_extra_fields", None) + if tool_extra_fields is not None: + extra_info.update(tool_extra_fields.items()) + + num_turns = data_item.non_tensor_batch.get("__num_turns__", None) + rollout_reward_scores = data_item.non_tensor_batch.get("reward_scores", {}) + extra_info["num_turns"] = num_turns + extra_info["rollout_reward_scores"] = rollout_reward_scores + + response_str = await self.loop.run_in_executor( + None, lambda: self.tokenizer.decode(valid_response_ids, skip_special_tokens=True) + ) + + extra_reward_kwargs = ( + { + "reward_router_address": self.reward_router_address, + "reward_model_tokenizer": self.reward_model_tokenizer, + } + if self.reward_router_address is not None + else {} + ) + + reward_worker = self.choose_reward_worker() + result = await reward_worker.compute_score.remote( + data_source=data_source, + solution_str=response_str, + ground_truth=ground_truth, + extra_info=extra_info, + **extra_reward_kwargs, + ) + + reward_extra_info = {} + + score: float + if isinstance(result, dict): + score = result["score"] + for key, value in result.items(): + reward_extra_info[key] = value + else: + score = result + reward_extra_info["acc"] = score + + reward = score + + return {"reward_score": reward, "reward_extra_info": reward_extra_info} diff --git a/code/RL_model/verl/verl_train/verl/experimental/reward_loop/router/inner_sglang_router.py b/code/RL_model/verl/verl_train/verl/experimental/reward_loop/router/inner_sglang_router.py new file mode 100644 index 0000000000000000000000000000000000000000..e05b17c89fc9f9e7fa8a4e1d1331e3e8e0f11412 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/experimental/reward_loop/router/inner_sglang_router.py @@ -0,0 +1,73 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import multiprocessing +import os +import time + +import ray +import requests +from sglang_router.launch_server import RouterArgs, launch_router + +from verl.utils.net_utils import get_free_port, is_valid_ipv6_address + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +def launch_router_process( + worker_urls: list[str], + request_timeout: int = 180, + max_wait_time: int = 300, + timeout: int = 30, +) -> str: + router_ip = ray.util.get_node_ip_address().strip("[]") + router_port, _ = get_free_port(router_ip) + router_address = ( + f"[{router_ip}]:{router_port}" if is_valid_ipv6_address(router_ip) else f"{router_ip}:{router_port}" + ) + router_args = RouterArgs( + host=router_ip, + port=router_port, + worker_urls=worker_urls, + balance_abs_threshold=0, + log_level="warn", + request_timeout_secs=request_timeout, + ) + router_process = multiprocessing.Process(target=launch_router, args=(router_args,)) + router_process.daemon = True + router_process.start() + time.sleep(3) + assert router_process.is_alive() + + # health check + start_time = time.time() + url = f"http://{router_address}/health" + with requests.Session() as session: + while time.time() - start_time < max_wait_time: + try: + response = session.get(url, timeout=timeout) + if response.status_code == 200: + break + except requests.RequestException as e: + logger.debug(f"Health check failed: {e}") + + time.sleep(2) + else: + router_process.terminate() + raise RuntimeError(f"Router health check failed after {max_wait_time} seconds.") + + logger.info(f"Router is running on {router_address}") + return router_address, router_process diff --git a/code/RL_model/verl/verl_train/verl/experimental/reward_loop/router/naive_router.py b/code/RL_model/verl/verl_train/verl/experimental/reward_loop/router/naive_router.py new file mode 100644 index 0000000000000000000000000000000000000000..a495c0592e3cf882b521584c1dcf1824a7cee18f --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/experimental/reward_loop/router/naive_router.py @@ -0,0 +1,183 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +import multiprocessing +import os +import time +from typing import Any + +import aiohttp +import ray +import uvicorn +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse + +from verl.utils.net_utils import get_free_port, is_valid_ipv6_address + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +async def _read_async_response(resp: aiohttp.ClientResponse) -> dict[str, Any]: + if resp.status == 204 or (resp.content_length == 0): + return {} + + try: + return await resp.json(content_type=None) + except Exception: + try: + text = await resp.text() + except Exception: + return {} + return { + "content_type": (resp.headers.get("Content-Type") or ""), + "text": text, + } + + +def launch_router_process( + worker_urls: list[str], +): + router_ip = ray.util.get_node_ip_address().strip("[]") + router_port, _ = get_free_port(router_ip) + router_address = ( + f"[{router_ip}]:{router_port}" if is_valid_ipv6_address(router_ip) else f"{router_ip}:{router_port}" + ) + + router_process = multiprocessing.Process( + target=run_router, + args=( + router_ip, + router_port, + worker_urls, + ), + ) + router_process.daemon = True + router_process.start() + time.sleep(3) + assert router_process.is_alive() + + logger.info(f"Router is running on {router_address}") + return router_address, router_process + + +def run_router(router_ip: str, router_port: int, worker_urls: list[str]): + router = NaiveRouter(worker_urls=worker_urls, verbose=False) + uvicorn.run(router.app, host=router_ip, port=router_port, log_level="warning") + + +class NaiveRouter: + def __init__( + self, + worker_urls: list[str], + max_connections: int = 1024, + timeout: int = 60, + max_attempts: int = 3, + retry_delay: float = 2.0, + verbose: bool = False, + ) -> None: + """A minimal async load-balancing router.""" + self.verbose = verbose + self.app = FastAPI() + self.worker_urls = worker_urls + self.request_counts = {url: 0 for url in worker_urls} + + self.max_connections = max_connections + self.timeout = timeout + self.max_attempts = max_attempts + self.retry_delay = retry_delay + + self.app = FastAPI() + + # Register startup / shutdown hooks + self.app.on_event("startup")(self._on_startup) + self.app.on_event("shutdown")(self._on_shutdown) + + # Catch-all proxy route + self.app.api_route("/{endpoint:path}", methods=["GET", "POST"])(self._make_async_request) + + # Placeholder for aiohttp client + self.client = None + + async def _on_startup(self): + """Initialize aiohttp client safely inside the event loop""" + connector = aiohttp.TCPConnector( + limit=self.max_connections, + limit_per_host=self.max_connections // 4, + ttl_dns_cache=300, + use_dns_cache=True, + ) + timeout = aiohttp.ClientTimeout(total=None) + self.client = aiohttp.ClientSession(connector=connector, timeout=timeout) + if self.verbose: + logger.info(f"[router] aiohttp client initialized with max_connections={self.max_connections}") + + async def _on_shutdown(self): + """Gracefully close aiohttp client""" + if self.client and not self.client.closed: + await self.client.close() + if self.verbose: + logger.info("[router] aiohttp client closed") + + async def _make_async_request(self, request: Request, endpoint: str): + """Proxy single request to a worker URL.""" + if not self.worker_urls: + return JSONResponse(status_code=503, content={"error": "No available workers"}) + + worker_url = self._select_worker() + target_url = f"{worker_url}/{endpoint}" + + if self.verbose: + logger.debug(f"[router] Forwarding request → {target_url}") + + # Copy request data + body = await request.body() + headers = dict(request.headers) + + for attempt in range(self.max_attempts): + # Send request to worker + try: + async with self.client.request(request.method, target_url, data=body, headers=headers) as response: + response.raise_for_status() + output = await _read_async_response(response) + self._release_worker(worker_url) + return output + except asyncio.TimeoutError: + logger.warning(f"Async request to {endpoint} timed out (attempt {attempt + 1})") + except aiohttp.ClientConnectorError: + logger.warning(f"Connection error for {endpoint} (attempt {attempt + 1})") + except aiohttp.ClientResponseError as e: + logger.error(f"HTTP error for {endpoint}: {e}") + raise + except Exception as e: + logger.error(f"Unexpected error for {endpoint}: {e}") + if attempt == self.max_attempts - 1: + raise + + if attempt < self.max_attempts - 1: + await asyncio.sleep(self.retry_delay * (2**attempt)) + + raise RuntimeError(f"Failed to complete async request to {endpoint} after {self.max_attempts} attempts") + + def _select_worker(self) -> str: + """Select the least-loaded worker (simple round-robin by request count).""" + url = min(self.request_counts, key=self.request_counts.get) + self.request_counts[url] += 1 + return url + + def _release_worker(self, url: str) -> None: + """Mark worker as free after request completes.""" + self.request_counts[url] = max(0, self.request_counts[url] - 1) diff --git a/code/RL_model/verl/verl_train/verl/models/llama/__init__.py b/code/RL_model/verl/verl_train/verl/models/llama/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1ce90c5eb352d85c59105c0dc85b5f1dd576f095 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/llama/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/code/RL_model/verl/verl_train/verl/models/llama/megatron/__init__.py b/code/RL_model/verl/verl_train/verl/models/llama/megatron/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fc851ea435ff43ad31eff24dc729df0e78cf8bee --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/llama/megatron/__init__.py @@ -0,0 +1,34 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .modeling_llama_megatron import ( + ParallelLlamaForCausalLM, + # rmpad with megatron + ParallelLlamaForCausalLMRmPad, + # rmpad with megatron and pipeline parallelism + ParallelLlamaForCausalLMRmPadPP, + ParallelLlamaForValueRmPad, + ParallelLlamaForValueRmPadPP, + # original model with megatron + ParallelLlamaModel, +) + +__all__ = [ + "ParallelLlamaForCausalLM", + "ParallelLlamaForCausalLMRmPad", + "ParallelLlamaForCausalLMRmPadPP", + "ParallelLlamaForValueRmPad", + "ParallelLlamaForValueRmPadPP", + "ParallelLlamaModel", +] diff --git a/code/RL_model/verl/verl_train/verl/models/llama/megatron/checkpoint_utils/__init__.py b/code/RL_model/verl/verl_train/verl/models/llama/megatron/checkpoint_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1ce90c5eb352d85c59105c0dc85b5f1dd576f095 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/llama/megatron/checkpoint_utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/code/RL_model/verl/verl_train/verl/models/llama/megatron/checkpoint_utils/llama_loader.py b/code/RL_model/verl/verl_train/verl/models/llama/megatron/checkpoint_utils/llama_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..dafecfdf084e81d2e72df9151fb3c593770127ac --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/llama/megatron/checkpoint_utils/llama_loader.py @@ -0,0 +1,317 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import torch +import torch.distributed as dist + +from verl.utils.device import get_device_id, get_torch_device + + +def _megatron_calc_layer_map(config): + """Calculate the mapping of global layer_idx to local layer_idx + Returns: + layer_map (Dict: int -> tuple(int, int, int)): + mapping from the global layer index to + a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) + """ + from megatron.core import mpu + + print(f"get megatron data parallel size: {mpu.get_data_parallel_world_size()}") + + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + + layer_map = dict() + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + for pp_rank_idx in range(pp_size): + for virtual_pp_rank_idx in range(virtual_pp_size): + layer_offset = ( + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + ) + for layer_idx in range(num_layers_per_model): + layer_map[layer_offset + layer_idx] = ( + pp_rank_idx, + virtual_pp_rank_idx, + layer_idx, + ) + return layer_map + + +def load_state_dict_to_megatron_llama( + state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False +): + """Load merged state_dict to sharded Megatron module in training.""" + from megatron.core import DistributedDataParallel as LocalDDP + from megatron.core import mpu + from megatron.core.transformer.module import Float16Module + from torch.nn.parallel import DistributedDataParallel as torchDDP + + from verl.utils.logger import print_rank_0 + from verl.utils.megatron_utils import unwrap_model + + start_time = time.time() + + def _get_gpt_model(model): + return model + + def fetch_params(module): + for param in module.parameters(): + torch.distributed.fetch( + param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group() + ) + + dp_rank = mpu.get_data_parallel_rank() + pp_rank = mpu.get_pipeline_model_parallel_rank() + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + mp_group = mpu.get_model_parallel_group() + + if torch.distributed.get_rank() == 0: + assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" + assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" + assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" + + if not isinstance(wrapped_models, list | tuple): + wrapped_models = list(wrapped_models) + + assert len(wrapped_models) == virtual_pp_size + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, ( + f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size " + f"{virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}" + ) + + models = [None] * len(wrapped_models) + + for i, wrapped_model in enumerate(wrapped_models): + models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) + gpt_model_module = _get_gpt_model(models[i]) + assert len(gpt_model_module.model.layers) == num_layers_per_model + + def _fetch_tensor(tensor, name) -> torch.Tensor: + """fetch tensor""" + nonlocal state_dict + if tensor is not None: + tensor.data.copy_(state_dict[name]) + + def _fetch_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + """fetch tensor in tp shards""" + nonlocal state_dict + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + if name in state_dict: + full_weight = state_dict[name] + + if mutate_func is not None: + full_weight = mutate_func(full_weight) + tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) + if tensor is not None: + tensor.data.copy_(tensor_chunk[tp_rank]) + else: + print(f"tp_shard tensor:[{name}] not in state_dict, skip loading") + + def _fetch_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + """fetch tensor in tp shards""" + nonlocal state_dict + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + if name in state_dict: + full_weight = state_dict[name] + + if mutate_func is not None: + full_weight = mutate_func(full_weight) + tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) + if tensor is not None: + tensor.data.copy_(tensor_chunk[tp_rank]) + else: + print(f"tp_shard tensor:[{name}] not in state_dict, skip loading") + + def _fetch_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor: + """fetch gate_up tensor in tp shards""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + if gate_name in state_dict and up_name in state_dict: + gate_weight = state_dict[gate_name] + up_weight = state_dict[up_name] + new_gate_up_weight = torch.empty( + config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) + for i in range(tp_size): + intermediate_size_tp = config.intermediate_size // tp_size + gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_( + torch.cat([gate_weight_tp, up_weight_tp], dim=0) + ) + + tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) + if tensor is not None: + tensor.data.copy_(tensor_chunk[tp_rank]) + else: + print(f"tp_shard tensor:[{gate_name}, {up_name}] not in state_dict, skip loading") + + def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tensor: + """fetch tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + assert q_name in state_dict and k_name in state_dict and v_name in state_dict + full_weight_q = state_dict[q_name] + full_weight_k = state_dict[k_name] + full_weight_v = state_dict[v_name] + + hidden_size_per_head = config.hidden_size // config.num_attention_heads + + if config.num_key_value_heads >= tp_size: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size + total_size = q_size_tp + 2 * kv_size_tp + new_weight_qkv = torch.empty( + total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) + for i in range(tp_size): + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] + k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] + v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp] + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) + + else: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head + total_size = q_size_tp + 2 * kv_size_tp + new_weight_qkv = torch.empty( + total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) + for i in range(tp_size): + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] + start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head + end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head + k_part = full_weight_k[start_idx:end_idx] + v_part = full_weight_v[start_idx:end_idx] + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) + + tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) + if tensor is not None: + tensor.data.copy_(tensor_chunk[tp_rank]) + + # Embeddings + # ------------------- + print_rank_0("loading embeddings...") + gpt_model_module = _get_gpt_model(models[0]) + embed_tokens_weight = None + if pp_rank == 0: + embed_tokens_weight = gpt_model_module.model.embed_tokens.weight + _fetch_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight") + + # Transformer layers + # ------------------- + layer_map = _megatron_calc_layer_map(config) + + pp_rank = mpu.get_pipeline_model_parallel_rank() + pp_size = mpu.get_pipeline_model_parallel_world_size() + num_layer_per_pp = config.num_hidden_layers // pp_size + vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size() + + layer_list = [] + if vpp_size is not None: + for vpp_rank in range(vpp_size): + num_layer_vpp_chunk = num_layer_per_pp // vpp_size + num_layer_this_model = num_layer_vpp_chunk + offset = vpp_rank * (config.num_hidden_layers // mpu.get_virtual_pipeline_model_parallel_world_size()) + ( + mpu.get_pipeline_model_parallel_rank() * num_layer_vpp_chunk + ) + layer_list.extend(list(range(offset, offset + num_layer_this_model))) + else: + num_layer_this_model = num_layer_per_pp + offset = pp_rank * num_layer_per_pp + layer_list.extend(list(range(offset, offset + num_layer_this_model))) + + for layer in layer_list: + print_rank_0(f"loading layer #{layer}...") + layer_name = f"model.layers.{layer}" + dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer] + + gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank]) + sync_layer = gpt_model_module.model.layers[dst_layer_idx] + + _fetch_tensor( + sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.input_layernorm.weight", + ) + + _fetch_tp_shard_tensor_qkv( + sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.q_proj.weight", + f"{layer_name}.self_attn.k_proj.weight", + f"{layer_name}.self_attn.v_proj.weight", + ) + + _fetch_tp_shard_tensor( + sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.o_proj.weight", + chunk_dim=1, + ) + + _fetch_tensor( + sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.post_attention_layernorm.weight", + ) + + _fetch_tp_shard_tensor_gate_up( + sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.mlp.gate_proj.weight", + f"{layer_name}.mlp.up_proj.weight", + ) + + _fetch_tp_shard_tensor( + sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.mlp.down_proj.weight", + chunk_dim=1, + ) + # Final Layernorm + # ------------------- + print_rank_0("loading final layernorm...") + gpt_model_module = _get_gpt_model(models[-1]) + _fetch_tensor( + getattr(gpt_model_module.model.norm, "weight", None), + "model.norm.weight", + ) + + print_rank_0("loading lm_head...") + if pp_rank + 1 == pp_size: + lm_head_weight = gpt_model_module.lm_head.weight + + if is_value_model: + if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1: + _fetch_tensor(lm_head_weight, "lm_head.weight") + print_rank_0("load lm_head weight") + elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1: + _fetch_tensor(lm_head_weight, "reward_head.weight") + print_rank_0("load lm_head from value_head weight") + else: + _fetch_tensor(None, "lm_head.weight") + print_rank_0("fail to match lm_head in value_model") + else: + _fetch_tp_shard_tensor(lm_head_weight, "lm_head.weight") + + dist.barrier() + get_torch_device().empty_cache() + print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") diff --git a/code/RL_model/verl/verl_train/verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py b/code/RL_model/verl/verl_train/verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py new file mode 100644 index 0000000000000000000000000000000000000000..2f65bc6b1701bdb79cf1ed282de0212bd6396fdc --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py @@ -0,0 +1,458 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import torch +import torch.distributed as dist + +from verl.utils.device import get_device_id, get_torch_device + + +def _megatron_calc_layer_map(config): + """Calculate the mapping of global layer_idx to local layer_idx + Returns: + layer_map (Dict: int -> tuple(int, int, int)): + mapping from the global layer index to + a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) + """ + from megatron.core import mpu + + print(f"get megatron data parallel size: {mpu.get_data_parallel_world_size()}") + + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + + layer_map = dict() + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + for pp_rank_idx in range(pp_size): + for virtual_pp_rank_idx in range(virtual_pp_size): + layer_offset = ( + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + ) + for layer_idx in range(num_layers_per_model): + layer_map[layer_offset + layer_idx] = ( + pp_rank_idx, + virtual_pp_rank_idx, + layer_idx, + ) + return layer_map + + +def load_state_dict_to_megatron_llama( + state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False +): + """Load merged state_dict to sharded Megatron module in training.""" + from megatron.core import DistributedDataParallel as LocalDDP + from megatron.core import mpu + from megatron.core.transformer.module import Float16Module + from torch.nn.parallel import DistributedDataParallel as torchDDP + + from verl.utils.logger import print_rank_0 + from verl.utils.megatron_utils import unwrap_model + + start_time = time.time() + + def _get_gpt_model(model): + return model + + def broadcast_params(module): + for param in module.parameters(): + torch.distributed.broadcast( + param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group() + ) + + dp_rank = mpu.get_data_parallel_rank() + pp_rank = mpu.get_pipeline_model_parallel_rank() + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + mp_group = mpu.get_model_parallel_group() + + if torch.distributed.get_rank() == 0: + assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" + assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" + assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" + + if not isinstance(wrapped_models, list | tuple): + wrapped_models = list(wrapped_models) + + assert len(wrapped_models) == virtual_pp_size + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, ( + f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size " + f"{virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}" + ) + + models = [None] * len(wrapped_models) + + for i, wrapped_model in enumerate(wrapped_models): + models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) + gpt_model_module = _get_gpt_model(models[i]) + assert len(gpt_model_module.model.layers) == num_layers_per_model + + def _broadcast_tensor(tensor, name) -> torch.Tensor: + """broadcast tensor from rank0 across mp_group""" + nonlocal state_dict + nonlocal mp_group + if torch.distributed.get_rank() == 0: + if name in state_dict: + weight = state_dict[name] + tensor_shape = weight.shape + else: + tensor_shape = None + else: + weight = None + tensor_shape = None + + obj_list = [tensor_shape] + dist.broadcast_object_list(obj_list, src=0, group=mp_group) + tensor_shape = obj_list[0] + + if tensor_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tensor:[{name}] not in state_dict, skip load") + return + + if tensor is None: + tensor = torch.empty( + tensor_shape, + dtype=params_dtype, + device=get_device_id(), + requires_grad=False, + ) + if torch.distributed.get_rank() == 0: + tensor.data.copy_(weight) + dist.broadcast(tensor, src=0, group=mp_group) + + def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == 0: + if name in state_dict: + full_weight = state_dict[name] + + if mutate_func is not None: + full_weight = mutate_func(full_weight) + tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=0, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=get_device_id(), + requires_grad=False, + ) + else: + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + ) + sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == 0: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=0, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == 0: + if name in state_dict: + full_weight = state_dict[name] + if mutate_func is not None: + full_weight = mutate_func(full_weight) + tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=0, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=get_device_id(), + requires_grad=False, + ) + else: + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + ) + sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == 0: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=0, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == 0: + gate_weight = state_dict[gate_name] + up_weight = state_dict[up_name] + new_gate_up_weight = torch.empty( + config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) + for i in range(tp_size): + intermediate_size_tp = config.intermediate_size // tp_size + gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_( + torch.cat([gate_weight_tp, up_weight_tp], dim=0) + ) + + tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=0, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=get_device_id(), + requires_grad=False, + ) + else: + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape " + f"{tensor.shape} != {chunk_shape}" + ) + sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == 0: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=0, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == 0: + assert q_name in state_dict and k_name in state_dict and v_name in state_dict + full_weight_q = state_dict[q_name] + full_weight_k = state_dict[k_name] + full_weight_v = state_dict[v_name] + + hidden_size_per_head = config.hidden_size // config.num_attention_heads + + if config.num_key_value_heads >= tp_size: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size + total_size = q_size_tp + 2 * kv_size_tp + new_weight_qkv = torch.empty( + total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) + for i in range(tp_size): + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] + k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] + v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp] + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_( + torch.cat([q_part, k_part, v_part], dim=0) + ) + + else: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head + total_size = q_size_tp + 2 * kv_size_tp + new_weight_qkv = torch.empty( + total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) + for i in range(tp_size): + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] + start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head + end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head + k_part = full_weight_k[start_idx:end_idx] + v_part = full_weight_v[start_idx:end_idx] + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_( + torch.cat([q_part, k_part, v_part], dim=0) + ) + + tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=0, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=get_device_id(), + requires_grad=False, + ) + else: + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}" + ) + sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == 0: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=0, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + if dp_rank == 0: + # Embeddings + # ------------------- + print_rank_0("loading embeddings...") + gpt_model_module = _get_gpt_model(models[0]) + embed_tokens_weight = None + if pp_rank == 0: + embed_tokens_weight = gpt_model_module.model.embed_tokens.weight + _broadcast_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight") + + # Transformer layers + # ------------------- + layer_map = _megatron_calc_layer_map(config) + + for layer in range(config.num_hidden_layers): + print_rank_0(f"loading layer #{layer}...") + layer_name = f"model.layers.{layer}" + dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer] + + gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank]) + sync_layer = gpt_model_module.model.layers[dst_layer_idx] + + _broadcast_tensor( + sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.input_layernorm.weight", + ) + + _broadcast_tp_shard_tensor_qkv( + sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.q_proj.weight", + f"{layer_name}.self_attn.k_proj.weight", + f"{layer_name}.self_attn.v_proj.weight", + ) + + _broadcast_tp_shard_tensor( + sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.o_proj.weight", + chunk_dim=1, + ) + + _broadcast_tensor( + sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.post_attention_layernorm.weight", + ) + + _broadcast_tp_shard_tensor_gate_up( + sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.mlp.gate_proj.weight", + f"{layer_name}.mlp.up_proj.weight", + ) + + _broadcast_tp_shard_tensor( + sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.mlp.down_proj.weight", + chunk_dim=1, + ) + # Final Layernorm + # ------------------- + print_rank_0("loading final layernorm...") + gpt_model_module = _get_gpt_model(models[-1]) + _broadcast_tensor( + getattr(gpt_model_module.model.norm, "weight", None), + "model.norm.weight", + ) + + print_rank_0("loading lm_head...") + lm_head_weight = None + if pp_rank + 1 == pp_size: + lm_head_weight = gpt_model_module.lm_head.weight + + if is_value_model: + if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1: + _broadcast_tensor(lm_head_weight, "lm_head.weight") + print_rank_0("load lm_head weight") + elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1: + _broadcast_tensor(lm_head_weight, "reward_head.weight") + print_rank_0("load lm_head from value_head weight") + else: + _broadcast_tensor(None, "lm_head.weight") + print_rank_0("fail to match lm_head in value_model") + else: + _broadcast_tp_shard_tensor(lm_head_weight, "lm_head.weight") + dist.barrier() + # Broadcast weights inside data parallel groups + for wrapped_model in wrapped_models: + broadcast_params(wrapped_model) + + get_torch_device().empty_cache() + print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") diff --git a/code/RL_model/verl/verl_train/verl/models/llama/megatron/checkpoint_utils/llama_saver.py b/code/RL_model/verl/verl_train/verl/models/llama/megatron/checkpoint_utils/llama_saver.py new file mode 100644 index 0000000000000000000000000000000000000000..595efcde376ea498ee65bc39310060a046b83d1b --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/llama/megatron/checkpoint_utils/llama_saver.py @@ -0,0 +1,442 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import torch +import torch.distributed as dist +from megatron.core import mpu +from megatron.core.distributed import DistributedDataParallel as LocalDDP +from megatron.core.transformer.module import Float16Module +from torch.nn.parallel import DistributedDataParallel as torchDDP + +from verl.utils.device import get_device_id, get_torch_device +from verl.utils.logger import print_rank_0 +from verl.utils.megatron_utils import unwrap_model + + +def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0): + """given TP,DP,PP rank to get the global rank.""" + + tp_size = mpu.get_tensor_model_parallel_world_size() + dp_size = mpu.get_data_parallel_world_size() + pp_size = mpu.get_pipeline_model_parallel_world_size() + assert tp_size * dp_size * pp_size == torch.distributed.get_world_size(), ( + f"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}" + ) + # We only support TP-DP-PP grouping, for correctness when resharding + return (pp_rank * dp_size + dp_rank) * tp_size + tp_rank + + +def _megatron_calc_layer_map(config): + """Calculate the mapping of global layer_idx to local layer_idx + Returns: + layer_map (Dict: int -> tuple(int, int, int)): + mapping from the global layer index to + a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) + """ + from megatron.core import mpu + + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + + layer_map = dict() + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + for pp_rank_idx in range(pp_size): + for virtual_pp_rank_idx in range(virtual_pp_size): + layer_offset = ( + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + ) + for layer_idx in range(num_layers_per_model): + layer_map[layer_offset + layer_idx] = ( + pp_rank_idx, + virtual_pp_rank_idx, + layer_idx, + ) + return layer_map + + +def merge_megatron_ckpt_llama(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False): + """Merge sharded parameters of a Megatron module into a merged checkpoint. + + Args: + wrapped_models (list of megatron.core.distributed.DistributedDataParallel): + The local DDP wrapped megatron modules. + config (str or None): + HF config for model + dtype: model params type + is_value_model: if model is value model + tie_word_embeddings: tie_word_embeddings, not used in llama, only to keep same interface with qwen2 + Returns: + state_dict (dict): + The merged state_dict in rank 0, and an empty dictionary in other ranks. + """ + start_time = time.time() + + def _get_gpt_model(model): + return model + + dp_rank = mpu.get_data_parallel_rank() + pp_size = mpu.get_pipeline_model_parallel_world_size() + pp_rank = mpu.get_pipeline_model_parallel_rank() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + mp_group = mpu.get_model_parallel_group() + + if dist.get_rank() == 0: + assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" + assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" + assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" + + if not isinstance(wrapped_models, list | tuple): + wrapped_models = list(wrapped_models) + + assert len(wrapped_models) == virtual_pp_size + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + models = [None] * len(wrapped_models) + + for i, wrapped_model in enumerate(wrapped_models): + models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) + assert len(models[i].model.layers) == num_layers_per_model, ( + "len model layers {} not equal to num_layers_per_model {}".format( + len(models[i].model.layers), num_layers_per_model + ) + ) + + state_dict = dict() + + def _get_cpu_tensor(tensor: torch.Tensor): + if tensor is None: + return None + if tensor.device == torch.device("cpu"): + return tensor.detach().clone() + return tensor.detach().cpu() + + def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor: + """broadcast tensor across mp_group""" + nonlocal state_dict + nonlocal mp_group + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) + + if torch.distributed.get_rank() == src_rank: + if tensor is None: + weight = None + tensor_shape = None + else: + weight = tensor + tensor_shape = weight.shape + else: + weight = None + tensor_shape = None + + obj_list = [tensor_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + tensor_shape = obj_list[0] + + if tensor_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tensor:[{name}] not exist, skip collect") + return + + if weight is None: + weight = torch.empty( + tensor_shape, + dtype=dtype, + device=get_device_id(), + requires_grad=False, + ) + + dist.broadcast(weight, src=src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + state_dict[name] = _get_cpu_tensor(weight) + + def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_size = mpu.get_tensor_model_parallel_world_size() + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) + + chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{name}] not exist, skip collecting") + return + + buffer_tensor = torch.empty( + chunk_shape, + dtype=dtype, + device=get_device_id(), + requires_grad=False, + ) + + chunk_tensors = [None] * tp_size + + for i in range(tp_size): + cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) + sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + chunk_tensors[i] = _get_cpu_tensor(sync_tensor) + + if torch.distributed.get_rank() == 0: + full_tensor = torch.concat(chunk_tensors, dim=concat_dim) + if mutate_func is not None: + full_tensor = mutate_func(full_tensor) + state_dict[name] = full_tensor + + def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_size = mpu.get_tensor_model_parallel_world_size() + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) + + chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting") + return + + buffer_tensor = torch.empty( + chunk_shape, + dtype=dtype, + device=get_device_id(), + requires_grad=False, + ) + + chunk_tensors = [None] * tp_size + + for i in range(tp_size): + cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) + sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + chunk_tensors[i] = _get_cpu_tensor(sync_tensor) + + if torch.distributed.get_rank() == 0: + full_tensor = torch.concat(chunk_tensors, dim=0) + intermediate_size_tp = config.intermediate_size // tp_size + gate_weight_list = [] + up_weight_list = [] + for i in range(tp_size): + gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)] + gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp] + up_weight_tp = gate_up_weight_tp[intermediate_size_tp:] + gate_weight_list.append(gate_weight_tp) + up_weight_list.append(up_weight_tp) + + state_dict[gate_name] = torch.cat(gate_weight_list, dim=0) + state_dict[up_name] = torch.cat(up_weight_list, dim=0) + + def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_size = mpu.get_tensor_model_parallel_world_size() + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) + + chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{q_name}] not exist, skip collecting") + return + + buffer_tensor = torch.empty( + chunk_shape, + dtype=dtype, + device=get_device_id(), + requires_grad=False, + ) + + chunk_tensors = [None] * tp_size + + for i in range(tp_size): + cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) + sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + chunk_tensors[i] = _get_cpu_tensor(sync_tensor) + + if torch.distributed.get_rank() == 0: + full_tensor = torch.concat(chunk_tensors, dim=0) + q_weight_list = [] + k_weight_list = [] + v_weight_list = [] + hidden_size_per_head = config.hidden_size // config.num_attention_heads + + if config.num_key_value_heads >= tp_size: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size + total_size = q_size_tp + 2 * kv_size_tp + for i in range(tp_size): + qkv_part = full_tensor[i * total_size : (i + 1) * total_size] + q_part = qkv_part[:q_size_tp] + k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp] + v_part = qkv_part[q_size_tp + kv_size_tp : total_size] + q_weight_list.append(q_part) + k_weight_list.append(k_part) + v_weight_list.append(v_part) + else: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head + total_size = q_size_tp + 2 * kv_size_tp + for i in range(tp_size): + qkv_part = full_tensor[i * total_size : (i + 1) * total_size] + q_part = qkv_part[:q_size_tp] + k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp] + v_part = qkv_part[q_size_tp + kv_size_tp : total_size] + q_weight_list.append(q_part) + if i * config.num_key_value_heads % tp_size == 0: + k_weight_list.append(k_part) + v_weight_list.append(v_part) + + state_dict[q_name] = torch.cat(q_weight_list, dim=0) + state_dict[k_name] = torch.cat(k_weight_list, dim=0) + state_dict[v_name] = torch.cat(v_weight_list, dim=0) + + # empty cache before collecting weights + get_torch_device().empty_cache() + # Embeddings + # ------------------- + if dp_rank == 0: + # Embeddings + # ------------------- + print_rank_0("collecting embeddings...") + gpt_model_module = _get_gpt_model(models[0]) + _broadcast_tp_shard_tensor( + gpt_model_module.model.embed_tokens.weight if pp_rank == 0 else None, + "model.embed_tokens.weight", + src_pp_rank=0, + ) + + # Transformer layers + # ------------------- + layer_map = _megatron_calc_layer_map(config) + for layer in range(config.num_hidden_layers): + print_rank_0(f"collecting layer #{layer}...") + layer_name = f"model.layers.{layer}" + src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer] + + gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank]) + sync_layer = gpt_model_module.model.layers[src_layer_idx] + + _broadcast_tensor( + sync_layer.input_layernorm.weight, + f"{layer_name}.input_layernorm.weight", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor_qkv( + sync_layer.self_attn.qkv_proj.weight, + f"{layer_name}.self_attn.q_proj.weight", + f"{layer_name}.self_attn.k_proj.weight", + f"{layer_name}.self_attn.v_proj.weight", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor( + sync_layer.self_attn.o_proj.weight, + f"{layer_name}.self_attn.o_proj.weight", + concat_dim=1, + src_pp_rank=src_pp_rank, + ) + + _broadcast_tensor( + sync_layer.post_attention_layernorm.weight, + f"{layer_name}.post_attention_layernorm.weight", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor_gate_up( + sync_layer.mlp.gate_up_proj.weight, + f"{layer_name}.mlp.gate_proj.weight", + f"{layer_name}.mlp.up_proj.weight", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor( + sync_layer.mlp.down_proj.weight, + f"{layer_name}.mlp.down_proj.weight", + concat_dim=1, + src_pp_rank=src_pp_rank, + ) + + # Final Layernorm + # ------------------- + print_rank_0("collecting final layernorm...") + gpt_model_module = _get_gpt_model(models[-1]) + _broadcast_tensor( + getattr(gpt_model_module.model.norm, "weight", None), + "model.norm.weight", + src_pp_rank=pp_size - 1, + ) + + print_rank_0("collecting lm_head...") + + if is_value_model: + if pp_rank == pp_size - 1: + print(f"gpt_model_module.lm_head.weight: {gpt_model_module.lm_head.weight.shape}") + _broadcast_tensor( + gpt_model_module.lm_head.weight if pp_rank == pp_size - 1 else None, + "lm_head.weight", + src_pp_rank=pp_size - 1, + ) + _broadcast_tensor( + gpt_model_module.reward_head.weight + if pp_rank == pp_size - 1 and getattr(gpt_model_module, "reward_weight", None) is not None + else None, + "reward_head.weight", + src_pp_rank=pp_size - 1, + ) + + else: + _broadcast_tp_shard_tensor( + getattr(gpt_model_module.lm_head, "weight", None) if pp_rank == pp_size - 1 else None, + "lm_head.weight", + src_pp_rank=pp_size - 1, + ) + + dist.barrier() + + get_torch_device().empty_cache() + if torch.distributed.get_rank() == 0: + if dtype not in [torch.float16, torch.bfloat16, torch.float32]: + print(f'Unknown/unsupported dtype to save: {dtype}"') + exit(1) + for k, v in state_dict.items(): + if dtype != v.dtype: + state_dict[k] = v.to(dtype) + + print_rank_0(f"merge megatron ckpt done, time elapsed {time.time() - start_time}s") + return state_dict diff --git a/code/RL_model/verl/verl_train/verl/models/llama/megatron/layers/__init__.py b/code/RL_model/verl/verl_train/verl/models/llama/megatron/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..352bc56086dcf1e7e2a6534f0e6e506796a1fb6d --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/llama/megatron/layers/__init__.py @@ -0,0 +1,34 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .parallel_attention import ParallelLlamaAttention +from .parallel_decoder import ParallelLlamaDecoderLayer, ParallelLlamaDecoderLayerRmPad +from .parallel_linear import ( + LinearForLastLayer, + MergedColumnParallelLinear, + QKVParallelLinear, +) +from .parallel_mlp import ParallelLlamaMLP +from .parallel_rmsnorm import ParallelLlamaRMSNorm + +__all__ = [ + "LinearForLastLayer", + "MergedColumnParallelLinear", + "QKVParallelLinear", + "ParallelLlamaAttention", + "ParallelLlamaDecoderLayer", + "ParallelLlamaDecoderLayerRmPad", + "ParallelLlamaMLP", + "ParallelLlamaRMSNorm", +] diff --git a/code/RL_model/verl/verl_train/verl/models/llama/megatron/layers/parallel_attention.py b/code/RL_model/verl/verl_train/verl/models/llama/megatron/layers/parallel_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..4f76b991abda8038db299d09cc230b6051479d47 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/llama/megatron/layers/parallel_attention.py @@ -0,0 +1,460 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional + +import torch +import torch.nn.functional as F +from einops import rearrange +from flash_attn.layers.rotary import apply_rotary_emb +from megatron.core import ModelParallelConfig, tensor_parallel +from megatron.core import parallel_state as mpu +from torch import nn +from transformers import LlamaConfig +from transformers.utils import is_flash_attn_2_available + +from verl.models.llama.megatron.layers.parallel_linear import QKVParallelLinear +from verl.utils.megatron import tensor_parallel as tp_utils + + +class LlamaRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = t / self.scaling_factor + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +class LlamaLlama3ScalingRotaryEmbedding(LlamaRotaryEmbedding): + def __init__(self, dim, config, max_position_embeddings=2048, base=10000, device=None): + super().__init__(dim, max_position_embeddings, base, device) + + self.factor = config.rope_scaling["factor"] # `8` in the original implementation + self.high_freq_factor = config.rope_scaling["high_freq_factor"] # `1` in the original implementation + self.low_freq_factor = config.rope_scaling["low_freq_factor"] # `4` in the original implementation + self.old_context_len = config.rope_scaling[ + "original_max_position_embeddings" + ] # `8192` in the original implementation + + low_freq_wavelen = self.old_context_len / self.low_freq_factor + high_freq_wavelen = self.old_context_len / self.high_freq_factor + + wavelen = 2 * math.pi / self.inv_freq + # wavelen < high_freq_wavelen: do nothing; wavelen > low_freq_wavelen: divide by factor + inv_freq_llama = torch.where(wavelen > low_freq_wavelen, self.inv_freq / self.factor, self.inv_freq) + # otherwise: interpolate between the two, using a smooth factor + smooth_factor = (self.old_context_len / wavelen - self.low_freq_factor) / ( + self.high_freq_factor - self.low_freq_factor + ) + smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / self.factor + smooth_factor * inv_freq_llama + is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) + inv_freq = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class ParallelLlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): + super().__init__() + self.config = config + self.megatron_config = megatron_config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + + # assign values after tp + tp_size = mpu.get_tensor_model_parallel_world_size() + assert self.num_heads % tp_size == 0, ( + f"num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}" + ) + assert self.num_key_value_heads % tp_size == 0, ( + f"num_key_value_heads must be divisible by tp_size. Got num_key_value_heads=" + f"{self.num_key_value_heads}, tp_size={tp_size}" + ) + + self.num_heads_per_tp = self.num_heads // tp_size + self.num_key_value_heads_per_tp = self.num_key_value_heads // tp_size + self.hidden_size_per_tp = self.hidden_size // tp_size + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and " + f"`num_heads`: {self.num_heads})." + ) + + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear() + + if megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + assert row_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(column_kwargs, megatron_config) + tp_utils.update_kwargs_with_config(row_kwargs, megatron_config) + + # [self.q_size, self.k_size, self.v_size] + self.qkv_proj = QKVParallelLinear( + input_size=self.hidden_size, + num_heads=self.num_heads, + num_key_value_heads=self.num_key_value_heads, + head_dim=self.head_dim, + bias=config.attention_bias, + gather_output=False, + skip_bias_add=False, + **column_kwargs, + ) + + self.q_size = self.num_heads_per_tp * self.head_dim + self.k_size = self.num_key_value_heads_per_tp * self.head_dim + self.v_size = self.num_key_value_heads_per_tp * self.head_dim + + self.o_proj = tensor_parallel.RowParallelLinear( + input_size=self.num_heads * self.head_dim, + output_size=self.hidden_size, + bias=config.attention_bias, + input_is_parallel=True, + skip_bias_add=False, + **row_kwargs, + ) + + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + rope_type_key = "type" if "type" in self.config.rope_scaling else "rope_type" + scaling_type = self.config.rope_scaling[rope_type_key] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "llama3": + self.rotary_emb = LlamaLlama3ScalingRotaryEmbedding( + self.head_dim, + self.config, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + qkv = self.qkv_proj(hidden_states)[0] + query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) + + query_states = query_states.view(bsz, q_len, self.num_heads_per_tp, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads_per_tp, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads_per_tp, q_len, kv_seq_len)}, " + f"but is {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads_per_tp, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads_per_tp, q_len, self.head_dim)}, " + f"but is {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size_per_tp) + attn_output = self.o_proj(attn_output)[0] + return attn_output + + +""" +Remove padding Attention +- Using Flash-attn 2 +- Compatible with sequence parallel +""" + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa: F401 + + +def apply_rotary_pos_emb_rmpad(q, k, cos, sin, position_ids, indices, sequence_length): + batch_size = position_ids.shape[0] + + q = pad_input(q, indices, batch_size, sequence_length) # (batch_size, seqlen, num_head, head_dim) + k = pad_input(k, indices, batch_size, sequence_length) + cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + + q_embed = index_first_axis(rearrange(q_embed, "b s ... -> (b s) ..."), indices) + k_embed = index_first_axis(rearrange(k_embed, "b s ... -> (b s) ..."), indices) + + return q_embed, k_embed + + +# use flash-attn rotary embeddings with rmpad +# cos/sin shoudl be: (seq_length, rotary_dim / 2) +def apply_rotary_pos_emb_rmpad_flash(q, k, cos, sin, cu_seqlens, max_seqlen): + q_embed = apply_rotary_emb( + q, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ) + k_embed = apply_rotary_emb( + k, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ) + return q_embed, k_embed + + +class ParallelLlamaAttentionRmPad(ParallelLlamaAttention): + def forward( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + sequence_length: int = None, + indices: torch.Tensor = None, + cu_seqlens: torch.Tensor = None, + max_seqlen_in_batch: int = None, + ): + total_nnz, _, _ = hidden_states.size() # This is the total_nnz padded after sequence parallel + + if self.megatron_config.sequence_parallel: + total_nnz = total_nnz * mpu.get_tensor_model_parallel_world_size() + + qkv = self.qkv_proj(hidden_states)[0] + query_states, key_states, value_states = qkv.split( + [self.q_size, self.k_size, self.v_size], dim=-1 + ) # (total_nnz, 1, hidden_size) + + if self.megatron_config.sequence_parallel: + sequence_parallel_pad = total_nnz - cu_seqlens[-1] + total_nnz = cu_seqlens[-1] # total_nnz before sp padding + query_states = query_states[:total_nnz] + key_states = key_states[:total_nnz] + value_states = value_states[:total_nnz] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dime x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(total_nnz, self.num_heads_per_tp, self.head_dim) + key_states = key_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim) + value_states = value_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim) + + cos, sin = self.rotary_emb(value_states, seq_len=sequence_length) + cos, sin = cos[:, : cos.shape[1] // 2], sin[:, : sin.shape[1] // 2] # flash attn only needs half + query_states, key_states = apply_rotary_pos_emb_rmpad_flash( + query_states, key_states, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen_in_batch + ) + # query_states, key_states = apply_rotary_pos_emb_rmpad(query_states, key_states, cos, sin, + # position_ids, indices, + + # TODO: llama does not have dropout in the config?? + # It is recommended to use dropout with FA according to the docs + # when training. + dropout_rate = 0.0 # if not self.training else self.attn_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + input_dtype = query_states.dtype + if input_dtype == torch.float32: + query_states = query_states.to(torch.float16) + key_states = key_states.to(torch.float16) + value_states = value_states.to(torch.float16) + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen_in_batch, + max_seqlen_k=max_seqlen_in_batch, + dropout_p=dropout_rate, + softmax_scale=None, + causal=True, + ) + + attn_output_unpad = attn_output_unpad.to(input_dtype) + attn_output_unpad = attn_output_unpad.reshape(total_nnz, 1, self.hidden_size_per_tp).contiguous() + + # sequence parallel reduce_scatter is performed inside RowColumnParallel if enabled + # Here we need to repad + if self.megatron_config.sequence_parallel: + attn_output_unpad = F.pad(attn_output_unpad, pad=(0, 0, 0, 0, 0, sequence_parallel_pad)) + + attn_output_unpad = self.o_proj(attn_output_unpad)[0] + return attn_output_unpad diff --git a/code/RL_model/verl/verl_train/verl/models/llama/megatron/layers/parallel_decoder.py b/code/RL_model/verl/verl_train/verl/models/llama/megatron/layers/parallel_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..f46e9457c793ccc4a9dc72f6d471d58ef48e8bfe --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/llama/megatron/layers/parallel_decoder.py @@ -0,0 +1,150 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +from megatron.core import ModelParallelConfig +from torch import nn +from transformers import LlamaConfig + +from verl.utils.megatron_utils import TransformerConfig, convert_config + +from .parallel_attention import ParallelLlamaAttention, ParallelLlamaAttentionRmPad +from .parallel_mlp import ParallelLlamaMLP +from .parallel_rmsnorm import ParallelLlamaRMSNorm + + +class ParallelLlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, layer_idx: int): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.self_attn = ParallelLlamaAttention(config=config, megatron_config=megatron_config) + + self.mlp = ParallelLlamaMLP(config, megatron_config=megatron_config) + self.input_layernorm = ParallelLlamaRMSNorm(config, megatron_config) + self.post_attention_layernorm = ParallelLlamaRMSNorm(config, megatron_config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Note: sequence parallel is hidden inside ColumnParallelLinear + # reduce scatter is hidden inside RowParallelLinear + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + # TODO: add sequence parallel operator reduce_scatter here + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + # TODO: add sequence parallel operator all_gather here + + hidden_states = self.mlp(hidden_states) + + # TODO: add sequence parallel operator reduce_scatter here + + hidden_states = residual + hidden_states + + outputs = hidden_states + + return outputs + + +class ParallelLlamaDecoderLayerRmPad(nn.Module): + def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, layer_idx: int): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.self_attn = ParallelLlamaAttentionRmPad(config=config, megatron_config=megatron_config) + + self.mlp = ParallelLlamaMLP(config, megatron_config=megatron_config) + self.input_layernorm = ParallelLlamaRMSNorm(config, megatron_config) + self.post_attention_layernorm = ParallelLlamaRMSNorm(config, megatron_config) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + sequence_length: int = None, + indices: torch.Tensor = None, + cu_seqlens: int = None, + max_seqlen_in_batch: int = None, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states # (total_nnz // sp, 1, hidden_size) + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + # (total_nnz // sp, 1, hidden_size) -> all-gather (total_nnz, 1, hidden_size) + # -> col + row -> reduce-scatter -> (total_nnz // sp, 1, hidden_size) + hidden_states = self.self_attn( + hidden_states=hidden_states, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + # shape changes same as attn + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = hidden_states + + return outputs diff --git a/code/RL_model/verl/verl_train/verl/models/llama/megatron/layers/parallel_linear.py b/code/RL_model/verl/verl_train/verl/models/llama/megatron/layers/parallel_linear.py new file mode 100644 index 0000000000000000000000000000000000000000..043726c46c3705cf1bfa8ae10ab77d2b930e19d2 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/llama/megatron/layers/parallel_linear.py @@ -0,0 +1,106 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py + +import torch +from megatron.core import tensor_parallel + + +class QKVParallelLinear(tensor_parallel.ColumnParallelLinear): + def __init__( + self, + input_size, + num_heads, + num_key_value_heads, + head_dim, + *, + bias=True, + gather_output=True, + skip_bias_add=False, + **kwargs, + ): + # Keep input parameters, and already restrict the head numbers + self.input_size = input_size + self.q_output_size = num_heads * head_dim + self.kv_output_size = num_key_value_heads * head_dim + self.head_dim = head_dim + self.gather_output = gather_output + self.skip_bias_add = skip_bias_add + + input_size = self.input_size + output_size = (num_heads + 2 * num_key_value_heads) * self.head_dim + + super().__init__( + input_size=input_size, + output_size=output_size, + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + **kwargs, + ) + + +class MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear): + def __init__( + self, + input_size, + gate_ouput_size, + up_output_size, + *, + bias=True, + gather_output=True, + skip_bias_add=False, + **kwargs, + ): + # Keep input parameters, and already restrict the head numbers + self.input_size = input_size + self.output_size = gate_ouput_size + up_output_size + self.gather_output = gather_output + self.skip_bias_add = skip_bias_add + + super().__init__( + input_size=self.input_size, + output_size=self.output_size, + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + **kwargs, + ) + + +class LinearForLastLayer(torch.nn.Linear): + def __init__( + self, + input_size, + output_size, + *, + config, + bias=True, + ): + super().__init__(in_features=input_size, out_features=output_size, bias=bias) + self.sequence_parallel = config.sequence_parallel + if self.sequence_parallel: + self.weight.sequence_parallel = True + + def forward( + self, + input_, + weight=None, + runtime_gather_output=None, + ): + logits = super().forward(input_) + logits = logits.float() + if self.sequence_parallel: + logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False) + return logits, None diff --git a/code/RL_model/verl/verl_train/verl/models/llama/megatron/layers/parallel_mlp.py b/code/RL_model/verl/verl_train/verl/models/llama/megatron/layers/parallel_mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..583a317eb6aedadeb26d82cef54b815d2b9d22e6 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/llama/megatron/layers/parallel_mlp.py @@ -0,0 +1,74 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from megatron.core import ModelParallelConfig, tensor_parallel +from megatron.core import parallel_state as mpu +from torch import nn +from transformers.activations import ACT2FN + +from verl.models.llama.megatron.layers.parallel_linear import MergedColumnParallelLinear +from verl.utils.megatron import tensor_parallel as tp_utils + + +class ParallelLlamaMLP(nn.Module): + def __init__(self, config, megatron_config: ModelParallelConfig = None) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + # The weight is only [hidden_size, intermediate_size // model_parallel_world_size] + + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear() + + if megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + assert row_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(row_kwargs, megatron_config) + tp_utils.update_kwargs_with_config(column_kwargs, megatron_config) + + tp_size = mpu.get_tensor_model_parallel_world_size() + + self.gate_up_proj = MergedColumnParallelLinear( + input_size=self.hidden_size, + gate_ouput_size=self.intermediate_size, + up_output_size=self.intermediate_size, + bias=False, + gather_output=False, + skip_bias_add=False, + **column_kwargs, + ) + self.gate_size = self.intermediate_size // tp_size + + self.down_proj = tensor_parallel.RowParallelLinear( + input_size=self.intermediate_size, + output_size=self.hidden_size, + bias=False, + input_is_parallel=True, + skip_bias_add=False, + **row_kwargs, + ) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + gate_up = self.gate_up_proj(x)[0] + gate, up = gate_up.split(self.gate_size, dim=-1) + return self.down_proj(self.act_fn(gate) * up)[0] diff --git a/code/RL_model/verl/verl_train/verl/models/llama/megatron/layers/parallel_rmsnorm.py b/code/RL_model/verl/verl_train/verl/models/llama/megatron/layers/parallel_rmsnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..23a4a847ff875b2410f5c76b7386b806d86a5735 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/llama/megatron/layers/parallel_rmsnorm.py @@ -0,0 +1,49 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numbers + +import torch +from megatron.core import ModelParallelConfig +from torch import nn +from transformers import LlamaConfig + +from verl.utils.megatron import sequence_parallel as sp_utils + + +class ParallelLlamaRMSNorm(nn.Module): + def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + if isinstance(config.hidden_size, numbers.Integral): + normalized_shape = (config.hidden_size,) + self.normalized_shape = torch.Size(normalized_shape) + self.weight = nn.Parameter(torch.ones(self.normalized_shape)) + self.variance_epsilon = config.rms_norm_eps + + if megatron_config.sequence_parallel: + sp_utils.mark_parameter_as_sequence_parallel(self.weight) + + def forward(self, hidden_states): + from apex.normalization.fused_layer_norm import fused_rms_norm_affine + + return fused_rms_norm_affine( + input=hidden_states, + weight=self.weight, + normalized_shape=self.normalized_shape, + eps=self.variance_epsilon, + memory_efficient=True, + ) diff --git a/code/RL_model/verl/verl_train/verl/models/llama/megatron/modeling_llama_megatron.py b/code/RL_model/verl/verl_train/verl/models/llama/megatron/modeling_llama_megatron.py new file mode 100644 index 0000000000000000000000000000000000000000..e8a7e2440e643fb48f87093f235a4834b4a23e48 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/llama/megatron/modeling_llama_megatron.py @@ -0,0 +1,688 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch LLaMA model with Megatron-style acceleration.""" + +from typing import Optional + +import torch +import torch.utils.checkpoint +from megatron.core import ModelParallelConfig, mpu, tensor_parallel +from torch import nn +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import CausalLMOutputWithPast + +from verl.utils.megatron import sequence_parallel as sp_utils +from verl.utils.megatron import tensor_parallel as tp_utils +from verl.utils.megatron_utils import TransformerConfig, convert_config + +from .layers import ParallelLlamaDecoderLayer, ParallelLlamaDecoderLayerRmPad, ParallelLlamaRMSNorm + +""" +TODO: +1. Add weight initialization. Here we need to be careful on TP weight init. +2. Add sequence parallel +3. Load checkpoint from meta LLama pretrained checkpoint +""" + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +class ParallelLlamaModel(nn.Module): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() + if megatron_config is not None: + assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) + self.embed_tokens = tensor_parallel.VocabParallelEmbedding( + num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs + ) + + self.layers = nn.ModuleList( + [ParallelLlamaDecoderLayer(config, megatron_config) for _ in range(config.num_hidden_layers)] + ) + self.norm = ParallelLlamaRMSNorm(config, megatron_config) + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> tuple | BaseModelOutputWithPast: + """ + + Args: + input_ids: input ids. shape (batch_size, seq_length) + attention_mask: attention_mask. shape (batch_size, seq_length) + position_ids: position ids. shape (batch_size, seq_length) + + Returns: + + """ + batch_size, seq_length = input_ids.shape + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + + attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds) + + hidden_states = inputs_embeds + + for idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + hidden_states = layer_outputs + + hidden_states = self.norm(hidden_states) + + return hidden_states + + +class ParallelLlamaForCausalLM(nn.Module): + def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.model = ParallelLlamaModel(config, megatron_config=megatron_config) + self.vocab_size = config.vocab_size + + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + if megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) + + self.lm_head = tensor_parallel.ColumnParallelLinear( + input_size=config.hidden_size, + output_size=config.vocab_size, + bias=False, + gather_output=False, + skip_bias_add=False, + **column_kwargs, + ) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> tuple | CausalLMOutputWithPast: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + ```""" + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + hidden_states = outputs + logits = self.lm_head(hidden_states)[0] + + logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) + + logits = logits.float() + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=None, + hidden_states=None, + attentions=None, + ) + + +from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa: F401, E402 + + +class ParallelLlamaModelRmPad(nn.Module): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() + self.megatron_config = megatron_config + if megatron_config is not None: + assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) + self.embed_tokens = tensor_parallel.VocabParallelEmbedding( + num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs + ) + + self.layers = nn.ModuleList( + [ParallelLlamaDecoderLayerRmPad(config, megatron_config) for _ in range(config.num_hidden_layers)] + ) + self.norm = ParallelLlamaRMSNorm(config, megatron_config) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + sequence_length: int = None, + indices: torch.Tensor = None, + cu_seqlens: int = None, + max_seqlen_in_batch: int = None, + ) -> tuple | BaseModelOutputWithPast: + """ + + Args: + input_ids: input ids. shape (1, totol_nnz) + position_ids: position ids. shape (batch_size, seq_length) + + Returns: + + """ + inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size) + + # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size) + inputs_embeds = inputs_embeds.transpose(0, 1) + if self.megatron_config.sequence_parallel: + inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds) + + hidden_states = inputs_embeds + for idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer( + hidden_states, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) + + hidden_states = layer_outputs + + hidden_states = self.norm(hidden_states) + + return hidden_states + + +class ParallelLlamaForCausalLMRmPad(nn.Module): + def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.megatron_config = megatron_config + self.model = ParallelLlamaModelRmPad(config, megatron_config=megatron_config) + self.vocab_size = config.vocab_size + self._init_head(config) + + def _init_head(self, config): + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + if self.megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) + self.lm_head = tensor_parallel.ColumnParallelLinear( + input_size=config.hidden_size, + output_size=config.vocab_size, + bias=False, + gather_output=False, + skip_bias_add=False, + **column_kwargs, + ) + + def _forward_head(self, hidden_states): + # all_gather from sequence parallel region is performed inside lm_head + logits = self.lm_head(hidden_states)[0] + logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp) + logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) # (total_nnz_padded, 1, vocab_size) + return logits + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> tuple | CausalLMOutputWithPast: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + ```""" + batch_size, sequence_length = input_ids.shape + + # remove padding here + input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input( + input_ids.unsqueeze(dim=-1), attention_mask + ) # (total_nnz, 1) + + # pad input_ids to multiple of tp for all tp ranks + # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap + if self.megatron_config.sequence_parallel: + input_ids = sp_utils.pad_to_sequence_parallel(input_ids) + + input_ids = input_ids.transpose(0, 1) # (1, total_nnz+pad) + + outputs = self.model( + input_ids=input_ids, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) + + hidden_states = outputs + + logits = self._forward_head(hidden_states) + + # remove padding from sequence parallel + if self.megatron_config.sequence_parallel: + totol_nnz = cu_seqlens[-1] + logits = logits[:totol_nnz] # (total_nnz_padded) + + logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension + # add removed padding back + logits = pad_input( + logits, indices, batch_size, seqlen=sequence_length + ) # (batch_size, sequence_length, vocab_size) + + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=None, + hidden_states=None, + attentions=None, + ) + + +class ParallelLlamaForValueRmPad(ParallelLlamaForCausalLMRmPad): + def _init_head(self, config): + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + if self.megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) + self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False) + # lm_head is effectively the same as sequence parallel + sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight) + + def _forward_head(self, hidden_states): + logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1) + logits = logits.float() + if self.megatron_config.sequence_parallel: + logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False) + return logits + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> tuple | CausalLMOutputWithPast: + output = super().forward(input_ids, attention_mask, position_ids) + output.logits = torch.squeeze(output.logits, dim=-1) + return output + + +""" +Support pipeline parallelism +""" + + +class ParallelLlamaModelRmPadPP(nn.Module): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + This model definition supports pipeline parallelism. To support pp and vpp, + - This model only contains layer in this pp stage and vpp chunk + - When calling get_model in Megatron, this rank will instantiate all the vpp chunks in this pp. + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, pre_process, post_process): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.pre_process = pre_process + self.post_process = post_process + self.megatron_config = megatron_config + embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() + if megatron_config is not None: + assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) + if pre_process: + self.embed_tokens = tensor_parallel.VocabParallelEmbedding( + num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs + ) + else: + self.embed_tokens = None + + pp_rank = mpu.get_pipeline_model_parallel_rank() + pp_size = megatron_config.pipeline_model_parallel_size + self.num_layer_per_pp = config.num_hidden_layers // pp_size + vpp_size = megatron_config.virtual_pipeline_model_parallel_size + vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank() + + if vpp_size is not None: + self.layers = nn.ModuleList() + self.num_layer_vpp_chunk = self.num_layer_per_pp // vpp_size + self.num_layer_this_model = self.num_layer_vpp_chunk + offset = vpp_rank * (config.num_hidden_layers // vpp_size) + (pp_rank * self.num_layer_vpp_chunk) + else: + self.num_layer_this_model = self.num_layer_per_pp + offset = pp_rank * self.num_layer_per_pp + + self.layers = nn.ModuleList() + for i in range(self.num_layer_this_model): + layer = ParallelLlamaDecoderLayerRmPad(config, megatron_config, layer_idx=offset + i) + self.layers.add_module(f"{i}", layer) + + if post_process: + self.norm = ParallelLlamaRMSNorm(config, megatron_config) + else: + self.norm = None + + def set_input_tensor(self, input_tensor): + """Set input tensor to be used instead of forward()'s input. + + When doing pipeline parallelism the input from the previous + stage comes from communication, not from the input, so the + model's forward_step_func won't have it. This function is thus + used by internal code to bypass the input provided by the + forward_step_func""" + self.input_tensor = input_tensor + + def forward( + self, + input_ids: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + sequence_length: int = None, + indices: torch.Tensor = None, + cu_seqlens: int = None, + max_seqlen_in_batch: int = None, + ) -> tuple | BaseModelOutputWithPast: + """ + + Args: + input_ids: input ids. shape (1, totol_nnz) + position_ids: position ids. shape (batch_size, seq_length) + + Returns: + + """ + if self.pre_process: + inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size) + + # vocab parallel embedding will not do sequence parallel reduce-scatter in open source megatron + # so need to deal with it by handle here: + # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size) + inputs_embeds = inputs_embeds.transpose(0, 1) + if self.megatron_config.sequence_parallel: + inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds) + + hidden_states = inputs_embeds + else: + # self.hidden_states should be passed by Megatron + hidden_states = self.input_tensor + + for idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer( + hidden_states, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) + + hidden_states = layer_outputs + + if self.post_process: + hidden_states = self.norm(hidden_states) + + return hidden_states + + +class ParallelLlamaForCausalLMRmPadPP(nn.Module): + def __init__( + self, + config: LlamaConfig, + megatron_config: ModelParallelConfig, + pre_process, + post_process, + share_embeddings_and_output_weights=False, + ): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.megatron_config = megatron_config + self.model = ParallelLlamaModelRmPadPP( + config, megatron_config=megatron_config, pre_process=pre_process, post_process=post_process + ) + assert share_embeddings_and_output_weights is False, ( + "Llama Model not supports sharing embedding and output weights" + ) + self.share_embeddings_and_output_weights = share_embeddings_and_output_weights + self.vocab_size = config.vocab_size + self.pre_process = pre_process + self.post_process = post_process + if post_process: + self._init_head(config) + + def set_input_tensor(self, input_tensor): + """Set input tensor to be used instead of forward()'s input. + + When doing pipeline parallelism the input from the previous + stage comes from communication, not from the input, so the + model's forward_step_func won't have it. This function is thus + used by internal code to bypass the input provided by the + forward_step_func""" + assert len(input_tensor) == 1 + self.model.set_input_tensor(input_tensor[0]) + + def _init_head(self, config): + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + if self.megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) + self.lm_head = tensor_parallel.ColumnParallelLinear( + input_size=config.hidden_size, + output_size=config.vocab_size, + bias=False, + gather_output=False, + skip_bias_add=False, + **column_kwargs, + ) + + def _forward_head(self, hidden_states): + # all_gather from sequence parallel region is performed inside lm_head + # logits shape before forward_head hidden_states.shape: [4, 32, 4096] + logits = self.lm_head(hidden_states)[0] + # logits shape after forward_head logits.shape: [8, 32, 8] + logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp) + return logits + + def forward( + self, + # original input + *, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> tuple | CausalLMOutputWithPast: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + ```""" + + # Note that input_ids, attention_mask and position_ids should be passed to every pp layer. + # In the first pp, input_ids will be used, in other pp layers hidden_states will be used inside self.model + batch_size, sequence_length = input_ids.shape + # remove padding here + input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input( + input_ids.unsqueeze(dim=-1), attention_mask + ) # (total_nnz, 1) + + # pad input_ids to multiple of tp for all tp ranks + # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap + if self.megatron_config.sequence_parallel: + input_ids_rmpad = sp_utils.pad_to_sequence_parallel(input_ids_rmpad) + + input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz+pad) + + outputs = self.model( + input_ids=input_ids_rmpad, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) + + if self.post_process: + hidden_states = outputs + # print(f'hidden_states.shape = {hidden_states.shape}') # torch.Size([4, 32, 4096]) + logits = self._forward_head(hidden_states) + logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # torch.Size([8, 32, 16]) + + # remove padding from sequence parallel + if self.megatron_config.sequence_parallel: + totol_nnz = cu_seqlens[-1] + logits = logits[:totol_nnz] # (total_nnz_padded) + # add removed padding back. If input is already rmpad, we let the caller pad_input + logits = pad_input( + logits, indices, batch_size, seqlen=sequence_length + ) # (batch_size, sequence_length, vocab_size) + + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=None, + hidden_states=None, + attentions=None, + ) + else: + return outputs + + +class ParallelLlamaForValueRmPadPP(ParallelLlamaForCausalLMRmPadPP): + def _init_head(self, config): + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + if self.megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) + self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False) + # lm_head is effectively the same as sequence parallel + sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight) + + def _forward_head(self, hidden_states): + logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1) + logits = logits.float() + if self.megatron_config.sequence_parallel: + logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False) + return logits + + def forward( + self, + *, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> tuple | CausalLMOutputWithPast: + output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids) + if self.post_process: + output.logits = torch.squeeze(output.logits, dim=-1) + return output + else: + return output diff --git a/code/RL_model/verl/verl_train/verl/models/mcore/__init__.py b/code/RL_model/verl/verl_train/verl/models/mcore/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a0f6e76f3f8b0c238fd9085942f6df1b90d4a974 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/mcore/__init__.py @@ -0,0 +1,32 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .registry import ( + get_mcore_forward_fn, + get_mcore_forward_fused_fn, + get_mcore_forward_no_padding_fn, + get_mcore_weight_converter, + hf_to_mcore_config, + init_mcore_model, +) + +__all__ = [ + "hf_to_mcore_config", + "init_mcore_model", + "get_mcore_forward_fn", + "get_mcore_weight_converter", + "get_mcore_forward_fused_fn", + "get_mcore_forward_no_padding_fn", +] diff --git a/code/RL_model/verl/verl_train/verl/models/mcore/bridge.py b/code/RL_model/verl/verl_train/verl/models/mcore/bridge.py new file mode 100644 index 0000000000000000000000000000000000000000..dffb661b7b098a6d24352ae9583551da8048b055 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/mcore/bridge.py @@ -0,0 +1,178 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + from megatron.bridge import AutoBridge + from megatron.bridge.models.conversion.param_mapping import AutoMapping + from megatron.bridge.peft.canonical_lora import CanonicalLoRA + from megatron.bridge.peft.dora import DoRA + from megatron.bridge.peft.lora import LoRA, VLMLoRA +except ImportError: + # `pip install verl[mcore]` or + print("Megatron-Bridge package not found. Please install Megatron-Bridge with `pip install megatron-bridge`") + raise + +import torch +from megatron.core import tensor_parallel + + +def _ensure_model_list(model): + return model if isinstance(model, list) else [model] + + +class LinearForLastLayer(torch.nn.Linear): + """ + A custom linear layer implementation for the last layer of a model. + + This layer extends PyTorch's Linear module with functionality specifically designed + for handling the final layer in transformer models with sequence parallelism. + + Attributes: + sequence_parallel: Boolean indicating whether sequence parallelism is enabled + """ + + def __init__( + self, + input_size, + output_size, + *, + sequence_parallel: bool, + ): + """ + Initializes the LinearForLastLayer. + + Args: + input_size: The size of the input features + output_size: The size of the output features + sequence_parallel (bool): Whether sequence parallelism is enabled + """ + super().__init__(in_features=input_size, out_features=output_size, bias=False) + self.sequence_parallel = sequence_parallel + if self.sequence_parallel: + self.weight.sequence_parallel = True + + def forward( + self, + input_, + weight=None, + runtime_gather_output=None, + ): + """ + Forward pass for the linear layer. + + This method computes the linear transformation and handles sequence parallelism + if enabled, gathering outputs from different sequence parallel regions. + + Args: + input_: Input tensor + weight: Placeholder for compatibility + runtime_gather_output: Placeholder for compatibility + + Returns: + tuple: (logits, None) where logits is the output of the linear transformation + """ + logits = super().forward(input_) + logits = logits.float() + if self.sequence_parallel: + logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False) + return logits, None + + +# Make Megatron-Bridge AutoMapping treats the custom last layer as replicated. +AutoMapping.register_module_type("LinearForLastLayer", "replicated") + + +def make_value_model(hidden_size, sequence_parallel): + """Creates a pre-wrap hook that replace the output layer with a value head. + + Args: + hidden_size (int): The hidden size of the model's transformer layers. + sequence_parallel (bool): Whether sequence parallelism is enabled. + + Returns: + A hook function that can be used as a `pre_wrap_hook` in Megatron-Bridge. + The hook itself takes the model as input and prepares it for value head activation. + """ + + from megatron.core import parallel_state + + def hook(model): + model_post_process = [] + if ( + parallel_state.get_pipeline_model_parallel_world_size() > 1 + and parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None + ): + for i in range(parallel_state.get_virtual_pipeline_model_parallel_world_size()): + model_post_process.append(parallel_state.is_pipeline_last_stage(ignore_virtual=False, vp_stage=i)) + else: + model_post_process.append(parallel_state.is_pipeline_last_stage()) + + model_list = _ensure_model_list(model) + assert len(model_post_process) == len(model_list), "Model list length and post process list length must match." + + for index, model_chunk in enumerate(model_list): + if not model_post_process[index]: + continue + + model_chunk.output_layer = LinearForLastLayer( + input_size=hidden_size, + output_size=1, + sequence_parallel=sequence_parallel, + ) + + return hook + + +def freeze_moe_router(model): + """Pre-wrap hook to freeze MoE router parameters. + + Args: + model: List of MegatronModule instances or single module + + Returns: + The model with frozen router parameters + """ + for model_chunk in _ensure_model_list(model): + if hasattr(model_chunk, "decoder") and hasattr(model_chunk.decoder, "layers"): + for layer in model_chunk.decoder.layers: + if hasattr(layer.mlp, "router"): + if hasattr(layer.mlp.router, "weight"): + layer.mlp.router.weight.requires_grad = False + if hasattr(layer.mlp.router, "bias"): + layer.mlp.router.bias.requires_grad = False + if hasattr(layer.mlp, "shared_experts"): + if ( + hasattr(layer.mlp.shared_experts, "gate_weight") + and layer.mlp.shared_experts.gate_weight is not None + ): + layer.mlp.shared_experts.gate_weight.requires_grad = False + if ( + hasattr(layer.mlp.shared_experts, "gate_bias") + and layer.mlp.shared_experts.gate_bias is not None + ): + layer.mlp.shared_experts.gate_bias.requires_grad = False + + return model + + +__all__ = [ + "AutoBridge", + "make_value_model", + "freeze_moe_router", + "LoRA", + "VLMLoRA", + "DoRA", + "CanonicalLoRA", +] diff --git a/code/RL_model/verl/verl_train/verl/models/mcore/config_converter.py b/code/RL_model/verl/verl_train/verl/models/mcore/config_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..c4df938286146ccf9d50bed3d0938d49d7f03875 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/mcore/config_converter.py @@ -0,0 +1,399 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# convert huggingface config to mcore transformer config + + +import warnings +from typing import TypeVar + +import torch +import torch.nn.functional as F +from megatron.core import parallel_state as mpu +from megatron.core.transformer import MLATransformerConfig, TransformerConfig +from transformers import PretrainedConfig + +T = TypeVar("T", bound=TransformerConfig) + + +def _get_base_transformer_config( + hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs +) -> dict: + """ + Create a base TransformerConfig with common parameters across different model architectures. + TODO: (ycl) use dataclass or converter config? + + Args: + hf_config: HuggingFace model configuration + dtype: Data type for the model + override_transformer_config_kwargs: Additional parameters to override defaults + + Returns: + TransformerConfig with common parameters + """ + + # Common parallel state parameters + overlap_p2p_comm = ( + mpu.get_virtual_pipeline_model_parallel_world_size() is not None + and mpu.get_virtual_pipeline_model_parallel_world_size() > 1 + ) + batch_p2p_comm = False + + # Base configuration with common parameters + base_config = { + # Model architecture parameters + "num_layers": hf_config.num_hidden_layers, + "hidden_size": hf_config.hidden_size, + "num_attention_heads": hf_config.num_attention_heads, + "num_query_groups": hf_config.num_key_value_heads, + "ffn_hidden_size": hf_config.intermediate_size, + "attention_dropout": hf_config.attention_dropout, + "hidden_dropout": getattr(hf_config, "hidden_dropout", 0.0), + "kv_channels": getattr(hf_config, "head_dim", None), + "layernorm_epsilon": hf_config.rms_norm_eps, + "add_bias_linear": True, + # Activation and normalization + "activation_func": F.silu, + "normalization": "RMSNorm", + "gated_linear_unit": True, + # Data types + "pipeline_dtype": dtype, + "params_dtype": dtype, + "bf16": dtype is torch.bfloat16, + # Parallel configuration + "tensor_model_parallel_size": mpu.get_tensor_model_parallel_world_size(), + "pipeline_model_parallel_size": mpu.get_pipeline_model_parallel_world_size(), + "expert_model_parallel_size": mpu.get_expert_model_parallel_world_size(), + "expert_tensor_parallel_size": mpu.get_expert_tensor_parallel_world_size(), + "virtual_pipeline_model_parallel_size": mpu.get_virtual_pipeline_model_parallel_world_size(), + "context_parallel_size": mpu.get_context_parallel_world_size(), + "overlap_p2p_comm": overlap_p2p_comm, + "batch_p2p_comm": batch_p2p_comm, + "sequence_parallel": mpu.get_tensor_model_parallel_world_size() > 1, + # Common settings + "variable_seq_lengths": True, + "masked_softmax_fusion": True, + "moe_token_dispatcher_type": "alltoall", + } + + # Update with any provided overrides + # override_transformer_config_kwargs as kwargs shall never be none + base_config.update(override_transformer_config_kwargs) + + return base_config + + +def _get_mla_transformer_config( + hf_config: PretrainedConfig, mla_rope_config: dict, dtype: torch.dtype, **override_transformer_config_kwargs +) -> dict: + """ + Create a MLATransformerConfig with common parameters across different model architectures. + This is specifically for MLA models like DeepseekV3. + + Args: + hf_config: HuggingFace model configuration + mla_rope_config: MLA specific RoPE configuration + dtype: Data type for the model + override_transformer_config_kwargs: Additional parameters to override defaults + + Returns: + MLATransformerConfig with common parameters + """ + base_config = _get_base_transformer_config(hf_config=hf_config, dtype=dtype, **override_transformer_config_kwargs) + mla_config = { + # MLA specific parameters + "q_lora_rank": hf_config.q_lora_rank, + "kv_lora_rank": hf_config.kv_lora_rank, + "qk_head_dim": hf_config.qk_nope_head_dim, + "qk_pos_emb_head_dim": hf_config.qk_rope_head_dim, + "v_head_dim": hf_config.v_head_dim, + "rotary_base": hf_config.rope_theta, + "rotary_scaling_factor": mla_rope_config["factor"], + "rope_type": mla_rope_config["type"], + "max_position_embeddings": mla_rope_config["original_max_position_embeddings"], + "beta_fast": mla_rope_config["beta_fast"], + "beta_slow": mla_rope_config["beta_slow"], + "mscale": mla_rope_config["mscale"], + "mscale_all_dim": mla_rope_config["mscale_all_dim"], + } + + base_config.update(mla_config) + return base_config + + +def check_and_construct_configs(original_config: dict, cls: type[T]) -> T: + """ + Check and disable incompatible configurations for older Megatron version. + + Args: + original_config (dict): The original model configuration. + + Returns: + dict: The updated model configuration with incompatible settings disabled. + """ + removed_keys = [] + for key in original_config.keys(): + if not hasattr(cls, key): + removed_keys.append(key) + if removed_keys: + warnings.warn( + f"The following keys are not supported in the current Megatron version and will be removed: {removed_keys}", + stacklevel=2, + ) + for key in removed_keys: + original_config.pop(key) + + original_config = mapping_string_to_attn_backend(original_config) + if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: + print(f"Overridden {cls.__name__} init config: {original_config}") + return cls(**original_config) + + +def hf_to_mcore_config_dense( + hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs +) -> TransformerConfig: + # for LlamaForCausalLM or Qwen2ForCausalLM + qkv_bias = True if "Qwen2" in hf_config.architectures[0] else getattr(hf_config, "attention_bias", False) + qk_layernorm = True if "Qwen3" in hf_config.architectures[0] else False + + args: dict = _get_base_transformer_config( + hf_config=hf_config, + dtype=dtype, + use_cpu_initialization=False, + add_bias_linear=False, + add_qkv_bias=qkv_bias, + qk_layernorm=qk_layernorm, + ) + # override_transformer_config_kwargs as kwargs shall never be none + args.update(override_transformer_config_kwargs) + return check_and_construct_configs(args, TransformerConfig) + + +def hf_to_mcore_config_qwen2moe( + hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs +) -> TransformerConfig: + args: dict = _get_base_transformer_config( + hf_config=hf_config, + dtype=dtype, + use_cpu_initialization=False, + add_bias_linear=False, + layernorm_epsilon=hf_config.rms_norm_eps, + # MoE specific + moe_ffn_hidden_size=hf_config.moe_intermediate_size, + moe_router_bias_update_rate=0.001, + moe_router_topk=hf_config.num_experts_per_tok, + num_moe_experts=hf_config.num_experts, + moe_shared_expert_intermediate_size=hf_config.shared_expert_intermediate_size, + moe_aux_loss_coeff=hf_config.router_aux_loss_coef, + # moe_aux_loss_coeff=0.0, + moe_router_load_balancing_type="none", # turn off aux_loss as it hurts perf in RL + moe_shared_expert_overlap=True, + moe_grouped_gemm=True, + moe_router_score_function="softmax", + # Other optimizations + persist_layer_norm=True, + bias_activation_fusion=True, + bias_dropout_fusion=True, + # Qwen specific + moe_router_pre_softmax=True, + add_qkv_bias=True, + ) + # override_transformer_config_kwargs as kwargs shall never be none + args.update(override_transformer_config_kwargs) + return check_and_construct_configs(args, TransformerConfig) + + +def hf_to_mcore_config_mixtral( + hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs +) -> TransformerConfig: + args: dict = _get_base_transformer_config( + hf_config=hf_config, + dtype=dtype, + use_cpu_initialization=False, + add_bias_linear=False, + layernorm_epsilon=hf_config.rms_norm_eps, + # MoE specific + num_moe_experts=hf_config.num_local_experts, + moe_aux_loss_coeff=hf_config.router_aux_loss_coef, + moe_router_topk=hf_config.num_experts_per_tok, + moe_router_pre_softmax=True, + moe_router_load_balancing_type="none", # turn off aux_loss as it hurts perf in RL + moe_router_score_function="softmax", + moe_shared_expert_intermediate_size=None, # mixtral has no shared expert + moe_shared_expert_overlap=False, # mixtral has no shared expert + moe_ffn_hidden_size=hf_config.intermediate_size, + moe_router_bias_update_rate=0.001, + # moe_permute_fusion=True, # need TE 2.1+ + moe_grouped_gemm=True, + # Other optimizations + persist_layer_norm=True, + apply_rope_fusion=True, + bias_activation_fusion=True, + bias_dropout_fusion=True, + ) + # override_transformer_config_kwargs as kwargs shall never be none + args.update(override_transformer_config_kwargs) + return check_and_construct_configs(args, TransformerConfig) + + +def hf_to_mcore_config_qwen3moe( + hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs +) -> TransformerConfig: + args: dict = _get_base_transformer_config( + hf_config=hf_config, + dtype=dtype, + use_cpu_initialization=False, + add_bias_linear=False, + layernorm_epsilon=hf_config.rms_norm_eps, + # MoE specific + moe_ffn_hidden_size=hf_config.moe_intermediate_size, + moe_router_bias_update_rate=0.001, + moe_router_topk=hf_config.num_experts_per_tok, + num_moe_experts=hf_config.num_experts, + moe_aux_loss_coeff=hf_config.router_aux_loss_coef, + # moe_aux_loss_coeff=0.0, + moe_router_load_balancing_type="none", # turn off aux_loss as it hurts perf in RL + moe_grouped_gemm=True, + moe_router_score_function="softmax", + # Other optimizations + persist_layer_norm=True, + bias_activation_fusion=True, + bias_dropout_fusion=True, + # Qwen specific + moe_router_pre_softmax=False, + qk_layernorm=True, + ) + # override_transformer_config_kwargs as kwargs shall never be none + args.update(override_transformer_config_kwargs) + return check_and_construct_configs(args, TransformerConfig) + + +def hf_to_mcore_config_dpskv3( + hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs +) -> MLATransformerConfig: + # DeepseekV3ForCausalLM + from megatron.core.config import set_experimental_flag + from megatron.core.transformer.enums import AttnBackend + + set_experimental_flag(True) + + from .patch import apply_patch + + apply_patch() + + mla_rope_config = { + "beta_fast": 32, + "beta_slow": 1, + "factor": 1, + "mscale": 1.0, + "mscale_all_dim": 1.0, + "original_max_position_embeddings": 4096, + "type": "rope", + } + if "rope_scaling" in hf_config and hf_config.rope_scaling is not None: + mla_rope_config.update(hf_config.rope_scaling) + moe_layer_freq = [1] * hf_config.num_hidden_layers + for i in range(min(hf_config.first_k_dense_replace, hf_config.num_hidden_layers)): + moe_layer_freq[i] = 0 + + # disable MTP and quantization for now + if "num_nextn_predict_layers" in hf_config: + assert hf_config.num_nextn_predict_layers == 0, ( + "MTP is not supported for now, please modify the config.json to set num_nextn_predict_layers to 0" + ) + assert "quantization_config" not in hf_config or not hf_config.quantization_config, ( + "quantization is not supported for now, please modify the config.json to remove quantization_config" + ) + + args: dict = _get_mla_transformer_config( + hf_config=hf_config, + mla_rope_config=mla_rope_config, + dtype=dtype, + # Additional parameters + use_cpu_initialization=False, + add_bias_linear=False, + attention_backend=AttnBackend.fused, + qk_layernorm=True, + # Standard MoE parameters + moe_ffn_hidden_size=hf_config.moe_intermediate_size, + moe_token_dispatcher_type="alltoall", + moe_router_bias_update_rate=0.001, + moe_router_enable_expert_bias=True, + moe_router_topk=hf_config.num_experts_per_tok, + num_moe_experts=hf_config.n_routed_experts, + moe_shared_expert_intermediate_size=hf_config.moe_intermediate_size * hf_config.n_shared_experts, + moe_aux_loss_coeff=getattr(hf_config, "aux_loss_alpha", 0.001), + moe_router_load_balancing_type="seq_aux_loss", + moe_shared_expert_overlap=True, + # moe_permute_fusion=True, # need TE 2.1+ + moe_grouped_gemm=True, + moe_router_score_function="sigmoid", + moe_router_pre_softmax=True, + moe_router_topk_scaling_factor=hf_config.routed_scaling_factor, + moe_layer_freq=moe_layer_freq, + # mcore 0.12 moe + moe_router_dtype="fp64", + disable_bf16_reduced_precision_matmul=True, + # Other optimizations + # deallocate_pipeline_outputs=True, + # gradient_accumulation_fusion=True, + persist_layer_norm=True, + bias_activation_fusion=True, + bias_dropout_fusion=True, + ) + # override_transformer_config_kwargs as kwargs shall never be none + args.update(override_transformer_config_kwargs) + transformer_config = check_and_construct_configs(args, MLATransformerConfig) + # MTP + if "num_nextn_predict_layers" in hf_config: + transformer_config.mtp_num_layers = hf_config.num_nextn_predict_layers + transformer_config.mtp_loss_scaling_factor = 0.1 + + return transformer_config + + +def hf_to_mcore_config_qwen2_5_vl( + hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs +) -> TransformerConfig: + # Qwen2_5_VLForConditionalGeneration + + args = _get_base_transformer_config( + hf_config=hf_config, + dtype=dtype, + add_bias_linear=False, + # qwen specific + add_qkv_bias=True, + mrope_section=hf_config.rope_scaling["mrope_section"], + ) + # override_transformer_config_kwargs as kwargs shall never be none + args.update(override_transformer_config_kwargs) + args = mapping_string_to_attn_backend(args) + return TransformerConfig(**args) + + +def hf_to_mcore_config_llama4( + hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs +) -> TransformerConfig: + # Llama4ForConditionalGeneration + raise NotImplementedError("Llama4ForConditionalGeneration is not supported yet") + + +def mapping_string_to_attn_backend(args: dict) -> dict: + if "attention_backend" in args and isinstance(args["attention_backend"], str): + from megatron.core.transformer.enums import AttnBackend + + args["attention_backend"] = AttnBackend[args["attention_backend"]] + return args diff --git a/code/RL_model/verl/verl_train/verl/models/mcore/loader.py b/code/RL_model/verl/verl_train/verl/models/mcore/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..577ffc5ecf4f138ab4183d9ee4bef445d6f8142c --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/mcore/loader.py @@ -0,0 +1,495 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import torch +import torch.distributed as dist + +from verl.utils.device import get_device_id, get_torch_device + +from .saver import _megatron_calc_global_rank + + +def _megatron_calc_layer_map(config): + """Calculate the mapping of global layer_idx to local layer_idx + Returns: + layer_map (Dict: int -> tuple(int, int, int)): + mapping from the global layer index to + a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) + """ + from megatron.core import mpu + + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + + layer_map = dict() + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + for pp_rank_idx in range(pp_size): + for virtual_pp_rank_idx in range(virtual_pp_size): + layer_offset = ( + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + ) + for layer_idx in range(num_layers_per_model): + layer_map[layer_offset + layer_idx] = ( + pp_rank_idx, + virtual_pp_rank_idx, + layer_idx, + ) + return layer_map + + +def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, params_dtype, is_value_model=False): + """Load merged state_dict to sharded Megatron module in training.""" + from megatron.core import DistributedDataParallel as LocalDDP + from megatron.core import mpu + from megatron.core.transformer.module import Float16Module + from torch.nn.parallel import DistributedDataParallel as torchDDP + + from verl.utils.logger import print_rank_0 + from verl.utils.megatron_utils import unwrap_model + + start_time = time.time() + + def _get_gpt_model(model): + return model + + def broadcast_params(module): + for param in module.parameters(): + torch.distributed.broadcast( + param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group() + ) + + dp_rank = mpu.get_data_parallel_rank() + pp_rank = mpu.get_pipeline_model_parallel_rank() + cp_rank = mpu.get_context_parallel_rank() + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=0, cp_rank=cp_rank) + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + mp_group = mpu.get_model_parallel_group() + + if torch.distributed.get_rank() == src_rank: + assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" + assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" + assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" + + if not isinstance(wrapped_models, list | tuple): + wrapped_models = list(wrapped_models) + + assert len(wrapped_models) == virtual_pp_size + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + models = [None] * len(wrapped_models) + + for i, wrapped_model in enumerate(wrapped_models): + models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) + gpt_model_module = _get_gpt_model(models[i]) + assert len(gpt_model_module.decoder.layers) == num_layers_per_model + + def _broadcast_tensor(tensor, name) -> torch.Tensor: + """broadcast tensor from rank0 across mp_group""" + nonlocal state_dict + nonlocal mp_group + if torch.distributed.get_rank() == src_rank: + if name in state_dict: + weight = state_dict[name] + tensor_shape = weight.shape + else: + tensor_shape = None + else: + weight = None + tensor_shape = None + + obj_list = [tensor_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + tensor_shape = obj_list[0] + + if tensor_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tensor:[{name}] not in state_dict, skip load") + return + + if tensor is None: + tensor = torch.empty( + tensor_shape, + dtype=params_dtype, + device=get_device_id(), + requires_grad=False, + ) + if torch.distributed.get_rank() == src_rank: + tensor.data.copy_(weight) + dist.broadcast(tensor, src=src_rank, group=mp_group) + + def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == src_rank: + if name in state_dict: + full_weight = state_dict[name] + + if mutate_func is not None: + full_weight = mutate_func(full_weight) + tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=get_device_id(), + requires_grad=False, + ) + else: + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + ) + sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == src_rank: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=src_rank, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == src_rank: + if name in state_dict: + full_weight = state_dict[name] + if mutate_func is not None: + full_weight = mutate_func(full_weight) + tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=get_device_id(), + requires_grad=False, + ) + else: + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + ) + sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == src_rank: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=src_rank, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == src_rank: + gate_weight = state_dict[gate_name] + up_weight = state_dict[up_name] + new_gate_up_weight = torch.empty( + config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) + for i in range(tp_size): + intermediate_size_tp = config.intermediate_size // tp_size + gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_( + torch.cat([gate_weight_tp, up_weight_tp], dim=0) + ) + + tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=get_device_id(), + requires_grad=False, + ) + else: + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank() == src_rank:} tensor {gate_name, up_name} shape " + f"{tensor.shape} != {chunk_shape}" + ) + sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == src_rank: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=src_rank, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == src_rank: + assert q_name in state_dict and k_name in state_dict and v_name in state_dict + full_weight_q = state_dict[q_name] + full_weight_k = state_dict[k_name] + full_weight_v = state_dict[v_name] + + hidden_size_per_head = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + + if config.num_key_value_heads >= tp_size: + q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size + kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size + total_size = q_size_tp + 2 * kv_size_tp + sizes = [total_size * tp_size] + if not bias: + sizes.append(config.hidden_size) + new_weight_qkv = torch.empty(*sizes, dtype=params_dtype, device=get_device_id()) + for i in range(tp_size): + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] + k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] + v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp] + num_query_groups_per_partition = models[0].config.num_query_groups // tp_size + new_weight_qkv_this_tp = new_weight_qkv[i * total_size : (i + 1) * total_size] + q_part_per_head = torch.chunk(q_part, num_query_groups_per_partition, dim=0) + k_part_per_head = torch.chunk(k_part, num_query_groups_per_partition, dim=0) + v_part_per_head = torch.chunk(v_part, num_query_groups_per_partition, dim=0) + total_size_per_head = total_size // num_query_groups_per_partition + for j in range(num_query_groups_per_partition): + new_weight_qkv_this_tp[j * total_size_per_head : (j + 1) * total_size_per_head].copy_( + torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0) + ) + + else: + q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size + kv_size_tp = hidden_size_per_head + total_size = q_size_tp + 2 * kv_size_tp + sizes = [total_size * tp_size] + if not bias: + sizes.append(config.hidden_size) + new_weight_qkv = torch.empty(*sizes, dtype=params_dtype, device=get_device_id()) + for i in range(tp_size): + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] + start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head + end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head + k_part = full_weight_k[start_idx:end_idx] + v_part = full_weight_v[start_idx:end_idx] + new_weight_qkv_this_tp = new_weight_qkv[i * total_size : (i + 1) * total_size] + q_part_per_head = torch.chunk(q_part, config.num_attention_heads, dim=0) + k_part_per_head = torch.chunk(k_part, config.num_attention_heads, dim=0) + v_part_per_head = torch.chunk(v_part, config.num_attention_heads, dim=0) + total_size_per_head = total_size // config.num_attention_heads + for j in range(config.num_attention_heads): + new_weight_qkv_this_tp[j * total_size_per_head : (j + 1) * total_size_per_head].copy_( + torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0) + ) + + tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=get_device_id(), + requires_grad=False, + ) + else: + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}" + ) + sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == src_rank: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=src_rank, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + if dp_rank == 0: + # Embeddings + # ------------------- + print_rank_0("loading embeddings...") + gpt_model_module = _get_gpt_model(models[0]) + embed_tokens_weight = None + if pp_rank == 0: + embed_tokens_weight = gpt_model_module.embedding.word_embeddings.weight + _broadcast_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight") + + # Transformer layers + # ------------------- + layer_map = _megatron_calc_layer_map(config) + + for layer in range(config.num_hidden_layers): + layer_name = f"model.layers.{layer}" + print_rank_0(f"loading layer #{layer}, with layer_name model.layers.{layer}...") + dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer] + + gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank]) + sync_layer = gpt_model_module.decoder.layers[dst_layer_idx] + + _broadcast_tensor( + sync_layer.self_attention.linear_qkv.layer_norm_weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.input_layernorm.weight", + ) + + if f"{layer_name}.self_attn.q_norm.weight" in state_dict: + _broadcast_tensor( + sync_layer.self_attention.q_layernorm.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.q_norm.weight", + ) + _broadcast_tensor( + sync_layer.self_attention.k_layernorm.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.k_norm.weight", + ) + + _broadcast_tp_shard_tensor_qkv( + sync_layer.self_attention.linear_qkv.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.q_proj.weight", + f"{layer_name}.self_attn.k_proj.weight", + f"{layer_name}.self_attn.v_proj.weight", + ) + if f"{layer_name}.self_attn.q_proj.bias" in state_dict: + _broadcast_tp_shard_tensor_qkv( + sync_layer.self_attention.linear_qkv.bias if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.q_proj.bias", + f"{layer_name}.self_attn.k_proj.bias", + f"{layer_name}.self_attn.v_proj.bias", + bias=True, + ) + + _broadcast_tp_shard_tensor( + sync_layer.self_attention.linear_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.o_proj.weight", + chunk_dim=1, + ) + _broadcast_tensor( + sync_layer.mlp.linear_fc1.layer_norm_weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.post_attention_layernorm.weight", + ) + + _broadcast_tp_shard_tensor_gate_up( + sync_layer.mlp.linear_fc1.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.mlp.gate_proj.weight", + f"{layer_name}.mlp.up_proj.weight", + ) + + _broadcast_tp_shard_tensor( + sync_layer.mlp.linear_fc2.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.mlp.down_proj.weight", + chunk_dim=1, + ) + # Final Layernorm + # ------------------- + print_rank_0("loading final layernorm...") + gpt_model_module = _get_gpt_model(models[-1]) + _broadcast_tensor( + getattr(gpt_model_module.decoder.final_layernorm, "weight", None), + "model.norm.weight", + ) + + print_rank_0("loading lm_head...") + lm_head_weight = None + if pp_rank + 1 == pp_size: + lm_head_weight = gpt_model_module.output_layer.weight + + if is_value_model: + # if torch.distributed.get_rank() == src_rank: + if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1: + _broadcast_tensor(lm_head_weight, "lm_head.weight") + elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1: + _broadcast_tensor(lm_head_weight, "reward_head.weight") + print_rank_0("load lm_head from value_head weight") + elif "score.weight" in state_dict and state_dict["score.weight"].shape[0] == 1: + _broadcast_tensor(lm_head_weight, "score.weight") + print_rank_0("load lm_head from score weight") + else: + _broadcast_tensor(None, "lm_head.weight") + print_rank_0("fail to match lm_head in value_model") + # else: + + # _broadcast_tensor(lm_head_weight, "lm_head.weight") + + else: + _broadcast_tp_shard_tensor(lm_head_weight, "lm_head.weight") + dist.barrier() + # Broadcast weights inside data parallel groups + for wrapped_model in wrapped_models: + broadcast_params(wrapped_model) + pass + get_torch_device().empty_cache() + print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") diff --git a/code/RL_model/verl/verl_train/verl/models/mcore/mbridge.py b/code/RL_model/verl/verl_train/verl/models/mcore/mbridge.py new file mode 100644 index 0000000000000000000000000000000000000000..9c6d5036e3f300720f98cc8ddee3df4f06335bb1 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/mcore/mbridge.py @@ -0,0 +1,27 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# VANILLA_MBRIDGE +try: + from verl.models.mcore.patch import apply_patch_mbridge + + apply_patch_mbridge() + from mbridge import AutoBridge + from mbridge.utils.post_creation_callbacks import freeze_moe_router, make_value_model +except ImportError: + print("mbridge package not found. Please install mbridge with `pip install verl[mcore]` or `pip install mbridge`") + raise + +__all__ = ["AutoBridge", "make_value_model", "freeze_moe_router"] diff --git a/code/RL_model/verl/verl_train/verl/models/mcore/model_forward.py b/code/RL_model/verl/verl_train/verl/models/mcore/model_forward.py new file mode 100644 index 0000000000000000000000000000000000000000..10d3a1bf35e973faa66fb0408c6fc7b780205f7e --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/mcore/model_forward.py @@ -0,0 +1,282 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from verl.utils.megatron_utils import unwrap_model +from verl.workers.config import MtpConfig + +from .util import ( + postprocess_bshd, + postprocess_bshd_no_padding, + postprocess_packed_seqs, + postprocess_thd_no_padding, + preprocess_bshd, + preprocess_bshd_no_padding, + preprocess_packed_seqs, + preprocess_thd_no_padding, +) + + +def model_forward_gen(vision_model: bool = False): + def model_forward( + model, + input_ids, + attention_mask, + position_ids, + multi_modal_inputs: dict, + logits_processor=None, + logits_processor_args: dict = None, + value_model=False, + data_format: str = "thd", + mtp_config: MtpConfig = None, + ): + """Forward pass for models with sequence packing.""" + assert data_format in ["thd", "bshd"], "data_format must be 'thd' or 'bshd'" + pre_process = ( + unwrap_model(model).pre_process if not vision_model else False + ) # vision model does not need pre_process, because we pack the input_ids to thd in the forward function + post_process = unwrap_model(model).post_process + sp = unwrap_model(model).config.sequence_parallel + fp8 = unwrap_model(model).config.fp8 + use_fp8_padding = fp8 in ["e4m3", "hybrid"] + + model_kwargs = {} + if "pixel_values" in multi_modal_inputs: + model_kwargs["pixel_values"] = multi_modal_inputs["pixel_values"].to(input_ids.device) + if "image_grid_thw" in multi_modal_inputs: + model_kwargs["image_grid_thw"] = multi_modal_inputs["image_grid_thw"].to(input_ids.device) + if "pixel_values_videos" in multi_modal_inputs: + model_kwargs["pixel_values_videos"] = multi_modal_inputs["pixel_values_videos"].to(input_ids.device) + if "video_grid_thw" in multi_modal_inputs: + model_kwargs["video_grid_thw"] = multi_modal_inputs["video_grid_thw"].to(input_ids.device) + + batch_size, seq_len = attention_mask.shape[:2] + if data_format == "thd": + input_ids_rmpad, packed_seq_params = preprocess_packed_seqs( + input_ids, attention_mask, pre_process=pre_process or post_process, use_fp8_padding=use_fp8_padding + ) + input_ids_rmpad = input_ids_rmpad.contiguous() + + # when pp > 1 and processor is not None, we need to pass the labels and loss_mask to the model + if mtp_config and mtp_config.enable_train and post_process: + args = { + k: preprocess_packed_seqs(v, attention_mask, pre_process=True, use_fp8_padding=use_fp8_padding)[0] + for k, v in logits_processor_args.items() + } + model_kwargs["labels"] = args["label"].contiguous() + model_kwargs["loss_mask"] = args["label_mask"].contiguous() + + input_args = dict( + input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=position_ids if not vision_model else None, # vision models will calculate position_ids + packed_seq_params=packed_seq_params, + **model_kwargs, + ) + + if vision_model: + # workaround for supporting sequence packing with context parallelism + # cp split with sequence packing will make model lose vision token information, so we need to keep + # the original input_ids and pack them after vision embedding is calculated, + # cooporate with mbridge + input_args["input_ids"] = input_ids + input_args["attention_mask"] = attention_mask + + output_orig = model(**input_args) + + if post_process and logits_processor is not None: + args = { + k: preprocess_packed_seqs(v, attention_mask, pre_process=True, use_fp8_padding=use_fp8_padding)[0] + for k, v in logits_processor_args.items() + } + output_dict = logits_processor(output_orig, **args) + output = { + k: postprocess_packed_seqs( + v, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process + ) + for k, v in output_dict.items() + } + else: + output = postprocess_packed_seqs( + output_orig, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process + ) + elif data_format == "bshd": + """ + data_format: "thd" or "bshd", default is "thd", + why we need this? + for some new models, GPT-OSS, the thd format is not supported, so we need to use the bshd format. + When using the bshd format, we have to add paddings to the input_ids to meet the longest sequence length, + so it is recommended to disable dynamic batch size and set batch size to 1 + """ + assert not vision_model, "vision model does not support bshd format" + assert fp8 is None, "fp8 is not supported for bshd format yet" + + batch_size, sequence_length = attention_mask.shape[:2] + new_input_ids, new_attention_mask, new_position_ids = preprocess_bshd( + input_ids, attention_mask, position_ids, sequence_parallel=sp, pre_process=pre_process + ) + output_orig = model( + input_ids=new_input_ids, + position_ids=new_position_ids, + attention_mask=new_attention_mask, + **model_kwargs, + ) + if post_process and logits_processor is not None: + args = { + k: preprocess_bshd(v, attention_mask, position_ids, sequence_parallel=sp, pre_process=True)[0] + for k, v in logits_processor_args.items() + } + output_dict = logits_processor(output_orig, **args) + output = { + k: postprocess_bshd( + v, new_attention_mask, attention_mask, sequence_length, post_process=post_process + ) + for k, v in output_dict.items() + } + else: + output = postprocess_bshd( + output_orig, new_attention_mask, attention_mask, sequence_length, post_process=post_process + ) + if value_model and post_process: + output = output[..., 0] + return output + + return model_forward + + +def gptmodel_forward_no_padding( + model, + input_ids, + multi_modal_inputs: dict, + logits_processor=None, + logits_processor_args: dict = None, + value_model=False, + vision_model=False, + pad_token_id=None, + data_format: str = "thd", + enable_mtp: bool = False, +): + """Default forward pass for GPT models with optional sequence packing.""" + + assert data_format in ["thd", "bshd"], "data_format must be 'thd' or 'bshd'" + pre_process = unwrap_model(model).pre_process + post_process = unwrap_model(model).post_process + + model_kwargs = {} + if "pixel_values" in multi_modal_inputs: + model_kwargs["pixel_values"] = multi_modal_inputs["pixel_values"].to(input_ids.device) + if "image_grid_thw" in multi_modal_inputs: + model_kwargs["image_grid_thw"] = multi_modal_inputs["image_grid_thw"].to(input_ids.device) + if "pixel_values_videos" in multi_modal_inputs: + model_kwargs["pixel_values_videos"] = multi_modal_inputs["pixel_values_videos"].to(input_ids.device) + if "video_grid_thw" in multi_modal_inputs: + model_kwargs["video_grid_thw"] = multi_modal_inputs["video_grid_thw"].to(input_ids.device) + + batch_size = input_ids.shape[0] + if data_format == "thd": + input_ids_rmpad, packed_seq_params = preprocess_thd_no_padding(input_ids, pre_process=pre_process) + input_ids_rmpad = input_ids_rmpad.contiguous() + + if enable_mtp and post_process: + args = { + k: preprocess_thd_no_padding(v, pre_process=True, need_roll=(k == "label" or k == "loss_mask"))[0] + for k, v in logits_processor_args.items() + } + model_kwargs["labels"] = args["label"].contiguous() + model_kwargs["loss_mask"] = args["loss_mask"].contiguous() + logits_processor_args.pop("loss_mask") + + # For VLM model, need to pass bshd format `input_ids` and `attention_mask`. + attention_mask = None + if vision_model: + input_ids_rmpad = input_ids.to_padded_tensor(pad_token_id) + seqlens_in_batch = input_ids.offsets().diff() + attention_mask = torch.zeros_like(input_ids_rmpad, dtype=torch.bool) + for i, seqlen in enumerate(seqlens_in_batch): + attention_mask[i, :seqlen] = True + + output_orig = model( + input_ids=input_ids_rmpad, + attention_mask=attention_mask, + position_ids=None, + packed_seq_params=packed_seq_params, + **model_kwargs, + ) + + if post_process and logits_processor is not None: + args = { + k: preprocess_thd_no_padding(v, pre_process=True, need_roll=(k == "label"))[0] + for k, v in logits_processor_args.items() + } + output_dict = logits_processor(output_orig, **args) + output = { + k: postprocess_thd_no_padding(v, packed_seq_params, input_ids, batch_size, post_process=post_process) + for k, v in output_dict.items() + } + else: + output = postprocess_thd_no_padding( + output_orig, packed_seq_params, input_ids, batch_size, post_process=post_process + ) + else: + """ + data_format: "thd" or "bshd", default is "thd", + why we need this? + for some new models, GPT-OSS, the thd format is not supported, so we need to use the bshd format. + When using the bshd format, we have to add paddings to the input_ids to meet the longest sequence length, + so it is recommended to disable dynamic batch size and set batch size to 1 + """ + + input_ids_bshd, attention_mask_bshd, position_ids_bshd = preprocess_bshd_no_padding( + input_ids, pre_process=pre_process + ) + + if enable_mtp and post_process: + args = { + k: preprocess_bshd_no_padding(v, pre_process=True, need_roll=(k == "label" or k == "loss_mask"))[0] + for k, v in logits_processor_args.items() + } + model_kwargs["labels"] = args["label"].contiguous() + model_kwargs["loss_mask"] = args["loss_mask"].contiguous() + logits_processor_args.pop("loss_mask") + + output_orig = model( + input_ids=input_ids_bshd, + attention_mask=attention_mask_bshd, + position_ids=position_ids_bshd, + **model_kwargs, + ) + if post_process and logits_processor is not None: + args = { + k: preprocess_bshd_no_padding(v, pre_process=True, need_roll=(k == "label"))[0] + for k, v in logits_processor_args.items() + } + output_dict = logits_processor(output_orig, **args) + output = { + k: postprocess_bshd_no_padding(v, attention_mask_bshd, post_process=post_process) + for k, v in output_dict.items() + } + else: + output = postprocess_bshd_no_padding(output_orig, attention_mask_bshd, post_process=post_process) + + if value_model and post_process: + # output = output[..., 0] + # while using nested tensor, the advanced indexing operation above will result in an error at backward, i.e. + # ValueError: NestedTensor _nested_select_backward_default(grad_output: t, self: jt_all, dim: any, index: any) + # so we use `squeeze` to remove the last dimension + output = output.squeeze(-1) + + return output diff --git a/code/RL_model/verl/verl_train/verl/models/mcore/model_forward_1f1b_overlap.py b/code/RL_model/verl/verl_train/verl/models/mcore/model_forward_1f1b_overlap.py new file mode 100644 index 0000000000000000000000000000000000000000..b8786e01f884e78fda4b37dc902a136ad0c1b5dd --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/mcore/model_forward_1f1b_overlap.py @@ -0,0 +1,252 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional + +import torch +from megatron.core.models.common.model_chunk_schedule_plan import TransformerModelChunkSchedulePlan +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.utils import make_viewless_tensor +from torch import Tensor + +from verl.models.mcore.util import preprocess_packed_seqs +from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy +from verl.utils.megatron_utils import unwrap_model +from verl.utils.model import CausalLMOutputForPPO + +from .util import postprocess_packed_seqs, postprocess_packed_seqs_for_dict_output + + +def gptmodel_forward_1f1b_overlap( + model: GPTModel, + input_ids: Tensor, + position_ids: Tensor, + attention_mask: Tensor, + labels: Tensor = None, + labels_mask: Tensor = None, + multi_modal_inputs: Optional[dict] = None, + logits_processor: Optional[Callable] = None, + logits_processor_args: Optional[dict] = None, + temperature: float = 1.0, +) -> TransformerModelChunkSchedulePlan: + pre_process: bool = unwrap_model(model).pre_process + post_process: bool = unwrap_model(model).post_process + assert logits_processor is None, "only support fused kernel" + batch_size, seq_len = attention_mask.shape[:2] + input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=pre_process) + input_ids_rmpad = input_ids_rmpad.contiguous() + + schedule_plan = model.build_schedule_plan( + input_ids=input_ids_rmpad, + attention_mask=attention_mask, + labels=labels, + position_ids=position_ids, + packed_seq_params=packed_seq_params, + ) + if post_process: + attention_mask_out = attention_mask + + def _postprocess( + self, + hidden_states, + input_ids, + position_ids, + labels, + rotary_pos_emb, + rotary_pos_cos, + rotary_pos_sin, + mtp_in_postprocess=None, + loss_mask=None, + decoder_input=None, + attention_mask=None, + inference_params=None, + packed_seq_params=None, + sequence_len_offset=None, + runtime_gather_output=None, + extra_block_kwargs=None, + inference_context=None, + ): + """patched from https://github.com/NVIDIA/Megatron-LM/blob/core_r0.14.0/megatron/core/models/gpt/gpt_model.py#L412""" + """Postprocesses decoder hidden states to generate logits or compute loss. + + Applies Multi-Token Prediction if enabled, generates output logits through + the output layer, and computes language model loss when labels are provided. + """ + from megatron.core import parallel_state + from megatron.core.tensor_parallel import gather_from_sequence_parallel_region + + in_inference_mode = inference_context is not None and not self.training + if in_inference_mode: + assert runtime_gather_output, "Inference must always gather TP logits" + + # logits and loss + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + + if mtp_in_postprocess: + hidden_states = self.mtp( + input_ids=input_ids, + position_ids=position_ids, + hidden_states=hidden_states, + attention_mask=attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + embedding=self.embedding, + **(extra_block_kwargs or {}), + ) + + if not self.post_process: + return hidden_states + + if self.mtp_process: + from megatron.core.transformer.multi_token_prediction import ( + MTPLossAutoScaler, + MTPLossLoggingHelper, + roll_tensor, + ) + + mtp_labels = labels.clone() + hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) + hidden_states = hidden_states_list[0] + if loss_mask is None: + # if loss_mask is not provided, use all ones as loss_mask + loss_mask = torch.ones_like(mtp_labels) + for mtp_layer_number in range(self.config.mtp_num_layers): + # output + mtp_logits, _ = self.output_layer( + hidden_states_list[mtp_layer_number + 1], + weight=output_weight, + runtime_gather_output=runtime_gather_output, + ) + # Calc loss for the current Multi-Token Prediction (MTP) layers. + mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group) + loss_mask, num_tokens = roll_tensor(loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group) + mtp_loss = self.compute_language_model_loss(mtp_labels, mtp_logits) + mtp_loss = loss_mask * mtp_loss + if self.training: + # TODO(shifangx): remove the use of parallel_state here + # after moving loss logging to loss_func in pretrain_gpt.py + MTPLossLoggingHelper.save_loss_to_tracker( + torch.sum(mtp_loss) / num_tokens, + mtp_layer_number, + self.config.mtp_num_layers, + avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True), + ) + mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers + if self.config.calculate_per_token_loss: + hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss) + else: + hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens) + + if logits_processor is not None: + logits, _ = self.output_layer( + hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output + ) + output_orig = logits.transpose(0, 1).contiguous() + args = { + k: preprocess_packed_seqs(v, attention_mask_out, pre_process=True)[0] + for k, v in logits_processor_args.items() + } + output_dict = logits_processor(output_orig, **args) + output = { + k: postprocess_packed_seqs( + v, packed_seq_params, attention_mask_out, batch_size, seq_len, post_process=post_process + ) + for k, v in output_dict.items() + } + else: + # fused kernel + + labels_rmpad, _ = preprocess_packed_seqs(labels, attention_mask, pre_process=True) + labels_mask_rmpad, _ = preprocess_packed_seqs(labels_mask, attention_mask, pre_process=True) + labels_rmpad = labels_rmpad.contiguous() + labels_mask_rmpad = labels_mask_rmpad.contiguous() + + output = CausalLMOutputForPPO( + loss=None, + logits=None, + past_key_values=None, + hidden_states=hidden_states, + attentions=None, + ) + if self.config.sequence_parallel: + hidden_states = gather_from_sequence_parallel_region(hidden_states) + logprobs, entropy = linear_cross_entropy( + hidden_states, + self.output_layer.weight, + labels_rmpad, + temperature, + "none", + parallel_state.get_tensor_model_parallel_group(), + ) + output.entropy = entropy + output.log_probs = logprobs + + output = postprocess_packed_seqs_for_dict_output( + labels_mask_rmpad, + output, + packed_seq_params, + attention_mask, + batch_size, + seq_len, + post_process=post_process, + ) + output_ = [output["log_probs"]] + # TODO NOW 1f1b overlap only support one tensor output + # if "entropy" in output: + # output_.append(output["entropy"]) + output_ = tuple(output_) + return output_ + + def _custom_post_process_node_forward_impl(self, hidden_states): + if self.gpt_model.decoder.final_layernorm and not self.gpt_model.mtp_process: + hidden_states = self.gpt_model.decoder.final_layernorm(hidden_states) + # TENorm produces a "viewed" tensor. This will result in schedule.py's + # deallocate_output_tensor() throwing an error, so a viewless tensor is + # created to prevent this. + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + + # Run GPTModel._postprocess + output = self.gpt_model._postprocess( + hidden_states=hidden_states, + input_ids=self.chunk_state.input_ids, + position_ids=self.chunk_state.position_ids, + labels=self.chunk_state.labels, + decoder_input=self.chunk_state.decoder_input, + rotary_pos_emb=self.chunk_state.rotary_pos_emb, + rotary_pos_cos=self.chunk_state.rotary_pos_cos, + rotary_pos_sin=self.chunk_state.rotary_pos_sin, + mtp_in_postprocess=False, + loss_mask=self.chunk_state.loss_mask, + attention_mask=self.chunk_state.attention_mask, + packed_seq_params=self.chunk_state.packed_seq_params, + sequence_len_offset=self.chunk_state.sequence_len_offset, + runtime_gather_output=self.chunk_state.runtime_gather_output, + extra_block_kwargs=self.chunk_state.extra_block_kwargs, + ) + return output + + schedule_plan.post_process.forward_impl = _custom_post_process_node_forward_impl.__get__( + schedule_plan.post_process, schedule_plan.post_process.__class__ + ) + unwrap_model(model)._postprocess = _postprocess.__get__(unwrap_model(model), unwrap_model(model).__class__) + + return schedule_plan diff --git a/code/RL_model/verl/verl_train/verl/models/mcore/model_forward_fused.py b/code/RL_model/verl/verl_train/verl/models/mcore/model_forward_fused.py new file mode 100644 index 0000000000000000000000000000000000000000..0826caa9c72d158d68b5830417e631e361a7e6df --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/mcore/model_forward_fused.py @@ -0,0 +1,237 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict +from typing import Optional + +import megatron.core as mcore +import torch +from megatron.core import parallel_state +from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk +from megatron.core.inference.contexts import BaseInferenceContext +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region +from megatron.core.utils import deprecate_inference_params +from packaging import version +from torch import Tensor + +from verl.models.mcore.util import preprocess_packed_seqs +from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy +from verl.utils.megatron_utils import unwrap_model +from verl.utils.model import CausalLMOutputForPPO + +from .util import postprocess_packed_seqs_for_dict_output + + +def _get_patching_model(model: torch.nn.Module): + model = unwrap_model(model) + if isinstance(model, GPTModel): + return model + + if not (hasattr(model, "language_model") and isinstance(model.language_model, GPTModel)): + print(f"Model {model.__class__.__name__} is not a supported for fused forward") + return None + + return model.language_model + + +def patch_fused_forward(model: torch.nn.Module): + assert version.parse(mcore.__version__) >= version.parse("0.13.0"), ( + "Fused forward patching requires mecore >= 0.13.0" + ) + model = _get_patching_model(model) + if model is not None: + model.forward_backup = model.forward + model.forward = _fused_GPTModel_forward.__get__(model, model.__class__) + + +def unpatch_fused_forward(model: torch.nn.Module): + model = _get_patching_model(model) + if model is not None: + model.forward = model.forward_backup + + +def fused_forward_model_gen(vision_model: bool = False): + def fused_forward_model( + model, + input_ids: Tensor, + position_ids: Tensor, + attention_mask: Tensor, + labels: Tensor, + labels_mask: Tensor, + temperature: float, + multi_modal_inputs: dict, + ): + pre_process: bool = ( + unwrap_model(model).pre_process if not vision_model else False + ) # vision model does not need pre_process, because we pack the input_ids to thd in the forward function + post_process: bool = unwrap_model(model).post_process + + model_kwargs = {} + if "pixel_values" in multi_modal_inputs: + model_kwargs["pixel_values"] = multi_modal_inputs["pixel_values"].to(input_ids.device) + if "image_grid_thw" in multi_modal_inputs: + model_kwargs["image_grid_thw"] = multi_modal_inputs["image_grid_thw"].to(input_ids.device) + if "pixel_values_videos" in multi_modal_inputs: + model_kwargs["pixel_values_videos"] = multi_modal_inputs["pixel_values_videos"].to(input_ids.device) + if "video_grid_thw" in multi_modal_inputs: + model_kwargs["video_grid_thw"] = multi_modal_inputs["video_grid_thw"].to(input_ids.device) + + batch_size, seq_len = attention_mask.shape[:2] + input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=pre_process) + input_ids_rmpad = input_ids_rmpad.contiguous() + labels_rmpad, _ = preprocess_packed_seqs(labels, attention_mask, pre_process=True) + labels_mask_rmpad, _ = preprocess_packed_seqs(labels_mask, attention_mask, pre_process=True) + labels_rmpad = labels_rmpad.contiguous() + labels_mask_rmpad = labels_mask_rmpad.contiguous() + + input_args = dict( + input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=position_ids if not vision_model else None, # vision models will calculate position_ids + packed_seq_params=packed_seq_params, + labels=labels_rmpad, + temperature=temperature, + **model_kwargs, + ) + + if vision_model: + # workaround for supporting sequence packing with context parallelism + # cp split with sequence packing will make model lose vision token information, so we need to keep + # the original input_ids and pack them after vision embedding is calculated, + # cooporate with mbridge + input_args["input_ids"] = input_ids + input_args["attention_mask"] = attention_mask + + output_orig: CausalLMOutputForPPO = model(**input_args) + + if post_process: + # output_orig is in type of CausalLMOutputForPPO + output = postprocess_packed_seqs_for_dict_output( + labels_mask_rmpad, + output_orig, + packed_seq_params, + attention_mask, + batch_size, + seq_len, + post_process=post_process, + ) + else: + output = output_orig + return output + + return fused_forward_model + + +def _fused_GPTModel_forward( + model, + input_ids: Tensor, + position_ids: Tensor, + attention_mask: Tensor, + decoder_input: Tensor = None, + labels: Tensor = None, + inference_context: BaseInferenceContext = None, + packed_seq_params: PackedSeqParams = None, + extra_block_kwargs: dict = None, + runtime_gather_output: Optional[bool] = None, + *, + inference_params: Optional[BaseInferenceContext] = None, + loss_mask: Optional[Tensor] = None, + temperature: float = 1.0, + **kwargs, +) -> CausalLMOutputForPPO: + """ + Patch self._postprocess in forward for GPT models to enable fused kernel support. + https://github.com/NVIDIA/Megatron-LM/blob/core_v0.13.0/megatron/core/models/gpt/gpt_model.py + + TODO: Currently we still need to patch `forward` because we need to pass `temperature` + explicitly to `self._postprocess` when calling, maybe there can be a better way to handle this? + """ + + inference_context = deprecate_inference_params(inference_context, inference_params) + + preproc_output = model._preprocess( + input_ids=input_ids, + position_ids=position_ids, + decoder_input=decoder_input, + inference_context=inference_context, + packed_seq_params=packed_seq_params, + ) + + (decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset) = preproc_output[:5] + + # Run decoder. + hidden_states = model.decoder( + hidden_states=decoder_input, + attention_mask=attention_mask, + inference_context=inference_context, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + **(extra_block_kwargs or {}), + **kwargs, + ) + + if not model.post_process: + return hidden_states + + output = CausalLMOutputForPPO( + loss=None, + logits=None, + past_key_values=None, + hidden_states=hidden_states, + attentions=None, + ) + + if model.config.sequence_parallel: + hidden_states = gather_from_sequence_parallel_region(hidden_states) + + # Get the output weight - use embedding weight if output_layer is None or weight is shared + if hasattr(model, "output_layer") and model.output_layer is not None and model.output_layer.weight is not None: + output_weight = model.output_layer.weight + else: + # When embeddings are tied, use the embedding weight + output_weight = model.embedding.word_embeddings.weight + + logprobs, entropy = linear_cross_entropy( + hidden_states, + output_weight, + labels, + temperature, + "none", + parallel_state.get_tensor_model_parallel_group(), + ) + + if has_config_logger_enabled(model.config): + payload = OrderedDict( + { + "input_ids": input_ids, + "position_ids": position_ids, + "attention_mask": attention_mask, + "decoder_input": decoder_input, + "logprobs": logprobs, + "entropy": entropy, + } + ) + log_config_to_disk(model.config, payload, prefix="input_and_logits") + + output.entropy = entropy + output.log_probs = logprobs + + return output diff --git a/code/RL_model/verl/verl_train/verl/models/mcore/model_initializer.py b/code/RL_model/verl/verl_train/verl/models/mcore/model_initializer.py new file mode 100644 index 0000000000000000000000000000000000000000..49a30bc9e2c982fa4e1182d6da745cdd34251dd5 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/mcore/model_initializer.py @@ -0,0 +1,276 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# use mcore transformer config to initialize the model +import inspect +from abc import ABC, abstractmethod + +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec, get_gpt_mtp_block_spec +from megatron.core.models.gpt.gpt_model import GPTModel + +from .config_converter import PretrainedConfig, TransformerConfig + + +class BaseModelInitializer(ABC): + """Base class for model initializers.""" + + def __init__(self, tfconfig: TransformerConfig, hf_config: PretrainedConfig): + self.tfconfig = tfconfig + self.hf_config = hf_config + self.has_vp_stage = inspect.signature(get_gpt_decoder_block_spec).parameters.get("vp_stage", None) is not None + + @abstractmethod + def get_transformer_layer_spec(self, vp_stage=None): + """Get the transformer layer specification. + https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/models/gpt/gpt_layer_specs.py""" + pass + + def get_rope_scaling_args(self) -> dict: + """Get rope scaling args.""" + rope_scaling_args = {} + if "rope_scaling" in self.hf_config: + if self.hf_config.rope_scaling is not None: + # assert self.hf_config.rope_scaling["type"] == "linear", "only linear scaling is supported for now" + rope_scaling_args["seq_len_interpolation_factor"] = self.hf_config.rope_scaling["factor"] + return rope_scaling_args + + def initialize( + self, + pre_process: bool = True, + post_process: bool = True, + share_embeddings_and_output_weights: bool = False, + value: bool = False, + **extra_kwargs, + ) -> GPTModel: + """Initialize a GPT model with the given configuration. + https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/models/gpt/gpt_model.py + + Args: + pre_process (bool): include embedding layer. + post_process (bool): including an output layer. + share_embeddings_and_output_weights (bool): input embeddings and output logit weights are shared. + value (bool): add an extra linear layer for classification or regression. + + Returns: + GPTModel: An initialized GPT model instance + """ + vp_stage = extra_kwargs.get("vp_stage", None) + transformer_layer_spec = self.get_transformer_layer_spec(vp_stage=vp_stage) + rope_scaling_args = self.get_rope_scaling_args() + mtp_block_spec = extra_kwargs.get("mtp_block_spec", None) + model = GPTModel( + config=self.tfconfig, + transformer_layer_spec=transformer_layer_spec, + vocab_size=self.hf_config.vocab_size, + max_sequence_length=self.hf_config.max_position_embeddings, + pre_process=pre_process, + post_process=post_process, + share_embeddings_and_output_weights=share_embeddings_and_output_weights, + position_embedding_type="rope", + rotary_base=self.hf_config.rope_theta, + **rope_scaling_args, + mtp_block_spec=mtp_block_spec, + **({} if not self.has_vp_stage else {"vp_stage": vp_stage}), + ) + + if post_process and value: + from verl.models.llama.megatron.layers.parallel_linear import LinearForLastLayer + + model.output_layer = LinearForLastLayer( + input_size=self.tfconfig.hidden_size, output_size=1, config=self.tfconfig + ) + + return model + + +class DenseModel(BaseModelInitializer): + """Initializer for dense models like Llama and Qwen2.""" + + def get_transformer_layer_spec(self, vp_stage=None): + assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now" + extra_kwargs = {} if not self.has_vp_stage else {"vp_stage": vp_stage} + return get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True, **extra_kwargs) + + +class Qwen2MoEModel(BaseModelInitializer): + """Initializer for Qwen2 MoE models.""" + + def get_transformer_layer_spec(self, vp_stage=None): + assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now" + extra_kwargs = {} if not self.has_vp_stage else {"vp_stage": vp_stage} + transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True, **extra_kwargs) + + # Patch layer spec for shared experts + for i in range(len(transformer_layer_spec.layer_specs)): + transformer_layer_spec.layer_specs[i].submodules.mlp.submodules.shared_experts.params["gate"] = True + + return transformer_layer_spec + + def initialize(self, **kwargs): + # Qwen default freeze_moe_router: true + model = super().initialize(**kwargs) + freeze_moe_router = kwargs.get("freeze_moe_router", True) + if freeze_moe_router: + for layer in model.decoder.layers: + layer.mlp.router.weight.requires_grad = False + return model + + +class MixtralModel(BaseModelInitializer): + """Initializer for Mixtral models.""" + + def get_transformer_layer_spec(self, vp_stage=None): + assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now" + extra_kwargs = {} if not self.has_vp_stage else {"vp_stage": vp_stage} + transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True, **extra_kwargs) + return transformer_layer_spec + + def initialize(self, **kwargs): + model = super().initialize(**kwargs) + freeze_moe_router = kwargs.get("freeze_moe_router", False) + if freeze_moe_router: + for layer in model.decoder.layers: + layer.mlp.router.weight.requires_grad = False + return model + + +class Qwen3MoEModel(BaseModelInitializer): + """Initializer for Qwen3 MoE models.""" + + def get_transformer_layer_spec(self, vp_stage=None): + assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now" + extra_kwargs = {} if not self.has_vp_stage else {"vp_stage": vp_stage} + transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True, **extra_kwargs) + return transformer_layer_spec + + def initialize(self, **kwargs): + # Qwen default freeze_moe_router: true + model = super().initialize(**kwargs) + freeze_moe_router = kwargs.get("freeze_moe_router", True) + if freeze_moe_router: + for layer in model.decoder.layers: + layer.mlp.router.weight.requires_grad = False + return model + + +class DeepseekV3Model(BaseModelInitializer): + """Initializer for DeepseekV3 models.""" + + def get_transformer_layer_spec(self, vp_stage=None): + extra_kwargs = {} if not self.has_vp_stage else {"vp_stage": vp_stage} + transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True, **extra_kwargs) + return transformer_layer_spec + + def get_rope_scaling_args(self) -> dict: + """Get rope scaling args.""" + rope_scaling_args = {} + return rope_scaling_args + + def initialize( + self, + **kwargs, + ): + vp_stage = kwargs.get("vp_stage", None) + freeze_moe_router = kwargs.get("freeze_moe_router", True) + if freeze_moe_router: + self.tfconfig.moe_router_load_balancing_type = "none" + # MTP + if self.tfconfig.mtp_num_layers is not None and self.tfconfig.mtp_num_layers > 0: + transformer_layer_spec = self.get_transformer_layer_spec(vp_stage=vp_stage) + mtp_block_spec = get_gpt_mtp_block_spec( + self.tfconfig, transformer_layer_spec, use_transformer_engine=True, vp_stage=vp_stage + ) + kwargs["mtp_block_spec"] = mtp_block_spec + + model = super().initialize(**kwargs) + if freeze_moe_router: + for layer in model.decoder.layers: + if hasattr(layer.mlp, "router"): + layer.mlp.router.weight.requires_grad = False + return model + + +class Qwen25VLModel(BaseModelInitializer): + """Initializer for Qwen2.5 VL models.""" + + def get_transformer_layer_spec(self, vp_stage=None): + extra_kwargs = {} if not self.has_vp_stage else {"vp_stage": vp_stage} + transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True, **extra_kwargs) + return transformer_layer_spec + + def initialize( + self, + pre_process=None, + post_process=None, + share_embeddings_and_output_weights=False, + value=False, + **extra_kwargs, + ): + tfconfig = self.tfconfig + hf_config = self.hf_config + # Qwen2_5_VLForConditionalGeneration + from copy import deepcopy + + transformer_layer_spec = self.get_transformer_layer_spec() + + from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TERowParallelLinear + from megatron.core.models.gpt.moe_module_specs import MLPSubmodules + from megatron.core.models.vision.vit_layer_specs import get_vit_layer_with_transformer_engine_spec + + from .qwen2_5_vl import Qwen2_5VLModel, get_vision_model_config, get_vision_projection_config + + vision_transformer_config = get_vision_model_config(deepcopy(tfconfig)) + vision_transformer_config.pipeline_model_parallel_size = 1 + vision_transformer_config.first_pipeline_num_layers = None + + vision_projection_config = get_vision_projection_config( + deepcopy(tfconfig), + vision_transformer_config.hidden_size, + spatial_merge_size=hf_config.vision_config.spatial_merge_size, + ) + vision_projection_layer_spec = MLPSubmodules( + linear_fc1=TEColumnParallelLinear, + linear_fc2=TERowParallelLinear, + ) + vision_transformer_layer_spec = get_vit_layer_with_transformer_engine_spec() + + qwen25_vl_model = Qwen2_5VLModel( + language_transformer_config=tfconfig, + language_transformer_layer_spec=transformer_layer_spec, + language_vocab_size=hf_config.vocab_size, + language_max_sequence_length=hf_config.max_position_embeddings, + vision_transformer_config=vision_transformer_config, + vision_transformer_layer_spec=vision_transformer_layer_spec, + vision_projection_config=vision_projection_config, + vision_projection_layer_spec=vision_projection_layer_spec, + vision_projection_type="mlp", + language_rotary_base=hf_config.rope_theta, + pre_process=pre_process, + post_process=post_process, + add_decoder=True, + add_encoder=True, + parallel_output=True, + language_share_embeddings_and_output_weights=share_embeddings_and_output_weights, + ) + + if post_process and value: + from verl.models.llama.megatron.layers.parallel_linear import LinearForLastLayer + + qwen25_vl_model.language_model.output_layer = LinearForLastLayer( + input_size=tfconfig.hidden_size, output_size=1, config=tfconfig + ) + + return qwen25_vl_model diff --git a/code/RL_model/verl/verl_train/verl/models/mcore/mtp_patch.py b/code/RL_model/verl/verl_train/verl/models/mcore/mtp_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..117b6e3f28c72e33855f74dcd2decec2cba4d461 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/mcore/mtp_patch.py @@ -0,0 +1,295 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright 2025 Meituan Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable + +import torch +from megatron.core import parallel_state +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.transformer.multi_token_prediction import ( + MTPLossAutoScaler, + MTPLossLoggingHelper, + roll_tensor, +) + +try: + from megatron.core.utils import unwrap_model +except ImportError: + from verl.utils.megatron_utils import unwrap_model + + +def _get_patching_model(model: torch.nn.Module): + model = unwrap_model(model) + if isinstance(model, GPTModel): + return model + + if not (hasattr(model, "language_model") and isinstance(model.language_model, GPTModel)): + print(f"Model {model.__class__.__name__} is not a supported for fused forward") + return None + + return model.language_model + + +def patch_postprocess(model: torch.nn.Module): + model = _get_patching_model(model) + if model is not None: + model._postprocess_backup = model._postprocess + model._postprocess = _megatron_gptmodel_postprocess.__get__(model, model.__class__) + + +def unpatch_postprocess(model: torch.nn.Module): + model = _get_patching_model(model) + if model is not None: + model._postprocess = model._postprocess_backup + + +# copy from https://github.com/NVIDIA/Megatron-LM/blob/23e092f41ec8bc659020e401ddac9576c1cfed7e/megatron/core/models/gpt/gpt_model.py +# patch the postprocess method of GPTModel to support advanced features like MTP, 1f1b overlap, etc. +def _megatron_gptmodel_postprocess( + self, + hidden_states, + input_ids, + position_ids, + labels, + rotary_pos_emb, + rotary_pos_cos, + rotary_pos_sin, + mtp_in_postprocess=None, + loss_mask=None, + decoder_input=None, + attention_mask=None, + inference_params=None, + packed_seq_params=None, + sequence_len_offset=None, + runtime_gather_output=None, + extra_block_kwargs=None, + inference_context=None, +): + """Postprocesses decoder hidden states to generate logits or compute loss. + + Applies Multi-Token Prediction if enabled, generates output logits through + the output layer, and computes language model loss when labels are provided. + """ + + # logits and loss + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + + if mtp_in_postprocess and labels is not None: + hidden_states = self.mtp( + input_ids=input_ids, + position_ids=position_ids, + hidden_states=hidden_states, + attention_mask=attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + embedding=self.embedding, + **(extra_block_kwargs or {}), + ) + + if not self.post_process: + return hidden_states + + # Skip when mtp_num_layers is None or 0 + if self.config.mtp_num_layers and labels is not None: + mtp_labels = labels.clone() + + hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) + hidden_states = hidden_states_list[0] + if loss_mask is None: + # if loss_mask is not provided, use all ones as loss_mask + loss_mask = torch.ones_like(mtp_labels) + for mtp_layer_number in range(self.config.mtp_num_layers): + # Calc loss for the current Multi-Token Prediction (MTP) layers. + mtp_labels, _ = roll_tensor( + mtp_labels, + shifts=-1, + dims=-1, + cp_group=self.cp_group, + packed_seq_params=packed_seq_params, + ) + loss_mask, num_tokens = roll_tensor( + loss_mask, + shifts=-1, + dims=-1, + cp_group=self.cp_group, + packed_seq_params=packed_seq_params, + ) + + # Compute mtp loss without storing logits to save memory. + mtp_loss = self.compute_output_layer_and_language_model_loss( + hidden_states_list[mtp_layer_number + 1], + labels=mtp_labels, + weight=self.shared_embedding_or_output_weight(), + sequence_parallel_enabled=self.output_layer.sequence_parallel, + column_parallel_linear=self.output_layer, + col_linear_kwargs={ + "weight": output_weight, + "runtime_gather_output": runtime_gather_output, + }, + ) + + mtp_loss = loss_mask * mtp_loss + if self.training: + # TODO(shifangx): remove the use of parallel_state here + # after moving loss logging to loss_func in pretrain_gpt.py + MTPLossLoggingHelper.save_loss_to_tracker( + torch.sum(mtp_loss) / num_tokens, + mtp_layer_number, + self.config.mtp_num_layers, + avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True), + ) + mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers + if self.config.calculate_per_token_loss: + hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss) + else: + hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens) + + logits, _ = self.output_layer(hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output) + # [s b h] => [b s h] + return logits.transpose(0, 1).contiguous() + + +def patch_mtp_layer_get_embeddings(model: torch.nn.Module): + """Patch the _get_embeddings method of MultiTokenPredictionLayer""" + from megatron.core.models.gpt.gpt_model import GPTModel + from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionLayer + + # Unwrap each model in the actor_module to get the actual GPTModel + model = _get_patching_model(model) + # Collect all MultiTokenPredictionLayer instances + target_layers = [] + + if isinstance(model, GPTModel): + # Check if GPTModel has MTP and find the layers + if hasattr(model, "mtp") and hasattr(model.mtp, "layers"): + for layer in model.mtp.layers: + if isinstance(layer, MultiTokenPredictionLayer): + target_layers.append(layer) + elif hasattr(model, "layers"): + # Check if any layer in the model is MultiTokenPredictionLayer + for layer in model.layers: + if isinstance(layer, MultiTokenPredictionLayer): + target_layers.append(layer) + + if target_layers: + for layer in target_layers: + layer._get_embeddings_backup = layer._get_embeddings + layer._get_embeddings = _patched_get_embeddings_for_detach.__get__(layer, layer.__class__) + print(f"Found and patched {len(target_layers)} MTP layer(s) in any of the actor modules") + return True + else: + print("No MTP layers found to patch in any of the actor modules") + return False + + +def unpatch_mtp_layer_get_embeddings(model: torch.nn.Module): + """Unpatch the _get_embeddings method of MultiTokenPredictionLayer""" + from megatron.core.models.gpt.gpt_model import GPTModel + from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionLayer + + # Unwrap each model in the actor_module to get the actual GPTModel + model = _get_patching_model(model) + + # Collect all MultiTokenPredictionLayer instances + target_layers = [] + + if isinstance(model, GPTModel): + # Check if GPTModel has MTP and find the layers + if hasattr(model, "mtp") and hasattr(model.mtp, "layers"): + for layer in model.mtp.layers: + if isinstance(layer, MultiTokenPredictionLayer): + target_layers.append(layer) + elif hasattr(model, "layers"): + # Check if any layer in the model is MultiTokenPredictionLayer + for layer in model.layers: + if isinstance(layer, MultiTokenPredictionLayer): + target_layers.append(layer) + + unpatched_count = 0 + for layer in target_layers: + if hasattr(layer, "_get_embeddings_backup"): + layer._get_embeddings = layer._get_embeddings_backup + delattr(layer, "_get_embeddings_backup") + unpatched_count += 1 + + if unpatched_count > 0: + print(f"Unpatched {unpatched_count} MTP layer(s)") + return True + return False + + +def _patched_get_embeddings_for_detach( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + embedding: Callable, + hidden_states: torch.Tensor, + packed_seq_params=None, +): + """ + Patched version of _get_embeddings method for MultiTokenPredictionLayer. + + This is a modified version that you can customize according to your needs. + The original implementation is preserved below with modifications. + """ + + # You can modify the logic here as needed + # For example, you could: + # - Change the shift amount in roll_tensor + # - Apply custom transformations to input_ids or position_ids + # - Add debugging information + # - Modify the embedding computation + + # Original logic with custom modifications + from megatron.core.transformer.multi_token_prediction import roll_tensor + from megatron.core.utils import make_viewless_tensor + + # Calc logits for the current Multi-Token Prediction (MTP) layers. + input_ids, _ = roll_tensor( + input_ids, + shifts=-1, # You can modify this shift value + dims=-1, + cp_group=self.cp_group, + packed_seq_params=packed_seq_params, + ) + position_ids, _ = roll_tensor( + position_ids, + shifts=-1, # You can modify this shift value + dims=-1, + cp_group=self.cp_group, + packed_seq_params=packed_seq_params, + ) + + # embedding computation - you can modify this part + decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) + + # Apply custom transformations if needed + # For example: decoder_input = some_custom_function(decoder_input) + + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + + # detach decoder_input and hidden_states + decoder_input = decoder_input.detach() + hidden_states = hidden_states.detach() + + return input_ids, position_ids, decoder_input, hidden_states diff --git a/code/RL_model/verl/verl_train/verl/models/mcore/patch.py b/code/RL_model/verl/verl_train/verl/models/mcore/patch.py new file mode 100644 index 0000000000000000000000000000000000000000..9b26e8e0f5b03b1c01456bb84b9d31f8b6797931 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/mcore/patch.py @@ -0,0 +1,364 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# there is some bug in mcore 0.12, so we need to patch it +# 1. `get_query_key_value_tensors` in `multi_latent_attention.py` works wrong when packed_seq_params is not None + + +def apply_patch(): + import megatron.core + import torch + import torch.nn.functional as F + from megatron.core import parallel_state, tensor_parallel + from megatron.core.transformer.multi_latent_attention import ( + MLASelfAttention, + MultiLatentAttention, + apply_rotary_pos_emb, + deprecate_inference_params, + gather_from_sequence_parallel_region, + gather_from_tensor_model_parallel_region, + scatter_to_sequence_parallel_region, + ) + from packaging import version + + mcore_013 = version.parse(megatron.core.__version__) >= version.parse("0.13.0rc0") + + def patch_get_query_key_value_tensors( + self, + hidden_states, + key_value_states=None, + position_ids=None, + packed_seq_params=None, + inference_context=None, + *, + inference_params=None, + ): + """ + Derives `query`, `key` and `value` tensors from `hidden_states`. + """ + # s = sequence length, b = batch size, h = hidden size, n = num attention heads + # Attention heads [s, b, n*h] + assert hidden_states.ndim == 3, f"hidden_states should be 3D, [s, b, n*h], got {hidden_states.ndim}D" + + inference_context = deprecate_inference_params(inference_context, inference_params) + + # ========================================= + # Prepare RoPE and seqlen related params + # ========================================= + rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( + inference_context, None, hidden_states, self.config, packed_seq_params + ) + + # rotary_pos_emb:[s, b, 1, 64] + mscale = 1.0 + if self.config.rope_type == "rope": + packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == "thd" + try: + # In case of TypeError: RotaryEmbedding.forward() got an unexpected keyword argument 'packed_seq' + rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len, packed_seq=packed_seq) + except TypeError: + rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) + else: + rotary_pos_emb, mscale = self.rotary_pos_emb(rotary_seq_len) + + # ========================================= + # QKV down projection and layernorm + # ========================================= + if self.config.q_lora_rank is not None: + # if linear_q_down_proj is ColumnParallelLinear: + # q_compressed: [s, b, q_lora_rank / TP] + # elif linear_q_down_proj is Linear: + # q_compressed: [s / TP, b, q_lora_rank] + q_compressed, _ = self.linear_q_down_proj(hidden_states) + + # When output is sharded (ColumnParallelLinear), two things are needed to be + # identical to a normal Linear. + # 1. Manually gather output to restore output dim q_lora_rank; + # 2. Scatter sequence back to s / TP if sequence-parallel since it was + # gathered by ColumnParallelLinear. + if q_compressed.size(-1) != self.config.q_lora_rank: + q_compressed = gather_from_tensor_model_parallel_region(q_compressed) + if self.config.sequence_parallel: + q_compressed = scatter_to_sequence_parallel_region(q_compressed) + + q_compressed = self.q_layernorm(q_compressed) + else: + q_compressed = hidden_states + + # if linear_kv_down_proj is ColumnParallelLinear: + # kv_combined: [s, b, (kv_lora_rank + qk_pos_emb_head_dim) / TP] + # elif linear_kv_down_proj is Linear: + # kv_combined: [s / TP, b, (kv_lora_rank + qk_pos_emb_head_dim)] + kv_combined, _ = self.linear_kv_down_proj(hidden_states) + if kv_combined.size(-1) != self.config.kv_lora_rank + self.config.qk_pos_emb_head_dim: + # kv_combined: [s, b, (kv_lora_rank + qk_pos_emb_head_dim)] + kv_combined = gather_from_tensor_model_parallel_region(kv_combined) + # kv_compressed:[s, b, kv_lora_rank], k_pos_emb: [s, b, qk_pos_emb_head_dim] + kv_compressed, k_pos_emb = torch.split( + kv_combined, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], dim=-1 + ) + if self.config.sequence_parallel: + # kv_compressed:[s / TP, b, kv_lora_rank] + kv_compressed = scatter_to_sequence_parallel_region(kv_compressed) + else: + # kv_compressed:[s / TP, b, kv_lora_rank], k_pos_emb: [s / TP, b, qk_pos_emb_head_dim] + kv_compressed, k_pos_emb = torch.split( + kv_combined, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], dim=-1 + ) + if parallel_state.get_tensor_model_parallel_world_size() > 1: + # k_pos_emb: [s, b, qk_pos_emb_head_dim] + k_pos_emb = gather_from_sequence_parallel_region(k_pos_emb) + + kv_compressed = self.kv_layernorm(kv_compressed) + + # ========================================= + # QKV up projection and RoPE apply + # ========================================= + def qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb): + if self.config.q_lora_rank is not None: + q, _ = self.linear_q_up_proj(q_compressed) + else: + # hidden_states:[s, b, 2048], q: [s, b, n * 192] + q, _ = self.linear_q_proj(q_compressed) + + q_len, bsz, _ = q.size() + + # q: [s, b, n, 192] + q = q.view(q_len, bsz, self.num_attention_heads_per_partition, self.q_head_dim) + + # kv: [s, b, 2048] + kv, _ = self.linear_kv_up_proj(kv_compressed) + + # kv: [s, b, n, 256] + kv = kv.view( + q_len, + bsz, + self.num_attention_heads_per_partition, + self.config.qk_head_dim + self.config.v_head_dim, + ) + + cp_size = parallel_state.get_context_parallel_world_size() + if inference_context is not None: + # add offset to the sequence start for inference + sequence_start = inference_context.sequence_len_offset + sequence_end = sequence_start + q_len + rotary_pos_emb = rotary_pos_emb[sequence_start:sequence_end] + elif packed_seq_params is None or cp_size == 1: + # Shorten rotary_pos_emb to the sequence length when inference_params + # is not provided. This makes sure we can run forward directly with + # any sequence length. During training, the sequence length is always + # the full rotary_pos_emb length, except for sequence packing + CP. + # When sequence packing and context parallel are both enabled, the + # position embedding will not split rotary_pos_emb, so it may exceed + # the sequence length on this CP rank, but we need the full rotary_pos_emb + # to cover the full sequence, so we do not shorten it here. + rotary_pos_emb = rotary_pos_emb[0:q_len] + + # [s, b, 64] -> [s, b, 1, 64] + k_pos_emb = torch.unsqueeze(k_pos_emb, 2) + + # q: [s, b, n, 128], q_pos_emb: [s, b, n, 64] + q_no_pe, q_pos_emb = torch.split(q, [self.config.qk_head_dim, self.config.qk_pos_emb_head_dim], dim=-1) + + # k_no_pe: [s, b, n, 128], value: [s, b, n, 128] + k_no_pe, value = torch.split(kv, [self.config.qk_head_dim, self.config.v_head_dim], dim=-1) + + if packed_seq_params is not None: + cu_seqlens_q = packed_seq_params.cu_seqlens_q + cu_seqlens_kv = packed_seq_params.cu_seqlens_kv + q_pos_emb = q_pos_emb.squeeze(1) + k_pos_emb = k_pos_emb.squeeze(1) + q_no_pe = q_no_pe.squeeze(1) + k_no_pe = k_no_pe.squeeze(1) + value = value.squeeze(1) + else: + cu_seqlens_q = cu_seqlens_kv = None + + # q_pos_emb: [s, b, n, 64], k_pos_emb:[s, b, 1, 64] + q_pos_emb = apply_rotary_pos_emb( + q_pos_emb, + rotary_pos_emb, + config=self.config, + cu_seqlens=cu_seqlens_q, + mscale=mscale, + ) + k_pos_emb = apply_rotary_pos_emb( + k_pos_emb, + rotary_pos_emb, + config=self.config, + cu_seqlens=cu_seqlens_kv, + mscale=mscale, + ) + + # query: [s, b, n, 192] + query = torch.cat([q_no_pe, q_pos_emb], dim=-1) + if packed_seq_params is not None: + k_pos_emb = k_pos_emb.expand(-1, self.num_attention_heads_per_partition, -1) + key = torch.cat([k_no_pe, k_pos_emb], dim=-1) + else: + # key: [s, b, n, 192] + k_pos_emb = k_pos_emb.expand(-1, -1, self.num_attention_heads_per_partition, -1) + key = torch.cat([k_no_pe, k_pos_emb], dim=-1) + + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + return query, key, value + + if self.recompute_up_proj: + self.qkv_up_checkpoint = tensor_parallel.CheckpointWithoutOutput() + query, key, value = self.qkv_up_checkpoint.checkpoint( + qkv_up_proj_and_rope_apply, q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb + ) + else: + query, key, value = qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb) + + return query, key, value + + def patch_forward( + self, + hidden_states, + attention_mask, + key_value_states=None, + inference_context=None, + rotary_pos_emb=None, + rotary_pos_cos=None, + rotary_pos_sin=None, + attention_bias=None, + packed_seq_params=None, + position_ids=None, + sequence_len_offset=None, + *, + inference_params=None, + **kwargs, + ): + """Forward pass for multi-latent attention""" + assert attention_bias is None, "Attention bias should not be passed into MLA." + assert rotary_pos_cos is None and rotary_pos_sin is None, "MLA does not support Flash Decoding" + + # hidden_states: [sq, b, h] + + inference_context = deprecate_inference_params(inference_context, inference_params) + + # ===================== + # Query, Key, and Value + # ===================== + # Get the query, key and value tensors based on the type of attention - + # self or cross attn. + # query: [96, 1, 16, 128], key:[96, 1, 16, 128], value:[96, 1, 16, 128] + query, key, value = self.get_query_key_value_tensors( + hidden_states, + key_value_states, + position_ids, + packed_seq_params, + inference_context=inference_context, + ) + + # =================================================== + # Adjust key, value for inference + # =================================================== + # rotary_pos_emb = None + if mcore_013: + query, key, value, _, attn_mask_type, _ = self._adjust_key_value_for_inference( + inference_context, query, key, value, rotary_pos_emb=None + ) + else: + query, key, value, _, attn_mask_type = self._adjust_key_value_for_inference( + inference_context, query, key, value, rotary_pos_emb=None + ) + + # TODO: Currently, TE can only accept contiguous tensors for MLA + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + # ================================== + # core attention computation + # ================================== + # Need corresponding TE change + thd_qkv_format = packed_seq_params and packed_seq_params.qkv_format == "thd" + v_dim = value.shape[-1] + if thd_qkv_format and query.shape[-1] != v_dim: + value = F.pad(value, [0, query.shape[-1] - v_dim]) + self.core_attention.hidden_size_per_attention_head_v = value.shape[-1] + if self.checkpoint_core_attention and self.training: + core_attn_out = self._checkpointed_attention_forward( + query, key, value, attention_mask, packed_seq_params=packed_seq_params + ) + else: + core_attn_out = self.core_attention( + query, + key, + value, + attention_mask, + packed_seq_params=packed_seq_params, + attn_mask_type=attn_mask_type, + ) + if thd_qkv_format: + if core_attn_out.ndim == 2: + core_attn_out = core_attn_out.reshape(*core_attn_out.shape[:-1], -1, value.shape[-1]) + if query.shape[-1] != v_dim: + core_attn_out = core_attn_out[..., :v_dim] + # reshape to same output shape as unpacked case + # (t, np, hn) -> (t, b=1, h=np*hn) + # t is the pack size = sum (sq_i) + # note that batch is a dummy dimension in the packed case + core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1) + + if self.recompute_up_proj: + assert self.qkv_up_checkpoint is not None + self.qkv_up_checkpoint.discard_output_and_register_recompute(core_attn_out) + self.qkv_up_checkpoint = None + + # ================= + # Output. [sq, b, h] + # ================= + output, bias = self.linear_proj(core_attn_out) + + return output, bias + + MLASelfAttention.get_query_key_value_tensors = patch_get_query_key_value_tensors + + MultiLatentAttention.forward = patch_forward + + +def apply_patch_mbridge(): + try: + from megatron.core.utils import get_tensor_model_parallel_group_if_none + except ImportError: + import warnings + + import megatron.core.utils + import torch + from megatron.core import parallel_state + + def get_tensor_model_parallel_group_if_none(tp_group, is_expert=False, check_initialized=True): + """Issue a deprecation warning if tp_group is None and return the default tp group.""" + if not torch.distributed.is_initialized(): + return None + if tp_group is None: + if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: + warnings.warn( + "Warning: tp_group is None, using default tp group. Passing tp_group will be mandatory soon", + DeprecationWarning, + stacklevel=2, + ) + if is_expert: + tp_group = parallel_state.get_expert_tensor_parallel_group(check_initialized=check_initialized) + else: + tp_group = parallel_state.get_tensor_model_parallel_group(check_initialized=check_initialized) + return tp_group + + megatron.core.utils.get_tensor_model_parallel_group_if_none = get_tensor_model_parallel_group_if_none diff --git a/code/RL_model/verl/verl_train/verl/models/mcore/qwen2_5_vl/__init__.py b/code/RL_model/verl/verl_train/verl/models/mcore/qwen2_5_vl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8842d0249e1fa5397734bb0929e65d20978f815f --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/mcore/qwen2_5_vl/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024 Alibaba PAI Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .model import Qwen2_5VLModel +from .vision_config import get_vision_model_config, get_vision_projection_config + +__all__ = ["Qwen2_5VLModel", "get_vision_model_config", "get_vision_projection_config"] diff --git a/code/RL_model/verl/verl_train/verl/models/mcore/qwen2_5_vl/attention.py b/code/RL_model/verl/verl_train/verl/models/mcore/qwen2_5_vl/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..2a87a053c59f9ad464ced527ec017498917d92d3 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/mcore/qwen2_5_vl/attention.py @@ -0,0 +1,225 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024 Alibaba PAI Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from megatron.core.transformer.attention import * + +from .rope_utils import apply_rotary_pos_emb_absolute + + +class Qwen2_5VLSelfAttention(SelfAttention): + """ + Overrides the SelfAttention class, the difference is that qwen2_5_vl uses apply_rotary_pos_emb_absolute + instead of apply_rotary_pos_emb + """ + + def forward( + self, + hidden_states: Tensor, + attention_mask: Tensor, + key_value_states: Optional[Tensor] = None, + inference_context: Optional[BaseInferenceContext] = None, + rotary_pos_emb: Optional[Union[Tensor, Tuple[Tensor, Tensor]]] = None, + rotary_pos_cos: Optional[Tensor] = None, + rotary_pos_sin: Optional[Tensor] = None, + attention_bias: Optional[Tensor] = None, + packed_seq_params: Optional[PackedSeqParams] = None, + sequence_len_offset: Optional[int] = None, + *, + inference_params: Optional[BaseInferenceContext] = None, + rotary_pos_cos_sin: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + """ + Perform a forward pass through the attention module. + + Args: + hidden_states (Tensor): Hidden states. + attention_mask (Tensor): Attention mask. + key_value_states (Optional[Tensor]): Key/value states (for cross attention). + inference_context (Optional[BaseInferenceContext]): Inference context that manages + KV cache. + rotary_pos_emb (Optional[Union[Tensor, Tuple[Tensor, Tensor]]]): Rotary + embedding tensor(s). + rotary_pos_cos (Optional[Tensor]): Rotary embedding cosine. + rotary_pos_sin (Optional[Tensor]): Rotary embedding sine. + attention_bias (Optional[Tensor]): Attention bias. + packed_seq_params (Optional[PackedSeqparams]): Parameters used for THD format. + sequence_len_offset (Optional[int]): Sequence length offset used for + inference CUDA graphs. + + Return: + (Tuple[Tensor, Tensor]) Attention output and bias. + + """ + + inference_context = deprecate_inference_params(inference_context, inference_params) + + if inference_context and inference_context.is_dynamic_batching(): + assert flash_decode_and_prefill_kernel is not None, ( + "Internal use only: install package `nvidia_chunked_flash_attn`." + ) + + # hidden_states: [sq, b, h] + if self.config.flash_decode and not self.training and inference_context is not None: + rotary_pos_emb = None + else: + assert rotary_pos_cos is None and rotary_pos_sin is None + + # For self attention we just duplicate the rotary_pos_emb if it isn't already + if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple): + rotary_pos_emb = (rotary_pos_emb,) * 2 + + # ===================== + # Query, Key, and Value + # ===================== + # Get the query, key and value tensors based on the type of attention - + # self or cross attn. + query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) + + # =================================================== + # Adjust key, value, and rotary_pos_emb for inference + # =================================================== + + # This branch only runs in the decode phase of flash decoding and returns after the linear + # projection. This conditional is not used in the prefill phase or non-flash-decoding cases. + if ( + self.config.flash_decode + and inference_context is not None + and inference_context.is_decode_only() + and not self.training + and rotary_pos_cos is not None + ): + assert self.layer_number in inference_context.key_value_memory_dict + assert inference_context.sequence_len_offset is not None + inference_key_memory, inference_value_memory = inference_context.key_value_memory_dict[self.layer_number] + output = self.flash_decode( + sequence_len_offset=sequence_len_offset, + query_layer=query, + key_layer=key, + value_layer=value, + inference_key_memory=inference_key_memory, + inference_value_memory=inference_value_memory, + rotary_cos=rotary_pos_cos, + rotary_sin=rotary_pos_sin, + ) + out = output.transpose(0, 1).contiguous() + context_layer = out.view(out.size(0), out.size(1), -1) + output, bias = self.linear_proj(context_layer) + return output, bias + + # Use latest mcore 0.13 API and forward-compatible with previous versions. + outputs = self._adjust_key_value_for_inference( + inference_context, + query, + key, + value, + rotary_pos_emb, + rotary_pos_cos, + rotary_pos_sin, + sequence_len_offset, + ) + + query, key, value, rotary_pos_emb, attn_mask_type = outputs[:5] + + if packed_seq_params is not None: + query = query.squeeze(1) + key = key.squeeze(1) + value = value.squeeze(1) + + # ================================================ + # relative positional embedding (rotary embedding) + # ================================================ + if rotary_pos_emb is not None and not self.config.flash_decode: + q_pos_emb, k_pos_emb = rotary_pos_emb + + if packed_seq_params is not None: + if packed_seq_params.cu_seqlens_q_padded is not None: + cu_seqlens_q = packed_seq_params.cu_seqlens_q_padded + else: + cu_seqlens_q = packed_seq_params.cu_seqlens_q + if packed_seq_params.cu_seqlens_kv_padded is not None: + cu_seqlens_kv = packed_seq_params.cu_seqlens_kv_padded + else: + cu_seqlens_kv = packed_seq_params.cu_seqlens_kv + else: + cu_seqlens_q = cu_seqlens_kv = None + + if q_pos_emb is not None: + # TODO VIJAY: simplify + if inference_context is None or inference_context.is_static_batching(): + query = apply_rotary_pos_emb_absolute(query, q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q) + else: + query = inference_context.apply_rotary_emb_query(query, q_pos_emb, self.config, cu_seqlens_q) + if k_pos_emb is not None: + key = apply_rotary_pos_emb_absolute(key, k_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv) + + # TODO, can apply positional embedding to value_layer so it has + # absolute positional embedding. + # otherwise, only relative positional embedding takes effect + # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb) + + # ================================== + # core attention computation + # ================================== + + if self.checkpoint_core_attention and self.training: + core_attn_out = self._checkpointed_attention_forward( + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + ) + else: + if inference_context is None or inference_context.is_static_batching(): + # Static batching attention kernel. + core_attn_out = self.core_attention( + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + ) + + else: + # Dynamic batching attention kernel. + q, k, v = (query, key, value) + cu_query_lengths, max_seqlen_q = inference_context.cu_query_lengths() + cu_kv_lengths, max_seqlen_k = inference_context.cu_kv_lengths() + + core_attn_out = self.flash_decode_and_prefill( + q, k, v, max_seqlen_q, max_seqlen_k, cu_query_lengths, cu_kv_lengths + ) + core_attn_out = core_attn_out.squeeze(0).unsqueeze(1) + core_attn_out = rearrange(core_attn_out, "s b h d -> s b (h d)") + + if packed_seq_params is not None and packed_seq_params.qkv_format == "thd": + # reshape to same output shape as unpacked case + # (t, np, hn) -> (t, b=1, h=np*hn) + # t is the pack size = sum (sq_i) + # note that batch is a dummy dimension in the packed case + core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1) + + # ================= + # Output. [sq, b, h] + # ================= + + output, bias = self.linear_proj(core_attn_out) + + return output, bias diff --git a/code/RL_model/verl/verl_train/verl/models/mcore/qwen2_5_vl/model.py b/code/RL_model/verl/verl_train/verl/models/mcore/qwen2_5_vl/model.py new file mode 100644 index 0000000000000000000000000000000000000000..91118edfb6c4d96107249e1be921b720c0498fa0 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/mcore/qwen2_5_vl/model.py @@ -0,0 +1,372 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024 Alibaba PAI Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging + +import torch +from megatron.core import InferenceParams, mpu, tensor_parallel +from megatron.core.models.gpt.gpt_model import GPTModel + +# from .transformer_config import Qwen2VLTransformerConfig +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.transformer import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_config import TransformerConfig + +from verl.models.mcore.util import preprocess_packed_seqs + +from .attention import Qwen2_5VLSelfAttention +from .vision_model import Qwen2_5VisionModel + + +# Note: This is under development and may be missing features. +class Qwen2_5VLModel(MegatronModule): + """Qwen2.5VL multi-modal model. + + Args: + language_transformer_config (TransformerConfig): Transformer config for the language model. + language_transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers of the + language model. + language_vocab_size (int): Language model vocabulary size. + language_max_sequence_length (int): Language model maximum sequence length. This is used for + positional embedding. + vision_transformer_config (TransformerConfig): Transformer config for the vision model. + vision_transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers of the + vision model. + vision_projection_config (TransformerConfig): Config for the projection from vision model outputs to + language model inputs. + vision_projection_layer_spec (ModuleSpec): Specifies the module to use for the vision + projection. + vision_projection_type (str): Type of the vision projection to use. Default is a 2-layer MLP. + parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel ranks. This + is typically True for training and False for inference. + language_rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings + in the language model. Defaults to 1.0. + pre_process (bool): Include the embedding layer in the gpt decoder (used with pipeline parallelism). + Defaults to True. + post_process (bool): Include an output layer and a layernorm in the gpt decoder (used with pipeline + parallelism). Defaults to True. + add_encoder (bool): Construct the encoder module (used with pipeline parallelism). Defaults to True. + When we use pipelining, the encoder + will live on only a subset of the pipeline stages (specifically, only the first stage). + add_decoder (bool): Construct the decoder module (used with pipeline parallelism). Defaults to True. + When we use pipelining, the decoder + will live on only a subset of the pipeline stages (specifically, every stage after the first one). + img_h (int): The height of each image that the ViT will see. + img_w (int): The width of each image that the ViT will see. + patch_dim (int): The size of each patch side. + img_embedding_idx (int): Index in the language_embeddings tensor where image_embeddings should be + inserted. Defaults to 0. + """ + + def __init__( + self, + language_transformer_config: TransformerConfig, + language_transformer_layer_spec: ModuleSpec, + language_vocab_size: int, + language_max_sequence_length: int, + vision_transformer_config: TransformerConfig, + vision_transformer_layer_spec: ModuleSpec, + vision_projection_config: TransformerConfig, + vision_projection_layer_spec: ModuleSpec, + vision_projection_type: str = "mlp", + parallel_output: bool = True, + language_rotary_percent: float = 1.0, + pre_process: bool = True, + post_process: bool = True, + add_encoder: bool = True, + add_decoder: bool = True, + language_rotary_base: int = 10000, + fp16_lm_cross_entropy: bool = False, + language_share_embeddings_and_output_weights: bool = False, + image_token_id: int = 151655, + video_token_id: int = 151656, + ) -> None: + super().__init__(config=language_transformer_config) + + # patch self_attention to use qwen2_5_vl attention + vision_transformer_layer_spec.submodules.self_attention.module = Qwen2_5VLSelfAttention + for layer_spec in language_transformer_layer_spec.layer_specs: + layer_spec.submodules.self_attention.module = Qwen2_5VLSelfAttention + + logging.getLogger(__name__).warning("Qwen2VL model is under development and may be missing features.") + + self.pre_process = pre_process + self.post_process = post_process + self.add_encoder = add_encoder + self.add_decoder = add_decoder + + self.encoder_hidden_state = None + self.vision_model = None + self.vision_projection = None + self.language_model = None + self.image_token_id = image_token_id + self.video_token_id = video_token_id + + self.square_merge_size = vision_projection_config.ffn_hidden_size // vision_transformer_config.hidden_size + + # This attribute is needed to check if an all-reduce is required + # on the word embeddings inside `finalize_model_grads._allreduce_word_embedding_grads`. + self.share_embeddings_and_output_weights = False + if self.pre_process: + self.vision_model = Qwen2_5VisionModel( + vision_transformer_config, + vision_transformer_layer_spec, + vision_projection_config, + vision_projection_layer_spec, + projection_type=vision_projection_type, + pre_process=True, + post_process=True, + ) + + self.language_model = GPTModel( + config=language_transformer_config, + transformer_layer_spec=language_transformer_layer_spec, + vocab_size=language_vocab_size, + max_sequence_length=language_max_sequence_length, + parallel_output=parallel_output, + position_embedding_type="mrope", + rotary_percent=language_rotary_percent, + pre_process=self.pre_process, + post_process=self.post_process, + rotary_base=language_rotary_base, + fp16_lm_cross_entropy=fp16_lm_cross_entropy, + share_embeddings_and_output_weights=language_share_embeddings_and_output_weights, + scatter_embedding_sequence_parallel=False, + ) + assert mpu.get_context_parallel_world_size() <= 1, "please use mbridge for qwen2_5_vl with context parallelism" + self.share_embeddings_and_output_weights = self.language_model.share_embeddings_and_output_weights + + def shared_embedding_or_output_weight(self): + """This is a convenience method to surface the language model's word embeddings, which is + necessary for `finalize_model_grads._allreduce_word_embedding_grads`.""" + if self.add_decoder: + return self.language_model.shared_embedding_or_output_weight() + return None + + def set_input_tensor(self, input_tensor) -> None: + # This is usually handled in schedules.py but some inference code still + # gives us non-lists or None + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + assert len(input_tensor) == 1, "input_tensor should only be length 1 for Qwen2VL" + + if self.pre_process: + self.encoder_hidden_state = input_tensor[0] + else: + self.language_model.set_input_tensor(input_tensor[0]) + + def freeze(self, freeze_language_model: bool, freeze_vision_model: bool, freeze_vision_projection: bool): + """Freeze model modules. + + Make specific modules non-trainable by setting requires_grad to False for the module's parameters. + + Args: + freeze_language_model (bool): Freeze the language model module. + freeze_vision_model (bool): Freeze the vision model module. + freeze_vision_projection (bool): Freeze the vision projection module. + """ + modules = [] + if freeze_language_model and self.language_model is not None: + modules.append(self.language_model) + if freeze_vision_model and self.vision_model is not None: + modules.append(self.vision_model) + if freeze_vision_projection and self.vision_projection is not None: + modules.append(self.vision_projection) + + for module in modules: + for param in module.parameters(): + param.requires_grad = False + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: torch.Tensor = None, + labels: torch.Tensor = None, + inference_params: InferenceParams = None, + packed_seq_params: PackedSeqParams = None, + extra_block_kwargs: dict = None, + pixel_values: torch.Tensor = None, + pixel_values_videos: torch.Tensor = None, + image_grid_thw: torch.Tensor = None, + video_grid_thw: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """Forward function of the Qwen2VL model. + ### there is a workaround for supporting sequence packing with context parallelism + # cp split with sequence packing will make model lose vision token information, so we need to keep + # the original input_ids and pack them after vision embedding is calculated, + # cooporate with verl's models/mcore/model_forward.py + # pack the combined_embeddings to thd here, we check if packed_seq_params is None to determine if + # we need to pack the combined_embeddings to thd + # this function needs the position_ids and attention_mask in BSHD format, no matter use packed_seq or not + + Args: + image_data (torch.Tensor): input image of shape [total_thw_size, n_features]. + input_ids (torch.Tensor): input text ids [batch, text_seq_len]. + position_ids (torch.Tensor): input text position ids [batch, text_seq_len]. + attention_mask (torch.Tensor): attention mask for the language model [batch, 1, combined_seq_len, + combined_seq_len]. + labels (torch.Tensor): Optional target text labels [batch, combined_seq_len]. + inference_params (InferenceParams): Inference-time parameters including KV cache. + + video_start_index: + 0 -- all video + len(video_seq) -- all image + others -- mixture + *_input_mask: should not be None in the first PP stage + Returns: + output (torch.Tensor): Loss of shape [b, s] if labels are provided, otherwise logits of shape + [b, s, vocab_size]. + """ + video_start_index = 0 + vision_grid_thw = None + vision_data = None + if image_grid_thw is not None: + image_mask = input_ids == self.image_token_id + vision_grid_thw = image_grid_thw + vision_data = pixel_values + video_start_index = image_mask.sum().item() + if video_grid_thw is not None: + video_mask = input_ids == self.video_token_id + if vision_grid_thw is not None: + vision_grid_thw = torch.cat([vision_grid_thw, video_grid_thw], dim=0) + vision_data = torch.cat([vision_data, pixel_values_videos], dim=0) + else: + vision_grid_thw = video_grid_thw + vision_data = pixel_values_videos + use_inference_kv_cache = ( + inference_params is not None and "image_tokens_count" in inference_params.key_value_memory_dict + ) + if use_inference_kv_cache: + raise NotImplementedError() + + if self.pre_process: + vision_embeds = None + if vision_grid_thw is not None and vision_grid_thw.shape[0] > 0: + vision_embeds = self.vision_model( + vision_data=vision_data, # If None, vision model should use intermediate outputs (EPP > 1) + grid_thw=vision_grid_thw, # should provided in each EPP stage + ) + + # If running inference, the language model KV cache will be updated for image token positions. + # Here we store the image tokens sequence length, which can be used as an offset to the KV cache later. + if inference_params is not None: + raise NotImplementedError() + # inference_params.key_value_memory_dict["image_tokens_count"] = ( + # vision_embeddings.shape[0] + # ) + + # If running inference, we can skip image token computation if they were computed already earlier + # for this sample. + if use_inference_kv_cache: + language_embeddings: torch.Tensor = self.language_model.embedding( + input_ids=input_ids, + position_ids=None, # NOTE: disable + ) # [text_seq_len, b, h_language] + # NOTE: why not cat here? is it the combined embeddings useless? + combined_embeddings = language_embeddings + elif vision_embeds is not None: + if video_start_index == 0: + image_embeds = None + video_embeds = vision_embeds + elif video_start_index == vision_embeds.shape[0]: + image_embeds = vision_embeds + video_embeds = None + elif 0 < video_start_index < vision_embeds.shape[0]: + image_embeds = vision_embeds[:video_start_index] + video_embeds = vision_embeds[video_start_index:] + else: + raise ValueError( + f"Expect video token start index in range [0, {vision_embeds.shape[0]}], but got " + f"{video_start_index}" + ) + + combined_embeddings = self.language_model.embedding( + input_ids=input_ids, + position_ids=None, # NOTE: disable + ) # [text_seq_len, b, h_language] + + if image_embeds is not None or video_embeds is not None: + combined_embeddings = combined_embeddings.transpose(0, 1).contiguous() + if image_embeds is not None: + image_mask = (input_ids == self.image_token_id).contiguous() + if image_mask.sum() > 0: + combined_embeddings = combined_embeddings.clone() + combined_embeddings[image_mask] = image_embeds.to( + dtype=combined_embeddings.dtype, device=combined_embeddings.device + ) + if video_embeds is not None: + video_mask = (input_ids == self.video_token_id).contiguous() + if video_mask.sum() > 0: + combined_embeddings = combined_embeddings.clone() + combined_embeddings[video_mask] = video_embeds.to( + dtype=combined_embeddings.dtype, device=combined_embeddings.device + ) + combined_embeddings = combined_embeddings.transpose(0, 1).contiguous() + + else: + combined_embeddings = self.language_model.embedding( + input_ids=input_ids, + position_ids=None, # NOTE: disable + ) # [text_seq_len, b, h_language] + + if packed_seq_params is not None: + combined_embeddings = ( + preprocess_packed_seqs( + combined_embeddings.transpose(0, 1).contiguous(), attention_mask, pre_process=True + )[0] + .transpose(0, 1) + .contiguous() + ) + if self.config.sequence_parallel: + combined_embeddings = tensor_parallel.scatter_to_sequence_parallel_region(combined_embeddings) + combined_embeddings = combined_embeddings.contiguous() + else: + combined_embeddings = None + from .rope_utils import get_rope_index + + # BSHD + position_ids, _ = get_rope_index( + input_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + attention_mask=attention_mask, + ) + # THD + if packed_seq_params is not None: + position_ids = ( + preprocess_packed_seqs(position_ids.permute(1, 2, 0), attention_mask, pre_process=True)[0] + .permute(2, 0, 1) + .contiguous() + ) + attention_mask = None + + output = self.language_model( + input_ids=None, + position_ids=position_ids, # None in encoder + attention_mask=attention_mask, # None in encoder + decoder_input=combined_embeddings, # only not None in the first decoder PP stage + labels=labels, # only not None in the last decoder PP stage + # inference_params=inference_params, # currently always None + packed_seq_params=packed_seq_params, # currently always None + **(extra_block_kwargs or {}), + **kwargs, + ) + + return output diff --git a/code/RL_model/verl/verl_train/verl/models/mcore/qwen2_5_vl/rope_utils.py b/code/RL_model/verl/verl_train/verl/models/mcore/qwen2_5_vl/rope_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fadc74daabe852f9e4561fe9981534815e5a148d --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/mcore/qwen2_5_vl/rope_utils.py @@ -0,0 +1,266 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024 Alibaba PAI Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import logging +from typing import Optional + +import torch +from megatron.core.models.common.embeddings.rope_utils import * +from megatron.core.models.common.embeddings.rope_utils import _apply_rotary_pos_emb_bshd +from torch import Tensor + +logger = logging.getLogger(__name__) + + +# Slightly modified from Qwen2VLForConditionalGeneration.get_rope_index +def get_rope_index( + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, +): + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Explanation: + + Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + + For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. + + Examples: + + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + and 1D rotary position embedding for text part. + + Examples: + + Temporal (Time): 3 patches, representing different segments of the video in time. + Height: 2 patches, dividing each frame vertically. + Width: 2 patches, dividing each frame horizontally. + We also have some important parameters: + fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each + second. + tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal + tokens" are conceptually packed into a one-second interval of the video. + In this case, we have 25 tokens per second. So each second of the video will be + represented with 25 separate time points. It essentially defines the temporal + granularity. + temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames. + interval: The step size for the temporal position IDs, calculated as tokens_per_second * + temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be + have a difference of 50 in the temporal position IDs. + input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100] + vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + text temporal position_ids: [101, 102, 103, 104, 105] + text height position_ids: [101, 102, 103, 104, 105] + text width position_ids: [101, 102, 103, 104, 105] + Here we calculate the text start position_ids as the max vision position_ids plus 1. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): + The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + spatial_merge_size = 2 + tokens_per_second = 2 + image_token_id = 151655 + video_token_id = 151656 + vision_start_token_id = 151652 + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + second_per_grid_t = 0 + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + if second_per_grid_ts is not None: + second_per_grid_t = second_per_grid_ts[video_index] + else: + second_per_grid_t = 1.0 + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + range_tensor = torch.arange(llm_grid_t).view(-1, 1) + expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) + + time_tensor = expanded_range * second_per_grid_t * tokens_per_second + + time_tensor_long = time_tensor.long() + t_index = time_tensor_long.flatten() + + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + +def apply_rotary_pos_emb_thd_absolute( + t: Tensor, cu_seqlens: Tensor, freqs: Tensor, rotary_interleaved: bool = False +) -> Tensor: + """A baseline implementation of applying RoPE for `thd` format. + + Args: + t (Tensor): Input tensor T is of shape [t, h, d] + cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`, + with shape [b + 1] and dtype torch.int32. + freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d] + + Returns: + Tensor: Shape [t, h, d]. The input tensor after applying RoPE. + """ + return _apply_rotary_pos_emb_bshd(t[:, None], freqs, rotary_interleaved=rotary_interleaved).squeeze(1) + + +def apply_rotary_pos_emb_absolute( + t: Tensor, + freqs: Tensor, + config: TransformerConfig, + cu_seqlens: Optional[Tensor] = None, +): + """ + Reroute to the appropriate apply_rotary_pos_emb function depending on + bshd (conventional) / thd (packed seq) format + + In Qwen2-VL, the shape of freqs is (seq_length, bs, 1, 2 * dim) instead of [max_seqlen, 1, 1, 2 * dim] + """ + + if config.apply_rope_fusion: + if cu_seqlens is None: + # NOTE: TE backends do not support mRoPE in bshd format when bs > 1 + if freqs.shape[1] > 1: + return _apply_rotary_pos_emb_bshd(t, freqs, rotary_interleaved=config.rotary_interleaved) + else: + return fused_apply_rotary_pos_emb(t, freqs) + else: + # NOTE: as expected, thd format can use bshd + return fused_apply_rotary_pos_emb(t[:, None], freqs).squeeze(1) + else: + if cu_seqlens is None: + return _apply_rotary_pos_emb_bshd(t, freqs, rotary_interleaved=config.rotary_interleaved) + else: + return apply_rotary_pos_emb_thd_absolute(t, cu_seqlens, freqs, rotary_interleaved=config.rotary_interleaved) diff --git a/code/RL_model/verl/verl_train/verl/models/mcore/qwen2_5_vl/vision_config.py b/code/RL_model/verl/verl_train/verl/models/mcore/qwen2_5_vl/vision_config.py new file mode 100644 index 0000000000000000000000000000000000000000..0631c90f61605f2ed0d659c8836f01c451e694a6 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/mcore/qwen2_5_vl/vision_config.py @@ -0,0 +1,85 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024 Alibaba PAI Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from megatron.core import parallel_state +from megatron.core.transformer import TransformerConfig + + +def get_vision_model_config(config: TransformerConfig) -> TransformerConfig: + # Given a Transformer Config from decoder, build vision encoder config + # diff: out_hidden_size & intermediate_size + + # mlp: hidden_size -> intermediate_size -> embed_dim, silu + # NOTE: here we provide a workaround to solve the wrong layer amount when VPP of decoder is on + if config.num_layers in [28, 36]: + config.ffn_hidden_size = 3420 + else: + config.ffn_hidden_size = 3456 + + if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: + config.num_layers = 32 * parallel_state.get_virtual_pipeline_model_parallel_world_size() # depth + else: + config.num_layers = 32 # depth + config.num_attention_heads = 16 # num_heads + config.add_bias_linear = True # all nn.Linear has bias (MLP, attn) + config.add_qkv_bias = True # qkv_proj in attn has bias + config.hidden_size = 1280 # hidden_size + config.hidden_dropout = 0.0 + config.attention_dropout = 0.0 + + # config.gated_linear_unit = False # no gated + # config.activation_func = quick_gelu # hidden_act + config.kv_channels = config.hidden_size // config.num_attention_heads + config.num_query_groups = config.num_attention_heads # no GQA + config.layernorm_zero_centered_gamma = False # False + config.apply_query_key_layer_scaling = False # factor=math.sqrt(head_dim) + config.bias_activation_fusion = False # no swiglu, set false + config.bias_dropout_fusion = False # no dropout, set false + config.attention_softmax_in_fp32 = True # use True + # config.normalization = 'LayerNorm' # use RMSNorm + config.seq_length = 1 + + config.tp_comm_overlap = False + config.sequence_parallel = False + config.temporal_patch_size = 2 + config.patch_size = 14 + config.in_channels = 3 + config.spatial_merge_size = 2 + + config.fullatt_block_indexes = [7, 15, 23, 31] + config._qwen2_5_vl_window_size = 112 + return config + + +def get_vision_projection_config( + config: TransformerConfig, embed_dim: int, spatial_merge_size: int +) -> TransformerConfig: + # merger: + # context_dim = hidden_size * merge_size**2 + # out_hidden_size = hidden_size + # context_dim -> context_dim -> out_hidden_size + # MLP: + # input_size -> ffn_hidden_size -> hidden_size + # spec: LN -> Linear(bias=True) -> GELU -> Linear(bias=True) + config.gated_linear_unit = False + config.bias_activation_fusion = False + config.add_bias_linear = True + config.ffn_hidden_size = embed_dim * (spatial_merge_size**2) + config.activation_func = torch.nn.functional.gelu + config.tp_comm_overlap = False + config.sequence_parallel = False + return config diff --git a/code/RL_model/verl/verl_train/verl/models/mcore/qwen2_5_vl/vision_model.py b/code/RL_model/verl/verl_train/verl/models/mcore/qwen2_5_vl/vision_model.py new file mode 100644 index 0000000000000000000000000000000000000000..06b4fd328064a1f50b32a7009aec8ecef573656e --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/mcore/qwen2_5_vl/vision_model.py @@ -0,0 +1,309 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024 Alibaba PAI Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +from megatron.core import InferenceParams +from megatron.core.models.common.vision_module.vision_module import VisionModule +from megatron.core.models.vision.multimodal_projector import MultimodalProjector +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.transformer.enums import ModelType +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_config import TransformerConfig +from torch import nn +from torch.nn import functional as F + +from .vision_transformer_block import Qwen2_5VisionTransformerBlock as TransformerBlock + + +# copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +class PatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + embed_dim: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + return hidden_states + + +# copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +class VisionRotaryEmbedding(nn.Module): + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs.float() + + +class Qwen2_5VisionModel(VisionModule): + """Qwen2.5 ViT vision model. + + Args: + transformer_config (TransformerConfig): Transformer config. + transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers. + ln_pre_impl (ModuleSpec or type): Specifies the layer norm type to use for ln_pre. + add_class_token (bool, optional): Include a class token. Defaults to True. + class_token_len (int): Class token length. Defaults to 1 but 8 may be faster. + patch_dim (int): Image patch size. + img_h (int): Input image height. + img_w (int): Input image width. + """ + + def __init__( + self, + transformer_config: TransformerConfig, + transformer_layer_spec: ModuleSpec, + projection_config: TransformerConfig, + projection_layer_spec: ModuleSpec, + projection_type: str = "mlp", + pre_process: bool = True, + post_process: bool = False, + ) -> None: + super().__init__(config=transformer_config) + + self.spatial_merge_size = transformer_config.spatial_merge_size + + embed_dim = transformer_config.hidden_size + num_heads = transformer_config.num_attention_heads + temporal_patch_size = transformer_config.temporal_patch_size + patch_size = transformer_config.patch_size + in_channels = transformer_config.in_channels + + self.patch_size = transformer_config.patch_size + self.fullatt_block_indexes = transformer_config.fullatt_block_indexes + self.window_size = transformer_config._qwen2_5_vl_window_size + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.max_sequence_length = transformer_config.seq_length + self.patch_embed = PatchEmbed( + patch_size=patch_size, + temporal_patch_size=temporal_patch_size, + in_channels=in_channels, + embed_dim=embed_dim, + ) + + head_dim = embed_dim // num_heads + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) + + self.model_type = ModelType.encoder_or_decoder + self.pre_process = pre_process + self.post_process = post_process + + # Transformer layers. + # TODO: Follow-up changes will make pre and post_process configurable. They are needed for supporting + # pipeline parallelism. + # NOTE: a final layer norm and/or linear layer present in some implementations are omitted here. + self.decoder = TransformerBlock( + config=transformer_config, + spec=transformer_layer_spec, + pre_process=self.pre_process, + post_process=self.post_process, + post_layer_norm=True, + ) + + self.merge_hidden_size = projection_config.ffn_hidden_size + self.square_merge_size = self.merge_hidden_size // embed_dim + + if self.post_process: + self.projection = MultimodalProjector( + projection_config, projection_layer_spec, projection_type, projection_config.ffn_hidden_size + ) + else: + self.projection = None + + self.input_tensor = None + + def set_input_tensor(self, input_tensor: torch.Tensor) -> None: + """Sets input tensor to the model. + + Args: + input_tensor (Tensor): Sets the input tensor for the model. + """ + if self.pre_process: # always True + self.input_tensor = input_tensor + else: + raise NotImplementedError() + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0).to(grid_thw.device) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size).to(grid_thw.device) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h, llm_grid_w = ( + grid_h // self.spatial_merge_size, + grid_w // self.spatial_merge_size, + ) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = torch.cat(window_index, dim=0) + + return window_index, cu_window_seqlens + + def forward( + self, + vision_data: Optional[torch.Tensor], + grid_thw: torch.Tensor, + inference_params: Optional[InferenceParams] = None, + extra_block_kwargs: dict = None, + ) -> torch.Tensor: + """Forward function of the Qwen2 Vision Model. This function passes the input tensors + through the embedding layer and then the transformer. + + Args: + x (torch.Tensor): input image/video data of shape [n_tokens, n_dims] + grid_thw (torch.Tensor): the size tensor indicates grid size of each image/frame + packed_seq_params (PackedSeqParams): parameters to build attention mask in the backend + + Returns: + x (torch.Tensor): output after final transformer block of shape [b, s, h]. + """ + assert grid_thw is not None + assert self.input_tensor is None + assert inference_params is None + + # Rotary positional embeddings (embedding is None for PP intermediate devices) + vision_data = self.patch_embed(vision_data) + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=vision_data.device, + dtype=torch.int32, + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + seq_len, _ = vision_data.size() + vision_data = vision_data.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + vision_data = vision_data[window_index, :, :] + vision_data = vision_data.reshape(seq_len, 1, -1) + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, 1, 1, -1).repeat(1, 1, 1, 2) + + hidden_states = self.decoder( + hidden_states=vision_data, + attention_mask=None, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + packed_seq_params=self.build_packed_seq_params(None, cu_window_seqlens), + packed_seq_params_full=self.build_packed_seq_params(grid_thw), + fullatt_block_indexes=self.fullatt_block_indexes, + **(extra_block_kwargs or {}), + ) + + hidden_states = self.projection(hidden_states.view(-1, self.merge_hidden_size)) + reverse_indices = torch.argsort(window_index) + return hidden_states[reverse_indices, :] + + def build_packed_seq_params( + self, + grid_thw: Optional[torch.Tensor], + cu_seqlens: Optional[torch.Tensor] = None, + ) -> PackedSeqParams: + # NOTE: each frame is a sequence (rather than each grid) + if grid_thw is not None: + seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]) + cu_seqlens = seqlens.cumsum(dim=0) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).int() + else: + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + + max_seqlen_q = seqlens.max() + return PackedSeqParams( + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + qkv_format="thd", + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_q, + ) diff --git a/code/RL_model/verl/verl_train/verl/models/mcore/qwen2_5_vl/vision_transformer_block.py b/code/RL_model/verl/verl_train/verl/models/mcore/qwen2_5_vl/vision_transformer_block.py new file mode 100644 index 0000000000000000000000000000000000000000..8f765a0ff632f65771d1b1d19a4b0f052ee6ec37 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/mcore/qwen2_5_vl/vision_transformer_block.py @@ -0,0 +1,265 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024 Alibaba PAI Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from megatron.core.transformer.transformer_block import * + + +class Qwen2_5VisionTransformerBlock(TransformerBlock): + def _checkpointed_forward( + self, + hidden_states: Tensor, + attention_mask: Tensor, + context: Tensor, + context_mask: Tensor, + rotary_pos_emb: Tensor, + attention_bias: Tensor, + packed_seq_params: PackedSeqParams, + packed_seq_params_full: PackedSeqParams, + fullatt_block_indexes, + ): + """Forward method with activation checkpointing.""" + + def custom(start: int, end: int): + def custom_forward(hidden_states, attention_mask, context, context_mask, rotary_pos_emb): + for index in range(start, end): + if index in fullatt_block_indexes: + packed_seq_params_now = packed_seq_params_full + else: + packed_seq_params_now = packed_seq_params + layer = self._get_layer(index) + hidden_states, context = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + attention_bias=attention_bias, + inference_context=None, + packed_seq_params=packed_seq_params_now, + ) + return hidden_states, context + + return custom_forward + + def checkpoint_handler(forward_func): + """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" + if self.config.fp8: + return te_checkpoint( + forward_func, + self.config.distribute_saved_activations, + tensor_parallel.random.get_cuda_rng_tracker, + parallel_state.get_tensor_model_parallel_group(), + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, + ) + else: + return tensor_parallel.checkpoint( + forward_func, + self.config.distribute_saved_activations, + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, + ) + + if self.config.recompute_method == "uniform": + # Uniformly divide the total number of Transformer layers and checkpoint + # the input activation of each divided chunk. + # A method to further reduce memory usage reducing checkpoints. + layer_idx = 0 + while layer_idx < self.num_layers_per_pipeline_rank: + hidden_states, context = checkpoint_handler( + custom(layer_idx, layer_idx + self.config.recompute_num_layers) + ) + + layer_idx += self.config.recompute_num_layers + + elif self.config.recompute_method == "block": + # Checkpoint the input activation of only a set number of individual + # Transformer layers and skip the rest. + # A method fully use the device memory removing redundant re-computation. + recompute_skip_num_layers = 0 + for layer_idx in range(self.num_layers_per_pipeline_rank): + # Skip recomputation when input grad computation is not needed. + # Need to have at least one input tensor with gradient computation + # for re-enterant autograd engine. + if self.config.fp8 and not hidden_states.requires_grad: + recompute_skip_num_layers += 1 + if ( + layer_idx >= recompute_skip_num_layers + and layer_idx < self.config.recompute_num_layers + recompute_skip_num_layers + ): + hidden_states, context = checkpoint_handler(custom(layer_idx, layer_idx + 1)) + else: + hidden_states, context = custom(layer_idx, layer_idx + 1)( + hidden_states, attention_mask, context, context_mask, rotary_pos_emb + ) + else: + raise ValueError("Invalid activation recompute method.") + + return hidden_states + + def forward( + self, + hidden_states: Union[Tensor, WrappedTensor], + attention_mask: Optional[Tensor], + context: Optional[Tensor] = None, + context_mask: Optional[Tensor] = None, + rotary_pos_emb: Optional[Tensor] = None, + rotary_pos_cos: Optional[Tensor] = None, + rotary_pos_sin: Optional[Tensor] = None, + attention_bias: Optional[Tensor] = None, + inference_context: Optional[BaseInferenceContext] = None, + packed_seq_params: Optional[PackedSeqParams] = None, + sequence_len_offset: Optional[Tensor] = None, + packed_seq_params_full: PackedSeqParams = None, + fullatt_block_indexes=None, + *, + inference_params: Optional[BaseInferenceContext] = None, + ): + """ + Perform the forward pass through the transformer block. + + This method handles the core computation of the transformer, including + self-attention, optional cross-attention, and feed-forward operations. + + Args: + hidden_states (Union[Tensor, WrappedTensor]): Input tensor of shape [s, b, h] + where s is the sequence length, b is the batch size, and h is the hidden size. + Can be passed as a WrappedTensor during inference to avoid an obsolete + reference in the calling function. + attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking + self-attention. + context (Tensor, optional): Context tensor for cross-attention. + context_mask (Tensor, optional): Mask for cross-attention context + rotary_pos_emb (Tensor, optional): Rotary positional embeddings. + attention_bias (Tensor): Bias tensor for Q * K.T of shape in shape broadcastable + to [b, num_head, sq, skv], e.g. [1, 1, sq, skv]. + Used as an alternative to apply attention mask for TE cuDNN attention. + inference_context (BaseInferenceContext, optional): Parameters for inference-time + optimizations. + packed_seq_params (PackedSeqParams, optional): Parameters for packed sequence + processing. + + Returns: + Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape + [s, b, h], and optionally the updated context tensor if cross-attention is used. + """ + + inference_context = deprecate_inference_params(inference_context, inference_params) + + # Delete the obsolete reference to the initial input tensor if necessary + if isinstance(hidden_states, WrappedTensor): + hidden_states = hidden_states.unwrap() + + if not self.pre_process: + # See set_input_tensor() + hidden_states = self.input_tensor + + # Update the inference parameters with the current batch size in case it is variable + if inference_context and not self.training: + inference_context.current_batch_size = hidden_states.size(1) + + # Viewless tensor. + # - We only need to create a viewless tensor in the case of micro batch + # size (mbs) == 1, since in this case, 'hidden_states.transpose()' + # above creates a view tensor, and '.contiguous()' is a pass-through. + # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating + # the need to make it viewless. + # + # However, we don't explicitly check mbs == 1 here because + # make_viewless_tensor() has negligible overhead when its input + # is already viewless. + # + # - For the 'else' case above, calling make_viewless_tensor() here is + # likely redundant, since p2p_communication.py (likely originator) + # already creates viewless tensors. That said, make_viewless_tensor() + # is called here to be future-proof and corner-case-proof. + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + + if self.config.sequence_parallel: + rng_context = tensor_parallel.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + + # If fp8_recipe is delayed, wrap the entire pass with get_fp8_context(), + # otherwise do nothing extra at the outer level + # if we are using other fp8 recipes, then the context manager enter&exit are free + # we can wrap fp8_context within the for loop over layers, so that we can fine-grained + # control which layer will be fp8 or bf16 + use_outer_fp8_context = self.config.fp8 and self.config.fp8_recipe == Fp8Recipe.delayed + use_inner_fp8_context = self.config.fp8 and self.config.fp8_recipe != Fp8Recipe.delayed + outer_fp8_context = get_fp8_context(self.config) if use_outer_fp8_context else nullcontext() + + with rng_context, outer_fp8_context: + # Forward pass. + if self.config.recompute_granularity == "full" and self.training: + hidden_states = self._checkpointed_forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + packed_seq_params_full=packed_seq_params_full, + fullatt_block_indexes=fullatt_block_indexes, + ) + else: + for l_no, layer in enumerate(self.layers): + inner_fp8_context = ( + get_fp8_context(self.config, layer.layer_number - 1) if use_inner_fp8_context else nullcontext() + ) + if l_no in fullatt_block_indexes: + packed_seq_params_now = packed_seq_params_full + else: + packed_seq_params_now = packed_seq_params + with self.offload_context, inner_fp8_context: + hidden_states, context = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + attention_bias=attention_bias, + inference_context=inference_context, + packed_seq_params=packed_seq_params_now, + sequence_len_offset=sequence_len_offset, + ) + + if ( + torch.is_grad_enabled() + and self.config.cpu_offloading + and self.group_prefetch_offload_commit_async is not None + ): + hidden_states = self.group_prefetch_offload_commit_async(hidden_states) + + # Final layer norm. + if self.final_layernorm is not None: + hidden_states = self.final_layernorm(hidden_states) + # TENorm produces a "viewed" tensor. This will result in schedule.py's + # deallocate_output_tensor() throwing an error, so a viewless tensor is + # created to prevent this. + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + + return hidden_states diff --git a/code/RL_model/verl/verl_train/verl/models/mcore/readme.md b/code/RL_model/verl/verl_train/verl/models/mcore/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..0807dbf50f71ae908ff6d28d4a1456f02dc27e29 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/mcore/readme.md @@ -0,0 +1,141 @@ +updated 20251222 + +# The ways verl integrates megatron-core +There has been 3 ways that verl integrates megatron-core as it training backend: +1. the codes inside this directory, which defines the conversion for new models one by one. (deprecated now) +2. through [mbridge](https://github.com/ISEEKYAN/mbridge) (will be deprecated at about v0.8) +3. through [megatron-bridge](https://github.com/NVIDIA-NeMo/Megatron-Bridge) (the official way for further development) + +There is a configure option of `megatron.use_mbridge` to choose way#1 (false) or way#2 (true), and after the megatron-bridge is integrated we have a new option `megatron.vanilla_mbridge` to choose way#2 (true) or way#3 (false) + +Now since we deprecated the way#1, the option `use_mbridge` will be asserted to be true and will be removed after v0.7. The default `vanilla_mbridge` is true for now and will be false one the megatron-bridge backend turns default. + +With the bridge way(#2 or #3), we can directly load and save the megatron model weight through HuggingFace format, and we can use any megatron version >= 0.13 to adopt new megatron optimization feature as handy as possible by directly add overrided megatron configs such as `+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform`. + +# How to support new models +1. Make sure the model is supported by your inference engine (vLLM or SGLang or TensorRT-LLM) with correct version. +2. Make sure the model is supported by the bridge + - If it is a model of new architecture, open an issue to `megatron-bridge` or contribute your implementation to `megatron-bridge`. Be cautious to have a matched version of `Megatron` and `TransformerEngine` + - If it is a private model, implement your private model with `mbridge` or `megatron-bridge`(prefered). + +3. Now the model is supported, just change the model path of your scripts and run the scritps. + + + + + +# #Below are deprecated since 2025.12# +# verl Megatron-Core Models +Now we use [mbridge](https://github.com/iseekyan/mbridge) to support megatron models. And we will migrate to [megatron-bridge](https://github.com/NVIDIA-NeMo/Megatron-Bridge) in the future. + +With the mbridge, we can use allmost all the Megatron-Core features to support new models with little effort. And no offline weights conversion is needed, all the weights conversion is done online. We can directly save the mcore model to huggingface format during training. + +Also, we can easily upgrade the mcore version to the latest version. In most cases, the upgrade is seamless. (except when the mcore API changes and we need to update the verl code accordingly) + +## How to support new models +1. make sure the model is supported by vLLM +2. Support the model in [mbridge](https://github.com/iseekyan/mbridge), see its currently supported models for example. + - we will migrate to [megatron-bridge](https://github.com/NVIDIA-NeMo/Megatron-Bridge) in the future. +3. Register the model forward function in verl, see the example in `verl/verl/models/mcore/registry.py`. + + + +# #Below are deprecated since 2025.10# +The earlier versions of verl use `Megatron-LM` 0.4 and workaround huggingface model classes. To better use the latest features and speedup of modern Megatron, we are migrating to `Megatron-Core`(mcore), and use the recommended `GPTModel` class for all language models. With mcore `GPTModel`, we can use the latest features like `context parallel`, `expert parallel`, `dist_checkpointing`, etc. and we can update mcore with little effort in the future for new features. + +The migration has been successful with the help of the mcore team and the community. What we have done is: +1. update `Megatron` version to `0.14.0` +2. migrate `LlamaForCausalLM` and `Qwen2ForCausalLM` to mcore `GPTModel` +3. support sequence packing/thd format. +4. support `tensor parallel`, `pipeline parallel`, `sequence parallel`, `virtual pipeline parallel`, `context parallel`. +5. support the mcore `dist_checkpointing` feature and a basic offline weighs conversion script from huggingface to mcore `dist_checkpointing` format. + +We are working on the following features: +- support `Qwen2MoeForCausalLM` +- support `MixtralForCausalLM` +- support `DeepseekV3ForCausalLM` +- support `expert parallel` + +Features we invite the community to contribute: +- better scripts for offline weights conversion from huggingface to mcore `dist_checkpointing` format. + - conversion of large models with multiple GPUs + - conversion of large models with single GPU +- refactor the `megatron_checkpoint_manager.py` by `dist_checkpointing` format. +- support llama4 +- support qwen2.5-vl + +To track the progress of verl mcore integration, please refer to the [mcore integration issue](https://github.com/volcengine/verl/issues/1033). + +## How things work now +To engage the community in contributing, here are the key steps in our mcore integration process and features under development. + +The huggingface `transformers` is the de facto standard of model zoo while mcore is good at computation efficiency. The main challenge is conversion between the two. +main steps: +1. modelling the huggingface model with mcore `GPTModel` + - a. convert the huggingface config to mcore `TransformerConfig` + - b. init the mcore `GPTModel` with the converted config + - c. load the huggingface model weights to the `GPTModel` +2. online weight conversion from mcore to huggingface (due to the rollout engine `vLLM` is using huggingface format) + - a. bridge the gap between mcore and huggingface weights format and name mapping + - b. online resharding the mcore weights to rollout engine + - this part is very complicated with multiple parallel strategies composition between mcore and rollout engine +3. support the mcore features in verl + - a. support `tensor parallel`, `pipeline parallel`, `sequence parallel`, `virtual pipeline parallel`, `context parallel` + - b. support recompute and other mcore speed up features + +4. checkpointing + - a. support recovering the verl training. + - b. support exporting the mcore checkpoint to huggingface format, for downstream inference. + +### Modelling the huggingface model with mcore `GPTModel` +The first step is to convert huggingface config to mcore `TransformerConfig` and init the mcore `GPTModel` with the converted config. See code in `verl/models/mcore/config_converter.py` and `verl/verl/models/mcore/models/model_initializer.py`. The corresponding model forward code is in `verl/verl/models/mcore/models/model_forward.py`. + +There are two ways of loading the huggingface model weights to the `GPTModel` +1. Runtime loading + - every rank loads the entire huggingface model weights and then shard and convert to mcore weights. + - speed is slow and memory consumption is high. + - this way is deprecated and will not support new models. +2. Offline loading + - use offline script to convert the huggingface model weights to mcore weights and save with mcore `dist_checkpointing` format. + - online loading and sharding is automatically done by mcore `dist_checkpointing` format. The speed is fast and memory consumption is low. + - the offline script is in `verl/scripts/converter_hf_to_mcore.py`. + +### online weight conversion from mcore to huggingface +See function `convert_megatron_model_to_transformers_model` in `verl/utils/megatron_utils.py` for the details. + +It should be refatored for extensibility and better performance. + +### support the mcore features in verl +Most of the features of `GPTModel` is out-of-the-box supported in verl through changing the `TransformerConfig`, except those about parallel strategies, such as `expert parallel`. +Features about parallel strategies should be supported with changes about the online weights conversion(especially the resharding part) and verl work dispatching. + +### checkpointing +The existing checkpointing code is in `verl/utils/checkpoint/megatron_checkpoint_manager.py`. And the script to convert checkpoint to huggingface format is in `verl/scripts/model_merger`. + +The existing checkpoint format simply saves every rank's weights and optimizer states. It should be refactored by `dist_checkpointing` format. + + +## How to support new models +1. make sure the model is supported by vLLM +2. modelling the huggingface model with mcore `GPTModel` (The [Pai-Megatron-Path](https://github.com/alibaba/Pai-Megatron-Patch/tree/main) is a good reference) + - a. convert the huggingface config to mcore `TransformerConfig` + - b. init the mcore `GPTModel` with the converted config + - c. load the huggingface model weights to the `GPTModel` + - d. for VLM the interface might be different, it is ok to add a new model class with GPTModel as its module. +3. offline weights conversion from huggingface to mcore `dist_checkpointing` format +4. support online weights conversion from mcore to huggingface + - it is recommended to initialize a vLLM model with the converted mcore weights, and then test if the generating sequence is correct. + + +## How to scale up to larger models like deepseek-v3 or other 100B+ models +The greatest challenge for scaling up to larger models is the memory consumption. + +The necessary features under development for scaling up are +1. Training engine part + - expert parallel +2. Rollout engine part + - pipeline parallel + - expert parallel + - more efficient and general weight resharding and loading +3. Offline weights conversion + - support weights larger than single GPU memory diff --git a/code/RL_model/verl/verl_train/verl/models/mcore/registry.py b/code/RL_model/verl/verl_train/verl/models/mcore/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..bc4679666a0866db7dcdc7715a00c1be121541e4 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/mcore/registry.py @@ -0,0 +1,301 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Registry module for model architecture components. +""" + +from enum import Enum +from typing import Callable + +import torch +import torch.nn as nn + +from .model_forward import gptmodel_forward_no_padding, model_forward_gen +from .model_forward_fused import fused_forward_model_gen + + +class SupportedVLM(Enum): + QWEN2_5_VL = "Qwen2_5_VLForConditionalGeneration" + QWEN3_MOE_VL = "Qwen3VLMoeForConditionalGeneration" + QWEN3_VL = "Qwen3VLForConditionalGeneration" + + +supported_vlm = [member.value for member in SupportedVLM] + + +def get_mcore_forward_fn(hf_config) -> Callable: + """ + Get the forward function for given model architecture. + """ + assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" + if hf_config.architectures[0] in supported_vlm: + return model_forward_gen(True) + else: + # default to language model + return model_forward_gen(False) + + +def get_mcore_forward_no_padding_fn(hf_config) -> Callable: + """ + Get the forward function for given model architecture. + """ + assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" + return gptmodel_forward_no_padding + + +def get_mcore_forward_fused_fn(hf_config) -> Callable: + """ + Get the forward function for given model architecture. + """ + assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" + if hf_config.architectures[0] in supported_vlm: + return fused_forward_model_gen(True) + else: + # default to language model + return fused_forward_model_gen(False) + + +# ruff: noqa + +######################################################## +# below is the deprecated code +######################################################## + +from .config_converter import ( + PretrainedConfig, + TransformerConfig, + hf_to_mcore_config_dense, + hf_to_mcore_config_dpskv3, + hf_to_mcore_config_llama4, + hf_to_mcore_config_mixtral, + hf_to_mcore_config_qwen2_5_vl, + hf_to_mcore_config_qwen2moe, + hf_to_mcore_config_qwen3moe, +) +from .model_initializer import ( + BaseModelInitializer, + DeepseekV3Model, + DenseModel, + MixtralModel, + Qwen2MoEModel, + Qwen3MoEModel, + Qwen25VLModel, +) +from .weight_converter import ( + McoreToHFWeightConverterDense, + McoreToHFWeightConverterDpskv3, + McoreToHFWeightConverterMixtral, + McoreToHFWeightConverterQwen2_5_VL, + McoreToHFWeightConverterQwen2Moe, + McoreToHFWeightConverterQwen3Moe, +) + + +class SupportedModel(Enum): + LLAMA = "LlamaForCausalLM" # tested + QWEN2 = "Qwen2ForCausalLM" # tested + QWEN2_MOE = "Qwen2MoeForCausalLM" # pending + DEEPSEEK_V3 = "DeepseekV3ForCausalLM" # not tested + MIXTRAL = "MixtralForCausalLM" # tested + QWEN2_5_VL = "Qwen2_5_VLForConditionalGeneration" # not supported + LLAMA4 = "Llama4ForConditionalGeneration" # not tested + QWEN3 = "Qwen3ForCausalLM" # tested + QWEN3_MOE = "Qwen3MoeForCausalLM" # tested + GLM4_MOE = "Glm4MoeForCausalLM" + QWEN3_TOKEN_CLASSIFICATION = "Qwen3ForTokenClassification" + LLAMA_TOKEN_CLASSIFICATION = "LlamaForTokenClassification" + QWEN3_MOE_VL = "Qwen3VLMoeForConditionalGeneration" + QWEN3_VL = "Qwen3VLForConditionalGeneration" + GPT_OSS = "GptOssForCausalLM" + MiMO = "MiMoForCausalLM" + + +# Registry for model configuration converters +MODEL_CONFIG_CONVERTER_REGISTRY: dict[SupportedModel, Callable[[PretrainedConfig, torch.dtype], TransformerConfig]] = { + SupportedModel.LLAMA: hf_to_mcore_config_dense, + SupportedModel.QWEN2: hf_to_mcore_config_dense, + SupportedModel.QWEN2_MOE: hf_to_mcore_config_qwen2moe, + SupportedModel.DEEPSEEK_V3: hf_to_mcore_config_dpskv3, + SupportedModel.MIXTRAL: hf_to_mcore_config_mixtral, + SupportedModel.QWEN2_5_VL: hf_to_mcore_config_qwen2_5_vl, + SupportedModel.LLAMA4: hf_to_mcore_config_llama4, + SupportedModel.QWEN3: hf_to_mcore_config_dense, + SupportedModel.QWEN3_MOE: hf_to_mcore_config_qwen3moe, + SupportedModel.QWEN3_TOKEN_CLASSIFICATION: hf_to_mcore_config_dense, + SupportedModel.LLAMA_TOKEN_CLASSIFICATION: hf_to_mcore_config_dense, +} + +# Registry for model initializers +MODEL_INITIALIZER_REGISTRY: dict[SupportedModel, type[BaseModelInitializer]] = { + SupportedModel.LLAMA: DenseModel, + SupportedModel.QWEN2: DenseModel, + SupportedModel.QWEN2_MOE: Qwen2MoEModel, + SupportedModel.MIXTRAL: MixtralModel, + SupportedModel.DEEPSEEK_V3: DeepseekV3Model, + SupportedModel.QWEN2_5_VL: Qwen25VLModel, + SupportedModel.LLAMA4: DenseModel, + SupportedModel.QWEN3: DenseModel, + SupportedModel.QWEN3_MOE: Qwen3MoEModel, + SupportedModel.QWEN3_TOKEN_CLASSIFICATION: DenseModel, + SupportedModel.LLAMA_TOKEN_CLASSIFICATION: DenseModel, +} + +# Registry for model forward functions +MODEL_FORWARD_REGISTRY: dict[SupportedModel, Callable] = { + SupportedModel.LLAMA: model_forward_gen(), + SupportedModel.QWEN2: model_forward_gen(), + SupportedModel.QWEN2_MOE: model_forward_gen(), + SupportedModel.MIXTRAL: model_forward_gen(), + SupportedModel.DEEPSEEK_V3: model_forward_gen(), + SupportedModel.LLAMA4: model_forward_gen(), + SupportedModel.QWEN3: model_forward_gen(), + SupportedModel.QWEN3_MOE: model_forward_gen(), + SupportedModel.QWEN2_5_VL: model_forward_gen(True), + SupportedModel.QWEN3_MOE_VL: model_forward_gen(True), + SupportedModel.QWEN3_VL: model_forward_gen(True), + SupportedModel.GLM4_MOE: model_forward_gen(), + SupportedModel.QWEN3_TOKEN_CLASSIFICATION: model_forward_gen(), + SupportedModel.LLAMA_TOKEN_CLASSIFICATION: model_forward_gen(), + SupportedModel.GPT_OSS: model_forward_gen(), + SupportedModel.MiMO: model_forward_gen(), +} + +# Registry for model forward functions +MODEL_FORWARD_NOPAD_REGISTRY: dict[SupportedModel, Callable] = { + SupportedModel.LLAMA: gptmodel_forward_no_padding, + SupportedModel.QWEN2: gptmodel_forward_no_padding, + SupportedModel.QWEN2_MOE: gptmodel_forward_no_padding, + SupportedModel.MIXTRAL: gptmodel_forward_no_padding, + SupportedModel.DEEPSEEK_V3: gptmodel_forward_no_padding, + SupportedModel.QWEN2_5_VL: gptmodel_forward_no_padding, + SupportedModel.QWEN3_MOE_VL: gptmodel_forward_no_padding, + SupportedModel.QWEN3_VL: gptmodel_forward_no_padding, + SupportedModel.LLAMA4: gptmodel_forward_no_padding, + SupportedModel.QWEN3: gptmodel_forward_no_padding, + SupportedModel.QWEN3_MOE: gptmodel_forward_no_padding, + SupportedModel.GLM4_MOE: gptmodel_forward_no_padding, + SupportedModel.QWEN3_TOKEN_CLASSIFICATION: gptmodel_forward_no_padding, + SupportedModel.LLAMA_TOKEN_CLASSIFICATION: gptmodel_forward_no_padding, + SupportedModel.GPT_OSS: gptmodel_forward_no_padding, + SupportedModel.MiMO: gptmodel_forward_no_padding, +} + +# Registry for model forward functions +MODEL_FORWARD_FUSED_REGISTRY: dict[SupportedModel, Callable] = { + SupportedModel.LLAMA: fused_forward_model_gen(), + SupportedModel.QWEN2: fused_forward_model_gen(), + SupportedModel.QWEN2_MOE: fused_forward_model_gen(), + SupportedModel.MIXTRAL: fused_forward_model_gen(), + SupportedModel.QWEN2_5_VL: fused_forward_model_gen(True), + SupportedModel.QWEN3_MOE_VL: fused_forward_model_gen(True), + SupportedModel.QWEN3_VL: fused_forward_model_gen(True), + SupportedModel.LLAMA4: fused_forward_model_gen(), + SupportedModel.QWEN3: fused_forward_model_gen(), + SupportedModel.QWEN3_MOE: fused_forward_model_gen(), + SupportedModel.DEEPSEEK_V3: fused_forward_model_gen(), + SupportedModel.GLM4_MOE: fused_forward_model_gen(), + SupportedModel.GPT_OSS: fused_forward_model_gen(), + SupportedModel.MiMO: fused_forward_model_gen(), +} + +# Registry for model weight converters +MODEL_WEIGHT_CONVERTER_REGISTRY: dict[SupportedModel, type] = { + SupportedModel.LLAMA: McoreToHFWeightConverterDense, + SupportedModel.QWEN2: McoreToHFWeightConverterDense, + SupportedModel.QWEN2_MOE: McoreToHFWeightConverterQwen2Moe, + SupportedModel.MIXTRAL: McoreToHFWeightConverterMixtral, + SupportedModel.DEEPSEEK_V3: McoreToHFWeightConverterDpskv3, + SupportedModel.QWEN3: McoreToHFWeightConverterDense, + SupportedModel.QWEN3_MOE: McoreToHFWeightConverterQwen3Moe, + SupportedModel.QWEN2_5_VL: McoreToHFWeightConverterQwen2_5_VL, + SupportedModel.QWEN3_TOKEN_CLASSIFICATION: McoreToHFWeightConverterDense, + SupportedModel.LLAMA_TOKEN_CLASSIFICATION: McoreToHFWeightConverterDense, +} + + +def get_supported_model(model_type: str) -> SupportedModel: + try: + return SupportedModel(model_type) + except ValueError as err: + supported_models = [e.value for e in SupportedModel] + raise NotImplementedError( + f"Model Type: {model_type} not supported. Supported models: {supported_models}" + ) from err + + +def hf_to_mcore_config( + hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs +) -> TransformerConfig: + """Convert huggingface PretrainedConfig to mcore TransformerConfig. + + Args: + hf_config: The huggingface PretrainedConfig. + dtype: The dtype of the model. + **override_transformer_config_kwargs: The kwargs to override the transformer config. + + Returns: + The mcore TransformerConfig. + """ + assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" + model = get_supported_model(hf_config.architectures[0]) + return MODEL_CONFIG_CONVERTER_REGISTRY[model](hf_config, dtype, **override_transformer_config_kwargs) + + +def init_mcore_model( + tfconfig: TransformerConfig, + hf_config: PretrainedConfig, + pre_process: bool = True, + post_process: bool = None, + *, + share_embeddings_and_output_weights: bool = False, + value: bool = False, + **extra_kwargs, # may be used for vlm and moe +) -> nn.Module: + """ + Initialize a Mcore model. + + Args: + tfconfig: The transformer config. + hf_config: The HuggingFace config. + pre_process: Optional pre-processing function. + post_process: Optional post-processing function. + share_embeddings_and_output_weights: Whether to share embeddings and output weights. + value: Whether to use value. + **extra_kwargs: Additional keyword arguments. + + Returns: + The initialized model. + """ + assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" + model = get_supported_model(hf_config.architectures[0]) + initializer_cls = MODEL_INITIALIZER_REGISTRY[model] + initializer = initializer_cls(tfconfig, hf_config) + return initializer.initialize( + pre_process=pre_process, + post_process=post_process, + share_embeddings_and_output_weights=share_embeddings_and_output_weights, + value=value, + **extra_kwargs, + ) + + +def get_mcore_weight_converter(hf_config: PretrainedConfig, dtype: torch.dtype) -> Callable: + """ + Get the weight converter for given model architecture. + """ + assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" + model = get_supported_model(hf_config.architectures[0]) + tfconfig = hf_to_mcore_config(hf_config, dtype) + return MODEL_WEIGHT_CONVERTER_REGISTRY[model](hf_config, tfconfig) diff --git a/code/RL_model/verl/verl_train/verl/models/mcore/saver.py b/code/RL_model/verl/verl_train/verl/models/mcore/saver.py new file mode 100644 index 0000000000000000000000000000000000000000..2a954b2417cd5b8d09e88b9935e52eeb6ef5273a --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/mcore/saver.py @@ -0,0 +1,497 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import torch +import torch.distributed as dist +from megatron.core import mpu +from megatron.core.distributed import DistributedDataParallel as LocalDDP +from megatron.core.transformer.module import Float16Module +from torch.nn.parallel import DistributedDataParallel as torchDDP + +from verl.utils.device import get_device_id, get_torch_device +from verl.utils.logger import print_rank_0 +from verl.utils.megatron_utils import unwrap_model + + +def _megatron_calc_global_rank( + tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0, cp_rank: int = 0, ep_rank: int = 0 +): + """Calculate global rank with support for CP/EP parallelism""" + + # Get parallel sizes for each dimension + tp_size = mpu.get_tensor_model_parallel_world_size() + dp_size = mpu.get_data_parallel_world_size() + pp_size = mpu.get_pipeline_model_parallel_world_size() + cp_size = mpu.get_context_parallel_world_size() + # ep_size = mpu.get_expert_model_parallel_world_size() + + # Verify total GPU count matches (must be consistent with parallel_state.py) + total_size = tp_size * dp_size * pp_size * cp_size + assert total_size == torch.distributed.get_world_size(), ( + f"{tp_size}x{dp_size}x{pp_size}x{cp_size} != {torch.distributed.get_world_size()}" + ) + + # Core calculation logic (corresponds to RankGenerator order parameter) + # Assumes default order is "tp-cp-ep-dp-pp" + return ((pp_rank * dp_size + dp_rank) * cp_size + cp_rank) * tp_size + tp_rank + + +def _megatron_calc_layer_map(config): + """Calculate the mapping of global layer_idx to local layer_idx + Returns: + layer_map (Dict: int -> tuple(int, int, int)): + mapping from the global layer index to + a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) + """ + from megatron.core import mpu + + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + + layer_map = dict() + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + for pp_rank_idx in range(pp_size): + for virtual_pp_rank_idx in range(virtual_pp_size): + layer_offset = ( + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + ) + for layer_idx in range(num_layers_per_model): + layer_map[layer_offset + layer_idx] = ( + pp_rank_idx, + virtual_pp_rank_idx, + layer_idx, + ) + return layer_map + + +def merge_megatron_ckpt_gptmodel(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False): + """Merge sharded parameters of a Megatron module into a merged checkpoint. + + Args: + wrapped_models (list of megatron.core.distributed.DistributedDataParallel): + The local DDP wrapped megatron modules. + config (str or None): + HF config for model + dtype: model params type + is_value_model: if model is value model + tie_word_embeddings: tie_word_embeddings + Returns: + state_dict (dict): + The merged state_dict in rank 0, and an empty dictionary in other ranks. + """ + start_time = time.time() + + def _get_gpt_model(model): + return model + + dp_rank = mpu.get_data_parallel_rank() + pp_size = mpu.get_pipeline_model_parallel_world_size() + pp_rank = mpu.get_pipeline_model_parallel_rank() + cp_rank = mpu.get_context_parallel_rank() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + mp_group = mpu.get_model_parallel_group() + + if dist.get_rank() == 0: + assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" + assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" + assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" + + if not isinstance(wrapped_models, list | tuple): + wrapped_models = list(wrapped_models) + + assert len(wrapped_models) == virtual_pp_size + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + models = [None] * len(wrapped_models) + + for i, wrapped_model in enumerate(wrapped_models): + models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) + assert len(models[i].decoder.layers) == num_layers_per_model, ( + "len model layers {} not equal to num_layers_per_model {}".format( + len(models[i].decoder.layers), num_layers_per_model + ) + ) + + state_dict = dict() + + def _get_cpu_tensor(tensor: torch.Tensor): + if tensor is None: + return None + if tensor.device == torch.device("cpu"): + return tensor.detach().clone() + return tensor.detach().cpu() + + def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor: + """broadcast tensor across mp_group""" + nonlocal state_dict + nonlocal mp_group + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) + + if torch.distributed.get_rank() == src_rank: + if tensor is None: + weight = None + tensor_shape = None + else: + weight = tensor + tensor_shape = weight.shape + else: + weight = None + tensor_shape = None + + obj_list = [tensor_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + tensor_shape = obj_list[0] + + if tensor_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tensor:[{name}] not exist, skip collect") + return + + if weight is None: + weight = torch.empty( + tensor_shape, + dtype=dtype, + device=get_device_id(), + requires_grad=False, + ) + + dist.broadcast(weight, src=src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + state_dict[name] = _get_cpu_tensor(weight) + + def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + # tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) + + chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{name}] not exist, skip collecting") + return + + buffer_tensor = torch.empty( + chunk_shape, + dtype=dtype, + device=get_device_id(), + requires_grad=False, + ) + + chunk_tensors = [None] * tp_size + + for i in range(tp_size): + cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) + sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + chunk_tensors[i] = _get_cpu_tensor(sync_tensor) + + if torch.distributed.get_rank() == 0: + full_tensor = torch.concat(chunk_tensors, dim=concat_dim) + if mutate_func is not None: + full_tensor = mutate_func(full_tensor) + state_dict[name] = full_tensor + + def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + # tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) + + chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting") + return + + buffer_tensor = torch.empty( + chunk_shape, + dtype=dtype, + device=get_device_id(), + requires_grad=False, + ) + + chunk_tensors = [None] * tp_size + + for i in range(tp_size): + cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) + sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + chunk_tensors[i] = _get_cpu_tensor(sync_tensor) + + if torch.distributed.get_rank() == 0: + full_tensor = torch.concat(chunk_tensors, dim=0) + intermediate_size_tp = config.intermediate_size // tp_size + gate_weight_list = [] + up_weight_list = [] + for i in range(tp_size): + gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)] + gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp] + up_weight_tp = gate_up_weight_tp[intermediate_size_tp:] + gate_weight_list.append(gate_weight_tp) + up_weight_list.append(up_weight_tp) + + state_dict[gate_name] = torch.cat(gate_weight_list, dim=0) + state_dict[up_name] = torch.cat(up_weight_list, dim=0) + + def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + # tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) + + chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{q_name}] not exist, skip collecting") + return + + buffer_tensor = torch.empty( + chunk_shape, + dtype=dtype, + device=get_device_id(), + requires_grad=False, + ) + + chunk_tensors = [None] * tp_size + + for i in range(tp_size): + cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) + sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + chunk_tensors[i] = _get_cpu_tensor(sync_tensor) + + if torch.distributed.get_rank() == 0: + full_tensor = torch.concat(chunk_tensors, dim=0) + q_weight_list = [] + k_weight_list = [] + v_weight_list = [] + hidden_size_per_head = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + + if config.num_key_value_heads >= tp_size: + q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size + kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size + total_size = q_size_tp + 2 * kv_size_tp + for i in range(tp_size): + num_query_groups_per_partition = wrapped_models[0].config.num_query_groups // tp_size + qkv_part = full_tensor[i * total_size : (i + 1) * total_size] + q_size_chunk = q_size_tp // num_query_groups_per_partition + kv_size_chunk = kv_size_tp // num_query_groups_per_partition + for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition): + q_part = qkv_part_chunk[:q_size_chunk] + k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk] + v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :] + q_weight_list.append(q_part) + k_weight_list.append(k_part) + v_weight_list.append(v_part) + else: + q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size + kv_size_tp = hidden_size_per_head + total_size = q_size_tp + 2 * kv_size_tp + for i in range(tp_size): + num_query_groups_per_partition = wrapped_models[0].config.num_query_groups // tp_size + qkv_part = full_tensor[i * total_size : (i + 1) * total_size] + q_size_chunk = q_size_tp // num_query_groups_per_partition + kv_size_chunk = kv_size_tp // num_query_groups_per_partition + for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition): + q_part = qkv_part_chunk[:q_size_chunk] + k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk] + v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :] + q_weight_list.append(q_part) + if i * config.num_key_value_heads % tp_size == 0: + k_weight_list.append(k_part) + v_weight_list.append(v_part) + + state_dict[q_name] = torch.cat(q_weight_list, dim=0) + state_dict[k_name] = torch.cat(k_weight_list, dim=0) + state_dict[v_name] = torch.cat(v_weight_list, dim=0) + + # empty cache before collecting weights + get_torch_device().empty_cache() + # Embeddings + # ------------------- + if dp_rank == 0 and cp_rank == 0: # models are identical across cp ranks + # Embeddings + # ------------------- + print_rank_0("collecting embeddings...") + gpt_model_module = _get_gpt_model(models[0]) + _broadcast_tp_shard_tensor( + gpt_model_module.embedding.word_embeddings.weight if pp_rank == 0 else None, + "model.embed_tokens.weight", + src_pp_rank=0, + ) + + # Transformer layers + # ------------------- + layer_map = _megatron_calc_layer_map(config) + for layer in range(config.num_hidden_layers): + print_rank_0(f"collecting layer #{layer}...") + layer_name = f"model.layers.{layer}" + src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer] + + gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank]) + sync_layer = gpt_model_module.decoder.layers[src_layer_idx] + + _broadcast_tensor( + sync_layer.self_attention.linear_qkv.layer_norm_weight, + f"{layer_name}.input_layernorm.weight", + src_pp_rank=src_pp_rank, + ) + + if gpt_model_module.config.qk_layernorm: + _broadcast_tensor( + sync_layer.self_attention.q_layernorm.weight, + f"{layer_name}.self_attn.q_norm.weight", + src_pp_rank=src_pp_rank, + ) + _broadcast_tensor( + sync_layer.self_attention.k_layernorm.weight, + f"{layer_name}.self_attn.k_norm.weight", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor_qkv( + sync_layer.self_attention.linear_qkv.weight, + f"{layer_name}.self_attn.q_proj.weight", + f"{layer_name}.self_attn.k_proj.weight", + f"{layer_name}.self_attn.v_proj.weight", + src_pp_rank=src_pp_rank, + ) + + if gpt_model_module.config.add_qkv_bias: + _broadcast_tp_shard_tensor_qkv( + sync_layer.self_attention.linear_qkv.bias, + f"{layer_name}.self_attn.q_proj.bias", + f"{layer_name}.self_attn.k_proj.bias", + f"{layer_name}.self_attn.v_proj.bias", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor( + sync_layer.self_attention.linear_proj.weight, + f"{layer_name}.self_attn.o_proj.weight", + concat_dim=1, + src_pp_rank=src_pp_rank, + ) + + _broadcast_tensor( + sync_layer.mlp.linear_fc1.layer_norm_weight, + f"{layer_name}.post_attention_layernorm.weight", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor_gate_up( + sync_layer.mlp.linear_fc1.weight, + f"{layer_name}.mlp.gate_proj.weight", + f"{layer_name}.mlp.up_proj.weight", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor( + sync_layer.mlp.linear_fc2.weight, + f"{layer_name}.mlp.down_proj.weight", + concat_dim=1, + src_pp_rank=src_pp_rank, + ) + + # Final Layernorm + # ------------------- + print_rank_0("collecting final layernorm...") + gpt_model_module = _get_gpt_model(models[-1]) + _broadcast_tensor( + getattr(gpt_model_module.decoder.final_layernorm, "weight", None), + "model.norm.weight", + src_pp_rank=pp_size - 1, + ) + + if tie_word_embeddings: + print_rank_0("tie word embedding skip load lm_head...") + else: + print_rank_0("collecting lm_head...") + + if is_value_model: + lm_head_weight = None + if pp_rank == pp_size - 1: + lm_head_weight = getattr(gpt_model_module.output_layer, "weight", None) + _broadcast_tensor(lm_head_weight, "lm_head.weight", src_pp_rank=pp_size - 1) + + else: + _broadcast_tp_shard_tensor( + getattr(gpt_model_module.output_layer, "weight", None) if pp_rank == pp_size - 1 else None, + "lm_head.weight", + src_pp_rank=pp_size - 1, + ) + + dist.barrier() + get_torch_device().empty_cache() + if torch.distributed.get_rank() == 0: + for k, v in state_dict.items(): + if dtype != v.dtype: + state_dict[k] = v.to(dtype) + + print_rank_0(f"merge megatron ckpt done, time elapsed {time.time() - start_time}s") + return state_dict + + +def merge_megatron_ckpt_gptmodel_qwen_moe( + wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False +): + raise NotImplementedError("merge_megatron_ckpt_gptmodel_qwen_moe is not implemented") + + +def merge_megatron_ckpt_gptmodel_qwen2_5_vl( + wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False +): + raise NotImplementedError("merge_megatron_ckpt_gptmodel_qwen2_5_vl is not implemented") + + +def merge_megatron_ckpt_gptmodel_dpskv3(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False): + raise NotImplementedError("merge_megatron_ckpt_gptmodel_dpskv3 is not implemented") + + +def merge_megatron_ckpt_gptmodel_mixtral( + wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False +): + raise NotImplementedError("merge_megatron_ckpt_gptmodel_mixtral is not implemented") diff --git a/code/RL_model/verl/verl_train/verl/models/mcore/util.py b/code/RL_model/verl/verl_train/verl/models/mcore/util.py new file mode 100644 index 0000000000000000000000000000000000000000..aefb798aa0bceecb82ea6cbc5397c1e070118017 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/mcore/util.py @@ -0,0 +1,493 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +from megatron.core import parallel_state as mpu +from megatron.core.packed_seq_params import PackedSeqParams + +from verl.utils.model import CausalLMOutputForPPO + + +def preprocess_packed_seqs( + input_ids: torch.Tensor, attention_mask: torch.Tensor, pre_process: bool = True, use_fp8_padding=False +) -> tuple[torch.Tensor, PackedSeqParams]: + """ + Preprocess packed sequences + CP splits sequence into CP*2 chunks, and each GPU gets 2 chunks (GPU0 gets first and last chunks, GPU1 + gets second and second last chunks, and so on), this is for load balancing with causal masking. + See https://github.com/NVIDIA/TransformerEngine/issues/1368 + """ + batch_size = input_ids.shape[0] + + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + tp_size = mpu.get_tensor_model_parallel_world_size() + cp_size = mpu.get_context_parallel_world_size() + cp_rank = mpu.get_context_parallel_rank() + align_size = tp_size * cp_size * 2 if cp_size > 1 else tp_size + if use_fp8_padding: + # if fp8 is enabled, ensure the sequence is padded to multiples of 16 for better performance + original_align_size = align_size + align_size = math.lcm(16, align_size) + + pad_size = (align_size - seqlens_in_batch % align_size) % align_size + seqlens_in_batch_padded = seqlens_in_batch + pad_size + + cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device) + cu_seqlens[1:] = torch.cumsum(seqlens_in_batch, dim=0) + cu_seqlens_padded = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device) + cu_seqlens_padded[1:] = torch.cumsum(seqlens_in_batch_padded, dim=0) + + if use_fp8_padding: + # make sure all the sequences are padded to multiples of 128 for TE compatibility + align_size_last = original_align_size * 128 + pad_size_last = (align_size_last - cu_seqlens_padded[-1] % align_size_last) % align_size_last + cu_seqlens_padded[-1] += pad_size_last + seqlens_in_batch_padded[-1] += pad_size_last + + # ---------------------------------------------------------------------------- + # Move the index information needed in the subsequent loop to the CPU at once, + # to avoid frequent .item() calls in the loop that cause D2H synchronization + # ---------------------------------------------------------------------------- + seqlens_in_batch_cpu: list[int] = seqlens_in_batch.tolist() # original valid lengths + seqlens_in_batch_padded_cpu: list[int] = seqlens_in_batch_padded.tolist() # lengths after padding + cu_seqlens_padded_cpu: list[int] = cu_seqlens_padded.tolist() # start positions (after padding) + + # Pure Python int calculation to avoid further synchronization + max_seqlen_in_batch = max(seqlens_in_batch_padded_cpu) + + shape = list(input_ids.shape[1:]) + shape[0] = sum(seqlens_in_batch_padded_cpu) // cp_size + if pre_process: + input_ids_rmpad = torch.zeros(shape, dtype=input_ids.dtype, device=input_ids.device) + for i in range(batch_size): + # Use Python int, so no GPU→CPU sync in the loop + if cp_size <= 1: + seqlen = seqlens_in_batch_cpu[i] + start_idx = cu_seqlens_padded_cpu[i] + input_ids_rmpad[start_idx : start_idx + seqlen] = input_ids[i, attention_mask[i]] + continue + + seqlen_padded_i = seqlens_in_batch_padded_cpu[i] + seqlen = seqlen_padded_i // cp_size + half_seqlen = seqlen // 2 + start_idx = cu_seqlens_padded_cpu[i] // cp_size + # split to 2 chunks + d = input_ids[i, attention_mask[i]] + input_ids_rmpad[start_idx : start_idx + half_seqlen] = d[ + half_seqlen * cp_rank : half_seqlen * (cp_rank + 1) + ] + + remain_start = seqlen_padded_i - half_seqlen * (cp_rank + 1) + remain_end = seqlen_padded_i - half_seqlen * cp_rank + remain_end = min(remain_end, d.shape[0]) + remain_len = remain_end - remain_start + if remain_len > 0: + input_ids_rmpad[start_idx + half_seqlen : start_idx + half_seqlen + remain_len] = d[ + remain_start:remain_end + ] + + packed_seq_params = PackedSeqParams( + qkv_format="thd", + cu_seqlens_q=cu_seqlens_padded, + max_seqlen_q=max_seqlen_in_batch, + cu_seqlens_kv=cu_seqlens_padded, + max_seqlen_kv=max_seqlen_in_batch, + cu_seqlens_q_padded=cu_seqlens_padded, + cu_seqlens_kv_padded=cu_seqlens_padded, + ) + if pre_process: + return input_ids_rmpad.unsqueeze(0), packed_seq_params + else: + return input_ids, packed_seq_params + + +def postprocess_packed_seqs( + output: torch.Tensor, + packed_seq_params: PackedSeqParams, + attention_mask: torch.Tensor, + batch_size: int, + seq_len: int, + post_process: bool = True, +) -> torch.Tensor: + """ + Postprocess packed sequences + """ + if not post_process: + return output + + # ------------------------------------------------------------------------- + # Move the lengths and offsets needed for subsequent Python-level indexing to the CPU in advance, + # to avoid a large number of .item() calls in the loop + # ------------------------------------------------------------------------- + cu_padded_cpu: list[int] = packed_seq_params.cu_seqlens_q_padded.tolist() + seq_lens_cpu: list[int] = attention_mask.sum(dim=1, dtype=torch.int32).cpu().tolist() + + shape = [batch_size, seq_len] + list(output.shape[2:]) # 1,packed, dim -> batch_size, seq_len, dim + output_new = torch.zeros(shape, dtype=output.dtype, device=output.device) + + cp_size = mpu.get_context_parallel_world_size() + # all gather output across context parallel group + if cp_size > 1: + # output shape: [1, packed_len, hidden_dim] + # need to gather across cp group and concatenate in sequence dimension + output_list = [torch.empty_like(output, dtype=output.dtype) for _ in range(cp_size)] + torch.distributed.all_gather(output_list, output.detach(), group=mpu.get_context_parallel_group()) + output_list[mpu.get_context_parallel_rank()] = output + else: + output_list = [output] + for i in range(batch_size): + if cp_size <= 1: + s = seq_lens_cpu[i] + start_idx = cu_padded_cpu[i] + output_new[i, attention_mask[i]] = output[0][start_idx : start_idx + s] + continue + s_len_padded_chunk = (cu_padded_cpu[i + 1] - cu_padded_cpu[i]) // cp_size + half_seqlen = s_len_padded_chunk // 2 + s_len = seq_lens_cpu[i] + s_len_padded = s_len_padded_chunk * cp_size + tmp = torch.empty(s_len_padded, *output.shape[2:], device=output.device, dtype=output.dtype) + for j in range(cp_size): + o = output_list[j][0] + # split to 2 chunks + packed_start_idx = cu_padded_cpu[i] // cp_size + o0, o1 = ( + o[packed_start_idx : packed_start_idx + half_seqlen], + o[packed_start_idx + half_seqlen : packed_start_idx + s_len_padded_chunk], + ) + tmp[j * half_seqlen : (j + 1) * half_seqlen] = o0 + tmp[s_len_padded - (j + 1) * half_seqlen : s_len_padded - j * half_seqlen] = o1 + output_new[i, attention_mask[i]] = tmp[:s_len] + + return output_new + + +def preprocess_bshd( + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: torch.Tensor, + sequence_parallel: bool = False, + pre_process: bool = True, +): + """ + Remove left padding from input_ids, attention_mask and position_ids + return new_input_ids, new_attention_mask, new_position_ids + """ + assert attention_mask.ndim == 2 + assert position_ids.ndim == 2 + cp_size = mpu.get_context_parallel_world_size() + assert cp_size == 1, "Context parallel size without seq_pack is not supported" + batch_size = input_ids.shape[0] + shape = list(input_ids.shape) # batch_size, seq_len,... + seq_lens = attention_mask.sum(dim=1) + seq_len = seq_lens.max().item() + if sequence_parallel: + sp_world_size = mpu.get_tensor_model_parallel_world_size() + pad_size = (sp_world_size - seq_len % sp_world_size) % sp_world_size + seq_len = seq_len + pad_size + shape[1] = seq_len + if pre_process: + new_input_ids = torch.zeros(dtype=input_ids.dtype, device=input_ids.device, size=shape) + new_attention_mask = torch.zeros( + dtype=attention_mask.dtype, device=attention_mask.device, size=(batch_size, seq_len) + ) + new_position_ids = torch.zeros(dtype=position_ids.dtype, device=position_ids.device, size=(batch_size, seq_len)) + for i in range(batch_size): + if pre_process: + new_input_ids[i, : seq_lens[i]] = input_ids[i, attention_mask[i]] + new_attention_mask[i, : seq_lens[i]] = attention_mask[i, attention_mask[i]] + new_position_ids[i, : seq_lens[i]] = position_ids[i, attention_mask[i]] + if pre_process: + return new_input_ids, new_attention_mask, new_position_ids + else: + return input_ids, new_attention_mask, new_position_ids + + +def postprocess_bshd( + result, + attention_mask: torch.Tensor, + original_attention_mask: torch.Tensor, + origin_seqlen: int, + post_process: bool = True, +): + """ + Recover left padding from result + return result + """ + if not post_process: + return result + shape = list(result.shape) + batch_size = shape[0] + shape[1] = origin_seqlen + new_result = torch.zeros(dtype=result.dtype, device=result.device, size=shape) + for i in range(batch_size): + new_result[i, original_attention_mask[i]] = result[i, attention_mask[i]] + return new_result + + +def postprocess_packed_seqs_for_dict_output( + labels_mask: torch.Tensor, + output: CausalLMOutputForPPO, + packed_seq_params: PackedSeqParams, + attention_mask: torch.Tensor, + batch_size: int, + seq_len: int, + post_process: bool = True, +) -> dict[str, torch.Tensor]: + """_summary_ + For fused kernels, the output is a dictionary with keys like 'log_probs', 'entropy', etc. + This function post-processes each tensor in the output dictionary. + Args: + output (CausalLMOutputForPPO): _description_ + packed_seq_params (PackedSeqParams): _description_ + attention_mask (torch.Tensor): _description_ + batch_size (int): _description_ + seq_len (int): _description_ + post_process (bool, optional): _description_. Defaults to True. + Returns: + CausalLMOutputForPPO: _description_ + """ + ret = {} + output.entropy = output.entropy.view(1, -1) + output.log_probs = output.log_probs.view(1, -1) + output.log_probs = output.log_probs.masked_fill(~labels_mask, 0.0) + ret["entropy"] = postprocess_packed_seqs( + output.entropy, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process + ) + ret["log_probs"] = postprocess_packed_seqs( + output.log_probs, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process + ) + return ret + + +### No padding versions for model engine +### inputs are nested tensors + + +def preprocess_thd_no_padding( + input_ids: torch.Tensor, pre_process: bool = True, need_roll: bool = False +) -> tuple[torch.Tensor, PackedSeqParams]: + """ + Preprocess packed sequences + CP splits sequence into CP*2 chunks, and each GPU gets 2 chunks (GPU0 gets first and last chunks, GPU1 + gets second and second last chunks, and so on), this is for load balancing with causal masking. + See https://github.com/NVIDIA/TransformerEngine/issues/1368 + """ + batch_size = input_ids.shape[0] + + tp_size = mpu.get_tensor_model_parallel_world_size() + cp_size = mpu.get_context_parallel_world_size() + cp_rank = mpu.get_context_parallel_rank() + align_size = tp_size * cp_size * 2 if cp_size > 1 else tp_size + seqlens_in_batch = input_ids.offsets().diff() + + pad_size = (align_size - seqlens_in_batch % align_size) % align_size + seqlens_in_batch_padded = seqlens_in_batch + pad_size + + cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device) + cu_seqlens[1:] = torch.cumsum(seqlens_in_batch, dim=0) + cu_seqlens_padded = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device) + cu_seqlens_padded[1:] = torch.cumsum(seqlens_in_batch_padded, dim=0) + + # ---------------------------------------------------------------------------- + # Move the index information needed in the subsequent loop to the CPU at once, + # to avoid frequent .item() calls in the loop that cause D2H synchronization + # ---------------------------------------------------------------------------- + seqlens_in_batch_cpu: list[int] = seqlens_in_batch.tolist() # original valid lengths + seqlens_in_batch_padded_cpu: list[int] = seqlens_in_batch_padded.tolist() # lengths after padding + cu_seqlens_padded_cpu: list[int] = cu_seqlens_padded.tolist() # start positions (after padding) + + # Pure Python int calculation to avoid further synchronization + max_seqlen_in_batch = max(seqlens_in_batch_padded_cpu) + + shape = list(input_ids.shape[1:]) + shape[0] = sum(seqlens_in_batch_padded_cpu) // cp_size + if pre_process: + input_ids_rmpad = torch.zeros(shape, dtype=input_ids.dtype, device=input_ids.device) + if need_roll: + saved_roll_dict = {} + for i in range(batch_size): + # Use Python int, so no GPU→CPU sync in the loop + if cp_size <= 1: + seqlen = seqlens_in_batch_cpu[i] + start_idx = cu_seqlens_padded_cpu[i] + input_ids_rmpad[start_idx : start_idx + seqlen] = input_ids[i] + continue + + seqlen_padded_i = seqlens_in_batch_padded_cpu[i] + seqlen = seqlen_padded_i // cp_size + half_seqlen = seqlen // 2 + start_idx = cu_seqlens_padded_cpu[i] // cp_size + # split to 2 chunks + d = input_ids[i] + input_ids_rmpad[start_idx : start_idx + half_seqlen] = d[ + half_seqlen * cp_rank : half_seqlen * (cp_rank + 1) + ] + + remain_start = seqlen_padded_i - half_seqlen * (cp_rank + 1) + remain_end = seqlen_padded_i - half_seqlen * cp_rank + remain_end = min(remain_end, d.shape[0]) + remain_len = remain_end - remain_start + if remain_len > 0: + input_ids_rmpad[start_idx + half_seqlen : start_idx + half_seqlen + remain_len] = d[ + remain_start:remain_end + ] + + if need_roll: + # Handle roll for cp_size > 1 case + saved_roll_dict[start_idx + half_seqlen - 1] = d[(cp_rank + 1) * half_seqlen] + if remain_len > 0: + if remain_end == d.shape[0]: + saved_roll_dict[start_idx + half_seqlen + remain_len - 1] = d[0] + else: + saved_roll_dict[start_idx + half_seqlen + remain_len - 1] = d[remain_end] + + if need_roll: + input_ids_rmpad = torch.roll(input_ids_rmpad, shifts=-1, dims=0) + if len(saved_roll_dict) > 0: + for k, v in saved_roll_dict.items(): + input_ids_rmpad[k] = v + + packed_seq_params = PackedSeqParams( + qkv_format="thd", + cu_seqlens_q=cu_seqlens_padded, + max_seqlen_q=max_seqlen_in_batch, + cu_seqlens_kv=cu_seqlens_padded, + max_seqlen_kv=max_seqlen_in_batch, + cu_seqlens_q_padded=cu_seqlens_padded, + cu_seqlens_kv_padded=cu_seqlens_padded, + ) + if pre_process: + return input_ids_rmpad.unsqueeze(0), packed_seq_params + else: + return input_ids, packed_seq_params + + +def postprocess_thd_no_padding( + output: torch.Tensor, + packed_seq_params: PackedSeqParams, + input_ids: torch.Tensor, + batch_size: int, + post_process: bool = True, +) -> torch.Tensor: + """ + Postprocess packed sequences + """ + if not post_process: + return output + + # ------------------------------------------------------------------------- + # Move the lengths and offsets needed for subsequent Python-level indexing to the CPU in advance, + # to avoid a large number of .item() calls in the loop + # ------------------------------------------------------------------------- + cu_padded_cpu: list[int] = packed_seq_params.cu_seqlens_q_padded.tolist() + # The reason why we use input_ids.offsets() instead of packed_seq_params.cu_seqlens_q.diff() + # is that the latter one is the padded length, while the former one is the original length. + cu_seqlens = input_ids.offsets() + seq_lens_cpu: list[int] = cu_seqlens.diff().tolist() + + output_new = [] + + cp_size = mpu.get_context_parallel_world_size() + # all gather output across context parallel group + if cp_size > 1: + # output shape: [1, packed_len, hidden_dim] + # need to gather across cp group and concatenate in sequence dimension + output_list = [torch.empty_like(output) for _ in range(cp_size)] + torch.distributed.all_gather(output_list, output.detach(), group=mpu.get_context_parallel_group()) + output_list[mpu.get_context_parallel_rank()] = output + else: + output_list = [output] + + for i in range(batch_size): + if cp_size <= 1: + s = seq_lens_cpu[i] + start_idx = cu_padded_cpu[i] + output_new.append(output[0][start_idx : start_idx + s]) + continue + s_len_padded_chunk = (cu_padded_cpu[i + 1] - cu_padded_cpu[i]) // cp_size + half_seqlen = s_len_padded_chunk // 2 + s_len = seq_lens_cpu[i] + s_len_padded = s_len_padded_chunk * cp_size + tmp = torch.empty(s_len_padded, *output.shape[2:], device=output.device) + for j in range(cp_size): + o = output_list[j][0] + # split to 2 chunks + packed_start_idx = cu_padded_cpu[i] // cp_size + o0, o1 = ( + o[packed_start_idx : packed_start_idx + half_seqlen], + o[packed_start_idx + half_seqlen : packed_start_idx + s_len_padded_chunk], + ) + tmp[j * half_seqlen : (j + 1) * half_seqlen] = o0 + tmp[s_len_padded - (j + 1) * half_seqlen : s_len_padded - j * half_seqlen] = o1 + output_new.append(tmp[:s_len]) + + output_new_tensor = torch.nested.as_nested_tensor(output_new, layout=torch.jagged) + + return output_new_tensor + + +def preprocess_bshd_no_padding(input_ids: torch.Tensor, pre_process: bool = True, need_roll: bool = False): + """ + Preprocess bshd sequences + return "input_ids, attention_mask, position_ids" + """ + cp_size = mpu.get_context_parallel_world_size() + # TODO: support context parallel size > 1 + assert cp_size == 1, "Context parallel size without bshd is not supported yet" + + batch_size = input_ids.shape[0] + seqlens_in_batch = input_ids.offsets().diff() + max_seqlen = seqlens_in_batch.max().item() + if mpu.get_tensor_model_parallel_world_size() > 1: + sp_world_size = mpu.get_tensor_model_parallel_world_size() + pad_size = (sp_world_size - max_seqlen % sp_world_size) % sp_world_size + max_seqlen = max_seqlen + pad_size + + attention_mask = torch.zeros(batch_size, max_seqlen, dtype=torch.bool, device=input_ids.device) + input_ids_bshd = torch.zeros(batch_size, max_seqlen, dtype=input_ids.dtype, device=input_ids.device) + for i in range(batch_size): + attention_mask[i, : seqlens_in_batch[i]] = True + input_ids_bshd[i, : seqlens_in_batch[i]] = input_ids[i] + position_ids = torch.arange(max_seqlen, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids_bshd) + if need_roll: + input_ids_bshd = torch.roll(input_ids_bshd, shifts=-1, dims=1) + + return input_ids_bshd, attention_mask, position_ids + + +def postprocess_bshd_no_padding( + output: torch.Tensor, + attention_mask: torch.Tensor, + post_process: bool = True, +) -> torch.Tensor: + """ + Postprocess bshd sequences + """ + if not post_process: + return output + + batch_size = output.shape[0] + output_new = [] + + for i in range(batch_size): + mask = attention_mask[i].bool() + output_new.append(output[i][mask]) + + output_new_tensor = torch.nested.as_nested_tensor(output_new, layout=torch.jagged) + + return output_new_tensor diff --git a/code/RL_model/verl/verl_train/verl/models/mcore/weight_converter.py b/code/RL_model/verl/verl_train/verl/models/mcore/weight_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..791513f32d1b7ab1e220d2c7f1abb5a2c8abeba3 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/mcore/weight_converter.py @@ -0,0 +1,479 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# online convert mcore weight to pure huggingface weight, no any fusion +# including format conversion and name mapping +# not including resharding +import torch +from megatron.core.transformer import TransformerConfig +from transformers import PretrainedConfig + + +class McoreToHFWeightConverterBase: + def __init__(self, hf_config: PretrainedConfig, mcore_config: TransformerConfig): + self.hf_config = hf_config + self.mcore_config = mcore_config + + def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> torch.Tensor: + raise NotImplementedError + + +class McoreToHFWeightConverterDense(McoreToHFWeightConverterBase): + def _convert_attention_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + # 'decoder.layers.0.self_attention.linear_proj.weight' + # 'decoder.layers.0.self_attention.linear_qkv.layer_norm_weight' + # 'decoder.layers.0.self_attention.linear_qkv.weight' + # 'decoder.layers.0.self_attention.linear_qkv.bias' + layer_number = name.split(".")[2] + convert_names = [] + if "self_attention.linear_qkv.bias" in name or "self_attention.linear_qkv.weight" in name: + param_type = name.split(".")[-1] + assert param_type == "bias" or param_type == "weight" + convert_names.append(f"model.layers.{layer_number}.self_attn.q_proj.{param_type}") + convert_names.append(f"model.layers.{layer_number}.self_attn.k_proj.{param_type}") + convert_names.append(f"model.layers.{layer_number}.self_attn.v_proj.{param_type}") + assert len(params) == 3 + elif "self_attention.linear_proj.weight" in name: + convert_names.append(f"model.layers.{layer_number}.self_attn.o_proj.weight") + assert len(params) == 1 + elif "self_attention.linear_qkv.layer_norm_weight" in name: + convert_names.append(f"model.layers.{layer_number}.input_layernorm.weight") + assert len(params) == 1 + elif "self_attention.q_layernorm.weight" in name: + convert_names.append(f"model.layers.{layer_number}.self_attn.q_norm.weight") + assert len(params) == 1 + elif "self_attention.k_layernorm.weight" in name: + convert_names.append(f"model.layers.{layer_number}.self_attn.k_norm.weight") + assert len(params) == 1 + else: + raise NotImplementedError(f"Unsupported parameter name: {name}") + return convert_names, params + + def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + # 'decoder.layers.0.mlp.linear_fc1.layer_norm_weight' + # 'decoder.layers.0.mlp.linear_fc1.weight' + # 'decoder.layers.0.mlp.linear_fc2.weight' + layer_number = name.split(".")[2] + convert_names = [] + if "mlp.linear_fc1.weight" in name: + # split gate_proj and up_proj + convert_names.append(f"model.layers.{layer_number}.mlp.gate_proj.weight") + convert_names.append(f"model.layers.{layer_number}.mlp.up_proj.weight") + assert len(params) == 2 + elif "mlp.linear_fc1.layer_norm_weight" in name: + convert_names.append(f"model.layers.{layer_number}.post_attention_layernorm.weight") + assert len(params) == 1 + elif "mlp.linear_fc2.weight" in name: + convert_names.append(f"model.layers.{layer_number}.mlp.down_proj.weight") + assert len(params) == 1 + else: + raise NotImplementedError(f"Unsupported parameter name: {name}") + return convert_names, params + + def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + direct_name_mapping = { + "embedding.word_embeddings.weight": "model.embed_tokens.weight", + "decoder.final_layernorm.weight": "model.norm.weight", + "output_layer.weight": "lm_head.weight", + } + if name in direct_name_mapping: + return [direct_name_mapping[name]], [params_one_group[0]] + + if "self_attention" in name: + return self._convert_attention_param(name, params_one_group) + elif "mlp" in name: + return self._convert_mlp_param(name, params_one_group) + else: + raise NotImplementedError(f"Unsupported parameter name: {name}") + + +class McoreToHFWeightConverterQwen2Moe(McoreToHFWeightConverterDense): + def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + # 'decoder.layers.0.pre_mlp_layernorm.weight', + # 'decoder.layers.0.mlp.router.weight', + # 'decoder.layers.0.mlp.shared_experts.gate_weight', + # 'decoder.layers.0.mlp.shared_experts.linear_fc1.weight', + # 'decoder.layers.0.mlp.shared_experts.linear_fc2.weight' + # moe1 + # 'decoder.layers.0.mlp.experts.linear_fc1.weight0', + # 'decoder.layers.0.mlp.experts.linear_fc1.weight1', + # 'decoder.layers.0.mlp.experts.linear_fc1.weight2', + # 'decoder.layers.0.mlp.experts.linear_fc1.weight3', + # moe2 + # 'decoder.layers.0.mlp.experts.linear_fc2.weight0', + # 'decoder.layers.0.mlp.experts.linear_fc2.weight1', + layer_number = name.split(".")[2] + convert_names = [] + if "pre_mlp_layernorm" in name: + convert_names.append(f"model.layers.{layer_number}.post_attention_layernorm.weight") + assert len(params) == 1 + elif "mlp.router.weight" in name: + convert_names.append(f"model.layers.{layer_number}.mlp.gate.weight") + assert len(params) == 1 + elif "shared_experts.gate_weight" in name: + convert_names.append(f"model.layers.{layer_number}.mlp.shared_expert_gate.weight") + assert len(params) == 1 + elif "shared_experts.linear_fc1.weight" in name: # split gate_proj and up_proj + convert_names.append(f"model.layers.{layer_number}.mlp.shared_expert.gate_proj.weight") + convert_names.append(f"model.layers.{layer_number}.mlp.shared_expert.up_proj.weight") + assert len(params) == 2 + elif "shared_experts.linear_fc2.weight" in name: + convert_names.append(f"model.layers.{layer_number}.mlp.shared_expert.down_proj.weight") + assert len(params) == 1 + elif "mlp.experts.linear_fc1" in name: # split gate_proj and up_proj + expert_id = name.split("weight")[-1] + convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.gate_proj.weight") + convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.up_proj.weight") + assert len(params) == 2 + elif "mlp.experts.linear_fc2" in name: + expert_id = name.split("weight")[-1] + convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.down_proj.weight") + assert len(params) == 1 + else: + raise NotImplementedError(f"Unsupported parameter name: {name}") + return convert_names, params + + +class McoreToHFWeightConverterQwen2_5_VL(McoreToHFWeightConverterDense): + def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + direct_name_mapping = { + "language_model.embedding.word_embeddings.weight": "model.embed_tokens.weight", + "language_model.decoder.final_layernorm.weight": "model.norm.weight", + "language_model.output_layer.weight": "lm_head.weight", + "vision_model.patch_embed.proj.weight": "visual.patch_embed.proj.weight", + "vision_model.decoder.final_layernorm.weight": "visual.merger.ln_q.weight", + "vision_model.projection.encoder.linear_fc1.weight": "visual.merger.mlp.0.weight", + "vision_model.projection.encoder.linear_fc1.bias": "visual.merger.mlp.0.bias", + "vision_model.projection.encoder.linear_fc2.weight": "visual.merger.mlp.2.weight", + "vision_model.projection.encoder.linear_fc2.bias": "visual.merger.mlp.2.bias", + } + if name in direct_name_mapping: + return [direct_name_mapping[name]], [params_one_group[0]] + + if "self_attention" in name: + return self._convert_attention_param(name, params_one_group) + elif "mlp" in name: + return self._convert_mlp_param(name, params_one_group) + else: + raise NotImplementedError(f"Unsupported parameter name: {name}") + + def _convert_attention_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + model_type, _, _, layer_number = name.split(".")[:4] + + convert_names = [] + if model_type == "language_model": + name_map_after_layer = { + "self_attention.linear_qkv.bias": [ + "self_attn.q_proj.bias", + "self_attn.k_proj.bias", + "self_attn.v_proj.bias", + ], + "self_attention.linear_qkv.weight": [ + "self_attn.q_proj.weight", + "self_attn.k_proj.weight", + "self_attn.v_proj.weight", + ], + "self_attention.linear_proj.weight": "self_attn.o_proj.weight", + "self_attention.linear_qkv.layer_norm_weight": "input_layernorm.weight", + } + name_after_layer = ".".join(name.split(".")[-3:]) + mapped_name = name_map_after_layer.get(name_after_layer) + if isinstance(mapped_name, list): + assert len(params) == len(mapped_name) + for one in mapped_name: + convert_names.append(f"model.layers.{layer_number}.{one}") + else: + assert len(params) == 1 + convert_names.append(f"model.layers.{layer_number}.{mapped_name}") + elif model_type == "vision_model": + name_map_after_layer = { + "self_attention.linear_proj.weight": "attn.proj.weight", + "self_attention.linear_proj.bias": "attn.proj.bias", + "self_attention.linear_qkv.layer_norm_weight": "norm1.weight", + } + name_after_layer = ".".join(name.split(".")[-3:]) + mapped_name = name_map_after_layer.get(name_after_layer, None) + if mapped_name is None: + assert "linear_qkv" in name_after_layer + assert len(params) == 3 + new_param = torch.cat(params, dim=0) + params = [new_param] + if "bias" in name_after_layer: + convert_names.append(f"visual.blocks.{layer_number}.attn.qkv.bias") + else: + convert_names.append(f"visual.blocks.{layer_number}.attn.qkv.weight") + else: + assert len(params) == 1 + convert_names.append(f"visual.blocks.{layer_number}.{mapped_name}") + else: + raise NotImplementedError(f"Unsupported model type: {model_type}") + return convert_names, params + + def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + model_type, _, _, layer_number = name.split(".")[:4] + + convert_names = [] + if model_type == "language_model": + name_map_after_layer = { + "mlp.linear_fc1.weight": ["mlp.gate_proj.weight", "mlp.up_proj.weight"], + "mlp.linear_fc1.bias": ["mlp.gate_proj.bias", "mlp.up_proj.bias"], + "mlp.linear_fc2.weight": "mlp.down_proj.weight", + "mlp.linear_fc2.bias": "mlp.down_proj.bias", + "mlp.linear_fc1.layer_norm_weight": "post_attention_layernorm.weight", + } + name_after_layer = ".".join(name.split(".")[-3:]) + mapped_name = name_map_after_layer.get(name_after_layer) + if isinstance(mapped_name, list): + assert len(params) == len(mapped_name) + for one in mapped_name: + convert_names.append(f"model.layers.{layer_number}.{one}") + else: + assert len(params) == 1 + convert_names.append(f"model.layers.{layer_number}.{mapped_name}") + + elif model_type == "vision_model": + name_map_after_layer = { + "mlp.linear_fc1.weight": ["mlp.gate_proj.weight", "mlp.up_proj.weight"], + "mlp.linear_fc1.bias": ["mlp.gate_proj.bias", "mlp.up_proj.bias"], + "mlp.linear_fc2.weight": "mlp.down_proj.weight", + "mlp.linear_fc2.bias": "mlp.down_proj.bias", + "mlp.linear_fc1.layer_norm_weight": "norm2.weight", + } + name_after_layer = ".".join(name.split(".")[-3:]) + mapped_name = name_map_after_layer.get(name_after_layer) + if isinstance(mapped_name, list): + assert len(params) == len(mapped_name) + for one in mapped_name: + convert_names.append(f"visual.blocks.{layer_number}.{one}") + else: + assert len(params) == 1 + convert_names.append(f"visual.blocks.{layer_number}.{mapped_name}") + else: + raise NotImplementedError(f"Unsupported model type: {model_type}") + return convert_names, params + + +class McoreToHFWeightConverterDpskv3(McoreToHFWeightConverterBase): + def _convert_attention_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + # mcore + # 'decoder.layers.0.input_layernorm.weight' + # 'decoder.layers.0.self_attention.linear_proj.weight' + # 'decoder.layers.0.self_attention.linear_q_proj.weight' + # 'decoder.layers.0.self_attention.linear_kv_down_proj.weight' + # 'decoder.layers.0.self_attention.linear_kv_up_proj.layer_norm_weight' + # 'decoder.layers.0.self_attention.linear_kv_up_proj.weight' + # 'decoder.layers.0.self_attention.linear_q_down_proj.weight' + # 'decoder.layers.0.self_attention.linear_q_up_proj.weight' + # 'decoder.layers.0.self_attention.linear_q_up_proj.layer_norm_weight' + # hf + # 'model.layers.0.input_layernorm.weight' + # 'model.layers.0.self_attn.o_proj.weight' + # 'model.layers.0.self_attn.q_proj.weight' + # 'model.layers.0.self_attn.kv_a_proj_with_mqa.weight' + # 'model.layers.0.self_attn.kv_a_layernorm.weight' + # 'model.layers.0.self_attn.kv_b_proj.weight' + # 'model.layers.0.self_attn.q_a_proj.weight' + # 'model.layers.0.self_attn.q_b_proj.weight' + # 'model.layers.0.self_attn.q_a_layernorm.weight' + name_map_after_layer = { + "input_layernorm.weight": "input_layernorm.weight", + "self_attention.linear_proj.weight": "self_attn.o_proj.weight", + "self_attention.linear_q_proj.weight": "self_attn.q_proj.weight", + "self_attention.linear_kv_down_proj.weight": "self_attn.kv_a_proj_with_mqa.weight", + "self_attention.linear_kv_up_proj.layer_norm_weight": "self_attn.kv_a_layernorm.weight", + "self_attention.linear_kv_up_proj.weight": "self_attn.kv_b_proj.weight", + "self_attention.linear_q_down_proj.weight": "self_attn.q_a_proj.weight", + "self_attention.linear_q_up_proj.weight": "self_attn.q_b_proj.weight", + "self_attention.linear_q_up_proj.layer_norm_weight": "self_attn.q_a_layernorm.weight", + } + assert len(params) == 1 + convert_names = [] + layer_number = name.split(".")[2] + name_after_layer = name.split(f".{layer_number}.")[1] + convert_names.append(f"model.layers.{layer_number}.{name_map_after_layer[name_after_layer]}") + return convert_names, params + + def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + # mcore dense + # 'decoder.layers.0.mlp.linear_fc1.layer_norm_weight' + # 'decoder.layers.0.mlp.linear_fc2.weight' + # 'decoder.layers.0.mlp.linear_fc1.weight' + # --- + # 'decoder.layers.1.mlp.shared_experts.linear_fc1.weight' + # --- + # 'decoder.layers.1.mlp.shared_experts.linear_fc2.weight' + # hf dense + # 'model.layers.0.post_attention_layernorm.weight' + # 'model.layers.0.mlp.down_proj.weight' + # 'model.layers.0.mlp.gate_proj.weight' + # 'model.layers.0.mlp.up_proj.weight' + # 'model.layers.1.mlp.shared_experts.gate_proj.weight' + # 'model.layers.1.mlp.shared_experts.up_proj.weight' + # 'model.layers.1.mlp.shared_experts.down_proj.weight' + + # mcore moe + # 'decoder.layers.1.pre_mlp_layernorm.weight' + # 'decoder.layers.1.mlp.router.weight' + # 'decoder.layers.1.mlp.router.expert_bias' + # 'decoder.layers.1.mlp.experts.linear_fc1.weight0' + # --- + # 'decoder.layers.1.mlp.experts.linear_fc2.weight0' + # hf moe + # 'model.layers.1.post_attention_layernorm.weight' + # 'model.layers.1.mlp.gate.weight' + # 'model.layers.1.mlp.gate.e_score_correction_bias' + # 'model.layers.1.mlp.experts.0.gate_proj.weight' + # 'model.layers.1.mlp.experts.0.up_proj.weight' + # 'model.layers.1.mlp.experts.0.down_proj.weight' + + name_map_after_layer = { + "mlp.linear_fc1.layer_norm_weight": "post_attention_layernorm.weight", + "mlp.linear_fc2.weight": "mlp.down_proj.weight", + "mlp.shared_experts.linear_fc2.weight": "mlp.shared_experts.down_proj.weight", + "mlp.linear_fc1.weight": ["mlp.gate_proj.weight", "mlp.up_proj.weight"], + "mlp.shared_experts.linear_fc1.weight": [ + "mlp.shared_experts.gate_proj.weight", + "mlp.shared_experts.up_proj.weight", + ], + "pre_mlp_layernorm.weight": "post_attention_layernorm.weight", + "mlp.router.weight": "mlp.gate.weight", + "mlp.router.expert_bias": "mlp.gate.e_score_correction_bias", + } + convert_names = [] + layer_number = name.split(".")[2] + name_after_layer = name.split(f".{layer_number}.")[1] + if name_after_layer in name_map_after_layer: + mapped_name = name_map_after_layer[name_after_layer] + if isinstance(mapped_name, list): + assert len(params) == len(mapped_name) + for one in mapped_name: + convert_names.append(f"model.layers.{layer_number}.{one}") + else: + assert len(params) == 1 + convert_names.append(f"model.layers.{layer_number}.{mapped_name}") + else: + if "mlp.experts.linear_fc1.weight" in name: + expert_id = name.split("weight")[-1] + convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.gate_proj.weight") + convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.up_proj.weight") + assert len(params) == 2 + elif "mlp.experts.linear_fc2.weight" in name: + expert_id = name.split("weight")[-1] + convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.down_proj.weight") + assert len(params) == 1 + else: + raise NotImplementedError(f"Unsupported parameter name: {name}") + + return convert_names, params + + def _convert_mtp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + assert self.mcore_config.mtp_num_layers == 1, "only support one mtp layer for now" + assert self.mcore_config.num_layers == 61, "only support 61 layers for now" + direct_name_mapping = { + "mtp.layers.0.enorm.weight": "model.layers.61.enorm.weight", + "mtp.layers.0.hnorm.weight": "model.layers.61.hnorm.weight", + "mtp.layers.0.eh_proj.weight": "model.layers.61.eh_proj.weight", + "mtp.layers.0.final_layernorm.weight": "model.layers.61.shared_head.norm.weight", + } + if name in direct_name_mapping: + return [direct_name_mapping[name]], [params[0]] + assert "mtp.layers.0.transformer_layer" in name, "only support transformer layer for now" + # use proxy name to convert + proxy_name = name.replace("mtp.layers.0.transformer_layer", "decoder.layers.61") + if "self_attention" in proxy_name or "input_layernorm.weight" in proxy_name: + convert_names, params = self._convert_attention_param(proxy_name, params) + elif "mlp" in proxy_name: + convert_names, params = self._convert_mlp_param(proxy_name, params) + else: + raise NotImplementedError(f"Unsupported parameter name: {name}") + return convert_names, params + + def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + direct_name_mapping = { + "embedding.word_embeddings.weight": "model.embed_tokens.weight", + "decoder.final_layernorm.weight": "model.norm.weight", + "output_layer.weight": "lm_head.weight", + } + if name in direct_name_mapping: + return [direct_name_mapping[name]], [params_one_group[0]] + if "mtp" in name: + return self._convert_mtp_param(name, params_one_group) + elif "self_attention" in name or "input_layernorm.weight" in name: + return self._convert_attention_param(name, params_one_group) + elif "mlp" in name: + return self._convert_mlp_param(name, params_one_group) + else: + raise NotImplementedError(f"Unsupported parameter name: {name}") + + +class McoreToHFWeightConverterMixtral(McoreToHFWeightConverterDense): + def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + # decoder.layers.0.mlp.router.weight + # decoder.layers.0.mlp.experts.linear_fc1.weight0 - weight7 + # decoder.layers.0.mlp.experts.linear_fc2.weight0 - weight7 + + layer_number = name.split(".")[2] + convert_names = [] + if "pre_mlp_layernorm" in name: + convert_names.append(f"model.layers.{layer_number}.post_attention_layernorm.weight") + elif "mlp.router.weight" in name: + convert_names.append(f"model.layers.{layer_number}.block_sparse_moe.gate.weight") + elif "mlp.experts.linear_fc1.weight" in name: + expert_id = name.split("weight")[-1] + convert_names.append(f"model.layers.{layer_number}.block_sparse_moe.experts.{expert_id}.w1.weight") + convert_names.append(f"model.layers.{layer_number}.block_sparse_moe.experts.{expert_id}.w3.weight") + elif "mlp.experts.linear_fc2.weight" in name: + expert_id = name.split("weight")[-1] + convert_names.append(f"model.layers.{layer_number}.block_sparse_moe.experts.{expert_id}.w2.weight") + else: + raise NotImplementedError(f"Unsupported parameter name: {name}") + return convert_names, params + + +class McoreToHFWeightConverterQwen3Moe(McoreToHFWeightConverterDense): + def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + # qwen3 moe no share expert + + # 'decoder.layers.0.pre_mlp_layernorm.weight', + # 'decoder.layers.0.mlp.router.weight', + # moe1 + # 'decoder.layers.0.mlp.experts.linear_fc1.weight0', + # 'decoder.layers.0.mlp.experts.linear_fc1.weight1', + # 'decoder.layers.0.mlp.experts.linear_fc1.weight2', + # 'decoder.layers.0.mlp.experts.linear_fc1.weight3', + # moe2 + # 'decoder.layers.0.mlp.experts.linear_fc2.weight0', + # 'decoder.layers.0.mlp.experts.linear_fc2.weight1', + layer_number = name.split(".")[2] + convert_names = [] + if "pre_mlp_layernorm" in name: + convert_names.append(f"model.layers.{layer_number}.post_attention_layernorm.weight") + assert len(params) == 1 + elif "mlp.router.weight" in name: + convert_names.append(f"model.layers.{layer_number}.mlp.gate.weight") + assert len(params) == 1 + elif "mlp.experts.linear_fc1" in name: # split gate_proj and up_proj + expert_id = name.split("weight")[-1] + convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.gate_proj.weight") + convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.up_proj.weight") + assert len(params) == 2 + elif "mlp.experts.linear_fc2" in name: + expert_id = name.split("weight")[-1] + convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.down_proj.weight") + assert len(params) == 1 + else: + raise NotImplementedError(f"Unsupported parameter name: {name}") + return convert_names, params diff --git a/code/RL_model/verl/verl_train/verl/models/qwen2/__init__.py b/code/RL_model/verl/verl_train/verl/models/qwen2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1ce90c5eb352d85c59105c0dc85b5f1dd576f095 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/qwen2/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/code/RL_model/verl/verl_train/verl/models/qwen2/megatron/__init__.py b/code/RL_model/verl/verl_train/verl/models/qwen2/megatron/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..57e33ee9e905a64eb92df812d2f0bc6126066042 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/qwen2/megatron/__init__.py @@ -0,0 +1,34 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .modeling_qwen2_megatron import ( + ParallelQwen2ForCausalLM, + # rmpad with megatron + ParallelQwen2ForCausalLMRmPad, + # rmpad with megatron and pipeline parallelism + ParallelQwen2ForCausalLMRmPadPP, + ParallelQwen2ForValueRmPad, + ParallelQwen2ForValueRmPadPP, + # original model with megatron + ParallelQwen2Model, +) + +__all__ = [ + "ParallelQwen2ForCausalLM", + "ParallelQwen2ForCausalLMRmPad", + "ParallelQwen2ForCausalLMRmPadPP", + "ParallelQwen2ForValueRmPad", + "ParallelQwen2ForValueRmPadPP", + "ParallelQwen2Model", +] diff --git a/code/RL_model/verl/verl_train/verl/models/qwen2/megatron/checkpoint_utils/__init__.py b/code/RL_model/verl/verl_train/verl/models/qwen2/megatron/checkpoint_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1ce90c5eb352d85c59105c0dc85b5f1dd576f095 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/qwen2/megatron/checkpoint_utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/code/RL_model/verl/verl_train/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py b/code/RL_model/verl/verl_train/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..3168635c7fe7b5b0e35a8e99b189057acbb8a5cb --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py @@ -0,0 +1,337 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import torch +import torch.distributed as dist + +from verl.utils.device import get_device_id, get_torch_device + + +def _megatron_calc_layer_map(config): + """Calculate the mapping of global layer_idx to local layer_idx + Returns: + layer_map (Dict: int -> tuple(int, int, int)): + mapping from the global layer index to + a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) + """ + from megatron.core import mpu + + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + + layer_map = dict() + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + for pp_rank_idx in range(pp_size): + for virtual_pp_rank_idx in range(virtual_pp_size): + layer_offset = ( + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + ) + for layer_idx in range(num_layers_per_model): + layer_map[layer_offset + layer_idx] = ( + pp_rank_idx, + virtual_pp_rank_idx, + layer_idx, + ) + return layer_map + + +def load_state_dict_to_megatron_qwen2( + state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False +): + """Load merged state_dict to sharded Megatron module in training.""" + from megatron.core import DistributedDataParallel as LocalDDP + from megatron.core import mpu + from megatron.core.transformer.module import Float16Module + from torch.nn.parallel import DistributedDataParallel as torchDDP + + from verl.utils.logger import print_rank_0 + from verl.utils.megatron_utils import unwrap_model + + start_time = time.time() + + def _get_gpt_model(model): + return model + + def fetch_params(module): + for param in module.parameters(): + torch.distributed.fetch( + param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group() + ) + + dp_rank = mpu.get_data_parallel_rank() + pp_rank = mpu.get_pipeline_model_parallel_rank() + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + mp_group = mpu.get_model_parallel_group() + + if torch.distributed.get_rank() == 0: + assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" + assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" + assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" + + if not isinstance(wrapped_models, list | tuple): + wrapped_models = list(wrapped_models) + + assert len(wrapped_models) == virtual_pp_size + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, ( + f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size: " + f"{virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}" + ) + + models = [None] * len(wrapped_models) + + for i, wrapped_model in enumerate(wrapped_models): + models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) + gpt_model_module = _get_gpt_model(models[i]) + assert len(gpt_model_module.model.layers) == num_layers_per_model + + def _fetch_tensor(tensor, name) -> torch.Tensor: + """fetch tensor""" + nonlocal state_dict + if tensor is not None: + tensor = tensor.data.copy_(state_dict[name], non_blocking=True) + + def _fetch_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + """fetch tensor in tp shards""" + nonlocal state_dict + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + if name in state_dict: + full_weight = state_dict[name] + + if mutate_func is not None: + full_weight = mutate_func(full_weight) + tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) + if tensor is not None: + tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True) + else: + print(f"tp_shard tensor:[{name}] not in state_dict, skip loading") + + def _fetch_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + """fetch tensor in tp shards""" + nonlocal state_dict + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + if name in state_dict: + full_weight = state_dict[name] + + if mutate_func is not None: + full_weight = mutate_func(full_weight) + tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) + if tensor is not None: + tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True) + else: + print(f"tp_shard tensor:[{name}] not in state_dict, skip loading") + + def _fetch_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor: + """fetch gate_up tensor in tp shards""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + if gate_name in state_dict and up_name in state_dict: + gate_weight = state_dict[gate_name] + up_weight = state_dict[up_name] + new_gate_up_weight = torch.empty( + config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) + for i in range(tp_size): + intermediate_size_tp = config.intermediate_size // tp_size + gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_( + torch.cat([gate_weight_tp, up_weight_tp], dim=0) + ) + + tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) + if tensor is not None: + tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True) + else: + print(f"tp_shard tensor:[{gate_name}, {up_name}] not in state_dict, skip loading") + + def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor: + """fetch tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + assert q_name in state_dict and k_name in state_dict and v_name in state_dict + full_weight_q = state_dict[q_name] + full_weight_k = state_dict[k_name] + full_weight_v = state_dict[v_name] + + hidden_size_per_head = config.hidden_size // config.num_attention_heads + + if config.num_key_value_heads >= tp_size: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size + total_size = q_size_tp + 2 * kv_size_tp + if not bias: + new_weight_qkv = torch.empty( + total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) + else: + new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=get_device_id()) + for i in range(tp_size): + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] + k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] + v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp] + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) + + else: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head + total_size = q_size_tp + 2 * kv_size_tp + if not bias: + new_weight_qkv = torch.empty( + total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) + else: + new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=get_device_id()) + for i in range(tp_size): + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] + start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head + end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head + k_part = full_weight_k[start_idx:end_idx] + v_part = full_weight_v[start_idx:end_idx] + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) + + tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) + if tensor is not None: + tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True) + + # Embeddings + # ------------------- + print_rank_0("loading embeddings...") + gpt_model_module = _get_gpt_model(models[0]) + if pp_rank == 0: + embed_tokens_weight = gpt_model_module.model.embed_tokens.weight + _fetch_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight") + + # Transformer layers + # ------------------- + layer_map = _megatron_calc_layer_map(config) + + pp_rank = mpu.get_pipeline_model_parallel_rank() + pp_size = mpu.get_pipeline_model_parallel_world_size() + num_layer_per_pp = config.num_hidden_layers // pp_size + vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size() + + layer_list = [] + if vpp_size is not None: + for vpp_rank in range(vpp_size): + num_layer_vpp_chunk = num_layer_per_pp // vpp_size + num_layer_this_model = num_layer_vpp_chunk + offset = vpp_rank * (config.num_hidden_layers // mpu.get_virtual_pipeline_model_parallel_world_size()) + ( + mpu.get_pipeline_model_parallel_rank() * num_layer_vpp_chunk + ) + layer_list.extend(list(range(offset, offset + num_layer_this_model))) + else: + num_layer_this_model = num_layer_per_pp + offset = pp_rank * num_layer_per_pp + layer_list.extend(list(range(offset, offset + num_layer_this_model))) + + for layer in layer_list: + print(f"{torch.distributed.get_rank()} loading layer #{layer}...") + layer_name = f"model.layers.{layer}" + dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer] + + print( + f"{torch.distributed.get_rank()} offset: {offset}, num_layer_this_model: {num_layer_this_model}, " + f"layer_name: {layer_name}, layer_map[layer]: {layer_map[layer]}" + ) + + gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank]) + sync_layer = gpt_model_module.model.layers[dst_layer_idx] + + _fetch_tensor( + sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.input_layernorm.weight", + ) + + _fetch_tp_shard_tensor_qkv( + sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.q_proj.weight", + f"{layer_name}.self_attn.k_proj.weight", + f"{layer_name}.self_attn.v_proj.weight", + ) + + _fetch_tp_shard_tensor_qkv( + sync_layer.self_attn.qkv_proj.bias if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.q_proj.bias", + f"{layer_name}.self_attn.k_proj.bias", + f"{layer_name}.self_attn.v_proj.bias", + bias=True, + ) + + _fetch_tp_shard_tensor( + sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.o_proj.weight", + chunk_dim=1, + ) + + _fetch_tensor( + sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.post_attention_layernorm.weight", + ) + + _fetch_tp_shard_tensor_gate_up( + sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.mlp.gate_proj.weight", + f"{layer_name}.mlp.up_proj.weight", + ) + + _fetch_tp_shard_tensor( + sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.mlp.down_proj.weight", + chunk_dim=1, + ) + # Final Layernorm + # ------------------- + print_rank_0("loading final layernorm...") + gpt_model_module = _get_gpt_model(models[-1]) + _fetch_tensor( + getattr(gpt_model_module.model.norm, "weight", None), + "model.norm.weight", + ) + + if tie_word_embeddings: + print_rank_0("tie_word_embeddings skip load lm_head") + else: + print_rank_0("loading lm_head...") + if pp_rank + 1 == pp_size: + lm_head_weight = gpt_model_module.lm_head.weight + + if is_value_model: + if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1: + _fetch_tensor(lm_head_weight, "lm_head.weight") + print_rank_0("load lm_head from value_head weight") + elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1: + _fetch_tensor(lm_head_weight, "reward_head.weight") + print_rank_0("load lm_head from value_head weight") + else: + _fetch_tensor(None, "lm_head.weight") + print_rank_0("fail to match lm_head in value_model") + + else: + _fetch_tp_shard_tensor(lm_head_weight, "lm_head.weight") + + dist.barrier() + get_torch_device().empty_cache() + print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") diff --git a/code/RL_model/verl/verl_train/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py b/code/RL_model/verl/verl_train/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py new file mode 100644 index 0000000000000000000000000000000000000000..770e3653366321159ec079c42009052aeaf26510 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py @@ -0,0 +1,475 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import torch +import torch.distributed as dist + +from verl.utils.device import get_device_id, get_torch_device + + +def _megatron_calc_layer_map(config): + """Calculate the mapping of global layer_idx to local layer_idx + Returns: + layer_map (Dict: int -> tuple(int, int, int)): + mapping from the global layer index to + a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) + """ + from megatron.core import mpu + + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + + layer_map = dict() + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + for pp_rank_idx in range(pp_size): + for virtual_pp_rank_idx in range(virtual_pp_size): + layer_offset = ( + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + ) + for layer_idx in range(num_layers_per_model): + layer_map[layer_offset + layer_idx] = ( + pp_rank_idx, + virtual_pp_rank_idx, + layer_idx, + ) + return layer_map + + +def load_state_dict_to_megatron_qwen2( + state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False +): + """Load merged state_dict to sharded Megatron module in training.""" + from megatron.core import DistributedDataParallel as LocalDDP + from megatron.core import mpu + from megatron.core.transformer.module import Float16Module + from torch.nn.parallel import DistributedDataParallel as torchDDP + + from verl.utils.logger import print_rank_0 + from verl.utils.megatron_utils import unwrap_model + + start_time = time.time() + + def _get_gpt_model(model): + return model + + def broadcast_params(module): + for param in module.parameters(): + torch.distributed.broadcast( + param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group() + ) + + dp_rank = mpu.get_data_parallel_rank() + pp_rank = mpu.get_pipeline_model_parallel_rank() + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + mp_group = mpu.get_model_parallel_group() + + if torch.distributed.get_rank() == 0: + assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" + assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" + assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" + + if not isinstance(wrapped_models, list | tuple): + wrapped_models = list(wrapped_models) + + assert len(wrapped_models) == virtual_pp_size + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, ( + f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size: " + f"{virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}" + ) + + models = [None] * len(wrapped_models) + + for i, wrapped_model in enumerate(wrapped_models): + models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) + gpt_model_module = _get_gpt_model(models[i]) + assert len(gpt_model_module.model.layers) == num_layers_per_model + + def _broadcast_tensor(tensor, name) -> torch.Tensor: + """broadcast tensor from rank0 across mp_group""" + nonlocal state_dict + nonlocal mp_group + if torch.distributed.get_rank() == 0: + if name in state_dict: + weight = state_dict[name] + tensor_shape = weight.shape + else: + tensor_shape = None + else: + weight = None + tensor_shape = None + + obj_list = [tensor_shape] + dist.broadcast_object_list(obj_list, src=0, group=mp_group) + tensor_shape = obj_list[0] + + if tensor_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tensor:[{name}] not in state_dict, skip load") + return + + if tensor is None: + tensor = torch.empty( + tensor_shape, + dtype=params_dtype, + device=get_device_id(), + requires_grad=False, + ) + if torch.distributed.get_rank() == 0: + tensor.data.copy_(weight) + dist.broadcast(tensor, src=0, group=mp_group) + + def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == 0: + if name in state_dict: + full_weight = state_dict[name] + + if mutate_func is not None: + full_weight = mutate_func(full_weight) + tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=0, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=get_device_id(), + requires_grad=False, + ) + else: + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + ) + sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == 0: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=0, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == 0: + if name in state_dict: + full_weight = state_dict[name] + if mutate_func is not None: + full_weight = mutate_func(full_weight) + tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=0, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=get_device_id(), + requires_grad=False, + ) + else: + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + ) + sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == 0: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=0, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == 0: + gate_weight = state_dict[gate_name] + up_weight = state_dict[up_name] + new_gate_up_weight = torch.empty( + config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) + for i in range(tp_size): + intermediate_size_tp = config.intermediate_size // tp_size + gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_( + torch.cat([gate_weight_tp, up_weight_tp], dim=0) + ) + + tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=0, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=get_device_id(), + requires_grad=False, + ) + else: + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape " + f"{tensor.shape} != {chunk_shape}" + ) + sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == 0: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=0, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == 0: + assert q_name in state_dict and k_name in state_dict and v_name in state_dict + full_weight_q = state_dict[q_name] + full_weight_k = state_dict[k_name] + full_weight_v = state_dict[v_name] + + hidden_size_per_head = config.hidden_size // config.num_attention_heads + + if config.num_key_value_heads >= tp_size: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size + total_size = q_size_tp + 2 * kv_size_tp + if not bias: + new_weight_qkv = torch.empty( + total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) + else: + new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=get_device_id()) + for i in range(tp_size): + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] + k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] + v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp] + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_( + torch.cat([q_part, k_part, v_part], dim=0) + ) + + else: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head + total_size = q_size_tp + 2 * kv_size_tp + if not bias: + new_weight_qkv = torch.empty( + total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) + else: + new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=get_device_id()) + for i in range(tp_size): + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] + start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head + end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head + k_part = full_weight_k[start_idx:end_idx] + v_part = full_weight_v[start_idx:end_idx] + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_( + torch.cat([q_part, k_part, v_part], dim=0) + ) + + tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=0, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=get_device_id(), + requires_grad=False, + ) + else: + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}" + ) + sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == 0: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=0, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + if dp_rank == 0: + # Embeddings + # ------------------- + print_rank_0("loading embeddings...") + gpt_model_module = _get_gpt_model(models[0]) + embed_tokens_weight = None + if pp_rank == 0: + embed_tokens_weight = gpt_model_module.model.embed_tokens.weight + _broadcast_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight") + + # Transformer layers + # ------------------- + layer_map = _megatron_calc_layer_map(config) + + for layer in range(config.num_hidden_layers): + print_rank_0(f"loading layer #{layer}...") + layer_name = f"model.layers.{layer}" + dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer] + + gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank]) + sync_layer = gpt_model_module.model.layers[dst_layer_idx] + + _broadcast_tensor( + sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.input_layernorm.weight", + ) + + _broadcast_tp_shard_tensor_qkv( + sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.q_proj.weight", + f"{layer_name}.self_attn.k_proj.weight", + f"{layer_name}.self_attn.v_proj.weight", + ) + + _broadcast_tp_shard_tensor_qkv( + sync_layer.self_attn.qkv_proj.bias if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.q_proj.bias", + f"{layer_name}.self_attn.k_proj.bias", + f"{layer_name}.self_attn.v_proj.bias", + bias=True, + ) + + _broadcast_tp_shard_tensor( + sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.o_proj.weight", + chunk_dim=1, + ) + + _broadcast_tensor( + sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.post_attention_layernorm.weight", + ) + + _broadcast_tp_shard_tensor_gate_up( + sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.mlp.gate_proj.weight", + f"{layer_name}.mlp.up_proj.weight", + ) + + _broadcast_tp_shard_tensor( + sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.mlp.down_proj.weight", + chunk_dim=1, + ) + # Final Layernorm + # ------------------- + print_rank_0("loading final layernorm...") + gpt_model_module = _get_gpt_model(models[-1]) + _broadcast_tensor( + getattr(gpt_model_module.model.norm, "weight", None), + "model.norm.weight", + ) + + if tie_word_embeddings: + print_rank_0("tie_word_embeddings skip load lm_head") + else: + print_rank_0("loading lm_head...") + lm_head_weight = None + if pp_rank + 1 == pp_size: + lm_head_weight = gpt_model_module.lm_head.weight + + if is_value_model: + if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1: + _broadcast_tensor(lm_head_weight, "lm_head.weight") + print_rank_0("load lm_head from value_head weight") + elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1: + _broadcast_tensor(lm_head_weight, "reward_head.weight") + print_rank_0("load lm_head from value_head weight") + else: + _broadcast_tensor(None, "lm_head.weight") + print_rank_0("fail to match lm_head in value_model") + + else: + _broadcast_tp_shard_tensor(lm_head_weight, "lm_head.weight") + + dist.barrier() + # Broadcast weights inside data parallel groups + for wrapped_model in wrapped_models: + broadcast_params(wrapped_model) + + get_torch_device().empty_cache() + print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") diff --git a/code/RL_model/verl/verl_train/verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py b/code/RL_model/verl/verl_train/verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py new file mode 100644 index 0000000000000000000000000000000000000000..737f73b4c6163ee674d97466b4fb37b71df2534b --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py @@ -0,0 +1,448 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import torch +import torch.distributed as dist +from megatron.core import mpu +from megatron.core.distributed import DistributedDataParallel as LocalDDP +from megatron.core.transformer.module import Float16Module +from torch.nn.parallel import DistributedDataParallel as torchDDP + +from verl.utils.device import get_device_id, get_torch_device +from verl.utils.logger import print_rank_0 +from verl.utils.megatron_utils import unwrap_model + + +def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0): + """given TP,DP,PP rank to get the global rank.""" + + tp_size = mpu.get_tensor_model_parallel_world_size() + dp_size = mpu.get_data_parallel_world_size() + pp_size = mpu.get_pipeline_model_parallel_world_size() + assert tp_size * dp_size * pp_size == torch.distributed.get_world_size(), ( + f"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}" + ) + # We only support TP-DP-PP grouping, for correctness when resharding + return (pp_rank * dp_size + dp_rank) * tp_size + tp_rank + + +def _megatron_calc_layer_map(config): + """Calculate the mapping of global layer_idx to local layer_idx + Returns: + layer_map (Dict: int -> tuple(int, int, int)): + mapping from the global layer index to + a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) + """ + from megatron.core import mpu + + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + + layer_map = dict() + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + for pp_rank_idx in range(pp_size): + for virtual_pp_rank_idx in range(virtual_pp_size): + layer_offset = ( + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + ) + for layer_idx in range(num_layers_per_model): + layer_map[layer_offset + layer_idx] = ( + pp_rank_idx, + virtual_pp_rank_idx, + layer_idx, + ) + return layer_map + + +def merge_megatron_ckpt_qwen2(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False): + """Merge sharded parameters of a Megatron module into a merged checkpoint. + + Args: + wrapped_models (list of megatron.core.distributed.DistributedDataParallel): + The local DDP wrapped megatron modules. + config (str or None): + HF config for model + dtype: model params type + is_value_model: if model is value model + tie_word_embeddings: tie_word_embeddings + Returns: + state_dict (dict): + The merged state_dict in rank 0, and an empty dictionary in other ranks. + """ + start_time = time.time() + + def _get_gpt_model(model): + return model + + dp_rank = mpu.get_data_parallel_rank() + pp_size = mpu.get_pipeline_model_parallel_world_size() + pp_rank = mpu.get_pipeline_model_parallel_rank() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + mp_group = mpu.get_model_parallel_group() + + if dist.get_rank() == 0: + assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" + assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" + assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" + + if not isinstance(wrapped_models, list | tuple): + wrapped_models = list(wrapped_models) + + assert len(wrapped_models) == virtual_pp_size + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + models = [None] * len(wrapped_models) + + for i, wrapped_model in enumerate(wrapped_models): + models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) + assert len(models[i].model.layers) == num_layers_per_model, ( + "len model layers {} not equal to num_layers_per_model {}".format( + len(models[i].model.layers), num_layers_per_model + ) + ) + + state_dict = dict() + + def _get_cpu_tensor(tensor: torch.Tensor): + if tensor is None: + return None + if tensor.device == torch.device("cpu"): + return tensor.detach().clone() + return tensor.detach().cpu() + + def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor: + """broadcast tensor across mp_group""" + nonlocal state_dict + nonlocal mp_group + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) + + if torch.distributed.get_rank() == src_rank: + if tensor is None: + weight = None + tensor_shape = None + else: + weight = tensor + tensor_shape = weight.shape + else: + weight = None + tensor_shape = None + + obj_list = [tensor_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + tensor_shape = obj_list[0] + + if tensor_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tensor:[{name}] not exist, skip collect") + return + + if weight is None: + weight = torch.empty( + tensor_shape, + dtype=dtype, + device=get_device_id(), + requires_grad=False, + ) + + dist.broadcast(weight, src=src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + state_dict[name] = _get_cpu_tensor(weight) + + def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_size = mpu.get_tensor_model_parallel_world_size() + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) + + chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{name}] not exist, skip collecting") + return + + buffer_tensor = torch.empty( + chunk_shape, + dtype=dtype, + device=get_device_id(), + requires_grad=False, + ) + + chunk_tensors = [None] * tp_size + + for i in range(tp_size): + cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) + sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + chunk_tensors[i] = _get_cpu_tensor(sync_tensor) + + if torch.distributed.get_rank() == 0: + full_tensor = torch.concat(chunk_tensors, dim=concat_dim) + if mutate_func is not None: + full_tensor = mutate_func(full_tensor) + state_dict[name] = full_tensor + + def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_size = mpu.get_tensor_model_parallel_world_size() + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) + + chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting") + return + + buffer_tensor = torch.empty( + chunk_shape, + dtype=dtype, + device=get_device_id(), + requires_grad=False, + ) + + chunk_tensors = [None] * tp_size + + for i in range(tp_size): + cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) + sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + chunk_tensors[i] = _get_cpu_tensor(sync_tensor) + + if torch.distributed.get_rank() == 0: + full_tensor = torch.concat(chunk_tensors, dim=0) + intermediate_size_tp = config.intermediate_size // tp_size + gate_weight_list = [] + up_weight_list = [] + for i in range(tp_size): + gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)] + gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp] + up_weight_tp = gate_up_weight_tp[intermediate_size_tp:] + gate_weight_list.append(gate_weight_tp) + up_weight_list.append(up_weight_tp) + + state_dict[gate_name] = torch.cat(gate_weight_list, dim=0) + state_dict[up_name] = torch.cat(up_weight_list, dim=0) + + def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_size = mpu.get_tensor_model_parallel_world_size() + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) + + chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{q_name}] not exist, skip collecting") + return + + buffer_tensor = torch.empty( + chunk_shape, + dtype=dtype, + device=get_device_id(), + requires_grad=False, + ) + + chunk_tensors = [None] * tp_size + + for i in range(tp_size): + cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) + sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + chunk_tensors[i] = _get_cpu_tensor(sync_tensor) + + if torch.distributed.get_rank() == 0: + full_tensor = torch.concat(chunk_tensors, dim=0) + q_weight_list = [] + k_weight_list = [] + v_weight_list = [] + hidden_size_per_head = config.hidden_size // config.num_attention_heads + + if config.num_key_value_heads >= tp_size: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size + total_size = q_size_tp + 2 * kv_size_tp + for i in range(tp_size): + qkv_part = full_tensor[i * total_size : (i + 1) * total_size] + q_part = qkv_part[:q_size_tp] + k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp] + v_part = qkv_part[q_size_tp + kv_size_tp : total_size] + q_weight_list.append(q_part) + k_weight_list.append(k_part) + v_weight_list.append(v_part) + else: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head + total_size = q_size_tp + 2 * kv_size_tp + for i in range(tp_size): + qkv_part = full_tensor[i * total_size : (i + 1) * total_size] + q_part = qkv_part[:q_size_tp] + k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp] + v_part = qkv_part[q_size_tp + kv_size_tp : total_size] + q_weight_list.append(q_part) + if i * config.num_key_value_heads % tp_size == 0: + k_weight_list.append(k_part) + v_weight_list.append(v_part) + + state_dict[q_name] = torch.cat(q_weight_list, dim=0) + state_dict[k_name] = torch.cat(k_weight_list, dim=0) + state_dict[v_name] = torch.cat(v_weight_list, dim=0) + + # empty cache before collecting weights + get_torch_device().empty_cache() + # Embeddings + # ------------------- + if dp_rank == 0: + # Embeddings + # ------------------- + print_rank_0("collecting embeddings...") + gpt_model_module = _get_gpt_model(models[0]) + _broadcast_tp_shard_tensor( + gpt_model_module.model.embed_tokens.weight if pp_rank == 0 else None, + "model.embed_tokens.weight", + src_pp_rank=0, + ) + + # Transformer layers + # ------------------- + layer_map = _megatron_calc_layer_map(config) + for layer in range(config.num_hidden_layers): + print_rank_0(f"collecting layer #{layer}...") + layer_name = f"model.layers.{layer}" + src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer] + + gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank]) + sync_layer = gpt_model_module.model.layers[src_layer_idx] + + _broadcast_tensor( + sync_layer.input_layernorm.weight, + f"{layer_name}.input_layernorm.weight", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor_qkv( + sync_layer.self_attn.qkv_proj.weight, + f"{layer_name}.self_attn.q_proj.weight", + f"{layer_name}.self_attn.k_proj.weight", + f"{layer_name}.self_attn.v_proj.weight", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor_qkv( + sync_layer.self_attn.qkv_proj.bias, + f"{layer_name}.self_attn.q_proj.bias", + f"{layer_name}.self_attn.k_proj.bias", + f"{layer_name}.self_attn.v_proj.bias", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor( + sync_layer.self_attn.o_proj.weight, + f"{layer_name}.self_attn.o_proj.weight", + concat_dim=1, + src_pp_rank=src_pp_rank, + ) + + _broadcast_tensor( + sync_layer.post_attention_layernorm.weight, + f"{layer_name}.post_attention_layernorm.weight", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor_gate_up( + sync_layer.mlp.gate_up_proj.weight, + f"{layer_name}.mlp.gate_proj.weight", + f"{layer_name}.mlp.up_proj.weight", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor( + sync_layer.mlp.down_proj.weight, + f"{layer_name}.mlp.down_proj.weight", + concat_dim=1, + src_pp_rank=src_pp_rank, + ) + + # Final Layernorm + # ------------------- + print_rank_0("collecting final layernorm...") + gpt_model_module = _get_gpt_model(models[-1]) + _broadcast_tensor( + getattr(gpt_model_module.model.norm, "weight", None), + "model.norm.weight", + src_pp_rank=pp_size - 1, + ) + + if tie_word_embeddings: + print_rank_0("tie word embedding skip load lm_head...") + else: + print_rank_0("collecting lm_head...") + + if is_value_model: + _broadcast_tensor( + gpt_model_module.lm_head.weight if pp_rank == pp_size - 1 else None, + "lm_head.weight", + src_pp_rank=pp_size - 1, + ) + _broadcast_tensor( + gpt_model_module.reward_head.weight + if pp_rank == pp_size - 1 and getattr(gpt_model_module, "reward_weight", None) is not None + else None, + "reward_head.weight", + src_pp_rank=pp_size - 1, + ) + + else: + _broadcast_tp_shard_tensor( + getattr(gpt_model_module.lm_head, "weight", None) if pp_rank == pp_size - 1 else None, + "lm_head.weight", + src_pp_rank=pp_size - 1, + ) + + dist.barrier() + + get_torch_device().empty_cache() + if torch.distributed.get_rank() == 0: + for k, v in state_dict.items(): + if dtype != v.dtype: + state_dict[k] = v.to(dtype) + + print_rank_0(f"merge megatron ckpt done, time elapsed {time.time() - start_time}s") + return state_dict diff --git a/code/RL_model/verl/verl_train/verl/models/qwen2/megatron/layers/__init__.py b/code/RL_model/verl/verl_train/verl/models/qwen2/megatron/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..263ea596fa758fdef2201e9e99e4a5c7d435e434 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/qwen2/megatron/layers/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .parallel_attention import ParallelQwen2Attention +from .parallel_decoder import ParallelQwen2DecoderLayer, ParallelQwen2DecoderLayerRmPad +from .parallel_mlp import ParallelQwen2MLP +from .parallel_rmsnorm import ParallelQwen2RMSNorm + +__all__ = [ + "ParallelQwen2Attention", + "ParallelQwen2DecoderLayer", + "ParallelQwen2DecoderLayerRmPad", + "ParallelQwen2MLP", + "ParallelQwen2RMSNorm", +] diff --git a/code/RL_model/verl/verl_train/verl/models/qwen2/megatron/layers/parallel_attention.py b/code/RL_model/verl/verl_train/verl/models/qwen2/megatron/layers/parallel_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..4e4f59101511e39a67e10d31f4a001c79f366ce5 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/qwen2/megatron/layers/parallel_attention.py @@ -0,0 +1,400 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional + +import torch.nn.functional as F +from einops import rearrange +from transformers.utils import is_flash_attn_2_available + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa: F401 + +import torch +from flash_attn.layers.rotary import apply_rotary_emb +from megatron.core import ModelParallelConfig, tensor_parallel +from megatron.core import parallel_state as mpu +from torch import nn +from transformers import Qwen2Config + +from verl.models.qwen2.megatron.layers.parallel_linear import QKVParallelLinear +from verl.utils.megatron import tensor_parallel as tp_utils + + +class Qwen2RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +class Qwen2LinearScalingRotaryEmbedding(Qwen2RotaryEmbedding): + """Qwen2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = t / self.scaling_factor + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +class Qwen2DynamicNTKScalingRotaryEmbedding(Qwen2RotaryEmbedding): + """Qwen2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class ParallelQwen2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): + super().__init__() + self.config = config + self.megatron_config = megatron_config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + + # assign values after tp + tp_size = mpu.get_tensor_model_parallel_world_size() + assert self.num_heads % tp_size == 0, ( + f"num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}" + ) + assert self.num_key_value_heads % tp_size == 0, ( + f"num_key_value_heads must be divisible by tp_size. Got num_key_value_heads=" + f"{self.num_key_value_heads}, tp_size={tp_size}" + ) + + self.num_heads_per_tp = self.num_heads // tp_size + self.num_key_value_heads_per_tp = self.num_key_value_heads // tp_size + self.hidden_size_per_tp = self.hidden_size // tp_size + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and " + f"`num_heads`: {self.num_heads})." + ) + + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear() + + if megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + assert row_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(column_kwargs, megatron_config) + tp_utils.update_kwargs_with_config(row_kwargs, megatron_config) + + # [self.q_size, self.k_size, self.v_size] + self.qkv_proj = QKVParallelLinear( + input_size=self.hidden_size, + num_heads=self.num_heads, + num_key_value_heads=self.num_key_value_heads, + head_dim=self.head_dim, + # bias=config.attention_bias, + bias=True, + gather_output=False, + skip_bias_add=False, + **column_kwargs, + ) + + self.q_size = self.num_heads_per_tp * self.head_dim + self.k_size = self.num_key_value_heads_per_tp * self.head_dim + self.v_size = self.num_key_value_heads_per_tp * self.head_dim + + self.o_proj = tensor_parallel.RowParallelLinear( + input_size=self.num_heads * self.head_dim, + output_size=self.hidden_size, + # bias=config.attention_bias, + bias=False, + input_is_parallel=True, + skip_bias_add=False, + **row_kwargs, + ) + + self._init_rope() + + def _init_rope(self): + self.rotary_emb = Qwen2RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + qkv = self.qkv_proj(hidden_states)[0] + query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) + + query_states = query_states.view(bsz, q_len, self.num_heads_per_tp, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads_per_tp, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads_per_tp, q_len, kv_seq_len)}, " + f"but is {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads_per_tp, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads_per_tp, q_len, self.head_dim)}, " + f"but is {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size_per_tp) + attn_output = self.o_proj(attn_output)[0] + return attn_output + + +""" +Remove padding Attention +- Using Flash-attn 2 +- Compatible with sequence parallel +""" + + +def apply_rotary_pos_emb_rmpad(q, k, cos, sin, position_ids, indices, sequence_length): + batch_size = position_ids.shape[0] + + q = pad_input(q, indices, batch_size, sequence_length) # (batch_size, seqlen, num_head, head_dim) + k = pad_input(k, indices, batch_size, sequence_length) + cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + + q_embed = index_first_axis(rearrange(q_embed, "b s ... -> (b s) ..."), indices) + k_embed = index_first_axis(rearrange(k_embed, "b s ... -> (b s) ..."), indices) + + return q_embed, k_embed + + +# use flash-attn rotary embeddings with rmpad +# cos/sin shoudl be: (seq_length, rotary_dim / 2) +def apply_rotary_pos_emb_rmpad_flash(q, k, cos, sin, cu_seqlens, max_seqlen): + q_embed = apply_rotary_emb( + q, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ) + k_embed = apply_rotary_emb( + k, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ) + return q_embed, k_embed + + +class ParallelQwen2AttentionRmPad(ParallelQwen2Attention): + def forward( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + sequence_length: int = None, + indices: torch.Tensor = None, + cu_seqlens: torch.Tensor = None, + max_seqlen_in_batch: int = None, + ): + total_nnz, _, _ = hidden_states.size() # This is the total_nnz padded after sequence parallel + + if self.megatron_config.sequence_parallel: + total_nnz = total_nnz * mpu.get_tensor_model_parallel_world_size() + + qkv = self.qkv_proj(hidden_states)[0] + query_states, key_states, value_states = qkv.split( + [self.q_size, self.k_size, self.v_size], dim=-1 + ) # (total_nnz, 1, hidden_size) + + if self.megatron_config.sequence_parallel: + sequence_parallel_pad = total_nnz - cu_seqlens[-1] + total_nnz = cu_seqlens[-1] # total_nnz before sp padding + query_states = query_states[:total_nnz] + key_states = key_states[:total_nnz] + value_states = value_states[:total_nnz] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dime x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(total_nnz, self.num_heads_per_tp, self.head_dim) + key_states = key_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim) + value_states = value_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim) + + cos, sin = self.rotary_emb(value_states, seq_len=sequence_length) + cos, sin = cos[:, : cos.shape[1] // 2], sin[:, : sin.shape[1] // 2] # flash attn only needs half + query_states, key_states = apply_rotary_pos_emb_rmpad_flash( + query_states, key_states, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen_in_batch + ) + # query_states, key_states = apply_rotary_pos_emb_rmpad(query_states, key_states, cos, sin, + # position_ids, indices, + + # It is recommended to use dropout with FA according to the docs + # when training. + dropout_rate = 0.0 # if not self.training else self.attn_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (Qwen2RMSNorm handles it correctly) + input_dtype = query_states.dtype + if input_dtype == torch.float32: + query_states = query_states.to(torch.float16) + key_states = key_states.to(torch.float16) + value_states = value_states.to(torch.float16) + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen_in_batch, + max_seqlen_k=max_seqlen_in_batch, + dropout_p=dropout_rate, + softmax_scale=None, + causal=True, + ) + + attn_output_unpad = attn_output_unpad.to(input_dtype) + attn_output_unpad = attn_output_unpad.reshape(total_nnz, 1, self.hidden_size_per_tp).contiguous() + + # sequence parallel reduce_scatter is performed inside RowColumnParallel if enabled + # Here we need to repad + if self.megatron_config.sequence_parallel: + attn_output_unpad = F.pad(attn_output_unpad, pad=(0, 0, 0, 0, 0, sequence_parallel_pad)) + + attn_output_unpad = self.o_proj(attn_output_unpad)[0] + return attn_output_unpad diff --git a/code/RL_model/verl/verl_train/verl/models/qwen2/megatron/layers/parallel_decoder.py b/code/RL_model/verl/verl_train/verl/models/qwen2/megatron/layers/parallel_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..3c8a2a6ee946eb014658006a2da6d2d602c51063 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/qwen2/megatron/layers/parallel_decoder.py @@ -0,0 +1,150 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +from megatron.core import ModelParallelConfig +from torch import nn +from transformers import Qwen2Config + +from verl.utils.megatron_utils import TransformerConfig, convert_config + +from .parallel_attention import ParallelQwen2Attention, ParallelQwen2AttentionRmPad +from .parallel_mlp import ParallelQwen2MLP +from .parallel_rmsnorm import ParallelQwen2RMSNorm + + +class ParallelQwen2DecoderLayer(nn.Module): + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, layer_idx: int): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.self_attn = ParallelQwen2Attention(config=config, megatron_config=megatron_config) + + self.mlp = ParallelQwen2MLP(config, megatron_config=megatron_config) + self.input_layernorm = ParallelQwen2RMSNorm(config, megatron_config) + self.post_attention_layernorm = ParallelQwen2RMSNorm(config, megatron_config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Note: sequence parallel is hidden inside ColumnParallelLinear + # reduce scatter is hidden inside RowParallelLinear + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + # TODO: add sequence parallel operator reduce_scatter here + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + # TODO: add sequence parallel operator all_gather here + + hidden_states = self.mlp(hidden_states) + + # TODO: add sequence parallel operator reduce_scatter here + + hidden_states = residual + hidden_states + + outputs = hidden_states + + return outputs + + +class ParallelQwen2DecoderLayerRmPad(nn.Module): + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, layer_idx: int): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + self.self_attn = ParallelQwen2AttentionRmPad(config=config, megatron_config=megatron_config) + + self.mlp = ParallelQwen2MLP(config, megatron_config=megatron_config) + self.input_layernorm = ParallelQwen2RMSNorm(config, megatron_config) + self.post_attention_layernorm = ParallelQwen2RMSNorm(config, megatron_config) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + sequence_length: int = None, + indices: torch.Tensor = None, + cu_seqlens: int = None, + max_seqlen_in_batch: int = None, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states # (total_nnz // sp, 1, hidden_size) + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + # (total_nnz // sp, 1, hidden_size) -> all-gather (total_nnz, 1, hidden_size) + # -> col + row -> reduce-scatter -> (total_nnz // sp, 1, hidden_size) + hidden_states = self.self_attn( + hidden_states=hidden_states, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + # shape changes same as attn + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = hidden_states + + return outputs diff --git a/code/RL_model/verl/verl_train/verl/models/qwen2/megatron/layers/parallel_linear.py b/code/RL_model/verl/verl_train/verl/models/qwen2/megatron/layers/parallel_linear.py new file mode 100644 index 0000000000000000000000000000000000000000..e6d4a09f43013ed75feb03fdb427bc8ad86db093 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/qwen2/megatron/layers/parallel_linear.py @@ -0,0 +1,79 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py + + +from megatron.core import tensor_parallel + + +class QKVParallelLinear(tensor_parallel.ColumnParallelLinear): + def __init__( + self, + input_size, + num_heads, + num_key_value_heads, + head_dim, + *, + bias=True, + gather_output=True, + skip_bias_add=False, + **kwargs, + ): + # Keep input parameters, and already restrict the head numbers + self.input_size = input_size + self.q_output_size = num_heads * head_dim + self.kv_output_size = num_key_value_heads * head_dim + self.head_dim = head_dim + self.gather_output = gather_output + self.skip_bias_add = skip_bias_add + + input_size = self.input_size + output_size = (num_heads + 2 * num_key_value_heads) * self.head_dim + + super().__init__( + input_size=input_size, + output_size=output_size, + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + **kwargs, + ) + + +class MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear): + def __init__( + self, + input_size, + gate_ouput_size, + up_output_size, + *, + bias=True, + gather_output=True, + skip_bias_add=False, + **kwargs, + ): + # Keep input parameters, and already restrict the head numbers + self.input_size = input_size + self.output_size = gate_ouput_size + up_output_size + self.gather_output = gather_output + self.skip_bias_add = skip_bias_add + + super().__init__( + input_size=self.input_size, + output_size=self.output_size, + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + **kwargs, + ) diff --git a/code/RL_model/verl/verl_train/verl/models/qwen2/megatron/layers/parallel_mlp.py b/code/RL_model/verl/verl_train/verl/models/qwen2/megatron/layers/parallel_mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..672908a21ae8c8e69c0536eda7fadd0431cba5fe --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/qwen2/megatron/layers/parallel_mlp.py @@ -0,0 +1,74 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from megatron.core import ModelParallelConfig, tensor_parallel +from megatron.core import parallel_state as mpu +from torch import nn +from transformers.activations import ACT2FN + +from verl.models.qwen2.megatron.layers.parallel_linear import MergedColumnParallelLinear +from verl.utils.megatron import tensor_parallel as tp_utils + + +class ParallelQwen2MLP(nn.Module): + def __init__(self, config, megatron_config: ModelParallelConfig = None) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + # The weight is only [hidden_size, intermediate_size // model_parallel_world_size] + + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear() + + if megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + assert row_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(row_kwargs, megatron_config) + tp_utils.update_kwargs_with_config(column_kwargs, megatron_config) + + tp_size = mpu.get_tensor_model_parallel_world_size() + + self.gate_up_proj = MergedColumnParallelLinear( + input_size=self.hidden_size, + gate_ouput_size=self.intermediate_size, + up_output_size=self.intermediate_size, + bias=False, + gather_output=False, + skip_bias_add=False, + **column_kwargs, + ) + self.gate_size = self.intermediate_size // tp_size + + self.down_proj = tensor_parallel.RowParallelLinear( + input_size=self.intermediate_size, + output_size=self.hidden_size, + bias=False, + input_is_parallel=True, + skip_bias_add=False, + **row_kwargs, + ) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + gate_up = self.gate_up_proj(x)[0] + gate, up = gate_up.split(self.gate_size, dim=-1) + return self.down_proj(self.act_fn(gate) * up)[0] diff --git a/code/RL_model/verl/verl_train/verl/models/qwen2/megatron/layers/parallel_rmsnorm.py b/code/RL_model/verl/verl_train/verl/models/qwen2/megatron/layers/parallel_rmsnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..2f4c90dd44e2b72f1116e3c097e52efca5567129 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/qwen2/megatron/layers/parallel_rmsnorm.py @@ -0,0 +1,48 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numbers + +import torch +from apex.normalization.fused_layer_norm import fused_rms_norm_affine +from megatron.core import ModelParallelConfig +from torch import nn +from transformers import Qwen2Config + +from verl.utils.megatron import sequence_parallel as sp_utils + + +class ParallelQwen2RMSNorm(nn.Module): + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + if isinstance(config.hidden_size, numbers.Integral): + normalized_shape = (config.hidden_size,) + self.normalized_shape = torch.Size(normalized_shape) + self.weight = nn.Parameter(torch.ones(self.normalized_shape)) + self.variance_epsilon = config.rms_norm_eps + + if megatron_config.sequence_parallel: + sp_utils.mark_parameter_as_sequence_parallel(self.weight) + + def forward(self, hidden_states): + return fused_rms_norm_affine( + input=hidden_states, + weight=self.weight, + normalized_shape=self.normalized_shape, + eps=self.variance_epsilon, + memory_efficient=True, + ) diff --git a/code/RL_model/verl/verl_train/verl/models/qwen2/megatron/modeling_qwen2_megatron.py b/code/RL_model/verl/verl_train/verl/models/qwen2/megatron/modeling_qwen2_megatron.py new file mode 100644 index 0000000000000000000000000000000000000000..b3512f8afa5dc6bf2b786e753acc22cac8d75784 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/qwen2/megatron/modeling_qwen2_megatron.py @@ -0,0 +1,737 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen2 model.""" + +from typing import Optional + +import torch +import torch.utils.checkpoint +from megatron.core import ModelParallelConfig, mpu, parallel_state, tensor_parallel +from torch import nn +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.qwen2.configuration_qwen2 import Qwen2Config +from transformers.models.qwen2.modeling_qwen2 import CausalLMOutputWithPast + +from verl.utils.device import get_device_name +from verl.utils.megatron import sequence_parallel as sp_utils +from verl.utils.megatron import tensor_parallel as tp_utils +from verl.utils.megatron_utils import TransformerConfig, convert_config + +from .layers import ParallelQwen2DecoderLayer, ParallelQwen2DecoderLayerRmPad, ParallelQwen2RMSNorm + +""" +TODO: +1. Add weight initialization. Here we need to be careful on TP weight init. +2. Add sequence parallel +3. Load checkpoint from Qwen2 pretrained checkpoint +""" + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +class ParallelQwen2Model(nn.Module): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] + + Args: + config: Qwen2Config + """ + + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() + if megatron_config is not None: + assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(embedding_kwargs, megatron_config) + self.embed_tokens = tensor_parallel.VocabParallelEmbedding( + num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs + ) + + self.layers = nn.ModuleList( + [ParallelQwen2DecoderLayer(config, megatron_config) for _ in range(config.num_hidden_layers)] + ) + self.norm = ParallelQwen2RMSNorm(config, megatron_config) + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> tuple | BaseModelOutputWithPast: + """ + + Args: + input_ids: input ids. shape (batch_size, seq_length) + attention_mask: attention_mask. shape (batch_size, seq_length) + position_ids: position ids. shape (batch_size, seq_length) + + Returns: + + """ + batch_size, seq_length = input_ids.shape + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + + attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds) + + hidden_states = inputs_embeds + + for idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + hidden_states = layer_outputs + + hidden_states = self.norm(hidden_states) + + return hidden_states + + +class ParallelQwen2ForCausalLM(nn.Module): + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.model = ParallelQwen2Model(config, megatron_config=megatron_config) + self.vocab_size = config.vocab_size + + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + if megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) + + self.lm_head = tensor_parallel.ColumnParallelLinear( + input_size=config.hidden_size, + output_size=config.vocab_size, + bias=False, + gather_output=False, + skip_bias_add=False, + **column_kwargs, + ) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> tuple | CausalLMOutputWithPast: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + ```""" + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + hidden_states = outputs + logits = self.lm_head(hidden_states)[0] + + logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) + + logits = logits.float() + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=None, + hidden_states=None, + attentions=None, + ) + + +from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa: F401, E402 + + +class ParallelQwen2ModelRmPad(nn.Module): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] + + Args: + config: Qwen2Config + """ + + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() + self.megatron_config = megatron_config + if megatron_config is not None: + assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) + self.embed_tokens = tensor_parallel.VocabParallelEmbedding( + num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs + ) + + self.layers = nn.ModuleList( + [ParallelQwen2DecoderLayerRmPad(config, megatron_config) for _ in range(config.num_hidden_layers)] + ) + self.norm = ParallelQwen2RMSNorm(config, megatron_config) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + sequence_length: int = None, + indices: torch.Tensor = None, + cu_seqlens: int = None, + max_seqlen_in_batch: int = None, + ) -> tuple | BaseModelOutputWithPast: + """ + + Args: + input_ids: input ids. shape (1, totol_nnz) + position_ids: position ids. shape (batch_size, seq_length) + + Returns: + + """ + inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size) + + # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size) + inputs_embeds = inputs_embeds.transpose(0, 1) + if self.megatron_config.sequence_parallel: + inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds) + + hidden_states = inputs_embeds + for idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer( + hidden_states, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) + + hidden_states = layer_outputs + + hidden_states = self.norm(hidden_states) + + return hidden_states + + +class ParallelQwen2ForCausalLMRmPad(nn.Module): + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.megatron_config = megatron_config + self.model = ParallelQwen2ModelRmPad(config, megatron_config=megatron_config) + self.vocab_size = config.vocab_size + self._init_head(config) + + def _init_head(self, config: Qwen2Config): + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + if self.megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) + self.lm_head = tensor_parallel.ColumnParallelLinear( + input_size=config.hidden_size, + output_size=config.vocab_size, + bias=False, + gather_output=False, + skip_bias_add=False, + **column_kwargs, + ) + + def _forward_head(self, hidden_states): + # all_gather from sequence parallel region is performed inside lm_head + logits = self.lm_head(hidden_states)[0] + logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp) + logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) # (total_nnz_padded, 1, vocab_size) + return logits + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> tuple | CausalLMOutputWithPast: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + ```""" + batch_size, sequence_length = input_ids.shape + + # remove padding here + input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input( + input_ids.unsqueeze(dim=-1), attention_mask + ) # (total_nnz, 1) + + # pad input_ids to multiple of tp for all tp ranks + # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap + if self.megatron_config.sequence_parallel: + input_ids = sp_utils.pad_to_sequence_parallel(input_ids) + + input_ids = input_ids.transpose(0, 1) # (1, total_nnz+pad) + + outputs = self.model( + input_ids=input_ids, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) + + hidden_states = outputs + + logits = self._forward_head(hidden_states) + + # remove padding from sequence parallel + if self.megatron_config.sequence_parallel: + totol_nnz = cu_seqlens[-1] + logits = logits[:totol_nnz] # (total_nnz_padded) + + logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension + # add removed padding back + logits = pad_input( + logits, indices, batch_size, seqlen=sequence_length + ) # (batch_size, sequence_length, vocab_size) + + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=None, + hidden_states=None, + attentions=None, + ) + + +class ParallelQwen2ForValueRmPad(ParallelQwen2ForCausalLMRmPad): + def _init_head(self, config): + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + if self.megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) + self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False) + # lm_head is effectively the same as sequence parallel + sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight) + + def _forward_head(self, hidden_states): + logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1) + logits = logits.float() + if self.megatron_config.sequence_parallel: + logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False) + return logits + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> tuple | CausalLMOutputWithPast: + output = super().forward(input_ids, attention_mask, position_ids) + output.logits = torch.squeeze(output.logits, dim=-1) + return output + + +""" +Support pipeline parallelism +""" + + +class ParallelQwen2ModelRmPadPP(nn.Module): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] + This model definition supports pipeline parallelism. To support pp and vpp, + - This model only contains layer in this pp stage and vpp chunk + - When calling get_model in Megatron, this rank will instantiate all the vpp chunks in this pp. + Args: + config: Qwen2Config + """ + + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, pre_process, post_process): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.pre_process = pre_process + self.post_process = post_process + self.megatron_config = megatron_config + embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() + if megatron_config is not None: + assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) + if pre_process: + self.embed_tokens = tensor_parallel.VocabParallelEmbedding( + num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs + ) + else: + self.embed_tokens = None + + pp_rank = mpu.get_pipeline_model_parallel_rank() + pp_size = megatron_config.pipeline_model_parallel_size + self.num_layer_per_pp = config.num_hidden_layers // pp_size + vpp_size = megatron_config.virtual_pipeline_model_parallel_size + vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank() + + if vpp_size is not None: + self.num_layer_vpp_chunk = self.num_layer_per_pp // vpp_size + self.num_layer_this_model = self.num_layer_vpp_chunk + offset = vpp_rank * (config.num_hidden_layers // vpp_size) + (pp_rank * self.num_layer_vpp_chunk) + else: + self.num_layer_this_model = self.num_layer_per_pp + offset = pp_rank * self.num_layer_per_pp + + self.layers = nn.ModuleList() + for i in range(self.num_layer_this_model): + layer = ParallelQwen2DecoderLayerRmPad(config, megatron_config, layer_idx=i + offset) + self.layers.add_module(f"{i}", layer) + + if post_process: + self.norm = ParallelQwen2RMSNorm(config, megatron_config) + else: + self.norm = None + + def set_input_tensor(self, input_tensor): + """Set input tensor to be used instead of forward()'s input. + + When doing pipeline parallelism the input from the previous + stage comes from communication, not from the input, so the + model's forward_step_func won't have it. This function is thus + used by internal code to bypass the input provided by the + forward_step_func""" + self.input_tensor = input_tensor + + def forward( + self, + input_ids: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + sequence_length: int = None, + indices: torch.Tensor = None, + cu_seqlens: int = None, + max_seqlen_in_batch: int = None, + ) -> tuple | BaseModelOutputWithPast: + """ + + Args: + input_ids: input ids. shape (1, totol_nnz) + position_ids: position ids. shape (batch_size, seq_length) + + Returns: + + """ + if self.pre_process: + inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size) + + # vocab parallel embedding will not do sequence parallel reduce-scatter in open source megatron + # so need to deal with it by handle here: + # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size) + inputs_embeds = inputs_embeds.transpose(0, 1) + if self.megatron_config.sequence_parallel: + inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds) + + hidden_states = inputs_embeds + else: + # self.hidden_states should be passed by Megatron + hidden_states = self.input_tensor + + for idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer( + hidden_states, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) + + hidden_states = layer_outputs + + if self.post_process: + hidden_states = self.norm(hidden_states) + + return hidden_states + + +class ParallelQwen2ForCausalLMRmPadPP(nn.Module): + def __init__( + self, + config: Qwen2Config, + megatron_config: ModelParallelConfig, + pre_process, + post_process, + share_embeddings_and_output_weights, + ): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.megatron_config = megatron_config + self.model = ParallelQwen2ModelRmPadPP( + config, megatron_config=megatron_config, pre_process=pre_process, post_process=post_process + ) + self.share_embeddings_and_output_weights = share_embeddings_and_output_weights + self.vocab_size = config.vocab_size + self.pre_process = pre_process + self.post_process = post_process + if post_process: + self._init_head(config) + if pre_process or post_process: + self.setup_embeddings_and_output_layer() + + def set_input_tensor(self, input_tensor): + """Set input tensor to be used instead of forward()'s input. + + When doing pipeline parallelism the input from the previous + stage comes from communication, not from the input, so the + model's forward_step_func won't have it. This function is thus + used by internal code to bypass the input provided by the + forward_step_func""" + assert len(input_tensor) == 1 + self.model.set_input_tensor(input_tensor[0]) + + def _init_head(self, config): + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + if self.megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) + self.lm_head = tensor_parallel.ColumnParallelLinear( + input_size=config.hidden_size, + output_size=config.vocab_size, + bias=False, + gather_output=False, + skip_bias_add=False, + skip_weight_param_allocation=self.pre_process and self.share_embeddings_and_output_weights, + **column_kwargs, + ) + + def setup_embeddings_and_output_layer(self) -> None: + """Sets up embedding layer in first stage and output layer in last stage. + + This function initializes word embeddings in the final stage when we are + using pipeline parallelism and sharing word embeddings, and sets up param + attributes on the embedding and output layers. + """ + # Set `is_embedding_or_output_parameter` attribute. + if self.pre_process: + self.model.embed_tokens.weight.is_embedding_or_output_parameter = True + if self.post_process and self.lm_head.weight is not None: + self.lm_head.weight.is_embedding_or_output_parameter = True + + if not self.share_embeddings_and_output_weights: + return + + if parallel_state.get_pipeline_model_parallel_world_size() == 1: + # Zero out wgrad if sharing embeddings between two layers on same + # pipeline stage to make sure grad accumulation into main_grad is + # correct and does not include garbage values (e.g., from torch.empty). + self.shared_embedding_or_output_weight().zero_out_wgrad = True + return + + if parallel_state.is_pipeline_first_stage() and self.pre_process and not self.post_process: + self.shared_embedding_or_output_weight().shared_embedding = True + + if self.post_process and not self.pre_process: + assert not parallel_state.is_pipeline_first_stage() + # set word_embeddings weights to 0 here, then copy first + # stage's weights using all_reduce below. + self.lm_head.weight.data.fill_(0) + self.lm_head.weight.shared = True + self.lm_head.weight.shared_embedding = True + + if torch.distributed.is_initialized() and parallel_state.is_rank_in_embedding_group(): + weight = self.shared_embedding_or_output_weight() + weight.data = weight.data.to(get_device_name()) + torch.distributed.all_reduce(weight.data, group=parallel_state.get_embedding_group()) + + def shared_embedding_or_output_weight(self) -> torch.Tensor: + if self.pre_process: + return self.model.embed_tokens.weight + elif self.post_process: + return self.lm_head.weight + return None + + def _forward_head(self, hidden_states): + # all_gather from sequence parallel region is performed inside lm_head + # print(f'logits shape before forward_head: {hidden_states.shape}, vocab_size = ' + # f'{self.config.vocab_size}') # [4, 32, 4096] + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + logits = self.lm_head(hidden_states, weight=output_weight)[0] + # print(f'logits shape after forward_head: {logits.shape}') # [8, 32, 8] + logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp) + return logits + + def forward( + self, + # original input + *, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> tuple | CausalLMOutputWithPast: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + ```""" + + # Note that input_ids, attention_mask and position_ids should be passed to every pp layer. + # In the first pp, input_ids will be used, in other pp layers hidden_states will be used inside self.model + batch_size, sequence_length = input_ids.shape + # remove padding here + input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input( + input_ids.unsqueeze(dim=-1), attention_mask + ) # (total_nnz, 1) + + # pad input_ids to multiple of tp for all tp ranks + # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap + if self.megatron_config.sequence_parallel: + input_ids_rmpad = sp_utils.pad_to_sequence_parallel(input_ids_rmpad) + + input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz+pad) + + outputs = self.model( + input_ids=input_ids_rmpad, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) + + if self.post_process: + hidden_states = outputs + logits = self._forward_head(hidden_states) + logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # torch.Size([8, 32, 16]) + + # remove padding from sequence parallel + if self.megatron_config.sequence_parallel: + totol_nnz = cu_seqlens[-1] + logits = logits[:totol_nnz] # (total_nnz_padded) + # add removed padding back. If input is already rmpad, we let the caller pad_input + logits = pad_input( + logits, indices, batch_size, seqlen=sequence_length + ) # (batch_size, sequence_length, vocab_size) + + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=None, + hidden_states=None, + attentions=None, + ) + else: + return outputs + + +class ParallelQwen2ForValueRmPadPP(ParallelQwen2ForCausalLMRmPadPP): + def _init_head(self, config): + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + if self.megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) + self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False) + # lm_head is effectively the same as sequence parallel + sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight) + + def _forward_head(self, hidden_states): + logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1) + logits = logits.float() + if self.megatron_config.sequence_parallel: + logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False) + return logits + + def forward( + self, + *, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> tuple | CausalLMOutputWithPast: + output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids) + if self.post_process: + output.logits = torch.squeeze(output.logits, dim=-1) + return output + else: + return output diff --git a/code/RL_model/verl/verl_train/verl/models/transformers/__init__.py b/code/RL_model/verl/verl_train/verl/models/transformers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d992168f109c6a97d6749b1ab39c915b48330e19 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/transformers/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from verl.models.transformers.monkey_patch import apply_monkey_patch +from verl.models.transformers.tiled_mlp import apply_tiled_mlp_monkey_patch + +__all__ = [ + "apply_monkey_patch", + "apply_tiled_mlp_monkey_patch", +] diff --git a/code/RL_model/verl/verl_train/verl/models/transformers/apertus.py b/code/RL_model/verl/verl_train/verl/models/transformers/apertus.py new file mode 100644 index 0000000000000000000000000000000000000000..a42f50957b62e3ae3800b8aadf54793a2c97f2fc --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/transformers/apertus.py @@ -0,0 +1,118 @@ +# Copyright 2025 The SwissAI Initiative +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from typing import Callable, Optional + +import torch + +if sys.version_info >= (3, 11): + pass +else: + pass + +from transformers.cache_utils import Cache +from transformers.models.apertus.modeling_apertus import apply_rotary_pos_emb +from transformers.utils import logging + +# Import compatibility wrapper for flash_attn_supports_top_left_mask +from verl.utils.ulysses import ( + gather_heads_scatter_seq, + gather_seq_scatter_heads, + get_ulysses_sequence_parallel_world_size, + validate_ulysses_config, +) + +logger = logging.get_logger(__name__) + + +def apertus_attn_forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """ + Adapted from transformers 4.49.0 to support Ulysses sequence parallelism for transformers >= 4.48.0. + + Key differences from Llama attention: + - QK normalization applied after Q/K projections + + NOTE: This function has been tested only on transformers versions between 4.48.0 and 4.50.0. + """ + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + from transformers.models.apertus.modeling_apertus import eager_attention_forward + + bsz, q_len, _ = hidden_states.shape + + query_states = self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + ########## AlltoAll for Ulysses ########## + ulysses_sp_size = get_ulysses_sequence_parallel_world_size() + + if ulysses_sp_size > 1: + validate_ulysses_config(self.config.num_attention_heads, ulysses_sp_size) + + query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1) + key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1) + value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1) + + full_q_len = query_states.size(2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. " + "Falling back to eager attention. This warning can be removed using the argument " + '`attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous() + ########## AlltoAll for Ulysses ########## + if ulysses_sp_size > 1: + attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2) + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights diff --git a/code/RL_model/verl/verl_train/verl/models/transformers/dense_common.py b/code/RL_model/verl/verl_train/verl/models/transformers/dense_common.py new file mode 100644 index 0000000000000000000000000000000000000000..56fe293f5cbec4f9efa2a6a77a3374d09e358e56 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/transformers/dense_common.py @@ -0,0 +1,193 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from transformers.cache_utils import Cache +from transformers.modeling_outputs import CausalLMOutputWithPast + + +@dataclass +class CausalLMOutputForPPO(CausalLMOutputWithPast): + log_probs: Optional[torch.FloatTensor] = None + entropy: Optional[torch.FloatTensor] = None + + +def forward_base_model( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, +) -> CausalLMOutputWithPast: + r""" + Copy paste LLaMa's forward + https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/llama.py + + This function should be generic enough for all pure text models. + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + return outputs + + +def forward_with_torch_backend( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union["Cache", list[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: int | torch.Tensor = 0, + temperature: float = 1.0, + **loss_kwargs, +) -> tuple | CausalLMOutputForPPO: + from verl.utils.experimental.torch_functional import FusedLinearForPPO + + outputs = forward_base_model( + self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + if not return_dict: + raise NotImplementedError("forward_with_torch_backend has to return_dict") + + # Loss calculations + if labels is not None: + rolled_labels = torch.roll(labels, shifts=-1, dims=-1) + elif input_ids is not None: + rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) + else: + raise RuntimeError("To use forward_with_torch_backend, either labels or input_ids must be provided.") + + fused_linear_for_ppo = FusedLinearForPPO() + log_probs, entropy = fused_linear_for_ppo.forward( + hidden_states=hidden_states, + vocab_weights=self.lm_head.weight, + input_ids=rolled_labels, + temperature=temperature, + ) + + return CausalLMOutputForPPO( + log_probs=log_probs, + entropy=entropy, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def forward_with_triton_backend( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union["Cache", list[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: int | torch.Tensor = 0, + temperature: float = 1.0, + **loss_kwargs, +) -> tuple | CausalLMOutputForPPO: + from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy + + outputs = forward_base_model( + self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + if not return_dict: + raise NotImplementedError("forward_with_triton_backend has to return_dict") + + # Loss calculations + if labels is not None: + rolled_labels = torch.roll(labels, shifts=-1, dims=-1) + elif input_ids is not None: + rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) + else: + raise RuntimeError("To use forward_with_triton_backend, either labels or input_ids must be provided.") + + log_probs, entropy = linear_cross_entropy( + hidden_states, + self.lm_head.weight, + rolled_labels, + temperature, + "none", + ) + + return CausalLMOutputForPPO( + log_probs=log_probs, + entropy=entropy, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/code/RL_model/verl/verl_train/verl/models/transformers/glm4v.py b/code/RL_model/verl/verl_train/verl/models/transformers/glm4v.py new file mode 100644 index 0000000000000000000000000000000000000000..b2efe369a262155c62bca1d3bb026d101f2a46dc --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/transformers/glm4v.py @@ -0,0 +1,533 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import itertools +import logging +import os +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.distributed as dist +from transformers.modeling_flash_attention_utils import _flash_attention_forward, fa_peft_integration_check +from transformers.models.glm4v.modeling_glm4v import ( + Glm4vCausalLMOutputWithPast, + Glm4vForConditionalGeneration, + Glm4vTextAttention, +) +from transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10 + +from verl.utils.device import is_npu_available +from verl.utils.ulysses import ( + gather_heads_scatter_seq, + gather_seq_scatter_heads, + get_ulysses_sequence_parallel_group, + get_ulysses_sequence_parallel_world_size, + validate_ulysses_config, +) + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + + _flash_supports_window_size = "window_size" in inspect.signature(flash_attn_func).parameters + _flash_supports_deterministic = "deterministic" in inspect.signature(flash_attn_func).parameters + _flash_use_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + +if is_npu_available: + from transformers.integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func + from transformers.integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func + from transformers.modeling_flash_attention_utils import flash_attn_supports_top_left_mask + + _flash_supports_window_size = "window_size" in inspect.signature(flash_attn_func).parameters + _flash_supports_deterministic = "deterministic" in inspect.signature(flash_attn_func).parameters + _flash_use_top_left_mask = flash_attn_supports_top_left_mask() + +_flash_deterministic_enabled = os.getenv("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" + + +def get_rope_index( + processor, + input_ids: torch.Tensor, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Gets the position ids for GLM4V in padding-free format. + The batch dim has been removed and the input_ids should be a 1D tensor representing a single example. + """ + spatial_merge_size = processor.image_processor.merge_size + image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image|>") + video_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|begin_of_video|>") + video_end_token_id = processor.tokenizer.convert_tokens_to_ids("<|end_of_video|>") + + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + + position_ids = torch.ones(3, input_ids.size(0), dtype=input_ids.dtype, device=input_ids.device) # (3, seqlen) + image_index, video_index = 0, 0 + video_group_index = 0 + + input_ids_filtered = input_ids[attention_mask == 1] + input_tokens = input_ids_filtered.tolist() + + input_token_type = [] + video_check_flg = False + for token in input_tokens: + if token == video_start_token_id: + video_check_flg = True + elif token == video_end_token_id: + video_check_flg = False + + if token == image_token_id and not video_check_flg: + input_token_type.append("image") + elif token == image_token_id and video_check_flg: + input_token_type.append("video") + else: + input_token_type.append("text") + + input_type_group = [] + for key, group in itertools.groupby(enumerate(input_token_type), lambda x: x[1]): + group = list(group) + start_index = group[0][0] + end_index = group[-1][0] + 1 + input_type_group.append((key, start_index, end_index)) + + llm_pos_ids_list = [] + video_frame_num = 1 + + for modality_type, start_idx, end_idx in input_type_group: + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + + if modality_type == "image": + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx) + + image_index += 1 + video_frame_num = 1 + + elif modality_type == "video": + t, h, w = ( + video_frame_num, + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + + for t_idx in range(llm_grid_t): + t_index = torch.tensor(t_idx).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(1, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx) + + video_group_index += 1 + + if video_group_index >= video_grid_thw[video_index][0]: + video_index += 1 + video_group_index = 0 + + video_frame_num += 1 + + else: + text_len = end_idx - start_idx + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + video_frame_num = 1 + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., attention_mask == 1] = llm_positions.to(position_ids.device) + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1).to(input_ids.device) + else: + position_ids = torch.arange(input_ids.shape[0], device=input_ids.device).view(1, -1).expand(3, -1) + + return position_ids + + +def prepare_fa2_from_position_ids( + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, position_ids: torch.Tensor +): + assert position_ids.ndim == 2 # (batch_size, seq_length) + query = query.contiguous().view(-1, query.size(-2), query.size(-1)) + key = key.contiguous().view(-1, key.size(-2), key.size(-1)) + value = value.contiguous().view(-1, value.size(-2), value.size(-1)) + position_ids = position_ids.view(-1) + cu_seqlens = torch.cat( + ( + (position_ids == 0).nonzero().view(-1).to(torch.int32), + torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32), + ) + ) + max_length = cu_seqlens.diff().max() # use cu_seqlens to infer max_length for qwen2vl mrope + return (query, key, value, (cu_seqlens, cu_seqlens), (max_length, max_length)) + + +def _custom_flash_attention_forward( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + query_length: int, + is_causal: bool = True, + position_ids: Optional[torch.Tensor] = None, + use_top_left_mask: bool = False, + deterministic: Optional[bool] = None, + **kwargs, +): + """ + Patches flash attention forward to handle 3D position ids in mrope. (3, batch_size, seq_length) + """ + # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). + flash_kwargs = {} + + if _flash_supports_deterministic: + flash_kwargs["deterministic"] = deterministic if deterministic is not None else _flash_deterministic_enabled + + if kwargs.get("softcap") is not None: + flash_kwargs["softcap"] = kwargs.pop("softcap") + + query_states, key_states, value_states = fa_peft_integration_check( + query_states, key_states, value_states, target_dtype=torch.bfloat16 + ) + + if position_ids is not None: + assert position_ids.ndim == 2 # (batch_size, seq_length / sp_size) + + sp_size = get_ulysses_sequence_parallel_world_size() + if sp_size > 1: + # qkv: (batch_size, seq_length / sp_size, num_head, head_size) + validate_ulysses_config(query_states.size(2), sp_size) + query_states = gather_seq_scatter_heads(query_states, seq_dim=1, head_dim=2) + key_states = gather_seq_scatter_heads(key_states, seq_dim=1, head_dim=2) + value_states = gather_seq_scatter_heads(value_states, seq_dim=1, head_dim=2) + position_ids_lst = [torch.empty_like(position_ids) for _ in range(sp_size)] + position_ids = dist.all_gather(position_ids_lst, position_ids, group=get_ulysses_sequence_parallel_group()) + position_ids = torch.cat(position_ids_lst, dim=-1) # (batch_size, seq_length) + + if position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all(): + batch_size = query_states.size(0) + q, k, v, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = prepare_fa2_from_position_ids( + query_states, key_states, value_states, position_ids + ) + attn_output = flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + dropout_p=kwargs.pop("dropout", 0.0), + softmax_scale=kwargs.pop("softmax_scale", None), + causal=is_causal, + **flash_kwargs, + ) + attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1)) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + query_length, + is_causal=is_causal, + use_top_left_mask=use_top_left_mask, + deterministic=deterministic, + **kwargs, + ) # do not pass position_ids to old flash_attention_forward + + if sp_size > 1: + # (batch_size, seq_length, num_head, head_size) + attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1) + + return attn_output + + +def glm4v_attn_forward( + self: "Glm4vTextAttention", + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, +) -> tuple[torch.Tensor, None, None]: + from transformers.models.glm4v.modeling_glm4v import apply_multimodal_rotary_pos_emb, repeat_kv + + bsz, q_len, _ = hidden_states.size() # q_len = seq_length / sp_size + query_states = self.q_proj(hidden_states) # (batch_size, seq_length / sp_size, num_heads * head_size) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # This is before the transpose + q_len = query_states.shape[2] + + # FA2 uses non-transposed inputs + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = _custom_flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + query_length=q_len, + is_causal=getattr(self, "is_causal", True), + dropout=dropout_rate, + use_top_left_mask=_flash_use_top_left_mask, + position_ids=position_ids, # important: pass position ids + ) # (batch_size, seq_length / sp_size, num_head, head_size) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, None + + +def _get_input_embeds( + model: "Glm4vForConditionalGeneration", + input_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, +): + inputs_embeds = model.get_input_embeddings()(input_ids) + if pixel_values is not None: + pixel_values = pixel_values.type(model.visual.dtype) + image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw) + n_image_tokens = (input_ids == model.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + + mask = input_ids == model.config.image_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + image_mask = mask_expanded.to(inputs_embeds.device) + + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(model.visual.dtype) + video_embeds = model.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == model.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + + mask = input_ids == model.config.video_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + video_mask = mask_expanded.to(inputs_embeds.device) + + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if pixel_values is None and pixel_values_videos is None: # handle mixed text-image data + pixel_values = torch.zeros((16, 1176), dtype=inputs_embeds.dtype, device=inputs_embeds.device) + image_grid_thw = torch.tensor([[1, 4, 4]], dtype=torch.long, device=inputs_embeds.device) + image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw) + inputs_embeds += 0.0 * image_embeds.mean() + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + return inputs_embeds, attention_mask + + +def process_position_ids(position_ids: torch.Tensor) -> torch.Tensor: + if position_ids.ndim != 3 or position_ids.size(0) != 4: + # we concat the text position ids with the 3D vision position ids by default + # see https://github.com/huggingface/transformers/pull/39447 + raise ValueError("position_ids should be a 3D tensor of shape (4, batch_size, seq_length).") + + return position_ids + + +@dataclass +class Glm4vCausalLMOutputForPPO(Glm4vCausalLMOutputWithPast): + log_probs: Optional[torch.FloatTensor] = None + entropy: Optional[torch.FloatTensor] = None + + +def glm4v_base_forward( + self: "Glm4vForConditionalGeneration", + input_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + **kwargs, +): + kwargs["inputs_embeds"], kwargs["attention_mask"] = _get_input_embeds( + self, input_ids, attention_mask, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw + ) # avoid lora module having multiple keyword arguments + return self.language_model( + input_ids=None, + **kwargs, + ) + + +def glm4v_forward( + self: "Glm4vForConditionalGeneration", + input_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + **kwargs, +): + return self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=process_position_ids(position_ids), + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + **kwargs, + ) + + +def forward_with_normal_backend( + self: Glm4vForConditionalGeneration, + input_ids: torch.LongTensor = None, + labels: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + **kwargs, +) -> "Glm4vCausalLMOutputWithPast": + outputs = glm4v_forward(self, input_ids, **kwargs) + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + return Glm4vCausalLMOutputWithPast( + logits=logits, + hidden_states=outputs.hidden_states, + ) + + +def forward_with_torch_backend( + self: Glm4vForConditionalGeneration, + input_ids: torch.LongTensor = None, + labels: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + **kwargs, +) -> tuple | Glm4vCausalLMOutputForPPO: + from verl.utils.experimental.torch_functional import FusedLinearForPPO + + outputs = glm4v_forward(self, input_ids, **kwargs) + hidden_states = outputs[0] + + # Loss calculations + if labels is not None: + rolled_labels = torch.roll(labels, shifts=-1, dims=-1) + elif input_ids is not None: + rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) + else: + raise RuntimeError("To use forward_with_torch_backend, either labels or input_ids must be provided.") + + fused_linear_for_ppo = FusedLinearForPPO() + log_probs, entropy = fused_linear_for_ppo.forward( + hidden_states=hidden_states, + vocab_weights=self.lm_head.weight, + input_ids=rolled_labels, + temperature=temperature, + ) + return Glm4vCausalLMOutputForPPO( + log_probs=log_probs, + entropy=entropy, + hidden_states=outputs.hidden_states, + ) + + +def forward_with_triton_backend( + self: Glm4vForConditionalGeneration, + input_ids: torch.LongTensor = None, + labels: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + **kwargs, +) -> tuple | Glm4vCausalLMOutputForPPO: + from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy + + outputs = glm4v_forward(self, input_ids, **kwargs) + hidden_states = outputs[0] + + # Loss calculations + if labels is not None: + rolled_labels = torch.roll(labels, shifts=-1, dims=-1) + elif input_ids is not None: + rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) + else: + raise RuntimeError("To use forward_with_triton_backend, either labels or input_ids must be provided.") + + log_probs, entropy = linear_cross_entropy( + hidden_states, + self.lm_head.weight, + rolled_labels, + temperature, + "none", + ) + return Glm4vCausalLMOutputForPPO( + log_probs=log_probs, + entropy=entropy, + hidden_states=outputs.hidden_states, + ) diff --git a/code/RL_model/verl/verl_train/verl/models/transformers/kimi_vl.py b/code/RL_model/verl/verl_train/verl/models/transformers/kimi_vl.py new file mode 100644 index 0000000000000000000000000000000000000000..cabb518f4a113fc52f421700d9f216b4ec3bd627 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/transformers/kimi_vl.py @@ -0,0 +1,192 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +import torch.nn.functional as F +from transformers.cache_utils import Cache +from transformers.modeling_flash_attention_utils import _flash_attention_forward + +from verl.models.transformers.monkey_patch import is_transformers_version_in_range + +# Import compatibility wrapper for flash_attn_supports_top_left_mask +from verl.utils.transformers_compat import flash_attn_supports_top_left_mask +from verl.utils.ulysses import ( + gather_heads_scatter_seq, + gather_seq_scatter_heads, + get_ulysses_sequence_parallel_world_size, + validate_ulysses_config, +) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def _ulysses_flash_attn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + # patch + ulysses_sp_size = get_ulysses_sequence_parallel_world_size() + if ulysses_sp_size > 1: + validate_ulysses_config(self.num_heads, ulysses_sp_size) + + num_key_value_groups = self.config.num_attention_heads // self.config.num_key_value_heads + k_pe = repeat_kv(k_pe, ulysses_sp_size) # to keep heads=1 after a2a + k_nope = repeat_kv(k_nope, num_key_value_groups) + value_states = repeat_kv(value_states, num_key_value_groups) + q = gather_seq_scatter_heads(q, seq_dim=2, head_dim=1) + k_pe = gather_seq_scatter_heads(k_pe, seq_dim=2, head_dim=1) + k_nope = gather_seq_scatter_heads(k_nope, seq_dim=2, head_dim=1) + value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1) + # (batch_size, num_head / sp_size, seq_length, head_size) + full_q_len = q.size(2) # full_q_len = seq_length + + else: + full_q_len = q_len + + q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + cos, sin = self.rotary_emb(value_states, seq_len=full_q_len) + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = k_pe.new_empty(bsz, self.num_heads // ulysses_sp_size, full_q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads // ulysses_sp_size, full_q_len, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + + if self.q_head_dim != self.v_head_dim: + value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout + # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + full_q_len, + dropout=dropout_rate, + sliding_window=None, + is_causal=self.is_causal, + use_top_left_mask=flash_attn_supports_top_left_mask(), + position_ids=position_ids, # important: pass position ids + softmax_scale=self.softmax_scale, + ) + + if ulysses_sp_size > 1: + attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1) + + if self.q_head_dim != self.v_head_dim: + attn_output = attn_output[:, :, :, : self.v_head_dim] + + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim).contiguous() + attn_output = self.o_proj(attn_output) + + if is_transformers_version_in_range(min_version="4.53.0"): + return attn_output, None + else: + return attn_output, None, None diff --git a/code/RL_model/verl/verl_train/verl/models/transformers/llama.py b/code/RL_model/verl/verl_train/verl/models/transformers/llama.py new file mode 100644 index 0000000000000000000000000000000000000000..b3efb8646d55808bf647bb9d490ab69b80dc6fe1 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/transformers/llama.py @@ -0,0 +1,241 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from typing import Callable, Optional + +import torch + +if sys.version_info >= (3, 11): + pass +else: + pass + +from transformers.cache_utils import Cache +from transformers.modeling_flash_attention_utils import _flash_attention_forward +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb +from transformers.utils import logging + +# Import compatibility wrapper for flash_attn_supports_top_left_mask +from verl.utils.transformers_compat import flash_attn_supports_top_left_mask +from verl.utils.ulysses import ( + gather_heads_scatter_seq, + gather_seq_scatter_heads, + get_ulysses_sequence_parallel_world_size, + validate_ulysses_config, +) + +logger = logging.get_logger(__name__) + + +def llama_flash_attn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """ + Adapted from transformers 4.47.1 to support Ulysses sequence parallelism. + + NOTE: This function is used for transformers versions in the range [4.45.0, 4.47.1]. + """ + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + # trade off: repeat first and then all to all + # key_states = repeat_kv(key_states, self.num_key_value_groups) + # value_states = repeat_kv(value_states, self.num_key_value_groups) + + ########## AlltoAll for Ulysses ########## + ulysses_sp_size = get_ulysses_sequence_parallel_world_size() + + if ulysses_sp_size > 1: + validate_ulysses_config(self.num_heads, ulysses_sp_size) + + # (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim) + query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1) + key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1) + value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1) + + full_q_len = query_states.size(2) # full seq length + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout + # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to " + f"the fact you have upcasted embedding or layer norm layers in float32. We will cast back the " + f"input in {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + full_q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=flash_attn_supports_top_left_mask(), + is_causal=self.is_causal, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous() + ########## AlltoAll for Ulysses ########## + if ulysses_sp_size > 1: + attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2) + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def llama_attn_forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """ + Adapted from transformers 4.49.0 to support Ulysses sequence parallelism for transformers >= 4.48.0. + + NOTE: This function has been tested only on transformers versions between 4.48.0 and 4.50.0. + """ + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + from transformers.models.llama.modeling_llama import eager_attention_forward + + bsz, q_len, _ = hidden_states.shape + + query_states = self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + ########## AlltoAll for Ulysses ########## + ulysses_sp_size = get_ulysses_sequence_parallel_world_size() + + if ulysses_sp_size > 1: + validate_ulysses_config(self.config.num_attention_heads, ulysses_sp_size) + + query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1) + key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1) + value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1) + + full_q_len = query_states.size(2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. " + "Falling back to eager attention. This warning can be removed using the argument " + '`attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous() + ########## AlltoAll for Ulysses ########## + if ulysses_sp_size > 1: + attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2) + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights diff --git a/code/RL_model/verl/verl_train/verl/models/transformers/monkey_patch.py b/code/RL_model/verl/verl_train/verl/models/transformers/monkey_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..bb26dac2da9486d102509ccf55cfda94694a656d --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/transformers/monkey_patch.py @@ -0,0 +1,493 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Apply monkey-patch function to models +""" + +import sys +from types import SimpleNamespace +from typing import Optional + +import torch +from transformers.modeling_flash_attention_utils import _flash_attention_forward +from transformers.modeling_utils import PreTrainedModel + +from verl.utils.import_utils import is_trl_available +from verl.utils.transformers_compat import is_transformers_version_in_range +from verl.utils.ulysses import ( + gather_heads_scatter_seq, + gather_seq_scatter_heads, + get_ulysses_sequence_parallel_group, + get_ulysses_sequence_parallel_world_size, + slice_input_tensor, +) + +_PREFIX_GROUPER_PATCHED = False +_PREFIX_GROUPER_SUPPORTED_ATTENTIONS = {"flash_attention_2", "flash_attention_3", "sdpa", "flex_attention", "eager"} + + +def _create_prefix_grouper_wrapper(original_fn): + """Wrap attention function to support prefix_grouper in kwargs.""" + + def wrapped(module, query, key, value, attention_mask, *args, **kwargs): + prefix_grouper = kwargs.pop("prefix_grouper", None) + if prefix_grouper is None: + return original_fn(module, query, key, value, attention_mask, *args, **kwargs) + + def attn_func(q, k, v, attn_mask, *inner_args, **inner_kwargs): + out, _ = original_fn(module, q, k, v, attn_mask, *inner_args, **inner_kwargs) + return out + + return prefix_grouper.forward(attn_func, query, key, value, *args, **kwargs), None + + return wrapped + + +def apply_prefix_grouper_patch(): + """Patch ALL_ATTENTION_FUNCTIONS to support prefix_grouper parameter.""" + global _PREFIX_GROUPER_PATCHED + if _PREFIX_GROUPER_PATCHED: + return + + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + patched = [] + for name in list(ALL_ATTENTION_FUNCTIONS.keys()): + if name in _PREFIX_GROUPER_SUPPORTED_ATTENTIONS: + ALL_ATTENTION_FUNCTIONS[name] = _create_prefix_grouper_wrapper(ALL_ATTENTION_FUNCTIONS[name]) + patched.append(name) + + _PREFIX_GROUPER_PATCHED = True + print(f"[PrefixGrouper] Patched: {patched}") + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=2, repeats=n_rep). The hidden states go from (batch, + seqlen, num_key_value_heads, head_dim) to (batch, seqlen, num_attention_heads, head_dim) + """ + batch, slen, num_key_value_heads, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, :, None, :].expand(batch, slen, num_key_value_heads, n_rep, head_dim) + return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep, head_dim) + + +def _ulysses_flash_attention_forward( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + query_length: int, + *args, + position_ids: Optional[torch.Tensor] = None, + **kwargs, +): + """Insert all-to-all before and after flash attention. + DeepSpeed-Ulysses: https://arxiv.org/pdf/2309.14509 + + For transformers>=4.55, the flash attention api has changed, + we need to pass the query_length after doing ulysses all2all. + See https://github.com/huggingface/transformers/issues/40399 + + Args: + query_states (torch.Tensor): (batch_size, seqlen/sp_size, nheads, head_dim) + key_states (torch.Tensor): (batch_size, seqlen/sp_size, nheads_k, head_dim) + value_states (torch.Tensor): (batch_size, seqlen/sp_size, nheads_k, head_dim) + position_ids (torch.Tensor, optional): (batch_size, seqlen/sp_size) + + Returns: + torch.Tensor: (batch_size, seqlen/sp_size, nheads, head_dim) + + """ + ulysses_sp_size = get_ulysses_sequence_parallel_world_size() + + ########## AlltoAll for Ulysses ########## + # TODO: Disable sp for ViT, there's no elegent way to determine whether it's ViT or not. + # Use `position_ids` as condition since ViT doesn't pass it to flash attention. + if ulysses_sp_size > 1 and position_ids is not None: + # NOTE: repeat kv heads to be divided by sequence parallel. Instead of repeating nheads_q//nheads_k, + # we choose to repeat sp_size//nheads_k, since flash_attention supports MQA/GQA. + # For example: + # - nheads_k=4, sp=8, repeats=2 + # - nheads_k=8, sp=8, repeats=1 + # - nheads_k=16, sp=8, repeats=1 + repeats = max(ulysses_sp_size // key_states.size(2), 1) + key_states = repeat_kv(key_states, repeats) + value_states = repeat_kv(value_states, repeats) + + # (bsz, seq_len/n, n_head, head_dim) -> (bsz, seq_len, n_head/n, head_dim) + query_states = gather_seq_scatter_heads(query_states, seq_dim=1, head_dim=2) + key_states = gather_seq_scatter_heads(key_states, seq_dim=1, head_dim=2) + value_states = gather_seq_scatter_heads(value_states, seq_dim=1, head_dim=2) + + # TODO: all_gather position_ids because `prepare_fa2_from_position_ids` needs it, we can eliminate + # this all_gather by passing cu_seq_lens_q, cu_seq_lens_k, max_length_k, max_length_q explicitly. + # https://github.com/huggingface/transformers/pull/33932 + + # (bsz, seq_len/n) -> (bsz, seq_len) + position_ids_list = [torch.empty_like(position_ids) for _ in range(ulysses_sp_size)] + torch.distributed.all_gather(position_ids_list, position_ids, group=get_ulysses_sequence_parallel_group()) + position_ids = torch.concat(position_ids_list, dim=-1) + + # (bsz, seq_len, n_head/n, head_dim) + query_length = query_states.size(1) + attn_output = _flash_attention_forward( + query_states, key_states, value_states, attention_mask, query_length, *args, position_ids=position_ids, **kwargs + ) + + ########## AlltoAll for Ulysses ########## + if ulysses_sp_size > 1 and position_ids is not None: + # (bsz, seq_len, n_head/n, head_dim) -> (bsz, seq_len/n, n_head, head_dim) + attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2) + + return attn_output + + +def patch_vlm_for_ulysses_input_slicing(model_class: type): + """ + Applies a monkey patch to the forward method of a given model class + to enable Ulysses sequence parallelism input slicing. + """ + + def _create_ulysses_wrapped_decoder_forward(original_forward): + def ulysses_wrapped_decoder_forward(self, *args, **kwargs): + inputs_embeds = kwargs.get("inputs_embeds") + position_ids = kwargs.get("position_ids") + visual_pos_masks = kwargs.get("visual_pos_masks") + deepstack_visual_embeds = kwargs.get("deepstack_visual_embeds") + call_kwargs = kwargs.copy() + + current_ulysses_sp_size = get_ulysses_sequence_parallel_world_size() + + slice_now = ( + inputs_embeds is not None + and current_ulysses_sp_size > 1 + and getattr(self, "_needs_initial_slice", True) + ) + if slice_now: + call_kwargs["inputs_embeds"] = slice_input_tensor(inputs_embeds, dim=1, padding=False) + call_kwargs["position_ids"] = slice_input_tensor(position_ids, dim=-1, padding=False) + # Also slice visual_pos_masks and deepstack_visual_embeds for Qwen3 VL models + if visual_pos_masks is not None: + original_visual_mask = visual_pos_masks + sliced_visual_mask = slice_input_tensor(visual_pos_masks, dim=1, padding=False) + call_kwargs["visual_pos_masks"] = sliced_visual_mask + + if deepstack_visual_embeds is not None: + sliced_embeds = [] + + num_visual_before = original_visual_mask.sum().item() + num_visual_in_shard = sliced_visual_mask.sum().item() + + if num_visual_in_shard > 0 and num_visual_before > 0: + # Calculate which visual embeddings belong to this shard + # We need to find the offset of visual tokens in this shard + from verl.utils.ulysses import get_ulysses_sequence_parallel_rank + + rank = get_ulysses_sequence_parallel_rank() + seq_len = original_visual_mask.shape[1] + local_seq_len = seq_len // current_ulysses_sp_size + start_idx = rank * local_seq_len + end_idx = start_idx + local_seq_len + + # Get total visual tokens before and up to the end of the shard's sequence slice + # This correctly handles batches by summing across all samples + visual_start = original_visual_mask[:, :start_idx].sum().item() if start_idx > 0 else 0 + visual_end = original_visual_mask[:, :end_idx].sum().item() + + # Slice each tensor in deepstack_visual_embeds + for embed in deepstack_visual_embeds: + sliced_embeds.append(embed[visual_start:visual_end]) + else: + # No visual tokens in this shard, create empty tensors to maintain gradient flow + for embed in deepstack_visual_embeds: + sliced_embeds.append(embed[:0]) + call_kwargs["deepstack_visual_embeds"] = sliced_embeds + + self._needs_initial_slice = False + try: + return original_forward(self, *args, **call_kwargs) + finally: + if slice_now: + self._needs_initial_slice = True + + return ulysses_wrapped_decoder_forward + + original_forward = model_class.forward + wrapped_forward = _create_ulysses_wrapped_decoder_forward(original_forward) + model_class.forward = wrapped_forward + print(f"Monkey patch {model_class.__name__}.forward for Ulysses SP input slicing.") + + +def patch_forward_with_backends( + model: PreTrainedModel, + use_fused_kernels: bool = False, + fused_kernels_backend: str = None, +): + """ + Choose the forward function based on the model and backend. + Args: + model (PreTrainedModel): The model to apply the monkey patch. + use_fused_kernels (bool): Whether to use fused kernels. + fused_kernels_backend (str): The backend to use for fused kernels. + """ + if not use_fused_kernels or fused_kernels_backend not in ["triton", "torch"]: + print( + f"Skipping monkey patch for {model.__class__.__name__} as use_fused_kernels is " + f"{use_fused_kernels} or fused_kernels_backend is {fused_kernels_backend}" + ) + return + + forward_with_torch_backend_function = model.__class__.forward + forward_with_triton_backend_function = model.__class__.forward + if model.config.model_type in ["qwen2_5_vl", "qwen2_vl"]: + from verl.models.transformers.qwen2_vl import forward_with_torch_backend, forward_with_triton_backend + + forward_with_torch_backend_function = forward_with_torch_backend + forward_with_triton_backend_function = forward_with_triton_backend + elif model.config.model_type in ["qwen3_vl", "qwen3_vl_moe"]: + from verl.models.transformers.qwen3_vl import forward_with_torch_backend, forward_with_triton_backend + + forward_with_torch_backend_function = forward_with_torch_backend + forward_with_triton_backend_function = forward_with_triton_backend + elif model.config.model_type == "glm4v": + from verl.models.transformers.glm4v import forward_with_torch_backend, forward_with_triton_backend + + forward_with_torch_backend_function = forward_with_torch_backend + forward_with_triton_backend_function = forward_with_triton_backend + else: + from verl.models.transformers.dense_common import forward_with_torch_backend, forward_with_triton_backend + + forward_with_torch_backend_function = forward_with_torch_backend + forward_with_triton_backend_function = forward_with_triton_backend + + if fused_kernels_backend == "triton": + model.__class__.forward = forward_with_triton_backend_function + print(f"Using Triton backend for fused kernels in {model.__class__.__name__}") + elif fused_kernels_backend == "torch": + model.__class__.forward = forward_with_torch_backend_function + print(f"Using Torch backend for fused kernels in {model.__class__.__name__}") + else: + raise ValueError(f"Unsupported fused_kernels_backend: {fused_kernels_backend}. Choose 'triton' or 'torch'.") + + +def apply_monkey_patch( + model: PreTrainedModel, + ulysses_sp_size: int = 1, + use_remove_padding: bool = True, + use_fused_kernels: bool = False, + fused_kernels_backend: str = None, + use_prefix_grouper: bool = False, + use_tiled_mlp: bool = False, + tiled_mlp_shards: int = 4, +): + """ + Apply monkey patch to the models for ulysses sequence parallel, fused kernel, tiled MLP and prefix grouper. + + In the end of this function forward function of the model is patched for fused kernel. + If the model is not supported with fused kernel, please return after patch. + + Args: + model: The model to apply the monkey patch. + ulysses_sp_size: The size of ulysses sequence parallel. + use_remove_padding: Whether to use remove padding. + use_fused_kernels: Whether to use fused kernels. + fused_kernels_backend: The backend to use for fused kernels. + use_tiled_mlp: Whether to use TiledMLP for memory-efficient MLP computation. + tiled_mlp_shards: Number of shards for TiledMLP (higher = lower memory, slightly slower). + """ + + # Apply TiledMLP monkey patch for memory-efficient MLP computation + if use_tiled_mlp: + from verl.models.transformers.tiled_mlp import apply_tiled_mlp_monkey_patch + + model_type = getattr(model.config, "model_type", None) + apply_tiled_mlp_monkey_patch(num_shards=tiled_mlp_shards, model_type=model_type) + # Apply PrefixGrouper patch if enabled + if use_prefix_grouper: + apply_prefix_grouper_patch() + + """Replace _flash_attention_forward to _ulysses_flash_attention_forward""" + module = sys.modules[model.__module__] + + try: + num_attention_heads, num_key_value_heads = model.config.num_attention_heads, model.config.num_key_value_heads + except AttributeError: + num_attention_heads, num_key_value_heads = ( + model.config.text_config.num_attention_heads, + model.config.text_config.num_key_value_heads, + ) + + assert num_attention_heads % ulysses_sp_size == 0, ( + f"num_attention_heads {num_attention_heads} must be divisible by ulysses_sp_size {ulysses_sp_size}" + ) + assert num_key_value_heads % ulysses_sp_size == 0 or ulysses_sp_size % num_key_value_heads == 0, ( + f"num_key_value_heads {num_key_value_heads} must be divisible by ulysses_sp_size " + f"{ulysses_sp_size}or vise versa. Upon ulysses_sp_size % num_key_value_heads == 0," + f"kv heads are repeated to ensure correctness." + ) + + if is_trl_available(): + from trl import AutoModelForCausalLMWithValueHead # type: ignore + + def state_dict(self, *args, **kwargs): + return torch.nn.Module.state_dict(self, *args, **kwargs) + + AutoModelForCausalLMWithValueHead.state_dict = state_dict + print("Monkey patch state_dict in AutoModelForCausalLMWithValueHead. ") + + # TODO: VLM models only, unify monkey patch to LLM models. + if model.config.model_type in ["qwen2_5_vl", "qwen2_vl"]: + # Step 1: patch model to support image-text mixed data + if is_transformers_version_in_range(min_version="4.52.0"): + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VLForConditionalGeneration, + Qwen2_5_VLModel, + Qwen2_5_VLTextModel, + ) + from transformers.models.qwen2_vl.modeling_qwen2_vl import ( + Qwen2VLForConditionalGeneration, + Qwen2VLModel, + Qwen2VLTextModel, + ) + else: + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel as Qwen2_5_VLTextModel + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel as Qwen2VLTextModel + + Qwen2_5_VLModel = SimpleNamespace(forward=None) + Qwen2VLModel = SimpleNamespace(forward=None) + + from verl.models.transformers.qwen2_vl import forward_with_normal_backend, qwen2_vl_base_forward + + Qwen2_5_VLModel.forward = qwen2_vl_base_forward + Qwen2VLModel.forward = qwen2_vl_base_forward + Qwen2_5_VLForConditionalGeneration.forward = forward_with_normal_backend + Qwen2VLForConditionalGeneration.forward = forward_with_normal_backend + print(f"Monkey patch {model.__class__.__name__} model forward") + + # Step 2: patch attention to support ulysses parallelism + if is_transformers_version_in_range(min_version="4.54.0"): + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLAttention + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLAttention + elif is_transformers_version_in_range(min_version="4.53.0"): + raise RuntimeError("Transformers 4.53.* is bugged. Use transformers 4.54.0 or later.") + else: + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VLFlashAttention2 as Qwen2_5_VLAttention, + ) + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2 as Qwen2VLAttention + + if use_remove_padding or ulysses_sp_size > 1: + from verl.models.transformers.qwen2_vl import qwen2_vl_attn_forward + + Qwen2_5_VLAttention.forward = qwen2_vl_attn_forward + Qwen2VLAttention.forward = qwen2_vl_attn_forward + print(f"Monkey patch {model.__class__.__name__} attention layer") + + # Step 3: patch input for multimodal sequence parallelism + if ulysses_sp_size > 1: + patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLTextModel) + patch_vlm_for_ulysses_input_slicing(Qwen2VLTextModel) + + elif model.config.model_type in ["qwen3_vl", "qwen3_vl_moe"]: + # Step 1: patch model to support image-text mixed data + from transformers.models.qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLForConditionalGeneration, + Qwen3VLModel, + Qwen3VLTextModel, + ) + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import ( + Qwen3VLMoeForConditionalGeneration, + Qwen3VLMoeModel, + Qwen3VLMoeTextModel, + ) + + from verl.models.transformers.qwen3_vl import ( + forward_with_normal_backend, + patch_qwen3_vl_moe_sparse_moe_block_forward, + qwen3_vl_base_forward, + ) + + Qwen3VLModel.forward = qwen3_vl_base_forward + Qwen3VLMoeModel.forward = qwen3_vl_base_forward + Qwen3VLForConditionalGeneration.forward = forward_with_normal_backend + Qwen3VLMoeForConditionalGeneration.forward = forward_with_normal_backend + print(f"Monkey patch {model.__class__.__name__} model forward") + + # Step 1.5: patch Qwen3VLMoeTextSparseMoeBlock to fix transformers 4.57.3 bug + if model.config.model_type == "qwen3_vl_moe" and is_transformers_version_in_range(max_version="4.57.3"): + patch_qwen3_vl_moe_sparse_moe_block_forward() + + # Step 2: patch input for multimodal sequence parallelism + if ulysses_sp_size > 1: + patch_vlm_for_ulysses_input_slicing(Qwen3VLTextModel) + patch_vlm_for_ulysses_input_slicing(Qwen3VLMoeTextModel) + + elif model.config.model_type == "glm4v": + # Step 1: patch model to support image-text mixed data + + from transformers.models.glm4v.modeling_glm4v import ( + Glm4vForConditionalGeneration, + Glm4vModel, + Glm4vTextAttention, + Glm4vTextModel, + ) + + from verl.models.transformers.glm4v import forward_with_normal_backend, glm4v_base_forward + + Glm4vModel.forward = glm4v_base_forward + Glm4vForConditionalGeneration.forward = forward_with_normal_backend + print(f"Monkey patch {model.__class__.__name__} model forward") + + # Step 2: patch attention to support ulysses parallelism + if use_remove_padding or ulysses_sp_size > 1: + from verl.models.transformers.glm4v import glm4v_attn_forward + + Glm4vTextAttention.forward = glm4v_attn_forward + print(f"Monkey patch {model.__class__.__name__} attention layer") + + # Step 3: patch input for multimodal sequence parallelism + if ulysses_sp_size > 1: + patch_vlm_for_ulysses_input_slicing(Glm4vTextModel) + + elif model.config.model_type == "kimi_vl": + if use_remove_padding or ulysses_sp_size > 1: + # TODO: Changes need to be made when transformers are adapted. + from verl.models.transformers.kimi_vl import _ulysses_flash_attn_forward + + module.DeepseekV3FlashAttention2.forward = _ulysses_flash_attn_forward + print("Monkey patch FlashAttention2.forward in KimiVL") + + if ulysses_sp_size > 1: + patch_vlm_for_ulysses_input_slicing(module.DeepseekV3ForCausalLM) + + if use_fused_kernels: + print("Not support fused kernels for KimiVL") + + return + + if use_remove_padding or ulysses_sp_size > 1: + if hasattr(module, "_flash_attention_forward"): # transformers <= 4.47.1 or legacy models + module._flash_attention_forward = _ulysses_flash_attention_forward + print(f"Monkey patch _flash_attention_forward in {model.__module__}") + else: + from transformers.integrations import flash_attention + + flash_attention._flash_attention_forward = _ulysses_flash_attention_forward + print(f"Monkey patch _flash_attention_forward in {flash_attention.__name__}") + + patch_forward_with_backends(model, use_fused_kernels=use_fused_kernels, fused_kernels_backend=fused_kernels_backend) diff --git a/code/RL_model/verl/verl_train/verl/models/transformers/npu_patch.py b/code/RL_model/verl/verl_train/verl/models/transformers/npu_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..ba25fe6e6ba52dff49f796236bdb84c6c0380a8b --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/transformers/npu_patch.py @@ -0,0 +1,261 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn.functional as F +import torch_npu +from torch import nn +from transformers.activations import ACT2FN +from transformers.models.qwen2 import modeling_qwen2 +from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl +from transformers.models.qwen3 import modeling_qwen3 +from transformers.models.qwen3_moe import modeling_qwen3_moe +from transformers.models.qwen3_vl import modeling_qwen3_vl +from transformers.models.qwen3_vl_moe import modeling_qwen3_vl_moe +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +def rms_norm_forward_npu(self, x): + """NPU optimized implementation for RMSNorm.""" + if x.dtype != self.weight.dtype: + x = x.to(self.weight.dtype) + return torch_npu.npu_rms_norm(x, self.weight, epsilon=self.variance_epsilon)[0] + + +def silu_forward_npu(self, hidden_state): + """NPU optimized implementation for SiLU in `forward` func in MLP layer.""" + gate_up = torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1) + return self.down_proj(torch_npu.npu_swiglu(gate_up, dim=-1)) + + +def apply_rotary_pos_emb_npu(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """NPU optimized implementation for RoPE.""" + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = torch_npu.npu_rotary_mul(q, cos, sin) + k_embed = torch_npu.npu_rotary_mul(k, cos, sin) + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + +class NPUGmmFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight, group_list, group_list_type=1): + """ + Grouped Matmul(GMM) for Ascend NPU. + + Args: + x (torch.Tensor): Input tensor, shape (tokens_num * top_k, hidden_size) + weight (torch.Tensor): Expert weights, shape (n_experts, hidden_size, intermediate_size) + group_list (torch.Tensor): Expert token counts, shape (n_experts,) + - type 0: cumsum of tokens per expert + - type 1: direct tokens per expert (default) + """ + ctx.save_for_backward(x, weight) + ctx.group_list = group_list + ctx.group_list_type = group_list_type + + output = torch_npu.npu_grouped_matmul( + [x], [weight], bias=None, group_list=group_list, split_item=2, group_type=0, group_list_type=group_list_type + )[0] + + return output + + @staticmethod + def backward(ctx, grad_output): + x, weight = ctx.saved_tensors + group_list = ctx.group_list + group_list_type = ctx.group_list_type + + dx = torch_npu.npu_grouped_matmul( + [grad_output], + [weight.transpose(1, 2)], + bias=None, + group_list=group_list, + split_item=2, + group_type=0, + group_list_type=group_list_type, + )[0] + + dw = torch_npu.npu_grouped_matmul( + [x.transpose(0, 1)], + [grad_output], + bias=None, + group_list=group_list, + split_item=3, + group_type=2, + group_list_type=group_list_type, + )[0] + + return dx, dw, None, None + + +def qwen3_moe_sparse_moe_block_forward_npu(self, hidden_states: torch.Tensor) -> torch.Tensor: + """NPU optimized implementation for `forward` in Qwen3MoeSparseMoeBlock.""" + # hidden_states: (batch_size, sequence_length, hidden_size) + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + if self.norm_topk_prob: # only diff with mixtral sparse moe block! + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + # Loop over all available experts in the model and perform the computation on each expert + # Concat all weights + input_dtype = hidden_states.dtype + up_weight_list = [e.up_proj.weight for e in self.experts] + gate_weight_list = [e.gate_proj.weight for e in self.experts] + down_weight_list = [e.down_proj.weight for e in self.experts] + w1 = torch.stack(up_weight_list).transpose(1, 2).to(input_dtype) + w2 = torch.stack(gate_weight_list).transpose(1, 2).to(input_dtype) + w3 = torch.stack(down_weight_list).transpose(1, 2).to(input_dtype) + + permuted_tokens, row_ids_map = torch_npu.npu_moe_token_permute(hidden_states, selected_experts.to(torch.int32)) + tokens_per_expert = torch.histc(selected_experts, bins=self.num_experts, min=0, max=self.num_experts) + + up_res = NPUGmmFunction.apply(permuted_tokens, w1, tokens_per_expert) + gate_res = NPUGmmFunction.apply(permuted_tokens, w2, tokens_per_expert) + act_res = torch_npu.npu_swiglu(torch.cat([gate_res, up_res], dim=-1)) + down_res = NPUGmmFunction.apply(act_res, w3, tokens_per_expert) + + final_hidden_states = torch_npu.npu_moe_token_unpermute(down_res, row_ids_map, probs=routing_weights) + + return final_hidden_states, router_logits + + +class NPUQwen3VLMoeTextExperts(nn.Module): + """NPU optimized implementation for Qwen3VLMoeTextExperts.""" + + def __init__(self, config): + super().__init__() + self.num_experts = config.num_experts + self.intermediate_size = config.moe_intermediate_size + self.hidden_size = config.hidden_size + self.expert_dim = self.intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim)) + self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size))) + self.act_fn = ACT2FN[config.hidden_act] + + def forward( + self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor + ) -> torch.Tensor: + """ + When training it is more efficient to just loop over the experts and compute the output for each expert + as otherwise the memory would explode. + + For inference we can sacrifice some memory and compute the output for all experts at once. + By repeating the inputs. + + Args: + hidden_states (torch.Tensor): (batch_size * token_num, hidden_size) + routing_weights (torch.Tensor): (batch_size * token_num, num_experts) + router_indices (torch.Tensor): (batch_size * token_num, top_k) + Returns: + torch.Tensor + """ + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) + if self.training: + permuted_hidden_states, row_ids_map = torch_npu.npu_moe_token_permute( + hidden_states, router_indices.to(torch.int32) + ) + tokens_per_expert = torch.histc(router_indices, bins=self.num_experts, min=0, max=self.num_experts) + intermediate_hidden_states = NPUGmmFunction.apply( + permuted_hidden_states, self.gate_up_proj, tokens_per_expert + ) + intermediate_activations = torch_npu.npu_swiglu(intermediate_hidden_states, dim=-1) + output = NPUGmmFunction.apply(intermediate_activations, self.down_proj, tokens_per_expert) + num_tokens = hidden_states.shape[0] + top_k = router_indices.shape[1] + batch_idx = torch.arange(num_tokens, device=routing_weights.device) + batch_idx = batch_idx.unsqueeze(1).expand(-1, top_k) + selected_probs = routing_weights[batch_idx, router_indices] + next_states = torch_npu.npu_moe_token_unpermute(output, row_ids_map, probs=selected_probs) + next_states = next_states.view(batch_size, -1, self.hidden_size) + else: + hidden_states = hidden_states.repeat(self.num_experts, 1) + hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) + gate_up = torch.bmm(hidden_states, self.gate_up_proj) + gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors + next_states = torch.bmm((up * self.act_fn(gate)), self.down_proj) + next_states = next_states.reshape(self.num_experts, batch_size, -1, self.hidden_size) + next_states = ( + next_states * routing_weights.transpose(0, 1).view(self.num_experts, batch_size, -1)[..., None] + ) + next_states = next_states.sum(dim=0) + return next_states + + +class NPUQwen3VLMoeTextSparseMoeBlock(nn.Module): + """NPU optimized implementation for Qwen3VLMoeTextSparseMoeBlock.""" + + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) + self.experts = NPUQwen3VLMoeTextExperts(config) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.hidden_size) + router_logits = self.gate(hidden_states) + routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float) + routing_weights, router_indices = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(router_logits.dtype) + hidden_states = hidden_states.reshape(batch_size, -1, self.hidden_size) + if not self.training: + routing_weights = torch.zeros_like(router_logits).scatter_(1, router_indices, routing_weights) + routed_out = self.experts(hidden_states, routing_weights, router_indices) + return routed_out + + +# Patches for Qwen2 Model +modeling_qwen2.Qwen2RMSNorm.forward = rms_norm_forward_npu +modeling_qwen2.Qwen2MLP.forward = silu_forward_npu +modeling_qwen2.apply_rotary_pos_emb = apply_rotary_pos_emb_npu + +# Patches for Qwen2.5-VL Model +modeling_qwen2_5_vl.Qwen2RMSNorm.forward = rms_norm_forward_npu +modeling_qwen2_5_vl.Qwen2_5_VLMLP.forward = silu_forward_npu + +# Patches for Qwen3 Model +modeling_qwen3.Qwen3RMSNorm.forward = rms_norm_forward_npu +modeling_qwen3.Qwen3MLP.forward = silu_forward_npu +modeling_qwen3.apply_rotary_pos_emb = apply_rotary_pos_emb_npu + +# Patches for Qwen3 MoE Model +modeling_qwen3_moe.Qwen3MoeRMSNorm.forward = rms_norm_forward_npu +modeling_qwen3_moe.Qwen3MoeSparseMoeBlock.forward = qwen3_moe_sparse_moe_block_forward_npu +modeling_qwen3_moe.apply_rotary_pos_emb = apply_rotary_pos_emb_npu + +# Patches for Qwen3 VL Model +modeling_qwen3_vl.Qwen3VLTextRMSNorm.forward = rms_norm_forward_npu +modeling_qwen3_vl.Qwen3VLTextMLP.forward = silu_forward_npu + +# Patches for Qwen3-VL MoE Model +modeling_qwen3_vl_moe.Qwen3VLMoeTextSparseMoeBlock = NPUQwen3VLMoeTextSparseMoeBlock +modeling_qwen3_vl_moe.Qwen3VLMoeTextRMSNorm.forward = rms_norm_forward_npu +modeling_qwen3_vl_moe.apply_rotary_pos_emb = apply_rotary_pos_emb_npu diff --git a/code/RL_model/verl/verl_train/verl/models/transformers/qwen2.py b/code/RL_model/verl/verl_train/verl/models/transformers/qwen2.py new file mode 100644 index 0000000000000000000000000000000000000000..3bac76e9a142530e86a32c3ad4228e6964afc19a --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/transformers/qwen2.py @@ -0,0 +1,243 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional + +import torch +from transformers.cache_utils import Cache +from transformers.modeling_flash_attention_utils import _flash_attention_forward +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv +from transformers.utils import logging + +# Import compatibility wrapper for flash_attn_supports_top_left_mask +from verl.utils.transformers_compat import flash_attn_supports_top_left_mask +from verl.utils.ulysses import ( + gather_heads_scatter_seq, + gather_seq_scatter_heads, + get_ulysses_sequence_parallel_world_size, + validate_ulysses_config, +) + +logger = logging.get_logger(__name__) + + +def qwen2_flash_attn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 +): + """ + Adapted from transformers 4.47.1 to support Ulysses sequence parallelism. + + NOTE: This function is only tested on transformers versions between 4.45.0 and 4.47.1. + """ + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + ########## AlltoAll for Ulysses ########## + ulysses_sp_size = get_ulysses_sequence_parallel_world_size() + + if ulysses_sp_size > 1: + validate_ulysses_config(self.num_heads, ulysses_sp_size) + + # (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim) + query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1) + key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1) + value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1) + + full_q_len = query_states.size(2) # full seq length + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to " + f"the fact you have upcasted embedding or layer norm layers in float32. We will cast back the " + f"input in {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + sliding_window = self.config.sliding_window + else: + sliding_window = None + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + full_q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=sliding_window, + is_causal=self.is_causal, + use_top_left_mask=flash_attn_supports_top_left_mask(), + ) + + # use full_q_len to reshape + attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous() + ########## AlltoAll for Ulysses ########## + if ulysses_sp_size > 1: + attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2) + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def qwen2_attn_forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """ + Adapted from transformers 4.49.0 to support Ulysses sequence parallelism for transformers >= 4.48.0. + + NOTE: This function has been tested only on transformers versions between 4.48.0 and 4.50.0. + """ + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + bsz, q_len, _ = hidden_states.shape + hidden_shape = (bsz, q_len, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + ########## AlltoAll for Ulysses ########## + ulysses_sp_size = get_ulysses_sequence_parallel_world_size() + + if ulysses_sp_size > 1: + validate_ulysses_config(self.config.num_attention_heads, ulysses_sp_size) + + # (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim) + query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1) + key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1) + value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1) + + full_q_len = query_states.size(2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + sliding_window = None + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + sliding_window = self.config.sliding_window + + from transformers.models.qwen2.modeling_qwen2 import eager_attention_forward + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. " + "Falling back to eager attention. This warning can be removed using the argument " + '`attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=sliding_window, # main diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous() + ########## AlltoAll for Ulysses ########## + if ulysses_sp_size > 1: + # (bsz, seq_len, n_head/n, head_dim) -> (bsz, seq_len/n, n_head, head_dim) + attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2) + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights diff --git a/code/RL_model/verl/verl_train/verl/models/transformers/qwen2_vl.py b/code/RL_model/verl/verl_train/verl/models/transformers/qwen2_vl.py new file mode 100644 index 0000000000000000000000000000000000000000..5e82fdd4dd4bd3211350e46b05dfb38e7ed5ca30 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/transformers/qwen2_vl.py @@ -0,0 +1,548 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import logging +import os +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.distributed as dist +from transformers.modeling_flash_attention_utils import _flash_attention_forward, fa_peft_integration_check +from transformers.models.qwen2_vl.modeling_qwen2_vl import ( + Qwen2VLAttention, + Qwen2VLCausalLMOutputWithPast, + Qwen2VLForConditionalGeneration, +) +from transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10 + +from verl.utils.device import is_npu_available +from verl.utils.transformers_compat import is_transformers_version_in_range +from verl.utils.ulysses import ( + gather_heads_scatter_seq, + gather_seq_scatter_heads, + get_ulysses_sequence_parallel_group, + get_ulysses_sequence_parallel_world_size, + validate_ulysses_config, +) + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + + _flash_supports_window_size = "window_size" in inspect.signature(flash_attn_func).parameters + _flash_supports_deterministic = "deterministic" in inspect.signature(flash_attn_func).parameters + _flash_use_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + +if is_npu_available: + from transformers.integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func + from transformers.integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func + from transformers.modeling_flash_attention_utils import flash_attn_supports_top_left_mask + + _flash_supports_window_size = "window_size" in inspect.signature(flash_attn_func).parameters + _flash_supports_deterministic = "deterministic" in inspect.signature(flash_attn_func).parameters + _flash_use_top_left_mask = flash_attn_supports_top_left_mask() + +_flash_deterministic_enabled = os.getenv("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" + + +def get_rope_index( + processor, + input_ids: torch.Tensor, + image_grid_thw: Optional[torch.Tensor] = None, + video_grid_thw: Optional[torch.Tensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Gets the position ids for Qwen2-VL, it should be generated before sharding the sequence. + The batch dim has been removed and the input_ids should be a 1D tensor representing a single example. + https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1405 + """ + spatial_merge_size = processor.image_processor.merge_size + tokens_per_second = 2 + image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>") + video_token_id = processor.tokenizer.convert_tokens_to_ids("<|video_pad|>") + vision_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|vision_start|>") + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + + position_ids = torch.ones(3, input_ids.size(0), dtype=input_ids.dtype, device=input_ids.device) # (3, seqlen) + image_index, video_index = 0, 0 + input_ids = input_ids[attention_mask == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + second_per_grid_t = 0 + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + second_per_grid_t = second_per_grid_ts[video_index] if second_per_grid_ts is not None else 1.0 + + video_index += 1 + remain_videos -= 1 + ed = ed_video + + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w) + t_index = (t_index * second_per_grid_t * tokens_per_second).long().flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., attention_mask == 1] = llm_positions.to(position_ids.device) + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1).to(input_ids.device) + else: + position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).view(1, -1).expand(3, -1) + + return position_ids + + +def prepare_fa2_from_position_ids( + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, position_ids: torch.Tensor +): + assert position_ids.ndim == 2 # (batch_size, seq_length) + query = query.contiguous().view(-1, query.size(-2), query.size(-1)) + key = key.contiguous().view(-1, key.size(-2), key.size(-1)) + value = value.contiguous().view(-1, value.size(-2), value.size(-1)) + position_ids = position_ids.view(-1) + cu_seqlens = torch.cat( + ( + (position_ids == 0).nonzero().view(-1).to(torch.int32), + torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32), + ) + ) + max_length = cu_seqlens.diff().max() # use cu_seqlens to infer max_length for qwen2vl mrope + return (query, key, value, (cu_seqlens, cu_seqlens), (max_length, max_length)) + + +def _custom_flash_attention_forward( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + query_length: int, + is_causal: bool = True, + position_ids: Optional[torch.Tensor] = None, + sliding_window: Optional[int] = None, + use_top_left_mask: bool = False, + deterministic: Optional[bool] = None, + **kwargs, +): + """ + Patches flash attention forward to handle 3D position ids in mrope. (3, batch_size, seq_length) + """ + # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). + use_sliding_windows = ( + _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window + ) + flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} + + if _flash_supports_deterministic: + flash_kwargs["deterministic"] = deterministic if deterministic is not None else _flash_deterministic_enabled + + if kwargs.get("softcap") is not None: + flash_kwargs["softcap"] = kwargs.pop("softcap") + + query_states, key_states, value_states = fa_peft_integration_check( + query_states, key_states, value_states, target_dtype=torch.bfloat16 + ) + + if position_ids is not None: + assert position_ids.ndim == 2 # (batch_size, seq_length / sp_size) + + sp_size = get_ulysses_sequence_parallel_world_size() + if sp_size > 1: + # qkv: (batch_size, seq_length / sp_size, num_head, head_size) + validate_ulysses_config(query_states.size(2), sp_size) + query_states = gather_seq_scatter_heads(query_states, seq_dim=1, head_dim=2) + key_states = gather_seq_scatter_heads(key_states, seq_dim=1, head_dim=2) + value_states = gather_seq_scatter_heads(value_states, seq_dim=1, head_dim=2) + position_ids_lst = [torch.empty_like(position_ids) for _ in range(sp_size)] + position_ids = dist.all_gather(position_ids_lst, position_ids, group=get_ulysses_sequence_parallel_group()) + position_ids = torch.cat(position_ids_lst, dim=-1) # (batch_size, seq_length) + + if position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all(): + batch_size = query_states.size(0) + q, k, v, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = prepare_fa2_from_position_ids( + query_states, key_states, value_states, position_ids + ) + attn_output = flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + dropout_p=kwargs.pop("dropout", 0.0), + softmax_scale=kwargs.pop("softmax_scale", None), + causal=is_causal, + **flash_kwargs, + ) + attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1)) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + query_length, + is_causal=is_causal, + sliding_window=sliding_window, + use_top_left_mask=use_top_left_mask, + deterministic=deterministic, + **kwargs, + ) # do not pass position_ids to old flash_attention_forward + + if sp_size > 1: + # (batch_size, seq_length, num_head, head_size) + attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1) + + return attn_output + + +def qwen2_vl_attn_forward( + self: "Qwen2VLAttention", + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, +) -> tuple[torch.Tensor, None, None]: + from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_multimodal_rotary_pos_emb, repeat_kv + + bsz, q_len, _ = hidden_states.size() # q_len = seq_length / sp_size + query_states = self.q_proj(hidden_states) # (batch_size, seq_length / sp_size, num_heads * head_size) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + sliding_window = None + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + sliding_window = self.config.sliding_window + + # This is before the transpose + q_len = query_states.shape[2] + + # FA2 uses non-transposed inputs + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if position_ids.ndim == 3: + position_ids = position_ids[0] + + attn_output = _custom_flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + query_length=q_len, + is_causal=getattr(self, "is_causal", True), + dropout=dropout_rate, + sliding_window=sliding_window, + use_top_left_mask=_flash_use_top_left_mask, + position_ids=position_ids, # important: pass position ids + ) # (batch_size, seq_length / sp_size, num_head, head_size) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + if is_transformers_version_in_range(min_version="4.54.0"): + return attn_output, None + else: + return attn_output, None, None + + +def _get_input_embeds( + model: "Qwen2VLForConditionalGeneration", + input_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, +): + inputs_embeds = model.get_input_embeddings()(input_ids) + if pixel_values is not None: + pixel_values = pixel_values.type(model.visual.dtype) + image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw) + n_image_tokens = (input_ids == model.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + + mask = input_ids == model.config.image_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + image_mask = mask_expanded.to(inputs_embeds.device) + + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(model.visual.dtype) + video_embeds = model.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == model.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + + mask = input_ids == model.config.video_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + video_mask = mask_expanded.to(inputs_embeds.device) + + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if pixel_values is None and pixel_values_videos is None: # handle mixed text-image data + config = model.config.vision_config + patch_dim = config.in_channels * config.temporal_patch_size * config.patch_size**2 + pixel_values = torch.zeros((16, patch_dim), dtype=inputs_embeds.dtype, device=inputs_embeds.device) + image_grid_thw = torch.tensor([[1, 4, 4]], dtype=torch.long, device=inputs_embeds.device) + image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw) + inputs_embeds += 0.0 * image_embeds.mean() + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + return inputs_embeds, attention_mask + + +def process_position_ids(position_ids: torch.Tensor) -> torch.Tensor: + if position_ids.ndim != 3 or position_ids.size(0) != 4: + # we concat the text position ids with the 3D vision position ids by default + # see https://github.com/huggingface/transformers/pull/39447 + raise ValueError("position_ids should be a 3D tensor of shape (4, batch_size, seq_length).") + + if is_transformers_version_in_range(max_version="4.53.3"): + # transformers < 4.54.0 only accepts vision position ids, so we discard the text position ids here + position_ids = position_ids[1:] + + return position_ids + + +@dataclass +class Qwen2VLCausalLMOutputForPPO(Qwen2VLCausalLMOutputWithPast): + log_probs: Optional[torch.FloatTensor] = None + entropy: Optional[torch.FloatTensor] = None + + +def qwen2_vl_base_forward( + self: "Qwen2VLForConditionalGeneration", + input_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + **kwargs, +): + kwargs["inputs_embeds"], kwargs["attention_mask"] = _get_input_embeds( + self, input_ids, attention_mask, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw + ) # avoid lora module having multiple keyword arguments + return self.language_model(input_ids=None, **kwargs) + + +def qwen2_vl_forward( + self: "Qwen2VLForConditionalGeneration", + input_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + **kwargs, +): + if is_transformers_version_in_range(min_version="4.52.0"): + return self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=process_position_ids(position_ids), + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + **kwargs, + ) + else: + inputs_embeds, attention_mask = _get_input_embeds( + self, input_ids, attention_mask, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw + ) + return self.model( + input_ids=None, + attention_mask=attention_mask, + position_ids=process_position_ids(position_ids), + inputs_embeds=inputs_embeds, + **kwargs, + ) + + +def forward_with_normal_backend( + self: Qwen2VLForConditionalGeneration, + input_ids: torch.LongTensor = None, + labels: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + **kwargs, +) -> "Qwen2VLCausalLMOutputWithPast": + outputs = qwen2_vl_forward(self, input_ids, **kwargs) + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + return Qwen2VLCausalLMOutputWithPast( + logits=logits, + hidden_states=outputs.hidden_states, + ) + + +def forward_with_torch_backend( + self: Qwen2VLForConditionalGeneration, + input_ids: torch.LongTensor = None, + labels: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + **kwargs, +) -> tuple | Qwen2VLCausalLMOutputForPPO: + from verl.utils.experimental.torch_functional import FusedLinearForPPO + + outputs = qwen2_vl_forward(self, input_ids, **kwargs) + hidden_states = outputs[0] + + # Loss calculations + if labels is not None: + rolled_labels = torch.roll(labels, shifts=-1, dims=-1) + elif input_ids is not None: + rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) + else: + raise RuntimeError("To use forward_with_torch_backend, either labels or input_ids must be provided.") + + fused_linear_for_ppo = FusedLinearForPPO() + log_probs, entropy = fused_linear_for_ppo.forward( + hidden_states=hidden_states, + vocab_weights=self.lm_head.weight, + input_ids=rolled_labels, + temperature=temperature, + ) + return Qwen2VLCausalLMOutputForPPO( + log_probs=log_probs, + entropy=entropy, + hidden_states=outputs.hidden_states, + ) + + +def forward_with_triton_backend( + self: Qwen2VLForConditionalGeneration, + input_ids: torch.LongTensor = None, + labels: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + **kwargs, +) -> tuple | Qwen2VLCausalLMOutputForPPO: + from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy + + outputs = qwen2_vl_forward(self, input_ids, **kwargs) + hidden_states = outputs[0] + + # Loss calculations + if labels is not None: + rolled_labels = torch.roll(labels, shifts=-1, dims=-1) + elif input_ids is not None: + rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) + else: + raise RuntimeError("To use forward_with_triton_backend, either labels or input_ids must be provided.") + + log_probs, entropy = linear_cross_entropy( + hidden_states, + self.lm_head.weight, + rolled_labels, + temperature, + "none", + ) + return Qwen2VLCausalLMOutputForPPO( + log_probs=log_probs, + entropy=entropy, + hidden_states=outputs.hidden_states, + ) diff --git a/code/RL_model/verl/verl_train/verl/models/transformers/qwen3_vl.py b/code/RL_model/verl/verl_train/verl/models/transformers/qwen3_vl.py new file mode 100644 index 0000000000000000000000000000000000000000..972848a1a083b1c01525806a088acfd3229e6e83 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/transformers/qwen3_vl.py @@ -0,0 +1,375 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import logging +import os +from dataclasses import dataclass +from typing import Optional + +import torch +from transformers.models.qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLCausalLMOutputWithPast, + Qwen3VLForConditionalGeneration, +) + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +def get_rope_index( + processor, + input_ids: torch.Tensor, + image_grid_thw: Optional[torch.Tensor] = None, + video_grid_thw: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, +) -> torch.Tensor: + """ + Gets the position ids for Qwen3-VL, it should be generated before sharding the sequence. + The batch dim has been removed and the input_ids should be a 1D tensor representing a single example. + https://github.com/huggingface/transformers/blob/v4.57.0/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py#L916 + """ + spatial_merge_size = processor.image_processor.merge_size + image_token_id = processor.image_token_id + video_token_id = processor.video_token_id + vision_start_token_id = processor.vision_start_token_id + + # Since we use timestamps to separate videos, + # like , + # the video_grid_thw should also be split + if video_grid_thw is not None: + video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0) + video_grid_thw[:, 0] = 1 + + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + + position_ids = torch.ones(3, input_ids.shape[0], dtype=input_ids.dtype, device=input_ids.device) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(input_ids.device) + input_ids = input_ids[attention_mask == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + # t_index is always 0 because llm_grid_t is always 1 + # (we use timestamps to encode the temporal information for videos) + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., attention_mask == 1] = llm_positions.to(position_ids.device) + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1).to(attention_mask.device) + else: + position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).view(1, -1).expand(3, -1) + + return position_ids + + +def _get_input_embeds( + model: "Qwen3VLForConditionalGeneration", + input_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, +): + inputs_embeds = model.get_input_embeddings()(input_ids) + image_mask, video_mask = None, None + if pixel_values is not None: + pixel_values = pixel_values.type(model.visual.dtype) + image_embeds, deepstack_image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw) + n_image_tokens = (input_ids == model.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + + mask = input_ids == model.config.image_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + image_mask = mask_expanded.to(inputs_embeds.device) + + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(model.visual.dtype) + video_embeds, deepstack_video_embeds = model.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == model.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + + mask = input_ids == model.config.video_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + video_mask = mask_expanded.to(inputs_embeds.device) + + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + visual_pos_masks = None + deepstack_visual_embeds = None + if image_mask is not None and video_mask is not None: + # aggregate visual_pos_masks and deepstack_visual_embeds + image_mask = image_mask[..., 0] + video_mask = video_mask[..., 0] + visual_pos_masks = image_mask | video_mask + deepstack_visual_embeds = [] + image_mask_joint = image_mask[visual_pos_masks] + video_mask_joint = video_mask[visual_pos_masks] + for img_embed, vid_embed in zip(deepstack_image_embeds, deepstack_video_embeds, strict=False): + embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1]).to(img_embed.device) + embed_joint[image_mask_joint, :] = img_embed + embed_joint[video_mask_joint, :] = vid_embed + deepstack_visual_embeds.append(embed_joint) + elif image_mask is not None: + image_mask = image_mask[..., 0] + visual_pos_masks = image_mask + deepstack_visual_embeds = deepstack_image_embeds + elif video_mask is not None: + video_mask = video_mask[..., 0] + visual_pos_masks = video_mask + deepstack_visual_embeds = deepstack_video_embeds + + if pixel_values is None and pixel_values_videos is None: + config = model.config.vision_config + patch_dim = config.in_channels * config.temporal_patch_size * config.patch_size**2 + pixel_values = torch.zeros((16, patch_dim), dtype=inputs_embeds.dtype, device=inputs_embeds.device) + image_grid_thw = torch.tensor([[1, 4, 4]], dtype=torch.long, device=inputs_embeds.device) + image_embeds, dummy_deepstack_image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw) + inputs_embeds += 0.0 * image_embeds.mean() + for emb in dummy_deepstack_image_embeds or []: + inputs_embeds += 0.0 * emb.mean() + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + return { + "inputs_embeds": inputs_embeds, + "attention_mask": attention_mask, + "visual_pos_masks": visual_pos_masks, + "deepstack_visual_embeds": deepstack_visual_embeds, + } + + +@dataclass +class Qwen3VLCausalLMOutputForPPO(Qwen3VLCausalLMOutputWithPast): + log_probs: Optional[torch.FloatTensor] = None + entropy: Optional[torch.FloatTensor] = None + + +def qwen3_vl_base_forward( + self: "Qwen3VLForConditionalGeneration", + input_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + **kwargs, +): + input_kwargs = _get_input_embeds( + self, input_ids, attention_mask, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw + ) # avoid lora module having multiple keyword arguments + kwargs.update(input_kwargs) + return self.language_model( + input_ids=None, + **kwargs, + ) + + +def forward_with_normal_backend( + self: "Qwen3VLForConditionalGeneration", + input_ids: torch.LongTensor = None, + labels: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + **kwargs, +) -> "Qwen3VLCausalLMOutputForPPO": + outputs = self.model(input_ids, **kwargs) + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + return Qwen3VLCausalLMOutputForPPO( + logits=logits, + hidden_states=outputs.hidden_states, + ) + + +def forward_with_torch_backend( + self: "Qwen3VLForConditionalGeneration", + input_ids: torch.LongTensor = None, + labels: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + **kwargs, +) -> "Qwen3VLCausalLMOutputForPPO": + from verl.utils.experimental.torch_functional import FusedLinearForPPO + + outputs = self.model(input_ids, **kwargs) + hidden_states = outputs[0] + + # Loss calculations + if labels is not None: + rolled_labels = torch.roll(labels, shifts=-1, dims=-1) + elif input_ids is not None: + rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) + else: + raise RuntimeError("To use forward_with_torch_backend, either labels or input_ids must be provided.") + + fused_linear_for_ppo = FusedLinearForPPO() + log_probs, entropy = fused_linear_for_ppo.forward( + hidden_states=hidden_states, + vocab_weights=self.lm_head.weight, + input_ids=rolled_labels, + temperature=temperature, + ) + return Qwen3VLCausalLMOutputForPPO( + log_probs=log_probs, + entropy=entropy, + hidden_states=outputs.hidden_states, + ) + + +def forward_with_triton_backend( + self: "Qwen3VLForConditionalGeneration", + input_ids: torch.LongTensor = None, + labels: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + **kwargs, +) -> "Qwen3VLCausalLMOutputForPPO": + from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy + + outputs = self.model(input_ids, **kwargs) + hidden_states = outputs[0] + + # Loss calculations + if labels is not None: + rolled_labels = torch.roll(labels, shifts=-1, dims=-1) + elif input_ids is not None: + rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) + else: + raise RuntimeError("To use forward_with_triton_backend, either labels or input_ids must be provided.") + + log_probs, entropy = linear_cross_entropy( + hidden_states, + self.lm_head.weight, + rolled_labels, + temperature, + "none", + ) + return Qwen3VLCausalLMOutputForPPO( + log_probs=log_probs, + entropy=entropy, + hidden_states=outputs.hidden_states, + ) + + +def patch_qwen3_vl_moe_sparse_moe_block_forward(): + """ + Monkey patch to fix a bug in transformers 4.57.3 where Qwen3VLMoeTextSparseMoeBlock.forward + incorrectly uses torch.zeros_like(hidden_states) instead of torch.zeros_like(router_logits) + when creating router_weights (line 148 in modeling_qwen3_vl_moe.py). + + This is a minimal fix that only changes the problematic line while keeping the rest of the + original implementation intact. + """ + try: + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextSparseMoeBlock + except ImportError: + # Model not available, skip patching + return + + # Store the original forward method for reference + original_forward = Qwen3VLMoeTextSparseMoeBlock.forward + + @functools.wraps(original_forward) + def patched_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.hidden_size) + router_logits = self.gate(hidden_states) + routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float) + routing_weights, router_indices = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) + # BUG FIX: Original code incorrectly uses hidden_states here, should use router_logits + routing_weights = routing_weights.to(router_logits.dtype) + router_weights = torch.zeros_like(router_logits).scatter_(1, router_indices, routing_weights) + hidden_states = hidden_states.reshape(batch_size, -1, self.hidden_size) + routed_out = self.experts(hidden_states, router_weights, router_indices) + return routed_out + + # Apply the patch + Qwen3VLMoeTextSparseMoeBlock.forward = patched_forward + logger.info("Monkey patched Qwen3VLMoeTextSparseMoeBlock.forward to fix router_weights bug") diff --git a/code/RL_model/verl/verl_train/verl/models/transformers/tiled_mlp.py b/code/RL_model/verl/verl_train/verl/models/transformers/tiled_mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..b43fa6f4ab259888e02833f49d4b7fb7e1eba49f --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/transformers/tiled_mlp.py @@ -0,0 +1,236 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +FSDP2-compatible TiledMLP implementation for memory-efficient MLP computation. + +This module provides a tiled MLP implementation that reduces peak memory usage +by processing the MLP forward/backward pass in chunks (tiles). This is particularly +useful for large models with FSDP2 training. +""" + +import threading +from typing import Optional + +import torch +import torch.nn as nn + + +class GradientAccumulator: + """Gradient accumulator for TiledMLP (FSDP compatible). + + This class manages gradient accumulation across multiple shards during + the backward pass of TiledMLP. It ensures correct gradient computation + when processing input in chunks. + """ + + def __init__(self, params: list[torch.nn.Parameter], total_shards: int, dtype: torch.dtype = None): + self.params = params + self.total_shards = total_shards + self.grad_accumulation_dtype = dtype or torch.float32 + self.accumulated_grads = {} + self.hooks = [] + self.lock = threading.Lock() + + for param in self.params: + if param.grad is not None: + self.accumulated_grads[param] = param.grad.to(self.grad_accumulation_dtype) + param.grad = None + else: + self.accumulated_grads[param] = torch.zeros_like(param, dtype=self.grad_accumulation_dtype) + + def install_hooks(self, is_last_shard: bool): + """Install gradient hooks for the current shard.""" + self._remove_hooks() + + def create_hook(param): + def hook(grad): + with self.lock: + grad_to_accum_dtype = grad.to(self.grad_accumulation_dtype) + self.accumulated_grads[param] += grad_to_accum_dtype + + if is_last_shard: + param.grad = None # Critical: prevent double accumulation + final_grad = self.accumulated_grads[param].to(param.dtype) + return final_grad + return None + + return hook + + for param in self.params: + if param.requires_grad: + hook = param.register_hook(create_hook(param)) + self.hooks.append(hook) + + def _remove_hooks(self): + """Remove all registered hooks.""" + for hook in self.hooks: + hook.remove() + self.hooks.clear() + + def cleanup(self): + """Cleanup hooks and resources.""" + self._remove_hooks() + + +class TiledMLP(torch.autograd.Function): + """TiledMLP implementation for memory-efficient MLP computation. + + This autograd function processes MLP forward/backward in tiles (chunks) + to reduce peak memory usage. Compatible with FSDP2. + """ + + @staticmethod + def forward(ctx, fn, module, x, shards, compute_params): + ctx.fn = fn + ctx.module = module + ctx.shards = shards + ctx.compute_params = [p for p in compute_params if p.requires_grad] + ctx.save_for_backward(x) + + # Split on dim=-2 (seqlen dimension) following Liger Kernel style + x_shards = list(torch.chunk(x, chunks=shards, dim=-2)) + with torch.no_grad(): + output_shards = [fn(module, x_shard) for x_shard in x_shards] + output_unsharded = torch.cat(output_shards, dim=-2) + return output_unsharded + + @staticmethod + def backward(ctx, *grads): + fn = ctx.fn + (x,) = ctx.saved_tensors + module = ctx.module + shards = ctx.shards + compute_params = ctx.compute_params + + x_requires_grad = x.requires_grad + x = x.detach() + x.requires_grad_(x_requires_grad) + + # Flatten to [bs*seqlen, hidden_size] + hidden_size = x.shape[-1] + x_shape_orig = x.shape + x = x.view(-1, hidden_size) + incoming_grad = grads[0].view(-1, hidden_size) + + # Pre-allocate input gradient + x_grad = torch.zeros_like(x) + + # Split on dim=0 + x_shards = list(torch.chunk(x, chunks=shards, dim=0)) + + grad_accumulator = GradientAccumulator(compute_params, shards, dtype=x.dtype) + + for i, x_shard in enumerate(x_shards): + x_shard.requires_grad_(x_requires_grad) + + shard_step = x_shards[i].shape[0] + shard_offset = i * x_shards[0].shape[0] + + # narrow(0, ...) creates a contiguous view that can receive gradients + x_shard.grad = x_grad.narrow(0, shard_offset, shard_step) + incoming_grad_shard = incoming_grad.narrow(0, shard_offset, shard_step) + + is_last_shard = i + 1 == shards + grad_accumulator.install_hooks(is_last_shard) + + with torch.enable_grad(): + output = fn(module, x_shard) + torch.autograd.backward(output, incoming_grad_shard) + + grad_accumulator.cleanup() + del grad_accumulator + + # Restore original shape + x_grad = x_grad.view(x_shape_orig) if x_requires_grad else None + return (None, None, x_grad, None, None) + + +def _mlp_forward_fn(module, x): + """Forward function for LlamaMLP / Qwen2MLP / Qwen3MLP style.""" + return module.down_proj(module.act_fn(module.gate_proj(x)) * module.up_proj(x)) + + +# ============================================================================ +# Monkey Patch Functions +# ============================================================================ + +# Model type to MLP class mapping +_MODEL_TYPE_TO_MLP_CLASS = { + "llama": ("transformers.models.llama.modeling_llama", "LlamaMLP"), + "qwen2": ("transformers.models.qwen2.modeling_qwen2", "Qwen2MLP"), + "qwen2_5": ("transformers.models.qwen2.modeling_qwen2", "Qwen2MLP"), # Qwen2.5 uses Qwen2 MLP + "qwen3": ("transformers.models.qwen3.modeling_qwen3", "Qwen3MLP"), +} + + +def apply_tiled_mlp_monkey_patch( + num_shards: int = 4, + model_type: Optional[str] = None, +): + """Apply TiledMLP monkey patch based on model_type. + + This function MUST be called BEFORE model instantiation to take effect. + It patches the MLP classes in transformers library to use TiledMLP for + memory-efficient computation during training. + + Args: + num_shards: Number of shards to split the input into. Higher values + reduce peak memory but may slightly impact performance. + model_type: The model type string (e.g., "llama", "qwen2", "qwen3"). + If None, patches all supported model types. + + Returns: + List of patched class names. + """ + if model_type is None: + types_to_patch = list(_MODEL_TYPE_TO_MLP_CLASS.keys()) + elif model_type in _MODEL_TYPE_TO_MLP_CLASS: + types_to_patch = [model_type] + else: + raise ValueError( + f"TiledMLP does not support model_type='{model_type}'. " + f"Supported types: {list(_MODEL_TYPE_TO_MLP_CLASS.keys())}. " + f"For SwiGLU-style MLPs, you can add support by extending _MODEL_TYPE_TO_MLP_CLASS " + f"in verl/models/transformers/tiled_mlp.py" + ) + + patched_classes = [] + + for mtype in types_to_patch: + module_path, class_name = _MODEL_TYPE_TO_MLP_CLASS[mtype] + try: + import importlib + + module = importlib.import_module(module_path) + mlp_class = getattr(module, class_name) + _patch_mlp_class(mlp_class, _mlp_forward_fn, num_shards) + if class_name not in patched_classes: + patched_classes.append(class_name) + except (ImportError, AttributeError) as e: + print(f"Warning: Could not patch {mtype} MLP: {e}") + + if patched_classes: + print(f"TiledMLP monkey patch applied to: {', '.join(patched_classes)} (shards={num_shards})") + + return patched_classes + + +def _patch_mlp_class(mlp_class: type[nn.Module], forward_fn, num_shards: int): + """Patch a single MLP class to use TiledMLP.""" + + def tiled_forward(self, x): + compute_params = [p for p in self.parameters() if p.requires_grad] + return TiledMLP.apply(forward_fn, self, x, num_shards, compute_params) + + mlp_class.forward = tiled_forward diff --git a/code/RL_model/verl/verl_train/verl/third_party/torch/__init__.py b/code/RL_model/verl/verl_train/verl/third_party/torch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7664279b7411a806f615b52b2405fd2c40672517 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/third_party/torch/__init__.py @@ -0,0 +1,87 @@ +# official torch 2.6.0 set_model_state_dict API leads to OOM +# this is a copy of torch/distributed/checkpoint from torch 2.7.0 + +# From PyTorch: + +# Copyright (c) 2016- Facebook, Inc (Adam Paszke) +# Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +# Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +# Copyright (c) 2011-2013 NYU (Clement Farabet) +# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +# Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +# From Caffe2: + +# Copyright (c) 2016-present, Facebook Inc. All rights reserved. + +# All contributions by Facebook: +# Copyright (c) 2016 Facebook Inc. + +# All contributions by Google: +# Copyright (c) 2015 Google Inc. +# All rights reserved. + +# All contributions by Yangqing Jia: +# Copyright (c) 2015 Yangqing Jia +# All rights reserved. + +# All contributions by Kakao Brain: +# Copyright 2019-2020 Kakao Brain + +# All contributions by Cruise LLC: +# Copyright (c) 2022 Cruise LLC. +# All rights reserved. + +# All contributions by Tri Dao: +# Copyright (c) 2024 Tri Dao. +# All rights reserved. + +# All contributions by Arm: +# Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates + +# All contributions from Caffe: +# Copyright(c) 2013, 2014, 2015, the respective contributors +# All rights reserved. + +# All other contributions: +# Copyright(c) 2015, 2016 the respective contributors +# All rights reserved. + +# Caffe2 uses a copyright model similar to Caffe: each contributor holds +# copyright over their contributions to Caffe2. The project versioning records +# all such contribution and copyright details. If a contributor wants to further +# mark their specific copyright on a particular contribution, they should +# indicate their copyright solely in the commit message of the change when it is +# committed. + +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. + +# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America +# and IDIAP Research Institute nor the names of its contributors may be +# used to endorse or promote products derived from this software without +# specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. diff --git a/code/RL_model/verl/verl_train/verl/third_party/torch/distributed/__init__.py b/code/RL_model/verl/verl_train/verl/third_party/torch/distributed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7664279b7411a806f615b52b2405fd2c40672517 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/third_party/torch/distributed/__init__.py @@ -0,0 +1,87 @@ +# official torch 2.6.0 set_model_state_dict API leads to OOM +# this is a copy of torch/distributed/checkpoint from torch 2.7.0 + +# From PyTorch: + +# Copyright (c) 2016- Facebook, Inc (Adam Paszke) +# Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +# Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +# Copyright (c) 2011-2013 NYU (Clement Farabet) +# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +# Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +# From Caffe2: + +# Copyright (c) 2016-present, Facebook Inc. All rights reserved. + +# All contributions by Facebook: +# Copyright (c) 2016 Facebook Inc. + +# All contributions by Google: +# Copyright (c) 2015 Google Inc. +# All rights reserved. + +# All contributions by Yangqing Jia: +# Copyright (c) 2015 Yangqing Jia +# All rights reserved. + +# All contributions by Kakao Brain: +# Copyright 2019-2020 Kakao Brain + +# All contributions by Cruise LLC: +# Copyright (c) 2022 Cruise LLC. +# All rights reserved. + +# All contributions by Tri Dao: +# Copyright (c) 2024 Tri Dao. +# All rights reserved. + +# All contributions by Arm: +# Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates + +# All contributions from Caffe: +# Copyright(c) 2013, 2014, 2015, the respective contributors +# All rights reserved. + +# All other contributions: +# Copyright(c) 2015, 2016 the respective contributors +# All rights reserved. + +# Caffe2 uses a copyright model similar to Caffe: each contributor holds +# copyright over their contributions to Caffe2. The project versioning records +# all such contribution and copyright details. If a contributor wants to further +# mark their specific copyright on a particular contribution, they should +# indicate their copyright solely in the commit message of the change when it is +# committed. + +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. + +# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America +# and IDIAP Research Institute nor the names of its contributors may be +# used to endorse or promote products derived from this software without +# specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. diff --git a/code/RL_model/verl/verl_train/verl/third_party/torch/distributed/_state_dict_utils.py b/code/RL_model/verl/verl_train/verl/third_party/torch/distributed/_state_dict_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d308449f7104e0c42afd48e38ed1696d2bf3072f --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/third_party/torch/distributed/_state_dict_utils.py @@ -0,0 +1,840 @@ +# official torch 2.6.0 set_model_state_dict API leads to OOM +# this is a copy of torch/distributed/checkpoint from torch 2.7.0 + +# From PyTorch: + +# Copyright (c) 2016- Facebook, Inc (Adam Paszke) +# Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +# Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +# Copyright (c) 2011-2013 NYU (Clement Farabet) +# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +# Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +# From Caffe2: + +# Copyright (c) 2016-present, Facebook Inc. All rights reserved. + +# All contributions by Facebook: +# Copyright (c) 2016 Facebook Inc. + +# All contributions by Google: +# Copyright (c) 2015 Google Inc. +# All rights reserved. + +# All contributions by Yangqing Jia: +# Copyright (c) 2015 Yangqing Jia +# All rights reserved. + +# All contributions by Kakao Brain: +# Copyright 2019-2020 Kakao Brain + +# All contributions by Cruise LLC: +# Copyright (c) 2022 Cruise LLC. +# All rights reserved. + +# All contributions by Tri Dao: +# Copyright (c) 2024 Tri Dao. +# All rights reserved. + +# All contributions by Arm: +# Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates + +# All contributions from Caffe: +# Copyright(c) 2013, 2014, 2015, the respective contributors +# All rights reserved. + +# All other contributions: +# Copyright(c) 2015, 2016 the respective contributors +# All rights reserved. + +# Caffe2 uses a copyright model similar to Caffe: each contributor holds +# copyright over their contributions to Caffe2. The project versioning records +# all such contribution and copyright details. If a contributor wants to further +# mark their specific copyright on a particular contribution, they should +# indicate their copyright solely in the commit message of the change when it is +# committed. + +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. + +# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America +# and IDIAP Research Institute nor the names of its contributors may be +# used to endorse or promote products derived from this software without +# specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + + +# ruff: noqa: B028, UP038, UP007, E721, E501 +# mypy: allow-untyped-defs +import copy +import io +import math +import weakref +from collections.abc import Mapping, MutableMapping +from typing import TYPE_CHECKING, Any, Callable, NamedTuple, Optional, Union, cast + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.distributed._functional_collectives import AsyncCollectiveTensor + +if dist.is_available() or TYPE_CHECKING: + from torch.distributed import distributed_c10d + from torch.distributed._shard.sharded_tensor import ShardedTensor + from torch.distributed.tensor import DTensor, Replicate, distribute_tensor + from torch.distributed.tensor._utils import compute_local_shape_and_global_offset + + +def _identity_func( + obj: torch.Tensor, + pg: Optional[dist.ProcessGroup], + device: Optional[torch.device], + companion_obj: Any, +) -> torch.Tensor: + return obj + + +def _all_gather_sharded_tensor( + sharded_tensor: "ShardedTensor", + pg: Optional[dist.ProcessGroup] = None, + device: Optional[torch.device] = None, +) -> torch.Tensor: + if pg is None: + pg = distributed_c10d._get_default_group() + world_size = dist.get_world_size(pg) + shards = sharded_tensor.local_shards() + dim_0_size = sharded_tensor.size()[0] # type: ignore[index] + tensor_numel = sharded_tensor.size().numel() # type: ignore[union-attr] + chunk_size = math.ceil(dim_0_size / world_size) * tensor_numel // dim_0_size + pg_device = distributed_c10d._get_pg_default_device(pg) if device is None else device + if shards: + local_tensor = shards[0].tensor.flatten() + if local_tensor.device.type != pg_device.type: + local_tensor = local_tensor.to(pg_device) + num_padding = chunk_size - local_tensor.numel() + if num_padding > 0: + local_tensor = F.pad(local_tensor, [0, num_padding]) + else: + local_tensor = torch.zeros(chunk_size, dtype=sharded_tensor.dtype, device=pg_device) + + tensor = torch.empty( + chunk_size * world_size, + dtype=local_tensor.dtype, + device=pg_device, + ) + dist.all_gather_into_tensor(tensor, local_tensor, group=pg) + + tensor = tensor.narrow(0, 0, tensor_numel).reshape(sharded_tensor.size()) + return tensor + + +class CompanionMismatch(Exception): + pass + + +def _iterate_state_dict( + iter_object: Any, + sharded_tensor_func: Callable, + dtensor_func: Callable, + tensor_func: Callable, + *, + pg: Optional[dist.ProcessGroup] = None, + device: Optional[torch.device] = None, + cpu_offload: bool = False, + companion_obj: Any = None, + ranks_only: tuple[int, ...] = (), + type_check: bool = True, + non_blocking: bool = True, +) -> dict[str, Any]: + """Iterate through the state dict, applying the given functions to each tensor type. + + Args: + iter_object (Any): the target state_dict. + sharded_tensor_func (Callable): the function to apply to ShardedTensor + dtensor_func (Callable): the function to apply to DTensor + tensor_func (Callable): the function to apply to Tensor + pg (Optional[dist.ProcessGroup]): process group passed to tensor functions + device (Optional[torch.device]): device passed to tensor functions + cpu_offload (bool): whether to offload the tensors to CPU memory. This option is ignored + if a companion_obj is supplied. + companion_obj (Any): A companion object to the state dict. If this object + is supplied, we attempt to copy the tensor to the companion object. + ranks_only (Tuple[int, ...]): if this tuple is empty, all ranks will + have the same state_dicts. Otherwise only ranks that in ``ranks_only`` + have the same state_dicts. Other ranks will get empty state_dicts. + type_check (bool): check if the instance data type is a supported type + that can be saved by DCP. The current supported data types are + torch.Tensor, DTensor, int, float, str, list, dict, None. + non_blocking (bool): whether to use non-blocking copy when copying to the companion object. + """ + # TODO: should we use pytree? + cpu_device = torch.device("cpu") + if isinstance(iter_object, ShardedTensor): + ret = sharded_tensor_func(iter_object, pg, device, companion_obj) + elif isinstance(iter_object, DTensor): + ret = dtensor_func(iter_object, pg, device, companion_obj) + elif isinstance(iter_object, torch.Tensor): + ret = tensor_func(iter_object, pg, device, companion_obj) + elif isinstance(iter_object, (int, float, str, bytes, io.BytesIO)) or iter_object is None: + ret = iter_object + elif isinstance(iter_object, dict): + if companion_obj is not None and ( + not isinstance(companion_obj, dict) or set(companion_obj.keys()) != set(iter_object.keys()) + ): + msg = "" if isinstance(companion_obj, dict) else f"{set(companion_obj.keys())=} {set(iter_object.keys())=}" + raise CompanionMismatch(msg) + + ret = { + key: _iterate_state_dict( + value, + sharded_tensor_func, + dtensor_func, + tensor_func, + pg=pg, + device=device, + cpu_offload=cpu_offload, + companion_obj=companion_obj[key] if companion_obj is not None else None, + ranks_only=ranks_only, + type_check=type_check, + non_blocking=non_blocking, + ) + for key, value in iter_object.items() + } + elif isinstance(iter_object, (list, tuple)): + if companion_obj is not None and ( + not isinstance(companion_obj, (list, tuple)) or len(companion_obj) != len(iter_object) + ): + raise CompanionMismatch + + ret = [ + _iterate_state_dict( + v, + sharded_tensor_func, + dtensor_func, + tensor_func, + pg=pg, + device=device, + cpu_offload=cpu_offload, + companion_obj=companion_obj[idx] if companion_obj is not None else None, + ranks_only=ranks_only, + type_check=type_check, + non_blocking=non_blocking, + ) + for idx, v in enumerate(iter_object) + ] + if isinstance(iter_object, tuple): + ret = tuple(ret) + elif not type_check: + ret = copy.deepcopy(iter_object) + else: + raise ValueError(f"Unexpected value type {type(iter_object)}") + + if not ranks_only or dist.get_rank(pg) in ranks_only: + if isinstance(ret, torch.Tensor): + if cpu_offload and companion_obj is None: + ret = ret.to(cpu_device) + + if companion_obj is not None: + if isinstance(companion_obj, DTensor): + assert isinstance(ret, DTensor) + companion_obj._local_tensor.copy_(ret._local_tensor, non_blocking=non_blocking) + else: + companion_obj.copy_(ret, non_blocking=non_blocking) + ret = companion_obj + else: + ret = {} if isinstance(ret, dict) else None + + return ret + + +def _gather_state_dict( + state_dict: dict[str, Any], + *, + pg: Optional[dist.ProcessGroup] = None, + device: Optional[torch.device] = None, + cpu_offload: bool = False, + ranks_only: tuple[int, ...] = (), + type_check: bool = True, +) -> dict[str, Any]: + """ + Given a state_dict, this API gathers all the ShardedTensors or DTensors in + the state_dict. + + + Args: + state_dict (Dict[str, Any]): the target sharded state_dict. + pg (Optional[dist.ProcessGroup]): the process group that is used to + gather ShardedTensor. Note that gathering a DTensor will use + the DeviceMesh. So this argument will be ignored when gathering a + DTensor. + device: (Optional[torch.device]): the device that is used to + perform allgather for ShardedTensor. Note that gathering a DTensor + will use the DeviceMesh. So this argument will be ignored when + gathering a DTensor. + cpu_offload (bool): whether to offload the tensors to CPU memory. The + default value is False. + ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will + have the same state_dicts. Otherwise only ranks that in ``ranks_only`` + have the same state_dicts. Other ranks will get empty state_dicts. + type_check: (bool): check if the instance data type is a supported type + that can be saved by DCP. The current supported data types are + torch.Tensor, DTensor, int, float, str, list, dict, None. + + Returns: + The gathered state dictionary. + """ + + def sharded_tensor_func(value, pg, device, companion_obj): + # ShardedTensor does not seem to record the original device type. + # So if the tensor is moved to CPU, we won't know the original type. + # As a result, we have to rely on the user to tell us the correct one. + cpu_device = torch.device("cpu") + output_tensor = _all_gather_sharded_tensor(value, pg, device) + local_shard_device = value.local_shards()[0].tensor.device if value.local_shards() else cpu_device + if output_tensor.device != local_shard_device: + value = output_tensor.to(local_shard_device) + else: + value = output_tensor + return value + + def dtensor_func(value, pg, device, companion_obj): + if value.device != value.device_mesh.device_type: + value = value.to(value.device_mesh.device_type) + # FSDP all_gather: [Shard(0)] -> [Replicate()] + # HSDP all_gather: [Replicate(), Shard(0)] -> [Replicate(), Replicate()] + # 2D FSDP + TP all_gather: + # - [Shard(0), Shard(n)] -> [Replicate(), Replicate()] + # - [Shard(0), Replicate()] -> [Replicate(), Replicate()] + placements = [Replicate() for _ in value.placements] + value = value.redistribute( + device_mesh=value.device_mesh, + placements=placements, + ) + # Call `wait()` to force the tensor to be synchronous with respect + # to the main stream. + # See the discussion in https://github.com/pytorch/pytorch/pull/117799. + value = value.to_local() + if isinstance(value, AsyncCollectiveTensor): + value = value.wait() + return value + + return _iterate_state_dict( + state_dict, + sharded_tensor_func, + dtensor_func, + _identity_func, + pg=pg, + device=device, + cpu_offload=cpu_offload, + ranks_only=ranks_only, + type_check=type_check, + ) + + +def _offload_state_dict_to_cpu( + state_dict: dict[str, Any], + *, + ranks_only: tuple[int, ...] = (), + type_check: bool = True, +) -> dict[str, Any]: + """ + Given a state_dict, this API offload all the tensors to CPU memory. + + Args: + state_dict (Dict[str, Any]): the target state_dict. + pg (Optional[dist.ProcessGroup]): the process group that is used to + gather ShardedTensor. Note that gathering a DTensor will use + the DeviceMesh. So this argument will be ignored when gathering a + DTensor. + ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will + have the same state_dicts. Otherwise only ranks that in ``ranks_only`` + have the same state_dicts. Other ranks will get empty state_dicts. + type_check: (bool): check if the instance data type is a supported type + that can be saved by DCP. The current supported data types are + torch.Tensor, DTensor, int, float, str, list, dict, None. + + Returns: + The gathered state dictionary. + """ + + ret = _iterate_state_dict( + state_dict, + _identity_func, + _identity_func, + _identity_func, + pg=None, + device=None, + cpu_offload=True, + ranks_only=ranks_only, + type_check=type_check, + ) + return ret + + +@torch.no_grad() +def _copy_state_dict( + state_dict: dict[str, Any], + copy_state_dict: dict[str, Any], + non_blocking: bool = False, + type_check: bool = True, +) -> dict[str, Any]: + """ + Copies all tensors in a given state dict into a different state_dict with the + same structure. Additionally, a copied state dict with the same value references + is returned. Editing the keys on this state dict will not affect the + passed in copy_state_dict (but the value references are the same). + + .. warning:: + It is expected by this function that state_dict and copy_state_dict share + the same structure and data types. + + .. warning:: + The current supported data types are + torch.Tensor, DTensor, int, float, str, list, dict, None. + + Args: + state_dict (Dict[str, Any]): the target state_dict. + copy_state_dict (Dict[str, Any]): + The state dict we are copying into. This state_dict must have exactly + the same structure as the source `state_dict`. + non_blocking: (bool): Whether copy ops should be performed asynchronously + type_check (bool): check if the instance data type is a supported type + that can be saved by DCP. The current supported data types are + torch.Tensor, DTensor, int, float, str, list, dict, None. + + Returns: + State Dict copy + """ + + return _iterate_state_dict( + state_dict, + _identity_func, + _identity_func, + _identity_func, + pg=None, + device=None, + cpu_offload=False, + ranks_only=(), + companion_obj=copy_state_dict, + type_check=type_check, + non_blocking=non_blocking, + ) + + +@torch.no_grad() +def _create_cpu_state_dict( + state_dict: dict[str, Any], pin_memory: bool = False, share_memory: bool = False +) -> dict[str, Any]: + """ + Given a state_dict, create another state_dict with the same structure and elements. + However, all tensors in the returned state_dict are new tensors on CPU. These + tensors can be placed on pin_memory or share_memory based on the provided arguments. + + .. warning:: + Setting both `pin_memory` and `share_memory` to True significantly increases the + latency of this method because of the nuances which require us to register memory + as pinned directly as opposed to relying on the pin_memory cache allocator. This + option should only be used for long lived tensors which are required to be shared. + This is not the case as long as at least one of `pin_memory` or `share_memory` is + set to False. + + """ + + def tensor_func( + obj: torch.Tensor, + pg: Optional[dist.ProcessGroup], + device: Optional[torch.device], + _: Any, + ) -> torch.Tensor: + if len(obj.size()) == 0: + return torch.tensor(0, dtype=obj.dtype) + + if share_memory: + t = torch.empty(*tuple(obj.size()), dtype=obj.dtype) + t = t.share_memory_() + if pin_memory: + + def unpin_memory(t): + succ = int(torch.cuda.cudart().cudaHostUnregister(t.data_ptr())) + assert succ == 0, f"Unpinning shared memory failed with error-code: {succ}" + + weakref.finalize(t, unpin_memory, t) + succ = int( + torch.cuda.cudart().cudaHostRegister( + t.data_ptr(), + t.numel() * t.element_size(), + 1, # lines up with 'cudaHostRegisterPortable' + ) + ) + assert succ == 0, f"Pinning shared memory failed with error-code: {succ}" + return t + elif pin_memory: + return torch.empty(*tuple(obj.size()), dtype=obj.dtype).pin_memory() + else: + return torch.empty(*tuple(obj.size()), dtype=obj.dtype) + + def dtensor_func( + obj: DTensor, + pg: Optional[dist.ProcessGroup], + device: Optional[torch.device], + _: Any, + ) -> DTensor: + if len(obj.size()) == 0: + return obj + + if obj.device != torch.device("cpu"): + ret = cast(DTensor, obj.to(device="cpu")) + else: + ret = copy.deepcopy(obj) + ret._local_tensor = tensor_func(ret._local_tensor, pg, device, None) + return ret + + ret = _iterate_state_dict( + state_dict, + _identity_func, + dtensor_func, + tensor_func, + pg=None, + device=None, + cpu_offload=False, + ranks_only=(), + type_check=False, + ) + return ret + + +def _check_state_dict_similarity( + state_dict: dict[str, Any], + compared_state_dict: dict[str, Any], +) -> bool: + """ + Given two state_dicts, check if the structures are the same. And + if a [key, tensor] pair exist in one state_dict there must be + the a corresponding pait, [key, other_tensor], in the other state_dict, + where tensor and other_tensor have the same size and dtype. + + Return the check result. + """ + + def tensor_func( + obj: torch.Tensor, + pg: Optional[dist.ProcessGroup], + device: Optional[torch.device], + companion_obj: Any, + ) -> torch.Tensor: + if companion_obj.dtype != obj.dtype or companion_obj.size() != obj.size(): + raise CompanionMismatch + return obj + + try: + _iterate_state_dict( + state_dict, + _identity_func, + _identity_func, + tensor_func, + pg=None, + device=None, + cpu_offload=False, + ranks_only=(), + companion_obj=compared_state_dict, + type_check=False, + ) + except CompanionMismatch: + return False + + return True + + +class _TensorInfo(NamedTuple): + size: torch.Size + dtype: torch.dtype + + +def _broadcast_tensors( + full_state_dict: dict[str, Any], + local_state_dict: dict[str, Any], + keys: list[str], + device: torch.device, + pg: Optional[dist.ProcessGroup] = None, +) -> None: + tensors = [] + for key in keys: + if dist.get_rank() == 0: + full_state = full_state_dict[key] + assert isinstance(full_state, torch.Tensor) + full_tensor = full_state.detach().to(device) + else: + tensor_info = full_state_dict[key] + full_tensor = torch.empty( + size=tensor_info.size, + device=device, + dtype=tensor_info.dtype, + ) + tensors.append(full_tensor) + local_state = local_state_dict.get(key, None) + if local_state is None: + continue + elif isinstance(local_state, DTensor): + local_state_dict[key] = (local_state, full_tensor) + else: + local_state_dict[key] = full_tensor + + if pg is None: + pg = dist.distributed_c10d._get_default_group() + + if len(tensors) > 1: + dist._broadcast_coalesced(pg, tensors, 500, 0) + else: + dist.broadcast(tensors[0], src=0, group=pg) + + _distribute_tensors(local_state_dict, keys, device, pg) + + +def _distribute_tensors( + local_state_dict: dict[str, Any], + keys: list[str], + device: torch.device, + pg: Optional[dist.ProcessGroup] = None, +) -> None: + if pg is None: + pg = dist.distributed_c10d._get_default_group() + for key in keys: + _local_state = local_state_dict.get(key, None) + if _local_state is None or torch.is_tensor(_local_state): + continue + + local_state = _local_state[0] + full_tensor = _local_state[1] + + shape, offset = compute_local_shape_and_global_offset( + full_tensor.shape, local_state.device_mesh, local_state.placements + ) + slices = [ + slice(cur_offset, cur_offset + cur_shape) for cur_shape, cur_offset in zip(shape, offset, strict=False) + ] + if local_state.is_meta: + # Use .clone() here rather than view to clone and return only the sliced portion, minimizing memory access and cost. + local_tensor = full_tensor[slices].detach().clone() + # TODO: currently, we cannot handle strided sharding if the dp dimension is not even. For example, + # one of the case that is not yet supported is when placements = (Shard(0), _StridedShard(0, sf=2)). + ret = DTensor.from_local( + local_tensor, + local_state.device_mesh, + local_state.placements, + shape=local_state.shape, + stride=local_state.stride(), + ) + else: + ret = local_state + # Copy full_tensor[slices] into local_state.to_local() to reduce memory footprint. + ret.to_local().copy_(full_tensor[slices]) + local_state_dict[key] = ret + + +def _broadcast_state_dict( + full_state_dict: dict[str, Any], + local_state_dict: dict[str, Any], + device: torch.device, + pg: Optional[dist.ProcessGroup] = None, + strict: bool = False, + cpu_offload: bool = False, +) -> None: + # Broadcast from rank0's `full_state_dict` to all ranks' `local_state_dict`. + # If strict is True, any keys in `local_state_dict` but not in `full_state_dict` + # will be removed from `local_state_dict`. + ret = {} + if dist.get_rank() == 0: + for key, value in full_state_dict.items(): + if not torch.is_tensor(value): + ret[key] = value + elif value.dim() == 0: + ret[key] = value.cpu() + else: + ret[key] = _TensorInfo(value.size(), value.dtype) + + broadcast_list = [ret] + dist.broadcast_object_list(broadcast_list, src=0, group=pg) + ret = broadcast_list[0] + # Gather values + keys = [] + local_state_dict_keys = set(local_state_dict.keys()) + global_keys = set() + for key, value in ret.items(): + global_keys.add(key) + if not isinstance(value, _TensorInfo): + if key in local_state_dict: + local_state_dict[key] = value + continue + + if dist.get_rank() == 0: + ret[key] = full_state_dict[key] + + keys.append(key) + # Broadcast every tensor to avoid OOM for now. + if len(keys) >= 1: + _broadcast_tensors(ret, local_state_dict, keys, device, pg) + if cpu_offload: + for key in keys: + local_state_dict[key] = local_state_dict[key].cpu() + keys.clear() + + if strict: + if missing_keys := (local_state_dict_keys - global_keys): + for key in missing_keys: + local_state_dict.pop(key) + + if keys: + _broadcast_tensors(ret, local_state_dict, keys, device, pg) + if cpu_offload: + for key in keys: + local_state_dict[key] = local_state_dict[key].cpu() + + +def _distribute_state_dict( + full_state_dict: dict[str, Any], + local_state_dict: dict[str, Any], + device: torch.device, + pg: Optional[dist.ProcessGroup] = None, +) -> None: + # Full_state_dict = True, broadcast_from_rank0 = False here. Each rank has + # full_state_dict. Skip the broadcast in ``_broadcast_state_dict`` and + # distribute tensors in each rank + for key, value in full_state_dict.items(): + if key not in full_state_dict: + continue + if not torch.is_tensor(value): + local_state_dict[key] = value + elif value.dim() == 0: + local_state_dict[key] = value.cpu() + else: + assert isinstance(value, torch.Tensor) + local_state = local_state_dict.get(key, None) + if local_state is None: + continue + elif isinstance(local_state, DTensor): + local_state_dict[key] = distribute_tensor( + value.detach().to(device), + local_state.device_mesh, + local_state.placements, + ) + else: + local_state_dict[key] = value.detach().to(device) + + +# These APIs are from torch.distributed.checkpoint. +# TODO: We should consolidate the code here as some not all modules can depend on +# DCP. +PATH_ITEM = Union[str, int] +OBJ_PATH = tuple[PATH_ITEM, ...] +FLATTEN_MAPPING = dict[str, OBJ_PATH] +STATE_DICT_TYPE = dict[str, Any] +CONTAINER_TYPE = MutableMapping[PATH_ITEM, Any] + + +def _traverse_state_dict( + state_dict: STATE_DICT_TYPE, + visitor: Callable[[OBJ_PATH, Any], None], +) -> None: + """ + Invoke ``visitor`` for each value recursively in ``state_dict``. + Mapping, list, and tuple will be flattened and other value types are treated + as the terminal values and will invoke ``visitor``. + """ + + def _traverse_obj(path: OBJ_PATH, value: Any) -> None: + if isinstance(value, Mapping): + for k, v in value.items(): + _traverse_obj(path + (str(k),), v) + elif isinstance(value, (list, tuple)): + for i, v in enumerate(value): + _traverse_obj(path + (i,), v) + else: + visitor(path, value) + + for key, value in state_dict.items(): + _traverse_obj((str(key),), value) + + +def _flatten_state_dict( + state_dict: STATE_DICT_TYPE, +) -> tuple[STATE_DICT_TYPE, FLATTEN_MAPPING]: + """ + Flatten ``state_dict`` made of nested dicts and lists into a top level dictionary. + + Use ``unflatten_state_dict`` to revert this process. + Returns: + A tuple with the flatten state_dict and a mapping from original to new state_dict. + N.B. The new keys are derived from the object paths, joined by dot. + For example: ``{ 'a': {'b':...}}`` results in the key `a.b`. + """ + flattened: STATE_DICT_TYPE = {} + mappings: FLATTEN_MAPPING = {} + + def flat_copy(path: OBJ_PATH, value: Any) -> None: + new_fqn = ".".join(map(str, path)) + if new_fqn in flattened: + raise ValueError(f"duplicated flatten key {new_fqn}") + flattened[new_fqn] = value + mappings[new_fqn] = path + + _traverse_state_dict(state_dict, flat_copy) + return flattened, mappings + + +def _set_element(root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: Any) -> None: + """Set ``value`` in ``root_dict`` along the ``path`` object path.""" + cur_container = cast(CONTAINER_TYPE, root_dict) + + def extend_list(lst: list[Any], idx: int) -> None: + while len(lst) <= idx: + lst.append(None) + + for i in range(1, len(path)): + prev_key = path[i - 1] + key = path[i] + def_val: CONTAINER_TYPE | list[Any] = {} if type(key) == str else [] + + if isinstance(cur_container, Mapping): + cur_container = cast(CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val)) + else: + extend_list(cur_container, prev_key) + if cur_container[prev_key] is None: + cur_container[prev_key] = def_val + cur_container = cur_container[prev_key] + + key = path[-1] + if type(key) == int: + extend_list(cast(list[Any], cur_container), key) + + cur_container[key] = value + + +def _unflatten_state_dict(state_dict: STATE_DICT_TYPE, mapping: FLATTEN_MAPPING) -> STATE_DICT_TYPE: + """Restore the original nested state_dict according to ``mapping`` and the flattened ``state_dict``.""" + nested: STATE_DICT_TYPE = {} + for key, value in state_dict.items(): + _set_element(nested, mapping[key], value) + return nested diff --git a/code/RL_model/verl/verl_train/verl/third_party/torch/distributed/checkpoint/__init__.py b/code/RL_model/verl/verl_train/verl/third_party/torch/distributed/checkpoint/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7664279b7411a806f615b52b2405fd2c40672517 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/third_party/torch/distributed/checkpoint/__init__.py @@ -0,0 +1,87 @@ +# official torch 2.6.0 set_model_state_dict API leads to OOM +# this is a copy of torch/distributed/checkpoint from torch 2.7.0 + +# From PyTorch: + +# Copyright (c) 2016- Facebook, Inc (Adam Paszke) +# Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +# Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +# Copyright (c) 2011-2013 NYU (Clement Farabet) +# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +# Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +# From Caffe2: + +# Copyright (c) 2016-present, Facebook Inc. All rights reserved. + +# All contributions by Facebook: +# Copyright (c) 2016 Facebook Inc. + +# All contributions by Google: +# Copyright (c) 2015 Google Inc. +# All rights reserved. + +# All contributions by Yangqing Jia: +# Copyright (c) 2015 Yangqing Jia +# All rights reserved. + +# All contributions by Kakao Brain: +# Copyright 2019-2020 Kakao Brain + +# All contributions by Cruise LLC: +# Copyright (c) 2022 Cruise LLC. +# All rights reserved. + +# All contributions by Tri Dao: +# Copyright (c) 2024 Tri Dao. +# All rights reserved. + +# All contributions by Arm: +# Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates + +# All contributions from Caffe: +# Copyright(c) 2013, 2014, 2015, the respective contributors +# All rights reserved. + +# All other contributions: +# Copyright(c) 2015, 2016 the respective contributors +# All rights reserved. + +# Caffe2 uses a copyright model similar to Caffe: each contributor holds +# copyright over their contributions to Caffe2. The project versioning records +# all such contribution and copyright details. If a contributor wants to further +# mark their specific copyright on a particular contribution, they should +# indicate their copyright solely in the commit message of the change when it is +# committed. + +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. + +# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America +# and IDIAP Research Institute nor the names of its contributors may be +# used to endorse or promote products derived from this software without +# specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. diff --git a/code/RL_model/verl/verl_train/verl/third_party/torch/distributed/checkpoint/state_dict.py b/code/RL_model/verl/verl_train/verl/third_party/torch/distributed/checkpoint/state_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..e4555802aed8c4b5963892a688b1ff41ae97fb56 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/third_party/torch/distributed/checkpoint/state_dict.py @@ -0,0 +1,1493 @@ +# official torch 2.6.0 set_model_state_dict API leads to OOM +# this is a copy of torch/distributed/checkpoint from torch 2.7.0 + +# From PyTorch: + +# Copyright (c) 2016- Facebook, Inc (Adam Paszke) +# Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +# Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +# Copyright (c) 2011-2013 NYU (Clement Farabet) +# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +# Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +# From Caffe2: + +# Copyright (c) 2016-present, Facebook Inc. All rights reserved. + +# All contributions by Facebook: +# Copyright (c) 2016 Facebook Inc. + +# All contributions by Google: +# Copyright (c) 2015 Google Inc. +# All rights reserved. + +# All contributions by Yangqing Jia: +# Copyright (c) 2015 Yangqing Jia +# All rights reserved. + +# All contributions by Kakao Brain: +# Copyright 2019-2020 Kakao Brain + +# All contributions by Cruise LLC: +# Copyright (c) 2022 Cruise LLC. +# All rights reserved. + +# All contributions by Tri Dao: +# Copyright (c) 2024 Tri Dao. +# All rights reserved. + +# All contributions by Arm: +# Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates + +# All contributions from Caffe: +# Copyright(c) 2013, 2014, 2015, the respective contributors +# All rights reserved. + +# All other contributions: +# Copyright(c) 2015, 2016 the respective contributors +# All rights reserved. + +# Caffe2 uses a copyright model similar to Caffe: each contributor holds +# copyright over their contributions to Caffe2. The project versioning records +# all such contribution and copyright details. If a contributor wants to further +# mark their specific copyright on a particular contribution, they should +# indicate their copyright solely in the commit message of the change when it is +# committed. + +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. + +# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America +# and IDIAP Research Institute nor the names of its contributors may be +# used to endorse or promote products derived from this software without +# specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + +# ruff: noqa: B028, UP038, UP007, E721 +# mypy: allow-untyped-defs +import contextlib +import functools +import gc +import warnings +from collections.abc import Generator, Iterable +from dataclasses import asdict, dataclass, field +from itertools import chain +from typing import Any, Callable, Optional, Union, cast, no_type_check + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed._shard.sharded_tensor import ShardedTensor +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + _CHECKPOINT_PREFIX, +) +from torch.distributed.fsdp import ( + FullOptimStateDictConfig, + FullStateDictConfig, + OptimStateDictConfig, + ShardedOptimStateDictConfig, + ShardedStateDictConfig, + StateDictConfig, + StateDictType, +) +from torch.distributed.fsdp import ( + FullyShardedDataParallel as FSDP, +) +from torch.distributed.fsdp._common_utils import ( + FSDP_WRAPPED_MODULE, + _get_module_fsdp_state_if_fully_sharded_module, +) +from torch.distributed.tensor import DTensor +from torch.nn.modules.module import _IncompatibleKeys +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils._pytree import tree_map_only + +from verl.third_party.torch.distributed._state_dict_utils import ( + _broadcast_state_dict, + _distribute_state_dict, + _flatten_state_dict, + _gather_state_dict, + _offload_state_dict_to_cpu, + _unflatten_state_dict, +) + +__all__ = [ + "FQNS_T", + "PrimitiveType", + "ValueType", + "DictValueType", + "ListDictValueType", + "OptimizerStateType", + "StateDictOptions", + "get_model_state_dict", + "get_optimizer_state_dict", + "get_state_dict", + "set_model_state_dict", + "set_optimizer_state_dict", + "set_state_dict", +] + + +_FLAT_PARAM = "_flat_param" +_PG = "param_groups" +_PARAMS = "params" +_STATE = "state" + +FQNS_T = set[str] +PrimitiveType = Union[DTensor, ShardedTensor, torch.Tensor, int, float, str] +ValueType = Union[PrimitiveType, list[PrimitiveType], tuple[PrimitiveType], dict[str, "ValueType"]] +DictValueType = dict[str, ValueType] +ListDictValueType = list[DictValueType] +OptimizerStateType = dict[str, DictValueType | ListDictValueType] + + +_patched_state_dict: set[Callable] = set() + + +@contextlib.contextmanager +def _gc_context(): + is_enabled = gc.isenabled() + gc.disable() + try: + yield + finally: + if is_enabled: + gc.enable() + + +@dataclass +class StateDictOptions: + """ + This dataclass specifies how get_state_dict/set_state_dict will work. + + - ``full_state_dict``: if this is set to True, all the tensors in the + returned state_dict will be gathered. No ShardedTensor and DTensor + will be in the returned state_dict. + + - ``cpu_offload``: offload all the tensors to cpu. To prevent CPU OOM, if + ``full_state_dict`` is also true, then only the rank0 will get the + state_dict and all other ranks will get empty state_dict. + + - ``ignore_frozen_params``: if the value is True, the returned state_dict + won't contain any frozen parameters -- the ``requires_grad`` is False. + The default value is False. + + - ``keep_submodule_prefixes`` (deprecated): when ``submodules`` is not None, this option + indicates whether to keep the submodule prefixes from the state_dict keys. + or example, if the submodule is ``module.pretrain`` and the full FQN of + the parameter is ``pretrain.layer1.weight`` of the param. When this option + is True, the parameter's key in the returned state_dict will be + ``pretrain.layer1.weight``. If the options is False, the key will be + ``layer1.weight``. + Note that if ``keep_submodule_prefixes`` is False, there may be conflicted + FQNs, hence there should be only one submodule in ``submodules``. + + - ``strict``: the ``strict`` option when ``set_state_dict`` calls + model.load_state_dict(). + + - ``broadcast_from_rank0``: when the option is True, rank0 should receive a + full state_dict and will broadcast the tensors in the state_dict/ + optim_state_dict one by one to other ranks. Other ranks will receive + the tensors and shard according to the local shards in the model and + optimizer. ``full_state_dict`` must be set to True when using this option. + This option currently only supports DTensor, not the legacy ShardedTensor. + """ + + full_state_dict: bool = False + cpu_offload: bool = False + ignore_frozen_params: bool = False + keep_submodule_prefixes: bool = True + strict: bool = True + broadcast_from_rank0: bool = False + flatten_optimizer_state_dict: bool = False + dsd_fqn_modifiers: str = "_fqn_modifiers" + + +@dataclass +class _StateDictInfo(StateDictOptions): + fqn_param_mapping: dict[ + str | torch.Tensor, + FQNS_T | torch.Tensor, + ] = field(default_factory=dict) + shared_params_mapping: dict[ + str | torch.Tensor, + FQNS_T | torch.Tensor, + ] = field(default_factory=dict) + submodule_prefixes: set[str] = field(default_factory=set) + handle_model: bool = True + handle_optim: bool = True + fsdp_context: Callable = contextlib.nullcontext + fsdp_modules: list[nn.Module] = field(default_factory=list) + + +@functools.cache +def _get_fqns( + model: nn.Module, + name: str, + dsd_fqn_modifiers: str = "_fqn_modifiers", + skip_ddp_prefix: bool = True, + skip_compiler_prefix: bool = True, +) -> FQNS_T: + """ + This API is used to convert the name of a parameter to the FQNs. For FSDP + without `use_orig_params`, the name of FlatParameter can be mapped to + multiple original parameters. As a result, the return type of this function + is `set[str]`. + + Args: + module (nn.Module): the root model. + name (str): the name + skip_ddp_prefix (bool): whether to skip DDP's `module` prefix + + Returns: + The canonical FQNs based on the model traversal. + """ + + # Remove the checkpoint prefix, if it exists. + name = name.replace(_CHECKPOINT_PREFIX, "") + if "." not in name: + return {name} + + obj_names = name.split(".") + fqn_obj_names = [] + curr_obj = model + for i, curr_obj_name in enumerate(obj_names): + if isinstance(curr_obj, DDP): + assert curr_obj_name == "module" + curr_obj = curr_obj.module + if not skip_ddp_prefix: + fqn_obj_names.append(curr_obj_name) + elif isinstance(curr_obj, FSDP): + if i < len(obj_names) - 1 and obj_names[i + 1] == _FLAT_PARAM: + prefix = ".".join(fqn_obj_names) + flat_param = getattr(curr_obj, _FLAT_PARAM) + if prefix: + prefix = f"{prefix}." + return {f"{prefix}{fqn}" for fqn in flat_param._fqns} + curr_obj = getattr(curr_obj, FSDP_WRAPPED_MODULE) + if curr_obj_name != FSDP_WRAPPED_MODULE: + fqn_obj_names.append(curr_obj_name) + curr_obj = getattr(curr_obj, curr_obj_name) + elif isinstance(curr_obj, torch._dynamo.eval_frame.OptimizedModule): + assert curr_obj_name == "_orig_mod" + curr_obj = curr_obj._orig_mod + if not skip_compiler_prefix: + fqn_obj_names.append(curr_obj_name) + else: + # In some modeuls, _fqn_modifiers would not shown in the state_dict keys, + # skip them in the fqn to ensure load stat dict successfully for them. + if hasattr(curr_obj, dsd_fqn_modifiers): + if removed_fqn := getattr(curr_obj, dsd_fqn_modifiers)().get(curr_obj_name): + if hasattr(curr_obj, removed_fqn): + curr_obj = getattr(curr_obj, removed_fqn) + fqn_obj_names.append(curr_obj_name) + if curr_obj_name == nn.modules.module._EXTRA_STATE_KEY_SUFFIX: + if i != len(obj_names) - 1: + raise RuntimeError("Expect `_extra_state` to be the last obj name") + else: + curr_obj = getattr(curr_obj, curr_obj_name) + + return {".".join(fqn_obj_names).replace(_CHECKPOINT_PREFIX, "")} + + +class _EXTRA_STATE: + pass + + +def _iterate_valid_model_state(model, dsd_fqn_modifiers="_fqn_modifiers"): + visited_modules: set[nn.Module] = set() + + def recurse(module: nn.Module, curr_fqn: str) -> Generator: + visited_modules.add(module) + + curr_fqn = f"{curr_fqn}." if curr_fqn else "" + for name, submodule in module.named_children(): + if submodule in visited_modules: + continue + # if user have state_dict_hooks in their model, they can add the state_dict key changes + # at dsd_fqn_modifiers in input to align with the function of state_dict_hook + if hasattr(module, dsd_fqn_modifiers) and name in getattr(module, dsd_fqn_modifiers)().values(): + # skip _fqn_modifiers here thus remove the last `.` added + new_fqn = curr_fqn[:-1] + else: + new_fqn = f"{curr_fqn}{name}" + yield from recurse(submodule, new_fqn) + + for name, obj in chain(module.named_buffers(recurse=False), module.named_parameters(recurse=False)): + if name in module._non_persistent_buffers_set: + continue + new_fqn = f"{curr_fqn}{name}" + yield new_fqn, obj + + if getattr(module.__class__, "get_extra_state", nn.Module.get_extra_state) != nn.Module.get_extra_state: + new_fqn = f"{curr_fqn}{nn.modules.module._EXTRA_STATE_KEY_SUFFIX}" + yield new_fqn, _EXTRA_STATE() + + yield from recurse(model, "") + + +def _verify_options( + model: nn.Module, + optims: tuple[torch.optim.Optimizer, ...], + optim_only: bool, + *, + submodules: Optional[set[nn.Module]] = None, + options: Optional[StateDictOptions] = None, +) -> _StateDictInfo: + """ + Verify the model and options passed by the user and generates _StateDictInfo. + """ + if submodules: + warnings.warn( + "Getting submodules only model/optim state_dict is deprecated and " + "will be removed in 2.5. This feature can be achieved by manually " + "filtering out the state_dict returned from get_state_dict.", + FutureWarning, + ) + if optim_only and not optims: + raise RuntimeError("Optimizers are not passed in but optim_only is set to True.") + + options = options or StateDictOptions() + + fqn_param_mapping: dict[str | torch.Tensor, set[str] | torch.Tensor] = {} + shared_params_mapping: dict[str | torch.Tensor, set[str] | torch.Tensor] = {} + for name, param in _iterate_valid_model_state(model): + if isinstance(param, _EXTRA_STATE): + continue + + fqns = _get_fqns(model, name) + fqn = fqn_param_mapping.get(param, None) + if fqn is not None: + cast(set[str], fqn_param_mapping[param]).update(fqns) + shared_params_mapping[param] = fqn_param_mapping[param] + else: + # We need to do copy as _get_fqns is lru_cached + fqn_param_mapping[param] = fqns.copy() + for fqn in fqns: + if not isinstance(param, _EXTRA_STATE): + fqn_param_mapping[fqn] = param + + for param_, fqns_ in list(shared_params_mapping.items()): + for fqn in fqns_: + shared_params_mapping[fqn] = cast(torch.Tensor, param_) + + submodule_prefixes: set[str] = set() + if submodules: + submodules = set(submodules) + for name, module in model.named_modules(): + if module not in submodules: + continue + fqns = _get_fqns(model, name) + assert len(fqns) == 1, "Submodule FQN should only have 1 instance" + submodule_prefixes.update(f"{fqn}." for fqn in fqns) + + if options.broadcast_from_rank0 and not options.full_state_dict: + raise ValueError("full_state_dict must be True when broadcast_from_rank0 is True.") + fsdp_modules = FSDP.fsdp_modules(model) + state_dict_config: StateDictConfig + optim_state_dict_config: OptimStateDictConfig + fsdp_context: Callable + if fsdp_modules: + # FSDP API only work if at least one FSDP instance exists. + if options.full_state_dict: + state_dict_config = FullStateDictConfig(offload_to_cpu=options.cpu_offload, rank0_only=options.cpu_offload) + optim_state_dict_config = FullOptimStateDictConfig( + offload_to_cpu=options.cpu_offload, + rank0_only=(options.cpu_offload or options.broadcast_from_rank0), + ) + state_dict_type = StateDictType.FULL_STATE_DICT + else: + state_dict_config = ShardedStateDictConfig( + offload_to_cpu=options.cpu_offload, + ) + optim_state_dict_config = ShardedOptimStateDictConfig( + offload_to_cpu=options.cpu_offload, + ) + state_dict_type = StateDictType.SHARDED_STATE_DICT + + @contextlib.contextmanager + def fsdp_state_dict_type_without_warning( + module, + state_dict_type, + state_dict_config, + optim_state_dict_config, + ): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="FSDP.state_dict_type", category=FutureWarning) + with FSDP.state_dict_type( + module=module, + state_dict_type=state_dict_type, + state_dict_config=state_dict_config, + optim_state_dict_config=optim_state_dict_config, + ): + yield + + fsdp_context = functools.partial( + fsdp_state_dict_type_without_warning, + module=model, + state_dict_type=state_dict_type, + state_dict_config=state_dict_config, + optim_state_dict_config=optim_state_dict_config, + ) + else: + fsdp_context = contextlib.nullcontext + + return _StateDictInfo( + **asdict(options), + fqn_param_mapping=fqn_param_mapping, + shared_params_mapping=shared_params_mapping, + submodule_prefixes=submodule_prefixes, + fsdp_context=fsdp_context, + fsdp_modules=cast(list[nn.Module], fsdp_modules), + handle_model=not optim_only, + handle_optim=(len(optims) > 0), + ) + + +def _verify_state_dict( + model_state_dict: dict[str, ValueType], + optim_state_dict: OptimizerStateType, + info: _StateDictInfo, +) -> None: + for module in info.fsdp_modules: + fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module) + assert fsdp_state is not None, "Expected a fsdp_state with a fsdp module." + + # Verify if the model_state_dict and optim_state_dict are valid. This API + # should give the users an explicit error message to debug or report. + if ( + info.handle_model + and not model_state_dict + and not info.submodule_prefixes + and not info.ignore_frozen_params + and not (info.cpu_offload and info.full_state_dict) + and info.strict + and not info.broadcast_from_rank0 + ): + raise RuntimeError( + "The option indicates that model state_dict is required to save " + "or load, but model state_dict is empty." + f"rank = {dist.get_rank()=}." + ) + + if info.handle_optim: + if not optim_state_dict and not (info.cpu_offload and info.full_state_dict) and (not info.broadcast_from_rank0): + raise RuntimeError( + "The option indicates that model state_dict is required to save, " + f"or load but optim state_dict is empty. {optim_state_dict}" + ) + + for key in model_state_dict.keys(): + if _FLAT_PARAM in key: + raise RuntimeError(f"{key} contains {_FLAT_PARAM}. This can happen if the model is not the root module.") + + +def _state_dict_fn(obj: nn.Module | torch.optim.Optimizer, api: str) -> Callable: + call = getattr(obj, api) + if call in _patched_state_dict: + call = functools.partial(getattr(obj.__class__, api), self=obj) + return call + + +def _maybe_full_or_cpu_state_dict(state_dict: dict[str, Any], info: _StateDictInfo) -> dict[str, Any]: + if info.full_state_dict: + ranks_only = () if (not info.cpu_offload or not torch.distributed.is_initialized()) else (0,) + return _gather_state_dict(state_dict, cpu_offload=info.cpu_offload, ranks_only=ranks_only) + elif info.cpu_offload: + return _offload_state_dict_to_cpu(state_dict) + else: + return state_dict + + +@torch.no_grad() +def _get_model_state_dict(model: nn.Module, info: _StateDictInfo) -> dict[str, ValueType]: + if not info.handle_model: + return {} + + with info.fsdp_context(): + state_dict = _state_dict_fn(model, "state_dict")() + + for key in list(state_dict.keys()): + fqns = _get_fqns(model, key) + assert len(fqns) == 1, (key, fqns) + fqn = next(iter(fqns)) + if fqn != key: + # As we only support FSDP, DDP, and TP, the only cases are + # wrapper-based DDP and compiler. Verify if the assumption + # is correct. + def verify(key, fqn) -> bool: + if len(fqn) >= len(key): + return False + fqn_split = fqn.split(".") + key_split = key.split(".") + fqn_idx = 0 + for key_idx, key_name in enumerate(key_split): + if key_name == fqn_split[fqn_idx]: + fqn_idx += 1 + if fqn_idx == len(fqn_split): + return key_idx == len(key_split) - 1 + elif key_name in ("module", "_orig_mod"): + continue + else: + return False + return True + + if not verify(key, fqn): + raise RuntimeError(f"An unexpected key, {key}, exists. FQN is {fqn}") + state_dict[fqn] = state_dict.pop(key) + + if info.submodule_prefixes: + new_state_dict: dict[str, ValueType] = {} + # TODO: make this faster. + for fqn in state_dict.keys(): + for prefix in info.submodule_prefixes: + if not fqn.startswith(prefix): + continue + if info.keep_submodule_prefixes: + new_state_dict[fqn] = state_dict[fqn] + else: + new_fqn = fqn[len(prefix) :] + new_state_dict[new_fqn] = state_dict[fqn] + state_dict = new_state_dict + + if info.ignore_frozen_params: + for key, param in model.named_parameters(): + if param.requires_grad: + continue + fqns = _get_fqns(model, key) + for fqn in fqns: + state_dict.pop(fqn) + + for key, p in list(state_dict.items()): + if torch.is_tensor(p) and p.is_meta: + state_dict.pop(key) + + return _maybe_full_or_cpu_state_dict(state_dict, info) + + +@torch.no_grad() +def _load_model_state_dict( + model: nn.Module, + state_dict: dict[str, ValueType], + info: _StateDictInfo, +) -> _IncompatibleKeys: + if not info.handle_model or (not state_dict and not info.broadcast_from_rank0): + return _IncompatibleKeys({}, {}) + + local_state_dict = {} + for key, value in _iterate_valid_model_state(model, info.dsd_fqn_modifiers): + fqns = _get_fqns(model, key, info.dsd_fqn_modifiers) + fqns_with_prefix = _get_fqns( + model, + key, + info.dsd_fqn_modifiers, + skip_ddp_prefix=False, + skip_compiler_prefix=False, + ) + + for fqn, fqn_with_prefix in zip(fqns, fqns_with_prefix, strict=False): + if (not info.broadcast_from_rank0 or dist.get_rank() == 0) and fqn != fqn_with_prefix: + load_value = state_dict.pop(fqn, None) + if load_value is None: + if info.strict: + raise RuntimeError(f"Missing key: {fqn}.") + else: + state_dict[fqn_with_prefix] = load_value + local_state_dict[fqn_with_prefix] = value + + assign = False + if info.broadcast_from_rank0 or info.full_state_dict: + devices = set() + for key, value in local_state_dict.items(): + if torch.is_tensor(value) and value.dim() > 0: + devices.add(value.device) + # In lora state_dict, there could be multiple devices, with meta device inside. + # Take the other device in the broadcast/distribtue, and set assign to True + if torch.device("meta") in devices: + devices.remove(torch.device("meta")) + assign = True + if len(devices) == 0: + devices.add(dist.distributed_c10d._get_pg_default_device()) + elif len(devices) > 1: + raise ValueError("Multiple devices found") + + if info.broadcast_from_rank0: + _broadcast_state_dict( + state_dict, + local_state_dict, + device=devices.pop(), + strict=info.strict, + cpu_offload=info.cpu_offload, + ) + elif info.full_state_dict: + _distribute_state_dict(state_dict, local_state_dict, device=devices.pop()) + for fqn, local_state in local_state_dict.items(): + state_dict[fqn] = local_state + + with info.fsdp_context(): + return cast( + _IncompatibleKeys, + _state_dict_fn(model, "load_state_dict")(state_dict=state_dict, strict=info.strict, assign=assign), + ) + + +def _init_optim_state(optim: torch.optim.Optimizer) -> None: + """ + Initialize optim states by calling the step() with zero grads. + """ + if optim.state: + # The optimizer state is initialized. + return + + # There are some stateless optimizers like SGD. These optimizer will + # not return in the above condition. So if gradients exist, we should also + # return. If gradients do not exist, the following initialization should + # not disturb SGD because the gradients and lr are both zero. + for param_group in optim.param_groups: + for param in param_group[_PARAMS]: + if param.grad is not None: + return + + for param_group in optim.param_groups: + for param in param_group[_PARAMS]: + if param.requires_grad: + param.grad = torch.zeros_like(param) + + # Some optimizers will update parameters regardless of grads due to lr, so + # make lr to zero when calling `step()`. + lrs = [] + for param_group in optim.param_groups: + if "lr" in param_group: + lrs.append(param_group["lr"]) + param_group["lr"] = torch.tensor(0.0) if isinstance(param_group["lr"], torch.Tensor) else 0.0 + optim.step(closure=None) + # Whether to recover the "lr" should not matter too much as we will + # restore checkpointing later. + for param_group in optim.param_groups: + if "lr" in param_group: + param_group["lr"] = lrs.pop(0) + optim.zero_grad(set_to_none=True) + + +def _flatten_optim_state_dict(state_dict: OptimizerStateType) -> dict[str, ValueType]: + """ + This API flattens the optimizer state_dict to support optimizer resharding for + MPMD, e.g., pipeline parallelism. + + Without the API, the original optimizer state_dict looks like: + { + "state": { + "layer1.weight": { + "step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor + }, + "layer2.weight": { + "step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor + }, + }, + "param_group": [ + { + "lr": 0.0, + "betas": (0.9, 0.95), ..., + "params": ["layer1.weight", "layer2.weight"] + } + ] + } + + With this API, the optimizer state_dict looks like: + { + "state.layer1.weight.step": 10, + "state.layer2.weight.step": 10, + "state.layer1.weight.exp_avg": SomeTensor, + "state.layer2.weight.exp_avg": SomeTensor, + "state.layer1.weight.exp_avg_sq": SomeTensor, + "state.layer2.weight.exp_avg_sq": SomeTensor, + "param_group.layer1.weight.lr" : 0.1, + "param_group.layer2.weight.lr" : 0.1, + "param_group.layer1.weight.betas" : (0.9, 0.95), + "param_group.layer2.weight.betas" : (0.9, 0.95), + } + + Note that if any of the value is a container, like the betas in the example, + this API won't flattent it. + """ + + def _raise_if_type_not_supported(v): + if not isinstance(v, (torch.Tensor, int, float)): + raise NotImplementedError( + f"Flattening optimizer state_dict only supports tensor, int, float states now. Type is {type(v)}." + ) + + ret: dict[str, ValueType] = {} + for fqn, state in cast(DictValueType, state_dict[_STATE]).items(): + for k, v in cast(DictValueType, state).items(): + _raise_if_type_not_supported(v) + ret[f"{_STATE}.{fqn}.{k}"] = v + + for param_group in cast(ListDictValueType, state_dict[_PG]): + fqns = param_group.pop(_PARAMS) + for fqn in cast(list[str], fqns): + for k, v in param_group.items(): + ret[f"{_PG}.{fqn}.{k}"] = v + return ret + + +def _unflatten_optim_state_dict( + optim: torch.optim.Optimizer, + state_dict: dict[str, ValueType], + info: _StateDictInfo, +) -> OptimizerStateType: + """ + This API unflattens the state_dict generated by _flatten_optim_state_dict(). + See the docstring of _flatten_optim_state_dict() for more detail. + """ + state: DictValueType = {} + pg_state: ListDictValueType = [] + return_osd: OptimizerStateType = {_STATE: state, _PG: pg_state} + + for param_group in optim.param_groups: + pg_state.append({_PARAMS: []}) + for param in param_group[_PARAMS]: + for fqn in info.fqn_param_mapping[param]: + # If a parameter is shared, only one of the FQN will be used. + # So we need to verify which if this fqn is actually used in + # the state_dict. + if fqn in info.shared_params_mapping: + in_params = False + for k in param_group.keys(): + if k == _PARAMS: + continue + flatten_key = f"{_PG}.{fqn}.{k}" + if flatten_key in state_dict: + in_params = True + break + else: + in_params = True + + if not in_params: + continue + + params = pg_state[-1][_PARAMS] + assert isinstance(params, list) # typing + params.append(fqn) + if not param.requires_grad: + continue + state[fqn] = {} + for state_name in optim.state[param].keys(): + cast(DictValueType, state[fqn])[state_name] = state_dict[f"{_STATE}.{fqn}.{state_name}"] + + first_param_fqn = cast(list[str], pg_state[-1][_PARAMS])[0] + for k in param_group.keys(): + if k == _PARAMS: + continue + value = state_dict[f"{_PG}.{first_param_fqn}.{k}"] + if k not in pg_state[-1]: + pg_state[-1][k] = value + elif pg_state[-1][k] != value: + raise RuntimeError( + "All the parameters in the same parameter group should have " + f"the same saved param_group value. But {first_param_fqn}.{k} " + f"is {value} while other(s) is {pg_state[-1][k]}." + ) + + return return_osd + + +@torch.no_grad() +def _get_optim_state_dict( + model: nn.Module, + optimizers: tuple[torch.optim.Optimizer, ...], + info: _StateDictInfo, +) -> OptimizerStateType: + if not info.handle_optim: + return {} + + optim_state_dict: OptimizerStateType = {_STATE: {}, _PG: []} + for optim in optimizers: + _init_optim_state(optim) + osd = _state_dict_fn(optim, "state_dict")() + if info.fsdp_modules: + with info.fsdp_context(): + osd = FSDP.optim_state_dict(model, optim, osd) + + # We need to specially handle FlatParameter FSDP as + # FlatParameter FSDP converts the FQNs. + # There are no easy ways to do this conversion systematically. + # We can only use a string replacment without correctness check. + if not osd: + continue + for k in list(osd[_STATE].keys()): + if "_orig_mod" in k: + osd[_STATE][k.replace("_orig_mod.", "")] = osd[_STATE].pop(k) + for g in osd[_PG]: + params = [k.replace("_orig_mod.", "") for k in g[_PARAMS]] + g[_PARAMS] = params + else: + params = list(chain.from_iterable(g[_PARAMS] for g in optim.param_groups)) + param_pid_mapping = dict(zip(params, range(len(params)), strict=False)) + fqn_pid_mapping = {} + for key, param in model.named_parameters(): + fqns = _get_fqns(model, key) + assert len(fqns) == 1 + fqn = next(iter(fqns)) + if param not in param_pid_mapping: + continue + pid = param_pid_mapping[param] + fqn_pid_mapping[fqn] = pid + fqn_pid_mapping[pid] = fqn + + for key in list(osd[_STATE].keys()): + fqn = fqn_pid_mapping[key] + osd[_STATE][fqn] = osd[_STATE].pop(key) + + for group in osd[_PG]: + group[_PARAMS] = [fqn_pid_mapping[pid] for pid in group[_PARAMS]] + + if not osd: + continue + + cast(DictValueType, optim_state_dict[_STATE]).update(osd[_STATE]) + cast(ListDictValueType, optim_state_dict[_PG]).extend(osd[_PG]) + + if info.flatten_optimizer_state_dict: + optim_state_dict = cast(OptimizerStateType, _flatten_optim_state_dict(optim_state_dict)) + + return _maybe_full_or_cpu_state_dict(optim_state_dict, info) + + +def _split_optim_state_dict( + model: nn.Module, + optim: torch.optim.Optimizer, + optim_state_dict: OptimizerStateType, + info: _StateDictInfo, +) -> OptimizerStateType: + """ + Extract the corresponding optim state_dict from ``optim_state_dict`` for + ``optim`` and return the result optim state_dict. + + Args: + model (nn.Module): the root model. + optim (torch.optim.Optimizer): the optimizer. + optim_state_dict (Dict[str, ValueType]): the superset optim state_dict that + contains the optim state_dict of ``optim``. + info (_StateDictInfo): state dict information. + + Returns: + The optim state_dict of ``optim``. + """ + + state: DictValueType = {} + pg_state: ListDictValueType = [] + return_osd: OptimizerStateType = {_STATE: state, _PG: pg_state} + pg_mapping: dict[int, int] = {} + + if all(isinstance(k, int) for k in cast(DictValueType, optim_state_dict[_STATE]).keys()): + return optim_state_dict + + for param_group in optim.param_groups: + pg_state.append({_PARAMS: []}) + for param in param_group[_PARAMS]: + for fqn in info.fqn_param_mapping[param]: + if fqn in info.shared_params_mapping: + in_params = False + for loaded_param_group in cast(ListDictValueType, optim_state_dict[_PG]): + if fqn in cast(list[str], loaded_param_group[_PARAMS]): + in_params = True + break + else: + in_params = True + if not in_params: + continue + + params = pg_state[-1][_PARAMS] + assert isinstance(params, list) + params.append(fqn) + if param.requires_grad: + state[fqn] = cast(DictValueType, optim_state_dict[_STATE])[fqn] + for loaded_param_group in cast(ListDictValueType, optim_state_dict[_PG]): + if fqn in cast(list[str], loaded_param_group[_PARAMS]): + pg_mapping[id(loaded_param_group)] = len(return_osd[_PG]) - 1 + + if len(param_group[_PARAMS]) == 0: + # Param_group with empty params. + ret = [] + for loaded_param_group in cast(ListDictValueType, optim_state_dict[_PG]): + if len(cast(list[str], loaded_param_group[_PARAMS])) == 0: + ret.append(loaded_param_group) + if len(ret) != 1: + raise ValueError( + "There are param groups that have zero parameters. " + "In such a case, DSD only support exactly one param group " + "with zero parameters." + "But the loaded state_dict has zero or more than one param groups " + "that have zero parameters." + ) + if len(optim_state_dict[_PG]) != len(optim.param_groups): + raise ValueError( + "When there is a parameter group that has zero parameters, multiple optimizers are not supported." + ) + pg_mapping[id(loaded_param_group)] = len(return_osd[_PG]) - 1 + + for param_group in cast(ListDictValueType, optim_state_dict[_PG]): + pg_idx = pg_mapping.get(id(param_group), -1) + if pg_idx == -1: + continue + + for key, value in param_group.items(): + if key == _PARAMS: + continue + # TODO: check if value is the same if exists. + pg_state[pg_idx][key] = value + + return return_osd + + +@torch.no_grad() +def _load_optim_state_dict( + model: nn.Module, + optimizers: tuple[torch.optim.Optimizer, ...], + state_dict: OptimizerStateType, + info: _StateDictInfo, +) -> None: + if not info.handle_optim: + return + + for optim in optimizers: + _init_optim_state(optim) + if state_dict: + if _STATE in state_dict: + optim_state_dict = _split_optim_state_dict(model, optim, state_dict, info) + else: + optim_state_dict = _unflatten_optim_state_dict(optim, cast(dict[str, ValueType], state_dict), info) + else: + optim_state_dict = {} + if info.fsdp_modules: + # We need to specially handle FlatParameter FSDP as + # FlatParameter FSDP converts the FQNs. + for original_fqn, _ in model.named_parameters(): + fqns = _get_fqns(model, original_fqn) + fqns_with_compiler = _get_fqns(model, original_fqn, skip_compiler_prefix=False) + if fqns == fqns_with_compiler: + continue + + assert len(fqns) == 1 + fqn = fqns.pop() + fqn_with_compiler = fqns_with_compiler.pop() + for g in optim_state_dict[_PG]: + val = cast(dict[str, Any], g) + params = [key.replace(fqn, fqn_with_compiler) for key in val[_PARAMS]] + val[_PARAMS] = params + osd_state = cast(DictValueType, optim_state_dict[_STATE]) + for k in list(osd_state.keys()): + if fqn in k: + osd_state[k.replace(fqn, fqn_with_compiler)] = osd_state.pop(k) + + with info.fsdp_context(): + optim_state_dict = FSDP.optim_state_dict_to_load(model, optim, optim_state_dict) + elif info.full_state_dict: + info.full_state_dict = False + local_state_dict = _get_optim_state_dict(model, (optim,), info) + info.full_state_dict = True + device = None + + def _device(t): + if t.dim() > 0: + nonlocal device + if device is None: + device = t.device + elif device != t.device: + raise ValueError("Device mismatch") + return t + + _ = tree_map_only(torch.Tensor, _device, local_state_dict) + assert device is not None + flatten_osd, osd_mapping = _flatten_state_dict(optim_state_dict) + flatten_local_osd, local_osd_mapping = _flatten_state_dict(local_state_dict) + if info.broadcast_from_rank0: + _broadcast_state_dict(flatten_osd, flatten_local_osd, device=device) + else: + _distribute_state_dict(flatten_osd, flatten_local_osd, device=device) + # The modifications listed seek to address the problem where optim might possess + # dissimilar parameters in comparison to optim_state_dict. This is achieved by + # incorporating differential parameters within local, which may result in optim + # having additional parameters ultimately. + for optim_key in flatten_osd.keys(): + if optim_key not in flatten_local_osd: + assert optim_key in osd_mapping + flatten_local_osd[optim_key] = flatten_osd[optim_key] + local_osd_mapping[optim_key] = osd_mapping[optim_key] + optim_state_dict = _unflatten_state_dict(flatten_local_osd, local_osd_mapping) + for pg in optim_state_dict[_PG]: + if _PARAMS not in pg: + cast(dict[str, ValueType], pg)[_PARAMS] = [] + + # Note that we do not have to convert the FQN back to param id here if + # order in optim.param_groups[idx][_PARAMS] is the same as the one in + # optim_state_dict[_PG][idx][_PARAMS]. + _state_dict_fn(optim, "load_state_dict")(state_dict=optim_state_dict) + + +def get_model_state_dict( + model: nn.Module, + *, + submodules: Optional[set[nn.Module]] = None, + options: Optional[StateDictOptions] = None, +) -> dict[str, ValueType]: + """ + Return the model state_dict of ``model``. + + See ``get_state_dict`` for the detail usage. + + Args: + model (nn.Module): the nn.Module to the model. + submodules (deprecated): Optional[set[nn.Module]]: only return the model parameters + that belong to the submodules. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be returned. See + `StateDictOptions` for the details. + + Returns: + The state_dict for ``model``. + + :rtype: typing.Dict[str, ValueType] + """ + with _gc_context(): + info = _verify_options( + model, + (), + optim_only=False, + submodules=submodules, + options=options, + ) + model_state_dict = _get_model_state_dict(model, info) + _verify_state_dict(model_state_dict, {}, info) + return model_state_dict + + +def get_optimizer_state_dict( + model: nn.Module, + optimizers: torch.optim.Optimizer | Iterable[torch.optim.Optimizer], + *, + submodules: Optional[set[nn.Module]] = None, + options: Optional[StateDictOptions] = None, +) -> OptimizerStateType: + """ + Return the combined state_dict for optimizers. + + See ``get_state_dict`` for the detail usage. + + Args: + model (nn.Module): the nn.Module to the model. + optimizers (Union[None, Optimizer, Iterable[Optimizer]]): + The optimizers that are used to optimize ``model``. + submodules (deprecated): Optional[set[nn.Module]]: only return the model parameters + that belong to the submodules. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be returned. See + `StateDictOptions` for the details. + + Returns: + The state_dict for ``optimizers``. + + :rtype: OptimizerStateType + """ + with _gc_context(): + optimizers = (optimizers,) if isinstance(optimizers, torch.optim.Optimizer) else tuple(optimizers) + info = _verify_options( + model, + optimizers, + optim_only=True, + submodules=submodules, + options=options, + ) + optim_state_dict = _get_optim_state_dict(model, optimizers, info) + _verify_state_dict({}, optim_state_dict, info) + return optim_state_dict + + +def get_state_dict( + model: nn.Module, + optimizers: torch.optim.Optimizer | Iterable[torch.optim.Optimizer], + *, + submodules: Optional[set[nn.Module]] = None, + options: Optional[StateDictOptions] = None, +) -> tuple[dict[str, ValueType], OptimizerStateType]: + """ + Return the model state_dict and optimizers state_dict. + + ``get_state_dict`` can process any module that is parallelized by PyTorch + FSDP/fully_shard, DDP/replicate, tensor_parallel/parallelize_module, and any + combination of these parallelisms. The main functions of ``get_state_dict`` + are: 1.) returning a model and optimizer state_dict that can be resharded + with a different number of trainers and/or different parallelisms. + 2.) hiding the parallelism-specific state_dict APIs. Users don't have to call + these APIs. + 3.) sanity checking the result state_dict. + + The keys of the result state dictionary are the canonical FQNs (Fully + Qualified Names). A canonical FQN refers to the FQN based on a parameter's + position in an nn.Module hierarchy. More specifically, a canonical FQN to a + parameter is the FQN returned by ``module.named_parameters()`` or + ``module.named_buffers()`` when the module is not distributed by any + parallelisms. Since the optimizer internally uses parameter IDs to represent + a parameter, there will be a conversion from the parameter IDs to the + canonical FQNs when calling this API. + + ``get_state_dict`` can also process a module that is not parallelized. In + such a case, ``get_state_dict`` only performs one function -- converting the + optimizer parameter IDs to the canonical FQNs. + + Example: + >>> # xdoctest: +SKIP + >>> import torch + >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + >>> from torch.nn.parallel import DistributedDataParallel as DDP + >>> from torch.distributed.checkpoint.state_dict import get_state_dict + + >>> fsdp_model = FSDP(copy.deepcopy(model)) + >>> fsdp_optim = torch.optim.Adam(model.parameters(), lr=1e-3) + >>> ddp_model = DDP(copy.deepcopy(model)) + >>> ddp_optim = torch.optim.Adam(model.parameters(), lr=1e-3) + + + >>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim) + >>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict( + ... fsdp_model, fsdp_optim + ... ) + + >>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(), + >>> # the asserts will fail. + >>> assert ddp_state_dict == fsdp_state_dict + >>> assert ddp_optim_state == fsdp_optim_state_dict + + + Args: + model (nn.Module): the nn.Module to the model. + optimizers (Union[None, Optimizer, Iterable[Optimizer]]): + The optimizers that are used to optimize ``model``. + submodules (deprecated): Optional[set[nn.Module]]: only return the model parameters + that belong to the submodules. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be returned. See + `StateDictOptions` for the details. + + Returns: + ``Tuple`` that contain model state_dict and optimizer state_dict. + + :rtype: typing.Tuple[typing.Dict[str, ValueType], OptimizerStateType] + """ + + with _gc_context(): + optimizers = (optimizers,) if isinstance(optimizers, torch.optim.Optimizer) else tuple(optimizers) + info = _verify_options( + model, + optimizers, + optim_only=False, + submodules=submodules, + options=options, + ) + model_state_dict = _get_model_state_dict(model, info) + optim_state_dict = _get_optim_state_dict(model, optimizers, info) + _verify_state_dict(model_state_dict, optim_state_dict, info) + return model_state_dict, optim_state_dict + + +def _unflatten_model_state_dict( + model: nn.Module, + state_dict: dict[nn.Module, dict[str, ValueType]] | dict[str, ValueType], +) -> dict[str, ValueType]: + if not state_dict: + return {} + + if isinstance(next(iter(state_dict.keys())), nn.Module): + warnings.warn( + "Passing model_state_dict as a ``Dict[nn.Module, Dict[str, Any]]``" + "is deprecated and will be removed in 2.5. If you need this " + "feature, please preprocessing the model_state_dict to achieve the " + "same functionality.", + FutureWarning, + ) + cast_state_dict = cast(dict[nn.Module, dict[str, ValueType]], state_dict) + new_state_dict: dict[str, ValueType] = {} + for submodule, sub_state_dict in cast_state_dict.items(): + for name, m in model.named_modules(): + if m != submodule: + continue + + fqns = _get_fqns(model, name) + assert len(fqns) == 1, "FQNs for a submodule should only have 1 element" + prefix = f"{next(iter(fqns))}." + new_state_dict.update({prefix + subfqn: value for subfqn, value in sub_state_dict.items()}) + return new_state_dict + else: + return cast(dict[str, ValueType], state_dict) + + +def set_model_state_dict( + model: nn.Module, + model_state_dict: dict[str, ValueType], + *, + options: Optional[StateDictOptions] = None, +) -> _IncompatibleKeys: + """Load the model state_dict. + + The counterpart of ``get_model_state_dict`` to set the state_dict to the + model. See ``set_state_dict`` for the detail usage. + + Args: + model (nn.Module): the nn.Module to the model. + model_state_dict: (Dict[str, ValueType]): + the model state_dict to load. If the key of the ``model_state_dict`` + is nn.Module, the key is a submodule of ``model`` and the value should + be the state_dict of the submodule. When loading the state_dict, + the prefix of the submodule will be append to the state_dict. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be loaded. See + `StateDictOptions` for the details. + + Returns: + ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: + * **missing_keys** is a list of str containing the missing keys + * **unexpected_keys** is a list of str containing the unexpected keys + + :type model_state_dict: typing.Dict[str, ValueType] + """ + model_state_dict: dict[str, ValueType] = _unflatten_model_state_dict(model, model_state_dict) + with _gc_context(): + info = _verify_options(model, (), optim_only=False, options=options) + + _verify_state_dict(model_state_dict, {}, info) + return _load_model_state_dict(model, model_state_dict, info) + + +def set_optimizer_state_dict( + model: nn.Module, + optimizers: torch.optim.Optimizer | Iterable[torch.optim.Optimizer], + optim_state_dict: OptimizerStateType, + *, + options: Optional[StateDictOptions] = None, +) -> None: + """Load the optimizers state_dict. + + The counterpart of ``get_optimizer_state_dict`` to set the state_dict to the + optimizers. See ``set_state_dict`` for the detail usage. + + WARN: ``set_optimizer_state_dict`` can only be called before ``backward()`` or after + ``step()`` is called on the optimizers. Otherwise, the optimizer states won't be + initialized correctly. + + Args: + model (nn.Module): the nn.Module to the model. + optimizers (Union[Optimizer, Iterable[Optimizer]]): + The optimizers that are used to optimize ``model``. + optim_state_dict: OptimizerStateType: + the optimizer state_dict to load. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be loaded. See + `StateDictOptions` for the details. + + Returns: + None + + :type optim_state_dict: typing.OptimizerStateType + """ + with _gc_context(): + optimizers = (optimizers,) if isinstance(optimizers, torch.optim.Optimizer) else tuple(optimizers) + info = _verify_options(model, optimizers, optim_only=True, options=options) + + _verify_state_dict({}, optim_state_dict, info) + _load_optim_state_dict(model, optimizers, optim_state_dict, info) + + +def set_state_dict( + model: nn.Module, + optimizers: torch.optim.Optimizer | Iterable[torch.optim.Optimizer], + *, + model_state_dict: dict[str, ValueType], + optim_state_dict: OptimizerStateType, + options: Optional[StateDictOptions] = None, +) -> _IncompatibleKeys: + """Load the model state_dict and optimizers state_dict. + + The counterpart of ``get_state_dict`` to set the state_dict to the model and + optimizers. The given ``model_state_dict`` and ``optim_state_dict`` do not + have to be returned by ``get_state_dict`` but must meet the following + requirements: 1) all FQNs are canonical FQNs as defined in ``get_state_dict``, + 2) if a tensor is sharded, it must be either a ShardedTensor or DTensor, + 3) optimizer state_dict cannot contain the parameter IDs; the keys should be + the canonical FQNs. + + WARN: ``set_state_dict`` can only be called before ``backward()`` or after ``step()`` + is called on the optimizers. Otherwise, the optimizer states won't be initialized + correctly. + + Args: + model (nn.Module): the nn.Module to the model. + optimizers (Union[Optimizer, Iterable[Optimizer]]): + The optimizers that are used to optimize ``model``. + model_state_dict: (Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]): + the model state_dict to load. If the key of the ``model_state_dict`` + is nn.Module, the key is a submodule of ``model`` and the value should + be the state_dict of the submodule. When loading the state_dict, + the prefix of the submodule will be append to the state_dict. + optim_state_dict: OptimizerStateType: + the optimizer state_dict to load. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be loaded. See + `StateDictOptions` for the details. + + Returns: + ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: + * **missing_keys** is a list of str containing the missing keys of the model state_dict. + * **unexpected_keys** is a list of str containing the unexpected keys of the model state_dict. + + :type model_state_dict: typing.Dict[str, ValueType] + :type optim_state_dict: typing.OptimizerStateType + """ + + model_state_dict: dict[str, ValueType] = _unflatten_model_state_dict(model, model_state_dict) + with _gc_context(): + optimizers = (optimizers,) if isinstance(optimizers, torch.optim.Optimizer) else tuple(optimizers) + info = _verify_options(model, optimizers, optim_only=not model_state_dict, options=options) + + _verify_state_dict(model_state_dict, optim_state_dict, info) + _load_optim_state_dict(model, optimizers, optim_state_dict, info) + return _load_model_state_dict(model, model_state_dict, info) + + +# TODO: correct the state_dict function signature. +# TODO: this API is not yet fully tested. Make it private +@no_type_check +def _patch_model_state_dict( + model: nn.Module, + *, + options: Optional[StateDictOptions] = None, +) -> None: + """Patch the ``state_dict`` and ``load_state_dict`` attributes of ``model``. + + Patch the ``state_dict`` and ``load_state_dict`` attributes of ``model`` to + be a partial function to call ``get_state_dict`` and ``set_state_dict``. + + Example: + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.checkpoint.state_dict import patch_model_state_dict + + model = fsdp(model) + patch_model_state_dict(model) + + Args: + model (nn.Module): the nn.Module to the model. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be loaded. See + `StateDictOptions` for the details. + Returns: + None + """ + + _state_dict_call = functools.partial( + get_model_state_dict, + model=model, + options=options, + ) + + def state_dict_call(): + return _state_dict_call() + + model.state_dict = state_dict_call + + _load_state_dict_call = functools.partial( + set_model_state_dict, + model=model, + options=options, + ) + + def load_state_dict_call(state_dict: dict[str, Any]): + _load_state_dict_call(model_state_dict=state_dict) + + model.load_state_dict = load_state_dict_call + + _patched_state_dict.add(state_dict_call) + _patched_state_dict.add(load_state_dict_call) + + +# TODO: correct the load_state_dict function signature. +# TODO: this API is not yet fully tested. Make it private +@no_type_check +def _patch_optimizer_state_dict( + model: nn.Module, + *, + optimizers: tuple[torch.optim.Optimizer, ...], + options: Optional[StateDictOptions] = None, +) -> None: + """Patch the ``state_dict`` and ``load_state_dict`` attributes of ``optimizers``. + + Patch the ``state_dict`` and ``load_state_dict`` attributes of ``optimizers`` to + be a partial function to call ``get_state_dict`` and ``set_state_dict``. + + Note that if there are multiple optimizers, all of the optimizers will be patched. + So users only need to call one of the state_dict() to get the full result. + + Example: + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.checkpoint.state_dict import patch_model_state_dict + + model = fsdp(model) + patch_model_state_dict(model) + + Args: + model (nn.Module): the nn.Module to the model. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be loaded. See + `StateDictOptions` for the details. + Returns: + None + """ + + _state_dict_call = functools.partial( + get_optimizer_state_dict, + model=model, + optimizers=optimizers, + options=options, + ) + + def state_dict_call(): + return _state_dict_call() + + _load_state_dict_call = functools.partial( + set_optimizer_state_dict, + model=model, + optimizers=optimizers, + options=options, + ) + + def load_state_dict_call(state_dict: dict[str, Any]): + _load_state_dict_call(optim_state_dict=state_dict) + + _patched_state_dict.add(state_dict_call) + _patched_state_dict.add(load_state_dict_call) + optimizers = (optimizers,) if isinstance(optimizers, torch.optim.Optimizer) else tuple(optimizers) + for optim in optimizers: + optim.state_dict = state_dict_call + optim.load_state_dict = load_state_dict_call diff --git a/code/RL_model/verl/verl_train/verl/third_party/vllm/__init__.py b/code/RL_model/verl/verl_train/verl/third_party/vllm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2f6646f3b6939851190bc9ecf6b6e0b1cb8e63d5 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/third_party/vllm/__init__.py @@ -0,0 +1,64 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from importlib.metadata import PackageNotFoundError, version + +from packaging import version as vs + +from verl.utils.device import is_npu_available +from verl.utils.import_utils import is_sglang_available + + +def get_version(pkg): + try: + return version(pkg) + except PackageNotFoundError: + return None + + +package_name = "vllm" +package_version = get_version(package_name) +vllm_version = None +VLLM_SLEEP_LEVEL = 1 + +if package_version is None: + if not is_sglang_available(): + raise ValueError( + f"vllm version {package_version} not supported and SGLang also not Found. Currently supported " + f"vllm versions are 0.7.0+" + ) +elif is_npu_available: + # sleep_mode=2 is not supported on vllm-ascend for now, will remove this restriction when this ability is ready. + VLLM_SLEEP_LEVEL = 1 + from vllm import LLM + from vllm.distributed import parallel_state +elif vs.parse(package_version) >= vs.parse("0.7.0"): + vllm_version = package_version + if vs.parse(package_version) >= vs.parse("0.8.5"): + VLLM_SLEEP_LEVEL = 2 + from vllm import LLM + from vllm.distributed import parallel_state +else: + if vs.parse(package_version) in [vs.parse("0.5.4"), vs.parse("0.6.3")]: + raise ValueError( + f"vLLM version {package_version} support has been removed. vLLM 0.5.4 and 0.6.3 are no longer " + f"supported. Please use vLLM 0.7.0 or later." + ) + if not is_sglang_available(): + raise ValueError( + f"vllm version {package_version} not supported and SGLang also not Found. Currently supported " + f"vllm versions are 0.7.0+" + ) + +__all__ = ["LLM", "parallel_state"] diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/__init__.py b/code/RL_model/verl/verl_train/verl/trainer/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..402475c3f0bac4aaea8ec15a9c2b24bf07fdf0e4 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import algorithm, config +from .algorithm import * # noqa: F401 +from .config import * # noqa: F401 + +__all__ = config.__all__ + algorithm.__all__ diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/code/RL_model/verl/verl_train/verl/trainer/config/_generated_ppo_megatron_trainer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bc418f0b1fa9297f5c91ae6ae7d3d085e48570ea --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -0,0 +1,719 @@ +# This reference configration yaml is automatically generated via 'scripts/generate_trainer_config.sh' +# in which it invokes 'python3 scripts/print_cfg.py --cfg job --config-name=ppo_megatron_trainer.yaml' to flatten the 'verl/trainer/config/ppo_megatron_trainer.yaml' config fields into a single file. +# Do not modify this file directly. +# The file is usually only for reference and never used. + +actor_rollout_ref: + actor: + optim: + _target_: verl.workers.config.McoreOptimizerConfig + lr: 1.0e-06 + lr_warmup_steps_ratio: 0.0 + total_training_steps: -1 + weight_decay: 0.01 + lr_warmup_steps: -1 + betas: + - 0.9 + - 0.999 + clip_grad: 1.0 + optimizer: adam + lr_warmup_init: 0.0 + lr_decay_steps: null + lr_decay_style: constant + min_lr: 0.0 + weight_decay_incr_style: constant + lr_wsd_decay_style: exponential + lr_wsd_decay_steps: null + use_checkpoint_opt_param_scheduler: false + override_optimizer_config: {} + megatron: + _target_: verl.workers.config.McoreEngineConfig + param_offload: false + grad_offload: false + optimizer_offload: false + tensor_model_parallel_size: 1 + expert_model_parallel_size: 1 + expert_tensor_parallel_size: null + pipeline_model_parallel_size: 1 + virtual_pipeline_model_parallel_size: null + context_parallel_size: 1 + sequence_parallel: true + use_distributed_optimizer: true + use_dist_checkpointing: false + dist_checkpointing_path: null + dist_checkpointing_prefix: '' + seed: 42 + override_ddp_config: {} + override_transformer_config: + recompute_granularity: null + recompute_modules: + - core_attn + recompute_method: null + recompute_num_layers: null + attention_backend: flash + override_mcore_model_config: {} + use_mbridge: true + vanilla_mbridge: true + use_remove_padding: true + forward_only: false + dtype: bfloat16 + _target_: verl.workers.config.McoreActorConfig + rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} + strategy: megatron + ppo_mini_batch_size: 256 + ppo_micro_batch_size: null + ppo_micro_batch_size_per_gpu: null + use_dynamic_bsz: false + ppo_max_token_len_per_gpu: 16384 + clip_ratio: 0.2 + clip_ratio_low: 0.2 + clip_ratio_high: 0.2 + tau_pos: 1.0 + tau_neg: 1.05 + freeze_vision_tower: false + policy_loss: + _target_: verl.workers.config.PolicyLossConfig + loss_mode: vanilla + clip_cov_ratio: 0.0002 + clip_cov_lb: 1.0 + clip_cov_ub: 5.0 + kl_cov_ratio: 0.0002 + ppo_kl_coef: 0.1 + clip_ratio_c: 3.0 + loss_agg_mode: token-mean + loss_scale_factor: null + entropy_coeff: 0 + calculate_entropy: false + use_kl_loss: false + use_prefix_grouper: false + use_torch_compile: true + kl_loss_coef: 0.001 + kl_loss_type: low_var_kl + ppo_epochs: 1 + shuffle: false + data_loader_seed: 42 + checkpoint: + _target_: verl.trainer.config.CheckpointConfig + save_contents: + - model + - optimizer + - extra + load_contents: ${.save_contents} + async_save: false + use_fused_kernels: ${oc.select:actor_rollout_ref.model.use_fused_kernels,false} + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: false + all_ranks: false + ranks: [] + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: + nsys: + _target_: verl.utils.profiler.config.NsightToolConfig + discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + npu: + _target_: verl.utils.profiler.config.NPUToolConfig + contents: [] + level: level0 + analysis: true + discrete: false + torch: + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + contents: [] + discrete: false + torch_memory: + _target_: verl.utils.profiler.config.TorchMemoryToolConfig + trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} + stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + router_replay: + _target_: verl.workers.config.RouterReplayConfig + mode: disabled + record_file: null + replay_file: null + load_weight: true + ref: + rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} + strategy: megatron + use_torch_compile: ${oc.select:actor_rollout_ref.actor.use_torch_compile,true} + log_prob_micro_batch_size: null + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: false + all_ranks: false + ranks: [] + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: + nsys: + _target_: verl.utils.profiler.config.NsightToolConfig + discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + npu: + _target_: verl.utils.profiler.config.NPUToolConfig + contents: [] + level: level0 + analysis: true + discrete: false + torch: + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + contents: [] + discrete: false + torch_memory: + _target_: verl.utils.profiler.config.TorchMemoryToolConfig + trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} + stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + router_replay: + _target_: verl.workers.config.RouterReplayConfig + mode: disabled + record_file: null + replay_file: null + megatron: + _target_: verl.workers.config.McoreEngineConfig + param_offload: ${oc.select:actor_rollout_ref.actor.megatron.param_offload,False} + grad_offload: false + optimizer_offload: false + tensor_model_parallel_size: ${oc.select:actor_rollout_ref.actor.megatron.tensor_model_parallel_size,1} + expert_model_parallel_size: ${oc.select:actor_rollout_ref.actor.megatron.expert_model_parallel_size,1} + expert_tensor_parallel_size: ${oc.select:actor_rollout_ref.actor.megatron.expert_tensor_parallel_size,null} + pipeline_model_parallel_size: ${oc.select:actor_rollout_ref.actor.megatron.pipeline_model_parallel_size,1} + virtual_pipeline_model_parallel_size: ${oc.select:actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size,null} + context_parallel_size: ${oc.select:actor_rollout_ref.actor.megatron.context_parallel_size,1} + sequence_parallel: true + use_distributed_optimizer: true + use_dist_checkpointing: false + dist_checkpointing_path: null + dist_checkpointing_prefix: '' + seed: ${oc.select:actor_rollout_ref.actor.megatron.seed,42} + override_ddp_config: {} + override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}} + override_mcore_model_config: {} + use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False} + vanilla_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.vanilla_mbridge,True} + use_remove_padding: ${oc.select:actor_rollout_ref.actor.megatron.use_remove_padding,True} + forward_only: true + dtype: bfloat16 + _target_: verl.workers.config.McoreActorConfig + load_weight: true + rollout: + _target_: verl.workers.config.RolloutConfig + name: ??? + mode: async + temperature: 1.0 + top_k: -1 + top_p: 1 + prompt_length: ${oc.select:data.max_prompt_length,512} + response_length: ${oc.select:data.max_response_length,512} + dtype: bfloat16 + gpu_memory_utilization: 0.5 + ignore_eos: false + enforce_eager: false + cudagraph_capture_sizes: null + free_cache_engine: true + tensor_model_parallel_size: 2 + data_parallel_size: 1 + expert_parallel_size: 1 + pipeline_model_parallel_size: 1 + max_num_batched_tokens: 8192 + max_model_len: null + max_num_seqs: 1024 + enable_chunked_prefill: true + enable_prefix_caching: true + logprobs_mode: processed_logprobs + scheduling_policy: fcfs + load_format: dummy + log_prob_micro_batch_size: null + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} + disable_log_stats: true + do_sample: true + 'n': 1 + over_sample_rate: 0 + multi_stage_wake_up: false + engine_kwargs: + vllm: {} + sglang: {} + trtllm: {} + val_kwargs: + _target_: verl.workers.config.SamplingConfig + top_k: -1 + top_p: 1.0 + temperature: 0 + 'n': 1 + do_sample: false + multi_turn: + _target_: verl.workers.config.MultiTurnConfig + enable: false + max_assistant_turns: null + tool_config_path: null + max_user_turns: null + max_parallel_calls: 1 + max_tool_response_length: 256 + tool_response_truncate_side: middle + interaction_config_path: null + use_inference_chat_template: false + tokenization_sanity_check_mode: strict + format: hermes + num_repeat_rollouts: null + calculate_log_probs: false + agent: + _target_: verl.workers.config.AgentLoopConfig + num_workers: 8 + default_agent_loop: single_turn_agent + agent_loop_config_path: null + custom_async_server: + _target_: verl.workers.config.CustomAsyncServerConfig + path: null + name: null + checkpoint_engine: + _target_: verl.workers.config.CheckpointEngineConfig + backend: naive + update_weights_bucket_megabytes: 2048 + engine_kwargs: {} + trace: + _target_: verl.workers.config.TraceConfig + backend: null + token2text: false + max_samples_per_step_per_worker: null + skip_rollout: false + skip_dump_dir: /tmp/rollout_dump + skip_tokenizer_init: true + enable_rollout_routing_replay: false + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: ${oc.select:actor_rollout_ref.actor.profiler.enable,false} + all_ranks: ${oc.select:actor_rollout_ref.actor.profiler.all_ranks,false} + ranks: ${oc.select:actor_rollout_ref.actor.profiler.ranks,[]} + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: ${oc.select:actor_rollout_ref.actor.profiler.tool_config,null} + prometheus: + _target_: verl.workers.config.PrometheusConfig + enable: false + port: 9090 + file: /tmp/ray/session_latest/metrics/prometheus/prometheus.yml + served_model_name: ${oc.select:actor_rollout_ref.model.path,null} + quantization: null + quantization_config_file: null + mtp: ${oc.select:actor_rollout_ref.model.mtp, null} + layer_name_map: + qkv_layer_name: qkv + gate_proj_layer_name: gate_up + model: + _target_: verl.workers.config.HFModelConfig + path: ~/models/deepseek-llm-7b-chat + hf_config_path: null + tokenizer_path: null + use_shm: false + trust_remote_code: false + custom_chat_template: null + external_lib: null + override_config: + model_config: {} + moe_config: + freeze_moe_router: false + enable_gradient_checkpointing: true + enable_activation_offload: false + use_remove_padding: false + lora_rank: 0 + lora_alpha: 16 + target_modules: all-linear + exclude_modules: null + lora_adapter_path: null + use_liger: false + use_fused_kernels: false + fused_kernel_options: + impl_backend: torch + tiled_mlp: + enabled: false + num_shards: 4 + mtp: + _target_: verl.workers.config.MtpConfig + enable: false + enable_train: false + enable_rollout: false + detach_encoder: false + mtp_loss_scaling_factor: 0.1 + speculative_algorithm: EAGLE + speculative_num_steps: 3 + speculative_eagle_topk: 1 + speculative_num_draft_tokens: 4 + method: mtp + num_speculative_tokens: 1 + lora: + type: lora + merge: false + rank: 0 + alpha: 32 + dropout: 0.0 + target_modules: + - linear_qkv + - linear_proj + - linear_fc1 + - linear_fc2 + exclude_modules: [] + dropout_position: pre + lora_A_init_method: xavier + lora_B_init_method: zero + a2a_experimental: false + dtype: null + adapter_path: null + freeze_vision_model: true + freeze_vision_projection: true + freeze_language_model: true + hybrid_engine: true + nccl_timeout: 600 +data: + tokenizer: null + use_shm: false + train_files: ~/data/rlhf/gsm8k/train.parquet + val_files: ~/data/rlhf/gsm8k/test.parquet + train_max_samples: -1 + val_max_samples: -1 + prompt_key: prompt + reward_fn_key: data_source + max_prompt_length: 512 + max_response_length: 512 + train_batch_size: 1024 + val_batch_size: null + tool_config_path: ${oc.select:actor_rollout_ref.rollout.multi_turn.tool_config_path, + null} + return_raw_input_ids: false + return_raw_chat: true + return_full_prompt: false + shuffle: true + seed: null + dataloader_num_workers: 8 + image_patch_size: 14 + validation_shuffle: false + filter_overlong_prompts: false + filter_overlong_prompts_workers: 1 + truncation: error + image_key: images + video_key: videos + trust_remote_code: false + custom_cls: + path: null + name: null + return_multi_modal_inputs: true + sampler: + class_path: null + class_name: null + datagen: + path: null + name: null + apply_chat_template_kwargs: {} +reward_manager: + _target_: verl.trainer.config.config.RewardManagerConfig + source: register + name: ${oc.select:reward_model.reward_manager,naive} + module: + _target_: verl.trainer.config.config.ModuleConfig + path: null + name: custom_reward_manager +critic: + optim: + _target_: verl.workers.config.McoreOptimizerConfig + lr: 1.0e-05 + lr_warmup_steps_ratio: 0.0 + total_training_steps: -1 + weight_decay: 0.01 + lr_warmup_steps: -1 + betas: + - 0.9 + - 0.999 + clip_grad: 1.0 + optimizer: adam + lr_warmup_init: 0.0 + lr_decay_steps: null + lr_decay_style: constant + min_lr: 0.0 + weight_decay_incr_style: constant + lr_wsd_decay_style: exponential + lr_wsd_decay_steps: null + use_checkpoint_opt_param_scheduler: false + override_optimizer_config: {} + megatron: + _target_: verl.workers.config.McoreEngineConfig + param_offload: false + grad_offload: false + optimizer_offload: false + tensor_model_parallel_size: 1 + expert_model_parallel_size: 1 + expert_tensor_parallel_size: null + pipeline_model_parallel_size: 1 + virtual_pipeline_model_parallel_size: null + context_parallel_size: 1 + sequence_parallel: true + use_distributed_optimizer: true + use_dist_checkpointing: false + dist_checkpointing_path: null + dist_checkpointing_prefix: '' + seed: 42 + override_ddp_config: {} + override_transformer_config: + recompute_granularity: null + recompute_modules: + - core_attn + recompute_method: null + recompute_num_layers: null + attention_backend: flash + override_mcore_model_config: {} + use_mbridge: true + vanilla_mbridge: true + use_remove_padding: true + forward_only: false + dtype: bfloat16 + _target_: verl.workers.config.McoreCriticConfig + rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} + strategy: megatron + enable: null + model: + path: ~/models/deepseek-llm-7b-chat + tokenizer_path: ${oc.select:actor_rollout_ref.model.path,"~/models/deepseek-llm-7b-chat"} + override_config: + model_config: {} + moe_config: + freeze_moe_router: false + external_lib: ${oc.select:actor_rollout_ref.model.external_lib,null} + trust_remote_code: ${oc.select:actor_rollout_ref.model.trust_remote_code,false} + _target_: verl.trainer.config.BaseModelConfig + lora: + type: lora + rank: 0 + alpha: 32 + dropout: 0.0 + target_modules: + - linear_qkv + - linear_proj + - linear_fc1 + - linear_fc2 + exclude_modules: [] + dropout_position: pre + lora_A_init_method: xavier + lora_B_init_method: zero + a2a_experimental: false + dtype: null + adapter_path: null + freeze_vision_model: true + freeze_vision_projection: true + freeze_language_model: true + ppo_mini_batch_size: ${oc.select:actor_rollout_ref.actor.ppo_mini_batch_size,256} + ppo_micro_batch_size: null + ppo_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size,null} + use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + ppo_max_token_len_per_gpu: 32768 + forward_max_token_len_per_gpu: ${.ppo_max_token_len_per_gpu} + ppo_epochs: ${oc.select:actor_rollout_ref.actor.ppo_epochs,1} + shuffle: ${oc.select:actor_rollout_ref.actor.shuffle,false} + data_loader_seed: ${oc.select:actor_rollout_ref.actor.data_loader_seed,null} + cliprange_value: 0.5 + loss_agg_mode: ${oc.select:actor_rollout_ref.actor.loss_agg_mode,token-mean} + checkpoint: + _target_: verl.trainer.config.CheckpointConfig + save_contents: + - model + - optimizer + - extra + load_contents: ${.save_contents} + async_save: false + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: false + all_ranks: false + ranks: [] + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: + nsys: + _target_: verl.utils.profiler.config.NsightToolConfig + discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + npu: + _target_: verl.utils.profiler.config.NPUToolConfig + contents: [] + level: level0 + analysis: true + discrete: false + torch: + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + contents: [] + discrete: false + torch_memory: + _target_: verl.utils.profiler.config.TorchMemoryToolConfig + trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} + stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + nccl_timeout: 600 + load_weight: true +reward_model: + enable: false + enable_resource_pool: false + n_gpus_per_node: 8 + nnodes: 0 + strategy: megatron + model: + input_tokenizer: ${actor_rollout_ref.model.path} + path: ~/models/FsfairX-LLaMA3-RM-v0.1 + external_lib: ${actor_rollout_ref.model.external_lib} + trust_remote_code: false + override_config: {} + micro_batch_size: null + micro_batch_size_per_gpu: null + max_length: null + use_dynamic_bsz: ${critic.use_dynamic_bsz} + forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} + reward_manager: naive + reward_loop_source: register + reward_loop_module_path: null + reward_loop_class_name: null + launch_reward_fn_async: false + sandbox_fusion: + url: null + max_concurrent: 64 + memory_limit_mb: 1024 + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: false + all_ranks: false + ranks: [] + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: ${oc.select:actor_rollout_ref.actor.profiler.tool_config,null} + nccl_timeout: 600 + megatron: + _target_: verl.workers.config.MegatronEngineConfig + param_offload: false + tensor_model_parallel_size: 1 + expert_model_parallel_size: 1 + expert_tensor_parallel_size: null + pipeline_model_parallel_size: 1 + virtual_pipeline_model_parallel_size: null + context_parallel_size: 1 + sequence_parallel: true + use_distributed_optimizer: false + use_dist_checkpointing: false + dist_checkpointing_path: null + dist_checkpointing_prefix: '' + seed: ${oc.select:actor_rollout_ref.actor.megatron.seed,42} + override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}} + use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False} + vanilla_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.vanilla_mbridge,True} + use_remove_padding: ${oc.select:actor_rollout_ref.actor.megatron.use_remove_padding,True} + dtype: bfloat16 + load_weight: true + use_reward_loop: true + num_workers: 1 + rollout: + _target_: verl.workers.config.RolloutConfig + name: ??? + dtype: bfloat16 + gpu_memory_utilization: 0.5 + enforce_eager: true + cudagraph_capture_sizes: null + free_cache_engine: true + data_parallel_size: 1 + expert_parallel_size: 1 + tensor_model_parallel_size: 2 + max_num_batched_tokens: 8192 + max_model_len: null + max_num_seqs: 1024 + load_format: auto + engine_kwargs: {} + limit_images: null + enable_chunked_prefill: true + enable_prefix_caching: true + disable_log_stats: true + skip_tokenizer_init: false + prompt_length: 2048 + response_length: 2048 +algorithm: + rollout_correction: + rollout_is: null + rollout_is_threshold: 2.0 + rollout_rs: null + rollout_rs_threshold: null + bypass_mode: false + loss_type: ppo_clip + rollout_is_batch_normalize: false + _target_: verl.trainer.config.AlgoConfig + gamma: 1.0 + lam: 1.0 + adv_estimator: gae + norm_adv_by_std_in_grpo: true + use_kl_in_reward: false + kl_penalty: kl + kl_ctrl: + _target_: verl.trainer.config.KLControlConfig + type: fixed + kl_coef: 0.001 + horizon: 10000 + target_kl: 0.1 + use_pf_ppo: false + pf_ppo: + reweight_method: pow + weight_pow: 2.0 +custom_reward_function: + path: null + name: compute_score +trainer: + balance_batch: true + total_epochs: 30 + total_training_steps: null + project_name: verl_examples + experiment_name: gsm8k + logger: + - console + - wandb + log_val_generations: 0 + nnodes: 1 + n_gpus_per_node: 8 + save_freq: -1 + esi_redundant_time: 0 + resume_mode: auto + resume_from_path: null + del_local_ckpt_after_load: false + val_before_train: true + test_freq: -1 + critic_warmup: 0 + default_hdfs_dir: null + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + max_actor_ckpt_to_keep: null + max_critic_ckpt_to_keep: null + ray_wait_register_center_timeout: 300 + device: cuda + rollout_data_dir: null + use_legacy_worker_impl: auto +global_profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: null + steps: null + profile_continuous_steps: false + save_path: outputs/profile + global_tool_config: + nsys: + discrete: false + controller_nsight_options: + trace: cuda,nvtx,cublas,ucx + cuda-memory-usage: 'true' + cuda-graph-trace: graph + worker_nsight_options: + trace: cuda,nvtx,cublas,ucx + cuda-memory-usage: 'true' + cuda-graph-trace: graph + capture-range: cudaProfilerApi + capture-range-end: null + kill: none + torch_memory: + trace_alloc_max_entries: 100000 + stack_depth: 32 + context: all + stacks: all + kw_args: {} +transfer_queue: + enable: false +ray_kwargs: + ray_init: + num_cpus: null + timeline_json_file: null diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/_generated_ppo_trainer.yaml b/code/RL_model/verl/verl_train/verl/trainer/config/_generated_ppo_trainer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a3baaf52af3e16782f4cff8eaf2651021b9e0060 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/_generated_ppo_trainer.yaml @@ -0,0 +1,653 @@ +# This reference configration yaml is automatically generated via 'scripts/generate_trainer_config.sh' +# in which it invokes 'python3 scripts/print_cfg.py --cfg job ' to flatten the 'verl/trainer/config/ppo_trainer.yaml' config fields into a single file. +# Do not modify this file directly. +# The file is usually only for reference and never used. + +actor_rollout_ref: + actor: + optim: + _target_: verl.workers.config.FSDPOptimizerConfig + optimizer: AdamW + optimizer_impl: torch.optim + lr: 1.0e-06 + lr_warmup_steps_ratio: 0.0 + total_training_steps: -1 + weight_decay: 0.01 + lr_warmup_steps: -1 + betas: + - 0.9 + - 0.999 + clip_grad: 1.0 + min_lr_ratio: 0.0 + num_cycles: 0.5 + lr_scheduler_type: constant + warmup_style: null + override_optimizer_config: null + fsdp_config: + _target_: verl.workers.config.FSDPEngineConfig + wrap_policy: + min_num_params: 0 + param_offload: false + optimizer_offload: false + offload_policy: false + reshard_after_forward: true + fsdp_size: -1 + forward_prefetch: false + model_dtype: fp32 + use_orig_params: false + seed: 42 + full_determinism: false + ulysses_sequence_parallel_size: 1 + entropy_from_logits_with_chunking: false + use_torch_compile: true + entropy_checkpointing: false + forward_only: false + strategy: fsdp + dtype: bfloat16 + _target_: verl.workers.config.FSDPActorConfig + rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} + strategy: fsdp + ppo_mini_batch_size: 256 + ppo_micro_batch_size: null + ppo_micro_batch_size_per_gpu: null + use_dynamic_bsz: false + ppo_max_token_len_per_gpu: 16384 + clip_ratio: 0.2 + clip_ratio_low: 0.2 + clip_ratio_high: 0.2 + tau_pos: 1.0 + tau_neg: 1.05 + freeze_vision_tower: false + policy_loss: + _target_: verl.workers.config.PolicyLossConfig + loss_mode: vanilla + clip_cov_ratio: 0.0002 + clip_cov_lb: 1.0 + clip_cov_ub: 5.0 + kl_cov_ratio: 0.0002 + ppo_kl_coef: 0.1 + clip_ratio_c: 3.0 + loss_agg_mode: token-mean + loss_scale_factor: null + entropy_coeff: 0 + calculate_entropy: false + use_kl_loss: false + use_prefix_grouper: false + use_torch_compile: true + kl_loss_coef: 0.001 + kl_loss_type: low_var_kl + ppo_epochs: 1 + shuffle: false + data_loader_seed: 42 + checkpoint: + _target_: verl.trainer.config.CheckpointConfig + save_contents: + - model + - optimizer + - extra + load_contents: ${.save_contents} + async_save: false + use_fused_kernels: ${oc.select:actor_rollout_ref.model.use_fused_kernels,false} + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: false + all_ranks: false + ranks: [] + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: + nsys: + _target_: verl.utils.profiler.config.NsightToolConfig + discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + npu: + _target_: verl.utils.profiler.config.NPUToolConfig + contents: [] + level: level0 + analysis: true + discrete: false + torch: + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + contents: [] + discrete: false + torch_memory: + _target_: verl.utils.profiler.config.TorchMemoryToolConfig + trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} + stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + router_replay: + _target_: verl.workers.config.RouterReplayConfig + mode: disabled + record_file: null + replay_file: null + grad_clip: 1.0 + ulysses_sequence_parallel_size: 1 + entropy_from_logits_with_chunking: false + entropy_checkpointing: false + use_remove_padding: ${oc.select:actor_rollout_ref.model.use_remove_padding,false} + calculate_sum_pi_squared: false + sum_pi_squared_checkpointing: false + ref: + rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} + strategy: ${actor_rollout_ref.actor.strategy} + use_torch_compile: ${oc.select:actor_rollout_ref.actor.use_torch_compile,true} + log_prob_micro_batch_size: null + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: false + all_ranks: false + ranks: [] + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: + nsys: + _target_: verl.utils.profiler.config.NsightToolConfig + discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + npu: + _target_: verl.utils.profiler.config.NPUToolConfig + contents: [] + level: level0 + analysis: true + discrete: false + torch: + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + contents: [] + discrete: false + torch_memory: + _target_: verl.utils.profiler.config.TorchMemoryToolConfig + trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} + stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + router_replay: + _target_: verl.workers.config.RouterReplayConfig + mode: disabled + record_file: null + replay_file: null + fsdp_config: + _target_: verl.workers.config.FSDPEngineConfig + wrap_policy: + min_num_params: 0 + param_offload: false + optimizer_offload: false + offload_policy: false + reshard_after_forward: true + fsdp_size: -1 + forward_prefetch: false + model_dtype: fp32 + use_orig_params: false + seed: 42 + full_determinism: false + ulysses_sequence_parallel_size: 1 + entropy_from_logits_with_chunking: false + use_torch_compile: true + entropy_checkpointing: false + forward_only: true + strategy: fsdp + dtype: bfloat16 + _target_: verl.workers.config.FSDPActorConfig + ulysses_sequence_parallel_size: ${oc.select:actor_rollout_ref.actor.ulysses_sequence_parallel_size,1} + entropy_from_logits_with_chunking: false + entropy_checkpointing: false + rollout: + _target_: verl.workers.config.RolloutConfig + name: ??? + mode: async + temperature: 1.0 + top_k: -1 + top_p: 1 + prompt_length: ${oc.select:data.max_prompt_length,512} + response_length: ${oc.select:data.max_response_length,512} + dtype: bfloat16 + gpu_memory_utilization: 0.5 + ignore_eos: false + enforce_eager: false + cudagraph_capture_sizes: null + free_cache_engine: true + tensor_model_parallel_size: 2 + data_parallel_size: 1 + expert_parallel_size: 1 + pipeline_model_parallel_size: 1 + max_num_batched_tokens: 8192 + max_model_len: null + max_num_seqs: 1024 + enable_chunked_prefill: true + enable_prefix_caching: true + logprobs_mode: processed_logprobs + scheduling_policy: fcfs + load_format: dummy + log_prob_micro_batch_size: null + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} + disable_log_stats: true + do_sample: true + 'n': 1 + over_sample_rate: 0 + multi_stage_wake_up: false + engine_kwargs: + vllm: {} + sglang: {} + trtllm: {} + val_kwargs: + _target_: verl.workers.config.SamplingConfig + top_k: -1 + top_p: 1.0 + temperature: 0 + 'n': 1 + do_sample: false + multi_turn: + _target_: verl.workers.config.MultiTurnConfig + enable: false + max_assistant_turns: null + tool_config_path: null + max_user_turns: null + max_parallel_calls: 1 + max_tool_response_length: 256 + tool_response_truncate_side: middle + interaction_config_path: null + use_inference_chat_template: false + tokenization_sanity_check_mode: strict + format: hermes + num_repeat_rollouts: null + calculate_log_probs: false + agent: + _target_: verl.workers.config.AgentLoopConfig + num_workers: 8 + default_agent_loop: single_turn_agent + agent_loop_config_path: null + custom_async_server: + _target_: verl.workers.config.CustomAsyncServerConfig + path: null + name: null + checkpoint_engine: + _target_: verl.workers.config.CheckpointEngineConfig + backend: naive + update_weights_bucket_megabytes: 2048 + engine_kwargs: {} + trace: + _target_: verl.workers.config.TraceConfig + backend: null + token2text: false + max_samples_per_step_per_worker: null + skip_rollout: false + skip_dump_dir: /tmp/rollout_dump + skip_tokenizer_init: true + enable_rollout_routing_replay: false + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: ${oc.select:actor_rollout_ref.actor.profiler.enable,false} + all_ranks: ${oc.select:actor_rollout_ref.actor.profiler.all_ranks,false} + ranks: ${oc.select:actor_rollout_ref.actor.profiler.ranks,[]} + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: ${oc.select:actor_rollout_ref.actor.profiler.tool_config,null} + prometheus: + _target_: verl.workers.config.PrometheusConfig + enable: false + port: 9090 + file: /tmp/ray/session_latest/metrics/prometheus/prometheus.yml + served_model_name: ${oc.select:actor_rollout_ref.model.path,null} + quantization: null + quantization_config_file: null + mtp: ${oc.select:actor_rollout_ref.model.mtp, null} + layered_summon: false + model: + _target_: verl.workers.config.HFModelConfig + path: ~/models/deepseek-llm-7b-chat + hf_config_path: null + tokenizer_path: null + use_shm: false + trust_remote_code: false + custom_chat_template: null + external_lib: null + override_config: {} + enable_gradient_checkpointing: true + enable_activation_offload: false + use_remove_padding: true + lora_rank: 0 + lora_alpha: 16 + target_modules: all-linear + exclude_modules: null + lora_adapter_path: null + use_liger: false + use_fused_kernels: false + fused_kernel_options: + impl_backend: torch + tiled_mlp: + enabled: false + num_shards: 4 + mtp: + _target_: verl.workers.config.MtpConfig + enable: false + enable_train: false + enable_rollout: false + detach_encoder: false + mtp_loss_scaling_factor: 0.1 + speculative_algorithm: EAGLE + speculative_num_steps: 3 + speculative_eagle_topk: 1 + speculative_num_draft_tokens: 4 + method: mtp + num_speculative_tokens: 1 + hybrid_engine: true + nccl_timeout: 600 +data: + tokenizer: null + use_shm: false + train_files: ~/data/rlhf/gsm8k/train.parquet + val_files: ~/data/rlhf/gsm8k/test.parquet + train_max_samples: -1 + val_max_samples: -1 + prompt_key: prompt + reward_fn_key: data_source + max_prompt_length: 512 + max_response_length: 512 + train_batch_size: 1024 + val_batch_size: null + tool_config_path: ${oc.select:actor_rollout_ref.rollout.multi_turn.tool_config_path, + null} + return_raw_input_ids: false + return_raw_chat: true + return_full_prompt: false + shuffle: true + seed: null + dataloader_num_workers: 8 + image_patch_size: 14 + validation_shuffle: false + filter_overlong_prompts: false + filter_overlong_prompts_workers: 1 + truncation: error + image_key: images + video_key: videos + trust_remote_code: false + custom_cls: + path: null + name: null + return_multi_modal_inputs: true + sampler: + class_path: null + class_name: null + datagen: + path: null + name: null + apply_chat_template_kwargs: {} +reward_manager: + _target_: verl.trainer.config.config.RewardManagerConfig + source: register + name: ${oc.select:reward_model.reward_manager,naive} + module: + _target_: verl.trainer.config.config.ModuleConfig + path: null + name: custom_reward_manager +critic: + optim: + _target_: verl.workers.config.FSDPOptimizerConfig + optimizer: AdamW + optimizer_impl: torch.optim + lr: 1.0e-05 + lr_warmup_steps_ratio: 0.0 + total_training_steps: -1 + weight_decay: 0.01 + lr_warmup_steps: -1 + betas: + - 0.9 + - 0.999 + clip_grad: 1.0 + min_lr_ratio: 0.0 + num_cycles: 0.5 + lr_scheduler_type: constant + warmup_style: null + override_optimizer_config: null + model: + fsdp_config: + _target_: verl.workers.config.FSDPEngineConfig + wrap_policy: + min_num_params: 0 + param_offload: false + optimizer_offload: false + offload_policy: false + reshard_after_forward: true + fsdp_size: -1 + forward_prefetch: false + model_dtype: fp32 + use_orig_params: false + seed: 42 + full_determinism: false + ulysses_sequence_parallel_size: 1 + entropy_from_logits_with_chunking: false + use_torch_compile: true + entropy_checkpointing: false + forward_only: false + strategy: fsdp + dtype: bfloat16 + path: ~/models/deepseek-llm-7b-chat + tokenizer_path: ${oc.select:actor_rollout_ref.model.path,"~/models/deepseek-llm-7b-chat"} + override_config: {} + external_lib: ${oc.select:actor_rollout_ref.model.external_lib,null} + trust_remote_code: ${oc.select:actor_rollout_ref.model.trust_remote_code,false} + _target_: verl.workers.config.FSDPCriticModelCfg + use_shm: false + enable_gradient_checkpointing: true + enable_activation_offload: false + use_remove_padding: false + lora_rank: 0 + lora_alpha: 16 + target_modules: all-linear + tiled_mlp: + enabled: false + num_shards: 4 + _target_: verl.workers.config.FSDPCriticConfig + rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} + strategy: fsdp + enable: null + ppo_mini_batch_size: ${oc.select:actor_rollout_ref.actor.ppo_mini_batch_size,256} + ppo_micro_batch_size: null + ppo_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size,null} + use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + ppo_max_token_len_per_gpu: 32768 + forward_max_token_len_per_gpu: ${.ppo_max_token_len_per_gpu} + ppo_epochs: ${oc.select:actor_rollout_ref.actor.ppo_epochs,1} + shuffle: ${oc.select:actor_rollout_ref.actor.shuffle,false} + data_loader_seed: 42 + cliprange_value: 0.5 + loss_agg_mode: ${oc.select:actor_rollout_ref.actor.loss_agg_mode,token-mean} + checkpoint: + _target_: verl.trainer.config.CheckpointConfig + save_contents: + - model + - optimizer + - extra + load_contents: ${.save_contents} + async_save: false + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: false + all_ranks: false + ranks: [] + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: + nsys: + _target_: verl.utils.profiler.config.NsightToolConfig + discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + npu: + _target_: verl.utils.profiler.config.NPUToolConfig + contents: [] + level: level0 + analysis: true + discrete: false + torch: + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + contents: [] + discrete: false + torch_memory: + _target_: verl.utils.profiler.config.TorchMemoryToolConfig + trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} + stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + forward_micro_batch_size: ${oc.select:.ppo_micro_batch_size,null} + forward_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size_per_gpu,null} + ulysses_sequence_parallel_size: 1 + grad_clip: 1.0 +reward_model: + enable: false + enable_resource_pool: false + n_gpus_per_node: 8 + nnodes: 0 + strategy: fsdp + model: + input_tokenizer: ${actor_rollout_ref.model.path} + path: ~/models/FsfairX-LLaMA3-RM-v0.1 + external_lib: ${actor_rollout_ref.model.external_lib} + trust_remote_code: false + override_config: {} + use_shm: false + use_remove_padding: false + use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels} + fsdp_config: + _target_: verl.workers.config.FSDPEngineConfig + wrap_policy: + min_num_params: 0 + param_offload: false + reshard_after_forward: true + fsdp_size: -1 + forward_prefetch: false + micro_batch_size: null + micro_batch_size_per_gpu: null + max_length: null + use_dynamic_bsz: ${critic.use_dynamic_bsz} + forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} + reward_manager: naive + reward_loop_source: register + reward_loop_module_path: null + reward_loop_class_name: null + launch_reward_fn_async: false + sandbox_fusion: + url: null + max_concurrent: 64 + memory_limit_mb: 1024 + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: false + all_ranks: false + ranks: [] + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: ${oc.select:actor_rollout_ref.actor.profiler.tool_config,null} + ulysses_sequence_parallel_size: 1 + use_reward_loop: true + num_workers: 1 + rollout: + _target_: verl.workers.config.RolloutConfig + name: ??? + dtype: bfloat16 + gpu_memory_utilization: 0.5 + enforce_eager: true + cudagraph_capture_sizes: null + free_cache_engine: true + data_parallel_size: 1 + expert_parallel_size: 1 + tensor_model_parallel_size: 2 + max_num_batched_tokens: 8192 + max_model_len: null + max_num_seqs: 1024 + load_format: auto + engine_kwargs: {} + limit_images: null + enable_chunked_prefill: true + enable_prefix_caching: true + disable_log_stats: true + skip_tokenizer_init: false + prompt_length: 2048 + response_length: 2048 +algorithm: + rollout_correction: + rollout_is: null + rollout_is_threshold: 2.0 + rollout_rs: null + rollout_rs_threshold: null + bypass_mode: false + loss_type: ppo_clip + rollout_is_batch_normalize: false + _target_: verl.trainer.config.AlgoConfig + gamma: 1.0 + lam: 1.0 + adv_estimator: gae + norm_adv_by_std_in_grpo: true + use_kl_in_reward: false + kl_penalty: kl + kl_ctrl: + _target_: verl.trainer.config.KLControlConfig + type: fixed + kl_coef: 0.001 + horizon: 10000 + target_kl: 0.1 + use_pf_ppo: false + pf_ppo: + reweight_method: pow + weight_pow: 2.0 +custom_reward_function: + path: null + name: compute_score +trainer: + balance_batch: true + total_epochs: 30 + total_training_steps: null + project_name: verl_examples + experiment_name: gsm8k + logger: + - console + - wandb + log_val_generations: 0 + rollout_data_dir: null + validation_data_dir: null + nnodes: 1 + n_gpus_per_node: 8 + save_freq: -1 + esi_redundant_time: 0 + resume_mode: auto + resume_from_path: null + val_before_train: true + val_only: false + test_freq: -1 + critic_warmup: 0 + default_hdfs_dir: null + del_local_ckpt_after_load: false + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + max_actor_ckpt_to_keep: null + max_critic_ckpt_to_keep: null + ray_wait_register_center_timeout: 300 + device: cuda + use_legacy_worker_impl: auto +global_profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: null + steps: null + profile_continuous_steps: false + save_path: outputs/profile + global_tool_config: + nsys: + _target_: verl.utils.profiler.config.NsightToolConfig + discrete: false + controller_nsight_options: + trace: cuda,nvtx,cublas,ucx + cuda-memory-usage: 'true' + cuda-graph-trace: graph + worker_nsight_options: + trace: cuda,nvtx,cublas,ucx + cuda-memory-usage: 'true' + cuda-graph-trace: graph + capture-range: cudaProfilerApi + capture-range-end: null + kill: none + torch_memory: + trace_alloc_max_entries: 100000 + stack_depth: 32 + context: all + stacks: all + kw_args: {} +transfer_queue: + enable: false +ray_kwargs: + ray_init: + num_cpus: null + timeline_json_file: null diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/actor/actor.yaml b/code/RL_model/verl/verl_train/verl/trainer/config/actor/actor.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7680013228c26cfe88d19c8a7604209df4548772 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/actor/actor.yaml @@ -0,0 +1,254 @@ +# Format checks enforced on CI: +# 1. Comments must appear above each field. +# 2. There must be a blank line between each field. +# 3. Inline comments (after a field on the same line) are not allowed. +# 4. Indentation level is respected for nested fields. + +# Target class for this configuration +_target_: verl.workers.config.ActorConfig + +# Number of rollouts per update (mirrors actor rollout_n) +rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} + +# the abstract actor configs +# fsdp, fsdp2 or megatron. must be set. +strategy: ??? + +# Split each sample into sub-batches of this size for PPO +ppo_mini_batch_size: 256 + +# [Deprecated] Global micro batch size +ppo_micro_batch_size: null + +# Local per-GPU micro batch size +ppo_micro_batch_size_per_gpu: null + +# Whether to automatically adjust batch size at runtime +# oc.select: the default val for ref.log_prob_use_dynamic_bsz +use_dynamic_bsz: false + +# Max tokens per GPU in one PPO batch; affects gradient accumulation +# Typically it should be: n * ${data.max_prompt_length} + ${data.max_response_length} +# oc.select: the default val for ref.log_prob_max_token_len_per_gpu +ppo_max_token_len_per_gpu: 16384 + +# PPO clip ratio +clip_ratio: 0.2 + +# Lower bound for asymmetric clipping (used in dual-clip PPO) +clip_ratio_low: 0.2 + +# Upper bound for asymmetric clipping (used in dual-clip PPO) +clip_ratio_high: 0.2 + +# Positive and negative tau for smoothing function in SAPO (https://arxiv.org/pdf/2511.20347) +# default values used in the paper with Qwen3-30B-A3B-Base +tau_pos: 1.0 + +# negative tau for smoothing function in SAPO +tau_neg: 1.05 + +# Whether to freeze vision model, if set true, it will be freeze vision model +freeze_vision_tower: false + +# policy loss config +policy_loss: + + # # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.workers.config.PolicyLossConfig + + # Loss function mode: vanilla / clip-cov / kl-cov /gpg from https://arxiv.org/abs/2505.22617 + loss_mode: "vanilla" + + # Ratio of tokens to be clipped for clip-cov loss + clip_cov_ratio: 0.0002 + + # Lower bound for clip-cov loss + clip_cov_lb: 1.0 + + # Upper bound for clip-cov loss + clip_cov_ub: 5.0 + + # Ratio of tokens to be applied kl penalty for kl-cov loss + kl_cov_ratio: 0.0002 + + # KL divergence penalty coefficient + ppo_kl_coef: 0.1 + +# Constant C in Dual-clip PPO; clips when advantage < 0 and ratio > C +clip_ratio_c: 3.0 + +# Loss aggregation mode: "token-mean", "seq-mean-token-sum", "seq-mean-token-mean", or "seq-mean-token-sum-norm" +loss_agg_mode: token-mean + +# Scale factor for "seq-mean-token-sum-norm" loss aggregation mode. +# If null, uses response_length. Set to a constant to ensure consistent normalization. +loss_scale_factor: null + +# Entropy regularization coefficient in PPO loss +entropy_coeff: 0 + +# When true, the actor forward will request entropy from the model +calculate_entropy: false + +# Whether to use KL loss instead of KL reward penalty. True for GRPO +use_kl_loss: false + +# Whether to enable PrefixGrouper shared-prefix forward +use_prefix_grouper: false + +# Whether to use torch.compile() +# oc.select: the default val for ref.use_torch_compile +use_torch_compile: true + +# KL loss coefficient when use_kl_loss is enabled. For GRPO +kl_loss_coef: 0.001 + +# Type of KL divergence loss. Options: "kl"(k1), "abs", "mse"(k2), "low_var_kl"(k3), "full" +kl_loss_type: low_var_kl + +# Number of PPO epochs per batch +ppo_epochs: 1 + +# Shuffle training data across PPO epochs +shuffle: false + +# The seed used to construct mini-batch +data_loader_seed: 42 + +# checkpoint configs +checkpoint: + + # Target dataclass for this configuration + _target_: verl.trainer.config.CheckpointConfig + + # What to include in saved checkpoints + # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space + save_contents: ['model', 'optimizer', 'extra'] + + # For more flexibility, you can specify the contents to load from the checkpoint. + # .xxx refers to the local variable xxx from the same level of hierarchy similar to python pkg + load_contents: ${.save_contents} + + # Whether to save checkpoints asynchronously. Only effective for Megatron as of now. + async_save: False + +# optimizer configs +optim: + + # Learning rate + lr: 1e-6 + + # Warmup steps ratio (used if lr_warmup_steps is 0 or negative) + lr_warmup_steps_ratio: 0.0 + + # Total training steps (must be overridden at runtime) + total_training_steps: -1 + + # Weight decay + weight_decay: 0.01 + + # Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio. + lr_warmup_steps: -1 + + +# Whether to use custom fused kernels (e.g., FlashAttention, fused MLP) +use_fused_kernels: ${oc.select:actor_rollout_ref.model.use_fused_kernels,false} + +# profile the actor model in `update_policy` +profiler: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.ProfilerConfig + + # profiler tool, default same as profiler.tool in global config + # choices: nsys, npu, torch + tool: ${oc.select:global_profiler.tool,null} + + # whether enable profile on Actor + enable: False + + # Whether to profile all ranks. + all_ranks: False + + # The ranks that will be profiled. [] or [0,1,...] + ranks: [] + + # profile results saving path + save_path: ${oc.select:global_profiler.save_path,null} + + # specific tool config which only related to the role + tool_config: + + # nsys tool config + nsys: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.NsightToolConfig + + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + + # npu config + npu: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.NPUToolConfig + + # Contents to profile, can be empty + # options: npu, cpu, memory, shapes, module, stack + contents: [] + + # Collection level, optional values: level_none, level0, level1, level2. + level: "level0" + + # Whether to automatically parse the data. + analysis: True + + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: False + + # torch profiler config + torch: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + + # Contents to profile, can be empty + # options: cuda, cpu, memory, shapes, stack + contents: [] + + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: false + + # torch memory profiler config + torch_memory: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.TorchMemoryToolConfig + + # Maximum number of memory allocation entries to track + trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} + + # Stack trace depth for memory allocations + stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + +# Router replay configuration for MoE models +router_replay: + + # Target dataclass for this configuration + _target_: verl.workers.config.RouterReplayConfig + + # Router replay mode: disabled, R2, R3 + # - R2: Use R2 routing strategy (record mode) + # - R3: Use R3 routing strategy (record mode) + mode: disabled + + # File path to save recorded routing decisions + # Required when mode is 'record', 'R2', or 'R3' + record_file: null + + # File path to load recorded routing decisions for replay + # Required when mode is 'replay' + replay_file: null + diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/actor/dp_actor.yaml b/code/RL_model/verl/verl_train/verl/trainer/config/actor/dp_actor.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fc0a16be6098380ac22acafec9f14efe34c7f9d2 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/actor/dp_actor.yaml @@ -0,0 +1,50 @@ +# Format checks enforced on CI: +# 1. Comments must appear above each field. +# 2. There must be a blank line between each field. +# 3. Inline comments (after a field on the same line) are not allowed. +# 4. Indentation level is respected for nested fields. + +# defaults specify the default config from each component +defaults: + + # fsdp optimizer config + - ../optim@optim: fsdp + + # fsdp engine config + - ../engine@fsdp_config: fsdp + + # dp actor config, inheriting from trainer/config/actor/actor.yaml + - actor + + # load the reference default config, then apply the fields in the current yaml + - _self_ + +# Target class for this configuration +_target_: verl.workers.config.FSDPActorConfig + +# TODO(haibin.lin): switch to fsdp2 +strategy: fsdp + +# Gradient clipping for actor updates, specific to the strategy. +grad_clip: 1.0 + +# Sequence parallelism size for Ulysses-style model parallelism +# oc.select: the default val for ref.ulysses_sequence_parallel_size +# [DEPRECATED] use fsdp_config.ulysses_sequence_parallel_size instead +ulysses_sequence_parallel_size: 1 + +# calculate entropy with chunking to reduce memory peak +entropy_from_logits_with_chunking: False + +# recompute entropy +entropy_checkpointing: False + +# Whether to remove padding tokens in inputs during training +use_remove_padding: ${oc.select:actor_rollout_ref.model.use_remove_padding,false} + +# This computes Σπ² needed for the Logit-Gradient Norm proxy W(τ) = Σ_t[1 - 2π_t + Σπ²] +# c.f. https://yingru.notion.site/The-Optimal-Token-Baseline-399211a558b782cfa936014c0d42dfb8 +calculate_sum_pi_squared: False + +# Enable gradient checkpointing for sum_pi_squared computation (saves memory) +sum_pi_squared_checkpointing: False diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/actor/megatron_actor.yaml b/code/RL_model/verl/verl_train/verl/trainer/config/actor/megatron_actor.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fde70c363c4cd8b6f6c524998b06d79b8b821453 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/actor/megatron_actor.yaml @@ -0,0 +1,18 @@ +# megatron actor config, inheriting from trainer/config/actor/actor.yaml +defaults: + # megatron optimizer config + - ../optim@optim: megatron + + # megatron engine config + - ../engine@megatron: megatron + + - actor + + # load the reference default config, then apply the fields in the current yaml + - _self_ + +_target_: verl.workers.config.McoreActorConfig + +strategy: megatron + +load_weight: True diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/algorithm.py b/code/RL_model/verl/verl_train/verl/trainer/config/algorithm.py new file mode 100644 index 0000000000000000000000000000000000000000..5aa650d7bf99520e306a45d30d639e1db1e68788 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/algorithm.py @@ -0,0 +1,614 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any, Optional + +from verl.base_config import BaseConfig + +__all__ = ["AlgoConfig", "FilterGroupsConfig", "KLControlConfig", "RolloutCorrectionConfig"] + + +@dataclass +class KLControlConfig(BaseConfig): + """Configuration for KL control. + + The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. + + Args: + type (str): Type of KL control. Can be "fixed" or "adaptive". + kl_coef (float): Initial coefficient for KL penalty. + horizon (int): Horizon value for adaptive controller. + target_kl (float): Target KL divergence for adaptive controller. + """ + + type: str = "fixed" + kl_coef: float = 0.001 + horizon: int = 10000 + target_kl: float = 0.1 + + +@dataclass +class FilterGroupsConfig(BaseConfig): + """Configuration for filter groups (used in DAPO and Entropy). + + The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. + + Args: + enable (bool): Whether to enable filter groups. + metric (Optional[str]): Metric to use for filtering: "acc", "score", "seq_reward", "seq_final_reward", etc. + max_num_gen_batches (int): Non-positive values mean no upper limit. + """ + + enable: bool = False + metric: Optional[str] = None + max_num_gen_batches: int = 0 + + +@dataclass +class RolloutCorrectionConfig(BaseConfig): + """Configuration for Rollout Correction (addresses off-policy issues in RL training). + + The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. + + Rollout Correction handles off-policiness from multiple sources: + 1. Policy mismatch: Rollout policy (e.g., vLLM BF16) vs Training policy (e.g., FSDP FP32) + 2. Model update staleness: Rollout data collected from older policy checkpoints + 3. General off-policy scenarios: Any distribution shift between data collection and training + + For more details, see: + "When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch" + https://richardli.xyz/rl-collapse + + This typed config replaces the old dict-based approach and provides: + - Type safety and validation + - Clear documentation of all parameters + - Named factory methods for common presets (TIS, MIS, etc.) + - Sensible defaults + + Args: + rollout_is (Optional[str]): IS weight aggregation level. + - None: No IS weights (metrics only) + - "token": Per-token IS weights (low variance, biased) + - "sequence": Per-sequence IS weights (unbiased, high variance) + Default: "sequence" + + rollout_is_threshold (float): Upper threshold for IS weight truncation/rejection. + Typical range: 1.5-5.0 for token level, 2.0-10.0 for sequence level. + Default: 2.0 + + rollout_is_batch_normalize (bool): Apply batch normalization to IS weights. + - True: Normalize IS weights to have mean=1.0 within each batch + - False: Use raw (truncated) IS weights (standard) + - Reduces variance by ensuring average weight is 1.0 per batch + - Only affects IS weight values, not rejection sampling + Default: False (no batch normalization) + + rollout_rs (Optional[str]): Rejection sampling aggregation modes. + Accepts a comma-delimited list (duplicates removed) of canonical options implemented in + ``rollout_corr_helper``: + - "token_k1": Token-level rejection with ``-log r`` (ratio thresholds supplied via + ``rollout_rs_threshold`` as ``lower_upper``) + - "token_k2": Token-level rejection with ``0.5 * (log r)^2`` (upper bound only) + - "token_k3": Token-level rejection with ``exp(log r) - 1 - log r`` (upper bound only) + - "seq_sum_k1": Sequence sum of ``-log r`` (ratio bounds) + - "seq_sum_k2": Sequence sum of rejection with ``0.5 * (log r)^2`` (upper bound only) + - "seq_sum_k3": Sequence sum of rejection with ``exp(log r) - 1 - log r`` (upper bound only) + - "seq_mean_k1": Sequence mean of ``-log r`` (ratio bounds) + - "seq_mean_k2": Sequence mean of rejection with ``0.5 * (log r)^2`` (upper bound only) + - "seq_mean_k3": Sequence mean of rejection with ``exp(log r) - 1 - log r`` (upper bound only) + - "seq_max_k2": Sequence max of rejection with ``0.5 * (log r)^2`` (upper bound only) + - "seq_max_k3": Sequence max of rejection with ``exp(log r) - 1 - log r`` (upper bound only) + names automatically. Default: None + + rollout_rs_threshold (Optional[Union[str, float]]): Threshold specification for rejection sampling. + Provide one value per option (single entry is broadcast when multiple options are supplied). + Ratio-based modes (``*k1``) expect ``lower_upper`` strings; supplying a single float implies + only the upper ratio bound, with the lower bound inferred as its reciprocal. Divergence modes + (k2/k3) expect positive upper bounds (float or string). Default: None + + bypass_mode (bool): Operating mode - bypass or decoupled. + - True: Bypass mode - reuse rollout_log_prob as old_log_prob (2 policies) + Uses compute_policy_loss_bypass_mode() with loss_type selection + - False: Decoupled mode - compute old_log_prob separately (3 policies) + Uses standard PPO loss with IS weight correction + Default: False (decoupled mode) + + loss_type (str): Loss function type in bypass mode (bypass_mode=True). + - "reinforce": REINFORCE-style policy gradient with explicit IS weights + L = -E[w * log π(a|s) * A] where w = π_current / π_rollout + - "ppo_clip": PPO clipped objective (IS handled by ratio, no explicit weights) + L = -E[min(r*A, clip(r)*A)] where r = π_current / π_rollout + Default: "ppo_clip" + + Example: + # Create with defaults + config = RolloutCorrectionConfig() + + # Decoupled PPO mode presets (3 policies: π_rollout, π_old, π_θ) + # IS weights correct for gap between π_old and π_rollout + config = RolloutCorrectionConfig.decoupled_token_is() # Token-TIS + config = RolloutCorrectionConfig.decoupled_seq_is() # Seq-TIS + config = RolloutCorrectionConfig.decoupled_seq_is_rs() # Seq-MIS + config = RolloutCorrectionConfig.decoupled_geo_rs() # Geo-RS (ratio mode) + + # Bypass mode presets (2 policies: π_rollout = π_old, π_θ) + # loss_type controls the loss function + # PPO-clip presets (ratio handles IS, so no separate IS weights needed): + config = RolloutCorrectionConfig.bypass_ppo_clip() # PPO-clip only + config = RolloutCorrectionConfig.bypass_ppo_clip_geo_rs() # PPO-clip + Geo-RS + config = RolloutCorrectionConfig.bypass_ppo_clip_k3_rs() # PPO-clip + K3-RS + # REINFORCE presets (explicit IS weights): + config = RolloutCorrectionConfig.bypass_pg_is() # REINFORCE + Seq-TIS + config = RolloutCorrectionConfig.bypass_pg_geo_rs() # REINFORCE + Geo-RS + config = RolloutCorrectionConfig.bypass_pg_geo_rs_seq_tis() # REINFORCE + Geo-RS + Seq-TIS + config = RolloutCorrectionConfig.bypass_pg_geo_rs_token_tis() # REINFORCE + Geo-RS + Token-TIS + + # Decoupled Geometric ratio presets (length-normalized IS ratio) + config = RolloutCorrectionConfig.decoupled_geo_rs_seq_tis() # Decoupled Geo-RS + Seq-TIS + config = RolloutCorrectionConfig.decoupled_geo_rs_token_tis() # Decoupled Geo-RS + Token-TIS + + # Decoupled K3 KL Estimator presets (more stable for small KL values) + config = RolloutCorrectionConfig.decoupled_k3_rs() # Decoupled K3-RS + config = RolloutCorrectionConfig.decoupled_k3_rs_seq_tis() # Decoupled K3-RS + Seq-TIS + config = RolloutCorrectionConfig.decoupled_k3_rs_token_tis() # Decoupled K3-RS + Token-TIS + + Reference: + Liu, Li, Fu, Wang, Liu, Shen (2025) + "When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch" + https://richardli.xyz/rl-collapse + """ + + rollout_is: Optional[str] = "sequence" + rollout_is_threshold: float = 2.0 + rollout_is_batch_normalize: bool = False + rollout_rs: Optional[str] = None + rollout_rs_threshold: Optional[str | float] = None + bypass_mode: bool = False + loss_type: str = "ppo_clip" + + @classmethod + def decoupled_token_is(cls, threshold: float = 2.0) -> "RolloutCorrectionConfig": + """Decoupled Mode with Token-level Importance Sampling. + + IS weight correction at token level in decoupled mode (three policies). + + Args: + threshold (float): Upper threshold for IS weights. Default: 2.0 + + Returns: + RolloutCorrectionConfig configured for decoupled mode with token-level IS + """ + return cls(rollout_is="token", rollout_is_threshold=threshold, rollout_rs=None) + + @classmethod + def decoupled_seq_is(cls, threshold: float = 2.0) -> "RolloutCorrectionConfig": + """Decoupled Mode with Sequence-level Importance Sampling. + + IS weight correction at sequence level in decoupled mode (three policies). + + Args: + threshold (float): Upper threshold for IS weights. Default: 2.0 + + Returns: + RolloutCorrectionConfig configured for decoupled mode with sequence-level IS + """ + return cls(rollout_is="sequence", rollout_is_threshold=threshold, rollout_rs=None) + + @classmethod + def decoupled_seq_is_rs( + cls, + is_threshold: float = 2.0, + rs_threshold: Optional[str | float] = "0.5_2.0", + ) -> "RolloutCorrectionConfig": + """Decoupled Mode with Sequence-level IS + Rejection Sampling. + + Sequence-level IS with sequence-level rejection sampling in decoupled mode. + Rejects entire sequences based on sequence-level IS weight. + + Args: + is_threshold (float): Upper threshold for IS weights. Default: 2.0 + rs_threshold (Optional[Union[str, float]]): Upper threshold for rejection sampling. Default: 0.5_2.0 + + Returns: + RolloutCorrectionConfig configured for decoupled mode with sequence IS + RS + """ + return cls( + rollout_is="sequence", + rollout_is_threshold=is_threshold, + rollout_rs="seq_sum_k1", + rollout_rs_threshold=rs_threshold, + ) + + @classmethod + def decoupled_geo_rs( + cls, + rs_threshold: Optional[str | float] = "0.999_1.001", + ) -> "RolloutCorrectionConfig": + """Decoupled Mode with Geometric Mean Rejection Sampling (ratio-based). + + Uses geometric mean IS ratio E[log(r)] for rejection sampling at sequence level. + This is a ratio-based mode (ideal = 0.0) with [lower, upper] threshold bounds. + Length-normalized but still uses IS ratio semantics. + + Args: + rs_threshold (Optional[Union[str, float]]): Geometric RS threshold (upper). Default: 0.999_1.001 (±0.1%) + + Returns: + RolloutCorrectionConfig configured for decoupled mode with Geo-RS + """ + return cls( + rollout_is=None, + rollout_rs="seq_mean_k1", + rollout_rs_threshold=rs_threshold, + ) + + @classmethod + def bypass_ppo_clip(cls) -> "RolloutCorrectionConfig": + """Bypass mode with PPO-clip loss. + + PPO clipped objective in bypass mode. The PPO ratio = π_θ/π_rollout + already handles IS correction, so no explicit IS weights are applied. + + Skips old_log_prob computation for faster execution (2 policies instead of 3). + + Returns: + RolloutCorrectionConfig configured for bypass mode with PPO-clip + """ + return cls( + rollout_is=None, + rollout_rs=None, + bypass_mode=True, + loss_type="ppo_clip", + ) + + @classmethod + def bypass_ppo_clip_geo_rs( + cls, + rs_threshold: Optional[str | float] = "0.999_1.001", + ) -> "RolloutCorrectionConfig": + """Bypass mode with PPO-clip loss and Geometric Mean RS (ratio-based). + + PPO clipped objective in bypass mode with geometric mean IS ratio RS. + Uses E[log(r)] (ideal = 0.0) with [lower, upper] threshold bounds. + + Args: + rs_threshold (Optional[Union[str, float]]): Geometric RS threshold (upper). Default: 0.999_1.001 (±0.1%) + + Returns: + RolloutCorrectionConfig configured for bypass mode with PPO-clip + Geo-RS + """ + return cls( + rollout_is=None, + rollout_rs="seq_mean_k1", + rollout_rs_threshold=rs_threshold, + bypass_mode=True, + loss_type="ppo_clip", + ) + + @classmethod + def bypass_ppo_clip_k3_rs( + cls, + rs_threshold: float = 0.01, + ) -> "RolloutCorrectionConfig": + """Bypass mode with PPO-clip loss and K3 Rejection Sampling. + + PPO clipped objective in bypass mode with K3 KL estimator RS to mask outliers. + K3 is more stable than K1 for small KL values. + The PPO ratio = π_θ/π_rollout already handles IS correction. + + Args: + rs_threshold (float): Max allowed K3 divergence. Default: 0.01 + + Returns: + RolloutCorrectionConfig configured for bypass mode with PPO-clip + K3-RS + """ + return cls( + rollout_is=None, + rollout_rs="seq_mean_k3", + rollout_rs_threshold=rs_threshold, + bypass_mode=True, + loss_type="ppo_clip", + ) + + @classmethod + def bypass_pg_is(cls, threshold: float = 2.0) -> "RolloutCorrectionConfig": + """Bypass mode with REINFORCE loss and IS Correction. + + Uses REINFORCE loss with explicit IS correction in bypass mode. + No PPO clipping. + + Args: + threshold (float): Upper threshold for IS weights. Default: 2.0 + + Returns: + RolloutCorrectionConfig configured for bypass mode with REINFORCE + IS + """ + return cls( + rollout_is="sequence", + rollout_is_threshold=threshold, + rollout_rs=None, + bypass_mode=True, + loss_type="reinforce", + ) + + @classmethod + def bypass_pg_geo_rs( + cls, + rs_threshold: Optional[str | float] = "0.999_1.001", + ) -> "RolloutCorrectionConfig": + """Bypass mode with REINFORCE loss and Geometric Mean RS (ratio-based). + + REINFORCE with geometric mean IS ratio rejection sampling in bypass mode. + Uses E[log(r)] (ideal = 0.0) with [lower, upper] threshold bounds. + + Args: + rs_threshold (Optional[Union[str, float]]): Geometric RS threshold (upper). Default: 0.999_1.001 (±0.1%) + + Returns: + RolloutCorrectionConfig configured for bypass mode with REINFORCE + Geo-RS + """ + return cls( + rollout_is=None, + rollout_rs="seq_mean_k1", + rollout_rs_threshold=rs_threshold, + bypass_mode=True, + loss_type="reinforce", + ) + + @classmethod + def decoupled_geo_rs_seq_tis( + cls, + is_threshold: float = 2.0, + rs_threshold: Optional[str | float] = "0.999_1.001", + ) -> "RolloutCorrectionConfig": + """Decoupled mode with Geometric Mean RS and Sequence-level Truncated IS (ratio-based). + + Combines the Geometric Mean Filter (ratio-based validity check) with + Clipped Sequence Weight (debiasing). Uses E[log(r)] (ideal = 0.0). + + Args: + is_threshold (float): Upper threshold for sequence IS weights. Default: 2.0 + rs_threshold (Optional[Union[str, float]]): Geometric RS threshold (upper). Default: 0.999_1.001 (±0.1%) + + Returns: + RolloutCorrectionConfig configured for Geo-RS-Seq-TIS + """ + return cls( + rollout_is="sequence", + rollout_is_threshold=is_threshold, + rollout_rs="seq_mean_k1", + rollout_rs_threshold=rs_threshold, + ) + + @classmethod + def decoupled_geo_rs_token_tis( + cls, + is_threshold: float = 2.0, + rs_threshold: Optional[str | float] = "0.999_1.001", + ) -> "RolloutCorrectionConfig": + """Decoupled mode with Geometric Mean RS and Token-level Truncated IS (ratio-based). + + Combines the Geometric Mean Filter (ratio-based validity check) with + Token-level IS weights. Uses E[log(r)] (ideal = 0.0). + + Args: + is_threshold (float): Upper threshold for token IS weights. Default: 2.0 + rs_threshold (Optional[Union[str, float]]): Geometric RS threshold (upper). Default: 0.999_1.001 (±0.1%) + + Returns: + RolloutCorrectionConfig configured for Geo-RS-Token-TIS + """ + return cls( + rollout_is="token", + rollout_is_threshold=is_threshold, + rollout_rs="seq_mean_k1", + rollout_rs_threshold=rs_threshold, + ) + + @classmethod + def bypass_pg_geo_rs_seq_tis( + cls, + is_threshold: float = 2.0, + rs_threshold: Optional[str | float] = "0.999_1.001", + ) -> "RolloutCorrectionConfig": + """Bypass mode with REINFORCE loss, Geo-RS, and Sequence-level IS. + + Combines geometric mean IS ratio rejection with sequence-level IS + in bypass mode with REINFORCE loss (no PPO clipping). + Uses E[log(r)] (ideal = 0.0) with [lower, upper] threshold bounds. + + Args: + is_threshold (float): Upper threshold for sequence IS weights. Default: 2.0 + rs_threshold (Optional[Union[str, float]]): Geometric RS threshold (upper). Default: 0.999_1.001 (±0.1%) + + Returns: + RolloutCorrectionConfig configured for bypass mode with REINFORCE + Geo-RS + Seq-TIS + """ + return cls( + rollout_is="sequence", + rollout_is_threshold=is_threshold, + rollout_rs="seq_mean_k1", + rollout_rs_threshold=rs_threshold, + bypass_mode=True, + loss_type="reinforce", + ) + + @classmethod + def bypass_pg_geo_rs_token_tis( + cls, + is_threshold: float = 2.0, + rs_threshold: Optional[str | float] = "0.999_1.001", + ) -> "RolloutCorrectionConfig": + """Bypass mode with REINFORCE loss, Geo-RS, and Token-level IS. + + Combines geometric mean IS ratio rejection with token-level IS weights + in bypass mode with REINFORCE loss (no PPO clipping). + Uses E[log(r)] (ideal = 0.0) with [lower, upper] threshold bounds. + + Token-level IS has lower variance but introduces bias. + + Args: + is_threshold (float): Upper threshold for token IS weights. Default: 2.0 + rs_threshold (Optional[Union[str, float]]): Geometric RS threshold (upper). Default: 0.999_1.001 (±0.1%) + + Returns: + RolloutCorrectionConfig configured for bypass mode with REINFORCE + Geo-RS + Token-TIS + """ + return cls( + rollout_is="token", + rollout_is_threshold=is_threshold, + rollout_rs="seq_mean_k1", + rollout_rs_threshold=rs_threshold, + bypass_mode=True, + loss_type="reinforce", + ) + + @classmethod + def decoupled_k3_rs( + cls, + rs_threshold: float = 0.01, + ) -> "RolloutCorrectionConfig": + """Decoupled mode with K3 KL Estimator Rejection Sampling. + + Uses K3 KL estimator at sequence level for rejection sampling. + K3 = E[r - log(r) - 1] where r = π_train/π_rollout. + More stable than geometric mean for small KL values. + + K3 >= 0 always (equals 0 when policies match exactly). + + Args: + rs_threshold (float): Max allowed K3 divergence. Default: 0.01 + Typical range: 0.001-0.1 + + Returns: + RolloutCorrectionConfig configured for K3 RS + """ + return cls( + rollout_is=None, + rollout_rs="seq_mean_k3", + rollout_rs_threshold=rs_threshold, + ) + + @classmethod + def decoupled_k3_rs_seq_tis( + cls, + is_threshold: float = 2.0, + rs_threshold: float = 0.01, + ) -> "RolloutCorrectionConfig": + """Decoupled mode with K3 RS and Sequence-level Truncated IS. + + Combines K3 KL estimator rejection with sequence-level IS weights. + K3 provides more stable outlier detection than geometric mean. + + Args: + is_threshold (float): Upper threshold for sequence IS weights. Default: 2.0 + rs_threshold (float): Max allowed K3 divergence. Default: 0.01 + + Returns: + RolloutCorrectionConfig configured for K3-RS-Seq-TIS + """ + return cls( + rollout_is="sequence", + rollout_is_threshold=is_threshold, + rollout_rs="seq_mean_k3", + rollout_rs_threshold=rs_threshold, + ) + + @classmethod + def decoupled_k3_rs_token_tis( + cls, + is_threshold: float = 2.0, + rs_threshold: float = 0.01, + ) -> "RolloutCorrectionConfig": + """Decoupled mode with K3 RS and Token-level Truncated IS. + + Combines K3 KL estimator rejection with token-level IS weights. + K3 provides more stable outlier detection than geometric mean. + Token-level IS has lower variance but introduces bias. + + Args: + is_threshold (float): Upper threshold for token IS weights. Default: 2.0 + rs_threshold (float): Max allowed K3 divergence. Default: 0.01 + + Returns: + RolloutCorrectionConfig configured for K3-RS-Token-TIS + """ + return cls( + rollout_is="token", + rollout_is_threshold=is_threshold, + rollout_rs="seq_mean_k3", + rollout_rs_threshold=rs_threshold, + ) + + @classmethod + def disabled(cls) -> "RolloutCorrectionConfig": + """Disabled - Metrics Only Mode. + + Computes and logs off-policy metrics without applying correction. + + Returns: + RolloutCorrectionConfig with all correction disabled + """ + return cls(rollout_is=None, rollout_rs=None) + + +@dataclass +class AlgoConfig(BaseConfig): + """Configuration for the algorithm. + + The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. + + Args: + gamma (float): Discount factor for future rewards. + lam (float): Trade-off between bias and variance in the GAE estimator. + adv_estimator (str): Advantage estimator type: "gae", "grpo", "reinforce_plus_plus", etc. + norm_adv_by_std_in_grpo (bool): Whether to normalize advantages by std (specific to GRPO). + use_kl_in_reward (bool): Whether to enable in-reward KL penalty. + kl_penalty (str): How to estimate KL divergence: "kl", "abs", "mse", "low_var_kl", or "full". + kl_ctrl (KLControlConfig): KL control configuration. + use_pf_ppo (bool): Whether to enable preference feedback PPO. + pf_ppo (dict[str, Any]): Preference feedback PPO settings. + filter_groups (Optional[FilterGroupsConfig]): Filter groups configuration, used in DAPO and Entropy + rollout_correction (Optional[RolloutCorrectionConfig]): Rollout Correction configuration. + Addresses off-policy issues from policy mismatch, model staleness, and general distribution shifts. + + Set to None to disable entirely. Use factory methods for common presets: + - RolloutCorrectionConfig.decoupled_token_is() - Decoupled mode with token-level IS + - RolloutCorrectionConfig.decoupled_seq_is() - Decoupled mode with sequence-level IS + - RolloutCorrectionConfig.decoupled_seq_is_rs() - Decoupled mode with sequence IS + RS + - RolloutCorrectionConfig.decoupled_k1_rs() - Decoupled mode with K1-RS (divergence) + - RolloutCorrectionConfig.decoupled_geo_rs() - Decoupled mode with Geo-RS (ratio) + - RolloutCorrectionConfig.bypass_ppo_clip() - Bypass mode with PPO-clip + - RolloutCorrectionConfig.bypass_ppo_clip_k1_rs() - Bypass mode with PPO-clip + K1-RS + - RolloutCorrectionConfig.bypass_pg_is() - Bypass mode with REINFORCE + IS + - RolloutCorrectionConfig.bypass_pg_k1_rs() - Bypass mode with REINFORCE + K1-RS + + For backward compatibility, you can still pass a dict, which will be converted to + RolloutCorrectionConfig automatically. + """ + + gamma: float = 1.0 + lam: float = 1.0 + adv_estimator: str = "gae" + norm_adv_by_std_in_grpo: bool = True + use_kl_in_reward: bool = False + kl_penalty: str = "kl" + kl_ctrl: KLControlConfig = field(default_factory=KLControlConfig) + use_pf_ppo: bool = False + pf_ppo: dict[str, Any] = field(default_factory=dict) + filter_groups: Optional[FilterGroupsConfig] = None + # Rollout Correction: corrects off-policy issues (policy mismatch, model staleness, distribution shifts) + # Set to None to disable, use RolloutCorrectionConfig presets (e.g., .tis(), .mis()), or pass dict + rollout_correction: Optional[RolloutCorrectionConfig] = None diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/algorithm/rollout_correction.yaml b/code/RL_model/verl/verl_train/verl/trainer/config/algorithm/rollout_correction.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2fd953184530df87b740f48b20ec5c98981321fa --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/algorithm/rollout_correction.yaml @@ -0,0 +1,26 @@ +# Rollout Correction: corrects off-policy distribution shifts +# See documentation: docs/algo/rollout_corr.md +# Use presets: RolloutCorrectionConfig.decoupled_seq_is(), .bypass_pg_is(), etc. + +# IS aggregation level: null (disabled), "token" (per-token), "sequence" (per-sequence) +rollout_is: null + +# Upper threshold for IS weight truncation (typical: 2.0-5.0) +rollout_is_threshold: 2.0 + +# RS aggregation level: null (disabled), e.g. "token_k1", "seq_sum_k1", "seq_mean_k3" +rollout_rs: null + +# Threshold for rejection sampling (string or float; see code docs) +rollout_rs_threshold: null + +# Operating mode: false = Decoupled (3 policies), true = Bypass (2 policies) +bypass_mode: false + +# Loss type in bypass mode (bypass_mode=true): +# - "ppo_clip": PPO clipped objective (IS handled by ratio, default) +# - "reinforce": REINFORCE with explicit IS weights (no PPO clipping) +loss_type: ppo_clip + +# Batch normalize IS weights: false = raw weights, true = normalize to mean=1.0 +rollout_is_batch_normalize: false diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/config.py b/code/RL_model/verl/verl_train/verl/trainer/config/config.py new file mode 100644 index 0000000000000000000000000000000000000000..bd323d09d0f624dc4330cd2085aced4165e33579 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/config.py @@ -0,0 +1,129 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any, Optional + +from verl.base_config import BaseConfig + +__all__ = ["CheckpointConfig", "ProfileConfig", "BaseModelConfig"] + + +@dataclass +class CheckpointConfig(BaseConfig): + """Configuration for model checkpointing. + + The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. + + Args: + save_contents (list[str]): What to include in saved checkpoints. + Options: 'model', 'optimizer', 'extra', 'hf_model'. + load_contents (list[str]): Contents to load from checkpoint. Defaults to same as save_contents. + async_save (bool): Whether to save checkpoints asynchronously. Only implemented for Megatron as of now. + """ + + save_contents: list[str] = field(default_factory=lambda: ["model", "optimizer", "extra"]) + load_contents: list[str] = field(default_factory=lambda: ["model", "optimizer", "extra"]) + async_save: bool = False + + +@dataclass +class ProfileConfig(BaseConfig): + """Configuration for profiling. + + The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. + + Args: + profile_ranks (Optional[list[int]]): List of ranks to profile. None means all ranks. + step_start (int): Starting step for profiling. + step_end (int): Ending step for profiling. + save_path (Optional[str]): Path to save profiling results. + """ + + profile_ranks: Optional[list[int]] = None + step_start: int = -1 + step_end: int = -1 + save_path: Optional[str] = None + + +@dataclass +class BaseModelConfig(BaseConfig): + """Base configuration for a model. + Contains core settings for loading and initializing a pretrained model checkpoint. + + Args: + path (str): Path to pretrained model weights. + tokenizer_path (Optional[str]): Tokenizer path (defaults to actor's model path if not set). + override_config (dict): Hugging Face config override. + external_lib (Optional[str]): External model implementation (optional). + trust_remote_code (bool): Whether to trust remote code from Hugging Face models. + lora (dict[str, Any]): LoRA configuration dictionary. + """ + + path: str = "~/models/deepseek-llm-7b-chat" + tokenizer_path: Optional[str] = None + override_config: dict[str, Any] = field(default_factory=dict) + external_lib: Optional[str] = None + trust_remote_code: bool = False + lora: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class ModuleConfig(BaseConfig): + """Configuration for external Python module, which can be loaded, executed (and optionally, ``import``ed). + + Args: + path (str, optional): Path to the module file to load and execute. + name (str, optional): Name of the module to ``import``. Format: ``"import.path.to.module"``. + If ``None``, the module will be loaded with a hased name and + will not be added to ``sys.modules``, thus can not be ``import``ed as ``name``. + """ + + path: Optional[str] = None + name: Optional[str] = None + + +@dataclass +class RewardManagerConfig(BaseConfig): + """Configuration for reward manager. + + A reward manager defines the mechanism of computing rule-based reward and handling different reward sources. + + Args: + source (str): Source of the reward manager. Options: ``"register"``, ``"importlib"``. Default: ``"register"``. + name (str, optional): + - When ``source`` is ``"register"``, the name is used in `get_reward_manager_cls(name)``. + See ``verl/experimental/reward/reward_manager.py`` for options. Default: ``"naive"``. + - When ``source`` is ``"importlib"``, the name is used in ``getattr(module, name)``, + e.g., ``"DAPORewardManager"``. + module (ModuleConfig, optional): Optional configuration for the external module defining the reward manager, + """ + + source: str = "register" + name: str = "naive" + module: Optional[ModuleConfig] = field(default_factory=ModuleConfig) + + def __post_init__(self): + super().__post_init__() + if self.source == "register": + from verl.workers.reward_manager.registry import REWARD_MANAGER_REGISTRY + + assert self.name in REWARD_MANAGER_REGISTRY, ( + f"Reward manager is not registered: {self.name=} ,{REWARD_MANAGER_REGISTRY.keys()=}" + ) + elif self.source == "importlib": + # NOTE: The existence is not checked since it depends on which machine the config is initialized on. + assert self.module is not None and self.module.path is not None, ( + "When source is importlib, module.path should be set." + ) diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/critic/critic.yaml b/code/RL_model/verl/verl_train/verl/trainer/config/critic/critic.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b0e52b12b752e9692290a27762c7fcfb7cf4a5c9 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/critic/critic.yaml @@ -0,0 +1,178 @@ +# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs +_target_: verl.workers.config.CriticConfig + +# Number of rollouts per update (mirrors actor rollout_n) +rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} + +# fsdp or fsdp2 strategy used for critic model training +strategy: ??? + +# whether to enable the critic worker. +# by default it is only enabled if advantage estimator is gae +# set it to True manually if you always want to enable critic worker +enable: null + +# optimizer configs +optim: + + # Learning rate + lr: 1e-5 + + # Warmup steps ratio; total steps will be injected at runtime + lr_warmup_steps_ratio: 0.0 + + # Total training steps (must be overridden at runtime) + total_training_steps: -1 + + # Weight decay + weight_decay: 0.01 + + # Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio. + lr_warmup_steps: -1 + + +# model config for the critic +model: + + # Path to pretrained model weights + path: ~/models/deepseek-llm-7b-chat + + # Tokenizer path (defaults to actor's model path) + tokenizer_path: ${oc.select:actor_rollout_ref.model.path,"~/models/deepseek-llm-7b-chat"} + + # Hugging Face config override + override_config: {} + + # External model implementation (optional) + external_lib: ${oc.select:actor_rollout_ref.model.external_lib,null} + + # Whether to trust remote code from Hugging Face models + trust_remote_code: ${oc.select:actor_rollout_ref.model.trust_remote_code,false} + +# PPO mini-batch size per update +ppo_mini_batch_size: ${oc.select:actor_rollout_ref.actor.ppo_mini_batch_size,256} + +# [Deprecated] Global micro batch size +ppo_micro_batch_size: null + +# Local per-GPU micro batch size +ppo_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size,null} + +# Whether to automatically adjust batch size at runtime +use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + +# Max tokens per GPU in one PPO batch (doubled for critic) +ppo_max_token_len_per_gpu: 32768 + +# Max token length per GPU in forward pass +forward_max_token_len_per_gpu: ${.ppo_max_token_len_per_gpu} + +# Number of PPO epochs per batch +ppo_epochs: ${oc.select:actor_rollout_ref.actor.ppo_epochs,1} + +# Shuffle training data across PPO epochs +shuffle: ${oc.select:actor_rollout_ref.actor.shuffle,false} + +# The seed used to construct mini-batch +data_loader_seed: 42 + +# PPO value function clipping range +cliprange_value: 0.5 + +# Loss aggregation mode: "token-mean", "seq-mean-token-sum", or "seq-mean-token-mean" +loss_agg_mode: ${oc.select:actor_rollout_ref.actor.loss_agg_mode,token-mean} + +# checkpoint configs +checkpoint: + + # Target dataclass for this configuration + _target_: verl.trainer.config.CheckpointConfig + + # What to include in saved checkpoints + # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space + save_contents: ['model', 'optimizer', 'extra'] + + # What to include when loading checkpoints + load_contents: ${.save_contents} + + # Whether to save checkpoints asynchronously. Only effective for Megatron as of now. + async_save: False + +# profile the critic model in `update_critic` +profiler: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.ProfilerConfig + + # profiler tool, default same as profiler.tool in global config + # choices: nsys, npu, torch, torch_memory + tool: ${oc.select:global_profiler.tool,null} + + # whether enable profile on Critic + enable: False + + # Whether to profile all ranks. + all_ranks: False + + # The ranks that will be profiled. [] or [0,1,...] + ranks: [] + + # profile results saving path + save_path: ${oc.select:global_profiler.save_path,null} + + # specific tool config which only related to the role + tool_config: + + # nsys tool config + nsys: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.NsightToolConfig + + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + + # npu config + npu: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.NPUToolConfig + + # Contents to profile, can be empty + # options: npu, cpu, memory, shapes, module, stack + contents: [] + + # Collection level, optional values: level_none, level0, level1, level2. + level: "level0" + + # Whether to automatically parse the data. + analysis: True + + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: False + + # torch profiler config + torch: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + + # Contents to profile, can be empty + # options: cuda, cpu, memory, shapes, stack + contents: [] + + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: false + + # torch memory profiler config + torch_memory: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.TorchMemoryToolConfig + + # Maximum number of memory allocation entries to track + trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} + + # Stack trace depth for memory allocations + stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/critic/dp_critic.yaml b/code/RL_model/verl/verl_train/verl/trainer/config/critic/dp_critic.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1cbaf03444a30aa9da87c6786a6bb48f9fc84f9d --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/critic/dp_critic.yaml @@ -0,0 +1,75 @@ +# Format checks enforced on CI: +# 1. Comments must appear above each field. +# 2. There must be a blank line between each field. +# 3. Inline comments (after a field on the same line) are not allowed. +# 4. Indentation level is respected for nested fields. + +# defaults specify the default config from each component +defaults: + + # fsdp optimizer config + - ../optim@optim: fsdp + + # fsdp engine config + - ../engine@model.fsdp_config: fsdp + + # dp actor config, inheriting from trainer/config/critic/critic.yaml + - critic + + # load the reference default config, then apply the fields in the current yaml + - _self_ + +# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs +_target_: verl.workers.config.FSDPCriticConfig + +# distribution strategy. Options: fsdp (deprecating), fsdp2 +strategy: fsdp + +# model config for the critic +model: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.workers.config.FSDPCriticModelCfg + + # Whether to use shared memory for loading the model + use_shm: False + + # Enable gradient checkpointing to save memory + enable_gradient_checkpointing: True + + # Offload activations to CPU to reduce GPU memory usage + enable_activation_offload: False + + # Use remove padding optimization (saves compute) + use_remove_padding: False + + # Set to positive value to enable LoRA (e.g., 32) + lora_rank: 0 + + # LoRA scaling factor + lora_alpha: 16 + + # LoRA target modules: "all-linear" or list of linear projection layers + target_modules: all-linear + + # TiledMLP configuration for memory-efficient MLP computation. + tiled_mlp: + + # whether to enable TiledMLP + enabled: False + + # number of shards to split the input + num_shards: 4 + +# Forward-only batch size during inference (global) +forward_micro_batch_size: ${oc.select:.ppo_micro_batch_size,null} + +# Forward-only batch size during inference (per GPU) +forward_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size_per_gpu,null} + +# Sequence parallelism size for Ulysses-style model parallelism +# [DEPRECATED] use fsdp_config.ulysses_sequence_parallel_size instead +ulysses_sequence_parallel_size: 1 + +# Gradient clipping for critic updates +grad_clip: 1.0 diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/critic/megatron_critic.yaml b/code/RL_model/verl/verl_train/verl/trainer/config/critic/megatron_critic.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3f170575cdc63a28a804f26f42901ba79a1fc898 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/critic/megatron_critic.yaml @@ -0,0 +1,106 @@ +# defaults specify the default config from each component +defaults: + + # megatron optimizer config + - ../optim@optim: megatron + + # megatron engine config + - ../engine@megatron: megatron + + # dp actor config, inheriting from trainer/config/critic/critic.yaml + - critic + + # load the reference default config, then apply the fields in the current yaml + - _self_ + +# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs +_target_: verl.workers.config.McoreCriticConfig + +strategy: megatron + +# seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron +nccl_timeout: 600 + +# model config for the critic +model: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.trainer.config.BaseModelConfig + + # override default empty mapping + override_config: + + model_config: {} + + moe_config: + + freeze_moe_router: False + + # LoRA (Low-Rank Adaptation) configuration for parameter-efficient fine-tuning + lora: + # LoRA type: "lora", "vlm_lora", "canonical_lora", or "dora" + type: lora + + # LoRA rank (Dimension of the low-rank projection space.). Set to 0 to disable LoRA + rank: 0 # typical values: 8, 16, 32, 64 + + # Weighting factor for the low-rank projection. Defaults to 32 + alpha: 32 + + # Dropout rate for the low-rank projection. Defaults to 0.0 + dropout: 0.0 + + # A list of module names to apply LoRA to. + # For fused LoRA, Defaults to all linear layers ['linear_qkv', 'linear_proj', 'linear_fc1', 'linear_fc2']. + # For canonical LoRA: ["linear_q", "linear_k", "linear_v", "linear_proj", "linear_fc1_up", "linear_fc1_gate", "linear_fc2"] + # - 'linear_qkv': Apply LoRA to the fused linear layer used for query, key, and value projections in self-attention + # - 'linear_proj': Apply LoRA to the linear layer used for projecting the output of self-attention + # - 'linear_fc1': Apply LoRA to the first fully-connected layer in MLP + # - 'linear_fc2': Apply LoRA to the second fully-connected layer in MLP + # Target modules can also contain wildcards. For example, you can specify + # target_modules=['*.layers.0.*.linear_qkv', '*.layers.1.*.linear_qkv'] to add LoRA to only linear_qkv on the first two layers + # + # Note: + # For MLA (e.g., DeepSeek), you should use ["linear_kv_down_proj","linear_kv_up_proj","linear_q_down_proj","linear_q_up_proj","linear_q_proj"] + # Instead of "linear_qkv" or ["linear_q","linear_k","linear_v"] + # By default, MoE routers are excluded from LoRA adaptation, and you will need to specify "router" in target_modules to include them. + target_modules: + - linear_qkv + - linear_proj + - linear_fc1 + - linear_fc2 + + # A list of module names not to apply LoRa to. It will match all nn.Linear & nn.Linear-adjacent modules whose name + # does not match any string in exclude_modules. If used, will require target_modules to be empty list or null + exclude_modules: [] + + # Position for applying dropout, can be 'pre' (before the low-rank projection) or 'post' (after). Defaults to 'pre' + dropout_position: pre + + # Initialization method for the low-rank matrix A. Defaults to "xavier". + lora_A_init_method: xavier + + # Initialization method for the low-rank matrix B. Defaults to "zero". + lora_B_init_method: zero + + # Enables the experimental All-to-All (A2A) communication strategy. Defaults to False + a2a_experimental: False + + # Parameter data type for LoRA weights. Default to null, which will use model's dtype. + dtype: null + + # Path to pre-trained LoRA adapter weights (null to train from scratch) + adapter_path: null + + # VLMLoRA additionally allows the user to specify whether the language or vision models should be frozen. + # For example, a common finetuning workload for multimodal models is to apply adapters to language model and fully + # finetune the vision model. + freeze_vision_model: True + freeze_vision_projection: True + freeze_language_model: True + +# Whether to load initial weights +load_weight: True + +# seed for data loader +data_loader_seed: ${oc.select:actor_rollout_ref.actor.data_loader_seed,null} diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/data/legacy_data.yaml b/code/RL_model/verl/verl_train/verl/trainer/config/data/legacy_data.yaml new file mode 100644 index 0000000000000000000000000000000000000000..60818f9e198e86266f51c5ac6c997fe73fe38300 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/data/legacy_data.yaml @@ -0,0 +1,131 @@ +# Tokenizer class or path. If null, it will be inferred from the model. +tokenizer: null + +# Whether to use shared memory for data loading. +use_shm: False + +# Training set parquet. Can be a list or a single file. +# The program will read all files into memory, so it can't be too large (< 100GB). +# The path can be either a local path or an HDFS path. +# For HDFS path, we provide utils to download it to DRAM and convert it to a local path. +train_files: ~/data/rlhf/gsm8k/train.parquet + +# Validation parquet. Can be a list or a single file. +val_files: ~/data/rlhf/gsm8k/test.parquet + +# Maximum sample length to be used. +# Set to -1 to use full dataset, otherwise, randomly +# select the specified number of samples from train dataset +train_max_samples: -1 + +# Maximum sample length to be used. +# Set to -1 to use full dataset, otherwise, randomly +# select the specified number of samples from val dataset +val_max_samples: -1 + +# The field in the dataset where the prompt is located. Default is 'prompt'. +prompt_key: prompt + +# The field used to select the reward function (if using different ones per example). +reward_fn_key: data_source + +# Maximum prompt length. All prompts will be left-padded to this length. +# An error will be reported if the length is too long. +# oc.select: default val for rollout.prompt_length +max_prompt_length: 512 + +# Maximum response length. Rollout in RL algorithms (e.g. PPO) generates up to this length. +# oc.select: default val for rollout.response_length +max_response_length: 512 + +# Batch size sampled for one training iteration of different RL algorithms. +train_batch_size: 1024 + +# Batch size used during validation. Can be null. +val_batch_size: null + +# use tool config to calculate true prompt length +tool_config_path: ${oc.select:actor_rollout_ref.rollout.multi_turn.tool_config_path, null} + +# Whether to return the original input_ids without adding chat template. +# This is used when the reward model's chat template differs from the policy. +# If using a model-based RM with different templates, this should be True. +return_raw_input_ids: False + +# Whether to return the original chat (prompt) without applying chat template. +return_raw_chat: True + +# Whether to return the full prompt with chat template. +return_full_prompt: False + +# Whether to shuffle the data in the dataloader. +shuffle: True + +# Seed to use when shuffling the data +seed: null + +# num dataloader workers +dataloader_num_workers: 8 + +# image patch size +image_patch_size: 14 + +# Whether to shuffle the validation set. +validation_shuffle: False + +# Whether to filter overlong prompts. +filter_overlong_prompts: False + +# Number of workers for filtering overlong prompts. +# For large-scale datasets, filtering can be time-consuming. +# Use multiprocessing to speed up. Default is 1. +filter_overlong_prompts_workers: 1 + +# Truncate the input_ids or prompt if they exceed max_prompt_length. +# Options: 'error', 'left', 'right', 'middle'. Default is 'error'. +truncation: error + +# The field in the multi-modal dataset where the image is located. Default is 'images'. +image_key: images + +# The field in the multi-modal dataset where the video is located. +video_key: videos + +# If the remote tokenizer has a Python file, this flag determines whether to allow using it. +trust_remote_code: False + +# Optional: specify a custom dataset class path and name if overriding default loading behavior. +custom_cls: + + # The path to the file containing your customized dataset class. If not specified, pre-implemented dataset will be used. + path: null + + # The name of the dataset class within the specified file. + name: null + +# Whether to return multi-modal inputs in the dataset. Set to False if rollout generates new multi-modal inputs. +return_multi_modal_inputs: True + +# settings related to data sampler +sampler: + + # the path to the module containing a curriculum class which implements the + # AbstractSampler interface + class_path: null + + # the name of the curriculum class like `MySampler` + class_name: null + +# Data generation configuration for augmenting the dataset. +datagen: + + # The path to the file containing your customized data generation class. + # E.g. 'pkg://verl.experimental.dynamic_dataset.dynamicgen_dataset' + path: null + + # The class name of the data generation class within the specified file. + # E.g. 'MockDataGenerator' + name: null + +# Additional kwargs when calling tokenizer.apply_chat_template +apply_chat_template_kwargs: {} diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/engine/fsdp.yaml b/code/RL_model/verl/verl_train/verl/trainer/config/engine/fsdp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..81d17e06add64db1f570566adee95639b6f10273 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/engine/fsdp.yaml @@ -0,0 +1,63 @@ +# Target class for this configuration +_target_: verl.workers.config.FSDPEngineConfig + +# policy for wrapping the model +wrap_policy: + + # Minimum number of parameters to trigger wrapping a layer with FSDP + min_num_params: 0 + +# Whether to offload model parameters to CPU (trades speed for memory) +# Note that this differs from the offload_policy in FSDP +param_offload: false + +# Whether to offload optimizer state to CPU +# Note that this differs from the offload_policy in FSDP +optimizer_offload: false + +# Only for FSDP2: offload param/grad/optimizer during train +offload_policy: false + +# Reshard after forward pass to reduce memory footprint +# For FSDP1, `false` enables `ShardingStrategy.SHARD_GRAD_OP` +reshard_after_forward: true + +# Number of GPUs in each FSDP shard group; -1 means auto +fsdp_size: -1 + +# Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather +# before the current forward computation. +forward_prefetch: False + +# model dtype of fsdp +model_dtype: fp32 + +# Whether to use original parameters in fsdp. Only avaiable in fsdp1 +use_orig_params: false + +# Random seed for reproducibility. +seed: 42 + +# Whether to enable full determinism for distributed training, only for debugging. +full_determinism: false + +# ulysses sequence parallel size +ulysses_sequence_parallel_size: 1 + +# Whether to use entropy_from_logits_with_chunking in fsdp. +entropy_from_logits_with_chunking: false + +# Whether to use torch compile in fsdp. +use_torch_compile: true + +# Whether to use entropy checkpointing in fsdp. +entropy_checkpointing: false + +# Whether to use forward only in fsdp. +forward_only: false + +# fsdp or fsdp2 +strategy: fsdp + +# Mixed precision training param dtype +dtype: bfloat16 # ["bfloat16", "float16"] diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/engine/megatron.yaml b/code/RL_model/verl/verl_train/verl/trainer/config/engine/megatron.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b588a96c1b3993f85de13179da6c4c84f66c795f --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/engine/megatron.yaml @@ -0,0 +1,90 @@ +# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs +_target_: verl.workers.config.McoreEngineConfig + +# Whether to offload model parameters to CPU +param_offload: False + +# Whether to offload gradients to CPU +grad_offload: False + +# Whether to offload optimizer state to CPU +optimizer_offload: False + +# tensor model parallel size +tensor_model_parallel_size: 1 + +# expert model parallel size +expert_model_parallel_size: 1 + +# expert tensor parallel size (null to be same as TP) +expert_tensor_parallel_size: null + +# pipeline model parallel size +pipeline_model_parallel_size: 1 + +# virtual pipeline model parallel size +virtual_pipeline_model_parallel_size: null + +# context parallel size +context_parallel_size: 1 + +# sequence parallel +sequence_parallel: True + +# Whether to use distributed optimizer +use_distributed_optimizer: True + +# Whether to use distributed checkpointing +use_dist_checkpointing: False + +# distributed checkpointing path +dist_checkpointing_path: null + +# distributed checkpointing prefix, e.g. Nemo2 will append prefix 'module.' to the state dict keys +dist_checkpointing_prefix: '' + +# oc.select: default val for ref.megatron.seed +seed: 42 + +# Allow to override Distributed Data Parallel (DDP) config +override_ddp_config: {} + +# additional transformer config like: num_layers_in_first(/last)_pipeline_stage +# oc.select: default val for ref.megatron.override_transformer_config +override_transformer_config: + # Recompute configuration, same as in megatron.training.arguments + # default use minimal performance-interference recompute methods + # Recompute granualarity, choices: ["full", "selective"] + recompute_granularity: null + + # Recompute modules, multiple choices: ["core_attn", "moe_act", "layernorm", "mla_up_proj", "mlp", "moe"] + # Please use correct module in matched model + recompute_modules: ["core_attn"] + + # 'uniform', 'block' + # 'uniform' divides the total number of transformer layers and checkpoints the input activation of each chunk + # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity + recompute_method: null + + # 'full' will checkpoint the entire transformer layer and 'selective' only checkpoints memory intensive part of attention + recompute_num_layers: null + + # Attention backend to use (flash,fused,unfused,local,auto). Defaults to auto in mcore, flash in verl + attention_backend: flash + +override_mcore_model_config: {} + +# oc.select: default val for ref.megatron.use_mbridge +use_mbridge: True + +# oc.select: default val for ref.megatron.vanilla_mbridge +vanilla_mbridge: True + +# whether to use thd format (sequence packing), if not, use bshd format, padding the input_ids to the longest sequence length +use_remove_padding: True + +# whether to use forward only +forward_only: False + +# Mixed precision training param dtype +dtype: bfloat16 # ["bfloat16", "float16"] diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/engine/veomni.yaml b/code/RL_model/verl/verl_train/verl/trainer/config/engine/veomni.yaml new file mode 100644 index 0000000000000000000000000000000000000000..da70cfabe51aeec48498fa6894d14e4ceba7cf0d --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/engine/veomni.yaml @@ -0,0 +1,68 @@ +# Target class for this configuration +_target_: verl.workers.config.VeOmniEngineConfig + +# Whether to offload model parameters to CPU +param_offload: False + +# Whether to offload optimizer state to CPU +optimizer_offload: False + +# fsdp or fsdp2 +data_parallel_mode: fsdp2 + +data_parallel_size: 1 + +data_parallel_replicate_size: 1 + +data_parallel_shard_size: 1 + +tensor_parallel_size: 1 + +expert_parallel_size: 1 + +pipeline_parallel_size: 1 + +context_parallel_size: 1 + +ulysses_parallel_size: 1 + +mixed_precision: true + +# Random seed for reproducibility. +seed: 42 + +# Whether to enable full determinism for distributed training, only for debugging. +full_determinism: false + +init_device: meta + +enable_full_shard: true + +ckpt_manager: dcp + +# Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather +# before the current forward computation. +forward_prefetch: true + +strategy: veomni + +# Whether to use torch compile in fsdp. +use_torch_compile: false + +# Whether to use forward only in fsdp. +forward_only: false + +enable_fsdp_offload: false + +enable_reentrant: false + +# support eager, sdpa, flash_attention_2, flash_attention_3, veomni_flash_attention_2_with_sp, +# veomni_flash_attention_3_with_sp and native-sparse +attn_implementation: flash_attention_2 + +# eager or fused +moe_implementation: fused + +force_use_huggingface: false + +activation_gpu_limit: 0.0 diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/evaluation.yaml b/code/RL_model/verl/verl_train/verl/trainer/config/evaluation.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6a88d77f1e73b6c3cce1972f639fcafb412669fa --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/evaluation.yaml @@ -0,0 +1,15 @@ +data: + path: /tmp/math_Qwen2-7B-Instruct.parquet + prompt_key: prompt + response_key: responses + data_source_key: data_source + reward_model_key: reward_model + +custom_reward_function: + path: null + name: compute_score + +ray_kwargs: + ray_init: + num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then. + timeline_json_file: null diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/generation.yaml b/code/RL_model/verl/verl_train/verl/trainer/config/generation.yaml new file mode 100644 index 0000000000000000000000000000000000000000..478733339ceabaf1ec5b71f381895ccc7d24ebea --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/generation.yaml @@ -0,0 +1,62 @@ +trainer: + nnodes: 1 + n_gpus_per_node: 8 + device: cuda + +data: + path: ~/data/rlhf/math/test.parquet + prompt_key: prompt + n_samples: 5 + output_path: /opt/tiger/math_Qwen2-7B-Instruct.parquet + batch_size: 128 + +model: + path: ~/models/Qwen2-7B-Instruct + external_lib: null +rollout: + _target_: verl.workers.config.RolloutConfig + name: vllm + # NOTE: 'sync' mode was removed in PR #4411. Only 'async' mode is supported. + # WARNING: The main_generation.py workflow is currently broken for vLLM async rollout + # as it requires synchronous generate_sequences() which vLLMAsyncRollout doesn't support. + # See issue #4682 for discussion and workarounds. + mode: async + temperature: 1.0 + top_k: 50 # 0 for hf rollout, -1 for vllm rollout + top_p: 0.7 + prompt_length: 1536 + response_length: 512 + # for vllm rollout + dtype: bfloat16 # should align with FSDP + gpu_memory_utilization: 0.5 + ignore_eos: False + enforce_eager: True + free_cache_engine: True + load_format: auto + tensor_model_parallel_size: 1 + data_parallel_size: 1 + max_num_batched_tokens: 8192 + max_model_len: null + max_num_seqs: 1024 + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: 8 + # for hf rollout + do_sample: True + disable_log_stats: True + enable_chunked_prefill: True + n: 1 + # support logging rollout prob for debugging purpose + calculate_log_probs: False +actor: + strategy: fsdp # This is for backward-compatibility + ulysses_sequence_parallel_size: 1 # sp size + entropy_from_logits_with_chunking: False # calculate entropy with chunking to reduce memory peak + entropy_checkpointing: False # recompute entropy + fsdp_config: + fsdp_size: -1 + forward_prefetch: False # FSDP1 forward_prefetch configuration + +ray_kwargs: + ray_init: + num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then. + timeline_json_file: null diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/model/hf_model.yaml b/code/RL_model/verl/verl_train/verl/trainer/config/model/hf_model.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4002a7f68c239824510b53bd80e38c960bae9df6 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/model/hf_model.yaml @@ -0,0 +1,97 @@ +# Format checks enforced on CI: +# 1. Comments must appear above each field. +# 2. There must be a blank line between each field. +# 3. Inline comments (after a field on the same line) are not allowed. +# 4. Indentation level is respected for nested fields. + +_target_: verl.workers.config.HFModelConfig + +# path to the huggingface model +path: ~/models/deepseek-llm-7b-chat + +# config to the huggingface config. In case it is not the same as path +hf_config_path: null + +# path to the huggingface tokenizer. In case it is not the same as path +tokenizer_path: null + +# whether to use shared memory for model loading +use_shm: False + +# whether to trust remote code. +trust_remote_code: False + +# custom chat template for the model +custom_chat_template: null + +# whether to use external libs for the model +external_lib: null + +# override hf config +override_config: {} + +# whether to enable gradient checkpointing. Only valid when we use hf model definition +enable_gradient_checkpointing: True + +# whether to enable activation offload. Only valid when we use hf model definition +enable_activation_offload: False + +# whether to use remove padding. Only valid when we use hf model definition +use_remove_padding: True + +# Set to positive value to enable LoRA (e.g., 32) +lora_rank: 0 + +# LoRA scaling factor +lora_alpha: 16 + +# Target modules for LoRA adaptation +target_modules: all-linear + +# Exclude modules from LoRA adaptation +exclude_modules: null + +# Path to pre-trained LoRA adapter to load for continued training +lora_adapter_path: null + +# whether to use liger. Only valid when we use hf model definition +use_liger: False + +# whether to use fused kernels. +use_fused_kernels: False + +# fused kernel options. +fused_kernel_options: + + # the implementation backend for fused kernels. + impl_backend: torch + +# TiledMLP configuration for memory-efficient MLP computation. +# Reduces peak memory by processing MLP forward/backward in tiles. +tiled_mlp: + + # whether to enable TiledMLP + enabled: False + + # number of shards to split the input. Higher values reduce peak memory but may slightly impact performance. + num_shards: 4 + +# MTP +mtp: + + _target_: verl.workers.config.MtpConfig + + enable: False + enable_train: False + enable_rollout: False + + detach_encoder: False + mtp_loss_scaling_factor: 0.1 + + speculative_algorithm: EAGLE + speculative_num_steps: 3 + speculative_eagle_topk: 1 + speculative_num_draft_tokens: 4 + + method: mtp + num_speculative_tokens: 1 diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/npu_profile/npu_profile.yaml b/code/RL_model/verl/verl_train/verl/trainer/config/npu_profile/npu_profile.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bb34dc7cf5988cde5e03b1544020388d9dda1ec7 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/npu_profile/npu_profile.yaml @@ -0,0 +1,34 @@ +# Options for the npu profiler +options: + + # Storage path of collected data. + save_path: ./profiler_data + + # The roles that will be profiled. Only takes effect in discrete mode. + # optional values: all, rollout_generate, actor_compute_log_prob, actor_update and ref_compute_log_prob. + # "all" means all roles will be profiled. + roles: ["all"] + + # Collection level, optional values: level_none, level0, level1, level2. + level: level0 + + # Whether to enable memory analysis. + with_memory: False + + # Whether to record tensor shape. + record_shapes: False + + # Whether to record Device-side performance data. + with_npu: True + + # Whether to record Host-side performance data. + with_cpu: True + + # Whether to record Python call stack information. + with_module: False + + # Whether to record operator call stack information. + with_stack: False + + # Whether to automatically parse the data. + analysis: True \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/optim/fsdp.yaml b/code/RL_model/verl/verl_train/verl/trainer/config/optim/fsdp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a7dd99b1ee2a3c724dd2b45b4db75b86dadcffa0 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/optim/fsdp.yaml @@ -0,0 +1,50 @@ +# Target class for this configuration +_target_: verl.workers.config.FSDPOptimizerConfig + +# Optimizer class name (e.g., "AdamW", "AdamW8bit", "_AdamW", "Adam") +optimizer: AdamW + +# Module path to import optimizer +# Examples: "torch.optim", "torchao.optim", "bitsandbytes.optim" +optimizer_impl: torch.optim + +# Learning rate +lr: 1e-3 + +# LR warmup steps ratio +lr_warmup_steps_ratio: 0.0 + +# Total training steps +total_training_steps: -1 + +# Weight decay +weight_decay: 0.01 + +# LR warmup steps +lr_warmup_steps: -1 + +# Betas for Adam optimizer +betas: [0.9, 0.999] + +# Clip gradient +clip_grad: 1.0 + +# Minimum LR ratio for cosine schedule +min_lr_ratio: 0.0 + +# Number of cosine cycles in LR schedule +num_cycles: 0.5 + +# LR scheduler type: "constant" or "cosine" +lr_scheduler_type: constant + +# deprecated +warmup_style: null + +# Additional optimizer-specific keyword arguments +# Example for torchao with bf16 stochastic rounding: +# optimizer_impl: torchao.optim +# optimizer: _AdamW +# override_optimizer_config: +# bf16_stochastic_round: true +override_optimizer_config: null diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/optim/megatron.yaml b/code/RL_model/verl/verl_train/verl/trainer/config/optim/megatron.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c3e49b7df8e59d33f51b50b943d9353af66d296c --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/optim/megatron.yaml @@ -0,0 +1,49 @@ +_target_: verl.workers.config.McoreOptimizerConfig + +# Learning rate +lr: 1e-3 + +# LR warmup steps ratio +lr_warmup_steps_ratio: 0.0 + +# Total training steps +total_training_steps: -1 + +# Weight decay +weight_decay: 0.01 + +# LR warmup steps +lr_warmup_steps: -1 + +# Betas for Adam optimizer +betas: [0.9, 0.999] + +# Clip gradient +clip_grad: 1.0 + +# optimizer type +optimizer: adam + +# initial learning rate for warmup, default to 0.0 +lr_warmup_init: 0.0 + +lr_decay_steps: null + +# select from constant/linear/cosine/inverse_square_root +lr_decay_style: constant + +# minimum learning rate, default to 0.0 +min_lr: 0.0 + +# select from constant/linear/cosine +weight_decay_incr_style: constant + +# select from constant/exponential/cosine +lr_wsd_decay_style: exponential + +lr_wsd_decay_steps: null + +# use checkpoint optimizer parameter scheduler +use_checkpoint_opt_param_scheduler: False + +override_optimizer_config: {} diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/optim/veomni.yaml b/code/RL_model/verl/verl_train/verl/trainer/config/optim/veomni.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ed9c69deb97a17902902f2cabd28a3c5ebe13377 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/optim/veomni.yaml @@ -0,0 +1,39 @@ +# Target class for this configuration +_target_: verl.workers.config.VeOmniOptimizerConfig + +optimizer: adamw + +# Learning rate +lr: 1e-3 + +# Minimum learning rate +lr_min: 0.0 + +# Starting learning rate for warmup +lr_start: 0.0 + +# LR warmup steps ratio +lr_warmup_steps_ratio: 0.0 + +# LR decay steps ratio +lr_decay_ratio: 1.0 + +# Total training steps +total_training_steps: -1 + +# Weight decay +weight_decay: 0.01 + +# LR warmup steps +lr_warmup_steps: -1 + +# Betas for Adam optimizer +betas: [0.9, 0.999] + +# Clip gradient +clip_grad: 1.0 + +# LR scheduler type: "constant" or "cosine" +lr_scheduler_type: cosine + +override_optimizer_config: {} diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/ppo_megatron_trainer.yaml b/code/RL_model/verl/verl_train/verl/trainer/config/ppo_megatron_trainer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..76ba4c5757512c44e2bab9e06a2c82ad66870872 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/ppo_megatron_trainer.yaml @@ -0,0 +1,248 @@ +# specify the default per-component configs +defaults: + # @.: + # actor_rollout_ref.actor: trainer/config/actor/megatron_actor.yaml + - actor@actor_rollout_ref.actor: megatron_actor + # data: trainer/config/data/legacy_data.yaml + - data@data: legacy_data + # (Rule-based) Reward manager config. + - reward_manager@reward_manager + # load the reference default config, then apply the fields in the current yaml + # Reference model config. + # Reference model will be enabled when actor.use_kl_loss or/and algorithm.use_kl_in_reward is/are True. + - ref@actor_rollout_ref.ref: megatron_ref + # Rollout model config. + - rollout@actor_rollout_ref.rollout: rollout + # Model config. + - model@actor_rollout_ref.model: hf_model + # Critic model config. + - critic@critic: megatron_critic + # Reward model config. + - reward_model@reward_model: megatron_reward_loop + # Rollout correction config. + - algorithm@algorithm.rollout_correction: rollout_correction + - _self_ + +actor_rollout_ref: + hybrid_engine: True + + nccl_timeout: 600 # seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron + + model: + override_config: + model_config: {} + moe_config: + freeze_moe_router: False + + use_fused_kernels: False # Whether to use custom fused kernels (PostProcessing, for memory efficiency) + + trust_remote_code: False + + # Whether to remove padding tokens in inputs during training + use_remove_padding: false + + # LoRA (Low-Rank Adaptation) configuration for parameter-efficient fine-tuning + lora: + # LoRA type: "lora", "vlm_lora", "canonical_lora", or "dora" + type: lora + + # whether to sync weights / refit by either merging LoRA adapters into the base model weights before transferring to vLLM (for better inference speed but more refit time and potential precision loss). If this is False, it will load separate adapters. + merge: False + + # LoRA rank (Dimension of the low-rank projection space.). Set to 0 to disable LoRA + rank: 0 # typical values: 8, 16, 32, 64 + + # Weighting factor for the low-rank projection. Defaults to 32 + alpha: 32 + + # Dropout rate for the low-rank projection. Defaults to 0.0 + dropout: 0.0 + + # A list of module names to apply LoRA to. + # For fused LoRA, Defaults to all linear layers ['linear_qkv', 'linear_proj', 'linear_fc1', 'linear_fc2']. + # For canonical LoRA: ["linear_q", "linear_k", "linear_v", "linear_proj", "linear_fc1_up", "linear_fc1_gate", "linear_fc2"] + # - 'linear_qkv': Apply LoRA to the fused linear layer used for query, key, and value projections in self-attention + # - 'linear_proj': Apply LoRA to the linear layer used for projecting the output of self-attention + # - 'linear_fc1': Apply LoRA to the first fully-connected layer in MLP + # - 'linear_fc2': Apply LoRA to the second fully-connected layer in MLP + # Target modules can also contain wildcards. For example, you can specify + # target_modules=['*.layers.0.*.linear_qkv', '*.layers.1.*.linear_qkv'] to add LoRA to only linear_qkv on the first two layers + # + # Note: + # For MLA (e.g., DeepSeek), you should use ["linear_kv_down_proj","linear_kv_up_proj","linear_q_down_proj","linear_q_up_proj","linear_q_proj"] + # Instead of "linear_qkv" or ["linear_q","linear_k","linear_v"] + # By default, MoE routers are excluded from LoRA adaptation, and you will need to specify "router" in target_modules to include them. + target_modules: + - linear_qkv + - linear_proj + - linear_fc1 + - linear_fc2 + + # A list of module names not to apply LoRa to. It will match all nn.Linear & nn.Linear-adjacent modules whose name + # does not match any string in exclude_modules. If used, will require target_modules to be empty list or None + exclude_modules: [] + + # Position for applying dropout, can be 'pre' (before the low-rank projection) or 'post' (after). Defaults to 'pre' + dropout_position: pre + + # Initialization method for the low-rank matrix A. Defaults to "xavier". + lora_A_init_method: xavier + + # Initialization method for the low-rank matrix B. Defaults to "zero". + lora_B_init_method: zero + + # Enables the experimental All-to-All (A2A) communication strategy. Defaults to False + a2a_experimental: False + + # Parameter data type for LoRA weights. Default to null, which will use model's dtype. + dtype: null + + # Path to pre-trained LoRA adapter weights (null to train from scratch) + adapter_path: null + + # VLMLoRA additionally allows the user to specify whether the language or vision models should be frozen. + # For example, a common finetuning workload for multimodal models is to apply adapters to language model and fully + # finetune the vision model. + freeze_vision_model: True + freeze_vision_projection: True + freeze_language_model: True + + rollout: + quantization: null + + layer_name_map: + qkv_layer_name: qkv + gate_proj_layer_name: gate_up + +custom_reward_function: + path: null + name: compute_score + +algorithm: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.trainer.config.AlgoConfig + gamma: 1.0 + lam: 1.0 + adv_estimator: gae + norm_adv_by_std_in_grpo: True + use_kl_in_reward: False + kl_penalty: kl # how to estimate kl divergence + kl_ctrl: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.trainer.config.KLControlConfig + type: fixed + kl_coef: 0.001 + horizon: 10000 + target_kl: 0.1 + use_pf_ppo: False + pf_ppo: + reweight_method: pow # ["pow", "max_min", "max_random"] + weight_pow: 2.0 + +trainer: + balance_batch: True + total_epochs: 30 + total_training_steps: null + project_name: verl_examples + experiment_name: gsm8k + logger: ["console", "wandb"] + log_val_generations: 0 + nnodes: 1 + n_gpus_per_node: 8 + save_freq: -1 + esi_redundant_time: 0 + + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or disable or resume_path if resume_from_path is set + resume_from_path: null + del_local_ckpt_after_load: False + val_before_train: True + test_freq: -1 + critic_warmup: 0 + default_hdfs_dir: null + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + max_actor_ckpt_to_keep: null + max_critic_ckpt_to_keep: null + # The timeout for ray worker group to wait for the register center to be ready + ray_wait_register_center_timeout: 300 + device: cuda + # Directory for logging rollout data; no dump if null + rollout_data_dir: null + + # whether to use legacy worker implementation + # mode: "auto", "enable", or "disable" + use_legacy_worker_impl: auto + +global_profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: null # choose between nsys, npu, torch, torch_memory + steps: null # profile steps + profile_continuous_steps: False + save_path: "outputs/profile" # profiler saving path + # Specific tool configs, can use +profiler.tool_config.[tool].xxx to config + global_tool_config: + # nsys config + nsys: + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: False + + # controller Nvidia Nsight Systems Options. Must set when profile_steps is not None. + ## reference https://docs.nvidia.com/nsight-systems/UserGuide/index.html + ## reference https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html + controller_nsight_options: + # Select the API(s) to be traced. + trace: "cuda,nvtx,cublas,ucx" + + # Track the GPU memory usage by CUDA kernels. Must be string type "true" or "false". + cuda-memory-usage: "true" + + # CUDA graphs will be traced as a whole + cuda-graph-trace: "graph" + + # worker Nvidia Nsight Systems Options. Must set when profile_steps is not None. + worker_nsight_options: + # Select the API(s) to be traced. + trace: "cuda,nvtx,cublas,ucx" + + # Track the GPU memory usage by CUDA kernels. Must be string type "true" or "false". + cuda-memory-usage: "true" + + # CUDA graphs will be traced as a whole + cuda-graph-trace: "graph" + + # Profiling only in a range of torch.cuda.profiler.start and stop. Do not change this config. + capture-range: "cudaProfilerApi" + + # Specify the desired behavior when a capture range ends. + # In verl we need the torch.cuda.profiler.start/stop pair to repeats n times. + # valid values are "repeat-shutdown:n" or null. + # For normal whole step profiling, n = len(profile_steps); + # but for discrete profiling, n = len(profile_steps) * Number(subtasks). + # Or you can just leave it null and the program will use n = len(profile_steps) * 6; + capture-range-end: null + + # Send signal to the target application's process group. We let the program to exit by itself. + kill: none + + # enable memory visualization for debugging memory usage + torch_memory: + # Maximum number of allocation entries to record + trace_alloc_max_entries: 100_000 + # The depth of the call stack to capture for each allocation + stack_depth: 32 + # 'alloc': records only allocation events || 'state': records memory state changes || 'all': records both. + context: "all" + # 'python': records Python stacks || 'cpp': records C++ stacks (available in some versions) || 'all': records both. + stacks: "all" + # devices, record_context etc. + kw_args: {} + +# configs for TransferQueue +transfer_queue: + # Whether to enable transfer queue + enable: False + +ray_kwargs: + ray_init: + num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then. + timeline_json_file: null diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/ppo_trainer.yaml b/code/RL_model/verl/verl_train/verl/trainer/config/ppo_trainer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7489b522fa22de75528cbc47ec768d1bb13fb92c --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/ppo_trainer.yaml @@ -0,0 +1,320 @@ +# Format checks enforced on CI: +# 1. Comments must appear above each field. +# 2. There must be a blank line between each field. +# 3. Inline comments (after a field on the same line) are not allowed. +# 4. Indentation level is respected for nested fields. + +# specify the default per-component configs +defaults: + + # @.: + # actor_rollout_ref.actor: trainer/config/actor/dp_actor.yaml + - actor@actor_rollout_ref.actor: dp_actor + + # data: trainer/config/data/legacy_data.yaml + - data@data: legacy_data + + # (Rule-based) Reward manager config. + - reward_manager@reward_manager + + # Reference model config. + # Reference model will be enabled when actor.use_kl_loss or/and algorithm.use_kl_in_reward is/are True. + - ref@actor_rollout_ref.ref: dp_ref + + # Rollout model config. + - rollout@actor_rollout_ref.rollout: rollout + + # Model config. + - model@actor_rollout_ref.model: hf_model + + # Critic model config. + - critic@critic: dp_critic + + # Reward model config. + - reward_model@reward_model: dp_reward_loop + + # Rollout correction config. + - algorithm@algorithm.rollout_correction: rollout_correction + + # load the reference default config, then apply the fields in the current yaml + # self config override anything above + - _self_ + +# config for actor, rollout and reference model +actor_rollout_ref: + + # Whether it's a hybrid engine, currently only supports hybrid engine + hybrid_engine: true + + # Timeout for operations executed against the process group + nccl_timeout: 600 + + # Rollout model config. + rollout: + + # for huge model, layered summon can save memory (prevent OOM) but make it slower + layered_summon: False + +# custom reward function definition +custom_reward_function: + + # The path to the file containing your customized reward function. + # If not specified, pre-implemented reward functions will be used. + path: null + + # The name of the reward function within the specified file. Default is 'compute_score'. + name: compute_score + +# config for the algorithm +algorithm: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.trainer.config.AlgoConfig + + # Discount factor for future rewards + gamma: 1.0 + + # Trade-off between bias and variance in the GAE estimator + lam: 1.0 + + # Advantage estimator type: "gae", "grpo", "reinforce_plus_plus", etc. + adv_estimator: gae + + # Whether to normalize advantages by std (specific to GRPO) + norm_adv_by_std_in_grpo: True + + # Whether to enable in-reward KL penalty + use_kl_in_reward: False + + # How to estimate KL divergence: "kl", "abs", "mse", "low_var_kl", or "full" + kl_penalty: kl + + # KL control configuration + kl_ctrl: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.trainer.config.KLControlConfig + + # KL control type: "fixed" or "adaptive" + type: fixed + + # Initial coefficient for KL penalty + kl_coef: 0.001 + + # Horizon value for adaptive controller (if enabled) + horizon: 10000 + + # Target KL divergence (used for adaptive controller) + target_kl: 0.1 + + # Whether to enable preference feedback PPO + use_pf_ppo: False + + # Preference feedback PPO settings + pf_ppo: + + # Method for reweighting samples: "pow", "max_min", or "max_random" + reweight_method: pow + + # Power used for weight scaling in "pow" method + weight_pow: 2.0 + +# config for the trainer +trainer: + + # Whether to balance batch sizes across distributed workers + balance_batch: True + + # Number of epochs in training + total_epochs: 30 + + # Total training steps (can be set explicitly or derived from epochs) + total_training_steps: null + + # Project name for experiment tracking (e.g., wandb) + project_name: verl_examples + + # Experiment name for run identification in tracking tools + experiment_name: gsm8k + + # Logging backends to use: "console", "wandb", etc. + logger: ["console", "wandb"] + + # Number of generations to log during validation + log_val_generations: 0 + + # Directory for logging rollout data; no dump if null + rollout_data_dir: null + + # Directory for logging validation data; no dump if null + validation_data_dir: null + + # Number of nodes used in the training + nnodes: 1 + + # Number of GPUs per node + n_gpus_per_node: 8 + + # Save frequency (by iteration) for model checkpoints + save_freq: -1 + + # ESI refers to the elastic server instance used during training, similar to the training plan. For example, + # if you purchase 10 hours of computing power, the ESI will automatically shut down after 10 hours of training. + # To ensure a checkpoint is saved before ESI shuts down, the system will start saving a checkpoint in advance. + # The advance time is calculated as: Advance Time = Longest historical step duration + Checkpoint save duration + esi_redundant_time. + # Here, esi_redundant_time is a user-defined value that further extends the advance time for added safety. + esi_redundant_time: 0 + + # Resume mode: "auto", "disable", or "resume_path" + # "auto": resume from last checkpoint if available + # "disable": start from scratch + # "resume_path": resume from a user-defined path + resume_mode: auto + + # Path to resume training from (only used when resume_mode is "resume_path") + resume_from_path: null + + # Whether to run validation before training begins + val_before_train: True + + # Whether to run validation only + val_only: False + + # Validation frequency (in training iterations) + test_freq: -1 + + # Number of iterations to warm up the critic before updating policy + critic_warmup: 0 + + # Default path to distributed filesystem for saving checkpoints + default_hdfs_dir: null + + # Whether to delete local checkpoints after loading + del_local_ckpt_after_load: False + + # Default local directory for saving checkpoints + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + + # Maximum number of actor checkpoints to keep + max_actor_ckpt_to_keep: null + + # Maximum number of critic checkpoints to keep + max_critic_ckpt_to_keep: null + + # Timeout (in seconds) for Ray worker to wait for registration + ray_wait_register_center_timeout: 300 + + # Device to run training on (e.g., "cuda", "cpu") + device: cuda + + # whether to use legacy worker implementation + # mode: "auto", "enable", or "disable" + use_legacy_worker_impl: auto + +# profiler configs +global_profiler: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.ProfilerConfig + + # Profiling tool: choose between nsys, npu, torch, torch_memory + tool: null + + # profile steps + steps: null + + # Whether to combine continuous steps into one database. + ## If True, worker.profiler.discrete must be False, [1,2] in one, [5] in another. + ## If False, [1] in one, [2] in another, [5] in another. + profile_continuous_steps: False + + # Path to save profiling contents + save_path: "outputs/profile" + + # Specific tool configs, can use +profiler.tool_config.[tool].xxx to config + global_tool_config: + + # nsys config + nsys: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.NsightToolConfig + + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: False + + # controller Nvidia Nsight Systems Options. Must set when profile_steps is not None. + ## reference https://docs.nvidia.com/nsight-systems/UserGuide/index.html + ## reference https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html + controller_nsight_options: + + # Select the API(s) to be traced. + trace: "cuda,nvtx,cublas,ucx" + + # Track the GPU memory usage by CUDA kernels. Must be string type "true" or "false". + cuda-memory-usage: "true" + + # CUDA graphs will be traced as a whole + cuda-graph-trace: "graph" + + # worker Nvidia Nsight Systems Options. Must set when profile_steps is not None. + worker_nsight_options: + + # Select the API(s) to be traced. + trace: "cuda,nvtx,cublas,ucx" + + # Track the GPU memory usage by CUDA kernels. Must be string type "true" or "false". + cuda-memory-usage: "true" + + # CUDA graphs will be traced as a whole + cuda-graph-trace: "graph" + + # Profiling only in a range of torch.cuda.profiler.start and stop. Do not change this config. + capture-range: "cudaProfilerApi" + + # Specify the desired behavior when a capture range ends. + # In verl we need the torch.cuda.profiler.start/stop pair to repeats n times. + # valid values are "repeat-shutdown:n" or null. + # For normal whole step profiling, n = len(profile_steps); + # but for discrete profiling, n = len(profile_steps) * Number(subtasks). + # Or you can just leave it null and the program will use n = len(profile_steps) * 6; + capture-range-end: null + + # Send signal to the target application's process group. We let the program to exit by itself. + kill: none + + # enable memory visualization for debugging memory usage + torch_memory: + + # Maximum number of allocation entries to record + trace_alloc_max_entries: 100_000 + + # The depth of the call stack to capture for each allocation + stack_depth: 32 + + # 'alloc': records only allocation events || 'state': records memory state changes || 'all': records both. + context: "all" + + # 'python': records Python stacks || 'cpp': records C++ stacks (available in some versions) || 'all': records both. + stacks: "all" + + # devices, record_context etc. + kw_args: {} + +# configs for TransferQueue +transfer_queue: + + # Whether to enable transfer queue + enable: False + +# configs related to ray +ray_kwargs: + + # configs related to ray initialization + ray_init: + + # Number of CPUs for Ray. Use a fixed number instead of null when using SLURM. + num_cpus: null + + # Path to save Ray timeline JSON for performance profiling + timeline_json_file: null diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/profiler/profiler.yaml b/code/RL_model/verl/verl_train/verl/trainer/config/profiler/profiler.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2004ba3f5f00d0f79c449991b55860f670d6d8ae --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/profiler/profiler.yaml @@ -0,0 +1,73 @@ +# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs +_target_: verl.utils.profiler.ProfilerConfig + +# profiler tool, default same as profiler.tool in global config +# choices: nsys, npu, torch +tool: torch + +# whether enable profile on Actor +enable: False + +# Whether to profile all ranks. +all_ranks: False + +# The ranks that will be profiled. [] or [0,1,...] +ranks: [] + +# profile results saving path +save_path: "outputs/profile" + +tool_config: + npu: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.NPUToolConfig + + # Contents to profile, can be empty + # options: npu, cpu, memory, shapes, module, stack + contents: [ ] + + # Collection level, optional values: level_none, level0, level1, level2. + level: "level0" + + # Whether to automatically parse the data. + analysis: True + + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: False + + name: npu + + + nsys: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.NsightToolConfig + + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + + name: nsight + + torch: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + + # Contents to profile, can be empty + # options: cuda, cpu, memory, shapes, stack + contents: [] + + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: false + + name: torch + + torch_memory: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.TorchMemoryToolConfig + + # Maximum number of memory allocation entries to track + trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} + + # Stack trace depth for memory allocations + stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + + name: torch_memory \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/ref/dp_ref.yaml b/code/RL_model/verl/verl_train/verl/trainer/config/ref/dp_ref.yaml new file mode 100644 index 0000000000000000000000000000000000000000..64b7d2abbc0fe920f7ad3bf3424f9198865e9811 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/ref/dp_ref.yaml @@ -0,0 +1,30 @@ +# defaults specify the default config from each component +defaults: + + # dp ref config, inheriting from trainer/config/ref/ref.yaml + - ref + + # fsdp engine config + - ../engine@fsdp_config: fsdp + + # load the reference default config, then apply the fields in the current yaml + - _self_ + +# Target class for this configuration +_target_: verl.workers.config.FSDPActorConfig + +# fsdp config +fsdp_config: + + # ref model is forward only + forward_only: True + +# sequence parallel size +# same as actor_rollout_ref.actor.ulysses_sequence_parallel_size if it exists, otherwise 1 +ulysses_sequence_parallel_size: ${oc.select:actor_rollout_ref.actor.ulysses_sequence_parallel_size,1} + +# calculate entropy with chunking to reduce memory peak +entropy_from_logits_with_chunking: False + +# recompute entropy +entropy_checkpointing: False diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/ref/megatron_ref.yaml b/code/RL_model/verl/verl_train/verl/trainer/config/ref/megatron_ref.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ca1fbb3c0739ef9286fac15c7829a8f8869766ea --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/ref/megatron_ref.yaml @@ -0,0 +1,30 @@ +# megatron ref config, inheriting from trainer/config/ref/ref.yaml +defaults: + - ref + + # megatron engine config + - ../engine@megatron: megatron + + # load the reference default config, then apply the fields in the current yaml + - _self_ + +_target_: verl.workers.config.McoreActorConfig + +strategy: megatron + +megatron: + seed: ${oc.select:actor_rollout_ref.actor.megatron.seed,42} + override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}} + use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False} + vanilla_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.vanilla_mbridge,True} + use_remove_padding: ${oc.select:actor_rollout_ref.actor.megatron.use_remove_padding,True} + tensor_model_parallel_size: ${oc.select:actor_rollout_ref.actor.megatron.tensor_model_parallel_size,1} + pipeline_model_parallel_size: ${oc.select:actor_rollout_ref.actor.megatron.pipeline_model_parallel_size,1} + virtual_pipeline_model_parallel_size: ${oc.select:actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size,null} + context_parallel_size: ${oc.select:actor_rollout_ref.actor.megatron.context_parallel_size,1} + expert_model_parallel_size: ${oc.select:actor_rollout_ref.actor.megatron.expert_model_parallel_size,1} + expert_tensor_parallel_size: ${oc.select:actor_rollout_ref.actor.megatron.expert_tensor_parallel_size,null} + param_offload: ${oc.select:actor_rollout_ref.actor.megatron.param_offload,False} + forward_only: True + +load_weight: True diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/ref/ref.yaml b/code/RL_model/verl/verl_train/verl/trainer/config/ref/ref.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9034aa3e652ac4aa6ed9df7e42b85aed8dcd2d65 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/ref/ref.yaml @@ -0,0 +1,120 @@ +# Number of rollouts per update (mirrors actor rollout_n) +rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} + +# actor_rollout_ref.ref: FSDP config same as actor. For models larger than 7B, it’s recommended to turn on offload for ref by default +strategy: ${actor_rollout_ref.actor.strategy} + +# whether to enable torch.compile +# same as actor_rollout_ref.actor.use_torch_compile if it exists, otherwise 1 +use_torch_compile: ${oc.select:actor_rollout_ref.actor.use_torch_compile,true} + +# [Will be deprecated, use log_prob_micro_batch_size_per_gpu] +# The batch size for one forward pass in the computation of log_prob. Global batch size. +log_prob_micro_batch_size: null + +# The batch size for one forward pass in the computation of log_prob. Local batch size per GPU. +log_prob_micro_batch_size_per_gpu: null + +# enable dynamic batch size (sequence packing) for log_prob computation +# same as actor_rollout_ref.actor.use_dynamic_bsz if it exists, otherwise false +log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + +# the max token length per GPU +# same as actor_rollout_ref.actor.ppo_max_token_len_per_gpu if it exists, otherwise 16384 +log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} + +# profile the ref model in `compute_log_prob` +profiler: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.ProfilerConfig + + # choices: nsys, npu, torch, torch_memory + tool: ${oc.select:global_profiler.tool,null} + + # whether enable profile on Ref + enable: False + + # Whether to profile all ranks. + all_ranks: False + + # The ranks that will be profiled. [] or [0,1,...] + ranks: [] + + # profile results saving path + save_path: ${oc.select:global_profiler.save_path,null} + + # specific tool config which only related to the role + tool_config: + + # nsys tool config + nsys: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.NsightToolConfig + + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + + # npu config + npu: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.NPUToolConfig + + # Contents to profile, can be empty + # options: npu, cpu, memory, shapes, module, stack + contents: [] + + # Collection level, optional values: level_none, level0, level1, level2. + level: "level0" + + # Whether to automatically parse the data. + analysis: True + + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: False + + # torch profiler config + torch: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + + # Contents to profile, can be empty + # options: cuda, cpu, memory, shapes, stack + contents: [] + + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: false + + # torch memory profiler config + torch_memory: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.TorchMemoryToolConfig + + # Maximum number of memory allocation entries to track + trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} + + # Stack trace depth for memory allocations + stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + +# Router replay configuration for MoE models +router_replay: + + # Target dataclass for this configuration + _target_: verl.workers.config.RouterReplayConfig + + # Router replay mode: disabled, R2, R3 + # - R2: Use R2 routing strategy (record mode) + # - R3: Use R3 routing strategy (record mode) + mode: disabled + + # File path to save recorded routing decisions + # Required when mode is 'record', 'R2', or 'R3' + record_file: null + + # File path to load recorded routing decisions for replay + # Required when mode is 'replay' + replay_file: null \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/reward_manager.yaml b/code/RL_model/verl/verl_train/verl/trainer/config/reward_manager.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3e55a1dafc52b3b1da97f219875cd8a7fbdf2662 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/reward_manager.yaml @@ -0,0 +1,8 @@ +# See `verl/trainer/config/config.py:RewardManagerConfig` for more details. +_target_: verl.trainer.config.config.RewardManagerConfig +source: register +name: ${oc.select:reward_model.reward_manager,naive} +module: + _target_: verl.trainer.config.config.ModuleConfig + path: null + name: custom_reward_manager diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/reward_model/dp_reward_loop.yaml b/code/RL_model/verl/verl_train/verl/trainer/config/reward_model/dp_reward_loop.yaml new file mode 100644 index 0000000000000000000000000000000000000000..04fb106df1cc54fa6de1739f3be816138a5e0937 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/reward_model/dp_reward_loop.yaml @@ -0,0 +1,43 @@ +defaults: + - dp_reward_model + - _self_ + +use_reward_loop: True +reward_manager: naive +enable: False + +# Whether to deploy the model to a separate resource pool. +enable_resource_pool: False +n_gpus_per_node: 8 +num_workers: 1 +nnodes: 0 + +model: + path: ~/models/FsfairX-LLaMA3-RM-v0.1 + external_lib: ${actor_rollout_ref.model.external_lib} + trust_remote_code: False + +rollout: + _target_: verl.workers.config.RolloutConfig + name: ??? + dtype: bfloat16 + gpu_memory_utilization: 0.5 + enforce_eager: true + cudagraph_capture_sizes: null + free_cache_engine: true + data_parallel_size: 1 + expert_parallel_size: 1 + tensor_model_parallel_size: 2 + max_num_batched_tokens: 8192 + max_model_len: null + max_num_seqs: 1024 + load_format: auto + engine_kwargs: {} + limit_images: null + enable_chunked_prefill: true + enable_prefix_caching: true + disable_log_stats: true + skip_tokenizer_init: false + + prompt_length: 2048 + response_length: 2048 \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/reward_model/dp_reward_model.yaml b/code/RL_model/verl/verl_train/verl/trainer/config/reward_model/dp_reward_model.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fff1f9f1f1d32100e77357781ee29a5728ef298c --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/reward_model/dp_reward_model.yaml @@ -0,0 +1,55 @@ +# Format checks enforced on CI: +# 1. Comments must appear above each field. +# 2. There must be a blank line between each field. +# 3. Inline comments (after a field on the same line) are not allowed. +# 4. Indentation level is respected for nested fields. + +# defaults specify the default config from each component +defaults: + + # dp actor config, inheriting from trainer/config/reward_model/reward_model.yaml + - reward_model + + # load the reference default config, then apply the fields in the current yaml + - _self_ + +strategy: fsdp + +model: + + # Whether to use shared memory for loading the model + use_shm: False + + # Use remove padding optimization (saves compute) + use_remove_padding: False + + # Whether to use fused reward kernels for speedup + use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels} + + # FSDP-specific config + fsdp_config: + + # Target configuration dataclass + _target_: verl.workers.config.FSDPEngineConfig + + # Policy for wrapping layers with FSDP + wrap_policy: + + # Minimum number of parameters to trigger wrapping + min_num_params: 0 + + # Whether to offload model parameters to CPU + param_offload: False + + # Only for FSDP2: Reshard after forward pass to reduce memory footprint + reshard_after_forward: True + + # Number of GPUs in each FSDP shard group; -1 means auto + fsdp_size: -1 + + # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather + # before the current forward computation. + forward_prefetch: False + +# Sequence parallelism size for Ulysses-style model parallelism +ulysses_sequence_parallel_size: 1 \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/reward_model/megatron_reward_loop.yaml b/code/RL_model/verl/verl_train/verl/trainer/config/reward_model/megatron_reward_loop.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f99b94abcc4917b08363cc6c01039a319592483c --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/reward_model/megatron_reward_loop.yaml @@ -0,0 +1,43 @@ +defaults: + - megatron_reward_model + - _self_ + +use_reward_loop: True +reward_manager: naive +enable: False + +# Whether to deploy the model to a separate resource pool. +enable_resource_pool: False +n_gpus_per_node: 8 +num_workers: 1 +nnodes: 0 + +model: + path: ~/models/FsfairX-LLaMA3-RM-v0.1 + external_lib: ${actor_rollout_ref.model.external_lib} + trust_remote_code: False + +rollout: + _target_: verl.workers.config.RolloutConfig + name: ??? + dtype: bfloat16 + gpu_memory_utilization: 0.5 + enforce_eager: true + cudagraph_capture_sizes: null + free_cache_engine: true + data_parallel_size: 1 + expert_parallel_size: 1 + tensor_model_parallel_size: 2 + max_num_batched_tokens: 8192 + max_model_len: null + max_num_seqs: 1024 + load_format: auto + engine_kwargs: {} + limit_images: null + enable_chunked_prefill: true + enable_prefix_caching: true + disable_log_stats: true + skip_tokenizer_init: false + + prompt_length: 2048 + response_length: 2048 \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/reward_model/megatron_reward_model.yaml b/code/RL_model/verl/verl_train/verl/trainer/config/reward_model/megatron_reward_model.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ea585075e57c9116ef4be4e9026062ab6ad40c61 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/reward_model/megatron_reward_model.yaml @@ -0,0 +1,76 @@ +# defaults specify the default config from each component +defaults: + + # dp actor config, inheriting from trainer/config/reward_model/reward_model.yaml + - reward_model + + # load the reference default config, then apply the fields in the current yaml + - _self_ + +strategy: megatron + +# seconds, default is 10 minutes for torch, you can set it to a larger value +# if you have long-running operations like 32B or 72B model using megatron +nccl_timeout: 600 + +# Megatron parallelism & checkpointing config +megatron: + + # Target configuration dataclass + _target_: verl.workers.config.MegatronEngineConfig + + # Whether to offload model parameters to CPU + param_offload: False + + # Number of GPUs in tensor model parallel group + tensor_model_parallel_size: 1 + + # Number of GPUs in expert model parallel group + expert_model_parallel_size: 1 + + # Expert tensor parallel size (null to be same as TP) + expert_tensor_parallel_size: null + + # Number of pipeline model parallel stages + pipeline_model_parallel_size: 1 + + # change VPP interface for parallelism tests + virtual_pipeline_model_parallel_size: null + + # Context parallel size + context_parallel_size: 1 + + # Whether to use sequence parallelism + sequence_parallel: True + + # Whether to use distributed optimizer + use_distributed_optimizer: False + + # Whether to enable distributed checkpointing + use_dist_checkpointing: False + + # Path for distributed checkpoints + dist_checkpointing_path: null + + # distributed checkpointing prefix, e.g. Nemo2 will append prefix 'module.' to the state dict keys + dist_checkpointing_prefix: '' + + # RNG seed for megatron + seed: ${oc.select:actor_rollout_ref.actor.megatron.seed,42} + + # Any overrides to transformer config + override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}} + + # Whether to use mbridge for faster comms + use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False} + + # Whether to use mbridge instead of Megatron-Bridge + vanilla_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.vanilla_mbridge,True} + + # Whether to use thd format (sequence packing), if not, use bshd format, padding the input_ids to the longest sequence length + use_remove_padding: ${oc.select:actor_rollout_ref.actor.megatron.use_remove_padding,True} + + dtype: bfloat16 + +# Whether to load weights (default True) +load_weight: True \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/reward_model/reward_model.yaml b/code/RL_model/verl/verl_train/verl/trainer/config/reward_model/reward_model.yaml new file mode 100644 index 0000000000000000000000000000000000000000..36f3a2e4381e6eb31d035975ecca7ef9d5d02c9d --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/reward_model/reward_model.yaml @@ -0,0 +1,109 @@ +# configs for the reward model + +# Whether to enable reward model. If False, we compute the reward only with the user-defined reward functions. +# In GSM8K and Math examples, we disable reward model. +# For RLHF alignment example using full_hh_rlhf, we utilize reward model to assess the responses. +# If False, the following parameters are not effective +enable: False + +# Whether to deploy the model to a separate resource pool. +# If true, n_gpus_per_node & nnodes will be used to determine the resource node. +enable_resource_pool: False +n_gpus_per_node: 0 +nnodes: 0 + +# FSDP strategy: "fsdp" or "fsdp2" +strategy: ??? + +# model config for reward scoring +model: + + # Input tokenizer. If the reward model's chat template is inconsistent with the policy, + # we need to first decode to plaintext, then apply the rm's chat_template. + # Then score with RM. If chat_templates are consistent, it can be set to null. + # set this to null if the chat template is identical + input_tokenizer: ${actor_rollout_ref.model.path} + + # RM’s HDFS path or local path. Note that RM only supports AutoModelForSequenceClassification. + # Other model types need to define their own RewardModelWorker and pass it from the code. + path: ~/models/FsfairX-LLaMA3-RM-v0.1 + + # External model implementation (optional) + external_lib: ${actor_rollout_ref.model.external_lib} + + # Whether to enable loading a remote code model, default to False + trust_remote_code: False + + # override hf config + override_config: {} + +# [Deprecated] Global micro batch size +# will be deprecated, use micro_batch_size_per_gpu +micro_batch_size: null + +# Local per-GPU micro batch size +micro_batch_size_per_gpu: null + +# Maximum sequence length to process for scoring +max_length: null + +# Whether to dynamically adjust batch size at runtime +use_dynamic_bsz: ${critic.use_dynamic_bsz} + +# Maximum number of tokens per GPU in one forward pass +forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} + +# Deprecated. Use `reward_manager.name` instead. See `verl/trainer/config/reward_manager.yaml` for details. +# Kept for backward compatibility. +reward_manager: naive + +# Reward Loop Loading Configuration (for experimental reward system) +# Source for loading reward loop manager: "register" (default) or "importlib" +reward_loop_source: register + +# Module path when using importlib (e.g., "hytuner/reward/reward_loop/xxx_reward_loop.py") +reward_loop_module_path: null + +# Class name when using importlib (e.g., "XXXRewardManager") +reward_loop_class_name: null + +# Whether to launch custom reward function asynchronously during log_prob +# custom reward function executed async on CPU, during log_prob +launch_reward_fn_async: False + +# Cloud/local sandbox fusion configuration for custom reward logic +sandbox_fusion: + + # Cloud /local function URL for sandbox execution + url: null + + # Max concurrent requests allowed to sandbox + max_concurrent: 64 + + # Max memory limit for each sandbox process in MB + memory_limit_mb: 1024 + +# profile the reward model in `compute_reward` +profiler: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.ProfilerConfig + + # profiler tool, default same as profiler.tool in global config + # choices: nsys, npu, torch + tool: ${oc.select:global_profiler.tool,null} + + # whether enable profile on ref + enable: False + + # Whether to profile all ranks. + all_ranks: False + + # The ranks that will be profiled. [] or [0,1,...] + ranks: [] + + # profile results saving path + save_path: ${oc.select:global_profiler.save_path,null} + + # specific tool config + tool_config: ${oc.select:actor_rollout_ref.actor.profiler.tool_config,null} \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/rollout/rollout.yaml b/code/RL_model/verl/verl_train/verl/trainer/config/rollout/rollout.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8d4a337125986471ac7094a3b1c76dad63080220 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/rollout/rollout.yaml @@ -0,0 +1,356 @@ +# Target class for this configuration +_target_: verl.workers.config.RolloutConfig + +# actor_rollout_ref.rollout.name: hf/vllm/sglang/trtllm. The default value will be removed in the future +name: ??? + +# sync: LLM, async: AsyncLLM +mode: async + +# Sampling temperature for rollout. +temperature: 1.0 + +# Top-k sampling parameter. -1 for vLLM rollout, 0 for HF rollout. +top_k: -1 + +# Top-p sampling parameter. Default 1.0. +top_p: 1 + +# typically the same as data max prompt length +# same as data.max_prompt_length if it exists +prompt_length: ${oc.select:data.max_prompt_length,512} + +# typically the same as data max response length +# same as data.max_response_length if it exists +response_length: ${oc.select:data.max_response_length,512} + +# for vllm rollout +# Rollout model parameters type. Align with actor model's FSDP/Megatron type. +dtype: bfloat16 + +# Fraction of GPU memory used by vLLM/SGLang/TRTLLM for KV cache. +gpu_memory_utilization: 0.5 + +# Whether to ignore EOS and continue generating after EOS is hit. +ignore_eos: False + +# Whether to disable CUDA graph. Default False to best performance. +enforce_eager: False + +# batch size of cudagraph to capture. Require enforce_eager: False to use this option +# Since cudagraph in inference engine can not be offloaded during update policy, +# you can use smaller batch size to save memory used in cuda graph, eg: [1 ,2, 4, 8, 16, 32] +# supported engines: vllm +cudagraph_capture_sizes: null + +# Whether to free engine KVCache after generation. +free_cache_engine: True + +# TP size for rollout. Not effective for hf +tensor_model_parallel_size: 2 + +# DP size for rollout +data_parallel_size: 1 + +# EP size for rollout +expert_parallel_size: 1 + +# PP size for rollout. +pipeline_model_parallel_size: 1 + +# max number of tokens in a batch +max_num_batched_tokens: 8192 + +# max length for rollout +max_model_len: null + +# max length of sequences +max_num_seqs: 1024 + +# may get higher throughput when set to True. When activated, Please increase max_num_batched_tokens or decrease max_model_len. +enable_chunked_prefill: True + +# Prefix caching kv-cache blocks is a popular optimization in LLM inference to avoid redundant prompt computations. +enable_prefix_caching: True + +# logprobs mode for rollout logprobs +logprobs_mode: processed_logprobs + +# scheduling policy for vllm rollout +scheduling_policy: fcfs + +# Which loader to use for rollout model weights: dummy, hf, megatron, etc. +# safetensors (for huge model, and set use_shm=True); dummy: randomly init model weight +load_format: dummy + +# [Will be deprecated, use log_prob_micro_batch_size_per_gpu] The batch size for one forward pass in the computation of log_prob. Global batch size. +log_prob_micro_batch_size: null + +# The batch size for one forward pass in the computation of log_prob. Local batch size per GPU. +log_prob_micro_batch_size_per_gpu: null + +# enable dynamic batch size (sequence packing) for log_prob computation +# same as actor_rollout_ref.actor.use_dynamic_bsz if it exists, otherwise false +log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + +# max token length for log_prob computation +# same as actor_rollout_ref.actor.ppo_max_token_len_per_gpu if it exists, otherwise 16384 +log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} + +# disable logging statistics +disable_log_stats: True + +# for hf rollout +# Whether to sample during training rollout. False uses greedy sampling. +do_sample: True + +# number of responses (i.e. num sample times). > 1 for grpo +n: 1 + +# The over_sample_rate parameter controls the early termination threshold for training rollouts, +# where the system will abort remaining requests when (1 - over_sample_rate) * total_requests completions are reached. +over_sample_rate: 0 + +# Whether to wake up inference engine in multi-stage for SGLang +# to reduce peak memory during training-rollout transition. +# This is only effective for SGLang rollout. +multi_stage_wake_up: false + +# Extra inference engine arguments (vllm, sglang, trtllm), please refer vllm/sglang/trtllm official doc for detail +engine_kwargs: + + # vllm engine config + vllm: {} + + # sglang engine config + sglang: {} + + # trtllm engine config + trtllm: {} + +# Sampling parameters used during validation. +val_kwargs: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.workers.config.SamplingConfig + + # sampling parameters for validation + # Top-k sampling parameter. -1 for vLLM rollout, 0 for HF rollout. + top_k: -1 + + # Top-p sampling parameter. Default 1.0. + top_p: 1.0 + + # Sampling temperature for rollout. + temperature: 0 + + # whether to repeat n times for validation + n: 1 + + # Whether to sample during training rollout. False uses greedy sampling. + do_sample: False + +# Multi-turn interaction config for tools or chat. +multi_turn: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.workers.config.MultiTurnConfig + + # set to True for multi-turn tool interaction tasks; should set rollout.name to sglang as well + enable: False + + # null for no limit (default max_length // 3) + max_assistant_turns: null + + # null for no tool + tool_config_path: null + + # null for no limit (default max_length // 3) + max_user_turns: null + + # max parallel call for tools in single turn + max_parallel_calls: 1 + + # max length of tool response + max_tool_response_length: 256 + + # truncate side of tool response: left, middle, right + tool_response_truncate_side: middle + + # null for no interaction + interaction_config_path: null + + # - When set to True, the model's default chat template is used for multi-turn rollout, which typically matches production behavior. + # - When set to False, the token ids recorded for training are used instead; unlike the default chat template, these always include the model's full output, + # which may contain additional content such as reasoning content. This maintains the consistency between training and rollout, but it will lead to longer prompts. + use_inference_chat_template: False + + # Tokenization is performed turn by turn and the resulting token ids are concatenated to form the full conversation. + # To ensure this matches the result of tokenizing the entire conversation at once, a sanity check is run at the end of each multi-turn rollout to compare the two sets of token ids. + # Some models are known to produce different tokenization results when tokenizing turn by turn vs. all at once. aThis behavior has already been validated for them. + # To reduce excessive warnings, you can turn off the sanity check for these models if you are using their default chat template: + # Qwen/QwQ-32B, Qwen/Qwen3-xxB + # - disable: disable tokenization sanity check + # - strict: enable strict tokenization sanity check (default) + # - ignore_strippable: ignore strippable tokens when checking tokenization sanity + tokenization_sanity_check_mode: strict + + # Format of the multi-turn interaction. Options: hermes, llama3_json, ... + format: hermes + + # Number of repeat rollouts for each interaction + num_repeat_rollouts: null + +# support logging rollout prob for debugging purpose +# "Truncated importance sampling" requires rollout log probs, set to True when turning on Truncated importance sampling +calculate_log_probs: False + +# [Experimental] agent loop based rollout configs +agent: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.workers.config.AgentLoopConfig + + # Number of agent loop workers + num_workers: 8 + + # default agent loop to use if `agent_name` not set in RL dataset + default_agent_loop: single_turn_agent + + # custom agent loop config path, which should contain list of configs to initialize AgentLoop instances. + # https://hydra.cc/docs/advanced/instantiate_objects/overview/ + # + # - name: react_agent + # _target_: recipe.langgraph_agent.react_agent_loop.ReactAgentLoop + # tools: ["get_current_temperature"] + # - name: math_expression + # _target_: recipe.langgraph_agent.example.math_expression.MathExpressionReactAgentLoop + # min_terms: 2 + # max_terms: 6 + agent_loop_config_path: null + + # custom async server configs + custom_async_server: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.workers.config.CustomAsyncServerConfig + + # Path to the custom async server implementation + path: null + + # Class name of the custom async server class (e.g. AsyncvLLMServer) + name: null + +# Checkpoint Engine config for update weights from trainer to rollout +checkpoint_engine: + + # Target class for checkpoint engine config + _target_: verl.workers.config.CheckpointEngineConfig + + # Backend for checkpoint engine: naive, nccl, nixl, hccl + backend: naive + + # Specifies the tensor bucket size (in megabytes) for batch weight updates during rollout operations. + # This parameter controls the maximum payload size for a single weight update request. + # Reference: https://github.com/volcengine/verl/pull/2418 + # Currently only supported in SGLang rollout implementations + # Larger values may improve throughput but increase memory overhead + # Detailed performance comparison: + # https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/issues/169#issuecomment-3070686720 + # Default value (512MB) is optimized for typical GPU memory configurations + # For the best performance of `rebuild_cuda_tensor`, it is recommended to: + # 1. Enable `RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES` + # 2. Manually set `CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7` + # when using Tensor Parallelism (TP) >= 8. + update_weights_bucket_megabytes: 2048 + + # Additional keyword arguments to pass to the checkpoint engine constructor + engine_kwargs: {} + +# trace rollout data +trace: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.workers.config.TraceConfig + + # trace backend, support mlflow, weave + backend: null + + # whether translate token id to text in output + token2text: False + + # Maximum number of unique samples to trace per agent worker per training step. + # If null, all samples are traced. If set to N, each agent loop worker will randomly + # select N unique samples to trace (including all their rollouts for GRPO). + # Total traces per step = max_samples_per_step_per_worker * num_workers * n_rollouts_per_sample + max_samples_per_step_per_worker: null + +# When enabled (True), the trainer will attempt to load previously generated rollout data from the specified directory instead of computing new rollouts. +# If no cached data is found or loading fails, new rollouts will be generated and automatically saved. +# This feature is useful for debugging or when you want to reuse computation results across multiple runs. +skip_rollout: False + +# Specifies the filesystem path where rollout data should be cached when skip_rollout is enabled. +# Note: Giving path under /tmp/ray/session* is not recommended as these are temporary Ray cluster directories. +skip_dump_dir: /tmp/rollout_dump + +# Whether to skip tokenizer initialization for rollout engine +# When enabled (True), the rollout assume token in token out for generation +skip_tokenizer_init: True + +# Whether to enable rollout routing replay for MoE models +# When enabled (True), the rollout will record the routing decisions. +enable_rollout_routing_replay: False + + +# profile the rollout model in `generate_sequence` +profiler: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.ProfilerConfig + + # profiler tool, default same as profiler.tool in global config + # choices: nsys, npu, torch + tool: ${oc.select:global_profiler.tool,null} + + # whether enable profile on ref + enable: ${oc.select:actor_rollout_ref.actor.profiler.enable,false} + + # Whether to profile all ranks. + all_ranks: ${oc.select:actor_rollout_ref.actor.profiler.all_ranks,false} + + # The ranks that will be profiled. [] or [0,1,...] + ranks: ${oc.select:actor_rollout_ref.actor.profiler.ranks,[]} + + # profile results saving path + save_path: ${oc.select:global_profiler.save_path,null} + + # specific tool config + tool_config: ${oc.select:actor_rollout_ref.actor.profiler.tool_config,null} + +# prometheus configuration for vllm/sglang server mode +prometheus: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.workers.config.PrometheusConfig + + # whether enable prometheus on server mode rollout + enable: false + + # Port number that Prometheus listens on, default is 9090 + port: 9090 + + # Path to Prometheus configuration file + file: /tmp/ray/session_latest/metrics/prometheus/prometheus.yml + + # Specify served_model_name to avoid displaying overly long model paths in Grafana + served_model_name: ${oc.select:actor_rollout_ref.model.path,null} + +# type of quantization in vllm, currently support fp8 and torchao +quantization: null + +# extra quantization information serialized in a config file, e.g. torchao_config.json +quantization_config_file: null + +# MTP configuration, reuse model configuration +mtp: ${oc.select:actor_rollout_ref.model.mtp, null} diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/sft_trainer.yaml b/code/RL_model/verl/verl_train/verl/trainer/config/sft_trainer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b2308e39e44fdb1c0cca318133e145d42a222b90 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/sft_trainer.yaml @@ -0,0 +1,91 @@ +defaults: + - optim: fsdp + - _self_ + +data: + train_batch_size: 256 + micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu + micro_batch_size_per_gpu: 4 # this is also val batch size + train_files: ~/data/gsm8k/train.parquet + val_files: ~/data/gsm8k/test.parquet + train_max_samples: -1 # set to -1 to use full dataset + val_max_samples: -1 # set to -1 to use full dataset + # Single-turn settings + prompt_key: question + response_key: answer + prompt_dict_keys: null + response_dict_keys: null + # Multi-turn settings + multiturn: + enable: false # Set to true to use multi-turn dataset + messages_key: messages # Key for messages list in multi-turn mode + tools_key: tools # Key for tools list in multi-turn mode + enable_thinking_key: enable_thinking # Whether to enable thinking in multi-turn mode + max_length: 1024 + truncation: error + balance_dp_token: False + chat_template: null + custom_cls: + path: null + name: null + use_shm: False + apply_chat_template_kwargs: {} +model: + partial_pretrain: ~/models/gemma-1.1-7b-it + use_shm: False + fsdp_config: + model_dtype: fp32 + wrap_policy: + min_num_params: 0 + cpu_offload: False + offload_params: False + external_lib: null + enable_gradient_checkpointing: True + trust_remote_code: False + lora_rank: 0 # Set to positive value to enable LoRA (e.g., 32) + lora_alpha: 16 # LoRA scaling factor + target_modules: all-linear # Target modules for LoRA adaptation + use_liger: False + strategy: fsdp2 +optim: + lr: 1e-5 + betas: [0.9, 0.95] + weight_decay: 0.01 + lr_warmup_steps_ratio: 0.1 + clip_grad: 1.0 + lr_scheduler: cosine +ulysses_sequence_parallel_size: 1 +use_remove_padding: False +trainer: + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + default_hdfs_dir: null + project_name: gsm8k-sft + experiment_name: test + total_epochs: 4 + total_training_steps: null + logger: [ 'console', 'wandb' ] + seed: 1 + save_freq: -1 + test_freq: -1 + nnodes: 1 + n_gpus_per_node: 8 + max_ckpt_to_keep: null # Maximum number of checkpoints to keep, set to null to keep all + + # Resume mode: "auto", "disable", or "resume_path" + # "auto": resume from last checkpoint if available + # "disable": start from scratch + # "resume_path": resume from a user-defined path + resume_mode: auto + + # Path to resume training from (used when resume_mode is "resume_path" or "auto") + resume_from_path: null + + # Checkpoint configuration + checkpoint: + # What to include in saved checkpoints + # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space + save_contents: ["model", "optimizer", "extra"] + + # For more flexibility, you can specify the contents to load from the checkpoint. + load_contents: ${trainer.checkpoint.save_contents} + device: cuda diff --git a/code/RL_model/verl/verl_train/verl/trainer/config/sft_trainer_engine.yaml b/code/RL_model/verl/verl_train/verl/trainer/config/sft_trainer_engine.yaml new file mode 100644 index 0000000000000000000000000000000000000000..134dbd6005d64b4f50247c7af611086e6ac9a748 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/config/sft_trainer_engine.yaml @@ -0,0 +1,85 @@ +# Format checks enforced on CI: +# 1. Comments must appear above each field. +# 2. There must be a blank line between each field. +# 3. Inline comments (after a field on the same line) are not allowed. +# 4. Indentation level is respected for nested fields. + +# @.: + +defaults: + - model@model: hf_model + - engine@engine: fsdp + - optim@optim: fsdp + - profiler@profiler: profiler + - _self_ + +data: + train_batch_size: 256 # global batch size + micro_batch_size_per_gpu: 4 # this is also val batch size + max_token_len_per_gpu: 8192 + use_dynamic_bsz: True + train_files: ~/data/gsm8k/train.parquet + val_files: null + train_max_samples: -1 # set to -1 to use full dataset + val_max_samples: -1 # set to -1 to use full dataset + # Multi-turn settings + messages_key: messages # Key for messages list in multi-turn mode + tools_key: tools # Key for tools list in multi-turn mode + enable_thinking_key: enable_thinking # Whether to enable thinking in multi-turn mode + enable_thinking_default: none # The default value when enable_thinking_key is not present in the dataset + pad_mode: no_padding + # for right padding + max_length: 1024 + truncation: error + balance_dp_token: False # to be implement + custom_cls: + path: null + name: null + use_shm: False + apply_chat_template_kwargs: {} + num_workers: 8 + + # MultiTurnSFTDataset apply_chat_template to each turn separately and concat `input_ids` + # as a whole sequence, which may not equal to apply_chat_template to whole messages at once. + # For example, Qwen Thinking series models add tags to last turn, please check + # your tokenizer chat template settings. + # Set to True to ignore input_ids mismatch and use the concatenated input_ids as the final input_ids. + ignore_input_ids_mismatch: False + +# Checkpoint configuration +checkpoint: + _target_: verl.trainer.config.CheckpointConfig + # What to include in saved checkpoints + # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space + save_contents: ["model", "optimizer", "extra"] + + # For more flexibility, you can specify the contents to load from the checkpoint. + load_contents: ${checkpoint.save_contents} + +trainer: + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + default_hdfs_dir: null + project_name: gsm8k-sft + experiment_name: test + total_epochs: 4 + total_training_steps: null + logger: [ 'console', 'wandb' ] + seed: 1 + save_freq: -1 + test_freq: -1 + max_ckpt_to_keep: null # Maximum number of checkpoints to keep, set to null to keep all + + # Resume mode: "auto", "disable", or "resume_path" + # "auto": resume from last checkpoint if available + # "disable": start from scratch + # "resume_path": resume from a user-defined path + resume_mode: auto + + # Path to resume training from (used when resume_mode is "resume_path" or "auto") + resume_from_path: null + device: cuda + + nnodes: 1 + n_gpus_per_node: 1 + + profile_interval: [-1, -1] diff --git a/code/RL_model/verl/verl_train/verl/trainer/ppo/__init__.py b/code/RL_model/verl/verl_train/verl/trainer/ppo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1ce90c5eb352d85c59105c0dc85b5f1dd576f095 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/ppo/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/code/RL_model/verl/verl_train/verl/trainer/ppo/core_algos.py b/code/RL_model/verl/verl_train/verl/trainer/ppo/core_algos.py new file mode 100644 index 0000000000000000000000000000000000000000..2039fe56f62f52190846fbf8b8b31dc0df160929 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/ppo/core_algos.py @@ -0,0 +1,2200 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Core functions to implement PPO algorithms. +The function implemented in this file should be used by trainer with different distributed strategies to +implement PPO-like algorithms. +""" + +__all__ = ["register_adv_est", "get_adv_estimator_fn", "AdvantageEstimator"] + +from collections import defaultdict +from enum import Enum +from typing import Any, Callable, Optional + +import numpy as np +import torch +from omegaconf import DictConfig + +import verl.utils.torch_functional as verl_F +from verl.trainer.config import AlgoConfig +from verl.utils import as_torch_index, group_mean_std +from verl.utils.import_utils import deprecated +from verl.workers.config import ActorConfig + +PolicyLossFn = Callable[ + [ + torch.Tensor, # old_log_prob + torch.Tensor, # log_prob + torch.Tensor, # advantages + torch.Tensor, # response_mask + str, # loss_agg_mode + Optional[DictConfig | ActorConfig], # config + torch.Tensor | None, # rollout_log_probs + ], + tuple[torch.Tensor, dict[str, Any]], +] + +POLICY_LOSS_REGISTRY: dict[str, PolicyLossFn] = {} + + +def register_policy_loss(name: str) -> Callable[[PolicyLossFn], PolicyLossFn]: + """Register a policy loss function with the given name. + + Args: + name (str): The name to register the policy loss function under. + + Returns: + function: Decorator function that registers the policy loss function. + """ + + def decorator(func: PolicyLossFn) -> PolicyLossFn: + POLICY_LOSS_REGISTRY[name] = func + return func + + return decorator + + +def get_policy_loss_fn(name): + """Get the policy loss with a given name. + + Args: + name: `(str)` + The name of the policy loss. + + Returns: + `(callable)`: The policy loss function. + """ + loss_name = name + if loss_name not in POLICY_LOSS_REGISTRY: + raise ValueError( + f"Unsupported loss mode: {loss_name}. Supported modes are: {list(POLICY_LOSS_REGISTRY.keys())}" + ) + return POLICY_LOSS_REGISTRY[loss_name] + + +class AdvantageEstimator(str, Enum): + """Using an enumeration class to avoid spelling errors in adv_estimator. + + Note(haibin.lin): this enum class is immutable after creation. Extending this + enum for new estimators may not be necessary since users can always just call + `verl.trainer.ppo.core_algos.register` with string name for a custom advantage + estimator instead. + """ + + GAE = "gae" + GRPO = "grpo" + REINFORCE_PLUS_PLUS = "reinforce_plus_plus" + REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline" + REMAX = "remax" + RLOO = "rloo" + OPO = "opo" + GRPO_PASSK = "grpo_passk" + GPG = "gpg" + RLOO_VECTORIZED = "rloo_vectorized" + GRPO_VECTORIZED = "grpo_vectorized" + OPTIMAL_TOKEN_BASELINE = "optimal_token_baseline" + TIR_OPTIMAL_TOKEN_BASELINE = "tir_optimal_token_baseline" + + +ADV_ESTIMATOR_REGISTRY: dict[str, Any] = {} + + +def register_adv_est(name_or_enum: str | AdvantageEstimator) -> Any: + """Decorator to register a advantage estimator function with a given name. + + Args: + name_or_enum: `(str)` or `(AdvantageEstimator)` + The name or enum of the advantage estimator. + + """ + + def decorator(fn): + name = name_or_enum.value if isinstance(name_or_enum, Enum) else name_or_enum + if name in ADV_ESTIMATOR_REGISTRY and ADV_ESTIMATOR_REGISTRY[name] != fn: + raise ValueError( + f"Adv estimator {name} has already been registered: {ADV_ESTIMATOR_REGISTRY[name]} vs {fn}" + ) + ADV_ESTIMATOR_REGISTRY[name] = fn + return fn + + return decorator + + +def get_adv_estimator_fn(name_or_enum): + """Get the advantage estimator function with a given name. + + Args: + name_or_enum: `(str)` or `(AdvantageEstimator)` + The name or enum of the advantage estimator. + + Returns: + `(callable)`: The advantage estimator function. + """ + name = name_or_enum.value if isinstance(name_or_enum, Enum) else name_or_enum + if name not in ADV_ESTIMATOR_REGISTRY: + raise ValueError(f"Unknown advantage estimator simply: {name}") + return ADV_ESTIMATOR_REGISTRY[name] + + +class AdaptiveKLController: + """ + Adaptive KL controller described in the paper: + https://arxiv.org/pdf/1909.08593.pdf + """ + + def __init__(self, init_kl_coef, target_kl, horizon): + self.value = init_kl_coef + self.target = target_kl + self.horizon = horizon + + def update(self, current_kl, n_steps): + """Update the KL coefficient based on current KL divergence. + + Args: + current_kl (float): Current KL divergence value. + n_steps (int): Number of steps taken. + """ + target = self.target + proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2) + mult = 1 + proportional_error * n_steps / self.horizon + self.value *= mult + + +class FixedKLController: + """Fixed KL controller.""" + + def __init__(self, kl_coef): + self.value = kl_coef + + def update(self, current_kl, n_steps): + """Update method for fixed KL controller (no-op). + + Args: + current_kl (float): Current KL divergence value (unused). + n_steps (int): Number of steps taken (unused). + """ + pass + + +def get_kl_controller(kl_ctrl): + """Factory function to create appropriate KL controller based on configuration. + + Args: + kl_ctrl: Configuration object containing KL controller settings. + + Returns: + KL controller instance (FixedKLController or AdaptiveKLController). + + Raises: + NotImplementedError: If controller type is not supported. + AssertionError: If adaptive controller horizon is not positive. + """ + if kl_ctrl.type == "fixed": + return FixedKLController(kl_coef=kl_ctrl.kl_coef) + elif kl_ctrl.type == "adaptive": + assert kl_ctrl.horizon > 0, f"horizon must be larger than 0. Got {kl_ctrl.horizon}" + return AdaptiveKLController(init_kl_coef=kl_ctrl.kl_coef, target_kl=kl_ctrl.target_kl, horizon=kl_ctrl.horizon) + else: + raise NotImplementedError + + +@register_adv_est(AdvantageEstimator.GAE) # or simply: @register_adv_est("gae") +def compute_gae_advantage_return( + token_level_rewards: torch.Tensor, + values: torch.Tensor, + response_mask: torch.Tensor, + gamma: torch.Tensor, + lam: torch.Tensor, +): + """Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py + + Args: + token_level_rewards: `(torch.Tensor)` + shape is (bs, response_length) + values: `(torch.Tensor)` + shape is (bs, response_length) + response_mask: `(torch.Tensor)` + shape is (bs, response_length). [EOS] mask. The token after [EOS] have mask zero. + gamma is `(float)` + discounted factor used in RL + lam: `(float)` + lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438) + + Returns: + advantages: `(torch.Tensor)` + shape: (bs, response_length) + Returns: `(torch.Tensor)` + shape: (bs, response_length) + + """ + with torch.no_grad(): + nextvalues = 0 + lastgaelam = 0 + advantages_reversed = [] + gen_len = token_level_rewards.shape[-1] + + for t in reversed(range(gen_len)): + delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t] + lastgaelam_ = delta + gamma * lam * lastgaelam + + # skip values and TD-error on observation tokens + nextvalues = values[:, t] * response_mask[:, t] + (1 - response_mask[:, t]) * nextvalues + lastgaelam = lastgaelam_ * response_mask[:, t] + (1 - response_mask[:, t]) * lastgaelam + + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], dim=1) + + returns = advantages + values + advantages = verl_F.masked_whiten(advantages, response_mask) + return advantages, returns + + +# NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar. +@register_adv_est(AdvantageEstimator.GRPO) # or simply: @register_adv_est("grpo") +def compute_grpo_outcome_advantage( + token_level_rewards: torch.Tensor, + response_mask: torch.Tensor, + index: np.ndarray, + epsilon: float = 1e-6, + norm_adv_by_std_in_grpo: bool = True, + config: Optional[AlgoConfig] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute advantage for GRPO, operating only on Outcome reward + (with only one scalar reward for each response). + + Args: + token_level_rewards: `(torch.Tensor)` + shape is (bs, response_length) + response_mask: `(torch.Tensor)` + shape is (bs, response_length) + index: `(np.ndarray)` + index array for grouping + epsilon: `(float)` + small value to avoid division by zero + norm_adv_by_std_in_grpo: `(bool)` + whether to scale the GRPO advantage + config: `(Optional[AlgoConfig])` + algorithm configuration object + + Note: + If norm_adv_by_std_in_grpo is True, the advantage is scaled by the std, as in the original GRPO. + If False, the advantage is not scaled, as in Dr.GRPO (https://arxiv.org/abs/2503.20783). + + Returns: + advantages: `(torch.Tensor)` + shape is (bs, response_length) + Returns: `(torch.Tensor)` + shape is (bs, response_length) + """ + scores = token_level_rewards.sum(dim=-1) + + id2score = defaultdict(list) + id2mean = {} + id2std = {} + + with torch.no_grad(): + bsz = scores.shape[0] + for i in range(bsz): + id2score[index[i]].append(scores[i]) + for idx in id2score: + if len(id2score[idx]) == 1: + id2mean[idx] = torch.tensor(0.0) + id2std[idx] = torch.tensor(1.0) + elif len(id2score[idx]) > 1: + scores_tensor = torch.stack(id2score[idx]) + id2mean[idx] = torch.mean(scores_tensor) + id2std[idx] = torch.std(scores_tensor) + else: + raise ValueError(f"no score in prompt index: {idx}") + for i in range(bsz): + if norm_adv_by_std_in_grpo: + scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon) + else: + scores[i] = scores[i] - id2mean[index[i]] + scores = scores.unsqueeze(-1) * response_mask + + return scores, scores + + +@register_adv_est(AdvantageEstimator.GRPO_VECTORIZED) +def compute_grpo_vectorized_outcome_advantage( + token_level_rewards: torch.Tensor, + response_mask: torch.Tensor, + index: np.ndarray, + epsilon: float = 1e-6, + norm_adv_by_std_in_grpo: bool = True, + config: Optional[AlgoConfig] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Vectorized GRPO(outcome-only): + For each group g: + a_i = \\frac{r_i - \\mu_g}{\\sigma_g} (or without dividing by \\sigma_g), + then broadcast the scalar across the token dimension (multiplied by response_mask).。 + """ + with torch.no_grad(): + scores = token_level_rewards.sum(dim=-1) + g = as_torch_index(index, device=scores.device) + mean_g, std_g, _ = group_mean_std(scores, g, eps=epsilon, device=scores.device) + if norm_adv_by_std_in_grpo: + scalars = (scores - mean_g[g]) / (std_g[g] + epsilon) + else: + scalars = scores - mean_g[g] + advantages = scalars.unsqueeze(-1) * response_mask + return advantages, advantages + + +@register_adv_est(AdvantageEstimator.GRPO_PASSK) # or simply: @register_adv_est("grpo_passk") +def compute_grpo_passk_outcome_advantage( + token_level_rewards: torch.Tensor, + response_mask: torch.Tensor, + index: np.ndarray, + epsilon: float = 1e-6, + norm_adv_by_std_in_grpo: bool = True, + config: Optional[AlgoConfig] = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute advantage for Pass@k using a GRPO-style outcome reward formulation. + Only the best response per group gets a non-zero advantage: r_max - r_second_max. + + Implemented as described in https://arxiv.org/abs/2503.19595. + + Args: + token_level_rewards: (bs, response_length) + response_mask: (bs, response_length) + index: (bs,) → group ID per sample + epsilon: float for numerical stability + config: (AlgoConfig) algorithm settings, which contains "norm_adv_by_std_in_grpo" + + Returns: + advantages: (bs, response_length) + returns: (bs, response_length) + """ + assert config is not None + # if True, normalize advantage by std within group + norm_adv_by_std_in_grpo = config.get("norm_adv_by_std_in_grpo", True) + scores = token_level_rewards.sum(dim=-1) # (bs,) + advantages = torch.zeros_like(scores) + + id2scores = defaultdict(list) + id2indices = defaultdict(list) + + with torch.no_grad(): + bsz = scores.shape[0] + for i in range(bsz): + idx = index[i] + id2scores[idx].append(scores[i]) + id2indices[idx].append(i) + + for idx in id2scores: + rewards = torch.stack(id2scores[idx]) # (k,) + if rewards.numel() < 2: + raise ValueError( + f"Pass@k requires at least 2 samples per group. Got {rewards.numel()} for group {idx}." + ) + topk, topk_idx = torch.topk(rewards, 2) + r_max, r_second_max = topk[0], topk[1] + i_max = id2indices[idx][topk_idx[0].item()] + advantage = r_max - r_second_max + if norm_adv_by_std_in_grpo: + std = torch.std(rewards) + advantage = advantage / (std + epsilon) + advantages[i_max] = advantage + + advantages = advantages.unsqueeze(-1) * response_mask + return advantages, advantages + + +@register_adv_est( + AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE +) # or simply: @register_adv_est("reinforce_plus_plus_baseline") +def compute_reinforce_plus_plus_baseline_outcome_advantage( + token_level_rewards: torch.Tensor, + response_mask: torch.Tensor, + index: torch.Tensor, + epsilon: float = 1e-6, + config: Optional[AlgoConfig] = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute advantage for RF++-baseline (https://arxiv.org/abs/2501.03262), operating only on Outcome reward + (with only one scalar reward for each response). + + Args: + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + response_mask: `(torch.Tensor)` + shape: (bs, response_length) + config: (AlgoConfig) algorithm config + + Returns: + advantages: `(torch.Tensor)` + shape: (bs, response_length) + Returns: `(torch.Tensor)` + shape: (bs, response_length) + """ + response_length = token_level_rewards.shape[-1] + scores = token_level_rewards.sum(dim=-1) + + id2score = defaultdict(list) + id2mean = {} + + with torch.no_grad(): + bsz = scores.shape[0] + for i in range(bsz): + id2score[index[i]].append(scores[i]) + for idx in id2score: + if len(id2score[idx]) == 1: + id2mean[idx] = torch.tensor(0.0) + elif len(id2score[idx]) > 1: + id2mean[idx] = torch.mean(torch.stack(id2score[idx])) + else: + raise ValueError(f"no score in prompt index: {idx}") + for i in range(bsz): + scores[i] = scores[i] - id2mean[index[i]] + + scores = scores.unsqueeze(-1).tile([1, response_length]) * response_mask + scores = verl_F.masked_whiten(scores, response_mask) * response_mask + + return scores, scores + + +@register_adv_est(AdvantageEstimator.RLOO) # or simply: @register_adv_est("rloo") +def compute_rloo_outcome_advantage( + token_level_rewards: torch.Tensor, + response_mask: torch.Tensor, + index: np.ndarray, + epsilon: float = 1e-6, + config: Optional[AlgoConfig] = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740 + + Args: + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + response_mask: `(torch.Tensor)` + shape: (bs, response_length) + config: (AlgoConfig) algorithm config + + Returns: + advantages: `(torch.Tensor)` + shape: (bs, response_length) + Returns: `(torch.Tensor)` + shape: (bs, response_length) + """ + scores = token_level_rewards.sum(dim=-1) + + id2score = defaultdict(list) + id2mean = {} + + with torch.no_grad(): + bsz = scores.shape[0] + for i in range(bsz): + id2score[index[i]].append(scores[i]) + for idx in id2score: + if len(id2score[idx]) == 1: + id2mean[idx] = torch.tensor(0.0) + elif len(id2score[idx]) > 1: + id2mean[idx] = torch.mean(torch.stack(id2score[idx])) + else: + raise ValueError(f"no score in prompt index: {idx}") + for i in range(bsz): + response_num = len(id2score[index[i]]) + if response_num > 1: + scores[i] = scores[i] * response_num / (response_num - 1) - id2mean[index[i]] * response_num / ( + response_num - 1 + ) + scores = scores.unsqueeze(-1) * response_mask + + return scores, scores + + +@register_adv_est(AdvantageEstimator.OPO) # or simply: @register_adv_est("opo") +def compute_opo_outcome_advantage( + token_level_rewards: torch.Tensor, + response_mask: torch.Tensor, + index: np.ndarray, + epsilon: float = 1e-6, + config: Optional[AlgoConfig] = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute advantage for OPO based on https://arxiv.org/pdf/2505.23585 + + Args: + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + response_mask: `(torch.Tensor)` + shape: (bs, response_length) + config: (AlgoConfig) algorithm config + + Returns: + advantages: `(torch.Tensor)` + shape: (bs, response_length) + Returns: `(torch.Tensor)` + shape: (bs, response_length) + """ + response_length = response_mask.sum(dim=-1) + scores = token_level_rewards.sum(dim=-1) + + id2score = defaultdict(list) + id2len = defaultdict(list) + id2bsl = {} + + with torch.no_grad(): + bsz = scores.shape[0] + for i in range(bsz): + id2score[index[i]].append(scores[i]) + id2len[index[i]].append(response_length[i]) + + for idx in id2score: + if len(id2score[idx]) == 1: + id2bsl[idx] = torch.tensor(0.0) + elif len(id2score[idx]) > 1: + score_tensor = torch.stack(id2score[idx]) + len_tensor = torch.stack(id2len[idx]) + id2bsl[idx] = (len_tensor * score_tensor).sum() / len_tensor.sum() + else: + raise ValueError(f"no score in prompt index: {idx}") + for i in range(bsz): + scores[i] = scores[i] - id2bsl[index[i]] + scores = scores.unsqueeze(-1) * response_mask + + return scores, scores + + +@register_adv_est(AdvantageEstimator.REINFORCE_PLUS_PLUS) # or simply: @register_adv_est("reinforce_plus_plus") +def compute_reinforce_plus_plus_outcome_advantage( + token_level_rewards: torch.Tensor, response_mask: torch.Tensor, config: Optional[AlgoConfig] = None, **kwargs +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute advantage for REINFORCE++. + This implementation is based on the paper: https://arxiv.org/abs/2501.03262 + + Args: + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + response_mask: `(torch.Tensor)` + shape: (bs, response_length) + config: (AlgoConfig) algorithm config + + Returns: + advantages: `(torch.Tensor)` + shape: (bs, response_length) + Returns: `(torch.Tensor)` + shape: (bs, response_length) + """ + assert config is not None + gamma = config.gamma + with torch.no_grad(): + returns = torch.zeros_like(token_level_rewards) + running_return = 0 + + for t in reversed(range(token_level_rewards.shape[1])): + running_return = token_level_rewards[:, t] + gamma * running_return + returns[:, t] = running_return + # Reset after EOS + running_return = running_return * response_mask[:, t] + + advantages = verl_F.masked_whiten(returns, response_mask) + advantages = advantages * response_mask + + return advantages, returns + + +@register_adv_est(AdvantageEstimator.REMAX) # or simply: @register_adv_est("remax") +def compute_remax_outcome_advantage( + token_level_rewards: torch.Tensor, + reward_baselines: torch.Tensor, + response_mask: torch.Tensor, + config: Optional[AlgoConfig] = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute advantage for ReMax, operating only on Outcome reward + This implementation is based on the paper: https://arxiv.org/abs/2310.10505 + (with only one scalar reward for each response). + + Args: + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + reward_baselines: `(torch.Tensor)` + shape: (bs,) + response_mask: `(torch.Tensor)` + shape: (bs, response_length) + config: (AlgoConfig) algorithm config + + Returns: + advantages: `(torch.Tensor)` + shape: (bs, response_length) + Returns: `(torch.Tensor)` + shape: (bs, response_length) + """ + + with torch.no_grad(): + returns = (token_level_rewards * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) + advantages = returns - reward_baselines.unsqueeze(-1) * response_mask + + return advantages, returns + + +@register_adv_est(AdvantageEstimator.GPG) # or simply: @register_adv_est("gpg") +def compute_gpg_outcome_advantage( + token_level_rewards: torch.Tensor, + response_mask: torch.Tensor, + index: np.ndarray, + epsilon: float = 1e-6, + f_norm: float = 1.0, + alpha: float = 1.0, + config=None, + **kwargs, +): + """ + Compute advantage for GPG, operating only on Outcome reward + (with only one scalar reward for each response). + Args: + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + response_mask: `(torch.Tensor)` + shape: (bs, response_length) + index: `(np.ndarray)` + shape: (bs,) + epsilon: (float) + f_norm: (float) + alpha: (float) + config: (dict) algorithm config + + Returns: + advantages: `(torch.Tensor)` + shape: (bs, response_length) + Returns: `(torch.Tensor)` + shape: (bs, response_length) + """ + scores = token_level_rewards.sum(dim=-1) + + id2score = defaultdict(list) + id2mean = {} + id2std = {} + + with torch.no_grad(): + bsz = scores.shape[0] + m = torch.count_nonzero(scores) + alpha = bsz / m.clamp(min=1) + + for i in range(bsz): + id2score[index[i]].append(scores[i]) + + for idx in id2score: + if len(id2score[idx]) == 1: + id2mean[idx] = torch.tensor(0.0) + id2std[idx] = torch.tensor(1.0) + elif len(id2score[idx]) > 1: + scores_tensor = torch.stack(id2score[idx]) + id2mean[idx] = torch.mean(scores_tensor) + id2std[idx] = torch.std(scores_tensor) + else: + raise ValueError(f"no score in prompt index: {idx}") + for i in range(bsz): + scores[i] = alpha * (scores[i] - id2mean[index[i]]) / (f_norm) + scores = scores.unsqueeze(-1) * response_mask + + return scores, scores + + +@register_adv_est(AdvantageEstimator.RLOO_VECTORIZED) # or simply: @register_adv_est("rloo_vectorized") +def compute_rloo_vectorized_outcome_advantage( + token_level_rewards: torch.Tensor, + response_mask: torch.Tensor, + index: np.ndarray, + epsilon: float = 1e-6, + config: Optional[AlgoConfig] = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740 + + Args: + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + response_mask: `(torch.Tensor)` + shape: (bs, response_length) + config: (AlgoConfig) algorithm config + + Returns: + advantages: `(torch.Tensor)` + shape: (bs, response_length) + Returns: `(torch.Tensor)` + shape: (bs, response_length) + """ + scores = token_level_rewards.sum(dim=-1) + + with torch.no_grad(): + inv = torch.from_numpy(np.unique(index, return_inverse=True)[1]).to(scores.device) + + c = torch.bincount(inv)[inv].to(scores.dtype) + adv = ((c * scores - torch.bincount(inv, weights=scores)[inv]) / (c - 1).clamp_min(1)) * (c > 1) + + adv = adv.unsqueeze(-1) * response_mask + + return adv, adv + + +@register_adv_est(AdvantageEstimator.OPTIMAL_TOKEN_BASELINE) +def compute_optimal_token_baseline_advantage( + token_level_rewards: torch.Tensor, + response_mask: torch.Tensor, + index: np.ndarray, + old_log_probs: torch.Tensor, + sum_pi_squared: torch.Tensor, + rollout_is_weights: torch.Tensor = None, + handle_zero_tail: bool = False, + epsilon: float = 1e-8, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute advantages using Optimal Token Baseline (OTB). + + Unlike the group mean based baseline which uses a single baseline per trajectory, + this computes a unique baseline for each timestep using cumulative path variance. + + Theory: + For each timestep t in each prompt group: + B_t* = E[G_t × W_t] / E[W_t] + where W_t = Σ_{j=1}^t ||s_j||² (cumulative path-variance proxy) + and ||s_j||² = 1 - 2π_j + Σπ² + + The cumulative sum W_t captures the "realized energy" of trajectory has been up to timestep t, + giving higher weight to predicting rewards on high-variance paths. + + Args: + token_level_rewards: Rewards at each token position [shape: (bs, response_length)] + response_mask: Binary mask for valid tokens (1) vs padding (0) [shape: (bs, response_length)] + index: Prompt indices for grouping trajectories from same prompt [shape: (bs,)] + old_log_probs: Log probabilities from training policy during generation [shape: (bs, response_length)] + sum_pi_squared: Sum of squared probabilities over vocabulary Σπ² [shape: (bs, response_length)] + rollout_is_weights: Pre-computed IS weights for W correction [shape: (bs, response_length)], + None if not using IS + handle_zero_tail: If True, zero baselines will be set in the portion of the longest trajectory + that extends beyond the second-longest trajectory in the prompt group. + Default: False + epsilon: Small constant for numerical stability (default: 1e-8) + + Returns: + advantages: OTB advantage estimates [shape: (bs, response_length)] + returns: Cumulative rewards (returns) from each position [shape: (bs, response_length)] + + Note on Rollout Importance Sampling: + When rollout_is_weights is provided, W_t is scaled by ρ̄²(t) to minimize MSE under truncated IS: + B_t* = Σ[G_t × ρ̄²(t) × W_t] / Σ[ρ̄²(t) × W_t] + """ + with torch.no_grad(): + batch_size, seq_len = token_level_rewards.shape + device = token_level_rewards.device + + # Compute returns (reward-to-go) for each timestep + returns = (token_level_rewards * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) + + # Step 1: Compute w_per_timestep = 1 - 2π_t + Σπ²) + pi_t = torch.exp(old_log_probs) + w_per_timestep = 1 - 2 * pi_t + sum_pi_squared + + # Step 2: Apply rollout importance sampling correction (if enabled) + if rollout_is_weights is not None: + # Scale W by ρ̄² to minimize MSE under truncated IS + w_per_timestep = w_per_timestep * (rollout_is_weights**2) + + # Step 3: Compute cumulative path-variance proxy: W_t = Σ_{j=1}^t w_j + # This measures accumulated variance from the start of the trajectory up to timestep t + w_cumulative = (w_per_timestep * response_mask).cumsum(dim=-1) + + # Group trajectories by prompt + prompt_groups = defaultdict(list) + for i in range(batch_size): + prompt_groups[index[i]].append(i) + + # Initialize baselines tensor [batch_size, seq_len] + baselines = torch.zeros_like(returns) + + # Compute per-step baseline for each prompt group + for _, trajectory_indices in prompt_groups.items(): + N = len(trajectory_indices) + if N == 1: + # Single trajectory - no baseline (advantage = return) + continue + + traj_idx = torch.tensor(trajectory_indices, device=device) + + # Extract group data [N, seq_len] + returns_group = returns[traj_idx] + w_cumulative_group = w_cumulative[traj_idx] + mask_group = response_mask[traj_idx] + + # Compute per-timestep baseline: B_t = Σ[G_t × W_t] / Σ[W_t] + # where W_t = Σ_{j=1}^t ||s_j||² (cumulative path variance) + # Shape: [seq_len] + numerator = (returns_group * w_cumulative_group * mask_group).sum(dim=0) # Sum over trajectories + denominator = (w_cumulative_group * mask_group).sum(dim=0) + epsilon + + baseline_per_step = numerator / denominator # [seq_len] + + # Assign to all trajectories in this group + baselines[traj_idx] = baseline_per_step.unsqueeze(0).expand(N, -1) + + if handle_zero_tail: + # Optionally zero out the portion of the longest trajectory that extends + # beyond the second-longest trajectory in the prompt group. + response_lengths = mask_group.sum(dim=-1) + sorted_lengths, _ = torch.sort(response_lengths) + max_length = int(sorted_lengths[-1].item()) + second_max_length = int(sorted_lengths[-2].item()) + max_length_idx = (response_lengths == max_length).nonzero(as_tuple=True)[0] + if max_length_idx.numel() == 1 and max_length > second_max_length: + max_length_traj_idx = trajectory_indices[int(max_length_idx[0])] + baselines[max_length_traj_idx, second_max_length:] = 0.0 + + # Compute advantages: A_t = G_t - B_t + advantages = (returns - baselines) * response_mask + + return advantages, returns + + +@register_adv_est(AdvantageEstimator.TIR_OPTIMAL_TOKEN_BASELINE) +def compute_multi_turn_optimal_token_baseline_advantage( + token_level_rewards: torch.Tensor, + response_mask: torch.Tensor, + index: np.ndarray, + old_log_probs: torch.Tensor, + sum_pi_squared: torch.Tensor, + rollout_is_weights: torch.Tensor = None, + handle_zero_tail: bool = True, + epsilon: float = 1e-8, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute advantages using Optimal Token Baseline (OTB). + + Unlike the group mean based baseline which uses a single baseline per trajectory, + this computes a unique baseline for each timestep using cumulative path variance. + + Theory: + For each timestep t in each prompt group: + B_t* = E[G_t × W_t] / E[W_t] + where W_t = Σ_{j=1}^t ||s_j||² (cumulative path-variance proxy) + and ||s_j||² = 1 - 2π_j + Σπ² + + The cumulative sum W_t captures the "realized energy" of trajectory has been up to timestep t, + giving higher weight to predicting rewards on high-variance paths. + + Args: + token_level_rewards: Rewards at each token position [shape: (bs, response_length)] + response_mask: Binary mask for valid tokens (1) vs padding (0) [shape: (bs, response_length)] + index: Prompt indices for grouping trajectories from same prompt [shape: (bs,)] + old_log_probs: Log probabilities from training policy during generation [shape: (bs, response_length)] + sum_pi_squared: Sum of squared probabilities over vocabulary Σπ² [shape: (bs, response_length)] + rollout_is_weights: Pre-computed IS weights for W correction [shape: (bs, response_length)], + None if not using IS + handle_zero_tail: If True, zero baselines will be set in the portion of the longest trajectory + that extends beyond the second-longest trajectory in the prompt group. + Default: False + epsilon: Small constant for numerical stability (default: 1e-8) + + Returns: + advantages: OTB advantage estimates [shape: (bs, response_length)] + returns: Cumulative rewards (returns) from each position [shape: (bs, response_length)] + + Note on Rollout Importance Sampling: + When rollout_is_weights is provided, W_t is scaled by ρ̄²(t) to minimize MSE under truncated IS: + B_t* = Σ[G_t × ρ̄²(t) × W_t] / Σ[ρ̄²(t) × W_t] + """ + with torch.no_grad(): + # Compute returns (reward-to-go) for each timestep + token_returns = (token_level_rewards * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) + + # Step 1: Compute w_per_timestep = 1 - 2π_t + Σπ²) + pi_t = torch.exp(old_log_probs) + w_per_timestep = 1 - 2 * pi_t + sum_pi_squared + + # Step 2: Apply rollout importance sampling correction (if enabled) + if rollout_is_weights is not None: + # Scale W by ρ̄² to minimize MSE under truncated IS + w_per_timestep = w_per_timestep * (rollout_is_weights**2) + + # Step 3: Compute cumulative path-variance proxy: W_t = Σ_{j=1}^t w_j + # This measures accumulated variance from the start of the trajectory up to timestep t + w_cumulative = (w_per_timestep * response_mask).cumsum(dim=-1) + + # Step 4: Concatenate returns and w_cumulative for each trajectory + # This allows us to compute baseline per timestep for each trajectory + response_lengths = response_mask.sum(dim=-1).to(dtype=torch.long) # [shape: (bs * n, )] + max_response_length = int(response_lengths.max().item()) if response_lengths.numel() > 0 else 0 + all_w_values = w_cumulative.new_zeros( + (len(response_lengths), max_response_length) + ) # [shape: (bs * n, max_response_length)] + all_returns = torch.zeros_like(all_w_values) + for i in range(len(response_lengths)): + length = int(response_lengths[i].item()) + if length == 0: + continue + mask = response_mask[i].bool() + all_w_values[i, :length] = w_cumulative[i, mask] + all_returns[i, :length] = token_returns[i, mask] + + # Group trajectories by prompt + prompt_groups = defaultdict(list) + for i in range(len(response_lengths)): + if response_lengths[i] == 0: + continue + prompt_groups[index[i]].append(i) + + # Compute optimal baseline for each prompt group + baselines = torch.zeros_like(all_returns) + + for _, trajectory_indices in prompt_groups.items(): + N = len(trajectory_indices) + traj_idx = torch.tensor(trajectory_indices, device=all_returns.device) + + if N == 1: + # Single trajectory - no baseline (keep original reward as advantage) + baselines[traj_idx[0]] = 0.0 + continue + + # Extract group data + w_group = all_w_values[traj_idx] # [shape: (N, max_response_length)] + R_group = all_returns[traj_idx] # [shape: (N, max_response_length)] + # Direct optimal baseline - single value for all in group + b_star = (R_group * w_group).sum(dim=0) / (w_group.sum(dim=0) + epsilon) + # Convert to match baselines dtype (epsilon can cause float64 promotion) + baselines[traj_idx] = b_star.to(baselines.dtype) + + if handle_zero_tail: + # Optionally zero out the portion of the longest trajectory that extends + # beyond the second-longest trajectory in the prompt group. + response_lengths_group = response_lengths[traj_idx] + sorted_lengths, _ = torch.sort(response_lengths_group) + max_length = int(sorted_lengths[-1].item()) + second_max_length = int(sorted_lengths[-2].item()) + max_length_idx = (response_lengths_group == max_length).nonzero(as_tuple=True)[0] + if max_length_idx.numel() == 1 and max_length > second_max_length: + max_length_traj_idx = trajectory_indices[int(max_length_idx[0])] + baselines[max_length_traj_idx, second_max_length:] = 0.0 + + # Compute advantages + all_advantages = all_returns - baselines # [shape: (bs * n, max_response_length)] + + advantages = torch.zeros_like(token_returns) # [shape: (bs * n, turn * response_length)] + for i in range(len(response_lengths)): + if response_lengths[i] == 0: + continue + advantages[i, response_mask[i].bool()] = all_advantages[i, : response_lengths[i]] + + advantages = advantages * response_mask # [shape: (bs * n * turn, response_length)] + + return advantages, token_returns + + +def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio): + """Compute token-level rewards with KL penalty. + + Args: + token_level_scores (torch.Tensor): Token-level reward scores. + old_log_prob (torch.Tensor): Log probabilities from current policy. + ref_log_prob (torch.Tensor): Log probabilities from reference policy. + kl_ratio (float): KL penalty coefficient. + + Returns: + torch.Tensor: Token-level rewards with KL penalty applied. + """ + kl = old_log_prob - ref_log_prob + return token_level_scores - kl * kl_ratio + + +def agg_loss( + loss_mat: torch.Tensor, + loss_mask: torch.Tensor, + loss_agg_mode: str, + dp_size: int = 1, + batch_num_tokens: Optional[int] = None, + global_batch_size: Optional[int] = None, + loss_scale_factor: Optional[int] = None, +): + """ + Aggregate the loss across global batch to ensure the loss is invariant to fsdp/megatron parallelism. + + NOTE: The returned loss has different behaviors for different backend: + - FSDP: the loss is directly used for backward. + - Megatron: the loss should be scaled by `num_microbatches` and `cp_size` for pp schedule. + + Args: + loss_mat: micro batch loss matrix, (bs, response_length) + loss_mask: micro batch loss mask, (bs, response_length) + loss_agg_mode: method to aggregate the loss matrix into a scalar + dp_size: data parallel size + batch_num_tokens: number of valid tokens in global batch + global_batch_size: global batch size + loss_scale_factor: scale factor for "seq-mean-token-sum-norm" mode. If None, uses loss_mask.shape[-1]. + Set this to a constant value to ensure consistent normalization throughout training. + + Returns: + loss: `a scalar torch.Tensor` + aggregated loss + """ + if loss_agg_mode == "token-mean": + if batch_num_tokens is None: + batch_num_tokens = loss_mask.sum() + loss = verl_F.masked_sum(loss_mat, loss_mask) / batch_num_tokens * dp_size + elif loss_agg_mode == "seq-mean-token-sum": + seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) # token-sum + seq_mask = (torch.sum(loss_mask, dim=-1) > 0).float() # exclude fully masked sequences + if global_batch_size is None: + global_batch_size = seq_mask.sum() + loss = verl_F.masked_sum(seq_losses, seq_mask) / global_batch_size * dp_size # seq-mean + elif loss_agg_mode == "seq-mean-token-mean": + seq_mask = torch.sum(loss_mask, dim=-1) # per-sequence token count + seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / (seq_mask + 1e-8) # token-mean + seq_mask = (seq_mask > 0).float() # exclude fully masked sequences + if global_batch_size is None: + global_batch_size = seq_mask.sum() + loss = verl_F.masked_sum(seq_losses, seq_mask) / global_batch_size * dp_size # seq-mean + elif loss_agg_mode == "seq-mean-token-sum-norm": + seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) + if loss_scale_factor is None: + loss_scale_factor = loss_mask.shape[-1] + loss = torch.sum(seq_losses) / loss_scale_factor + else: + raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}") + + return loss + + +@deprecated("verl.trainer.ppo.core_algos.compute_policy_loss_vanilla") +def compute_policy_loss( + old_log_prob, + log_prob, + advantages, + response_mask, + cliprange=None, + cliprange_low=None, + cliprange_high=None, + clip_ratio_c=3.0, + loss_agg_mode: str = "token-mean", +): + """ + Compute the clipped policy objective and related metrics for PPO. + + Adapted from + https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122 + + Args: + old_log_prob (torch.Tensor): + Log-probabilities of actions under the old policy, shape (batch_size, response_length). + log_prob (torch.Tensor): + Log-probabilities of actions under the current policy, shape (batch_size, response_length). + advantages (torch.Tensor): + Advantage estimates for each action, shape (batch_size, response_length). + response_mask (torch.Tensor): + Mask indicating which tokens to include in the loss, shape (batch_size, response_length). + cliprange (float, optional): + Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347. + Defaults to None (must be provided). + cliprange_low (float, optional): + Lower clip range for dual-clip PPO. Defaults to same as `cliprange`. + cliprange_high (float, optional): + Upper clip range for dual-clip PPO. Defaults to same as `cliprange`. + clip_ratio_c (float, optional): + Lower bound of the ratio for dual-clip PPO. See https://arxiv.org/pdf/1912.09729. + Defaults to 3.0. + loss_agg_mode (str, optional): + Aggregation mode for `agg_loss`. Defaults to "token-mean". + """ + assert clip_ratio_c > 1.0, ( + "The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0," + + f" but get the value: {clip_ratio_c}." + ) + + negative_approx_kl = log_prob - old_log_prob + # Clamp negative_approx_kl for stability + negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) + ratio = torch.exp(negative_approx_kl) + ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) + + pg_losses1 = -advantages * ratio + if cliprange_low is None: + cliprange_low = cliprange + if cliprange_high is None: + cliprange_high = cliprange + pg_losses2 = -advantages * torch.clamp( + ratio, 1 - cliprange_low, 1 + cliprange_high + ) # - clip(ratio, 1-cliprange, 1+cliprange) * A + clip_pg_losses1 = torch.maximum( + pg_losses1, pg_losses2 + ) # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A) + pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask) + + pg_losses3 = -advantages * clip_ratio_c + clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1) + pg_clipfrac_lower = verl_F.masked_mean( + torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask + ) + + pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) + pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + + return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower + + +@register_policy_loss("vanilla") # type: ignore[arg-type] +def compute_policy_loss_vanilla( + old_log_prob: torch.Tensor, + log_prob: torch.Tensor, + advantages: torch.Tensor, + response_mask: torch.Tensor, + loss_agg_mode: str = "token-mean", + config: Optional[ActorConfig] = None, + rollout_is_weights: torch.Tensor | None = None, +) -> tuple[torch.Tensor, dict[str, Any]]: + """ + Compute the clipped policy objective and related metrics for PPO. + + Adapted from + https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122 + + Args: + old_log_prob (torch.Tensor): + Log-probabilities of actions under the old policy, shape (batch_size, response_length). + log_prob (torch.Tensor): + Log-probabilities of actions under the current policy, shape (batch_size, response_length). + advantages (torch.Tensor): + Advantage estimates for each action, shape (batch_size, response_length). + response_mask (torch.Tensor): + Mask indicating which tokens to include in the loss, shape (batch_size, response_length). + loss_agg_mode (str, optional): + Aggregation mode for `agg_loss`. Defaults to "token-mean". + config: `(verl.trainer.config.ActorConfig)`: + config for the actor. + rollout_log_probs: `(torch.Tensor)`: + log probabilities of actions under the rollout policy, shape (batch_size, response_length). + """ + + assert config is not None + assert not isinstance(config, AlgoConfig) + clip_ratio = config.clip_ratio # Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347. + clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio + clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio + clip_ratio_c = config.get( # Lower bound of the ratio for dual-clip PPO. See https://arxiv.org/pdf/1912.09729. + "clip_ratio_c", 3.0 + ) + + cliprange = clip_ratio + cliprange_low = clip_ratio_low + cliprange_high = clip_ratio_high + + assert clip_ratio_c > 1.0, ( + "The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0," + + f" but get the value: {clip_ratio_c}." + ) + + negative_approx_kl = log_prob - old_log_prob + # Clamp negative_approx_kl for stability + negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) + ratio = torch.exp(negative_approx_kl) + ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) + + pg_losses1 = -advantages * ratio + if cliprange_low is None: + cliprange_low = cliprange + if cliprange_high is None: + cliprange_high = cliprange + pg_losses2 = -advantages * torch.clamp( + ratio, 1 - cliprange_low, 1 + cliprange_high + ) # - clip(ratio, 1-cliprange, 1+cliprange) * A + clip_pg_losses1 = torch.maximum( + pg_losses1, pg_losses2 + ) # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A) + pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask) + + pg_losses3 = -advantages * clip_ratio_c + clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1) + pg_clipfrac_lower = verl_F.masked_mean( + torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask + ) + + pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) + + # Apply rollout correction weights if provided + if rollout_is_weights is not None: + pg_losses = pg_losses * rollout_is_weights + + pg_loss = agg_loss( + loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode, **config.global_batch_info + ) + + pg_metrics = { + "actor/pg_clipfrac": pg_clipfrac.detach().item(), + "actor/ppo_kl": ppo_kl.detach().item(), + "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(), + } + return pg_loss, pg_metrics + + +@register_policy_loss("gspo") +def compute_policy_loss_gspo( + old_log_prob: torch.Tensor, + log_prob: torch.Tensor, + advantages: torch.Tensor, + response_mask: torch.Tensor, + loss_agg_mode: str = "seq-mean-token-mean", + config: Optional[ActorConfig] = None, + rollout_is_weights: torch.Tensor | None = None, +) -> tuple[torch.Tensor, dict[str, Any]]: + """ + Compute the clipped policy objective and related metrics for GSPO. + + See https://arxiv.org/pdf/2507.18071 for more details. + + Args: + old_log_prob (torch.Tensor): + Log-probabilities of actions under the old policy, shape (batch_size, response_length). + log_prob (torch.Tensor): + Log-probabilities of actions under the current policy, shape (batch_size, response_length). + advantages (torch.Tensor): + Advantage estimates for each action, shape (batch_size, response_length). + response_mask (torch.Tensor): + Mask indicating which tokens to include in the loss, shape (batch_size, response_length). + loss_agg_mode (str, optional): + Aggregation mode for `agg_loss`. For GSPO, it is recommended to use "seq-mean-token-mean". + """ + + assert config is not None + assert isinstance(config, ActorConfig) + clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else config.clip_ratio + clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else config.clip_ratio + + negative_approx_kl = log_prob - old_log_prob + + # compute sequence-level importance ratio: + # si(θ) = (π_θ(yi|x)/π_θold(yi|x))^(1/|yi|) = + # exp [(1/|y_i|) * Σ_t log(π_θ(y_i,t|x,y_i, tuple[torch.Tensor, dict[str, Any]]: + """ + Compute the smoothed policy objective and related metrics for SAPO. + + See https://arxiv.org/pdf/2511.20347 for more details. + + Args: + old_log_prob (torch.Tensor): + Log-probabilities of actions under the old policy, shape (batch_size, response_length). + log_prob (torch.Tensor): + Log-probabilities of actions under the current policy, shape (batch_size, response_length). + advantages (torch.Tensor): + Advantage estimates for each action, shape (batch_size, response_length). + response_mask (torch.Tensor): + Mask indicating which tokens to include in the loss, shape (batch_size, response_length). + loss_agg_mode (str, optional): + Aggregation mode for `agg_loss`. For SAPO, it is recommended to use "seq-mean-token-mean". + """ + + assert config is not None + assert isinstance(config, ActorConfig) + + # temperature for positive and negative token updates + tau_pos = torch.as_tensor(config.tau_pos, dtype=advantages.dtype, device=advantages.device) + tau_neg = torch.as_tensor(config.tau_neg, dtype=advantages.dtype, device=advantages.device) + + def gate_function(x, tau): + """The gating function used in SAPO""" + return torch.sigmoid(tau * (x - 1.0)) * (4.0 / tau) + + # compute IS at token level: + # r_{i,t}(θ) = π_θ(y_{i,t}|x, y_{i, 0 else tau_neg + taus = torch.where( + condition=advantages > 0, + input=tau_pos, # if A_{i,t} > 0 we set to tau_pos + other=tau_neg, # if A_{i,t} <= 0 we set to tau_neg + ) + + # compute the gates f_{i,t}(r_{i,t}(θ)) at token level + gates = gate_function(ratio, taus) + + # compute policy gradient loss + pg_losses = -gates * advantages + + # Apply rollout correction weights if provided + if rollout_is_weights is not None: + pg_losses = pg_losses * rollout_is_weights + + # for SAPO, we need to aggregate the loss at the sequence level (seq-mean-token-mean) + pg_loss = agg_loss( + loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode="seq-mean-token-mean", **config.global_batch_info + ) + + # For compatibility, return zero for both pg_clipfrac and pg_clipfrac_lower (not used in SAPO) + pg_clipfrac = torch.tensor(0.0, device=pg_loss.device) + pg_clipfrac_lower = torch.tensor(0.0, device=pg_loss.device) + # compute KL for metrics tracking + ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) + # return metrics dict + pg_metrics = { + "actor/pg_clipfrac": pg_clipfrac.detach().item(), + "actor/ppo_kl": ppo_kl.detach().item(), + "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(), + } + + return pg_loss, pg_metrics + + +@register_policy_loss("gpg") +def compute_policy_loss_gpg( + old_log_prob: torch.Tensor, + log_prob: torch.Tensor, + advantages: torch.Tensor, + response_mask: torch.Tensor, + loss_agg_mode: str = "token-mean", + config: Optional[ActorConfig] = None, + rollout_is_weights: torch.Tensor | None = None, +) -> tuple[torch.Tensor, dict[str, Any]]: + """Adapted from + https://github.com/AMAP-ML/GPG/blob/main/VisualThinker-R1-Zero/src/open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py#L495 + Args: + log_prob: `(torch.Tensor)` + shape: (bs, response_length) + advantages: `(torch.Tensor)` + shape: (bs, response_length) + response_mask: `(torch.Tensor)` + shape: (bs, response_length) + return: + pg_loss: `a scalar torch.Tensor` + policy gradient loss computed via GPG + """ + assert config is not None + pg_losses = -log_prob * advantages + + # Apply rollout correction weights if provided + if rollout_is_weights is not None: + pg_losses = pg_losses * rollout_is_weights + + pg_loss = agg_loss( + loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode, **config.global_batch_info + ) + return pg_loss, {} + + +@register_policy_loss("clip_cov") +def compute_policy_loss_clip_cov( + old_log_prob: torch.Tensor, + log_prob: torch.Tensor, + advantages: torch.Tensor, + response_mask: torch.Tensor, + loss_agg_mode: str = "token-mean", + config: Optional[ActorConfig] = None, + rollout_is_weights: torch.Tensor | None = None, +) -> tuple[torch.Tensor, dict[str, Any]]: + """ + Compute the clipped policy objective and related metrics for Clip-Cov. + + Adapted from + https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/verl/trainer/ppo/core_algos.py + + Args: + old_log_prob (torch.Tensor): + Log-probabilities of actions under the old policy, shape (batch_size, response_length). + log_prob (torch.Tensor): + Log-probabilities of actions under the current policy, shape (batch_size, response_length). + advantages (torch.Tensor): + Advantage estimates for each action, shape (batch_size, response_length). + response_mask (torch.Tensor): + Mask indicating which tokens to include in the loss, shape (batch_size, response_length). + cliprange (float, optional): + Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347. + Defaults to None (must be provided). + cliprange_low (float, optional): + Lower clip range for dual-clip PPO. Defaults to same as `cliprange`. + cliprange_high (float, optional): + Upper clip range for dual-clip PPO. Defaults to same as `cliprange`. + loss_agg_mode (str, optional): + Aggregation mode for `agg_loss`. Defaults to "token-mean". + clip_cvo_ratio (float, optional): + Ratio for clipping the covariance. Defaults to 0.0002. + clip_cov_lb (float, optional): + Lower bound for clipping covariance. Defaults to 1.0. + clip_cov_ub (float, optional): + Upper bound for clipping covariance. Defaults to 5.0. + """ + assert config is not None + assert not isinstance(config, AlgoConfig), "passing AlgoConfig not supported yet" + assert config.policy_loss is not None + + clip_cov_ratio = config.policy_loss.clip_cov_ratio if config.policy_loss.clip_cov_ratio is not None else 0.0002 + cliprange = config.clip_ratio + cliprange_low = config.clip_ratio_low if config.clip_ratio_low is not None else cliprange + cliprange_high = config.clip_ratio_high if config.clip_ratio_high is not None else cliprange + clip_cov_ub = config.policy_loss.clip_cov_ub if config.policy_loss.clip_cov_ub is not None else 5.0 + clip_cov_lb = config.policy_loss.clip_cov_lb if config.policy_loss.clip_cov_lb is not None else 1.0 + + assert clip_cov_ratio > 0, "clip_ratio should be larger than 0." + + negative_approx_kl = log_prob - old_log_prob + ratio = torch.exp(negative_approx_kl) + ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) + + pg_losses1 = -advantages * ratio + + if cliprange_low is None: + cliprange_low = cliprange + if cliprange_high is None: + cliprange_high = cliprange + + corr = torch.ones_like(advantages) + pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high) + clip_by_origin = (pg_losses2 > pg_losses1) & (response_mask > 0) + + cov_all = (advantages - verl_F.masked_mean(advantages, response_mask)) * ( + log_prob - verl_F.masked_mean(log_prob.detach(), response_mask) + ) + cov_all[response_mask == 0] = -torch.inf + cov_all[clip_by_origin] = -torch.inf + + clip_num = max(int(clip_cov_ratio * response_mask.sum().item()), 1) + top_k_idx = (cov_all < clip_cov_ub) & (cov_all > clip_cov_lb) & (response_mask > 0) + top_k_idx = torch.nonzero(top_k_idx) + + if len(top_k_idx) > 0: + perm = torch.randperm(len(top_k_idx)) + top_k_idx = top_k_idx[perm[: min(clip_num, len(top_k_idx))]] + else: + top_k_idx = torch.empty((0, 2), device=cov_all.device, dtype=torch.long) + + corr[top_k_idx[:, 0], top_k_idx[:, 1]] = 0 + + pg_clipfrac = verl_F.masked_mean((corr == 0).float(), response_mask) + + pg_losses = torch.maximum(pg_losses1, pg_losses2) * corr + + # Apply rollout correction weights if provided + if rollout_is_weights is not None: + pg_losses = pg_losses * rollout_is_weights + + pg_loss = agg_loss( + loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode, **config.global_batch_info + ) + pg_metrics = { + "actor/pg_clipfrac": pg_clipfrac.detach().item(), + "actor/ppo_kl": ppo_kl.detach().item(), + } + return pg_loss, pg_metrics + + +@register_policy_loss("kl_cov") +def compute_policy_loss_kl_cov( + old_log_prob: torch.Tensor, + log_prob: torch.Tensor, + advantages: torch.Tensor, + response_mask: torch.Tensor, + loss_agg_mode: str = "token-mean", + config: Optional[ActorConfig] = None, + rollout_is_weights: torch.Tensor | None = None, +) -> tuple[torch.Tensor, dict[str, Any]]: + """ + Compute the clipped policy objective and related metrics for Clip-Cov. + + Adapted from + https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/verl/trainer/ppo/core_algos.py + + Args: + old_log_prob (torch.Tensor): + Log-probabilities of actions under the old policy, shape (batch_size, response_length). + log_prob (torch.Tensor): + Log-probabilities of actions under the current policy, shape (batch_size, response_length). + advantages (torch.Tensor): + Advantage estimates for each action, shape (batch_size, response_length). + response_mask (torch.Tensor): + Mask indicating which tokens to include in the loss, shape (batch_size, response_length). + loss_agg_mode (str, optional): + Aggregation mode for `agg_loss`. Defaults to "token-mean". + kl_cov_ratio (float, optional): + Ratio for selecting the top-k covariance values. Defaults to 0.0002. + ppo_kl_coef (float, optional): + Coefficient for the KL penalty term in the loss. Defaults to 1. + """ + assert config is not None + assert not isinstance(config, AlgoConfig), "passing AlgoConfig not supported yet" + assert config.policy_loss is not None + + kl_cov_ratio = config.policy_loss.kl_cov_ratio if config.policy_loss.kl_cov_ratio is not None else 0.0002 + ppo_kl_coef = config.policy_loss.ppo_kl_coef if config.policy_loss.ppo_kl_coef is not None else 1.0 + + assert kl_cov_ratio > 0, "kl_cov_ratio should be larger than 0." + + negative_approx_kl = log_prob - old_log_prob + abs_kl = negative_approx_kl.abs() + ratio = torch.exp(negative_approx_kl) + ppo_kl_abs = verl_F.masked_mean(negative_approx_kl.abs(), response_mask) + pg_losses1 = -advantages * ratio + pg_losses_kl = -advantages * ratio + ppo_kl_coef * abs_kl + pg_losses = pg_losses1 + + all_valid = response_mask > 0 + all_valid_idx = torch.nonzero(all_valid.reshape(-1), as_tuple=True)[0] + all_valid_adv = advantages[all_valid].detach().reshape(-1).cpu() + all_valid_logp = log_prob[all_valid].detach().reshape(-1).cpu() + + k = min(kl_cov_ratio, len(all_valid_adv)) + + if k != 0: + cov_lst_all = (all_valid_adv - all_valid_adv.mean()) * (all_valid_logp - all_valid_logp.mean()) + k_percent_nums = max(1, int(len(cov_lst_all) * kl_cov_ratio)) + large_cov_idxs = torch.topk(cov_lst_all, k_percent_nums, largest=True).indices + + if len(large_cov_idxs) != 0: + large_cov_idxs = all_valid_idx[large_cov_idxs] + pg_losses[large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1]] = pg_losses_kl[ + large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1] + ] + + # Apply rollout correction weights if provided + if rollout_is_weights is not None: + pg_losses = pg_losses * rollout_is_weights + + pg_loss = agg_loss( + loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode, **config.global_batch_info + ) + pg_metrics = { + "actor/ppo_kl": ppo_kl_abs.detach().item(), + } + return pg_loss, pg_metrics + + +@register_policy_loss("geo_mean") +def compute_policy_loss_geo_mean( + old_log_prob: torch.Tensor, + log_prob: torch.Tensor, + advantages: torch.Tensor, + response_mask: torch.Tensor, + loss_agg_mode: str = "token-mean", + config: Optional[ActorConfig] = None, + rollout_is_weights: torch.Tensor | None = None, +) -> tuple[torch.Tensor, dict[str, Any]]: + """ + Compute the clipped policy objective and related metrics for GMPO. + + Adapted from paper https://arxiv.org/abs/2507.20673 + https://github.com/callsys/GMPO/blob/main/train_zero_math_gmpo.py + + Args: + old_log_prob (torch.Tensor): + Log-probabilities of actions under the old policy, shape (batch_size, response_length). + log_prob (torch.Tensor): + Log-probabilities of actions under the current policy, shape (batch_size, response_length). + advantages (torch.Tensor): + Advantage estimates for each action, shape (batch_size, response_length). + response_mask (torch.Tensor): + Mask indicating which tokens to include in the loss, shape (batch_size, response_length). + loss_agg_mode (str, optional): + not used + """ + + assert config is not None + assert not isinstance(config, AlgoConfig) + clip_ratio = config.clip_ratio # Clipping parameter. See https://arxiv.org/abs/1707.06347. + clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio + clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio + + cliprange = clip_ratio + cliprange_low = clip_ratio_low + cliprange_high = clip_ratio_high + if cliprange_low is None: + cliprange_low = cliprange + if cliprange_high is None: + cliprange_high = cliprange + + negative_approx_kl = log_prob - old_log_prob + # Clamp negative_approx_kl for stability (uncomment it if you like) + # negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) + ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) + + # Clipping at token-level & Clipping wider + sgn_advantage = torch.sign(advantages) + negative_approx_kl_clamp = torch.clamp(negative_approx_kl, -cliprange_low, cliprange_high) + negative_approx_kl_min = torch.min(sgn_advantage * negative_approx_kl, sgn_advantage * negative_approx_kl_clamp) + negative_approx_kl_min = sgn_advantage * negative_approx_kl_min + + # Geometric-Mean Policy Optimization + response_mask_sum = response_mask.sum(dim=-1) + ratio = torch.exp((negative_approx_kl_min * response_mask).sum(dim=-1) / (response_mask_sum + 1e-8)) + # we only support sequence level advantage for now, + # otherwise, below would be not consistent with the paper + advantage = (advantages * response_mask).sum(dim=-1) / (response_mask_sum + 1e-8) + pg_losses = -advantage * ratio + + # Apply rollout correction weights if provided + # For geo_mean, IS weights are 2D (batch_size, seq_length) and need to be aggregated to sequence level + if rollout_is_weights is not None: + # Aggregate token-level weights to sequence level using geometric mean for consistency + # Note: rollout_is_weights is always 2D regardless of aggregation mode + seq_is_weights = torch.exp( + (torch.log(rollout_is_weights + 1e-10) * response_mask).sum(dim=-1) / (response_mask_sum + 1e-8) + ) + pg_losses = pg_losses * seq_is_weights + + pg_loss = torch.mean(pg_losses) + + # higher: ratio is too large that need clamp to clip_high (when adv > 0) + clipped = torch.ne(negative_approx_kl, negative_approx_kl_clamp) + pg_clipfrac = verl_F.masked_mean((clipped * (advantages > 0)).float(), response_mask) + pg_clipfrac_lower = verl_F.masked_mean((clipped * (advantages < 0)).float(), response_mask) + pg_metrics = { + "actor/pg_clipfrac": pg_clipfrac.detach().item(), + "actor/ppo_kl": ppo_kl.detach().item(), + "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(), + } + return pg_loss, pg_metrics + + +@register_policy_loss("cispo") +def compute_policy_loss_cispo( + old_log_prob: torch.Tensor, + log_prob: torch.Tensor, + advantages: torch.Tensor, + response_mask: torch.Tensor, + loss_agg_mode: str = "token-mean", + config: Optional[DictConfig | ActorConfig] = None, + rollout_is_weights: torch.Tensor | None = None, +) -> tuple[torch.Tensor, dict[str, Any]]: + """ + Compute the clipped policy objective and related metrics for CISPO. + + See https://arxiv.org/pdf/2506.13585 for more details. + """ + + assert config is not None + assert isinstance(config, ActorConfig) + clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else config.clip_ratio + clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else config.clip_ratio + + # Compute importance sampling ratio: π_θ / π_θ_old + negative_approx_kl = log_prob - old_log_prob + # Clamp for numerical stability + negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) + ratio = torch.exp(negative_approx_kl) + ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) + + # CISPO: Clip the importance sampling weights + # KEY: Apply stop gradient to the clipped ratio + # This prevents gradients from flowing through the ratio computation and clipping + # Gradients only flow through log_prob in the final loss term + clipped_ratio = torch.clamp(ratio, 1 - clip_ratio_low, 1 + clip_ratio_high) + clipped_ratio_sg = clipped_ratio.detach() + + # CISPO objective function (to maximize): J = sg(clip(ratio)) * A * log π_θ + # Loss function (to minimize): L = -J = -sg(clip(ratio)) * A * log_prob + pg_losses = -clipped_ratio_sg * advantages * log_prob + + # Track clipping statistics + pg_clipfrac = verl_F.masked_mean((ratio != clipped_ratio).float(), response_mask) + + # Apply rollout importance sampling weights if provided + if rollout_is_weights is not None: + pg_losses = pg_losses * rollout_is_weights + + pg_loss = agg_loss( + loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode, **config.global_batch_info + ) + + # For compatibility, return zero for pg_clipfrac_lower (not used in CISPO) + pg_clipfrac_lower = torch.tensor(0.0, device=pg_loss.device) + + pg_metrics = { + "actor/pg_clipfrac": pg_clipfrac.detach().item(), + "actor/ppo_kl": ppo_kl.detach().item(), + "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(), + } + return pg_loss, pg_metrics + + +def compute_entropy_loss(logits, response_mask, loss_agg_mode: str = "token-mean"): + """Compute categorical entropy loss (For backward compatibility) + + Args: + logits (torch.Tensor): shape is (bs, response_length, vocab_size) + response_mask (torch.Tensor): shape is (bs, response_length) + + Returns: + entropy: a scalar torch.Tensor + + """ + # compute entropy + token_entropy = verl_F.entropy_from_logits(logits) # (bs, response_len) + entropy_loss = agg_loss(loss_mat=token_entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + return entropy_loss + + +def compute_value_loss( + vpreds: torch.Tensor, + returns: torch.Tensor, + values: torch.Tensor, + response_mask: torch.Tensor, + cliprange_value: float, + loss_agg_mode: str = "token-mean", +): + """ + Compute the clipped value-function loss for PPO. + + Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151 + + Args: + vpreds (torch.FloatTensor): + Predicted values from the value head, shape (batch_size, response_length). + values (torch.FloatTensor): + Old (baseline) values from the value head, shape (batch_size, response_length). + returns (torch.FloatTensor): + Ground-truth returns, shape (batch_size, response_length). + response_mask (torch.Tensor): + Mask indicating which tokens to include in the value loss calculation. + cliprange_value (float): + Clip range for value prediction updates. + loss_agg_mode (str, optional): + Aggregation mode for `agg_loss`. Defaults to "token-mean". + + Returns: + vf_loss (torch.FloatTensor): + A scalar tensor containing the aggregated value-function loss. + vf_clipfrac (float): + Fraction of elements where the clipped loss was used. + """ + vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value) + vf_losses1 = (vpreds - returns) ** 2 + vf_losses2 = (vpredclipped - returns) ** 2 + clipped_vf_losses = torch.max(vf_losses1, vf_losses2) + vf_loss = 0.5 * agg_loss(loss_mat=clipped_vf_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), response_mask) + return vf_loss, vf_clipfrac + + +def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor: + """Compute KL divergence given logprob and ref_logprob. Optionally using straight through to bind k2 on other + kl penalty compute method for unbiased KL gradient estimation. + See more description in http://joschu.net/blog/kl-approx.html + + Args: + logprob: + ref_logprob: + + Returns: + kl_estimate + """ + forward_score = kl_penalty_forward(logprob, ref_logprob, kl_penalty) + if not kl_penalty.endswith("+") or kl_penalty in ("mse", "k2"): + return forward_score + + """ + The expectation of k1 and k3 estimator is the expectaed value of KL, but the expected gradient of k1 and k3 + estimator is not the expectaed gradient of KL. On the other hand k2 estimator gives right gradient estimator, + so we use a straight through trick here if the kl_penalty method ends with '+', .e.g., k3+. + """ + backward_score = 0.5 * (logprob - ref_logprob).square() + + return backward_score - backward_score.detach() + forward_score.detach() + + +def kl_penalty_forward(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor: + """Compute KL divergence given logprob and ref_logprob. + Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104 + See more description in http://joschu.net/blog/kl-approx.html + + Args: + logprob: + ref_logprob: + + Returns: + kl_estimate + """ + if kl_penalty in ("kl", "k1"): + return logprob - ref_logprob + + if kl_penalty == "abs": + return (logprob - ref_logprob).abs() + + if kl_penalty in ("mse", "k2"): + return 0.5 * (logprob - ref_logprob).square() + + # J. Schulman. Approximating kl divergence, 2020. + # # URL http://joschu.net/blog/kl-approx.html. + if kl_penalty in ("low_var_kl", "k3"): + kl = ref_logprob - logprob + # For numerical stability + kl = torch.clamp(kl, min=-20, max=20) + ratio = torch.exp(kl) + kld = (ratio - kl - 1).contiguous() + return torch.clamp(kld, min=-10, max=10) + + if kl_penalty == "full": + # so, here logprob and ref_logprob should contain the logits for every token in vocabulary + raise NotImplementedError + + raise NotImplementedError + + +def compute_pf_ppo_reweight_data( + data, + reweight_method: str = "pow", + weight_pow: float = 2.0, +): + """Reweight the data based on the token_level_scores. + + Args: + data: DataProto object, containing batch, non_tensor_batch and meta_info + reweight_method: str, choices: "pow", "max_min", "max_random" + weight_pow: float, the power of the weight + + Returns: + + """ + + @torch.no_grad() + def compute_weights(scores: torch.Tensor, reweight_method: str, weight_pow: float) -> torch.Tensor: + """Compute importance weights for resampling based on scores. + + Args: + scores (torch.Tensor): Tensor of scores to compute weights from. + reweight_method (str): Method for computing weights ('pow', 'max_min', 'max_random'). + weight_pow (float): Power exponent for 'pow' method. + + Returns: + torch.Tensor: Computed importance weights. + + Raises: + ValueError: If reweight_method is not supported. + """ + if reweight_method == "pow": + weights = torch.pow(torch.abs(scores), weight_pow) + elif reweight_method == "max_min": + max_score = torch.max(scores) + min_score = torch.min(scores) + weights = torch.where((scores == max_score) | (scores == min_score), 1.0, 0.0) + elif reweight_method == "max_random": + max_score = torch.max(scores) + weights = torch.where(scores == max_score, 0.4, 0.1) + else: + raise ValueError(f"Unsupported reweight_method: {reweight_method}") + return weights + + scores = data.batch["token_level_scores"].sum(dim=-1) + weights = compute_weights(scores, reweight_method, weight_pow) + weights = torch.clamp(weights + 1e-8, min=1e-8) + + batch_size = scores.shape[0] + sample_indices = torch.multinomial(weights, batch_size, replacement=True) + + resampled_batch = {key: tensor[sample_indices] for key, tensor in data.batch.items()} + + sample_indices_np = sample_indices.numpy() + resampled_non_tensor_batch = {} + for key, array in data.non_tensor_batch.items(): + if isinstance(array, np.ndarray): + resampled_non_tensor_batch[key] = array[sample_indices_np] + else: + resampled_non_tensor_batch[key] = [array[i] for i in sample_indices_np] + + resampled_meta_info = {} + for key, value in data.meta_info.items(): + if isinstance(value, list) and len(value) == batch_size: + resampled_meta_info[key] = [value[i] for i in sample_indices_np] + else: + resampled_meta_info[key] = value + + from copy import deepcopy + + resampled_data = deepcopy(data) + resampled_data.batch = type(data.batch)(resampled_batch) + resampled_data.batch.batch_size = data.batch.batch_size + resampled_data.non_tensor_batch = resampled_non_tensor_batch + resampled_data.meta_info = resampled_meta_info + + return resampled_data + + +def compute_policy_loss_reinforce( + rollout_log_prob: torch.Tensor, + log_prob: torch.Tensor, + advantages: torch.Tensor, + response_mask: torch.Tensor, + loss_agg_mode: str = "seq-mean-token-sum", + config: Optional[ActorConfig] = None, + rollout_is_weights: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, dict[str, Any]]: + """Compute REINFORCE-style policy gradient loss with optional IS correction. + + This function implements policy gradient (REINFORCE) with optional importance + sampling correction for rollout-training policy mismatch. + + Mathematical formulation: + Without IS (rollout_is_weights=None): + L = -E[log π(a|s) * A(s,a)] + Gradient: ∇_θ L = -E[∇log π(a|s) * A] (standard REINFORCE) + + With IS (rollout_is_weights provided): + L = -E_π_rollout[w * log π(a|s) * A(s,a)] + where w = π_current / π_rollout (truncated IS weight) + Gradient: ∇_θ L = -E[w * ∇log π(a|s) * A] (IS-corrected policy gradient) + + Args: + rollout_log_prob: Log probabilities from rollout policy (e.g., vLLM BF16). + Shape: (batch_size, seq_length). Used for KL computation. + log_prob: Log probabilities from current training policy. + Shape: (batch_size, seq_length) + advantages: Advantage estimates for each token. + Shape: (batch_size, seq_length) + response_mask: Mask indicating valid tokens (1 for valid, 0 for padding). + Shape: (batch_size, seq_length). Should already include rejection sampling. + loss_agg_mode: Loss aggregation strategy (see agg_loss for details). + config: Actor config (required for global_batch_info). + rollout_is_weights: Pre-computed IS weights (π_current / π_rollout). + Shape: (batch_size, seq_length). None to disable IS correction. + + Returns: + Tuple of (loss, metrics): + loss: Scalar policy gradient loss + metrics: Dictionary with "actor/ppo_kl" + + Note: + Unlike PPO (compute_policy_loss_vanilla), this function: + - Does NOT use PPO clipping + - Uses log π(a|s) directly (not ratio) + - IS weights are applied as multiplicative factor + """ + assert config is not None, "ActorConfig must be provided for REINFORCE loss" + + # Compute pure policy gradient loss with optional IS correction + # Standard REINFORCE: L = -E[log π(a|s) * A] + # With IS: L = -E[w * log π(a|s) * A] where w = π_current / π_rollout + if rollout_is_weights is not None: + # IS-corrected policy gradient: L = -E[stopgrad(w) · log π · A] + pg_losses = -advantages * log_prob * rollout_is_weights + else: + # Standard REINFORCE: L = -E[log π · A] + pg_losses = -advantages * log_prob + + # Aggregate loss + pg_loss = agg_loss( + loss_mat=pg_losses, + loss_mask=response_mask, + loss_agg_mode=loss_agg_mode, + **config.global_batch_info, + ) + + # Compute KL divergence between current and rollout policy + negative_approx_kl = log_prob - rollout_log_prob + kl_divergence = verl_F.masked_mean(-negative_approx_kl, response_mask) + + pg_metrics = { + "actor/ppo_kl": kl_divergence.detach().item(), + } + + return pg_loss, pg_metrics + + +@register_policy_loss("bypass_mode") +def compute_policy_loss_bypass_mode( + old_log_prob: torch.Tensor, + log_prob: torch.Tensor, + advantages: torch.Tensor, + response_mask: torch.Tensor, + loss_agg_mode: str = "token-mean", + config: Optional[ActorConfig] = None, + rollout_is_weights: torch.Tensor | None = None, +) -> tuple[torch.Tensor, dict[str, Any]]: + """Bypass mode policy loss supporting both REINFORCE and PPO-clip. + + This function is the entry point for bypass mode, where old_log_prob = rollout_log_prob. + It computes IS weights and rejection masks, then dispatches to either REINFORCE or + PPO-clip loss based on the loss_type configuration. + + IMPORTANT - Bypass mode semantics: + In bypass mode, the trainer sets old_log_prob = rollout_log_prob. + This means: + - For REINFORCE: We use IS weights w = π_current / π_rollout explicitly + - For PPO-clip: The PPO ratio π_current / π_old = π_current / π_rollout + already incorporates the IS correction through clipping, so we do NOT + apply additional IS weights (would be double-counting) + + Loss types: + - "ppo_clip" (default): PPO clipped objective (compute_policy_loss_vanilla) + L = -E[min(r*A, clip(r)*A)] where r = π_current / π_rollout + Note: IS weights are NOT applied (clipping handles the ratio) + - "reinforce": REINFORCE-style policy gradient with IS correction + L = -E[w * log π(a|s) * A] where w = π_current / π_rollout + + Args: + old_log_prob: In bypass mode, this is actually rollout_log_prob. + Shape: (batch_size, seq_length) + log_prob: Current policy log probabilities. + Shape: (batch_size, seq_length) + advantages: Advantage estimates. + Shape: (batch_size, seq_length) + response_mask: Valid token mask (1=valid, 0=padding). + Shape: (batch_size, seq_length) + loss_agg_mode: Loss aggregation mode (passed to underlying loss function). + config: Actor config containing rollout_correction settings in policy_loss. + rollout_is_weights: Pre-computed IS weights (ignored, computed internally). + + Config options (in config.policy_loss.rollout_correction): + loss_type: "ppo_clip" (default) or "reinforce" + rollout_is: IS aggregation level ("token", "sequence", or None) + rollout_is_threshold: Upper threshold for truncating IS weights (default: 2.0) + rollout_rs: Rejection sampling level (see rollout_corr_helper for supported modes) + rollout_rs_threshold: Threshold specification for rejection sampling + rollout_is_batch_normalize: Whether to normalize IS weights to mean=1.0 + + Returns: + Tuple of (loss, metrics): + loss: Scalar policy loss + metrics: Dictionary with rollout correction metrics and actor/ppo_kl + """ + from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_rejection_mask + + assert config is not None, "config is required for bypass_mode loss" + + # Extract rollout_correction config from policy_loss + rollout_corr_config = config.policy_loss.get("rollout_correction", None) if hasattr(config, "policy_loss") else None + + if rollout_corr_config is None: + raise ValueError( + "rollout_correction config not found in policy_loss. " + "When using loss_mode='bypass_mode', ensure rollout_correction config is passed." + ) + + # Extract parameters + loss_type = rollout_corr_config.get("loss_type", "ppo_clip") + rollout_is = rollout_corr_config.get("rollout_is", None) + rollout_is_threshold = rollout_corr_config.get("rollout_is_threshold", 2.0) + rollout_is_batch_normalize = rollout_corr_config.get("rollout_is_batch_normalize", False) + rollout_rs = rollout_corr_config.get("rollout_rs", None) + rollout_rs_threshold = rollout_corr_config.get("rollout_rs_threshold", None) + + # In bypass mode: old_log_prob IS rollout_log_prob + rollout_log_prob = old_log_prob + + # Compute IS weights and rejection mask + # Note: For PPO-clip, we still compute IS weights for metrics, but don't apply them + with torch.no_grad(): + rollout_is_weights_proto, modified_response_mask, rollout_metrics = ( + compute_rollout_correction_and_rejection_mask( + old_log_prob=log_prob, # Current policy (for IS ratio: π_current / π_rollout) + rollout_log_prob=rollout_log_prob, # Rollout policy + response_mask=response_mask, + rollout_is=rollout_is, + rollout_is_threshold=rollout_is_threshold, + rollout_is_batch_normalize=rollout_is_batch_normalize, + rollout_rs=rollout_rs, + rollout_rs_threshold=rollout_rs_threshold, + ) + ) + + # Extract IS weights tensor (or None if disabled) + computed_is_weights = rollout_is_weights_proto.batch["rollout_is_weights"] if rollout_is_weights_proto else None + + # Apply rejection mask (RS + veto) + effective_mask = modified_response_mask + + # Dispatch to appropriate loss function based on loss_type + if loss_type == "reinforce": + # REINFORCE: Apply IS weights explicitly + pg_loss, pg_metrics = compute_policy_loss_reinforce( + rollout_log_prob=rollout_log_prob, + log_prob=log_prob, + advantages=advantages, + response_mask=effective_mask, + loss_agg_mode=loss_agg_mode, + config=config, + rollout_is_weights=computed_is_weights, + ) + + elif loss_type == "ppo_clip": + # PPO-clip: The ratio π_current/π_old = π_current/π_rollout already handles IS + # DO NOT apply IS weights - would be double-counting! + # The clipping mechanism constrains the effective IS ratio + pg_loss, pg_metrics = compute_policy_loss_vanilla( # type: ignore[call-arg] + old_log_prob=rollout_log_prob, # = old_log_prob in bypass mode + log_prob=log_prob, + advantages=advantages, + response_mask=effective_mask, + loss_agg_mode=loss_agg_mode, + config=config, + rollout_is_weights=None, # Explicitly None - no IS weights for PPO-clip + ) + + else: + raise ValueError(f"Invalid loss_type: {loss_type}. Must be 'reinforce' or 'ppo_clip'.") + + # Merge rollout correction metrics + pg_metrics.update(rollout_metrics) + + return pg_loss, pg_metrics diff --git a/code/RL_model/verl/verl_train/verl/trainer/ppo/metric_utils.py b/code/RL_model/verl/verl_train/verl/trainer/ppo/metric_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4dd7d2d00a5990533566eed5aad5ee56a38a50ca --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/ppo/metric_utils.py @@ -0,0 +1,659 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Metrics related to the PPO trainer. +""" + +from collections import defaultdict +from functools import partial +from typing import Any, Callable + +import numpy as np +import torch + +import verl.utils.torch_functional as verl_F +from verl import DataProto +from verl.utils.import_utils import deprecated + + +@deprecated("verl.utils.metric.reduce_metrics") +def reduce_metrics(metrics: dict[str, list[Any]]) -> dict[str, Any]: + """ + Reduces a dictionary of metric lists by computing the mean of each list. + + Args: + metrics: A dictionary mapping metric names to lists of metric values. + + Returns: + A dictionary with the same keys but with each list replaced by its mean value. + + Example: + >>> metrics = {"loss": [1.0, 2.0, 3.0], "accuracy": [0.8, 0.9, 0.7]} + >>> reduce_metrics(metrics) + {"loss": 2.0, "accuracy": 0.8} + """ + from verl.utils.metric import reduce_metrics + + return reduce_metrics(metrics) + + +def _compute_response_info(batch: DataProto) -> dict[str, Any]: + """ + Computes information about prompts and responses from a batch. + + This is an internal helper function that extracts masks and lengths for prompts and responses. + + Args: + batch: A DataProto object containing batch data with responses and attention masks. + + Returns: + A dictionary containing: + - response_mask: Attention mask for the response tokens + - prompt_length: Tensor of prompt lengths for each item in the batch + - response_length: Tensor of response lengths for each item in the batch + """ + response_length = batch.batch["responses"].shape[-1] + + prompt_mask = batch.batch["attention_mask"][:, :-response_length] + response_mask = batch.batch["attention_mask"][:, -response_length:] + + prompt_length = prompt_mask.sum(-1).float() + response_length = response_mask.sum(-1).float() # (batch_size,) + + return dict( + response_mask=response_mask, + prompt_length=prompt_length, + response_length=response_length, + ) + + +def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str, Any]: + """ + Computes various metrics from a batch of data for PPO training. + + This function calculates metrics related to scores, rewards, advantages, returns, values, + and sequence lengths from a batch of data. It provides statistical information (mean, max, min) + for each metric category. + + Args: + batch: A DataProto object containing batch data with token-level scores, rewards, advantages, etc. + use_critic: Whether to include critic-specific metrics. Defaults to True. + + Returns: + A dictionary of metrics including: + - critic/score/mean, max, min: Statistics about sequence scores + - critic/rewards/mean, max, min: Statistics about sequence rewards + - critic/advantages/mean, max, min: Statistics about advantages + - critic/returns/mean, max, min: Statistics about returns + - critic/values/mean, max, min: Statistics about critic values (if use_critic=True) + - critic/vf_explained_var: Explained variance of the value function (if use_critic=True) + - response_length/mean, max, min, clip_ratio: Statistics about response lengths + - prompt_length/mean, max, min, clip_ratio: Statistics about prompt lengths + - num_turns/mean, max, min: Statistics about the number of multi-turn conversations + """ + sequence_score = batch.batch["token_level_scores"].sum(-1) + sequence_reward = batch.batch["token_level_rewards"].sum(-1) + + advantages = batch.batch["advantages"] + returns = batch.batch["returns"] + + max_response_length = batch.batch["responses"].shape[-1] + + prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool() + response_mask = batch.batch["response_mask"].bool() + + max_prompt_length = prompt_mask.size(-1) + + response_info = _compute_response_info(batch) + prompt_length = response_info["prompt_length"] + response_length = response_info["response_length"] + + aborted_mask = (response_length == 0).bool() + non_aborted_mask = ~aborted_mask + + non_aborted_sequence_score = sequence_score[non_aborted_mask] + non_aborted_sequence_reward = sequence_reward[non_aborted_mask] + + score_mean = torch.mean(non_aborted_sequence_score).detach().item() + score_max = torch.max(non_aborted_sequence_score).detach().item() + score_min = torch.min(non_aborted_sequence_score).detach().item() + + reward_mean = torch.mean(non_aborted_sequence_reward).detach().item() + reward_max = torch.max(non_aborted_sequence_reward).detach().item() + reward_min = torch.min(non_aborted_sequence_reward).detach().item() + + valid_adv = torch.masked_select(advantages, response_mask) + valid_returns = torch.masked_select(returns, response_mask) + + if use_critic: + values = batch.batch["values"] + valid_values = torch.masked_select(values, response_mask) + return_diff_var = torch.var(valid_returns - valid_values) + return_var = torch.var(valid_returns) + + # Aborted samples and non-aborted response length statistics + # response_length_non_aborted/*: statistics computed on non-aborted samples only + aborted_ratio = torch.mean(aborted_mask.float()).detach().item() + + non_aborted_response_length = response_length[non_aborted_mask] + if non_aborted_response_length.numel() > 0: + non_aborted_response_length_mean = torch.mean(non_aborted_response_length).detach().item() + non_aborted_response_length_max = torch.max(non_aborted_response_length).detach().item() + non_aborted_response_length_min = torch.min(non_aborted_response_length).detach().item() + non_aborted_response_length_clip_ratio = ( + torch.mean(torch.eq(non_aborted_response_length, max_response_length).float()).detach().item() + ) + else: + raise ValueError("All samples are aborted, this should not happen.") + + metrics = { + # score + "critic/score/mean": score_mean, + "critic/score/max": score_max, + "critic/score/min": score_min, + # reward + "critic/rewards/mean": reward_mean, + "critic/rewards/max": reward_max, + "critic/rewards/min": reward_min, + # adv + "critic/advantages/mean": torch.mean(valid_adv).detach().item(), + "critic/advantages/max": torch.max(valid_adv).detach().item(), + "critic/advantages/min": torch.min(valid_adv).detach().item(), + # returns + "critic/returns/mean": torch.mean(valid_returns).detach().item(), + "critic/returns/max": torch.max(valid_returns).detach().item(), + "critic/returns/min": torch.min(valid_returns).detach().item(), + **( + { + # values + "critic/values/mean": torch.mean(valid_values).detach().item(), + "critic/values/max": torch.max(valid_values).detach().item(), + "critic/values/min": torch.min(valid_values).detach().item(), + # vf explained var + "critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(), + } + if use_critic + else {} + ), + # response length + "response_length/mean": torch.mean(response_length).detach().item(), + "response_length/max": torch.max(response_length).detach().item(), + "response_length/min": torch.min(response_length).detach().item(), + "response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float()) + .detach() + .item(), + # response length (non-aborted only) + # These statistics exclude aborted samples to avoid skew from zeros + "response_length_non_aborted/mean": non_aborted_response_length_mean, + "response_length_non_aborted/max": non_aborted_response_length_max, + "response_length_non_aborted/min": non_aborted_response_length_min, + "response_length_non_aborted/clip_ratio": non_aborted_response_length_clip_ratio, + # aborted ratio + # Fraction of samples whose response length is zero + "response/aborted_ratio": aborted_ratio, + # prompt length + "prompt_length/mean": torch.mean(prompt_length).detach().item(), + "prompt_length/max": torch.max(prompt_length).detach().item(), + "prompt_length/min": torch.min(prompt_length).detach().item(), + "prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(), + } + + # multi-turn conversation + if "__num_turns__" in batch.non_tensor_batch: + num_turns = batch.non_tensor_batch["__num_turns__"] + metrics["num_turns/min"] = num_turns.min() + metrics["num_turns/max"] = num_turns.max() + metrics["num_turns/mean"] = num_turns.mean() + + if "tool_call_counts" in batch.non_tensor_batch: + tool_call_counts = batch.non_tensor_batch["tool_call_counts"] + metrics["tool_call_counts/min"] = tool_call_counts.min() + metrics["tool_call_counts/max"] = tool_call_counts.max() + metrics["tool_call_counts/mean"] = tool_call_counts.mean() + + return metrics + + +def compute_timing_metrics(batch: DataProto, timing_raw: dict[str, float]) -> dict[str, Any]: + """ + Computes timing metrics for different processing stages in PPO training. + + This function calculates both raw timing metrics (in seconds) and per-token timing metrics + (in milliseconds) for various processing stages like generation, reference computation, + value computation, advantage computation, and model updates. + + Args: + batch: A DataProto object containing batch data with responses and attention masks. + timing_raw: A dictionary mapping stage names to their execution times in seconds. + + Returns: + A dictionary containing: + - timing_s/{name}: Raw timing in seconds for each stage + - timing_per_token_ms/{name}: Per-token timing in milliseconds for each stage + + Note: + Different stages use different token counts for normalization: + - "gen" uses only response tokens + - Other stages ("ref", "values", "adv", "update_critic", "update_actor") use all tokens + (prompt + response) + """ + response_info = _compute_response_info(batch) + num_prompt_tokens = torch.sum(response_info["prompt_length"]).item() + num_response_tokens = torch.sum(response_info["response_length"]).item() + num_overall_tokens = num_prompt_tokens + num_response_tokens + + num_tokens_of_section = { + "gen": num_response_tokens, + **{name: num_overall_tokens for name in ["ref", "values", "adv", "update_critic", "update_actor"]}, + } + + return { + **{f"timing_s/{name}": value for name, value in timing_raw.items()}, + **{ + f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name] + for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys()) + }, + } + + +def compute_throughout_metrics(batch: DataProto, timing_raw: dict[str, float], n_gpus: int) -> dict[str, Any]: + """ + Computes throughput metrics for PPO training. + + This function calculates performance metrics related to token processing speed, + including the total number of tokens processed, time per step, and throughput + (tokens per second per GPU). + + Args: + batch: A DataProto object containing batch data with meta information about token counts. + timing_raw: A dictionary mapping stage names to their execution times in seconds. + Must contain a "step" key with the total step time. + n_gpus: Number of GPUs used for training. + + Returns: + A dictionary containing: + - perf/total_num_tokens: Total number of tokens processed in the batch + - perf/time_per_step: Time taken for the step in seconds + - perf/throughput: Tokens processed per second per GPU + + Note: + The throughput is calculated as total_tokens / (time * n_gpus) to normalize + across different GPU counts. + """ + total_num_tokens = sum(batch.meta_info["global_token_num"]) + time = timing_raw["step"] + # estimated_flops, promised_flops = flops_function.estimate_flops(num_tokens, time) + # f'Actual TFLOPs/s/GPU​': estimated_flops/(n_gpus), + # f'Theoretical TFLOPs/s/GPU​': promised_flops, + return { + "perf/total_num_tokens": total_num_tokens, + "perf/time_per_step": time, + "perf/throughput": total_num_tokens / (time * n_gpus), + } + + +def compute_variance_proxy_metrics(batch: DataProto, gradient_norm: float = None) -> dict[str, float]: + """ + Compute variance proxy metrics using the simplified expected squared norm approach. + + This metric provides a computationally efficient way to monitor gradient variance + during training. It works for any advantage estimator as long as sum_pi_squared + is available from the actor. + + Theory: + - Full variance: Var(g̃) = E[||g̃||²] - ||g_true||² + - Simplified proxy (when ||g_true||² ≈ 0): Var(g̃) ≈ E[||g̃||²] + - Using W-score approximation: E[||g̃||²] ≈ E[A² × W(τ)] + + Where W(τ) = Σ_t[1 - 2π_t(y_t) + Σπ²] is the score-norm proxy. + """ + metrics = {} + + # Check if we have the necessary data (sum_pi_squared is required for W-score) + if "sum_pi_squared" not in batch.batch or "old_log_probs" not in batch.batch or "advantages" not in batch.batch: + return metrics + + # Compute W(τ) = Σ_t[1 - 2π_t(y_t) + Σπ²] + pi_t = torch.exp(batch.batch["old_log_probs"]) + w_per_timestep = 1 - 2 * pi_t + batch.batch["sum_pi_squared"] + + # Get response mask to only consider valid tokens + response_mask = batch.batch["response_mask"] + + # Use pre-computed rollout IS weights from batch (for variance proxy consistency with training loss) + # IS weights are computed centrally in ray_trainer.py to avoid duplication + rollout_is_weights = None + if "rollout_is_weights" in batch.batch: + # Extract pre-computed IS weights from batch (already computed in trainer) + rollout_is_weights = batch.batch["rollout_is_weights"] + + # Scale W by (rollout IS weight)² for optimal baseline under biased estimation + w_per_timestep = w_per_timestep * (rollout_is_weights**2).detach() + + # Note: IS weight statistics and mismatch metrics are logged in ray_trainer.py + + # Get scalar advantages (mean over timesteps) + advantages = batch.batch["advantages"] + # Compute mean advantage per trajectory using masked_mean + advantages_scalar = verl_F.masked_mean(advantages, response_mask, axis=-1) + + # Compute W values (sum over timesteps) + w_values = verl_F.masked_sum(w_per_timestep, response_mask, axis=-1) + + # ====== COMPUTE VARIANCE PROXIES ====== + # Variance proxy should match the actual gradient computation: + # - If IS weights were computed/applied: use them in variance proxy calculation + # - Otherwise: compute on-policy variance proxy + + # ====== PROXY 1: Signal Strength ||ḡ||² ====== + # The squared norm of the mean gradient (provided from training loop) + proxy1_signal_strength = gradient_norm**2 if gradient_norm is not None else None + + # ====== PROXY 2: Total Power E[||ĝ_τ||²] ====== + # Measures the average of squared gradient norms (Signal + Noise) + if rollout_is_weights is not None: + # Off-policy with IS correction applied: use clamped weights consistently with actual gradient computation + rollout_is_weights_scalar = verl_F.masked_mean(rollout_is_weights, response_mask, axis=-1) + # Recover original W (before IS correction was applied in line 657) + # Clamp to avoid division by zero when IS weights are zero + w_original = verl_F.masked_sum( + w_per_timestep / torch.clamp((rollout_is_weights**2).detach(), min=1e-10), response_mask, axis=-1 + ) + # Clamp W to avoid negative values (which would cause NaN in sqrt) + w_original = torch.clamp(w_original, min=0.0) + # Proxy 2 for off-policy: E[ρ̄² × A² × W] + proxy2_total_power = ((rollout_is_weights_scalar**2) * (advantages_scalar**2) * w_original).mean() + + else: + # On-policy Proxy 2: E[A² × W] + # Clamp W to avoid negative values (which would cause NaN in sqrt) + w_values_clamped = torch.clamp(w_values, min=0.0) + proxy2_total_power = (advantages_scalar**2 * w_values_clamped).mean() + + # ====== PROXY 3: Pure Noise - Variance of Mean Vector ====== + # Requires ||ḡ||² from actual batch gradient + # Formula: (1/(N-1)) × (Proxy2 - Proxy1) + proxy3_pure_noise = None + if proxy1_signal_strength is not None: + batch_size = advantages_scalar.shape[0] + if batch_size > 1: + proxy3_pure_noise = (1.0 / (batch_size - 1)) * (proxy2_total_power - proxy1_signal_strength) + # Ensure non-negative (can be negative due to numerical errors) + proxy3_pure_noise = max( + 0.0, proxy3_pure_noise.item() if torch.is_tensor(proxy3_pure_noise) else proxy3_pure_noise + ) + + # Decompose into components for analysis + expected_a_squared = (advantages_scalar**2).mean() + expected_w = w_values.mean() + + metrics.update( + { + # Proxy 1: Signal Strength ||ḡ||² + "variance_proxy/proxy1_signal_strength": ( + proxy1_signal_strength if proxy1_signal_strength is not None else 0.0 + ), + # Proxy 2: Total Power E[||ĝ_τ||²] + "variance_proxy/proxy2_total_power": proxy2_total_power.detach().item(), + # Proxy 3: Pure Noise - Variance of Mean Vector + "variance_proxy/proxy3_pure_noise": proxy3_pure_noise if proxy3_pure_noise is not None else 0.0, + # Component metrics for debugging + "variance_proxy/expected_a_squared": expected_a_squared.detach().item(), + "variance_proxy/expected_w": expected_w.detach().item(), + } + ) + + return metrics + + +def bootstrap_metric( + data: list[Any], + subset_size: int, + reduce_fns: list[Callable[[np.ndarray], float]], + n_bootstrap: int = 1000, + seed: int = 42, +) -> list[tuple[float, float]]: + """ + Performs bootstrap resampling to estimate statistics of metrics. + + This function uses bootstrap resampling to estimate the mean and standard deviation + of metrics computed by the provided reduction functions on random subsets of the data. + + Args: + data: List of data points to bootstrap from. + subset_size: Size of each bootstrap sample. + reduce_fns: List of functions that compute a metric from a subset of data. + n_bootstrap: Number of bootstrap iterations. Defaults to 1000. + seed: Random seed for reproducibility. Defaults to 42. + + Returns: + A list of tuples, where each tuple contains (mean, std) for a metric + corresponding to each reduction function in reduce_fns. + + Example: + >>> data = [1, 2, 3, 4, 5] + >>> reduce_fns = [np.mean, np.max] + >>> bootstrap_metric(data, 3, reduce_fns) + [(3.0, 0.5), (4.5, 0.3)] # Example values + """ + np.random.seed(seed) + data_np = np.array(data, dtype=object) + n_data = len(data_np) + + # generate bootstrap indices, shape: (n_bootstrap, subset_size) + bootstrap_idxs = np.random.choice(n_data, size=(n_bootstrap, subset_size), replace=True) + + # pre-allocate result array, shape: (n_fns, n_bootstrap) + n_fns = len(reduce_fns) + metric_results = np.empty((n_fns, n_bootstrap), dtype=np.float64) + + # compute metric results for each bootstrap sample + for fn_idx, reduce_fn in enumerate(reduce_fns): + # bootstrap sample and compute metric + for boot_idx in range(n_bootstrap): + sample = data_np[bootstrap_idxs[boot_idx]] + metric_results[fn_idx, boot_idx] = reduce_fn(sample) + + # compute mean and std for each metric function + result = [ + (float(np.mean(metric_results[fn_idx])), float(np.std(metric_results[fn_idx]))) for fn_idx in range(n_fns) + ] + return result + + +def calc_maj_val(data: list[dict[str, Any]], vote_key: str, val_key: str) -> float: + """ + Calculate a value based on majority voting. + + This function identifies the most common value for a specified vote key + in the data, then returns the corresponding value for that majority vote. + + Args: + data: List of dictionaries, where each dictionary contains both vote_key and val_key. + vote_key: The key in each dictionary used for voting/counting. + val_key: The key in each dictionary whose value will be returned for the majority vote. + + Returns: + The value associated with the most common vote. + + Example: + >>> data = [ + ... {"pred": "A", "val": 0.9}, + ... {"pred": "B", "val": 0.8}, + ... {"pred": "A", "val": 0.7} + ... ] + >>> calc_maj_val(data, vote_key="pred", val_key="val") + 0.9 # Returns the first "val" for the majority vote "A" + """ + vote2vals = defaultdict(list) + for d in data: + vote2vals[d[vote_key]].append(d[val_key]) + + vote2cnt = {k: len(v) for k, v in vote2vals.items()} + maj_vote = max(vote2cnt, key=vote2cnt.get) + + maj_val = vote2vals[maj_vote][0] + + return maj_val + + +def process_validation_metrics( + data_sources: list[str], sample_uids: list[str], infos_dict: dict[str, list[Any]], seed: int = 42 +) -> dict[str, dict[str, dict[str, float]]]: + """ + Process validation metrics into a structured format with statistical analysis. + + This function organizes validation metrics by data source and prompt, then computes + various statistical measures including means, standard deviations, best/worst values, + and majority voting results. It also performs bootstrap sampling to estimate statistics + for different sample sizes. + + Args: + data_sources: List of data source identifiers for each sample. + sample_uids: List of sample uids corresponding to each sample. + infos_dict: Dictionary mapping variable names to lists of values for each sample. + seed: Random seed for bootstrap sampling. Defaults to 42. + + Returns: + A nested dictionary with the structure: + { + data_source: { + variable_name: { + metric_name: value + } + } + } + + Where metric_name includes: + - "mean@N": Mean value across N samples + - "std@N": Standard deviation across N samples + - "best@N/mean": Mean of the best values in bootstrap samples of size N + - "best@N/std": Standard deviation of the best values in bootstrap samples + - "worst@N/mean": Mean of the worst values in bootstrap samples + - "worst@N/std": Standard deviation of the worst values in bootstrap samples + - "maj@N/mean": Mean of majority voting results in bootstrap samples (if "pred" exists) + - "maj@N/std": Standard deviation of majority voting results (if "pred" exists) + + Example: + >>> data_sources = ["source1", "source1", "source2"] + >>> sample_uids = ["uid1", "uid1", "uid2"] + >>> infos_dict = {"score": [0.8, 0.9, 0.7], "pred": ["A", "A", "B"]} + >>> result = process_validation_metrics(data_sources, sample_uids, infos_dict) + >>> # result will contain statistics for each data source and variable + """ + # Group metrics by data source, prompt and variable + data_src2uid2var2vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + for sample_idx, data_source in enumerate(data_sources): + uid = sample_uids[sample_idx] + var2vals = data_src2uid2var2vals[data_source][uid] + for var_name, var_vals in infos_dict.items(): + var2vals[var_name].append(var_vals[sample_idx]) + + np_mean = np.mean + np_std = np.std + reduce_fns_best_worst = [np.max, np.min] + n_bootstrap = 1000 + + # 2. cache ns list + def gen_ns(n_resps: int) -> list[int]: + if n_resps <= 1: + return [] + ns = [] + n = 2 + while n < n_resps: + ns.append(n) + n *= 2 + ns.append(n_resps) + return ns + + ns_cache = {} + + # 3. cache metric results + data_src2uid2var2metric = {} + + # 4. flatten loop + for data_source, uid2var2vals in data_src2uid2var2vals.items(): + # create uid dict + uid_dict = data_src2uid2var2metric.setdefault(data_source, {}) + + for uid, var2vals in uid2var2vals.items(): + pred_vals = var2vals.get("pred") + has_pred = pred_vals is not None + var_dict = uid_dict.setdefault(uid, {}) + + for var_name, var_vals in var2vals.items(): + # skip empty or string values + if not var_vals or isinstance(var_vals[0], str): + continue + + # compute mean and std + n_resps = len(var_vals) + metric = {f"mean@{n_resps}": float(np_mean(var_vals))} + + if n_resps > 1: + metric[f"std@{n_resps}"] = float(np_std(var_vals)) + + # cache ns list + if n_resps not in ns_cache: + ns_cache[n_resps] = gen_ns(n_resps) + ns = ns_cache[n_resps] + + # compute best/worst metrics + for n in ns: + # compute best/worst metrics + (bon_mean, bon_std), (won_mean, won_std) = bootstrap_metric( + data=var_vals, + subset_size=n, + reduce_fns=reduce_fns_best_worst, + n_bootstrap=n_bootstrap, + seed=seed, + ) + metric[f"best@{n}/mean"] = bon_mean + metric[f"best@{n}/std"] = bon_std + metric[f"worst@{n}/mean"] = won_mean + metric[f"worst@{n}/std"] = won_std + + # compute maj metrics + if has_pred: + # create vote_data + vote_data = [ + {"val": val, "pred": pred} for val, pred in zip(var_vals, pred_vals, strict=True) + ] + # compute maj metrics + [(maj_n_mean, maj_n_std)] = bootstrap_metric( + data=vote_data, + subset_size=n, + reduce_fns=[partial(calc_maj_val, vote_key="pred", val_key="val")], + n_bootstrap=n_bootstrap, + seed=seed, + ) + metric[f"maj@{n}/mean"] = maj_n_mean + metric[f"maj@{n}/std"] = maj_n_std + + var_dict[var_name] = metric + + # Aggregate metrics across uids + data_src2var2metric2uid_vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + for data_source, uid2var2metric in data_src2uid2var2metric.items(): + for uid, var2metric in uid2var2metric.items(): + for var_name, metric in var2metric.items(): + for metric_name, metric_val in metric.items(): + data_src2var2metric2uid_vals[data_source][var_name][metric_name].append(metric_val) + + data_src2var2metric2val = defaultdict(lambda: defaultdict(lambda: defaultdict(float))) + for data_source, var2metric2uid_vals in data_src2var2metric2uid_vals.items(): + for var_name, metric2uid_vals in var2metric2uid_vals.items(): + for metric_name, uid_vals in metric2uid_vals.items(): + data_src2var2metric2val[data_source][var_name][metric_name] = np.mean(uid_vals) + return data_src2var2metric2val diff --git a/code/RL_model/verl/verl_train/verl/trainer/ppo/prefix_grouper_utils.py b/code/RL_model/verl/verl_train/verl/trainer/ppo/prefix_grouper_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..97b5f36237e53fef23e119d4042de2d0f83810b8 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/ppo/prefix_grouper_utils.py @@ -0,0 +1,235 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +from prefix_grouper import PrefixGrouper + +from verl.utils.torch_functional import logprobs_from_logits + + +def build_position_ids_for_prefix_grouper(prefix_grouper: PrefixGrouper) -> torch.Tensor: + """Build position_ids for PrefixGrouper where each response restarts from prefix_len.""" + num_samples = len(prefix_grouper.group_info) + max_len = prefix_grouper.padding_mask.size(1) + device = prefix_grouper.padding_mask.device + + position_ids = torch.zeros(num_samples, max_len, dtype=torch.long, device=device) + + for i, group in enumerate(prefix_grouper.group_info): + prefix_len = group.prefix_len + + position_ids[i, :prefix_len] = torch.arange(prefix_len, device=device) + cur_pos = prefix_len + for suffix_len in group.suffix_lens: + if suffix_len > 0: + position_ids[i, cur_pos : cur_pos + suffix_len] = torch.arange( + prefix_len, prefix_len + suffix_len, device=device + ) + cur_pos += suffix_len + + return position_ids + + +def build_pg_from_micro_batch( + micro_batch: dict, + pad_token_id: int, + padding_mode: str = "right", +): + """Build PrefixGrouper from micro_batch dict containing prompts, responses, response_mask, uid.""" + prompts = micro_batch["prompts"] + responses = micro_batch["responses"] + response_mask = micro_batch["response_mask"] + uids = micro_batch["uid"] + + bs = responses.size(0) + + group_sizes = [] + cur = 1 + for i in range(1, bs): + if uids[i] == uids[i - 1]: + cur += 1 + else: + group_sizes.append(cur) + cur = 1 + group_sizes.append(cur) + + prefix_indices = [] + cursor = 0 + for gs in group_sizes: + prefix_indices.append(cursor) + cursor += gs + prefix_indices = torch.tensor(prefix_indices, device=prompts.device) + + prefix_ids = prompts.index_select(0, prefix_indices) + prefix_mask = prefix_ids.ne(pad_token_id) + + prefix_grouper = PrefixGrouper.from_ungrouped_masks( + prefix_mask=prefix_mask, + suffix_mask=response_mask, + group_sizes=group_sizes, + padding_mode=padding_mode, + device=prompts.device, + ) + + concat_input_ids = prefix_grouper.concat_input(prefix_ids, prefix_mask, responses, response_mask) + + attention_mask = prefix_grouper.padding_mask + + position_ids = build_position_ids_for_prefix_grouper(prefix_grouper) + + return ( + prefix_grouper, + concat_input_ids, + attention_mask, + position_ids, + responses, + response_mask, + ) + + +def pg_forward( + model, + prefix_grouper, + concat_input_ids, + attention_mask, + position_ids, + completion_ids, + completion_mask, + *, + temperature=1.0, + padding_mode="right", + include_prefix_last=1, + calculate_entropy=False, + entropy_fn=None, +): + logits = model( + input_ids=concat_input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False, + prefix_grouper=prefix_grouper, + ).logits + + prefix_out, prefix_mask, suffix_out_raw, suffix_mask_raw = prefix_grouper.split_output( + logits, include_prefix_last=include_prefix_last + ) + + completion_ids_right = prefix_grouper.convert_padding( + completion_ids, + completion_mask, + padding_mode=padding_mode, + ) + + suffix_out = suffix_out_raw[:, :-1].float() + suffix_mask = suffix_mask_raw[:, 1:] + + suffix_out /= temperature + + log_probs = logprobs_from_logits(suffix_out, completion_ids_right) + + entropy = None + if calculate_entropy and entropy_fn is not None: + entropy = entropy_fn(suffix_out) + + return log_probs, entropy, suffix_mask + + +def forward_micro_batch_with_prefix_grouper( + micro_batch: dict, + model, + temperature: float, + calculate_entropy: bool, + device_name: str, + param_dtype, + use_chunking_entropy: bool = False, +): + """ + Forward pass using PrefixGrouper for shared-prefix optimization. + + Args: + micro_batch: Dict containing prompts, responses, response_mask, uid, etc. + model: The actor module. + temperature: Temperature for logits scaling. + calculate_entropy: Whether to compute entropy. + device_name: Device name for autocast. + param_dtype: Parameter dtype for autocast. + use_chunking_entropy: Whether to use chunking entropy function. + + Returns: + tuple: (entropy, log_probs) where entropy may be None if not calculated. + """ + import verl.utils.torch_functional as verl_F + + entropy_fn = None + if calculate_entropy: + if use_chunking_entropy: + entropy_fn = verl_F.entropy_from_logits_with_chunking + else: + entropy_fn = verl_F.entropy_from_logits + + pad_token_id = micro_batch.get("pad_token_id", 0) + + ( + prefix_grouper, + concat_input_ids, + attention_mask, + position_ids, + responses, + response_mask, + ) = build_pg_from_micro_batch( + micro_batch, + pad_token_id=pad_token_id, + padding_mode="right", + ) + + with torch.autocast(device_type=device_name, dtype=param_dtype): + log_probs, entropy, suffix_mask_from_pg = pg_forward( + model=model, + prefix_grouper=prefix_grouper, + concat_input_ids=concat_input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + completion_ids=responses, + completion_mask=response_mask, + temperature=temperature, + padding_mode="right", + include_prefix_last=1, + calculate_entropy=calculate_entropy, + entropy_fn=entropy_fn, + ) + + # Zero out padding positions + padding_mask = suffix_mask_from_pg == 0 + log_probs = log_probs.masked_fill(padding_mask, 0.0) + if entropy is not None: + entropy = entropy.masked_fill(padding_mask, 0.0) + + # Pad to target response length if needed + target_response_length = responses.size(1) + if log_probs.size(1) != target_response_length: + batch_size = log_probs.size(0) + current_len = log_probs.size(1) + + full_log_probs = log_probs.new_zeros(batch_size, target_response_length) + full_log_probs[:, :current_len] = log_probs + log_probs = full_log_probs + + if entropy is not None: + full_entropy = entropy.new_zeros(batch_size, target_response_length) + full_entropy[:, :current_len] = entropy + entropy = full_entropy + + return entropy, log_probs diff --git a/code/RL_model/verl/verl_train/verl/trainer/ppo/ray_trainer.py b/code/RL_model/verl/verl_train/verl/trainer/ppo/ray_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..92157ac509f895a73f8a22a84ffaa2832086a757 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/ppo/ray_trainer.py @@ -0,0 +1,1749 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +PPO Trainer with Ray-based single controller. +This trainer supports model-agonistic model initialization with huggingface +""" + +import json +import os +import shutil +import uuid +from collections import defaultdict +from copy import deepcopy +from pprint import pprint +from typing import Any, Optional + +import numpy as np +import ray +import torch +from omegaconf import OmegaConf, open_dict +from torch.utils.data import Dataset, Sampler +from torchdata.stateful_dataloader import StatefulDataLoader +from tqdm import tqdm + +from verl import DataProto +from verl.checkpoint_engine import CheckpointEngineManager +from verl.experimental.dataset.sampler import AbstractCurriculumSampler +from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto +from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup, ResourcePoolManager +from verl.single_controller.ray.base import create_colocated_worker_cls +from verl.trainer.config import AlgoConfig +from verl.trainer.ppo import core_algos +from verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss +from verl.trainer.ppo.metric_utils import ( + compute_data_metrics, + compute_throughout_metrics, + compute_timing_metrics, + compute_variance_proxy_metrics, + process_validation_metrics, +) +from verl.trainer.ppo.reward import compute_reward, compute_reward_async +from verl.trainer.ppo.utils import Role, WorkerType, need_critic, need_reference_policy, need_reward_model +from verl.utils import tensordict_utils as tu +from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, should_save_ckpt_esi +from verl.utils.config import omega_conf_to_dataclass +from verl.utils.debug import marked_timer +from verl.utils.import_utils import load_class_from_fqn +from verl.utils.metric import reduce_metrics +from verl.utils.py_functional import rename_dict +from verl.utils.rollout_skip import RolloutSkip +from verl.utils.seqlen_balancing import calculate_workload, get_seqlen_balanced_partitions, log_seqlen_unbalance +from verl.utils.torch_functional import masked_mean +from verl.utils.tracking import ValidationGenerationsLogger +from verl.workers.config import FSDPEngineConfig +from verl.workers.utils.padding import left_right_2_no_padding, no_padding_2_padding + + +def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl"): + """Apply KL penalty to the token-level rewards. + + This function computes the KL divergence between the reference policy and current policy, + then applies a penalty to the token-level rewards based on this divergence. + + Args: + data (DataProto): The data containing batched model outputs and inputs. + kl_ctrl (core_algos.AdaptiveKLController): Controller for adaptive KL penalty. + kl_penalty (str, optional): Type of KL penalty to apply. Defaults to "kl". + + Returns: + tuple: A tuple containing: + - The updated data with token-level rewards adjusted by KL penalty + - A dictionary of metrics related to the KL penalty + """ + response_mask = data.batch["response_mask"] + token_level_scores = data.batch["token_level_scores"] + batch_size = data.batch.batch_size[0] + + # compute kl between ref_policy and current policy + # When apply_kl_penalty, algorithm.use_kl_in_reward=True, so the reference model has been enabled. + kld = core_algos.kl_penalty( + data.batch["old_log_probs"], data.batch["ref_log_prob"], kl_penalty=kl_penalty + ) # (batch_size, response_length) + kld = kld * response_mask + beta = kl_ctrl.value + + token_level_rewards = token_level_scores - beta * kld + + current_kl = masked_mean(kld, mask=response_mask, axis=-1) # average over sequence + current_kl = torch.mean(current_kl, dim=0).item() + + # according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837 + kl_ctrl.update(current_kl=current_kl, n_steps=batch_size) + data.batch["token_level_rewards"] = token_level_rewards + + metrics = {"actor/reward_kl_penalty": current_kl, "actor/reward_kl_penalty_coeff": beta} + + return data, metrics + + +def compute_response_mask(data: DataProto): + """Compute the attention mask for the response part of the sequence. + + This function extracts the portion of the attention mask that corresponds to the model's response, + which is used for masking computations that should only apply to response tokens. + + Args: + data (DataProto): The data containing batched model outputs and inputs. + + Returns: + torch.Tensor: The attention mask for the response tokens. + """ + responses = data.batch["responses"] + response_length = responses.size(1) + attention_mask = data.batch["attention_mask"] + return attention_mask[:, -response_length:] + + +def compute_advantage( + data: DataProto, + adv_estimator: AdvantageEstimator, + gamma: float = 1.0, + lam: float = 1.0, + num_repeat: int = 1, + norm_adv_by_std_in_grpo: bool = True, + config: Optional[AlgoConfig] = None, +) -> DataProto: + """Compute advantage estimates for policy optimization. + + This function computes advantage estimates using various estimators like GAE, GRPO, REINFORCE++, etc. + The advantage estimates are used to guide policy optimization in RL algorithms. + + Args: + data (DataProto): The data containing batched model outputs and inputs. + adv_estimator (AdvantageEstimator): The advantage estimator to use (e.g., GAE, GRPO, REINFORCE++). + gamma (float, optional): Discount factor for future rewards. Defaults to 1.0. + lam (float, optional): Lambda parameter for GAE. Defaults to 1.0. + num_repeat (int, optional): Number of times to repeat the computation. Defaults to 1. + norm_adv_by_std_in_grpo (bool, optional): Whether to normalize advantages by standard deviation in + GRPO. Defaults to True. + config (dict, optional): Configuration dictionary for algorithm settings. Defaults to None. + + Returns: + DataProto: The updated data with computed advantages and returns. + """ + # Back-compatible with trainers that do not compute response mask in fit + if "response_mask" not in data.batch.keys(): + data.batch["response_mask"] = compute_response_mask(data) + # prepare response group + if adv_estimator == AdvantageEstimator.GAE: + # Compute advantages and returns using Generalized Advantage Estimation (GAE) + advantages, returns = core_algos.compute_gae_advantage_return( + token_level_rewards=data.batch["token_level_rewards"], + values=data.batch["values"], + response_mask=data.batch["response_mask"], + gamma=gamma, + lam=lam, + ) + data.batch["advantages"] = advantages + data.batch["returns"] = returns + if config.get("use_pf_ppo", False): + data = core_algos.compute_pf_ppo_reweight_data( + data, + config.pf_ppo.get("reweight_method"), + config.pf_ppo.get("weight_pow"), + ) + elif adv_estimator == AdvantageEstimator.GRPO: + # Initialize the mask for GRPO calculation + grpo_calculation_mask = data.batch["response_mask"] + + # Call compute_grpo_outcome_advantage with parameters matching its definition + advantages, returns = core_algos.compute_grpo_outcome_advantage( + token_level_rewards=data.batch["token_level_rewards"], + response_mask=grpo_calculation_mask, + index=data.non_tensor_batch["uid"], + norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, + ) + data.batch["advantages"] = advantages + data.batch["returns"] = returns + else: + # handle all other adv estimator type other than GAE and GRPO + adv_estimator_fn = core_algos.get_adv_estimator_fn(adv_estimator) + adv_kwargs = { + "token_level_rewards": data.batch["token_level_rewards"], + "response_mask": data.batch["response_mask"], + "config": config, + } + if "uid" in data.non_tensor_batch: # optional + adv_kwargs["index"] = data.non_tensor_batch["uid"] + if "reward_baselines" in data.batch: # optional + adv_kwargs["reward_baselines"] = data.batch["reward_baselines"] + # Add sum_pi_squared for Optimal Token Baseline + if adv_estimator in (AdvantageEstimator.OPTIMAL_TOKEN_BASELINE, AdvantageEstimator.TIR_OPTIMAL_TOKEN_BASELINE): + # Check if sum_pi_squared is available + assert "sum_pi_squared" in data.batch, ( + "Step-dependent optimal baseline requires sum_pi_squared from actor. " + "Please set actor.calculate_sum_pi_squared=True in config." + ) + adv_kwargs["sum_pi_squared"] = data.batch["sum_pi_squared"] + # Get pre-computed rollout IS weights if available + rollout_is_weights = data.batch.get("rollout_is_weights", None) + adv_kwargs["rollout_is_weights"] = rollout_is_weights + + # calculate advantage estimator + advantages, returns = adv_estimator_fn(**adv_kwargs) + data.batch["advantages"] = advantages + data.batch["returns"] = returns + return data + + +class RayPPOTrainer: + """Distributed PPO trainer using Ray for scalable reinforcement learning. + + This trainer orchestrates distributed PPO training across multiple nodes and GPUs, + managing actor rollouts, critic training, and reward computation with Ray backend. + Supports various model architectures including FSDP, Megatron, vLLM, and SGLang integration. + """ + + # TODO: support each role have individual ray_worker_group_cls, + # i.e., support different backend of different role + def __init__( + self, + config, + tokenizer, + role_worker_mapping: dict[Role, WorkerType], + resource_pool_manager: ResourcePoolManager, + ray_worker_group_cls: type[RayWorkerGroup] = RayWorkerGroup, + processor=None, + reward_fn=None, + val_reward_fn=None, + train_dataset: Optional[Dataset] = None, + val_dataset: Optional[Dataset] = None, + collate_fn=None, + train_sampler: Optional[Sampler] = None, + device_name=None, + ): + """ + Initialize distributed PPO trainer with Ray backend. + Note that this trainer runs on the driver process on a single CPU/GPU node. + + Args: + config: Configuration object containing training parameters. + tokenizer: Tokenizer used for encoding and decoding text. + role_worker_mapping (dict[Role, WorkerType]): Mapping from roles to worker classes. + resource_pool_manager (ResourcePoolManager): Manager for Ray resource pools. + ray_worker_group_cls (RayWorkerGroup, optional): Class for Ray worker groups. Defaults to RayWorkerGroup. + processor: Optional data processor, used for multimodal data + reward_fn: Function for computing rewards during training. + val_reward_fn: Function for computing rewards during validation. + train_dataset (Optional[Dataset], optional): Training dataset. Defaults to None. + val_dataset (Optional[Dataset], optional): Validation dataset. Defaults to None. + collate_fn: Function to collate data samples into batches. + train_sampler (Optional[Sampler], optional): Sampler for the training dataset. Defaults to None. + device_name (str, optional): Device name for training (e.g., "cuda", "cpu"). Defaults to None. + """ + + # Store the tokenizer for text processing + self.tokenizer = tokenizer + self.processor = processor + self.config = config + self.reward_fn = reward_fn + self.val_reward_fn = val_reward_fn + + self.hybrid_engine = config.actor_rollout_ref.hybrid_engine + assert self.hybrid_engine, "Currently, only support hybrid engine" + + if self.hybrid_engine: + assert Role.ActorRollout in role_worker_mapping or Role.ActorRolloutRef in role_worker_mapping, ( + f"{role_worker_mapping.keys()=}" + ) + + self.role_worker_mapping = role_worker_mapping + self.resource_pool_manager = resource_pool_manager + self.use_reference_policy = need_reference_policy(self.config) + # legacy reward model implementation + self.use_rm = need_reward_model(self.role_worker_mapping) + self.use_reward_loop = self.config.reward_model.use_reward_loop + + self.use_critic = need_critic(self.config) + self.ray_worker_group_cls = ray_worker_group_cls + self.device_name = device_name if device_name else self.config.trainer.device + self.validation_generations_logger = ValidationGenerationsLogger( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + ) + + # if ref_in_actor is True, the reference policy will be actor without lora applied + lora_rank = config.actor_rollout_ref.model.get("lora", {}).get("rank", 0) + if lora_rank <= 0: + lora_rank = config.actor_rollout_ref.model.get("lora_rank", 0) + self.ref_in_actor = lora_rank > 0 or config.actor_rollout_ref.model.get("lora_adapter_path") is not None + + # define in-reward KL control + # kl loss control currently not suppoorted + if self.config.algorithm.use_kl_in_reward: + self.kl_ctrl_in_reward = core_algos.get_kl_controller(self.config.algorithm.kl_ctrl) + + self.use_prefix_grouper = self.config.actor_rollout_ref.actor.get("use_prefix_grouper", False) + self.use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto") + + self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler) + + def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]): + """ + Creates the train and validation dataloaders. + """ + # TODO: we have to make sure the batch size is divisible by the dp size + from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler + + if train_dataset is None: + train_dataset = create_rl_dataset( + self.config.data.train_files, + self.config.data, + self.tokenizer, + self.processor, + max_samples=self.config.data.get("train_max_samples", -1), + ) + if val_dataset is None: + val_dataset = create_rl_dataset( + self.config.data.val_files, + self.config.data, + self.tokenizer, + self.processor, + max_samples=self.config.data.get("val_max_samples", -1), + ) + self.train_dataset, self.val_dataset = train_dataset, val_dataset + + if train_sampler is None: + train_sampler = create_rl_sampler(self.config.data, self.train_dataset) + if collate_fn is None: + from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn + + collate_fn = default_collate_fn + + num_workers = self.config.data["dataloader_num_workers"] + + self.train_dataloader = StatefulDataLoader( + dataset=self.train_dataset, + batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size), + num_workers=num_workers, + drop_last=True, + collate_fn=collate_fn, + sampler=train_sampler, + ) + + val_batch_size = self.config.data.val_batch_size # Prefer config value if set + if val_batch_size is None: + val_batch_size = len(self.val_dataset) + + self.val_dataloader = StatefulDataLoader( + dataset=self.val_dataset, + batch_size=val_batch_size, + num_workers=num_workers, + shuffle=self.config.data.get("validation_shuffle", True), + drop_last=False, + collate_fn=collate_fn, + ) + + assert len(self.train_dataloader) >= 1, "Train dataloader is empty!" + assert len(self.val_dataloader) >= 1, "Validation dataloader is empty!" + + print( + f"Size of train dataloader: {len(self.train_dataloader)}, Size of val dataloader: " + f"{len(self.val_dataloader)}" + ) + + total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs + + if self.config.trainer.total_training_steps is not None: + total_training_steps = self.config.trainer.total_training_steps + + self.total_training_steps = total_training_steps + print(f"Total training steps: {self.total_training_steps}") + + try: + OmegaConf.set_struct(self.config, True) + with open_dict(self.config): + if OmegaConf.select(self.config, "actor_rollout_ref.actor.optim"): + self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps + if OmegaConf.select(self.config, "critic.optim"): + self.config.critic.optim.total_training_steps = total_training_steps + except Exception as e: + print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}") + + def _dump_generations(self, inputs, outputs, gts, scores, reward_extra_infos_dict, dump_path): + """Dump rollout/validation samples as JSONL.""" + os.makedirs(dump_path, exist_ok=True) + filename = os.path.join(dump_path, f"{self.global_steps}.jsonl") + + n = len(inputs) + base_data = { + "input": inputs, + "output": outputs, + "gts": gts, + "score": scores, + "step": [self.global_steps] * n, + } + + for k, v in reward_extra_infos_dict.items(): + if len(v) == n: + base_data[k] = v + + lines = [] + for i in range(n): + entry = {k: v[i] for k, v in base_data.items()} + lines.append(json.dumps(entry, ensure_ascii=False)) + + with open(filename, "w") as f: + f.write("\n".join(lines) + "\n") + + print(f"Dumped generations to {filename}") + + def _log_rollout_data( + self, batch: DataProto, reward_extra_infos_dict: dict, timing_raw: dict, rollout_data_dir: str + ): + """Log rollout data to disk. + Args: + batch (DataProto): The batch containing rollout data + reward_extra_infos_dict (dict): Additional reward information to log + timing_raw (dict): Timing information for profiling + rollout_data_dir (str): Directory path to save the rollout data + """ + with marked_timer("dump_rollout_generations", timing_raw, color="green"): + inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True) + outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True) + scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist() + sample_gts = [item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in batch] + + reward_extra_infos_to_dump = reward_extra_infos_dict.copy() + if "request_id" in batch.non_tensor_batch: + reward_extra_infos_dict.setdefault( + "request_id", + batch.non_tensor_batch["request_id"].tolist(), + ) + + self._dump_generations( + inputs=inputs, + outputs=outputs, + gts=sample_gts, + scores=scores, + reward_extra_infos_dict=reward_extra_infos_to_dump, + dump_path=rollout_data_dir, + ) + + def _maybe_log_val_generations(self, inputs, outputs, scores): + """Log a table of validation samples to the configured logger (wandb or swanlab)""" + + generations_to_log = self.config.trainer.log_val_generations + + if generations_to_log == 0: + return + + import numpy as np + + # Create tuples of (input, output, score) and sort by input text + samples = list(zip(inputs, outputs, scores, strict=True)) + samples.sort(key=lambda x: x[0]) # Sort by input text + + # Use fixed random seed for deterministic shuffling + rng = np.random.RandomState(42) + rng.shuffle(samples) + + # Take first N samples after shuffling + samples = samples[:generations_to_log] + + # Log to each configured logger + self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps) + + def _compute_or_extract_reward( + self, + batch: DataProto, + reward_fn=None, + reward_for_val: bool = False, + sum_reward: bool = False, + ) -> tuple[torch.Tensor, dict[str, Any]] | torch.Tensor: + """ + Compute or extract reward from batch. + + When use_reward_loop=True, rewards are already computed during generate_sequences + and stored in rm_scores. This method directly extracts them instead of calling + reward functions which would only perform format conversion. + + Args: + batch: DataProto containing the batch data + reward_fn: Reward function to use if rm_scores doesn't exist (for training/validation) + reward_for_val: Whether this is for validation + sum_reward: Whether to sum reward tensor along last dimension (for REMAX baseline) + + Returns: + If reward_for_val=False and sum_reward=True: summed reward_tensor (1D tensor) + Otherwise: tuple of (reward_tensor, reward_extra_infos_dict) + """ + # When rm_scores already exists, extract it directly (format conversion only) + if "rm_scores" in batch.batch.keys(): + reward_tensor = batch.batch["rm_scores"] + if sum_reward: + reward_tensor = reward_tensor.sum(dim=-1) + + if not reward_for_val and sum_reward: + return reward_tensor + + reward_extra_keys = batch.meta_info.get("reward_extra_keys", []) + reward_extra_infos_dict = ( + {key: batch.non_tensor_batch[key] for key in reward_extra_keys} if reward_extra_keys else {} + ) + return reward_tensor, reward_extra_infos_dict + + # Otherwise, compute reward using reward_fn + if reward_fn is None: + raise ValueError("reward_fn must be provided when rm_scores is not available.") + + if reward_for_val: + result = reward_fn(batch, return_dict=True) + reward_tensor = result["reward_tensor"] + if sum_reward: + reward_tensor = reward_tensor.sum(dim=-1) + reward_extra_infos_dict = result.get("reward_extra_info", {}) + return reward_tensor, reward_extra_infos_dict + else: + reward_tensor, reward_extra_infos_dict = compute_reward(batch, reward_fn) + if sum_reward: + reward_tensor = reward_tensor.sum(dim=-1) + return reward_tensor, reward_extra_infos_dict + + def _get_gen_batch(self, batch: DataProto) -> DataProto: + reward_model_keys = set({"data_source", "reward_model", "extra_info", "uid"}) & batch.non_tensor_batch.keys() + + # pop those keys for generation + batch_keys_to_pop = [] + non_tensor_batch_keys_to_pop = set(batch.non_tensor_batch.keys()) - reward_model_keys + gen_batch = batch.pop( + batch_keys=batch_keys_to_pop, + non_tensor_batch_keys=list(non_tensor_batch_keys_to_pop), + ) + + # For agent loop, we need reward model keys to compute score. + if self.async_rollout_mode: + gen_batch.non_tensor_batch.update(batch.non_tensor_batch) + + return gen_batch + + def _validate(self, merged: bool = False): + data_source_lst = [] + reward_extra_infos_dict: dict[str, list] = defaultdict(list) + + # Lists to collect samples for the table + sample_inputs = [] + sample_outputs = [] + sample_gts = [] + sample_scores = [] + sample_turns = [] + sample_uids = [] + + for test_data in self.val_dataloader: + test_batch = DataProto.from_single_dict(test_data) + + if "uid" not in test_batch.non_tensor_batch: + test_batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(test_batch.batch))], dtype=object + ) + + # repeat test batch + test_batch = test_batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True + ) + + # we only do validation on rule-based rm + if self.config.reward_model.enable and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model": + return {} + + ground_truths = [ + item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in test_batch + ] + sample_gts.extend(ground_truths) + + test_gen_batch = self._get_gen_batch(test_batch) + test_gen_batch.meta_info = { + "eos_token_id": self.tokenizer.eos_token_id, + "pad_token_id": self.tokenizer.pad_token_id, + "recompute_log_prob": False, + "do_sample": self.config.actor_rollout_ref.rollout.val_kwargs.do_sample, + "validate": True, + "global_steps": self.global_steps, + } + print(f"test_gen_batch meta info: {test_gen_batch.meta_info}") + + # pad to be divisible by dp_size + size_divisor = ( + self.actor_rollout_wg.world_size + if not self.async_rollout_mode + else self.config.actor_rollout_ref.rollout.agent.num_workers + ) + test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, size_divisor) + if not self.async_rollout_mode: + test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded) + else: + test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded) + + # unpad + test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size) + + print("validation generation end") + + # Store generated outputs + output_ids = test_output_gen_batch.batch["responses"] + output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] + sample_outputs.extend(output_texts) + + test_batch = test_batch.union(test_output_gen_batch) + test_batch.meta_info["validate"] = True + + # Store original inputs + input_ids = test_batch.batch["prompts"] + # TODO: Can we keep special tokens except for padding tokens? + input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] + sample_inputs.extend(input_texts) + sample_uids.extend(test_batch.non_tensor_batch["uid"]) + + # evaluate using reward_function + reward_tensor, reward_extra_info = self._compute_or_extract_reward( + test_batch, reward_fn=self.val_reward_fn, reward_for_val=True + ) + scores = reward_tensor.sum(-1).cpu().tolist() + sample_scores.extend(scores) + + reward_extra_infos_dict["reward"].extend(scores) + for key, values in reward_extra_info.items(): + if key not in reward_extra_infos_dict: + reward_extra_infos_dict[key] = [] + if isinstance(values, np.ndarray): + reward_extra_infos_dict[key].extend(values.tolist()) + else: + reward_extra_infos_dict[key].extend(values if isinstance(values, list) else [values]) + + # collect num_turns of each prompt + if "__num_turns__" in test_batch.non_tensor_batch: + sample_turns.append(test_batch.non_tensor_batch["__num_turns__"]) + + data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0])) + + self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores) + + # dump generations + val_data_dir = self.config.trainer.get("validation_data_dir", None) + if val_data_dir: + self._dump_generations( + inputs=sample_inputs, + outputs=sample_outputs, + gts=sample_gts, + scores=sample_scores, + reward_extra_infos_dict=reward_extra_infos_dict, + dump_path=val_data_dir, + ) + + for key_info, lst in reward_extra_infos_dict.items(): + assert len(lst) == 0 or len(lst) == len(sample_scores), f"{key_info}: {len(lst)=}, {len(sample_scores)=}" + + if merged: + print("_merge_validation_results validate result will be merged") + return { + "data_sources": data_source_lst, + "sample_uids": sample_uids, + "sample_turns": sample_turns, + "reward_extra_infos_dict": reward_extra_infos_dict, + } + data_sources = np.concatenate(data_source_lst, axis=0) + return self._val_metrics_update(data_sources, sample_uids, reward_extra_infos_dict, sample_turns) + + def _val_metrics_update(self, data_sources, sample_uids, reward_extra_infos_dict, sample_turns): + data_src2var2metric2val = process_validation_metrics(data_sources, sample_uids, reward_extra_infos_dict) + metric_dict = {} + for data_source, var2metric2val in data_src2var2metric2val.items(): + core_var = "acc" if "acc" in var2metric2val else "reward" + for var_name, metric2val in var2metric2val.items(): + n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()]) + for metric_name, metric_val in metric2val.items(): + if ( + (var_name == core_var) + and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best"]) + and (f"@{n_max}" in metric_name) + ): + metric_sec = "val-core" + else: + metric_sec = "val-aux" + pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}" + metric_dict[pfx] = metric_val + + if len(sample_turns) > 0: + sample_turns = np.concatenate(sample_turns) + metric_dict["val-aux/num_turns/min"] = sample_turns.min() + metric_dict["val-aux/num_turns/max"] = sample_turns.max() + metric_dict["val-aux/num_turns/mean"] = sample_turns.mean() + + return metric_dict + + def _merge_validation_results(self, result_a, result_b): + if result_a is None and result_b is None: + return {} + if result_a is None: + result_a = {"data_sources": [], "sample_uids": [], "sample_turns": [], "reward_extra_infos_dict": {}} + if result_b is None: + result_b = {"data_sources": [], "sample_uids": [], "sample_turns": [], "reward_extra_infos_dict": {}} + + if not result_a.get("data_sources") and not result_b.get("data_sources"): + return {} + + data_sources = np.concatenate(result_a["data_sources"] + result_b["data_sources"], axis=0) + sample_uids = result_a["sample_uids"] + result_b["sample_uids"] + sample_turns = result_a["sample_turns"] + result_b["sample_turns"] + + reward_extra_infos_dict = {} + all_keys = set(result_a["reward_extra_infos_dict"].keys()) | set(result_b["reward_extra_infos_dict"].keys()) + for key in all_keys: + list_a = result_a["reward_extra_infos_dict"].get(key, []) + list_b = result_b["reward_extra_infos_dict"].get(key, []) + reward_extra_infos_dict[key] = list_a + list_b + + return self._val_metrics_update(data_sources, sample_uids, reward_extra_infos_dict, sample_turns) + + def init_workers(self): + """Initialize distributed training workers using Ray backend. + + Creates: + 1. Ray resource pools from configuration + 2. Worker groups for each role (actor, critic, etc.) + """ + self.resource_pool_manager.create_resource_pool() + + self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} + + # create actor and rollout + actor_role = Role.ActorRolloutRef if Role.ActorRolloutRef in self.role_worker_mapping else Role.ActorRollout + if self.hybrid_engine: + actor_rollout_resource_pool = self.resource_pool_manager.get_resource_pool(actor_role) + actor_rollout_cls = RayClassWithInitArgs( + cls=self.role_worker_mapping[actor_role], + config=self.config.actor_rollout_ref, + role=str(actor_role), + ) + self.resource_pool_to_cls[actor_rollout_resource_pool][str(actor_role)] = actor_rollout_cls + else: + raise NotImplementedError + + # create critic + if self.use_critic: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) + + from verl.workers.config import CriticConfig + + critic_cfg: CriticConfig = omega_conf_to_dataclass(self.config.critic) + + if self.use_legacy_worker_impl == "disable": + # convert critic_cfg into TrainingWorkerConfig + from verl.workers.engine_workers import TrainingWorkerConfig + + orig_critic_cfg = critic_cfg + if orig_critic_cfg.strategy == "fsdp": + engine_config: FSDPEngineConfig = orig_critic_cfg.model.fsdp_config + engine_config.infer_max_token_len_per_gpu = critic_cfg.ppo_infer_max_token_len_per_gpu + engine_config.max_token_len_per_gpu = critic_cfg.ppo_max_token_len_per_gpu + else: + raise NotImplementedError(f"Unknown strategy {orig_critic_cfg.strategy=}") + + critic_cfg = TrainingWorkerConfig( + model_type="value_model", + model_config=orig_critic_cfg.model_config, + engine_config=engine_config, + optimizer_config=orig_critic_cfg.optim, + checkpoint_config=orig_critic_cfg.checkpoint, + ) + + critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=critic_cfg) + self.resource_pool_to_cls[resource_pool][str(Role.Critic)] = critic_cls + + # create reference policy if needed + if self.use_reference_policy and Role.RefPolicy in self.role_worker_mapping: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) + ref_policy_cls = RayClassWithInitArgs( + self.role_worker_mapping[Role.RefPolicy], + config=self.config.actor_rollout_ref, + role=str(Role.RefPolicy), + ) + self.resource_pool_to_cls[resource_pool][str(Role.RefPolicy)] = ref_policy_cls + + # create a reward model if reward_fn is None + # for legacy discriminative reward model, we create a reward model worker here + # for reward loop discriminative reward model, we create a reward loop manager here + if not self.use_reward_loop: + # legacy reward model only handle reward-model based scenario + if self.use_rm: + # we create a RM here + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) + rm_cls = RayClassWithInitArgs( + self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model + ) + self.resource_pool_to_cls[resource_pool][str(Role.RewardModel)] = rm_cls + else: + # reward loop handle hybrid reward scenario (rule, disrm, genrm, ...) + # Note: mode is always "async" since sync mode is deprecated + can_reward_loop_parallelize = not self.use_rm or self.config.reward_model.enable_resource_pool + # judge if we can asynchronously parallelize reward model with actor rollout + # two condition that we can parallelize reward model with actor rollout: + # 1. reward model is not enabled (rule-based reward can parallelize) + # 2. reward model is enabled but extra resource pool is enabled + # If we cannot parallelize, we should enable synchronous mode here, and launch a reward loop manager here + # else for parallelize mode, we launch a reward worker for each rollout worker (in agent loop, not here) + if not can_reward_loop_parallelize: + from verl.experimental.reward_loop import RewardLoopManager + + self.config.reward_model.n_gpus_per_node = self.config.trainer.n_gpus_per_node + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) + self.reward_loop_manager = RewardLoopManager( + config=self.config, + rm_resource_pool=resource_pool, + ) + + # initialize WorkerGroup + # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, + # you should not use `create_colocated_worker_cls`. + # Instead, directly pass different resource pool to different worker groups. + # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. + all_wg = {} + wg_kwargs = {} # Setting up kwargs for RayWorkerGroup + if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None: + wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout + if OmegaConf.select(self.config.global_profiler, "steps") is not None: + wg_kwargs["profile_steps"] = OmegaConf.select(self.config.global_profiler, "steps") + # Only require nsight worker options when tool is nsys + if OmegaConf.select(self.config.global_profiler, "tool") == "nsys": + assert ( + OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options") + is not None + ), "worker_nsight_options must be set when using nsys with profile_steps" + wg_kwargs["worker_nsight_options"] = OmegaConf.to_container( + OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options") + ) + wg_kwargs["device_name"] = self.device_name + + for resource_pool, class_dict in self.resource_pool_to_cls.items(): + worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) + wg_dict = self.ray_worker_group_cls( + resource_pool=resource_pool, + ray_cls_with_init=worker_dict_cls, + **wg_kwargs, + ) + spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) + all_wg.update(spawn_wg) + + if self.use_critic: + self.critic_wg = all_wg[str(Role.Critic)] + if self.use_legacy_worker_impl == "disable": + self.critic_wg.reset() + # assign critic loss + from functools import partial + + from verl.workers.utils.losses import value_loss + + value_loss_ = partial(value_loss, config=orig_critic_cfg) + self.critic_wg.set_loss_fn(value_loss_) + else: + self.critic_wg.init_model() + + if self.use_reference_policy and not self.ref_in_actor: + if str(Role.RefPolicy) in all_wg: + self.ref_policy_wg = all_wg[str(Role.RefPolicy)] + self.ref_policy_wg.init_model() + else: + # Model engine: ActorRolloutRefWorker + assert str(Role.ActorRolloutRef) in all_wg, f"{all_wg.keys()=}" + self.ref_policy_wg = all_wg[str(Role.ActorRolloutRef)] + + self.rm_wg = None + # initalization of rm_wg will be deprecated in the future + if self.use_rm and not self.use_reward_loop: + self.rm_wg = all_wg[str(Role.RewardModel)] + self.rm_wg.init_model() + + # we should create rollout at the end so that vllm can have a better estimation of kv cache memory + self.actor_rollout_wg = all_wg[str(actor_role)] + self.actor_rollout_wg.init_model() + + if self.ref_in_actor: + self.ref_policy_wg = self.actor_rollout_wg + + # create async rollout manager and request scheduler + # Note: mode is always "async" since sync mode is deprecated + self.async_rollout_mode = True + + # Support custom AgentLoopManager via config + manager_class_fqn = self.config.actor_rollout_ref.rollout.get("agent", {}).get("agent_loop_manager_class") + if manager_class_fqn: + AgentLoopManager = load_class_from_fqn(manager_class_fqn, "AgentLoopManager") + else: + from verl.experimental.agent_loop import AgentLoopManager + + if self.config.reward_model.enable and self.config.reward_model.enable_resource_pool: + rm_resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) + else: + rm_resource_pool = None + + self.async_rollout_manager = AgentLoopManager( + config=self.config, + worker_group=self.actor_rollout_wg, + rollout_resource_pool=actor_rollout_resource_pool, + rm_resource_pool=rm_resource_pool, + ) + + self.checkpoint_manager = CheckpointEngineManager( + backend=self.config.actor_rollout_ref.rollout.checkpoint_engine.backend, + trainer=self.actor_rollout_wg, + replicas=self.async_rollout_manager.rollout_replicas, + ) + + # sleep all replicas to load checkpoint + self.checkpoint_manager.sleep_replicas() + + def _save_checkpoint(self): + from verl.utils.fs import local_mkdir_safe + + # path: given_path + `/global_step_{global_steps}` + `/actor` + local_global_step_folder = os.path.join( + self.config.trainer.default_local_dir, f"global_step_{self.global_steps}" + ) + + print(f"local_global_step_folder: {local_global_step_folder}") + actor_local_path = os.path.join(local_global_step_folder, "actor") + + actor_remote_path = ( + None + if self.config.trainer.default_hdfs_dir is None + else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor") + ) + + remove_previous_ckpt_in_save = self.config.trainer.get("remove_previous_ckpt_in_save", False) + if remove_previous_ckpt_in_save: + print( + "Warning: remove_previous_ckpt_in_save is deprecated," + + " set max_actor_ckpt_to_keep=1 and max_critic_ckpt_to_keep=1 instead" + ) + max_actor_ckpt_to_keep = ( + self.config.trainer.get("max_actor_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 + ) + max_critic_ckpt_to_keep = ( + self.config.trainer.get("max_critic_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 + ) + + self.actor_rollout_wg.save_checkpoint( + actor_local_path, actor_remote_path, self.global_steps, max_ckpt_to_keep=max_actor_ckpt_to_keep + ) + + if self.use_critic: + critic_local_path = os.path.join(local_global_step_folder, str(Role.Critic)) + critic_remote_path = ( + None + if self.config.trainer.default_hdfs_dir is None + else os.path.join( + self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", str(Role.Critic) + ) + ) + self.critic_wg.save_checkpoint( + critic_local_path, critic_remote_path, self.global_steps, max_ckpt_to_keep=max_critic_ckpt_to_keep + ) + + # save dataloader + local_mkdir_safe(local_global_step_folder) + dataloader_local_path = os.path.join(local_global_step_folder, "data.pt") + dataloader_state_dict = self.train_dataloader.state_dict() + torch.save(dataloader_state_dict, dataloader_local_path) + + if remove_previous_ckpt_in_save: + self._remove_old_global_step_dirs(self.global_steps) + + # latest checkpointed iteration tracker (for atomic usage) + if ( + hasattr(self.config.actor_rollout_ref.actor.checkpoint, "async_save") + and self.config.actor_rollout_ref.actor.checkpoint.async_save + ) or ( + "async_save" in self.config.actor_rollout_ref.actor.checkpoint + and self.config.actor_rollout_ref.actor.checkpoint["async_save"] + ): + print("skip write latest_checkpointed_iteration.txt when async_save is True") + return + local_latest_checkpointed_iteration = os.path.join( + self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt" + ) + with open(local_latest_checkpointed_iteration, "w") as f: + f.write(str(self.global_steps)) + + def _remove_old_global_step_dirs(self, current_step: int) -> None: + checkpoint_root = self.config.trainer.default_local_dir + if not checkpoint_root: + return + if not os.path.isabs(checkpoint_root): + checkpoint_root = os.path.join(os.getcwd(), checkpoint_root) + if not os.path.isdir(checkpoint_root): + return + for name in os.listdir(checkpoint_root): + if not name.startswith("global_step_"): + continue + step_str = name.split("global_step_")[-1] + if not step_str.isdigit(): + continue + step = int(step_str) + if step == current_step: + continue + path = os.path.join(checkpoint_root, name) + try: + shutil.rmtree(path) + print(f"Removed old checkpoint directory: {path}") + except Exception as exc: + print(f"Warning: failed to remove old checkpoint directory {path}: {exc}") + + def _load_checkpoint(self): + if self.config.trainer.resume_mode == "disable": + return 0 + + # load from hdfs + if self.config.trainer.default_hdfs_dir is not None: + raise NotImplementedError("load from hdfs is not implemented yet") + else: + checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path + if not os.path.isabs(checkpoint_folder): + working_dir = os.getcwd() + checkpoint_folder = os.path.join(working_dir, checkpoint_folder) + global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest + + # find global_step_folder + if self.config.trainer.resume_mode == "auto": + if global_step_folder is None: + print("Training from scratch") + return 0 + else: + if self.config.trainer.resume_mode == "resume_path": + assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type" + assert "global_step_" in self.config.trainer.resume_from_path, ( + "resume ckpt must specify the global_steps" + ) + global_step_folder = self.config.trainer.resume_from_path + if not os.path.isabs(global_step_folder): + working_dir = os.getcwd() + global_step_folder = os.path.join(working_dir, global_step_folder) + print(f"Load from checkpoint folder: {global_step_folder}") + # set global step + self.global_steps = int(global_step_folder.split("global_step_")[-1]) + + print(f"Setting global step to {self.global_steps}") + print(f"Resuming from {global_step_folder}") + + actor_path = os.path.join(global_step_folder, "actor") + critic_path = os.path.join(global_step_folder, str(Role.Critic)) + # load actor + self.actor_rollout_wg.load_checkpoint( + actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load + ) + # load critic + if self.use_critic: + self.critic_wg.load_checkpoint( + critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load + ) + + # load dataloader, + # TODO: from remote not implemented yet + dataloader_local_path = os.path.join(global_step_folder, "data.pt") + if os.path.exists(dataloader_local_path): + dataloader_state_dict = torch.load(dataloader_local_path, weights_only=False) + self.train_dataloader.load_state_dict(dataloader_state_dict) + else: + print(f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch") + + def _start_profiling(self, do_profile: bool) -> None: + """Start profiling for all worker groups if profiling is enabled.""" + if do_profile: + self.actor_rollout_wg.start_profile(role="e2e", profile_step=self.global_steps) + if self.use_reference_policy: + self.ref_policy_wg.start_profile(profile_step=self.global_steps) + if self.use_critic: + self.critic_wg.start_profile(profile_step=self.global_steps) + if self.use_rm and not self.use_reward_loop: + self.rm_wg.start_profile(profile_step=self.global_steps) + + def _stop_profiling(self, do_profile: bool) -> None: + """Stop profiling for all worker groups if profiling is enabled.""" + if do_profile: + self.actor_rollout_wg.stop_profile() + if self.use_reference_policy: + self.ref_policy_wg.stop_profile() + if self.use_critic: + self.critic_wg.stop_profile() + if self.use_rm and not self.use_reward_loop: + self.rm_wg.stop_profile() + + def _get_dp_size(self, worker_group, role: str) -> int: + """Get data parallel size from worker group dispatch info. + + This method retrieves the data parallel size by querying the dispatch info + for the specified role. The dispatch info is cached for subsequent calls. + + Args: + worker_group: The worker group to query dispatch info from. + role: The role name (e.g., "actor", "critic") to get DP size for. + + Returns: + The data parallel size (number of DP ranks). + """ + if role not in worker_group._dispatch_info: + dp_rank_mapping = worker_group._query_dispatch_info(role) + worker_group._dispatch_info[role] = dp_rank_mapping + else: + dp_rank_mapping = worker_group._dispatch_info[role] + return max(dp_rank_mapping) + 1 + + def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen", keep_minibatch=False): + """Reorder the data on single controller such that each dp rank gets similar total tokens. + + When use_prefix_grouper is enabled, uses group-level balancing to keep samples with + the same uid together on the same rank for prefix sharing optimization. + """ + attention_mask = batch.batch["attention_mask"] + batch_size = attention_mask.shape[0] + global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1) # (train_batch_size,) + workload_lst = calculate_workload(global_seqlen_lst) + # Get dp_size from dispatch info to correctly balance across data parallel ranks + # Note: world_size may include tensor/pipeline parallel dimensions, but we only want DP + dp_size = self._get_dp_size(self.actor_rollout_wg, "actor") + + # Use group-level balancing for PrefixGrouper to keep same-uid samples together + if getattr(self, "use_prefix_grouper", False) and "uid" in batch.non_tensor_batch: + from verl.utils.seqlen_balancing import get_group_balanced_partitions + + uid_list = list(batch.non_tensor_batch["uid"]) + seqlen_list = global_seqlen_lst.tolist() + + # Count number of uid groups + num_groups = len(set(uid_list)) + + if num_groups % dp_size != 0: + raise ValueError( + f"PrefixGrouper with balance_batch requires num_uid_groups ({num_groups}) " + f"% dp_size ({dp_size}) == 0. " + f"This ensures each rank gets equal number of groups. " + f"Current batch_size={batch_size}, adjust batch_size to be a multiple of " + f"dp_size * rollout.n." + ) + + global_partition_lst = get_group_balanced_partitions( + seqlen_list=seqlen_list, + uid_list=uid_list, + k_partitions=dp_size, + ) + + elif keep_minibatch: + # Decouple the DP balancing and mini-batching. + minibatch_size = self.config.actor_rollout_ref.actor.get("ppo_mini_batch_size") + minibatch_num = len(workload_lst) // minibatch_size + global_partition_lst = [[] for _ in range(dp_size)] + for i in range(minibatch_num): + rearrange_minibatch_lst = get_seqlen_balanced_partitions( + workload_lst[i * minibatch_size : (i + 1) * minibatch_size], + k_partitions=dp_size, + equal_size=True, + ) + for j, part in enumerate(rearrange_minibatch_lst): + global_partition_lst[j].extend([x + minibatch_size * i for x in part]) + else: + global_partition_lst = get_seqlen_balanced_partitions(workload_lst, k_partitions=dp_size, equal_size=True) + # Place smaller micro-batches at both ends to reduce the bubbles in pipeline parallel. + # Skip reordering within partitions for PrefixGrouper to maintain uid grouping + if not getattr(self, "use_prefix_grouper", False): + for idx, partition in enumerate(global_partition_lst): + partition.sort(key=lambda x: (workload_lst[x], x)) + ordered_partition = partition[::2] + partition[1::2][::-1] + global_partition_lst[idx] = ordered_partition + + # reorder based on index. The data will be automatically equally partitioned by dispatch function + global_idx = torch.tensor([j for partition in global_partition_lst for j in partition]) + batch.reorder(global_idx) + global_balance_stats = log_seqlen_unbalance( + seqlen_list=global_seqlen_lst.tolist(), partitions=global_partition_lst, prefix=logging_prefix + ) + metrics.update(global_balance_stats) + + def _compute_values(self, batch: DataProto) -> DataProto: + if self.use_legacy_worker_impl == "disable": + batch_td = batch.to_tensordict() + # step 2: convert from padding to nopadding + batch_td = left_right_2_no_padding(batch_td) + # step 3: add meta info + tu.assign_non_tensor(batch_td, compute_loss=False) + output = self.critic_wg.infer_batch(batch_td) + output = output.get() + values = tu.get(output, "values") + values = no_padding_2_padding(values, batch_td) + values = tu.get_tensordict({"values": values.float()}) + values = DataProto.from_tensordict(values) + else: + values = self.critic_wg.compute_values(batch) + return values + + def _compute_ref_log_prob(self, batch: DataProto) -> DataProto: + if self.use_legacy_worker_impl == "disable": + # step 1: convert dataproto to tensordict. + batch_td = batch.to_tensordict() + # step 2: convert from padding to nopadding + batch_td = left_right_2_no_padding(batch_td) + # step 3: add meta info + metadata = {"calculate_entropy": False, "compute_loss": False} + if self.ref_in_actor: + metadata["no_lora_adapter"] = True + tu.assign_non_tensor(batch_td, **metadata) + if self.ref_in_actor: + output = self.actor_rollout_wg.compute_log_prob(batch_td) + else: + output = self.ref_policy_wg.compute_ref_log_prob(batch_td) + # gather output + log_probs = tu.get(output, "log_probs") + # step 4. No padding to padding + log_probs = no_padding_2_padding(log_probs, batch_td) + # step 5: rebuild a tensordict and convert to dataproto + ref_log_prob = tu.get_tensordict({"ref_log_prob": log_probs.float()}) + ref_log_prob = DataProto.from_tensordict(ref_log_prob) + else: + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + + return ref_log_prob + + def _compute_old_log_prob(self, batch: DataProto): + if self.use_legacy_worker_impl == "disable": + # TODO: remove step 1, 2, 4 after we make the whole training tensordict and padding free + # step 1: convert dataproto to tensordict. + batch_td = batch.to_tensordict() + # step 2: convert from padding to nopadding + batch_td = left_right_2_no_padding(batch_td) + # step 3: add meta info + tu.assign_non_tensor(batch_td, calculate_entropy=True, compute_loss=False) + output = self.actor_rollout_wg.compute_log_prob(batch_td) + # gather output + entropy = tu.get(output, "entropy") + log_probs = tu.get(output, "log_probs") + old_log_prob_mfu = tu.get(output, "metrics")["mfu"] + # step 4. No padding to padding + entropy = no_padding_2_padding(entropy, batch_td) + log_probs = no_padding_2_padding(log_probs, batch_td) + # step 5: rebuild a tensordict and convert to dataproto + old_log_prob = tu.get_tensordict({"old_log_probs": log_probs.float(), "entropys": entropy.float()}) + old_log_prob = DataProto.from_tensordict(old_log_prob) + else: + old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + old_log_prob_mfu = 0 + return old_log_prob, old_log_prob_mfu + + def _update_actor(self, batch: DataProto) -> DataProto: + rollout_config = self.config.actor_rollout_ref.rollout + batch.meta_info["multi_turn"] = rollout_config.multi_turn.enable + # TODO: Make "temperature" single source of truth from generation. + batch.meta_info["temperature"] = rollout_config.temperature + # update actor + if self.use_legacy_worker_impl == "disable": + batch_td = batch.to_tensordict() + # step 2: convert from padding to no-padding + batch_td = left_right_2_no_padding(batch_td) + calculate_entropy = self.config.actor_rollout_ref.actor.entropy_coeff != 0.0 + ppo_mini_batch_size = self.config.actor_rollout_ref.actor.ppo_mini_batch_size + ppo_mini_batch_size = ppo_mini_batch_size * self.config.actor_rollout_ref.rollout.n + ppo_epochs = self.config.actor_rollout_ref.actor.ppo_epochs + seed = self.config.actor_rollout_ref.actor.data_loader_seed + shuffle = self.config.actor_rollout_ref.actor.shuffle + tu.assign_non_tensor( + batch_td, + calculate_entropy=calculate_entropy, + global_batch_size=ppo_mini_batch_size, + mini_batch_size=ppo_mini_batch_size, + epochs=ppo_epochs, + seed=seed, + dataloader_kwargs={"shuffle": shuffle}, + ) + + actor_output = self.actor_rollout_wg.update_actor(batch_td) + actor_output = tu.get(actor_output, "metrics") + actor_output = rename_dict(actor_output, "actor/") + # modify key name + actor_output["perf/mfu/actor"] = actor_output.pop("actor/mfu") + actor_output = DataProto.from_single_dict(data={}, meta_info={"metrics": actor_output}) + else: + actor_output = self.actor_rollout_wg.update_actor(batch) + + return actor_output + + def _update_critic(self, batch: DataProto) -> DataProto: + if self.use_legacy_worker_impl == "disable": + batch_td = batch.to_tensordict() + # step 2: convert from padding to no-padding + batch_td = left_right_2_no_padding(batch_td) + ppo_mini_batch_size = self.config.critic.ppo_mini_batch_size + ppo_mini_batch_size = ppo_mini_batch_size * self.config.actor_rollout_ref.rollout.n + ppo_epochs = self.config.critic.ppo_epochs + seed = self.config.critic.data_loader_seed + shuffle = self.config.critic.shuffle + tu.assign_non_tensor( + batch_td, + global_batch_size=ppo_mini_batch_size, + mini_batch_size=ppo_mini_batch_size, + epochs=ppo_epochs, + seed=seed, + dataloader_kwargs={"shuffle": shuffle}, + ) + + output = self.critic_wg.train_mini_batch(batch_td) + output = output.get() + output = tu.get(output, "metrics") + output = rename_dict(output, "critic/") + # modify key name + output["perf/mfu/critic"] = output.pop("critic/mfu") + critic_output = DataProto.from_single_dict(data={}, meta_info={"metrics": output}) + else: + critic_output = self.critic_wg.update_critic(batch) + return critic_output + + def fit(self): + """ + The training loop of PPO. + The driver process only need to call the compute functions of the worker group through RPC + to construct the PPO dataflow. + The light-weight advantage computation is done on the driver process. + """ + from omegaconf import OmegaConf + + from verl.utils.tracking import Tracking + + logger = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) + + self.global_steps = 0 + + # load checkpoint and update weights before doing anything + self._load_checkpoint() + self.checkpoint_manager.update_weights() + + current_epoch = self.global_steps // len(self.train_dataloader) + + # perform validation before training + # currently, we only support validation using the reward_function. + if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): + val_metrics = self._validate() + assert val_metrics, f"{val_metrics=}" + pprint(f"Initial validation metrics: {val_metrics}") + logger.log(data=val_metrics, step=self.global_steps) + if self.config.trainer.get("val_only", False): + return + + if self.config.actor_rollout_ref.rollout.get("skip_rollout", False): + rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg) + rollout_skip.wrap_generate_sequences() + + # add tqdm + progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") + + # we start from step 1 + self.global_steps += 1 + last_val_metrics = None + self.max_steps_duration = 0 + + prev_step_profile = False + curr_step_profile = ( + self.global_steps in self.config.global_profiler.steps + if self.config.global_profiler.steps is not None + else False + ) + next_step_profile = False + + for epoch in range(current_epoch, self.config.trainer.total_epochs): + for batch_dict in self.train_dataloader: + if hasattr(self.actor_rollout_wg, "async_calls_finalize_fn_exec"): + self.actor_rollout_wg.async_calls_finalize_fn_exec(blocking=False) + metrics = {} + timing_raw = {} + + with marked_timer("start_profile", timing_raw): + self._start_profiling( + not prev_step_profile and curr_step_profile + if self.config.global_profiler.profile_continuous_steps + else curr_step_profile + ) + batch: DataProto = DataProto.from_single_dict(batch_dict) + batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature + + # add uid to batch + batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object + ) + + gen_batch = self._get_gen_batch(batch) + + # pass global_steps to trace + gen_batch.meta_info["global_steps"] = self.global_steps + gen_batch_output = gen_batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True + ) + + is_last_step = self.global_steps >= self.total_training_steps + with marked_timer("step", timing_raw): + # generate a batch + with marked_timer("gen", timing_raw, color="red"): + if not self.async_rollout_mode: + gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output) + else: + if curr_step_profile: + self.async_rollout_manager.start_profile(global_step=self.global_steps) + gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output) + self.checkpoint_manager.sleep_replicas() + if curr_step_profile: + self.async_rollout_manager.stop_profile() + + timing_raw.update(gen_batch_output.meta_info["timing"]) + gen_batch_output.meta_info.pop("timing", None) + + if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: + if self.reward_fn is None: + raise ValueError("A reward_fn is required for REMAX advantage estimation.") + + with marked_timer("gen_max", timing_raw, color="purple"): + gen_baseline_batch = deepcopy(gen_batch) + gen_baseline_batch.meta_info["do_sample"] = False + if not self.async_rollout_mode: + gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) + else: + if curr_step_profile: + self.async_rollout_manager.start_profile() + gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch) + self.checkpoint_manager.sleep_replicas() + if curr_step_profile: + self.async_rollout_manager.stop_profile() + batch = batch.union(gen_baseline_output) + # compute reward model score on batch + rm_scores = None + if self.use_rm and "rm_scores" not in batch.batch.keys(): + if not self.use_reward_loop: + rm_scores = self.rm_wg.compute_rm_score(batch) + else: + assert self.reward_loop_manager is not None, "RewardLoopManager is None" + rm_scores = self.reward_loop_manager.compute_rm_score(batch) + batch = batch.union(rm_scores) + + # Compute or extract reward for REMAX baseline + reward_baseline_tensor = self._compute_or_extract_reward( + batch, reward_fn=self.reward_fn, sum_reward=True + ) + + keys_to_pop = set(gen_baseline_output.batch.keys()) + if rm_scores is not None: + keys_to_pop.update(rm_scores.batch.keys()) + batch.pop(batch_keys=list(keys_to_pop)) + + batch.batch["reward_baselines"] = reward_baseline_tensor + + del rm_scores, gen_baseline_batch, gen_baseline_output + # repeat to align with repeated responses in rollout + batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + batch = batch.union(gen_batch_output) + + if "response_mask" not in batch.batch.keys(): + batch.batch["response_mask"] = compute_response_mask(batch) + # Balance the number of valid tokens across DP ranks. + # NOTE: This usually changes the order of data in the `batch`, + # which won't affect the advantage calculation (since it's based on uid), + # but might affect the loss calculation (due to the change of mini-batching). + if self.config.trainer.balance_batch: + self._balance_batch(batch, metrics=metrics) + + # compute global_valid tokens + batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + # get images_seqlens + images_seqlens_all = [] + for multi_modal_input in batch.non_tensor_batch["multi_modal_inputs"]: + if "image_grid_thw" not in multi_modal_input.keys(): + continue + images_seqlens_all.extend(multi_modal_input["images_seqlens"].tolist()) + batch.meta_info["images_seqlens"] = images_seqlens_all + with marked_timer("reward", timing_raw, color="yellow"): + # compute reward model score + if self.use_rm and "rm_scores" not in batch.batch.keys(): + if not self.use_reward_loop: + reward_tensor = self.rm_wg.compute_rm_score(batch) + else: + assert self.reward_loop_manager is not None, "RewardLoopManager is None" + reward_tensor = self.reward_loop_manager.compute_rm_score(batch) + batch = batch.union(reward_tensor) + + # Compute or extract reward for training + if self.config.reward_model.launch_reward_fn_async: + future_reward = compute_reward_async.remote( + data=batch, config=self.config, tokenizer=self.tokenizer + ) + else: + reward_tensor, reward_extra_infos_dict = self._compute_or_extract_reward( + batch, reward_fn=self.reward_fn, reward_for_val=False + ) + + # Operating Mode Selection: + # - Bypass mode: Sets old_log_probs = rollout_log_probs (2 policies: π_rollout, π_θ) + # - Decoupled mode: Recomputes old_log_probs as proximal anchor (3 policies: π_rollout, π_old, π_θ) + # Note: π_old computed once per data batch, serves as stable reference during mini-batch updates + rollout_corr_config = self.config.algorithm.get("rollout_correction", None) + bypass_recomputing_logprobs = rollout_corr_config and rollout_corr_config.get("bypass_mode", False) + if bypass_recomputing_logprobs: # Use `rollout_log_probs` + from verl.trainer.ppo.rollout_corr_helper import apply_bypass_mode + + apply_bypass_mode( + batch=batch, + rollout_corr_config=rollout_corr_config, + policy_loss_config=self.config.actor_rollout_ref.actor.policy_loss, + ) + else: # Recompute old_log_probs + with marked_timer("old_log_prob", timing_raw, color="blue"): + old_log_prob, old_log_prob_mfu = self._compute_old_log_prob(batch) + entropys = old_log_prob.batch["entropys"] + response_masks = batch.batch["response_mask"] + actor_config = self.config.actor_rollout_ref.actor + entropy_agg = agg_loss( + loss_mat=entropys, + loss_mask=response_masks, + loss_agg_mode=actor_config.loss_agg_mode, + loss_scale_factor=actor_config.loss_scale_factor, + ) + old_log_prob_metrics = { + "actor/entropy": entropy_agg.detach().item(), + "perf/mfu/actor_infer": old_log_prob_mfu, + } + metrics.update(old_log_prob_metrics) + old_log_prob.batch.pop("entropys") + if "routed_experts" in batch.batch and "routed_experts" in old_log_prob.batch: + router_mode = getattr( + self.config.actor_rollout_ref.actor.router_replay, "mode", "disabled" + ) + if router_mode == "R2": + batch.batch.pop("routed_experts") + else: + old_log_prob.batch.pop("routed_experts") + batch = batch.union(old_log_prob) + if "rollout_log_probs" in batch.batch.keys(): + # TODO: we may want to add diff of probs too. + from verl.utils.debug.metrics import calculate_debug_metrics + + metrics.update(calculate_debug_metrics(batch)) + + assert "old_log_probs" in batch.batch, f'"old_log_prob" not in {batch.batch.keys()=}' + + if self.use_reference_policy: + # compute reference log_prob + with marked_timer(str(Role.RefPolicy), timing_raw, color="olive"): + ref_log_prob = self._compute_ref_log_prob(batch) + batch = batch.union(ref_log_prob) + + # compute values + if self.use_critic: + with marked_timer("values", timing_raw, color="cyan"): + values = self._compute_values(batch) + batch = batch.union(values) + + with marked_timer("adv", timing_raw, color="brown"): + # we combine with rule-based rm + reward_extra_infos_dict: dict[str, list] + if self.config.reward_model.launch_reward_fn_async: + reward_tensor, reward_extra_infos_dict = ray.get(future_reward) + batch.batch["token_level_scores"] = reward_tensor + + if reward_extra_infos_dict: + batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()}) + + # compute rewards. apply_kl_penalty if available + if self.config.algorithm.use_kl_in_reward: + batch, kl_metrics = apply_kl_penalty( + batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty + ) + metrics.update(kl_metrics) + else: + batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] + + # Compute rollout correction: IS weights, rejection sampling, and metrics + # Only runs in decoupled mode (computes once per batch using stable π_old) + # In bypass mode, this is skipped - actor computes metrics from evolving π_θ vs π_rollout + if ( + rollout_corr_config is not None + and "rollout_log_probs" in batch.batch + and not bypass_recomputing_logprobs # Only in decoupled mode + ): + from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_add_to_batch + + # Compute IS weights, apply rejection sampling, compute metrics + batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config) + # IS and off-policy metrics already have rollout_corr/ prefix + metrics.update(is_metrics) + + # compute advantages, executed on the driver process + norm_adv_by_std_in_grpo = self.config.algorithm.get( + "norm_adv_by_std_in_grpo", True + ) # GRPO adv normalization factor + + batch = compute_advantage( + batch, + adv_estimator=self.config.algorithm.adv_estimator, + gamma=self.config.algorithm.gamma, + lam=self.config.algorithm.lam, + num_repeat=self.config.actor_rollout_ref.rollout.n, + norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, + config=self.config.algorithm, + ) + + # update critic + if self.use_critic: + with marked_timer("update_critic", timing_raw, color="pink"): + critic_output = self._update_critic(batch) + critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) + metrics.update(critic_output_metrics) + + # implement critic warmup + if self.config.trainer.critic_warmup <= self.global_steps: + # update actor + with marked_timer("update_actor", timing_raw, color="red"): + actor_output = self._update_actor(batch) + + # update weights from trainer to rollout + with marked_timer("update_weights", timing_raw, color="red"): + self.checkpoint_manager.update_weights() + + actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) + metrics.update(actor_output_metrics) + + # Log rollout generations if enabled + rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) + if rollout_data_dir: + self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir) + + # validate + if ( + self.val_reward_fn is not None + and self.config.trainer.test_freq > 0 + and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) + ): + with marked_timer("testing", timing_raw, color="green"): + val_metrics: dict = self._validate() + if is_last_step: + last_val_metrics = val_metrics + metrics.update(val_metrics) + + # Check if the ESI (Elastic Server Instance)/training plan is close to expiration. + esi_close_to_expiration = should_save_ckpt_esi( + max_steps_duration=self.max_steps_duration, + redundant_time=self.config.trainer.esi_redundant_time, + ) + # Check if the conditions for saving a checkpoint are met. + # The conditions include a mandatory condition (1) and + # one of the following optional conditions (2/3/4): + # 1. The save frequency is set to a positive value. + # 2. It's the last training step. + # 3. The current step number is a multiple of the save frequency. + # 4. The ESI(Elastic Server Instance)/training plan is close to expiration. + if self.config.trainer.save_freq > 0 and ( + is_last_step or self.global_steps % self.config.trainer.save_freq == 0 or esi_close_to_expiration + ): + if esi_close_to_expiration: + print("Force saving checkpoint: ESI instance expiration approaching.") + with marked_timer("save_checkpoint", timing_raw, color="green"): + # sleep replicas to avoid OOM during checkpoint saving + self.checkpoint_manager.sleep_replicas() + self._save_checkpoint() + # wake replicas to avoid OOM during checkpoint saving + self.checkpoint_manager.update_weights() + + with marked_timer("stop_profile", timing_raw): + next_step_profile = ( + self.global_steps + 1 in self.config.global_profiler.steps + if self.config.global_profiler.steps is not None + else False + ) + self._stop_profiling( + curr_step_profile and not next_step_profile + if self.config.global_profiler.profile_continuous_steps + else curr_step_profile + ) + prev_step_profile = curr_step_profile + curr_step_profile = next_step_profile + + steps_duration = timing_raw["step"] + self.max_steps_duration = max(self.max_steps_duration, steps_duration) + + # training metrics + metrics.update( + { + "training/global_step": self.global_steps, + "training/epoch": epoch, + } + ) + # collect metrics + metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) + metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + # TODO: implement actual tflpo and theoretical tflpo + n_gpus = self.resource_pool_manager.get_n_gpus() + metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) + # compute variance proxy metrics + gradient_norm = metrics.get("actor/grad_norm", None) + metrics.update(compute_variance_proxy_metrics(batch=batch, gradient_norm=gradient_norm)) + # Note: mismatch metrics (KL, PPL, etc.) are collected at line 1179 after advantage computation + + # this is experimental and may be changed/removed in the future in favor of a general-purpose one + if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler): + self.train_dataloader.sampler.update(batch=batch) + + # TODO: make a canonical logger that supports various backend + logger.log(data=metrics, step=self.global_steps) + + progress_bar.update(1) + self.global_steps += 1 + + if ( + hasattr(self.config.actor_rollout_ref.actor, "profiler") + and self.config.actor_rollout_ref.actor.profiler.tool == "torch_memory" + ): + self.actor_rollout_wg.dump_memory_snapshot( + tag=f"post_update_step{self.global_steps}", sub_dir=f"step{self.global_steps}" + ) + + if is_last_step: + if hasattr(self.actor_rollout_wg, "async_calls_finalize_fn_exec"): + self.actor_rollout_wg.async_calls_finalize_fn_exec(blocking=True) + pprint(f"Final validation metrics: {last_val_metrics}") + progress_bar.close() + return + + # this is experimental and may be changed/removed in the future + # in favor of a general-purpose data buffer pool + if hasattr(self.train_dataset, "on_batch_end"): + # The dataset may be changed after each training batch + self.train_dataset.on_batch_end(batch=batch) diff --git a/code/RL_model/verl/verl_train/verl/trainer/ppo/reward.py b/code/RL_model/verl/verl_train/verl/trainer/ppo/reward.py new file mode 100644 index 0000000000000000000000000000000000000000..40c4876eb9fc75dc1543a6fd7cb89211cc60bb93 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/ppo/reward.py @@ -0,0 +1,216 @@ +# Copyright 2025 Individual Contributor: Thibaut Barroyer +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import inspect +import multiprocessing +import warnings +from functools import partial +from typing import TYPE_CHECKING, Any, Optional, cast + +import ray +import torch + +from verl.utils.reward_score import default_compute_score +from verl.utils.transferqueue_utils import tqbridge + +if TYPE_CHECKING: + from omegaconf import DictConfig + + from verl import DataProto + from verl.experimental.reward_loop.reward_manager.base import RewardManagerBase + from verl.trainer.config.config import ModuleConfig, RewardManagerConfig + from verl.workers.reward_manager.abstract import AbstractRewardManager, RawRewardFn +else: + try: + from verl.experimental.reward_loop.reward_manager.base import RewardManagerBase + except ImportError: + RewardManagerBase = None # type: ignore[assignment,misc] + + +def _call_with_kwargs(raw_fn, extra_kwargs, *args, **kwargs): + """Calls `raw_fn` by merging `extra_kwargs` into call-time `kwargs`, with `extra_kwargs` taking precedence. + + This function is used to merge additional keyword arguments with the original function's arguments. + """ + merged_kwargs = {**kwargs, **extra_kwargs} + return raw_fn(*args, **merged_kwargs) + + +async def _call_with_kwargs_async(raw_fn, extra_kwargs, *args, **kwargs): + """Calls `raw_fn` by merging `extra_kwargs` into call-time `kwargs`, with `extra_kwargs` taking precedence. + + This function is used to merge additional keyword arguments with the original function's arguments. + """ + merged_kwargs = {**kwargs, **extra_kwargs} + return await raw_fn(*args, **merged_kwargs) + + +def get_custom_reward_fn(config: DictConfig) -> Optional[RawRewardFn]: + """Load and return a custom reward function from external file. + + Dynamically imports a reward function from a specified file path and wraps + it with additional keyword arguments from the configuration. + + Args: + config (dict): Configuration dictionary containing custom_reward_function + settings with 'path', 'name', and 'reward_kwargs' fields. + + Returns: + callable or None: Wrapped reward function with merged kwargs, or None + if no custom reward function is configured. + + Raises: + FileNotFoundError: If the specified reward function file doesn't exist. + RuntimeError: If there's an error loading the module from file. + AttributeError: If the specified function name isn't found in the module. + """ + + reward_fn_config = config.get("custom_reward_function") or {} + module_path = reward_fn_config.get("path") + if not module_path: + return None + + fn_name = reward_fn_config.get("name") + assert fn_name is not None + + from verl.utils.import_utils import load_extern_object + + raw_fn = load_extern_object(module_path=module_path, object_name=fn_name) + + reward_kwargs = dict(reward_fn_config.get("reward_kwargs", {})) + if not inspect.iscoroutinefunction(raw_fn): + return partial(_call_with_kwargs, raw_fn, reward_kwargs) + else: + return partial(_call_with_kwargs_async, raw_fn, reward_kwargs) + + +def load_reward_manager( + config: DictConfig, tokenizer: Any, num_examine: int, **reward_kwargs: Any +) -> AbstractRewardManager: + """ + Load and initialize a reward manager based on the configuration. + + Args: + config: PPO trainer configuration object containing reward_model fields. + tokenizer: Tokenizer object used for processing text. + num_examine: Number of samples to examine. + **reward_kwargs: Additional keyword arguments for the reward manager. + + Returns: + An instance of the specified reward manager class. + """ + + # Try to get a custom reward function based on the configuration + # user defined reward manager can be registered in custom_reward_fn + compute_score = get_custom_reward_fn(config) + final_compute_score = compute_score + + reward_manager_cfg: RewardManagerConfig = config.reward_manager + reward_manager_cls: type[AbstractRewardManager] + if reward_manager_cfg.source == "register": + from verl.workers.reward_manager import get_reward_manager_cls + + reward_manager_cls = get_reward_manager_cls(reward_manager_cfg.name) + elif reward_manager_cfg.source == "importlib": + from verl.utils.import_utils import load_extern_object + + module_cfg: ModuleConfig | None = reward_manager_cfg.module + assert module_cfg is not None and module_cfg.path is not None, ( + f"Module path is required when {reward_manager_cfg.source=}, but got {module_cfg=}" + ) + reward_manager_cls_name = reward_manager_cfg.name + reward_manager_cls = cast( + "type[AbstractRewardManager]", + load_extern_object(module_path=module_cfg.path, object_name=reward_manager_cls_name), + ) + + if compute_score is None: + sandbox_config = config.reward_model.get("sandbox_fusion") + sandbox_url = sandbox_config.get("url") if sandbox_config else None + memory_limit_mb = sandbox_config.get("memory_limit_mb", 1024) if sandbox_config else 1024 + if sandbox_url: + sandbox_manager = multiprocessing.Manager() + # Create a semaphore to control concurrent access to the sandbox + _concurrent_semaphore = sandbox_manager.Semaphore(sandbox_config.get("max_concurrent", 64)) + final_compute_score = partial( + default_compute_score, + sandbox_fusion_url=sandbox_url, + concurrent_semaphore=_concurrent_semaphore, + memory_limit_mb=memory_limit_mb, + ) + else: + final_compute_score = default_compute_score + + # Instantiate and return the reward manager with the specified parameters + # RewardManagerBase subclasses (like RateLimitedRewardLoopManager) don't accept num_examine + # while AbstractRewardManager subclasses (like NaiveRewardManager) do + if RewardManagerBase is not None and issubclass(reward_manager_cls, RewardManagerBase): + # RewardManagerBase-based managers use a different signature + return reward_manager_cls( + config=config, + tokenizer=tokenizer, + compute_score=final_compute_score, + **reward_kwargs, + ) + else: + # Traditional AbstractRewardManager-based managers + return reward_manager_cls( + tokenizer=tokenizer, + num_examine=num_examine, + compute_score=final_compute_score, + reward_fn_key=config.data.reward_fn_key, + **reward_kwargs, + ) + + +@tqbridge(put_data=False) +def compute_reward(data: DataProto, reward_fn: AbstractRewardManager) -> tuple[torch.Tensor, dict[str, Any]]: + """ + Compute reward for a batch of data. + Args: + data: DataProto object containing the input data. + reward_fn: Reward function to compute the reward. + Returns: + Tuple of reward tensor and extra info dictionary. + """ + try: + reward_result = reward_fn(data, return_dict=True) + reward_tensor = reward_result["reward_tensor"] + reward_extra_infos_dict = reward_result.get("reward_extra_info", {}) + except Exception as e: + print(f"Error in reward_fn: {e}") + reward_tensor = reward_fn(data) + reward_extra_infos_dict = {} + + return reward_tensor, reward_extra_infos_dict + + +@ray.remote(num_cpus=1) +def compute_reward_async(data: DataProto, config=None, tokenizer=None, reward_fn=None): + """ + Load the reward manager and compute the reward for a batch of data. + This is meant to be run in a separate Ray worker. + """ + if reward_fn is None: + assert config is not None and tokenizer is not None, ( + "config and tokenizer must not be None when reward_fn is None" + ) + + warnings.warn("using config and tokenizer with compute_reward_async is deprecated", stacklevel=2) + reward_fn = load_reward_manager( + config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {}) + ) + + return compute_reward(data, reward_fn) diff --git a/code/RL_model/verl/verl_train/verl/trainer/ppo/rollout_corr_helper.py b/code/RL_model/verl/verl_train/verl/trainer/ppo/rollout_corr_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..6f770b38274d0692f32273d96edd1d2a35602a24 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/ppo/rollout_corr_helper.py @@ -0,0 +1,1074 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Rollout Correction Helper Module + +This module provides a complete pipeline to address **off-policy issues** in RL training, +including: +1. Policy mismatch between rollout and training implementations (e.g., vLLM BFloat16 vs FSDP FP32) +2. Model update staleness (training on trajectories from older checkpoints) +3. General distribution shifts between data collection and training + +Its core capabilities include computing importance sampling (IS) weights, +filtering outlier samples via rejection sampling (RS), and +tracking metrics to diagnose and correct off-policy issues. + +## Core Capabilities +1. **Multi-Granularity Aggregation**: + - Importance Sampling (IS): + Token-level + Sequence-level + - Rejection Sampling (RS): + Divergence-based filters (token_k*, seq_sum_k*, seq_mean_k*, seq_max_k*) +2. **Memory-Efficient Design**: + - Log-space computations to avoid numerical overflow/underflow. + - Fixed safety bounds (exp(±20)) for stable exponentiation. + - Metrics calculated without large intermediate tensors (prevents CUDA OOM). +3. **Comprehensive Metrics Tracking**: + - IS/RS statistics (mean/max/min, effective sample size ESS, rejection rate). + - Off-policy diagnostics (KL divergence, perplexity PPL, log PPL difference, χ² divergence). + - Sequence-level breakdowns (deviation from ideal weights, outlier fraction). + +## Key Interfaces & Usage +- compute_rollout_correction_and_rejection_mask(): compute IS weights + rejection mask. +- compute_rollout_correction_weights(): only compute truncated IS weights (for variance + reduction, no outlier rejection). +- compute_rollout_rejection_mask(): only filter outliers (for sample cleaning, no IS weight + computation). +- compute_offpolicy_metrics(): called by core functions to calculate off-policy diagnostics + (KL/PPL/χ²) — no direct external calls needed. + +### Integration Notes +- Used in `ray_trainer.py` via `compute_rollout_correction_and_add_to_batch()` (batch training pipeline). +- Used in `dp_actor.py` for distributed worker computations (distributed training scenarios). +- All functions support batch inputs and valid token masking (via `response_mask`). + + +## References +- "When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch": https://richardli.xyz/rl-collapse +- Off-policy RL (theoretical basis for IS): https://fengyao.notion.site/off-policy-rl +""" + +import math +from typing import Any, Optional + +import torch + +import verl.utils.torch_functional as verl_F +from verl.protocol import DataProto +from verl.trainer.config.algorithm import RolloutCorrectionConfig +from verl.workers.config.actor import PolicyLossConfig + +# Safety bound to prevent numerical overflow/underflow when exponentiating +# exp(20) ≈ 485 million (upper limit for stable weights), exp(-20) ≈ 2e-9 (lower limit) +SAFETY_BOUND = 20.0 + +SUPPORTED_ROLLOUT_RS_OPTIONS: set[str] = { + "token_k1", + "token_k2", + "token_k3", + "seq_sum_k1", + "seq_sum_k2", + "seq_sum_k3", + "seq_mean_k1", + "seq_mean_k2", + "seq_mean_k3", + "seq_max_k2", + "seq_max_k3", +} +TOKEN_LEVEL_ROLLOUT_RS_OPTIONS: set[str] = {"token_k1", "token_k2", "token_k3"} + + +def _parse_rollout_rs_thresholds( + options: list[str], threshold_spec: Optional[str | float] +) -> dict[str, dict[str, Optional[float]]]: + if threshold_spec is None: + raise ValueError("rollout_rs_threshold must be provided for rejection sampling.") + + if isinstance(threshold_spec, int | float): + raw_specs: list[str] = [str(threshold_spec)] + elif isinstance(threshold_spec, str): + raw_specs = [part.strip() for part in threshold_spec.split(",") if part.strip()] + else: + raise TypeError("rollout_rs_threshold must be a string or numeric value specifying per-option thresholds.") + + if not raw_specs: + raise ValueError("rollout_rs_threshold must contain at least one threshold value.") + + if len(raw_specs) not in (1, len(options)): + raise ValueError( + f"rollout_rs_threshold expects either one threshold shared by all options or exactly " + f"{len(options)} thresholds to match the provided rollout_rs options." + ) + + if len(raw_specs) == 1 and len(options) > 1: + raw_specs = raw_specs * len(options) + + thresholds: dict[str, dict[str, Optional[float]]] = {} + for option, spec in zip(options, raw_specs, strict=False): + if option.endswith("k1"): + if "_" in spec: + lower_str, upper_str = spec.split("_", 1) + else: + upper_str = spec + lower_str = str(1.0 / float(upper_str)) + try: + lower = float(lower_str) + upper = float(upper_str) + except ValueError as exc: + raise ValueError(f"Invalid numeric threshold '{spec}' for option '{option}'.") from exc + if lower <= 0 or upper <= 0: + raise ValueError(f"Thresholds for option '{option}' must be positive, got {spec}.") + thresholds[option] = { + "lower": lower, + "upper": upper, + } + else: + if "_" in spec: + raise ValueError( + f"rollout_rs_threshold for option '{option}' must provide a single upper bound " + f"without '_'. Received '{spec}'." + ) + try: + upper = float(spec) + except ValueError as exc: + raise ValueError(f"Invalid numeric threshold '{spec}' for option '{option}'.") from exc + if upper <= 0: + raise ValueError(f"Threshold for option '{option}' must be positive, got {spec}.") + thresholds[option] = { + "lower": None, + "upper": upper, + } + return thresholds + + +def compute_rollout_rejection_mask( + log_ratio: torch.Tensor, + response_mask: torch.Tensor, + rollout_rs: str = "token_k1", + rollout_rs_threshold: Optional[str | float] = None, +) -> tuple[torch.Tensor, dict[str, float]]: + """Compute hard trust region mask using divergence estimators. + + This function enforces a hard trust region constraint by masking tokens/sequences + where the estimated divergence (between training and rollout policies) exceeds + a threshold. Unlike PPO's soft clipping, this provides a hard boundary. + + Multiple rejection criteria can be supplied via a comma separated `rollout_rs` string. + All requested options must pass for a token/sequence to remain valid. + + Supported KL divergence-based modes (ideal = 0.0 unless noted): + - "token_k{1,2,3}": Token-level divergences. + - "seq_sum_k{1,2,3}": Sum of token divergences per sequence. + - "seq_mean_k{1,2,3}": Mean of token divergences per sequence. + - "seq_max_k{2,3}": Maximum token divergence per sequence. + + Args: + log_ratio: Log ratio of training policy probability to rollout policy probability, + shape (batch_size, seq_length). + response_mask: Binary mask for valid tokens (1=valid, 0=padding), + shape (batch_size, seq_length). + rollout_rs: Comma separated rejection sampling options (e.g. "token_k1,seq_sum_k3"). + rollout_rs_threshold: Threshold specification string (required). Provide one entry per + rollout_rs option separated by commas. Each entry must be a positive number. + For K1-style options (``*k1``), specify ``lower_upper`` (e.g. ``"0.1_1.2"``) + to denote lower/upper ratio bounds; other options accept a single upper bound. + + Returns: + Tuple containing: + modified_response_mask: Response mask with trust region violations masked (0=rejected), + shape (batch_size, seq_length). + metrics: Dictionary of trust region metrics (all scalars). + """ + if rollout_rs is None or not isinstance(rollout_rs, str): + raise ValueError("rollout_rs must be a non-empty string (comma separated for multiple options).") + if rollout_rs_threshold is None: + raise ValueError("rollout_rs_threshold must be provided for rejection sampling.") + + if log_ratio.shape[0] == 0: + return response_mask, {} + + # rollout_rs supports chained criteria via comma separation (e.g. "token_k1,seq_mean_k3"). + # Every listed option must pass; combined_mask aggregates them via logical AND. + option_modes = [opt.strip() for opt in rollout_rs.split(",") if opt.strip()] + if not option_modes: + raise ValueError("rollout_rs must contain at least one valid option.") + + normalized_options: list[str] = [] + seen: set[str] = set() + for opt in option_modes: + if opt not in SUPPORTED_ROLLOUT_RS_OPTIONS: + raise ValueError( + f"Invalid rollout_rs option: {opt}. Must be one of {sorted(SUPPORTED_ROLLOUT_RS_OPTIONS)}." + ) + if opt not in seen: + normalized_options.append(opt) + seen.add(opt) + + threshold_specs = _parse_rollout_rs_thresholds(normalized_options, rollout_rs_threshold) + + log_ratio_safe: torch.Tensor = torch.clamp(log_ratio, min=-SAFETY_BOUND, max=SAFETY_BOUND) + token_k1: torch.Tensor = -log_ratio_safe + token_k2: torch.Tensor = 0.5 * log_ratio_safe**2 + token_k3: torch.Tensor = torch.exp(log_ratio_safe) - 1.0 - log_ratio_safe + + response_mask_bool: torch.Tensor = response_mask.bool() + seq_valid_mask: torch.Tensor = response_mask.sum(dim=-1) > 0 + # combined_mask accumulates per-option passes; any failure flips tokens to 0. + combined_mask: torch.Tensor = torch.ones_like(response_mask, dtype=log_ratio.dtype) + metrics: dict[str, float] = {} + + def _sequence_sum(values: torch.Tensor) -> torch.Tensor: + return verl_F.masked_sum(values, response_mask, axis=-1) + + def _sequence_mean(values: torch.Tensor) -> torch.Tensor: + return verl_F.masked_mean(values, response_mask, axis=-1) + + def _sequence_max(values: torch.Tensor) -> torch.Tensor: + mask_bool = response_mask.bool() + neg_inf = torch.tensor(float("-inf"), device=values.device, dtype=values.dtype) + masked_values = values.masked_fill(~mask_bool, neg_inf) + max_values = masked_values.max(dim=-1).values + return torch.where(max_values == neg_inf, torch.zeros_like(max_values), max_values) + + for option_name in normalized_options: + thresholds_info = threshold_specs[option_name] + is_k1_option = option_name.endswith("k1") + upper_value = thresholds_info["upper"] + lower_value = thresholds_info["lower"] + apply_lower_threshold = is_k1_option + lower_log: Optional[float] = None + upper_log: Optional[float] = None + + if is_k1_option: + if lower_value is None or upper_value is None: + raise ValueError( + f"rollout_rs_threshold for option '{option_name}' must specify both lower and upper bounds." + ) + lower_log = math.log(lower_value) + upper_log = math.log(upper_value) + else: + if upper_value is None: + raise ValueError(f"rollout_rs_threshold for option '{option_name}' must specify an upper bound.") + + level = "sequence" if option_name not in TOKEN_LEVEL_ROLLOUT_RS_OPTIONS else "token" + + per_token_stat: torch.Tensor + per_sequence_stat: Optional[torch.Tensor] = None + token_keep_bool: torch.Tensor + + if option_name == "token_k1": + if lower_log is None: + raise ValueError("Threshold specification for token_k1 must include lower and upper bounds.") + per_token_stat = token_k1 + token_keep_bool = (per_token_stat >= lower_log) & (per_token_stat <= upper_log) + elif option_name == "token_k2": + per_token_stat = token_k2 + token_keep_bool = per_token_stat <= upper_value + elif option_name == "token_k3": + per_token_stat = token_k3 + token_keep_bool = per_token_stat <= upper_value + elif option_name.startswith("seq_sum"): + if option_name.endswith("k1"): + if lower_log is None: + raise ValueError( + f"Threshold specification for option '{option_name}' must include lower and upper bounds." + ) + seq_stat = _sequence_sum(token_k1) + seq_keep_bool_direct = (seq_stat >= lower_log) & (seq_stat <= upper_log) + elif option_name.endswith("k2"): + seq_stat = _sequence_sum(token_k2) + seq_keep_bool_direct = seq_stat <= upper_value + elif option_name.endswith("k3"): + seq_stat = _sequence_sum(token_k3) + seq_keep_bool_direct = seq_stat <= upper_value + else: + raise ValueError(f"Unsupported rollout_rs option: {option_name}.") + per_sequence_stat = seq_stat + token_keep_bool = seq_keep_bool_direct.unsqueeze(-1).expand_as(response_mask_bool) + per_token_stat = seq_stat.unsqueeze(-1).expand_as(response_mask) + elif option_name.startswith("seq_mean"): + if option_name.endswith("k1"): + if lower_log is None: + raise ValueError( + f"Threshold specification for option '{option_name}' must include lower and upper bounds." + ) + seq_stat = _sequence_mean(token_k1) + seq_keep_bool_direct = (seq_stat >= lower_log) & (seq_stat <= upper_log) + elif option_name.endswith("k2"): + seq_stat = _sequence_mean(token_k2) + seq_keep_bool_direct = seq_stat <= upper_value + elif option_name.endswith("k3"): + seq_stat = _sequence_mean(token_k3) + seq_keep_bool_direct = seq_stat <= upper_value + else: + raise ValueError(f"Unsupported rollout_rs option: {option_name}.") + per_sequence_stat = seq_stat + token_keep_bool = seq_keep_bool_direct.unsqueeze(-1).expand_as(response_mask_bool) + per_token_stat = seq_stat.unsqueeze(-1).expand_as(response_mask) + elif option_name.startswith("seq_max"): + if option_name.endswith("k2"): + seq_stat = _sequence_max(token_k2) + seq_keep_bool_direct = seq_stat <= upper_value + elif option_name.endswith("k3"): + seq_stat = _sequence_max(token_k3) + seq_keep_bool_direct = seq_stat <= upper_value + else: + raise ValueError(f"Unsupported rollout_rs option: {option_name}.") + per_sequence_stat = seq_stat + token_keep_bool = seq_keep_bool_direct.unsqueeze(-1).expand_as(response_mask_bool) + per_token_stat = seq_stat.unsqueeze(-1).expand_as(response_mask) + else: + raise ValueError(f"Unsupported rollout_rs option: {option_name}.") + + metrics_upper_threshold = upper_log if is_k1_option else upper_value + metrics_lower_threshold = lower_log if (is_k1_option and lower_log is not None) else 0.0 + + token_keep_mask = token_keep_bool.to(dtype=log_ratio.dtype) + combined_mask = combined_mask * token_keep_mask + seq_keep_bool_tensor = (~((~token_keep_bool) & response_mask_bool)).all(dim=-1) + + option_metrics = compute_rs_metrics( + option_name=option_name, + rs_statistic=per_token_stat, + response_mask=response_mask, + seq_valid_mask=seq_valid_mask, + level=level, + per_sequence_values=per_sequence_stat, + rollout_rs_threshold=metrics_upper_threshold, + rollout_rs_threshold_lower=metrics_lower_threshold, + apply_lower_threshold=apply_lower_threshold, + ) + metrics.update(option_metrics) + + token_masked_fraction = verl_F.masked_mean(1 - token_keep_mask, response_mask).item() + seq_valid_float = seq_valid_mask.float() + if seq_valid_float.sum() > 0: + seq_keep_float = seq_keep_bool_tensor.to(dtype=log_ratio.dtype) + seq_masked_fraction = (((1.0 - seq_keep_float) * seq_valid_float).sum() / seq_valid_float.sum()).item() + else: + seq_masked_fraction = 0.0 + metrics[f"rollout_rs_{option_name}_masked_fraction"] = token_masked_fraction + metrics[f"rollout_rs_{option_name}_seq_masked_fraction"] = seq_masked_fraction + + final_mask = combined_mask + metrics["rollout_rs_masked_fraction"] = verl_F.masked_mean(1 - final_mask, response_mask).item() + final_keep_bool = (final_mask > 0.5) & response_mask_bool + seq_has_masked: torch.Tensor = (~final_keep_bool & response_mask_bool).any(dim=-1) + metrics["rollout_rs_seq_masked_fraction"] = seq_has_masked.float().mean().item() + + modified_response_mask: torch.Tensor = (response_mask * final_mask).to(dtype=response_mask.dtype) + return modified_response_mask, metrics + + +def compute_rs_metrics( + option_name: str, + rs_statistic: torch.Tensor, + response_mask: torch.Tensor, + seq_valid_mask: torch.Tensor, + *, + level: str, + per_sequence_values: Optional[torch.Tensor], + rollout_rs_threshold: float, + rollout_rs_threshold_lower: float, + apply_lower_threshold: bool, +) -> dict[str, float]: + """Compute metrics for hard trust region enforcement (per-option). + + Args: + option_name: Original option string supplied by the user. + rs_statistic: Trust region statistic (per token) used for thresholding. + response_mask: Binary mask for valid tokens (1=valid, 0=padding). + seq_valid_mask: Boolean mask indicating sequences with at least one valid token. + level: "token" or "sequence" describing aggregation level. + per_sequence_values: Optional per-sequence statistic (same semantics as rs_statistic). + rollout_rs_threshold: Upper threshold. + rollout_rs_threshold_lower: Lower threshold (ignored if ``apply_lower_threshold`` is False). + apply_lower_threshold: Whether to mask/log metrics for values below the lower threshold. + """ + if not response_mask.any(): + raise ValueError("response_mask must contain at least one valid token (1).") + + metrics: dict[str, float] = {} + prefix = f"rollout_rs_{option_name}" + mask_bool: torch.Tensor = response_mask.bool() + + # Compute sequence statistics (used by several metrics). + if per_sequence_values is not None: + seq_values = per_sequence_values + else: + seq_values = verl_F.masked_mean(rs_statistic, response_mask, axis=-1) + if seq_values.dim() > 1: + seq_values = seq_values.squeeze(-1) + seq_values_valid = seq_values[seq_valid_mask] + + # Mean of the statistic (always reported). + metrics[f"{prefix}_mean"] = verl_F.masked_mean(rs_statistic, response_mask).item() + + # Max/min values. + if level == "sequence" and seq_values_valid.numel() > 0: + metrics[f"{prefix}_max"] = seq_values_valid.max().item() + metrics[f"{prefix}_min"] = seq_values_valid.min().item() + else: + metrics[f"{prefix}_max"] = rs_statistic.masked_fill(~mask_bool, float("-inf")).max().item() + metrics[f"{prefix}_min"] = rs_statistic.masked_fill(~mask_bool, float("inf")).min().item() + + # Fractions above/below the thresholds. + if level == "sequence" and seq_values_valid.numel() > 0: + fraction_high = (seq_values_valid > rollout_rs_threshold).float().mean().item() + fraction_low = ( + (seq_values_valid < rollout_rs_threshold_lower).float().mean().item() if apply_lower_threshold else 0.0 + ) + else: + fraction_high = verl_F.masked_mean((rs_statistic > rollout_rs_threshold).float(), response_mask).item() + fraction_low = ( + verl_F.masked_mean((rs_statistic < rollout_rs_threshold_lower).float(), response_mask).item() + if apply_lower_threshold + else 0.0 + ) + metrics[f"{prefix}_fraction_high"] = fraction_high + metrics[f"{prefix}_fraction_low"] = fraction_low + + # Standard deviation (clamped for stability). + mask_count: torch.Tensor = response_mask.sum() + if mask_count > 1: + if apply_lower_threshold: + clamp_min = rollout_rs_threshold_lower + else: + clamp_min = 0.0 + stat_for_std: torch.Tensor = rs_statistic.clamp(min=clamp_min, max=rollout_rs_threshold) + mean_clamped: torch.Tensor = verl_F.masked_mean(stat_for_std, response_mask) + stat_var: torch.Tensor = verl_F.masked_mean(stat_for_std.square(), response_mask) - mean_clamped.square() + metrics[f"{prefix}_std"] = torch.sqrt(torch.clamp(stat_var, min=0.0)).item() + else: + metrics[f"{prefix}_std"] = 0.0 + + # Sequence-level summary metrics. + if seq_values_valid.numel() > 0: + metrics[f"{prefix}_seq_mean"] = seq_values_valid.mean().item() + metrics[f"{prefix}_seq_std"] = seq_values_valid.std().item() if seq_values_valid.numel() > 1 else 0.0 + metrics[f"{prefix}_seq_max"] = seq_values_valid.max().item() + metrics[f"{prefix}_seq_min"] = seq_values_valid.min().item() + metrics[f"{prefix}_seq_max_deviation"] = (seq_values_valid - 0.0).abs().max().item() + metrics[f"{prefix}_seq_fraction_high"] = (seq_values_valid > rollout_rs_threshold).float().mean().item() + if apply_lower_threshold: + metrics[f"{prefix}_seq_fraction_low"] = ( + (seq_values_valid < rollout_rs_threshold_lower).float().mean().item() + ) + else: + metrics[f"{prefix}_seq_mean"] = 0.0 + metrics[f"{prefix}_seq_std"] = 0.0 + metrics[f"{prefix}_seq_max"] = 0.0 + metrics[f"{prefix}_seq_min"] = 0.0 + metrics[f"{prefix}_seq_max_deviation"] = 0.0 + metrics[f"{prefix}_seq_fraction_high"] = 0.0 + metrics[f"{prefix}_seq_fraction_low"] = 0.0 + + return metrics + + +def compute_rollout_correction_weights( + log_ratio: torch.Tensor, + response_mask: torch.Tensor, + rollout_is: str = "token", + rollout_is_threshold: float = 2.0, + rollout_is_batch_normalize: bool = False, +) -> tuple[torch.Tensor, dict[str, float]]: + """Compute importance sampling weights to correct for off-policy distribution shifts. + + This function calculates IS weights (π_train / π_rollout) using log ratios for numerical stability. + It supports multiple aggregation levels and truncates extreme weights to prevent training instability. + + Key design: + - Log-space computations to avoid overflow + - Truncation of extreme weights (TIS: Truncated Importance Sampling) + - Optional batch normalization (normalize to mean=1.0) + - Metrics tracking for weight distribution analysis + + Args: + log_ratio: Log ratio of training policy probability to rollout policy probability, + shape (batch_size, seq_length). + response_mask: Binary mask for valid tokens (1=valid, 0=padding), + shape (batch_size, seq_length). + rollout_is: IS weight aggregation level, must be one of: + - "token": Per-token weights (biased, low variance) + - "sequence": Per-sequence weight (product of tokens; unbiased, high variance) + rollout_is_threshold: Upper threshold for truncating extreme weights (e.g., 2.0), + default 2.0. + rollout_is_batch_normalize: Whether to normalize IS weights to have mean=1.0 per batch, + default False. + + Returns: + Tuple containing: + rollout_is_weights: Truncated IS weights (masked to zero for padding tokens), + shape (batch_size, seq_length). If batch_normalize=True, normalized to mean=1.0. + metrics: Dictionary of IS weight metrics (all scalars), including: + - rollout_is_mean/max/min: Statistic of weights (before batch normalization) + - rollout_is_eff_sample_size: Effective sample size (ESS) + - rollout_is_seq_*: Sequence-level weight statistics + - rollout_is_batch_norm_factor: Normalization factor (only if batch_normalize=True) + """ + # Validate input parameters + valid_is_levels = {"token", "sequence"} + if rollout_is not in valid_is_levels: + raise ValueError(f"Invalid rollout_is: {rollout_is}. Must be one of {valid_is_levels}.") + if rollout_is_threshold <= 0: + raise ValueError(f"rollout_is_threshold must be positive, got {rollout_is_threshold}.") + + # Compute IS weights from log ratio (handles different aggregation levels) + if rollout_is == "token": + # Per-token IS weight: exp(log(π_train/π_rollout)) with safety clamp + log_ratio_for_metrics: torch.Tensor = log_ratio + log_ratio_safe: torch.Tensor = torch.clamp(log_ratio, min=-SAFETY_BOUND, max=SAFETY_BOUND) + rollout_is_weights: torch.Tensor = torch.exp(log_ratio_safe) + + elif rollout_is == "sequence": + # Sequence-level IS weight: product of token ratios (exp(sum(log ratios))) + log_ratio_sum: torch.Tensor = verl_F.masked_sum(log_ratio, response_mask, axis=-1).unsqueeze( + -1 + ) # Shape: (batch_size, 1) + log_ratio_for_metrics = log_ratio_sum + + log_ratio_sum_safe: torch.Tensor = torch.clamp(log_ratio_sum, min=-SAFETY_BOUND, max=SAFETY_BOUND) + rollout_is_weights = torch.exp(log_ratio_sum_safe).expand_as(log_ratio) # Broadcast to sequence length + + else: + raise ValueError(f"Unsupported rollout_is: {rollout_is}") + + # Zero out weights for padding tokens using response mask + rollout_is_weights = rollout_is_weights * response_mask + + # Compute IS weight metrics (BEFORE truncation to get accurate fraction_high/low) + metrics: dict[str, float] = compute_is_metrics( + rollout_is_weights=rollout_is_weights, + log_ratio_for_metrics=log_ratio_for_metrics, + response_mask=response_mask, + rollout_is=rollout_is, + rollout_is_threshold=rollout_is_threshold, + ) + + # Truncate extreme weights (TIS: Truncated Importance Sampling) + rollout_is_weights = rollout_is_weights.clamp(max=rollout_is_threshold) + + # Detach weights to prevent gradient flow (mathematically required by IS theory) + # IS weights change the measure, not the objective. See §3.2.2 in docs/algo/rollout_corr_math.md + rollout_is_weights = rollout_is_weights.detach() + + # Apply batch normalization if requested + if rollout_is_batch_normalize: + # Compute mean based on aggregation level + mask_float = response_mask.to(dtype=rollout_is_weights.dtype) + if rollout_is == "token": + # Token-level: normalize over all token weights + if torch.distributed.is_available() and torch.distributed.is_initialized(): + weights_mean = verl_F.distributed_masked_mean(rollout_is_weights, mask_float) + else: + weights_mean = verl_F.masked_mean(rollout_is_weights, response_mask) + elif rollout_is == "sequence": + # Sequence-level: normalize over sequence weights (one weight per sequence) + # For each sequence, compute mean over valid tokens (they all have the same weight) + # then average across sequences + seq_weights = verl_F.masked_mean(rollout_is_weights, response_mask, axis=-1) # (batch_size,) + seq_mask = (response_mask.sum(dim=-1) > 0).to(dtype=rollout_is_weights.dtype) + if torch.distributed.is_available() and torch.distributed.is_initialized(): + weights_mean = verl_F.distributed_masked_mean(seq_weights, seq_mask) + else: + weights_mean = (seq_weights * seq_mask).sum() / seq_mask.sum().clamp_min(1e-8) + else: + raise ValueError(f"Unsupported rollout_is: {rollout_is}") + + # Normalize to mean=1.0 (avoid division by zero) + if weights_mean > 1e-8: + rollout_is_weights = rollout_is_weights / weights_mean + metrics["rollout_is_batch_norm_factor"] = weights_mean.item() + else: + metrics["rollout_is_batch_norm_factor"] = 1.0 + + return rollout_is_weights, metrics + + +def compute_is_metrics( + rollout_is_weights: torch.Tensor, + log_ratio_for_metrics: torch.Tensor, + response_mask: torch.Tensor, + rollout_is: str, + rollout_is_threshold: float, +) -> dict[str, float]: + """Compute comprehensive metrics for truncated importance sampling weights. + + This function calculates statistics for truncated IS weights (TIS), using log-space + for accurate threshold checks and clamped weights for stable mean/std calculations. + + Args: + rollout_is_weights: Truncated IS weights (π_train / π_rollout), + shape (batch_size, seq_length). + log_ratio_for_metrics: Log ratio of training to rollout probabilities (unclamped), + shape varies by aggregation level. + response_mask: Binary mask for valid tokens (1=valid, 0=padding), + shape (batch_size, seq_length). + rollout_is: IS weight aggregation level (matches compute_rollout_correction_weights). + rollout_is_threshold: Upper threshold for truncated IS weights. + + Returns: + Dictionary of IS weight metrics (all scalars). + """ + if not response_mask.any(): + raise ValueError("response_mask must contain at least one valid token (1).") + + metrics: dict[str, float] = {} + device: torch.device = rollout_is_weights.device + # Default lower threshold (reciprocal of upper threshold) + rollout_is_threshold_lower: float = 1.0 / rollout_is_threshold + + # Precompute log thresholds for accurate checks + log_threshold_upper: torch.Tensor = torch.log(torch.tensor(rollout_is_threshold, device=device)) + log_threshold_lower: torch.Tensor = torch.log(torch.tensor(rollout_is_threshold_lower, device=device)) + + # Compute metrics based on aggregation level + if rollout_is == "sequence": + # Sequence-level aggregation: use log-space for unclamped stats + log_max: torch.Tensor = log_ratio_for_metrics.max() + log_min: torch.Tensor = log_ratio_for_metrics.min() + metrics["rollout_is_max"] = torch.exp(torch.clamp(log_max, max=SAFETY_BOUND)).item() + metrics["rollout_is_min"] = torch.exp(log_min).item() + + # Mean uses truncated weights to avoid overflow + metrics["rollout_is_mean"] = verl_F.masked_mean(rollout_is_weights, response_mask).item() + + # Fraction of weights exceeding thresholds (log-space for accuracy) + exceeds_upper: torch.Tensor = log_ratio_for_metrics > log_threshold_upper + below_lower: torch.Tensor = log_ratio_for_metrics < log_threshold_lower + metrics["rollout_is_ratio_fraction_high"] = exceeds_upper.float().mean().item() + metrics["rollout_is_ratio_fraction_low"] = below_lower.float().mean().item() + + else: # token-level + # Token-level aggregation: compute directly from truncated weights + metrics["rollout_is_mean"] = verl_F.masked_mean(rollout_is_weights, response_mask).item() + + # Fraction of tokens exceeding thresholds + rollout_is_above_threshold: torch.Tensor = rollout_is_weights > rollout_is_threshold + rollout_is_below_threshold: torch.Tensor = rollout_is_weights < rollout_is_threshold_lower + metrics["rollout_is_ratio_fraction_high"] = verl_F.masked_mean( + rollout_is_above_threshold.float(), response_mask + ).item() + metrics["rollout_is_ratio_fraction_low"] = verl_F.masked_mean( + rollout_is_below_threshold.float(), response_mask + ).item() + + # Max/min (mask out padding tokens) + mask_bool: torch.Tensor = response_mask.bool() + metrics["rollout_is_max"] = rollout_is_weights.masked_fill(~mask_bool, float("-inf")).max().item() + metrics["rollout_is_min"] = rollout_is_weights.masked_fill(~mask_bool, float("inf")).min().item() + + # Compute standard deviation (using clamped weights for stability) + mask_count: torch.Tensor = response_mask.sum() + if mask_count > 1: + weights_for_std: torch.Tensor = rollout_is_weights.clamp( + min=rollout_is_threshold_lower, max=rollout_is_threshold + ) + mean_clamped: torch.Tensor = verl_F.masked_mean(weights_for_std, response_mask) + rollout_is_var: torch.Tensor = ( + verl_F.masked_mean(weights_for_std.square(), response_mask) - mean_clamped.square() + ) + metrics["rollout_is_std"] = torch.sqrt(torch.clamp(rollout_is_var, min=0.0)).item() + else: + metrics["rollout_is_std"] = 0.0 + + # Compute Effective Sample Size (ESS) for truncated weights + weights_for_ess: torch.Tensor = rollout_is_weights.clamp(min=rollout_is_threshold_lower, max=rollout_is_threshold) + mean_for_ess: torch.Tensor = verl_F.masked_mean(weights_for_ess, response_mask) + is_weights_normalized: torch.Tensor = weights_for_ess / (mean_for_ess + 1e-8) # Avoid division by zero + metrics["rollout_is_eff_sample_size"] = ( + 1.0 / verl_F.masked_mean(is_weights_normalized.square(), response_mask).item() + ) + + # Add sequence-level metrics if weights have batch dimension + if rollout_is_weights.dim() > 1: + seq_mean_weights: torch.Tensor = verl_F.masked_mean(rollout_is_weights, response_mask, axis=-1) + + metrics["rollout_is_seq_mean"] = seq_mean_weights.mean().item() + metrics["rollout_is_seq_std"] = seq_mean_weights.std().item() if seq_mean_weights.numel() > 1 else 0.0 + metrics["rollout_is_seq_max"] = seq_mean_weights.max().item() + metrics["rollout_is_seq_min"] = seq_mean_weights.min().item() + + # Sequence deviation from ideal weight (1.0) + seq_deviation: torch.Tensor = (seq_mean_weights - 1.0).abs() + metrics["rollout_is_seq_max_deviation"] = seq_deviation.max().item() + + # Fraction of sequences with extreme weights + metrics["rollout_is_seq_fraction_high"] = (seq_mean_weights > rollout_is_threshold).float().mean().item() + metrics["rollout_is_seq_fraction_low"] = (seq_mean_weights < rollout_is_threshold_lower).float().mean().item() + + return metrics + + +def compute_rollout_correction_and_rejection_mask( + old_log_prob: torch.Tensor, + rollout_log_prob: torch.Tensor, + response_mask: torch.Tensor, + rollout_is: Optional[str] = None, + rollout_is_threshold: Optional[float] = 2.0, + rollout_is_batch_normalize: bool = False, + rollout_rs: Optional[str] = None, + rollout_rs_threshold: Optional[str | float] = None, +) -> tuple[Optional[DataProto], torch.Tensor, dict[str, float]]: + """Unified interface for computing IS weights and rejection masks. + + This function combines IS weight calculation (truncated) and rejection sampling (masked) + into a single pipeline. + + Key design: + - Separation of IS weights (for variance reduction) and rejection masks (for sample filtering) + - Comprehensive metrics tracking for mismatch diagnosis + + Args: + old_log_prob: Log probabilities from the training policy (e.g., FSDP FP32), + shape (batch_size, seq_length). + rollout_log_prob: Log probabilities from the rollout policy (e.g., vLLM BF16), + shape (batch_size, seq_length). + response_mask: Binary mask for valid tokens (1=valid, 0=padding), + shape (batch_size, seq_length). + rollout_is: IS weight aggregation level (see compute_rollout_correction_weights for options). + Set to None to disable IS weight computation. + rollout_is_threshold: Upper threshold for truncated IS weights (used if rollout_is is set), + default 2.0. + rollout_rs: Rejection sampling aggregation modes as a comma separated string + (see compute_rollout_rejection_mask for the full list). Set to None to disable + rejection sampling. + rollout_rs_threshold: Threshold specification string (see compute_rollout_rejection_mask for details). + Provide one threshold per option (comma separated). For K1-style options, specify + ``lower_upper`` to denote the lower/upper ratio bounds. + rollout_is_batch_normalize: Whether to normalize IS weights to have mean=1.0 per batch. + Default: False. + + Returns: + Tuple containing: + rollout_is_weights_proto: DataProto with IS weights (None if rollout_is is None), + key "rollout_is_weights", shape (batch_size, seq_length). + modified_response_mask: Response mask with rejection sampling applied, + shape (batch_size, seq_length). + metrics: Dictionary of all metrics (prefixed with "rollout_corr/"), including: + - IS weight statistics + - Rejection sampling rates + - Policy mismatch metrics (KL, PPL, etc.) + """ + # Validate input masks + if not response_mask.any(): + raise ValueError("response_mask must contain at least one valid token (1).") + if old_log_prob.shape != rollout_log_prob.shape: + raise ValueError( + f"old_log_prob shape {old_log_prob.shape} does not match rollout_log_prob shape {rollout_log_prob.shape}." + ) + if old_log_prob.shape != response_mask.shape: + raise ValueError( + f"log_prob shape {old_log_prob.shape} does not match response_mask shape {response_mask.shape}." + ) + + # Step 1: Compute log ratio (log(π_train / π_rollout)) + log_ratio: torch.Tensor = old_log_prob - rollout_log_prob + metrics: dict[str, float] = {} + + # Step 2: Compute IS weights (if enabled) + rollout_is_weights: Optional[torch.Tensor] = None + if rollout_is is not None and rollout_is_threshold is not None: + rollout_is_weights, is_metrics = compute_rollout_correction_weights( + log_ratio=log_ratio, + response_mask=response_mask, + rollout_is=rollout_is, + rollout_is_threshold=rollout_is_threshold, + rollout_is_batch_normalize=rollout_is_batch_normalize, + ) + metrics.update(is_metrics) + + # Step 3: Compute rejection mask (if enabled) + modified_response_mask: torch.Tensor = response_mask.clone() + if rollout_rs is not None: + if rollout_rs_threshold is None: + raise ValueError( + "rollout_rs_threshold must be explicitly provided when rollout_rs is enabled. " + "Set rollout_rs_threshold to the desired threshold value." + ) + modified_response_mask, rs_metrics = compute_rollout_rejection_mask( + log_ratio=log_ratio, + response_mask=response_mask, + rollout_rs=rollout_rs, + rollout_rs_threshold=rollout_rs_threshold, + ) + metrics.update(rs_metrics) + + # Step 4: Compute off-policy metrics (KL, PPL, χ², etc.) + offpolicy_metrics: dict[str, float] = compute_offpolicy_metrics( + old_log_prob=old_log_prob, + rollout_log_prob=rollout_log_prob, + response_mask=response_mask, + ) + metrics.update(offpolicy_metrics) + + # Step 6: Add "rollout_corr/" prefix to all metrics for logging consistency + metrics_scalar: dict[str, float] = {} + for key, value in metrics.items(): + if isinstance(value, torch.Tensor): + metrics_scalar[f"rollout_corr/{key}"] = value.item() + else: + metrics_scalar[f"rollout_corr/{key}"] = value + + # Step 7: Wrap IS weights in DataProto for consistency with API + rollout_is_weights_proto: Optional[DataProto] = None + if rollout_is_weights is not None: + rollout_is_weights_proto = DataProto.from_dict(tensors={"rollout_is_weights": rollout_is_weights}) + + return rollout_is_weights_proto, modified_response_mask, metrics_scalar + + +def compute_offpolicy_metrics( + old_log_prob: torch.Tensor, + rollout_log_prob: Optional[torch.Tensor], + response_mask: torch.Tensor, +) -> dict[str, Any]: + """Compute off-policy diagnostic metrics (helper function). + + This helper function operates on raw tensors and is used internally by: + - compute_rollout_correction_and_rejection_mask() in this module (automatically included) + - Tests (test_rollout_corr.py, test_rollout_corr_integration.py) + + These metrics help diagnose the off-policy gap between rollout and training policies, + which can arise from: + - Policy mismatch (e.g., vLLM BF16 vs FSDP FP32) + - Model staleness (training on trajectories from older checkpoints) + - General distribution shifts + + Key metrics: + - kl: Direct KL divergence estimator KL(π_rollout || π_training) + - k3_kl: K3 KL estimator for stability (more stable for small KL) + - training_ppl: Perplexity of training policy + - rollout_ppl: Perplexity of rollout policy + - log_ppl_diff: Difference in log perplexities + - ppl_ratio: Ratio of training PPL to rollout PPL + - chi2_token: Token-level χ² divergence E[ρ²] - 1 + - chi2_seq: Sequence-level χ² divergence E[(∏ρ_t)²] - 1 + + Args: + old_log_prob: Log probabilities from training policy, shape (batch_size, seq_length) + rollout_log_prob: Log probabilities from rollout policy, shape (batch_size, seq_length) + response_mask: Mask for valid tokens, shape (batch_size, seq_length) + + Returns: + Dictionary of off-policy metrics (without prefix) + """ + # Validate that we have at least one valid token + assert response_mask.any(), "Expected at least one valid token in response_mask" + + metrics = {} + + # 1. Training policy perplexity (always available) + # Formula: exp(-1/|T| * Σ log π_training(y_t|y_ tuple[DataProto, dict]: + """Compute rollout correction weights and apply rejection sampling. + + Computes importance sampling weights to correct for off-policy issues between + rollout and training policies. Applies rejection sampling by modifying response_mask. + Always updates response_mask; conditionally adds IS weights. + + Key behavior: + - response_mask: ALWAYS updated with rejection (RS exclusions removed from training) + - rollout_is_weights: Added to batch ONLY if rollout_is parameter is set + + This separation ensures: + - Rejection works independently of IS weight application + - Metrics can be monitored before enabling IS weight correction + + Args: + batch: DataProto with old_log_probs, rollout_log_probs, response_mask + + Returns: + Tuple of (updated_batch, metrics): + updated_batch: Batch with modified response_mask (always) and rollout_is_weights (if enabled) + metrics: Dict of IS and off-policy metrics, all with "rollout_corr/" prefix + + Note: + The implementation is copied from szrlee . + """ + # Get new API parameters directly from config + rollout_is = rollout_corr_config.get("rollout_is", None) + rollout_is_threshold = rollout_corr_config.get("rollout_is_threshold", 2.0) + rollout_is_batch_normalize = rollout_corr_config.get("rollout_is_batch_normalize", False) + rollout_rs = rollout_corr_config.get("rollout_rs", None) + rollout_rs_threshold = rollout_corr_config.get("rollout_rs_threshold", None) + + # Compute IS weights and get modified response_mask + rollout_is_weights, modified_response_mask, rollout_corr_metrics = compute_rollout_correction_and_rejection_mask( + old_log_prob=batch.batch["old_log_probs"], + rollout_log_prob=batch.batch["rollout_log_probs"], + response_mask=batch.batch["response_mask"], + rollout_is=rollout_is, + rollout_is_threshold=rollout_is_threshold, + rollout_is_batch_normalize=rollout_is_batch_normalize, + rollout_rs=rollout_rs, + rollout_rs_threshold=rollout_rs_threshold, + ) + + # ALWAYS update response_mask with rejection applied + batch.batch["response_mask"] = modified_response_mask + + # Add IS weights to batch if computed + if rollout_is_weights is not None: + batch = batch.union(rollout_is_weights) + + return batch, rollout_corr_metrics + + +def compute_rollout_corr_metrics_from_logprobs( + log_prob: torch.Tensor, + rollout_log_prob: torch.Tensor, + response_mask: torch.Tensor, +) -> dict[str, float]: + """Compute rollout correction metrics from log probabilities during training. + + This function is used in the actor to compute metrics using the CURRENT policy + log probabilities versus rollout log probabilities, allowing tracking of the + off-policy gap as training progresses. + + It computes off-policy diagnostic metrics (KL, PPL, χ²) from log probabilities. + + Args: + log_prob: Current policy log probabilities, shape (batch_size, seq_length) + rollout_log_prob: Rollout policy log probabilities, shape (batch_size, seq_length) + response_mask: Valid token mask, shape (batch_size, seq_length) + + Returns: + Dictionary of metrics with "rollout_corr/" prefix + """ + # Compute off-policy diagnostic metrics + offpolicy_metrics = compute_offpolicy_metrics( + old_log_prob=log_prob, + rollout_log_prob=rollout_log_prob, + response_mask=response_mask, + ) + + # Add rollout_corr/ prefix to all metrics + metrics_with_prefix = {} + for key, value in offpolicy_metrics.items(): + if isinstance(value, torch.Tensor): + metrics_with_prefix[f"rollout_corr/{key}"] = value.item() + else: + metrics_with_prefix[f"rollout_corr/{key}"] = value + + return metrics_with_prefix + + +def apply_bypass_mode( + batch: DataProto, + rollout_corr_config: Optional[RolloutCorrectionConfig] = None, + policy_loss_config: PolicyLossConfig = None, +) -> None: + """ + Setup bypass mode: Use rollout_log_probs as old_log_probs. + + Bypass mode skips expensive actor forward pass for old_log_prob computation + by setting old_log_probs = rollout_log_probs (2 policies instead of 3). + + Uses compute_policy_loss_bypass_mode() which supports: + - loss_type="ppo_clip" (default): PPO clipped objective (IS handled by ratio) + - loss_type="reinforce": REINFORCE with explicit IS weights + + Both loss types benefit from rejection sampling (RS) which masks out-of-distribution samples. + + Note: + The implementation is copied from szrlee . + """ + from omegaconf import open_dict + + if "rollout_log_probs" not in batch.batch: + raise ValueError( + "bypass_mode=True requires rollout_log_probs in batch. " + "Ensure rollout worker is configured to calculate_log_probs=true." + ) + + # Use rollout log probs as old log probs (zero-cost substitution) + batch.batch["old_log_probs"] = batch.batch["rollout_log_probs"] + + with open_dict(policy_loss_config): + # Pass rollout_correction config to actor for loss computation and metrics + policy_loss_config["rollout_correction"] = rollout_corr_config + # Always use bypass_mode loss function which handles both loss_types + policy_loss_config["loss_mode"] = "bypass_mode" diff --git a/code/RL_model/verl/verl_train/verl/trainer/ppo/utils.py b/code/RL_model/verl/verl_train/verl/trainer/ppo/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..82903edfe4764f09c6c4a8ef541d8423c38e468c --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/ppo/utils.py @@ -0,0 +1,97 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from enum import Enum + +from omegaconf import DictConfig + +from verl.single_controller.base import Worker +from verl.trainer.ppo.core_algos import AdvantageEstimator + +WorkerType = type[Worker] + + +class Role(Enum): + """ + To create more roles dynamically, you can subclass Role and add new members + """ + + Actor = 0 + Rollout = 1 + ActorRollout = 2 + Critic = 3 + RefPolicy = 4 + RewardModel = 5 + ActorRolloutRef = 6 + Env = 7 + + def __str__(self): + return self._get_role_string() + + def _get_role_string(self): + role_mapping = { + Role.Actor: "actor", + Role.Rollout: "rollout", + Role.ActorRollout: "actor_rollout", + Role.Critic: "critic", + Role.RefPolicy: "ref", + Role.RewardModel: "rm", + Role.ActorRolloutRef: "actor_rollout_ref", + } + return role_mapping.get(self, self.name.lower()) + + @classmethod + def from_string(cls, name: str): + string_mapping = { + "actor": cls.Actor, + "rollout": cls.Rollout, + "actor_rollout": cls.ActorRollout, + "critic": cls.Critic, + "ref": cls.RefPolicy, + "rm": cls.RewardModel, + "actor_rollout_ref": cls.ActorRolloutRef, + } + role = string_mapping.get(name.lower()) + if role is None: + raise ValueError(f"No Role found for string: {name}") + return role + + +def need_reference_policy( + config: DictConfig, +) -> bool: + """Given the config, do we need ref policy.""" + return config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss + + +def need_reward_model( + role_worker_mapping: dict[Role, WorkerType], +) -> bool: + """Given a role worker mapping, do we need reward model.""" + return Role.RewardModel in role_worker_mapping + + +def need_critic(config: DictConfig) -> bool: + """Given a config, do we need critic.""" + if config.critic.enable is not None: + return bool(config.critic.enable) + elif config.algorithm.adv_estimator == AdvantageEstimator.GAE: + return True + else: + warnings.warn( + "Disabled critic as algorithm.adv_estimator != gae. If it is not intended, please set critic.enable=True", + stacklevel=2, + ) + return False