shahidul034 commited on
Commit
e00ff48
·
verified ·
1 Parent(s): ad5ba74

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. code/RL_model/verl/Search-R1/outputs/2026-02-01/20-18-07/.hydra/config.yaml +169 -0
  2. code/RL_model/verl/Search-R1/outputs/2026-02-01/20-18-07/.hydra/hydra.yaml +189 -0
  3. code/RL_model/verl/Search-R1/outputs/2026-02-01/20-18-07/.hydra/overrides.yaml +35 -0
  4. code/RL_model/verl/Search-R1/outputs/2026-02-01/20-18-07/main_ppo.log +0 -0
  5. code/RL_model/verl/Search-R1/outputs/2026-02-01/20-20-59/.hydra/config.yaml +169 -0
  6. code/RL_model/verl/Search-R1/outputs/2026-02-01/20-20-59/.hydra/hydra.yaml +189 -0
  7. code/RL_model/verl/Search-R1/outputs/2026-02-01/20-20-59/.hydra/overrides.yaml +35 -0
  8. code/RL_model/verl/Search-R1/outputs/2026-02-01/20-20-59/main_ppo.log +0 -0
  9. code/RL_model/verl/Search-R1/outputs/2026-02-01/20-24-15/.hydra/config.yaml +169 -0
  10. code/RL_model/verl/Search-R1/outputs/2026-02-01/20-24-15/.hydra/hydra.yaml +189 -0
  11. code/RL_model/verl/Search-R1/outputs/2026-02-01/20-24-15/.hydra/overrides.yaml +35 -0
  12. code/RL_model/verl/Search-R1/outputs/2026-02-01/20-24-15/main_ppo.log +0 -0
  13. code/RL_model/verl/Search-R1/outputs/2026-02-01/20-26-44/.hydra/config.yaml +169 -0
  14. code/RL_model/verl/Search-R1/outputs/2026-02-01/20-26-44/.hydra/hydra.yaml +189 -0
  15. code/RL_model/verl/Search-R1/outputs/2026-02-01/20-26-44/.hydra/overrides.yaml +35 -0
  16. code/RL_model/verl/Search-R1/outputs/2026-02-01/20-33-01/.hydra/config.yaml +169 -0
  17. code/RL_model/verl/Search-R1/outputs/2026-02-01/20-33-01/.hydra/hydra.yaml +189 -0
  18. code/RL_model/verl/Search-R1/outputs/2026-02-01/20-33-01/.hydra/overrides.yaml +35 -0
  19. code/RL_model/verl/Search-R1/outputs/2026-02-01/20-33-01/main_ppo.log +0 -0
  20. code/RL_model/verl/Search-R1/outputs/2026-02-01/20-35-38/.hydra/config.yaml +169 -0
  21. code/RL_model/verl/Search-R1/outputs/2026-02-01/20-35-38/.hydra/hydra.yaml +189 -0
  22. code/RL_model/verl/Search-R1/outputs/2026-02-01/20-35-38/.hydra/overrides.yaml +35 -0
  23. code/RL_model/verl/Search-R1/outputs/2026-02-01/20-35-38/main_ppo.log +0 -0
  24. code/RL_model/verl/Search-R1/outputs/2026-02-01/20-41-08/.hydra/config.yaml +169 -0
  25. code/RL_model/verl/Search-R1/outputs/2026-02-01/20-41-08/.hydra/hydra.yaml +189 -0
  26. code/RL_model/verl/Search-R1/outputs/2026-02-01/20-41-08/.hydra/overrides.yaml +35 -0
  27. code/RL_model/verl/Search-R1/outputs/2026-02-01/20-41-08/main_ppo.log +0 -0
  28. code/RL_model/verl/Search-R1/outputs/2026-02-01/20-42-57/.hydra/config.yaml +169 -0
  29. code/RL_model/verl/Search-R1/outputs/2026-02-01/20-42-57/.hydra/hydra.yaml +189 -0
  30. code/RL_model/verl/Search-R1/outputs/2026-02-01/20-42-57/.hydra/overrides.yaml +35 -0
  31. code/RL_model/verl/Search-R1/outputs/2026-02-01/20-42-57/main_ppo.log +0 -0
  32. code/RL_model/verl/Search-R1/search_r1/llm_agent/__init__.py +0 -0
  33. code/RL_model/verl/Search-R1/search_r1/llm_agent/generation.py +469 -0
  34. code/RL_model/verl/Search-R1/search_r1/llm_agent/tensor_helper.py +75 -0
  35. code/RL_model/verl/Search-R1/search_r1/search/build_index.sh +19 -0
  36. code/RL_model/verl/Search-R1/search_r1/search/google_search_server.py +202 -0
  37. code/RL_model/verl/Search-R1/search_r1/search/index_builder.py +349 -0
  38. code/RL_model/verl/Search-R1/search_r1/search/rerank_server.py +161 -0
  39. code/RL_model/verl/Search-R1/search_r1/search/retrieval.py +368 -0
  40. code/RL_model/verl/Search-R1/search_r1/search/retrieval.sh +25 -0
  41. code/RL_model/verl/Search-R1/search_r1/search/retrieval_request.py +23 -0
  42. code/RL_model/verl/Search-R1/search_r1/search/retrieval_rerank_server.py +123 -0
  43. code/RL_model/verl/Search-R1/search_r1/search/retrieval_server.py +392 -0
  44. code/RL_model/verl/Search-R1/search_r1/search/serp_search_server.py +112 -0
  45. code/RL_model/verl/Search-R1/verl.egg-info/SOURCES.txt +190 -0
  46. code/RL_model/verl/Search-R1/verl/single_controller/__init__.py +20 -0
  47. code/RL_model/verl/Search-R1/verl/trainer/__init__.py +13 -0
  48. code/RL_model/verl/Search-R1/verl/trainer/main_eval.py +69 -0
  49. code/RL_model/verl/Search-R1/verl/utils/__init__.py +18 -0
  50. code/RL_model/verl/Search-R1/verl/utils/config.py +23 -0
code/RL_model/verl/Search-R1/outputs/2026-02-01/20-18-07/.hydra/config.yaml ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ tokenizer: null
3
+ train_files: /home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet
4
+ val_files: /home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet
5
+ train_data_num: null
6
+ val_data_num: null
7
+ prompt_key: prompt
8
+ max_prompt_length: 4096
9
+ max_response_length: 1024
10
+ max_start_length: 256
11
+ max_obs_length: 512
12
+ train_batch_size: 128
13
+ val_batch_size: 64
14
+ return_raw_input_ids: false
15
+ return_raw_chat: false
16
+ shuffle_train_dataloader: true
17
+ actor_rollout_ref:
18
+ hybrid_engine: true
19
+ model:
20
+ path: Qwen/Qwen3-4B-Instruct-2507
21
+ external_lib: null
22
+ override_config: {}
23
+ enable_gradient_checkpointing: true
24
+ use_remove_padding: true
25
+ actor:
26
+ strategy: fsdp
27
+ ppo_mini_batch_size: 64
28
+ ppo_micro_batch_size: 64
29
+ use_dynamic_bsz: false
30
+ ppo_max_token_len_per_gpu: 16384
31
+ grad_clip: 1.0
32
+ state_masking: false
33
+ clip_ratio: 0.2
34
+ entropy_coeff: 0.001
35
+ use_kl_loss: false
36
+ kl_loss_coef: 0.001
37
+ kl_loss_type: low_var_kl
38
+ ppo_epochs: 1
39
+ shuffle: false
40
+ ulysses_sequence_parallel_size: 1
41
+ optim:
42
+ lr: 1.0e-06
43
+ lr_warmup_steps_ratio: 0.0
44
+ min_lr_ratio: null
45
+ warmup_style: constant
46
+ total_training_steps: -1
47
+ fsdp_config:
48
+ wrap_policy:
49
+ min_num_params: 0
50
+ param_offload: true
51
+ grad_offload: false
52
+ optimizer_offload: true
53
+ fsdp_size: -1
54
+ ppo_micro_batch_size_per_gpu: 16
55
+ ref:
56
+ fsdp_config:
57
+ param_offload: true
58
+ wrap_policy:
59
+ min_num_params: 0
60
+ fsdp_size: -1
61
+ log_prob_micro_batch_size: 64
62
+ log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
63
+ log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
64
+ ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size}
65
+ rollout:
66
+ name: vllm
67
+ temperature: 1.0
68
+ top_k: -1
69
+ top_p: 0.95
70
+ prompt_length: ${data.max_prompt_length}
71
+ response_length: ${data.max_response_length}
72
+ dtype: bfloat16
73
+ gpu_memory_utilization: 0.4
74
+ ignore_eos: false
75
+ enforce_eager: true
76
+ free_cache_engine: true
77
+ load_format: dummy_dtensor
78
+ tensor_model_parallel_size: 1
79
+ max_num_batched_tokens: 8192
80
+ max_num_seqs: 1024
81
+ log_prob_micro_batch_size: 64
82
+ log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
83
+ log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
84
+ do_sample: true
85
+ 'n': 1
86
+ n_agent: 1
87
+ critic:
88
+ strategy: fsdp
89
+ optim:
90
+ lr: 1.0e-05
91
+ lr_warmup_steps_ratio: 0.0
92
+ min_lr_ratio: null
93
+ warmup_style: constant
94
+ total_training_steps: -1
95
+ model:
96
+ path: ~/models/deepseek-llm-7b-chat
97
+ tokenizer_path: ${actor_rollout_ref.model.path}
98
+ override_config: {}
99
+ external_lib: ${actor_rollout_ref.model.external_lib}
100
+ enable_gradient_checkpointing: false
101
+ use_remove_padding: false
102
+ fsdp_config:
103
+ param_offload: false
104
+ grad_offload: false
105
+ optimizer_offload: false
106
+ wrap_policy:
107
+ min_num_params: 0
108
+ fsdp_size: -1
109
+ ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
110
+ ppo_micro_batch_size: 64
111
+ forward_micro_batch_size: ${critic.ppo_micro_batch_size}
112
+ use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
113
+ ppo_max_token_len_per_gpu: 32768
114
+ forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}
115
+ ulysses_sequence_parallel_size: 1
116
+ ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
117
+ shuffle: ${actor_rollout_ref.actor.shuffle}
118
+ grad_clip: 1.0
119
+ cliprange_value: 0.5
120
+ reward_model:
121
+ enable: false
122
+ strategy: fsdp
123
+ model:
124
+ input_tokenizer: ${actor_rollout_ref.model.path}
125
+ path: ~/models/FsfairX-LLaMA3-RM-v0.1
126
+ external_lib: ${actor_rollout_ref.model.external_lib}
127
+ use_remove_padding: false
128
+ fsdp_config:
129
+ min_num_params: 0
130
+ param_offload: false
131
+ micro_batch_size: 64
132
+ max_length: null
133
+ ulysses_sequence_parallel_size: 1
134
+ use_dynamic_bsz: ${critic.use_dynamic_bsz}
135
+ forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}
136
+ structure_format_score: 0
137
+ final_format_score: 0
138
+ retrieval_score: 0
139
+ retriever:
140
+ url: http://127.0.0.1:8000/retrieve
141
+ topk: 3
142
+ algorithm:
143
+ gamma: 1.0
144
+ lam: 1.0
145
+ adv_estimator: grpo
146
+ no_think_rl: false
147
+ kl_penalty: kl
148
+ kl_ctrl:
149
+ type: fixed
150
+ kl_coef: 0.001
151
+ state_masking:
152
+ start_state_marker: <information>
153
+ end_state_marker: </information>
154
+ trainer:
155
+ total_epochs: 15
156
+ total_training_steps: 1005
157
+ project_name: ''
158
+ experiment_name: llm_guard_3B_10k_v2
159
+ logger:
160
+ - wandb
161
+ nnodes: 1
162
+ n_gpus_per_node: 2
163
+ save_freq: 100
164
+ test_freq: 50
165
+ critic_warmup: 0
166
+ default_hdfs_dir: ~/experiments/gsm8k/ppo/${trainer.experiment_name}
167
+ default_local_dir: verl_checkpoints/llm_guard_3B_10k_v2
168
+ max_turns: 1
169
+ do_search: false
code/RL_model/verl/Search-R1/outputs/2026-02-01/20-18-07/.hydra/hydra.yaml ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+ sweep:
5
+ dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
6
+ subdir: ${hydra.job.num}
7
+ launcher:
8
+ _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
9
+ sweeper:
10
+ _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
11
+ max_batch_size: null
12
+ params: null
13
+ help:
14
+ app_name: ${hydra.job.name}
15
+ header: '${hydra.help.app_name} is powered by Hydra.
16
+
17
+ '
18
+ footer: 'Powered by Hydra (https://hydra.cc)
19
+
20
+ Use --hydra-help to view Hydra specific help
21
+
22
+ '
23
+ template: '${hydra.help.header}
24
+
25
+ == Configuration groups ==
26
+
27
+ Compose your configuration from those groups (group=option)
28
+
29
+
30
+ $APP_CONFIG_GROUPS
31
+
32
+
33
+ == Config ==
34
+
35
+ Override anything in the config (foo.bar=value)
36
+
37
+
38
+ $CONFIG
39
+
40
+
41
+ ${hydra.help.footer}
42
+
43
+ '
44
+ hydra_help:
45
+ template: 'Hydra (${hydra.runtime.version})
46
+
47
+ See https://hydra.cc for more info.
48
+
49
+
50
+ == Flags ==
51
+
52
+ $FLAGS_HELP
53
+
54
+
55
+ == Configuration groups ==
56
+
57
+ Compose your configuration from those groups (For example, append hydra/job_logging=disabled
58
+ to command line)
59
+
60
+
61
+ $HYDRA_CONFIG_GROUPS
62
+
63
+
64
+ Use ''--cfg hydra'' to Show the Hydra config.
65
+
66
+ '
67
+ hydra_help: ???
68
+ hydra_logging:
69
+ version: 1
70
+ formatters:
71
+ simple:
72
+ format: '[%(asctime)s][HYDRA] %(message)s'
73
+ handlers:
74
+ console:
75
+ class: logging.StreamHandler
76
+ formatter: simple
77
+ stream: ext://sys.stdout
78
+ root:
79
+ level: INFO
80
+ handlers:
81
+ - console
82
+ loggers:
83
+ logging_example:
84
+ level: DEBUG
85
+ disable_existing_loggers: false
86
+ job_logging:
87
+ version: 1
88
+ formatters:
89
+ simple:
90
+ format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
91
+ handlers:
92
+ console:
93
+ class: logging.StreamHandler
94
+ formatter: simple
95
+ stream: ext://sys.stdout
96
+ file:
97
+ class: logging.FileHandler
98
+ formatter: simple
99
+ filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
100
+ root:
101
+ level: INFO
102
+ handlers:
103
+ - console
104
+ - file
105
+ disable_existing_loggers: false
106
+ env: {}
107
+ mode: RUN
108
+ searchpath: []
109
+ callbacks: {}
110
+ output_subdir: .hydra
111
+ overrides:
112
+ hydra:
113
+ - hydra.mode=RUN
114
+ task:
115
+ - data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet
116
+ - data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet
117
+ - data.train_batch_size=128
118
+ - data.val_batch_size=64
119
+ - data.max_prompt_length=4096
120
+ - data.max_response_length=1024
121
+ - data.shuffle_train_dataloader=True
122
+ - algorithm.adv_estimator=grpo
123
+ - actor_rollout_ref.model.path=Qwen/Qwen3-4B-Instruct-2507
124
+ - actor_rollout_ref.model.enable_gradient_checkpointing=true
125
+ - actor_rollout_ref.model.use_remove_padding=True
126
+ - actor_rollout_ref.actor.optim.lr=1e-6
127
+ - actor_rollout_ref.actor.ppo_mini_batch_size=64
128
+ - +actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16
129
+ - actor_rollout_ref.actor.fsdp_config.param_offload=true
130
+ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=true
131
+ - actor_rollout_ref.rollout.log_prob_micro_batch_size=64
132
+ - actor_rollout_ref.rollout.tensor_model_parallel_size=1
133
+ - actor_rollout_ref.rollout.name=vllm
134
+ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4
135
+ - actor_rollout_ref.ref.log_prob_micro_batch_size=64
136
+ - actor_rollout_ref.ref.fsdp_config.param_offload=True
137
+ - actor_rollout_ref.actor.kl_loss_coef=0.001
138
+ - trainer.logger=[wandb]
139
+ - trainer.n_gpus_per_node=2
140
+ - trainer.nnodes=1
141
+ - trainer.save_freq=100
142
+ - trainer.test_freq=50
143
+ - trainer.project_name=
144
+ - trainer.experiment_name=llm_guard_3B_10k_v2
145
+ - trainer.total_epochs=15
146
+ - trainer.total_training_steps=1005
147
+ - trainer.default_local_dir=verl_checkpoints/llm_guard_3B_10k_v2
148
+ - do_search=false
149
+ - max_turns=1
150
+ job:
151
+ name: main_ppo
152
+ chdir: null
153
+ override_dirname: +actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16,actor_rollout_ref.actor.fsdp_config.optimizer_offload=true,actor_rollout_ref.actor.fsdp_config.param_offload=true,actor_rollout_ref.actor.kl_loss_coef=0.001,actor_rollout_ref.actor.optim.lr=1e-6,actor_rollout_ref.actor.ppo_mini_batch_size=64,actor_rollout_ref.model.enable_gradient_checkpointing=true,actor_rollout_ref.model.path=Qwen/Qwen3-4B-Instruct-2507,actor_rollout_ref.model.use_remove_padding=True,actor_rollout_ref.ref.fsdp_config.param_offload=True,actor_rollout_ref.ref.log_prob_micro_batch_size=64,actor_rollout_ref.rollout.gpu_memory_utilization=0.4,actor_rollout_ref.rollout.log_prob_micro_batch_size=64,actor_rollout_ref.rollout.name=vllm,actor_rollout_ref.rollout.tensor_model_parallel_size=1,algorithm.adv_estimator=grpo,data.max_prompt_length=4096,data.max_response_length=1024,data.shuffle_train_dataloader=True,data.train_batch_size=128,data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet,data.val_batch_size=64,data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet,do_search=false,max_turns=1,trainer.default_local_dir=verl_checkpoints/llm_guard_3B_10k_v2,trainer.experiment_name=llm_guard_3B_10k_v2,trainer.logger=[wandb],trainer.n_gpus_per_node=2,trainer.nnodes=1,trainer.project_name=,trainer.save_freq=100,trainer.test_freq=50,trainer.total_epochs=15,trainer.total_training_steps=1005
154
+ id: ???
155
+ num: ???
156
+ config_name: ppo_trainer
157
+ env_set: {}
158
+ env_copy: []
159
+ config:
160
+ override_dirname:
161
+ kv_sep: '='
162
+ item_sep: ','
163
+ exclude_keys: []
164
+ runtime:
165
+ version: 1.3.2
166
+ version_base: '1.3'
167
+ cwd: /data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1
168
+ config_sources:
169
+ - path: hydra.conf
170
+ schema: pkg
171
+ provider: hydra
172
+ - path: /data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1/verl/trainer/config
173
+ schema: file
174
+ provider: main
175
+ - path: ''
176
+ schema: structured
177
+ provider: schema
178
+ output_dir: /data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1/outputs/2026-02-01/20-18-07
179
+ choices:
180
+ hydra/env: default
181
+ hydra/callbacks: null
182
+ hydra/job_logging: default
183
+ hydra/hydra_logging: default
184
+ hydra/hydra_help: default
185
+ hydra/help: default
186
+ hydra/sweeper: basic
187
+ hydra/launcher: basic
188
+ hydra/output: default
189
+ verbose: false
code/RL_model/verl/Search-R1/outputs/2026-02-01/20-18-07/.hydra/overrides.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ - data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet
2
+ - data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet
3
+ - data.train_batch_size=128
4
+ - data.val_batch_size=64
5
+ - data.max_prompt_length=4096
6
+ - data.max_response_length=1024
7
+ - data.shuffle_train_dataloader=True
8
+ - algorithm.adv_estimator=grpo
9
+ - actor_rollout_ref.model.path=Qwen/Qwen3-4B-Instruct-2507
10
+ - actor_rollout_ref.model.enable_gradient_checkpointing=true
11
+ - actor_rollout_ref.model.use_remove_padding=True
12
+ - actor_rollout_ref.actor.optim.lr=1e-6
13
+ - actor_rollout_ref.actor.ppo_mini_batch_size=64
14
+ - +actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16
15
+ - actor_rollout_ref.actor.fsdp_config.param_offload=true
16
+ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=true
17
+ - actor_rollout_ref.rollout.log_prob_micro_batch_size=64
18
+ - actor_rollout_ref.rollout.tensor_model_parallel_size=1
19
+ - actor_rollout_ref.rollout.name=vllm
20
+ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4
21
+ - actor_rollout_ref.ref.log_prob_micro_batch_size=64
22
+ - actor_rollout_ref.ref.fsdp_config.param_offload=True
23
+ - actor_rollout_ref.actor.kl_loss_coef=0.001
24
+ - trainer.logger=[wandb]
25
+ - trainer.n_gpus_per_node=2
26
+ - trainer.nnodes=1
27
+ - trainer.save_freq=100
28
+ - trainer.test_freq=50
29
+ - trainer.project_name=
30
+ - trainer.experiment_name=llm_guard_3B_10k_v2
31
+ - trainer.total_epochs=15
32
+ - trainer.total_training_steps=1005
33
+ - trainer.default_local_dir=verl_checkpoints/llm_guard_3B_10k_v2
34
+ - do_search=false
35
+ - max_turns=1
code/RL_model/verl/Search-R1/outputs/2026-02-01/20-18-07/main_ppo.log ADDED
File without changes
code/RL_model/verl/Search-R1/outputs/2026-02-01/20-20-59/.hydra/config.yaml ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ tokenizer: null
3
+ train_files: /home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet
4
+ val_files: /home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet
5
+ train_data_num: null
6
+ val_data_num: null
7
+ prompt_key: prompt
8
+ max_prompt_length: 4096
9
+ max_response_length: 1024
10
+ max_start_length: 256
11
+ max_obs_length: 512
12
+ train_batch_size: 128
13
+ val_batch_size: 64
14
+ return_raw_input_ids: false
15
+ return_raw_chat: false
16
+ shuffle_train_dataloader: true
17
+ actor_rollout_ref:
18
+ hybrid_engine: true
19
+ model:
20
+ path: Qwen/Qwen3-4B-Instruct-2507
21
+ external_lib: null
22
+ override_config: {}
23
+ enable_gradient_checkpointing: true
24
+ use_remove_padding: true
25
+ actor:
26
+ strategy: fsdp
27
+ ppo_mini_batch_size: 64
28
+ ppo_micro_batch_size: 64
29
+ use_dynamic_bsz: false
30
+ ppo_max_token_len_per_gpu: 16384
31
+ grad_clip: 1.0
32
+ state_masking: false
33
+ clip_ratio: 0.2
34
+ entropy_coeff: 0.001
35
+ use_kl_loss: false
36
+ kl_loss_coef: 0.001
37
+ kl_loss_type: low_var_kl
38
+ ppo_epochs: 1
39
+ shuffle: false
40
+ ulysses_sequence_parallel_size: 1
41
+ optim:
42
+ lr: 1.0e-06
43
+ lr_warmup_steps_ratio: 0.0
44
+ min_lr_ratio: null
45
+ warmup_style: constant
46
+ total_training_steps: -1
47
+ fsdp_config:
48
+ wrap_policy:
49
+ min_num_params: 0
50
+ param_offload: true
51
+ grad_offload: false
52
+ optimizer_offload: true
53
+ fsdp_size: -1
54
+ ppo_micro_batch_size_per_gpu: 16
55
+ ref:
56
+ fsdp_config:
57
+ param_offload: true
58
+ wrap_policy:
59
+ min_num_params: 0
60
+ fsdp_size: -1
61
+ log_prob_micro_batch_size: 64
62
+ log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
63
+ log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
64
+ ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size}
65
+ rollout:
66
+ name: vllm
67
+ temperature: 1.0
68
+ top_k: -1
69
+ top_p: 0.95
70
+ prompt_length: ${data.max_prompt_length}
71
+ response_length: ${data.max_response_length}
72
+ dtype: bfloat16
73
+ gpu_memory_utilization: 0.4
74
+ ignore_eos: false
75
+ enforce_eager: true
76
+ free_cache_engine: true
77
+ load_format: dummy_dtensor
78
+ tensor_model_parallel_size: 1
79
+ max_num_batched_tokens: 8192
80
+ max_num_seqs: 1024
81
+ log_prob_micro_batch_size: 64
82
+ log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
83
+ log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
84
+ do_sample: true
85
+ 'n': 1
86
+ n_agent: 1
87
+ critic:
88
+ strategy: fsdp
89
+ optim:
90
+ lr: 1.0e-05
91
+ lr_warmup_steps_ratio: 0.0
92
+ min_lr_ratio: null
93
+ warmup_style: constant
94
+ total_training_steps: -1
95
+ model:
96
+ path: ~/models/deepseek-llm-7b-chat
97
+ tokenizer_path: ${actor_rollout_ref.model.path}
98
+ override_config: {}
99
+ external_lib: ${actor_rollout_ref.model.external_lib}
100
+ enable_gradient_checkpointing: false
101
+ use_remove_padding: false
102
+ fsdp_config:
103
+ param_offload: false
104
+ grad_offload: false
105
+ optimizer_offload: false
106
+ wrap_policy:
107
+ min_num_params: 0
108
+ fsdp_size: -1
109
+ ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
110
+ ppo_micro_batch_size: 64
111
+ forward_micro_batch_size: ${critic.ppo_micro_batch_size}
112
+ use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
113
+ ppo_max_token_len_per_gpu: 32768
114
+ forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}
115
+ ulysses_sequence_parallel_size: 1
116
+ ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
117
+ shuffle: ${actor_rollout_ref.actor.shuffle}
118
+ grad_clip: 1.0
119
+ cliprange_value: 0.5
120
+ reward_model:
121
+ enable: false
122
+ strategy: fsdp
123
+ model:
124
+ input_tokenizer: ${actor_rollout_ref.model.path}
125
+ path: ~/models/FsfairX-LLaMA3-RM-v0.1
126
+ external_lib: ${actor_rollout_ref.model.external_lib}
127
+ use_remove_padding: false
128
+ fsdp_config:
129
+ min_num_params: 0
130
+ param_offload: false
131
+ micro_batch_size: 64
132
+ max_length: null
133
+ ulysses_sequence_parallel_size: 1
134
+ use_dynamic_bsz: ${critic.use_dynamic_bsz}
135
+ forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}
136
+ structure_format_score: 0
137
+ final_format_score: 0
138
+ retrieval_score: 0
139
+ retriever:
140
+ url: http://127.0.0.1:8000/retrieve
141
+ topk: 3
142
+ algorithm:
143
+ gamma: 1.0
144
+ lam: 1.0
145
+ adv_estimator: grpo
146
+ no_think_rl: false
147
+ kl_penalty: kl
148
+ kl_ctrl:
149
+ type: fixed
150
+ kl_coef: 0.001
151
+ state_masking:
152
+ start_state_marker: <information>
153
+ end_state_marker: </information>
154
+ trainer:
155
+ total_epochs: 15
156
+ total_training_steps: 1005
157
+ project_name: ''
158
+ experiment_name: llm_guard_3B_10k_v2
159
+ logger:
160
+ - wandb
161
+ nnodes: 1
162
+ n_gpus_per_node: 2
163
+ save_freq: 100
164
+ test_freq: 50
165
+ critic_warmup: 0
166
+ default_hdfs_dir: ~/experiments/gsm8k/ppo/${trainer.experiment_name}
167
+ default_local_dir: verl_checkpoints/llm_guard_3B_10k_v2
168
+ max_turns: 1
169
+ do_search: false
code/RL_model/verl/Search-R1/outputs/2026-02-01/20-20-59/.hydra/hydra.yaml ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+ sweep:
5
+ dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
6
+ subdir: ${hydra.job.num}
7
+ launcher:
8
+ _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
9
+ sweeper:
10
+ _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
11
+ max_batch_size: null
12
+ params: null
13
+ help:
14
+ app_name: ${hydra.job.name}
15
+ header: '${hydra.help.app_name} is powered by Hydra.
16
+
17
+ '
18
+ footer: 'Powered by Hydra (https://hydra.cc)
19
+
20
+ Use --hydra-help to view Hydra specific help
21
+
22
+ '
23
+ template: '${hydra.help.header}
24
+
25
+ == Configuration groups ==
26
+
27
+ Compose your configuration from those groups (group=option)
28
+
29
+
30
+ $APP_CONFIG_GROUPS
31
+
32
+
33
+ == Config ==
34
+
35
+ Override anything in the config (foo.bar=value)
36
+
37
+
38
+ $CONFIG
39
+
40
+
41
+ ${hydra.help.footer}
42
+
43
+ '
44
+ hydra_help:
45
+ template: 'Hydra (${hydra.runtime.version})
46
+
47
+ See https://hydra.cc for more info.
48
+
49
+
50
+ == Flags ==
51
+
52
+ $FLAGS_HELP
53
+
54
+
55
+ == Configuration groups ==
56
+
57
+ Compose your configuration from those groups (For example, append hydra/job_logging=disabled
58
+ to command line)
59
+
60
+
61
+ $HYDRA_CONFIG_GROUPS
62
+
63
+
64
+ Use ''--cfg hydra'' to Show the Hydra config.
65
+
66
+ '
67
+ hydra_help: ???
68
+ hydra_logging:
69
+ version: 1
70
+ formatters:
71
+ simple:
72
+ format: '[%(asctime)s][HYDRA] %(message)s'
73
+ handlers:
74
+ console:
75
+ class: logging.StreamHandler
76
+ formatter: simple
77
+ stream: ext://sys.stdout
78
+ root:
79
+ level: INFO
80
+ handlers:
81
+ - console
82
+ loggers:
83
+ logging_example:
84
+ level: DEBUG
85
+ disable_existing_loggers: false
86
+ job_logging:
87
+ version: 1
88
+ formatters:
89
+ simple:
90
+ format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
91
+ handlers:
92
+ console:
93
+ class: logging.StreamHandler
94
+ formatter: simple
95
+ stream: ext://sys.stdout
96
+ file:
97
+ class: logging.FileHandler
98
+ formatter: simple
99
+ filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
100
+ root:
101
+ level: INFO
102
+ handlers:
103
+ - console
104
+ - file
105
+ disable_existing_loggers: false
106
+ env: {}
107
+ mode: RUN
108
+ searchpath: []
109
+ callbacks: {}
110
+ output_subdir: .hydra
111
+ overrides:
112
+ hydra:
113
+ - hydra.mode=RUN
114
+ task:
115
+ - data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet
116
+ - data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet
117
+ - data.train_batch_size=128
118
+ - data.val_batch_size=64
119
+ - data.max_prompt_length=4096
120
+ - data.max_response_length=1024
121
+ - data.shuffle_train_dataloader=True
122
+ - algorithm.adv_estimator=grpo
123
+ - actor_rollout_ref.model.path=Qwen/Qwen3-4B-Instruct-2507
124
+ - actor_rollout_ref.model.enable_gradient_checkpointing=true
125
+ - actor_rollout_ref.model.use_remove_padding=True
126
+ - actor_rollout_ref.actor.optim.lr=1e-6
127
+ - actor_rollout_ref.actor.ppo_mini_batch_size=64
128
+ - +actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16
129
+ - actor_rollout_ref.actor.fsdp_config.param_offload=true
130
+ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=true
131
+ - actor_rollout_ref.rollout.log_prob_micro_batch_size=64
132
+ - actor_rollout_ref.rollout.tensor_model_parallel_size=1
133
+ - actor_rollout_ref.rollout.name=vllm
134
+ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4
135
+ - actor_rollout_ref.ref.log_prob_micro_batch_size=64
136
+ - actor_rollout_ref.ref.fsdp_config.param_offload=True
137
+ - actor_rollout_ref.actor.kl_loss_coef=0.001
138
+ - trainer.logger=[wandb]
139
+ - trainer.n_gpus_per_node=2
140
+ - trainer.nnodes=1
141
+ - trainer.save_freq=100
142
+ - trainer.test_freq=50
143
+ - trainer.project_name=
144
+ - trainer.experiment_name=llm_guard_3B_10k_v2
145
+ - trainer.total_epochs=15
146
+ - trainer.total_training_steps=1005
147
+ - trainer.default_local_dir=verl_checkpoints/llm_guard_3B_10k_v2
148
+ - do_search=false
149
+ - max_turns=1
150
+ job:
151
+ name: main_ppo
152
+ chdir: null
153
+ override_dirname: +actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16,actor_rollout_ref.actor.fsdp_config.optimizer_offload=true,actor_rollout_ref.actor.fsdp_config.param_offload=true,actor_rollout_ref.actor.kl_loss_coef=0.001,actor_rollout_ref.actor.optim.lr=1e-6,actor_rollout_ref.actor.ppo_mini_batch_size=64,actor_rollout_ref.model.enable_gradient_checkpointing=true,actor_rollout_ref.model.path=Qwen/Qwen3-4B-Instruct-2507,actor_rollout_ref.model.use_remove_padding=True,actor_rollout_ref.ref.fsdp_config.param_offload=True,actor_rollout_ref.ref.log_prob_micro_batch_size=64,actor_rollout_ref.rollout.gpu_memory_utilization=0.4,actor_rollout_ref.rollout.log_prob_micro_batch_size=64,actor_rollout_ref.rollout.name=vllm,actor_rollout_ref.rollout.tensor_model_parallel_size=1,algorithm.adv_estimator=grpo,data.max_prompt_length=4096,data.max_response_length=1024,data.shuffle_train_dataloader=True,data.train_batch_size=128,data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet,data.val_batch_size=64,data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet,do_search=false,max_turns=1,trainer.default_local_dir=verl_checkpoints/llm_guard_3B_10k_v2,trainer.experiment_name=llm_guard_3B_10k_v2,trainer.logger=[wandb],trainer.n_gpus_per_node=2,trainer.nnodes=1,trainer.project_name=,trainer.save_freq=100,trainer.test_freq=50,trainer.total_epochs=15,trainer.total_training_steps=1005
154
+ id: ???
155
+ num: ???
156
+ config_name: ppo_trainer
157
+ env_set: {}
158
+ env_copy: []
159
+ config:
160
+ override_dirname:
161
+ kv_sep: '='
162
+ item_sep: ','
163
+ exclude_keys: []
164
+ runtime:
165
+ version: 1.3.2
166
+ version_base: '1.3'
167
+ cwd: /data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1
168
+ config_sources:
169
+ - path: hydra.conf
170
+ schema: pkg
171
+ provider: hydra
172
+ - path: /data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1/verl/trainer/config
173
+ schema: file
174
+ provider: main
175
+ - path: ''
176
+ schema: structured
177
+ provider: schema
178
+ output_dir: /data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1/outputs/2026-02-01/20-20-59
179
+ choices:
180
+ hydra/env: default
181
+ hydra/callbacks: null
182
+ hydra/job_logging: default
183
+ hydra/hydra_logging: default
184
+ hydra/hydra_help: default
185
+ hydra/help: default
186
+ hydra/sweeper: basic
187
+ hydra/launcher: basic
188
+ hydra/output: default
189
+ verbose: false
code/RL_model/verl/Search-R1/outputs/2026-02-01/20-20-59/.hydra/overrides.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ - data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet
2
+ - data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet
3
+ - data.train_batch_size=128
4
+ - data.val_batch_size=64
5
+ - data.max_prompt_length=4096
6
+ - data.max_response_length=1024
7
+ - data.shuffle_train_dataloader=True
8
+ - algorithm.adv_estimator=grpo
9
+ - actor_rollout_ref.model.path=Qwen/Qwen3-4B-Instruct-2507
10
+ - actor_rollout_ref.model.enable_gradient_checkpointing=true
11
+ - actor_rollout_ref.model.use_remove_padding=True
12
+ - actor_rollout_ref.actor.optim.lr=1e-6
13
+ - actor_rollout_ref.actor.ppo_mini_batch_size=64
14
+ - +actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16
15
+ - actor_rollout_ref.actor.fsdp_config.param_offload=true
16
+ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=true
17
+ - actor_rollout_ref.rollout.log_prob_micro_batch_size=64
18
+ - actor_rollout_ref.rollout.tensor_model_parallel_size=1
19
+ - actor_rollout_ref.rollout.name=vllm
20
+ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4
21
+ - actor_rollout_ref.ref.log_prob_micro_batch_size=64
22
+ - actor_rollout_ref.ref.fsdp_config.param_offload=True
23
+ - actor_rollout_ref.actor.kl_loss_coef=0.001
24
+ - trainer.logger=[wandb]
25
+ - trainer.n_gpus_per_node=2
26
+ - trainer.nnodes=1
27
+ - trainer.save_freq=100
28
+ - trainer.test_freq=50
29
+ - trainer.project_name=
30
+ - trainer.experiment_name=llm_guard_3B_10k_v2
31
+ - trainer.total_epochs=15
32
+ - trainer.total_training_steps=1005
33
+ - trainer.default_local_dir=verl_checkpoints/llm_guard_3B_10k_v2
34
+ - do_search=false
35
+ - max_turns=1
code/RL_model/verl/Search-R1/outputs/2026-02-01/20-20-59/main_ppo.log ADDED
File without changes
code/RL_model/verl/Search-R1/outputs/2026-02-01/20-24-15/.hydra/config.yaml ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ tokenizer: null
3
+ train_files: /home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet
4
+ val_files: /home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet
5
+ train_data_num: null
6
+ val_data_num: null
7
+ prompt_key: prompt
8
+ max_prompt_length: 4096
9
+ max_response_length: 1024
10
+ max_start_length: 256
11
+ max_obs_length: 512
12
+ train_batch_size: 128
13
+ val_batch_size: 64
14
+ return_raw_input_ids: false
15
+ return_raw_chat: false
16
+ shuffle_train_dataloader: true
17
+ actor_rollout_ref:
18
+ hybrid_engine: true
19
+ model:
20
+ path: Qwen/Qwen3-4B-Instruct-2507
21
+ external_lib: null
22
+ override_config: {}
23
+ enable_gradient_checkpointing: true
24
+ use_remove_padding: true
25
+ actor:
26
+ strategy: fsdp
27
+ ppo_mini_batch_size: 64
28
+ ppo_micro_batch_size: 64
29
+ use_dynamic_bsz: false
30
+ ppo_max_token_len_per_gpu: 16384
31
+ grad_clip: 1.0
32
+ state_masking: false
33
+ clip_ratio: 0.2
34
+ entropy_coeff: 0.001
35
+ use_kl_loss: false
36
+ kl_loss_coef: 0.001
37
+ kl_loss_type: low_var_kl
38
+ ppo_epochs: 1
39
+ shuffle: false
40
+ ulysses_sequence_parallel_size: 1
41
+ optim:
42
+ lr: 1.0e-06
43
+ lr_warmup_steps_ratio: 0.0
44
+ min_lr_ratio: null
45
+ warmup_style: constant
46
+ total_training_steps: -1
47
+ fsdp_config:
48
+ wrap_policy:
49
+ min_num_params: 0
50
+ param_offload: true
51
+ grad_offload: false
52
+ optimizer_offload: true
53
+ fsdp_size: -1
54
+ ppo_micro_batch_size_per_gpu: 16
55
+ ref:
56
+ fsdp_config:
57
+ param_offload: true
58
+ wrap_policy:
59
+ min_num_params: 0
60
+ fsdp_size: -1
61
+ log_prob_micro_batch_size: 64
62
+ log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
63
+ log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
64
+ ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size}
65
+ rollout:
66
+ name: vllm
67
+ temperature: 1.0
68
+ top_k: -1
69
+ top_p: 0.95
70
+ prompt_length: ${data.max_prompt_length}
71
+ response_length: ${data.max_response_length}
72
+ dtype: bfloat16
73
+ gpu_memory_utilization: 0.4
74
+ ignore_eos: false
75
+ enforce_eager: true
76
+ free_cache_engine: true
77
+ load_format: dummy_dtensor
78
+ tensor_model_parallel_size: 1
79
+ max_num_batched_tokens: 8192
80
+ max_num_seqs: 1024
81
+ log_prob_micro_batch_size: 64
82
+ log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
83
+ log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
84
+ do_sample: true
85
+ 'n': 1
86
+ n_agent: 1
87
+ critic:
88
+ strategy: fsdp
89
+ optim:
90
+ lr: 1.0e-05
91
+ lr_warmup_steps_ratio: 0.0
92
+ min_lr_ratio: null
93
+ warmup_style: constant
94
+ total_training_steps: -1
95
+ model:
96
+ path: ~/models/deepseek-llm-7b-chat
97
+ tokenizer_path: ${actor_rollout_ref.model.path}
98
+ override_config: {}
99
+ external_lib: ${actor_rollout_ref.model.external_lib}
100
+ enable_gradient_checkpointing: false
101
+ use_remove_padding: false
102
+ fsdp_config:
103
+ param_offload: false
104
+ grad_offload: false
105
+ optimizer_offload: false
106
+ wrap_policy:
107
+ min_num_params: 0
108
+ fsdp_size: -1
109
+ ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
110
+ ppo_micro_batch_size: 64
111
+ forward_micro_batch_size: ${critic.ppo_micro_batch_size}
112
+ use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
113
+ ppo_max_token_len_per_gpu: 32768
114
+ forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}
115
+ ulysses_sequence_parallel_size: 1
116
+ ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
117
+ shuffle: ${actor_rollout_ref.actor.shuffle}
118
+ grad_clip: 1.0
119
+ cliprange_value: 0.5
120
+ reward_model:
121
+ enable: false
122
+ strategy: fsdp
123
+ model:
124
+ input_tokenizer: ${actor_rollout_ref.model.path}
125
+ path: ~/models/FsfairX-LLaMA3-RM-v0.1
126
+ external_lib: ${actor_rollout_ref.model.external_lib}
127
+ use_remove_padding: false
128
+ fsdp_config:
129
+ min_num_params: 0
130
+ param_offload: false
131
+ micro_batch_size: 64
132
+ max_length: null
133
+ ulysses_sequence_parallel_size: 1
134
+ use_dynamic_bsz: ${critic.use_dynamic_bsz}
135
+ forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}
136
+ structure_format_score: 0
137
+ final_format_score: 0
138
+ retrieval_score: 0
139
+ retriever:
140
+ url: http://127.0.0.1:8000/retrieve
141
+ topk: 3
142
+ algorithm:
143
+ gamma: 1.0
144
+ lam: 1.0
145
+ adv_estimator: grpo
146
+ no_think_rl: false
147
+ kl_penalty: kl
148
+ kl_ctrl:
149
+ type: fixed
150
+ kl_coef: 0.001
151
+ state_masking:
152
+ start_state_marker: <information>
153
+ end_state_marker: </information>
154
+ trainer:
155
+ total_epochs: 15
156
+ total_training_steps: 1005
157
+ project_name: ''
158
+ experiment_name: llm_guard_3B_10k_v2
159
+ logger:
160
+ - wandb
161
+ nnodes: 1
162
+ n_gpus_per_node: 2
163
+ save_freq: 100
164
+ test_freq: 50
165
+ critic_warmup: 0
166
+ default_hdfs_dir: ~/experiments/gsm8k/ppo/${trainer.experiment_name}
167
+ default_local_dir: verl_checkpoints/llm_guard_3B_10k_v2
168
+ max_turns: 1
169
+ do_search: false
code/RL_model/verl/Search-R1/outputs/2026-02-01/20-24-15/.hydra/hydra.yaml ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+ sweep:
5
+ dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
6
+ subdir: ${hydra.job.num}
7
+ launcher:
8
+ _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
9
+ sweeper:
10
+ _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
11
+ max_batch_size: null
12
+ params: null
13
+ help:
14
+ app_name: ${hydra.job.name}
15
+ header: '${hydra.help.app_name} is powered by Hydra.
16
+
17
+ '
18
+ footer: 'Powered by Hydra (https://hydra.cc)
19
+
20
+ Use --hydra-help to view Hydra specific help
21
+
22
+ '
23
+ template: '${hydra.help.header}
24
+
25
+ == Configuration groups ==
26
+
27
+ Compose your configuration from those groups (group=option)
28
+
29
+
30
+ $APP_CONFIG_GROUPS
31
+
32
+
33
+ == Config ==
34
+
35
+ Override anything in the config (foo.bar=value)
36
+
37
+
38
+ $CONFIG
39
+
40
+
41
+ ${hydra.help.footer}
42
+
43
+ '
44
+ hydra_help:
45
+ template: 'Hydra (${hydra.runtime.version})
46
+
47
+ See https://hydra.cc for more info.
48
+
49
+
50
+ == Flags ==
51
+
52
+ $FLAGS_HELP
53
+
54
+
55
+ == Configuration groups ==
56
+
57
+ Compose your configuration from those groups (For example, append hydra/job_logging=disabled
58
+ to command line)
59
+
60
+
61
+ $HYDRA_CONFIG_GROUPS
62
+
63
+
64
+ Use ''--cfg hydra'' to Show the Hydra config.
65
+
66
+ '
67
+ hydra_help: ???
68
+ hydra_logging:
69
+ version: 1
70
+ formatters:
71
+ simple:
72
+ format: '[%(asctime)s][HYDRA] %(message)s'
73
+ handlers:
74
+ console:
75
+ class: logging.StreamHandler
76
+ formatter: simple
77
+ stream: ext://sys.stdout
78
+ root:
79
+ level: INFO
80
+ handlers:
81
+ - console
82
+ loggers:
83
+ logging_example:
84
+ level: DEBUG
85
+ disable_existing_loggers: false
86
+ job_logging:
87
+ version: 1
88
+ formatters:
89
+ simple:
90
+ format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
91
+ handlers:
92
+ console:
93
+ class: logging.StreamHandler
94
+ formatter: simple
95
+ stream: ext://sys.stdout
96
+ file:
97
+ class: logging.FileHandler
98
+ formatter: simple
99
+ filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
100
+ root:
101
+ level: INFO
102
+ handlers:
103
+ - console
104
+ - file
105
+ disable_existing_loggers: false
106
+ env: {}
107
+ mode: RUN
108
+ searchpath: []
109
+ callbacks: {}
110
+ output_subdir: .hydra
111
+ overrides:
112
+ hydra:
113
+ - hydra.mode=RUN
114
+ task:
115
+ - data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet
116
+ - data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet
117
+ - data.train_batch_size=128
118
+ - data.val_batch_size=64
119
+ - data.max_prompt_length=4096
120
+ - data.max_response_length=1024
121
+ - data.shuffle_train_dataloader=True
122
+ - algorithm.adv_estimator=grpo
123
+ - actor_rollout_ref.model.path=Qwen/Qwen3-4B-Instruct-2507
124
+ - actor_rollout_ref.model.enable_gradient_checkpointing=true
125
+ - actor_rollout_ref.model.use_remove_padding=True
126
+ - actor_rollout_ref.actor.optim.lr=1e-6
127
+ - actor_rollout_ref.actor.ppo_mini_batch_size=64
128
+ - +actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16
129
+ - actor_rollout_ref.actor.fsdp_config.param_offload=true
130
+ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=true
131
+ - actor_rollout_ref.rollout.log_prob_micro_batch_size=64
132
+ - actor_rollout_ref.rollout.tensor_model_parallel_size=1
133
+ - actor_rollout_ref.rollout.name=vllm
134
+ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4
135
+ - actor_rollout_ref.ref.log_prob_micro_batch_size=64
136
+ - actor_rollout_ref.ref.fsdp_config.param_offload=True
137
+ - actor_rollout_ref.actor.kl_loss_coef=0.001
138
+ - trainer.logger=[wandb]
139
+ - trainer.n_gpus_per_node=2
140
+ - trainer.nnodes=1
141
+ - trainer.save_freq=100
142
+ - trainer.test_freq=50
143
+ - trainer.project_name=
144
+ - trainer.experiment_name=llm_guard_3B_10k_v2
145
+ - trainer.total_epochs=15
146
+ - trainer.total_training_steps=1005
147
+ - trainer.default_local_dir=verl_checkpoints/llm_guard_3B_10k_v2
148
+ - do_search=false
149
+ - max_turns=1
150
+ job:
151
+ name: main_ppo
152
+ chdir: null
153
+ override_dirname: +actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16,actor_rollout_ref.actor.fsdp_config.optimizer_offload=true,actor_rollout_ref.actor.fsdp_config.param_offload=true,actor_rollout_ref.actor.kl_loss_coef=0.001,actor_rollout_ref.actor.optim.lr=1e-6,actor_rollout_ref.actor.ppo_mini_batch_size=64,actor_rollout_ref.model.enable_gradient_checkpointing=true,actor_rollout_ref.model.path=Qwen/Qwen3-4B-Instruct-2507,actor_rollout_ref.model.use_remove_padding=True,actor_rollout_ref.ref.fsdp_config.param_offload=True,actor_rollout_ref.ref.log_prob_micro_batch_size=64,actor_rollout_ref.rollout.gpu_memory_utilization=0.4,actor_rollout_ref.rollout.log_prob_micro_batch_size=64,actor_rollout_ref.rollout.name=vllm,actor_rollout_ref.rollout.tensor_model_parallel_size=1,algorithm.adv_estimator=grpo,data.max_prompt_length=4096,data.max_response_length=1024,data.shuffle_train_dataloader=True,data.train_batch_size=128,data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet,data.val_batch_size=64,data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet,do_search=false,max_turns=1,trainer.default_local_dir=verl_checkpoints/llm_guard_3B_10k_v2,trainer.experiment_name=llm_guard_3B_10k_v2,trainer.logger=[wandb],trainer.n_gpus_per_node=2,trainer.nnodes=1,trainer.project_name=,trainer.save_freq=100,trainer.test_freq=50,trainer.total_epochs=15,trainer.total_training_steps=1005
154
+ id: ???
155
+ num: ???
156
+ config_name: ppo_trainer
157
+ env_set: {}
158
+ env_copy: []
159
+ config:
160
+ override_dirname:
161
+ kv_sep: '='
162
+ item_sep: ','
163
+ exclude_keys: []
164
+ runtime:
165
+ version: 1.3.2
166
+ version_base: '1.3'
167
+ cwd: /data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1
168
+ config_sources:
169
+ - path: hydra.conf
170
+ schema: pkg
171
+ provider: hydra
172
+ - path: /data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1/verl/trainer/config
173
+ schema: file
174
+ provider: main
175
+ - path: ''
176
+ schema: structured
177
+ provider: schema
178
+ output_dir: /data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1/outputs/2026-02-01/20-24-15
179
+ choices:
180
+ hydra/env: default
181
+ hydra/callbacks: null
182
+ hydra/job_logging: default
183
+ hydra/hydra_logging: default
184
+ hydra/hydra_help: default
185
+ hydra/help: default
186
+ hydra/sweeper: basic
187
+ hydra/launcher: basic
188
+ hydra/output: default
189
+ verbose: false
code/RL_model/verl/Search-R1/outputs/2026-02-01/20-24-15/.hydra/overrides.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ - data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet
2
+ - data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet
3
+ - data.train_batch_size=128
4
+ - data.val_batch_size=64
5
+ - data.max_prompt_length=4096
6
+ - data.max_response_length=1024
7
+ - data.shuffle_train_dataloader=True
8
+ - algorithm.adv_estimator=grpo
9
+ - actor_rollout_ref.model.path=Qwen/Qwen3-4B-Instruct-2507
10
+ - actor_rollout_ref.model.enable_gradient_checkpointing=true
11
+ - actor_rollout_ref.model.use_remove_padding=True
12
+ - actor_rollout_ref.actor.optim.lr=1e-6
13
+ - actor_rollout_ref.actor.ppo_mini_batch_size=64
14
+ - +actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16
15
+ - actor_rollout_ref.actor.fsdp_config.param_offload=true
16
+ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=true
17
+ - actor_rollout_ref.rollout.log_prob_micro_batch_size=64
18
+ - actor_rollout_ref.rollout.tensor_model_parallel_size=1
19
+ - actor_rollout_ref.rollout.name=vllm
20
+ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4
21
+ - actor_rollout_ref.ref.log_prob_micro_batch_size=64
22
+ - actor_rollout_ref.ref.fsdp_config.param_offload=True
23
+ - actor_rollout_ref.actor.kl_loss_coef=0.001
24
+ - trainer.logger=[wandb]
25
+ - trainer.n_gpus_per_node=2
26
+ - trainer.nnodes=1
27
+ - trainer.save_freq=100
28
+ - trainer.test_freq=50
29
+ - trainer.project_name=
30
+ - trainer.experiment_name=llm_guard_3B_10k_v2
31
+ - trainer.total_epochs=15
32
+ - trainer.total_training_steps=1005
33
+ - trainer.default_local_dir=verl_checkpoints/llm_guard_3B_10k_v2
34
+ - do_search=false
35
+ - max_turns=1
code/RL_model/verl/Search-R1/outputs/2026-02-01/20-24-15/main_ppo.log ADDED
File without changes
code/RL_model/verl/Search-R1/outputs/2026-02-01/20-26-44/.hydra/config.yaml ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ tokenizer: null
3
+ train_files: /home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet
4
+ val_files: /home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet
5
+ train_data_num: null
6
+ val_data_num: null
7
+ prompt_key: prompt
8
+ max_prompt_length: 4096
9
+ max_response_length: 1024
10
+ max_start_length: 256
11
+ max_obs_length: 512
12
+ train_batch_size: 128
13
+ val_batch_size: 64
14
+ return_raw_input_ids: false
15
+ return_raw_chat: false
16
+ shuffle_train_dataloader: true
17
+ actor_rollout_ref:
18
+ hybrid_engine: true
19
+ model:
20
+ path: Qwen/Qwen3-4B-Instruct-2507
21
+ external_lib: null
22
+ override_config: {}
23
+ enable_gradient_checkpointing: true
24
+ use_remove_padding: false
25
+ actor:
26
+ strategy: fsdp
27
+ ppo_mini_batch_size: 64
28
+ ppo_micro_batch_size: 64
29
+ use_dynamic_bsz: false
30
+ ppo_max_token_len_per_gpu: 16384
31
+ grad_clip: 1.0
32
+ state_masking: false
33
+ clip_ratio: 0.2
34
+ entropy_coeff: 0.001
35
+ use_kl_loss: false
36
+ kl_loss_coef: 0.001
37
+ kl_loss_type: low_var_kl
38
+ ppo_epochs: 1
39
+ shuffle: false
40
+ ulysses_sequence_parallel_size: 1
41
+ optim:
42
+ lr: 1.0e-06
43
+ lr_warmup_steps_ratio: 0.0
44
+ min_lr_ratio: null
45
+ warmup_style: constant
46
+ total_training_steps: -1
47
+ fsdp_config:
48
+ wrap_policy:
49
+ min_num_params: 0
50
+ param_offload: true
51
+ grad_offload: false
52
+ optimizer_offload: true
53
+ fsdp_size: -1
54
+ ppo_micro_batch_size_per_gpu: 16
55
+ ref:
56
+ fsdp_config:
57
+ param_offload: true
58
+ wrap_policy:
59
+ min_num_params: 0
60
+ fsdp_size: -1
61
+ log_prob_micro_batch_size: 64
62
+ log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
63
+ log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
64
+ ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size}
65
+ rollout:
66
+ name: vllm
67
+ temperature: 1.0
68
+ top_k: -1
69
+ top_p: 0.95
70
+ prompt_length: ${data.max_prompt_length}
71
+ response_length: ${data.max_response_length}
72
+ dtype: bfloat16
73
+ gpu_memory_utilization: 0.4
74
+ ignore_eos: false
75
+ enforce_eager: true
76
+ free_cache_engine: true
77
+ load_format: dummy_dtensor
78
+ tensor_model_parallel_size: 1
79
+ max_num_batched_tokens: 8192
80
+ max_num_seqs: 1024
81
+ log_prob_micro_batch_size: 64
82
+ log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
83
+ log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
84
+ do_sample: true
85
+ 'n': 1
86
+ n_agent: 1
87
+ critic:
88
+ strategy: fsdp
89
+ optim:
90
+ lr: 1.0e-05
91
+ lr_warmup_steps_ratio: 0.0
92
+ min_lr_ratio: null
93
+ warmup_style: constant
94
+ total_training_steps: -1
95
+ model:
96
+ path: ~/models/deepseek-llm-7b-chat
97
+ tokenizer_path: ${actor_rollout_ref.model.path}
98
+ override_config: {}
99
+ external_lib: ${actor_rollout_ref.model.external_lib}
100
+ enable_gradient_checkpointing: false
101
+ use_remove_padding: false
102
+ fsdp_config:
103
+ param_offload: false
104
+ grad_offload: false
105
+ optimizer_offload: false
106
+ wrap_policy:
107
+ min_num_params: 0
108
+ fsdp_size: -1
109
+ ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
110
+ ppo_micro_batch_size: 64
111
+ forward_micro_batch_size: ${critic.ppo_micro_batch_size}
112
+ use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
113
+ ppo_max_token_len_per_gpu: 32768
114
+ forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}
115
+ ulysses_sequence_parallel_size: 1
116
+ ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
117
+ shuffle: ${actor_rollout_ref.actor.shuffle}
118
+ grad_clip: 1.0
119
+ cliprange_value: 0.5
120
+ reward_model:
121
+ enable: false
122
+ strategy: fsdp
123
+ model:
124
+ input_tokenizer: ${actor_rollout_ref.model.path}
125
+ path: ~/models/FsfairX-LLaMA3-RM-v0.1
126
+ external_lib: ${actor_rollout_ref.model.external_lib}
127
+ use_remove_padding: false
128
+ fsdp_config:
129
+ min_num_params: 0
130
+ param_offload: false
131
+ micro_batch_size: 64
132
+ max_length: null
133
+ ulysses_sequence_parallel_size: 1
134
+ use_dynamic_bsz: ${critic.use_dynamic_bsz}
135
+ forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}
136
+ structure_format_score: 0
137
+ final_format_score: 0
138
+ retrieval_score: 0
139
+ retriever:
140
+ url: http://127.0.0.1:8000/retrieve
141
+ topk: 3
142
+ algorithm:
143
+ gamma: 1.0
144
+ lam: 1.0
145
+ adv_estimator: grpo
146
+ no_think_rl: false
147
+ kl_penalty: kl
148
+ kl_ctrl:
149
+ type: fixed
150
+ kl_coef: 0.001
151
+ state_masking:
152
+ start_state_marker: <information>
153
+ end_state_marker: </information>
154
+ trainer:
155
+ total_epochs: 15
156
+ total_training_steps: 1005
157
+ project_name: ''
158
+ experiment_name: llm_guard_3B_10k_v2
159
+ logger:
160
+ - wandb
161
+ nnodes: 1
162
+ n_gpus_per_node: 2
163
+ save_freq: 100
164
+ test_freq: 50
165
+ critic_warmup: 0
166
+ default_hdfs_dir: ~/experiments/gsm8k/ppo/${trainer.experiment_name}
167
+ default_local_dir: verl_checkpoints/llm_guard_3B_10k_v2
168
+ max_turns: 1
169
+ do_search: false
code/RL_model/verl/Search-R1/outputs/2026-02-01/20-26-44/.hydra/hydra.yaml ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+ sweep:
5
+ dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
6
+ subdir: ${hydra.job.num}
7
+ launcher:
8
+ _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
9
+ sweeper:
10
+ _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
11
+ max_batch_size: null
12
+ params: null
13
+ help:
14
+ app_name: ${hydra.job.name}
15
+ header: '${hydra.help.app_name} is powered by Hydra.
16
+
17
+ '
18
+ footer: 'Powered by Hydra (https://hydra.cc)
19
+
20
+ Use --hydra-help to view Hydra specific help
21
+
22
+ '
23
+ template: '${hydra.help.header}
24
+
25
+ == Configuration groups ==
26
+
27
+ Compose your configuration from those groups (group=option)
28
+
29
+
30
+ $APP_CONFIG_GROUPS
31
+
32
+
33
+ == Config ==
34
+
35
+ Override anything in the config (foo.bar=value)
36
+
37
+
38
+ $CONFIG
39
+
40
+
41
+ ${hydra.help.footer}
42
+
43
+ '
44
+ hydra_help:
45
+ template: 'Hydra (${hydra.runtime.version})
46
+
47
+ See https://hydra.cc for more info.
48
+
49
+
50
+ == Flags ==
51
+
52
+ $FLAGS_HELP
53
+
54
+
55
+ == Configuration groups ==
56
+
57
+ Compose your configuration from those groups (For example, append hydra/job_logging=disabled
58
+ to command line)
59
+
60
+
61
+ $HYDRA_CONFIG_GROUPS
62
+
63
+
64
+ Use ''--cfg hydra'' to Show the Hydra config.
65
+
66
+ '
67
+ hydra_help: ???
68
+ hydra_logging:
69
+ version: 1
70
+ formatters:
71
+ simple:
72
+ format: '[%(asctime)s][HYDRA] %(message)s'
73
+ handlers:
74
+ console:
75
+ class: logging.StreamHandler
76
+ formatter: simple
77
+ stream: ext://sys.stdout
78
+ root:
79
+ level: INFO
80
+ handlers:
81
+ - console
82
+ loggers:
83
+ logging_example:
84
+ level: DEBUG
85
+ disable_existing_loggers: false
86
+ job_logging:
87
+ version: 1
88
+ formatters:
89
+ simple:
90
+ format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
91
+ handlers:
92
+ console:
93
+ class: logging.StreamHandler
94
+ formatter: simple
95
+ stream: ext://sys.stdout
96
+ file:
97
+ class: logging.FileHandler
98
+ formatter: simple
99
+ filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
100
+ root:
101
+ level: INFO
102
+ handlers:
103
+ - console
104
+ - file
105
+ disable_existing_loggers: false
106
+ env: {}
107
+ mode: RUN
108
+ searchpath: []
109
+ callbacks: {}
110
+ output_subdir: .hydra
111
+ overrides:
112
+ hydra:
113
+ - hydra.mode=RUN
114
+ task:
115
+ - data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet
116
+ - data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet
117
+ - data.train_batch_size=128
118
+ - data.val_batch_size=64
119
+ - data.max_prompt_length=4096
120
+ - data.max_response_length=1024
121
+ - data.shuffle_train_dataloader=True
122
+ - algorithm.adv_estimator=grpo
123
+ - actor_rollout_ref.model.path=Qwen/Qwen3-4B-Instruct-2507
124
+ - actor_rollout_ref.model.enable_gradient_checkpointing=true
125
+ - actor_rollout_ref.model.use_remove_padding=False
126
+ - actor_rollout_ref.actor.optim.lr=1e-6
127
+ - actor_rollout_ref.actor.ppo_mini_batch_size=64
128
+ - +actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16
129
+ - actor_rollout_ref.actor.fsdp_config.param_offload=true
130
+ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=true
131
+ - actor_rollout_ref.rollout.log_prob_micro_batch_size=64
132
+ - actor_rollout_ref.rollout.tensor_model_parallel_size=1
133
+ - actor_rollout_ref.rollout.name=vllm
134
+ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4
135
+ - actor_rollout_ref.ref.log_prob_micro_batch_size=64
136
+ - actor_rollout_ref.ref.fsdp_config.param_offload=True
137
+ - actor_rollout_ref.actor.kl_loss_coef=0.001
138
+ - trainer.logger=[wandb]
139
+ - trainer.n_gpus_per_node=2
140
+ - trainer.nnodes=1
141
+ - trainer.save_freq=100
142
+ - trainer.test_freq=50
143
+ - trainer.project_name=
144
+ - trainer.experiment_name=llm_guard_3B_10k_v2
145
+ - trainer.total_epochs=15
146
+ - trainer.total_training_steps=1005
147
+ - trainer.default_local_dir=verl_checkpoints/llm_guard_3B_10k_v2
148
+ - do_search=false
149
+ - max_turns=1
150
+ job:
151
+ name: main_ppo
152
+ chdir: null
153
+ override_dirname: +actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16,actor_rollout_ref.actor.fsdp_config.optimizer_offload=true,actor_rollout_ref.actor.fsdp_config.param_offload=true,actor_rollout_ref.actor.kl_loss_coef=0.001,actor_rollout_ref.actor.optim.lr=1e-6,actor_rollout_ref.actor.ppo_mini_batch_size=64,actor_rollout_ref.model.enable_gradient_checkpointing=true,actor_rollout_ref.model.path=Qwen/Qwen3-4B-Instruct-2507,actor_rollout_ref.model.use_remove_padding=False,actor_rollout_ref.ref.fsdp_config.param_offload=True,actor_rollout_ref.ref.log_prob_micro_batch_size=64,actor_rollout_ref.rollout.gpu_memory_utilization=0.4,actor_rollout_ref.rollout.log_prob_micro_batch_size=64,actor_rollout_ref.rollout.name=vllm,actor_rollout_ref.rollout.tensor_model_parallel_size=1,algorithm.adv_estimator=grpo,data.max_prompt_length=4096,data.max_response_length=1024,data.shuffle_train_dataloader=True,data.train_batch_size=128,data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet,data.val_batch_size=64,data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet,do_search=false,max_turns=1,trainer.default_local_dir=verl_checkpoints/llm_guard_3B_10k_v2,trainer.experiment_name=llm_guard_3B_10k_v2,trainer.logger=[wandb],trainer.n_gpus_per_node=2,trainer.nnodes=1,trainer.project_name=,trainer.save_freq=100,trainer.test_freq=50,trainer.total_epochs=15,trainer.total_training_steps=1005
154
+ id: ???
155
+ num: ???
156
+ config_name: ppo_trainer
157
+ env_set: {}
158
+ env_copy: []
159
+ config:
160
+ override_dirname:
161
+ kv_sep: '='
162
+ item_sep: ','
163
+ exclude_keys: []
164
+ runtime:
165
+ version: 1.3.2
166
+ version_base: '1.3'
167
+ cwd: /data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1
168
+ config_sources:
169
+ - path: hydra.conf
170
+ schema: pkg
171
+ provider: hydra
172
+ - path: /data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1/verl/trainer/config
173
+ schema: file
174
+ provider: main
175
+ - path: ''
176
+ schema: structured
177
+ provider: schema
178
+ output_dir: /data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1/outputs/2026-02-01/20-26-44
179
+ choices:
180
+ hydra/env: default
181
+ hydra/callbacks: null
182
+ hydra/job_logging: default
183
+ hydra/hydra_logging: default
184
+ hydra/hydra_help: default
185
+ hydra/help: default
186
+ hydra/sweeper: basic
187
+ hydra/launcher: basic
188
+ hydra/output: default
189
+ verbose: false
code/RL_model/verl/Search-R1/outputs/2026-02-01/20-26-44/.hydra/overrides.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ - data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet
2
+ - data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet
3
+ - data.train_batch_size=128
4
+ - data.val_batch_size=64
5
+ - data.max_prompt_length=4096
6
+ - data.max_response_length=1024
7
+ - data.shuffle_train_dataloader=True
8
+ - algorithm.adv_estimator=grpo
9
+ - actor_rollout_ref.model.path=Qwen/Qwen3-4B-Instruct-2507
10
+ - actor_rollout_ref.model.enable_gradient_checkpointing=true
11
+ - actor_rollout_ref.model.use_remove_padding=False
12
+ - actor_rollout_ref.actor.optim.lr=1e-6
13
+ - actor_rollout_ref.actor.ppo_mini_batch_size=64
14
+ - +actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16
15
+ - actor_rollout_ref.actor.fsdp_config.param_offload=true
16
+ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=true
17
+ - actor_rollout_ref.rollout.log_prob_micro_batch_size=64
18
+ - actor_rollout_ref.rollout.tensor_model_parallel_size=1
19
+ - actor_rollout_ref.rollout.name=vllm
20
+ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4
21
+ - actor_rollout_ref.ref.log_prob_micro_batch_size=64
22
+ - actor_rollout_ref.ref.fsdp_config.param_offload=True
23
+ - actor_rollout_ref.actor.kl_loss_coef=0.001
24
+ - trainer.logger=[wandb]
25
+ - trainer.n_gpus_per_node=2
26
+ - trainer.nnodes=1
27
+ - trainer.save_freq=100
28
+ - trainer.test_freq=50
29
+ - trainer.project_name=
30
+ - trainer.experiment_name=llm_guard_3B_10k_v2
31
+ - trainer.total_epochs=15
32
+ - trainer.total_training_steps=1005
33
+ - trainer.default_local_dir=verl_checkpoints/llm_guard_3B_10k_v2
34
+ - do_search=false
35
+ - max_turns=1
code/RL_model/verl/Search-R1/outputs/2026-02-01/20-33-01/.hydra/config.yaml ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ tokenizer: null
3
+ train_files: /home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet
4
+ val_files: /home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet
5
+ train_data_num: null
6
+ val_data_num: null
7
+ prompt_key: prompt
8
+ max_prompt_length: 4096
9
+ max_response_length: 1024
10
+ max_start_length: 256
11
+ max_obs_length: 512
12
+ train_batch_size: 128
13
+ val_batch_size: 64
14
+ return_raw_input_ids: false
15
+ return_raw_chat: false
16
+ shuffle_train_dataloader: true
17
+ actor_rollout_ref:
18
+ hybrid_engine: true
19
+ model:
20
+ path: Qwen/Qwen3-4B-Instruct-2507
21
+ external_lib: null
22
+ override_config: {}
23
+ enable_gradient_checkpointing: true
24
+ use_remove_padding: false
25
+ actor:
26
+ strategy: fsdp
27
+ ppo_mini_batch_size: 64
28
+ ppo_micro_batch_size: 64
29
+ use_dynamic_bsz: false
30
+ ppo_max_token_len_per_gpu: 16384
31
+ grad_clip: 1.0
32
+ state_masking: false
33
+ clip_ratio: 0.2
34
+ entropy_coeff: 0.001
35
+ use_kl_loss: false
36
+ kl_loss_coef: 0.001
37
+ kl_loss_type: low_var_kl
38
+ ppo_epochs: 1
39
+ shuffle: false
40
+ ulysses_sequence_parallel_size: 1
41
+ optim:
42
+ lr: 1.0e-06
43
+ lr_warmup_steps_ratio: 0.0
44
+ min_lr_ratio: null
45
+ warmup_style: constant
46
+ total_training_steps: -1
47
+ fsdp_config:
48
+ wrap_policy:
49
+ min_num_params: 0
50
+ param_offload: true
51
+ grad_offload: false
52
+ optimizer_offload: true
53
+ fsdp_size: -1
54
+ ppo_micro_batch_size_per_gpu: 16
55
+ ref:
56
+ fsdp_config:
57
+ param_offload: true
58
+ wrap_policy:
59
+ min_num_params: 0
60
+ fsdp_size: -1
61
+ log_prob_micro_batch_size: 64
62
+ log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
63
+ log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
64
+ ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size}
65
+ rollout:
66
+ name: vllm
67
+ temperature: 1.0
68
+ top_k: -1
69
+ top_p: 0.95
70
+ prompt_length: ${data.max_prompt_length}
71
+ response_length: ${data.max_response_length}
72
+ dtype: bfloat16
73
+ gpu_memory_utilization: 0.4
74
+ ignore_eos: false
75
+ enforce_eager: true
76
+ free_cache_engine: true
77
+ load_format: dummy_dtensor
78
+ tensor_model_parallel_size: 1
79
+ max_num_batched_tokens: 8192
80
+ max_num_seqs: 1024
81
+ log_prob_micro_batch_size: 64
82
+ log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
83
+ log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
84
+ do_sample: true
85
+ 'n': 1
86
+ n_agent: 1
87
+ critic:
88
+ strategy: fsdp
89
+ optim:
90
+ lr: 1.0e-05
91
+ lr_warmup_steps_ratio: 0.0
92
+ min_lr_ratio: null
93
+ warmup_style: constant
94
+ total_training_steps: -1
95
+ model:
96
+ path: ~/models/deepseek-llm-7b-chat
97
+ tokenizer_path: ${actor_rollout_ref.model.path}
98
+ override_config: {}
99
+ external_lib: ${actor_rollout_ref.model.external_lib}
100
+ enable_gradient_checkpointing: false
101
+ use_remove_padding: false
102
+ fsdp_config:
103
+ param_offload: false
104
+ grad_offload: false
105
+ optimizer_offload: false
106
+ wrap_policy:
107
+ min_num_params: 0
108
+ fsdp_size: -1
109
+ ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
110
+ ppo_micro_batch_size: 64
111
+ forward_micro_batch_size: ${critic.ppo_micro_batch_size}
112
+ use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
113
+ ppo_max_token_len_per_gpu: 32768
114
+ forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}
115
+ ulysses_sequence_parallel_size: 1
116
+ ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
117
+ shuffle: ${actor_rollout_ref.actor.shuffle}
118
+ grad_clip: 1.0
119
+ cliprange_value: 0.5
120
+ reward_model:
121
+ enable: false
122
+ strategy: fsdp
123
+ model:
124
+ input_tokenizer: ${actor_rollout_ref.model.path}
125
+ path: ~/models/FsfairX-LLaMA3-RM-v0.1
126
+ external_lib: ${actor_rollout_ref.model.external_lib}
127
+ use_remove_padding: false
128
+ fsdp_config:
129
+ min_num_params: 0
130
+ param_offload: false
131
+ micro_batch_size: 64
132
+ max_length: null
133
+ ulysses_sequence_parallel_size: 1
134
+ use_dynamic_bsz: ${critic.use_dynamic_bsz}
135
+ forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}
136
+ structure_format_score: 0
137
+ final_format_score: 0
138
+ retrieval_score: 0
139
+ retriever:
140
+ url: http://127.0.0.1:8000/retrieve
141
+ topk: 3
142
+ algorithm:
143
+ gamma: 1.0
144
+ lam: 1.0
145
+ adv_estimator: grpo
146
+ no_think_rl: false
147
+ kl_penalty: kl
148
+ kl_ctrl:
149
+ type: fixed
150
+ kl_coef: 0.001
151
+ state_masking:
152
+ start_state_marker: <information>
153
+ end_state_marker: </information>
154
+ trainer:
155
+ total_epochs: 15
156
+ total_training_steps: 1005
157
+ project_name: ''
158
+ experiment_name: llm_guard_3B_10k_v2
159
+ logger:
160
+ - wandb
161
+ nnodes: 1
162
+ n_gpus_per_node: 2
163
+ save_freq: 100
164
+ test_freq: 50
165
+ critic_warmup: 0
166
+ default_hdfs_dir: ~/experiments/gsm8k/ppo/${trainer.experiment_name}
167
+ default_local_dir: verl_checkpoints/llm_guard_3B_10k_v2
168
+ max_turns: 1
169
+ do_search: false
code/RL_model/verl/Search-R1/outputs/2026-02-01/20-33-01/.hydra/hydra.yaml ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+ sweep:
5
+ dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
6
+ subdir: ${hydra.job.num}
7
+ launcher:
8
+ _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
9
+ sweeper:
10
+ _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
11
+ max_batch_size: null
12
+ params: null
13
+ help:
14
+ app_name: ${hydra.job.name}
15
+ header: '${hydra.help.app_name} is powered by Hydra.
16
+
17
+ '
18
+ footer: 'Powered by Hydra (https://hydra.cc)
19
+
20
+ Use --hydra-help to view Hydra specific help
21
+
22
+ '
23
+ template: '${hydra.help.header}
24
+
25
+ == Configuration groups ==
26
+
27
+ Compose your configuration from those groups (group=option)
28
+
29
+
30
+ $APP_CONFIG_GROUPS
31
+
32
+
33
+ == Config ==
34
+
35
+ Override anything in the config (foo.bar=value)
36
+
37
+
38
+ $CONFIG
39
+
40
+
41
+ ${hydra.help.footer}
42
+
43
+ '
44
+ hydra_help:
45
+ template: 'Hydra (${hydra.runtime.version})
46
+
47
+ See https://hydra.cc for more info.
48
+
49
+
50
+ == Flags ==
51
+
52
+ $FLAGS_HELP
53
+
54
+
55
+ == Configuration groups ==
56
+
57
+ Compose your configuration from those groups (For example, append hydra/job_logging=disabled
58
+ to command line)
59
+
60
+
61
+ $HYDRA_CONFIG_GROUPS
62
+
63
+
64
+ Use ''--cfg hydra'' to Show the Hydra config.
65
+
66
+ '
67
+ hydra_help: ???
68
+ hydra_logging:
69
+ version: 1
70
+ formatters:
71
+ simple:
72
+ format: '[%(asctime)s][HYDRA] %(message)s'
73
+ handlers:
74
+ console:
75
+ class: logging.StreamHandler
76
+ formatter: simple
77
+ stream: ext://sys.stdout
78
+ root:
79
+ level: INFO
80
+ handlers:
81
+ - console
82
+ loggers:
83
+ logging_example:
84
+ level: DEBUG
85
+ disable_existing_loggers: false
86
+ job_logging:
87
+ version: 1
88
+ formatters:
89
+ simple:
90
+ format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
91
+ handlers:
92
+ console:
93
+ class: logging.StreamHandler
94
+ formatter: simple
95
+ stream: ext://sys.stdout
96
+ file:
97
+ class: logging.FileHandler
98
+ formatter: simple
99
+ filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
100
+ root:
101
+ level: INFO
102
+ handlers:
103
+ - console
104
+ - file
105
+ disable_existing_loggers: false
106
+ env: {}
107
+ mode: RUN
108
+ searchpath: []
109
+ callbacks: {}
110
+ output_subdir: .hydra
111
+ overrides:
112
+ hydra:
113
+ - hydra.mode=RUN
114
+ task:
115
+ - data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet
116
+ - data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet
117
+ - data.train_batch_size=128
118
+ - data.val_batch_size=64
119
+ - data.max_prompt_length=4096
120
+ - data.max_response_length=1024
121
+ - data.shuffle_train_dataloader=True
122
+ - algorithm.adv_estimator=grpo
123
+ - actor_rollout_ref.model.path=Qwen/Qwen3-4B-Instruct-2507
124
+ - actor_rollout_ref.model.enable_gradient_checkpointing=true
125
+ - actor_rollout_ref.model.use_remove_padding=False
126
+ - actor_rollout_ref.actor.optim.lr=1e-6
127
+ - actor_rollout_ref.actor.ppo_mini_batch_size=64
128
+ - +actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16
129
+ - actor_rollout_ref.actor.fsdp_config.param_offload=true
130
+ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=true
131
+ - actor_rollout_ref.rollout.log_prob_micro_batch_size=64
132
+ - actor_rollout_ref.rollout.tensor_model_parallel_size=1
133
+ - actor_rollout_ref.rollout.name=vllm
134
+ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4
135
+ - actor_rollout_ref.ref.log_prob_micro_batch_size=64
136
+ - actor_rollout_ref.ref.fsdp_config.param_offload=True
137
+ - actor_rollout_ref.actor.kl_loss_coef=0.001
138
+ - trainer.logger=[wandb]
139
+ - trainer.n_gpus_per_node=2
140
+ - trainer.nnodes=1
141
+ - trainer.save_freq=100
142
+ - trainer.test_freq=50
143
+ - trainer.project_name=
144
+ - trainer.experiment_name=llm_guard_3B_10k_v2
145
+ - trainer.total_epochs=15
146
+ - trainer.total_training_steps=1005
147
+ - trainer.default_local_dir=verl_checkpoints/llm_guard_3B_10k_v2
148
+ - do_search=false
149
+ - max_turns=1
150
+ job:
151
+ name: main_ppo
152
+ chdir: null
153
+ override_dirname: +actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16,actor_rollout_ref.actor.fsdp_config.optimizer_offload=true,actor_rollout_ref.actor.fsdp_config.param_offload=true,actor_rollout_ref.actor.kl_loss_coef=0.001,actor_rollout_ref.actor.optim.lr=1e-6,actor_rollout_ref.actor.ppo_mini_batch_size=64,actor_rollout_ref.model.enable_gradient_checkpointing=true,actor_rollout_ref.model.path=Qwen/Qwen3-4B-Instruct-2507,actor_rollout_ref.model.use_remove_padding=False,actor_rollout_ref.ref.fsdp_config.param_offload=True,actor_rollout_ref.ref.log_prob_micro_batch_size=64,actor_rollout_ref.rollout.gpu_memory_utilization=0.4,actor_rollout_ref.rollout.log_prob_micro_batch_size=64,actor_rollout_ref.rollout.name=vllm,actor_rollout_ref.rollout.tensor_model_parallel_size=1,algorithm.adv_estimator=grpo,data.max_prompt_length=4096,data.max_response_length=1024,data.shuffle_train_dataloader=True,data.train_batch_size=128,data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet,data.val_batch_size=64,data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet,do_search=false,max_turns=1,trainer.default_local_dir=verl_checkpoints/llm_guard_3B_10k_v2,trainer.experiment_name=llm_guard_3B_10k_v2,trainer.logger=[wandb],trainer.n_gpus_per_node=2,trainer.nnodes=1,trainer.project_name=,trainer.save_freq=100,trainer.test_freq=50,trainer.total_epochs=15,trainer.total_training_steps=1005
154
+ id: ???
155
+ num: ???
156
+ config_name: ppo_trainer
157
+ env_set: {}
158
+ env_copy: []
159
+ config:
160
+ override_dirname:
161
+ kv_sep: '='
162
+ item_sep: ','
163
+ exclude_keys: []
164
+ runtime:
165
+ version: 1.3.2
166
+ version_base: '1.3'
167
+ cwd: /data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1
168
+ config_sources:
169
+ - path: hydra.conf
170
+ schema: pkg
171
+ provider: hydra
172
+ - path: /data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1/verl/trainer/config
173
+ schema: file
174
+ provider: main
175
+ - path: ''
176
+ schema: structured
177
+ provider: schema
178
+ output_dir: /data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1/outputs/2026-02-01/20-33-01
179
+ choices:
180
+ hydra/env: default
181
+ hydra/callbacks: null
182
+ hydra/job_logging: default
183
+ hydra/hydra_logging: default
184
+ hydra/hydra_help: default
185
+ hydra/help: default
186
+ hydra/sweeper: basic
187
+ hydra/launcher: basic
188
+ hydra/output: default
189
+ verbose: false
code/RL_model/verl/Search-R1/outputs/2026-02-01/20-33-01/.hydra/overrides.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ - data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet
2
+ - data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet
3
+ - data.train_batch_size=128
4
+ - data.val_batch_size=64
5
+ - data.max_prompt_length=4096
6
+ - data.max_response_length=1024
7
+ - data.shuffle_train_dataloader=True
8
+ - algorithm.adv_estimator=grpo
9
+ - actor_rollout_ref.model.path=Qwen/Qwen3-4B-Instruct-2507
10
+ - actor_rollout_ref.model.enable_gradient_checkpointing=true
11
+ - actor_rollout_ref.model.use_remove_padding=False
12
+ - actor_rollout_ref.actor.optim.lr=1e-6
13
+ - actor_rollout_ref.actor.ppo_mini_batch_size=64
14
+ - +actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16
15
+ - actor_rollout_ref.actor.fsdp_config.param_offload=true
16
+ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=true
17
+ - actor_rollout_ref.rollout.log_prob_micro_batch_size=64
18
+ - actor_rollout_ref.rollout.tensor_model_parallel_size=1
19
+ - actor_rollout_ref.rollout.name=vllm
20
+ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4
21
+ - actor_rollout_ref.ref.log_prob_micro_batch_size=64
22
+ - actor_rollout_ref.ref.fsdp_config.param_offload=True
23
+ - actor_rollout_ref.actor.kl_loss_coef=0.001
24
+ - trainer.logger=[wandb]
25
+ - trainer.n_gpus_per_node=2
26
+ - trainer.nnodes=1
27
+ - trainer.save_freq=100
28
+ - trainer.test_freq=50
29
+ - trainer.project_name=
30
+ - trainer.experiment_name=llm_guard_3B_10k_v2
31
+ - trainer.total_epochs=15
32
+ - trainer.total_training_steps=1005
33
+ - trainer.default_local_dir=verl_checkpoints/llm_guard_3B_10k_v2
34
+ - do_search=false
35
+ - max_turns=1
code/RL_model/verl/Search-R1/outputs/2026-02-01/20-33-01/main_ppo.log ADDED
File without changes
code/RL_model/verl/Search-R1/outputs/2026-02-01/20-35-38/.hydra/config.yaml ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ tokenizer: null
3
+ train_files: /home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet
4
+ val_files: /home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet
5
+ train_data_num: null
6
+ val_data_num: null
7
+ prompt_key: prompt
8
+ max_prompt_length: 4096
9
+ max_response_length: 1024
10
+ max_start_length: 256
11
+ max_obs_length: 512
12
+ train_batch_size: 128
13
+ val_batch_size: 64
14
+ return_raw_input_ids: false
15
+ return_raw_chat: false
16
+ shuffle_train_dataloader: true
17
+ actor_rollout_ref:
18
+ hybrid_engine: true
19
+ model:
20
+ path: Qwen/Qwen3-4B-Instruct-2507
21
+ external_lib: null
22
+ override_config: {}
23
+ enable_gradient_checkpointing: true
24
+ use_remove_padding: false
25
+ actor:
26
+ strategy: fsdp
27
+ ppo_mini_batch_size: 64
28
+ ppo_micro_batch_size: 64
29
+ use_dynamic_bsz: false
30
+ ppo_max_token_len_per_gpu: 16384
31
+ grad_clip: 1.0
32
+ state_masking: false
33
+ clip_ratio: 0.2
34
+ entropy_coeff: 0.001
35
+ use_kl_loss: false
36
+ kl_loss_coef: 0.001
37
+ kl_loss_type: low_var_kl
38
+ ppo_epochs: 1
39
+ shuffle: false
40
+ ulysses_sequence_parallel_size: 1
41
+ optim:
42
+ lr: 1.0e-06
43
+ lr_warmup_steps_ratio: 0.0
44
+ min_lr_ratio: null
45
+ warmup_style: constant
46
+ total_training_steps: -1
47
+ fsdp_config:
48
+ wrap_policy:
49
+ min_num_params: 0
50
+ param_offload: true
51
+ grad_offload: false
52
+ optimizer_offload: true
53
+ fsdp_size: -1
54
+ ppo_micro_batch_size_per_gpu: 16
55
+ ref:
56
+ fsdp_config:
57
+ param_offload: true
58
+ wrap_policy:
59
+ min_num_params: 0
60
+ fsdp_size: -1
61
+ log_prob_micro_batch_size: 64
62
+ log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
63
+ log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
64
+ ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size}
65
+ rollout:
66
+ name: vllm
67
+ temperature: 1.0
68
+ top_k: -1
69
+ top_p: 0.95
70
+ prompt_length: ${data.max_prompt_length}
71
+ response_length: ${data.max_response_length}
72
+ dtype: bfloat16
73
+ gpu_memory_utilization: 0.4
74
+ ignore_eos: false
75
+ enforce_eager: true
76
+ free_cache_engine: true
77
+ load_format: dummy_dtensor
78
+ tensor_model_parallel_size: 1
79
+ max_num_batched_tokens: 8192
80
+ max_num_seqs: 1024
81
+ log_prob_micro_batch_size: 64
82
+ log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
83
+ log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
84
+ do_sample: true
85
+ 'n': 1
86
+ n_agent: 1
87
+ critic:
88
+ strategy: fsdp
89
+ optim:
90
+ lr: 1.0e-05
91
+ lr_warmup_steps_ratio: 0.0
92
+ min_lr_ratio: null
93
+ warmup_style: constant
94
+ total_training_steps: -1
95
+ model:
96
+ path: ~/models/deepseek-llm-7b-chat
97
+ tokenizer_path: ${actor_rollout_ref.model.path}
98
+ override_config: {}
99
+ external_lib: ${actor_rollout_ref.model.external_lib}
100
+ enable_gradient_checkpointing: false
101
+ use_remove_padding: false
102
+ fsdp_config:
103
+ param_offload: false
104
+ grad_offload: false
105
+ optimizer_offload: false
106
+ wrap_policy:
107
+ min_num_params: 0
108
+ fsdp_size: -1
109
+ ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
110
+ ppo_micro_batch_size: 64
111
+ forward_micro_batch_size: ${critic.ppo_micro_batch_size}
112
+ use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
113
+ ppo_max_token_len_per_gpu: 32768
114
+ forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}
115
+ ulysses_sequence_parallel_size: 1
116
+ ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
117
+ shuffle: ${actor_rollout_ref.actor.shuffle}
118
+ grad_clip: 1.0
119
+ cliprange_value: 0.5
120
+ reward_model:
121
+ enable: false
122
+ strategy: fsdp
123
+ model:
124
+ input_tokenizer: ${actor_rollout_ref.model.path}
125
+ path: ~/models/FsfairX-LLaMA3-RM-v0.1
126
+ external_lib: ${actor_rollout_ref.model.external_lib}
127
+ use_remove_padding: false
128
+ fsdp_config:
129
+ min_num_params: 0
130
+ param_offload: false
131
+ micro_batch_size: 64
132
+ max_length: null
133
+ ulysses_sequence_parallel_size: 1
134
+ use_dynamic_bsz: ${critic.use_dynamic_bsz}
135
+ forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}
136
+ structure_format_score: 0
137
+ final_format_score: 0
138
+ retrieval_score: 0
139
+ retriever:
140
+ url: http://127.0.0.1:8000/retrieve
141
+ topk: 3
142
+ algorithm:
143
+ gamma: 1.0
144
+ lam: 1.0
145
+ adv_estimator: grpo
146
+ no_think_rl: false
147
+ kl_penalty: kl
148
+ kl_ctrl:
149
+ type: fixed
150
+ kl_coef: 0.001
151
+ state_masking:
152
+ start_state_marker: <information>
153
+ end_state_marker: </information>
154
+ trainer:
155
+ total_epochs: 15
156
+ total_training_steps: 1005
157
+ project_name: ''
158
+ experiment_name: llm_guard_3B_10k_v2
159
+ logger:
160
+ - wandb
161
+ nnodes: 1
162
+ n_gpus_per_node: 2
163
+ save_freq: 100
164
+ test_freq: 50
165
+ critic_warmup: 0
166
+ default_hdfs_dir: ~/experiments/gsm8k/ppo/${trainer.experiment_name}
167
+ default_local_dir: verl_checkpoints/llm_guard_3B_10k_v2
168
+ max_turns: 1
169
+ do_search: false
code/RL_model/verl/Search-R1/outputs/2026-02-01/20-35-38/.hydra/hydra.yaml ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+ sweep:
5
+ dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
6
+ subdir: ${hydra.job.num}
7
+ launcher:
8
+ _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
9
+ sweeper:
10
+ _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
11
+ max_batch_size: null
12
+ params: null
13
+ help:
14
+ app_name: ${hydra.job.name}
15
+ header: '${hydra.help.app_name} is powered by Hydra.
16
+
17
+ '
18
+ footer: 'Powered by Hydra (https://hydra.cc)
19
+
20
+ Use --hydra-help to view Hydra specific help
21
+
22
+ '
23
+ template: '${hydra.help.header}
24
+
25
+ == Configuration groups ==
26
+
27
+ Compose your configuration from those groups (group=option)
28
+
29
+
30
+ $APP_CONFIG_GROUPS
31
+
32
+
33
+ == Config ==
34
+
35
+ Override anything in the config (foo.bar=value)
36
+
37
+
38
+ $CONFIG
39
+
40
+
41
+ ${hydra.help.footer}
42
+
43
+ '
44
+ hydra_help:
45
+ template: 'Hydra (${hydra.runtime.version})
46
+
47
+ See https://hydra.cc for more info.
48
+
49
+
50
+ == Flags ==
51
+
52
+ $FLAGS_HELP
53
+
54
+
55
+ == Configuration groups ==
56
+
57
+ Compose your configuration from those groups (For example, append hydra/job_logging=disabled
58
+ to command line)
59
+
60
+
61
+ $HYDRA_CONFIG_GROUPS
62
+
63
+
64
+ Use ''--cfg hydra'' to Show the Hydra config.
65
+
66
+ '
67
+ hydra_help: ???
68
+ hydra_logging:
69
+ version: 1
70
+ formatters:
71
+ simple:
72
+ format: '[%(asctime)s][HYDRA] %(message)s'
73
+ handlers:
74
+ console:
75
+ class: logging.StreamHandler
76
+ formatter: simple
77
+ stream: ext://sys.stdout
78
+ root:
79
+ level: INFO
80
+ handlers:
81
+ - console
82
+ loggers:
83
+ logging_example:
84
+ level: DEBUG
85
+ disable_existing_loggers: false
86
+ job_logging:
87
+ version: 1
88
+ formatters:
89
+ simple:
90
+ format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
91
+ handlers:
92
+ console:
93
+ class: logging.StreamHandler
94
+ formatter: simple
95
+ stream: ext://sys.stdout
96
+ file:
97
+ class: logging.FileHandler
98
+ formatter: simple
99
+ filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
100
+ root:
101
+ level: INFO
102
+ handlers:
103
+ - console
104
+ - file
105
+ disable_existing_loggers: false
106
+ env: {}
107
+ mode: RUN
108
+ searchpath: []
109
+ callbacks: {}
110
+ output_subdir: .hydra
111
+ overrides:
112
+ hydra:
113
+ - hydra.mode=RUN
114
+ task:
115
+ - data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet
116
+ - data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet
117
+ - data.train_batch_size=128
118
+ - data.val_batch_size=64
119
+ - data.max_prompt_length=4096
120
+ - data.max_response_length=1024
121
+ - data.shuffle_train_dataloader=True
122
+ - algorithm.adv_estimator=grpo
123
+ - actor_rollout_ref.model.path=Qwen/Qwen3-4B-Instruct-2507
124
+ - actor_rollout_ref.model.enable_gradient_checkpointing=true
125
+ - actor_rollout_ref.model.use_remove_padding=False
126
+ - actor_rollout_ref.actor.optim.lr=1e-6
127
+ - actor_rollout_ref.actor.ppo_mini_batch_size=64
128
+ - +actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16
129
+ - actor_rollout_ref.actor.fsdp_config.param_offload=true
130
+ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=true
131
+ - actor_rollout_ref.rollout.log_prob_micro_batch_size=64
132
+ - actor_rollout_ref.rollout.tensor_model_parallel_size=1
133
+ - actor_rollout_ref.rollout.name=vllm
134
+ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4
135
+ - actor_rollout_ref.ref.log_prob_micro_batch_size=64
136
+ - actor_rollout_ref.ref.fsdp_config.param_offload=True
137
+ - actor_rollout_ref.actor.kl_loss_coef=0.001
138
+ - trainer.logger=[wandb]
139
+ - trainer.n_gpus_per_node=2
140
+ - trainer.nnodes=1
141
+ - trainer.save_freq=100
142
+ - trainer.test_freq=50
143
+ - trainer.project_name=
144
+ - trainer.experiment_name=llm_guard_3B_10k_v2
145
+ - trainer.total_epochs=15
146
+ - trainer.total_training_steps=1005
147
+ - trainer.default_local_dir=verl_checkpoints/llm_guard_3B_10k_v2
148
+ - do_search=false
149
+ - max_turns=1
150
+ job:
151
+ name: main_ppo
152
+ chdir: null
153
+ override_dirname: +actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16,actor_rollout_ref.actor.fsdp_config.optimizer_offload=true,actor_rollout_ref.actor.fsdp_config.param_offload=true,actor_rollout_ref.actor.kl_loss_coef=0.001,actor_rollout_ref.actor.optim.lr=1e-6,actor_rollout_ref.actor.ppo_mini_batch_size=64,actor_rollout_ref.model.enable_gradient_checkpointing=true,actor_rollout_ref.model.path=Qwen/Qwen3-4B-Instruct-2507,actor_rollout_ref.model.use_remove_padding=False,actor_rollout_ref.ref.fsdp_config.param_offload=True,actor_rollout_ref.ref.log_prob_micro_batch_size=64,actor_rollout_ref.rollout.gpu_memory_utilization=0.4,actor_rollout_ref.rollout.log_prob_micro_batch_size=64,actor_rollout_ref.rollout.name=vllm,actor_rollout_ref.rollout.tensor_model_parallel_size=1,algorithm.adv_estimator=grpo,data.max_prompt_length=4096,data.max_response_length=1024,data.shuffle_train_dataloader=True,data.train_batch_size=128,data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet,data.val_batch_size=64,data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet,do_search=false,max_turns=1,trainer.default_local_dir=verl_checkpoints/llm_guard_3B_10k_v2,trainer.experiment_name=llm_guard_3B_10k_v2,trainer.logger=[wandb],trainer.n_gpus_per_node=2,trainer.nnodes=1,trainer.project_name=,trainer.save_freq=100,trainer.test_freq=50,trainer.total_epochs=15,trainer.total_training_steps=1005
154
+ id: ???
155
+ num: ???
156
+ config_name: ppo_trainer
157
+ env_set: {}
158
+ env_copy: []
159
+ config:
160
+ override_dirname:
161
+ kv_sep: '='
162
+ item_sep: ','
163
+ exclude_keys: []
164
+ runtime:
165
+ version: 1.3.2
166
+ version_base: '1.3'
167
+ cwd: /data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1
168
+ config_sources:
169
+ - path: hydra.conf
170
+ schema: pkg
171
+ provider: hydra
172
+ - path: /data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1/verl/trainer/config
173
+ schema: file
174
+ provider: main
175
+ - path: ''
176
+ schema: structured
177
+ provider: schema
178
+ output_dir: /data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1/outputs/2026-02-01/20-35-38
179
+ choices:
180
+ hydra/env: default
181
+ hydra/callbacks: null
182
+ hydra/job_logging: default
183
+ hydra/hydra_logging: default
184
+ hydra/hydra_help: default
185
+ hydra/help: default
186
+ hydra/sweeper: basic
187
+ hydra/launcher: basic
188
+ hydra/output: default
189
+ verbose: false
code/RL_model/verl/Search-R1/outputs/2026-02-01/20-35-38/.hydra/overrides.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ - data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet
2
+ - data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet
3
+ - data.train_batch_size=128
4
+ - data.val_batch_size=64
5
+ - data.max_prompt_length=4096
6
+ - data.max_response_length=1024
7
+ - data.shuffle_train_dataloader=True
8
+ - algorithm.adv_estimator=grpo
9
+ - actor_rollout_ref.model.path=Qwen/Qwen3-4B-Instruct-2507
10
+ - actor_rollout_ref.model.enable_gradient_checkpointing=true
11
+ - actor_rollout_ref.model.use_remove_padding=False
12
+ - actor_rollout_ref.actor.optim.lr=1e-6
13
+ - actor_rollout_ref.actor.ppo_mini_batch_size=64
14
+ - +actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16
15
+ - actor_rollout_ref.actor.fsdp_config.param_offload=true
16
+ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=true
17
+ - actor_rollout_ref.rollout.log_prob_micro_batch_size=64
18
+ - actor_rollout_ref.rollout.tensor_model_parallel_size=1
19
+ - actor_rollout_ref.rollout.name=vllm
20
+ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4
21
+ - actor_rollout_ref.ref.log_prob_micro_batch_size=64
22
+ - actor_rollout_ref.ref.fsdp_config.param_offload=True
23
+ - actor_rollout_ref.actor.kl_loss_coef=0.001
24
+ - trainer.logger=[wandb]
25
+ - trainer.n_gpus_per_node=2
26
+ - trainer.nnodes=1
27
+ - trainer.save_freq=100
28
+ - trainer.test_freq=50
29
+ - trainer.project_name=
30
+ - trainer.experiment_name=llm_guard_3B_10k_v2
31
+ - trainer.total_epochs=15
32
+ - trainer.total_training_steps=1005
33
+ - trainer.default_local_dir=verl_checkpoints/llm_guard_3B_10k_v2
34
+ - do_search=false
35
+ - max_turns=1
code/RL_model/verl/Search-R1/outputs/2026-02-01/20-35-38/main_ppo.log ADDED
File without changes
code/RL_model/verl/Search-R1/outputs/2026-02-01/20-41-08/.hydra/config.yaml ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ tokenizer: null
3
+ train_files: /home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet
4
+ val_files: /home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet
5
+ train_data_num: null
6
+ val_data_num: null
7
+ prompt_key: prompt
8
+ max_prompt_length: 4096
9
+ max_response_length: 1024
10
+ max_start_length: 256
11
+ max_obs_length: 512
12
+ train_batch_size: 128
13
+ val_batch_size: 64
14
+ return_raw_input_ids: false
15
+ return_raw_chat: false
16
+ shuffle_train_dataloader: true
17
+ actor_rollout_ref:
18
+ hybrid_engine: true
19
+ model:
20
+ path: Qwen/Qwen3-4B-Instruct-2507
21
+ external_lib: null
22
+ override_config: {}
23
+ enable_gradient_checkpointing: true
24
+ use_remove_padding: false
25
+ actor:
26
+ strategy: fsdp
27
+ ppo_mini_batch_size: 64
28
+ ppo_micro_batch_size: 64
29
+ use_dynamic_bsz: false
30
+ ppo_max_token_len_per_gpu: 16384
31
+ grad_clip: 1.0
32
+ state_masking: false
33
+ clip_ratio: 0.2
34
+ entropy_coeff: 0.001
35
+ use_kl_loss: false
36
+ kl_loss_coef: 0.001
37
+ kl_loss_type: low_var_kl
38
+ ppo_epochs: 1
39
+ shuffle: false
40
+ ulysses_sequence_parallel_size: 1
41
+ optim:
42
+ lr: 1.0e-06
43
+ lr_warmup_steps_ratio: 0.0
44
+ min_lr_ratio: null
45
+ warmup_style: constant
46
+ total_training_steps: -1
47
+ fsdp_config:
48
+ wrap_policy:
49
+ min_num_params: 0
50
+ param_offload: true
51
+ grad_offload: false
52
+ optimizer_offload: true
53
+ fsdp_size: -1
54
+ ppo_micro_batch_size_per_gpu: 16
55
+ ref:
56
+ fsdp_config:
57
+ param_offload: true
58
+ wrap_policy:
59
+ min_num_params: 0
60
+ fsdp_size: -1
61
+ log_prob_micro_batch_size: 64
62
+ log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
63
+ log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
64
+ ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size}
65
+ rollout:
66
+ name: vllm
67
+ temperature: 1.0
68
+ top_k: -1
69
+ top_p: 0.95
70
+ prompt_length: ${data.max_prompt_length}
71
+ response_length: ${data.max_response_length}
72
+ dtype: bfloat16
73
+ gpu_memory_utilization: 0.4
74
+ ignore_eos: false
75
+ enforce_eager: true
76
+ free_cache_engine: true
77
+ load_format: dummy_dtensor
78
+ tensor_model_parallel_size: 1
79
+ max_num_batched_tokens: 8192
80
+ max_num_seqs: 1024
81
+ log_prob_micro_batch_size: 64
82
+ log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
83
+ log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
84
+ do_sample: true
85
+ 'n': 1
86
+ n_agent: 1
87
+ critic:
88
+ strategy: fsdp
89
+ optim:
90
+ lr: 1.0e-05
91
+ lr_warmup_steps_ratio: 0.0
92
+ min_lr_ratio: null
93
+ warmup_style: constant
94
+ total_training_steps: -1
95
+ model:
96
+ path: ~/models/deepseek-llm-7b-chat
97
+ tokenizer_path: ${actor_rollout_ref.model.path}
98
+ override_config: {}
99
+ external_lib: ${actor_rollout_ref.model.external_lib}
100
+ enable_gradient_checkpointing: false
101
+ use_remove_padding: false
102
+ fsdp_config:
103
+ param_offload: false
104
+ grad_offload: false
105
+ optimizer_offload: false
106
+ wrap_policy:
107
+ min_num_params: 0
108
+ fsdp_size: -1
109
+ ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
110
+ ppo_micro_batch_size: 64
111
+ forward_micro_batch_size: ${critic.ppo_micro_batch_size}
112
+ use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
113
+ ppo_max_token_len_per_gpu: 32768
114
+ forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}
115
+ ulysses_sequence_parallel_size: 1
116
+ ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
117
+ shuffle: ${actor_rollout_ref.actor.shuffle}
118
+ grad_clip: 1.0
119
+ cliprange_value: 0.5
120
+ reward_model:
121
+ enable: false
122
+ strategy: fsdp
123
+ model:
124
+ input_tokenizer: ${actor_rollout_ref.model.path}
125
+ path: ~/models/FsfairX-LLaMA3-RM-v0.1
126
+ external_lib: ${actor_rollout_ref.model.external_lib}
127
+ use_remove_padding: false
128
+ fsdp_config:
129
+ min_num_params: 0
130
+ param_offload: false
131
+ micro_batch_size: 64
132
+ max_length: null
133
+ ulysses_sequence_parallel_size: 1
134
+ use_dynamic_bsz: ${critic.use_dynamic_bsz}
135
+ forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}
136
+ structure_format_score: 0
137
+ final_format_score: 0
138
+ retrieval_score: 0
139
+ retriever:
140
+ url: http://127.0.0.1:8000/retrieve
141
+ topk: 3
142
+ algorithm:
143
+ gamma: 1.0
144
+ lam: 1.0
145
+ adv_estimator: grpo
146
+ no_think_rl: false
147
+ kl_penalty: kl
148
+ kl_ctrl:
149
+ type: fixed
150
+ kl_coef: 0.001
151
+ state_masking:
152
+ start_state_marker: <information>
153
+ end_state_marker: </information>
154
+ trainer:
155
+ total_epochs: 15
156
+ total_training_steps: 1005
157
+ project_name: ''
158
+ experiment_name: llm_guard_3B_10k_v2
159
+ logger:
160
+ - wandb
161
+ nnodes: 1
162
+ n_gpus_per_node: 2
163
+ save_freq: 100
164
+ test_freq: 50
165
+ critic_warmup: 0
166
+ default_hdfs_dir: ~/experiments/gsm8k/ppo/${trainer.experiment_name}
167
+ default_local_dir: verl_checkpoints/llm_guard_3B_10k_v2
168
+ max_turns: 1
169
+ do_search: false
code/RL_model/verl/Search-R1/outputs/2026-02-01/20-41-08/.hydra/hydra.yaml ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+ sweep:
5
+ dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
6
+ subdir: ${hydra.job.num}
7
+ launcher:
8
+ _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
9
+ sweeper:
10
+ _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
11
+ max_batch_size: null
12
+ params: null
13
+ help:
14
+ app_name: ${hydra.job.name}
15
+ header: '${hydra.help.app_name} is powered by Hydra.
16
+
17
+ '
18
+ footer: 'Powered by Hydra (https://hydra.cc)
19
+
20
+ Use --hydra-help to view Hydra specific help
21
+
22
+ '
23
+ template: '${hydra.help.header}
24
+
25
+ == Configuration groups ==
26
+
27
+ Compose your configuration from those groups (group=option)
28
+
29
+
30
+ $APP_CONFIG_GROUPS
31
+
32
+
33
+ == Config ==
34
+
35
+ Override anything in the config (foo.bar=value)
36
+
37
+
38
+ $CONFIG
39
+
40
+
41
+ ${hydra.help.footer}
42
+
43
+ '
44
+ hydra_help:
45
+ template: 'Hydra (${hydra.runtime.version})
46
+
47
+ See https://hydra.cc for more info.
48
+
49
+
50
+ == Flags ==
51
+
52
+ $FLAGS_HELP
53
+
54
+
55
+ == Configuration groups ==
56
+
57
+ Compose your configuration from those groups (For example, append hydra/job_logging=disabled
58
+ to command line)
59
+
60
+
61
+ $HYDRA_CONFIG_GROUPS
62
+
63
+
64
+ Use ''--cfg hydra'' to Show the Hydra config.
65
+
66
+ '
67
+ hydra_help: ???
68
+ hydra_logging:
69
+ version: 1
70
+ formatters:
71
+ simple:
72
+ format: '[%(asctime)s][HYDRA] %(message)s'
73
+ handlers:
74
+ console:
75
+ class: logging.StreamHandler
76
+ formatter: simple
77
+ stream: ext://sys.stdout
78
+ root:
79
+ level: INFO
80
+ handlers:
81
+ - console
82
+ loggers:
83
+ logging_example:
84
+ level: DEBUG
85
+ disable_existing_loggers: false
86
+ job_logging:
87
+ version: 1
88
+ formatters:
89
+ simple:
90
+ format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
91
+ handlers:
92
+ console:
93
+ class: logging.StreamHandler
94
+ formatter: simple
95
+ stream: ext://sys.stdout
96
+ file:
97
+ class: logging.FileHandler
98
+ formatter: simple
99
+ filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
100
+ root:
101
+ level: INFO
102
+ handlers:
103
+ - console
104
+ - file
105
+ disable_existing_loggers: false
106
+ env: {}
107
+ mode: RUN
108
+ searchpath: []
109
+ callbacks: {}
110
+ output_subdir: .hydra
111
+ overrides:
112
+ hydra:
113
+ - hydra.mode=RUN
114
+ task:
115
+ - data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet
116
+ - data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet
117
+ - data.train_batch_size=128
118
+ - data.val_batch_size=64
119
+ - data.max_prompt_length=4096
120
+ - data.max_response_length=1024
121
+ - data.shuffle_train_dataloader=True
122
+ - algorithm.adv_estimator=grpo
123
+ - actor_rollout_ref.model.path=Qwen/Qwen3-4B-Instruct-2507
124
+ - actor_rollout_ref.model.enable_gradient_checkpointing=true
125
+ - actor_rollout_ref.model.use_remove_padding=False
126
+ - actor_rollout_ref.actor.optim.lr=1e-6
127
+ - actor_rollout_ref.actor.ppo_mini_batch_size=64
128
+ - +actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16
129
+ - actor_rollout_ref.actor.fsdp_config.param_offload=true
130
+ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=true
131
+ - actor_rollout_ref.rollout.log_prob_micro_batch_size=64
132
+ - actor_rollout_ref.rollout.tensor_model_parallel_size=1
133
+ - actor_rollout_ref.rollout.name=vllm
134
+ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4
135
+ - actor_rollout_ref.ref.log_prob_micro_batch_size=64
136
+ - actor_rollout_ref.ref.fsdp_config.param_offload=True
137
+ - actor_rollout_ref.actor.kl_loss_coef=0.001
138
+ - trainer.logger=[wandb]
139
+ - trainer.n_gpus_per_node=2
140
+ - trainer.nnodes=1
141
+ - trainer.save_freq=100
142
+ - trainer.test_freq=50
143
+ - trainer.project_name=
144
+ - trainer.experiment_name=llm_guard_3B_10k_v2
145
+ - trainer.total_epochs=15
146
+ - trainer.total_training_steps=1005
147
+ - trainer.default_local_dir=verl_checkpoints/llm_guard_3B_10k_v2
148
+ - do_search=false
149
+ - max_turns=1
150
+ job:
151
+ name: main_ppo
152
+ chdir: null
153
+ override_dirname: +actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16,actor_rollout_ref.actor.fsdp_config.optimizer_offload=true,actor_rollout_ref.actor.fsdp_config.param_offload=true,actor_rollout_ref.actor.kl_loss_coef=0.001,actor_rollout_ref.actor.optim.lr=1e-6,actor_rollout_ref.actor.ppo_mini_batch_size=64,actor_rollout_ref.model.enable_gradient_checkpointing=true,actor_rollout_ref.model.path=Qwen/Qwen3-4B-Instruct-2507,actor_rollout_ref.model.use_remove_padding=False,actor_rollout_ref.ref.fsdp_config.param_offload=True,actor_rollout_ref.ref.log_prob_micro_batch_size=64,actor_rollout_ref.rollout.gpu_memory_utilization=0.4,actor_rollout_ref.rollout.log_prob_micro_batch_size=64,actor_rollout_ref.rollout.name=vllm,actor_rollout_ref.rollout.tensor_model_parallel_size=1,algorithm.adv_estimator=grpo,data.max_prompt_length=4096,data.max_response_length=1024,data.shuffle_train_dataloader=True,data.train_batch_size=128,data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet,data.val_batch_size=64,data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet,do_search=false,max_turns=1,trainer.default_local_dir=verl_checkpoints/llm_guard_3B_10k_v2,trainer.experiment_name=llm_guard_3B_10k_v2,trainer.logger=[wandb],trainer.n_gpus_per_node=2,trainer.nnodes=1,trainer.project_name=,trainer.save_freq=100,trainer.test_freq=50,trainer.total_epochs=15,trainer.total_training_steps=1005
154
+ id: ???
155
+ num: ???
156
+ config_name: ppo_trainer
157
+ env_set: {}
158
+ env_copy: []
159
+ config:
160
+ override_dirname:
161
+ kv_sep: '='
162
+ item_sep: ','
163
+ exclude_keys: []
164
+ runtime:
165
+ version: 1.3.2
166
+ version_base: '1.3'
167
+ cwd: /data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1
168
+ config_sources:
169
+ - path: hydra.conf
170
+ schema: pkg
171
+ provider: hydra
172
+ - path: /data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1/verl/trainer/config
173
+ schema: file
174
+ provider: main
175
+ - path: ''
176
+ schema: structured
177
+ provider: schema
178
+ output_dir: /data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1/outputs/2026-02-01/20-41-08
179
+ choices:
180
+ hydra/env: default
181
+ hydra/callbacks: null
182
+ hydra/job_logging: default
183
+ hydra/hydra_logging: default
184
+ hydra/hydra_help: default
185
+ hydra/help: default
186
+ hydra/sweeper: basic
187
+ hydra/launcher: basic
188
+ hydra/output: default
189
+ verbose: false
code/RL_model/verl/Search-R1/outputs/2026-02-01/20-41-08/.hydra/overrides.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ - data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet
2
+ - data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet
3
+ - data.train_batch_size=128
4
+ - data.val_batch_size=64
5
+ - data.max_prompt_length=4096
6
+ - data.max_response_length=1024
7
+ - data.shuffle_train_dataloader=True
8
+ - algorithm.adv_estimator=grpo
9
+ - actor_rollout_ref.model.path=Qwen/Qwen3-4B-Instruct-2507
10
+ - actor_rollout_ref.model.enable_gradient_checkpointing=true
11
+ - actor_rollout_ref.model.use_remove_padding=False
12
+ - actor_rollout_ref.actor.optim.lr=1e-6
13
+ - actor_rollout_ref.actor.ppo_mini_batch_size=64
14
+ - +actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16
15
+ - actor_rollout_ref.actor.fsdp_config.param_offload=true
16
+ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=true
17
+ - actor_rollout_ref.rollout.log_prob_micro_batch_size=64
18
+ - actor_rollout_ref.rollout.tensor_model_parallel_size=1
19
+ - actor_rollout_ref.rollout.name=vllm
20
+ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4
21
+ - actor_rollout_ref.ref.log_prob_micro_batch_size=64
22
+ - actor_rollout_ref.ref.fsdp_config.param_offload=True
23
+ - actor_rollout_ref.actor.kl_loss_coef=0.001
24
+ - trainer.logger=[wandb]
25
+ - trainer.n_gpus_per_node=2
26
+ - trainer.nnodes=1
27
+ - trainer.save_freq=100
28
+ - trainer.test_freq=50
29
+ - trainer.project_name=
30
+ - trainer.experiment_name=llm_guard_3B_10k_v2
31
+ - trainer.total_epochs=15
32
+ - trainer.total_training_steps=1005
33
+ - trainer.default_local_dir=verl_checkpoints/llm_guard_3B_10k_v2
34
+ - do_search=false
35
+ - max_turns=1
code/RL_model/verl/Search-R1/outputs/2026-02-01/20-41-08/main_ppo.log ADDED
File without changes
code/RL_model/verl/Search-R1/outputs/2026-02-01/20-42-57/.hydra/config.yaml ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ tokenizer: null
3
+ train_files: /home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet
4
+ val_files: /home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet
5
+ train_data_num: null
6
+ val_data_num: null
7
+ prompt_key: prompt
8
+ max_prompt_length: 4096
9
+ max_response_length: 1024
10
+ max_start_length: 256
11
+ max_obs_length: 512
12
+ train_batch_size: 128
13
+ val_batch_size: 64
14
+ return_raw_input_ids: false
15
+ return_raw_chat: false
16
+ shuffle_train_dataloader: true
17
+ actor_rollout_ref:
18
+ hybrid_engine: true
19
+ model:
20
+ path: Qwen/Qwen3-4B-Instruct-2507
21
+ external_lib: null
22
+ override_config: {}
23
+ enable_gradient_checkpointing: true
24
+ use_remove_padding: false
25
+ actor:
26
+ strategy: fsdp
27
+ ppo_mini_batch_size: 64
28
+ ppo_micro_batch_size: 64
29
+ use_dynamic_bsz: false
30
+ ppo_max_token_len_per_gpu: 16384
31
+ grad_clip: 1.0
32
+ state_masking: false
33
+ clip_ratio: 0.2
34
+ entropy_coeff: 0.001
35
+ use_kl_loss: false
36
+ kl_loss_coef: 0.001
37
+ kl_loss_type: low_var_kl
38
+ ppo_epochs: 1
39
+ shuffle: false
40
+ ulysses_sequence_parallel_size: 1
41
+ optim:
42
+ lr: 1.0e-06
43
+ lr_warmup_steps_ratio: 0.0
44
+ min_lr_ratio: null
45
+ warmup_style: constant
46
+ total_training_steps: -1
47
+ fsdp_config:
48
+ wrap_policy:
49
+ min_num_params: 0
50
+ param_offload: true
51
+ grad_offload: false
52
+ optimizer_offload: true
53
+ fsdp_size: -1
54
+ ppo_micro_batch_size_per_gpu: 16
55
+ ref:
56
+ fsdp_config:
57
+ param_offload: true
58
+ wrap_policy:
59
+ min_num_params: 0
60
+ fsdp_size: -1
61
+ log_prob_micro_batch_size: 64
62
+ log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
63
+ log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
64
+ ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size}
65
+ rollout:
66
+ name: vllm
67
+ temperature: 1.0
68
+ top_k: -1
69
+ top_p: 0.95
70
+ prompt_length: ${data.max_prompt_length}
71
+ response_length: ${data.max_response_length}
72
+ dtype: bfloat16
73
+ gpu_memory_utilization: 0.4
74
+ ignore_eos: false
75
+ enforce_eager: true
76
+ free_cache_engine: true
77
+ load_format: dummy_dtensor
78
+ tensor_model_parallel_size: 1
79
+ max_num_batched_tokens: 8192
80
+ max_num_seqs: 1024
81
+ log_prob_micro_batch_size: 64
82
+ log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
83
+ log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
84
+ do_sample: true
85
+ 'n': 1
86
+ n_agent: 1
87
+ critic:
88
+ strategy: fsdp
89
+ optim:
90
+ lr: 1.0e-05
91
+ lr_warmup_steps_ratio: 0.0
92
+ min_lr_ratio: null
93
+ warmup_style: constant
94
+ total_training_steps: -1
95
+ model:
96
+ path: ~/models/deepseek-llm-7b-chat
97
+ tokenizer_path: ${actor_rollout_ref.model.path}
98
+ override_config: {}
99
+ external_lib: ${actor_rollout_ref.model.external_lib}
100
+ enable_gradient_checkpointing: false
101
+ use_remove_padding: false
102
+ fsdp_config:
103
+ param_offload: false
104
+ grad_offload: false
105
+ optimizer_offload: false
106
+ wrap_policy:
107
+ min_num_params: 0
108
+ fsdp_size: -1
109
+ ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
110
+ ppo_micro_batch_size: 64
111
+ forward_micro_batch_size: ${critic.ppo_micro_batch_size}
112
+ use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
113
+ ppo_max_token_len_per_gpu: 32768
114
+ forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}
115
+ ulysses_sequence_parallel_size: 1
116
+ ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
117
+ shuffle: ${actor_rollout_ref.actor.shuffle}
118
+ grad_clip: 1.0
119
+ cliprange_value: 0.5
120
+ reward_model:
121
+ enable: false
122
+ strategy: fsdp
123
+ model:
124
+ input_tokenizer: ${actor_rollout_ref.model.path}
125
+ path: ~/models/FsfairX-LLaMA3-RM-v0.1
126
+ external_lib: ${actor_rollout_ref.model.external_lib}
127
+ use_remove_padding: false
128
+ fsdp_config:
129
+ min_num_params: 0
130
+ param_offload: false
131
+ micro_batch_size: 64
132
+ max_length: null
133
+ ulysses_sequence_parallel_size: 1
134
+ use_dynamic_bsz: ${critic.use_dynamic_bsz}
135
+ forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}
136
+ structure_format_score: 0
137
+ final_format_score: 0
138
+ retrieval_score: 0
139
+ retriever:
140
+ url: http://127.0.0.1:8000/retrieve
141
+ topk: 3
142
+ algorithm:
143
+ gamma: 1.0
144
+ lam: 1.0
145
+ adv_estimator: grpo
146
+ no_think_rl: false
147
+ kl_penalty: kl
148
+ kl_ctrl:
149
+ type: fixed
150
+ kl_coef: 0.001
151
+ state_masking:
152
+ start_state_marker: <information>
153
+ end_state_marker: </information>
154
+ trainer:
155
+ total_epochs: 15
156
+ total_training_steps: 1005
157
+ project_name: ''
158
+ experiment_name: llm_guard_3B_10k_v2
159
+ logger:
160
+ - wandb
161
+ nnodes: 1
162
+ n_gpus_per_node: 2
163
+ save_freq: 100
164
+ test_freq: 50
165
+ critic_warmup: 0
166
+ default_hdfs_dir: ~/experiments/gsm8k/ppo/${trainer.experiment_name}
167
+ default_local_dir: verl_checkpoints/llm_guard_3B_10k_v2
168
+ max_turns: 1
169
+ do_search: false
code/RL_model/verl/Search-R1/outputs/2026-02-01/20-42-57/.hydra/hydra.yaml ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+ sweep:
5
+ dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
6
+ subdir: ${hydra.job.num}
7
+ launcher:
8
+ _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
9
+ sweeper:
10
+ _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
11
+ max_batch_size: null
12
+ params: null
13
+ help:
14
+ app_name: ${hydra.job.name}
15
+ header: '${hydra.help.app_name} is powered by Hydra.
16
+
17
+ '
18
+ footer: 'Powered by Hydra (https://hydra.cc)
19
+
20
+ Use --hydra-help to view Hydra specific help
21
+
22
+ '
23
+ template: '${hydra.help.header}
24
+
25
+ == Configuration groups ==
26
+
27
+ Compose your configuration from those groups (group=option)
28
+
29
+
30
+ $APP_CONFIG_GROUPS
31
+
32
+
33
+ == Config ==
34
+
35
+ Override anything in the config (foo.bar=value)
36
+
37
+
38
+ $CONFIG
39
+
40
+
41
+ ${hydra.help.footer}
42
+
43
+ '
44
+ hydra_help:
45
+ template: 'Hydra (${hydra.runtime.version})
46
+
47
+ See https://hydra.cc for more info.
48
+
49
+
50
+ == Flags ==
51
+
52
+ $FLAGS_HELP
53
+
54
+
55
+ == Configuration groups ==
56
+
57
+ Compose your configuration from those groups (For example, append hydra/job_logging=disabled
58
+ to command line)
59
+
60
+
61
+ $HYDRA_CONFIG_GROUPS
62
+
63
+
64
+ Use ''--cfg hydra'' to Show the Hydra config.
65
+
66
+ '
67
+ hydra_help: ???
68
+ hydra_logging:
69
+ version: 1
70
+ formatters:
71
+ simple:
72
+ format: '[%(asctime)s][HYDRA] %(message)s'
73
+ handlers:
74
+ console:
75
+ class: logging.StreamHandler
76
+ formatter: simple
77
+ stream: ext://sys.stdout
78
+ root:
79
+ level: INFO
80
+ handlers:
81
+ - console
82
+ loggers:
83
+ logging_example:
84
+ level: DEBUG
85
+ disable_existing_loggers: false
86
+ job_logging:
87
+ version: 1
88
+ formatters:
89
+ simple:
90
+ format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
91
+ handlers:
92
+ console:
93
+ class: logging.StreamHandler
94
+ formatter: simple
95
+ stream: ext://sys.stdout
96
+ file:
97
+ class: logging.FileHandler
98
+ formatter: simple
99
+ filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
100
+ root:
101
+ level: INFO
102
+ handlers:
103
+ - console
104
+ - file
105
+ disable_existing_loggers: false
106
+ env: {}
107
+ mode: RUN
108
+ searchpath: []
109
+ callbacks: {}
110
+ output_subdir: .hydra
111
+ overrides:
112
+ hydra:
113
+ - hydra.mode=RUN
114
+ task:
115
+ - data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet
116
+ - data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet
117
+ - data.train_batch_size=128
118
+ - data.val_batch_size=64
119
+ - data.max_prompt_length=4096
120
+ - data.max_response_length=1024
121
+ - data.shuffle_train_dataloader=True
122
+ - algorithm.adv_estimator=grpo
123
+ - actor_rollout_ref.model.path=Qwen/Qwen3-4B-Instruct-2507
124
+ - actor_rollout_ref.model.enable_gradient_checkpointing=true
125
+ - actor_rollout_ref.model.use_remove_padding=False
126
+ - actor_rollout_ref.actor.optim.lr=1e-6
127
+ - actor_rollout_ref.actor.ppo_mini_batch_size=64
128
+ - +actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16
129
+ - actor_rollout_ref.actor.fsdp_config.param_offload=true
130
+ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=true
131
+ - actor_rollout_ref.rollout.log_prob_micro_batch_size=64
132
+ - actor_rollout_ref.rollout.tensor_model_parallel_size=1
133
+ - actor_rollout_ref.rollout.name=vllm
134
+ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4
135
+ - actor_rollout_ref.ref.log_prob_micro_batch_size=64
136
+ - actor_rollout_ref.ref.fsdp_config.param_offload=True
137
+ - actor_rollout_ref.actor.kl_loss_coef=0.001
138
+ - trainer.logger=[wandb]
139
+ - trainer.n_gpus_per_node=2
140
+ - trainer.nnodes=1
141
+ - trainer.save_freq=100
142
+ - trainer.test_freq=50
143
+ - trainer.project_name=
144
+ - trainer.experiment_name=llm_guard_3B_10k_v2
145
+ - trainer.total_epochs=15
146
+ - trainer.total_training_steps=1005
147
+ - trainer.default_local_dir=verl_checkpoints/llm_guard_3B_10k_v2
148
+ - do_search=false
149
+ - max_turns=1
150
+ job:
151
+ name: main_ppo
152
+ chdir: null
153
+ override_dirname: +actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16,actor_rollout_ref.actor.fsdp_config.optimizer_offload=true,actor_rollout_ref.actor.fsdp_config.param_offload=true,actor_rollout_ref.actor.kl_loss_coef=0.001,actor_rollout_ref.actor.optim.lr=1e-6,actor_rollout_ref.actor.ppo_mini_batch_size=64,actor_rollout_ref.model.enable_gradient_checkpointing=true,actor_rollout_ref.model.path=Qwen/Qwen3-4B-Instruct-2507,actor_rollout_ref.model.use_remove_padding=False,actor_rollout_ref.ref.fsdp_config.param_offload=True,actor_rollout_ref.ref.log_prob_micro_batch_size=64,actor_rollout_ref.rollout.gpu_memory_utilization=0.4,actor_rollout_ref.rollout.log_prob_micro_batch_size=64,actor_rollout_ref.rollout.name=vllm,actor_rollout_ref.rollout.tensor_model_parallel_size=1,algorithm.adv_estimator=grpo,data.max_prompt_length=4096,data.max_response_length=1024,data.shuffle_train_dataloader=True,data.train_batch_size=128,data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet,data.val_batch_size=64,data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet,do_search=false,max_turns=1,trainer.default_local_dir=verl_checkpoints/llm_guard_3B_10k_v2,trainer.experiment_name=llm_guard_3B_10k_v2,trainer.logger=[wandb],trainer.n_gpus_per_node=2,trainer.nnodes=1,trainer.project_name=,trainer.save_freq=100,trainer.test_freq=50,trainer.total_epochs=15,trainer.total_training_steps=1005
154
+ id: ???
155
+ num: ???
156
+ config_name: ppo_trainer
157
+ env_set: {}
158
+ env_copy: []
159
+ config:
160
+ override_dirname:
161
+ kv_sep: '='
162
+ item_sep: ','
163
+ exclude_keys: []
164
+ runtime:
165
+ version: 1.3.2
166
+ version_base: '1.3'
167
+ cwd: /data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1
168
+ config_sources:
169
+ - path: hydra.conf
170
+ schema: pkg
171
+ provider: hydra
172
+ - path: /data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1/verl/trainer/config
173
+ schema: file
174
+ provider: main
175
+ - path: ''
176
+ schema: structured
177
+ provider: schema
178
+ output_dir: /data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1/outputs/2026-02-01/20-42-57
179
+ choices:
180
+ hydra/env: default
181
+ hydra/callbacks: null
182
+ hydra/job_logging: default
183
+ hydra/hydra_logging: default
184
+ hydra/hydra_help: default
185
+ hydra/help: default
186
+ hydra/sweeper: basic
187
+ hydra/launcher: basic
188
+ hydra/output: default
189
+ verbose: false
code/RL_model/verl/Search-R1/outputs/2026-02-01/20-42-57/.hydra/overrides.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ - data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet
2
+ - data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet
3
+ - data.train_batch_size=128
4
+ - data.val_batch_size=64
5
+ - data.max_prompt_length=4096
6
+ - data.max_response_length=1024
7
+ - data.shuffle_train_dataloader=True
8
+ - algorithm.adv_estimator=grpo
9
+ - actor_rollout_ref.model.path=Qwen/Qwen3-4B-Instruct-2507
10
+ - actor_rollout_ref.model.enable_gradient_checkpointing=true
11
+ - actor_rollout_ref.model.use_remove_padding=False
12
+ - actor_rollout_ref.actor.optim.lr=1e-6
13
+ - actor_rollout_ref.actor.ppo_mini_batch_size=64
14
+ - +actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16
15
+ - actor_rollout_ref.actor.fsdp_config.param_offload=true
16
+ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=true
17
+ - actor_rollout_ref.rollout.log_prob_micro_batch_size=64
18
+ - actor_rollout_ref.rollout.tensor_model_parallel_size=1
19
+ - actor_rollout_ref.rollout.name=vllm
20
+ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4
21
+ - actor_rollout_ref.ref.log_prob_micro_batch_size=64
22
+ - actor_rollout_ref.ref.fsdp_config.param_offload=True
23
+ - actor_rollout_ref.actor.kl_loss_coef=0.001
24
+ - trainer.logger=[wandb]
25
+ - trainer.n_gpus_per_node=2
26
+ - trainer.nnodes=1
27
+ - trainer.save_freq=100
28
+ - trainer.test_freq=50
29
+ - trainer.project_name=
30
+ - trainer.experiment_name=llm_guard_3B_10k_v2
31
+ - trainer.total_epochs=15
32
+ - trainer.total_training_steps=1005
33
+ - trainer.default_local_dir=verl_checkpoints/llm_guard_3B_10k_v2
34
+ - do_search=false
35
+ - max_turns=1
code/RL_model/verl/Search-R1/outputs/2026-02-01/20-42-57/main_ppo.log ADDED
File without changes
code/RL_model/verl/Search-R1/search_r1/llm_agent/__init__.py ADDED
File without changes
code/RL_model/verl/Search-R1/search_r1/llm_agent/generation.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import re
3
+ from collections import defaultdict
4
+ import os
5
+ from typing import List, Dict, Any, Tuple
6
+ from dataclasses import dataclass
7
+ from .tensor_helper import TensorHelper, TensorConfig
8
+ from verl import DataProto
9
+ from verl.utils.tracking import Tracking
10
+ import shutil
11
+ import requests
12
+
13
+ @dataclass
14
+ class GenerationConfig:
15
+ max_turns: int
16
+ max_start_length: int
17
+ max_prompt_length: int
18
+ max_response_length: int
19
+ max_obs_length: int
20
+ num_gpus: int
21
+ no_think_rl: bool=False
22
+ search_url: str = None
23
+ topk: int = 3
24
+
25
+ class LLMGenerationManager:
26
+ def __init__(
27
+ self,
28
+ tokenizer,
29
+ actor_rollout_wg,
30
+ config: GenerationConfig,
31
+ is_validation: bool = False,
32
+ ):
33
+ self.tokenizer = tokenizer
34
+ self.actor_rollout_wg = actor_rollout_wg
35
+ self.config = config
36
+ self.is_validation = is_validation
37
+
38
+ self.tensor_fn = TensorHelper(TensorConfig(
39
+ pad_token_id=tokenizer.pad_token_id,
40
+ max_prompt_length=config.max_prompt_length,
41
+ max_obs_length=config.max_obs_length,
42
+ max_start_length=config.max_start_length
43
+ ))
44
+
45
+ def _batch_tokenize(self, responses: List[str]) -> torch.Tensor:
46
+ """Tokenize a batch of responses."""
47
+ return self.tokenizer(
48
+ responses,
49
+ add_special_tokens=False,
50
+ return_tensors='pt',
51
+ padding="longest"
52
+ )['input_ids']
53
+
54
+ def _postprocess_responses(self, responses: torch.Tensor) -> torch.Tensor:
55
+ """Process responses to stop at search operation or answer operation."""
56
+ responses_str = self.tokenizer.batch_decode(
57
+ responses,
58
+ skip_special_tokens=True
59
+ )
60
+
61
+ responses_str = [resp.split('</search>')[0] + '</search>'
62
+ if '</search>' in resp
63
+ else resp.split('</answer>')[0] + '</answer>'
64
+ if '</answer>' in resp
65
+ else resp
66
+ for resp in responses_str]
67
+
68
+ if self.config.no_think_rl:
69
+ raise ValueError('stop')
70
+ # if no_think_rl is enabled, only keep action in the str
71
+ actions, _ = self.env.postprocess_predictions(responses_str)
72
+ responses_str=[f"<answer>{envs[idx].ACTION_LOOKUP[action]}</answer>" for idx, action in enumerate(actions)]
73
+ print("RESPONSES:", responses_str)
74
+ responses = self._batch_tokenize(responses_str)
75
+ return responses, responses_str
76
+
77
+ def _process_next_obs(self, next_obs: List[str]) -> torch.Tensor:
78
+ """Process next observations from environment."""
79
+
80
+ next_obs_ids = self.tokenizer(
81
+ next_obs,
82
+ padding='longest',
83
+ return_tensors='pt',
84
+ add_special_tokens=False, # Prevents adding special tokens
85
+ )['input_ids']
86
+
87
+ if next_obs_ids.shape[1] > self.config.max_obs_length:
88
+ print(f"[WARNING] OBSERVATION TOO LONG, CONSIDER CHANGING YOUR CONFIG, {next_obs_ids.shape[1]} & {self.config.max_obs_length}")
89
+ next_obs_ids = next_obs_ids[:, :self.config.max_obs_length]
90
+
91
+ return next_obs_ids
92
+
93
+ def _update_rolling_state(self, rollings: DataProto, cur_responses: torch.Tensor,
94
+ next_obs_ids: torch.Tensor) -> Dict:
95
+ """Update rolling state with new responses and observations."""
96
+ # Concatenate and handle padding
97
+ new_input_ids = self.tensor_fn.concatenate_with_padding([
98
+ rollings.batch['input_ids'],
99
+ cur_responses,
100
+ next_obs_ids
101
+ ])
102
+
103
+ # Create attention mask and position ids
104
+ new_attention_mask = self.tensor_fn.create_attention_mask(new_input_ids)
105
+ new_position_ids = self.tensor_fn.create_position_ids(new_attention_mask)
106
+
107
+ # Cut to appropriate length
108
+ effective_len = new_attention_mask.sum(dim=1).max()
109
+ max_len = min(self.config.max_prompt_length, effective_len)
110
+
111
+ new_rollings = DataProto.from_dict({
112
+ 'input_ids': new_input_ids[:, -max_len:],
113
+ 'position_ids': new_position_ids[:, -max_len:],
114
+ 'attention_mask': new_attention_mask[:, -max_len:]
115
+ })
116
+ new_rollings.meta_info.update(rollings.meta_info)
117
+
118
+ return new_rollings
119
+
120
+ def _info_masked_concatenate_with_padding(self,
121
+ prompt: torch.Tensor,
122
+ prompt_with_mask: torch.Tensor,
123
+ response: torch.Tensor,
124
+ info: torch.Tensor = None,
125
+ pad_to_left: bool = True
126
+ ) -> torch.Tensor:
127
+ """Concatenate tensors and handle padding. Additionally, create a mask (info_mask) to cover the information block if it exists."""
128
+ pad_id = self.tokenizer.pad_token_id
129
+ tensors = [prompt, response]
130
+ tensors_with_mask = [prompt_with_mask, response]
131
+ if info is not None:
132
+ tensors.append(info)
133
+ info_mask = torch.full(info.size(), pad_id, dtype=info.dtype, device=info.device) # information mask
134
+ tensors_with_mask.append(info_mask)
135
+
136
+ concatenated = torch.cat(tensors, dim=1)
137
+ concatenated_with_info = torch.cat(tensors_with_mask, dim=1)
138
+ mask = concatenated != pad_id if pad_to_left else concatenated == pad_id
139
+ sorted_indices = mask.to(torch.int64).argsort(dim=1, stable=True)
140
+ padded_tensor = concatenated.gather(1, sorted_indices)
141
+ padded_tensor_with_info = concatenated_with_info.gather(1, sorted_indices)
142
+
143
+ return padded_tensor, padded_tensor_with_info
144
+
145
+ def _update_right_side(self, right_side: Dict,
146
+ cur_responses: torch.Tensor,
147
+ next_obs_ids: torch.Tensor = None) -> Dict:
148
+ """Update right side state."""
149
+ if next_obs_ids != None:
150
+ responses, responses_with_info_mask = self._info_masked_concatenate_with_padding(
151
+ right_side['responses'],
152
+ right_side['responses_with_info_mask'],
153
+ cur_responses,
154
+ next_obs_ids,
155
+ pad_to_left=False
156
+ )
157
+ else:
158
+ responses, responses_with_info_mask = self._info_masked_concatenate_with_padding(
159
+ right_side['responses'],
160
+ right_side['responses_with_info_mask'],
161
+ cur_responses,
162
+ pad_to_left=False
163
+ )
164
+ effective_len = self.tensor_fn.create_attention_mask(responses).sum(dim=1).max()
165
+ max_len = min(self.config.max_prompt_length, effective_len)
166
+
167
+ return {'responses': responses[:, :max_len], 'responses_with_info_mask': responses_with_info_mask[:, :max_len]}
168
+
169
+ def _generate_with_gpu_padding(self, active_batch: DataProto) -> DataProto:
170
+ """
171
+ Wrapper for generation that handles multi-GPU padding requirements.
172
+ if num_gpus <= 1, return self.actor_rollout_wg.generate_sequences(active_batch)
173
+ if active_batch size is not divisible by num_gpus, pad with first sequence
174
+ then remove padding from output
175
+ """
176
+ num_gpus = self.config.num_gpus
177
+ if num_gpus <= 1:
178
+ return self.actor_rollout_wg.generate_sequences(active_batch)
179
+
180
+ batch_size = active_batch.batch['input_ids'].shape[0]
181
+ remainder = batch_size % num_gpus
182
+
183
+ for key in active_batch.batch.keys():
184
+ active_batch.batch[key] = active_batch.batch[key].long()
185
+ if remainder == 0:
186
+ return self.actor_rollout_wg.generate_sequences(active_batch)
187
+
188
+ # Add padding sequences
189
+ padding_size = num_gpus - remainder
190
+ padded_batch = {}
191
+
192
+ for k, v in active_batch.batch.items():
193
+ # Use first sequence as padding template
194
+ pad_sequence = v[0:1].repeat(padding_size, *[1] * (len(v.shape) - 1))
195
+ padded_batch[k] = torch.cat([v, pad_sequence], dim=0)
196
+
197
+ padded_active_batch = DataProto.from_dict(padded_batch)
198
+ for key in padded_active_batch.batch.keys():
199
+ padded_active_batch.batch[key] = padded_active_batch.batch[key].long()
200
+
201
+ # Generate with padded batch
202
+ padded_output = self.actor_rollout_wg.generate_sequences(padded_active_batch)
203
+
204
+ # Remove padding from output
205
+ trimmed_batch = {k: v[:-padding_size] for k, v in padded_output.batch.items()}
206
+
207
+ # Handle meta_info if present
208
+ if hasattr(padded_output, 'meta_info') and padded_output.meta_info:
209
+ trimmed_meta = {}
210
+ for k, v in padded_output.meta_info.items():
211
+ if isinstance(v, torch.Tensor):
212
+ trimmed_meta[k] = v[:-padding_size]
213
+ else:
214
+ trimmed_meta[k] = v
215
+ padded_output.meta_info = trimmed_meta
216
+
217
+ padded_output.batch = trimmed_batch
218
+ return padded_output
219
+
220
+ def run_llm_loop(self, gen_batch, initial_input_ids: torch.Tensor) -> Tuple[Dict, Dict]:
221
+ """Run main LLM generation loop."""
222
+
223
+ original_left_side = {'input_ids': initial_input_ids[:, -self.config.max_start_length:]}
224
+ original_right_side = {'responses': initial_input_ids[:, []], 'responses_with_info_mask': initial_input_ids[:, []]}
225
+
226
+ active_mask = torch.ones(gen_batch.batch['input_ids'].shape[0], dtype=torch.bool)
227
+ turns_stats = torch.ones(gen_batch.batch['input_ids'].shape[0], dtype=torch.int)
228
+ valid_action_stats = torch.zeros(gen_batch.batch['input_ids'].shape[0], dtype=torch.int)
229
+ valid_search_stats = torch.zeros(gen_batch.batch['input_ids'].shape[0], dtype=torch.int)
230
+ active_num_list = [active_mask.sum().item()]
231
+ rollings = gen_batch
232
+
233
+ # Main generation loop
234
+ for step in range(self.config.max_turns):
235
+ if not active_mask.sum():
236
+ break
237
+ rollings.batch = self.tensor_fn.cut_to_effective_len(
238
+ rollings.batch,
239
+ keys=['input_ids', 'attention_mask', 'position_ids']
240
+ )
241
+
242
+ # gen_output = self.actor_rollout_wg.generate_sequences(rollings)
243
+ rollings_active = DataProto.from_dict({
244
+ k: v[active_mask] for k, v in rollings.batch.items()
245
+ })
246
+ gen_output = self._generate_with_gpu_padding(rollings_active)
247
+
248
+ meta_info = gen_output.meta_info
249
+ responses_ids, responses_str = self._postprocess_responses(gen_output.batch['responses'])
250
+ responses_ids, responses_str = self.tensor_fn._example_level_pad(responses_ids, responses_str, active_mask)
251
+
252
+ # Execute in environment and process observations
253
+ next_obs, dones, valid_action, is_search = self.execute_predictions(
254
+ responses_str, self.tokenizer.pad_token, active_mask
255
+ )
256
+
257
+ curr_active_mask = torch.tensor([not done for done in dones], dtype=torch.bool)
258
+ active_mask = active_mask * curr_active_mask
259
+ active_num_list.append(active_mask.sum().item())
260
+ turns_stats[curr_active_mask] += 1
261
+ valid_action_stats += torch.tensor(valid_action, dtype=torch.int)
262
+ valid_search_stats += torch.tensor(is_search, dtype=torch.int)
263
+
264
+ next_obs_ids = self._process_next_obs(next_obs)
265
+
266
+ # Update states
267
+ rollings = self._update_rolling_state(
268
+ rollings,
269
+ responses_ids,
270
+ next_obs_ids
271
+ )
272
+ original_right_side = self._update_right_side(
273
+ original_right_side,
274
+ responses_ids,
275
+ next_obs_ids
276
+ )
277
+
278
+ # final LLM rollout
279
+ if active_mask.sum():
280
+ rollings.batch = self.tensor_fn.cut_to_effective_len(
281
+ rollings.batch,
282
+ keys=['input_ids', 'attention_mask', 'position_ids']
283
+ )
284
+
285
+ # gen_output = self.actor_rollout_wg.generate_sequences(rollings)
286
+ rollings_active = DataProto.from_dict({
287
+ k: v[active_mask] for k, v in rollings.batch.items()
288
+ })
289
+ gen_output = self._generate_with_gpu_padding(rollings_active)
290
+
291
+ meta_info = gen_output.meta_info
292
+ responses_ids, responses_str = self._postprocess_responses(gen_output.batch['responses'])
293
+ responses_ids, responses_str = self.tensor_fn._example_level_pad(responses_ids, responses_str, active_mask)
294
+
295
+ # # Execute in environment and process observations
296
+ _, dones, valid_action, is_search = self.execute_predictions(
297
+ responses_str, self.tokenizer.pad_token, active_mask, do_search=False
298
+ )
299
+
300
+ curr_active_mask = torch.tensor([not done for done in dones], dtype=torch.bool)
301
+ active_mask = active_mask * curr_active_mask
302
+ active_num_list.append(active_mask.sum().item())
303
+ valid_action_stats += torch.tensor(valid_action, dtype=torch.int)
304
+ valid_search_stats += torch.tensor(is_search, dtype=torch.int)
305
+
306
+
307
+ original_right_side = self._update_right_side(
308
+ original_right_side,
309
+ responses_ids,
310
+ )
311
+
312
+ meta_info['turns_stats'] = turns_stats.tolist()
313
+ meta_info['active_mask'] = active_mask.tolist()
314
+ meta_info['valid_action_stats'] = valid_action_stats.tolist()
315
+ meta_info['valid_search_stats'] = valid_search_stats.tolist()
316
+
317
+ print("ACTIVE_TRAJ_NUM:", active_num_list)
318
+
319
+ return self._compose_final_output(original_left_side, original_right_side, meta_info)
320
+
321
+ def _compose_final_output(self, left_side: Dict,
322
+ right_side: Dict,
323
+ meta_info: Dict) -> Tuple[Dict, Dict]:
324
+ """Compose final generation output."""
325
+ final_output = right_side.copy()
326
+ final_output['prompts'] = left_side['input_ids']
327
+
328
+ # Combine input IDs
329
+ final_output['input_ids'] = torch.cat([
330
+ left_side['input_ids'],
331
+ right_side['responses']
332
+ ], dim=1)
333
+
334
+ # Create attention mask and position ids
335
+ final_output['attention_mask'] = torch.cat([
336
+ self.tensor_fn.create_attention_mask(left_side['input_ids']),
337
+ self.tensor_fn.create_attention_mask(final_output['responses'])
338
+ ], dim=1)
339
+ final_output['info_mask'] = torch.cat([
340
+ self.tensor_fn.create_attention_mask(left_side['input_ids']),
341
+ self.tensor_fn.create_attention_mask(final_output['responses_with_info_mask'])
342
+ ], dim=1)
343
+
344
+ final_output['position_ids'] = self.tensor_fn.create_position_ids(
345
+ final_output['attention_mask']
346
+ )
347
+
348
+ final_output = DataProto.from_dict(final_output)
349
+ final_output.meta_info.update(meta_info)
350
+
351
+ return final_output
352
+
353
+ def execute_predictions(self, predictions: List[str], pad_token: str, active_mask=None, do_search=True) -> List[str]:
354
+ """
355
+ Execute predictions across multiple environments.
356
+ NOTE: the function is the actual `step` function in the environment
357
+ NOTE penalty_for_invalid is not included in observation shown to the LLM
358
+
359
+ Args:
360
+ envs: List of environment instances
361
+ predictions: List of action predictions
362
+ pad_token: Token to use for padding
363
+
364
+ Returns:
365
+ List of observation strings
366
+ """
367
+ cur_actions, contents = self.postprocess_predictions(predictions)
368
+ next_obs, dones, valid_action, is_search = [], [], [], []
369
+
370
+ search_queries = [content for action, content in zip(cur_actions, contents) if action == 'search']
371
+ if do_search:
372
+ search_results = self.batch_search(search_queries)
373
+ assert len(search_results) == sum([1 for action in cur_actions if action == 'search'])
374
+ else:
375
+ search_results = [''] * sum([1 for action in cur_actions if action == 'search'])
376
+
377
+ for i, (action, active) in enumerate(zip(cur_actions, active_mask)):
378
+
379
+ if not active:
380
+ next_obs.append('')
381
+ dones.append(1)
382
+ valid_action.append(0)
383
+ is_search.append(0)
384
+ else:
385
+ if action == 'answer':
386
+ next_obs.append('')
387
+ dones.append(1)
388
+ valid_action.append(1)
389
+ is_search.append(0)
390
+ elif action == 'search':
391
+ next_obs.append(f'\n\n<information>{search_results.pop(0).strip()}</information>\n\n')
392
+ dones.append(0)
393
+ valid_action.append(1)
394
+ is_search.append(1)
395
+ else:
396
+ next_obs.append(f'\nMy previous action is invalid. \
397
+ If I want to search, I should put the query between <search> and </search>. \
398
+ If I want to give the final answer, I should put the answer between <answer> and </answer>. Let me try again.\n')
399
+ dones.append(0)
400
+ valid_action.append(0)
401
+ is_search.append(0)
402
+
403
+ assert len(search_results) == 0
404
+
405
+ return next_obs, dones, valid_action, is_search
406
+
407
+ def postprocess_predictions(self, predictions: List[Any]) -> Tuple[List[int], List[bool]]:
408
+ """
409
+ Process (text-based) predictions from llm into actions and validity flags.
410
+
411
+ Args:
412
+ predictions: List of raw predictions
413
+
414
+ Returns:
415
+ Tuple of (actions list, validity flags list)
416
+ """
417
+ actions = []
418
+ contents = []
419
+
420
+ for prediction in predictions:
421
+ if isinstance(prediction, str): # for llm output
422
+ pattern = r'<(search|answer)>(.*?)</\1>'
423
+ match = re.search(pattern, prediction, re.DOTALL)
424
+ if match:
425
+ content = match.group(2).strip() # Return only the content inside the tags
426
+ action = match.group(1)
427
+ else:
428
+ content = ''
429
+ action = None
430
+ else:
431
+ raise ValueError(f"Invalid prediction type: {type(prediction)}")
432
+
433
+ actions.append(action)
434
+ contents.append(content)
435
+
436
+ return actions, contents
437
+
438
+ def batch_search(self, queries: List[str] = None) -> str:
439
+ """
440
+ Batchified search for queries.
441
+ Args:
442
+ queries: queries to call the search engine
443
+ Returns:
444
+ search results which is concatenated into a string
445
+ """
446
+ results = self._batch_search(queries)['result']
447
+
448
+ return [self._passages2string(result) for result in results]
449
+
450
+ def _batch_search(self, queries):
451
+
452
+ payload = {
453
+ "queries": queries,
454
+ "topk": self.config.topk,
455
+ "return_scores": True
456
+ }
457
+
458
+ return requests.post(self.config.search_url, json=payload).json()
459
+
460
+ def _passages2string(self, retrieval_result):
461
+ format_reference = ''
462
+ for idx, doc_item in enumerate(retrieval_result):
463
+
464
+ content = doc_item['document']['contents']
465
+ title = content.split("\n")[0]
466
+ text = "\n".join(content.split("\n")[1:])
467
+ format_reference += f"Doc {idx+1}(Title: {title}) {text}\n"
468
+
469
+ return format_reference
code/RL_model/verl/Search-R1/search_r1/llm_agent/tensor_helper.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Dict, Tuple, List
3
+ from dataclasses import dataclass
4
+
5
+ @dataclass
6
+ class TensorConfig:
7
+ pad_token_id: int
8
+ max_prompt_length: int
9
+ max_obs_length: int
10
+ max_start_length: int
11
+
12
+ class TensorHelper:
13
+ def __init__(self, config: TensorConfig):
14
+ self.config = config
15
+
16
+ def cut_to_effective_len(self, tensor_dict: Dict[str, torch.Tensor],
17
+ keys: List[str], cut_left: bool = True) -> Dict[str, torch.Tensor]:
18
+ """Cut tensors to their effective length based on attention mask."""
19
+ effective_len = tensor_dict['attention_mask'].sum(dim=1).max()
20
+ result = tensor_dict.copy()
21
+
22
+ for key in keys:
23
+ if cut_left:
24
+ result[key] = tensor_dict[key][:, -effective_len:]
25
+ else:
26
+ result[key] = tensor_dict[key][:, :effective_len]
27
+ return result
28
+
29
+ def convert_pad_structure(self, tensor: torch.Tensor, pad_to_left: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
30
+ """Convert padding structure and return sorted tensor with indices."""
31
+ mask = tensor != self.config.pad_token_id if pad_to_left else tensor == self.config.pad_token_id
32
+ sorted_indices = mask.to(torch.int64).argsort(dim=1, stable=True)
33
+ return tensor.gather(1, sorted_indices), sorted_indices
34
+
35
+ def create_attention_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
36
+ """Create attention mask from input ids."""
37
+ return torch.where(input_ids != self.config.pad_token_id, 1, 0)
38
+
39
+ def create_position_ids(self, attention_mask: torch.Tensor) -> torch.Tensor:
40
+ """Create position ids from attention mask."""
41
+ return (torch.cumsum(attention_mask, dim=1) - 1) * attention_mask
42
+
43
+ def concatenate_with_padding(self, tensors: List[torch.Tensor],
44
+ pad_to_left: bool = True) -> torch.Tensor:
45
+ """Concatenate tensors and handle padding."""
46
+ concatenated = torch.cat(tensors, dim=1)
47
+ padded_tensor, _ = self.convert_pad_structure(concatenated, pad_to_left)
48
+ return padded_tensor
49
+
50
+ def _example_level_pad(self, responses: torch.Tensor,
51
+ responses_str: List[str],
52
+ active_mask: torch.Tensor) -> Tuple[torch.Tensor, List[str]]:
53
+ """
54
+ Pad responses for non-active examples with pad tokens.
55
+ """
56
+ assert active_mask.sum() == responses.shape[0]
57
+ # Create masked responses tensor
58
+ batch_size = active_mask.shape[0]
59
+ seq_len = responses.shape[1]
60
+ padded_responses = torch.full(
61
+ (batch_size, seq_len), self.config.pad_token_id,
62
+ dtype=responses.dtype, device=responses.device
63
+ )
64
+ padded_responses[active_mask] = responses
65
+
66
+ # Create masked response strings
67
+ padded_responses_str = [""] * batch_size
68
+
69
+ s = 0
70
+ for i, is_active in enumerate(active_mask):
71
+ if is_active:
72
+ padded_responses_str[i] = responses_str[s]
73
+ s += 1
74
+
75
+ return padded_responses, padded_responses_str
code/RL_model/verl/Search-R1/search_r1/search/build_index.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ corpus_file=/your/corpus/jsonl/file # jsonl
3
+ save_dir=/the/path/to/save/index
4
+ retriever_name=e5 # this is for indexing naming
5
+ retriever_model=intfloat/e5-base-v2
6
+
7
+ # change faiss_type to HNSW32/64/128 for ANN indexing
8
+ # change retriever_name to bm25 for BM25 indexing
9
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python index_builder.py \
10
+ --retrieval_method $retriever_name \
11
+ --model_path $retriever_model \
12
+ --corpus_path $corpus_file \
13
+ --save_dir $save_dir \
14
+ --use_fp16 \
15
+ --max_length 256 \
16
+ --batch_size 512 \
17
+ --pooling_method mean \
18
+ --faiss_type Flat \
19
+ --save_embedding
code/RL_model/verl/Search-R1/search_r1/search/google_search_server.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import requests
4
+ import argparse
5
+ import asyncio
6
+ import random
7
+ from typing import List, Optional, Dict
8
+ from concurrent.futures import ThreadPoolExecutor
9
+
10
+ import chardet
11
+ import aiohttp
12
+ import bs4
13
+ import uvicorn
14
+ from fastapi import FastAPI
15
+ from pydantic import BaseModel
16
+ from googleapiclient.discovery import build
17
+
18
+
19
+ # --- CLI Args ---
20
+ parser = argparse.ArgumentParser(description="Launch online search server.")
21
+ parser.add_argument('--api_key', type=str, required=True, help="API key for Google search")
22
+ parser.add_argument('--cse_id', type=str, required=True, help="CSE ID for Google search")
23
+ parser.add_argument('--topk', type=int, default=3, help="Number of results to return per query")
24
+ parser.add_argument('--snippet_only', action='store_true', help="If set, only return snippets; otherwise, return full context.")
25
+ args = parser.parse_args()
26
+
27
+
28
+ # --- Config ---
29
+ class OnlineSearchConfig:
30
+ def __init__(self, topk: int = 3, api_key: Optional[str] = None, cse_id: Optional[str] = None, snippet_only: bool = False):
31
+ self.topk = topk
32
+ self.api_key = api_key
33
+ self.cse_id = cse_id
34
+ self.snippet_only = snippet_only
35
+
36
+
37
+ # --- Utilities ---
38
+ def parse_snippet(snippet: str) -> List[str]:
39
+ segments = snippet.split("...")
40
+ return [s.strip() for s in segments if len(s.strip().split()) > 5]
41
+
42
+
43
+ def sanitize_search_query(query: str) -> str:
44
+ # Remove or replace special characters that might cause issues.
45
+ # This is a basic example; you might need to add more characters or patterns.
46
+ sanitized_query = re.sub(r'[^\w\s]', ' ', query) # Replace non-alphanumeric and non-whitespace with spaces.
47
+ sanitized_query = re.sub(r'[\t\r\f\v\n]', ' ', sanitized_query) # replace tab, return, formfeed, vertical tab with spaces.
48
+ sanitized_query = re.sub(r'\s+', ' ', sanitized_query).strip() #remove duplicate spaces, and trailing/leading spaces.
49
+
50
+ return sanitized_query
51
+
52
+
53
+ def filter_links(search_results: List[Dict]) -> List[str]:
54
+ links = []
55
+ for result in search_results:
56
+ for item in result.get("items", []):
57
+ if "mime" in item:
58
+ continue
59
+ ext = os.path.splitext(item["link"])[1]
60
+ if ext in ["", ".html", ".htm", ".shtml"]:
61
+ links.append(item["link"])
62
+ return links
63
+
64
+
65
+ async def fetch(session: aiohttp.ClientSession, url: str, semaphore: asyncio.Semaphore) -> str:
66
+ user_agents = [
67
+ "Mozilla/5.0 (Linux; Android 6.0.1; Nexus 5X Build/MMB29P)...",
68
+ "Mozilla/5.0 AppleWebKit/537.36...",
69
+ "Mozilla/5.0 (compatible; Googlebot/2.1; +https://www.google.com/bot.html)",
70
+ ]
71
+ headers = {"User-Agent": random.choice(user_agents)}
72
+
73
+ async with semaphore:
74
+ try:
75
+ async with session.get(url, headers=headers) as response:
76
+ raw = await response.read()
77
+ detected = chardet.detect(raw)
78
+ encoding = detected["encoding"] or "utf-8"
79
+ return raw.decode(encoding, errors="ignore")
80
+ except (aiohttp.ClientError, asyncio.TimeoutError):
81
+ return ""
82
+
83
+
84
+ async def fetch_all(urls: List[str], limit: int = 8) -> List[str]:
85
+ semaphore = asyncio.Semaphore(limit)
86
+ timeout = aiohttp.ClientTimeout(total=5)
87
+ connector = aiohttp.TCPConnector(limit_per_host=limit, force_close=True)
88
+
89
+ async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
90
+ tasks = [fetch(session, url, semaphore) for url in urls]
91
+ return await asyncio.gather(*tasks)
92
+
93
+
94
+ # --- Search Engine ---
95
+ class OnlineSearchEngine:
96
+ def __init__(self, config: OnlineSearchConfig):
97
+ self.config = config
98
+
99
+ def collect_context(self, snippet: str, doc: str) -> str:
100
+ snippets = parse_snippet(snippet)
101
+ ctx_paras = []
102
+
103
+ for s in snippets:
104
+ pos = doc.replace("\n", " ").find(s)
105
+ if pos == -1:
106
+ continue
107
+ sta = pos
108
+ while sta > 0 and doc[sta] != "\n":
109
+ sta -= 1
110
+ end = pos + len(s)
111
+ while end < len(doc) and doc[end] != "\n":
112
+ end += 1
113
+ para = doc[sta:end].strip()
114
+ if para not in ctx_paras:
115
+ ctx_paras.append(para)
116
+
117
+ return "\n".join(ctx_paras)
118
+
119
+ def fetch_web_content(self, search_results: List[Dict]) -> Dict[str, str]:
120
+ links = filter_links(search_results)
121
+ contents = asyncio.run(fetch_all(links))
122
+ content_dict = {}
123
+ for html, link in zip(contents, links):
124
+ soup = bs4.BeautifulSoup(html, "html.parser")
125
+ text = "\n".join([p.get_text() for p in soup.find_all("p")])
126
+ content_dict[link] = text
127
+ return content_dict
128
+
129
+ def search(self, search_term: str, num_iter: int = 1) -> List[Dict]:
130
+ service = build('customsearch', 'v1', developerKey=self.config.api_key)
131
+ results = []
132
+ sanitize_search_term = sanitize_search_query(search_term)
133
+ if search_term.isspace():
134
+ return results
135
+ res = service.cse().list(q=sanitize_search_term, cx=self.config.cse_id).execute()
136
+ results.append(res)
137
+
138
+ for _ in range(num_iter - 1):
139
+ if 'nextPage' not in res.get('queries', {}):
140
+ break
141
+ start_idx = res['queries']['nextPage'][0]['startIndex']
142
+ res = service.cse().list(q=search_term, cx=self.config.cse_id, start=start_idx).execute()
143
+ results.append(res)
144
+
145
+ return results
146
+
147
+ def batch_search(self, queries: List[str]) -> List[List[str]]:
148
+ with ThreadPoolExecutor() as executor:
149
+ return list(executor.map(self._retrieve_context, queries))
150
+
151
+ def _retrieve_context(self, query: str) -> List[str]:
152
+
153
+ if self.config.snippet_only:
154
+ search_results = self.search(query)
155
+ contexts = []
156
+ for result in search_results:
157
+ for item in result.get("items", []):
158
+ title = item.get("title", "")
159
+ context = ' '.join(parse_snippet(item.get("snippet", "")))
160
+ if title != "" or context != "":
161
+ title = "No title." if not title else title
162
+ context = "No snippet available." if not context else context
163
+ contexts.append({
164
+ 'document': {"contents": f'\"{title}\"\n{context}'},
165
+ })
166
+ else:
167
+ content_dict = self.fetch_web_content(search_results)
168
+ contexts = []
169
+ for result in search_results:
170
+ for item in result.get("items", []):
171
+ link = item["link"]
172
+ title = item.get("title", "")
173
+ snippet = item.get("snippet", "")
174
+ if link in content_dict:
175
+ context = self.collect_context(snippet, content_dict[link])
176
+ if title != "" or context != "":
177
+ title = "No title." if not title else title
178
+ context = "No snippet available." if not context else context
179
+ contexts.append({
180
+ 'document': {"contents": f'\"{title}\"\n{context}'},
181
+ })
182
+
183
+ return contexts[:self.config.topk]
184
+
185
+
186
+ # --- FastAPI App ---
187
+ app = FastAPI(title="Online Search Proxy Server")
188
+
189
+ class SearchRequest(BaseModel):
190
+ queries: List[str]
191
+
192
+ config = OnlineSearchConfig(api_key=args.api_key, cse_id=args.cse_id, topk=args.topk, snippet_only=args.snippet_only)
193
+ engine = OnlineSearchEngine(config)
194
+
195
+ @app.post("/retrieve")
196
+ def search_endpoint(request: SearchRequest):
197
+ results = engine.batch_search(request.queries)
198
+ return {"result": results}
199
+
200
+
201
+ if __name__ == "__main__":
202
+ uvicorn.run(app, host="0.0.0.0", port=8000)
code/RL_model/verl/Search-R1/search_r1/search/index_builder.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import faiss
3
+ import json
4
+ import warnings
5
+ import numpy as np
6
+ from typing import cast, List, Dict
7
+ import shutil
8
+ import subprocess
9
+ import argparse
10
+ import torch
11
+ from tqdm import tqdm
12
+ # from LongRAG.retriever.utils import load_model, load_corpus, pooling
13
+ import datasets
14
+ from transformers import AutoTokenizer, AutoModel, AutoConfig
15
+
16
+
17
+ def load_model(
18
+ model_path: str,
19
+ use_fp16: bool = False
20
+ ):
21
+ model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
22
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
23
+ model.eval()
24
+ model.cuda()
25
+ if use_fp16:
26
+ model = model.half()
27
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=True)
28
+
29
+ return model, tokenizer
30
+
31
+
32
+ def pooling(
33
+ pooler_output,
34
+ last_hidden_state,
35
+ attention_mask = None,
36
+ pooling_method = "mean"
37
+ ):
38
+ if pooling_method == "mean":
39
+ last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
40
+ return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
41
+ elif pooling_method == "cls":
42
+ return last_hidden_state[:, 0]
43
+ elif pooling_method == "pooler":
44
+ return pooler_output
45
+ else:
46
+ raise NotImplementedError("Pooling method not implemented!")
47
+
48
+
49
+ def load_corpus(corpus_path: str):
50
+ corpus = datasets.load_dataset(
51
+ 'json',
52
+ data_files=corpus_path,
53
+ split="train",
54
+ num_proc=4)
55
+ return corpus
56
+
57
+
58
+ class Index_Builder:
59
+ r"""A tool class used to build an index used in retrieval.
60
+
61
+ """
62
+ def __init__(
63
+ self,
64
+ retrieval_method,
65
+ model_path,
66
+ corpus_path,
67
+ save_dir,
68
+ max_length,
69
+ batch_size,
70
+ use_fp16,
71
+ pooling_method,
72
+ faiss_type=None,
73
+ embedding_path=None,
74
+ save_embedding=False,
75
+ faiss_gpu=False
76
+ ):
77
+
78
+ self.retrieval_method = retrieval_method.lower()
79
+ self.model_path = model_path
80
+ self.corpus_path = corpus_path
81
+ self.save_dir = save_dir
82
+ self.max_length = max_length
83
+ self.batch_size = batch_size
84
+ self.use_fp16 = use_fp16
85
+ self.pooling_method = pooling_method
86
+ self.faiss_type = faiss_type if faiss_type is not None else 'Flat'
87
+ self.embedding_path = embedding_path
88
+ self.save_embedding = save_embedding
89
+ self.faiss_gpu = faiss_gpu
90
+
91
+ self.gpu_num = torch.cuda.device_count()
92
+ # prepare save dir
93
+ print(self.save_dir)
94
+ if not os.path.exists(self.save_dir):
95
+ os.makedirs(self.save_dir)
96
+ else:
97
+ if not self._check_dir(self.save_dir):
98
+ warnings.warn("Some files already exists in save dir and may be overwritten.", UserWarning)
99
+
100
+ self.index_save_path = os.path.join(self.save_dir, f"{self.retrieval_method}_{self.faiss_type}.index")
101
+
102
+ self.embedding_save_path = os.path.join(self.save_dir, f"emb_{self.retrieval_method}.memmap")
103
+
104
+ self.corpus = load_corpus(self.corpus_path)
105
+
106
+ print("Finish loading...")
107
+ @staticmethod
108
+ def _check_dir(dir_path):
109
+ r"""Check if the dir path exists and if there is content.
110
+
111
+ """
112
+
113
+ if os.path.isdir(dir_path):
114
+ if len(os.listdir(dir_path)) > 0:
115
+ return False
116
+ else:
117
+ os.makedirs(dir_path, exist_ok=True)
118
+ return True
119
+
120
+ def build_index(self):
121
+ r"""Constructing different indexes based on selective retrieval method.
122
+
123
+ """
124
+ if self.retrieval_method == "bm25":
125
+ self.build_bm25_index()
126
+ else:
127
+ self.build_dense_index()
128
+
129
+ def build_bm25_index(self):
130
+ """Building BM25 index based on Pyserini library.
131
+
132
+ Reference: https://github.com/castorini/pyserini/blob/master/docs/usage-index.md#building-a-bm25-index-direct-java-implementation
133
+ """
134
+
135
+ # to use pyserini pipeline, we first need to place jsonl file in the folder
136
+ self.save_dir = os.path.join(self.save_dir, "bm25")
137
+ os.makedirs(self.save_dir, exist_ok=True)
138
+ temp_dir = self.save_dir + "/temp"
139
+ temp_file_path = temp_dir + "/temp.jsonl"
140
+ os.makedirs(temp_dir)
141
+
142
+ # if self.have_contents:
143
+ # shutil.copyfile(self.corpus_path, temp_file_path)
144
+ # else:
145
+ # with open(temp_file_path, "w") as f:
146
+ # for item in self.corpus:
147
+ # f.write(json.dumps(item) + "\n")
148
+ shutil.copyfile(self.corpus_path, temp_file_path)
149
+
150
+ print("Start building bm25 index...")
151
+ pyserini_args = ["--collection", "JsonCollection",
152
+ "--input", temp_dir,
153
+ "--index", self.save_dir,
154
+ "--generator", "DefaultLuceneDocumentGenerator",
155
+ "--threads", "1"]
156
+
157
+ subprocess.run(["python", "-m", "pyserini.index.lucene"] + pyserini_args)
158
+
159
+ shutil.rmtree(temp_dir)
160
+
161
+ print("Finish!")
162
+
163
+ def _load_embedding(self, embedding_path, corpus_size, hidden_size):
164
+ all_embeddings = np.memmap(
165
+ embedding_path,
166
+ mode="r",
167
+ dtype=np.float32
168
+ ).reshape(corpus_size, hidden_size)
169
+ return all_embeddings
170
+
171
+ def _save_embedding(self, all_embeddings):
172
+ memmap = np.memmap(
173
+ self.embedding_save_path,
174
+ shape=all_embeddings.shape,
175
+ mode="w+",
176
+ dtype=all_embeddings.dtype
177
+ )
178
+ length = all_embeddings.shape[0]
179
+ # add in batch
180
+ save_batch_size = 10000
181
+ if length > save_batch_size:
182
+ for i in tqdm(range(0, length, save_batch_size), leave=False, desc="Saving Embeddings"):
183
+ j = min(i + save_batch_size, length)
184
+ memmap[i: j] = all_embeddings[i: j]
185
+ else:
186
+ memmap[:] = all_embeddings
187
+
188
+ def encode_all(self):
189
+ if self.gpu_num > 1:
190
+ print("Use multi gpu!")
191
+ self.encoder = torch.nn.DataParallel(self.encoder)
192
+ self.batch_size = self.batch_size * self.gpu_num
193
+
194
+ all_embeddings = []
195
+
196
+ for start_idx in tqdm(range(0, len(self.corpus), self.batch_size), desc='Inference Embeddings:'):
197
+
198
+ # batch_data_title = self.corpus[start_idx:start_idx+self.batch_size]['title']
199
+ # batch_data_text = self.corpus[start_idx:start_idx+self.batch_size]['text']
200
+ # batch_data = ['"' + title + '"\n' + text for title, text in zip(batch_data_title, batch_data_text)]
201
+ batch_data = self.corpus[start_idx:start_idx+self.batch_size]['contents']
202
+
203
+ if self.retrieval_method == "e5":
204
+ batch_data = [f"passage: {doc}" for doc in batch_data]
205
+
206
+ inputs = self.tokenizer(
207
+ batch_data,
208
+ padding=True,
209
+ truncation=True,
210
+ return_tensors='pt',
211
+ max_length=self.max_length,
212
+ ).to('cuda')
213
+
214
+ inputs = {k: v.cuda() for k, v in inputs.items()}
215
+
216
+ #TODO: support encoder-only T5 model
217
+ if "T5" in type(self.encoder).__name__:
218
+ # T5-based retrieval model
219
+ decoder_input_ids = torch.zeros(
220
+ (inputs['input_ids'].shape[0], 1), dtype=torch.long
221
+ ).to(inputs['input_ids'].device)
222
+ output = self.encoder(
223
+ **inputs, decoder_input_ids=decoder_input_ids, return_dict=True
224
+ )
225
+ embeddings = output.last_hidden_state[:, 0, :]
226
+
227
+ else:
228
+ output = self.encoder(**inputs, return_dict=True)
229
+ embeddings = pooling(output.pooler_output,
230
+ output.last_hidden_state,
231
+ inputs['attention_mask'],
232
+ self.pooling_method)
233
+ if "dpr" not in self.retrieval_method:
234
+ embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
235
+
236
+ embeddings = cast(torch.Tensor, embeddings)
237
+ embeddings = embeddings.detach().cpu().numpy()
238
+ all_embeddings.append(embeddings)
239
+
240
+ all_embeddings = np.concatenate(all_embeddings, axis=0)
241
+ all_embeddings = all_embeddings.astype(np.float32)
242
+
243
+ return all_embeddings
244
+
245
+ @torch.no_grad()
246
+ def build_dense_index(self):
247
+ """Obtain the representation of documents based on the embedding model(BERT-based) and
248
+ construct a faiss index.
249
+ """
250
+
251
+ if os.path.exists(self.index_save_path):
252
+ print("The index file already exists and will be overwritten.")
253
+
254
+ self.encoder, self.tokenizer = load_model(model_path = self.model_path,
255
+ use_fp16 = self.use_fp16)
256
+ if self.embedding_path is not None:
257
+ hidden_size = self.encoder.config.hidden_size
258
+ corpus_size = len(self.corpus)
259
+ all_embeddings = self._load_embedding(self.embedding_path, corpus_size, hidden_size)
260
+ else:
261
+ all_embeddings = self.encode_all()
262
+ if self.save_embedding:
263
+ self._save_embedding(all_embeddings)
264
+ del self.corpus
265
+
266
+ # build index
267
+ print("Creating index")
268
+ dim = all_embeddings.shape[-1]
269
+ faiss_index = faiss.index_factory(dim, self.faiss_type, faiss.METRIC_INNER_PRODUCT)
270
+
271
+ if self.faiss_gpu:
272
+ co = faiss.GpuMultipleClonerOptions()
273
+ co.useFloat16 = True
274
+ co.shard = True
275
+ faiss_index = faiss.index_cpu_to_all_gpus(faiss_index, co)
276
+ if not faiss_index.is_trained:
277
+ faiss_index.train(all_embeddings)
278
+ faiss_index.add(all_embeddings)
279
+ faiss_index = faiss.index_gpu_to_cpu(faiss_index)
280
+ else:
281
+ if not faiss_index.is_trained:
282
+ faiss_index.train(all_embeddings)
283
+ faiss_index.add(all_embeddings)
284
+
285
+ faiss.write_index(faiss_index, self.index_save_path)
286
+ print("Finish!")
287
+
288
+
289
+ MODEL2POOLING = {
290
+ "e5": "mean",
291
+ "bge": "cls",
292
+ "contriever": "mean",
293
+ 'jina': 'mean'
294
+ }
295
+
296
+
297
+ def main():
298
+ parser = argparse.ArgumentParser(description = "Creating index.")
299
+
300
+ # Basic parameters
301
+ parser.add_argument('--retrieval_method', type=str)
302
+ parser.add_argument('--model_path', type=str, default=None)
303
+ parser.add_argument('--corpus_path', type=str)
304
+ parser.add_argument('--save_dir', default= 'indexes/',type=str)
305
+
306
+ # Parameters for building dense index
307
+ parser.add_argument('--max_length', type=int, default=180)
308
+ parser.add_argument('--batch_size', type=int, default=512)
309
+ parser.add_argument('--use_fp16', default=False, action='store_true')
310
+ parser.add_argument('--pooling_method', type=str, default=None)
311
+ parser.add_argument('--faiss_type',default=None,type=str)
312
+ parser.add_argument('--embedding_path', default=None, type=str)
313
+ parser.add_argument('--save_embedding', action='store_true', default=False)
314
+ parser.add_argument('--faiss_gpu', default=False, action='store_true')
315
+
316
+ args = parser.parse_args()
317
+
318
+ if args.pooling_method is None:
319
+ pooling_method = 'mean'
320
+ for k,v in MODEL2POOLING.items():
321
+ if k in args.retrieval_method.lower():
322
+ pooling_method = v
323
+ break
324
+ else:
325
+ if args.pooling_method not in ['mean','cls','pooler']:
326
+ raise NotImplementedError
327
+ else:
328
+ pooling_method = args.pooling_method
329
+
330
+
331
+ index_builder = Index_Builder(
332
+ retrieval_method = args.retrieval_method,
333
+ model_path = args.model_path,
334
+ corpus_path = args.corpus_path,
335
+ save_dir = args.save_dir,
336
+ max_length = args.max_length,
337
+ batch_size = args.batch_size,
338
+ use_fp16 = args.use_fp16,
339
+ pooling_method = pooling_method,
340
+ faiss_type = args.faiss_type,
341
+ embedding_path = args.embedding_path,
342
+ save_embedding = args.save_embedding,
343
+ faiss_gpu = args.faiss_gpu
344
+ )
345
+ index_builder.build_index()
346
+
347
+
348
+ if __name__ == "__main__":
349
+ main()
code/RL_model/verl/Search-R1/search_r1/search/rerank_server.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from collections import defaultdict
3
+ from typing import Optional
4
+ from dataclasses import dataclass, field
5
+
6
+ from sentence_transformers import CrossEncoder
7
+ import torch
8
+ from transformers import HfArgumentParser
9
+ import numpy as np
10
+
11
+ import uvicorn
12
+ from fastapi import FastAPI
13
+ from pydantic import BaseModel
14
+
15
+
16
+ class BaseCrossEncoder:
17
+ def __init__(self, model, batch_size=32, device="cuda"):
18
+ self.model = model
19
+ self.batch_size = batch_size
20
+ self.model.to(device)
21
+
22
+ def _passage_to_string(self, doc_item):
23
+ if "document" not in doc_item:
24
+ content = doc_item['contents']
25
+ else:
26
+ content = doc_item['document']['contents']
27
+ title = content.split("\n")[0]
28
+ text = "\n".join(content.split("\n")[1:])
29
+
30
+ return f"(Title: {title}) {text}"
31
+
32
+ def rerank(self,
33
+ queries: list[str],
34
+ documents: list[list[dict]]):
35
+ """
36
+ Assume documents is a list of list of dicts, where each dict is a document with keys "id" and "contents".
37
+ This asumption is made to be consistent with the output of the retrieval server.
38
+ """
39
+ assert len(queries) == len(documents)
40
+
41
+ pairs = []
42
+ qids = []
43
+ for qid, query in enumerate(queries):
44
+ for document in documents:
45
+ for doc_item in document:
46
+ doc = self._passage_to_string(doc_item)
47
+ pairs.append((query, doc))
48
+ qids.append(qid)
49
+
50
+ scores = self._predict(pairs)
51
+ query_to_doc_scores = defaultdict(list)
52
+
53
+ assert len(scores) == len(pairs) == len(qids)
54
+ for i in range(len(pairs)):
55
+ query, doc = pairs[i]
56
+ score = scores[i]
57
+ qid = qids[i]
58
+ query_to_doc_scores[qid].append((doc, score))
59
+
60
+ sorted_query_to_doc_scores = {}
61
+ for query, doc_scores in query_to_doc_scores.items():
62
+ sorted_query_to_doc_scores[query] = sorted(doc_scores, key=lambda x: x[1], reverse=True)
63
+
64
+ return sorted_query_to_doc_scores
65
+
66
+ def _predict(self, pairs: list[tuple[str, str]]):
67
+ raise NotImplementedError
68
+
69
+ @classmethod
70
+ def load(cls, model_name_or_path, **kwargs):
71
+ raise NotImplementedError
72
+
73
+
74
+ class SentenceTransformerCrossEncoder(BaseCrossEncoder):
75
+ def __init__(self, model, batch_size=32, device="cuda"):
76
+ super().__init__(model, batch_size, device)
77
+
78
+ def _predict(self, pairs: list[tuple[str, str]]):
79
+ scores = self.model.predict(pairs, batch_size=self.batch_size)
80
+ scores = scores.tolist() if isinstance(scores, torch.Tensor) or isinstance(scores, np.ndarray) else scores
81
+ return scores
82
+
83
+ @classmethod
84
+ def load(cls, model_name_or_path, **kwargs):
85
+ model = CrossEncoder(model_name_or_path)
86
+ return cls(model, **kwargs)
87
+
88
+
89
+ class RerankRequest(BaseModel):
90
+ queries: list[str]
91
+ documents: list[list[dict]]
92
+ rerank_topk: Optional[int] = None
93
+ return_scores: bool = False
94
+
95
+
96
+ @dataclass
97
+ class RerankerArguments:
98
+ max_length: int = field(default=512)
99
+ rerank_topk: int = field(default=3)
100
+ rerank_model_name_or_path: str = field(default="cross-encoder/ms-marco-MiniLM-L12-v2")
101
+ batch_size: int = field(default=32)
102
+ reranker_type: str = field(default="sentence_transformer")
103
+
104
+ def get_reranker(config):
105
+ if config.reranker_type == "sentence_transformer":
106
+ return SentenceTransformerCrossEncoder.load(
107
+ config.rerank_model_name_or_path,
108
+ batch_size=config.batch_size,
109
+ device="cuda" if torch.cuda.is_available() else "cpu"
110
+ )
111
+ else:
112
+ raise ValueError(f"Unknown reranker type: {config.reranker_type}")
113
+
114
+
115
+ app = FastAPI()
116
+
117
+ @app.post("/rerank")
118
+ def rerank_endpoint(request: RerankRequest):
119
+ """
120
+ Endpoint that accepts queries and performs retrieval.
121
+ Input format:
122
+ {
123
+ "queries": ["What is Python?", "Tell me about neural networks."],
124
+ "documents": [[doc_item_1, ..., doc_item_k], [doc_item_1, ..., doc_item_k]],
125
+ "rerank_topk": 3,
126
+ "return_scores": true
127
+ }
128
+ """
129
+ if not request.rerank_topk:
130
+ request.rerank_topk = config.rerank_topk # fallback to default
131
+
132
+ # Perform batch re reranking
133
+ # doc_scores already sorted by score
134
+ query_to_doc_scores = reranker.rerank(request.queries, request.documents)
135
+
136
+ # Format response
137
+ resp = []
138
+ for _, doc_scores in query_to_doc_scores.items():
139
+ doc_scores = doc_scores[:request.rerank_topk]
140
+ if request.return_scores:
141
+ combined = []
142
+ for doc, score in doc_scores:
143
+ combined.append({"document": doc, "score": score})
144
+ resp.append(combined)
145
+ else:
146
+ resp.append([doc for doc, _ in doc_scores])
147
+ return {"result": resp}
148
+
149
+
150
+ if __name__ == "__main__":
151
+
152
+ # 1) Build a config (could also parse from arguments).
153
+ # In real usage, you'd parse your CLI arguments or environment variables.
154
+ parser = HfArgumentParser((RerankerArguments))
155
+ config = parser.parse_args_into_dataclasses()[0]
156
+
157
+ # 2) Instantiate a global retriever so it is loaded once and reused.
158
+ reranker = get_reranker(config)
159
+
160
+ # 3) Launch the server. By default, it listens on http://127.0.0.1:8000
161
+ uvicorn.run(app, host="0.0.0.0", port=6980)
code/RL_model/verl/Search-R1/search_r1/search/retrieval.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import warnings
4
+ from typing import List, Dict
5
+ import functools
6
+ from tqdm import tqdm
7
+ from multiprocessing import Pool
8
+ import faiss
9
+ import torch
10
+ import numpy as np
11
+ from transformers import AutoConfig, AutoTokenizer, AutoModel
12
+ import argparse
13
+ import datasets
14
+
15
+
16
+ def load_corpus(corpus_path: str):
17
+ corpus = datasets.load_dataset(
18
+ 'json',
19
+ data_files=corpus_path,
20
+ split="train",
21
+ num_proc=4)
22
+ return corpus
23
+
24
+
25
+ def read_jsonl(file_path):
26
+ data = []
27
+
28
+ with open(file_path, "r") as f:
29
+ readin = f.readlines()
30
+ for line in readin:
31
+ data.append(json.loads(line))
32
+ return data
33
+
34
+
35
+ def load_docs(corpus, doc_idxs):
36
+ results = [corpus[int(idx)] for idx in doc_idxs]
37
+
38
+ return results
39
+
40
+
41
+ def load_model(
42
+ model_path: str,
43
+ use_fp16: bool = False
44
+ ):
45
+ model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
46
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
47
+ model.eval()
48
+ model.cuda()
49
+ if use_fp16:
50
+ model = model.half()
51
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=True)
52
+
53
+ return model, tokenizer
54
+
55
+
56
+ def pooling(
57
+ pooler_output,
58
+ last_hidden_state,
59
+ attention_mask = None,
60
+ pooling_method = "mean"
61
+ ):
62
+ if pooling_method == "mean":
63
+ last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
64
+ return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
65
+ elif pooling_method == "cls":
66
+ return last_hidden_state[:, 0]
67
+ elif pooling_method == "pooler":
68
+ return pooler_output
69
+ else:
70
+ raise NotImplementedError("Pooling method not implemented!")
71
+
72
+
73
+ class Encoder:
74
+ def __init__(self, model_name, model_path, pooling_method, max_length, use_fp16):
75
+ self.model_name = model_name
76
+ self.model_path = model_path
77
+ self.pooling_method = pooling_method
78
+ self.max_length = max_length
79
+ self.use_fp16 = use_fp16
80
+
81
+ self.model, self.tokenizer = load_model(model_path=model_path,
82
+ use_fp16=use_fp16)
83
+
84
+ @torch.no_grad()
85
+ def encode(self, query_list: List[str], is_query=True) -> np.ndarray:
86
+ # processing query for different encoders
87
+ if isinstance(query_list, str):
88
+ query_list = [query_list]
89
+
90
+ if "e5" in self.model_name.lower():
91
+ if is_query:
92
+ query_list = [f"query: {query}" for query in query_list]
93
+ else:
94
+ query_list = [f"passage: {query}" for query in query_list]
95
+
96
+ if "bge" in self.model_name.lower():
97
+ if is_query:
98
+ query_list = [f"Represent this sentence for searching relevant passages: {query}" for query in query_list]
99
+
100
+ inputs = self.tokenizer(query_list,
101
+ max_length=self.max_length,
102
+ padding=True,
103
+ truncation=True,
104
+ return_tensors="pt"
105
+ )
106
+ inputs = {k: v.cuda() for k, v in inputs.items()}
107
+
108
+ if "T5" in type(self.model).__name__:
109
+ # T5-based retrieval model
110
+ decoder_input_ids = torch.zeros(
111
+ (inputs['input_ids'].shape[0], 1), dtype=torch.long
112
+ ).to(inputs['input_ids'].device)
113
+ output = self.model(
114
+ **inputs, decoder_input_ids=decoder_input_ids, return_dict=True
115
+ )
116
+ query_emb = output.last_hidden_state[:, 0, :]
117
+
118
+ else:
119
+ output = self.model(**inputs, return_dict=True)
120
+ query_emb = pooling(output.pooler_output,
121
+ output.last_hidden_state,
122
+ inputs['attention_mask'],
123
+ self.pooling_method)
124
+ if "dpr" not in self.model_name.lower():
125
+ query_emb = torch.nn.functional.normalize(query_emb, dim=-1)
126
+
127
+ query_emb = query_emb.detach().cpu().numpy()
128
+ query_emb = query_emb.astype(np.float32, order="C")
129
+ return query_emb
130
+
131
+
132
+ class BaseRetriever:
133
+ """Base object for all retrievers."""
134
+
135
+ def __init__(self, config):
136
+ self.config = config
137
+ self.retrieval_method = config.retrieval_method
138
+ self.topk = config.retrieval_topk
139
+
140
+ self.index_path = config.index_path
141
+ self.corpus_path = config.corpus_path
142
+
143
+ # self.cache_save_path = os.path.join(config.save_dir, 'retrieval_cache.json')
144
+
145
+ def _search(self, query: str, num: int, return_score:bool) -> List[Dict[str, str]]:
146
+ r"""Retrieve topk relevant documents in corpus.
147
+ Return:
148
+ list: contains information related to the document, including:
149
+ contents: used for building index
150
+ title: (if provided)
151
+ text: (if provided)
152
+ """
153
+ pass
154
+
155
+ def _batch_search(self, query_list, num, return_score):
156
+ pass
157
+
158
+ def search(self, *args, **kwargs):
159
+ return self._search(*args, **kwargs)
160
+
161
+ def batch_search(self, *args, **kwargs):
162
+ return self._batch_search(*args, **kwargs)
163
+
164
+
165
+ class BM25Retriever(BaseRetriever):
166
+ r"""BM25 retriever based on pre-built pyserini index."""
167
+
168
+ def __init__(self, config):
169
+ super().__init__(config)
170
+ from pyserini.search.lucene import LuceneSearcher
171
+ self.searcher = LuceneSearcher(self.index_path)
172
+ self.contain_doc = self._check_contain_doc()
173
+ if not self.contain_doc:
174
+ self.corpus = load_corpus(self.corpus_path)
175
+ self.max_process_num = 8
176
+
177
+ def _check_contain_doc(self):
178
+ r"""Check if the index contains document content
179
+ """
180
+ return self.searcher.doc(0).raw() is not None
181
+
182
+ def _search(self, query: str, num: int = None, return_score = False) -> List[Dict[str, str]]:
183
+ if num is None:
184
+ num = self.topk
185
+
186
+ hits = self.searcher.search(query, num)
187
+ if len(hits) < 1:
188
+ if return_score:
189
+ return [],[]
190
+ else:
191
+ return []
192
+
193
+ scores = [hit.score for hit in hits]
194
+ if len(hits) < num:
195
+ warnings.warn('Not enough documents retrieved!')
196
+ else:
197
+ hits = hits[:num]
198
+
199
+ if self.contain_doc:
200
+ all_contents = [json.loads(self.searcher.doc(hit.docid).raw())['contents'] for hit in hits]
201
+ results = [{'title': content.split("\n")[0].strip("\""),
202
+ 'text': "\n".join(content.split("\n")[1:]),
203
+ 'contents': content} for content in all_contents]
204
+ else:
205
+ results = load_docs(self.corpus, [hit.docid for hit in hits])
206
+
207
+ if return_score:
208
+ return results, scores
209
+ else:
210
+ return results
211
+
212
+ def _batch_search(self, query_list, num: int = None, return_score = False):
213
+ # TODO: modify batch method
214
+ results = []
215
+ scores = []
216
+ for query in query_list:
217
+ item_result, item_score = self._search(query, num,True)
218
+ results.append(item_result)
219
+ scores.append(item_score)
220
+
221
+ if return_score:
222
+ return results, scores
223
+ else:
224
+ return results
225
+
226
+ def get_available_gpu_memory():
227
+ memory_info = []
228
+ for i in range(torch.cuda.device_count()):
229
+ total_memory = torch.cuda.get_device_properties(i).total_memory
230
+ allocated_memory = torch.cuda.memory_allocated(i)
231
+ free_memory = total_memory - allocated_memory
232
+ memory_info.append((i, free_memory / 1e9)) # Convert to GB
233
+ return memory_info
234
+
235
+
236
+ class DenseRetriever(BaseRetriever):
237
+ r"""Dense retriever based on pre-built faiss index."""
238
+
239
+ def __init__(self, config: dict):
240
+ super().__init__(config)
241
+ self.index = faiss.read_index(self.index_path)
242
+ if config.faiss_gpu:
243
+ co = faiss.GpuMultipleClonerOptions()
244
+ co.useFloat16 = True
245
+ co.shard = True
246
+ self.index = faiss.index_cpu_to_all_gpus(self.index, co=co)
247
+ # self.index = faiss.index_cpu_to_all_gpus(self.index)
248
+
249
+ self.corpus = load_corpus(self.corpus_path)
250
+ self.encoder = Encoder(
251
+ model_name = self.retrieval_method,
252
+ model_path = config.retrieval_model_path,
253
+ pooling_method = config.retrieval_pooling_method,
254
+ max_length = config.retrieval_query_max_length,
255
+ use_fp16 = config.retrieval_use_fp16
256
+ )
257
+ self.topk = config.retrieval_topk
258
+ self.batch_size = self.config.retrieval_batch_size
259
+
260
+ def _search(self, query: str, num: int = None, return_score = False):
261
+ if num is None:
262
+ num = self.topk
263
+ query_emb = self.encoder.encode(query)
264
+ scores, idxs = self.index.search(query_emb, k=num)
265
+ idxs = idxs[0]
266
+ scores = scores[0]
267
+
268
+ results = load_docs(self.corpus, idxs)
269
+ if return_score:
270
+ return results, scores
271
+ else:
272
+ return results
273
+
274
+ def _batch_search(self, query_list: List[str], num: int = None, return_score = False):
275
+ if isinstance(query_list, str):
276
+ query_list = [query_list]
277
+ if num is None:
278
+ num = self.topk
279
+
280
+ batch_size = self.batch_size
281
+
282
+ results = []
283
+ scores = []
284
+
285
+ for start_idx in tqdm(range(0, len(query_list), batch_size), desc='Retrieval process: '):
286
+ query_batch = query_list[start_idx:start_idx + batch_size]
287
+
288
+ # from time import time
289
+ # a = time()
290
+ batch_emb = self.encoder.encode(query_batch)
291
+ # b = time()
292
+ # print(f'################### encode time {b-a} #####################')
293
+ batch_scores, batch_idxs = self.index.search(batch_emb, k=num)
294
+ batch_scores = batch_scores.tolist()
295
+ batch_idxs = batch_idxs.tolist()
296
+ # print(f'################### search time {time()-b} #####################')
297
+ # exit()
298
+
299
+ flat_idxs = sum(batch_idxs, [])
300
+ batch_results = load_docs(self.corpus, flat_idxs)
301
+ batch_results = [batch_results[i*num : (i+1)*num] for i in range(len(batch_idxs))]
302
+
303
+ scores.extend(batch_scores)
304
+ results.extend(batch_results)
305
+
306
+ if return_score:
307
+ return results, scores
308
+ else:
309
+ return results
310
+
311
+ def get_retriever(config):
312
+ r"""Automatically select retriever class based on config's retrieval method
313
+
314
+ Args:
315
+ config (dict): configuration with 'retrieval_method' key
316
+
317
+ Returns:
318
+ Retriever: retriever instance
319
+ """
320
+ if config.retrieval_method == "bm25":
321
+ return BM25Retriever(config)
322
+ else:
323
+ return DenseRetriever(config)
324
+
325
+
326
+ def get_dataset(config):
327
+ """Load dataset from config."""
328
+
329
+ split_path = os.path.join(config.dataset_path, f'{config.data_split}.jsonl')
330
+ return read_jsonl(split_path)
331
+
332
+
333
+ if __name__ == '__main__':
334
+
335
+ parser = argparse.ArgumentParser(description = "Retrieval")
336
+
337
+ # Basic parameters
338
+ parser.add_argument('--retrieval_method', type=str)
339
+ parser.add_argument('--retrieval_topk', type=int, default=10)
340
+ parser.add_argument('--index_path', type=str, default=None)
341
+ parser.add_argument('--corpus_path', type=str)
342
+ parser.add_argument('--dataset_path', default=None, type=str)
343
+
344
+ parser.add_argument('--faiss_gpu', default=True, type=bool)
345
+ parser.add_argument('--data_split', default="train", type=str)
346
+
347
+ parser.add_argument('--retrieval_model_path', type=str, default=None)
348
+ parser.add_argument('--retrieval_pooling_method', default='mean', type=str)
349
+ parser.add_argument('--retrieval_query_max_length', default=256, type=str)
350
+ parser.add_argument('--retrieval_use_fp16', action='store_true', default=False)
351
+ parser.add_argument('--retrieval_batch_size', default=512, type=int)
352
+
353
+ args = parser.parse_args()
354
+
355
+ args.index_path = os.path.join(args.index_path, f'{args.retrieval_method}_Flat.index') if args.retrieval_method != 'bm25' else os.path.join(args.index_path, 'bm25')
356
+
357
+ # load dataset
358
+ all_split = get_dataset(args)
359
+
360
+ input_query = [sample['question'] for sample in all_split[:512]]
361
+
362
+ # initialize the retriever and conduct retrieval
363
+ retriever = get_retriever(args)
364
+ print('Start Retrieving ...')
365
+ results, scores = retriever.batch_search(input_query, return_score=True)
366
+
367
+ # from IPython import embed
368
+ # embed()
code/RL_model/verl/Search-R1/search_r1/search/retrieval.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ DATA_NAME=nq
3
+
4
+ DATASET_PATH="/home/peterjin/mnt/data/$DATA_NAME"
5
+
6
+ SPLIT='test'
7
+ TOPK=3
8
+
9
+ INDEX_PATH=/home/peterjin/mnt/index/wiki-18
10
+ CORPUS_PATH=/home/peterjin/mnt/data/retrieval-corpus/wiki-18.jsonl
11
+ SAVE_NAME=e5_${TOPK}_wiki18.json
12
+
13
+ # INDEX_PATH=/home/peterjin/rm_retrieval_corpus/index/wiki-21
14
+ # CORPUS_PATH=/home/peterjin/rm_retrieval_corpus/corpora/wiki/enwiki-dec2021/text-list-100-sec.jsonl
15
+ # SAVE_NAME=e5_${TOPK}_wiki21.json
16
+
17
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python retrieval.py --retrieval_method e5 \
18
+ --retrieval_topk $TOPK \
19
+ --index_path $INDEX_PATH \
20
+ --corpus_path $CORPUS_PATH \
21
+ --dataset_path $DATASET_PATH \
22
+ --data_split $SPLIT \
23
+ --retrieval_model_path "intfloat/e5-base-v2" \
24
+ --retrieval_pooling_method "mean" \
25
+ --retrieval_batch_size 512 \
code/RL_model/verl/Search-R1/search_r1/search/retrieval_request.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+
3
+ # URL for your local FastAPI server
4
+ url = "http://127.0.0.1:8000/retrieve"
5
+
6
+ # Example payload
7
+ payload = {
8
+ "queries": ["What is the capital of France?", "Explain neural networks."] * 200,
9
+ "topk": 5,
10
+ "return_scores": True
11
+ }
12
+
13
+ # Send POST request
14
+ response = requests.post(url, json=payload)
15
+
16
+ # Raise an exception if the request failed
17
+ response.raise_for_status()
18
+
19
+ # Get the JSON response
20
+ retrieved_data = response.json()
21
+
22
+ print("Response from server:")
23
+ print(retrieved_data)
code/RL_model/verl/Search-R1/search_r1/search/retrieval_rerank_server.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pip install -U sentence-transformers
2
+ import os
3
+ import re
4
+ import argparse
5
+ from dataclasses import dataclass, field
6
+ from typing import List, Optional
7
+ from collections import defaultdict
8
+
9
+ import torch
10
+ import numpy as np
11
+ from fastapi import FastAPI
12
+ from pydantic import BaseModel
13
+ from sentence_transformers import CrossEncoder
14
+
15
+ from retrieval_server import get_retriever, Config as RetrieverConfig
16
+ from rerank_server import SentenceTransformerCrossEncoder
17
+
18
+ app = FastAPI()
19
+
20
+ def convert_title_format(text):
21
+ # Use regex to extract the title and the content
22
+ match = re.match(r'\(Title:\s*([^)]+)\)\s*(.+)', text, re.DOTALL)
23
+ if match:
24
+ title, content = match.groups()
25
+ return f'\"{title}\"\n{content}'
26
+ else:
27
+ return text
28
+
29
+ # ----------- Combined Request Schema -----------
30
+ class SearchRequest(BaseModel):
31
+ queries: List[str]
32
+ topk_retrieval: Optional[int] = 10
33
+ topk_rerank: Optional[int] = 3
34
+ return_scores: bool = False
35
+
36
+ # ----------- Reranker Config Schema -----------
37
+ @dataclass
38
+ class RerankerArguments:
39
+ max_length: int = field(default=512)
40
+ rerank_topk: int = field(default=3)
41
+ rerank_model_name_or_path: str = field(default="cross-encoder/ms-marco-MiniLM-L12-v2")
42
+ batch_size: int = field(default=32)
43
+ reranker_type: str = field(default="sentence_transformer")
44
+
45
+ def get_reranker(config):
46
+ if config.reranker_type == "sentence_transformer":
47
+ return SentenceTransformerCrossEncoder.load(
48
+ config.rerank_model_name_or_path,
49
+ batch_size=config.batch_size,
50
+ device="cuda" if torch.cuda.is_available() else "cpu"
51
+ )
52
+ else:
53
+ raise ValueError(f"Unknown reranker type: {config.reranker_type}")
54
+
55
+ # ----------- Endpoint -----------
56
+ @app.post("/retrieve")
57
+ def search_endpoint(request: SearchRequest):
58
+ # Step 1: Retrieve documents
59
+ retrieved_docs = retriever.batch_search(
60
+ query_list=request.queries,
61
+ num=request.topk_retrieval,
62
+ return_score=False
63
+ )
64
+
65
+ # Step 2: Rerank
66
+ reranked = reranker.rerank(request.queries, retrieved_docs)
67
+
68
+ # Step 3: Format response
69
+ response = []
70
+ for i, doc_scores in reranked.items():
71
+ doc_scores = doc_scores[:request.topk_rerank]
72
+ if request.return_scores:
73
+ combined = []
74
+ for doc, score in doc_scores:
75
+ combined.append({"document": convert_title_format(doc), "score": score})
76
+ response.append(combined)
77
+ else:
78
+ response.append([convert_title_format(doc) for doc, _ in doc_scores])
79
+
80
+ return {"result": response}
81
+
82
+
83
+ if __name__ == "__main__":
84
+
85
+ parser = argparse.ArgumentParser(description="Launch the local faiss retriever.")
86
+ # retriever
87
+ parser.add_argument("--index_path", type=str, default="/home/peterjin/mnt/index/wiki-18/e5_Flat.index", help="Corpus indexing file.")
88
+ parser.add_argument("--corpus_path", type=str, default="/home/peterjin/mnt/data/retrieval-corpus/wiki-18.jsonl", help="Local corpus file.")
89
+ parser.add_argument("--retrieval_topk", type=int, default=10, help="Number of retrieved passages for one query.")
90
+ parser.add_argument("--retriever_name", type=str, default="e5", help="Name of the retriever model.")
91
+ parser.add_argument("--retriever_model", type=str, default="intfloat/e5-base-v2", help="Path of the retriever model.")
92
+ parser.add_argument('--faiss_gpu', action='store_true', help='Use GPU for computation')
93
+ # reranker
94
+ parser.add_argument("--reranking_topk", type=int, default=3, help="Number of reranked passages for one query.")
95
+ parser.add_argument("--reranker_model", type=str, default="cross-encoder/ms-marco-MiniLM-L12-v2", help="Path of the reranker model.")
96
+ parser.add_argument("--reranker_batch_size", type=int, default=32, help="Batch size for the reranker inference.")
97
+
98
+ args = parser.parse_args()
99
+
100
+ # ----------- Load Retriever and Reranker -----------
101
+ retriever_config = RetrieverConfig(
102
+ retrieval_method = args.retriever_name,
103
+ index_path=args.index_path,
104
+ corpus_path=args.corpus_path,
105
+ retrieval_topk=args.retrieval_topk,
106
+ faiss_gpu=args.faiss_gpu,
107
+ retrieval_model_path=args.retriever_model,
108
+ retrieval_pooling_method="mean",
109
+ retrieval_query_max_length=256,
110
+ retrieval_use_fp16=True,
111
+ retrieval_batch_size=512,
112
+ )
113
+ retriever = get_retriever(retriever_config)
114
+
115
+ reranker_config = RerankerArguments(
116
+ rerank_topk = args.reranking_topk,
117
+ rerank_model_name_or_path = args.reranker_model,
118
+ batch_size = args.reranker_batch_size,
119
+ )
120
+ reranker = get_reranker(reranker_config)
121
+
122
+ import uvicorn
123
+ uvicorn.run(app, host="0.0.0.0", port=8000)
code/RL_model/verl/Search-R1/search_r1/search/retrieval_server.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import warnings
4
+ from typing import List, Dict, Optional
5
+ import argparse
6
+
7
+ import faiss
8
+ import torch
9
+ import numpy as np
10
+ from transformers import AutoConfig, AutoTokenizer, AutoModel
11
+ from tqdm import tqdm
12
+ import datasets
13
+
14
+ import uvicorn
15
+ from fastapi import FastAPI
16
+ from pydantic import BaseModel
17
+
18
+ def load_corpus(corpus_path: str):
19
+ corpus = datasets.load_dataset(
20
+ 'json',
21
+ data_files=corpus_path,
22
+ split="train",
23
+ num_proc=4
24
+ )
25
+ return corpus
26
+
27
+ def read_jsonl(file_path):
28
+ data = []
29
+ with open(file_path, "r") as f:
30
+ for line in f:
31
+ data.append(json.loads(line))
32
+ return data
33
+
34
+ def load_docs(corpus, doc_idxs):
35
+ results = [corpus[int(idx)] for idx in doc_idxs]
36
+ return results
37
+
38
+ def load_model(model_path: str, use_fp16: bool = False):
39
+ model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
40
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
41
+ model.eval()
42
+ model.cuda()
43
+ if use_fp16:
44
+ model = model.half()
45
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=True)
46
+ return model, tokenizer
47
+
48
+ def pooling(
49
+ pooler_output,
50
+ last_hidden_state,
51
+ attention_mask = None,
52
+ pooling_method = "mean"
53
+ ):
54
+ if pooling_method == "mean":
55
+ last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
56
+ return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
57
+ elif pooling_method == "cls":
58
+ return last_hidden_state[:, 0]
59
+ elif pooling_method == "pooler":
60
+ return pooler_output
61
+ else:
62
+ raise NotImplementedError("Pooling method not implemented!")
63
+
64
+ class Encoder:
65
+ def __init__(self, model_name, model_path, pooling_method, max_length, use_fp16):
66
+ self.model_name = model_name
67
+ self.model_path = model_path
68
+ self.pooling_method = pooling_method
69
+ self.max_length = max_length
70
+ self.use_fp16 = use_fp16
71
+
72
+ self.model, self.tokenizer = load_model(model_path=model_path, use_fp16=use_fp16)
73
+ self.model.eval()
74
+
75
+ @torch.no_grad()
76
+ def encode(self, query_list: List[str], is_query=True) -> np.ndarray:
77
+ # processing query for different encoders
78
+ if isinstance(query_list, str):
79
+ query_list = [query_list]
80
+
81
+ if "e5" in self.model_name.lower():
82
+ if is_query:
83
+ query_list = [f"query: {query}" for query in query_list]
84
+ else:
85
+ query_list = [f"passage: {query}" for query in query_list]
86
+
87
+ if "bge" in self.model_name.lower():
88
+ if is_query:
89
+ query_list = [f"Represent this sentence for searching relevant passages: {query}" for query in query_list]
90
+
91
+ inputs = self.tokenizer(query_list,
92
+ max_length=self.max_length,
93
+ padding=True,
94
+ truncation=True,
95
+ return_tensors="pt"
96
+ )
97
+ inputs = {k: v.cuda() for k, v in inputs.items()}
98
+
99
+ if "T5" in type(self.model).__name__:
100
+ # T5-based retrieval model
101
+ decoder_input_ids = torch.zeros(
102
+ (inputs['input_ids'].shape[0], 1), dtype=torch.long
103
+ ).to(inputs['input_ids'].device)
104
+ output = self.model(
105
+ **inputs, decoder_input_ids=decoder_input_ids, return_dict=True
106
+ )
107
+ query_emb = output.last_hidden_state[:, 0, :]
108
+ else:
109
+ output = self.model(**inputs, return_dict=True)
110
+ query_emb = pooling(output.pooler_output,
111
+ output.last_hidden_state,
112
+ inputs['attention_mask'],
113
+ self.pooling_method)
114
+ if "dpr" not in self.model_name.lower():
115
+ query_emb = torch.nn.functional.normalize(query_emb, dim=-1)
116
+
117
+ query_emb = query_emb.detach().cpu().numpy()
118
+ query_emb = query_emb.astype(np.float32, order="C")
119
+
120
+ del inputs, output
121
+ torch.cuda.empty_cache()
122
+
123
+ return query_emb
124
+
125
+ class BaseRetriever:
126
+ def __init__(self, config):
127
+ self.config = config
128
+ self.retrieval_method = config.retrieval_method
129
+ self.topk = config.retrieval_topk
130
+
131
+ self.index_path = config.index_path
132
+ self.corpus_path = config.corpus_path
133
+
134
+ def _search(self, query: str, num: int, return_score: bool):
135
+ raise NotImplementedError
136
+
137
+ def _batch_search(self, query_list: List[str], num: int, return_score: bool):
138
+ raise NotImplementedError
139
+
140
+ def search(self, query: str, num: int = None, return_score: bool = False):
141
+ return self._search(query, num, return_score)
142
+
143
+ def batch_search(self, query_list: List[str], num: int = None, return_score: bool = False):
144
+ return self._batch_search(query_list, num, return_score)
145
+
146
+ class BM25Retriever(BaseRetriever):
147
+ def __init__(self, config):
148
+ super().__init__(config)
149
+ from pyserini.search.lucene import LuceneSearcher
150
+ self.searcher = LuceneSearcher(self.index_path)
151
+ self.contain_doc = self._check_contain_doc()
152
+ if not self.contain_doc:
153
+ self.corpus = load_corpus(self.corpus_path)
154
+ self.max_process_num = 8
155
+
156
+ def _check_contain_doc(self):
157
+ return self.searcher.doc(0).raw() is not None
158
+
159
+ def _search(self, query: str, num: int = None, return_score: bool = False):
160
+ if num is None:
161
+ num = self.topk
162
+ hits = self.searcher.search(query, num)
163
+ if len(hits) < 1:
164
+ if return_score:
165
+ return [], []
166
+ else:
167
+ return []
168
+ scores = [hit.score for hit in hits]
169
+ if len(hits) < num:
170
+ warnings.warn('Not enough documents retrieved!')
171
+ else:
172
+ hits = hits[:num]
173
+
174
+ if self.contain_doc:
175
+ all_contents = [
176
+ json.loads(self.searcher.doc(hit.docid).raw())['contents']
177
+ for hit in hits
178
+ ]
179
+ results = [
180
+ {
181
+ 'title': content.split("\n")[0].strip("\""),
182
+ 'text': "\n".join(content.split("\n")[1:]),
183
+ 'contents': content
184
+ }
185
+ for content in all_contents
186
+ ]
187
+ else:
188
+ results = load_docs(self.corpus, [hit.docid for hit in hits])
189
+
190
+ if return_score:
191
+ return results, scores
192
+ else:
193
+ return results
194
+
195
+ def _batch_search(self, query_list: List[str], num: int = None, return_score: bool = False):
196
+ results = []
197
+ scores = []
198
+ for query in query_list:
199
+ item_result, item_score = self._search(query, num, True)
200
+ results.append(item_result)
201
+ scores.append(item_score)
202
+ if return_score:
203
+ return results, scores
204
+ else:
205
+ return results
206
+
207
+ class DenseRetriever(BaseRetriever):
208
+ def __init__(self, config):
209
+ super().__init__(config)
210
+ self.index = faiss.read_index(self.index_path)
211
+ if config.faiss_gpu:
212
+ co = faiss.GpuMultipleClonerOptions()
213
+ co.useFloat16 = True
214
+ co.shard = True
215
+ self.index = faiss.index_cpu_to_all_gpus(self.index, co=co)
216
+
217
+ self.corpus = load_corpus(self.corpus_path)
218
+ self.encoder = Encoder(
219
+ model_name = self.retrieval_method,
220
+ model_path = config.retrieval_model_path,
221
+ pooling_method = config.retrieval_pooling_method,
222
+ max_length = config.retrieval_query_max_length,
223
+ use_fp16 = config.retrieval_use_fp16
224
+ )
225
+ self.topk = config.retrieval_topk
226
+ self.batch_size = config.retrieval_batch_size
227
+
228
+ def _search(self, query: str, num: int = None, return_score: bool = False):
229
+ if num is None:
230
+ num = self.topk
231
+ query_emb = self.encoder.encode(query)
232
+ scores, idxs = self.index.search(query_emb, k=num)
233
+ idxs = idxs[0]
234
+ scores = scores[0]
235
+ results = load_docs(self.corpus, idxs)
236
+ if return_score:
237
+ return results, scores.tolist()
238
+ else:
239
+ return results
240
+
241
+ def _batch_search(self, query_list: List[str], num: int = None, return_score: bool = False):
242
+ if isinstance(query_list, str):
243
+ query_list = [query_list]
244
+ if num is None:
245
+ num = self.topk
246
+
247
+ results = []
248
+ scores = []
249
+ for start_idx in tqdm(range(0, len(query_list), self.batch_size), desc='Retrieval process: '):
250
+ query_batch = query_list[start_idx:start_idx + self.batch_size]
251
+ batch_emb = self.encoder.encode(query_batch)
252
+ batch_scores, batch_idxs = self.index.search(batch_emb, k=num)
253
+ batch_scores = batch_scores.tolist()
254
+ batch_idxs = batch_idxs.tolist()
255
+
256
+ # load_docs is not vectorized, but is a python list approach
257
+ flat_idxs = sum(batch_idxs, [])
258
+ batch_results = load_docs(self.corpus, flat_idxs)
259
+ # chunk them back
260
+ batch_results = [batch_results[i*num : (i+1)*num] for i in range(len(batch_idxs))]
261
+
262
+ results.extend(batch_results)
263
+ scores.extend(batch_scores)
264
+
265
+ del batch_emb, batch_scores, batch_idxs, query_batch, flat_idxs, batch_results
266
+ torch.cuda.empty_cache()
267
+
268
+ if return_score:
269
+ return results, scores
270
+ else:
271
+ return results
272
+
273
+ def get_retriever(config):
274
+ if config.retrieval_method == "bm25":
275
+ return BM25Retriever(config)
276
+ else:
277
+ return DenseRetriever(config)
278
+
279
+
280
+ #####################################
281
+ # FastAPI server below
282
+ #####################################
283
+
284
+ class Config:
285
+ """
286
+ Minimal config class (simulating your argparse)
287
+ Replace this with your real arguments or load them dynamically.
288
+ """
289
+ def __init__(
290
+ self,
291
+ retrieval_method: str = "bm25",
292
+ retrieval_topk: int = 10,
293
+ index_path: str = "./index/bm25",
294
+ corpus_path: str = "./data/corpus.jsonl",
295
+ dataset_path: str = "./data",
296
+ data_split: str = "train",
297
+ faiss_gpu: bool = True,
298
+ retrieval_model_path: str = "./model",
299
+ retrieval_pooling_method: str = "mean",
300
+ retrieval_query_max_length: int = 256,
301
+ retrieval_use_fp16: bool = False,
302
+ retrieval_batch_size: int = 128
303
+ ):
304
+ self.retrieval_method = retrieval_method
305
+ self.retrieval_topk = retrieval_topk
306
+ self.index_path = index_path
307
+ self.corpus_path = corpus_path
308
+ self.dataset_path = dataset_path
309
+ self.data_split = data_split
310
+ self.faiss_gpu = faiss_gpu
311
+ self.retrieval_model_path = retrieval_model_path
312
+ self.retrieval_pooling_method = retrieval_pooling_method
313
+ self.retrieval_query_max_length = retrieval_query_max_length
314
+ self.retrieval_use_fp16 = retrieval_use_fp16
315
+ self.retrieval_batch_size = retrieval_batch_size
316
+
317
+
318
+ class QueryRequest(BaseModel):
319
+ queries: List[str]
320
+ topk: Optional[int] = None
321
+ return_scores: bool = False
322
+
323
+
324
+ app = FastAPI()
325
+
326
+ @app.post("/retrieve")
327
+ def retrieve_endpoint(request: QueryRequest):
328
+ """
329
+ Endpoint that accepts queries and performs retrieval.
330
+ Input format:
331
+ {
332
+ "queries": ["What is Python?", "Tell me about neural networks."],
333
+ "topk": 3,
334
+ "return_scores": true
335
+ }
336
+ """
337
+ if not request.topk:
338
+ request.topk = config.retrieval_topk # fallback to default
339
+
340
+ # Perform batch retrieval
341
+ results, scores = retriever.batch_search(
342
+ query_list=request.queries,
343
+ num=request.topk,
344
+ return_score=request.return_scores
345
+ )
346
+
347
+ # Format response
348
+ resp = []
349
+ for i, single_result in enumerate(results):
350
+ if request.return_scores:
351
+ # If scores are returned, combine them with results
352
+ combined = []
353
+ for doc, score in zip(single_result, scores[i]):
354
+ combined.append({"document": doc, "score": score})
355
+ resp.append(combined)
356
+ else:
357
+ resp.append(single_result)
358
+ return {"result": resp}
359
+
360
+
361
+ if __name__ == "__main__":
362
+
363
+ parser = argparse.ArgumentParser(description="Launch the local faiss retriever.")
364
+ parser.add_argument("--index_path", type=str, default="/home/peterjin/mnt/index/wiki-18/e5_Flat.index", help="Corpus indexing file.")
365
+ parser.add_argument("--corpus_path", type=str, default="/home/peterjin/mnt/data/retrieval-corpus/wiki-18.jsonl", help="Local corpus file.")
366
+ parser.add_argument("--topk", type=int, default=3, help="Number of retrieved passages for one query.")
367
+ parser.add_argument("--retriever_name", type=str, default="e5", help="Name of the retriever model.")
368
+ parser.add_argument("--retriever_model", type=str, default="intfloat/e5-base-v2", help="Path of the retriever model.")
369
+ parser.add_argument('--faiss_gpu', action='store_true', help='Use GPU for computation')
370
+
371
+ args = parser.parse_args()
372
+
373
+ # 1) Build a config (could also parse from arguments).
374
+ # In real usage, you'd parse your CLI arguments or environment variables.
375
+ config = Config(
376
+ retrieval_method = args.retriever_name, # or "dense"
377
+ index_path=args.index_path,
378
+ corpus_path=args.corpus_path,
379
+ retrieval_topk=args.topk,
380
+ faiss_gpu=args.faiss_gpu,
381
+ retrieval_model_path=args.retriever_model,
382
+ retrieval_pooling_method="mean",
383
+ retrieval_query_max_length=256,
384
+ retrieval_use_fp16=True,
385
+ retrieval_batch_size=512,
386
+ )
387
+
388
+ # 2) Instantiate a global retriever so it is loaded once and reused.
389
+ retriever = get_retriever(config)
390
+
391
+ # 3) Launch the server. By default, it listens on http://127.0.0.1:8000
392
+ uvicorn.run(app, host="0.0.0.0", port=8000)
code/RL_model/verl/Search-R1/search_r1/search/serp_search_server.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ from fastapi import FastAPI
4
+ from pydantic import BaseModel
5
+ from typing import List, Optional, Dict
6
+ from concurrent.futures import ThreadPoolExecutor
7
+ import argparse
8
+ import uvicorn
9
+
10
+ parser = argparse.ArgumentParser(description="Launch online search server.")
11
+ parser.add_argument('--search_url', type=str, required=True,
12
+ help="URL for search engine (e.g. https://serpapi.com/search)")
13
+ parser.add_argument('--topk', type=int, default=3,
14
+ help="Number of results to return per query")
15
+ parser.add_argument('--serp_api_key', type=str, default=None,
16
+ help="SerpAPI key for online search")
17
+ parser.add_argument('--serp_engine', type=str, default="google",
18
+ help="SerpAPI engine for online search")
19
+ args = parser.parse_args()
20
+
21
+ # --- Config ---
22
+ class OnlineSearchConfig:
23
+ def __init__(
24
+ self,
25
+ search_url: str = "https://serpapi.com/search",
26
+ topk: int = 3,
27
+ serp_api_key: Optional[str] = None,
28
+ serp_engine: Optional[str] = None,
29
+ ):
30
+ self.search_url = search_url
31
+ self.topk = topk
32
+ self.serp_api_key = serp_api_key
33
+ self.serp_engine = serp_engine
34
+
35
+
36
+ # --- Online Search Wrapper ---
37
+ class OnlineSearchEngine:
38
+ def __init__(self, config: OnlineSearchConfig):
39
+ self.config = config
40
+
41
+ def _search_query(self, query: str):
42
+ params = {
43
+ "engine": self.config.serp_engine,
44
+ "q": query,
45
+ "api_key": self.config.serp_api_key,
46
+ }
47
+ response = requests.get(self.config.search_url, params=params)
48
+ return response.json()
49
+
50
+ def batch_search(self, queries: List[str]):
51
+ results = []
52
+ with ThreadPoolExecutor() as executor:
53
+ for result in executor.map(self._search_query, queries):
54
+ results.append(self._process_result(result))
55
+ return results
56
+
57
+ def _process_result(self, search_result: Dict):
58
+ results = []
59
+
60
+ answer_box = search_result.get('answer_box', {})
61
+ if answer_box:
62
+ title = answer_box.get('title', 'No title.')
63
+ snippet = answer_box.get('snippet', 'No snippet available.')
64
+ results.append({
65
+ 'document': {"contents": f'\"{title}\"\n{snippet}'},
66
+ })
67
+
68
+ organic_results = search_result.get('organic_results', [])
69
+ for _, result in enumerate(organic_results[:self.config.topk]):
70
+ title = result.get('title', 'No title.')
71
+ snippet = result.get('snippet', 'No snippet available.')
72
+ results.append({
73
+ 'document': {"contents": f'\"{title}\"\n{snippet}'},
74
+ })
75
+
76
+ related_results = search_result.get('related_questions', [])
77
+ for _, result in enumerate(related_results[:self.config.topk]):
78
+ title = result.get('question', 'No title.') # question is the title here
79
+ snippet = result.get('snippet', 'No snippet available.')
80
+ results.append({
81
+ 'document': {"contents": f'\"{title}\"\n{snippet}'},
82
+ })
83
+
84
+ return results
85
+
86
+
87
+ # --- FastAPI Setup ---
88
+ app = FastAPI(title="Online Search Proxy Server")
89
+
90
+ class SearchRequest(BaseModel):
91
+ queries: List[str]
92
+
93
+ # Instantiate global config + engine
94
+ config = OnlineSearchConfig(
95
+ search_url=args.search_url,
96
+ topk=args.topk,
97
+ serp_api_key=args.serp_api_key,
98
+ serp_engine=args.serp_engine,
99
+ )
100
+ engine = OnlineSearchEngine(config)
101
+
102
+ # --- Routes ---
103
+ @app.post("/retrieve")
104
+ def search_endpoint(request: SearchRequest):
105
+ results = engine.batch_search(request.queries)
106
+ return {"result": results}
107
+
108
+ ## return {"result": List[List[{'document': {"id": xx, "content": "title" + \n + "content"}, 'score': xx}]]}
109
+
110
+ if __name__ == "__main__":
111
+ # 3) Launch the server. By default, it listens on http://127.0.0.1:8000
112
+ uvicorn.run(app, host="0.0.0.0", port=8000)
code/RL_model/verl/Search-R1/verl.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ setup.py
5
+ ./search_r1/__init__.py
6
+ ./search_r1/llm_agent/__init__.py
7
+ ./search_r1/llm_agent/generation.py
8
+ ./search_r1/llm_agent/tensor_helper.py
9
+ ./verl/__init__.py
10
+ ./verl/protocol.py
11
+ ./verl/models/__init__.py
12
+ ./verl/models/registry.py
13
+ ./verl/models/weight_loader_registry.py
14
+ ./verl/models/llama/__init__.py
15
+ ./verl/models/llama/megatron/__init__.py
16
+ ./verl/models/llama/megatron/modeling_llama_megatron.py
17
+ ./verl/models/llama/megatron/checkpoint_utils/__init__.py
18
+ ./verl/models/llama/megatron/checkpoint_utils/llama_loader.py
19
+ ./verl/models/llama/megatron/checkpoint_utils/llama_saver.py
20
+ ./verl/models/llama/megatron/layers/__init__.py
21
+ ./verl/models/llama/megatron/layers/parallel_attention.py
22
+ ./verl/models/llama/megatron/layers/parallel_decoder.py
23
+ ./verl/models/llama/megatron/layers/parallel_linear.py
24
+ ./verl/models/llama/megatron/layers/parallel_mlp.py
25
+ ./verl/models/llama/megatron/layers/parallel_rmsnorm.py
26
+ ./verl/models/transformers/__init__.py
27
+ ./verl/models/transformers/llama.py
28
+ ./verl/models/transformers/monkey_patch.py
29
+ ./verl/models/transformers/qwen2.py
30
+ ./verl/single_controller/__init__.py
31
+ ./verl/single_controller/base/__init__.py
32
+ ./verl/single_controller/base/decorator.py
33
+ ./verl/single_controller/base/worker.py
34
+ ./verl/single_controller/base/worker_group.py
35
+ ./verl/single_controller/base/megatron/__init__.py
36
+ ./verl/single_controller/base/megatron/worker.py
37
+ ./verl/single_controller/base/megatron/worker_group.py
38
+ ./verl/single_controller/base/register_center/__init__.py
39
+ ./verl/single_controller/base/register_center/ray.py
40
+ ./verl/single_controller/ray/__init__.py
41
+ ./verl/single_controller/ray/base.py
42
+ ./verl/single_controller/ray/megatron.py
43
+ ./verl/third_party/__init__.py
44
+ ./verl/third_party/vllm/__init__.py
45
+ ./verl/third_party/vllm/vllm_v_0_3_1/__init__.py
46
+ ./verl/third_party/vllm/vllm_v_0_3_1/arg_utils.py
47
+ ./verl/third_party/vllm/vllm_v_0_3_1/config.py
48
+ ./verl/third_party/vllm/vllm_v_0_3_1/llm.py
49
+ ./verl/third_party/vllm/vllm_v_0_3_1/llm_engine_sp.py
50
+ ./verl/third_party/vllm/vllm_v_0_3_1/model_loader.py
51
+ ./verl/third_party/vllm/vllm_v_0_3_1/model_runner.py
52
+ ./verl/third_party/vllm/vllm_v_0_3_1/parallel_state.py
53
+ ./verl/third_party/vllm/vllm_v_0_3_1/tokenizer.py
54
+ ./verl/third_party/vllm/vllm_v_0_3_1/weight_loaders.py
55
+ ./verl/third_party/vllm/vllm_v_0_3_1/worker.py
56
+ ./verl/third_party/vllm/vllm_v_0_4_2/__init__.py
57
+ ./verl/third_party/vllm/vllm_v_0_4_2/arg_utils.py
58
+ ./verl/third_party/vllm/vllm_v_0_4_2/config.py
59
+ ./verl/third_party/vllm/vllm_v_0_4_2/dtensor_weight_loaders.py
60
+ ./verl/third_party/vllm/vllm_v_0_4_2/hf_weight_loader.py
61
+ ./verl/third_party/vllm/vllm_v_0_4_2/llm.py
62
+ ./verl/third_party/vllm/vllm_v_0_4_2/llm_engine_sp.py
63
+ ./verl/third_party/vllm/vllm_v_0_4_2/megatron_weight_loaders.py
64
+ ./verl/third_party/vllm/vllm_v_0_4_2/model_loader.py
65
+ ./verl/third_party/vllm/vllm_v_0_4_2/model_runner.py
66
+ ./verl/third_party/vllm/vllm_v_0_4_2/parallel_state.py
67
+ ./verl/third_party/vllm/vllm_v_0_4_2/spmd_gpu_executor.py
68
+ ./verl/third_party/vllm/vllm_v_0_4_2/tokenizer.py
69
+ ./verl/third_party/vllm/vllm_v_0_4_2/worker.py
70
+ ./verl/third_party/vllm/vllm_v_0_5_4/__init__.py
71
+ ./verl/third_party/vllm/vllm_v_0_5_4/arg_utils.py
72
+ ./verl/third_party/vllm/vllm_v_0_5_4/config.py
73
+ ./verl/third_party/vllm/vllm_v_0_5_4/dtensor_weight_loaders.py
74
+ ./verl/third_party/vllm/vllm_v_0_5_4/hf_weight_loader.py
75
+ ./verl/third_party/vllm/vllm_v_0_5_4/llm.py
76
+ ./verl/third_party/vllm/vllm_v_0_5_4/llm_engine_sp.py
77
+ ./verl/third_party/vllm/vllm_v_0_5_4/megatron_weight_loaders.py
78
+ ./verl/third_party/vllm/vllm_v_0_5_4/model_loader.py
79
+ ./verl/third_party/vllm/vllm_v_0_5_4/model_runner.py
80
+ ./verl/third_party/vllm/vllm_v_0_5_4/parallel_state.py
81
+ ./verl/third_party/vllm/vllm_v_0_5_4/spmd_gpu_executor.py
82
+ ./verl/third_party/vllm/vllm_v_0_5_4/tokenizer.py
83
+ ./verl/third_party/vllm/vllm_v_0_5_4/worker.py
84
+ ./verl/third_party/vllm/vllm_v_0_6_3/__init__.py
85
+ ./verl/third_party/vllm/vllm_v_0_6_3/arg_utils.py
86
+ ./verl/third_party/vllm/vllm_v_0_6_3/config.py
87
+ ./verl/third_party/vllm/vllm_v_0_6_3/dtensor_weight_loaders.py
88
+ ./verl/third_party/vllm/vllm_v_0_6_3/hf_weight_loader.py
89
+ ./verl/third_party/vllm/vllm_v_0_6_3/llm.py
90
+ ./verl/third_party/vllm/vllm_v_0_6_3/llm_engine_sp.py
91
+ ./verl/third_party/vllm/vllm_v_0_6_3/megatron_weight_loaders.py
92
+ ./verl/third_party/vllm/vllm_v_0_6_3/model_loader.py
93
+ ./verl/third_party/vllm/vllm_v_0_6_3/model_runner.py
94
+ ./verl/third_party/vllm/vllm_v_0_6_3/parallel_state.py
95
+ ./verl/third_party/vllm/vllm_v_0_6_3/spmd_gpu_executor.py
96
+ ./verl/third_party/vllm/vllm_v_0_6_3/tokenizer.py
97
+ ./verl/third_party/vllm/vllm_v_0_6_3/worker.py
98
+ ./verl/trainer/__init__.py
99
+ ./verl/trainer/fsdp_sft_trainer.py
100
+ ./verl/trainer/main_eval.py
101
+ ./verl/trainer/main_generation.py
102
+ ./verl/trainer/main_ppo.py
103
+ ./verl/trainer/main_ppo_format.py
104
+ ./verl/trainer/config/evaluation.yaml
105
+ ./verl/trainer/config/generation.yaml
106
+ ./verl/trainer/config/ppo_megatron_trainer.yaml
107
+ ./verl/trainer/config/ppo_trainer.yaml
108
+ ./verl/trainer/config/sft_trainer.yaml
109
+ ./verl/trainer/ppo/__init__.py
110
+ ./verl/trainer/ppo/core_algos.py
111
+ ./verl/trainer/ppo/ray_trainer.py
112
+ ./verl/utils/__init__.py
113
+ ./verl/utils/config.py
114
+ ./verl/utils/distributed.py
115
+ ./verl/utils/flops_counter.py
116
+ ./verl/utils/fs.py
117
+ ./verl/utils/fsdp_utils.py
118
+ ./verl/utils/hdfs_io.py
119
+ ./verl/utils/import_utils.py
120
+ ./verl/utils/logging_utils.py
121
+ ./verl/utils/megatron_utils.py
122
+ ./verl/utils/memory_buffer.py
123
+ ./verl/utils/model.py
124
+ ./verl/utils/py_functional.py
125
+ ./verl/utils/ray_utils.py
126
+ ./verl/utils/seqlen_balancing.py
127
+ ./verl/utils/tokenizer.py
128
+ ./verl/utils/torch_dtypes.py
129
+ ./verl/utils/torch_functional.py
130
+ ./verl/utils/tracking.py
131
+ ./verl/utils/ulysses.py
132
+ ./verl/utils/dataset/__init__.py
133
+ ./verl/utils/dataset/rl_dataset.py
134
+ ./verl/utils/dataset/rm_dataset.py
135
+ ./verl/utils/debug/__init__.py
136
+ ./verl/utils/debug/performance.py
137
+ ./verl/utils/debug/trajectory_tracker.py
138
+ ./verl/utils/logger/__init__.py
139
+ ./verl/utils/logger/aggregate_logger.py
140
+ ./verl/utils/megatron/__init__.py
141
+ ./verl/utils/megatron/memory.py
142
+ ./verl/utils/megatron/optimizer.py
143
+ ./verl/utils/megatron/optimizer_config.py
144
+ ./verl/utils/megatron/pipeline_parallel.py
145
+ ./verl/utils/megatron/sequence_parallel.py
146
+ ./verl/utils/megatron/tensor_parallel.py
147
+ ./verl/utils/rendezvous/__init__.py
148
+ ./verl/utils/rendezvous/ray_backend.py
149
+ ./verl/utils/reward_score/__init__.py
150
+ ./verl/utils/reward_score/countdown.py
151
+ ./verl/utils/reward_score/gsm8k.py
152
+ ./verl/utils/reward_score/math.py
153
+ ./verl/utils/reward_score/multiply.py
154
+ ./verl/utils/reward_score/qa_em.py
155
+ ./verl/utils/reward_score/qa_em_format.py
156
+ ./verl/version/version
157
+ ./verl/workers/__init__.py
158
+ ./verl/workers/fsdp_workers.py
159
+ ./verl/workers/megatron_workers.py
160
+ ./verl/workers/actor/__init__.py
161
+ ./verl/workers/actor/base.py
162
+ ./verl/workers/actor/dp_actor.py
163
+ ./verl/workers/actor/megatron_actor.py
164
+ ./verl/workers/critic/__init__.py
165
+ ./verl/workers/critic/base.py
166
+ ./verl/workers/critic/dp_critic.py
167
+ ./verl/workers/critic/megatron_critic.py
168
+ ./verl/workers/reward_model/__init__.py
169
+ ./verl/workers/reward_model/base.py
170
+ ./verl/workers/reward_model/megatron/__init__.py
171
+ ./verl/workers/reward_model/megatron/reward_model.py
172
+ ./verl/workers/rollout/__init__.py
173
+ ./verl/workers/rollout/base.py
174
+ ./verl/workers/rollout/hf_rollout.py
175
+ ./verl/workers/rollout/tokenizer.py
176
+ ./verl/workers/rollout/naive/__init__.py
177
+ ./verl/workers/rollout/naive/naive_rollout.py
178
+ ./verl/workers/rollout/vllm_rollout/__init__.py
179
+ ./verl/workers/rollout/vllm_rollout/vllm_rollout.py
180
+ ./verl/workers/sharding_manager/__init__.py
181
+ ./verl/workers/sharding_manager/base.py
182
+ ./verl/workers/sharding_manager/fsdp_ulysses.py
183
+ ./verl/workers/sharding_manager/fsdp_vllm.py
184
+ ./verl/workers/sharding_manager/megatron_vllm.py
185
+ verl.egg-info/PKG-INFO
186
+ verl.egg-info/SOURCES.txt
187
+ verl.egg-info/dependency_links.txt
188
+ verl.egg-info/requires.txt
189
+ verl.egg-info/top_level.txt
190
+ verl/version/version
code/RL_model/verl/Search-R1/verl/single_controller/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+
17
+ version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))
18
+
19
+ with open(os.path.join(version_folder, 'version/version')) as f:
20
+ __version__ = f.read().strip()
code/RL_model/verl/Search-R1/verl/trainer/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
code/RL_model/verl/Search-R1/verl/trainer/main_eval.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Offline evaluate the performance of a generated file using reward model and ground truth verifier.
16
+ The input is a parquet file that contains N generated sequences and (optional) the ground truth.
17
+
18
+ """
19
+
20
+ import hydra
21
+ from verl.utils.fs import copy_local_path_from_hdfs
22
+ from verl.utils.reward_score import math, gsm8k
23
+ import pandas as pd
24
+ import numpy as np
25
+
26
+
27
+ def select_reward_fn(data_source):
28
+ if data_source == 'lighteval/MATH':
29
+ return math.compute_score
30
+ else:
31
+ raise NotImplementedError
32
+
33
+
34
+ @hydra.main(config_path='config', config_name='evaluation', version_base=None)
35
+ def main(config):
36
+ local_path = copy_local_path_from_hdfs(config.data.path)
37
+ dataset = pd.read_parquet(local_path)
38
+ prompts = dataset[config.data.prompt_key]
39
+ responses = dataset[config.data.response_key]
40
+ data_sources = dataset[config.data.data_source_key]
41
+ reward_model_data = dataset[config.data.reward_model_key]
42
+
43
+ passes = 0
44
+
45
+ total = len(dataset)
46
+
47
+ for i in range(total):
48
+ response_lst = responses[i]
49
+ data_source = data_sources[i]
50
+ # select reward score based on data_source
51
+ prompt = prompts[i]
52
+ reward_data = reward_model_data[i]
53
+ reward_fn = select_reward_fn(data_source)
54
+ ground_truth = reward_data['ground_truth']
55
+ score_lst = []
56
+ for r in response_lst:
57
+ score = reward_fn(r, ground_truth)
58
+ score_lst.append(score)
59
+
60
+ max_score = np.max(score_lst)
61
+
62
+ if max_score == 1:
63
+ passes += 1
64
+
65
+ print(f'pass@5: {passes / total}')
66
+
67
+
68
+ if __name__ == '__main__':
69
+ main()
code/RL_model/verl/Search-R1/verl/utils/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from . import tokenizer
16
+ from .tokenizer import *
17
+
18
+ __all__ = tokenizer.__all__
code/RL_model/verl/Search-R1/verl/utils/config.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict
16
+
17
+ from omegaconf import DictConfig
18
+
19
+
20
+ def update_dict_with_config(dictionary: Dict, config: DictConfig):
21
+ for key in dictionary:
22
+ if hasattr(config, key):
23
+ dictionary[key] = getattr(config, key)