Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- code/RL_model/verl/Search-R1/dataset/data_prep.py +88 -0
- code/RL_model/verl/Search-R1/dataset/prompt +58 -0
- code/RL_model/verl/Search-R1/outputs/2026-02-01/20-26-44/main_ppo.log +0 -0
- code/RL_model/verl/Search-R1/search_r1/__init__.py +0 -0
- code/RL_model/verl/Search-R1/verl.egg-info/PKG-INFO +507 -0
- code/RL_model/verl/Search-R1/verl.egg-info/dependency_links.txt +1 -0
- code/RL_model/verl/Search-R1/verl.egg-info/requires.txt +15 -0
- code/RL_model/verl/Search-R1/verl.egg-info/top_level.txt +2 -0
- code/RL_model/verl/Search-R1/verl/__init__.py +27 -0
- code/RL_model/verl/Search-R1/verl/protocol.py +639 -0
- code/RL_model/verl/Search-R1/wandb/debug-internal.log +6 -0
- code/RL_model/verl/Search-R1/wandb/debug.log +21 -0
- code/RL_model/verl/verl_train/tests/experimental/agent_loop/agent_utils.py +92 -0
- code/RL_model/verl/verl_train/tests/experimental/agent_loop/qwen_vl_tool_chat_template.jinja2 +150 -0
- code/RL_model/verl/verl_train/tests/experimental/agent_loop/test_basic_agent_loop.py +454 -0
- code/RL_model/verl/verl_train/tests/experimental/agent_loop/test_gpt_oss_tool_parser.py +34 -0
- code/RL_model/verl/verl_train/tests/experimental/agent_loop/test_multi_modal.py +570 -0
- code/RL_model/verl/verl_train/tests/experimental/agent_loop/test_standalone_rollout.py +157 -0
- code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_agent_loop_reward_manager.py +111 -0
- code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_agent_reward_loop_colocate.py +168 -0
- code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_async_token_bucket_on_cpu.py +267 -0
- code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_math_verify.py +100 -0
- code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_rate_limited_reward_manager_on_cpu.py +528 -0
- code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_reward_model_disrm.py +153 -0
- code/RL_model/verl/verl_train/tests/experimental/vla/test_sim_envs.py +101 -0
- code/RL_model/verl/verl_train/tests/single_controller/base/test_decorator.py +76 -0
- code/RL_model/verl/verl_train/tests/single_controller/check_worker_alive/main.py +64 -0
- code/RL_model/verl/verl_train/tests/single_controller/detached_worker/README.md +14 -0
- code/RL_model/verl/verl_train/tests/single_controller/detached_worker/client.py +56 -0
- code/RL_model/verl/verl_train/tests/single_controller/detached_worker/run.sh +5 -0
- code/RL_model/verl/verl_train/tests/single_controller/detached_worker/server.py +152 -0
- code/RL_model/verl/verl_train/tests/special_e2e/envs/__init__.py +17 -0
- code/RL_model/verl/verl_train/tests/special_e2e/envs/digit_completion/__init__.py +22 -0
- code/RL_model/verl/verl_train/tests/special_e2e/envs/digit_completion/task.py +179 -0
- code/RL_model/verl/verl_train/tests/special_e2e/envs/digit_completion/tokenizer.py +155 -0
- code/RL_model/verl/verl_train/tests/special_e2e/generation/run_gen_qwen05.sh +26 -0
- code/RL_model/verl/verl_train/tests/special_e2e/generation/run_gen_qwen05_server.sh +26 -0
- code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/expert_parallel/qwen2moe_minimal.json +4 -0
- code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/expert_parallel/qwen3moe_minimal.json +4 -0
- code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/run_function_reward.sh +165 -0
- code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/run_model_reward.sh +101 -0
- code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/run_single_gpu.sh +24 -0
- code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/run_single_gpu_with_engine.sh +25 -0
- code/RL_model/verl/verl_train/tests/special_e2e/sft/compare_sft_engine_results.py +58 -0
- code/RL_model/verl/verl_train/tests/special_e2e/sft/run_sft.sh +63 -0
- code/RL_model/verl/verl_train/tests/special_e2e/sft/run_sft_engine.sh +134 -0
- code/RL_model/verl/verl_train/tests/special_e2e/sft/test_sft_engine_all.sh +42 -0
- code/RL_model/verl/verl_train/tests/special_e2e/sft/test_sp_loss_match.py +150 -0
- code/RL_model/verl/verl_train/tests/trainer/config/__init__.py +13 -0
- code/RL_model/verl/verl_train/tests/utils/ckpt/test_checkpoint_cleanup_on_cpu.py +139 -0
code/RL_model/verl/Search-R1/dataset/data_prep.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import datasets
|
| 4 |
+
import argparse
|
| 5 |
+
from verl.utils.hdfs_io import copy, makedirs
|
| 6 |
+
|
| 7 |
+
# 1. Define the exact Prompt Template from your requirements
|
| 8 |
+
# /home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/prompt
|
| 9 |
+
with open("/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/prompt", 'r') as f:
|
| 10 |
+
PROMPT_TEMPLATE = f.read()
|
| 11 |
+
|
| 12 |
+
def make_map_fn(split, data_source):
|
| 13 |
+
def process_fn(example, idx):
|
| 14 |
+
# Extract fields from your specific JSON keys: ['id', 'fulltext', 'summary']
|
| 15 |
+
full_text = example.pop('fulltext')
|
| 16 |
+
gold_summary = example.pop('summary')
|
| 17 |
+
|
| 18 |
+
# Format the prompt using your template
|
| 19 |
+
# Note: Added 'English' as default source lang based on filename
|
| 20 |
+
prompt_content = PROMPT_TEMPLATE.format(
|
| 21 |
+
source_lang="English",
|
| 22 |
+
gold_summary=gold_summary,
|
| 23 |
+
full_text=full_text
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
return {
|
| 27 |
+
"data_source": data_source,
|
| 28 |
+
"prompt": [{
|
| 29 |
+
"role": "user",
|
| 30 |
+
"content": prompt_content
|
| 31 |
+
}],
|
| 32 |
+
"ability": "summarization",
|
| 33 |
+
"reward_model": {
|
| 34 |
+
"style": "rule",
|
| 35 |
+
"ground_truth": gold_summary
|
| 36 |
+
},
|
| 37 |
+
"extra_info": {
|
| 38 |
+
"split": split,
|
| 39 |
+
"index": idx,
|
| 40 |
+
"original_id": example.get('id', idx)
|
| 41 |
+
}
|
| 42 |
+
}
|
| 43 |
+
return process_fn
|
| 44 |
+
|
| 45 |
+
if __name__ == '__main__':
|
| 46 |
+
parser = argparse.ArgumentParser()
|
| 47 |
+
# Path to your input JSON
|
| 48 |
+
parser.add_argument('--input_path', default='/home/mshahidul/readctrl/data/processed_test_raw_data/multiclinsum_test_en.json')
|
| 49 |
+
# Updated destination as requested
|
| 50 |
+
parser.add_argument('--local_dir', default='/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset')
|
| 51 |
+
args = parser.parse_args()
|
| 52 |
+
|
| 53 |
+
data_source = 'multiclinsum'
|
| 54 |
+
|
| 55 |
+
# Load your local JSON file
|
| 56 |
+
with open(args.input_path, 'r') as f:
|
| 57 |
+
raw_data = json.load(f)
|
| 58 |
+
|
| 59 |
+
# Convert to HuggingFace Dataset
|
| 60 |
+
dataset = datasets.Dataset.from_list(raw_data)
|
| 61 |
+
|
| 62 |
+
# Split into train/test (95% train, 5% test)
|
| 63 |
+
split_dataset = dataset.train_test_split(test_size=0.05, seed=42)
|
| 64 |
+
|
| 65 |
+
# Apply the mapping transformation for each split
|
| 66 |
+
processed_train = split_dataset["train"].map(
|
| 67 |
+
function=make_map_fn('train', data_source),
|
| 68 |
+
with_indices=True
|
| 69 |
+
)
|
| 70 |
+
processed_test = split_dataset["test"].map(
|
| 71 |
+
function=make_map_fn('test', data_source),
|
| 72 |
+
with_indices=True
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
# Create the directory if it doesn't exist
|
| 76 |
+
os.makedirs(args.local_dir, exist_ok=True)
|
| 77 |
+
|
| 78 |
+
# Save to Parquet in the specified location
|
| 79 |
+
train_output_path = os.path.join(args.local_dir, 'train.parquet')
|
| 80 |
+
test_output_path = os.path.join(args.local_dir, 'test.parquet')
|
| 81 |
+
processed_train.to_parquet(train_output_path)
|
| 82 |
+
processed_test.to_parquet(test_output_path)
|
| 83 |
+
|
| 84 |
+
print(f"--- Dataset Preparation Complete ---")
|
| 85 |
+
print(f"Train file saved to: {train_output_path}")
|
| 86 |
+
print(f"Test file saved to: {test_output_path}")
|
| 87 |
+
print(f"Total train records: {len(processed_train)}")
|
| 88 |
+
print(f"Total test records: {len(processed_test)}")
|
code/RL_model/verl/Search-R1/dataset/prompt
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
**System Role:**
|
| 2 |
+
|
| 3 |
+
You are an expert medical editor and Health Literacy specialist. Your task is to transform complex medical text into three distinct versions based on the reader's health literacy level. You must maintain the source language of the input while adjusting the linguistic complexity. Use the provided Gold Summary as the factual anchor to ensure the simplified versions remain accurate and focused on the most important information.
|
| 4 |
+
|
| 5 |
+
**User Prompt:**
|
| 6 |
+
|
| 7 |
+
Please process the following medical Source Text and its corresponding Gold Summary to generate three versions tailored to different health literacy levels.
|
| 8 |
+
### Instructions for Each Level:
|
| 9 |
+
|
| 10 |
+
1. Level: Low Health Literacy (High Readability)
|
| 11 |
+
|
| 12 |
+
Target: Individuals needing the simplest terms for immediate action.
|
| 13 |
+
|
| 14 |
+
Linguistic Goal: Use "living room" language. Replace all medical jargon with functional descriptions (e.g., "renal" becomes "kidney").
|
| 15 |
+
|
| 16 |
+
Information Density: Focus strictly on the "need-to-know" info found in the Gold Summary.
|
| 17 |
+
|
| 18 |
+
Strategy: High paraphrasing using analogies. One idea per sentence.
|
| 19 |
+
|
| 20 |
+
Faithfulness: Must align perfectly with the Gold Summary.
|
| 21 |
+
|
| 22 |
+
2. Level: Intermediate Health Literacy (Medium Readability)
|
| 23 |
+
|
| 24 |
+
Target: The general public (news-reading level).
|
| 25 |
+
|
| 26 |
+
Linguistic Goal: Standard vocabulary. Common medical terms are okay, but technical "doctor-speak" must be simplified.
|
| 27 |
+
|
| 28 |
+
Information Density: Balanced. Use the Gold Summary as the lead, supplemented by necessary context from the Source Text.
|
| 29 |
+
|
| 30 |
+
Strategy: Moderate paraphrasing. Remove minor technical details to avoid information overload.
|
| 31 |
+
|
| 32 |
+
Faithfulness: Maintains the main narrative of the Gold Summary.
|
| 33 |
+
|
| 34 |
+
3. Level: Proficient Health Literacy (Low Readability)
|
| 35 |
+
|
| 36 |
+
Target: Researchers, clinicians, or highly informed patients.
|
| 37 |
+
|
| 38 |
+
Linguistic Goal: Technical and academic language. Prioritize clinical nuance and medical accuracy.
|
| 39 |
+
|
| 40 |
+
Information Density: High. Use the Full Source Text to include data, physiological mechanisms, and statistics.
|
| 41 |
+
|
| 42 |
+
Strategy: Minimal paraphrasing. Retain all original technical terminology.
|
| 43 |
+
|
| 44 |
+
Faithfulness: Adhere to the Source Text; you may add related subclaims that provide deeper scientific context.
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
I will provide the following information:
|
| 48 |
+
|
| 49 |
+
- Input Language: <<<SOURCE_LANGUAGE>>>
|
| 50 |
+
- Gold Summary (the anchor reference summary): <<<GOLD_SUMMARY>>>
|
| 51 |
+
- Source Text (detailed content): <<<FULL_TEXT>>>
|
| 52 |
+
|
| 53 |
+
**Output Format (JSON only):**
|
| 54 |
+
{{
|
| 55 |
+
"low_health_literacy": "...",
|
| 56 |
+
"intermediate_health_literacy": "...",
|
| 57 |
+
"proficient_health_literacy": "..."
|
| 58 |
+
}}
|
code/RL_model/verl/Search-R1/outputs/2026-02-01/20-26-44/main_ppo.log
ADDED
|
File without changes
|
code/RL_model/verl/Search-R1/search_r1/__init__.py
ADDED
|
File without changes
|
code/RL_model/verl/Search-R1/verl.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,507 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: verl
|
| 3 |
+
Version: 0.1
|
| 4 |
+
Summary: veRL: Volcano Engine Reinforcement Learning for LLM
|
| 5 |
+
Home-page: https://github.com/volcengine/verl
|
| 6 |
+
Author: Bytedance - Seed - MLSys
|
| 7 |
+
Author-email: Bytedance - Seed - MLSys <zhangchi.usc1992@bytedance.com>, Bytedance - Seed - MLSys <gmsheng@connect.hku.hk>
|
| 8 |
+
License:
|
| 9 |
+
Apache License
|
| 10 |
+
Version 2.0, January 2004
|
| 11 |
+
http://www.apache.org/licenses/
|
| 12 |
+
|
| 13 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 14 |
+
|
| 15 |
+
1. Definitions.
|
| 16 |
+
|
| 17 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 18 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 19 |
+
|
| 20 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 21 |
+
the copyright owner that is granting the License.
|
| 22 |
+
|
| 23 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 24 |
+
other entities that control, are controlled by, or are under common
|
| 25 |
+
control with that entity. For the purposes of this definition,
|
| 26 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 27 |
+
direction or management of such entity, whether by contract or
|
| 28 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 29 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 30 |
+
|
| 31 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 32 |
+
exercising permissions granted by this License.
|
| 33 |
+
|
| 34 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 35 |
+
including but not limited to software source code, documentation
|
| 36 |
+
source, and configuration files.
|
| 37 |
+
|
| 38 |
+
"Object" form shall mean any form resulting from mechanical
|
| 39 |
+
transformation or translation of a Source form, including but
|
| 40 |
+
not limited to compiled object code, generated documentation,
|
| 41 |
+
and conversions to other media types.
|
| 42 |
+
|
| 43 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 44 |
+
Object form, made available under the License, as indicated by a
|
| 45 |
+
copyright notice that is included in or attached to the work
|
| 46 |
+
(an example is provided in the Appendix below).
|
| 47 |
+
|
| 48 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 49 |
+
form, that is based on (or derived from) the Work and for which the
|
| 50 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 51 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 52 |
+
of this License, Derivative Works shall not include works that remain
|
| 53 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 54 |
+
the Work and Derivative Works thereof.
|
| 55 |
+
|
| 56 |
+
"Contribution" shall mean any work of authorship, including
|
| 57 |
+
the original version of the Work and any modifications or additions
|
| 58 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 59 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 60 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 61 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 62 |
+
means any form of electronic, verbal, or written communication sent
|
| 63 |
+
to the Licensor or its representatives, including but not limited to
|
| 64 |
+
communication on electronic mailing lists, source code control systems,
|
| 65 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 66 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 67 |
+
excluding communication that is conspicuously marked or otherwise
|
| 68 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 69 |
+
|
| 70 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 71 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 72 |
+
subsequently incorporated within the Work.
|
| 73 |
+
|
| 74 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 77 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 78 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 79 |
+
Work and such Derivative Works in Source or Object form.
|
| 80 |
+
|
| 81 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 82 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 83 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 84 |
+
(except as stated in this section) patent license to make, have made,
|
| 85 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 86 |
+
where such license applies only to those patent claims licensable
|
| 87 |
+
by such Contributor that are necessarily infringed by their
|
| 88 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 89 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 90 |
+
institute patent litigation against any entity (including a
|
| 91 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 92 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 93 |
+
or contributory patent infringement, then any patent licenses
|
| 94 |
+
granted to You under this License for that Work shall terminate
|
| 95 |
+
as of the date such litigation is filed.
|
| 96 |
+
|
| 97 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 98 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 99 |
+
modifications, and in Source or Object form, provided that You
|
| 100 |
+
meet the following conditions:
|
| 101 |
+
|
| 102 |
+
(a) You must give any other recipients of the Work or
|
| 103 |
+
Derivative Works a copy of this License; and
|
| 104 |
+
|
| 105 |
+
(b) You must cause any modified files to carry prominent notices
|
| 106 |
+
stating that You changed the files; and
|
| 107 |
+
|
| 108 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 109 |
+
that You distribute, all copyright, patent, trademark, and
|
| 110 |
+
attribution notices from the Source form of the Work,
|
| 111 |
+
excluding those notices that do not pertain to any part of
|
| 112 |
+
the Derivative Works; and
|
| 113 |
+
|
| 114 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 115 |
+
distribution, then any Derivative Works that You distribute must
|
| 116 |
+
include a readable copy of the attribution notices contained
|
| 117 |
+
within such NOTICE file, excluding those notices that do not
|
| 118 |
+
pertain to any part of the Derivative Works, in at least one
|
| 119 |
+
of the following places: within a NOTICE text file distributed
|
| 120 |
+
as part of the Derivative Works; within the Source form or
|
| 121 |
+
documentation, if provided along with the Derivative Works; or,
|
| 122 |
+
within a display generated by the Derivative Works, if and
|
| 123 |
+
wherever such third-party notices normally appear. The contents
|
| 124 |
+
of the NOTICE file are for informational purposes only and
|
| 125 |
+
do not modify the License. You may add Your own attribution
|
| 126 |
+
notices within Derivative Works that You distribute, alongside
|
| 127 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 128 |
+
that such additional attribution notices cannot be construed
|
| 129 |
+
as modifying the License.
|
| 130 |
+
|
| 131 |
+
You may add Your own copyright statement to Your modifications and
|
| 132 |
+
may provide additional or different license terms and conditions
|
| 133 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 134 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 135 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 136 |
+
the conditions stated in this License.
|
| 137 |
+
|
| 138 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 139 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 140 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 141 |
+
this License, without any additional terms or conditions.
|
| 142 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 143 |
+
the terms of any separate license agreement you may have executed
|
| 144 |
+
with Licensor regarding such Contributions.
|
| 145 |
+
|
| 146 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 147 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 148 |
+
except as required for reasonable and customary use in describing the
|
| 149 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 150 |
+
|
| 151 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 152 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 153 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 154 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 155 |
+
implied, including, without limitation, any warranties or conditions
|
| 156 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 157 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 158 |
+
appropriateness of using or redistributing the Work and assume any
|
| 159 |
+
risks associated with Your exercise of permissions under this License.
|
| 160 |
+
|
| 161 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 162 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 163 |
+
unless required by applicable law (such as deliberate and grossly
|
| 164 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 165 |
+
liable to You for damages, including any direct, indirect, special,
|
| 166 |
+
incidental, or consequential damages of any character arising as a
|
| 167 |
+
result of this License or out of the use or inability to use the
|
| 168 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 169 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 170 |
+
other commercial damages or losses), even if such Contributor
|
| 171 |
+
has been advised of the possibility of such damages.
|
| 172 |
+
|
| 173 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 174 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 175 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 176 |
+
or other liability obligations and/or rights consistent with this
|
| 177 |
+
License. However, in accepting such obligations, You may act only
|
| 178 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 179 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 180 |
+
defend, and hold each Contributor harmless for any liability
|
| 181 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 182 |
+
of your accepting any such warranty or additional liability.
|
| 183 |
+
|
| 184 |
+
END OF TERMS AND CONDITIONS
|
| 185 |
+
|
| 186 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 187 |
+
|
| 188 |
+
To apply the Apache License to your work, attach the following
|
| 189 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 190 |
+
replaced with your own identifying information. (Don't include
|
| 191 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 192 |
+
comment syntax for the file format. We also recommend that a
|
| 193 |
+
file or class name and description of purpose be included on the
|
| 194 |
+
same "printed page" as the copyright notice for easier
|
| 195 |
+
identification within third-party archives.
|
| 196 |
+
|
| 197 |
+
Copyright [yyyy] [name of copyright owner]
|
| 198 |
+
|
| 199 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 200 |
+
you may not use this file except in compliance with the License.
|
| 201 |
+
You may obtain a copy of the License at
|
| 202 |
+
|
| 203 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 204 |
+
|
| 205 |
+
Unless required by applicable law or agreed to in writing, software
|
| 206 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 207 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 208 |
+
See the License for the specific language governing permissions and
|
| 209 |
+
limitations under the License.
|
| 210 |
+
|
| 211 |
+
Project-URL: Homepage, https://github.com/volcengine/verl
|
| 212 |
+
Requires-Python: >=3.8
|
| 213 |
+
Description-Content-Type: text/markdown
|
| 214 |
+
License-File: LICENSE
|
| 215 |
+
Requires-Dist: accelerate
|
| 216 |
+
Requires-Dist: codetiming
|
| 217 |
+
Requires-Dist: datasets
|
| 218 |
+
Requires-Dist: dill
|
| 219 |
+
Requires-Dist: hydra-core
|
| 220 |
+
Requires-Dist: numpy
|
| 221 |
+
Requires-Dist: pybind11
|
| 222 |
+
Requires-Dist: ray
|
| 223 |
+
Requires-Dist: tensordict
|
| 224 |
+
Requires-Dist: transformers<4.48
|
| 225 |
+
Requires-Dist: vllm<=0.6.3
|
| 226 |
+
Provides-Extra: test
|
| 227 |
+
Requires-Dist: pytest; extra == "test"
|
| 228 |
+
Requires-Dist: yapf; extra == "test"
|
| 229 |
+
Dynamic: author
|
| 230 |
+
Dynamic: home-page
|
| 231 |
+
Dynamic: license-file
|
| 232 |
+
|
| 233 |
+
# Search-R1: Train your LLMs to reason and call a search engine with reinforcement learning
|
| 234 |
+
|
| 235 |
+
<div align="center">
|
| 236 |
+
<img src="https://raw.githubusercontent.com/PeterGriffinJin/Search-R1/main/public/logo.png" alt="logo" width="300"/>
|
| 237 |
+
</div>
|
| 238 |
+
|
| 239 |
+
<p align="center">
|
| 240 |
+
<a href="https://arxiv.org/abs/2503.09516">
|
| 241 |
+
<img src="https://img.shields.io/badge/Paper1-blue?style=for-the-badge" alt="Button1"/>
|
| 242 |
+
</a>
|
| 243 |
+
<a href="https://arxiv.org/abs/2505.15117">
|
| 244 |
+
<img src="https://img.shields.io/badge/Paper2-green?style=for-the-badge" alt="Button2"/>
|
| 245 |
+
</a>
|
| 246 |
+
<a href="https://huggingface.co/collections/PeterJinGo/search-r1-67d1a021202731cb065740f5">
|
| 247 |
+
<img src="https://img.shields.io/badge/Resources-orange?style=for-the-badge" alt="Button3"/>
|
| 248 |
+
</a>
|
| 249 |
+
<a href="https://x.com/BowenJin13/status/1895544294473109889">
|
| 250 |
+
<img src="https://img.shields.io/badge/Tweet-red?style=for-the-badge" alt="Button4"/>
|
| 251 |
+
</a>
|
| 252 |
+
<a href="https://wandb.ai/peterjin/Search-R1-v0.2">
|
| 253 |
+
<img src="https://img.shields.io/badge/Logs-purple?style=for-the-badge" alt="Button5"/>
|
| 254 |
+
</a>
|
| 255 |
+
</p>
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
<!-- <strong>Search-R1</strong> is a reinforcement learning framework for <em>training reasoning and searching (tool-call) interleaved LLMs</em>. -->
|
| 259 |
+
<!-- We built upon [veRL](https://github.com/volcengine/verl). -->
|
| 260 |
+
**Search-R1** is a reinforcement learning framework designed for training **reasoning-and-searching interleaved LLMs**—language models that learn to reason and make tool calls (e.g., to search engines) in a coordinated manner.
|
| 261 |
+
|
| 262 |
+
<!-- It can be seen as an extension of <strong>DeepSeek-R1(-Zero)</strong> with interleaved search engine calling and an opensource RL training-based solution for <strong>OpenAI DeepResearch</strong>. -->
|
| 263 |
+
Built upon [veRL](https://github.com/volcengine/verl), Search-R1 extends the ideas of **DeepSeek-R1(-Zero)** by incorporating interleaved search engine access and provides a fully open-source RL training pipeline. It serves as an alternative and open solution to **OpenAI DeepResearch**, enabling research and development in tool-augmented LLM reasoning.
|
| 264 |
+
|
| 265 |
+
<!-- Through RL (rule-based outcome reward), the 3B **base** LLM (both Qwen2.5-3b-base and Llama3.2-3b-base) develops reasoning and search engine calling abilities all on its own. -->
|
| 266 |
+
|
| 267 |
+
We support different RL methods (e.g., PPO, GRPO, reinforce), different LLMs (e.g., llama3, Qwen2.5, etc) and different search engines (e.g., local sparse/dense retrievers and online search engines).
|
| 268 |
+
|
| 269 |
+
Paper: [link1](https://arxiv.org/pdf/2503.09516), [link2](https://arxiv.org/abs/2505.15117); Model and data: [link](https://huggingface.co/collections/PeterJinGo/search-r1-67d1a021202731cb065740f5); Twitter thread: [link](https://x.com/BowenJin13/status/1895544294473109889); Full experiment log: [prelim](https://wandb.ai/peterjin/Search-R1-open); [v0.1](https://wandb.ai/peterjin/Search-R1-nq_hotpotqa_train); [v0.2](https://wandb.ai/peterjin/Search-R1-v0.2); [v0.3](https://wandb.ai/peterjin/Search-R1-v0.3). Details about these logs and methods can be find [here](https://github.com/PeterGriffinJin/Search-R1/blob/main/docs/experiment_log.md).
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+

|
| 273 |
+
|
| 274 |
+
## News
|
| 275 |
+
|
| 276 |
+
- [2025.10] Search-R1 is featured by Thinking Machines Lab's first product [Tinker](https://github.com/thinking-machines-lab/tinker-cookbook)! Details: [Document](https://github.com/thinking-machines-lab/tinker-cookbook/tree/main/tinker_cookbook/recipes/tool_use/search).
|
| 277 |
+
- [2025.7] Search-R1 is supported by [SkyRL](https://github.com/NovaSky-AI/SkyRL)! Detailed instructions: [code](https://github.com/NovaSky-AI/SkyRL/tree/main/skyrl-train/examples/search), [Document](https://novasky-ai.notion.site/skyrl-searchr1).
|
| 278 |
+
- [2025.6] Search-R1 is now integrated into the latest version of veRL and can take advantage of its most up-to-date features! Detailed instructions: [veRL](https://verl.readthedocs.io/en/latest/sglang_multiturn/search_tool_example.html), [English Document](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/multi-turn/tool_examples/verl-multiturn-searchR1-like.md), [Chinese Document](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/multi-turn/tool_examples/verl-multiturn-searchR1-like_ZH.md).
|
| 279 |
+
- [2025.5] The second [paper](https://arxiv.org/abs/2505.15117) conducting detailed empirical studies is published with logs: [v0.3](https://wandb.ai/peterjin/Search-R1-v0.3).
|
| 280 |
+
- [2025.4] We support [multinode](https://github.com/PeterGriffinJin/Search-R1/blob/main/docs/multinode.md) training for 30B+ LLMs!
|
| 281 |
+
- [2025.4] We support [different search engines](https://github.com/PeterGriffinJin/Search-R1/blob/main/docs/retriever.md) including sparse local retriever, dense local retriever with ANN indexing and online search engines!
|
| 282 |
+
- [2025.3] The first Search-R1 [paper](https://arxiv.org/pdf/2503.09516) is published with the logs: [v0.1](https://wandb.ai/peterjin/Search-R1-nq_hotpotqa_train); [v0.2](https://wandb.ai/peterjin/Search-R1-v0.2).
|
| 283 |
+
- [2025.2] We opensource Search-R1 codebase with [preliminary results](https://wandb.ai/peterjin/Search-R1-open).
|
| 284 |
+
|
| 285 |
+
## Links
|
| 286 |
+
|
| 287 |
+
- [Installation](#installation)
|
| 288 |
+
- [Quick start](#quick-start)
|
| 289 |
+
- [Preliminary results](#preliminary-results)
|
| 290 |
+
- [Inference](#inference)
|
| 291 |
+
- [Use your own dataset](#use-your-own-dataset)
|
| 292 |
+
- [Use your own search engine](#use-your-own-search-engine)
|
| 293 |
+
- [Features](#features)
|
| 294 |
+
- [Ackowledge](#acknowledge)
|
| 295 |
+
- [Citations](#citations)
|
| 296 |
+
|
| 297 |
+
## Installation
|
| 298 |
+
|
| 299 |
+
### Search-r1 environment
|
| 300 |
+
```bash
|
| 301 |
+
conda create -n searchr1 python=3.9
|
| 302 |
+
conda activate searchr1
|
| 303 |
+
# install torch [or you can skip this step and let vllm to install the correct version for you]
|
| 304 |
+
pip install torch==2.4.0 --index-url https://download.pytorch.org/whl/cu121
|
| 305 |
+
# install vllm
|
| 306 |
+
pip3 install vllm==0.6.3 # or you can install 0.5.4, 0.4.2 and 0.3.1
|
| 307 |
+
|
| 308 |
+
# verl
|
| 309 |
+
pip install -e .
|
| 310 |
+
|
| 311 |
+
# flash attention 2
|
| 312 |
+
pip3 install flash-attn --no-build-isolation
|
| 313 |
+
pip install wandb
|
| 314 |
+
```
|
| 315 |
+
|
| 316 |
+
### Retriever environment (optional)
|
| 317 |
+
If you would like to call a local retriever as the search engine, you can install the environment as follows. (We recommend using a seperate environment.)
|
| 318 |
+
```bash
|
| 319 |
+
conda create -n retriever python=3.10
|
| 320 |
+
conda activate retriever
|
| 321 |
+
|
| 322 |
+
# we recommend installing torch with conda for faiss-gpu
|
| 323 |
+
conda install pytorch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 pytorch-cuda=12.1 -c pytorch -c nvidia
|
| 324 |
+
pip install transformers datasets pyserini
|
| 325 |
+
|
| 326 |
+
## install the gpu version faiss to guarantee efficient RL rollout
|
| 327 |
+
conda install -c pytorch -c nvidia faiss-gpu=1.8.0
|
| 328 |
+
|
| 329 |
+
## API function
|
| 330 |
+
pip install uvicorn fastapi
|
| 331 |
+
```
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
## Quick start
|
| 335 |
+
|
| 336 |
+
Train a reasoning + search LLM on NQ dataset with e5 as the retriever and wikipedia as the corpus.
|
| 337 |
+
|
| 338 |
+
(1) Download the indexing and corpus.
|
| 339 |
+
```bash
|
| 340 |
+
save_path=/the/path/to/save
|
| 341 |
+
python scripts/download.py --save_path $save_path
|
| 342 |
+
cat $save_path/part_* > $save_path/e5_Flat.index
|
| 343 |
+
gzip -d $save_path/wiki-18.jsonl.gz
|
| 344 |
+
```
|
| 345 |
+
|
| 346 |
+
(2) Process the NQ dataset.
|
| 347 |
+
```bash
|
| 348 |
+
python scripts/data_process/nq_search.py
|
| 349 |
+
```
|
| 350 |
+
|
| 351 |
+
(3) Launch a local retrieval server.
|
| 352 |
+
```bash
|
| 353 |
+
conda activate retriever
|
| 354 |
+
bash retrieval_launch.sh
|
| 355 |
+
```
|
| 356 |
+
|
| 357 |
+
(4) Run RL training (PPO) with Llama-3.2-3b-base.
|
| 358 |
+
```bash
|
| 359 |
+
conda activate searchr1
|
| 360 |
+
bash train_ppo.sh
|
| 361 |
+
```
|
| 362 |
+
|
| 363 |
+
## Preliminary results
|
| 364 |
+
|
| 365 |
+
(1) The base model (llama3.2-3b-base) learns to call the search engine and obtain improved performance.
|
| 366 |
+
|
| 367 |
+

|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
(2) The base model (Qwen2.5-7b-base) can learn to conduct multi-turn search engine calling and reasoning with RL.
|
| 371 |
+
|
| 372 |
+

|
| 373 |
+
|
| 374 |
+
## Inference
|
| 375 |
+
#### You can play with the trained Search-R1 model with your own question.
|
| 376 |
+
(1) Launch a local retrieval server.
|
| 377 |
+
```bash
|
| 378 |
+
conda activate retriever
|
| 379 |
+
bash retrieval_launch.sh
|
| 380 |
+
```
|
| 381 |
+
|
| 382 |
+
(2) Run inference.
|
| 383 |
+
```bash
|
| 384 |
+
conda activate searchr1
|
| 385 |
+
python infer.py
|
| 386 |
+
```
|
| 387 |
+
You can modify the ```question``` on line 7 to something you're interested in.
|
| 388 |
+
|
| 389 |
+
## Use your own dataset
|
| 390 |
+
|
| 391 |
+
### QA data
|
| 392 |
+
For each question-answer sample, it should be a dictionary containing the desired content as below:
|
| 393 |
+
|
| 394 |
+
```
|
| 395 |
+
data = {
|
| 396 |
+
"data_source": data_source,
|
| 397 |
+
"prompt": [{
|
| 398 |
+
"role": "user",
|
| 399 |
+
"content": question,
|
| 400 |
+
}],
|
| 401 |
+
"ability": "fact-reasoning",
|
| 402 |
+
"reward_model": {
|
| 403 |
+
"style": "rule",
|
| 404 |
+
"ground_truth": solution
|
| 405 |
+
},
|
| 406 |
+
"extra_info": {
|
| 407 |
+
'split': split,
|
| 408 |
+
'index': idx,
|
| 409 |
+
}
|
| 410 |
+
}
|
| 411 |
+
```
|
| 412 |
+
|
| 413 |
+
You can refer to ```scripts/data_process/nq_search.py``` for a concrete data processing example.
|
| 414 |
+
|
| 415 |
+
### Corpora
|
| 416 |
+
|
| 417 |
+
It is recommended to make your corpus a jsonl file, where each line (a dictionary with "id" key and "contents" key) corresponds to one passage. You can refer to ```example/corpus.jsonl``` for an example.
|
| 418 |
+
|
| 419 |
+
The "id" key corresponds to the passage id, while the "contents" key corresponds to the passage content ('"' + title + '"\n' + text).
|
| 420 |
+
For example:
|
| 421 |
+
```
|
| 422 |
+
{"id": "0", "contents": "Evan Morris Evan L. Morris (January 26, 1977 \u2013 July 9, 2015) was a lobbyist for Genentech and its parent corporation Roche in Washington."}
|
| 423 |
+
...
|
| 424 |
+
{"id": "100", "contents": "Three years later, when the United States Exploring Expedition to little-known portions of the globe was organised under Charles Wilkes, Hale was recommended, while yet an undergraduate."}
|
| 425 |
+
...
|
| 426 |
+
```
|
| 427 |
+
|
| 428 |
+
**Index your corpora (optional).**
|
| 429 |
+
If you would like to use a local retriever as the search engine, you can index your own corpus by:
|
| 430 |
+
```
|
| 431 |
+
bash search_r1/search/build_index.sh
|
| 432 |
+
```
|
| 433 |
+
You can change ```retriever_name``` and ```retriever_model``` to your interested off-the-shelf retriever.
|
| 434 |
+
|
| 435 |
+
## Use your own search engine
|
| 436 |
+
|
| 437 |
+
Our codebase supports local sparse retriever (e.g., BM25), local dense retriever (both flat indexing with GPUs and ANN indexing with CPUs) and online search engine (e.g., Google, Bing, etc). More details can be found [here](https://github.com/PeterGriffinJin/Search-R1/tree/main/docs/retriever.md).
|
| 438 |
+
|
| 439 |
+
The main philosophy is to launch a local or remote search engine server separately from the main RL training pipeline.
|
| 440 |
+
|
| 441 |
+
The LLM can call the search engine by calling the search API (e.g., "http://127.0.0.1:8000/retrieve").
|
| 442 |
+
|
| 443 |
+
You can refer to ```search_r1/search/retriever_server.py``` for an example of launching a local retriever server.
|
| 444 |
+
|
| 445 |
+
## Features
|
| 446 |
+
- Support local sparse retrievers (e.g., BM25). ✔️
|
| 447 |
+
- Support local dense retrievers (both flat indexing and ANN indexing) ✔️
|
| 448 |
+
- Support google search / bing search / brave search API and others. ✔️
|
| 449 |
+
- Support off-the-shelf neural rerankers. ✔️
|
| 450 |
+
- Support different RL methods (e.g., PPO, GRPO, reinforce). ✔️
|
| 451 |
+
- Support different LLMs (e.g., llama3, Qwen2.5, etc). ✔️
|
| 452 |
+
|
| 453 |
+
## Acknowledge
|
| 454 |
+
|
| 455 |
+
The concept of Search-R1 is inspired by [Deepseek-R1](https://github.com/deepseek-ai/DeepSeek-R1) and [TinyZero](https://github.com/Jiayi-Pan/TinyZero/tree/main).
|
| 456 |
+
Its implementation is built upon [veRL](https://github.com/volcengine/verl) and [RAGEN](https://github.com/ZihanWang314/RAGEN/tree/main).
|
| 457 |
+
We sincerely appreciate the efforts of these teams for their contributions to open-source research and development.
|
| 458 |
+
|
| 459 |
+
## Awesome work powered or inspired by Search-R1
|
| 460 |
+
|
| 461 |
+
- [DeepResearcher](https://github.com/GAIR-NLP/DeepResearcher): Scaling Deep Research via Reinforcement Learning in Real-world Environments. [![[code]](https://img.shields.io/github/stars/GAIR-NLP/DeepResearcher)](https://github.com/GAIR-NLP/DeepResearcher)
|
| 462 |
+
- [Multimodal-Search-R1](https://github.com/EvolvingLMMs-Lab/multimodal-search-r1): Incentivizing LMMs to Search. [![[code]](https://img.shields.io/github/stars/EvolvingLMMs-Lab/multimodal-search-r1)](https://github.com/EvolvingLMMs-Lab/multimodal-search-r1)
|
| 463 |
+
- [OTC](https://arxiv.org/pdf/2504.14870): Optimal Tool Calls via Reinforcement Learning.
|
| 464 |
+
- [ZeroSearch](https://github.com/Alibaba-NLP/ZeroSearch): Incentivize the Search Capability of LLMs without Searching. [![[code]](https://img.shields.io/github/stars/Alibaba-NLP/ZeroSearch)](https://github.com/Alibaba-NLP/ZeroSearch)
|
| 465 |
+
- [IKEA](https://github.com/hzy312/knowledge-r1): Reinforced Internal-External Knowledge Synergistic Reasoning for Efficient Adaptive Search Agent. [![[code]](https://img.shields.io/github/stars/hzy312/knowledge-r1)](https://github.com/hzy312/knowledge-r1)
|
| 466 |
+
- [Scent of Knowledge](https://arxiv.org/abs/2505.09316): Optimizing Search-Enhanced Reasoning with Information Foraging.
|
| 467 |
+
- [AutoRefine](https://www.arxiv.org/pdf/2505.11277): Search and Refine During Think. [![[code]](https://img.shields.io/github/stars/syr-cn/AutoRefine)](https://github.com/syr-cn/AutoRefine)
|
| 468 |
+
- [O^2-Searcher](https://arxiv.org/pdf/2505.16582): A Searching-based Agent Model for Open-Domain Open-Ended Question Answering. [![[code]](https://img.shields.io/github/stars/Acade-Mate/O2-Searcher)](https://github.com/Acade-Mate/O2-Searcher)
|
| 469 |
+
- [MaskSearch](https://arxiv.org/pdf/2505.20285): A Universal Pre-Training Framework to Enhance Agentic Search Capability. [![[code]](https://img.shields.io/github/stars/Alibaba-NLP/MaskSearch)](https://github.com/Alibaba-NLP/MaskSearch)
|
| 470 |
+
- [VRAG-RL](https://arxiv.org/abs/2505.22019): Vision-Perception-Based RAG for Visually Rich Information Understanding. [![[code]](https://img.shields.io/github/stars/Alibaba-NLP/VRAG)](https://github.com/Alibaba-NLP/VRAG)
|
| 471 |
+
- [R1-Code-Interpreter](https://arxiv.org/abs/2505.21668): Training LLMs to Reason with Code via SFT and RL. [![[code]](https://img.shields.io/github/stars/yongchao98/R1-Code-Interpreter)](https://github.com/yongchao98/R1-Code-Interpreter)
|
| 472 |
+
- [R-Search](https://arxiv.org/abs/2506.04185): Empowering LLM Reasoning with Search via Multi-Reward Reinforcement Learning. [![[code]](https://img.shields.io/github/stars/QingFei1/R-Search)](https://github.com/QingFei1/R-Search)
|
| 473 |
+
- [StepSearch](https://arxiv.org/pdf/2505.15107): Igniting LLMs Search Ability via Step-Wise Proximal Policy Optimization. [![[code]](https://img.shields.io/github/stars/Zillwang/StepSearch)](https://github.com/Zillwang/StepSearch)
|
| 474 |
+
- [SimpleTIR](https://simpletir.notion.site/report): Stable End-to-End Reinforcement Learning for Multi-Turn Tool-Integrated Reasoning. [![[code]](https://img.shields.io/github/stars/ltzheng/SimpleTIR)](https://github.com/ltzheng/SimpleTIR)
|
| 475 |
+
- [Router-R1](https://arxiv.org/pdf/2506.09033): Teaching LLMs Multi-Round Routing and Aggregation via Reinforcement Learning. [![[code]](https://img.shields.io/github/stars/ulab-uiuc/Router-R1)](https://github.com/ulab-uiuc/Router-R1)
|
| 476 |
+
- [SkyRL](https://skyrl.readthedocs.io/en/latest/): A Modular Full-stack RL Library for LLMs. [![[code]](https://img.shields.io/github/stars/NovaSky-AI/SkyRL)](https://github.com/NovaSky-AI/SkyRL)
|
| 477 |
+
- [ASearcher](https://arxiv.org/abs/2508.07976): Large-Scale RL for Search Agents. [![[code]](https://img.shields.io/github/stars/inclusionAI/ASearcher)](https://github.com/inclusionAI/ASearcher)
|
| 478 |
+
- [ParallelSearch](https://www.arxiv.org/abs/2508.09303): Decompose Query and Search Sub-queries in Parallel with RL. [![[code]](https://img.shields.io/github/stars/Tree-Shu-Zhao/ParallelSearch)](https://github.com/Tree-Shu-Zhao/ParallelSearch)
|
| 479 |
+
- [AutoTIR](https://arxiv.org/pdf/2507.21836): Autonomous Tools Integrated Reasoning via Reinforcement Learning. [![[code]](https://img.shields.io/github/stars/weiyifan1023/AutoTIR)](https://github.com/weiyifan1023/AutoTIR)
|
| 480 |
+
- [verl-tool](https://arxiv.org/pdf/2509.01055): A version of verl to support diverse tool use. [![[code]](https://img.shields.io/github/stars/TIGER-AI-Lab/verl-tool)](https://github.com/TIGER-AI-Lab/verl-tool)
|
| 481 |
+
- [Tree-GRPO](https://arxiv.org/abs/2509.21240): Tree Search for LLM Agent Reinforcement Learning. [![[code]](https://img.shields.io/github/stars/AMAP-ML/Tree-GRPO)](https://github.com/AMAP-ML/Tree-GRPO)
|
| 482 |
+
- [EviNote-RAG](https://arxiv.org/abs/2509.00877): Enhancing RAG Models via Answer-Supportive Evidence Notes. [![[code]](https://img.shields.io/github/stars/Da1yuqin/EviNoteRAG)](https://github.com/Da1yuqin/EviNoteRAG)
|
| 483 |
+
- [GlobalRAG](https://arxiv.org/pdf/2510.20548v1): GlobalRAG: Enhancing Global Reasoning in Multi-hop Question Answering via Reinforcement Learning. [![[code]](https://img.shields.io/github/stars/CarnegieBin/GlobalRAG)](https://github.com/CarnegieBin/GlobalRAG)
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
## Citations
|
| 490 |
+
|
| 491 |
+
```bibtex
|
| 492 |
+
@article{jin2025search,
|
| 493 |
+
title={Search-r1: Training llms to reason and leverage search engines with reinforcement learning},
|
| 494 |
+
author={Jin, Bowen and Zeng, Hansi and Yue, Zhenrui and Yoon, Jinsung and Arik, Sercan and Wang, Dong and Zamani, Hamed and Han, Jiawei},
|
| 495 |
+
journal={arXiv preprint arXiv:2503.09516},
|
| 496 |
+
year={2025}
|
| 497 |
+
}
|
| 498 |
+
```
|
| 499 |
+
|
| 500 |
+
```bibtex
|
| 501 |
+
@article{jin2025empirical,
|
| 502 |
+
title={An Empirical Study on Reinforcement Learning for Reasoning-Search Interleaved LLM Agents},
|
| 503 |
+
author={Jin, Bowen and Yoon, Jinsung and Kargupta, Priyanka and Arik, Sercan O and Han, Jiawei},
|
| 504 |
+
journal={arXiv preprint arXiv:2505.15117},
|
| 505 |
+
year={2025}
|
| 506 |
+
}
|
| 507 |
+
```
|
code/RL_model/verl/Search-R1/verl.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
code/RL_model/verl/Search-R1/verl.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate
|
| 2 |
+
codetiming
|
| 3 |
+
datasets
|
| 4 |
+
dill
|
| 5 |
+
hydra-core
|
| 6 |
+
numpy
|
| 7 |
+
pybind11
|
| 8 |
+
ray
|
| 9 |
+
tensordict
|
| 10 |
+
transformers<4.48
|
| 11 |
+
vllm<=0.6.3
|
| 12 |
+
|
| 13 |
+
[test]
|
| 14 |
+
pytest
|
| 15 |
+
yapf
|
code/RL_model/verl/Search-R1/verl.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
search_r1
|
| 2 |
+
verl
|
code/RL_model/verl/Search-R1/verl/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
|
| 17 |
+
version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))
|
| 18 |
+
|
| 19 |
+
with open(os.path.join(version_folder, 'version/version')) as f:
|
| 20 |
+
__version__ = f.read().strip()
|
| 21 |
+
|
| 22 |
+
from .protocol import DataProto
|
| 23 |
+
|
| 24 |
+
from .utils.logging_utils import set_basic_config
|
| 25 |
+
import logging
|
| 26 |
+
|
| 27 |
+
set_basic_config(level=logging.WARNING)
|
code/RL_model/verl/Search-R1/verl/protocol.py
ADDED
|
@@ -0,0 +1,639 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
Implement base data transfer protocol between any two functions, modules.
|
| 16 |
+
We can subclass Protocol to define more detailed batch info with specific keys
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import pickle
|
| 20 |
+
import numpy as np
|
| 21 |
+
import copy
|
| 22 |
+
from dataclasses import dataclass, field
|
| 23 |
+
from typing import Callable, Dict, List, Union
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
import tensordict
|
| 27 |
+
from tensordict import TensorDict
|
| 28 |
+
from torch.utils.data import DataLoader, Dataset
|
| 29 |
+
|
| 30 |
+
from verl.utils.py_functional import union_two_dict
|
| 31 |
+
|
| 32 |
+
__all__ = ['DataProto', 'union_tensor_dict']
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
tensordict.set_lazy_legacy(False).set()
|
| 36 |
+
except:
|
| 37 |
+
pass
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def pad_dataproto_to_divisor(data: 'DataProto', size_divisor: int):
|
| 41 |
+
"""Pad a DataProto to size divisible by size_divisor
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
size_divisor (int): size divisor
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
data: (DataProto): the padded DataProto
|
| 48 |
+
pad_size (int)
|
| 49 |
+
"""
|
| 50 |
+
assert isinstance(data, DataProto), 'data must be a DataProto'
|
| 51 |
+
if len(data) % size_divisor != 0:
|
| 52 |
+
pad_size = size_divisor - len(data) % size_divisor
|
| 53 |
+
data_padded = DataProto.concat([data, data[:pad_size]])
|
| 54 |
+
else:
|
| 55 |
+
pad_size = 0
|
| 56 |
+
data_padded = data
|
| 57 |
+
return data_padded, pad_size
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def unpad_dataproto(data: 'DataProto', pad_size):
|
| 61 |
+
if pad_size != 0:
|
| 62 |
+
data = data[:-pad_size]
|
| 63 |
+
return data
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> TensorDict:
|
| 67 |
+
"""Union two tensordicts."""
|
| 68 |
+
assert tensor_dict1.batch_size == tensor_dict2.batch_size, \
|
| 69 |
+
f'Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}'
|
| 70 |
+
for key in tensor_dict2.keys():
|
| 71 |
+
if key not in tensor_dict1.keys():
|
| 72 |
+
tensor_dict1[key] = tensor_dict2[key]
|
| 73 |
+
else:
|
| 74 |
+
assert tensor_dict1[key].equal(tensor_dict2[key]), \
|
| 75 |
+
f'{key} in tensor_dict1 and tensor_dict2 are not the same object'
|
| 76 |
+
|
| 77 |
+
return tensor_dict1
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def union_numpy_dict(tensor_dict1: dict[np.ndarray], tensor_dict2: dict[np.ndarray]) -> dict[np.ndarray]:
|
| 81 |
+
for key, val in tensor_dict2.items():
|
| 82 |
+
if key in tensor_dict1:
|
| 83 |
+
assert isinstance(tensor_dict2[key], np.ndarray)
|
| 84 |
+
assert isinstance(tensor_dict1[key], np.ndarray)
|
| 85 |
+
assert np.all(tensor_dict2[key] == tensor_dict1[key]), \
|
| 86 |
+
f'{key} in tensor_dict1 and tensor_dict2 are not the same object'
|
| 87 |
+
tensor_dict1[key] = val
|
| 88 |
+
|
| 89 |
+
return tensor_dict1
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def list_of_dict_to_dict_of_list(list_of_dict: list[dict]):
|
| 93 |
+
if len(list_of_dict) == 0:
|
| 94 |
+
return {}
|
| 95 |
+
keys = list_of_dict[0].keys()
|
| 96 |
+
output = {key: [] for key in keys}
|
| 97 |
+
for data in list_of_dict:
|
| 98 |
+
for key, item in data.items():
|
| 99 |
+
assert key in output
|
| 100 |
+
output[key].append(item)
|
| 101 |
+
return output
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def fold_batch_dim(data: 'DataProto', new_batch_size):
|
| 105 |
+
"""
|
| 106 |
+
Fold a batch dim from [bsz, xxx] into [new_bsz, bsz // new_bsz, xxx]
|
| 107 |
+
"""
|
| 108 |
+
batch_size = data.batch.batch_size[0]
|
| 109 |
+
|
| 110 |
+
assert batch_size % new_batch_size == 0
|
| 111 |
+
|
| 112 |
+
tensor: TensorDict = data.batch
|
| 113 |
+
non_tensor = data.non_tensor_batch
|
| 114 |
+
|
| 115 |
+
tensor = tensor.view(new_batch_size, -1)
|
| 116 |
+
tensor.auto_batch_size_(batch_dims=1)
|
| 117 |
+
|
| 118 |
+
for key, val in non_tensor.items():
|
| 119 |
+
non_tensor[key] = np.reshape(val, newshape=(new_batch_size, -1, *val.shape[1:]))
|
| 120 |
+
|
| 121 |
+
return DataProto(batch=tensor, non_tensor_batch=non_tensor, meta_info=data.meta_info)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def unfold_batch_dim(data: 'DataProto', batch_dims=2):
|
| 125 |
+
"""
|
| 126 |
+
Unfold the first n dims as new batch dim
|
| 127 |
+
"""
|
| 128 |
+
tensor: TensorDict = data.batch
|
| 129 |
+
non_tensor = data.non_tensor_batch
|
| 130 |
+
tensor.auto_batch_size_(batch_dims=batch_dims)
|
| 131 |
+
tensor = tensor.view(-1)
|
| 132 |
+
|
| 133 |
+
batch_size = tensor.batch_size[0]
|
| 134 |
+
|
| 135 |
+
non_tensor_new = {}
|
| 136 |
+
|
| 137 |
+
for key, val in non_tensor.items():
|
| 138 |
+
non_tensor_new[key] = np.reshape(val, newshape=(batch_size, *val.shape[batch_dims:]))
|
| 139 |
+
|
| 140 |
+
return DataProto(batch=tensor, non_tensor_batch=non_tensor_new, meta_info=data.meta_info)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def collate_fn(x: list['DataProtoItem']):
|
| 144 |
+
batch = []
|
| 145 |
+
non_tensor_batch = []
|
| 146 |
+
for data in x:
|
| 147 |
+
batch.append(data.batch)
|
| 148 |
+
non_tensor_batch.append(data.non_tensor_batch)
|
| 149 |
+
batch = torch.stack(batch).contiguous()
|
| 150 |
+
non_tensor_batch = list_of_dict_to_dict_of_list(non_tensor_batch)
|
| 151 |
+
for key, val in non_tensor_batch.items():
|
| 152 |
+
non_tensor_batch[key] = np.array(val, dtype=object)
|
| 153 |
+
return DataProto(batch=batch, non_tensor_batch=non_tensor_batch)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
@dataclass
|
| 157 |
+
class DataProtoItem:
|
| 158 |
+
# TODO(zhangchi.usc1992) add consistency check
|
| 159 |
+
batch: TensorDict = None
|
| 160 |
+
non_tensor_batch: Dict = field(default_factory=dict)
|
| 161 |
+
meta_info: Dict = field(default_factory=dict)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
@dataclass
|
| 165 |
+
class DataProto:
|
| 166 |
+
"""
|
| 167 |
+
A DataProto is a data structure that aims to provide a standard protocol for data exchange between functions.
|
| 168 |
+
It contains a batch (TensorDict) and a meta_info (Dict). The batch is a TensorDict https://pytorch.org/tensordict/.
|
| 169 |
+
TensorDict allows you to manipulate a dictionary of Tensors like a single Tensor. Ideally, the tensors with the
|
| 170 |
+
same batch size should be put inside batch.
|
| 171 |
+
"""
|
| 172 |
+
batch: TensorDict = None
|
| 173 |
+
non_tensor_batch: Dict = field(default_factory=dict)
|
| 174 |
+
meta_info: Dict = field(default_factory=dict)
|
| 175 |
+
|
| 176 |
+
def __post_init__(self):
|
| 177 |
+
# perform necessary checking
|
| 178 |
+
self.check_consistency()
|
| 179 |
+
|
| 180 |
+
def __len__(self):
|
| 181 |
+
if self.batch is not None:
|
| 182 |
+
return self.batch.batch_size[0]
|
| 183 |
+
elif self.non_tensor_batch is not None and len(self.non_tensor_batch) > 0:
|
| 184 |
+
random_key = list(self.non_tensor_batch.keys())[0]
|
| 185 |
+
return self.non_tensor_batch[random_key].shape[0]
|
| 186 |
+
else:
|
| 187 |
+
return 0
|
| 188 |
+
|
| 189 |
+
def __getitem__(self, item):
|
| 190 |
+
tensor_data = self.batch[item]
|
| 191 |
+
non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()}
|
| 192 |
+
return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)
|
| 193 |
+
|
| 194 |
+
def __getstate__(self):
|
| 195 |
+
import io
|
| 196 |
+
buffer = io.BytesIO()
|
| 197 |
+
if tensordict.__version__ >= '0.5.0' and self.batch is not None:
|
| 198 |
+
self.batch = self.batch.contiguous()
|
| 199 |
+
self.batch = self.batch.consolidate()
|
| 200 |
+
torch.save(self.batch, buffer)
|
| 201 |
+
buffer_bytes = buffer.getvalue()
|
| 202 |
+
return buffer_bytes, self.non_tensor_batch, self.meta_info
|
| 203 |
+
|
| 204 |
+
def __setstate__(self, data):
|
| 205 |
+
import io
|
| 206 |
+
batch_deserialized_bytes, non_tensor_batch, meta_info = data
|
| 207 |
+
batch_deserialized = io.BytesIO(initial_bytes=batch_deserialized_bytes)
|
| 208 |
+
batch = torch.load(batch_deserialized,
|
| 209 |
+
weights_only=False,
|
| 210 |
+
map_location='cpu' if not torch.cuda.is_available() else None)
|
| 211 |
+
self.batch = batch
|
| 212 |
+
self.non_tensor_batch = non_tensor_batch
|
| 213 |
+
self.meta_info = meta_info
|
| 214 |
+
|
| 215 |
+
def save_to_disk(self, filepath):
|
| 216 |
+
with open(filepath, 'wb') as f:
|
| 217 |
+
pickle.dump(self, f)
|
| 218 |
+
|
| 219 |
+
@staticmethod
|
| 220 |
+
def load_from_disk(filepath) -> 'DataProto':
|
| 221 |
+
with open(filepath, 'rb') as f:
|
| 222 |
+
data = pickle.load(f)
|
| 223 |
+
return data
|
| 224 |
+
|
| 225 |
+
def print_size(self, prefix=""):
|
| 226 |
+
size_of_tensordict = 0
|
| 227 |
+
for key, tensor in self.batch.items():
|
| 228 |
+
size_of_tensordict += tensor.element_size() * tensor.numel()
|
| 229 |
+
size_of_numpy_array = 0
|
| 230 |
+
for key, numpy_array in self.non_tensor_batch.items():
|
| 231 |
+
size_of_numpy_array += numpy_array.nbytes
|
| 232 |
+
|
| 233 |
+
size_of_numpy_array /= 1024**3
|
| 234 |
+
size_of_tensordict /= 1024**3
|
| 235 |
+
|
| 236 |
+
message = f'Size of tensordict: {size_of_tensordict} GB, size of non_tensor_batch: {size_of_numpy_array} GB'
|
| 237 |
+
|
| 238 |
+
if prefix:
|
| 239 |
+
message = f'{prefix}, ' + message
|
| 240 |
+
print(message)
|
| 241 |
+
|
| 242 |
+
def check_consistency(self):
|
| 243 |
+
"""Check the consistency of the DataProto. Mainly for batch and non_tensor_batch
|
| 244 |
+
We expose this function as a public one so that user can call themselves directly
|
| 245 |
+
"""
|
| 246 |
+
if self.batch is not None:
|
| 247 |
+
assert len(self.batch.batch_size) == 1, 'only support num_batch_dims=1'
|
| 248 |
+
|
| 249 |
+
if self.non_tensor_batch is not None:
|
| 250 |
+
for key, val in self.non_tensor_batch.items():
|
| 251 |
+
assert isinstance(val, np.ndarray)
|
| 252 |
+
|
| 253 |
+
if self.batch is not None and len(self.non_tensor_batch) != 0:
|
| 254 |
+
# TODO: we can actually lift this restriction if needed
|
| 255 |
+
assert len(self.batch.batch_size) == 1, 'only support num_batch_dims=1 when non_tensor_batch is not empty.'
|
| 256 |
+
|
| 257 |
+
batch_size = self.batch.batch_size[0]
|
| 258 |
+
for key, val in self.non_tensor_batch.items():
|
| 259 |
+
assert isinstance(
|
| 260 |
+
val, np.ndarray
|
| 261 |
+
) and val.dtype == object, 'data in the non_tensor_batch must be a numpy.array with dtype=object'
|
| 262 |
+
assert val.shape[
|
| 263 |
+
0] == batch_size, f'key {key} length {len(val)} is not equal to batch size {batch_size}'
|
| 264 |
+
|
| 265 |
+
@classmethod
|
| 266 |
+
def from_single_dict(cls, data: Dict[str, Union[torch.Tensor, np.ndarray]], meta_info=None):
|
| 267 |
+
tensors = {}
|
| 268 |
+
non_tensors = {}
|
| 269 |
+
|
| 270 |
+
for key, val in data.items():
|
| 271 |
+
if isinstance(val, torch.Tensor):
|
| 272 |
+
tensors[key] = val
|
| 273 |
+
elif isinstance(val, np.ndarray):
|
| 274 |
+
non_tensors[key] = val
|
| 275 |
+
else:
|
| 276 |
+
raise ValueError(f'Unsupported type in data {type(val)}')
|
| 277 |
+
|
| 278 |
+
return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info)
|
| 279 |
+
|
| 280 |
+
@classmethod
|
| 281 |
+
def from_dict(cls, tensors: Dict[str, torch.Tensor], non_tensors=None, meta_info=None, num_batch_dims=1):
|
| 282 |
+
"""Create a DataProto from a dict of tensors. This assumes that
|
| 283 |
+
1. All the tensor in tensors have the same dim0
|
| 284 |
+
2. Only dim0 is the batch dim
|
| 285 |
+
"""
|
| 286 |
+
assert len(tensors) > 0, 'tensors must not be empty'
|
| 287 |
+
assert num_batch_dims > 0, 'num_batch_dims must be greater than zero'
|
| 288 |
+
if non_tensors is not None:
|
| 289 |
+
assert num_batch_dims == 1, 'only support num_batch_dims=1 when non_tensors is not None.'
|
| 290 |
+
|
| 291 |
+
if meta_info is None:
|
| 292 |
+
meta_info = {}
|
| 293 |
+
if non_tensors is None:
|
| 294 |
+
non_tensors = {}
|
| 295 |
+
|
| 296 |
+
assert isinstance(non_tensors, dict)
|
| 297 |
+
|
| 298 |
+
# get and check batch size
|
| 299 |
+
batch_size = None
|
| 300 |
+
pivot_key = None
|
| 301 |
+
for key, tensor in tensors.items():
|
| 302 |
+
if batch_size is None:
|
| 303 |
+
batch_size = tensor.shape[:num_batch_dims]
|
| 304 |
+
pivot_key = key
|
| 305 |
+
else:
|
| 306 |
+
current_batch = tensor.shape[:num_batch_dims]
|
| 307 |
+
assert batch_size == current_batch, \
|
| 308 |
+
f'Not all the tensor in tensors have the same batch size with batch_dims={num_batch_dims}. Got {pivot_key} has {batch_size}, {key} has {current_batch}'
|
| 309 |
+
|
| 310 |
+
for key, val in non_tensors.items():
|
| 311 |
+
non_tensors[key] = np.array(val, dtype=object)
|
| 312 |
+
|
| 313 |
+
tensor_dict = TensorDict(source=tensors, batch_size=batch_size)
|
| 314 |
+
return cls(batch=tensor_dict, non_tensor_batch=non_tensors, meta_info=meta_info)
|
| 315 |
+
|
| 316 |
+
def to(self, device) -> 'DataProto':
|
| 317 |
+
"""move the batch to device
|
| 318 |
+
|
| 319 |
+
Args:
|
| 320 |
+
device (torch.device, str): torch device
|
| 321 |
+
|
| 322 |
+
Returns:
|
| 323 |
+
DataProto: the current DataProto
|
| 324 |
+
|
| 325 |
+
"""
|
| 326 |
+
if self.batch is not None:
|
| 327 |
+
self.batch = self.batch.to(device)
|
| 328 |
+
return self
|
| 329 |
+
|
| 330 |
+
def select(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None, deepcopy=False) -> 'DataProto':
|
| 331 |
+
"""Select a subset of the DataProto via batch_keys and meta_info_keys
|
| 332 |
+
|
| 333 |
+
Args:
|
| 334 |
+
batch_keys (list, optional): a list of strings indicating the keys in batch to select
|
| 335 |
+
meta_info_keys (list, optional): a list of keys indicating the meta info to select
|
| 336 |
+
|
| 337 |
+
Returns:
|
| 338 |
+
DataProto: the DataProto with the selected batch_keys and meta_info_keys
|
| 339 |
+
"""
|
| 340 |
+
# TODO (zhangchi.usc1992) whether to copy
|
| 341 |
+
if batch_keys is not None:
|
| 342 |
+
batch_keys = tuple(batch_keys)
|
| 343 |
+
sub_batch = self.batch.select(*batch_keys)
|
| 344 |
+
else:
|
| 345 |
+
sub_batch = self.batch
|
| 346 |
+
|
| 347 |
+
if non_tensor_batch_keys is not None:
|
| 348 |
+
non_tensor_batch = {key: val for key, val in self.non_tensor_batch.items() if key in non_tensor_batch_keys}
|
| 349 |
+
else:
|
| 350 |
+
non_tensor_batch = self.non_tensor_batch
|
| 351 |
+
|
| 352 |
+
if deepcopy:
|
| 353 |
+
non_tensor_batch = copy.deepcopy(non_tensor_batch)
|
| 354 |
+
|
| 355 |
+
if meta_info_keys is not None:
|
| 356 |
+
sub_meta_info = {key: val for key, val in self.meta_info.items() if key in meta_info_keys}
|
| 357 |
+
else:
|
| 358 |
+
sub_meta_info = self.meta_info
|
| 359 |
+
|
| 360 |
+
if deepcopy:
|
| 361 |
+
sub_meta_info = copy.deepcopy(sub_meta_info)
|
| 362 |
+
|
| 363 |
+
return DataProto(batch=sub_batch, non_tensor_batch=non_tensor_batch, meta_info=sub_meta_info)
|
| 364 |
+
|
| 365 |
+
def pop(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None) -> 'DataProto':
|
| 366 |
+
"""Pop a subset of the DataProto via `batch_keys` and `meta_info_keys`
|
| 367 |
+
|
| 368 |
+
Args:
|
| 369 |
+
batch_keys (list, optional): a list of strings indicating the keys in batch to pop
|
| 370 |
+
meta_info_keys (list, optional): a list of keys indicating the meta info to pop
|
| 371 |
+
|
| 372 |
+
Returns:
|
| 373 |
+
DataProto: the DataProto with the poped batch_keys and meta_info_keys
|
| 374 |
+
"""
|
| 375 |
+
assert batch_keys is not None
|
| 376 |
+
if meta_info_keys is None:
|
| 377 |
+
meta_info_keys = []
|
| 378 |
+
if non_tensor_batch_keys is None:
|
| 379 |
+
non_tensor_batch_keys = []
|
| 380 |
+
|
| 381 |
+
tensors = {}
|
| 382 |
+
# tensor batch
|
| 383 |
+
for key in batch_keys:
|
| 384 |
+
assert key in self.batch.keys()
|
| 385 |
+
tensors[key] = self.batch.pop(key)
|
| 386 |
+
non_tensors = {}
|
| 387 |
+
# non tensor batch
|
| 388 |
+
for key in non_tensor_batch_keys:
|
| 389 |
+
assert key in self.non_tensor_batch.keys()
|
| 390 |
+
non_tensors[key] = self.non_tensor_batch.pop(key)
|
| 391 |
+
meta_info = {}
|
| 392 |
+
for key in meta_info_keys:
|
| 393 |
+
assert key in self.meta_info.keys()
|
| 394 |
+
meta_info[key] = self.meta_info.pop(key)
|
| 395 |
+
return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info)
|
| 396 |
+
|
| 397 |
+
def rename(self, old_keys=None, new_keys=None) -> 'DataProto':
|
| 398 |
+
"""
|
| 399 |
+
Note that this function only rename the key in the batch
|
| 400 |
+
"""
|
| 401 |
+
|
| 402 |
+
def validate_input(keys):
|
| 403 |
+
if keys is not None:
|
| 404 |
+
if isinstance(keys, str):
|
| 405 |
+
keys = [keys]
|
| 406 |
+
elif isinstance(keys, list):
|
| 407 |
+
pass
|
| 408 |
+
else:
|
| 409 |
+
raise TypeError(f'keys must be a list or a string, but got {type(keys)}')
|
| 410 |
+
return keys
|
| 411 |
+
|
| 412 |
+
old_keys = validate_input(old_keys)
|
| 413 |
+
new_keys = validate_input(new_keys)
|
| 414 |
+
|
| 415 |
+
if len(new_keys) != len(old_keys):
|
| 416 |
+
raise ValueError(
|
| 417 |
+
f'new_keys and old_keys must have the same length, but got {len(new_keys)} and {len(old_keys)}')
|
| 418 |
+
|
| 419 |
+
self.batch.rename_key_(tuple(old_keys), tuple(new_keys))
|
| 420 |
+
|
| 421 |
+
return self
|
| 422 |
+
|
| 423 |
+
def union(self, other: 'DataProto') -> 'DataProto':
|
| 424 |
+
"""Union with another DataProto. Union batch and meta_info separately.
|
| 425 |
+
Throw an error if
|
| 426 |
+
- there are conflict keys in batch and they are not equal
|
| 427 |
+
- the batch size of two data batch is not the same
|
| 428 |
+
- there are conflict keys in meta_info and they are not the same.
|
| 429 |
+
|
| 430 |
+
Args:
|
| 431 |
+
other (DataProto): another DataProto to union
|
| 432 |
+
|
| 433 |
+
Returns:
|
| 434 |
+
DataProto: the DataProto after union
|
| 435 |
+
"""
|
| 436 |
+
self.batch = union_tensor_dict(self.batch, other.batch)
|
| 437 |
+
self.non_tensor_batch = union_numpy_dict(self.non_tensor_batch, other.non_tensor_batch)
|
| 438 |
+
self.meta_info = union_two_dict(self.meta_info, other.meta_info)
|
| 439 |
+
return self
|
| 440 |
+
|
| 441 |
+
def make_iterator(self, mini_batch_size, epochs, seed=None, dataloader_kwargs=None):
|
| 442 |
+
"""Make an iterator from the DataProto. This is built upon that TensorDict can be used as a normal Pytorch
|
| 443 |
+
dataset. See https://pytorch.org/tensordict/tutorials/data_fashion for more details.
|
| 444 |
+
|
| 445 |
+
Args:
|
| 446 |
+
mini_batch_size (int): mini-batch size when iterating the dataset. We require that
|
| 447 |
+
``batch.batch_size[0] % mini_batch_size == 0``
|
| 448 |
+
epochs (int): number of epochs when iterating the dataset.
|
| 449 |
+
dataloader_kwargs: internally, it returns a DataLoader over the batch.
|
| 450 |
+
The dataloader_kwargs is the kwargs passed to the DataLoader
|
| 451 |
+
|
| 452 |
+
Returns:
|
| 453 |
+
Iterator: an iterator that yields a mini-batch data at a time. The total number of iteration steps is
|
| 454 |
+
``self.batch.batch_size * epochs // mini_batch_size``
|
| 455 |
+
"""
|
| 456 |
+
assert self.batch.batch_size[0] % mini_batch_size == 0, f"{self.batch.batch_size[0]} % {mini_batch_size} != 0"
|
| 457 |
+
# we can directly create a dataloader from TensorDict
|
| 458 |
+
if dataloader_kwargs is None:
|
| 459 |
+
dataloader_kwargs = {}
|
| 460 |
+
|
| 461 |
+
if seed is not None:
|
| 462 |
+
generator = torch.Generator()
|
| 463 |
+
generator.manual_seed(seed)
|
| 464 |
+
else:
|
| 465 |
+
generator = None
|
| 466 |
+
|
| 467 |
+
assert isinstance(dataloader_kwargs, Dict)
|
| 468 |
+
train_dataloader = DataLoader(dataset=self,
|
| 469 |
+
batch_size=mini_batch_size,
|
| 470 |
+
collate_fn=collate_fn,
|
| 471 |
+
generator=generator,
|
| 472 |
+
**dataloader_kwargs)
|
| 473 |
+
|
| 474 |
+
def get_data():
|
| 475 |
+
for _ in range(epochs):
|
| 476 |
+
for d in train_dataloader:
|
| 477 |
+
d.meta_info = self.meta_info
|
| 478 |
+
yield d
|
| 479 |
+
|
| 480 |
+
return iter(get_data())
|
| 481 |
+
|
| 482 |
+
def chunk(self, chunks: int) -> List['DataProto']:
|
| 483 |
+
"""Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split.
|
| 484 |
+
|
| 485 |
+
Args:
|
| 486 |
+
chunks (int): the number of chunks to split on dim=0
|
| 487 |
+
|
| 488 |
+
Returns:
|
| 489 |
+
List[DataProto]: a list of DataProto after splitting
|
| 490 |
+
"""
|
| 491 |
+
assert len(
|
| 492 |
+
self) % chunks == 0, f'only support equal chunk. Got size of DataProto {len(self)} and chunk {chunks}.'
|
| 493 |
+
|
| 494 |
+
if self.batch is not None:
|
| 495 |
+
batch_lst = self.batch.chunk(chunks=chunks, dim=0)
|
| 496 |
+
else:
|
| 497 |
+
batch_lst = [None for _ in range(chunks)]
|
| 498 |
+
|
| 499 |
+
non_tensor_batch_lst = [{} for _ in range(chunks)]
|
| 500 |
+
for key, val in self.non_tensor_batch.items():
|
| 501 |
+
assert isinstance(val, np.ndarray)
|
| 502 |
+
non_tensor_lst = np.array_split(val, chunks)
|
| 503 |
+
assert len(non_tensor_lst) == chunks
|
| 504 |
+
for i in range(chunks):
|
| 505 |
+
non_tensor_batch_lst[i][key] = non_tensor_lst[i]
|
| 506 |
+
|
| 507 |
+
output = []
|
| 508 |
+
for i in range(chunks):
|
| 509 |
+
output.append(
|
| 510 |
+
DataProto(batch=batch_lst[i], non_tensor_batch=non_tensor_batch_lst[i], meta_info=self.meta_info))
|
| 511 |
+
|
| 512 |
+
return output
|
| 513 |
+
|
| 514 |
+
@staticmethod
|
| 515 |
+
def concat(data: List['DataProto']) -> 'DataProto':
|
| 516 |
+
"""Concat a list of DataProto. The batch is concatenated among dim=0.
|
| 517 |
+
The meta_info is assumed to be identical and will use the first one.
|
| 518 |
+
|
| 519 |
+
Args:
|
| 520 |
+
data (List[DataProto]): list of DataProto
|
| 521 |
+
|
| 522 |
+
Returns:
|
| 523 |
+
DataProto: concatenated DataProto
|
| 524 |
+
"""
|
| 525 |
+
batch_lst = []
|
| 526 |
+
for batch in data:
|
| 527 |
+
batch_lst.append(batch.batch)
|
| 528 |
+
if batch_lst[0] is not None:
|
| 529 |
+
new_batch = torch.cat(batch_lst, dim=0)
|
| 530 |
+
else:
|
| 531 |
+
new_batch = None
|
| 532 |
+
|
| 533 |
+
non_tensor_batch = list_of_dict_to_dict_of_list(list_of_dict=[d.non_tensor_batch for d in data])
|
| 534 |
+
for key, val in non_tensor_batch.items():
|
| 535 |
+
non_tensor_batch[key] = np.concatenate(val, axis=0)
|
| 536 |
+
|
| 537 |
+
return DataProto(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=data[0].meta_info)
|
| 538 |
+
|
| 539 |
+
def reorder(self, indices):
|
| 540 |
+
"""
|
| 541 |
+
Note that this operation is in-place
|
| 542 |
+
"""
|
| 543 |
+
indices_np = indices.detach().numpy()
|
| 544 |
+
self.batch = self.batch[indices]
|
| 545 |
+
self.non_tensor_batch = {key: val[indices_np] for key, val in self.non_tensor_batch.items()}
|
| 546 |
+
|
| 547 |
+
def repeat(self, repeat_times=2, interleave=True):
|
| 548 |
+
"""
|
| 549 |
+
Repeat the batch data a specified number of times.
|
| 550 |
+
|
| 551 |
+
Args:
|
| 552 |
+
repeat_times (int): Number of times to repeat the data.
|
| 553 |
+
interleave (bool): Whether to interleave the repeated data.
|
| 554 |
+
|
| 555 |
+
Returns:
|
| 556 |
+
DataProto: A new DataProto with repeated data.
|
| 557 |
+
"""
|
| 558 |
+
if self.batch is not None:
|
| 559 |
+
if interleave:
|
| 560 |
+
# Interleave the data
|
| 561 |
+
repeated_tensors = {
|
| 562 |
+
key: tensor.repeat_interleave(repeat_times, dim=0) for key, tensor in self.batch.items()
|
| 563 |
+
}
|
| 564 |
+
else:
|
| 565 |
+
# Stack the data
|
| 566 |
+
repeated_tensors = {
|
| 567 |
+
key: tensor.unsqueeze(0).expand(repeat_times, *tensor.shape).reshape(-1, *tensor.shape[1:])
|
| 568 |
+
for key, tensor in self.batch.items()
|
| 569 |
+
}
|
| 570 |
+
|
| 571 |
+
repeated_batch = TensorDict(
|
| 572 |
+
source=repeated_tensors,
|
| 573 |
+
batch_size=(self.batch.batch_size[0] * repeat_times,),
|
| 574 |
+
)
|
| 575 |
+
else:
|
| 576 |
+
repeated_batch = None
|
| 577 |
+
|
| 578 |
+
repeated_non_tensor_batch = {}
|
| 579 |
+
for key, val in self.non_tensor_batch.items():
|
| 580 |
+
if interleave:
|
| 581 |
+
repeated_non_tensor_batch[key] = np.repeat(val, repeat_times, axis=0)
|
| 582 |
+
else:
|
| 583 |
+
repeated_non_tensor_batch[key] = np.tile(val, (repeat_times,) + (1,) * (val.ndim - 1))
|
| 584 |
+
|
| 585 |
+
return DataProto(
|
| 586 |
+
batch=repeated_batch,
|
| 587 |
+
non_tensor_batch=repeated_non_tensor_batch,
|
| 588 |
+
meta_info=self.meta_info,
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
import ray
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
@dataclass
|
| 596 |
+
class DataProtoFuture:
|
| 597 |
+
"""
|
| 598 |
+
DataProtoFuture aims to eliminate actual data fetching on driver. By doing so, the driver doesn't have to wait
|
| 599 |
+
for data so that asynchronous execution becomes possible.
|
| 600 |
+
DataProtoFuture contains a list of futures from another WorkerGroup of size world_size.
|
| 601 |
+
- collect_fn is a Callable that reduces the list of futures to a DataProto
|
| 602 |
+
- dispatch_fn is a Callable that partitions the DataProto into a list of DataProto of size world_size and then select
|
| 603 |
+
|
| 604 |
+
Potential issue: we can optimize dispatch_fn(collect_fn) such that only needed data is fetched on destination
|
| 605 |
+
- DataProtoFuture only supports directly passing from the output of a method to another input. You can't perform any
|
| 606 |
+
operation on the DataProtoFuture in driver.
|
| 607 |
+
"""
|
| 608 |
+
collect_fn: Callable
|
| 609 |
+
futures: List[ray.ObjectRef]
|
| 610 |
+
dispatch_fn: Callable = None
|
| 611 |
+
|
| 612 |
+
@staticmethod
|
| 613 |
+
def concat(data: List[ray.ObjectRef]) -> 'DataProtoFuture':
|
| 614 |
+
output = DataProtoFuture(collect_fn=DataProto.concat, futures=data)
|
| 615 |
+
return output
|
| 616 |
+
|
| 617 |
+
def chunk(self, chunks: int) -> List['DataProtoFuture']:
|
| 618 |
+
from functools import partial
|
| 619 |
+
|
| 620 |
+
arg_future_lst = []
|
| 621 |
+
for i in range(chunks):
|
| 622 |
+
# note that we can't directly pass i and chunks
|
| 623 |
+
def dispatch_fn(x, i, chunks):
|
| 624 |
+
return x.chunk(chunks=chunks)[i]
|
| 625 |
+
|
| 626 |
+
arg_future = DataProtoFuture(collect_fn=self.collect_fn,
|
| 627 |
+
dispatch_fn=partial(dispatch_fn, i=i, chunks=chunks),
|
| 628 |
+
futures=self.futures)
|
| 629 |
+
arg_future_lst.append(arg_future)
|
| 630 |
+
return arg_future_lst
|
| 631 |
+
|
| 632 |
+
def get(self):
|
| 633 |
+
output = ray.get(self.futures) # dp_size.
|
| 634 |
+
for o in output:
|
| 635 |
+
assert isinstance(o, DataProto)
|
| 636 |
+
output = self.collect_fn(output) # select dp, concat
|
| 637 |
+
if self.dispatch_fn is not None:
|
| 638 |
+
output = self.dispatch_fn(output) # split in batch dim, select using dp
|
| 639 |
+
return output
|
code/RL_model/verl/Search-R1/wandb/debug-internal.log
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"time":"2026-02-01T20:27:26.269116545-05:00","level":"INFO","msg":"stream: starting","core version":"0.23.1"}
|
| 2 |
+
{"time":"2026-02-01T20:27:27.692526697-05:00","level":"INFO","msg":"stream: created new stream","id":"lly0j9zs"}
|
| 3 |
+
{"time":"2026-02-01T20:27:27.692680073-05:00","level":"INFO","msg":"handler: started","stream_id":"lly0j9zs"}
|
| 4 |
+
{"time":"2026-02-01T20:27:27.695494454-05:00","level":"INFO","msg":"stream: started","id":"lly0j9zs"}
|
| 5 |
+
{"time":"2026-02-01T20:27:27.69557747-05:00","level":"INFO","msg":"writer: started","stream_id":"lly0j9zs"}
|
| 6 |
+
{"time":"2026-02-01T20:27:27.695701035-05:00","level":"INFO","msg":"sender: started","stream_id":"lly0j9zs"}
|
code/RL_model/verl/Search-R1/wandb/debug.log
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2026-02-01 20:27:25,874 INFO MainThread:1578907 [wandb_setup.py:_flush():80] Current SDK version is 0.23.1
|
| 2 |
+
2026-02-01 20:27:25,874 INFO MainThread:1578907 [wandb_setup.py:_flush():80] Configure stats pid to 1578907
|
| 3 |
+
2026-02-01 20:27:25,875 INFO MainThread:1578907 [wandb_setup.py:_flush():80] Loading settings from /home/mshahidul/.config/wandb/settings
|
| 4 |
+
2026-02-01 20:27:25,875 INFO MainThread:1578907 [wandb_setup.py:_flush():80] Loading settings from /data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1/wandb/settings
|
| 5 |
+
2026-02-01 20:27:25,875 INFO MainThread:1578907 [wandb_setup.py:_flush():80] Loading settings from environment variables
|
| 6 |
+
2026-02-01 20:27:25,875 INFO MainThread:1578907 [wandb_init.py:setup_run_log_directory():714] Logging user logs to /data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1/wandb/run-20260201_202725-lly0j9zs/logs/debug.log
|
| 7 |
+
2026-02-01 20:27:25,875 INFO MainThread:1578907 [wandb_init.py:setup_run_log_directory():715] Logging internal logs to /data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1/wandb/run-20260201_202725-lly0j9zs/logs/debug-internal.log
|
| 8 |
+
2026-02-01 20:27:25,876 INFO MainThread:1578907 [wandb_init.py:init():841] calling init triggers
|
| 9 |
+
2026-02-01 20:27:25,876 INFO MainThread:1578907 [wandb_init.py:init():846] wandb.init called with sweep_config: {}
|
| 10 |
+
config: {'data': {'tokenizer': None, 'train_files': '/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet', 'val_files': '/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet', 'train_data_num': None, 'val_data_num': None, 'prompt_key': 'prompt', 'max_prompt_length': 4096, 'max_response_length': 1024, 'max_start_length': 256, 'max_obs_length': 512, 'train_batch_size': 128, 'val_batch_size': 64, 'return_raw_input_ids': False, 'return_raw_chat': False, 'shuffle_train_dataloader': True}, 'actor_rollout_ref': {'hybrid_engine': True, 'model': {'path': 'Qwen/Qwen3-4B-Instruct-2507', 'external_lib': None, 'override_config': {}, 'enable_gradient_checkpointing': True, 'use_remove_padding': False}, 'actor': {'strategy': 'fsdp', 'ppo_mini_batch_size': 64, 'ppo_micro_batch_size': 64, 'use_dynamic_bsz': False, 'ppo_max_token_len_per_gpu': 16384, 'grad_clip': 1.0, 'state_masking': False, 'clip_ratio': 0.2, 'entropy_coeff': 0.001, 'use_kl_loss': False, 'kl_loss_coef': 0.001, 'kl_loss_type': 'low_var_kl', 'ppo_epochs': 1, 'shuffle': False, 'ulysses_sequence_parallel_size': 1, 'optim': {'lr': 1e-06, 'lr_warmup_steps_ratio': 0.0, 'min_lr_ratio': None, 'warmup_style': 'constant', 'total_training_steps': 1005}, 'fsdp_config': {'wrap_policy': {'min_num_params': 0}, 'param_offload': True, 'grad_offload': False, 'optimizer_offload': True, 'fsdp_size': -1}, 'ppo_micro_batch_size_per_gpu': 16}, 'ref': {'fsdp_config': {'param_offload': True, 'wrap_policy': {'min_num_params': 0}, 'fsdp_size': -1}, 'log_prob_micro_batch_size': 64, 'log_prob_use_dynamic_bsz': False, 'log_prob_max_token_len_per_gpu': 16384, 'ulysses_sequence_parallel_size': 1}, 'rollout': {'name': 'vllm', 'temperature': 1.0, 'top_k': -1, 'top_p': 0.95, 'prompt_length': 4096, 'response_length': 1024, 'dtype': 'bfloat16', 'gpu_memory_utilization': 0.4, 'ignore_eos': False, 'enforce_eager': True, 'free_cache_engine': True, 'load_format': 'dummy_dtensor', 'tensor_model_parallel_size': 1, 'max_num_batched_tokens': 8192, 'max_num_seqs': 1024, 'log_prob_micro_batch_size': 64, 'log_prob_use_dynamic_bsz': False, 'log_prob_max_token_len_per_gpu': 16384, 'do_sample': True, 'n': 1, 'n_agent': 1}}, 'critic': {'strategy': 'fsdp', 'optim': {'lr': 1e-05, 'lr_warmup_steps_ratio': 0.0, 'min_lr_ratio': None, 'warmup_style': 'constant', 'total_training_steps': 1005}, 'model': {'path': '~/models/deepseek-llm-7b-chat', 'tokenizer_path': 'Qwen/Qwen3-4B-Instruct-2507', 'override_config': {}, 'external_lib': None, 'enable_gradient_checkpointing': False, 'use_remove_padding': False, 'fsdp_config': {'param_offload': False, 'grad_offload': False, 'optimizer_offload': False, 'wrap_policy': {'min_num_params': 0}, 'fsdp_size': -1}}, 'ppo_mini_batch_size': 64, 'ppo_micro_batch_size': 64, 'forward_micro_batch_size': 64, 'use_dynamic_bsz': False, 'ppo_max_token_len_per_gpu': 32768, 'forward_max_token_len_per_gpu': 32768, 'ulysses_sequence_parallel_size': 1, 'ppo_epochs': 1, 'shuffle': False, 'grad_clip': 1.0, 'cliprange_value': 0.5}, 'reward_model': {'enable': False, 'strategy': 'fsdp', 'model': {'input_tokenizer': 'Qwen/Qwen3-4B-Instruct-2507', 'path': '~/models/FsfairX-LLaMA3-RM-v0.1', 'external_lib': None, 'use_remove_padding': False, 'fsdp_config': {'min_num_params': 0, 'param_offload': False}}, 'micro_batch_size': 64, 'max_length': None, 'ulysses_sequence_parallel_size': 1, 'use_dynamic_bsz': False, 'forward_max_token_len_per_gpu': 32768, 'structure_format_score': 0, 'final_format_score': 0, 'retrieval_score': 0}, 'retriever': {'url': 'http://127.0.0.1:8000/retrieve', 'topk': 3}, 'algorithm': {'gamma': 1.0, 'lam': 1.0, 'adv_estimator': 'grpo', 'no_think_rl': False, 'kl_penalty': 'kl', 'kl_ctrl': {'type': 'fixed', 'kl_coef': 0.001}, 'state_masking': {'start_state_marker': '<information>', 'end_state_marker': '</information>'}}, 'trainer': {'total_epochs': 15, 'total_training_steps': 1005, 'project_name': '', 'experiment_name': 'llm_guard_3B_10k_v2', 'logger': ['wandb'], 'nnodes': 1, 'n_gpus_per_node': 2, 'save_freq': 100, 'test_freq': 50, 'critic_warmup': 0, 'default_hdfs_dir': '~/experiments/gsm8k/ppo/llm_guard_3B_10k_v2', 'default_local_dir': 'verl_checkpoints/llm_guard_3B_10k_v2'}, 'max_turns': 1, 'do_search': False, '_wandb': {}}
|
| 11 |
+
2026-02-01 20:27:25,876 INFO MainThread:1578907 [wandb_init.py:init():889] starting backend
|
| 12 |
+
2026-02-01 20:27:26,251 INFO MainThread:1578907 [wandb_init.py:init():892] sending inform_init request
|
| 13 |
+
2026-02-01 20:27:26,261 INFO MainThread:1578907 [wandb_init.py:init():900] backend started and connected
|
| 14 |
+
2026-02-01 20:27:26,270 INFO MainThread:1578907 [wandb_init.py:init():970] updated telemetry
|
| 15 |
+
2026-02-01 20:27:26,293 INFO MainThread:1578907 [wandb_init.py:init():994] communicating run to backend with 90.0 second timeout
|
| 16 |
+
2026-02-01 20:27:27,908 INFO MainThread:1578907 [wandb_init.py:init():1041] starting run threads in backend
|
| 17 |
+
2026-02-01 20:27:28,715 INFO MainThread:1578907 [wandb_run.py:_console_start():2521] atexit reg
|
| 18 |
+
2026-02-01 20:27:28,716 INFO MainThread:1578907 [wandb_run.py:_redirect():2369] redirect: wrap_raw
|
| 19 |
+
2026-02-01 20:27:28,716 INFO MainThread:1578907 [wandb_run.py:_redirect():2438] Wrapping output streams.
|
| 20 |
+
2026-02-01 20:27:28,716 INFO MainThread:1578907 [wandb_run.py:_redirect():2461] Redirects installed.
|
| 21 |
+
2026-02-01 20:27:28,726 INFO MainThread:1578907 [wandb_init.py:init():1081] run started, returning control to user process
|
code/RL_model/verl/verl_train/tests/experimental/agent_loop/agent_utils.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import ray
|
| 16 |
+
from omegaconf import DictConfig
|
| 17 |
+
|
| 18 |
+
from verl.experimental.agent_loop import AgentLoopManager
|
| 19 |
+
from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup
|
| 20 |
+
from verl.single_controller.ray.base import create_colocated_worker_cls
|
| 21 |
+
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
|
| 22 |
+
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, RewardModelWorker
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def init_agent_loop_manager(config: DictConfig) -> AgentLoopManager | RayWorkerGroup:
|
| 26 |
+
# =========================== 1. Create hybrid ActorRollout workers ===========================
|
| 27 |
+
actor_rollout_cls = (
|
| 28 |
+
AsyncActorRolloutRefWorker if config.actor_rollout_ref.rollout.mode == "async" else ActorRolloutRefWorker
|
| 29 |
+
)
|
| 30 |
+
role_worker_mapping = {
|
| 31 |
+
Role.ActorRollout: ray.remote(actor_rollout_cls),
|
| 32 |
+
}
|
| 33 |
+
if config.reward_model.enable:
|
| 34 |
+
role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
|
| 35 |
+
|
| 36 |
+
global_pool_id = "global_pool"
|
| 37 |
+
resource_pool_spec = {
|
| 38 |
+
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
|
| 39 |
+
}
|
| 40 |
+
mapping = {
|
| 41 |
+
Role.ActorRollout: global_pool_id,
|
| 42 |
+
}
|
| 43 |
+
if config.reward_model.enable_resource_pool:
|
| 44 |
+
mapping[Role.RewardModel] = "reward_pool"
|
| 45 |
+
if config.reward_model.n_gpus_per_node <= 0:
|
| 46 |
+
raise ValueError("config.reward_model.n_gpus_per_node must be greater than 0")
|
| 47 |
+
if config.reward_model.nnodes <= 0:
|
| 48 |
+
raise ValueError("config.reward_model.nnodes must be greater than 0")
|
| 49 |
+
|
| 50 |
+
reward_pool = [config.reward_model.n_gpus_per_node] * config.reward_model.nnodes
|
| 51 |
+
resource_pool_spec["reward_pool"] = reward_pool
|
| 52 |
+
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
|
| 53 |
+
resource_pool_manager.create_resource_pool()
|
| 54 |
+
resource_pool_to_cls = {pool: {} for pool in resource_pool_manager.resource_pool_dict.values()}
|
| 55 |
+
|
| 56 |
+
# create actor and rollout
|
| 57 |
+
resource_pool = resource_pool_manager.get_resource_pool(Role.ActorRollout)
|
| 58 |
+
actor_rollout_cls = RayClassWithInitArgs(
|
| 59 |
+
cls=role_worker_mapping[Role.ActorRollout], config=config.actor_rollout_ref, role="actor_rollout"
|
| 60 |
+
)
|
| 61 |
+
resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls
|
| 62 |
+
|
| 63 |
+
if config.reward_model.enable:
|
| 64 |
+
# we create a RM here
|
| 65 |
+
resource_pool = resource_pool_manager.get_resource_pool(Role.RewardModel)
|
| 66 |
+
rm_cls = RayClassWithInitArgs(role_worker_mapping[Role.RewardModel], config=config.reward_model)
|
| 67 |
+
resource_pool_to_cls[resource_pool]["rm"] = rm_cls
|
| 68 |
+
|
| 69 |
+
all_wg = {}
|
| 70 |
+
for resource_pool, class_dict in resource_pool_to_cls.items():
|
| 71 |
+
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
|
| 72 |
+
wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls)
|
| 73 |
+
spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
|
| 74 |
+
all_wg.update(spawn_wg)
|
| 75 |
+
actor_rollout_wg = all_wg["actor_rollout"]
|
| 76 |
+
actor_rollout_wg.init_model()
|
| 77 |
+
|
| 78 |
+
if config.actor_rollout_ref.rollout.mode == "sync":
|
| 79 |
+
raise ValueError("Agent loop tests require async rollout mode. Please set rollout.mode=async.")
|
| 80 |
+
|
| 81 |
+
if config.reward_model.enable_resource_pool and config.reward_model.enable:
|
| 82 |
+
rm_resource_pool = resource_pool_manager.get_resource_pool(Role.RewardModel)
|
| 83 |
+
else:
|
| 84 |
+
rm_resource_pool = None
|
| 85 |
+
# =========================== 2. Create AgentLoopManager ===========================
|
| 86 |
+
agent_loop_manager = AgentLoopManager(
|
| 87 |
+
config=config,
|
| 88 |
+
worker_group=actor_rollout_wg,
|
| 89 |
+
rm_resource_pool=rm_resource_pool,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
return agent_loop_manager
|
code/RL_model/verl/verl_train/tests/experimental/agent_loop/qwen_vl_tool_chat_template.jinja2
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{% set image_count = namespace(value=0) %}
|
| 2 |
+
{% set video_count = namespace(value=0) %}
|
| 3 |
+
{%- if tools %}
|
| 4 |
+
{{- '<|im_start|>system\n' }}
|
| 5 |
+
{%- if messages[0]['role'] == 'system' %}
|
| 6 |
+
{%- if messages[0]['content'] is string %}
|
| 7 |
+
{{- messages[0]['content'] }}
|
| 8 |
+
{%- else %}
|
| 9 |
+
{{- messages[0]['content'][0]['text'] }}
|
| 10 |
+
{%- endif %}
|
| 11 |
+
{%- else %}
|
| 12 |
+
{{- 'You are a helpful assistant.' }}
|
| 13 |
+
{%- endif %}
|
| 14 |
+
{{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
|
| 15 |
+
{%- for tool in tools %}
|
| 16 |
+
{{- "\n" }}
|
| 17 |
+
{{- tool | tojson }}
|
| 18 |
+
{%- endfor %}
|
| 19 |
+
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
|
| 20 |
+
{% for message in messages %}
|
| 21 |
+
{% if message['role'] != 'system' or loop.first == false %}
|
| 22 |
+
{%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %}
|
| 23 |
+
<|im_start|>{{ message['role'] }}
|
| 24 |
+
{% if message['content'] is string %}
|
| 25 |
+
{{ message['content'] }}<|im_end|>
|
| 26 |
+
{% else %}
|
| 27 |
+
{% for content in message['content'] %}
|
| 28 |
+
{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}
|
| 29 |
+
{% set image_count.value = image_count.value + 1 %}
|
| 30 |
+
{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>
|
| 31 |
+
{% elif content['type'] == 'video' or 'video' in content %}
|
| 32 |
+
{% set video_count.value = video_count.value + 1 %}
|
| 33 |
+
{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>
|
| 34 |
+
{% elif 'text' in content %}
|
| 35 |
+
{{ content['text'] }}
|
| 36 |
+
{% endif %}
|
| 37 |
+
{% endfor %}<|im_end|>
|
| 38 |
+
{% endif %}
|
| 39 |
+
{%- elif message.role == "assistant" %}
|
| 40 |
+
{{- '<|im_start|>' + message.role }}
|
| 41 |
+
{%- if message.content %}
|
| 42 |
+
{{- '\n' + message.content }}
|
| 43 |
+
{%- endif %}
|
| 44 |
+
{%- for tool_call in message.tool_calls %}
|
| 45 |
+
{%- if tool_call.function is defined %}
|
| 46 |
+
{%- set tool_call = tool_call.function %}
|
| 47 |
+
{%- endif %}
|
| 48 |
+
{{- '\n<tool_call>\n{"name": "' }}
|
| 49 |
+
{{- tool_call.name }}
|
| 50 |
+
{{- '", "arguments": ' }}
|
| 51 |
+
{{- tool_call.arguments | tojson }}
|
| 52 |
+
{{- '}\n</tool_call>' }}
|
| 53 |
+
{%- endfor %}
|
| 54 |
+
{{- '<|im_end|>\n' }}
|
| 55 |
+
{%- elif message.role == "tool" %}
|
| 56 |
+
{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}
|
| 57 |
+
{{- '<|im_start|>user' }}
|
| 58 |
+
{%- endif %}
|
| 59 |
+
{{- '\n<tool_response>\n' }}
|
| 60 |
+
{% if message['content'] is string %}
|
| 61 |
+
{{ message.content }}
|
| 62 |
+
{% else %}
|
| 63 |
+
{% for content in message['content'] %}
|
| 64 |
+
{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}
|
| 65 |
+
{% set image_count.value = image_count.value + 1 %}
|
| 66 |
+
{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>
|
| 67 |
+
{% elif content['type'] == 'video' or 'video' in content %}
|
| 68 |
+
{% set video_count.value = video_count.value + 1 %}
|
| 69 |
+
{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>
|
| 70 |
+
{% elif content['type'] == 'text' or 'text' in content %}
|
| 71 |
+
{{ content['text'] }}
|
| 72 |
+
{% endif %}
|
| 73 |
+
{% endfor %}
|
| 74 |
+
{% endif %}
|
| 75 |
+
{{- '\n</tool_response>' }}
|
| 76 |
+
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
| 77 |
+
{{- '<|im_end|>\n' }}
|
| 78 |
+
{%- endif %}
|
| 79 |
+
{%- endif %}
|
| 80 |
+
{% endif %}
|
| 81 |
+
{% endfor %}
|
| 82 |
+
{%- else %}
|
| 83 |
+
{% for message in messages %}
|
| 84 |
+
{% if loop.first and message['role'] != 'system' %}
|
| 85 |
+
<|im_start|>system
|
| 86 |
+
You are a helpful assistant.<|im_end|>
|
| 87 |
+
{% endif %}
|
| 88 |
+
{%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %}
|
| 89 |
+
<|im_start|>{{ message['role'] }}
|
| 90 |
+
{% if message['content'] is string %}
|
| 91 |
+
{{ message['content'] }}<|im_end|>
|
| 92 |
+
{% else %}
|
| 93 |
+
{% for content in message['content'] %}
|
| 94 |
+
{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}
|
| 95 |
+
{% set image_count.value = image_count.value + 1 %}
|
| 96 |
+
{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>
|
| 97 |
+
{% elif content['type'] == 'video' or 'video' in content %}
|
| 98 |
+
{% set video_count.value = video_count.value + 1 %}
|
| 99 |
+
{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>
|
| 100 |
+
{% elif 'text' in content %}
|
| 101 |
+
{{ content['text'] }}
|
| 102 |
+
{% endif %}
|
| 103 |
+
{% endfor %}<|im_end|>
|
| 104 |
+
{% endif %}
|
| 105 |
+
{%- elif message.role == "assistant" %}
|
| 106 |
+
{{- '<|im_start|>' + message.role }}
|
| 107 |
+
{%- if message.content %}
|
| 108 |
+
{{- '\n' + message.content }}
|
| 109 |
+
{%- endif %}
|
| 110 |
+
{%- for tool_call in message.tool_calls %}
|
| 111 |
+
{%- if tool_call.function is defined %}
|
| 112 |
+
{%- set tool_call = tool_call.function %}
|
| 113 |
+
{%- endif %}
|
| 114 |
+
{{- '\n<tool_call>\n{"name": "' }}
|
| 115 |
+
{{- tool_call.name }}
|
| 116 |
+
{{- '", "arguments": ' }}
|
| 117 |
+
{{- tool_call.arguments | tojson }}
|
| 118 |
+
{{- '}\n</tool_call>' }}
|
| 119 |
+
{%- endfor %}
|
| 120 |
+
{{- '<|im_end|>\n' }}
|
| 121 |
+
{%- elif message.role == "tool" %}
|
| 122 |
+
{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}
|
| 123 |
+
{{- '<|im_start|>user' }}
|
| 124 |
+
{%- endif %}
|
| 125 |
+
{{- '\n<tool_response>\n' }}
|
| 126 |
+
{% if message['content'] is string %}
|
| 127 |
+
{{ message.content }}
|
| 128 |
+
{% else %}
|
| 129 |
+
{% for content in message['content'] %}
|
| 130 |
+
{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}
|
| 131 |
+
{% set image_count.value = image_count.value + 1 %}
|
| 132 |
+
{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>
|
| 133 |
+
{% elif content['type'] == 'video' or 'video' in content %}
|
| 134 |
+
{% set video_count.value = video_count.value + 1 %}
|
| 135 |
+
{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>
|
| 136 |
+
{% elif content['type'] == 'text' or 'text' in content %}
|
| 137 |
+
{{ content['text'] }}
|
| 138 |
+
{% endif %}
|
| 139 |
+
{% endfor %}
|
| 140 |
+
{% endif %}
|
| 141 |
+
{{- '\n</tool_response>' }}
|
| 142 |
+
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
| 143 |
+
{{- '<|im_end|>\n' }}
|
| 144 |
+
{%- endif %}
|
| 145 |
+
{%- endif %}
|
| 146 |
+
{% endfor %}
|
| 147 |
+
{%- endif %}
|
| 148 |
+
{% if add_generation_prompt %}
|
| 149 |
+
<|im_start|>assistant
|
| 150 |
+
{% endif %}
|
code/RL_model/verl/verl_train/tests/experimental/agent_loop/test_basic_agent_loop.py
ADDED
|
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import json
|
| 15 |
+
import os
|
| 16 |
+
from typing import Any
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import pytest
|
| 20 |
+
import ray
|
| 21 |
+
from omegaconf import DictConfig
|
| 22 |
+
from transformers.utils import get_json_schema
|
| 23 |
+
|
| 24 |
+
from tests.experimental.agent_loop.agent_utils import init_agent_loop_manager
|
| 25 |
+
from verl.checkpoint_engine import CheckpointEngineManager
|
| 26 |
+
from verl.experimental.agent_loop import AgentLoopManager
|
| 27 |
+
from verl.experimental.agent_loop.agent_loop import get_trajectory_info
|
| 28 |
+
from verl.protocol import DataProto
|
| 29 |
+
from verl.tools.base_tool import BaseTool, OpenAIFunctionToolSchema
|
| 30 |
+
from verl.tools.schemas import ToolResponse
|
| 31 |
+
from verl.trainer.ppo.reward import compute_reward, load_reward_manager
|
| 32 |
+
from verl.utils import hf_tokenizer
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@pytest.fixture
|
| 36 |
+
def init_config() -> DictConfig:
|
| 37 |
+
from hydra import compose, initialize_config_dir
|
| 38 |
+
|
| 39 |
+
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
|
| 40 |
+
config = compose(
|
| 41 |
+
config_name="ppo_trainer",
|
| 42 |
+
overrides=[
|
| 43 |
+
"actor_rollout_ref.actor.use_dynamic_bsz=true",
|
| 44 |
+
# test sleep/wake_up with fsdp offload
|
| 45 |
+
"actor_rollout_ref.actor.fsdp_config.param_offload=True",
|
| 46 |
+
"actor_rollout_ref.actor.fsdp_config.optimizer_offload=True",
|
| 47 |
+
"reward_model.reward_manager=dapo",
|
| 48 |
+
"+reward_model.reward_kwargs.overlong_buffer_cfg.enable=False",
|
| 49 |
+
"+reward_model.reward_kwargs.overlong_buffer_cfg.len=3072",
|
| 50 |
+
"+reward_model.reward_kwargs.max_resp_len=4096",
|
| 51 |
+
],
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
model_path = os.path.expanduser("~/models/Qwen/Qwen2.5-1.5B-Instruct")
|
| 55 |
+
config.actor_rollout_ref.model.path = model_path
|
| 56 |
+
config.actor_rollout_ref.rollout.name = os.environ["ROLLOUT_NAME"]
|
| 57 |
+
config.actor_rollout_ref.rollout.mode = "async"
|
| 58 |
+
config.actor_rollout_ref.rollout.enforce_eager = True
|
| 59 |
+
config.actor_rollout_ref.rollout.prompt_length = 4096
|
| 60 |
+
config.actor_rollout_ref.rollout.response_length = 4096
|
| 61 |
+
config.actor_rollout_ref.rollout.n = 4
|
| 62 |
+
config.actor_rollout_ref.rollout.agent.num_workers = 2
|
| 63 |
+
config.actor_rollout_ref.rollout.skip_tokenizer_init = True
|
| 64 |
+
|
| 65 |
+
return config
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def test_single_turn(init_config):
|
| 69 |
+
ray.init(
|
| 70 |
+
runtime_env={
|
| 71 |
+
"env_vars": {
|
| 72 |
+
"TOKENIZERS_PARALLELISM": "true",
|
| 73 |
+
"NCCL_DEBUG": "WARN",
|
| 74 |
+
"VLLM_LOGGING_LEVEL": "INFO",
|
| 75 |
+
"VLLM_USE_V1": "1",
|
| 76 |
+
}
|
| 77 |
+
}
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
agent_loop_manager = AgentLoopManager(init_config)
|
| 81 |
+
tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path)
|
| 82 |
+
reward_fn = load_reward_manager(
|
| 83 |
+
init_config, tokenizer, num_examine=0, **init_config.reward_model.get("reward_kwargs", {})
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
raw_prompts = [
|
| 87 |
+
[
|
| 88 |
+
{
|
| 89 |
+
"role": "user",
|
| 90 |
+
"content": "Let's play a role playing game. Your name is Alice, your favorite color is blue.",
|
| 91 |
+
}
|
| 92 |
+
],
|
| 93 |
+
[{"role": "user", "content": "Let's play a role playing game. Your name is Bob, your favorite color is red."}],
|
| 94 |
+
]
|
| 95 |
+
batch = DataProto(
|
| 96 |
+
non_tensor_batch={
|
| 97 |
+
"raw_prompt": np.array(raw_prompts),
|
| 98 |
+
"agent_name": np.array(["single_turn_agent"] * len(raw_prompts)),
|
| 99 |
+
"data_source": np.array(["openai/gsm8k"] * len(raw_prompts)),
|
| 100 |
+
"reward_model": np.array([{"style": "rule", "ground_truth": "1.0"}] * len(raw_prompts)),
|
| 101 |
+
},
|
| 102 |
+
)
|
| 103 |
+
n = init_config.actor_rollout_ref.rollout.n
|
| 104 |
+
batch = batch.repeat(n)
|
| 105 |
+
result = agent_loop_manager.generate_sequences(prompts=batch)
|
| 106 |
+
assert len(result) == len(raw_prompts) * n
|
| 107 |
+
|
| 108 |
+
# check result
|
| 109 |
+
seq_len = result.batch["prompts"].size(1) + result.batch["responses"].size(1)
|
| 110 |
+
assert result.batch["input_ids"].size(1) == seq_len
|
| 111 |
+
assert result.batch["attention_mask"].size(1) == seq_len
|
| 112 |
+
assert result.batch["position_ids"].size(1) == seq_len
|
| 113 |
+
|
| 114 |
+
if init_config.actor_rollout_ref.rollout.calculate_log_probs:
|
| 115 |
+
assert result.batch["rollout_log_probs"].size(1) == result.batch["responses"].size(1)
|
| 116 |
+
|
| 117 |
+
# check compute score
|
| 118 |
+
assert result.batch["rm_scores"].shape == result.batch["responses"].shape
|
| 119 |
+
reward_tensor, reward_extra_info = compute_reward(result, reward_fn)
|
| 120 |
+
assert reward_tensor.shape == result.batch["responses"].shape
|
| 121 |
+
assert "acc" in reward_extra_info, f"reward_extra_info {reward_extra_info} should contain 'acc'"
|
| 122 |
+
assert reward_extra_info["acc"].shape == (len(result),), f"invalid acc: {reward_extra_info['acc']}"
|
| 123 |
+
|
| 124 |
+
# check turns
|
| 125 |
+
num_turns = result.non_tensor_batch["__num_turns__"]
|
| 126 |
+
assert np.all(num_turns == 2)
|
| 127 |
+
|
| 128 |
+
print("Test passed!")
|
| 129 |
+
ray.shutdown()
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class WeatherTool(BaseTool):
|
| 133 |
+
def get_current_temperature(self, location: str, unit: str = "celsius"):
|
| 134 |
+
"""Get current temperature at a location.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
location: The location to get the temperature for, in the format "City, State, Country".
|
| 138 |
+
unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"])
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
the temperature, the location, and the unit in a dict
|
| 142 |
+
"""
|
| 143 |
+
print(f"[DEBUG] get_current_temperature: {location}, {unit}")
|
| 144 |
+
return {
|
| 145 |
+
"temperature": 26.1,
|
| 146 |
+
"location": location,
|
| 147 |
+
"unit": unit,
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:
|
| 151 |
+
schema = get_json_schema(self.get_current_temperature)
|
| 152 |
+
return OpenAIFunctionToolSchema(**schema)
|
| 153 |
+
|
| 154 |
+
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[ToolResponse, float, dict]:
|
| 155 |
+
try:
|
| 156 |
+
result = self.get_current_temperature(**parameters)
|
| 157 |
+
return ToolResponse(text=json.dumps(result)), 0, {}
|
| 158 |
+
except Exception as e:
|
| 159 |
+
return ToolResponse(text=str(e)), 0, {}
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class WeatherToolWithData(BaseTool):
|
| 163 |
+
def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:
|
| 164 |
+
schema = get_json_schema(self.get_temperature_date)
|
| 165 |
+
return OpenAIFunctionToolSchema(**schema)
|
| 166 |
+
|
| 167 |
+
def get_temperature_date(self, location: str, date: str, unit: str = "celsius"):
|
| 168 |
+
"""Get temperature at a location and date.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
location: The location to get the temperature for, in the format "City, State, Country".
|
| 172 |
+
date: The date to get the temperature for, in the format "Year-Month-Day".
|
| 173 |
+
unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"])
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
the temperature, the location, the date and the unit in a dict
|
| 177 |
+
"""
|
| 178 |
+
print(f"[DEBUG] get_temperature_date: {location}, {date}, {unit}")
|
| 179 |
+
return {
|
| 180 |
+
"temperature": 25.9,
|
| 181 |
+
"location": location,
|
| 182 |
+
"date": date,
|
| 183 |
+
"unit": unit,
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[ToolResponse, float, dict]:
|
| 187 |
+
try:
|
| 188 |
+
result = self.get_temperature_date(**parameters)
|
| 189 |
+
return ToolResponse(text=json.dumps(result)), 0, {}
|
| 190 |
+
except Exception as e:
|
| 191 |
+
return ToolResponse(text=str(e)), 0, {}
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def test_tool_agent(init_config):
|
| 195 |
+
ray.init(
|
| 196 |
+
runtime_env={
|
| 197 |
+
"env_vars": {
|
| 198 |
+
"TOKENIZERS_PARALLELISM": "true",
|
| 199 |
+
"NCCL_DEBUG": "WARN",
|
| 200 |
+
"VLLM_LOGGING_LEVEL": "INFO",
|
| 201 |
+
"VLLM_USE_V1": "1",
|
| 202 |
+
}
|
| 203 |
+
},
|
| 204 |
+
ignore_reinit_error=True,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# =========================== 1. Init rollout manager ===========================
|
| 208 |
+
tool_config = {
|
| 209 |
+
"tools": [
|
| 210 |
+
{
|
| 211 |
+
"class_name": "tests.experimental.agent_loop.test_basic_agent_loop.WeatherTool",
|
| 212 |
+
"config": {"type": "native"},
|
| 213 |
+
},
|
| 214 |
+
{
|
| 215 |
+
"class_name": "tests.experimental.agent_loop.test_basic_agent_loop.WeatherToolWithData",
|
| 216 |
+
"config": {"type": "native"},
|
| 217 |
+
},
|
| 218 |
+
]
|
| 219 |
+
}
|
| 220 |
+
tool_config_path = "/tmp/tool_config.json"
|
| 221 |
+
with open(tool_config_path, "w") as f:
|
| 222 |
+
json.dump(tool_config, f)
|
| 223 |
+
|
| 224 |
+
n = 2
|
| 225 |
+
init_config.actor_rollout_ref.rollout.n = n
|
| 226 |
+
init_config.actor_rollout_ref.rollout.multi_turn.tool_config_path = tool_config_path
|
| 227 |
+
init_config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls = 2
|
| 228 |
+
init_config.actor_rollout_ref.rollout.calculate_log_probs = True
|
| 229 |
+
agent_loop_manager = AgentLoopManager(init_config)
|
| 230 |
+
|
| 231 |
+
# =========================== 2. Generate sequences ===========================
|
| 232 |
+
raw_prompts = [
|
| 233 |
+
[
|
| 234 |
+
{"role": "user", "content": "How are you?"},
|
| 235 |
+
],
|
| 236 |
+
[
|
| 237 |
+
{"role": "user", "content": "What's the temperature in Los Angeles now?"},
|
| 238 |
+
],
|
| 239 |
+
[
|
| 240 |
+
{"role": "user", "content": "What's the temperature in New York now?"},
|
| 241 |
+
],
|
| 242 |
+
[
|
| 243 |
+
{
|
| 244 |
+
"role": "system",
|
| 245 |
+
"content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\n\n"
|
| 246 |
+
"Current Date: 2024-09-30",
|
| 247 |
+
},
|
| 248 |
+
{"role": "user", "content": "What's the temperature in San Francisco now? How about tomorrow?"},
|
| 249 |
+
],
|
| 250 |
+
]
|
| 251 |
+
batch = DataProto(
|
| 252 |
+
non_tensor_batch={
|
| 253 |
+
"raw_prompt": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object),
|
| 254 |
+
"agent_name": np.array(["tool_agent"] * len(raw_prompts)),
|
| 255 |
+
"data_source": np.array(["openai/gsm8k"] * len(raw_prompts)),
|
| 256 |
+
"reward_model": np.array([{"style": "rule", "ground_truth": "1.0"}] * len(raw_prompts)),
|
| 257 |
+
},
|
| 258 |
+
)
|
| 259 |
+
batch = batch.repeat(n)
|
| 260 |
+
result = agent_loop_manager.generate_sequences(prompts=batch)
|
| 261 |
+
assert len(result) == len(raw_prompts) * n
|
| 262 |
+
|
| 263 |
+
# Check turns
|
| 264 |
+
num_turns = result.non_tensor_batch["__num_turns__"]
|
| 265 |
+
print(f"num_turns: {num_turns}")
|
| 266 |
+
for i in range(len(num_turns)):
|
| 267 |
+
if i // n == 0:
|
| 268 |
+
# [user, assistant]
|
| 269 |
+
assert num_turns[i] == 2
|
| 270 |
+
else:
|
| 271 |
+
# [user, assistant, tool, assistant]
|
| 272 |
+
assert num_turns[i] == 4
|
| 273 |
+
|
| 274 |
+
# Check response_mask
|
| 275 |
+
tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path)
|
| 276 |
+
responses = result.batch["responses"]
|
| 277 |
+
response_mask = result.batch["response_mask"]
|
| 278 |
+
attention_mask = result.batch["attention_mask"]
|
| 279 |
+
assert result.batch["rm_scores"].size(1) == responses.size(1)
|
| 280 |
+
assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}"
|
| 281 |
+
assert result.batch["rollout_log_probs"].size(1) == result.batch["responses"].size(1)
|
| 282 |
+
|
| 283 |
+
response_length = response_mask.size(1)
|
| 284 |
+
for i in range(len(responses)):
|
| 285 |
+
# response with tool response
|
| 286 |
+
valid_tokens = responses[i][attention_mask[i][-response_length:].bool()]
|
| 287 |
+
response_with_obs = tokenizer.decode(valid_tokens)
|
| 288 |
+
|
| 289 |
+
# response without tool response
|
| 290 |
+
valid_tokens = responses[i][response_mask[i].bool()]
|
| 291 |
+
response_without_obs = tokenizer.decode(valid_tokens)
|
| 292 |
+
|
| 293 |
+
assert "<tool_response>" not in response_without_obs, (
|
| 294 |
+
f"found <tool_response> in response: {response_without_obs}"
|
| 295 |
+
)
|
| 296 |
+
assert "</tool_response>" not in response_without_obs, (
|
| 297 |
+
f"found </tool_response> in response: {response_without_obs}"
|
| 298 |
+
)
|
| 299 |
+
print("=========================")
|
| 300 |
+
print(response_with_obs)
|
| 301 |
+
print("---")
|
| 302 |
+
print(response_without_obs)
|
| 303 |
+
|
| 304 |
+
print("Test passed!")
|
| 305 |
+
ray.shutdown()
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def test_tool_agent_with_interaction(init_config):
|
| 309 |
+
ray.init(
|
| 310 |
+
runtime_env={
|
| 311 |
+
"env_vars": {
|
| 312 |
+
"TOKENIZERS_PARALLELISM": "true",
|
| 313 |
+
"NCCL_DEBUG": "WARN",
|
| 314 |
+
"VLLM_LOGGING_LEVEL": "INFO",
|
| 315 |
+
"VLLM_USE_V1": "1",
|
| 316 |
+
}
|
| 317 |
+
}
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
# =========================== 1. Init rollout manager ===========================
|
| 321 |
+
tool_config = {
|
| 322 |
+
"tools": [
|
| 323 |
+
{
|
| 324 |
+
"class_name": "tests.experimental.agent_loop.test_basic_agent_loop.WeatherTool",
|
| 325 |
+
"config": {"type": "native"},
|
| 326 |
+
},
|
| 327 |
+
{
|
| 328 |
+
"class_name": "tests.experimental.agent_loop.test_basic_agent_loop.WeatherToolWithData",
|
| 329 |
+
"config": {"type": "native"},
|
| 330 |
+
},
|
| 331 |
+
]
|
| 332 |
+
}
|
| 333 |
+
tool_config_path = "/tmp/tool_config.json"
|
| 334 |
+
with open(tool_config_path, "w") as f:
|
| 335 |
+
json.dump(tool_config, f)
|
| 336 |
+
|
| 337 |
+
interaction_config = {
|
| 338 |
+
"interaction": [
|
| 339 |
+
{"name": "weather", "class_name": "verl.interactions.weather_interaction.WeatherInteraction", "config": {}}
|
| 340 |
+
]
|
| 341 |
+
}
|
| 342 |
+
interaction_config_path = "/tmp/interaction_config.json"
|
| 343 |
+
with open(interaction_config_path, "w") as f:
|
| 344 |
+
json.dump(interaction_config, f)
|
| 345 |
+
|
| 346 |
+
n = 2
|
| 347 |
+
init_config.actor_rollout_ref.rollout.n = n
|
| 348 |
+
init_config.actor_rollout_ref.rollout.multi_turn.tool_config_path = tool_config_path
|
| 349 |
+
init_config.actor_rollout_ref.rollout.multi_turn.interaction_config_path = interaction_config_path
|
| 350 |
+
init_config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls = 2
|
| 351 |
+
agent_loop_manager = init_agent_loop_manager(init_config)
|
| 352 |
+
checkpoint_manager = CheckpointEngineManager(
|
| 353 |
+
backend=init_config.actor_rollout_ref.rollout.checkpoint_engine.backend,
|
| 354 |
+
trainer=agent_loop_manager.worker_group,
|
| 355 |
+
replicas=agent_loop_manager.rollout_replicas,
|
| 356 |
+
)
|
| 357 |
+
checkpoint_manager.sleep_replicas()
|
| 358 |
+
checkpoint_manager.update_weights()
|
| 359 |
+
|
| 360 |
+
# =========================== 2. Generate sequences ===========================
|
| 361 |
+
raw_prompts = [
|
| 362 |
+
[
|
| 363 |
+
{"role": "user", "content": "How are you?"},
|
| 364 |
+
],
|
| 365 |
+
[
|
| 366 |
+
{"role": "user", "content": "What's the temperature in Los Angeles now?"},
|
| 367 |
+
],
|
| 368 |
+
[
|
| 369 |
+
{"role": "user", "content": "What's the temperature in New York now?"},
|
| 370 |
+
],
|
| 371 |
+
[
|
| 372 |
+
{
|
| 373 |
+
"role": "system",
|
| 374 |
+
"content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\n\n"
|
| 375 |
+
"Current Date: 2024-09-30",
|
| 376 |
+
},
|
| 377 |
+
{"role": "user", "content": "What's the temperature in San Francisco now? How about tomorrow?"},
|
| 378 |
+
],
|
| 379 |
+
]
|
| 380 |
+
batch = DataProto(
|
| 381 |
+
non_tensor_batch={
|
| 382 |
+
"raw_prompt": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object),
|
| 383 |
+
"agent_name": np.array(["tool_agent"] * len(raw_prompts)),
|
| 384 |
+
"data_source": np.array(["openai/gsm8k"] * len(raw_prompts)),
|
| 385 |
+
"reward_model": np.array([{"style": "rule", "ground_truth": "1.0"}] * len(raw_prompts)),
|
| 386 |
+
"extra_info": np.array(
|
| 387 |
+
[
|
| 388 |
+
{"interaction_kwargs": {"name": "weather"}},
|
| 389 |
+
{"interaction_kwargs": {"name": "weather"}},
|
| 390 |
+
{"interaction_kwargs": {"name": "weather"}},
|
| 391 |
+
{"interaction_kwargs": {"name": "weather"}},
|
| 392 |
+
]
|
| 393 |
+
),
|
| 394 |
+
},
|
| 395 |
+
)
|
| 396 |
+
batch = batch.repeat(n)
|
| 397 |
+
result = agent_loop_manager.generate_sequences(prompts=batch)
|
| 398 |
+
assert len(result) == len(raw_prompts) * n
|
| 399 |
+
|
| 400 |
+
# Check turns
|
| 401 |
+
num_turns = result.non_tensor_batch["__num_turns__"]
|
| 402 |
+
print(f"num_turns: {num_turns}")
|
| 403 |
+
for i in range(len(num_turns)):
|
| 404 |
+
if i // n == 0:
|
| 405 |
+
# [user, assistant, user]
|
| 406 |
+
assert num_turns[i] == 3
|
| 407 |
+
else:
|
| 408 |
+
# [user, assistant, tool, assistant, user]
|
| 409 |
+
assert num_turns[i] == 5
|
| 410 |
+
|
| 411 |
+
# Check response_mask
|
| 412 |
+
tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path)
|
| 413 |
+
responses = result.batch["responses"]
|
| 414 |
+
response_mask = result.batch["response_mask"]
|
| 415 |
+
attention_mask = result.batch["attention_mask"]
|
| 416 |
+
assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}"
|
| 417 |
+
response_length = response_mask.size(1)
|
| 418 |
+
|
| 419 |
+
for i in range(len(responses)):
|
| 420 |
+
# response with tool response
|
| 421 |
+
valid_tokens = responses[i][attention_mask[i][-response_length:].bool()]
|
| 422 |
+
response_with_obs = tokenizer.decode(valid_tokens)
|
| 423 |
+
|
| 424 |
+
# response without tool response
|
| 425 |
+
valid_tokens = responses[i][response_mask[i].bool()]
|
| 426 |
+
response_without_obs = tokenizer.decode(valid_tokens)
|
| 427 |
+
|
| 428 |
+
assert "\udb82\udc89" not in response_without_obs, f"found \udb82\udc89 in response: {response_without_obs}"
|
| 429 |
+
assert "\udb82\udc8a" not in response_without_obs, f"found \udb82\udc8a in response: {response_without_obs}"
|
| 430 |
+
print("=========================")
|
| 431 |
+
print(response_with_obs)
|
| 432 |
+
print("---")
|
| 433 |
+
print(response_without_obs)
|
| 434 |
+
|
| 435 |
+
print("Test passed!")
|
| 436 |
+
ray.shutdown()
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
@pytest.mark.asyncio
|
| 440 |
+
async def test_get_trajectory_info():
|
| 441 |
+
"""Tests the get_trajectory_info method."""
|
| 442 |
+
# Initialize the class to set up class-level attributes
|
| 443 |
+
step = 10
|
| 444 |
+
index = [1, 1, 3, 3]
|
| 445 |
+
expected_info = [
|
| 446 |
+
{"step": step, "sample_index": 1, "rollout_n": 0, "validate": False},
|
| 447 |
+
{"step": step, "sample_index": 1, "rollout_n": 1, "validate": False},
|
| 448 |
+
{"step": step, "sample_index": 3, "rollout_n": 0, "validate": False},
|
| 449 |
+
{"step": step, "sample_index": 3, "rollout_n": 1, "validate": False},
|
| 450 |
+
]
|
| 451 |
+
|
| 452 |
+
trajectory_info = await get_trajectory_info(step, index, validate=False)
|
| 453 |
+
|
| 454 |
+
assert trajectory_info == expected_info
|
code/RL_model/verl/verl_train/tests/experimental/agent_loop/test_gpt_oss_tool_parser.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import pytest
|
| 15 |
+
from transformers import AutoTokenizer
|
| 16 |
+
|
| 17 |
+
from verl.experimental.agent_loop.tool_parser import GptOssToolParser
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@pytest.mark.asyncio
|
| 21 |
+
@pytest.mark.skip(reason="local test only")
|
| 22 |
+
async def test_gpt_oss_tool_parser():
|
| 23 |
+
example_text = """
|
| 24 |
+
<|start|>assistant<|channel|>commentary to=functions.get_current_weather \
|
| 25 |
+
<|constrain|>json<|message|>{"location": "Tokyo"}<|call|>
|
| 26 |
+
<|start|>functions.get_current_weather to=assistant<|channel|>commentary<|message|>\
|
| 27 |
+
{ "temperature": 20, "sunny": true }<|end|>"""
|
| 28 |
+
tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b")
|
| 29 |
+
response_ids = tokenizer.encode(example_text)
|
| 30 |
+
tool_parser = GptOssToolParser(tokenizer)
|
| 31 |
+
_, function_calls = await tool_parser.extract_tool_calls(response_ids)
|
| 32 |
+
assert len(function_calls) == 1
|
| 33 |
+
assert function_calls[0].name == "get_current_weather"
|
| 34 |
+
assert function_calls[0].arguments == '{"location": "Tokyo"}'
|
code/RL_model/verl/verl_train/tests/experimental/agent_loop/test_multi_modal.py
ADDED
|
@@ -0,0 +1,570 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import json
|
| 15 |
+
import os
|
| 16 |
+
from typing import Any
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import pytest
|
| 20 |
+
import ray
|
| 21 |
+
from omegaconf import DictConfig
|
| 22 |
+
from PIL import Image
|
| 23 |
+
from transformers.utils import get_json_schema
|
| 24 |
+
|
| 25 |
+
from verl.experimental.agent_loop import AgentLoopManager
|
| 26 |
+
from verl.protocol import DataProto
|
| 27 |
+
from verl.tools.base_tool import BaseTool, OpenAIFunctionToolSchema
|
| 28 |
+
from verl.tools.schemas import ToolResponse
|
| 29 |
+
from verl.utils import hf_tokenizer
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def parse_multi_modal_type(messages: list[dict]) -> str:
|
| 33 |
+
message = messages[-1]
|
| 34 |
+
if isinstance(message["content"], str):
|
| 35 |
+
return "text"
|
| 36 |
+
|
| 37 |
+
for content in message["content"]:
|
| 38 |
+
if content["type"] == "image":
|
| 39 |
+
return "image"
|
| 40 |
+
elif content["type"] == "video":
|
| 41 |
+
return "video"
|
| 42 |
+
|
| 43 |
+
return "text"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@pytest.fixture
|
| 47 |
+
def init_config() -> DictConfig:
|
| 48 |
+
from hydra import compose, initialize_config_dir
|
| 49 |
+
|
| 50 |
+
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
|
| 51 |
+
config = compose(
|
| 52 |
+
config_name="ppo_trainer",
|
| 53 |
+
overrides=[
|
| 54 |
+
"actor_rollout_ref.actor.use_dynamic_bsz=true",
|
| 55 |
+
# test sleep/wake_up with fsdp offload
|
| 56 |
+
"actor_rollout_ref.actor.fsdp_config.param_offload=True",
|
| 57 |
+
"actor_rollout_ref.actor.fsdp_config.optimizer_offload=True",
|
| 58 |
+
],
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
model_path = os.path.expanduser("~/models/Qwen/Qwen2.5-VL-3B-Instruct")
|
| 62 |
+
config.actor_rollout_ref.model.path = model_path
|
| 63 |
+
config.actor_rollout_ref.rollout.name = os.environ["ROLLOUT_NAME"]
|
| 64 |
+
config.actor_rollout_ref.rollout.mode = "async"
|
| 65 |
+
config.actor_rollout_ref.rollout.enforce_eager = True
|
| 66 |
+
config.actor_rollout_ref.rollout.prompt_length = 10240
|
| 67 |
+
config.actor_rollout_ref.rollout.response_length = 4096
|
| 68 |
+
config.actor_rollout_ref.rollout.n = 4
|
| 69 |
+
config.actor_rollout_ref.rollout.agent.num_workers = 2
|
| 70 |
+
config.actor_rollout_ref.rollout.skip_tokenizer_init = True
|
| 71 |
+
|
| 72 |
+
return config
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class ImageGeneratorTool(BaseTool):
|
| 76 |
+
def generate_image(self, description: str, size: str = "256x256"):
|
| 77 |
+
"""Generate a simple image based on description.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
description: The description of the image to generate.
|
| 81 |
+
size: The size of the image. Defaults to "256x256". (choices: ["256x256", "512x512"])
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
A generated image
|
| 85 |
+
"""
|
| 86 |
+
print(f"[DEBUG] generate_image: {description}, {size}")
|
| 87 |
+
# Create a simple colored image for testing
|
| 88 |
+
width, height = map(int, size.split("x"))
|
| 89 |
+
|
| 90 |
+
# Create different colors based on description
|
| 91 |
+
if "red" in description.lower():
|
| 92 |
+
color = (255, 0, 0)
|
| 93 |
+
elif "blue" in description.lower():
|
| 94 |
+
color = (0, 0, 255)
|
| 95 |
+
elif "green" in description.lower():
|
| 96 |
+
color = (0, 255, 0)
|
| 97 |
+
else:
|
| 98 |
+
color = (128, 128, 128) # gray
|
| 99 |
+
|
| 100 |
+
# Create image
|
| 101 |
+
image = Image.new("RGB", (width, height), color)
|
| 102 |
+
|
| 103 |
+
# Add some pattern to make it more interesting
|
| 104 |
+
for i in range(0, width, 50):
|
| 105 |
+
for j in range(0, height, 50):
|
| 106 |
+
# Add white squares in a grid pattern
|
| 107 |
+
for x in range(i, min(i + 20, width)):
|
| 108 |
+
for y in range(j, min(j + 20, height)):
|
| 109 |
+
image.putpixel((x, y), (255, 255, 255))
|
| 110 |
+
|
| 111 |
+
return image
|
| 112 |
+
|
| 113 |
+
def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:
|
| 114 |
+
schema = get_json_schema(self.generate_image)
|
| 115 |
+
return OpenAIFunctionToolSchema(**schema)
|
| 116 |
+
|
| 117 |
+
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[ToolResponse, float, dict]:
|
| 118 |
+
try:
|
| 119 |
+
image = self.generate_image(**parameters)
|
| 120 |
+
# Return the PIL Image directly - the framework should handle the conversion
|
| 121 |
+
return ToolResponse(image=[image]), 0, {}
|
| 122 |
+
except Exception as e:
|
| 123 |
+
return ToolResponse(text=str(e)), 0, {}
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
@pytest.mark.flaky(reruns=3)
|
| 127 |
+
def test_multimodal_tool_agent(init_config):
|
| 128 |
+
"""Test agent loop with multimodal tool that returns images using Qwen VL model."""
|
| 129 |
+
ray.shutdown()
|
| 130 |
+
ray.init(
|
| 131 |
+
runtime_env={
|
| 132 |
+
"env_vars": {
|
| 133 |
+
"TOKENIZERS_PARALLELISM": "true",
|
| 134 |
+
"NCCL_DEBUG": "WARN",
|
| 135 |
+
"VLLM_LOGGING_LEVEL": "INFO",
|
| 136 |
+
"VLLM_USE_V1": "1",
|
| 137 |
+
}
|
| 138 |
+
},
|
| 139 |
+
ignore_reinit_error=True,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# Add custom chat template to enable tool calling support (same as recipe/deepeyes)
|
| 143 |
+
template_path = os.path.join(os.path.dirname(__file__), "qwen_vl_tool_chat_template.jinja2")
|
| 144 |
+
with open(template_path, encoding="utf-8") as f:
|
| 145 |
+
custom_chat_template = f.read()
|
| 146 |
+
|
| 147 |
+
init_config.actor_rollout_ref.model.custom_chat_template = custom_chat_template
|
| 148 |
+
|
| 149 |
+
# =========================== 1. Init rollout manager with image tool ===========================
|
| 150 |
+
tool_config = {
|
| 151 |
+
"tools": [
|
| 152 |
+
{
|
| 153 |
+
"class_name": "tests.experimental.agent_loop.test_multi_modal.ImageGeneratorTool",
|
| 154 |
+
"config": {"type": "native"},
|
| 155 |
+
},
|
| 156 |
+
]
|
| 157 |
+
}
|
| 158 |
+
tool_config_path = "/tmp/multimodal_tool_config.json"
|
| 159 |
+
with open(tool_config_path, "w") as f:
|
| 160 |
+
json.dump(tool_config, f)
|
| 161 |
+
|
| 162 |
+
n = 2
|
| 163 |
+
init_config.actor_rollout_ref.rollout.n = n
|
| 164 |
+
init_config.actor_rollout_ref.rollout.multi_turn.tool_config_path = tool_config_path
|
| 165 |
+
init_config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls = 1
|
| 166 |
+
init_config.actor_rollout_ref.rollout.multi_turn.max_user_turns = 1
|
| 167 |
+
agent_loop_manager = AgentLoopManager(init_config)
|
| 168 |
+
|
| 169 |
+
# =========================== 2. Generate sequences with multimodal prompts ===========================
|
| 170 |
+
raw_prompts = [
|
| 171 |
+
[
|
| 172 |
+
{"role": "user", "content": "How are you?"},
|
| 173 |
+
],
|
| 174 |
+
[
|
| 175 |
+
{
|
| 176 |
+
"role": "user",
|
| 177 |
+
"content": [
|
| 178 |
+
{
|
| 179 |
+
"type": "video",
|
| 180 |
+
"video": os.path.expanduser("~/models/hf_data/test-videos/space_woaudio.mp4"),
|
| 181 |
+
"min_pixels": 4 * 32 * 32,
|
| 182 |
+
"max_pixels": 256 * 32 * 32,
|
| 183 |
+
"total_pixels": 4096 * 32 * 32,
|
| 184 |
+
},
|
| 185 |
+
{
|
| 186 |
+
"type": "text",
|
| 187 |
+
"text": "Describe this video. Then you must call the "
|
| 188 |
+
"image generator tool to generate a green image for me.",
|
| 189 |
+
},
|
| 190 |
+
],
|
| 191 |
+
},
|
| 192 |
+
],
|
| 193 |
+
[
|
| 194 |
+
{"role": "user", "content": "Please generate a red image for me."},
|
| 195 |
+
],
|
| 196 |
+
[
|
| 197 |
+
{"role": "user", "content": "Can you create a blue picture with size 512x512?"},
|
| 198 |
+
],
|
| 199 |
+
[
|
| 200 |
+
{
|
| 201 |
+
"role": "system",
|
| 202 |
+
"content": (
|
| 203 |
+
"You are Qwen VL, created by Alibaba Cloud. You are a helpful "
|
| 204 |
+
"assistant that can generate and analyze images."
|
| 205 |
+
),
|
| 206 |
+
},
|
| 207 |
+
{"role": "user", "content": "Generate a green landscape image and describe what you see in it."},
|
| 208 |
+
],
|
| 209 |
+
]
|
| 210 |
+
|
| 211 |
+
batch = DataProto(
|
| 212 |
+
non_tensor_batch={
|
| 213 |
+
"raw_prompt": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object),
|
| 214 |
+
"agent_name": np.array(["tool_agent"] * len(raw_prompts)),
|
| 215 |
+
"data_source": np.array(["openai/gsm8k"] * len(raw_prompts)),
|
| 216 |
+
"reward_model": np.array([{"style": "rule", "ground_truth": "1.0"}] * len(raw_prompts)),
|
| 217 |
+
},
|
| 218 |
+
)
|
| 219 |
+
batch = batch.repeat(n)
|
| 220 |
+
result = agent_loop_manager.generate_sequences(prompts=batch)
|
| 221 |
+
assert len(result) == len(raw_prompts) * n
|
| 222 |
+
|
| 223 |
+
# Check turns
|
| 224 |
+
num_turns = result.non_tensor_batch["__num_turns__"]
|
| 225 |
+
multi_modal_inputs = result.non_tensor_batch["multi_modal_inputs"]
|
| 226 |
+
print(f"num_turns: {num_turns}")
|
| 227 |
+
for i in range(len(num_turns)):
|
| 228 |
+
multi_modal_type = parse_multi_modal_type(raw_prompts[i // n])
|
| 229 |
+
if multi_modal_type == "video":
|
| 230 |
+
assert "pixel_values_videos" in multi_modal_inputs[i], f"Sample {i} should have pixel_values_videos"
|
| 231 |
+
assert "video_grid_thw" in multi_modal_inputs[i], f"Sample {i} should have video_grid_thw"
|
| 232 |
+
|
| 233 |
+
if i // n <= 1:
|
| 234 |
+
# TODO: prompt with video not generate tool call as expected
|
| 235 |
+
# First prompt: "How are you?" - should have 2 turns [user, assistant]
|
| 236 |
+
assert num_turns[i] == 2, f"Expected 2 turns but got {num_turns[i]} for sample {i}"
|
| 237 |
+
else:
|
| 238 |
+
# Tool-calling prompts should have 4 turns [user, assistant, tool, assistant]
|
| 239 |
+
assert num_turns[i] == 4, f"Expected 4 turns but got {num_turns[i]} for sample {i}"
|
| 240 |
+
assert "pixel_values" in multi_modal_inputs[i], f"Sample {i} should have pixel_values"
|
| 241 |
+
assert "image_grid_thw" in multi_modal_inputs[i], f"Sample {i} should have image_grid_thw"
|
| 242 |
+
|
| 243 |
+
# Check that images were properly returned in the tool responses
|
| 244 |
+
tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path)
|
| 245 |
+
responses = result.batch["responses"]
|
| 246 |
+
response_mask = result.batch["response_mask"]
|
| 247 |
+
attention_mask = result.batch["attention_mask"]
|
| 248 |
+
assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}"
|
| 249 |
+
response_length = response_mask.size(1)
|
| 250 |
+
|
| 251 |
+
image_found_count = 0
|
| 252 |
+
for i in range(len(responses)):
|
| 253 |
+
# response with tool response (including images)
|
| 254 |
+
valid_tokens = responses[i][attention_mask[i][-response_length:].bool()]
|
| 255 |
+
response_with_obs = tokenizer.decode(valid_tokens)
|
| 256 |
+
|
| 257 |
+
# response without tool response
|
| 258 |
+
valid_tokens = responses[i][response_mask[i].bool()]
|
| 259 |
+
response_without_obs = tokenizer.decode(valid_tokens)
|
| 260 |
+
|
| 261 |
+
# Check that tool responses were properly masked out from training
|
| 262 |
+
assert "<tool_response>" not in response_without_obs, (
|
| 263 |
+
f"found <tool_response> in response: {response_without_obs}"
|
| 264 |
+
)
|
| 265 |
+
assert "</tool_response>" not in response_without_obs, (
|
| 266 |
+
f"found </tool_response> in response: {response_without_obs}"
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
# Check that images were included in the full response
|
| 270 |
+
if "<image>" in response_with_obs or "image" in response_with_obs.lower():
|
| 271 |
+
image_found_count += 1
|
| 272 |
+
|
| 273 |
+
print("=========================")
|
| 274 |
+
print("Response with tool observations:")
|
| 275 |
+
print(response_with_obs)
|
| 276 |
+
print("---")
|
| 277 |
+
print("Response without tool observations:")
|
| 278 |
+
print(response_without_obs)
|
| 279 |
+
|
| 280 |
+
# Verify that tool-calling responses contained image-related content
|
| 281 |
+
print(f"Found {image_found_count} responses with image content out of {len(responses)}")
|
| 282 |
+
# We should have at least some image content from the tool-calling prompts
|
| 283 |
+
# Note: First prompt might not use tools, so we don't expect 100% image content
|
| 284 |
+
expected_tool_calls = sum(1 for i in range(len(num_turns)) if num_turns[i] == 4)
|
| 285 |
+
assert image_found_count >= 0, (
|
| 286 |
+
f"No image-related content found, but expected at least some from {expected_tool_calls} tool calls"
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
print("Multimodal tool test passed!")
|
| 290 |
+
ray.shutdown()
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def test_multimodal_single_turn_agent(init_config):
|
| 294 |
+
"""Test single turn agent loop with multimodal inputs using Qwen VL model."""
|
| 295 |
+
ray.init(
|
| 296 |
+
runtime_env={
|
| 297 |
+
"env_vars": {
|
| 298 |
+
"TOKENIZERS_PARALLELISM": "true",
|
| 299 |
+
"NCCL_DEBUG": "WARN",
|
| 300 |
+
"VLLM_LOGGING_LEVEL": "INFO",
|
| 301 |
+
"VLLM_USE_V1": "1",
|
| 302 |
+
}
|
| 303 |
+
},
|
| 304 |
+
ignore_reinit_error=True,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
# =========================== 1. Init rollout manager ===========================
|
| 308 |
+
n = 2
|
| 309 |
+
init_config.actor_rollout_ref.rollout.n = n
|
| 310 |
+
init_config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls = 1
|
| 311 |
+
init_config.actor_rollout_ref.rollout.multi_turn.max_user_turns = 1
|
| 312 |
+
agent_loop_manager = AgentLoopManager(init_config)
|
| 313 |
+
|
| 314 |
+
# =========================== 2. Generate sequences with multimodal prompts ===========================
|
| 315 |
+
# Create a simple test image
|
| 316 |
+
test_image = Image.new("RGB", (256, 256), (100, 150, 200))
|
| 317 |
+
test_image2 = Image.new("RGB", (512, 512), (100, 150, 200))
|
| 318 |
+
|
| 319 |
+
raw_prompts = [
|
| 320 |
+
# text
|
| 321 |
+
[
|
| 322 |
+
{"role": "user", "content": "Hello, how are you?"},
|
| 323 |
+
],
|
| 324 |
+
# image
|
| 325 |
+
[
|
| 326 |
+
{
|
| 327 |
+
"role": "user",
|
| 328 |
+
"content": [
|
| 329 |
+
{"type": "image", "image": test_image},
|
| 330 |
+
{"type": "text", "text": "What color is this image?"},
|
| 331 |
+
],
|
| 332 |
+
},
|
| 333 |
+
],
|
| 334 |
+
# system + image
|
| 335 |
+
[
|
| 336 |
+
{
|
| 337 |
+
"role": "system",
|
| 338 |
+
"content": "You are Qwen VL, created by Alibaba Cloud. You are a helpful assistant.",
|
| 339 |
+
},
|
| 340 |
+
{
|
| 341 |
+
"role": "user",
|
| 342 |
+
"content": [
|
| 343 |
+
{"type": "image", "image": test_image2},
|
| 344 |
+
{"type": "text", "text": "Describe this image in detail."},
|
| 345 |
+
],
|
| 346 |
+
},
|
| 347 |
+
],
|
| 348 |
+
# video
|
| 349 |
+
[
|
| 350 |
+
{
|
| 351 |
+
"role": "user",
|
| 352 |
+
"content": [
|
| 353 |
+
{
|
| 354 |
+
"type": "video",
|
| 355 |
+
"video": os.path.expanduser("~/models/hf_data/test-videos/space_woaudio.mp4"),
|
| 356 |
+
"min_pixels": 4 * 32 * 32,
|
| 357 |
+
"max_pixels": 256 * 32 * 32,
|
| 358 |
+
"total_pixels": 4096 * 32 * 32,
|
| 359 |
+
},
|
| 360 |
+
{"type": "text", "text": "Describe this video."},
|
| 361 |
+
],
|
| 362 |
+
},
|
| 363 |
+
],
|
| 364 |
+
]
|
| 365 |
+
|
| 366 |
+
batch = DataProto(
|
| 367 |
+
non_tensor_batch={
|
| 368 |
+
"raw_prompt": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object),
|
| 369 |
+
"agent_name": np.array(["single_turn_agent"] * len(raw_prompts)),
|
| 370 |
+
"data_source": np.array(["openai/gsm8k"] * len(raw_prompts)),
|
| 371 |
+
"reward_model": np.array([{"style": "rule", "ground_truth": "1.0"}] * len(raw_prompts)),
|
| 372 |
+
},
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
batch = batch.repeat(n)
|
| 376 |
+
result = agent_loop_manager.generate_sequences(prompts=batch)
|
| 377 |
+
assert len(result) == len(raw_prompts) * n
|
| 378 |
+
|
| 379 |
+
# Check turns - all should be single turn (2: user + assistant)
|
| 380 |
+
num_turns = result.non_tensor_batch["__num_turns__"]
|
| 381 |
+
print(f"num_turns: {num_turns}")
|
| 382 |
+
for i in range(len(num_turns)):
|
| 383 |
+
assert num_turns[i] == 2, f"Expected 2 turns but got {num_turns[i]} for sample {i}"
|
| 384 |
+
|
| 385 |
+
# Verify responses
|
| 386 |
+
tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path)
|
| 387 |
+
prompts = result.batch["prompts"]
|
| 388 |
+
responses = result.batch["responses"]
|
| 389 |
+
response_mask = result.batch["response_mask"]
|
| 390 |
+
input_ids = result.batch["input_ids"]
|
| 391 |
+
position_ids = result.batch["position_ids"]
|
| 392 |
+
multi_modal_inputs = result.non_tensor_batch["multi_modal_inputs"]
|
| 393 |
+
assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}"
|
| 394 |
+
assert position_ids.size() == (input_ids.size(0), 4, input_ids.size(1)) # (batch_size, 4, seq_len)
|
| 395 |
+
|
| 396 |
+
# Check for image pads in prompts
|
| 397 |
+
image_pad_count = 0
|
| 398 |
+
for i in range(len(prompts)):
|
| 399 |
+
prompt_ids = prompts[i][prompts[i] != tokenizer.pad_token_id].tolist()
|
| 400 |
+
prompt_text = tokenizer.decode(prompt_ids)
|
| 401 |
+
|
| 402 |
+
# Check if this sample should have image pads (samples with index 1 and 2 in each repeat have images)
|
| 403 |
+
sample_idx = i // n
|
| 404 |
+
has_image_pad = "<|image_pad|>" in prompt_text or "<|vision_start|>" in prompt_text
|
| 405 |
+
|
| 406 |
+
print("=========================")
|
| 407 |
+
print(f"Sample {i} (original prompt index: {sample_idx}):")
|
| 408 |
+
print(f"Prompt length: {len(prompt_ids)} tokens")
|
| 409 |
+
print(f"Has image_pad: {has_image_pad}")
|
| 410 |
+
|
| 411 |
+
# Check multi-modal type
|
| 412 |
+
multi_modal_type = parse_multi_modal_type(raw_prompts[sample_idx])
|
| 413 |
+
|
| 414 |
+
if multi_modal_type == "text":
|
| 415 |
+
assert len(multi_modal_inputs[i]) == 0, f"Sample {i} should not have multi-modal inputs"
|
| 416 |
+
elif multi_modal_type == "image":
|
| 417 |
+
assert "pixel_values" in multi_modal_inputs[i], f"Sample {i} should have pixel_values"
|
| 418 |
+
assert "image_grid_thw" in multi_modal_inputs[i], f"Sample {i} should have image_grid_thw"
|
| 419 |
+
else:
|
| 420 |
+
assert "pixel_values_videos" in multi_modal_inputs[i], f"Sample {i} should have pixel_values_videos"
|
| 421 |
+
assert "video_grid_thw" in multi_modal_inputs[i], f"Sample {i} should have video_grid_thw"
|
| 422 |
+
|
| 423 |
+
# Show first 200 chars of prompt
|
| 424 |
+
print(f"Prompt text (first 200 chars): {prompt_text[:200]}...")
|
| 425 |
+
|
| 426 |
+
for i in range(len(responses)):
|
| 427 |
+
valid_tokens = responses[i][response_mask[i].bool()]
|
| 428 |
+
response_text = tokenizer.decode(valid_tokens)
|
| 429 |
+
print(f"Sample {i} response: {response_text[:100]}...")
|
| 430 |
+
|
| 431 |
+
# Verify that we found image pads in multimodal samples
|
| 432 |
+
expected_multimodal_samples = 2 * n # 2 prompts with images, repeated n times
|
| 433 |
+
print(f"\nFound {image_pad_count} samples with image_pad out of {expected_multimodal_samples} expected")
|
| 434 |
+
|
| 435 |
+
print("Single turn multimodal test passed!")
|
| 436 |
+
ray.shutdown()
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
def test_multimodal_partial_single_turn_agent(init_config):
|
| 440 |
+
"""Test partial single turn agent loop with multimodal inputs using Qwen VL model."""
|
| 441 |
+
|
| 442 |
+
# TODO(baiyan):
|
| 443 |
+
# see verl/recipe/fully_async_policy/agent_loop/partial_single_turn_agent_loop.py for more details.
|
| 444 |
+
# if use_correct_processor=True, the test will pass but the async training will hang, so I disable this test
|
| 445 |
+
# for now
|
| 446 |
+
|
| 447 |
+
return
|
| 448 |
+
|
| 449 |
+
ray.init(
|
| 450 |
+
runtime_env={
|
| 451 |
+
"env_vars": {
|
| 452 |
+
"TOKENIZERS_PARALLELISM": "true",
|
| 453 |
+
"NCCL_DEBUG": "WARN",
|
| 454 |
+
"VLLM_LOGGING_LEVEL": "INFO",
|
| 455 |
+
"VLLM_USE_V1": "1",
|
| 456 |
+
}
|
| 457 |
+
},
|
| 458 |
+
ignore_reinit_error=True,
|
| 459 |
+
)
|
| 460 |
+
from verl.experimental.fully_async_policy.agent_loop import FullyAsyncAgentLoopManager
|
| 461 |
+
|
| 462 |
+
# =========================== 1. Init rollout manager ===========================
|
| 463 |
+
n = 2
|
| 464 |
+
init_config.actor_rollout_ref.rollout.n = n
|
| 465 |
+
init_config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls = 1
|
| 466 |
+
init_config.actor_rollout_ref.rollout.multi_turn.max_user_turns = 1
|
| 467 |
+
import asyncio
|
| 468 |
+
|
| 469 |
+
loop = asyncio.new_event_loop()
|
| 470 |
+
asyncio.set_event_loop(loop)
|
| 471 |
+
agent_loop_manager = loop.run_until_complete(FullyAsyncAgentLoopManager.create(init_config))
|
| 472 |
+
|
| 473 |
+
# =========================== 2. Generate sequences with multimodal prompts ===========================
|
| 474 |
+
# Create a simple test image
|
| 475 |
+
test_image = Image.new("RGB", (256, 256), (200, 100, 50))
|
| 476 |
+
test_image2 = Image.new("RGB", (512, 512), (100, 150, 200))
|
| 477 |
+
|
| 478 |
+
raw_prompts = [
|
| 479 |
+
[
|
| 480 |
+
{"role": "user", "content": "What is the capital of France?"},
|
| 481 |
+
],
|
| 482 |
+
[
|
| 483 |
+
{
|
| 484 |
+
"role": "user",
|
| 485 |
+
"content": [
|
| 486 |
+
{"type": "image", "image": test_image},
|
| 487 |
+
{"type": "text", "text": "What do you see in this image?"},
|
| 488 |
+
],
|
| 489 |
+
},
|
| 490 |
+
],
|
| 491 |
+
[
|
| 492 |
+
{
|
| 493 |
+
"role": "system",
|
| 494 |
+
"content": "You are Qwen VL, a helpful multimodal assistant.",
|
| 495 |
+
},
|
| 496 |
+
{
|
| 497 |
+
"role": "user",
|
| 498 |
+
"content": [
|
| 499 |
+
{"type": "image", "image": test_image2},
|
| 500 |
+
{"type": "text", "text": "Analyze the colors in this image."},
|
| 501 |
+
],
|
| 502 |
+
},
|
| 503 |
+
],
|
| 504 |
+
]
|
| 505 |
+
|
| 506 |
+
batch = DataProto(
|
| 507 |
+
non_tensor_batch={
|
| 508 |
+
"raw_prompt": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object),
|
| 509 |
+
"agent_name": np.array(["partial_single_turn_agent"] * len(raw_prompts)),
|
| 510 |
+
"data_source": np.array(["openai/gsm8k"] * len(raw_prompts)),
|
| 511 |
+
"reward_model": np.array([{"style": "rule", "ground_truth": "1.0"}] * len(raw_prompts)),
|
| 512 |
+
},
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
batch = batch.repeat(n)
|
| 516 |
+
result = agent_loop_manager.generate_sequences(prompts=batch)
|
| 517 |
+
assert len(result) == len(raw_prompts) * n
|
| 518 |
+
|
| 519 |
+
# Check turns - all should be single turn (2: user + assistant)
|
| 520 |
+
num_turns = result.non_tensor_batch["__num_turns__"]
|
| 521 |
+
print(f"num_turns: {num_turns}")
|
| 522 |
+
for i in range(len(num_turns)):
|
| 523 |
+
assert num_turns[i] == 2, f"Expected 2 turns but got {num_turns[i]} for sample {i}"
|
| 524 |
+
|
| 525 |
+
# Verify responses
|
| 526 |
+
tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path)
|
| 527 |
+
prompts = result.batch["prompts"]
|
| 528 |
+
responses = result.batch["responses"]
|
| 529 |
+
response_mask = result.batch["response_mask"]
|
| 530 |
+
assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}"
|
| 531 |
+
|
| 532 |
+
# Check for image pads in prompts
|
| 533 |
+
image_pad_count = 0
|
| 534 |
+
for i in range(len(prompts)):
|
| 535 |
+
prompt_ids = prompts[i][prompts[i] != tokenizer.pad_token_id].tolist()
|
| 536 |
+
prompt_text = tokenizer.decode(prompt_ids)
|
| 537 |
+
|
| 538 |
+
# Check if this sample should have image pads (samples with index 1 and 2 in each repeat have images)
|
| 539 |
+
sample_idx = i // n
|
| 540 |
+
has_image_pad = "<|image_pad|>" in prompt_text or "<|vision_start|>" in prompt_text
|
| 541 |
+
|
| 542 |
+
print("=========================")
|
| 543 |
+
print(f"Sample {i} (original prompt index: {sample_idx}):")
|
| 544 |
+
print(f"Prompt length: {len(prompt_ids)} tokens")
|
| 545 |
+
print(f"Has image_pad: {has_image_pad}")
|
| 546 |
+
|
| 547 |
+
if sample_idx != 0: # Samples 1 and 2 should have images
|
| 548 |
+
if has_image_pad:
|
| 549 |
+
image_pad_count += 1
|
| 550 |
+
# Count the number of image_pad tokens
|
| 551 |
+
num_image_pads = prompt_text.count("<|image_pad|>")
|
| 552 |
+
print(f"Number of <|image_pad|> tokens: {num_image_pads}")
|
| 553 |
+
else:
|
| 554 |
+
print("WARNING: Expected image_pad but not found!")
|
| 555 |
+
|
| 556 |
+
# Show first 200 chars of prompt
|
| 557 |
+
print(f"Prompt text (first 200 chars): {prompt_text[:200]}...")
|
| 558 |
+
|
| 559 |
+
for i in range(len(responses)):
|
| 560 |
+
valid_tokens = responses[i][response_mask[i].bool()]
|
| 561 |
+
response_text = tokenizer.decode(valid_tokens)
|
| 562 |
+
print(f"Sample {i} response: {response_text[:100]}...")
|
| 563 |
+
|
| 564 |
+
# Verify that we found image pads in multimodal samples
|
| 565 |
+
expected_multimodal_samples = 2 * n # 2 prompts with images, repeated n times
|
| 566 |
+
print(f"\nFound {image_pad_count} samples with image_pad out of {expected_multimodal_samples} expected")
|
| 567 |
+
assert image_pad_count > 0, "No image_pad tokens found in multimodal samples!"
|
| 568 |
+
|
| 569 |
+
print("Partial single turn multimodal test passed!")
|
| 570 |
+
ray.shutdown()
|
code/RL_model/verl/verl_train/tests/experimental/agent_loop/test_standalone_rollout.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import asyncio
|
| 15 |
+
import os
|
| 16 |
+
|
| 17 |
+
import pytest
|
| 18 |
+
import ray
|
| 19 |
+
from omegaconf import DictConfig
|
| 20 |
+
from openai import AsyncOpenAI, OpenAI
|
| 21 |
+
|
| 22 |
+
from tests.experimental.agent_loop.agent_utils import init_agent_loop_manager
|
| 23 |
+
from verl.checkpoint_engine import CheckpointEngineManager
|
| 24 |
+
from verl.workers.rollout.replica import get_rollout_replica_class
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@pytest.fixture
|
| 28 |
+
def init_config() -> DictConfig:
|
| 29 |
+
from hydra import compose, initialize_config_dir
|
| 30 |
+
|
| 31 |
+
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
|
| 32 |
+
config = compose(config_name="ppo_trainer")
|
| 33 |
+
|
| 34 |
+
config.trainer.n_gpus_per_node = 4
|
| 35 |
+
config.trainer.nnodes = 2
|
| 36 |
+
config.actor_rollout_ref.actor.use_dynamic_bsz = True
|
| 37 |
+
config.actor_rollout_ref.model.path = os.path.expanduser("~/models/Qwen/Qwen2.5-1.5B-Instruct")
|
| 38 |
+
config.actor_rollout_ref.rollout.name = os.environ["ROLLOUT_NAME"]
|
| 39 |
+
config.actor_rollout_ref.rollout.mode = "async"
|
| 40 |
+
config.actor_rollout_ref.rollout.skip_tokenizer_init = False
|
| 41 |
+
|
| 42 |
+
return config
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@pytest.mark.asyncio
|
| 46 |
+
@pytest.mark.parametrize("tp_size", [2, 4])
|
| 47 |
+
async def test_standalone_rollout(init_config, tp_size):
|
| 48 |
+
"""Test standalone rollout single node and multi nodes."""
|
| 49 |
+
ray.init(
|
| 50 |
+
runtime_env={
|
| 51 |
+
"env_vars": {
|
| 52 |
+
"TOKENIZERS_PARALLELISM": "true",
|
| 53 |
+
"NCCL_DEBUG": "WARN",
|
| 54 |
+
"VLLM_LOGGING_LEVEL": "INFO",
|
| 55 |
+
"VLLM_USE_V1": "1",
|
| 56 |
+
"NCCL_P2P_DISABLE": "1", # disable p2p in L20
|
| 57 |
+
}
|
| 58 |
+
}
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
init_config.actor_rollout_ref.rollout.tensor_model_parallel_size = tp_size
|
| 62 |
+
num_replicas = (init_config.trainer.n_gpus_per_node * init_config.trainer.nnodes) // tp_size
|
| 63 |
+
rollout_config = init_config.actor_rollout_ref.rollout
|
| 64 |
+
model_config = init_config.actor_rollout_ref.model
|
| 65 |
+
|
| 66 |
+
# create standalone rollout server
|
| 67 |
+
rollout_server_class = get_rollout_replica_class(init_config.actor_rollout_ref.rollout.name)
|
| 68 |
+
rollout_servers = [
|
| 69 |
+
rollout_server_class(
|
| 70 |
+
replica_rank=replica_rank, config=rollout_config, model_config=model_config, gpus_per_node=2
|
| 71 |
+
)
|
| 72 |
+
for replica_rank in range(num_replicas)
|
| 73 |
+
]
|
| 74 |
+
await asyncio.gather(*[server.init_standalone() for server in rollout_servers])
|
| 75 |
+
|
| 76 |
+
server_handles = [server._server_handle for server in rollout_servers]
|
| 77 |
+
server_addresses = [server._server_address for server in rollout_servers]
|
| 78 |
+
assert len(server_handles) == num_replicas
|
| 79 |
+
assert len(server_addresses) == num_replicas
|
| 80 |
+
|
| 81 |
+
os.environ.pop("HTTPS_PROXY", None)
|
| 82 |
+
os.environ.pop("HTTP_PROXY", None)
|
| 83 |
+
os.environ.pop("NO_PROXY", None)
|
| 84 |
+
|
| 85 |
+
client = AsyncOpenAI(
|
| 86 |
+
api_key="123-abc",
|
| 87 |
+
base_url=f"http://{server_addresses[0]}/v1",
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
completion = await client.chat.completions.create(
|
| 91 |
+
model=init_config.actor_rollout_ref.model.path,
|
| 92 |
+
messages=[{"role": "user", "content": "What can you do?"}],
|
| 93 |
+
)
|
| 94 |
+
print(completion.choices[0].message.content)
|
| 95 |
+
|
| 96 |
+
ray.shutdown()
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
@pytest.mark.skip(reason="local test only")
|
| 100 |
+
def test_hybrid_rollout_with_ep(init_config):
|
| 101 |
+
"""Test hybrid rollout with expert parallelism, DP=2, TP=4, EP=8."""
|
| 102 |
+
ray.init(
|
| 103 |
+
runtime_env={
|
| 104 |
+
"env_vars": {
|
| 105 |
+
"TOKENIZERS_PARALLELISM": "true",
|
| 106 |
+
"NCCL_DEBUG": "WARN",
|
| 107 |
+
"VLLM_LOGGING_LEVEL": "INFO",
|
| 108 |
+
"VLLM_USE_V1": "1",
|
| 109 |
+
}
|
| 110 |
+
}
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
model_path = os.path.expanduser("~/models/Qwen/Qwen3-30B-A3B-Instruct-2507")
|
| 114 |
+
init_config.actor_rollout_ref.model.path = model_path
|
| 115 |
+
|
| 116 |
+
# parallelism config
|
| 117 |
+
init_config.actor_rollout_ref.rollout.tensor_model_parallel_size = 2
|
| 118 |
+
init_config.actor_rollout_ref.rollout.data_parallel_size = 4
|
| 119 |
+
init_config.actor_rollout_ref.rollout.expert_parallel_size = 8
|
| 120 |
+
|
| 121 |
+
# 1. init hybrid worker: FSDP+rollout
|
| 122 |
+
# - build FSDP model and optimizer
|
| 123 |
+
# - offload FSDP model and optimizer, build rollout
|
| 124 |
+
# - sleep rollout and load FSDP model and optimizer
|
| 125 |
+
agent_loop_manager = init_agent_loop_manager(init_config)
|
| 126 |
+
checkpoint_manager = CheckpointEngineManager(
|
| 127 |
+
backend=init_config.actor_rollout_ref.rollout.checkpoint_engine.backend,
|
| 128 |
+
trainer=agent_loop_manager.worker_group,
|
| 129 |
+
replicas=agent_loop_manager.rollout_replicas,
|
| 130 |
+
)
|
| 131 |
+
checkpoint_manager.sleep_replicas()
|
| 132 |
+
checkpoint_manager.update_weights()
|
| 133 |
+
|
| 134 |
+
# 3. test async openai call
|
| 135 |
+
server_address = agent_loop_manager.server_addresses[0]
|
| 136 |
+
client = OpenAI(
|
| 137 |
+
api_key="123-abc",
|
| 138 |
+
base_url=f"http://{server_address}/v1",
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
smapling_params = {
|
| 142 |
+
"temperature": 1.0,
|
| 143 |
+
"top_p": 1.0,
|
| 144 |
+
"max_tokens": 512,
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
response = client.chat.completions.create(
|
| 148 |
+
model=model_path,
|
| 149 |
+
messages=[{"role": "user", "content": "What can you do?"}],
|
| 150 |
+
**smapling_params,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
completion = response.choices[0].message.content
|
| 154 |
+
print(f"response: {completion}")
|
| 155 |
+
|
| 156 |
+
print("Test passed!")
|
| 157 |
+
ray.shutdown()
|
code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_agent_loop_reward_manager.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import os
|
| 15 |
+
|
| 16 |
+
import ray
|
| 17 |
+
from hydra import compose, initialize_config_dir
|
| 18 |
+
from torchdata.stateful_dataloader import StatefulDataLoader
|
| 19 |
+
from transformers import AutoTokenizer
|
| 20 |
+
|
| 21 |
+
from verl.experimental.agent_loop import AgentLoopManager
|
| 22 |
+
from verl.protocol import DataProto
|
| 23 |
+
from verl.trainer.main_ppo import create_rl_sampler
|
| 24 |
+
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def test_agent_loop_reward_manager():
|
| 28 |
+
ray.init(
|
| 29 |
+
runtime_env={
|
| 30 |
+
"env_vars": {
|
| 31 |
+
"TOKENIZERS_PARALLELISM": "true",
|
| 32 |
+
"NCCL_DEBUG": "WARN",
|
| 33 |
+
"VLLM_LOGGING_LEVEL": "INFO",
|
| 34 |
+
"VLLM_USE_V1": "1",
|
| 35 |
+
}
|
| 36 |
+
}
|
| 37 |
+
)
|
| 38 |
+
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
|
| 39 |
+
config = compose(config_name="ppo_trainer")
|
| 40 |
+
|
| 41 |
+
rollout_model_path = os.path.expanduser("~/models/Qwen/Qwen2.5-0.5B-Instruct")
|
| 42 |
+
reward_model_path = os.path.expanduser("~/models/Qwen/Qwen2.5-1.5B-Instruct")
|
| 43 |
+
|
| 44 |
+
# actor_rollout_ref config
|
| 45 |
+
config.data.return_raw_chat = True
|
| 46 |
+
config.data.max_prompt_length = 1024
|
| 47 |
+
config.data.max_response_length = 4096
|
| 48 |
+
config.actor_rollout_ref.model.path = rollout_model_path
|
| 49 |
+
config.actor_rollout_ref.actor.use_dynamic_bsz = True
|
| 50 |
+
config.actor_rollout_ref.rollout.name = os.getenv("ROLLOUT_NAME", "vllm")
|
| 51 |
+
config.actor_rollout_ref.rollout.mode = "async"
|
| 52 |
+
config.actor_rollout_ref.rollout.tensor_model_parallel_size = 2
|
| 53 |
+
config.actor_rollout_ref.rollout.gpu_memory_utilization = 0.9
|
| 54 |
+
config.actor_rollout_ref.rollout.enforce_eager = True
|
| 55 |
+
config.actor_rollout_ref.rollout.prompt_length = 1024
|
| 56 |
+
config.actor_rollout_ref.rollout.response_length = 4096
|
| 57 |
+
config.actor_rollout_ref.rollout.skip_tokenizer_init = True
|
| 58 |
+
config.trainer.n_gpus_per_node = 4
|
| 59 |
+
config.trainer.nnodes = 1
|
| 60 |
+
|
| 61 |
+
config.reward_model.reward_manager = "dapo"
|
| 62 |
+
config.reward_model.enable = True
|
| 63 |
+
config.reward_model.enable_resource_pool = True
|
| 64 |
+
config.reward_model.n_gpus_per_node = 4
|
| 65 |
+
config.reward_model.nnodes = 1
|
| 66 |
+
config.reward_model.model.path = reward_model_path
|
| 67 |
+
config.reward_model.rollout.name = os.getenv("ROLLOUT_NAME", "vllm")
|
| 68 |
+
config.reward_model.rollout.gpu_memory_utilization = 0.9
|
| 69 |
+
config.reward_model.rollout.tensor_model_parallel_size = 2
|
| 70 |
+
config.reward_model.rollout.skip_tokenizer_init = False
|
| 71 |
+
config.reward_model.rollout.prompt_length = 5120
|
| 72 |
+
config.reward_model.rollout.response_length = 4096
|
| 73 |
+
config.custom_reward_function.path = "tests/experimental/reward_loop/reward_fn.py"
|
| 74 |
+
config.custom_reward_function.name = "compute_score_gsm8k"
|
| 75 |
+
|
| 76 |
+
# 1. init reward model manager
|
| 77 |
+
agent_loop_manager = AgentLoopManager(config)
|
| 78 |
+
|
| 79 |
+
# 2. init test data
|
| 80 |
+
local_folder = os.path.expanduser("~/data/gsm8k/")
|
| 81 |
+
data_files = [os.path.join(local_folder, "train.parquet")]
|
| 82 |
+
tokenizer = AutoTokenizer.from_pretrained(rollout_model_path)
|
| 83 |
+
|
| 84 |
+
dataset = RLHFDataset(
|
| 85 |
+
data_files=data_files,
|
| 86 |
+
tokenizer=tokenizer,
|
| 87 |
+
config=config.data,
|
| 88 |
+
processor=None,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
batch_size = 64
|
| 92 |
+
sampler = create_rl_sampler(config.data, dataset)
|
| 93 |
+
dataloader = StatefulDataLoader(
|
| 94 |
+
dataset=dataset,
|
| 95 |
+
batch_size=batch_size,
|
| 96 |
+
num_workers=config.data.dataloader_num_workers,
|
| 97 |
+
drop_last=True,
|
| 98 |
+
collate_fn=collate_fn,
|
| 99 |
+
sampler=sampler,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# 3. generate responses
|
| 103 |
+
batch_dict = next(iter(dataloader))
|
| 104 |
+
batch = DataProto.from_single_dict(batch_dict)
|
| 105 |
+
gen_batch = agent_loop_manager.generate_sequences(prompts=batch)
|
| 106 |
+
|
| 107 |
+
rm_scores = gen_batch.batch["rm_scores"]
|
| 108 |
+
sample_scores = rm_scores.sum(dim=1)
|
| 109 |
+
print(sample_scores)
|
| 110 |
+
|
| 111 |
+
ray.shutdown()
|
code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_agent_reward_loop_colocate.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import os
|
| 15 |
+
|
| 16 |
+
import ray
|
| 17 |
+
from hydra import compose, initialize_config_dir
|
| 18 |
+
from torchdata.stateful_dataloader import StatefulDataLoader
|
| 19 |
+
from transformers import AutoTokenizer
|
| 20 |
+
|
| 21 |
+
from verl.checkpoint_engine import CheckpointEngineManager
|
| 22 |
+
from verl.experimental.agent_loop import AgentLoopManager
|
| 23 |
+
from verl.experimental.reward_loop import RewardLoopManager
|
| 24 |
+
from verl.protocol import DataProto
|
| 25 |
+
from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup
|
| 26 |
+
from verl.trainer.main_ppo import create_rl_sampler
|
| 27 |
+
from verl.trainer.ppo.ray_trainer import ResourcePoolManager
|
| 28 |
+
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
|
| 29 |
+
from verl.utils.device import get_device_name
|
| 30 |
+
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def test_agent_loop_reward_manager():
|
| 34 |
+
ray.init(
|
| 35 |
+
runtime_env={
|
| 36 |
+
"env_vars": {
|
| 37 |
+
"TOKENIZERS_PARALLELISM": "true",
|
| 38 |
+
"NCCL_DEBUG": "WARN",
|
| 39 |
+
"VLLM_LOGGING_LEVEL": "INFO",
|
| 40 |
+
"VLLM_USE_V1": "1",
|
| 41 |
+
}
|
| 42 |
+
}
|
| 43 |
+
)
|
| 44 |
+
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
|
| 45 |
+
config = compose(config_name="ppo_trainer")
|
| 46 |
+
|
| 47 |
+
rollout_model_path = os.path.expanduser("~/models/Qwen/Qwen2.5-0.5B-Instruct")
|
| 48 |
+
reward_model_path = os.path.expanduser("~/models/Qwen/Qwen2.5-1.5B-Instruct")
|
| 49 |
+
|
| 50 |
+
# actor_rollout_ref config
|
| 51 |
+
config.data.return_raw_chat = True
|
| 52 |
+
config.data.max_prompt_length = 1024
|
| 53 |
+
config.data.max_response_length = 4096
|
| 54 |
+
config.actor_rollout_ref.model.path = rollout_model_path
|
| 55 |
+
config.actor_rollout_ref.actor.use_dynamic_bsz = True
|
| 56 |
+
config.actor_rollout_ref.rollout.name = os.getenv("ROLLOUT_NAME", "vllm")
|
| 57 |
+
config.actor_rollout_ref.rollout.mode = "async"
|
| 58 |
+
config.actor_rollout_ref.rollout.tensor_model_parallel_size = 2
|
| 59 |
+
config.actor_rollout_ref.rollout.gpu_memory_utilization = 0.8
|
| 60 |
+
config.actor_rollout_ref.rollout.enforce_eager = True
|
| 61 |
+
config.actor_rollout_ref.rollout.prompt_length = 1024
|
| 62 |
+
config.actor_rollout_ref.rollout.response_length = 4096
|
| 63 |
+
config.actor_rollout_ref.rollout.skip_tokenizer_init = True
|
| 64 |
+
config.trainer.n_gpus_per_node = 8
|
| 65 |
+
config.trainer.nnodes = 1
|
| 66 |
+
|
| 67 |
+
config.reward_model.reward_manager = "dapo"
|
| 68 |
+
config.reward_model.enable = True
|
| 69 |
+
config.reward_model.enable_resource_pool = False
|
| 70 |
+
config.reward_model.n_gpus_per_node = 8
|
| 71 |
+
config.reward_model.model.path = reward_model_path
|
| 72 |
+
config.reward_model.rollout.name = os.getenv("ROLLOUT_NAME", "vllm")
|
| 73 |
+
config.reward_model.rollout.gpu_memory_utilization = 0.8
|
| 74 |
+
config.reward_model.rollout.tensor_model_parallel_size = 2
|
| 75 |
+
config.reward_model.rollout.skip_tokenizer_init = False
|
| 76 |
+
config.reward_model.rollout.prompt_length = 5120
|
| 77 |
+
config.reward_model.rollout.response_length = 4096
|
| 78 |
+
config.custom_reward_function.path = "tests/experimental/reward_loop/reward_fn.py"
|
| 79 |
+
config.custom_reward_function.name = "compute_score_gsm8k"
|
| 80 |
+
|
| 81 |
+
# 1. init reward model manager
|
| 82 |
+
actor_rollout_cls = (
|
| 83 |
+
AsyncActorRolloutRefWorker if config.actor_rollout_ref.rollout.mode == "async" else ActorRolloutRefWorker
|
| 84 |
+
)
|
| 85 |
+
global_pool_id = "global_pool"
|
| 86 |
+
resource_pool_spec = {
|
| 87 |
+
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
|
| 88 |
+
}
|
| 89 |
+
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=None)
|
| 90 |
+
resource_pool_manager.create_resource_pool()
|
| 91 |
+
resource_pool = resource_pool_manager.resource_pool_dict[global_pool_id]
|
| 92 |
+
actor_rollout_cls = RayClassWithInitArgs(
|
| 93 |
+
cls=ray.remote(actor_rollout_cls), config=config.actor_rollout_ref, role="actor_rollout"
|
| 94 |
+
)
|
| 95 |
+
actor_rollout_wg = RayWorkerGroup(
|
| 96 |
+
resource_pool=resource_pool, ray_cls_with_init=actor_rollout_cls, device_name=get_device_name()
|
| 97 |
+
)
|
| 98 |
+
actor_rollout_wg.init_model()
|
| 99 |
+
|
| 100 |
+
agent_loop_manager = AgentLoopManager(config, worker_group=actor_rollout_wg)
|
| 101 |
+
# sleep rollout replicas
|
| 102 |
+
checkpoint_manager = CheckpointEngineManager(
|
| 103 |
+
backend=config.actor_rollout_ref.rollout.checkpoint_engine.backend,
|
| 104 |
+
trainer=actor_rollout_wg,
|
| 105 |
+
replicas=agent_loop_manager.rollout_replicas,
|
| 106 |
+
)
|
| 107 |
+
checkpoint_manager.sleep_replicas()
|
| 108 |
+
reward_loop_manager = RewardLoopManager(config, rm_resource_pool=resource_pool)
|
| 109 |
+
|
| 110 |
+
# 2. init test data
|
| 111 |
+
local_folder = os.path.expanduser("~/data/gsm8k/")
|
| 112 |
+
|
| 113 |
+
data_files = [os.path.join(local_folder, "train.parquet")]
|
| 114 |
+
tokenizer = AutoTokenizer.from_pretrained(rollout_model_path)
|
| 115 |
+
|
| 116 |
+
dataset = RLHFDataset(
|
| 117 |
+
data_files=data_files,
|
| 118 |
+
tokenizer=tokenizer,
|
| 119 |
+
config=config.data,
|
| 120 |
+
processor=None,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
batch_size = 64
|
| 124 |
+
sampler = create_rl_sampler(config.data, dataset)
|
| 125 |
+
dataloader = StatefulDataLoader(
|
| 126 |
+
dataset=dataset,
|
| 127 |
+
batch_size=batch_size,
|
| 128 |
+
num_workers=config.data.dataloader_num_workers,
|
| 129 |
+
drop_last=True,
|
| 130 |
+
collate_fn=collate_fn,
|
| 131 |
+
sampler=sampler,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# 3. generate responses
|
| 135 |
+
batch_dict = next(iter(dataloader))
|
| 136 |
+
batch = DataProto.from_single_dict(batch_dict)
|
| 137 |
+
|
| 138 |
+
def _get_gen_batch(batch: DataProto) -> DataProto:
|
| 139 |
+
reward_model_keys = set({"data_source", "reward_model", "extra_info", "uid"}) & batch.non_tensor_batch.keys()
|
| 140 |
+
|
| 141 |
+
# pop those keys for generation
|
| 142 |
+
batch_keys_to_pop = []
|
| 143 |
+
non_tensor_batch_keys_to_pop = set(batch.non_tensor_batch.keys()) - reward_model_keys
|
| 144 |
+
gen_batch = batch.pop(
|
| 145 |
+
batch_keys=batch_keys_to_pop,
|
| 146 |
+
non_tensor_batch_keys=list(non_tensor_batch_keys_to_pop),
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
# For agent loop, we need reward model keys to compute score.
|
| 150 |
+
gen_batch.non_tensor_batch.update(batch.non_tensor_batch)
|
| 151 |
+
|
| 152 |
+
return gen_batch
|
| 153 |
+
|
| 154 |
+
# wake up rollout replicas via update_weight
|
| 155 |
+
checkpoint_manager.update_weights()
|
| 156 |
+
gen_batch = _get_gen_batch(batch)
|
| 157 |
+
gen_batch = agent_loop_manager.generate_sequences(gen_batch)
|
| 158 |
+
checkpoint_manager.sleep_replicas()
|
| 159 |
+
|
| 160 |
+
batch = batch.union(gen_batch)
|
| 161 |
+
rm_outputs = reward_loop_manager.compute_rm_score(batch)
|
| 162 |
+
|
| 163 |
+
for output in rm_outputs[:5]:
|
| 164 |
+
print(output.non_tensor_batch)
|
| 165 |
+
|
| 166 |
+
print("done")
|
| 167 |
+
|
| 168 |
+
ray.shutdown()
|
code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_async_token_bucket_on_cpu.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import asyncio
|
| 16 |
+
import time
|
| 17 |
+
|
| 18 |
+
import pytest
|
| 19 |
+
|
| 20 |
+
from verl.experimental.reward_loop.reward_manager.limited import AsyncTokenBucket
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class TestAsyncTokenBucket:
|
| 24 |
+
"""Unit tests for AsyncTokenBucket rate limiter."""
|
| 25 |
+
|
| 26 |
+
@pytest.mark.asyncio
|
| 27 |
+
async def test_basic_acquire(self):
|
| 28 |
+
"""Test basic token acquisition."""
|
| 29 |
+
bucket = AsyncTokenBucket(rate_limit=10.0, max_tokens=10.0)
|
| 30 |
+
|
| 31 |
+
# Should be able to acquire tokens immediately when bucket is full
|
| 32 |
+
start = time.time()
|
| 33 |
+
await bucket.acquire(5.0)
|
| 34 |
+
elapsed = time.time() - start
|
| 35 |
+
|
| 36 |
+
assert elapsed < 0.1, "Initial acquire should be immediate"
|
| 37 |
+
assert bucket.tokens == pytest.approx(5.0, abs=0.1)
|
| 38 |
+
|
| 39 |
+
@pytest.mark.asyncio
|
| 40 |
+
async def test_refill_mechanism(self):
|
| 41 |
+
"""Test that tokens refill over time."""
|
| 42 |
+
bucket = AsyncTokenBucket(rate_limit=10.0, max_tokens=10.0)
|
| 43 |
+
|
| 44 |
+
# Consume all tokens
|
| 45 |
+
await bucket.acquire(10.0)
|
| 46 |
+
assert bucket.tokens == pytest.approx(0.0, abs=0.1)
|
| 47 |
+
|
| 48 |
+
# Wait for refill (should get ~5 tokens in 0.5 seconds at 10 tokens/sec)
|
| 49 |
+
await asyncio.sleep(0.5)
|
| 50 |
+
|
| 51 |
+
# Try to acquire 4 tokens (should succeed without waiting)
|
| 52 |
+
start = time.time()
|
| 53 |
+
await bucket.acquire(4.0)
|
| 54 |
+
elapsed = time.time() - start
|
| 55 |
+
|
| 56 |
+
assert elapsed < 0.1, "Acquire should be quick after refill"
|
| 57 |
+
|
| 58 |
+
@pytest.mark.asyncio
|
| 59 |
+
async def test_waiting_for_tokens(self):
|
| 60 |
+
"""Test that acquire waits when insufficient tokens available."""
|
| 61 |
+
bucket = AsyncTokenBucket(rate_limit=10.0, max_tokens=10.0)
|
| 62 |
+
|
| 63 |
+
# Consume all tokens
|
| 64 |
+
await bucket.acquire(10.0)
|
| 65 |
+
|
| 66 |
+
# Try to acquire more tokens (should wait ~0.5 seconds for 5 tokens)
|
| 67 |
+
start = time.time()
|
| 68 |
+
await bucket.acquire(5.0)
|
| 69 |
+
elapsed = time.time() - start
|
| 70 |
+
|
| 71 |
+
# Should wait approximately 0.5 seconds (5 tokens / 10 tokens per second)
|
| 72 |
+
assert 0.4 < elapsed < 0.7, f"Expected ~0.5s wait, got {elapsed:.3f}s"
|
| 73 |
+
|
| 74 |
+
@pytest.mark.asyncio
|
| 75 |
+
async def test_max_tokens_cap(self):
|
| 76 |
+
"""Test that tokens don't exceed max_tokens capacity."""
|
| 77 |
+
bucket = AsyncTokenBucket(rate_limit=10.0, max_tokens=5.0)
|
| 78 |
+
|
| 79 |
+
# Wait for potential overflow
|
| 80 |
+
await asyncio.sleep(1.0)
|
| 81 |
+
|
| 82 |
+
# Tokens should be capped at max_tokens
|
| 83 |
+
await bucket.acquire(1.0)
|
| 84 |
+
|
| 85 |
+
# After 1 second at 10 tokens/sec, should have max_tokens (5.0)
|
| 86 |
+
# After acquiring 1, should have 4.0 remaining
|
| 87 |
+
assert bucket.tokens <= 5.0, "Tokens should not exceed max_tokens"
|
| 88 |
+
|
| 89 |
+
@pytest.mark.asyncio
|
| 90 |
+
async def test_fractional_tokens(self):
|
| 91 |
+
"""Test acquiring fractional tokens."""
|
| 92 |
+
bucket = AsyncTokenBucket(rate_limit=100.0, max_tokens=100.0)
|
| 93 |
+
|
| 94 |
+
# Acquire fractional amounts
|
| 95 |
+
await bucket.acquire(0.5)
|
| 96 |
+
await bucket.acquire(1.5)
|
| 97 |
+
await bucket.acquire(2.3)
|
| 98 |
+
|
| 99 |
+
assert bucket.tokens == pytest.approx(100.0 - 0.5 - 1.5 - 2.3, abs=0.1)
|
| 100 |
+
|
| 101 |
+
@pytest.mark.asyncio
|
| 102 |
+
async def test_concurrent_acquires(self):
|
| 103 |
+
"""Test multiple concurrent acquire operations."""
|
| 104 |
+
bucket = AsyncTokenBucket(rate_limit=10.0, max_tokens=10.0)
|
| 105 |
+
|
| 106 |
+
async def acquire_task(num_tokens: float, task_id: int):
|
| 107 |
+
await bucket.acquire(num_tokens)
|
| 108 |
+
return task_id
|
| 109 |
+
|
| 110 |
+
# Launch 5 concurrent tasks, each acquiring 3 tokens (15 total)
|
| 111 |
+
# Bucket only has 10, so some will need to wait
|
| 112 |
+
start = time.time()
|
| 113 |
+
tasks = [acquire_task(3.0, i) for i in range(5)]
|
| 114 |
+
results = await asyncio.gather(*tasks)
|
| 115 |
+
elapsed = time.time() - start
|
| 116 |
+
|
| 117 |
+
# Should take at least 0.5 seconds to refill 5 tokens
|
| 118 |
+
# (15 needed - 10 available) / 10 tokens per second = 0.5 seconds
|
| 119 |
+
assert elapsed >= 0.4, f"Expected >=0.4s for concurrent acquires, got {elapsed:.3f}s"
|
| 120 |
+
assert len(results) == 5, "All tasks should complete"
|
| 121 |
+
|
| 122 |
+
@pytest.mark.asyncio
|
| 123 |
+
async def test_high_rate_limit(self):
|
| 124 |
+
"""Test with high rate limit (simulating high-throughput scenarios)."""
|
| 125 |
+
bucket = AsyncTokenBucket(rate_limit=1000.0, max_tokens=1000.0)
|
| 126 |
+
|
| 127 |
+
# Rapidly acquire tokens
|
| 128 |
+
start = time.time()
|
| 129 |
+
for _ in range(100):
|
| 130 |
+
await bucket.acquire(10.0) # 1000 tokens total
|
| 131 |
+
elapsed = time.time() - start
|
| 132 |
+
|
| 133 |
+
# Should complete in approximately 1 second
|
| 134 |
+
assert elapsed < 1.5, f"High rate limit test took too long: {elapsed:.3f}s"
|
| 135 |
+
|
| 136 |
+
@pytest.mark.asyncio
|
| 137 |
+
async def test_zero_initial_state(self):
|
| 138 |
+
"""Test that bucket starts with full tokens."""
|
| 139 |
+
bucket = AsyncTokenBucket(rate_limit=10.0, max_tokens=10.0)
|
| 140 |
+
|
| 141 |
+
assert bucket.tokens == 10.0, "Bucket should start full"
|
| 142 |
+
assert bucket.last_update is None, "last_update should be None initially"
|
| 143 |
+
|
| 144 |
+
# After first acquire, last_update should be set
|
| 145 |
+
await bucket.acquire(1.0)
|
| 146 |
+
assert bucket.last_update is not None, "last_update should be set after acquire"
|
| 147 |
+
|
| 148 |
+
@pytest.mark.asyncio
|
| 149 |
+
async def test_rate_limit_accuracy(self):
|
| 150 |
+
"""Test rate limit accuracy over time."""
|
| 151 |
+
rate = 50.0 # 50 tokens per second
|
| 152 |
+
bucket = AsyncTokenBucket(rate_limit=rate, max_tokens=rate)
|
| 153 |
+
|
| 154 |
+
# Consume all tokens and measure refill time for 25 tokens
|
| 155 |
+
await bucket.acquire(50.0)
|
| 156 |
+
|
| 157 |
+
start = time.time()
|
| 158 |
+
await bucket.acquire(25.0)
|
| 159 |
+
elapsed = time.time() - start
|
| 160 |
+
|
| 161 |
+
expected_time = 25.0 / rate # 0.5 seconds
|
| 162 |
+
# Allow 20% margin for timing inaccuracy
|
| 163 |
+
assert abs(elapsed - expected_time) < expected_time * 0.2, f"Expected ~{expected_time:.3f}s, got {elapsed:.3f}s"
|
| 164 |
+
|
| 165 |
+
@pytest.mark.asyncio
|
| 166 |
+
async def test_sequential_acquires(self):
|
| 167 |
+
"""Test sequential acquire operations."""
|
| 168 |
+
bucket = AsyncTokenBucket(rate_limit=20.0, max_tokens=20.0)
|
| 169 |
+
|
| 170 |
+
# Sequential acquires without waiting
|
| 171 |
+
await bucket.acquire(5.0)
|
| 172 |
+
await bucket.acquire(5.0)
|
| 173 |
+
await bucket.acquire(5.0)
|
| 174 |
+
await bucket.acquire(5.0)
|
| 175 |
+
|
| 176 |
+
# Bucket should be empty
|
| 177 |
+
assert bucket.tokens == pytest.approx(0.0, abs=0.1)
|
| 178 |
+
|
| 179 |
+
# Next acquire should wait
|
| 180 |
+
start = time.time()
|
| 181 |
+
await bucket.acquire(10.0)
|
| 182 |
+
elapsed = time.time() - start
|
| 183 |
+
|
| 184 |
+
assert elapsed >= 0.4, "Should wait for token refill"
|
| 185 |
+
|
| 186 |
+
@pytest.mark.asyncio
|
| 187 |
+
async def test_default_max_tokens(self):
|
| 188 |
+
"""Test that max_tokens defaults to rate_limit."""
|
| 189 |
+
bucket = AsyncTokenBucket(rate_limit=15.0)
|
| 190 |
+
|
| 191 |
+
assert bucket.max_tokens == 15.0, "max_tokens should default to rate_limit"
|
| 192 |
+
assert bucket.tokens == 15.0, "Initial tokens should equal max_tokens"
|
| 193 |
+
|
| 194 |
+
@pytest.mark.asyncio
|
| 195 |
+
async def test_single_token_acquire(self):
|
| 196 |
+
"""Test default acquire of 1 token."""
|
| 197 |
+
bucket = AsyncTokenBucket(rate_limit=10.0, max_tokens=10.0)
|
| 198 |
+
|
| 199 |
+
await bucket.acquire() # Default num_tokens=1.0
|
| 200 |
+
|
| 201 |
+
assert bucket.tokens == pytest.approx(9.0, abs=0.1)
|
| 202 |
+
|
| 203 |
+
@pytest.mark.asyncio
|
| 204 |
+
async def test_large_token_acquire(self):
|
| 205 |
+
"""Test acquiring more tokens than bucket capacity."""
|
| 206 |
+
bucket = AsyncTokenBucket(rate_limit=10.0, max_tokens=10.0)
|
| 207 |
+
|
| 208 |
+
# Try to acquire 50 tokens (5x capacity)
|
| 209 |
+
start = time.time()
|
| 210 |
+
await bucket.acquire(50.0)
|
| 211 |
+
elapsed = time.time() - start
|
| 212 |
+
|
| 213 |
+
# Should wait for: (50 - 10) / 10 = 4 seconds
|
| 214 |
+
assert 3.5 < elapsed < 5.0, f"Expected ~4s wait for large acquire, got {elapsed:.3f}s"
|
| 215 |
+
|
| 216 |
+
@pytest.mark.asyncio
|
| 217 |
+
async def test_thread_safety_with_lock(self):
|
| 218 |
+
"""Test that lock prevents race conditions."""
|
| 219 |
+
bucket = AsyncTokenBucket(rate_limit=100.0, max_tokens=100.0)
|
| 220 |
+
results = []
|
| 221 |
+
|
| 222 |
+
async def acquire_and_record():
|
| 223 |
+
await bucket.acquire(10.0)
|
| 224 |
+
results.append(1)
|
| 225 |
+
|
| 226 |
+
# Launch many concurrent tasks
|
| 227 |
+
tasks = [acquire_and_record() for _ in range(10)]
|
| 228 |
+
await asyncio.gather(*tasks)
|
| 229 |
+
|
| 230 |
+
# All tasks should complete
|
| 231 |
+
assert len(results) == 10, "All tasks should complete successfully"
|
| 232 |
+
|
| 233 |
+
# Bucket should have consumed exactly 100 tokens
|
| 234 |
+
assert bucket.tokens == pytest.approx(0.0, abs=0.5)
|
| 235 |
+
|
| 236 |
+
@pytest.mark.asyncio
|
| 237 |
+
async def test_multiple_wait_cycles(self):
|
| 238 |
+
"""Test multiple wait cycles in the acquire loop."""
|
| 239 |
+
bucket = AsyncTokenBucket(rate_limit=10.0, max_tokens=10.0)
|
| 240 |
+
|
| 241 |
+
# Consume all tokens
|
| 242 |
+
await bucket.acquire(10.0)
|
| 243 |
+
|
| 244 |
+
# Acquire tokens that require multiple refill cycles
|
| 245 |
+
start = time.time()
|
| 246 |
+
await bucket.acquire(15.0)
|
| 247 |
+
elapsed = time.time() - start
|
| 248 |
+
|
| 249 |
+
# Should wait for 15 tokens / 10 tokens per second = 1.5 seconds
|
| 250 |
+
assert 1.3 < elapsed < 1.8, f"Expected ~1.5s for multiple refill cycles, got {elapsed:.3f}s"
|
| 251 |
+
|
| 252 |
+
@pytest.mark.asyncio
|
| 253 |
+
async def test_rapid_small_acquires(self):
|
| 254 |
+
"""Test many rapid small acquisitions."""
|
| 255 |
+
bucket = AsyncTokenBucket(rate_limit=100.0, max_tokens=100.0)
|
| 256 |
+
|
| 257 |
+
start = time.time()
|
| 258 |
+
for _ in range(50):
|
| 259 |
+
await bucket.acquire(2.0) # 100 tokens total
|
| 260 |
+
elapsed = time.time() - start
|
| 261 |
+
|
| 262 |
+
# Should complete quickly since we're within capacity
|
| 263 |
+
assert elapsed < 0.5, f"Rapid small acquires took too long: {elapsed:.3f}s"
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
if __name__ == "__main__":
|
| 267 |
+
pytest.main([__file__, "-v"])
|
code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_math_verify.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import os
|
| 15 |
+
|
| 16 |
+
import ray
|
| 17 |
+
from hydra import compose, initialize_config_dir
|
| 18 |
+
from torchdata.stateful_dataloader import StatefulDataLoader
|
| 19 |
+
from transformers import AutoTokenizer
|
| 20 |
+
|
| 21 |
+
from verl.experimental.agent_loop import AgentLoopManager
|
| 22 |
+
from verl.protocol import DataProto
|
| 23 |
+
from verl.trainer.main_ppo import create_rl_sampler
|
| 24 |
+
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def test_agent_loop_reward_manager():
|
| 28 |
+
ray.init(
|
| 29 |
+
runtime_env={
|
| 30 |
+
"env_vars": {
|
| 31 |
+
"TOKENIZERS_PARALLELISM": "true",
|
| 32 |
+
"NCCL_DEBUG": "WARN",
|
| 33 |
+
"VLLM_LOGGING_LEVEL": "INFO",
|
| 34 |
+
"VLLM_USE_V1": "1",
|
| 35 |
+
}
|
| 36 |
+
}
|
| 37 |
+
)
|
| 38 |
+
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
|
| 39 |
+
config = compose(config_name="ppo_trainer")
|
| 40 |
+
|
| 41 |
+
rollout_model_path = os.path.expanduser("~/models/Qwen/Qwen2.5-3B-Instruct")
|
| 42 |
+
|
| 43 |
+
# actor_rollout_ref config
|
| 44 |
+
config.data.return_raw_chat = True
|
| 45 |
+
config.data.max_prompt_length = 1024
|
| 46 |
+
config.data.max_response_length = 4096
|
| 47 |
+
config.actor_rollout_ref.model.path = rollout_model_path
|
| 48 |
+
config.actor_rollout_ref.actor.use_dynamic_bsz = True
|
| 49 |
+
config.actor_rollout_ref.rollout.name = os.getenv("ROLLOUT_NAME", "vllm")
|
| 50 |
+
config.actor_rollout_ref.rollout.mode = "async"
|
| 51 |
+
config.actor_rollout_ref.rollout.tensor_model_parallel_size = 2
|
| 52 |
+
config.actor_rollout_ref.rollout.gpu_memory_utilization = 0.9
|
| 53 |
+
config.actor_rollout_ref.rollout.enforce_eager = True
|
| 54 |
+
config.actor_rollout_ref.rollout.prompt_length = 2048
|
| 55 |
+
config.actor_rollout_ref.rollout.response_length = 4096
|
| 56 |
+
config.actor_rollout_ref.rollout.skip_tokenizer_init = True
|
| 57 |
+
config.trainer.n_gpus_per_node = 8
|
| 58 |
+
config.trainer.nnodes = 1
|
| 59 |
+
|
| 60 |
+
config.reward_model.reward_manager = "remote"
|
| 61 |
+
config.reward_model.num_workers = 2
|
| 62 |
+
config.custom_reward_function.path = "tests/experimental/reward_loop/reward_fn.py"
|
| 63 |
+
config.custom_reward_function.name = "compute_score_math_verify"
|
| 64 |
+
|
| 65 |
+
# 1. init reward model manager
|
| 66 |
+
agent_loop_manager = AgentLoopManager(config)
|
| 67 |
+
|
| 68 |
+
# 2. init test data
|
| 69 |
+
local_folder = os.path.expanduser("~/data/math/")
|
| 70 |
+
data_files = [os.path.join(local_folder, "train.parquet")]
|
| 71 |
+
tokenizer = AutoTokenizer.from_pretrained(rollout_model_path)
|
| 72 |
+
|
| 73 |
+
dataset = RLHFDataset(
|
| 74 |
+
data_files=data_files,
|
| 75 |
+
tokenizer=tokenizer,
|
| 76 |
+
config=config.data,
|
| 77 |
+
processor=None,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
batch_size = 64
|
| 81 |
+
sampler = create_rl_sampler(config.data, dataset)
|
| 82 |
+
dataloader = StatefulDataLoader(
|
| 83 |
+
dataset=dataset,
|
| 84 |
+
batch_size=batch_size,
|
| 85 |
+
num_workers=config.data.dataloader_num_workers,
|
| 86 |
+
drop_last=True,
|
| 87 |
+
collate_fn=collate_fn,
|
| 88 |
+
sampler=sampler,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# 3. generate responses
|
| 92 |
+
batch_dict = next(iter(dataloader))
|
| 93 |
+
batch = DataProto.from_single_dict(batch_dict)
|
| 94 |
+
gen_batch = agent_loop_manager.generate_sequences(prompts=batch)
|
| 95 |
+
|
| 96 |
+
rm_scores = gen_batch.batch["rm_scores"]
|
| 97 |
+
accuracy = rm_scores.sum(dim=-1).mean()
|
| 98 |
+
print(accuracy)
|
| 99 |
+
|
| 100 |
+
ray.shutdown()
|
code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_rate_limited_reward_manager_on_cpu.py
ADDED
|
@@ -0,0 +1,528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import asyncio
|
| 16 |
+
import logging
|
| 17 |
+
import os.path
|
| 18 |
+
import time
|
| 19 |
+
|
| 20 |
+
import pytest
|
| 21 |
+
import torch
|
| 22 |
+
from omegaconf import DictConfig
|
| 23 |
+
from transformers import AutoTokenizer
|
| 24 |
+
|
| 25 |
+
from verl import DataProto
|
| 26 |
+
from verl.experimental.reward_loop.reward_manager.limited import RateLimitedRewardManager
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# Mock API reward functions for testing
|
| 30 |
+
class MockAPICounter:
|
| 31 |
+
"""Shared counter to track API calls across tests."""
|
| 32 |
+
|
| 33 |
+
def __init__(self):
|
| 34 |
+
self.call_count = 0
|
| 35 |
+
self.call_times = []
|
| 36 |
+
self.lock = asyncio.Lock()
|
| 37 |
+
|
| 38 |
+
async def record_call(self):
|
| 39 |
+
async with self.lock:
|
| 40 |
+
self.call_count += 1
|
| 41 |
+
self.call_times.append(time.time())
|
| 42 |
+
|
| 43 |
+
def reset(self):
|
| 44 |
+
self.call_count = 0
|
| 45 |
+
self.call_times.clear()
|
| 46 |
+
|
| 47 |
+
def get_rate_per_second(self, window_start: float = None):
|
| 48 |
+
"""Calculate API call rate over a time window."""
|
| 49 |
+
if window_start is None:
|
| 50 |
+
if not self.call_times:
|
| 51 |
+
return 0.0
|
| 52 |
+
window_start = self.call_times[0]
|
| 53 |
+
|
| 54 |
+
if not self.call_times:
|
| 55 |
+
return 0.0
|
| 56 |
+
|
| 57 |
+
window_end = self.call_times[-1]
|
| 58 |
+
duration = window_end - window_start
|
| 59 |
+
|
| 60 |
+
if duration <= 0:
|
| 61 |
+
return 0.0
|
| 62 |
+
|
| 63 |
+
calls_in_window = sum(1 for t in self.call_times if t >= window_start)
|
| 64 |
+
return calls_in_window / duration
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# Global counter instance
|
| 68 |
+
api_counter = MockAPICounter()
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def mock_sync_reward_function(
|
| 72 |
+
data_source: str, solution_str: str, ground_truth: str, extra_info: dict, **kwargs
|
| 73 |
+
) -> float:
|
| 74 |
+
"""Synchronous mock reward function that simulates API call."""
|
| 75 |
+
# Simulate API processing time
|
| 76 |
+
time.sleep(0.01)
|
| 77 |
+
|
| 78 |
+
# Simple scoring logic
|
| 79 |
+
score = 1.0 if solution_str.strip() == ground_truth.strip() else 0.0
|
| 80 |
+
return score
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
async def mock_async_reward_function(
|
| 84 |
+
data_source: str, solution_str: str, ground_truth: str, extra_info: dict, **kwargs
|
| 85 |
+
) -> float:
|
| 86 |
+
"""Asynchronous mock reward function that simulates API call."""
|
| 87 |
+
# Record API call for rate tracking
|
| 88 |
+
await api_counter.record_call()
|
| 89 |
+
|
| 90 |
+
# Simulate async API call (e.g., HTTP request)
|
| 91 |
+
await asyncio.sleep(0.01)
|
| 92 |
+
|
| 93 |
+
# Simple scoring logic
|
| 94 |
+
score = 1.0 if solution_str.strip() == ground_truth.strip() else 0.0
|
| 95 |
+
return score
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
async def mock_slow_api_function(
|
| 99 |
+
data_source: str, solution_str: str, ground_truth: str, extra_info: dict, **kwargs
|
| 100 |
+
) -> float:
|
| 101 |
+
"""Slow mock API function for timeout testing."""
|
| 102 |
+
await asyncio.sleep(2.0) # Simulate slow API
|
| 103 |
+
return 0.5
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
async def mock_failing_api_function(
|
| 107 |
+
data_source: str, solution_str: str, ground_truth: str, extra_info: dict, **kwargs
|
| 108 |
+
) -> float:
|
| 109 |
+
"""Mock API function that raises an exception."""
|
| 110 |
+
await api_counter.record_call()
|
| 111 |
+
raise ValueError("Simulated API error")
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
async def mock_dict_result_function(
|
| 115 |
+
data_source: str, solution_str: str, ground_truth: str, extra_info: dict, **kwargs
|
| 116 |
+
) -> dict:
|
| 117 |
+
"""Mock API function that returns dict result."""
|
| 118 |
+
await api_counter.record_call()
|
| 119 |
+
await asyncio.sleep(0.01)
|
| 120 |
+
|
| 121 |
+
correct = solution_str.strip() == ground_truth.strip()
|
| 122 |
+
return {"score": 1.0 if correct else 0.0, "correct": correct, "reasoning": "Mock reasoning"}
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def create_test_data_proto(tokenizer, response_text: str, ground_truth: str, data_source: str = "test"):
|
| 126 |
+
"""Helper to create DataProto for testing."""
|
| 127 |
+
response_ids = tokenizer.encode(response_text, add_special_tokens=False)
|
| 128 |
+
response_tensor = torch.tensor([response_ids], dtype=torch.long)
|
| 129 |
+
attention_mask = torch.ones_like(response_tensor)
|
| 130 |
+
|
| 131 |
+
data = DataProto.from_dict(
|
| 132 |
+
{
|
| 133 |
+
"responses": response_tensor,
|
| 134 |
+
"attention_mask": attention_mask,
|
| 135 |
+
}
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# Wrap non-tensor values in lists to match batch dimension
|
| 139 |
+
data.non_tensor_batch = {"data_source": [data_source], "reward_model": [{"ground_truth": ground_truth}]}
|
| 140 |
+
|
| 141 |
+
return data
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class TestRateLimitedRewardManager:
|
| 145 |
+
"""Integration tests for RateLimitedRewardManager with mock API functions."""
|
| 146 |
+
|
| 147 |
+
@pytest.fixture(autouse=True)
|
| 148 |
+
def setup_and_teardown(self):
|
| 149 |
+
"""Reset global state before each test."""
|
| 150 |
+
api_counter.reset()
|
| 151 |
+
# Reset class state
|
| 152 |
+
RateLimitedRewardManager._class_initialized = False
|
| 153 |
+
RateLimitedRewardManager._semaphore = None
|
| 154 |
+
RateLimitedRewardManager._rpm_limiter = None
|
| 155 |
+
RateLimitedRewardManager._tpm_limiter = None
|
| 156 |
+
yield
|
| 157 |
+
# Cleanup
|
| 158 |
+
api_counter.reset()
|
| 159 |
+
|
| 160 |
+
@pytest.fixture
|
| 161 |
+
def tokenizer(self):
|
| 162 |
+
"""Load a simple tokenizer for testing."""
|
| 163 |
+
return AutoTokenizer.from_pretrained(os.path.expanduser("~/models/Qwen/Qwen2.5-0.5B-Instruct"))
|
| 164 |
+
|
| 165 |
+
@pytest.mark.asyncio
|
| 166 |
+
async def test_basic_reward_computation(self, tokenizer):
|
| 167 |
+
"""Test basic reward computation without rate limiting."""
|
| 168 |
+
config = DictConfig({"reward_model": {"max_concurrent": 10, "timeout": 10.0}})
|
| 169 |
+
|
| 170 |
+
RateLimitedRewardManager.init_class(config, tokenizer)
|
| 171 |
+
manager = RateLimitedRewardManager(config=config, tokenizer=tokenizer, compute_score=mock_async_reward_function)
|
| 172 |
+
|
| 173 |
+
# Create test data
|
| 174 |
+
data = create_test_data_proto(tokenizer, "correct answer", "correct answer")
|
| 175 |
+
|
| 176 |
+
# Compute reward
|
| 177 |
+
result = await manager.run_single(data)
|
| 178 |
+
|
| 179 |
+
assert "reward_score" in result
|
| 180 |
+
assert result["reward_score"] == 1.0
|
| 181 |
+
assert api_counter.call_count == 1
|
| 182 |
+
|
| 183 |
+
@pytest.mark.asyncio
|
| 184 |
+
async def test_rpm_rate_limiting(self, tokenizer):
|
| 185 |
+
"""Test request per minute (RPM) rate limiting."""
|
| 186 |
+
# Set RPM limit to 60 (1 request per second)
|
| 187 |
+
config = DictConfig(
|
| 188 |
+
{
|
| 189 |
+
"reward_model": {
|
| 190 |
+
"max_concurrent": 10,
|
| 191 |
+
"max_rpm": 60, # 1 request per second
|
| 192 |
+
"timeout": 10.0,
|
| 193 |
+
}
|
| 194 |
+
}
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
RateLimitedRewardManager.init_class(config, tokenizer)
|
| 198 |
+
manager = RateLimitedRewardManager(config=config, tokenizer=tokenizer, compute_score=mock_async_reward_function)
|
| 199 |
+
|
| 200 |
+
# Create test data
|
| 201 |
+
data = create_test_data_proto(tokenizer, "answer", "answer")
|
| 202 |
+
|
| 203 |
+
# Make 3 requests - should be rate limited
|
| 204 |
+
start_time = time.time()
|
| 205 |
+
|
| 206 |
+
results = []
|
| 207 |
+
for _ in range(3):
|
| 208 |
+
result = await manager.run_single(data)
|
| 209 |
+
results.append(result)
|
| 210 |
+
|
| 211 |
+
elapsed = time.time() - start_time
|
| 212 |
+
|
| 213 |
+
# Should take at least ~2 seconds for 3 requests at 1 req/sec
|
| 214 |
+
assert elapsed >= 1.8, f"RPM limiting failed: {elapsed:.3f}s for 3 requests"
|
| 215 |
+
assert all(r["reward_score"] == 1.0 for r in results)
|
| 216 |
+
assert api_counter.call_count == 3
|
| 217 |
+
|
| 218 |
+
@pytest.mark.asyncio
|
| 219 |
+
async def test_tpm_rate_limiting(self, tokenizer):
|
| 220 |
+
"""Test tokens per minute (TPM) rate limiting."""
|
| 221 |
+
# Set TPM limit to 6000 (100 tokens per second)
|
| 222 |
+
# With 2000 tokens per request, that's 0.05 req/sec or 20 seconds per request
|
| 223 |
+
config = DictConfig(
|
| 224 |
+
{
|
| 225 |
+
"reward_model": {
|
| 226 |
+
"max_concurrent": 10,
|
| 227 |
+
"max_tpm": 6000, # 100 tokens per second
|
| 228 |
+
"estimated_tokens_per_request": 2000, # Each request = 2000 tokens
|
| 229 |
+
"timeout": 30.0,
|
| 230 |
+
}
|
| 231 |
+
}
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
RateLimitedRewardManager.init_class(config, tokenizer)
|
| 235 |
+
manager = RateLimitedRewardManager(config=config, tokenizer=tokenizer, compute_score=mock_async_reward_function)
|
| 236 |
+
|
| 237 |
+
data = create_test_data_proto(tokenizer, "answer", "answer")
|
| 238 |
+
|
| 239 |
+
# Make 2 requests
|
| 240 |
+
start_time = time.time()
|
| 241 |
+
|
| 242 |
+
result1 = await manager.run_single(data)
|
| 243 |
+
result2 = await manager.run_single(data)
|
| 244 |
+
|
| 245 |
+
elapsed = time.time() - start_time
|
| 246 |
+
|
| 247 |
+
# First request: consumes 2000 tokens (immediate)
|
| 248 |
+
# Second request: needs 2000 tokens, waits for refill
|
| 249 |
+
# Wait time: 2000 tokens / 100 tokens per second = 20 seconds
|
| 250 |
+
assert elapsed >= 18.0, f"TPM limiting failed: {elapsed:.3f}s for 2 requests"
|
| 251 |
+
assert result1["reward_score"] == 1.0
|
| 252 |
+
assert result2["reward_score"] == 1.0
|
| 253 |
+
|
| 254 |
+
@pytest.mark.asyncio
|
| 255 |
+
async def test_concurrency_limiting(self, tokenizer):
|
| 256 |
+
"""Test concurrent request limiting."""
|
| 257 |
+
config = DictConfig(
|
| 258 |
+
{
|
| 259 |
+
"reward_model": {
|
| 260 |
+
"max_concurrent": 2, # Only 2 concurrent requests
|
| 261 |
+
"timeout": 10.0,
|
| 262 |
+
}
|
| 263 |
+
}
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
RateLimitedRewardManager.init_class(config, tokenizer)
|
| 267 |
+
manager = RateLimitedRewardManager(config=config, tokenizer=tokenizer, compute_score=mock_async_reward_function)
|
| 268 |
+
|
| 269 |
+
data = create_test_data_proto(tokenizer, "answer", "answer")
|
| 270 |
+
|
| 271 |
+
# Launch 5 concurrent requests
|
| 272 |
+
start_time = time.time()
|
| 273 |
+
|
| 274 |
+
tasks = [manager.run_single(data) for _ in range(5)]
|
| 275 |
+
results = await asyncio.gather(*tasks)
|
| 276 |
+
|
| 277 |
+
elapsed = time.time() - start_time
|
| 278 |
+
|
| 279 |
+
# All should succeed
|
| 280 |
+
assert len(results) == 5
|
| 281 |
+
assert all(r["reward_score"] == 1.0 for r in results)
|
| 282 |
+
|
| 283 |
+
# With concurrency=2 and 0.01s per request, should take at least 0.03s
|
| 284 |
+
# (3 batches: 2+2+1)
|
| 285 |
+
assert elapsed >= 0.02, f"Concurrency limiting may not be working: {elapsed:.3f}s"
|
| 286 |
+
|
| 287 |
+
@pytest.mark.asyncio
|
| 288 |
+
async def test_timeout_handling(self, tokenizer):
|
| 289 |
+
"""Test timeout handling for slow API."""
|
| 290 |
+
config = DictConfig(
|
| 291 |
+
{
|
| 292 |
+
"reward_model": {
|
| 293 |
+
"max_concurrent": 10,
|
| 294 |
+
"timeout": 0.5, # 500ms timeout
|
| 295 |
+
}
|
| 296 |
+
}
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
RateLimitedRewardManager.init_class(config, tokenizer)
|
| 300 |
+
manager = RateLimitedRewardManager(config=config, tokenizer=tokenizer, compute_score=mock_slow_api_function)
|
| 301 |
+
|
| 302 |
+
data = create_test_data_proto(tokenizer, "answer", "answer")
|
| 303 |
+
|
| 304 |
+
# Should timeout and return 0.0
|
| 305 |
+
result = await manager.run_single(data)
|
| 306 |
+
|
| 307 |
+
assert result["reward_score"] == 0.0
|
| 308 |
+
assert result["reward_extra_info"].get("timeout") is True
|
| 309 |
+
assert result["reward_extra_info"].get("acc") == 0.0
|
| 310 |
+
|
| 311 |
+
@pytest.mark.asyncio
|
| 312 |
+
async def test_error_handling(self, tokenizer):
|
| 313 |
+
"""Test error handling for failing API."""
|
| 314 |
+
config = DictConfig({"reward_model": {"max_concurrent": 10, "timeout": 10.0}})
|
| 315 |
+
|
| 316 |
+
RateLimitedRewardManager.init_class(config, tokenizer)
|
| 317 |
+
manager = RateLimitedRewardManager(config=config, tokenizer=tokenizer, compute_score=mock_failing_api_function)
|
| 318 |
+
|
| 319 |
+
data = create_test_data_proto(tokenizer, "answer", "answer")
|
| 320 |
+
|
| 321 |
+
# Should catch exception and return 0.0
|
| 322 |
+
result = await manager.run_single(data)
|
| 323 |
+
|
| 324 |
+
assert result["reward_score"] == 0.0
|
| 325 |
+
assert "error" in result["reward_extra_info"]
|
| 326 |
+
assert "Simulated API error" in result["reward_extra_info"]["error"]
|
| 327 |
+
assert result["reward_extra_info"].get("acc") == 0.0
|
| 328 |
+
assert api_counter.call_count == 1
|
| 329 |
+
|
| 330 |
+
@pytest.mark.asyncio
|
| 331 |
+
async def test_dict_result_format(self, tokenizer):
|
| 332 |
+
"""Test handling of dict return format from reward function."""
|
| 333 |
+
config = DictConfig({"reward_model": {"max_concurrent": 10, "timeout": 10.0}})
|
| 334 |
+
|
| 335 |
+
RateLimitedRewardManager.init_class(config, tokenizer)
|
| 336 |
+
manager = RateLimitedRewardManager(config=config, tokenizer=tokenizer, compute_score=mock_dict_result_function)
|
| 337 |
+
|
| 338 |
+
data = create_test_data_proto(tokenizer, "correct", "correct")
|
| 339 |
+
|
| 340 |
+
result = await manager.run_single(data)
|
| 341 |
+
|
| 342 |
+
assert result["reward_score"] == 1.0
|
| 343 |
+
assert result["reward_extra_info"]["score"] == 1.0
|
| 344 |
+
assert result["reward_extra_info"]["correct"] is True
|
| 345 |
+
assert result["reward_extra_info"]["reasoning"] == "Mock reasoning"
|
| 346 |
+
|
| 347 |
+
@pytest.mark.asyncio
|
| 348 |
+
async def test_sync_reward_function(self, tokenizer):
|
| 349 |
+
"""Test that synchronous reward functions work correctly."""
|
| 350 |
+
config = DictConfig({"reward_model": {"max_concurrent": 10, "timeout": 10.0}})
|
| 351 |
+
|
| 352 |
+
RateLimitedRewardManager.init_class(config, tokenizer)
|
| 353 |
+
manager = RateLimitedRewardManager(config=config, tokenizer=tokenizer, compute_score=mock_sync_reward_function)
|
| 354 |
+
|
| 355 |
+
data = create_test_data_proto(tokenizer, "answer", "answer")
|
| 356 |
+
|
| 357 |
+
result = await manager.run_single(data)
|
| 358 |
+
|
| 359 |
+
assert result["reward_score"] == 1.0
|
| 360 |
+
assert manager.is_async_reward_score is False
|
| 361 |
+
|
| 362 |
+
@pytest.mark.asyncio
|
| 363 |
+
async def test_combined_rate_limits(self, tokenizer):
|
| 364 |
+
"""Test all three rate limiting layers together."""
|
| 365 |
+
config = DictConfig(
|
| 366 |
+
{
|
| 367 |
+
"reward_model": {
|
| 368 |
+
"max_concurrent": 2,
|
| 369 |
+
"max_rpm": 120, # 2 requests per second
|
| 370 |
+
"max_tpm": 12000, # 200 tokens per second
|
| 371 |
+
"estimated_tokens_per_request": 100, # 0.5 seconds per request
|
| 372 |
+
"timeout": 10.0,
|
| 373 |
+
}
|
| 374 |
+
}
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
RateLimitedRewardManager.init_class(config, tokenizer)
|
| 378 |
+
manager = RateLimitedRewardManager(config=config, tokenizer=tokenizer, compute_score=mock_async_reward_function)
|
| 379 |
+
|
| 380 |
+
data = create_test_data_proto(tokenizer, "answer", "answer")
|
| 381 |
+
|
| 382 |
+
# Make 6 requests to exceed burst capacity (RPM bucket starts with 2 tokens)
|
| 383 |
+
start_time = time.time()
|
| 384 |
+
|
| 385 |
+
tasks = [manager.run_single(data) for _ in range(6)]
|
| 386 |
+
results = await asyncio.gather(*tasks)
|
| 387 |
+
|
| 388 |
+
elapsed = time.time() - start_time
|
| 389 |
+
|
| 390 |
+
# Bucket starts with 2 RPM tokens and 200 TPM tokens
|
| 391 |
+
# First 2 requests: use burst capacity (2 RPM tokens, 200 TPM tokens)
|
| 392 |
+
# Next 4 requests: need 4 RPM tokens (wait 2 seconds) and 400 TPM tokens (wait 2 seconds)
|
| 393 |
+
# Limiting factor: RPM at 2 seconds
|
| 394 |
+
assert elapsed >= 1.8, f"Combined rate limiting: {elapsed:.3f}s"
|
| 395 |
+
assert all(r["reward_score"] == 1.0 for r in results)
|
| 396 |
+
assert api_counter.call_count == 6
|
| 397 |
+
|
| 398 |
+
@pytest.mark.asyncio
|
| 399 |
+
async def test_correct_vs_incorrect_answers(self, tokenizer):
|
| 400 |
+
"""Test scoring of correct vs incorrect answers."""
|
| 401 |
+
config = DictConfig({"reward_model": {"max_concurrent": 10, "timeout": 10.0}})
|
| 402 |
+
|
| 403 |
+
RateLimitedRewardManager.init_class(config, tokenizer)
|
| 404 |
+
manager = RateLimitedRewardManager(config=config, tokenizer=tokenizer, compute_score=mock_async_reward_function)
|
| 405 |
+
|
| 406 |
+
# Test correct answer
|
| 407 |
+
data_correct = create_test_data_proto(tokenizer, "right answer", "right answer")
|
| 408 |
+
result_correct = await manager.run_single(data_correct)
|
| 409 |
+
|
| 410 |
+
# Test incorrect answer
|
| 411 |
+
data_incorrect = create_test_data_proto(tokenizer, "wrong answer", "right answer")
|
| 412 |
+
result_incorrect = await manager.run_single(data_incorrect)
|
| 413 |
+
|
| 414 |
+
assert result_correct["reward_score"] == 1.0
|
| 415 |
+
assert result_incorrect["reward_score"] == 0.0
|
| 416 |
+
|
| 417 |
+
@pytest.mark.asyncio
|
| 418 |
+
async def test_high_throughput(self, tokenizer):
|
| 419 |
+
"""Test high throughput with many concurrent requests."""
|
| 420 |
+
config = DictConfig(
|
| 421 |
+
{
|
| 422 |
+
"reward_model": {
|
| 423 |
+
"max_concurrent": 20,
|
| 424 |
+
"max_rpm": 6000, # 100 requests per second
|
| 425 |
+
"timeout": 10.0,
|
| 426 |
+
}
|
| 427 |
+
}
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
RateLimitedRewardManager.init_class(config, tokenizer)
|
| 431 |
+
manager = RateLimitedRewardManager(config=config, tokenizer=tokenizer, compute_score=mock_async_reward_function)
|
| 432 |
+
|
| 433 |
+
data = create_test_data_proto(tokenizer, "answer", "answer")
|
| 434 |
+
|
| 435 |
+
# Launch 200 concurrent requests (more than burst capacity of 100)
|
| 436 |
+
start_time = time.time()
|
| 437 |
+
|
| 438 |
+
tasks = [manager.run_single(data) for _ in range(200)]
|
| 439 |
+
results = await asyncio.gather(*tasks)
|
| 440 |
+
|
| 441 |
+
elapsed = time.time() - start_time
|
| 442 |
+
|
| 443 |
+
assert len(results) == 200
|
| 444 |
+
assert all(r["reward_score"] == 1.0 for r in results)
|
| 445 |
+
|
| 446 |
+
# Bucket starts with 100 tokens (burst capacity)
|
| 447 |
+
# First 100 requests: use burst capacity instantly
|
| 448 |
+
# Next 100 requests: need to wait for refill at 100 tokens/sec = 1 second minimum
|
| 449 |
+
# Total time should be at least 1 second
|
| 450 |
+
assert elapsed >= 0.9, f"Should take at least 0.9s for rate limiting, took {elapsed:.3f}s"
|
| 451 |
+
|
| 452 |
+
# Calculate actual rate over the time window
|
| 453 |
+
actual_rate = api_counter.call_count / elapsed
|
| 454 |
+
|
| 455 |
+
# Average rate should not significantly exceed 100 req/sec
|
| 456 |
+
# Allow some burst overhead due to initial capacity
|
| 457 |
+
assert actual_rate <= 200, f"Rate limiting failed: {actual_rate:.1f} req/sec (max 200)"
|
| 458 |
+
|
| 459 |
+
@pytest.mark.asyncio
|
| 460 |
+
async def test_class_initialization_once(self, tokenizer):
|
| 461 |
+
"""Test that class initialization only happens once."""
|
| 462 |
+
config = DictConfig({"reward_model": {"max_concurrent": 5, "timeout": 10.0}})
|
| 463 |
+
|
| 464 |
+
# Initialize multiple times
|
| 465 |
+
RateLimitedRewardManager.init_class(config, tokenizer)
|
| 466 |
+
first_semaphore = RateLimitedRewardManager._semaphore
|
| 467 |
+
|
| 468 |
+
RateLimitedRewardManager.init_class(config, tokenizer)
|
| 469 |
+
second_semaphore = RateLimitedRewardManager._semaphore
|
| 470 |
+
|
| 471 |
+
# Should be the same object
|
| 472 |
+
assert first_semaphore is second_semaphore
|
| 473 |
+
|
| 474 |
+
def test_warn_when_rate_limits_are_ignored_due_to_prior_init(self, tokenizer, caplog):
|
| 475 |
+
"""Warn when a new config attempts to change global RPM/TPM after the class has been initialized."""
|
| 476 |
+
caplog.set_level(logging.WARNING)
|
| 477 |
+
|
| 478 |
+
# First instantiation without a config (legacy signature) initializes global limiters with defaults.
|
| 479 |
+
_ = RateLimitedRewardManager(
|
| 480 |
+
tokenizer=tokenizer,
|
| 481 |
+
compute_score=mock_async_reward_function,
|
| 482 |
+
num_examine=0,
|
| 483 |
+
reward_fn_key="data_source",
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
# Second instantiation attempts to set RPM limits, but will be ignored due to global initialization.
|
| 487 |
+
config = DictConfig({"reward_model": {"max_concurrent": 10, "max_rpm": 60, "timeout": 10.0}})
|
| 488 |
+
_ = RateLimitedRewardManager(
|
| 489 |
+
config=config,
|
| 490 |
+
tokenizer=tokenizer,
|
| 491 |
+
compute_score=mock_async_reward_function,
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
assert any(
|
| 495 |
+
"RateLimitedRewardManager has already been initialized" in record.getMessage()
|
| 496 |
+
and "ignored" in record.getMessage()
|
| 497 |
+
for record in caplog.records
|
| 498 |
+
), "Expected a warning when attempting to change global rate limits after initialization."
|
| 499 |
+
|
| 500 |
+
@pytest.mark.asyncio
|
| 501 |
+
async def test_extra_info_handling(self, tokenizer):
|
| 502 |
+
"""Test that extra_info is properly passed to reward function."""
|
| 503 |
+
received_extra_info = {}
|
| 504 |
+
|
| 505 |
+
async def mock_reward_with_extra_info(
|
| 506 |
+
data_source: str, solution_str: str, ground_truth: str, extra_info: dict, **kwargs
|
| 507 |
+
):
|
| 508 |
+
received_extra_info.update(extra_info)
|
| 509 |
+
return 1.0
|
| 510 |
+
|
| 511 |
+
config = DictConfig({"reward_model": {"max_concurrent": 10, "timeout": 10.0}})
|
| 512 |
+
|
| 513 |
+
RateLimitedRewardManager.init_class(config, tokenizer)
|
| 514 |
+
manager = RateLimitedRewardManager(
|
| 515 |
+
config=config, tokenizer=tokenizer, compute_score=mock_reward_with_extra_info
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
data = create_test_data_proto(tokenizer, "answer", "answer")
|
| 519 |
+
data.non_tensor_batch["extra_info"] = [{"custom_field": "test_value"}]
|
| 520 |
+
|
| 521 |
+
await manager.run_single(data)
|
| 522 |
+
|
| 523 |
+
assert "custom_field" in received_extra_info
|
| 524 |
+
assert received_extra_info["custom_field"] == "test_value"
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
if __name__ == "__main__":
|
| 528 |
+
pytest.main([__file__, "-v", "-s"])
|
code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_reward_model_disrm.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import os
|
| 15 |
+
|
| 16 |
+
import ray
|
| 17 |
+
import torch
|
| 18 |
+
from hydra import compose, initialize_config_dir
|
| 19 |
+
|
| 20 |
+
from verl.experimental.reward_loop import RewardLoopManager
|
| 21 |
+
from verl.protocol import DataProto
|
| 22 |
+
from verl.utils import hf_tokenizer
|
| 23 |
+
from verl.utils.model import compute_position_id_with_mask
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def create_data_samples(tokenizer) -> DataProto:
|
| 27 |
+
convs = [
|
| 28 |
+
[
|
| 29 |
+
{
|
| 30 |
+
"role": "user",
|
| 31 |
+
"content": "What is the range of the numeric output of a sigmoid node in a neural network?",
|
| 32 |
+
},
|
| 33 |
+
{"role": "assistant", "content": "Between -1 and 1."},
|
| 34 |
+
],
|
| 35 |
+
[
|
| 36 |
+
{
|
| 37 |
+
"role": "user",
|
| 38 |
+
"content": "What is the range of the numeric output of a sigmoid node in a neural network?",
|
| 39 |
+
},
|
| 40 |
+
{"role": "assistant", "content": "Between 0 and 1."},
|
| 41 |
+
],
|
| 42 |
+
[
|
| 43 |
+
{"role": "user", "content": "What is the capital of Australia?"},
|
| 44 |
+
{
|
| 45 |
+
"role": "assistant",
|
| 46 |
+
"content": "Canberra is the capital city of Australia.",
|
| 47 |
+
},
|
| 48 |
+
],
|
| 49 |
+
[
|
| 50 |
+
{"role": "user", "content": "What is the capital of Australia?"},
|
| 51 |
+
{
|
| 52 |
+
"role": "assistant",
|
| 53 |
+
"content": "Sydney is the capital of Australia.",
|
| 54 |
+
},
|
| 55 |
+
],
|
| 56 |
+
]
|
| 57 |
+
raw_prompt = [conv[:1] for conv in convs]
|
| 58 |
+
data_source = ["gsm8k"] * len(convs)
|
| 59 |
+
reward_info = [{"ground_truth": "Not Used"}] * len(convs)
|
| 60 |
+
extra_info = [{"question": conv[0]["content"]} for conv in convs]
|
| 61 |
+
|
| 62 |
+
prompt_length, response_length = 1024, 4096
|
| 63 |
+
pad_token_id = tokenizer.pad_token_id
|
| 64 |
+
prompts, responses, input_ids, attention_masks = [], [], [], []
|
| 65 |
+
for conv in convs:
|
| 66 |
+
prompt_tokens = tokenizer.apply_chat_template(conv[:1], tokenize=True)
|
| 67 |
+
response_tokens = tokenizer.apply_chat_template(conv, tokenize=True)[len(prompt_tokens) :]
|
| 68 |
+
|
| 69 |
+
padded_prompt = [pad_token_id] * (prompt_length - len(prompt_tokens)) + prompt_tokens
|
| 70 |
+
padded_response = response_tokens + [pad_token_id] * (response_length - len(response_tokens))
|
| 71 |
+
attention_mask = (
|
| 72 |
+
[0] * (prompt_length - len(prompt_tokens))
|
| 73 |
+
+ [1] * len(prompt_tokens)
|
| 74 |
+
+ [1] * len(response_tokens)
|
| 75 |
+
+ [0] * (response_length - len(response_tokens))
|
| 76 |
+
)
|
| 77 |
+
prompts.append(torch.tensor(padded_prompt))
|
| 78 |
+
responses.append(torch.tensor(padded_response))
|
| 79 |
+
input_ids.append(torch.tensor(padded_prompt + padded_response))
|
| 80 |
+
attention_masks.append(torch.tensor(attention_mask))
|
| 81 |
+
|
| 82 |
+
prompts = torch.stack(prompts)
|
| 83 |
+
responses = torch.stack(responses)
|
| 84 |
+
input_ids = torch.stack(input_ids)
|
| 85 |
+
attention_masks = torch.stack(attention_masks)
|
| 86 |
+
position_ids = compute_position_id_with_mask(attention_masks)
|
| 87 |
+
|
| 88 |
+
data = DataProto.from_dict(
|
| 89 |
+
tensors={
|
| 90 |
+
"prompts": prompts,
|
| 91 |
+
"responses": responses,
|
| 92 |
+
"input_ids": input_ids,
|
| 93 |
+
"attention_mask": attention_masks,
|
| 94 |
+
"position_ids": position_ids,
|
| 95 |
+
},
|
| 96 |
+
non_tensors={
|
| 97 |
+
"data_source": data_source,
|
| 98 |
+
"reward_model": reward_info,
|
| 99 |
+
"raw_prompt": raw_prompt,
|
| 100 |
+
"extra_info": extra_info,
|
| 101 |
+
},
|
| 102 |
+
)
|
| 103 |
+
return data, convs
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def test_reward_model_manager():
|
| 107 |
+
ray.init(
|
| 108 |
+
runtime_env={
|
| 109 |
+
"env_vars": {
|
| 110 |
+
"TOKENIZERS_PARALLELISM": "true",
|
| 111 |
+
"NCCL_DEBUG": "WARN",
|
| 112 |
+
"VLLM_LOGGING_LEVEL": "INFO",
|
| 113 |
+
"VLLM_USE_V1": "1",
|
| 114 |
+
}
|
| 115 |
+
}
|
| 116 |
+
)
|
| 117 |
+
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
|
| 118 |
+
config = compose(config_name="ppo_trainer")
|
| 119 |
+
|
| 120 |
+
rollout_model_name = os.path.expanduser("~/models/Qwen/Qwen2.5-1.5B-Instruct")
|
| 121 |
+
reward_model_name = os.path.expanduser("~/models/Skywork/Skywork-Reward-V2-Llama-3.2-1B")
|
| 122 |
+
|
| 123 |
+
config.actor_rollout_ref.model.path = rollout_model_name
|
| 124 |
+
config.reward_model.reward_manager = "dapo"
|
| 125 |
+
config.reward_model.enable = True
|
| 126 |
+
config.reward_model.enable_resource_pool = True
|
| 127 |
+
config.reward_model.n_gpus_per_node = 8
|
| 128 |
+
config.reward_model.nnodes = 1
|
| 129 |
+
config.reward_model.model.path = reward_model_name
|
| 130 |
+
config.reward_model.rollout.name = os.getenv("ROLLOUT_NAME", "vllm")
|
| 131 |
+
config.reward_model.rollout.gpu_memory_utilization = 0.9
|
| 132 |
+
config.reward_model.rollout.tensor_model_parallel_size = 2
|
| 133 |
+
config.reward_model.rollout.skip_tokenizer_init = False
|
| 134 |
+
config.reward_model.rollout.prompt_length = 2048
|
| 135 |
+
config.reward_model.rollout.response_length = 4096
|
| 136 |
+
|
| 137 |
+
# 1. init reward model manager
|
| 138 |
+
reward_loop_manager = RewardLoopManager(config)
|
| 139 |
+
|
| 140 |
+
# 2. init test data
|
| 141 |
+
rollout_tokenizer = hf_tokenizer(rollout_model_name)
|
| 142 |
+
data, convs = create_data_samples(rollout_tokenizer)
|
| 143 |
+
|
| 144 |
+
# 3. generate responses
|
| 145 |
+
outputs = reward_loop_manager.compute_rm_score(data)
|
| 146 |
+
|
| 147 |
+
for idx, (conv, output) in enumerate(zip(convs, outputs, strict=True)):
|
| 148 |
+
print(f"Problem {idx}:\n{conv[0]['content']}\n")
|
| 149 |
+
print(f"AI Solution {idx}:\n{conv[1]['content']}\n")
|
| 150 |
+
print(f"DisRM Score {idx}:\n{output.batch['rm_scores'].sum(dim=-1).item()}\n")
|
| 151 |
+
print("=" * 50 + "\n")
|
| 152 |
+
|
| 153 |
+
ray.shutdown()
|
code/RL_model/verl/verl_train/tests/experimental/vla/test_sim_envs.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import unittest
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import pytest
|
| 20 |
+
from omegaconf import OmegaConf
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# @pytest.mark.parametrize("simulator_type", ["libero", "isaac"])
|
| 24 |
+
@pytest.mark.parametrize("simulator_type", ["isaac"])
|
| 25 |
+
def test_sim_env_creation_and_step(simulator_type):
|
| 26 |
+
num_envs = 8
|
| 27 |
+
actions = np.array(
|
| 28 |
+
[
|
| 29 |
+
[5.59112417e-01, 8.06460073e-02, 1.36817226e-02, -4.64279854e-04, -1.72158767e-02, -6.57548380e-04, -1],
|
| 30 |
+
[2.12711899e-03, -3.13366604e-01, 3.41386353e-04, -4.64279854e-04, -8.76528812e-03, -6.57548380e-04, -1],
|
| 31 |
+
[7.38182960e-02, -4.64548351e-02, -6.63602950e-02, -4.64279854e-04, -2.32520114e-02, -6.57548380e-04, -1],
|
| 32 |
+
[7.38182960e-02, -1.60845593e-01, 3.41386353e-04, -4.64279854e-04, 1.05503430e-02, -6.57548380e-04, -1],
|
| 33 |
+
[7.38182960e-02, -3.95982152e-01, -7.97006313e-02, -5.10713711e-03, 3.22804279e-02, -6.57548380e-04, -1],
|
| 34 |
+
[2.41859427e-02, -3.64206941e-01, -6.63602950e-02, -4.64279854e-04, 1.05503430e-02, -6.57548380e-04, -1],
|
| 35 |
+
[4.62447664e-02, -5.16727952e-01, -7.97006313e-02, -4.64279854e-04, 1.05503430e-02, 8.73740975e-03, -1],
|
| 36 |
+
[4.62447664e-02, -5.73923331e-01, 3.41386353e-04, -4.64279854e-04, 6.92866212e-03, -6.57548380e-04, -1],
|
| 37 |
+
]
|
| 38 |
+
)
|
| 39 |
+
cfg = OmegaConf.create(
|
| 40 |
+
{
|
| 41 |
+
"max_episode_steps": 512,
|
| 42 |
+
"only_eval": False,
|
| 43 |
+
"reward_coef": 1.0,
|
| 44 |
+
"init_params": {
|
| 45 |
+
"camera_names": ["agentview"],
|
| 46 |
+
},
|
| 47 |
+
"video_cfg": {
|
| 48 |
+
"save_video": True,
|
| 49 |
+
"video_base_dir": "/tmp/test_sim_env_creation_and_step",
|
| 50 |
+
},
|
| 51 |
+
"task_suite_name": "libero_10",
|
| 52 |
+
"num_envs": num_envs,
|
| 53 |
+
"num_group": 1,
|
| 54 |
+
"group_size": num_envs,
|
| 55 |
+
"seed": 0,
|
| 56 |
+
},
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
sim_env = None
|
| 60 |
+
if simulator_type == "isaac":
|
| 61 |
+
from verl.experimental.vla.envs.isaac_env.isaac_env import IsaacEnv
|
| 62 |
+
|
| 63 |
+
sim_env = IsaacEnv(cfg, rank=0, world_size=1)
|
| 64 |
+
elif simulator_type == "libero":
|
| 65 |
+
from verl.experimental.vla.envs.libero_env.libero_env import LiberoEnv
|
| 66 |
+
|
| 67 |
+
sim_env = LiberoEnv(cfg, rank=0, world_size=1)
|
| 68 |
+
else:
|
| 69 |
+
raise ValueError(f"simulator_type {simulator_type} is not supported")
|
| 70 |
+
|
| 71 |
+
video_count = 0
|
| 72 |
+
for i in [0]:
|
| 73 |
+
# The first call to step with actions=None will reset the environment
|
| 74 |
+
step = 0
|
| 75 |
+
sim_env.reset_envs_to_state_ids([0] * num_envs, [i] * num_envs)
|
| 76 |
+
for action in actions:
|
| 77 |
+
obs_venv, reward_venv, terminated_venv, truncated_venv, info_venv = sim_env.step(
|
| 78 |
+
np.array([action] * num_envs)
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
assert isinstance(obs_venv, dict)
|
| 82 |
+
assert reward_venv.shape == (num_envs,)
|
| 83 |
+
assert terminated_venv.shape == (num_envs,)
|
| 84 |
+
assert truncated_venv.shape == (num_envs,)
|
| 85 |
+
assert isinstance(info_venv, dict)
|
| 86 |
+
|
| 87 |
+
if terminated_venv.any() or truncated_venv.any():
|
| 88 |
+
break
|
| 89 |
+
step += 1
|
| 90 |
+
|
| 91 |
+
sim_env.flush_video(video_sub_dir=f"task_{i}")
|
| 92 |
+
assert os.path.exists(os.path.join(cfg.video_cfg.video_base_dir, f"rank_0/task_{i}/{video_count}.mp4"))
|
| 93 |
+
os.remove(os.path.join(cfg.video_cfg.video_base_dir, f"rank_0/task_{i}/{video_count}.mp4"))
|
| 94 |
+
video_count += 1
|
| 95 |
+
|
| 96 |
+
print("test passed")
|
| 97 |
+
sim_env.close()
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
if __name__ == "__main__":
|
| 101 |
+
unittest.main()
|
code/RL_model/verl/verl_train/tests/single_controller/base/test_decorator.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import pytest
|
| 16 |
+
|
| 17 |
+
import verl.single_controller.base.decorator as decorator_module
|
| 18 |
+
from verl.single_controller.base.decorator import (
|
| 19 |
+
DISPATCH_MODE_FN_REGISTRY,
|
| 20 |
+
Dispatch,
|
| 21 |
+
_check_dispatch_mode,
|
| 22 |
+
get_predefined_dispatch_fn,
|
| 23 |
+
register_dispatch_mode,
|
| 24 |
+
update_dispatch_mode,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@pytest.fixture
|
| 29 |
+
def reset_dispatch_registry():
|
| 30 |
+
# Store original state
|
| 31 |
+
original_registry = DISPATCH_MODE_FN_REGISTRY.copy()
|
| 32 |
+
yield
|
| 33 |
+
# Reset registry after test
|
| 34 |
+
decorator_module.DISPATCH_MODE_FN_REGISTRY.clear()
|
| 35 |
+
decorator_module.DISPATCH_MODE_FN_REGISTRY.update(original_registry)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def test_register_new_dispatch_mode(reset_dispatch_registry):
|
| 39 |
+
# Test registration
|
| 40 |
+
def dummy_dispatch(worker_group, *args, **kwargs):
|
| 41 |
+
return args, kwargs
|
| 42 |
+
|
| 43 |
+
def dummy_collect(worker_group, output):
|
| 44 |
+
return output
|
| 45 |
+
|
| 46 |
+
register_dispatch_mode("TEST_MODE", dummy_dispatch, dummy_collect)
|
| 47 |
+
|
| 48 |
+
# Verify enum extension
|
| 49 |
+
_check_dispatch_mode(Dispatch.TEST_MODE)
|
| 50 |
+
|
| 51 |
+
# Verify registry update
|
| 52 |
+
assert get_predefined_dispatch_fn(Dispatch.TEST_MODE) == {
|
| 53 |
+
"dispatch_fn": dummy_dispatch,
|
| 54 |
+
"collect_fn": dummy_collect,
|
| 55 |
+
}
|
| 56 |
+
# Clean up
|
| 57 |
+
Dispatch.remove("TEST_MODE")
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def test_update_existing_dispatch_mode(reset_dispatch_registry):
|
| 61 |
+
# Store original implementation
|
| 62 |
+
original_mode = Dispatch.ONE_TO_ALL
|
| 63 |
+
|
| 64 |
+
# New implementations
|
| 65 |
+
def new_dispatch(worker_group, *args, **kwargs):
|
| 66 |
+
return args, kwargs
|
| 67 |
+
|
| 68 |
+
def new_collect(worker_group, output):
|
| 69 |
+
return output
|
| 70 |
+
|
| 71 |
+
# Test update=
|
| 72 |
+
update_dispatch_mode(original_mode, new_dispatch, new_collect)
|
| 73 |
+
|
| 74 |
+
# Verify update
|
| 75 |
+
assert get_predefined_dispatch_fn(original_mode)["dispatch_fn"] == new_dispatch
|
| 76 |
+
assert get_predefined_dispatch_fn(original_mode)["collect_fn"] == new_collect
|
code/RL_model/verl/verl_train/tests/single_controller/check_worker_alive/main.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import sys
|
| 17 |
+
import time
|
| 18 |
+
|
| 19 |
+
import ray
|
| 20 |
+
|
| 21 |
+
from verl.single_controller.base.decorator import Dispatch, register
|
| 22 |
+
from verl.single_controller.base.worker import Worker
|
| 23 |
+
from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@ray.remote
|
| 27 |
+
class TestActor(Worker):
|
| 28 |
+
def __init__(self) -> None:
|
| 29 |
+
super().__init__()
|
| 30 |
+
|
| 31 |
+
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
|
| 32 |
+
def foo(self, wait_time):
|
| 33 |
+
time.sleep(wait_time)
|
| 34 |
+
sys.exit(1)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
if __name__ == "__main__":
|
| 38 |
+
wait_time = int(os.getenv("WAIT_TIME", "10"))
|
| 39 |
+
|
| 40 |
+
ray.init()
|
| 41 |
+
|
| 42 |
+
# test single-node-no-partition
|
| 43 |
+
print("test single-node-no-partition")
|
| 44 |
+
resource_pool = RayResourcePool([2], use_gpu=False)
|
| 45 |
+
class_with_args = RayClassWithInitArgs(cls=TestActor)
|
| 46 |
+
|
| 47 |
+
print("create worker group")
|
| 48 |
+
wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="test")
|
| 49 |
+
|
| 50 |
+
wg.start_worker_aliveness_check(1)
|
| 51 |
+
time.sleep(1)
|
| 52 |
+
|
| 53 |
+
print(time.time(), "start foo")
|
| 54 |
+
|
| 55 |
+
_ = wg.foo(wait_time)
|
| 56 |
+
print("foo started")
|
| 57 |
+
|
| 58 |
+
print(
|
| 59 |
+
time.time(),
|
| 60 |
+
f"wait 6x wait time {wait_time * 6} to let signal returned to process but still not exceed process wait time",
|
| 61 |
+
)
|
| 62 |
+
time.sleep(wait_time * 6)
|
| 63 |
+
|
| 64 |
+
ray.shutdown()
|
code/RL_model/verl/verl_train/tests/single_controller/detached_worker/README.md
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Detached Worker
|
| 2 |
+
## How to run (Only on a single node)
|
| 3 |
+
- Start a local ray cluster:
|
| 4 |
+
```bash
|
| 5 |
+
ray start --head --port=6379
|
| 6 |
+
```
|
| 7 |
+
- Run the server
|
| 8 |
+
```bash
|
| 9 |
+
python3 server.py
|
| 10 |
+
```
|
| 11 |
+
- On another terminal, Run the client
|
| 12 |
+
```bash
|
| 13 |
+
python3 client.py
|
| 14 |
+
```
|
code/RL_model/verl/verl_train/tests/single_controller/detached_worker/client.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
In client, we can get the server handler and send RPC request
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import ray
|
| 19 |
+
import torch
|
| 20 |
+
from server import Trainer
|
| 21 |
+
from tensordict import TensorDict
|
| 22 |
+
|
| 23 |
+
from verl import DataProto
|
| 24 |
+
from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def compute_position_id_with_mask(mask):
|
| 28 |
+
return torch.clip(torch.cumsum(mask, dim=-1) - 1, min=0, max=None)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
if __name__ == "__main__":
|
| 32 |
+
ray.init(address="auto", namespace="verl")
|
| 33 |
+
# get the worker group using names
|
| 34 |
+
worker_names = ["trainerTrainer_0:0", "trainerTrainer_0:1"]
|
| 35 |
+
cls_with_init_args = RayClassWithInitArgs(cls=Trainer)
|
| 36 |
+
worker_group = RayWorkerGroup.from_detached(worker_names=worker_names, ray_cls_with_init=cls_with_init_args)
|
| 37 |
+
|
| 38 |
+
batch_size = 16
|
| 39 |
+
sequence_length = 1024
|
| 40 |
+
|
| 41 |
+
# give Trainer some data to train
|
| 42 |
+
input_ids = torch.randint(low=0, high=256, size=(batch_size, sequence_length), dtype=torch.int64, device="cuda")
|
| 43 |
+
attention_mask = torch.ones_like(input_ids)
|
| 44 |
+
position_ids = compute_position_id_with_mask(attention_mask)
|
| 45 |
+
|
| 46 |
+
data = DataProto(
|
| 47 |
+
batch=TensorDict(
|
| 48 |
+
{"input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids},
|
| 49 |
+
batch_size=batch_size,
|
| 50 |
+
),
|
| 51 |
+
meta_info={},
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
output = worker_group.train_model(data)
|
| 55 |
+
|
| 56 |
+
print(output)
|
code/RL_model/verl/verl_train/tests/single_controller/detached_worker/run.sh
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
ray start --head --port=6379
|
| 3 |
+
python3 server.py
|
| 4 |
+
python3 client.py
|
| 5 |
+
ray stop --force
|
code/RL_model/verl/verl_train/tests/single_controller/detached_worker/server.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
Server starts a Trainer. Client sends data to the server to train.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
|
| 20 |
+
os.environ["MEGATRON_USE_CUDA_TIMER"] = "0"
|
| 21 |
+
os.environ["MEGATRON_START_PROCESS_TIMER"] = "False"
|
| 22 |
+
os.environ["NCCL_DEBUG"] = "WARN"
|
| 23 |
+
|
| 24 |
+
import ray
|
| 25 |
+
import torch
|
| 26 |
+
from megatron.core import parallel_state as mpu
|
| 27 |
+
from megatron.core import tensor_parallel
|
| 28 |
+
from megatron.core.models.gpt.gpt_model import ModelType
|
| 29 |
+
from omegaconf import OmegaConf
|
| 30 |
+
from tensordict import TensorDict
|
| 31 |
+
from torch import nn
|
| 32 |
+
from transformers import LlamaConfig
|
| 33 |
+
|
| 34 |
+
from verl import DataProto
|
| 35 |
+
from verl.models.llama.megatron import ParallelLlamaForCausalLMRmPadPP
|
| 36 |
+
from verl.single_controller.base import Worker
|
| 37 |
+
from verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register
|
| 38 |
+
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
|
| 39 |
+
from verl.utils.megatron.optimizer import get_megatron_optimizer, init_megatron_optim_config
|
| 40 |
+
from verl.utils.megatron_utils import get_model, mcore_model_parallel_config
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@ray.remote
|
| 44 |
+
class Trainer(Worker):
|
| 45 |
+
def __init__(self):
|
| 46 |
+
super().__init__()
|
| 47 |
+
|
| 48 |
+
if not torch.distributed.is_initialized():
|
| 49 |
+
rank = int(os.environ["LOCAL_RANK"])
|
| 50 |
+
torch.distributed.init_process_group(backend="nccl")
|
| 51 |
+
torch.cuda.set_device(rank)
|
| 52 |
+
|
| 53 |
+
mpu.initialize_model_parallel(
|
| 54 |
+
tensor_model_parallel_size=2,
|
| 55 |
+
pipeline_model_parallel_size=1,
|
| 56 |
+
virtual_pipeline_model_parallel_size=None,
|
| 57 |
+
use_sharp=False,
|
| 58 |
+
context_parallel_size=1,
|
| 59 |
+
expert_model_parallel_size=1,
|
| 60 |
+
nccl_communicator_config_path=None,
|
| 61 |
+
)
|
| 62 |
+
tensor_parallel.model_parallel_cuda_manual_seed(10)
|
| 63 |
+
|
| 64 |
+
is_collect = (
|
| 65 |
+
mpu.get_tensor_model_parallel_rank() == 0
|
| 66 |
+
and mpu.get_pipeline_model_parallel_rank() == mpu.get_pipeline_model_parallel_world_size() - 1
|
| 67 |
+
and mpu.get_context_parallel_rank() == 0
|
| 68 |
+
)
|
| 69 |
+
self._register_dispatch_collect_info(
|
| 70 |
+
mesh_name="train", dp_rank=mpu.get_data_parallel_rank(), is_collect=is_collect
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
|
| 74 |
+
def init_model(self):
|
| 75 |
+
actor_model_config = LlamaConfig(
|
| 76 |
+
vocab_size=256,
|
| 77 |
+
hidden_size=2048,
|
| 78 |
+
intermediate_size=5504,
|
| 79 |
+
num_hidden_layers=24,
|
| 80 |
+
num_attention_heads=16,
|
| 81 |
+
num_key_value_heads=16,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
megatron_config = mcore_model_parallel_config(sequence_parallel=True, params_dtype=torch.bfloat16)
|
| 85 |
+
self.megatron_config = megatron_config
|
| 86 |
+
|
| 87 |
+
def megatron_actor_model_provider(pre_process, post_process):
|
| 88 |
+
# vpp is not supported yet because it will hang for some reason. Need debugging
|
| 89 |
+
# this_megatron_config = copy.deepcopy(megatron_config)
|
| 90 |
+
# this_megatron_config.virtual_pipeline_model_parallel_rank = vpp_rank
|
| 91 |
+
parallel_model = ParallelLlamaForCausalLMRmPadPP(
|
| 92 |
+
config=actor_model_config,
|
| 93 |
+
megatron_config=megatron_config,
|
| 94 |
+
pre_process=pre_process,
|
| 95 |
+
post_process=post_process,
|
| 96 |
+
)
|
| 97 |
+
parallel_model.cuda()
|
| 98 |
+
return parallel_model
|
| 99 |
+
|
| 100 |
+
actor_module = get_model(
|
| 101 |
+
model_provider_func=megatron_actor_model_provider,
|
| 102 |
+
model_type=ModelType.encoder_or_decoder,
|
| 103 |
+
wrap_with_ddp=True,
|
| 104 |
+
)
|
| 105 |
+
actor_module = nn.ModuleList(actor_module)
|
| 106 |
+
|
| 107 |
+
optim_config = OmegaConf.create({"lr": 1e-6, "clip_grad": 1.0})
|
| 108 |
+
|
| 109 |
+
optim_config = init_megatron_optim_config(optim_config)
|
| 110 |
+
self.optimizer_config = optim_config
|
| 111 |
+
actor_optimizer = get_megatron_optimizer(model=actor_module, config=optim_config)
|
| 112 |
+
|
| 113 |
+
self.model = actor_module[0]
|
| 114 |
+
self.optimizer = actor_optimizer
|
| 115 |
+
|
| 116 |
+
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="train"))
|
| 117 |
+
def train_model(self, data: DataProto) -> DataProto:
|
| 118 |
+
input_ids = data.batch["input_ids"]
|
| 119 |
+
attention_mask = data.batch["attention_mask"]
|
| 120 |
+
position_ids = data.batch["position_ids"]
|
| 121 |
+
|
| 122 |
+
self.optimizer.zero_grad()
|
| 123 |
+
self.model.zero_grad_buffer(
|
| 124 |
+
zero_buffer=(not self.optimizer_config.use_distributed_optimizer)
|
| 125 |
+
) # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm
|
| 126 |
+
# update for 1 iteration
|
| 127 |
+
output = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids).logits
|
| 128 |
+
output.mean().backward()
|
| 129 |
+
|
| 130 |
+
update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step(
|
| 131 |
+
self.megatron_config, self.megatron_config.timers
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
return DataProto(batch=TensorDict({"loss": output.detach()}, batch_size=output.shape[0]))
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
if __name__ == "__main__":
|
| 138 |
+
ray.init(address="auto", namespace="verl")
|
| 139 |
+
|
| 140 |
+
resource_pool = RayResourcePool(process_on_nodes=[2], detached=True)
|
| 141 |
+
cls_with_init_args = RayClassWithInitArgs(cls=Trainer)
|
| 142 |
+
worker_group = RayWorkerGroup(
|
| 143 |
+
resource_pool=resource_pool,
|
| 144 |
+
ray_cls_with_init=cls_with_init_args,
|
| 145 |
+
name_prefix="trainer",
|
| 146 |
+
detached=True,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
worker_group.init_model()
|
| 150 |
+
|
| 151 |
+
worker_names = worker_group.worker_names
|
| 152 |
+
print(worker_names)
|
code/RL_model/verl/verl_train/tests/special_e2e/envs/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .digit_completion import DigitCompletion
|
| 16 |
+
|
| 17 |
+
__all__ = ["DigitCompletion"]
|
code/RL_model/verl/verl_train/tests/special_e2e/envs/digit_completion/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from transformers import AutoTokenizer, LlamaConfig
|
| 16 |
+
|
| 17 |
+
from .task import DigitCompletion, generate_ground_truth_response
|
| 18 |
+
from .tokenizer import CharTokenizer
|
| 19 |
+
|
| 20 |
+
AutoTokenizer.register(LlamaConfig, CharTokenizer, exist_ok=True)
|
| 21 |
+
|
| 22 |
+
__all__ = ["DigitCompletion", "generate_ground_truth_response", "CharTokenizer"]
|
code/RL_model/verl/verl_train/tests/special_e2e/envs/digit_completion/task.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Task and environment definition for digit completion."""
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class DigitCompletion:
|
| 20 |
+
"""
|
| 21 |
+
The implementation of a simple digit completion task.
|
| 22 |
+
The prompt is a sequence of numbers with fixed difference. The task is to complete the next N numbers.
|
| 23 |
+
If the max number is reached, the next number should be modulo with max number.
|
| 24 |
+
|
| 25 |
+
For example,
|
| 26 |
+
- prompt = [1, 2, 3]
|
| 27 |
+
- N = 5
|
| 28 |
+
- max_number = 6
|
| 29 |
+
|
| 30 |
+
the response should be [4, 5, 6, 7%6, 8%6] = [4, 5, 6, 0, 1]
|
| 31 |
+
|
| 32 |
+
Note that the tokenizer is char-level to increase the difficulty.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, max_number: int, max_diff: int, max_num_in_response: int, seed=0):
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
max_number: the maximum number allowed in the arithmetic sequence
|
| 40 |
+
max_diff: the maximum diff. The actual common diff will be sampled from [0, max_diff]
|
| 41 |
+
max_num_in_response: the maximum number in the response
|
| 42 |
+
"""
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.max_number = max_number
|
| 45 |
+
self.max_diff = max_diff
|
| 46 |
+
self.max_num_in_response = max_num_in_response
|
| 47 |
+
assert self.max_num_in_response < 10
|
| 48 |
+
assert self.max_number > 0
|
| 49 |
+
assert self.max_diff > 0
|
| 50 |
+
self.max_number_length = len(str(max_number))
|
| 51 |
+
# {num1},{num2}:{max_num_in_response},{max_number}
|
| 52 |
+
self._prompt_length = self.max_number_length * 2 + 4 + self.max_number_length # no negative is allowed
|
| 53 |
+
|
| 54 |
+
self.np_rng = np.random.default_rng(seed=seed)
|
| 55 |
+
|
| 56 |
+
def __str__(self):
|
| 57 |
+
return (
|
| 58 |
+
f"Prompt length: {self.prompt_length}. Response length: {self.response_length}, "
|
| 59 |
+
f"Max number: {self.max_number}. Max diff: {self.max_diff}, "
|
| 60 |
+
f"Max number in response: {self.max_num_in_response}"
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
def get_state(self):
|
| 64 |
+
return {"rng": self.np_rng}
|
| 65 |
+
|
| 66 |
+
def set_state(self, state):
|
| 67 |
+
assert "rng" in state, "rng must be inside state"
|
| 68 |
+
self.np_rng = state["rng"]
|
| 69 |
+
|
| 70 |
+
@property
|
| 71 |
+
def prompt_length(self):
|
| 72 |
+
return self._prompt_length
|
| 73 |
+
|
| 74 |
+
@property
|
| 75 |
+
def response_length(self):
|
| 76 |
+
# number length + comma length + [EOS]
|
| 77 |
+
# The actual number times 1.5 to allow 'U'
|
| 78 |
+
return (self.max_num_in_response * self.max_number_length + (self.max_num_in_response - 1) + 1) * 2
|
| 79 |
+
|
| 80 |
+
def add(self, a, b):
|
| 81 |
+
return (a + b) % self.max_number
|
| 82 |
+
|
| 83 |
+
def get_all_prompts(self):
|
| 84 |
+
all_prompts = []
|
| 85 |
+
for first_num in range(self.max_number + 1):
|
| 86 |
+
for diff in range(0, self.max_diff + 1):
|
| 87 |
+
second_num = self.add(first_num, diff)
|
| 88 |
+
for num_to_complete in range(self.max_num_in_response + 1):
|
| 89 |
+
prompt = str(first_num) + "," + str(second_num) + f":{self.max_number},{num_to_complete}"
|
| 90 |
+
all_prompts.append(prompt)
|
| 91 |
+
return all_prompts
|
| 92 |
+
|
| 93 |
+
def sample_str_prompts(self):
|
| 94 |
+
# step 1: sample initial numbers
|
| 95 |
+
first_num = self.np_rng.integers(self.max_number + 1)
|
| 96 |
+
diff = self.np_rng.integers(self.max_diff + 1)
|
| 97 |
+
second_num = self.add(first_num, diff)
|
| 98 |
+
num_to_complete = self.np_rng.integers(self.max_num_in_response + 1)
|
| 99 |
+
prompt = str(first_num) + "," + str(second_num) + f":{self.max_number},{num_to_complete}"
|
| 100 |
+
return prompt
|
| 101 |
+
|
| 102 |
+
def sample_batch_str_prompts(self, batch_size):
|
| 103 |
+
str_prompts = []
|
| 104 |
+
for _ in range(batch_size):
|
| 105 |
+
str_prompts.append(self.sample_str_prompts())
|
| 106 |
+
return str_prompts
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def compute_attention_mask(prompts, pad_token_id):
|
| 110 |
+
mask = np.ones_like(prompts)
|
| 111 |
+
mask[prompts == pad_token_id] = 0
|
| 112 |
+
return mask
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def compute_position_id_with_mask(mask):
|
| 116 |
+
return np.clip(np.cumsum(mask, axis=-1) - 1, a_min=0, a_max=None)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def generate_ground_truth_response(prompt: str):
|
| 120 |
+
"""Generate ground truth response given a prompt."""
|
| 121 |
+
num, info = prompt.split(":")
|
| 122 |
+
num1, num2 = num.split(",")
|
| 123 |
+
max_number, num_to_gen = info.split(",")
|
| 124 |
+
num1 = int(num1)
|
| 125 |
+
num2 = int(num2)
|
| 126 |
+
max_number = int(max_number)
|
| 127 |
+
num_to_gen = int(num_to_gen)
|
| 128 |
+
diff = (num2 - num1) % max_number
|
| 129 |
+
results = []
|
| 130 |
+
last_num = num2
|
| 131 |
+
for _ in range(num_to_gen):
|
| 132 |
+
curr = (last_num + diff) % max_number
|
| 133 |
+
results.append(str(curr))
|
| 134 |
+
last_num = curr
|
| 135 |
+
response = ",".join(results)
|
| 136 |
+
return response
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def compute_reward(prompt: str, response: str, sequence_reward=1.0):
|
| 140 |
+
"""We compute dense reward here so that we can directly train RL without SFT"""
|
| 141 |
+
response_length = len(response)
|
| 142 |
+
ground_truth_response = generate_ground_truth_response(prompt)
|
| 143 |
+
per_token_reward = sequence_reward / (len(ground_truth_response) + 1) # including [EOS]
|
| 144 |
+
|
| 145 |
+
# pad
|
| 146 |
+
reward = np.zeros(response_length, dtype=np.float32) # this assumes that each char is a token
|
| 147 |
+
# assign reward until mismatches
|
| 148 |
+
ground_truth_idx = 0
|
| 149 |
+
for i in range(response_length):
|
| 150 |
+
if ground_truth_idx == len(ground_truth_response):
|
| 151 |
+
break
|
| 152 |
+
|
| 153 |
+
ground_truth_response_token = ground_truth_response[ground_truth_idx]
|
| 154 |
+
response_token = response[i]
|
| 155 |
+
if ground_truth_response_token == response_token:
|
| 156 |
+
reward[i] = per_token_reward
|
| 157 |
+
ground_truth_idx += 1
|
| 158 |
+
else:
|
| 159 |
+
# no matches
|
| 160 |
+
break
|
| 161 |
+
|
| 162 |
+
return reward, {"ground_truth_response": ground_truth_response}
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
if __name__ == "__main__":
|
| 166 |
+
task = DigitCompletion(max_number=20, max_diff=3, max_num_in_response=5)
|
| 167 |
+
print(task.sample_str_prompts())
|
| 168 |
+
|
| 169 |
+
prompt = "7,8:20,0"
|
| 170 |
+
response = ""
|
| 171 |
+
print(compute_reward(prompt, response))
|
| 172 |
+
|
| 173 |
+
prompt = "7,8:20,0"
|
| 174 |
+
response = "E000"
|
| 175 |
+
print(compute_reward(prompt, response))
|
| 176 |
+
|
| 177 |
+
prompt = "9,10:20,2"
|
| 178 |
+
response = "11,12,13"
|
| 179 |
+
print(compute_reward(prompt, response))
|
code/RL_model/verl/verl_train/tests/special_e2e/envs/digit_completion/tokenizer.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Copied from https://github.com/dariush-bahrami/character-tokenizer/blob/master/charactertokenizer/core.py
|
| 15 |
+
|
| 16 |
+
CharacterTokenzier for Hugging Face Transformers.
|
| 17 |
+
|
| 18 |
+
This is heavily inspired from CanineTokenizer in transformers package.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import json
|
| 22 |
+
import os
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from typing import Optional, Sequence
|
| 25 |
+
|
| 26 |
+
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class CharTokenizer(PreTrainedTokenizer):
|
| 30 |
+
def __init__(self, characters: Sequence[str], model_max_length: int, chat_template, **kwargs):
|
| 31 |
+
"""Character tokenizer for Hugging Face transformers.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
characters (Sequence[str]): List of desired characters. Any character which
|
| 35 |
+
is not included in this list will be replaced by a special token called
|
| 36 |
+
[UNK] with id=6. Following are list of all of the special tokens with
|
| 37 |
+
their corresponding ids:
|
| 38 |
+
"[CLS]": 0
|
| 39 |
+
"[SEP]": 1
|
| 40 |
+
"[BOS]": 2
|
| 41 |
+
"[MASK]": 3
|
| 42 |
+
"[PAD]": 4
|
| 43 |
+
"[RESERVED]": 5
|
| 44 |
+
"[UNK]": 6
|
| 45 |
+
an id (starting at 7) will be assigned to each character.
|
| 46 |
+
|
| 47 |
+
model_max_length (int): Model maximum sequence length.
|
| 48 |
+
"""
|
| 49 |
+
eos_token_str = "E"
|
| 50 |
+
sep_token_str = "S"
|
| 51 |
+
pad_token_str = "P"
|
| 52 |
+
unk_token_str = "U"
|
| 53 |
+
|
| 54 |
+
self.characters = characters
|
| 55 |
+
self.model_max_length = model_max_length
|
| 56 |
+
eos_token = AddedToken(eos_token_str, lstrip=False, rstrip=False)
|
| 57 |
+
sep_token = AddedToken(sep_token_str, lstrip=False, rstrip=False)
|
| 58 |
+
pad_token = AddedToken(pad_token_str, lstrip=False, rstrip=False)
|
| 59 |
+
unk_token = AddedToken(unk_token_str, lstrip=False, rstrip=False)
|
| 60 |
+
|
| 61 |
+
self._vocab_str_to_int = {
|
| 62 |
+
sep_token_str: 0,
|
| 63 |
+
eos_token_str: 1,
|
| 64 |
+
pad_token_str: 2,
|
| 65 |
+
unk_token_str: 3,
|
| 66 |
+
**{ch: i + 4 for i, ch in enumerate(characters)},
|
| 67 |
+
}
|
| 68 |
+
self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}
|
| 69 |
+
|
| 70 |
+
super().__init__(
|
| 71 |
+
eos_token=eos_token,
|
| 72 |
+
sep_token=sep_token,
|
| 73 |
+
pad_token=pad_token,
|
| 74 |
+
unk_token=unk_token,
|
| 75 |
+
add_prefix_space=False,
|
| 76 |
+
model_max_length=model_max_length,
|
| 77 |
+
**kwargs,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
self.chat_template = chat_template
|
| 81 |
+
|
| 82 |
+
@property
|
| 83 |
+
def vocab_size(self) -> int:
|
| 84 |
+
return len(self._vocab_str_to_int)
|
| 85 |
+
|
| 86 |
+
def get_vocab(self):
|
| 87 |
+
return self._vocab_str_to_int
|
| 88 |
+
|
| 89 |
+
def _tokenize(self, text: str) -> list[str]:
|
| 90 |
+
return list(text)
|
| 91 |
+
|
| 92 |
+
def _convert_token_to_id(self, token: str) -> int:
|
| 93 |
+
return self._vocab_str_to_int.get(token, self._vocab_str_to_int["U"])
|
| 94 |
+
|
| 95 |
+
def _convert_id_to_token(self, index: int) -> str:
|
| 96 |
+
return self._vocab_int_to_str[index]
|
| 97 |
+
|
| 98 |
+
def convert_tokens_to_string(self, tokens):
|
| 99 |
+
return "".join(tokens)
|
| 100 |
+
|
| 101 |
+
def build_inputs_with_special_tokens(
|
| 102 |
+
self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
|
| 103 |
+
) -> list[int]:
|
| 104 |
+
sep = [self.sep_token_id]
|
| 105 |
+
cls = [self.cls_token_id]
|
| 106 |
+
result = cls + token_ids_0 + sep
|
| 107 |
+
if token_ids_1 is not None:
|
| 108 |
+
result += token_ids_1 + sep
|
| 109 |
+
return result
|
| 110 |
+
|
| 111 |
+
def get_special_tokens_mask(
|
| 112 |
+
self,
|
| 113 |
+
token_ids_0: list[int],
|
| 114 |
+
token_ids_1: Optional[list[int]] = None,
|
| 115 |
+
already_has_special_tokens: bool = False,
|
| 116 |
+
) -> list[int]:
|
| 117 |
+
if already_has_special_tokens:
|
| 118 |
+
return super().get_special_tokens_mask(
|
| 119 |
+
token_ids_0=token_ids_0,
|
| 120 |
+
token_ids_1=token_ids_1,
|
| 121 |
+
already_has_special_tokens=True,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
result = [1] + ([0] * len(token_ids_0)) + [1]
|
| 125 |
+
if token_ids_1 is not None:
|
| 126 |
+
result += ([0] * len(token_ids_1)) + [1]
|
| 127 |
+
return result
|
| 128 |
+
|
| 129 |
+
def get_config(self) -> dict:
|
| 130 |
+
return {
|
| 131 |
+
"char_ords": [ord(ch) for ch in self.characters],
|
| 132 |
+
"model_max_length": self.model_max_length,
|
| 133 |
+
"chat_template": self.chat_template,
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
@classmethod
|
| 137 |
+
def from_config(cls, config: dict):
|
| 138 |
+
cfg = {}
|
| 139 |
+
cfg["characters"] = [chr(i) for i in config["char_ords"]]
|
| 140 |
+
cfg["model_max_length"] = config["model_max_length"]
|
| 141 |
+
cfg["chat_template"] = config["chat_template"]
|
| 142 |
+
return cls(**cfg)
|
| 143 |
+
|
| 144 |
+
def save_pretrained(self, save_directory: str | os.PathLike, **kwargs):
|
| 145 |
+
cfg_file = Path(save_directory) / "tokenizer_config.json"
|
| 146 |
+
cfg = self.get_config()
|
| 147 |
+
with open(cfg_file, "w") as f:
|
| 148 |
+
json.dump(cfg, f, indent=4)
|
| 149 |
+
|
| 150 |
+
@classmethod
|
| 151 |
+
def from_pretrained(cls, save_directory: str | os.PathLike, **kwargs):
|
| 152 |
+
cfg_file = Path(save_directory) / "tokenizer_config.json"
|
| 153 |
+
with open(cfg_file) as f:
|
| 154 |
+
cfg = json.load(f)
|
| 155 |
+
return cls.from_config(cfg)
|
code/RL_model/verl/verl_train/tests/special_e2e/generation/run_gen_qwen05.sh
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Tested with 1 & 4 GPUs
|
| 3 |
+
set -xeuo pipefail
|
| 4 |
+
|
| 5 |
+
MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct}
|
| 6 |
+
|
| 7 |
+
NGPUS_PER_NODE=${NGPUS_PER_NODE:-4}
|
| 8 |
+
OUTPUT_PATH=${OUTPUT_PATH:-$HOME/data/gen/qwen_05_gen_test.parquet}
|
| 9 |
+
GEN_TP=${GEN_TP:-2} # Default tensor parallel size to 2
|
| 10 |
+
|
| 11 |
+
python3 -m verl.trainer.main_generation \
|
| 12 |
+
trainer.nnodes=1 \
|
| 13 |
+
trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
|
| 14 |
+
data.path="${HOME}/data/gsm8k/test.parquet" \
|
| 15 |
+
data.prompt_key=prompt \
|
| 16 |
+
data.n_samples=1 \
|
| 17 |
+
data.output_path="${OUTPUT_PATH}" \
|
| 18 |
+
model.path="${MODEL_ID}" \
|
| 19 |
+
+model.trust_remote_code=True \
|
| 20 |
+
rollout.temperature=1.0 \
|
| 21 |
+
rollout.top_k=50 \
|
| 22 |
+
rollout.top_p=0.7 \
|
| 23 |
+
rollout.prompt_length=2048 \
|
| 24 |
+
rollout.response_length=1024 \
|
| 25 |
+
rollout.tensor_model_parallel_size="${GEN_TP}" \
|
| 26 |
+
rollout.gpu_memory_utilization=0.8
|
code/RL_model/verl/verl_train/tests/special_e2e/generation/run_gen_qwen05_server.sh
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Tested with 1 & 4 GPUs
|
| 3 |
+
set -xeuo pipefail
|
| 4 |
+
|
| 5 |
+
MODEL_ID=${MODEL_ID:-$HOME/models/Qwen/Qwen2.5-0.5B-Instruct}
|
| 6 |
+
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
|
| 7 |
+
OUTPUT_PATH=${OUTPUT_PATH:-$HOME/data/gen/qwen_05_gen_test.parquet}
|
| 8 |
+
GEN_TP=${GEN_TP:-2} # Default tensor parallel size to 2
|
| 9 |
+
|
| 10 |
+
python3 -m verl.trainer.main_generation_server \
|
| 11 |
+
trainer.nnodes=1 \
|
| 12 |
+
trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
|
| 13 |
+
actor_rollout_ref.model.path="${MODEL_ID}" \
|
| 14 |
+
actor_rollout_ref.model.trust_remote_code=True \
|
| 15 |
+
actor_rollout_ref.rollout.temperature=1.0 \
|
| 16 |
+
actor_rollout_ref.rollout.top_k=50 \
|
| 17 |
+
actor_rollout_ref.rollout.top_p=0.7 \
|
| 18 |
+
actor_rollout_ref.rollout.prompt_length=2048 \
|
| 19 |
+
actor_rollout_ref.rollout.response_length=1024 \
|
| 20 |
+
actor_rollout_ref.rollout.tensor_model_parallel_size="${GEN_TP}" \
|
| 21 |
+
actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \
|
| 22 |
+
actor_rollout_ref.rollout.name=vllm \
|
| 23 |
+
actor_rollout_ref.rollout.n=4 \
|
| 24 |
+
data.train_files="${HOME}/data/gsm8k/test.parquet" \
|
| 25 |
+
data.prompt_key=prompt \
|
| 26 |
+
+data.output_path="${OUTPUT_PATH}" \
|
code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/expert_parallel/qwen2moe_minimal.json
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"num_hidden_layers": 2,
|
| 3 |
+
"max_window_layers": 2
|
| 4 |
+
}
|
code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/expert_parallel/qwen3moe_minimal.json
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"num_hidden_layers": 2,
|
| 3 |
+
"max_window_layers": 2
|
| 4 |
+
}
|
code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/run_function_reward.sh
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -xeuo pipefail
|
| 3 |
+
|
| 4 |
+
NUM_GPUS=${NUM_GPUS:-8}
|
| 5 |
+
|
| 6 |
+
MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B}
|
| 7 |
+
MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}
|
| 8 |
+
#hf download "${MODEL_ID}" --local-dir "${MODEL_PATH}"
|
| 9 |
+
|
| 10 |
+
TRAIN_FILES=${TRAIN_FILES:-$HOME/data/gsm8k/train.parquet}
|
| 11 |
+
VAL_FILES=${VAL_FILES:-$HOME/data/gsm8k/test.parquet}
|
| 12 |
+
MAX_PROMPT_LEN=${MAX_PROMPT_LEN:-512}
|
| 13 |
+
MAX_RESPONSE_LEN=${MAX_RESPONSE_LEN:-512}
|
| 14 |
+
|
| 15 |
+
ENGINE=${ENGINE:-vllm}
|
| 16 |
+
if [ "$ENGINE" = "vllm" ]; then
|
| 17 |
+
export VLLM_USE_V1=1
|
| 18 |
+
fi
|
| 19 |
+
ROLLOUT_MODE="async"
|
| 20 |
+
|
| 21 |
+
RETURN_RAW_CHAT="True"
|
| 22 |
+
SKIP_TOKENIZER_INIT="True"
|
| 23 |
+
|
| 24 |
+
GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.7}
|
| 25 |
+
ACTOR_FSDP_PARAM_OFFLOAD=${ACTOR_FSDP_PARAM_OFFLOAD:-False}
|
| 26 |
+
ACTOR_FSDP_OPTIMIZER_OFFLOAD=${ACTOR_FSDP_OPTIMIZER_OFFLOAD:-False}
|
| 27 |
+
REF_FSDP_PARAM_OFFLOAD=${REF_FSDP_PARAM_OFFLOAD:-True}
|
| 28 |
+
RM_PAD=${RM_PAD:-True}
|
| 29 |
+
FUSED_KERNELS=${FUSED_KERNELS:-False}
|
| 30 |
+
FUSED_KERNEL_BACKEND=${FUSED_KERNEL_BACKEND:-torch} # or 'triton' for triton backend
|
| 31 |
+
ADV_ESTIMATOR=${ADV_ESTIMATOR:-gae}
|
| 32 |
+
LOSS_MODE=${LOSS_MODE:-vanilla}
|
| 33 |
+
USE_KL=${USE_KL:-False}
|
| 34 |
+
CUSTOM_REWARD_FN=${CUSTOM_REWARD_FN:-False}
|
| 35 |
+
ENABLE_CHUNKED_PREFILL=${ENABLE_CHUNKED_PREFILL:-True} # For vLLM VLM placeholder issue: https://github.com/vllm-project/vllm/issues/15185
|
| 36 |
+
STRATEGY=${STRATEGY:-fsdp}
|
| 37 |
+
# LoRA config
|
| 38 |
+
LORA_RANK=${LORA_RANK:-0}
|
| 39 |
+
LORA_ALPHA=${LORA_ALPHA:-${LORA_RANK}}
|
| 40 |
+
LORA_TARGET=${LORA_TARGET:-"all-linear"}
|
| 41 |
+
LORA_EXCLUDE=${LORA_EXCLUDE:-"DONT_EXCLUDE"}
|
| 42 |
+
USE_SHM=${USE_SHM:-False}
|
| 43 |
+
LOAD_FORMAT=${LOAD_FORMAT:-dummy}
|
| 44 |
+
LAYERED_SUMMON=${LAYERED_SUMMON:-False}
|
| 45 |
+
# Validation
|
| 46 |
+
VAL_BEFORE_TRAIN=${VAL_BEFORE_TRAIN:-False}
|
| 47 |
+
TEST_FREQ=${TEST_FREQ:--1}
|
| 48 |
+
# Save & Resume
|
| 49 |
+
RESUME_MODE=${RESUME_MODE:-disable}
|
| 50 |
+
SAVE_FREQ=${SAVE_FREQ:--1}
|
| 51 |
+
TOTAL_TRAIN_STEPS=${TOTAL_TRAIN_STEPS:-1}
|
| 52 |
+
|
| 53 |
+
# whether to save hf_model
|
| 54 |
+
SAVE_HF_MODEL=${SAVE_HF_MODEL:-False}
|
| 55 |
+
FSDP_SIZE=${FSDP_SIZE:--1}
|
| 56 |
+
SP_SIZE=${SP_SIZE:-1}
|
| 57 |
+
|
| 58 |
+
if [ "${SAVE_HF_MODEL}" = "True" ]; then
|
| 59 |
+
CHECKPOINT_CONTENTS="['model','hf_model','optimizer','extra']"
|
| 60 |
+
else
|
| 61 |
+
CHECKPOINT_CONTENTS="['model','optimizer','extra']"
|
| 62 |
+
fi
|
| 63 |
+
|
| 64 |
+
train_traj_micro_bsz_per_gpu=2 # b
|
| 65 |
+
n_resp_per_prompt=4 # g
|
| 66 |
+
|
| 67 |
+
train_traj_micro_bsz=$((train_traj_micro_bsz_per_gpu * NUM_GPUS)) # b * n
|
| 68 |
+
train_traj_mini_bsz=$((train_traj_micro_bsz * 2)) # 2 * b * n
|
| 69 |
+
train_prompt_mini_bsz=$((train_traj_mini_bsz * n_resp_per_prompt)) # 2 * b * n / g
|
| 70 |
+
train_prompt_bsz=$((train_prompt_mini_bsz * 2)) # 4 * b * n / g
|
| 71 |
+
|
| 72 |
+
reward_fn_name=null
|
| 73 |
+
reward_fn_file_path=null
|
| 74 |
+
output_file="$(pwd)/output.txt"
|
| 75 |
+
if [ "${CUSTOM_REWARD_FN}" = "True" ]; then
|
| 76 |
+
reward_fn_name="my_reward_function"
|
| 77 |
+
reward_fn_file_path="$(pwd)/my_reward_function.py"
|
| 78 |
+
rm -rf "${reward_fn_file_path}"
|
| 79 |
+
cat <<EOF > "$reward_fn_file_path"
|
| 80 |
+
def ${reward_fn_name}(data_source, solution_str, ground_truth, extra_info=None):
|
| 81 |
+
print(f"Congratulations!!! You have called ${reward_fn_name} successfully!!!")
|
| 82 |
+
return 0.1
|
| 83 |
+
EOF
|
| 84 |
+
|
| 85 |
+
rm -rf "${output_file}"
|
| 86 |
+
fi
|
| 87 |
+
|
| 88 |
+
exp_name="${VERL_EXP_NAME:-$(basename "${MODEL_ID,,}")-function-reward-minimal}"
|
| 89 |
+
|
| 90 |
+
python3 -m verl.trainer.main_ppo \
|
| 91 |
+
algorithm.adv_estimator="${ADV_ESTIMATOR}" \
|
| 92 |
+
data.train_files="${TRAIN_FILES}" \
|
| 93 |
+
data.val_files="${VAL_FILES}" \
|
| 94 |
+
data.train_batch_size="${train_prompt_bsz}" \
|
| 95 |
+
data.max_prompt_length="${MAX_PROMPT_LEN}" \
|
| 96 |
+
data.max_response_length="${MAX_RESPONSE_LEN}" \
|
| 97 |
+
data.return_raw_chat=${RETURN_RAW_CHAT} \
|
| 98 |
+
actor_rollout_ref.model.path="${MODEL_PATH}" \
|
| 99 |
+
actor_rollout_ref.model.use_shm=${USE_SHM} \
|
| 100 |
+
actor_rollout_ref.model.lora_rank=${LORA_RANK} \
|
| 101 |
+
actor_rollout_ref.model.lora_alpha=${LORA_ALPHA} \
|
| 102 |
+
actor_rollout_ref.model.target_modules=${LORA_TARGET} \
|
| 103 |
+
actor_rollout_ref.model.exclude_modules=${LORA_EXCLUDE} \
|
| 104 |
+
actor_rollout_ref.actor.optim.lr=1e-6 \
|
| 105 |
+
actor_rollout_ref.model.use_remove_padding="${RM_PAD}" \
|
| 106 |
+
actor_rollout_ref.model.use_fused_kernels=${FUSED_KERNELS} \
|
| 107 |
+
actor_rollout_ref.model.fused_kernel_options.impl_backend=${FUSED_KERNEL_BACKEND} \
|
| 108 |
+
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
|
| 109 |
+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
|
| 110 |
+
actor_rollout_ref.actor.strategy=${STRATEGY} \
|
| 111 |
+
actor_rollout_ref.actor.fsdp_config.param_offload=${ACTOR_FSDP_PARAM_OFFLOAD} \
|
| 112 |
+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${ACTOR_FSDP_OPTIMIZER_OFFLOAD} \
|
| 113 |
+
actor_rollout_ref.actor.fsdp_config.fsdp_size=${FSDP_SIZE} \
|
| 114 |
+
actor_rollout_ref.actor.ulysses_sequence_parallel_size="${SP_SIZE}" \
|
| 115 |
+
actor_rollout_ref.actor.checkpoint.save_contents=${CHECKPOINT_CONTENTS} \
|
| 116 |
+
actor_rollout_ref.actor.use_kl_loss="${USE_KL}" \
|
| 117 |
+
actor_rollout_ref.actor.policy_loss.loss_mode="${LOSS_MODE}" \
|
| 118 |
+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
|
| 119 |
+
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
|
| 120 |
+
actor_rollout_ref.rollout.name="${ENGINE}" \
|
| 121 |
+
actor_rollout_ref.rollout.mode="${ROLLOUT_MODE}" \
|
| 122 |
+
actor_rollout_ref.rollout.load_format=${LOAD_FORMAT} \
|
| 123 |
+
actor_rollout_ref.rollout.layered_summon=${LAYERED_SUMMON} \
|
| 124 |
+
actor_rollout_ref.rollout.skip_tokenizer_init="${SKIP_TOKENIZER_INIT}" \
|
| 125 |
+
actor_rollout_ref.rollout.gpu_memory_utilization="${GPU_MEMORY_UTILIZATION}" \
|
| 126 |
+
actor_rollout_ref.rollout.enable_chunked_prefill="${ENABLE_CHUNKED_PREFILL}" \
|
| 127 |
+
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
|
| 128 |
+
actor_rollout_ref.ref.fsdp_config.param_offload="${REF_FSDP_PARAM_OFFLOAD}" \
|
| 129 |
+
critic.optim.lr=1e-5 \
|
| 130 |
+
critic.model.use_remove_padding="${RM_PAD}" \
|
| 131 |
+
critic.model.path="${MODEL_PATH}" \
|
| 132 |
+
critic.model.enable_gradient_checkpointing=False \
|
| 133 |
+
critic.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
|
| 134 |
+
critic.model.fsdp_config.param_offload=False \
|
| 135 |
+
critic.model.fsdp_config.optimizer_offload=False \
|
| 136 |
+
custom_reward_function.path="${reward_fn_file_path}"\
|
| 137 |
+
custom_reward_function.name="${reward_fn_name}"\
|
| 138 |
+
algorithm.use_kl_in_reward="${USE_KL}" \
|
| 139 |
+
algorithm.kl_penalty=kl \
|
| 140 |
+
algorithm.kl_ctrl.kl_coef=0.001 \
|
| 141 |
+
trainer.critic_warmup=0 \
|
| 142 |
+
trainer.logger=console \
|
| 143 |
+
trainer.project_name='verl-test' \
|
| 144 |
+
trainer.experiment_name="${exp_name}" \
|
| 145 |
+
trainer.nnodes=1 \
|
| 146 |
+
trainer.n_gpus_per_node="${NUM_GPUS}" \
|
| 147 |
+
trainer.val_before_train="${VAL_BEFORE_TRAIN}" \
|
| 148 |
+
trainer.test_freq="${TEST_FREQ}" \
|
| 149 |
+
trainer.save_freq="${SAVE_FREQ}" \
|
| 150 |
+
trainer.resume_mode="${RESUME_MODE}" \
|
| 151 |
+
trainer.total_epochs=2 \
|
| 152 |
+
trainer.device=cuda \
|
| 153 |
+
trainer.total_training_steps="${TOTAL_TRAIN_STEPS}" $@ \
|
| 154 |
+
| tee "${output_file}"
|
| 155 |
+
|
| 156 |
+
if [ "${CUSTOM_REWARD_FN}" = "True" ]; then
|
| 157 |
+
python3 tests/special_e2e/check_custom_rwd_fn.py --output_file="${output_file}"
|
| 158 |
+
check_exit_code=$?
|
| 159 |
+
rm -rf "${reward_fn_file_path}"
|
| 160 |
+
rm -rf "${output_file}"
|
| 161 |
+
# Return the exit code of check_custom_rwd_fn.py if it fails
|
| 162 |
+
if [ $check_exit_code -ne 0 ]; then
|
| 163 |
+
exit $check_exit_code
|
| 164 |
+
fi
|
| 165 |
+
fi
|
code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/run_model_reward.sh
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -xeuo pipefail
|
| 3 |
+
|
| 4 |
+
NUM_GPUS=${NUM_GPUS:-8}
|
| 5 |
+
|
| 6 |
+
MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B}
|
| 7 |
+
MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}
|
| 8 |
+
#hf download "${MODEL_ID}" --local-dir "${MODEL_PATH}"
|
| 9 |
+
|
| 10 |
+
TRAIN_FILES=${TRAIN_FILES:-$HOME/data/gsm8k/train.parquet}
|
| 11 |
+
VAL_FILES=${VAL_FILES:-$HOME/data/gsm8k/test.parquet}
|
| 12 |
+
|
| 13 |
+
RM_PAD=${RM_PAD:-True}
|
| 14 |
+
FUSED_KERNELS=${FUSED_KERNELS:-False}
|
| 15 |
+
FUSED_KERNEL_BACKEND=${FUSED_KERNEL_BACKEND:-torch} # or 'triton' for triton backend
|
| 16 |
+
SP_SIZE=${SP_SIZE:-1}
|
| 17 |
+
SEQ_BALANCE=${SEQ_BALANCE:-False}
|
| 18 |
+
LIGER=${LIGER:-False}
|
| 19 |
+
# Validation
|
| 20 |
+
VAL_BEFORE_TRAIN=${VAL_BEFORE_TRAIN:-False}
|
| 21 |
+
TEST_FREQ=${TEST_FREQ:--1}
|
| 22 |
+
# Save & Resume
|
| 23 |
+
RESUME_MODE=${RESUME_MODE:-disable}
|
| 24 |
+
SAVE_FREQ=${SAVE_FREQ:--1}
|
| 25 |
+
TOTAL_TRAIN_STEPS=${TOTAL_TRAIN_STEPS:-1}
|
| 26 |
+
|
| 27 |
+
train_traj_micro_bsz_per_gpu=2 # b
|
| 28 |
+
n_resp_per_prompt=4 # g
|
| 29 |
+
|
| 30 |
+
train_traj_micro_bsz=$((train_traj_micro_bsz_per_gpu * NUM_GPUS)) # b * n
|
| 31 |
+
train_traj_mini_bsz=$((train_traj_micro_bsz * 2)) # 2 * b * n
|
| 32 |
+
train_prompt_mini_bsz=$((train_traj_mini_bsz * n_resp_per_prompt)) # 2 * b * n / g
|
| 33 |
+
train_prompt_bsz=$((train_prompt_mini_bsz * 2)) # 4 * b * n / g
|
| 34 |
+
|
| 35 |
+
train_max_token_num_per_gpu=32768
|
| 36 |
+
infer_max_token_num_per_gpu=32768
|
| 37 |
+
|
| 38 |
+
exp_name="$(basename "${MODEL_ID,,}")-model-reward-minimal"
|
| 39 |
+
|
| 40 |
+
python3 -m verl.trainer.main_ppo \
|
| 41 |
+
algorithm.adv_estimator=gae \
|
| 42 |
+
data.train_files="${TRAIN_FILES}" \
|
| 43 |
+
data.val_files="${VAL_FILES}" \
|
| 44 |
+
data.train_batch_size=${train_prompt_bsz} \
|
| 45 |
+
data.max_prompt_length=512 \
|
| 46 |
+
data.max_response_length=512 \
|
| 47 |
+
data.return_raw_chat=True \
|
| 48 |
+
actor_rollout_ref.model.path="${MODEL_PATH}" \
|
| 49 |
+
actor_rollout_ref.model.use_liger="${LIGER}" \
|
| 50 |
+
actor_rollout_ref.actor.optim.lr=1e-6 \
|
| 51 |
+
actor_rollout_ref.model.use_remove_padding="${RM_PAD}" \
|
| 52 |
+
actor_rollout_ref.model.use_fused_kernels=${FUSED_KERNELS} \
|
| 53 |
+
actor_rollout_ref.model.fused_kernel_options.impl_backend=${FUSED_KERNEL_BACKEND} \
|
| 54 |
+
actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \
|
| 55 |
+
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
|
| 56 |
+
actor_rollout_ref.actor.use_dynamic_bsz="${SEQ_BALANCE}" \
|
| 57 |
+
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${train_max_token_num_per_gpu} \
|
| 58 |
+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
|
| 59 |
+
actor_rollout_ref.actor.ulysses_sequence_parallel_size="${SP_SIZE}" \
|
| 60 |
+
actor_rollout_ref.actor.fsdp_config.param_offload=False \
|
| 61 |
+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
|
| 62 |
+
actor_rollout_ref.actor.use_kl_loss=False \
|
| 63 |
+
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_max_token_num_per_gpu} \
|
| 64 |
+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
|
| 65 |
+
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
|
| 66 |
+
actor_rollout_ref.rollout.name=vllm \
|
| 67 |
+
actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
|
| 68 |
+
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_max_token_num_per_gpu} \
|
| 69 |
+
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
|
| 70 |
+
critic.optim.lr=1e-5 \
|
| 71 |
+
critic.ulysses_sequence_parallel_size="${SP_SIZE}" \
|
| 72 |
+
critic.model.use_remove_padding="${RM_PAD}" \
|
| 73 |
+
critic.optim.lr_warmup_steps_ratio=0.05 \
|
| 74 |
+
critic.model.path="${MODEL_PATH}" \
|
| 75 |
+
critic.model.enable_gradient_checkpointing=False \
|
| 76 |
+
critic.use_dynamic_bsz="${SEQ_BALANCE}" \
|
| 77 |
+
critic.ppo_max_token_len_per_gpu=${train_max_token_num_per_gpu} \
|
| 78 |
+
critic.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
|
| 79 |
+
critic.model.fsdp_config.param_offload=False \
|
| 80 |
+
critic.model.fsdp_config.optimizer_offload=False \
|
| 81 |
+
reward_model.enable=True \
|
| 82 |
+
reward_model.model.path="${MODEL_PATH}" \
|
| 83 |
+
reward_model.use_reward_loop=True \
|
| 84 |
+
reward_model.rollout.gpu_memory_utilization=0.8 \
|
| 85 |
+
reward_model.rollout.tensor_model_parallel_size=1 \
|
| 86 |
+
reward_model.rollout.prompt_length=1024 \
|
| 87 |
+
reward_model.rollout.response_length=512 \
|
| 88 |
+
reward_model.num_workers=8 \
|
| 89 |
+
algorithm.use_kl_in_reward=False \
|
| 90 |
+
trainer.critic_warmup=0 \
|
| 91 |
+
trainer.logger=console \
|
| 92 |
+
trainer.project_name='verl-test' \
|
| 93 |
+
trainer.experiment_name="${exp_name}" \
|
| 94 |
+
trainer.nnodes=1 \
|
| 95 |
+
trainer.n_gpus_per_node="${NUM_GPUS}" \
|
| 96 |
+
trainer.val_before_train="${VAL_BEFORE_TRAIN}" \
|
| 97 |
+
trainer.test_freq="${VAL_BEFORE_TRAIN}" \
|
| 98 |
+
trainer.save_freq="${SAVE_FREQ}" \
|
| 99 |
+
trainer.resume_mode="${RESUME_MODE}" \
|
| 100 |
+
trainer.total_epochs=2 \
|
| 101 |
+
trainer.total_training_steps="${TOTAL_TRAIN_STEPS}" $@
|
code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/run_single_gpu.sh
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \
|
| 2 |
+
data.train_files=$HOME/data/gsm8k/train.parquet \
|
| 3 |
+
data.val_files=$HOME/data/gsm8k/test.parquet \
|
| 4 |
+
data.train_batch_size=256 \
|
| 5 |
+
data.max_prompt_length=512 \
|
| 6 |
+
data.max_response_length=256 \
|
| 7 |
+
actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \
|
| 8 |
+
actor_rollout_ref.actor.optim.lr=1e-6 \
|
| 9 |
+
actor_rollout_ref.actor.ppo_mini_batch_size=64 \
|
| 10 |
+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
|
| 11 |
+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
|
| 12 |
+
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
|
| 13 |
+
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
|
| 14 |
+
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
|
| 15 |
+
critic.optim.lr=1e-5 \
|
| 16 |
+
critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \
|
| 17 |
+
critic.ppo_micro_batch_size_per_gpu=4 \
|
| 18 |
+
algorithm.kl_ctrl.kl_coef=0.001 \
|
| 19 |
+
trainer.logger=console \
|
| 20 |
+
trainer.val_before_train=False \
|
| 21 |
+
trainer.n_gpus_per_node=1 \
|
| 22 |
+
trainer.nnodes=1 \
|
| 23 |
+
actor_rollout_ref.rollout.name=hf \
|
| 24 |
+
trainer.total_training_steps=2
|
code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/run_single_gpu_with_engine.sh
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \
|
| 2 |
+
data.train_files=$HOME/data/gsm8k/train.parquet \
|
| 3 |
+
data.val_files=$HOME/data/gsm8k/test.parquet \
|
| 4 |
+
data.train_batch_size=256 \
|
| 5 |
+
data.max_prompt_length=512 \
|
| 6 |
+
data.max_response_length=256 \
|
| 7 |
+
actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \
|
| 8 |
+
actor_rollout_ref.actor.optim.lr=1e-6 \
|
| 9 |
+
actor_rollout_ref.actor.ppo_mini_batch_size=64 \
|
| 10 |
+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
|
| 11 |
+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
|
| 12 |
+
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
|
| 13 |
+
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
|
| 14 |
+
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
|
| 15 |
+
critic.optim.lr=1e-5 \
|
| 16 |
+
critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \
|
| 17 |
+
critic.ppo_micro_batch_size_per_gpu=4 \
|
| 18 |
+
algorithm.kl_ctrl.kl_coef=0.001 \
|
| 19 |
+
trainer.logger=['console'] \
|
| 20 |
+
trainer.val_before_train=False \
|
| 21 |
+
trainer.n_gpus_per_node=1 \
|
| 22 |
+
trainer.nnodes=1 \
|
| 23 |
+
actor_rollout_ref.rollout.name=hf \
|
| 24 |
+
trainer.use_legacy_worker_impl=disable \
|
| 25 |
+
trainer.total_training_steps=2
|
code/RL_model/verl/verl_train/tests/special_e2e/sft/compare_sft_engine_results.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_result(file):
|
| 22 |
+
file = os.path.expanduser(file)
|
| 23 |
+
result = []
|
| 24 |
+
with open(file) as f:
|
| 25 |
+
lines = f.readlines()
|
| 26 |
+
for line in lines:
|
| 27 |
+
result.append(json.loads(line))
|
| 28 |
+
return result
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def compare_results(golden_results, other_result):
|
| 32 |
+
golden_loss = golden_results[0]["data"]["train/loss"]
|
| 33 |
+
golden_grad_norm = golden_results[0]["data"]["train/grad_norm"]
|
| 34 |
+
|
| 35 |
+
loss = other_result[0]["data"]["train/loss"]
|
| 36 |
+
grad_norm = other_result[0]["data"]["train/grad_norm"]
|
| 37 |
+
|
| 38 |
+
torch.testing.assert_close(golden_loss, loss, atol=1e-2, rtol=1e-2)
|
| 39 |
+
torch.testing.assert_close(golden_grad_norm, grad_norm, atol=1e-4, rtol=3e-2)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
if __name__ == "__main__":
|
| 43 |
+
golden_results = get_result("~/verl/test/log/golden.jsonl")
|
| 44 |
+
|
| 45 |
+
# get all other results
|
| 46 |
+
other_results = {}
|
| 47 |
+
# walk through all files in ~/verl/test/log
|
| 48 |
+
for file in os.listdir(os.path.expanduser("~/verl/test/log/verl_sft_test")):
|
| 49 |
+
if file.endswith(".jsonl"):
|
| 50 |
+
other_results[file] = get_result(os.path.join(os.path.expanduser("~/verl/test/log/verl_sft_test"), file))
|
| 51 |
+
|
| 52 |
+
# # compare results
|
| 53 |
+
for file, other_result in other_results.items():
|
| 54 |
+
print(f"compare results {file}")
|
| 55 |
+
compare_results(golden_results, other_result)
|
| 56 |
+
print(f"compare results {file} done")
|
| 57 |
+
|
| 58 |
+
print("All results are close to golden results")
|
code/RL_model/verl/verl_train/tests/special_e2e/sft/run_sft.sh
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -xeuo pipefail
|
| 3 |
+
|
| 4 |
+
ENTRYPOINT=${ENTRYPOINT:-"-m verl.trainer.fsdp_sft_trainer"}
|
| 5 |
+
|
| 6 |
+
NUM_GPUS=${NUM_GPUS:-8}
|
| 7 |
+
|
| 8 |
+
MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct}
|
| 9 |
+
MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}
|
| 10 |
+
#hf download "${MODEL_ID}" --local-dir "${MODEL_PATH}"
|
| 11 |
+
|
| 12 |
+
TRAIN_FILES=${TRAIN_FILES:-$HOME/data/gsm8k/train.parquet}
|
| 13 |
+
VAL_FILES=${VAL_FILES:-$HOME/data/gsm8k/test.parquet}
|
| 14 |
+
|
| 15 |
+
SP_SIZE=${SP_SIZE:-1}
|
| 16 |
+
LIGER=${LIGER:-False}
|
| 17 |
+
MULTITURN=${MULTITURN:-False}
|
| 18 |
+
LORA_RANK=${LORA_RANK:-0}
|
| 19 |
+
RM_PAD=${RM_PAD:-True}
|
| 20 |
+
|
| 21 |
+
TOTAL_TRAIN_STEP=${TOTAL_TRAIN_STEP:-1}
|
| 22 |
+
RESUME_MODE=${RESUME_MODE:-disable}
|
| 23 |
+
SAVE_FREQ=${SAVE_FREQ:-1}
|
| 24 |
+
|
| 25 |
+
micro_bsz=2
|
| 26 |
+
NUM_GPUS=8
|
| 27 |
+
|
| 28 |
+
project_name="verl-test"
|
| 29 |
+
exp_name="$(basename "${MODEL_ID,,}")-sft-minimal"
|
| 30 |
+
ckpts_home=${ckpts_home:-$HOME/${project_name}/${exp_name}}
|
| 31 |
+
|
| 32 |
+
mkdir -p "${ckpts_home}"
|
| 33 |
+
|
| 34 |
+
torchrun --standalone --nnodes=1 --nproc_per_node=${NUM_GPUS} ${ENTRYPOINT} \
|
| 35 |
+
data.train_files="${TRAIN_FILES}" \
|
| 36 |
+
data.val_files="${VAL_FILES}" \
|
| 37 |
+
data.prompt_key=extra_info \
|
| 38 |
+
data.response_key=extra_info \
|
| 39 |
+
data.prompt_dict_keys=['question'] \
|
| 40 |
+
data.response_dict_keys=['answer'] \
|
| 41 |
+
data.multiturn.enable="${MULTITURN}" \
|
| 42 |
+
data.multiturn.messages_key=messages \
|
| 43 |
+
optim.lr=1e-4 \
|
| 44 |
+
data.micro_batch_size_per_gpu=${micro_bsz} \
|
| 45 |
+
model.strategy=fsdp \
|
| 46 |
+
model.partial_pretrain="${MODEL_PATH}" \
|
| 47 |
+
model.lora_rank="${LORA_RANK}" \
|
| 48 |
+
model.lora_alpha=16 \
|
| 49 |
+
model.target_modules=all-linear \
|
| 50 |
+
model.use_liger="${LIGER}" \
|
| 51 |
+
ulysses_sequence_parallel_size="${SP_SIZE}" \
|
| 52 |
+
use_remove_padding="${RM_PAD}" \
|
| 53 |
+
trainer.default_local_dir="${ckpts_home}" \
|
| 54 |
+
trainer.project_name="${project_name}" \
|
| 55 |
+
trainer.experiment_name="${exp_name}" \
|
| 56 |
+
trainer.total_training_steps=${TOTAL_TRAIN_STEP} \
|
| 57 |
+
trainer.save_freq=${SAVE_FREQ} \
|
| 58 |
+
trainer.checkpoint.save_contents=[model,optimizer,extra,hf_model] \
|
| 59 |
+
trainer.max_ckpt_to_keep=1 \
|
| 60 |
+
trainer.resume_mode=${RESUME_MODE} \
|
| 61 |
+
trainer.logger=['console'] $@
|
| 62 |
+
|
| 63 |
+
rm -rf "${ckpts_home:?}/*"
|
code/RL_model/verl/verl_train/tests/special_e2e/sft/run_sft_engine.sh
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -xeuo pipefail
|
| 3 |
+
|
| 4 |
+
NUM_GPUS=${NUM_GPUS:-1}
|
| 5 |
+
|
| 6 |
+
mode=${mode:-spmd}
|
| 7 |
+
|
| 8 |
+
if [ "$mode" = "spmd" ]; then
|
| 9 |
+
ENTRYPOINT=${ENTRYPOINT:-"-m verl.trainer.sft_trainer"}
|
| 10 |
+
COMMAND="torchrun --standalone --nnodes=${NNODES:-1} --nproc-per-node=${NUM_GPUS:-1} ${ENTRYPOINT}"
|
| 11 |
+
else
|
| 12 |
+
ENTRYPOINT=${ENTRYPOINT:-"-m verl.trainer.sft_trainer_ray"}
|
| 13 |
+
COMMAND="python ${ENTRYPOINT} trainer.nnodes=${NNODES:-1} trainer.n_gpus_per_node=${NUM_GPUS:-1}"
|
| 14 |
+
fi
|
| 15 |
+
|
| 16 |
+
DATASET_DIR=${DATASET_DIR:-~/data/gsm8k_sft}
|
| 17 |
+
TRAIN_FILES=${DATASET_DIR}/train.parquet
|
| 18 |
+
VAL_FILES=${DATASET_DIR}/test.parquet
|
| 19 |
+
|
| 20 |
+
backend=${BACKEND:-fsdp}
|
| 21 |
+
|
| 22 |
+
project_name=verl_sft_test
|
| 23 |
+
|
| 24 |
+
RESUME_MODE=disable
|
| 25 |
+
|
| 26 |
+
ckpts_home=${ckpts_home:-~/verl/test/gsm8k-sft-${backend}}
|
| 27 |
+
|
| 28 |
+
MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B}
|
| 29 |
+
MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}
|
| 30 |
+
#hf download "${MODEL_ID}" --local-dir "${MODEL_PATH}"
|
| 31 |
+
|
| 32 |
+
SP_SIZE=${SP_SIZE:-1}
|
| 33 |
+
FSDP_SIZE=${FSDP_SIZE:-${NUM_GPUS}}
|
| 34 |
+
FSDP_STRATEGY=${FSDP_STRATEGY:-"fsdp"}
|
| 35 |
+
|
| 36 |
+
TP_SIZE=${TP_SIZE:-1}
|
| 37 |
+
PP_SIZE=${PP_SIZE:-1}
|
| 38 |
+
VPP_SIZE=${VPP_SIZE:-null}
|
| 39 |
+
CP_SIZE=${CP_SIZE:-1}
|
| 40 |
+
|
| 41 |
+
PAD_MODE=${PAD_MODE:-no_padding}
|
| 42 |
+
|
| 43 |
+
USE_REMOVE_PADDING=${USE_REMOVE_PADDING:-True}
|
| 44 |
+
|
| 45 |
+
FSDP_ENGINE_CONFIG="\
|
| 46 |
+
engine=${backend} \
|
| 47 |
+
optim=${backend} \
|
| 48 |
+
optim.lr=1e-5 \
|
| 49 |
+
optim.lr_warmup_steps_ratio=0.2 \
|
| 50 |
+
optim.weight_decay=0.1 \
|
| 51 |
+
optim.betas="[0.9,0.95]" \
|
| 52 |
+
optim.clip_grad=1.0 \
|
| 53 |
+
optim.min_lr_ratio=0.1 \
|
| 54 |
+
optim.lr_scheduler_type=cosine \
|
| 55 |
+
engine.ulysses_sequence_parallel_size=${SP_SIZE} \
|
| 56 |
+
engine.strategy=${FSDP_STRATEGY} \
|
| 57 |
+
engine.fsdp_size=${FSDP_SIZE}"
|
| 58 |
+
|
| 59 |
+
VEOMNI_ENGINE_CONFIG="\
|
| 60 |
+
engine=${backend} \
|
| 61 |
+
optim=${backend} \
|
| 62 |
+
optim.lr=1e-5 \
|
| 63 |
+
optim.lr_warmup_steps_ratio=0.2 \
|
| 64 |
+
optim.weight_decay=0.1 \
|
| 65 |
+
optim.betas="[0.9,0.95]" \
|
| 66 |
+
optim.clip_grad=1.0 \
|
| 67 |
+
optim.lr_min=1e-6 \
|
| 68 |
+
optim.lr_scheduler_type=cosine \
|
| 69 |
+
engine.ulysses_parallel_size=${SP_SIZE} \
|
| 70 |
+
engine.data_parallel_mode=${FSDP_STRATEGY} \
|
| 71 |
+
engine.data_parallel_size=${FSDP_SIZE}"
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
MEGATRON_ENGINE_CONFIG="\
|
| 75 |
+
engine=${backend} \
|
| 76 |
+
optim=${backend} \
|
| 77 |
+
optim.lr=1e-5 \
|
| 78 |
+
optim.lr_warmup_steps_ratio=0.2 \
|
| 79 |
+
optim.weight_decay=0.1 \
|
| 80 |
+
optim.betas="[0.9,0.95]" \
|
| 81 |
+
optim.clip_grad=1.0 \
|
| 82 |
+
optim.lr_warmup_init=0 \
|
| 83 |
+
optim.lr_decay_style=cosine \
|
| 84 |
+
optim.min_lr=1e-6 \
|
| 85 |
+
engine.tensor_model_parallel_size=${TP_SIZE} \
|
| 86 |
+
engine.pipeline_model_parallel_size=${PP_SIZE} \
|
| 87 |
+
engine.virtual_pipeline_model_parallel_size=${VPP_SIZE} \
|
| 88 |
+
engine.context_parallel_size=${CP_SIZE} \
|
| 89 |
+
+engine.override_transformer_config.context_parallel_size=${CP_SIZE} \
|
| 90 |
+
engine.use_mbridge=True"
|
| 91 |
+
|
| 92 |
+
if [ "$backend" = "fsdp" ]; then
|
| 93 |
+
ENGINE_CONFIG="$FSDP_ENGINE_CONFIG"
|
| 94 |
+
echo "Using fsdp engine"
|
| 95 |
+
exp_name=gsm8k-${backend}-${FSDP_STRATEGY}-sp${SP_SIZE}-fsdp${FSDP_SIZE}-pad-${PAD_MODE}-use_remove_padding-${USE_REMOVE_PADDING}-mode-${mode}
|
| 96 |
+
elif [ "$backend" = "veomni" ]; then
|
| 97 |
+
ENGINE_CONFIG="$VEOMNI_ENGINE_CONFIG"
|
| 98 |
+
echo "Using veomni engine"
|
| 99 |
+
exp_name=gsm8k-${backend}-sp${SP_SIZE}-fsdp${FSDP_SIZE}-pad-${PAD_MODE}-use_remove_padding-${USE_REMOVE_PADDING}-mode-${mode}
|
| 100 |
+
else
|
| 101 |
+
ENGINE_CONFIG="$MEGATRON_ENGINE_CONFIG"
|
| 102 |
+
echo "Using megatron engine"
|
| 103 |
+
exp_name=gsm8k-${backend}-tp${TP_SIZE}-pp${PP_SIZE}-vpp${VPP_SIZE}-cp${CP_SIZE}-pad-${PAD_MODE}-use_remove_padding-${USE_REMOVE_PADDING}-mode-${mode}
|
| 104 |
+
fi
|
| 105 |
+
|
| 106 |
+
mkdir -p "${ckpts_home}"
|
| 107 |
+
|
| 108 |
+
$COMMAND \
|
| 109 |
+
data.train_files="${TRAIN_FILES}" \
|
| 110 |
+
data.val_files="${VAL_FILES}" \
|
| 111 |
+
data.train_batch_size=128 \
|
| 112 |
+
data.pad_mode=${PAD_MODE} \
|
| 113 |
+
data.truncation=error \
|
| 114 |
+
data.use_dynamic_bsz=True \
|
| 115 |
+
data.max_token_len_per_gpu=2048 \
|
| 116 |
+
data.messages_key=messages \
|
| 117 |
+
model.path=$MODEL_PATH \
|
| 118 |
+
model.use_remove_padding=${USE_REMOVE_PADDING} \
|
| 119 |
+
${ENGINE_CONFIG} \
|
| 120 |
+
trainer.test_freq=after_each_epoch \
|
| 121 |
+
trainer.save_freq=-1 \
|
| 122 |
+
trainer.logger=['console','file'] \
|
| 123 |
+
trainer.project_name="${project_name}" \
|
| 124 |
+
trainer.experiment_name="${exp_name}" \
|
| 125 |
+
trainer.total_epochs=2 \
|
| 126 |
+
trainer.total_training_steps=2 \
|
| 127 |
+
trainer.default_local_dir="${ckpts_home}" \
|
| 128 |
+
trainer.resume_mode=${RESUME_MODE} \
|
| 129 |
+
|
| 130 |
+
# trainer.total_training_steps=${TOTAL_TRAIN_STEP} \
|
| 131 |
+
# trainer.checkpoint.save_contents=[model,optimizer,extra,hf_model] \
|
| 132 |
+
# trainer.max_ckpt_to_keep=1 \
|
| 133 |
+
|
| 134 |
+
rm -rf "${ckpts_home:?}/*"
|
code/RL_model/verl/verl_train/tests/special_e2e/sft/test_sft_engine_all.sh
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -xeuo pipefail
|
| 3 |
+
|
| 4 |
+
rm -rf ~/verl/test/log
|
| 5 |
+
mkdir -p ~/verl/test/log
|
| 6 |
+
|
| 7 |
+
export VERL_FILE_LOGGER_ROOT=~/verl/test/log
|
| 8 |
+
VPP_SIZE=${VPP_SIZE:-2}
|
| 9 |
+
|
| 10 |
+
# test with single gpu as golden
|
| 11 |
+
echo "run with single gpu as golden"
|
| 12 |
+
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=1 NUM_GPUS=1 FSDP_STRATEGY=fsdp VERL_FILE_LOGGER_PATH=~/verl/test/log/golden.jsonl bash tests/special_e2e/sft/run_sft_engine.sh
|
| 13 |
+
|
| 14 |
+
# test with fsdp 1
|
| 15 |
+
echo "run with sp2 fsdp_size2 num_gpus8 fsdp_strategy fsdp pad_mode no_padding"
|
| 16 |
+
BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine.sh
|
| 17 |
+
|
| 18 |
+
# test with fsdp 1 use_remove_padding and pad_mode no_padding
|
| 19 |
+
echo "run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp pad_mode no_padding use_remove_padding False"
|
| 20 |
+
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding USE_REMOVE_PADDING=False bash tests/special_e2e/sft/run_sft_engine.sh
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# test with fsdp 2
|
| 24 |
+
echo "run with sp2 fsdp_size2 num_gpus8 fsdp_strategy fsdp2"
|
| 25 |
+
BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine.sh
|
| 26 |
+
|
| 27 |
+
# test with veomni
|
| 28 |
+
echo "run with sp2 fsdp_size4 num_gpus8 fsdp_strategy fsdp2"
|
| 29 |
+
BACKEND=veomni SP_SIZE=2 FSDP_SIZE=4 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine.sh
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# test with megatron
|
| 33 |
+
echo "run with tp2 pp2 vpp2 cp2 num_gpus8"
|
| 34 |
+
BACKEND=megatron TP_SIZE=2 PP_SIZE=2 VPP_SIZE=${VPP_SIZE} CP_SIZE=2 NUM_GPUS=8 bash tests/special_e2e/sft/run_sft_engine.sh
|
| 35 |
+
|
| 36 |
+
# test with cp in ray
|
| 37 |
+
echo "run with tp2 pp2 vpp2 cp2 num_gpus8 mode=ray"
|
| 38 |
+
BACKEND=megatron TP_SIZE=2 PP_SIZE=2 VPP_SIZE=${VPP_SIZE} CP_SIZE=2 NUM_GPUS=8 mode=ray bash tests/special_e2e/sft/run_sft_engine.sh
|
| 39 |
+
|
| 40 |
+
python3 tests/special_e2e/sft/compare_sft_engine_results.py
|
| 41 |
+
|
| 42 |
+
rm -rf ~/verl/test/log
|
code/RL_model/verl/verl_train/tests/special_e2e/sft/test_sp_loss_match.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.distributed
|
| 17 |
+
from tensordict import TensorDict
|
| 18 |
+
from torch.distributed.device_mesh import init_device_mesh
|
| 19 |
+
|
| 20 |
+
from verl.trainer.fsdp_sft_trainer import FSDPSFTTrainer
|
| 21 |
+
from verl.utils.distributed import initialize_global_process_group
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def test_trainer_forward_consistency(trainer: FSDPSFTTrainer, total_steps: int = 4):
|
| 25 |
+
"""Test consistency between original forward pass and SP+rmpad forward passes.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
trainer: The FSDPSFTTrainer instance to test
|
| 29 |
+
total_steps: Number of steps to test (default: 4)
|
| 30 |
+
"""
|
| 31 |
+
if trainer.device_mesh.get_rank() == 0:
|
| 32 |
+
print("\nStarting debug comparison between original and SP+rmpad forward passes...")
|
| 33 |
+
print(f"Sequence parallel size: {trainer.config.ulysses_sequence_parallel_size}")
|
| 34 |
+
print(f"Remove padding: {trainer.use_remove_padding}\n")
|
| 35 |
+
|
| 36 |
+
steps_remaining = total_steps
|
| 37 |
+
|
| 38 |
+
for epoch in range(1): # Just one epoch for testing
|
| 39 |
+
trainer.train_sampler.set_epoch(epoch=epoch)
|
| 40 |
+
for data in trainer.train_dataloader:
|
| 41 |
+
data = TensorDict(data, batch_size=trainer.config.data.train_batch_size).cuda()
|
| 42 |
+
trainer.fsdp_model.train()
|
| 43 |
+
micro_batches = data.split(trainer.config.data.micro_batch_size_per_gpu)
|
| 44 |
+
|
| 45 |
+
for idx, micro_batch in enumerate(micro_batches):
|
| 46 |
+
if trainer.device_mesh.get_rank() == 0:
|
| 47 |
+
print(f"\nProcessing micro batch {idx + 1}/{len(micro_batches)}")
|
| 48 |
+
|
| 49 |
+
# Compute losses using both methods
|
| 50 |
+
# Disable SP and rmpad
|
| 51 |
+
trainer.use_remove_padding = False
|
| 52 |
+
old_sp = trainer.config.ulysses_sequence_parallel_size
|
| 53 |
+
trainer.config.ulysses_sequence_parallel_size = 1
|
| 54 |
+
loss_ref = trainer._compute_loss_and_backward(micro_batch.copy(), do_backward=False)
|
| 55 |
+
|
| 56 |
+
# Do SP and rmpad
|
| 57 |
+
trainer.config.ulysses_sequence_parallel_size = old_sp
|
| 58 |
+
trainer.use_remove_padding = True
|
| 59 |
+
loss_sp = trainer._compute_loss_and_backward(micro_batch.copy(), do_backward=False)
|
| 60 |
+
|
| 61 |
+
# Collect losses across all ranks
|
| 62 |
+
loss_ref_all = loss_ref.clone()
|
| 63 |
+
loss_sp_all = loss_sp.clone()
|
| 64 |
+
torch.distributed.all_reduce(loss_ref_all, op=torch.distributed.ReduceOp.AVG)
|
| 65 |
+
torch.distributed.all_reduce(loss_sp_all, op=torch.distributed.ReduceOp.AVG)
|
| 66 |
+
|
| 67 |
+
# Calculate relative difference of averaged losses
|
| 68 |
+
rel_diff = torch.abs(loss_ref_all - loss_sp_all) / (torch.abs(loss_ref_all) + 1e-8)
|
| 69 |
+
|
| 70 |
+
if trainer.device_mesh.get_rank() == 0:
|
| 71 |
+
print("\nComparison Results (Averaged across ranks):")
|
| 72 |
+
print(f"Reference Loss: {loss_ref_all.item():.6f}")
|
| 73 |
+
print(f"SP+rmpad Loss: {loss_sp_all.item():.6f}")
|
| 74 |
+
print(f"Relative Difference: {rel_diff.item():.6f}")
|
| 75 |
+
|
| 76 |
+
assert rel_diff.item() < 1e-2, "Significant difference detected between averaged losses!"
|
| 77 |
+
print("Loss difference is within the acceptable range.")
|
| 78 |
+
|
| 79 |
+
steps_remaining -= 1
|
| 80 |
+
if steps_remaining == 0:
|
| 81 |
+
break
|
| 82 |
+
if steps_remaining == 0:
|
| 83 |
+
break
|
| 84 |
+
break
|
| 85 |
+
|
| 86 |
+
if trainer.device_mesh.get_rank() == 0:
|
| 87 |
+
print("\nDebug comparison completed successfully.")
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def create_trainer(config):
|
| 91 |
+
"""Create and initialize a trainer instance with the given config.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
config: Configuration object with training parameters
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
FSDPSFTTrainer: Initialized trainer instance
|
| 98 |
+
"""
|
| 99 |
+
local_rank, rank, world_size = initialize_global_process_group()
|
| 100 |
+
|
| 101 |
+
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,), mesh_dim_names=("fsdp",))
|
| 102 |
+
|
| 103 |
+
dp_size = world_size // config.ulysses_sequence_parallel_size
|
| 104 |
+
ulysses_device_mesh = init_device_mesh(
|
| 105 |
+
device_type="cuda", mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), mesh_dim_names=("dp", "sp")
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# build tokenizer and datasets first
|
| 109 |
+
from verl.trainer.fsdp_sft_trainer import create_sft_dataset
|
| 110 |
+
from verl.utils import hf_tokenizer
|
| 111 |
+
from verl.utils.fs import copy_to_local
|
| 112 |
+
|
| 113 |
+
local_model_path = copy_to_local(src=config.model.partial_pretrain, verbose=True)
|
| 114 |
+
tokenizer = hf_tokenizer(local_model_path, trust_remote_code=config.model.trust_remote_code)
|
| 115 |
+
train_dataset = create_sft_dataset(
|
| 116 |
+
config.data.train_files, config.data, tokenizer, max_samples=config.data.get("train_max_samples", -1)
|
| 117 |
+
)
|
| 118 |
+
val_dataset = create_sft_dataset(
|
| 119 |
+
config.data.val_files, config.data, tokenizer, max_samples=config.data.get("val_max_samples", -1)
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
return FSDPSFTTrainer(
|
| 123 |
+
config=config,
|
| 124 |
+
device_mesh=device_mesh,
|
| 125 |
+
ulysses_device_mesh=ulysses_device_mesh,
|
| 126 |
+
tokenizer=tokenizer,
|
| 127 |
+
train_dataset=train_dataset,
|
| 128 |
+
val_dataset=val_dataset,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def main(config):
|
| 133 |
+
"""Main function to run trainer tests.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
config: Configuration object with training parameters
|
| 137 |
+
"""
|
| 138 |
+
trainer = create_trainer(config)
|
| 139 |
+
test_trainer_forward_consistency(trainer)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
if __name__ == "__main__":
|
| 143 |
+
import hydra
|
| 144 |
+
from omegaconf import DictConfig
|
| 145 |
+
|
| 146 |
+
@hydra.main(config_path="../../../verl/trainer/config", config_name="sft_trainer")
|
| 147 |
+
def hydra_entry(cfg: DictConfig) -> None:
|
| 148 |
+
main(cfg)
|
| 149 |
+
|
| 150 |
+
hydra_entry()
|
code/RL_model/verl/verl_train/tests/trainer/config/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
code/RL_model/verl/verl_train/tests/utils/ckpt/test_checkpoint_cleanup_on_cpu.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import shutil
|
| 17 |
+
import tempfile
|
| 18 |
+
|
| 19 |
+
import pytest
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class TestCheckpointCleanupLogic:
|
| 23 |
+
"""Tests for checkpoint cleanup methods in BaseCheckpointManager."""
|
| 24 |
+
|
| 25 |
+
@pytest.fixture(autouse=True)
|
| 26 |
+
def setup(self):
|
| 27 |
+
"""Set up test fixtures."""
|
| 28 |
+
self.test_dir = tempfile.mkdtemp()
|
| 29 |
+
yield
|
| 30 |
+
shutil.rmtree(self.test_dir, ignore_errors=True)
|
| 31 |
+
|
| 32 |
+
@pytest.fixture
|
| 33 |
+
def manager(self, monkeypatch):
|
| 34 |
+
"""Create a minimal BaseCheckpointManager for testing."""
|
| 35 |
+
import torch.distributed
|
| 36 |
+
|
| 37 |
+
monkeypatch.setattr(torch.distributed, "get_rank", lambda: 0)
|
| 38 |
+
monkeypatch.setattr(torch.distributed, "get_world_size", lambda: 1)
|
| 39 |
+
|
| 40 |
+
from verl.utils.checkpoint.checkpoint_manager import BaseCheckpointManager
|
| 41 |
+
|
| 42 |
+
class MockModel:
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
class MockOptimizer:
|
| 46 |
+
pass
|
| 47 |
+
|
| 48 |
+
return BaseCheckpointManager(
|
| 49 |
+
model=MockModel(),
|
| 50 |
+
optimizer=MockOptimizer(),
|
| 51 |
+
lr_scheduler=None,
|
| 52 |
+
processing_class=None,
|
| 53 |
+
checkpoint_config=None,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
def _create_checkpoint_dir(self, step: int) -> str:
|
| 57 |
+
"""Create a mock checkpoint directory."""
|
| 58 |
+
path = os.path.join(self.test_dir, f"global_step_{step}")
|
| 59 |
+
os.makedirs(path, exist_ok=True)
|
| 60 |
+
with open(os.path.join(path, "checkpoint.txt"), "w") as f:
|
| 61 |
+
f.write(f"step={step}")
|
| 62 |
+
return path
|
| 63 |
+
|
| 64 |
+
def test_max_ckpt_1_preserves_existing_before_save(self, manager):
|
| 65 |
+
"""
|
| 66 |
+
Regression test: max_ckpt_to_keep=1 must NOT delete existing checkpoint before save.
|
| 67 |
+
"""
|
| 68 |
+
ckpt_100 = self._create_checkpoint_dir(100)
|
| 69 |
+
manager.previous_saved_paths = [ckpt_100]
|
| 70 |
+
|
| 71 |
+
manager.ensure_checkpoint_capacity(max_ckpt_to_keep=1)
|
| 72 |
+
|
| 73 |
+
assert os.path.exists(ckpt_100), "Bug: checkpoint deleted before save!"
|
| 74 |
+
assert manager.previous_saved_paths == [ckpt_100]
|
| 75 |
+
|
| 76 |
+
def test_max_ckpt_1_deletes_old_after_save(self, manager):
|
| 77 |
+
"""After save succeeds, old checkpoint should be deleted."""
|
| 78 |
+
ckpt_100 = self._create_checkpoint_dir(100)
|
| 79 |
+
manager.previous_saved_paths = [ckpt_100]
|
| 80 |
+
|
| 81 |
+
ckpt_200 = self._create_checkpoint_dir(200)
|
| 82 |
+
manager.register_checkpoint(ckpt_200, max_ckpt_to_keep=1)
|
| 83 |
+
|
| 84 |
+
assert not os.path.exists(ckpt_100)
|
| 85 |
+
assert os.path.exists(ckpt_200)
|
| 86 |
+
assert manager.previous_saved_paths == [ckpt_200]
|
| 87 |
+
|
| 88 |
+
def test_max_ckpt_2_keeps_one_before_save(self, manager):
|
| 89 |
+
"""With max_ckpt_to_keep=2, pre-save cleanup keeps 1 checkpoint."""
|
| 90 |
+
ckpt_100 = self._create_checkpoint_dir(100)
|
| 91 |
+
ckpt_200 = self._create_checkpoint_dir(200)
|
| 92 |
+
manager.previous_saved_paths = [ckpt_100, ckpt_200]
|
| 93 |
+
|
| 94 |
+
manager.ensure_checkpoint_capacity(max_ckpt_to_keep=2)
|
| 95 |
+
|
| 96 |
+
assert not os.path.exists(ckpt_100)
|
| 97 |
+
assert os.path.exists(ckpt_200)
|
| 98 |
+
assert len(manager.previous_saved_paths) == 1
|
| 99 |
+
|
| 100 |
+
def test_max_ckpt_0_keeps_all(self, manager):
|
| 101 |
+
"""max_ckpt_to_keep=0 means unlimited - no deletions."""
|
| 102 |
+
ckpt_100 = self._create_checkpoint_dir(100)
|
| 103 |
+
ckpt_200 = self._create_checkpoint_dir(200)
|
| 104 |
+
manager.previous_saved_paths = [ckpt_100, ckpt_200]
|
| 105 |
+
|
| 106 |
+
manager.ensure_checkpoint_capacity(max_ckpt_to_keep=0)
|
| 107 |
+
ckpt_300 = self._create_checkpoint_dir(300)
|
| 108 |
+
manager.register_checkpoint(ckpt_300, max_ckpt_to_keep=0)
|
| 109 |
+
|
| 110 |
+
assert os.path.exists(ckpt_100)
|
| 111 |
+
assert os.path.exists(ckpt_200)
|
| 112 |
+
assert os.path.exists(ckpt_300)
|
| 113 |
+
assert len(manager.previous_saved_paths) == 3
|
| 114 |
+
|
| 115 |
+
def test_full_save_cycle_max_ckpt_1(self, manager):
|
| 116 |
+
"""Simulate multiple save cycles with max_ckpt_to_keep=1."""
|
| 117 |
+
# First save
|
| 118 |
+
manager.ensure_checkpoint_capacity(1)
|
| 119 |
+
ckpt_100 = self._create_checkpoint_dir(100)
|
| 120 |
+
manager.register_checkpoint(ckpt_100, 1)
|
| 121 |
+
assert manager.previous_saved_paths == [ckpt_100]
|
| 122 |
+
|
| 123 |
+
# Second save - existing checkpoint must survive pre-save
|
| 124 |
+
manager.ensure_checkpoint_capacity(1)
|
| 125 |
+
assert os.path.exists(ckpt_100), "Bug: checkpoint deleted before save!"
|
| 126 |
+
|
| 127 |
+
ckpt_200 = self._create_checkpoint_dir(200)
|
| 128 |
+
manager.register_checkpoint(ckpt_200, 1)
|
| 129 |
+
assert not os.path.exists(ckpt_100)
|
| 130 |
+
assert manager.previous_saved_paths == [ckpt_200]
|
| 131 |
+
|
| 132 |
+
# Third save
|
| 133 |
+
manager.ensure_checkpoint_capacity(1)
|
| 134 |
+
assert os.path.exists(ckpt_200), "Bug: checkpoint deleted before save!"
|
| 135 |
+
|
| 136 |
+
ckpt_300 = self._create_checkpoint_dir(300)
|
| 137 |
+
manager.register_checkpoint(ckpt_300, 1)
|
| 138 |
+
assert not os.path.exists(ckpt_200)
|
| 139 |
+
assert manager.previous_saved_paths == [ckpt_300]
|