yuccaaa commited on
Commit
acbfbc3
·
verified ·
1 Parent(s): ffcfc75

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. BioReason-0813/=1.18.1 +10 -0
  2. BioReason-0813/__pycache__/blip2_dna_module.cpython-310.pyc +0 -0
  3. BioReason-0813/__pycache__/blip2_grpo_trainer.cpython-310.pyc +0 -0
  4. BioReason-0813/bioreason/dna_modules/__pycache__/dna_module.cpython-310.pyc +0 -0
  5. BioReason-0813/bioreason/dna_modules/dna_module.py +49 -0
  6. BioReason-0813/bioreason/trainer/grpo_config.py +365 -0
  7. BioReason-0813/blip2_dna_module.py +349 -0
  8. BioReason-0813/blip2_grpo_trainer.py +591 -0
  9. BioReason-0813/blips_reason.py +866 -0
  10. BioReason-0813/model/__pycache__/blip2.cpython-310.pyc +0 -0
  11. BioReason-0813/model/__pycache__/blip2_opt.cpython-310.pyc +0 -0
  12. BioReason-0813/model/__pycache__/blip2_opt.cpython-311.pyc +0 -0
  13. BioReason-0813/model/__pycache__/blip2_stage2.cpython-310.pyc +0 -0
  14. BioReason-0813/model/__pycache__/blip2_stage2.cpython-311.pyc +0 -0
  15. BioReason-0813/model/__pycache__/help_funcs.cpython-310.pyc +0 -0
  16. BioReason-0813/model/blip2.py +126 -0
  17. BioReason-0813/model/blip2_opt.py +550 -0
  18. BioReason-0813/model/blip2_stage2.py +365 -0
  19. BioReason-0813/model/help_funcs.py +112 -0
  20. BioReason-0813/prompt_templates.py +57 -0
  21. BioReason-0813/run.sh +103 -0
  22. BioReason-main/.gitignore +180 -0
  23. BioReason-main/LICENSE +201 -0
  24. BioReason-main/README.md +148 -0
  25. BioReason-main/bioreason.egg-info/PKG-INFO +181 -0
  26. BioReason-main/bioreason.egg-info/SOURCES.txt +9 -0
  27. BioReason-main/bioreason.egg-info/dependency_links.txt +1 -0
  28. BioReason-main/bioreason.egg-info/requires.txt +19 -0
  29. BioReason-main/bioreason.egg-info/top_level.txt +1 -0
  30. BioReason-main/bioreason/__init__.py +0 -0
  31. BioReason-main/bioreason/dataset/__init__.py +11 -0
  32. BioReason-main/bioreason/dataset/kegg.py +382 -0
  33. BioReason-main/bioreason/dataset/utils.py +59 -0
  34. BioReason-main/bioreason/dataset/variant_effect.py +98 -0
  35. BioReason-main/bioreason/dna_modules/__init__.py +4 -0
  36. BioReason-main/bioreason/dna_modules/dna_module.py +49 -0
  37. BioReason-main/bioreason/dna_modules/nucleotide_module.py +263 -0
  38. BioReason-main/bioreason/models/__init__.py +9 -0
  39. BioReason-main/bioreason/models/dl/__init__.py +1 -0
  40. BioReason-main/bioreason/models/dl/chat_template_dl.py +1 -0
  41. BioReason-main/bioreason/models/dl/configuration_dl.py +232 -0
  42. BioReason-main/bioreason/models/dl/processing_dl.py +275 -0
  43. BioReason-main/bioreason/models/dna_llm.py +306 -0
  44. BioReason-main/bioreason/models/dna_only.py +203 -0
  45. BioReason-main/bioreason/models/evo2_tokenizer.py +219 -0
  46. BioReason-main/bioreason/trainer/__init__.py +7 -0
  47. BioReason-main/bioreason/trainer/demo_grpo.py +811 -0
  48. BioReason-main/bioreason/trainer/grpo_config.py +365 -0
  49. BioReason-main/bioreason/trainer/grpo_trainer.py +905 -0
  50. BioReason-main/bioreason/utils/__init__.py +0 -0
BioReason-0813/=1.18.1 ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ Looking in indexes: https://mirrors.aliyun.com/pypi/simple/
2
+ Requirement already satisfied: modelscope in /root/miniconda3/envs/bioreason/lib/python3.11/site-packages (1.29.1)
3
+ Requirement already satisfied: filelock in /root/miniconda3/envs/bioreason/lib/python3.11/site-packages (from modelscope) (3.18.0)
4
+ Requirement already satisfied: requests>=2.25 in /root/miniconda3/envs/bioreason/lib/python3.11/site-packages (from modelscope) (2.32.4)
5
+ Requirement already satisfied: setuptools in /root/miniconda3/envs/bioreason/lib/python3.11/site-packages (from modelscope) (78.1.1)
6
+ Requirement already satisfied: tqdm>=4.64.0 in /root/miniconda3/envs/bioreason/lib/python3.11/site-packages (from modelscope) (4.67.1)
7
+ Requirement already satisfied: urllib3>=1.26 in /root/miniconda3/envs/bioreason/lib/python3.11/site-packages (from modelscope) (2.5.0)
8
+ Requirement already satisfied: charset_normalizer<4,>=2 in /root/miniconda3/envs/bioreason/lib/python3.11/site-packages (from requests>=2.25->modelscope) (3.4.3)
9
+ Requirement already satisfied: idna<4,>=2.5 in /root/miniconda3/envs/bioreason/lib/python3.11/site-packages (from requests>=2.25->modelscope) (3.10)
10
+ Requirement already satisfied: certifi>=2017.4.17 in /root/miniconda3/envs/bioreason/lib/python3.11/site-packages (from requests>=2.25->modelscope) (2025.8.3)
BioReason-0813/__pycache__/blip2_dna_module.cpython-310.pyc ADDED
Binary file (11.3 kB). View file
 
BioReason-0813/__pycache__/blip2_grpo_trainer.cpython-310.pyc ADDED
Binary file (15.6 kB). View file
 
BioReason-0813/bioreason/dna_modules/__pycache__/dna_module.cpython-310.pyc ADDED
Binary file (2.53 kB). View file
 
BioReason-0813/bioreason/dna_modules/dna_module.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, Any, Union
3
+ import torch
4
+
5
+ class DNABaseModule(ABC):
6
+ def __init__(self):
7
+ super().__init__()
8
+
9
+ @abstractmethod
10
+ def get_dnallm_key(self):
11
+ pass
12
+
13
+ @abstractmethod
14
+ def get_model_class(self, model_id: str, model_init_kwargs: dict):
15
+ pass
16
+
17
+ def post_model_init(self, model, processing_class):
18
+ pass
19
+
20
+ def is_embeds_input(self):
21
+ return False
22
+
23
+ @abstractmethod
24
+ def get_processing_class(self):
25
+ pass
26
+
27
+ @abstractmethod
28
+ def get_dnallm_modules_keywords(self):
29
+ pass
30
+
31
+ @abstractmethod
32
+ def get_custom_multimodal_keywords(self):
33
+ pass
34
+
35
+ @abstractmethod
36
+ def get_non_generate_params(self):
37
+ pass
38
+
39
+ @abstractmethod
40
+ def get_custom_processing_keywords(self):
41
+ pass
42
+
43
+ @abstractmethod
44
+ def prepare_prompt(self, processing_class, inputs: dict[str, Union[torch.Tensor, Any]]):
45
+ pass
46
+
47
+ @abstractmethod
48
+ def prepare_model_inputs(self, processing_class, prompts_text, images, return_tensors, padding, padding_side, add_special_tokens):
49
+ pass
BioReason-0813/bioreason/trainer/grpo_config.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass, field
16
+ from typing import Optional, Union
17
+
18
+ from transformers import TrainingArguments
19
+
20
+
21
+ @dataclass
22
+ class DNALLMGRPOConfig(TrainingArguments):
23
+ r"""
24
+ Configuration class for the [`GRPOTrainer`].
25
+
26
+ Only the parameters specific to GRPO training are listed here. For details on other parameters, refer to the
27
+ [`~transformers.TrainingArguments`] documentation.
28
+
29
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
30
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
31
+ command line.
32
+
33
+ Parameters:
34
+ > Parameters that control the model and reference model
35
+
36
+ model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
37
+ Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
38
+ argument of the [`GRPOTrainer`] is provided as a string.
39
+
40
+ > Parameters that control the data preprocessing
41
+
42
+ remove_unused_columns (`bool`, *optional*, defaults to `False`):
43
+ Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that
44
+ requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`.
45
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
46
+ Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
47
+ num_generations (`int` or `None`, *optional*, defaults to `8`):
48
+ Number of generations per prompt to sample. The global batch size (num_processes * per_device_batch_size)
49
+ must be divisible by this value.
50
+ max_completion_length (`int` or `None`, *optional*, defaults to `256`):
51
+ Maximum length of the generated completion.
52
+ ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
53
+ This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
54
+ improving generation speed. However, disabling this option allows training models that exceed the VRAM
55
+ capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible
56
+ with vLLM generation.
57
+
58
+ > Parameters that control generation
59
+
60
+ temperature (`float`, defaults to `0.9`):
61
+ Temperature for sampling. The higher the temperature, the more random the completions.
62
+ top_p (`float`, *optional*, defaults to `1.0`):
63
+ Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. Set to
64
+ `1.0` to consider all tokens.
65
+ top_k (`int` or `None`, *optional*, defaults to `50`):
66
+ Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, top-k-filtering is
67
+ disabled.
68
+ min_p (`float` or `None`, *optional*, defaults to `None`):
69
+ Minimum token probability, which will be scaled by the probability of the most likely token. It must be a
70
+ value between `0.0` and `1.0`. Typical values are in the `0.01-0.2` range.
71
+ repetition_penalty (`float`, *optional*, defaults to `1.0`):
72
+ Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far.
73
+ Values > `1.0` encourage the model to use new tokens, while values < `1.0` encourage the model to repeat
74
+ tokens.
75
+ cache_implementation (`str` or `None`, *optional*, defaults to `None`):
76
+ Implementation of the cache method for faster generation when use_vllm is set to False.
77
+
78
+ > Parameters that control generation acceleration powered by vLLM
79
+
80
+ use_vllm (`bool`, *optional*, defaults to `False`):
81
+ Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for
82
+ training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`).
83
+ vllm_device (`str`, *optional*, defaults to `"auto"`):
84
+ Device where vLLM generation will run, e.g. `"cuda:1"`. If set to `"auto"` (default), the system will
85
+ automatically select the next available GPU after the last one used for training. This assumes that
86
+ training has not already occupied all available GPUs. If only one device is available, the device will be
87
+ shared between both training and vLLM.
88
+ vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`):
89
+ Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the
90
+ device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus
91
+ improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors
92
+ during initialization.
93
+ vllm_dtype (`str`, *optional*, defaults to `"auto"`):
94
+ Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined
95
+ based on the model configuration. Find the supported values in the vLLM documentation.
96
+ vllm_max_model_len (`int` or `None`, *optional*, defaults to `None`):
97
+ If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced
98
+ `vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model
99
+ context size, which might be much larger than the KV cache, leading to inefficiencies.
100
+ vllm_enable_prefix_caching (`bool`, *optional*, defaults to `True`):
101
+ Whether to enable prefix caching in vLLM. If set to `True` (default), ensure that the model and the hardware
102
+ support this feature.
103
+ vllm_guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`):
104
+ Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled.
105
+
106
+ > Parameters that control the training
107
+
108
+ learning_rate (`float`, *optional*, defaults to `1e-6`):
109
+ Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
110
+ [`~transformers.TrainingArguments`].
111
+ beta (`float`, *optional*, defaults to `0.04`):
112
+ KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving training
113
+ speed, but may be numerically unstable for long training runs.
114
+ num_iterations (`int`, *optional*, defaults to `1`):
115
+ Number of iterations per batch (denoted as μ in the algorithm).
116
+ epsilon (`float`, *optional*, defaults to `0.2`):
117
+ Epsilon value for clipping.
118
+ epsilon_high (`float` or `None`, *optional*, defaults to `None`):
119
+ Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the lower-bound
120
+ specified in argument `epsilon`. Paper [DAPO](https://huggingface.co/papers/2503.14476) recommends `0.28`.
121
+ reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
122
+ Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
123
+ weighted equally with weight `1.0`.
124
+ sync_ref_model (`bool`, *optional*, defaults to `False`):
125
+ Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using
126
+ the `ref_model_mixup_alpha` parameter. This synchronization originites from the
127
+ [TR-DPO](https://huggingface.co/papers/2404.09656) paper.
128
+ ref_model_mixup_alpha (`float`, *optional*, defaults to `0.6`):
129
+ α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix
130
+ between the current policy and the previous reference policy during updates. The reference policy is
131
+ updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you
132
+ must set `sync_ref_model=True`.
133
+ ref_model_sync_steps (`int`, *optional*, defaults to `512`):
134
+ τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how
135
+ frequently the current policy is synchronized with the reference policy. To use this parameter, you must
136
+ set `sync_ref_model=True`.
137
+
138
+ > Parameters that control the logging
139
+
140
+ log_completions (`bool`, *optional*, defaults to `False`):
141
+ Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is
142
+ installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`.
143
+ """
144
+
145
+ # Parameters that control the model and reference model
146
+ model_init_kwargs: Optional[dict] = field(
147
+ default=None,
148
+ metadata={
149
+ "help": "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` "
150
+ "argument of the `GRPOTrainer` is provided as a string."
151
+ },
152
+ )
153
+
154
+ # Parameters that control the data preprocessing
155
+ # The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on
156
+ # additional columns to compute the reward
157
+ remove_unused_columns: Optional[bool] = field(
158
+ default=False,
159
+ metadata={
160
+ "help": "Whether to only keep the column 'prompt' in the dataset. If you use a custom reward function "
161
+ "that requires any column other than 'prompts' and 'completions', you should keep this to `False`."
162
+ },
163
+ )
164
+ max_prompt_length: Optional[int] = field(
165
+ default=512,
166
+ metadata={
167
+ "help": "Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left."
168
+ },
169
+ )
170
+ num_generations: Optional[int] = field(
171
+ default=8,
172
+ metadata={
173
+ "help": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) "
174
+ "must be divisible by this value."
175
+ },
176
+ )
177
+ max_completion_length: Optional[int] = field(
178
+ default=800,
179
+ metadata={"help": "Maximum length of the generated completion."},
180
+ )
181
+ ds3_gather_for_generation: bool = field(
182
+ default=True,
183
+ metadata={
184
+ "help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for "
185
+ "generation, improving generation speed. However, disabling this option allows training models that "
186
+ "exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation. Disabling this option "
187
+ "is not compatible with vLLM generation."
188
+ },
189
+ )
190
+
191
+ # Parameters that control generation
192
+ temperature: float = field(
193
+ default=0.6,
194
+ metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."},
195
+ )
196
+ top_p: float = field(
197
+ default=0.95,
198
+ metadata={
199
+ "help": "Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. "
200
+ "Set to 1.0 to consider all tokens."
201
+ },
202
+ )
203
+ top_k: Optional[int] = field(
204
+ default=20,
205
+ metadata={
206
+ "help": "Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, "
207
+ "top-k-filtering is disabled."
208
+ },
209
+ )
210
+ min_p: Optional[float] = field(
211
+ default=None,
212
+ metadata={
213
+ "help": "Minimum token probability, which will be scaled by the probability of the most likely token. It "
214
+ "must be a value between 0.0 and 1.0. Typical values are in the 0.01-0.2 range."
215
+ },
216
+ )
217
+ repetition_penalty: float = field(
218
+ default=1.0,
219
+ metadata={
220
+ "help": "Float that penalizes new tokens based on whether they appear in the prompt and the generated "
221
+ "text so far. Values > 1.0 encourage the model to use new tokens, while values < 1.0 encourage the model "
222
+ "to repeat tokens."
223
+ },
224
+ )
225
+ cache_implementation: Optional[str] = field(
226
+ default=None,
227
+ metadata={"help": "Implementation of the cache method for faster generation when use_vllm is set to False."},
228
+ )
229
+
230
+ # Parameters that control generation acceleration powered by vLLM
231
+ use_vllm: Optional[bool] = field(
232
+ default=False,
233
+ metadata={
234
+ "help": "Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept "
235
+ "unused for training, as vLLM will require one for generation. vLLM must be installed "
236
+ "(`pip install vllm`)."
237
+ },
238
+ )
239
+ vllm_device: Optional[str] = field(
240
+ default="auto",
241
+ metadata={
242
+ "help": "Device where vLLM generation will run, e.g. 'cuda:1'. If set to 'auto' (default), the system "
243
+ "will automatically select the next available GPU after the last one used for training. This assumes "
244
+ "that training has not already occupied all available GPUs."
245
+ },
246
+ )
247
+ vllm_gpu_memory_utilization: float = field(
248
+ default=0.9,
249
+ metadata={
250
+ "help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV "
251
+ "cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache "
252
+ "size and thus improve the model's throughput. However, if the value is too high, it may cause "
253
+ "out-of-memory (OOM) errors during initialization."
254
+ },
255
+ )
256
+ vllm_dtype: Optional[str] = field(
257
+ default="auto",
258
+ metadata={
259
+ "help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically "
260
+ "determined based on the model configuration. Find the supported values in the vLLM documentation."
261
+ },
262
+ )
263
+ vllm_max_model_len: Optional[int] = field(
264
+ default=None,
265
+ metadata={
266
+ "help": "If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced "
267
+ "`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model "
268
+ "context size, which might be much larger than the KV cache, leading to inefficiencies."
269
+ },
270
+ )
271
+ vllm_enable_prefix_caching: Optional[bool] = field(
272
+ default=True,
273
+ metadata={
274
+ "help": "Whether to enable prefix caching in vLLM. If set to `True` (default), ensure that the model and "
275
+ "the hardware support this feature."
276
+ },
277
+ )
278
+ vllm_guided_decoding_regex: Optional[str] = field(
279
+ default=None,
280
+ metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
281
+ )
282
+
283
+ # Parameters that control the training
284
+ learning_rate: float = field(
285
+ default=1e-6,
286
+ metadata={
287
+ "help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of "
288
+ "`transformers.TrainingArguments`."
289
+ },
290
+ )
291
+ beta: float = field(
292
+ default=0.04,
293
+ metadata={
294
+ "help": "KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving "
295
+ "training speed, but may be numerically unstable for long training runs."
296
+ },
297
+ )
298
+ num_iterations: int = field(
299
+ default=1,
300
+ metadata={"help": "Number of iterations per batch (denoted as μ in the algorithm)."},
301
+ )
302
+ epsilon: float = field(
303
+ default=0.2,
304
+ metadata={"help": "Epsilon value for clipping."},
305
+ )
306
+ epsilon_high: Optional[float] = field(
307
+ default=None,
308
+ metadata={
309
+ "help": "Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the "
310
+ "lower-bound specified in argument `epsilon`. Paper DAPO recommends `0.28`."
311
+ },
312
+ )
313
+ reward_weights: Optional[list[float]] = field(
314
+ default=None,
315
+ metadata={
316
+ "help": "Weights for each reward function. Must match the number of reward functions. If `None`, all "
317
+ "rewards are weighted equally with weight `1.0`."
318
+ },
319
+ )
320
+ sync_ref_model: bool = field(
321
+ default=False,
322
+ metadata={
323
+ "help": "Whether to synchronize the reference model with the active model every `ref_model_sync_steps` "
324
+ "steps, using the `ref_model_mixup_alpha` parameter."
325
+ },
326
+ )
327
+ ref_model_mixup_alpha: float = field(
328
+ default=0.6,
329
+ metadata={
330
+ "help": "α parameter from the TR-DPO paper, which controls the mix between the current policy and the "
331
+ "previous reference policy during updates. The reference policy is updated according to the equation: "
332
+ "`π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you must set `sync_ref_model=True`."
333
+ },
334
+ )
335
+ ref_model_sync_steps: int = field(
336
+ default=512,
337
+ metadata={
338
+ "help": "τ parameter from the TR-DPO paper, which determines how frequently the current policy is "
339
+ "synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`."
340
+ },
341
+ )
342
+
343
+ # Parameters that control the logging
344
+ log_completions: bool = field(
345
+ default=True,
346
+ metadata={
347
+ "help": "Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is "
348
+ "installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`."
349
+ },
350
+ )
351
+
352
+ report_to: Union[None, str, list[str]] = field(
353
+ default="wandb", metadata={"help": "The list of integrations to report the results and logs to."}
354
+ )
355
+
356
+ logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"})
357
+ logging_steps: float = field(
358
+ default=2,
359
+ metadata={
360
+ "help": (
361
+ "Log every X updates steps. Should be an integer or a float in range `[0,1)`. "
362
+ "If smaller than 1, will be interpreted as ratio of total training steps."
363
+ )
364
+ },
365
+ )
BioReason-0813/blip2_dna_module.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import (
2
+ AutoProcessor,
3
+ AutoTokenizer,
4
+ )
5
+ from typing import Dict, Any, Union, List, Optional, Callable, Type
6
+ from trl.data_utils import maybe_apply_chat_template
7
+ import torch
8
+
9
+ from bioreason.dna_modules.dna_module import DNABaseModule
10
+ from model.blip2_stage2 import Blip2Stage2
11
+
12
+
13
+ class Blip2DNAModule(DNABaseModule):
14
+ """
15
+ DNA module implementation for BLIP2-based models.
16
+
17
+ This module provides the interface between BLIP2 models and the GRPO training
18
+ infrastructure, handling model loading, processing setup, and reward functions.
19
+ """
20
+
21
+ def __init__(self):
22
+ """Initialize the Blip2DNAModule."""
23
+ super().__init__()
24
+
25
+ def get_dnallm_key(self) -> str:
26
+ """
27
+ Get the key identifier for this DNA-LLM implementation.
28
+
29
+ Returns:
30
+ String identifier for this module type
31
+ """
32
+ return "blip2"
33
+
34
+ def get_model_class(self, model_id: str, model_init_kwargs: Dict[str, Any]) -> Type:
35
+ """
36
+ Return the appropriate model class based on model ID.
37
+
38
+ Args:
39
+ model_id: Identifier for the model
40
+ model_init_kwargs: Initialization arguments for the model
41
+
42
+ Returns:
43
+ The model class to instantiate
44
+
45
+ Raises:
46
+ ValueError: If the model is not supported
47
+ """
48
+ if "blip2" in model_id.lower() or "stage2" in model_id.lower():
49
+ model_cls = Blip2Stage2
50
+ else:
51
+ raise ValueError(f"Unsupported model: {model_id}")
52
+ return model_cls
53
+
54
+ def post_model_init(self, model: Any, processing_class: Any) -> None:
55
+ """
56
+ Perform any post-initialization setup on the model.
57
+
58
+ Args:
59
+ model: The initialized model
60
+ processing_class: The processor for the model
61
+ """
62
+ # BLIP2 models might need specific post-init setup
63
+ if hasattr(model, 'blip2') and hasattr(model.blip2, 'llm_tokenizer'):
64
+ # Ensure the tokenizer is properly configured
65
+ if not hasattr(model.blip2.llm_tokenizer, 'pad_token') or model.blip2.llm_tokenizer.pad_token is None:
66
+ model.blip2.llm_tokenizer.pad_token = model.blip2.llm_tokenizer.eos_token
67
+
68
+ def get_processing_class(self) -> Type:
69
+ """
70
+ Get the processing class to use with this BLIP2 model.
71
+
72
+ Returns:
73
+ The processing class
74
+ """
75
+ return Blip2Processor
76
+
77
+ def get_dnallm_modules_keywords(self) -> List[str]:
78
+ """
79
+ Get keywords to identify DNA-specific modules in the model.
80
+
81
+ Used to exclude DNA modules from LoRA adaptation during training.
82
+
83
+ Returns:
84
+ List of keywords that identify DNA modules
85
+ """
86
+ return ["plm", "qformer", "opt_proj"]
87
+
88
+ def get_custom_multimodal_keywords(self) -> List[str]:
89
+ """
90
+ Get keywords for multimodal inputs that should be passed to the model.
91
+
92
+ Returns:
93
+ List of input keywords for multimodal processing
94
+ """
95
+ return ["prot_batch", "prompt_batch"]
96
+
97
+ def get_non_generate_params(self) -> List[str]:
98
+ """
99
+ Get parameter names that should be excluded from generation.
100
+
101
+ Returns:
102
+ List of parameter names to exclude from generation calls
103
+ """
104
+ return ["prot_batch"]
105
+
106
+ def get_custom_processing_keywords(self) -> List[tuple]:
107
+ """
108
+ Get custom processing keywords for the processor.
109
+
110
+ Returns:
111
+ List of (component, parameter) tuples for custom processing
112
+ """
113
+ return [("plm_tokenizer", "max_length"), ("llm_tokenizer", "max_length")]
114
+
115
+ def prepare_prompt(
116
+ self, processing_class: Any, inputs: List[Dict[str, Union[torch.Tensor, Any]]]
117
+ ) -> List[str]:
118
+ """
119
+ Prepare prompts from input examples.
120
+
121
+ Args:
122
+ processing_class: The processor to use
123
+ inputs: List of input examples
124
+
125
+ Returns:
126
+ List of prepared prompts
127
+ """
128
+ prompts_text = []
129
+ for example in inputs:
130
+ if "prompt" in example:
131
+ # Extract text content from conversational format
132
+ if isinstance(example["prompt"], list) and len(example["prompt"]) > 0:
133
+ user_content = example["prompt"][0].get("content", "")
134
+ if isinstance(user_content, list):
135
+ # Extract text from multimodal content
136
+ text_parts = [item.get("text", "") for item in user_content if item.get("type") == "text"]
137
+ prompt_text = " ".join(text_parts)
138
+ else:
139
+ prompt_text = str(user_content)
140
+ else:
141
+ prompt_text = str(example["prompt"])
142
+ else:
143
+ prompt_text = ""
144
+ prompts_text.append(prompt_text)
145
+ return prompts_text
146
+
147
+ def prepare_model_inputs(
148
+ self,
149
+ processing_class: Any,
150
+ model: Any,
151
+ prompts_text: List[str],
152
+ batch_dna_sequences: List[List[str]],
153
+ return_tensors: str = "pt",
154
+ padding: bool = True,
155
+ padding_side: str = "left",
156
+ add_special_tokens: bool = False,
157
+ ) -> Dict[str, Any]:
158
+ """
159
+ Prepare inputs for the BLIP2 model.
160
+
161
+ Args:
162
+ processing_class: The processor to use
163
+ model: The model to prepare inputs for
164
+ prompts_text: List of text prompts
165
+ batch_dna_sequences: List of lists of DNA sequences (treated as protein sequences)
166
+ return_tensors: Return format for tensors
167
+ padding: Whether to pad inputs
168
+ padding_side: Side to pad on
169
+ add_special_tokens: Whether to add special tokens
170
+
171
+ Returns:
172
+ Processed inputs for the model
173
+ """
174
+ # Get the BLIP2 model from the wrapper
175
+ blip2_model = model.blip2 if hasattr(model, 'blip2') else model
176
+
177
+ # Prepare protein batch (using DNA sequences as protein sequences)
178
+ # Flatten all DNA sequences to treat them as individual protein sequences
179
+ all_sequences = []
180
+ for sequences in batch_dna_sequences:
181
+ all_sequences.extend(sequences)
182
+
183
+ if all_sequences:
184
+ prot_batch = blip2_model.plm_tokenizer(
185
+ all_sequences,
186
+ padding=padding,
187
+ truncation=True,
188
+ max_length=512, # Default protein sequence length
189
+ return_tensors=return_tensors,
190
+ )
191
+ else:
192
+ # Empty batch handling
193
+ prot_batch = {
194
+ 'input_ids': torch.empty(0, 1, dtype=torch.long),
195
+ 'attention_mask': torch.empty(0, 1, dtype=torch.long)
196
+ }
197
+
198
+ # Prepare prompt batch
199
+ prompt_batch = blip2_model.llm_tokenizer(
200
+ prompts_text,
201
+ padding=padding,
202
+ truncation=True,
203
+ max_length=256, # Default prompt length
204
+ return_tensors=return_tensors,
205
+ )
206
+
207
+ return {
208
+ "prot_batch": prot_batch,
209
+ "prompt_batch": prompt_batch,
210
+ "input_ids": prompt_batch["input_ids"], # For compatibility
211
+ "attention_mask": prompt_batch["attention_mask"], # For compatibility
212
+ }
213
+
214
+ def is_embeds_input(self) -> bool:
215
+ """
216
+ Whether the model uses embeddings as input (instead of token IDs).
217
+
218
+ Returns:
219
+ Boolean indicating if the model takes embedding inputs
220
+ """
221
+ return True # BLIP2 uses embeddings internally
222
+
223
+ @staticmethod
224
+ def get_question_template() -> str:
225
+ """
226
+ Get the template for formatting questions.
227
+
228
+ Returns:
229
+ String template for questions
230
+ """
231
+ return "{Question}"
232
+
233
+ @staticmethod
234
+ def format_reward_rec(completions: List[Dict[str, Any]], **kwargs) -> List[float]:
235
+ """
236
+ Check if the BLIP2 model output matches a specific format.
237
+
238
+ Args:
239
+ completions: List of model completions
240
+ **kwargs: Additional arguments
241
+
242
+ Returns:
243
+ List of reward scores (1.0 for match, 0.0 for no match)
244
+ """
245
+ import re
246
+ import os
247
+ from datetime import datetime
248
+
249
+ # Pattern to match the expected output format
250
+ pattern = r"<think>.*?</think>\s*<answer>.*?\{.*\[\d+,\s*\d+,\s*\d+,\s*\d+\].*\}.*?</answer>"
251
+ completion_contents = [completion[0]["content"] for completion in completions]
252
+ matches = [
253
+ re.search(pattern, content, re.DOTALL) is not None
254
+ for content in completion_contents
255
+ ]
256
+
257
+ # Log format results if in debug mode
258
+ current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
259
+ if os.getenv("DEBUG_MODE") == "true":
260
+ log_path = os.getenv("LOG_PATH")
261
+ with open(
262
+ log_path.replace(".txt", "_format.txt"), "a", encoding="utf-8"
263
+ ) as f:
264
+ f.write(f"------------- {current_time} Format reward -------------\n")
265
+ for content, match in zip(completion_contents, matches):
266
+ f.write(f"Content: {content}\n")
267
+ f.write(f"Has format: {bool(match)}\n")
268
+
269
+ return [1.0 if match else 0.0 for match in matches]
270
+
271
+ @staticmethod
272
+ def select_reward_func(func: str, task_type: str) -> Callable:
273
+ """
274
+ Select the appropriate reward function based on function name and task type.
275
+
276
+ Args:
277
+ func: The type of reward function ('accuracy', 'format', etc.)
278
+ task_type: The type of task ('rec', etc.)
279
+
280
+ Returns:
281
+ The reward function to use
282
+
283
+ Raises:
284
+ ValueError: If the function or task type is not supported
285
+ """
286
+ if func == "accuracy":
287
+ match task_type:
288
+ case "rec":
289
+ return Blip2DNAModule.iou_reward
290
+ case _:
291
+ raise ValueError(f"Unsupported reward function: {func}")
292
+ elif func == "format":
293
+ match task_type:
294
+ case "rec":
295
+ return Blip2DNAModule.format_reward_rec
296
+ case _:
297
+ raise ValueError(f"Unsupported reward function: {func}")
298
+ else:
299
+ raise ValueError(f"Unsupported reward function: {func}")
300
+
301
+ @staticmethod
302
+ def iou_reward(completions: List[Dict[str, Any]], **kwargs) -> List[float]:
303
+ """
304
+ Placeholder IoU reward function.
305
+
306
+ Args:
307
+ completions: List of model completions
308
+ **kwargs: Additional arguments
309
+
310
+ Returns:
311
+ List of reward scores
312
+ """
313
+ # Placeholder implementation
314
+ return [1.0] * len(completions)
315
+
316
+
317
+ class Blip2Processor:
318
+ """
319
+ Simple processor wrapper for BLIP2 models to maintain compatibility
320
+ with the GRPO trainer interface.
321
+ """
322
+
323
+ def __init__(self, plm_tokenizer=None, llm_tokenizer=None):
324
+ self.plm_tokenizer = plm_tokenizer
325
+ self.llm_tokenizer = llm_tokenizer
326
+
327
+ # Set compatibility attributes
328
+ if llm_tokenizer:
329
+ self.eos_token_id = llm_tokenizer.eos_token_id
330
+ self.pad_token_id = llm_tokenizer.pad_token_id
331
+
332
+ def __call__(self, *args, **kwargs):
333
+ """
334
+ Process inputs for BLIP2 model.
335
+ This is a simplified version that delegates to the appropriate tokenizer.
336
+ """
337
+ # For compatibility, return a simple tokenization result
338
+ if self.llm_tokenizer:
339
+ return self.llm_tokenizer(*args, **kwargs)
340
+ else:
341
+ # Fallback behavior
342
+ return {"input_ids": torch.tensor([[1]]), "attention_mask": torch.tensor([[1]])}
343
+
344
+ def batch_decode(self, *args, **kwargs):
345
+ """Decode token sequences."""
346
+ if self.llm_tokenizer:
347
+ return self.llm_tokenizer.batch_decode(*args, **kwargs)
348
+ else:
349
+ return [""]
BioReason-0813/blip2_grpo_trainer.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import time
17
+ import textwrap
18
+ import pandas as pd
19
+ from collections import defaultdict
20
+ from typing import Any, Callable, Optional, Union, Sized
21
+
22
+ import torch
23
+ import torch.utils.data
24
+ import transformers
25
+ from datasets import Dataset, IterableDataset
26
+ from packaging import version
27
+ from transformers import (
28
+ AutoModelForCausalLM,
29
+ AutoModelForSequenceClassification,
30
+ AutoProcessor,
31
+ AutoTokenizer,
32
+ GenerationConfig,
33
+ PreTrainedModel,
34
+ PreTrainedTokenizerBase,
35
+ Trainer,
36
+ TrainerCallback,
37
+ is_wandb_available,
38
+ )
39
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
40
+ from transformers.utils import is_peft_available
41
+
42
+ from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
43
+ from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
44
+ from trl.trainer.grpo_config import GRPOConfig
45
+ from trl.trainer.utils import generate_model_card, get_comet_experiment_url
46
+
47
+ from accelerate.utils import is_peft_model, set_seed, gather_object
48
+ import PIL.Image
49
+
50
+ import copy
51
+ from torch.utils.data import Sampler
52
+ import warnings
53
+
54
+ if is_peft_available():
55
+ from peft import PeftConfig, get_peft_model, prepare_model_for_kbit_training
56
+
57
+ if is_wandb_available():
58
+ import wandb
59
+
60
+ from bioreason.dna_modules.dna_module import DNABaseModule
61
+ from bioreason.trainer import DNALLMGRPOConfig
62
+
63
+ # Import the RepeatRandomSampler from the original trainer
64
+ from bioreason.trainer.grpo_trainer import RepeatRandomSampler
65
+
66
+ # What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
67
+ # rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
68
+ RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
69
+
70
+
71
+ class Blip2GRPOTrainer(Trainer):
72
+ """
73
+ Modified GRPO Trainer for BLIP2 models.
74
+
75
+ This trainer adapts the original GRPO trainer to work with BLIP2 architecture,
76
+ handling the different input formats and forward pass requirements.
77
+ """
78
+
79
+ def __init__(
80
+ self,
81
+ model: Union[str, PreTrainedModel],
82
+ reward_funcs: Union[RewardFunc, list[RewardFunc]],
83
+ args: DNALLMGRPOConfig = None,
84
+ dna_module: DNABaseModule = None,
85
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
86
+ eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
87
+ processing_class: Optional[PreTrainedTokenizerBase] = None,
88
+ reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
89
+ callbacks: Optional[list[TrainerCallback]] = None,
90
+ optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
91
+ peft_config: Optional["PeftConfig"] = None,
92
+ freeze_dna_modules: Optional[bool] = False,
93
+ attn_implementation: str = "flash_attention_2",
94
+ torch_dtype: str = "bfloat16",
95
+ **kwargs,
96
+ ):
97
+ # Args
98
+ if args is None:
99
+ model_name = model if isinstance(model, str) else "blip2-model"
100
+ args = GRPOConfig(f"{model_name}-GRPO")
101
+
102
+ self.dna_module = dna_module
103
+
104
+ # Models
105
+ model_init_kwargs = args.model_init_kwargs or {}
106
+ model_init_kwargs["attn_implementation"] = attn_implementation
107
+ if model_init_kwargs.get("torch_dtype") is None:
108
+ model_init_kwargs["torch_dtype"] = torch_dtype
109
+
110
+ assert not isinstance(model, str), "model must NOT be a string in the current implementation"
111
+
112
+ torch_dtype = model_init_kwargs.get("torch_dtype")
113
+ if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
114
+ pass # torch_dtype is already a torch.dtype or "auto" or None
115
+ elif isinstance(torch_dtype, str): # it's a str, but not "auto"
116
+ torch_dtype = getattr(torch, torch_dtype)
117
+ else:
118
+ raise ValueError(
119
+ "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
120
+ f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
121
+ )
122
+
123
+ # Disable caching if gradient checkpointing is enabled (not supported)
124
+ if hasattr(model, 'blip2') and hasattr(model.blip2, 'llm_model'):
125
+ model.blip2.llm_model.config.use_cache = (
126
+ False if args.gradient_checkpointing else model.blip2.llm_model.config.use_cache
127
+ )
128
+
129
+ # LoRA setup for BLIP2
130
+ self.dna_modules_keywords = self.dna_module.get_dnallm_modules_keywords()
131
+ if peft_config is not None:
132
+ print("Applying LoRA...")
133
+ def find_all_linear_names(model, multimodal_keywords):
134
+ cls = torch.nn.Linear
135
+ lora_module_names = set()
136
+
137
+ # Focus on the LLM part of BLIP2
138
+ if hasattr(model, 'blip2') and hasattr(model.blip2, 'llm_model'):
139
+ llm_model = model.blip2.llm_model
140
+ for name, module in llm_model.named_modules():
141
+ # Skip DNA/multimodal modules
142
+ if any(mm_keyword in name for mm_keyword in multimodal_keywords):
143
+ continue
144
+ if isinstance(module, cls):
145
+ lora_module_names.add(name)
146
+
147
+ # Remove embedding layers
148
+ for m in list(lora_module_names):
149
+ if "embed_tokens" in m or "embedding" in m:
150
+ lora_module_names.remove(m)
151
+
152
+ return list(lora_module_names)
153
+
154
+ target_modules = find_all_linear_names(model, self.dna_modules_keywords)
155
+ peft_config.target_modules = target_modules
156
+
157
+ # Apply LoRA to the LLM part
158
+ if hasattr(model, 'blip2') and hasattr(model.blip2, 'llm_model'):
159
+ model.blip2.llm_model = prepare_model_for_kbit_training(model.blip2.llm_model)
160
+ model.blip2.llm_model = get_peft_model(model.blip2.llm_model, peft_config)
161
+
162
+ # Freeze DNA/protein modules if requested
163
+ if freeze_dna_modules:
164
+ print("Freezing protein/DNA modules...")
165
+ if hasattr(model, 'blip2'):
166
+ # Freeze protein language model
167
+ if hasattr(model.blip2, 'plm'):
168
+ for p in model.blip2.plm.parameters():
169
+ p.requires_grad = False
170
+
171
+ # Freeze Q-former if specified
172
+ if hasattr(model.blip2, 'Qformer'):
173
+ for p in model.blip2.Qformer.parameters():
174
+ p.requires_grad = False
175
+
176
+ # Count trainable parameters
177
+ trainable_params = [p for p in model.parameters() if p.requires_grad]
178
+ total_params = sum(p.numel() for p in trainable_params)
179
+ print(f"Total trainable parameters: {total_params}")
180
+
181
+ # Enable gradient checkpointing if requested
182
+ if args.gradient_checkpointing:
183
+ model = self._enable_gradient_checkpointing(model, args)
184
+
185
+ # Reference model
186
+ self.beta = args.beta
187
+ if self.beta == 0.0:
188
+ self.ref_model = None
189
+ elif is_deepspeed_zero3_enabled():
190
+ # Create reference model for DeepSpeed
191
+ self.ref_model = type(model)(model.args) # Create same type of model
192
+ elif is_peft_model(model.blip2.llm_model if hasattr(model, 'blip2') else model):
193
+ self.ref_model = None
194
+ else:
195
+ self.ref_model = create_reference_model(model)
196
+
197
+ # Processing class setup
198
+ if processing_class is None:
199
+ processing_cls = self.dna_module.get_processing_class()
200
+
201
+ # Get tokenizers from BLIP2 model
202
+ if hasattr(model, 'blip2'):
203
+ plm_tokenizer = getattr(model.blip2, 'plm_tokenizer', None)
204
+ llm_tokenizer = getattr(model.blip2, 'llm_tokenizer', None)
205
+ processing_class = processing_cls(plm_tokenizer=plm_tokenizer, llm_tokenizer=llm_tokenizer)
206
+ else:
207
+ processing_class = processing_cls()
208
+
209
+ # Set up tokenizer attributes
210
+ if hasattr(processing_class, 'llm_tokenizer') and processing_class.llm_tokenizer:
211
+ processing_class.pad_token_id = processing_class.llm_tokenizer.pad_token_id
212
+ processing_class.eos_token_id = processing_class.llm_tokenizer.eos_token_id
213
+ else:
214
+ # Fallback
215
+ processing_class.pad_token_id = 0
216
+ processing_class.eos_token_id = 1
217
+
218
+ self.dna_module.post_model_init(model, processing_class)
219
+ self.dna_module.post_model_init(self.ref_model, processing_class)
220
+
221
+ # Reward functions
222
+ if not isinstance(reward_funcs, list):
223
+ reward_funcs = [reward_funcs]
224
+ for i, reward_func in enumerate(reward_funcs):
225
+ if isinstance(reward_func, str):
226
+ reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
227
+ reward_func, num_labels=1, **model_init_kwargs
228
+ )
229
+ self.reward_funcs = reward_funcs
230
+
231
+ # Reward processing classes
232
+ if reward_processing_classes is None:
233
+ reward_processing_classes = [None] * len(reward_funcs)
234
+ elif not isinstance(reward_processing_classes, list):
235
+ reward_processing_classes = [reward_processing_classes]
236
+
237
+ for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
238
+ if isinstance(reward_func, PreTrainedModel):
239
+ if reward_processing_class is None:
240
+ reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
241
+ if reward_processing_class.pad_token_id is None:
242
+ reward_processing_class.pad_token = reward_processing_class.eos_token
243
+ reward_func.config.pad_token_id = reward_processing_class.pad_token_id
244
+ reward_processing_classes[i] = reward_processing_class
245
+ self.reward_processing_classes = reward_processing_classes
246
+
247
+ # Data collator
248
+ def data_collator(features):
249
+ return features
250
+
251
+ # Training arguments
252
+ self.max_prompt_length = args.max_prompt_length
253
+ self.max_prompt_length = None
254
+ if args.max_prompt_length is not None:
255
+ warnings.warn("Setting max_prompt_length is currently not supported, it has been set to None")
256
+
257
+ self.max_completion_length = args.max_completion_length
258
+ self.num_generations = args.num_generations
259
+
260
+ # Generation config for BLIP2
261
+ self.generation_config = GenerationConfig(
262
+ max_new_tokens=self.max_completion_length,
263
+ do_sample=True,
264
+ temperature=0.6,
265
+ top_p=0.95,
266
+ top_k=20,
267
+ pad_token_id=processing_class.pad_token_id,
268
+ eos_token_id=processing_class.eos_token_id,
269
+ )
270
+
271
+ self.beta = args.beta
272
+ self.epsilon_low = args.epsilon
273
+ self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon
274
+
275
+ # Multi-step
276
+ self.num_iterations = args.num_iterations
277
+ self._step = 0
278
+ self._buffered_inputs = [None] * args.gradient_accumulation_steps
279
+
280
+ # Initialize metrics
281
+ self._metrics = defaultdict(list)
282
+ self.log_completions = args.log_completions
283
+
284
+ super().__init__(
285
+ model=model,
286
+ args=args,
287
+ data_collator=data_collator,
288
+ train_dataset=train_dataset,
289
+ eval_dataset=eval_dataset,
290
+ processing_class=processing_class,
291
+ callbacks=callbacks,
292
+ optimizers=optimizers,
293
+ )
294
+
295
+ # Validate batch sizes
296
+ num_processes = self.accelerator.num_processes
297
+ global_batch_size = args.per_device_train_batch_size * num_processes
298
+ possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
299
+ if self.num_generations not in possible_values:
300
+ raise ValueError(
301
+ f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
302
+ f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train "
303
+ f"batch size, the valid values for the number of generations are: {possible_values}."
304
+ )
305
+
306
+ # Set unique seed per process
307
+ set_seed(args.seed, device_specific=True)
308
+
309
+ # Gradient accumulation settings
310
+ self.model_accepts_loss_kwargs = False
311
+
312
+ # Prepare reference model and reward functions
313
+ if self.ref_model is not None:
314
+ if is_deepspeed_zero3_enabled():
315
+ self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
316
+ else:
317
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
318
+
319
+ for i, reward_func in enumerate(self.reward_funcs):
320
+ if isinstance(reward_func, PreTrainedModel):
321
+ self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
322
+
323
+ def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: DNALLMGRPOConfig) -> PreTrainedModel:
324
+ """Enables gradient checkpointing for BLIP2 model."""
325
+ if hasattr(model, 'blip2'):
326
+ # Enable for the LLM component
327
+ if hasattr(model.blip2, 'llm_model'):
328
+ model.blip2.llm_model.config.use_cache = False
329
+ if hasattr(model.blip2.llm_model, 'gradient_checkpointing_enable'):
330
+ model.blip2.llm_model.gradient_checkpointing_enable()
331
+
332
+ # Enable for protein model if needed
333
+ if hasattr(model.blip2, 'plm') and hasattr(model.blip2.plm, 'gradient_checkpointing_enable'):
334
+ model.blip2.plm.gradient_checkpointing_enable()
335
+
336
+ return model
337
+
338
+ def _set_signature_columns_if_needed(self):
339
+ if self._signature_columns is None:
340
+ self._signature_columns = ["prompt"]
341
+
342
+ def _get_key_from_inputs(self, x, key):
343
+ ele = x.get(key, None)
344
+ assert ele is not None, f"The key {key} is not found in the input"
345
+ if isinstance(ele, list):
346
+ return [e for e in ele]
347
+ else:
348
+ return [ele]
349
+
350
+ def _generate_and_score_completions(self, inputs: dict[str, Union[torch.Tensor, Any]], model) -> dict[str, Union[torch.Tensor, Any]]:
351
+ device = self.accelerator.device
352
+ prompts = [x["prompt"] for x in inputs]
353
+ prompts_text = self.dna_module.prepare_prompt(self.processing_class, inputs)
354
+
355
+ # Handle DNA sequences (treat as protein sequences for BLIP2)
356
+ batch_dna_sequences = []
357
+ print("_generate_and_score_completions (BLIP2 GRPO):")
358
+ for x in inputs:
359
+ if 'dna_sequences' in x:
360
+ dnas = self._get_key_from_inputs(x, "dna_sequences")
361
+ batch_dna_sequences.append(dnas)
362
+ else:
363
+ batch_dna_sequences.append([])
364
+
365
+ # Prepare model inputs for BLIP2
366
+ prompt_inputs = self.dna_module.prepare_model_inputs(
367
+ self.processing_class,
368
+ model,
369
+ prompts_text,
370
+ batch_dna_sequences,
371
+ return_tensors="pt",
372
+ padding=True,
373
+ padding_side="left",
374
+ add_special_tokens=False,
375
+ )
376
+
377
+ prompt_inputs = super()._prepare_inputs(prompt_inputs)
378
+
379
+ # Extract BLIP2-specific inputs
380
+ prot_batch = prompt_inputs.get("prot_batch")
381
+ prompt_batch = prompt_inputs.get("prompt_batch")
382
+
383
+ # Generate completions using BLIP2
384
+ start = time.time()
385
+ with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
386
+ # Prepare samples for BLIP2 generation
387
+ samples = {
388
+ 'prot_batch': prot_batch,
389
+ 'prompt_batch': prompt_batch
390
+ }
391
+
392
+ # Use BLIP2's generate method
393
+ if hasattr(unwrapped_model, 'blip2'):
394
+ completions_text = unwrapped_model.blip2.generate(
395
+ samples,
396
+ do_sample=True,
397
+ temperature=0.6,
398
+ top_p=0.95,
399
+ num_beams=1,
400
+ max_length=self.max_completion_length,
401
+ min_length=1,
402
+ )
403
+ else:
404
+ # Fallback if not BLIP2 structure
405
+ completions_text = ["Generated text"] * len(prompts_text)
406
+
407
+ end = time.time()
408
+ print(f"Generation time: {end - start:.9f} seconds")
409
+
410
+ # Convert completions to expected format
411
+ if is_conversational(inputs[0]):
412
+ completions = [[{"role": "assistant", "content": completion}] for completion in completions_text]
413
+ else:
414
+ completions = completions_text
415
+
416
+ # Compute rewards
417
+ print("Reward calculation...")
418
+ rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
419
+ for i, (reward_func, reward_processing_class) in enumerate(
420
+ zip(self.reward_funcs, self.reward_processing_classes)
421
+ ):
422
+ if isinstance(reward_func, PreTrainedModel):
423
+ if is_conversational(inputs[0]):
424
+ messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
425
+ texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
426
+ else:
427
+ texts = [p + c for p, c in zip(prompts, completions)]
428
+ reward_inputs = reward_processing_class(
429
+ texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
430
+ )
431
+ reward_inputs = super()._prepare_inputs(reward_inputs)
432
+ with torch.inference_mode():
433
+ rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0]
434
+ else:
435
+ # Custom reward function
436
+ reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
437
+ for key in reward_kwargs:
438
+ for example in inputs:
439
+ reward_kwargs[key].extend([example[key]])
440
+ output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
441
+ rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
442
+
443
+ # Gather rewards across processes
444
+ rewards_per_func = self.accelerator.gather(rewards_per_func)
445
+ rewards = rewards_per_func.sum(dim=1)
446
+
447
+ # Compute grouped-wise rewards
448
+ mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
449
+ std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
450
+
451
+ # Normalize rewards to compute advantages
452
+ mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
453
+ std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
454
+ advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
455
+
456
+ # Get local slice of advantages
457
+ process_slice = slice(
458
+ self.accelerator.process_index * len(prompts),
459
+ (self.accelerator.process_index + 1) * len(prompts),
460
+ )
461
+ advantages = advantages[process_slice]
462
+
463
+ # Log metrics
464
+ print("Logging metrics...")
465
+ completion_length = len(completions_text[0].split()) if completions_text else 0
466
+ self._metrics["completion_length"].append(completion_length)
467
+
468
+ reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
469
+ for i, reward_func in enumerate(self.reward_funcs):
470
+ if isinstance(reward_func, PreTrainedModel):
471
+ reward_func_name = reward_func.config._name_or_path.split("/")[-1]
472
+ else:
473
+ reward_func_name = reward_func.__name__
474
+ self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
475
+
476
+ self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
477
+ self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())
478
+
479
+ # Log completions if enabled
480
+ if (
481
+ self.log_completions
482
+ and self.state.global_step % self.args.logging_steps == 0
483
+ and "wandb" in self.args.report_to
484
+ ):
485
+ timestamp = time.time()
486
+ num_items = len(gather_object(prompts_text))
487
+
488
+ table = {
489
+ "step": [f"{self.state.global_step}_{timestamp}"] * num_items,
490
+ "prompt": gather_object(prompts_text),
491
+ "completion": gather_object(completions_text),
492
+ "reward": rewards.tolist(),
493
+ }
494
+ df = pd.DataFrame(table)
495
+
496
+ if wandb.run is not None and self.accelerator.is_main_process:
497
+ wandb.log({f"completions_{self.state.global_step}_{timestamp}": wandb.Table(dataframe=df)})
498
+
499
+ return {
500
+ "prot_batch": prot_batch,
501
+ "prompt_batch": prompt_batch,
502
+ "completions_text": completions_text,
503
+ "old_per_token_logps": None, # BLIP2 doesn't need this for current implementation
504
+ "ref_per_token_logps": None, # BLIP2 doesn't need this for current implementation
505
+ "advantages": advantages,
506
+ "multimodal_inputs": {"prot_batch": prot_batch, "prompt_batch": prompt_batch}
507
+ }
508
+
509
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
510
+ if return_outputs:
511
+ raise ValueError("The BLIP2 GRPO Trainer does not support returning outputs")
512
+
513
+ print("compute_loss - index 1")
514
+ if self.state.global_step % self.num_iterations == 0:
515
+ inputs = self._generate_and_score_completions(inputs, model)
516
+ self._buffered_inputs[self._step % self.args.gradient_accumulation_steps] = inputs
517
+ else:
518
+ inputs = self._buffered_inputs[self._step % self.args.gradient_accumulation_steps]
519
+ self._step += 1
520
+
521
+ print("compute_loss - index 2")
522
+
523
+ # For BLIP2, we need to compute loss differently
524
+ # This is a simplified version - you may need to adapt based on your specific BLIP2 implementation
525
+
526
+ # Extract the necessary components
527
+ prot_batch = inputs.get("prot_batch")
528
+ prompt_batch = inputs.get("prompt_batch")
529
+ advantages = inputs.get("advantages")
530
+
531
+ print("compute_loss - index 3")
532
+
533
+ # Create a batch for BLIP2 forward pass
534
+ # This assumes your BLIP2 model expects (prot_batch, prompt_batch, text_dict) format
535
+ text_dict = {"targets": inputs.get("completions_text", [])}
536
+ batch = (prot_batch, prompt_batch, text_dict)
537
+
538
+ print("compute_loss - index 4")
539
+
540
+ # Forward pass through BLIP2
541
+ if hasattr(model, 'blip2'):
542
+ loss = model.blip2(batch)
543
+ else:
544
+ loss = model(batch)
545
+
546
+ print("compute_loss - index 5")
547
+
548
+ # For now, return the basic loss
549
+ # You may want to incorporate the advantages into the loss calculation
550
+ # based on your specific GRPO implementation needs
551
+
552
+ if advantages is not None:
553
+ # Apply advantages weighting (simplified)
554
+ advantage_weight = advantages.mean().item()
555
+ loss = loss * (1.0 + advantage_weight)
556
+
557
+ print("Computing final loss...")
558
+ return loss
559
+
560
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
561
+ metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()}
562
+ logs = {**logs, **metrics}
563
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
564
+ super().log(logs, start_time)
565
+ else:
566
+ super().log(logs)
567
+ self._metrics.clear()
568
+
569
+ def _get_train_sampler(self) -> Sampler:
570
+ """Returns a sampler that ensures proper data sampling for GRPO training."""
571
+ effective_batch_size = (
572
+ self.args.per_device_train_batch_size
573
+ * self.accelerator.num_processes
574
+ * self.args.gradient_accumulation_steps
575
+ )
576
+
577
+ return RepeatRandomSampler(
578
+ data_source=self.train_dataset,
579
+ mini_repeat_count=self.num_generations,
580
+ batch_size=effective_batch_size // self.num_generations,
581
+ repeat_count=self.num_iterations,
582
+ seed=self.args.seed,
583
+ )
584
+
585
+ def _get_eval_sampler(self, eval_dataset) -> Sampler:
586
+ """Returns a sampler for evaluation."""
587
+ return RepeatRandomSampler(
588
+ data_source=eval_dataset,
589
+ mini_repeat_count=self.num_generations,
590
+ seed=self.args.seed,
591
+ )
BioReason-0813/blips_reason.py ADDED
@@ -0,0 +1,866 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import pathlib
4
+ from argparse import ArgumentParser
5
+ from typing import List, Dict, Optional
6
+ from dataclasses import dataclass, field
7
+
8
+ import torch
9
+ from torch import nn
10
+ import torch.nn.functional as F
11
+ from torch.optim import AdamW
12
+ from torch.utils.data import DataLoader, Dataset
13
+ from transformers import get_cosine_schedule_with_warmup, AutoTokenizer
14
+
15
+ from transformers import (
16
+ AutoTokenizer,
17
+ AutoModelForCausalLM,
18
+ AutoModelForMaskedLM,
19
+ AutoProcessor,
20
+ )
21
+
22
+ from datasets import load_dataset, DatasetDict
23
+
24
+ from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
25
+ from transformers import BitsAndBytesConfig
26
+
27
+ import pytorch_lightning as pl
28
+ from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
29
+ from pytorch_lightning.loggers import WandbLogger
30
+
31
+ from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
32
+
33
+ # Import BLIP2 modules
34
+ from model.blip2_stage2 import Blip2Stage2
35
+ from blip2_dna_module import Blip2DNAModule
36
+ from blip2_grpo_trainer import Blip2GRPOTrainer
37
+ from bioreason.trainer import DNALLMGRPOConfig
38
+
39
+ # Custom TrainerCallback to override the saving mechanism
40
+ from transformers import TrainerCallback, TrainerState, TrainerControl
41
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
42
+
43
+ from prompt_templates import prompt_templates
44
+
45
+ class SaveWithPyTorchCallback(TrainerCallback):
46
+ """Custom callback to save models with PyTorch's native save mechanism instead of safetensors"""
47
+ def on_save(self, args, state, control, **kwargs):
48
+ # Get the checkpoint folder
49
+ checkpoint_folder = os.path.join(
50
+ args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
51
+ )
52
+ os.makedirs(checkpoint_folder, exist_ok=True)
53
+
54
+ # Save with PyTorch instead of safetensors
55
+ checkpoint_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
56
+ model = kwargs.get("model")
57
+
58
+ # Get model unwrapped from accelerator etc.
59
+ unwrapped_model = model.module if hasattr(model, "module") else model
60
+
61
+ # Save using PyTorch directly
62
+ torch.save(unwrapped_model.state_dict(), checkpoint_path)
63
+
64
+ # For BLIP2, save the config from the LLM component
65
+ if hasattr(unwrapped_model, "blip2") and hasattr(unwrapped_model.blip2, "llm_model"):
66
+ if hasattr(unwrapped_model.blip2.llm_model, "config"):
67
+ unwrapped_model.blip2.llm_model.config.save_pretrained(checkpoint_folder)
68
+ elif hasattr(unwrapped_model.blip2.llm_model, "base_model") and hasattr(unwrapped_model.blip2.llm_model.base_model, "config"):
69
+ unwrapped_model.blip2.llm_model.base_model.config.save_pretrained(checkpoint_folder)
70
+
71
+ # Print info about what's being saved
72
+ print(f"Saved model checkpoint to {checkpoint_folder}")
73
+ lora_params = [k for k in unwrapped_model.state_dict().keys() if "lora" in k]
74
+ print(f"Checkpoint contains {len(lora_params)} LoRA parameters")
75
+
76
+ # Signal that we've saved
77
+ control.should_save = False
78
+ return control
79
+
80
+ def extract_xml_answer(text: str) -> str:
81
+ """提取answer标签中的内容,如果没有则返回think标签后的内容"""
82
+ # 首先尝试提取answer标签
83
+ answer_match = re.search(r"<answer>(.*?)</answer>", text, re.DOTALL)
84
+ if answer_match:
85
+ return answer_match.group(1).strip()
86
+
87
+ # 如果没有answer标签,尝试提取think标签后的内容
88
+ think_split = text.split("</think>")
89
+ if len(think_split) > 1:
90
+ return think_split[-1].strip()
91
+
92
+ # 如果都没有,返回原文
93
+ return text.strip()
94
+
95
+ def extract_classification_answer(text: str) -> str:
96
+ """专门用于提取分类答案的函数"""
97
+ # 提取answer标签中的内容
98
+ answer_match = re.search(r"<answer>(.*?)</answer>", text, re.DOTALL)
99
+ if answer_match:
100
+ answer_content = answer_match.group(1).strip()
101
+
102
+ # 查找分类相关的模式
103
+ classification_patterns = [
104
+ r"[Cc]lassification:\s*(\d+)",
105
+ r"[Cc]lass:\s*(\d+)",
106
+ r"[Ll]abel:\s*(\d+)",
107
+ r"[Pp]rediction:\s*(\d+)",
108
+ r"(\d+)", # 任何数字
109
+ ]
110
+
111
+ for pattern in classification_patterns:
112
+ match = re.search(pattern, answer_content)
113
+ if match:
114
+ return match.group(1)
115
+
116
+ return answer_content
117
+
118
+ return extract_xml_answer(text)
119
+
120
+ def extract_hash_answer(text: str) -> str | None:
121
+ if "####" not in text:
122
+ return None
123
+ return text.split("####")[1].strip()
124
+
125
+ def get_kegg_questions() -> Dataset:
126
+ """保留原有的KEGG数据集加载函数作为fallback"""
127
+ try:
128
+ data = load_dataset('wanglab/kegg', 'default') # type: ignore
129
+ example_dna_sequences = ["ATCTACATGCAT", "CAGCAGCTACAG", "CATCACATCGACATCGAC"]
130
+ num_dna_sequences = 2
131
+
132
+ data = data.map(lambda x: { # type: ignore
133
+ 'prompt': [
134
+ {
135
+ 'role': 'user',
136
+ 'content': [
137
+ *({'type': 'dna', 'text': None} for _ in range(num_dna_sequences)),
138
+ {'type': 'text', 'text': x['question']},
139
+ ],
140
+ },
141
+ ],
142
+ 'dna_sequences': [x['reference_sequence'], x['variant_sequence']],
143
+ 'answer': x['answer'],
144
+ }) # type: ignore
145
+
146
+ return data
147
+ except Exception as e:
148
+ print(f"Failed to load KEGG dataset: {e}")
149
+ # 返回一个空的数据集结构
150
+ from datasets import Dataset
151
+ empty_data = {
152
+ 'prompt': [],
153
+ 'dna_sequences': [],
154
+ 'answer': []
155
+ }
156
+ dataset = Dataset.from_dict(empty_data)
157
+ return {'train': dataset, 'val': dataset}
158
+
159
+ def get_protein_classification_data(data_path: str = None, prompt_template: str = None) -> Dataset:
160
+ """
161
+ 加载蛋白质分类数据集
162
+ 数据格式:name,aa_seq,label,location,unique_id,pdb_hash
163
+ """
164
+ import pandas as pd
165
+ from datasets import Dataset
166
+
167
+ if data_path is None:
168
+ # 如果没有提供路径,使用默认的kegg数据集作为fallback
169
+ return get_kegg_questions()
170
+
171
+ # 读取CSV数据
172
+ if data_path.endswith('.csv'):
173
+ df = pd.read_csv(data_path)
174
+ else:
175
+ # 假设是其他格式,可以扩展
176
+ raise ValueError(f"Unsupported file format: {data_path}")
177
+
178
+ # 默认prompt模板
179
+ if prompt_template is None:
180
+ prompt_template = """
181
+ Please analyze the following protein sequence and predict its classification.
182
+
183
+ Protein sequence: <protein>{aa_seq}</protein>
184
+
185
+ Question: What is the classification of this protein sequence?
186
+
187
+ Please provide your reasoning in <think></think> tags and your final answer in <answer></answer> tags.
188
+ """
189
+
190
+ # 数据转换
191
+ def process_example(row):
192
+ # 构建prompt
193
+ prompt_text = prompt_template.format(
194
+ aa_seq=row['aa_seq'],
195
+ name=row.get('name', ''),
196
+ location=row.get('location', ''),
197
+ unique_id=row.get('unique_id', ''),
198
+ )
199
+
200
+ return {
201
+ 'prompt': [
202
+ {
203
+ 'role': 'user',
204
+ 'content': [
205
+ {'type': 'protein', 'text': None}, # 蛋白质序列占位符
206
+ {'type': 'text', 'text': prompt_text},
207
+ ],
208
+ },
209
+ ],
210
+ 'dna_sequences': [row['aa_seq']], # 使用aa_seq作为"dna_sequences"
211
+ 'answer': str(row['label']), # label作为答案
212
+ 'metadata': {
213
+ 'name': row.get('name', ''),
214
+ 'location': row.get('location', ''),
215
+ 'unique_id': row.get('unique_id', ''),
216
+ 'pdb_hash': row.get('pdb_hash', ''),
217
+ }
218
+ }
219
+
220
+ # 转换所有数据
221
+ processed_data = []
222
+ for _, row in df.iterrows():
223
+ processed_data.append(process_example(row))
224
+
225
+ # 创建数据集
226
+ dataset = Dataset.from_list(processed_data)
227
+
228
+ # 划分训练集和验证集
229
+ if len(dataset) > 100: # 如果数据足够大,进行划分
230
+ dataset = dataset.train_test_split(test_size=0.1, seed=42)
231
+ else:
232
+ # 数据较小时,复制训练集作为验证集
233
+ dataset = {
234
+ 'train': dataset,
235
+ 'val': dataset.select(range(min(10, len(dataset)))) # 选择前10个作为验证
236
+ }
237
+
238
+ return dataset
239
+
240
+ def get_custom_protein_data_with_prompts(data_path: str = None,
241
+ prompt_templates: Dict[str, str] = None) -> Dataset:
242
+ """
243
+ 更灵活的蛋白质数据加载函数,支持多种prompt模板
244
+ """
245
+ import pandas as pd
246
+ from datasets import Dataset
247
+ import random
248
+
249
+ if data_path is None:
250
+ return get_kegg_questions()
251
+
252
+ # 读取数据
253
+ df = pd.read_csv(data_path)
254
+
255
+ def process_example(row, template_name=None):
256
+ # 随机选择或指定template
257
+ if template_name is None:
258
+ template_name = random.choice(list(prompt_templates.keys()))
259
+
260
+ template = prompt_templates[template_name]
261
+
262
+ # 格式化prompt
263
+ prompt_text = template.format(
264
+ aa_seq=row['aa_seq'][:500] + "..." if len(row['aa_seq']) > 500 else row['aa_seq'], # 截断长序列
265
+ label=row['label'],
266
+ name=row.get('name', ''),
267
+ location=row.get('location', ''),
268
+ )
269
+
270
+ return {
271
+ 'prompt': [
272
+ {
273
+ 'role': 'user',
274
+ 'content': [
275
+ {'type': 'protein', 'text': None},
276
+ {'type': 'text', 'text': prompt_text.split('<protein>')[0]}, # prompt前半部分
277
+ ],
278
+ },
279
+ ],
280
+ 'dna_sequences': [row['aa_seq']], # 完整序列用于模型处理
281
+ 'answer': str(row['label']),
282
+ 'template_used': template_name,
283
+ 'metadata': {
284
+ 'name': row.get('name', ''),
285
+ 'location': row.get('location', ''),
286
+ 'unique_id': row.get('unique_id', ''),
287
+ 'pdb_hash': row.get('pdb_hash', ''),
288
+ 'full_prompt': prompt_text,
289
+ }
290
+ }
291
+
292
+ # 处理数据
293
+ processed_data = []
294
+ print("template_name")
295
+ print(script_args.template_name)
296
+ for _, row in df.iterrows():
297
+ processed_data.append(process_example(row,script_args.template_name))
298
+
299
+ dataset = Dataset.from_list(processed_data)
300
+
301
+ # 数据集划分
302
+ if len(dataset) > 50:
303
+ dataset = dataset.train_test_split(test_size=0.1, seed=42)
304
+ else:
305
+ dataset = {
306
+ 'train': dataset,
307
+ 'val': dataset.select(range(min(5, len(dataset))))
308
+ }
309
+
310
+ return dataset
311
+
312
+ def get_gsm8k_questions(question_prompt: str) -> Dataset:
313
+ data = load_dataset('openai/gsm8k', 'main') # type: ignore
314
+
315
+ example_dna_sequences = ["ATCTACATGCAT", "CAGCAGCTACAG", "CATCACATCGACATCGAC"]
316
+ data = data.map(lambda x: { # type: ignore
317
+ 'prompt': [
318
+ {
319
+ 'role': 'user',
320
+ 'content': [
321
+ *({'type': 'dna', 'text': None} for _ in range(len(example_dna_sequences))),
322
+ {'type': 'text', 'text': 'Give me a short introduction to large language model.'}
323
+ ]
324
+ },
325
+ ],
326
+ 'dna_sequences': [dna for dna in example_dna_sequences],
327
+ 'answer': extract_hash_answer(x['answer']),
328
+ }) # type: ignore
329
+
330
+ return data # type: ignore
331
+
332
+ # Reward functions
333
+ def format_correct_reward_func(completions, **kwargs) -> list[float]:
334
+ """
335
+ 奖励函数:检查格式是否正确
336
+ 要求:包含 <think>...</think> 和 <answer>...</answer> 标签
337
+ """
338
+ responses = [completion[0]["content"] for completion in completions]
339
+ rewards = []
340
+
341
+ for response in responses:
342
+ score = 0.0
343
+
344
+ # 检查是否有think标签
345
+ if "<think>" in response and "</think>" in response:
346
+ score += 0.5
347
+
348
+ # 检查是否有answer标签
349
+ if "<answer>" in response and "</answer>" in response:
350
+ score += 0.5
351
+
352
+ # 检查标签的顺序是否正确
353
+ think_start = response.find("<think>")
354
+ think_end = response.find("</think>")
355
+ answer_start = response.find("<answer>")
356
+ answer_end = response.find("</answer>")
357
+
358
+ if (think_start != -1 and think_end != -1 and
359
+ answer_start != -1 and answer_end != -1 and
360
+ think_start < think_end < answer_start < answer_end):
361
+ score += 0.5 # 格式完全正确的额外奖励
362
+
363
+ rewards.append(score)
364
+
365
+ return rewards
366
+
367
+ def accuracy_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
368
+ """
369
+ 奖励函数:检查答案准确率
370
+ 适配蛋白质分类任务
371
+ """
372
+ responses = [completion[0]['content'] for completion in completions]
373
+ rewards = []
374
+
375
+ for i, response in enumerate(responses):
376
+ # 提取answer标签中的内容
377
+ answer_match = re.search(r"<answer>(.*?)</answer>", response, re.DOTALL)
378
+ if answer_match:
379
+ extracted_answer = answer_match.group(1).strip()
380
+ else:
381
+ extracted_answer = response.strip()
382
+
383
+ # 获取正确答案
384
+ if isinstance(answer, list) and len(answer) > i:
385
+ correct_answer = str(answer[i]).strip()
386
+ elif isinstance(answer, list) and len(answer) > 0:
387
+ correct_answer = str(answer[0]).strip()
388
+ else:
389
+ correct_answer = str(answer).strip()
390
+
391
+ # 计算准确率奖励
392
+ # 对于分类任务,检查数字或类别匹配
393
+ extracted_clean = re.sub(r'[^\w\d]', '', extracted_answer.lower())
394
+ correct_clean = re.sub(r'[^\w\d]', '', correct_answer.lower())
395
+
396
+ if correct_clean in extracted_clean or extracted_clean == correct_clean:
397
+ rewards.append(1.0) # 完全匹配
398
+ elif any(word in extracted_clean for word in correct_clean.split()):
399
+ rewards.append(0.5) # 部分匹配
400
+ else:
401
+ rewards.append(0.0) # 不匹配
402
+
403
+ return rewards
404
+
405
+ def classification_specific_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
406
+ """
407
+ 针对蛋白质分类任务的专门奖励函数
408
+ """
409
+ responses = [completion[0]['content'] for completion in completions]
410
+ rewards = []
411
+
412
+ for i, response in enumerate(responses):
413
+ score = 0.0
414
+
415
+ # 提取答案
416
+ answer_match = re.search(r"<answer>(.*?)</answer>", response, re.DOTALL)
417
+ if answer_match:
418
+ extracted_answer = answer_match.group(1).strip()
419
+ else:
420
+ extracted_answer = response.strip()
421
+
422
+ # 获取正确答案
423
+ if isinstance(answer, list) and len(answer) > i:
424
+ correct_answer = str(answer[i]).strip()
425
+ elif isinstance(answer, list) and len(answer) > 0:
426
+ correct_answer = str(answer[0]).strip()
427
+ else:
428
+ correct_answer = str(answer).strip()
429
+
430
+ # 检查是否包含分类关键词
431
+ classification_keywords = ['classification', 'class', 'category', 'type', 'function', 'family']
432
+ if any(keyword in extracted_answer.lower() for keyword in classification_keywords):
433
+ score += 0.2
434
+
435
+ # 检查数字匹配(对于数字标签)
436
+ if correct_answer.isdigit():
437
+ if correct_answer in extracted_answer:
438
+ score += 0.8
439
+ # 检查数字临近性
440
+ try:
441
+ extracted_numbers = re.findall(r'\d+', extracted_answer)
442
+ if extracted_numbers:
443
+ closest_num = min(extracted_numbers, key=lambda x: abs(int(x) - int(correct_answer)))
444
+ if abs(int(closest_num) - int(correct_answer)) <= 1:
445
+ score += 0.4
446
+ except:
447
+ pass
448
+ else:
449
+ # 文本标签匹配
450
+ if correct_answer.lower() in extracted_answer.lower():
451
+ score += 0.8
452
+
453
+ # 检查是否有推理过程
454
+ if "<think>" in response and "</think>" in response:
455
+ think_content = re.search(r"<think>(.*?)</think>", response, re.DOTALL)
456
+ if think_content and len(think_content.group(1).strip()) > 20:
457
+ score += 0.2
458
+
459
+ rewards.append(min(score, 1.0)) # 确保不超过1.0
460
+
461
+ return rewards
462
+
463
+ def repetition_penalty_reward_func(completions, **kwargs) -> list[float]:
464
+ """
465
+ 奖励函数:检查重复率(越低越好)
466
+ 计算文本中重复词汇的比例,重复率越低奖励越高
467
+ """
468
+ responses = [completion[0]["content"] for completion in completions]
469
+ rewards = []
470
+
471
+ for response in responses:
472
+ # 提取answer部分的文本
473
+ answer_match = re.search(r"<answer>(.*?)</answer>", response, re.DOTALL)
474
+ if answer_match:
475
+ text_to_analyze = answer_match.group(1).strip()
476
+ else:
477
+ text_to_analyze = response.strip()
478
+
479
+ # 分词并计算重复率
480
+ words = text_to_analyze.lower().split()
481
+
482
+ if len(words) == 0:
483
+ rewards.append(0.0)
484
+ continue
485
+
486
+ # 计算词汇重复率
487
+ unique_words = set(words)
488
+ repetition_rate = 1.0 - (len(unique_words) / len(words))
489
+
490
+ # 计算句子重复率
491
+ sentences = [s.strip() for s in text_to_analyze.split('.') if s.strip()]
492
+ if len(sentences) > 1:
493
+ unique_sentences = set(sentences)
494
+ sentence_repetition_rate = 1.0 - (len(unique_sentences) / len(sentences))
495
+ else:
496
+ sentence_repetition_rate = 0.0
497
+
498
+ # 综合重复率
499
+ overall_repetition = (repetition_rate + sentence_repetition_rate) / 2
500
+
501
+ # 重复率越低,奖励越高
502
+ reward = max(0.0, 1.0 - overall_repetition * 2) # 乘以2让惩罚更明显
503
+ rewards.append(reward)
504
+
505
+ return rewards
506
+
507
+ def combined_reward_func(prompts, completions, answer,
508
+ format_weight=0.3, accuracy_weight=0.5, repetition_weight=0.2,
509
+ **kwargs) -> list[float]:
510
+ """
511
+ 组合奖励函数:格式+准确率+重复率的加权组合
512
+ """
513
+ format_rewards = format_correct_reward_func(completions, **kwargs)
514
+ accuracy_rewards = accuracy_reward_func(prompts, completions, answer, **kwargs)
515
+ repetition_rewards = repetition_penalty_reward_func(completions, **kwargs)
516
+
517
+ # 确保权重总和为1
518
+ total_weight = format_weight + accuracy_weight + repetition_weight
519
+ if total_weight != 1.0:
520
+ format_weight /= total_weight
521
+ accuracy_weight /= total_weight
522
+ repetition_weight /= total_weight
523
+ print(f"Normalized weights - Format: {format_weight:.3f}, Accuracy: {accuracy_weight:.3f}, Repetition: {repetition_weight:.3f}")
524
+
525
+ combined_rewards = []
526
+ for f_reward, a_reward, r_reward in zip(format_rewards, accuracy_rewards, repetition_rewards):
527
+ combined = (format_weight * f_reward +
528
+ accuracy_weight * a_reward +
529
+ repetition_weight * r_reward)
530
+ combined_rewards.append(combined)
531
+
532
+ return combined_rewards
533
+
534
+ # 保留一些原有的奖励函数作为备选
535
+ def less_than_4_reward_func(completions, **kwargs) -> list[float]:
536
+ responses = [completion[0]['content'] for completion in completions]
537
+ extracted_responses = [extract_xml_answer(r) for r in responses]
538
+ return [0.5 if len(r.split(' ')) <= 4 else 0.0 for r in extracted_responses]
539
+
540
+ def strict_format_reward_func(completions, **kwargs) -> list[float]:
541
+ """Reward function that checks if the completion has a specific format."""
542
+ pattern = r"^<think>\n.*?\n</think>\n.*?\n$"
543
+ responses = [completion[0]["content"] for completion in completions]
544
+ matches = [re.match(pattern, r) for r in responses]
545
+ return [0.5 if match else 0.0 for match in matches]
546
+
547
+ def xmlcount_reward_func(completions, **kwargs) -> list[float]:
548
+ contents = [completion[0]["content"] for completion in completions]
549
+ return [count_xml(c) for c in contents]
550
+
551
+ def count_xml(text) -> float:
552
+ count = 0.0
553
+ if text.count("<think>\n") == 1:
554
+ count += 0.125
555
+ if text.count("\n</think>\n") == 1:
556
+ count += 0.125
557
+ return count
558
+
559
+ @dataclass
560
+ class Blip2ModelConfig(ModelConfig):
561
+ # BLIP2 specific configuration
562
+ model_name_or_path: str = field(default="blip2-model", metadata={"help": "Model checkpoint for weights initialization."})
563
+
564
+ # BLIP2 Architecture parameters
565
+ bert_name: str = field(default="/path/to/bert", metadata={"help": "BERT model for Q-former"})
566
+ num_query_token: int = field(default=8, metadata={"help": "Number of query tokens"})
567
+ cross_attention_freq: int = field(default=2, metadata={"help": "Cross attention frequency"})
568
+ plm_model: str = field(default="facebook/esm2_t30_150M_UR50D", metadata={"help": "Protein language model"})
569
+ plm_tune: str = field(default="freeze", metadata={"help": "PLM tuning strategy"})
570
+ llm_name: str = field(default="facebook/galactica-1.3b", metadata={"help": "Language model name"})
571
+ llm_tune: str = field(default="lora", metadata={"help": "LLM tuning strategy"})
572
+ qformer_tune: str = field(default="train", metadata={"help": "Q-former tuning strategy"})
573
+ peft_dir: str = field(default="", metadata={"help": "PEFT directory"})
574
+
575
+ # LoRA parameters
576
+ lora_r: int = field(default=8, metadata={"help": "LoRA rank"})
577
+ lora_alpha: int = field(default=16, metadata={"help": "LoRA alpha"})
578
+ lora_dropout: float = field(default=0.1, metadata={"help": "LoRA dropout"})
579
+
580
+ # Training parameters
581
+ enbale_gradient_checkpointing: bool = field(default=False, metadata={"help": "Enable gradient checkpointing"})
582
+ enable_flash: bool = field(default=False, metadata={"help": "Enable flash attention"})
583
+
584
+ # Other parameters
585
+ cache_dir: str = field(default=None, metadata={"help": "Path to model cache directory."})
586
+ sft_checkpoint: str = field(default=None, metadata={"help": "Path to the checkpoint for SFT."})
587
+ freeze_dna_modules: bool = field(default=False, metadata={"help": "Freeze DNA/protein modules"})
588
+
589
+ @dataclass
590
+ class GRPOScriptArguments(ScriptArguments):
591
+ """
592
+ Script arguments for the GRPO training script with BLIP2.
593
+ """
594
+ dataset_name: str = field(default="wanglab/kegg", metadata={"help": "Dataset name with default."})
595
+ data_file_paths: str = field(
596
+ default=None,
597
+ metadata={"help": "Path to protein classification CSV file (format: name,aa_seq,label,location,unique_id,pdb_hash)"},
598
+ )
599
+ arrow_cache_dir: str = field(
600
+ default=None,
601
+ metadata={"help": "Path to arrow cache directory"},
602
+ )
603
+ val_split_ratio: float = field(
604
+ default=0.1,
605
+ metadata={"help": "Ratio of validation split, default 0.1"},
606
+ )
607
+ reward_funcs: list[str] = field(
608
+ # 选项1:使用组合奖励函数(推荐)
609
+ default_factory=lambda: ["combined"],
610
+
611
+ # 选项2:使用分离的奖励函数
612
+ # default_factory=lambda: ["format_correct", "accuracy", "repetition_penalty"],
613
+
614
+ # 选项3:使用蛋白质分类专用奖励
615
+ # default_factory=lambda: ["format_correct", "classification_specific", "repetition_penalty"],
616
+
617
+ metadata={"help": "List of reward functions. Available: 'combined', 'format_correct', 'accuracy', 'classification_specific', 'repetition_penalty', 'xmlcount', 'strict_format', 'less_than_4'"},
618
+ )
619
+
620
+ # 奖励函数权重配置
621
+ format_weight: float = field(
622
+ default=0.3,
623
+ metadata={"help": "Weight for format correctness reward (used in combined reward)"}
624
+ )
625
+ accuracy_weight: float = field(
626
+ default=0.5,
627
+ metadata={"help": "Weight for accuracy reward (used in combined reward)"}
628
+ )
629
+ repetition_weight: float = field(
630
+ default=0.2,
631
+ metadata={"help": "Weight for repetition penalty reward (used in combined reward)"}
632
+ )
633
+
634
+ # 数据处理参数
635
+ template_name: str = field(
636
+ default="classification",
637
+ metadata={"help": "Prompt template to use: 'classification', 'function_prediction', 'location_prediction'"}
638
+ )
639
+ max_seq_length: int = field(
640
+ default=1000,
641
+ metadata={"help": "Maximum protein sequence length for display in prompt"}
642
+ )
643
+ use_custom_prompts: bool = field(
644
+ default=True,
645
+ metadata={"help": "Whether to use custom protein-specific prompts"}
646
+ )
647
+
648
+ reward_funcs_registry = {
649
+ # 新的三合一奖励函数
650
+ "combined": combined_reward_func, # 格式+准确率+重复率组合
651
+
652
+ # 分离的奖励函数
653
+ "format_correct": format_correct_reward_func, # 格式正确性
654
+ "accuracy": accuracy_reward_func, # 准确率
655
+ "repetition_penalty": repetition_penalty_reward_func, # 重复率惩罚
656
+ "classification_specific": classification_specific_reward_func, # 蛋白质分类专用
657
+
658
+ # 原有的奖励函数(保留作为备选)
659
+ "xmlcount": xmlcount_reward_func,
660
+ "strict_format": strict_format_reward_func,
661
+ "less_than_4": less_than_4_reward_func,
662
+ }
663
+
664
+ def get_vlm_module(model_name_or_path):
665
+ # Always use BLIP2 module for this implementation
666
+ return Blip2DNAModule
667
+
668
+ def create_blip2_args_from_config(model_args):
669
+ """Create BLIP2 args from model config"""
670
+ # Convert model config to the format expected by BLIP2
671
+ blip2_args = {
672
+ 'bert_name': model_args.bert_name,
673
+ 'num_query_token': model_args.num_query_token,
674
+ 'cross_attention_freq': model_args.cross_attention_freq,
675
+ 'plm_model': model_args.plm_model,
676
+ 'plm_tune': model_args.plm_tune,
677
+ 'llm_name': model_args.llm_name,
678
+ 'llm_tune': model_args.llm_tune,
679
+ 'qformer_tune': model_args.qformer_tune,
680
+ 'peft_dir': model_args.peft_dir,
681
+ 'lora_r': model_args.lora_r,
682
+ 'lora_alpha': model_args.lora_alpha,
683
+ 'lora_dropout': model_args.lora_dropout,
684
+ 'enbale_gradient_checkpointing': model_args.enbale_gradient_checkpointing,
685
+ 'enable_flash': model_args.enable_flash,
686
+ }
687
+ return blip2_args
688
+
689
+ def _prep_for_training(model, training_args):
690
+ """
691
+ Prepare BLIP2 model for training with LoRA.
692
+ """
693
+ # The BLIP2 model should handle its own LoRA setup
694
+ # This is mainly for any additional preparation needed
695
+
696
+ target_modules = ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"]
697
+
698
+ lora_config = LoraConfig(
699
+ r=training_args.lora_r,
700
+ lora_alpha=training_args.lora_alpha,
701
+ lora_dropout=training_args.lora_dropout,
702
+ target_modules=target_modules,
703
+ init_lora_weights="gaussian",
704
+ bias="none",
705
+ task_type="CAUSAL_LM",
706
+ )
707
+
708
+ return lora_config
709
+
710
+ def main(script_args, training_args, model_args):
711
+ print(training_args.output_dir)
712
+ torch.cuda.empty_cache()
713
+ torch.set_float32_matmul_precision("medium")
714
+
715
+ # Create BLIP2 model
716
+ blip2_args = create_blip2_args_from_config(model_args)
717
+ model = Blip2Stage2(blip2_args)
718
+
719
+ # Load checkpoint if specified
720
+ if model_args.sft_checkpoint is not None:
721
+ print(f"Loading SFT checkpoint from {model_args.sft_checkpoint}")
722
+ model = Blip2Stage2.load_from_checkpoint(model_args.sft_checkpoint, strict=False, args=blip2_args, map_location='cpu')
723
+
724
+ # if os.path.isdir(model_args.sft_checkpoint):
725
+ # # Load Lightning checkpoint
726
+ # checkpoint = torch.load(os.path.join(model_args.sft_checkpoint, "last.ckpt"), map_location='cpu')
727
+ # model.load_state_dict(checkpoint['state_dict'], strict=False)
728
+ # print("Loaded Lightning checkpoint")
729
+ # else:
730
+ # # Load PyTorch state dict
731
+ # checkpoint = torch.load(model_args.sft_checkpoint, map_location='cpu')
732
+
733
+ # if "state_dict" in checkpoint:
734
+ # state_dict = checkpoint["state_dict"]
735
+ # else:
736
+ # state_dict = checkpoint
737
+
738
+ # # Remove module prefix if present
739
+ # state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
740
+
741
+ # result = model.load_state_dict(state_dict, strict=False)
742
+ # print(f"Loaded checkpoint with {len(result.missing_keys)} missing keys and {len(result.unexpected_keys)} unexpected keys")
743
+
744
+ # Get reward functions with weights
745
+ reward_funcs = []
746
+ for func_name in script_args.reward_funcs:
747
+ if func_name == "combined":
748
+ # 为组合奖励函数传递权重参数
749
+ def weighted_combined_reward(prompts, completions, answer, **kwargs):
750
+ return combined_reward_func(
751
+ prompts, completions, answer,
752
+ format_weight=script_args.format_weight,
753
+ accuracy_weight=script_args.accuracy_weight,
754
+ repetition_weight=script_args.repetition_weight,
755
+ **kwargs
756
+ )
757
+ reward_funcs.append(weighted_combined_reward)
758
+ else:
759
+ reward_funcs.append(reward_funcs_registry[func_name])
760
+
761
+ print("reward_funcs:", [func.__name__ if hasattr(func, '__name__') else 'weighted_combined_reward' for func in reward_funcs])
762
+ print(f"Reward weights - Format: {script_args.format_weight}, Accuracy: {script_args.accuracy_weight}, Repetition: {script_args.repetition_weight}")
763
+
764
+ vlm_module_cls = get_vlm_module(model_args.model_name_or_path)
765
+ print("using vlm module:", vlm_module_cls.__name__)
766
+ question_prompt = vlm_module_cls.get_question_template()
767
+
768
+ # Load dataset based on data source
769
+ if script_args.data_file_paths and script_args.use_custom_prompts:
770
+ print(f"Loading custom protein data from: {script_args.data_file_paths}")
771
+
772
+
773
+ dataset = get_custom_protein_data_with_prompts(
774
+ data_path=script_args.data_file_paths,
775
+ prompt_templates=prompt_templates,
776
+ template_name=script_args.template_name
777
+ )
778
+ elif script_args.data_file_paths:
779
+ print(f"Loading protein data from: {script_args.data_file_paths}")
780
+ dataset = get_protein_classification_data(
781
+ data_path=script_args.data_file_paths
782
+ )
783
+ else:
784
+ print("Using default KEGG dataset")
785
+ dataset = get_kegg_questions()
786
+
787
+ print("Dataset loaded:")
788
+ print(f"Train size: {len(dataset['train'])}")
789
+ print(f"Val size: {len(dataset.get('val', []))}")
790
+
791
+ # 打印数据样例
792
+ if len(dataset['train']) > 0:
793
+ print("\nSample data:")
794
+ sample = dataset['train'][0]
795
+ print(f"Prompt type: {type(sample.get('prompt', 'Unknown'))}")
796
+ print(f"DNA sequences count: {len(sample.get('dna_sequences', []))}")
797
+ print(f"Answer: {sample.get('answer', 'N/A')}")
798
+ if 'metadata' in sample:
799
+ print(f"Metadata: {sample['metadata']}")
800
+ print(f"First 100 chars of sequence: {sample.get('dna_sequences', [''])[0][:100]}...")
801
+
802
+
803
+ # Custom callback to handle saving with PyTorch's native mechanism
804
+ custom_save_callback = SaveWithPyTorchCallback()
805
+
806
+ # Initialize the BLIP2 GRPO trainer
807
+ trainer = Blip2GRPOTrainer(
808
+ model=model,
809
+ reward_funcs=reward_funcs,
810
+ args=training_args,
811
+ dna_module=vlm_module_cls(),
812
+ train_dataset=dataset['train'],
813
+ eval_dataset=dataset['val'] if training_args.eval_strategy != "no" else None,
814
+ peft_config=get_peft_config(model_args),
815
+ attn_implementation=getattr(model_args, 'attn_implementation', 'flash_attention_2'),
816
+ torch_dtype=getattr(model_args, 'torch_dtype', 'bfloat16'),
817
+ callbacks=[custom_save_callback],
818
+ )
819
+
820
+ # Set the trainer to save in PyTorch format instead of safetensors
821
+ training_args.save_safetensors = False
822
+
823
+ # Train the model
824
+ trainer.train()
825
+
826
+ if __name__ == "__main__":
827
+ print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}")
828
+ parser = TrlParser((GRPOScriptArguments, DNALLMGRPOConfig, Blip2ModelConfig))
829
+ script_args, training_args, model_args = parser.parse_args_and_config()
830
+
831
+ # Ensure we use PyTorch's save mechanism instead of safetensors
832
+ training_args.save_safetensors = False
833
+
834
+ main(script_args, training_args, model_args)
835
+
836
+ # 使用示例:
837
+ """
838
+ 使用你的蛋白质数据进行训练:
839
+
840
+ 1. 准备CSV文件,格式:name,aa_seq,label,location,unique_id,pdb_hash
841
+
842
+ 2. 运行训练:
843
+ python blip2_reason.py \
844
+ --data_file_paths /path/to/your/protein_data.csv \
845
+ --reward_funcs combined \
846
+ --format_weight 0.2 \
847
+ --accuracy_weight 0.6 \
848
+ --repetition_weight 0.2 \
849
+ --use_custom_prompts \
850
+ --prompt_template classification \
851
+ --max_seq_length 1000 \
852
+ --output_dir ./output \
853
+ --per_device_train_batch_size 4 \
854
+ --num_train_epochs 3 \
855
+ --learning_rate 1e-5
856
+
857
+ 3. 或者使用分离的奖励函数:
858
+ python blip2_reason.py \
859
+ --data_file_paths /path/to/your/protein_data.csv \
860
+ --reward_funcs format_correct classification_specific repetition_penalty \
861
+ --use_custom_prompts \
862
+ --prompt_template function_prediction
863
+
864
+ 数据格式示例:
865
+ P0DM40,MLRVVVESASINPPLSTTPKAFVTVYFRDMMKRTRVEEGHDPIWNETLIWHLWNQPLENDSFLKVILQDSVSKKKERFIGLATVPLKRLAQRPKEVMFVRDLILLNHSMKPTNCTVTLHVAQIYDQDTEMTGNEELLGSTVNEVTQKKLMVSGLPMHRALASKPQHFQVRVKVFEARQLLGNNIKPVVKVNIADQQHLTRIKMGNNPFFNEIFFQNFHEVPAKFFEENISIEVVDSAASRSKAEIGRFQTDIGFIYHSPGHTLLRKWLGLCQRNKTTSGVRGYLKVTICALGVGDQALVDQKLPYEQNTRVQIFKSKEVPVSLAYLQFFIYCAEDLHFGTHKSATPVLEVELIGDKLRTKPQNPSDNPIWNQILTFQIQLPCLSSYIKFRVMDCSKYKCQDEIGSASLCLSQISSTGEEIQGMYSGFLPCFGPSFLTLRGGKKPPFRTSEEGTCIMDAVQHGLAYRGRIFVEIVTKIKSQQDSVMKDLSQEVTQVEMQYYRQKYGLCVIFLSCTMMPKFKDLIQFEVSMGHYGNKTDPNYKPLVSTTQYSPVIYDGTTYHYVPWYNTKPVVAVTSNWEDVGFRMNCLNLLHITRDRLKTNLDILKSIRNPRDPALLQQWEKLLKELQEDCRRPLPCMTDQPRANSLDRNKWQLRSQLLQQLAQMAKEAKPVNMVGTAKEWLHRLNAVIPEPQESLPDVLIWLMSRQQRVAYARVPAHTVLFSPAGPLSSGKFCGKIQNILLQYPEGEGQDTFPASLRVCMWLGNVKYSKNLKLLQQGSMVVYAETYENQAKTRDDWGQQGLYHCPNFSDVMGRKALPKTDFKAPPGWHWKDDWVVEPQRRLLLDIDINKSQVLEEVYENQLRNATGAWVPAAIPNTDVNGQPVEALENVKCPQGWHFKKNWIVKLNHAVDSEGWEYGVGIPPSGLPQIWNSVEKTYHSCRRRRWVRVRFRNHKELGQERSQEQETLSFLQMQDLSEEGKEGWEYGTFDSRFHLDPQPTSRFRRRCWHRQLAPNKDRGVASIFLLEGSLAVEQKDQPRKEMEKTRSWQPWKDLRHTPEDPRIPTTPFIYYILNKPHYYQLFCYIYQARNLMYNQILTFQEPFIQVVFLNHSLCTQTLRSSAAPTWSQSIIFQHLLLFEDPKDTRENPPLVVLELWQHDSRGNKILWGRSMWPPVVWLGLQDWVFTPLRWHPLVRELGEEEGEILASCELILETQKLKELHPPILSIPCKDGIYLLPKNIQPTMKMMAIEIMAWGLRNMTKVRYPQLLLECGGESLKTEPISNFQENPNFPTSTFFFTVFMPLEETHAQPLVVKVVDNQEYGQQIVVGQANIDFLQPYFCDPWSLNYTTVKLPTLSVKKPDTFLDFVYKKFWFDSSKDEEVYEEEVDWWSKLFWATGDADKSLNYNHKSYHTLKVYDCELEAVLTFKGLQDFCQTFKLYQEKPKVDSPVVGEFKGLFRIYPFPEDPEAPKPPRQFSAWPEIEDFPQMCLVRVYLIRAINLQPQDYNGLCDPYVILKLGQTKLGSRDSYYPNTLDPIFGMMYELTCNIPLEKDLEIQLFDFDLITADDEIGSTVIDLENRLLSGFGARCGLSKSYCKSGPFKWRDQMTPSYLLYRYAKQKGLPPPVFDLEGDSLYYNGETFKLQSFESAPPTYKHLGPKKERLALYILNTQGLVPEHVETRTLHSNSQPGIDQGKIQMWVDIFPKMLGPPGPQVNISPRKPKRYQLRCIIWSTAEVDLVQETFSKEKMSDIYVKGWLFGLEEDTQKTDVHYHSLTGEATFNWRFIFTMDYLTTERACVQSQKDYIWSLDPTSTKFPARLMIQIWDNDFFSPDDFLGVLELDLSDMPLPAQNIKQCSLKMMETDSKWPFTPQKRISLFKKTNVTGWWPCQVLDGDKWRLSGKVKMTLEMLSEREALIRPAGRGQSEPNQFPMLHPPERNDSFLLWYQSPIKNFCYAVCKRYRSKIICLVVTLVIGFILLNFVYSAPSYFAMNWIKPQLRLSSPIKIVNLIGTVNTSNINSSILTMEGSTYHASHVFPEAPAP,0,M,af67d99c09f74ea8af5004cc2906bbc5,d55cbc3d94bd9668d97a678b4a04176a
866
+ """
BioReason-0813/model/__pycache__/blip2.cpython-310.pyc ADDED
Binary file (3.17 kB). View file
 
BioReason-0813/model/__pycache__/blip2_opt.cpython-310.pyc ADDED
Binary file (9.75 kB). View file
 
BioReason-0813/model/__pycache__/blip2_opt.cpython-311.pyc ADDED
Binary file (18.1 kB). View file
 
BioReason-0813/model/__pycache__/blip2_stage2.cpython-310.pyc ADDED
Binary file (14.1 kB). View file
 
BioReason-0813/model/__pycache__/blip2_stage2.cpython-311.pyc ADDED
Binary file (28.2 kB). View file
 
BioReason-0813/model/__pycache__/help_funcs.cpython-310.pyc ADDED
Binary file (3.97 kB). View file
 
BioReason-0813/model/blip2.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2023, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from lavis.models.base_model import BaseModel
11
+ from lavis.models.blip2_models.Qformer import BertConfig, BertLMHeadModel
12
+ from transformers import BertTokenizer, BitsAndBytesConfig
13
+ from transformers import EsmTokenizer, EsmModel
14
+ import os
15
+ from pathlib import Path # 添加到文件顶部
16
+
17
+
18
+ def get_gpu_memory(device=0):
19
+ # t = torch.cuda.get_device_properties(device).total_memory
20
+ # r = torch.cuda.memory_reserved(device)
21
+ # a = torch.cuda.memory_allocated(device)
22
+ # f = r-a # free inside reserved
23
+ free, total = torch.cuda.mem_get_info(device)
24
+ free = free / (1024 ** 3)
25
+ total = total / (1024 ** 3)
26
+ return free, total-free, total
27
+
28
+
29
+ class Blip2Base(BaseModel):
30
+ # @classmethod
31
+ # def init_tokenizer(cls):
32
+ # tokenizer = BertTokenizer.from_pretrained('./bert_pretrained/')
33
+ # tokenizer.add_special_tokens({"bos_token": "[DEC]"})
34
+ # return tokenizer
35
+
36
+ @classmethod
37
+ def init_Qformer(cls, model_name, num_query_token, plm_width, cross_attention_freq=2):
38
+ # assert model_name == 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract'
39
+ # print("bert load microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
40
+
41
+ print(f"Loading Qformer from: {model_name}")
42
+
43
+ # 修改2:添加本地路径检查逻辑
44
+ if not model_name.startswith('microsoft/') and Path(model_name).exists():
45
+ print("Loading from local path...")
46
+ else:
47
+ print("Loading from Hugging Face Hub...")
48
+
49
+ encoder_config = BertConfig.from_pretrained(model_name)
50
+ encoder_config.encoder_width = plm_width
51
+ # insert cross-attention layer every other block
52
+ encoder_config.add_cross_attention = True
53
+ encoder_config.cross_attention_freq = cross_attention_freq
54
+ encoder_config.query_length = num_query_token
55
+
56
+ Qformer = BertLMHeadModel.from_pretrained(model_name, config=encoder_config)
57
+ query_tokens = nn.Parameter(
58
+ torch.zeros(1, num_query_token, encoder_config.hidden_size)
59
+ )
60
+ query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
61
+
62
+ tokenizer = BertTokenizer.from_pretrained(model_name)
63
+ tokenizer.add_special_tokens({"bos_token": "[DEC]"})
64
+ return tokenizer, Qformer, query_tokens
65
+
66
+
67
+ def init_protein_encoder(self, plm_name, load_4bit=False):
68
+ # assert plm_name.startswith('facebook/esm2')
69
+ # plm_tokenizer = EsmTokenizer.from_pretrained(plm_name)
70
+ # 检查是否为本地路径(判断是否存在文件夹或文件)
71
+ if os.path.isdir(plm_name) or os.path.exists(os.path.join(plm_name, "config.json")):
72
+ print(f"Loading local PLM from {plm_name}")
73
+ plm_tokenizer = EsmTokenizer.from_pretrained(plm_name)
74
+ else:
75
+ # 保留远程加载逻辑(可选)
76
+ print(f"Loading remote PLM from {plm_name}")
77
+ plm_tokenizer = EsmTokenizer.from_pretrained(plm_name)
78
+
79
+ if not load_4bit:
80
+ plm = EsmModel.from_pretrained(plm_name, add_pooling_layer=False, torch_dtype=torch.bfloat16)
81
+ else:
82
+ quant_config = BitsAndBytesConfig(
83
+ load_in_4bit=True,
84
+ load_in_8bit=False,
85
+ llm_int8_threshold=6.0,
86
+ llm_int8_has_fp16_weight=False,
87
+ bnb_4bit_compute_dtype=torch.bfloat16,
88
+ bnb_4bit_use_double_quant=True,
89
+ bnb_4bit_quant_type='nf4',
90
+ )
91
+ ## give a device map that assign all layers to device 0
92
+ outputs = get_gpu_memory(6)
93
+ used_memory = outputs[1]
94
+ if used_memory > 1:
95
+ device_map = {"": 7}
96
+ else:
97
+ device_map = {"": 6}
98
+ plm = EsmModel.from_pretrained(
99
+ plm_name,
100
+ add_pooling_layer=False,
101
+ quantization_config=quant_config,
102
+ load_in_4bit=True,
103
+ load_in_8bit=False,
104
+ device_map=device_map,
105
+ torch_dtype=torch.bfloat16,
106
+ )
107
+
108
+ plm.num_features = plm.config.hidden_size
109
+ ln_layer = nn.LayerNorm(plm.num_features)
110
+ return plm_tokenizer, plm, ln_layer
111
+
112
+
113
+ def disabled_train(self, mode=True):
114
+ """Overwrite model.train with this function to make sure train/eval mode
115
+ does not change anymore."""
116
+ return self
117
+
118
+
119
+ # class LayerNorm(nn.LayerNorm):
120
+ # """Subclass torch's LayerNorm to handle fp16."""
121
+
122
+ # def forward(self, x: torch.Tensor):
123
+ # orig_type = x.dtype
124
+ # ret = super().forward(x.type(torch.float32))
125
+ # return ret.type(orig_type)
126
+
BioReason-0813/model/blip2_opt.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2023, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+ import logging
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.cuda.amp import autocast as autocast
11
+ # from peft import get_peft_config, get_peft_model, get_peft_model_state_dict, LoraConfig, TaskType, PeftModel
12
+ from lavis.models.blip2_models.blip2 import disabled_train
13
+ from model.blip2 import Blip2Base
14
+ from transformers import AutoTokenizer
15
+ from transformers import OPTForCausalLM
16
+ from transformers import AutoTokenizer, AutoModelForCausalLM
17
+ from opendelta import LoraModel
18
+ from opendelta.delta_models.lora import LoraConfig as DeltaLoraConfig
19
+ from transformers import BertTokenizer, BitsAndBytesConfig
20
+ from model.help_funcs import hf_enable_gradient_checkpointing
21
+ import json
22
+ # from accelerate import Accelerator
23
+ # import torch.distributed as dist
24
+
25
+ # from peft.tuners.lora import LoraLayer
26
+ # from peft import (
27
+ # prepare_model_for_kbit_training,
28
+ # LoraConfig as PeftLoraConfig,
29
+ # get_peft_model,
30
+ # PeftModel
31
+ # )
32
+
33
+ # from opendelta.delta_configs
34
+
35
+ opt_model_list = [
36
+ "facebook/galactica-125m",
37
+ "facebook/galactica-1.3b",
38
+ "facebook/galactica-6.7b",
39
+ "facebook/galactica-30b",
40
+ ]
41
+
42
+ def get_gpu_memory(device=0):
43
+ # t = torch.cuda.get_device_properties(device).total_memory
44
+ # r = torch.cuda.memory_reserved(device)
45
+ # a = torch.cuda.memory_allocated(device)
46
+ # f = r-a # free inside reserved
47
+ free, total = torch.cuda.mem_get_info(device)
48
+ free = free / (1024 ** 3)
49
+ total = total / (1024 ** 3)
50
+ return free, total-free, total
51
+
52
+ def mask_by_len(input, lens, fill_value=0):
53
+ '''
54
+ input: shape = [N, D]
55
+ lens: shape = [N]
56
+ '''
57
+ mask = torch.arange(input.shape[1], device=input.device).reshape(1, -1)
58
+ mask = mask < lens.reshape(-1, 1)
59
+ input[mask] = fill_value
60
+ return input
61
+
62
+
63
+
64
+ class Blip2OPT(Blip2Base):
65
+ """
66
+ BLIP2 first-stage model with Q-former and ViT.
67
+ Supported model types:
68
+ - pretrained: pretrained model with vit-g
69
+ - pretrain_vitL: pretrained model with vit-large
70
+ - coco: fintuned model on coco
71
+ Usage:
72
+ >>> from lavis.models import load_model
73
+ >>> model = load_model("blip2", "pretrain")
74
+ """
75
+ def __init__(
76
+ self,
77
+ bert_name,
78
+ num_query_token=32,
79
+ cross_attention_freq=2,
80
+ plm_model="facebook/esm2_t30_150M_UR50D",
81
+ plm_tune='freeze',
82
+ llm_name="facebook/galactica-1.3b",
83
+ llm_tune='freeze',
84
+ qformer_tune='train',
85
+ peft_dir='',
86
+ args=None,
87
+ ):
88
+ super().__init__()
89
+ self.args = args
90
+ self.enbale_gradient_checkpointing = args.enbale_gradient_checkpointing
91
+
92
+ self.plm_tokenizer, self.plm, self.ln_layer = self.init_protein_encoder(plm_model)
93
+ self.plm_tune = plm_tune
94
+ if plm_tune == 'freeze':
95
+ for name, param in self.plm.named_parameters():
96
+ param.requires_grad = False
97
+ self.plm = self.plm.eval()
98
+ self.plm.train = disabled_train
99
+ logging.info("freeze plm encoder")
100
+ elif plm_tune == 'lora':
101
+ lora_config = DeltaLoraConfig(args.lora_r,
102
+ args.lora_alpha,
103
+ args.lora_dropout,
104
+ modified_modules=["query", "value"])
105
+ self.delta = LoraModel.from_config(lora_config, self.plm)
106
+ self.delta.freeze_module(set_state_dict=False)
107
+ self.delta.log()
108
+ else:
109
+ raise NotImplementedError()
110
+
111
+ self.num_query_token = num_query_token
112
+ self.qformer_tokenizer, self.Qformer, self.query_tokens = self.init_Qformer(bert_name, num_query_token, self.plm.num_features, cross_attention_freq)
113
+ ### remove the unused parameters
114
+ self.Qformer.cls = None
115
+ self.Qformer.bert.embeddings.word_embeddings = None
116
+ self.Qformer.bert.embeddings.position_embeddings = None
117
+ for layer in self.Qformer.bert.encoder.layer:
118
+ layer.output = None
119
+ layer.intermediate = None
120
+
121
+ # === 3. 控制 Qformer 是否冻结 ===
122
+ self.qformer_tune = qformer_tune
123
+ if self.qformer_tune == 'freeze':
124
+ for name, param in self.Qformer.named_parameters():
125
+ param.requires_grad = False
126
+ self.Qformer = self.Qformer.eval()
127
+ self.Qformer.train = disabled_train
128
+ logging.info("freeze Qformer encoder")
129
+ elif self.qformer_tune == 'train':
130
+ logging.info("train Qformer encoder")
131
+ else:
132
+ raise NotImplementedError(f"Unsupported qformer_tune mode: {self.qformer_tune}")
133
+
134
+ ## initialize llm model
135
+ # self.init_distributed()
136
+ self.llm_model, self.llm_tokenizer = self.load_llm(llm_name)
137
+
138
+ #self.llm_model, self.llm_tokenizer = self.load_model_on_single_gpu(llm_name)
139
+ self.eos_token_id = self.llm_tokenizer.eos_token_id
140
+ self.pad_token_id = self.llm_tokenizer.pad_token_id
141
+
142
+ if llm_tune == 'freeze':
143
+ for name, param in self.llm_model.named_parameters():
144
+ param.requires_grad = False
145
+ elif llm_tune == 'full':
146
+ for name, param in self.llm_model.named_parameters():
147
+ param.requires_grad = True
148
+ elif llm_tune == 'lora':
149
+ lora_config = DeltaLoraConfig(args.lora_r,
150
+ args.lora_alpha,
151
+ args.lora_dropout,)
152
+ self.delta = LoraModel.from_config(lora_config, self.llm_model)
153
+ self.delta.freeze_module(set_state_dict=False)
154
+ self.delta.log()
155
+ elif llm_tune == 'mid_lora':
156
+ print("================")
157
+ print("加载了小lora")
158
+ print("=================")
159
+ lora_config = DeltaLoraConfig(args.lora_r, args.lora_alpha, args.lora_dropout, modified_modules=["q_proj", "v_proj", 'k_proj', "out_proj", "fc1", "fc2"])
160
+ self.delta = LoraModel.from_config(lora_config, self.llm_model)
161
+ self.delta.freeze_module(set_state_dict=False)
162
+ self.delta.log()
163
+ elif llm_tune == 'peft_lora':
164
+ config = PeftLoraConfig(
165
+ r=args.lora_r,
166
+ lora_alpha=args.lora_alpha,
167
+ # target_modules=modules,
168
+ lora_dropout=args.lora_dropout,
169
+ bias="none",
170
+ task_type="CAUSAL_LM",
171
+ )
172
+ self.llm_model = get_peft_model(self.llm_model, config)
173
+ for name, module in self.llm_model.named_modules():
174
+ if isinstance(module, LoraLayer):
175
+ if True:
176
+ module = module.to(torch.bfloat16)
177
+ if 'norm' in name:
178
+ module = module.to(torch.float32)
179
+ if 'lm_head' in name or 'embed_tokens' in name:
180
+ if hasattr(module, 'weight'):
181
+ if True and module.weight.dtype == torch.float32:
182
+ module = module.to(torch.bfloat16)
183
+ else:
184
+ raise NotImplementedError()
185
+
186
+ ## fixme: this is different from the original BLIP2
187
+ # self.eos_token_id = self.llm_tokenizer(
188
+ # "\n", add_special_tokens=False
189
+ # ).input_ids[0]
190
+ self.opt_proj = nn.Linear(self.Qformer.config.hidden_size, self.llm_model.config.hidden_size)
191
+
192
+ def load_llm(self, llm_model, load_4bit=False, enable_gradient_checkpointing=True):
193
+ llm_tokenizer = AutoTokenizer.from_pretrained(llm_model, use_fast=False, padding_side='right')
194
+ llm_tokenizer.add_special_tokens({'pad_token': '<pad>'})
195
+
196
+ special_tokens_dict = {'additional_special_tokens': ['<PROT>', '<TEXT>']}
197
+ llm_tokenizer.add_special_tokens(special_tokens_dict)
198
+
199
+ llm_model = AutoModelForCausalLM.from_pretrained(llm_model, torch_dtype=torch.bfloat16)
200
+ llm_model.resize_token_embeddings(len(llm_tokenizer)) ## this will cause bug when
201
+
202
+ return llm_model, llm_tokenizer
203
+
204
+
205
+ # def forward(self, batch):
206
+ # prot_batch, text_batch = batch
207
+ # prot_embeds = self.plm(**prot_batch, return_dict=True)
208
+ # prot_embeds = prot_embeds.last_hidden_state
209
+ # if self.plm_tune == 'freeze':
210
+ # prot_embeds = prot_embeds.detach()
211
+ # prot_embeds = self.ln_layer(prot_embeds)
212
+ # device = prot_embeds.device
213
+ # query_tokens = self.query_tokens.expand(prot_embeds.shape[0], -1, -1)
214
+ # query_output = self.Qformer.bert(
215
+ # query_embeds=query_tokens,
216
+ # encoder_hidden_states=prot_embeds,
217
+ # encoder_attention_mask=prot_batch.attention_mask,
218
+ # return_dict=True,
219
+ # )
220
+ # prot_tokens = self.opt_proj(query_output.last_hidden_state)
221
+ # prot_mask = torch.ones(prot_tokens.shape[:2], dtype=text_batch.attention_mask.dtype, device=device)
222
+ # prot_empty_targets = torch.ones(prot_tokens.shape[:2], dtype=torch.long, device=device).fill_(-100)
223
+
224
+ # targets = text_batch.input_ids.masked_fill(text_batch.input_ids == self.llm_tokenizer.pad_token_id, -100)
225
+ # targets = targets.masked_fill(text_batch.token_type_ids == 0, -100)
226
+ # targets = torch.cat([prot_empty_targets, targets], dim=1)
227
+
228
+ # inputs_embeds = self.llm_model.get_input_embeddings()(text_batch.input_ids)
229
+ # inputs_embeds = torch.cat((prot_tokens, inputs_embeds), dim=1)
230
+ # attention_mask = torch.cat([prot_mask, text_batch.attention_mask], dim=1)
231
+
232
+ # outputs = self.llm_model(
233
+ # inputs_embeds=inputs_embeds,
234
+ # attention_mask=attention_mask,
235
+ # return_dict=True,
236
+ # labels=targets,
237
+ # )
238
+ # loss = outputs.loss
239
+ # return loss
240
+
241
+ def forward(self, batch):
242
+ prot_batch, prompt_batch, text_dict = batch
243
+ text_seqs = text_dict['targets']
244
+ batch_size = prompt_batch['input_ids'].size(0)
245
+ # print("{{{{{}}}}}")
246
+ # print(batch_size)
247
+
248
+ prot_embeds = self.plm(**prot_batch, return_dict=True)
249
+ prot_embeds = prot_embeds.last_hidden_state
250
+ if self.plm_tune == 'freeze':
251
+ prot_embeds = prot_embeds.detach()
252
+ prot_embeds = self.ln_layer(prot_embeds)
253
+ device = prot_embeds.device
254
+ query_tokens = self.query_tokens.expand(prot_embeds.shape[0], -1, -1)
255
+ query_output = self.Qformer.bert(
256
+ query_embeds=query_tokens,
257
+ encoder_hidden_states=prot_embeds,
258
+ encoder_attention_mask=prot_batch.attention_mask,
259
+ return_dict=True,
260
+ )
261
+ prot_tokens = self.opt_proj(query_output.last_hidden_state)
262
+ prot_mask = torch.ones(prot_tokens.shape[:2], dtype=torch.long, device=device)
263
+
264
+ # === Step 3: 编码 prompt 输入 ===
265
+ prompt_embeds = self.llm_model.get_input_embeddings()(prompt_batch.input_ids) # [B, L_prompt, D_llm]
266
+ prompt_mask = prompt_batch['attention_mask']
267
+
268
+
269
+ text_batch = self.llm_tokenizer(
270
+ list(text_seqs),
271
+ padding='longest',
272
+ truncation=True,
273
+ max_length=1024,
274
+ return_tensors='pt'
275
+ ).to(device)
276
+ target_embeds = self.llm_model.get_input_embeddings()(text_batch['input_ids']) # [B, T, D]
277
+ target_mask = text_batch['attention_mask']
278
+ targets = text_batch['input_ids'].masked_fill(text_batch['input_ids'] == self.llm_tokenizer.pad_token_id, -100)
279
+
280
+ # === : 加入 ChatML 特殊 token embedding ===
281
+ embedding_layer = self.llm_model.get_input_embeddings()
282
+
283
+ def embed_special_str(token_str):
284
+ # 先 tokenize,得到一系列 ID
285
+ ids = self.llm_tokenizer(token_str, add_special_tokens=False).input_ids
286
+ # 把它变成 [1, N] tensor
287
+ ids_tensor = torch.tensor([ids], device=device)
288
+ # 查 embedding 层:
289
+ embs = embedding_layer(ids_tensor) # shape [1, N, D]
290
+ # Expand 到 batch 大小
291
+ return embs.expand(batch_size, -1, -1)
292
+
293
+ # 示例
294
+ embed_im_start = embed_special_str("<|im_start|>user\n protein sequence is:<protein>") # 可能对应多个 sub-tokens
295
+ embed_protein = embed_special_str("</protein>")
296
+ embed_im_end = embed_special_str("<|im_end|>\n")
297
+ embed_assistant= embed_special_str("<|im_start|>assistant\n")
298
+
299
+
300
+ user_embeds = torch.cat([embed_im_start, prot_tokens , embed_protein, prompt_embeds,embed_im_end, embed_assistant], dim=1)
301
+ user_mask = torch.ones(user_embeds.shape[:2], dtype=torch.long, device=device)
302
+
303
+ assistant_embeds = target_embeds
304
+ assistant_mask = target_mask
305
+
306
+ inputs_embeds = torch.cat([user_embeds, assistant_embeds], dim=1)
307
+ attention_mask = torch.cat([user_mask, assistant_mask], dim=1)
308
+
309
+ # === Step 6: 构造 labels,只监督 assistant 部分 ===
310
+ ignore_labels = torch.full(user_embeds.shape[:2], -100, dtype=torch.long, device=device)
311
+ assistant_labels = targets
312
+ labels = torch.cat([ignore_labels, assistant_labels], dim=1)
313
+
314
+ # print("embed_im_start:", embed_im_start.shape)
315
+ # print("prompt_embeds:", prompt_embeds.shape)
316
+ # print("prot_tokens:", prot_tokens.shape)
317
+ # print("embed_im_end:", embed_im_end.shape)
318
+ # print("embed_assistant:", embed_assistant.shape)
319
+ # print("target_embeds:", target_embeds.shape)
320
+ # print("labels:", labels.shape)
321
+ # print("inputs_embeds:", inputs_embeds.shape)
322
+
323
+ #============================
324
+
325
+ # inputs_embeds = torch.cat([prot_tokens, prompt_embeds, target_embeds], dim=1)
326
+ # attention_mask = torch.cat([prot_mask, prompt_mask, target_mask], dim=1)
327
+
328
+ # # === Step 7: 构造 labels,只监督 target 部分 ===
329
+ # prot_label_pad = torch.full(prot_tokens.shape[:2], -100, dtype=torch.long, device=device)
330
+ # prompt_label_pad = torch.full(prompt_mask.shape, -100, dtype=torch.long, device=device)
331
+ # labels = torch.cat([prot_label_pad, prompt_label_pad, targets], dim=1)
332
+
333
+ # === Step 8: 送入 LLM ===
334
+ outputs = self.llm_model(
335
+ inputs_embeds=inputs_embeds,
336
+ attention_mask=attention_mask,
337
+ labels=labels,
338
+ return_dict=True,
339
+ )
340
+ loss = outputs.loss
341
+ # prot_mask = torch.ones(prot_tokens.shape[:2], dtype=text_batch.attention_mask.dtype, device=device)
342
+ # prot_empty_targets = torch.ones(prot_tokens.shape[:2], dtype=torch.long, device=device).fill_(-100)
343
+ # empty_targets = torch.ones(prompt_batch.attention_mask.shape, dtype=torch.long, device=device).fill_(-100)
344
+ # targets = text_batch.input_ids.masked_fill(text_batch.input_ids == self.llm_tokenizer.pad_token_id, -100)
345
+ # targets = torch.cat([prot_empty_targets, empty_targets, targets], dim=1)
346
+
347
+ # prompt_embeds = self.llm_model.get_input_embeddings()(prompt_batch.input_ids)
348
+ # inputs_embeds = self.llm_model.get_input_embeddings()(text_batch.input_ids)
349
+ # inputs_embeds = torch.cat((prot_tokens, prompt_embeds, inputs_embeds), dim=1)
350
+ # attention_mask = torch.cat([prot_mask, prompt_batch.attention_mask, text_batch.attention_mask], dim=1)
351
+
352
+ # outputs = self.llm_model(
353
+ # inputs_embeds=inputs_embeds,
354
+ # attention_mask=attention_mask,
355
+ # return_dict=True,
356
+ # labels=targets,
357
+ # )
358
+ # loss = outputs.loss
359
+ return loss
360
+
361
+ # def forwardv2(self, batch):
362
+ # prot_batch, prompt_batch, text_batch = batch
363
+ # prot_embeds = self.plm(**prot_batch, return_dict=True)
364
+ # prot_embeds = prot_embeds.last_hidden_state
365
+ # if self.plm_tune == 'freeze':
366
+ # prot_embeds = prot_embeds.detach()
367
+ # prot_embeds = self.ln_layer(prot_embeds)
368
+ # device = prot_embeds.device
369
+ # query_tokens = self.query_tokens.expand(prot_embeds.shape[0], -1, -1)
370
+ # query_output = self.Qformer.bert(
371
+ # query_embeds=query_tokens,
372
+ # encoder_hidden_states=prot_embeds,
373
+ # encoder_attention_mask=prot_batch.attention_mask,
374
+ # return_dict=True,
375
+ # )
376
+ # prot_tokens = self.opt_proj(query_output.last_hidden_state)
377
+ # prot_mask = torch.ones(prot_tokens.shape[:2], dtype=text_batch.attention_mask.dtype, device=device)
378
+ # targets = text_batch.input_ids.masked_fill(text_batch.input_ids == self.llm_tokenizer.pad_token_id, -100)
379
+
380
+ # ### forward prefix
381
+ # prompt_embeds = self.llm_model.get_input_embeddings()(prompt_batch.input_ids)
382
+ # prefix_embeds = torch.cat([prot_tokens, prompt_embeds], dim=1)
383
+ # prefix_mask = torch.cat([prot_mask, prompt_batch.attention_mask], dim=1)
384
+ # prefix_output = self.llm_model.model(
385
+ # inputs_embeds=prefix_embeds,
386
+ # attention_mask=prefix_mask,
387
+ # use_cache=True,
388
+ # return_dict=True,
389
+ # )
390
+
391
+ # ## forward decoding
392
+ # if False:
393
+ # attention_mask = torch.cat([prot_mask, prompt_batch.attention_mask, text_batch.attention_mask], dim=1)
394
+ # else:
395
+ # attention_mask = text_batch.attention_mask
396
+ # print(prefix_output.past_key_values)
397
+ # outputs = self.llm_model(
398
+ # input_ids=text_batch.input_ids,
399
+ # attention_mask=attention_mask,
400
+ # past_key_values=prefix_output.past_key_values,
401
+ # return_dict=True,
402
+ # labels=targets,
403
+ # )
404
+ # loss = outputs.loss
405
+ # return loss
406
+
407
+ @torch.no_grad()
408
+ def generate(
409
+ self,
410
+ samples,
411
+ do_sample=False,
412
+ num_beams=5,
413
+ max_length=128,
414
+ min_length=1,
415
+ top_p=0.9,
416
+ repetition_penalty=1.0,
417
+ length_penalty=1.0,
418
+ num_captions=1,
419
+ temperature=1,
420
+ ):
421
+ """
422
+ Args:
423
+ samples (dict): A dictionary containing the following keys:
424
+ - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)
425
+ num_beams (int): Number of beams for beam search. 1 means no beam search.
426
+ max_length (int): The maximum length of the sequence to be generated.
427
+ min_length (int): The minimum length of the sequence to be generated.
428
+ top_p (float): The cumulative probability for nucleus sampling.
429
+ repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty.
430
+ num_captions (int): Number of captions to be generated for each image.
431
+ Returns:
432
+ captions (list): A list of strings of length batch_size * num_captions.
433
+ """
434
+ # prot_batch = samples['prot_batch']
435
+ # prompt_batch = samples['prompt_batch']
436
+
437
+ # # with self.maybe_autocast():
438
+ # prot_embeds = self.plm(**prot_batch, return_dict=True)
439
+ # prot_embeds = self.ln_layer(prot_embeds.last_hidden_state)
440
+
441
+ # query_tokens = self.query_tokens.expand(prot_embeds.shape[0], -1, -1)
442
+ # query_output = self.Qformer.bert(
443
+ # query_embeds=query_tokens,
444
+ # encoder_hidden_states=prot_embeds,
445
+ # encoder_attention_mask=prot_batch['attention_mask'],
446
+ # return_dict=True,
447
+ # )
448
+ # prot_tokens = self.opt_proj(query_output.last_hidden_state)
449
+
450
+
451
+
452
+ # # prompt_batch = samples['prompt_batch']
453
+ # prompt_input_ids = prompt_batch['input_ids'] # shape: [B, L]
454
+ # # for i, ids in enumerate(prompt_input_ids):
455
+ # # print(f"Prompt {i} token length: {len(ids)}")
456
+ # decoded_texts = [self.llm_tokenizer.decode(ids, skip_special_tokens=True) for ids in prompt_input_ids]
457
+ # # print("=========")
458
+ # # print(decoded_texts)
459
+ # #print(decoded_texts)
460
+ # save_path = "decoded_prompts.json"
461
+
462
+ # # 将 list 写入 JSON 文件
463
+ # with open(save_path, 'w', encoding='utf-8') as f:
464
+ # json.dump(decoded_texts, f, ensure_ascii=False, indent=4)
465
+
466
+ # prompt_attention_mask = prompt_batch['attention_mask']
467
+ # prompt_embeds = self.llm_model.model.embed_tokens(prompt_input_ids)
468
+
469
+ # inputs_embeds = torch.cat((prompt_embeds, prot_tokens), dim=1)
470
+
471
+ # prot_attention_mask = torch.ones(prot_tokens.shape[:2], dtype=prompt_attention_mask.dtype, device=prompt_attention_mask.device)
472
+ # #attention_mask = torch.cat([prot_attention_mask, prompt_attention_mask], dim=1)
473
+ # attention_mask = torch.cat([ prompt_attention_mask,prot_attention_mask], dim=1)
474
+
475
+ #==========================
476
+ prot_batch = samples['prot_batch']
477
+ prompt_batch = samples['prompt_batch']
478
+
479
+
480
+ device = prompt_batch['input_ids'].device
481
+ batch_size = prompt_batch['input_ids'].size(0)
482
+
483
+ # === Step 1: 编码蛋白质 + QFormer ===
484
+ prot_embeds = self.plm(**prot_batch, return_dict=True).last_hidden_state
485
+ prot_embeds = self.ln_layer(prot_embeds)
486
+ query_tokens = self.query_tokens.expand(prot_embeds.shape[0], -1, -1)
487
+ query_output = self.Qformer.bert(
488
+ query_embeds=query_tokens,
489
+ encoder_hidden_states=prot_embeds,
490
+ encoder_attention_mask=prot_batch['attention_mask'],
491
+ return_dict=True,
492
+ )
493
+ prot_tokens = self.opt_proj(query_output.last_hidden_state) # [B, L_qformer, D]
494
+
495
+ # === Step 2: 编码 prompt 输入 ===
496
+ prompt_input_ids = prompt_batch['input_ids']
497
+ prompt_attention_mask = prompt_batch['attention_mask']
498
+ prompt_embeds = self.llm_model.get_input_embeddings()(prompt_input_ids)
499
+
500
+ # === Step 3: 获取 ChatML 特殊 token 的 embedding ===
501
+ embedding_layer = self.llm_model.get_input_embeddings()
502
+
503
+ def embed_special_str(token_str):
504
+ # 先 tokenize,得到一系列 ID
505
+ ids = self.llm_tokenizer(token_str, add_special_tokens=False).input_ids
506
+ # 把它变成 [1, N] tensor
507
+ ids_tensor = torch.tensor([ids], device=device)
508
+ # 查 embedding 层:
509
+ embs = embedding_layer(ids_tensor) # shape [1, N, D]
510
+ # Expand 到 batch 大小
511
+ return embs.expand(batch_size, -1, -1)
512
+
513
+ # 示例
514
+ embed_im_start = embed_special_str("<|im_start|>user\nprotein sequence is: <protein>") # 可能对应多个 sub-tokens
515
+ embed_protein = embed_special_str("</protein>")
516
+ embed_im_end = embed_special_str("<|im_end|>\n")
517
+ embed_assistant= embed_special_str("<|im_start|>assistant\n")
518
+
519
+
520
+ # === Step 4: 拼接 Embeddings ===
521
+ user_embeds = torch.cat([embed_im_start, prot_tokens, embed_protein, prompt_embeds, embed_im_end], dim=1)
522
+ assistant_prefix = embed_assistant # 模型从这里开始生成
523
+ inputs_embeds = torch.cat([user_embeds, assistant_prefix], dim=1)
524
+
525
+ # === Step 5: attention_mask ===
526
+ user_mask = torch.ones(user_embeds.shape[:2], dtype=torch.long, device=device)
527
+ assistant_mask = torch.ones((batch_size, embed_assistant.size(1)), dtype=torch.long, device=device)
528
+ attention_mask = torch.cat([user_mask, assistant_mask], dim=1)
529
+
530
+ outputs = self.llm_model.generate(
531
+ inputs_embeds=inputs_embeds,
532
+ attention_mask=attention_mask,
533
+ do_sample=do_sample,
534
+ top_p=top_p,
535
+ temperature=temperature,
536
+ num_beams=num_beams,
537
+ max_new_tokens=max_length,
538
+ min_length=min_length,
539
+ # pad_token_id=self.pad_token_id,
540
+ eos_token_id=self.eos_token_id,
541
+ repetition_penalty=repetition_penalty,
542
+ length_penalty=length_penalty,
543
+ num_return_sequences=num_captions,
544
+ use_cache=True,
545
+ cache_implementation="hybrid"
546
+ )
547
+ output_text = self.llm_tokenizer.batch_decode(outputs, skip_special_tokens=True)
548
+ output_text = [text.strip() for text in output_text]
549
+ # print(output_text)
550
+ return output_text
BioReason-0813/model/blip2_stage2.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from model.blip2_opt import Blip2OPT
4
+ import pytorch_lightning as pl
5
+ from torch import optim
6
+ from lavis.common.optims import LinearWarmupCosineLRScheduler, LinearWarmupStepLRScheduler
7
+ import json
8
+ import torch.distributed as dist
9
+ # from peft import LoraConfig, TaskType
10
+ from typing import Any, Dict
11
+ from model.help_funcs import caption_evaluate, AttrDict
12
+ try:
13
+ from model.opt_flash_attention import replace_opt_attn_with_flash_attn, replace_opt_attn_with_original_attn
14
+ except ModuleNotFoundError:
15
+ pass
16
+
17
+
18
+ def get_module_state_dict(state_dict, module_name):
19
+ module_state_dict = {}
20
+ for key, value in state_dict.items():
21
+ if key.startswith(module_name):
22
+ key = key[len(module_name) + 1:]
23
+ if key == '':
24
+ return value
25
+ module_state_dict[key] = value
26
+ return module_state_dict
27
+
28
+ class Blip2Stage2(pl.LightningModule):
29
+ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
30
+ # checkpoint.pop('optimizer_states')
31
+ to_be_removed = []
32
+ for key, value in checkpoint['state_dict'].items():
33
+ try:
34
+ if not self.get_parameter(key).requires_grad:
35
+ to_be_removed.append(key)
36
+ except AttributeError:
37
+ to_be_removed.append(key)
38
+ for key in to_be_removed:
39
+ checkpoint['state_dict'].pop(key)
40
+
41
+ def __init__(self, args):
42
+ super().__init__()
43
+ if isinstance(args, dict):
44
+ args = AttrDict(**args)
45
+
46
+ self.args = args
47
+ self.caption_eval_epoch = args.caption_eval_epoch
48
+ self.do_sample = args.do_sample
49
+ self.num_beams = args.num_beams
50
+ self.max_inference_len = args.max_inference_len
51
+ self.min_inference_len = args.min_inference_len
52
+ self.llm_tune = args.llm_tune
53
+ self.enable_flash = args.enable_flash
54
+ # if args.llm_name.find('galactica') >= 0:
55
+ self.blip2 = Blip2OPT(args.bert_name,
56
+ args.num_query_token,
57
+ args.cross_attention_freq,
58
+ args.plm_model,
59
+ args.plm_tune,
60
+ args.llm_name,
61
+ args.llm_tune,
62
+ args.qformer_tune,
63
+ args.peft_dir,
64
+ args)
65
+ # else:
66
+ # raise NotImplementedError()
67
+ self.save_hyperparameters(args)
68
+
69
+ def load_from_stage1_checkpoint(self, path):
70
+ ckpt = torch.load(path, map_location='cpu')
71
+ state_dict = ckpt['state_dict']
72
+ state_dict = {k.split('blip2qformer.')[1]:v for k, v in state_dict.items()}
73
+ self.blip2.load_state_dict(state_dict, strict=False)
74
+ return self
75
+
76
+ def configure_optimizers(self):
77
+ self.trainer.fit_loop.setup_data()
78
+ warmup_steps = min(len(self.trainer.train_dataloader), self.args.warmup_steps)
79
+ optimizer = optim.AdamW(self.parameters(), lr=self.args.init_lr, weight_decay=self.args.weight_decay)
80
+ if self.args.scheduler == 'linear_warmup_cosine_lr':
81
+ self.scheduler = LinearWarmupCosineLRScheduler(optimizer, self.args.max_epochs, self.args.min_lr, self.args.init_lr, warmup_steps, self.args.warmup_lr)
82
+ elif self.args.scheduler == 'linear_warmup_step_lr':
83
+ self.scheduler = LinearWarmupStepLRScheduler(optimizer, self.args.max_epochs, self.args.min_lr, self.args.init_lr, self.args.lr_decay_rate, self.args.warmup_lr, warmup_steps)
84
+ elif self.args.scheduler == 'None':
85
+ self.scheduler = None
86
+ else:
87
+ raise NotImplementedError()
88
+ return optimizer
89
+
90
+ def save_predictions(self, predictions, targets, q_types=None, log_prefix=''):
91
+ assert len(predictions) == len(targets)
92
+ if log_prefix:
93
+ name = f'{log_prefix}_predictions.txt'
94
+ else:
95
+ name = 'predictions.txt'
96
+ with open(os.path.join(self.logger.log_dir, name), 'w', encoding='utf8') as f:
97
+ if q_types is not None:
98
+ for p, t, q in zip(predictions, targets, q_types):
99
+ line = {'prediction': p, 'target': t, 'q_type': q}
100
+ f.write(json.dumps(line, ensure_ascii=True) + '\n')
101
+ else:
102
+ for p, t in zip(predictions, targets):
103
+ line = {'prediction': p, 'target': t}
104
+ f.write(json.dumps(line, ensure_ascii=True) + '\n')
105
+
106
+ def on_validation_epoch_start(self) -> None:
107
+ if self.enable_flash:
108
+ replace_opt_attn_with_original_attn()
109
+ self.saved_dict_list = []
110
+ self.prediction_list0 = []
111
+ self.target_list0 = []
112
+ self.prediction_list1 = []
113
+ self.target_list1 = []
114
+
115
+ @torch.no_grad()
116
+ def validation_step(self, batch, batch_idx, dataloader_idx=0):
117
+ prot_batch, prompt_batch, target_dict = batch
118
+ if (dataloader_idx % 2) == 0:
119
+ # text_batch = batch[-1]
120
+ # batch_size = text_batch.input_ids.shape[0]
121
+ batch_size = len(target_dict['targets']) # ✅ 正确获取batch大小
122
+ loss = self.blip2(batch)
123
+ ###============== Overall Loss ===================###
124
+ self.log(f"dataloader{dataloader_idx}/val loss", float(loss), batch_size=batch_size, sync_dist=True)
125
+ elif (dataloader_idx % 2) == 1:
126
+ if (self.current_epoch+1) % self.caption_eval_epoch != 0:
127
+ return
128
+ # prot_batch, prompt_batch, target_dict = batch
129
+ ###============== Captioning Results ===================###
130
+ samples = {'prot_batch': prot_batch, 'prompt_batch': prompt_batch}
131
+ predictions = self.blip2.generate(
132
+ samples,
133
+ do_sample=self.do_sample,
134
+ num_beams=self.num_beams,
135
+ max_length=self.max_inference_len,
136
+ min_length=self.min_inference_len
137
+ )
138
+ target_dict['predictions'] = predictions
139
+ self.saved_dict_list.append(target_dict)
140
+
141
+ def gather_dict_results(self, dict_list):
142
+ list_of_dict_list = [None for _ in range(self.trainer.world_size)]
143
+ dist.all_gather_object(list_of_dict_list, dict_list)
144
+ dict_list = [i for ii in list_of_dict_list for i in ii] ## dict list, each dict has values that are lists of predictions, etc.
145
+ keys = dict_list[0].keys()
146
+ gathered_dict = {} # each value is a list of predictions, etc.
147
+ for key in keys:
148
+ gathered_dict[key] = [i for d in dict_list for i in d[key]]
149
+ dict_list = []
150
+ for i in range(len(gathered_dict['predictions'])):
151
+ d = {k:gathered_dict[k][i] for k in keys}
152
+ dict_list.append(d)
153
+ return dict_list
154
+
155
+ def save_results(self, dict_list, log_prefix=""):
156
+ ## save the results
157
+ if log_prefix:
158
+ name = f'results/{log_prefix}_predictions.txt'
159
+ else:
160
+ name = 'predictions.txt'
161
+ with open(name, 'w', encoding='utf8') as f:
162
+ for d in dict_list:
163
+ f.write(json.dumps(d, ensure_ascii=True) + '\n')
164
+
165
+ def on_validation_epoch_end(self):
166
+ if self.enable_flash:
167
+ replace_opt_attn_with_flash_attn()
168
+ if (self.current_epoch+1) % self.caption_eval_epoch != 0:
169
+ return
170
+ result_list = self.gather_dict_results(self.saved_dict_list)
171
+ ## empty cache
172
+ self.saved_dict_list = []
173
+
174
+ if self.global_rank == 0:
175
+ # 假设 args.filename = 'stage2_continue_deeplocmulti_07241522'
176
+ filename_parts = self.args.filename.split('_')
177
+ # 获取最后两部分并组合
178
+ new_filename = '_'.join(filename_parts[-2:]) # 得到 'deeplocmulti_07241522'
179
+ self.save_results(result_list, new_filename)
180
+ all_predictions = [i['predictions'] for i in result_list]
181
+ all_targets = [i['targets'] for i in result_list]
182
+
183
+ log_prefix = 'dataset0' ## fixme: this is just a placeholder
184
+ if 'q_types' in result_list[0]:
185
+ ## evaluate protein qa
186
+ pass
187
+ else:
188
+ ## evaluate captioning
189
+ bleu2, bleu4, rouge_1, rouge_2, rouge_l, meteor_score = \
190
+ caption_evaluate(all_predictions, all_targets, self.blip2.llm_tokenizer, self.max_inference_len)
191
+ acc = evaluate_exact_match(all_predictions, all_targets)
192
+ self.log(f"{log_prefix}/acc", acc, sync_dist=False)
193
+ self.log(f"{log_prefix}/bleu2", bleu2, sync_dist=False)
194
+ self.log(f"{log_prefix}/bleu4", bleu4, sync_dist=False)
195
+ self.log(f"{log_prefix}/rouge_1", rouge_1, sync_dist=False)
196
+ self.log(f"{log_prefix}/rouge_2", rouge_2, sync_dist=False)
197
+ self.log(f"{log_prefix}/rouge_l", rouge_l, sync_dist=False)
198
+ self.log(f"{log_prefix}/meteor_score", meteor_score, sync_dist=False)
199
+
200
+ @torch.no_grad()
201
+ def validation_step_old(self, batch, batch_idx, dataloader_idx=0):
202
+ if (dataloader_idx % 2) == 0:
203
+ text_batch = batch[-1]
204
+ batch_size = text_batch.input_ids.shape[0]
205
+ loss = self.blip2(batch)
206
+ ###============== Overall Loss ===================###
207
+ self.log(f"dataloader{dataloader_idx}/val loss", float(loss), batch_size=batch_size, sync_dist=True)
208
+ elif (dataloader_idx % 2) == 1:
209
+ if (self.current_epoch+1) % self.caption_eval_epoch != 0:
210
+ return
211
+ prot_batch, prompt_batch, target_dict = batch
212
+ ###============== Captioning Results ===================###
213
+ samples = {'prot_batch': prot_batch, 'prompt_batch': prompt_batch}
214
+ predictions = self.blip2.generate(
215
+ samples,
216
+ do_sample=self.do_sample,
217
+ num_beams=self.num_beams,
218
+ max_length=self.max_inference_len,
219
+ min_length=self.min_inference_len
220
+ )
221
+ if dataloader_idx // 2 == 0:
222
+ self.prediction_list0.append(predictions)
223
+ self.target_list0.append(target_dict)
224
+ elif dataloader_idx // 2 == 1:
225
+ self.prediction_list1.append(predictions)
226
+ self.target_list1.append(target_dict)
227
+ else:
228
+ raise NotImplementedError
229
+ else:
230
+ raise NotImplementedError
231
+
232
+ def on_validation_epoch_end_old(self):
233
+ if self.enable_flash:
234
+ replace_opt_attn_with_flash_attn()
235
+ if (self.current_epoch+1) % self.caption_eval_epoch != 0:
236
+ return
237
+ predictions0 = [i for ii in self.prediction_list0 for i in ii]
238
+ targets0 = [i for ii in self.target_list0 for i in ii['answers']]
239
+ if 'q_types' in self.target_list0[0]:
240
+ q_types0 = [i for ii in self.target_list0 for i in ii['q_types']]
241
+ self.reduce_and_evaluate_qa(predictions0, targets0, q_types0, 'dataset0')
242
+ else:
243
+ self.reduce_and_evaluate_captioning(predictions0, targets0, 'dataset0')
244
+
245
+ if len(self.prediction_list1) > 0:
246
+ predictions1 = [i for ii in self.prediction_list1 for i in ii]
247
+ targets1 = [i for ii in self.target_list1 for i in ii]
248
+ self.reduce_and_evaluate_captioning(predictions1, targets1, 'dataset1')
249
+
250
+ def reduce_and_evaluate_qa(self, predictions, targets, q_types, log_prefix=""):
251
+ all_predictions = [None for _ in range(self.trainer.world_size)]
252
+ all_targets = [None for _ in range(self.trainer.world_size)]
253
+ all_q_types = [None for _ in range(self.trainer.world_size)]
254
+ dist.all_gather_object(all_predictions, predictions)
255
+ dist.all_gather_object(all_targets, targets)
256
+ dist.all_gather_object(all_q_types, q_types)
257
+ if self.global_rank == 0:
258
+ all_predictions = [i for ii in all_predictions for i in ii]
259
+ all_targets = [i for ii in all_targets for i in ii]
260
+ all_q_types = [i for ii in all_q_types for i in ii]
261
+ self.save_predictions(all_predictions, all_targets, all_q_types, log_prefix=log_prefix)
262
+
263
+ def reduce_and_evaluate_captioning(self, predictions, targets, log_prefix=""):
264
+ all_predictions = [None for _ in range(self.trainer.world_size)]
265
+ all_targets = [None for _ in range(self.trainer.world_size)]
266
+ dist.all_gather_object(all_predictions, predictions)
267
+ dist.all_gather_object(all_targets, targets)
268
+ if self.global_rank == 0:
269
+ all_predictions = [i for ii in all_predictions for i in ii]
270
+ all_targets = [i for ii in all_targets for i in ii]
271
+ self.save_predictions(all_predictions, all_targets, log_prefix)
272
+ ## fixme: I am not sure if the max length is the same as previous experiments
273
+ bleu2, bleu4, rouge_1, rouge_2, rouge_l, meteor_score = \
274
+ caption_evaluate(all_predictions, all_targets, self.blip2.llm_tokenizer, self.max_inference_len)
275
+ acc = evaluate_exact_match(all_predictions, all_targets)
276
+ self.log(f"{log_prefix}/acc", acc, sync_dist=False)
277
+ self.log(f"{log_prefix}/bleu2", bleu2, sync_dist=False)
278
+ self.log(f"{log_prefix}/bleu4", bleu4, sync_dist=False)
279
+ self.log(f"{log_prefix}/rouge_1", rouge_1, sync_dist=False)
280
+ self.log(f"{log_prefix}/rouge_2", rouge_2, sync_dist=False)
281
+ self.log(f"{log_prefix}/rouge_l", rouge_l, sync_dist=False)
282
+ self.log(f"{log_prefix}/meteor_score", meteor_score, sync_dist=False)
283
+
284
+ def training_step(self, batch, batch_idx):
285
+ if self.scheduler:
286
+ self.scheduler.step(self.trainer.current_epoch, self.trainer.global_step)
287
+
288
+ #batch_size = batch[-1].input_ids.size(0)
289
+ batch_size = len(batch[-1]['targets'])
290
+ ###============== Overall Loss ===================###
291
+ loss = self.blip2(batch)
292
+ self.log("loss", float(loss), batch_size=batch_size, sync_dist=True)
293
+ self.log("lr", self.trainer.optimizers[0].param_groups[0]['lr'], batch_size=batch_size, sync_dist=True)
294
+ return loss
295
+
296
+ @staticmethod
297
+ def add_model_specific_args(parent_parser):
298
+ parser = parent_parser.add_argument_group("ProtBlip2")
299
+ # train mode
300
+ parser.add_argument('--save_every_n_epochs', type=int, default=0)
301
+
302
+ # Bert
303
+ parser.add_argument('--bert_name', type=str, default='/nas/shared/kilab/wangyujia/ProtT3/plm_model/microsoft')
304
+ parser.add_argument('--cross_attention_freq', type=int, default=2)
305
+ parser.add_argument('--num_query_token', type=int, default=8)
306
+ parser.add_argument('--qformer_tune',type=str,default='train')
307
+ # OPT
308
+ parser.add_argument('--llm_name', type=str, default="/oss/wangyujia/BIO/construction_finetuning/alpaca/v1-20250609-141541/checkpoint-50-merged")
309
+ parser.add_argument('--num_beams', type=int, default=5)
310
+ parser.add_argument('--do_sample', action='store_true', default=False)
311
+ parser.add_argument('--max_inference_len', type=int, default=512)
312
+ parser.add_argument('--min_inference_len', type=int, default=1)
313
+ parser.add_argument('--llm_tune', type=str, default='freeze')
314
+ parser.add_argument('--peft_config', type=str, default='')
315
+ parser.add_argument('--peft_dir', type=str, default='')
316
+
317
+ ## plm model
318
+ parser.add_argument('--plm_model', type=str, default='/nas/shared/kilab/wangyujia/ProtT3/plm_model/esm2-150m')
319
+ parser.add_argument('--plm_tune', type=str, default='freeze')
320
+
321
+ ## lora config
322
+ parser.add_argument('--lora_r', type=int, default=8)
323
+ parser.add_argument('--lora_alpha', type=int, default=16)
324
+ parser.add_argument('--lora_dropout', type=int, default=0.1)
325
+ parser.add_argument('--enbale_gradient_checkpointing', action='store_true', default=False)
326
+
327
+ # optimization
328
+ parser.add_argument('--weight_decay', type=float, default=0.05, help='optimizer weight decay')
329
+ parser.add_argument('--init_lr', type=float, default=1e-4, help='optimizer init learning rate')
330
+ parser.add_argument('--min_lr', type=float, default=1e-5, help='optimizer min learning rate')
331
+ parser.add_argument('--warmup_lr', type=float, default=1e-6, help='optimizer warmup learning rate')
332
+ parser.add_argument('--warmup_steps', type=int, default=1000, help='optimizer warmup steps')
333
+ parser.add_argument('--lr_decay_rate', type=float, default=0.9, help='optimizer lr decay rate')
334
+ parser.add_argument('--scheduler', type=str, default='linear_warmup_cosine_lr', help='type of scheduler') # or linear_warmup_step_lr
335
+ parser.add_argument('--stage1_path', type=str, default='')
336
+ parser.add_argument('--stage2_path', type=str, default='')
337
+ parser.add_argument('--init_checkpoint', type=str, default='/nas/shared/kilab/wangyujia/ProtT3/all_checkpoints/stage2_07070513_2datasets_construct/epoch=09.ckpt/converted.ckpt')
338
+ parser.add_argument('--caption_eval_epoch', type=int, default=5)
339
+ return parent_parser
340
+
341
+
342
+
343
+ # def evaluate_exact_match(predictions, targets):
344
+ # acc = 0
345
+ # for prediction, target in zip(predictions, targets):
346
+ # if prediction.strip() == target.strip():
347
+ # acc += 1
348
+ # acc = round(acc / len(predictions) * 100, 2)
349
+ # return acc
350
+
351
+ import re
352
+
353
+ def evaluate_exact_match(predictions, targets):
354
+ acc = 0
355
+ for prediction, target in zip(predictions, targets):
356
+ # 使用正则提取 <answer>...</answer> 中的内容
357
+ match = re.search(r"<answer>(.*?)</answer>", target.strip(), re.DOTALL)
358
+ if match:
359
+ answer = match.group(1).strip()
360
+ if prediction.strip() == answer:
361
+ acc += 1
362
+ else:
363
+ print(f"Warning: No <answer> tag found in target: {target}")
364
+ acc = round(acc / len(predictions) * 100, 2)
365
+ return acc
BioReason-0813/model/help_funcs.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from nltk.translate.bleu_score import corpus_bleu
3
+ from nltk.translate.meteor_score import meteor_score
4
+ from rouge_score import rouge_scorer
5
+ from tqdm import tqdm
6
+ import numpy as np
7
+
8
+
9
+ def caption_evaluate(predictions, targets, tokenizer, text_trunc_length):
10
+ targets = [t.strip() for t in targets]
11
+ meteor_scores = []
12
+ references = []
13
+ hypotheses = []
14
+ for gt, out in tqdm(zip(targets, predictions)):
15
+ gt_tokens = tokenizer.tokenize(gt, truncation=True, max_length=text_trunc_length,
16
+ padding='max_length')
17
+ ## added for galactica
18
+ gt_tokens = list(filter(('<pad>').__ne__, gt_tokens))
19
+ gt_tokens = list(filter(('[PAD]').__ne__, gt_tokens))
20
+ gt_tokens = list(filter(('[CLS]').__ne__, gt_tokens))
21
+ gt_tokens = list(filter(('[SEP]').__ne__, gt_tokens))
22
+
23
+ out_tokens = tokenizer.tokenize(out, truncation=True, max_length=text_trunc_length,
24
+ padding='max_length')
25
+ out_tokens = list(filter(('<pad>').__ne__, out_tokens))
26
+ gt_tokens = list(filter(('[PAD]').__ne__, gt_tokens))
27
+ out_tokens = list(filter(('[CLS]').__ne__, out_tokens))
28
+ out_tokens = list(filter(('[SEP]').__ne__, out_tokens))
29
+
30
+ references.append([gt_tokens])
31
+ hypotheses.append(out_tokens)
32
+
33
+ mscore = meteor_score([gt_tokens], out_tokens)
34
+ meteor_scores.append(mscore)
35
+
36
+ bleu2 = corpus_bleu(references, hypotheses, weights=(.5,.5))
37
+ bleu4 = corpus_bleu(references, hypotheses, weights=(.25,.25,.25,.25))
38
+ bleu2 *= 100
39
+ bleu4 *= 100
40
+
41
+ print('BLEU-2 score:', bleu2)
42
+ print('BLEU-4 score:', bleu4)
43
+ _meteor_score = np.mean(meteor_scores)
44
+ _meteor_score *= 100
45
+ print('Average Meteor score:', _meteor_score)
46
+
47
+ scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'])
48
+
49
+ rouge_scores = []
50
+
51
+ references = []
52
+ hypotheses = []
53
+
54
+ for gt, out in tqdm(zip(targets, predictions)):
55
+ rs = scorer.score(out, gt)
56
+ rouge_scores.append(rs)
57
+
58
+ print('ROUGE score:')
59
+ rouge_1 = np.mean([rs['rouge1'].fmeasure for rs in rouge_scores]) * 100
60
+ rouge_2 = np.mean([rs['rouge2'].fmeasure for rs in rouge_scores]) * 100
61
+ rouge_l = np.mean([rs['rougeL'].fmeasure for rs in rouge_scores]) * 100
62
+ print('rouge1:', rouge_1)
63
+ print('rouge2:', rouge_2)
64
+ print('rougeL:', rouge_l)
65
+ return bleu2, bleu4, rouge_1, rouge_2, rouge_l, _meteor_score
66
+
67
+
68
+ class AttrDict(dict):
69
+ def __init__(self, *args, **kwargs):
70
+ super(AttrDict, self).__init__(*args, **kwargs)
71
+ self.__dict__ = self
72
+
73
+
74
+ def pad_and_concat(tensor_list, fill_value=0):
75
+ '''
76
+ concat the first dimension and pad the second dimension
77
+ tensor_list: [[B (diff), N_num, *], ...]
78
+ '''
79
+ device = tensor_list[0].device
80
+ dtype=tensor_list[0].dtype
81
+ max_dim1 = max(t.shape[1] for t in tensor_list)
82
+ sum_dim0 = sum(t.shape[0] for t in tensor_list)
83
+ if len(tensor_list[0].shape) == 3:
84
+ out = torch.full((sum_dim0, max_dim1, tensor_list[0].shape[-1]), fill_value=fill_value, device=device, dtype=dtype)
85
+ i = 0
86
+ for t in tensor_list:
87
+ out[i:i+t.shape[0], :t.shape[1]] = t
88
+ i += t.shape[0]
89
+ return out
90
+ elif len(tensor_list[0].shape) == 2:
91
+ out = torch.full((sum_dim0, max_dim1), fill_value=fill_value, device=device, dtype=dtype)
92
+ i = 0
93
+ for t in tensor_list:
94
+ out[i:i+t.shape[0], :t.shape[1]] = t
95
+ i += t.shape[0]
96
+ return out
97
+ raise NotImplementedError()
98
+
99
+
100
+ def hf_enable_gradient_checkpointing(hf_model):
101
+ if hasattr(hf_model, "enable_input_require_grads"):
102
+ hf_model.enable_input_require_grads()
103
+ else:
104
+
105
+ def make_inputs_require_grad(module, input, output):
106
+ output.requires_grad_(True)
107
+
108
+ hf_model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
109
+
110
+ # enable gradient checkpointing for memory efficiency
111
+ hf_model.gradient_checkpointing_enable()
112
+ return hf_model
BioReason-0813/prompt_templates.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ prompt_templates = {
2
+ "classification": """
3
+ Analyze the following protein sequence and predict its classification.
4
+
5
+ Protein sequence: <protein>{aa_seq}</protein>
6
+
7
+ Please provide your reasoning and classification.
8
+
9
+ <think>
10
+ Let me analyze this protein sequence step by step:
11
+ 1. Sequence length: {seq_length}
12
+ 2. Composition analysis...
13
+ 3. Structural predictions...
14
+ 4. Functional domains...
15
+ </think>
16
+
17
+ <answer>
18
+ Classification: {label}
19
+ </answer>
20
+ """,
21
+ "function_prediction": """
22
+ Given the protein sequence below, predict its function and classification:
23
+
24
+ Sequence: <protein>{aa_seq}</protein>
25
+
26
+ Analyze the sequence and provide your prediction.
27
+
28
+ <think>
29
+ Sequence analysis:
30
+ - Length: {seq_length} amino acids
31
+ - Notable features...
32
+ - Homology considerations...
33
+ </think>
34
+
35
+ <answer>
36
+ Function prediction: {label}
37
+ </answer>
38
+ """,
39
+ "location_prediction": """
40
+ Predict the cellular location and classification of this protein:
41
+
42
+ Protein sequence: <protein>{aa_seq}</protein>
43
+
44
+ What is the most likely classification for this protein?
45
+
46
+ <think>
47
+ Location and function analysis:
48
+ - Sequence characteristics...
49
+ - Signal peptides...
50
+ - Transmembrane regions...
51
+ </think>
52
+
53
+ <answer>
54
+ Classification: {label}
55
+ </answer>
56
+ """
57
+ }
BioReason-0813/run.sh ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ echo "Starting GRPO training..."
2
+
3
+ #!/bin/bash
4
+ # run_blip2.sh
5
+ # 用于启动 BLIP2 + GRPO 训练的脚本
6
+
7
+ # ===== 基本路径配置 =====
8
+ DATA_FILE=/oss/wangyujia/ProtT3/ProtT3/data/sft/dataset/DeepLocBinary/test.csv
9
+ DATASET_NAME=deeplocbinary
10
+ OUTPUT_DIR=./output
11
+ CACHE_DIR=./cache
12
+
13
+ # ===== 模型配置 =====
14
+ BERT_PATH=/nas/shared/kilab/wangyujia/ProtT3/plm_model/microsoft
15
+ PLM_MODEL=/nas/shared/kilab/wangyujia/ProtT3/plm_model/esm2-150m
16
+ LLM_MODEL=/oss/wangyujia/BIO/construction_finetuning/alpaca/v1-20250609-141541/checkpoint-50-merged
17
+ SFT_CHECKPOINT=/nas/shared/kilab/wangyujia/ProtT3/all_checkpoints/stage2_07301646_2datasets_construct/epoch=09.ckpt/converted.ckpt
18
+ # ===== 训练参数 =====
19
+ BATCH_SIZE=4
20
+ EPOCHS=3
21
+ LR=1e-5
22
+
23
+ # ===== 奖励函数权重 =====
24
+ FORMAT_WEIGHT=0.2
25
+ ACCURACY_WEIGHT=0.6
26
+ REPETITION_WEIGHT=0.2
27
+
28
+ # ===== 运行训练 =====
29
+ python blips_reason.py \
30
+ --data_file_paths ${DATA_FILE} \
31
+ --dataset_name ${DATASET_NAME} \
32
+ --reward_funcs combined \
33
+ --format_weight ${FORMAT_WEIGHT} \
34
+ --accuracy_weight ${ACCURACY_WEIGHT} \
35
+ --repetition_weight ${REPETITION_WEIGHT} \
36
+ --use_custom_prompts \
37
+ --template_name classification \
38
+ --max_seq_length 1000 \
39
+ --output_dir ${OUTPUT_DIR} \
40
+ --per_device_train_batch_size ${BATCH_SIZE} \
41
+ --num_train_epochs ${EPOCHS} \
42
+ --learning_rate ${LR} \
43
+ --bert_name ${BERT_PATH} \
44
+ --plm_model ${PLM_MODEL} \
45
+ --llm_name ${LLM_MODEL} \
46
+ --sft_checkpoint ${SFT_CHECKPOINT} \
47
+ --plm_tune freeze \
48
+ --llm_tune lora \
49
+ --qformer_tune train \
50
+ --lora_r 8 \
51
+ --lora_alpha 16 \
52
+ --lora_dropout 0.1 \
53
+ --enable_flash \
54
+ --cache_dir ${CACHE_DIR}
55
+
56
+
57
+
58
+ # python protein_reason.py \
59
+ # --output_dir "./grpo_outputs" \
60
+ # --model_name_or_path "Qwen/Qwen3-0.6B" \
61
+ # --protein_model_name_or_path "facebook/esm2_t6_8M_UR50D" \
62
+ # --qformer_model_name_or_path "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext" \
63
+ # --dataset_name "wanglab/protein_function" \
64
+ # --sft_checkpoint "./checkpoints/best_model" \
65
+ # --per_device_train_batch_size 4 \
66
+ # --gradient_accumulation_steps 4 \
67
+ # --num_train_epochs 3 \
68
+ # --learning_rate 1e-6 \
69
+ # --beta 0.04 \
70
+ # --temperature 0.6 \
71
+ # --top_p 0.95 \
72
+ # --top_k 20 \
73
+ # --max_completion_length 800 \
74
+ # --num_generations 8 \
75
+ # --reward_funcs "xmlcount" "soft_format" "strict_format" "correctness" \
76
+ # --lora_r 32 \
77
+ # --lora_alpha 64 \
78
+ # --lora_dropout 0.05 \
79
+ # --freeze_protein_modules \
80
+ # --logging_steps 2 \
81
+ # --eval_strategy "steps" \
82
+ # --eval_steps 100 \
83
+ # --save_steps 200 \
84
+ # --report_to "wandb" \
85
+ # --log_completions
86
+
87
+ # python blip2_reason.py \
88
+ # --data_file_paths /oss/wangyujia/ProtT3/ProtT3/data/sft/dataset/DeepLocBinary/test.csv \
89
+ # --reward_funcs combined \
90
+ # --format_weight 0.2 \
91
+ # --accuracy_weight 0.6 \
92
+ # --repetition_weight 0.2 \
93
+ # --use_custom_prompts \
94
+ # --template_name classification \
95
+ # --max_seq_length 1000 \
96
+ # --output_dir ./output \
97
+ # --per_device_train_batch_size 4 \
98
+ # --num_train_epochs 3 \
99
+ # --learning_rate 1e-5
100
+
101
+ echo "GRPO training completed!"
102
+
103
+ echo "All training stages completed successfully!"
BioReason-main/.gitignore ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ .idea/
3
+ __pycache__/
4
+ *.py[cod]
5
+ *$py.class
6
+ wandb/
7
+ .DS_Store
8
+ .vscode/
9
+ .venv/
10
+ .env
11
+ .pytest_cache/
12
+
13
+ # C extensions
14
+ *.so
15
+
16
+ outputs/
17
+
18
+ # Distribution / packaging
19
+ .Python
20
+ build/
21
+ develop-eggs/
22
+ dist/
23
+ downloads/
24
+ eggs/
25
+ .eggs/
26
+ lib/
27
+ lib64/
28
+ parts/
29
+ sdist/
30
+ var/
31
+ wheels/
32
+ share/python-wheels/
33
+ *.egg-info/
34
+ .installed.cfg
35
+ *.egg
36
+ MANIFEST
37
+
38
+ # PyInstaller
39
+ # Usually these files are written by a python script from a template
40
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
41
+ *.manifest
42
+ *.spec
43
+
44
+ # Installer logs
45
+ pip-log.txt
46
+ pip-delete-this-directory.txt
47
+
48
+ # Unit test / coverage reports
49
+ htmlcov/
50
+ .tox/
51
+ .nox/
52
+ .coverage
53
+ .coverage.*
54
+ .cache
55
+ nosetests.xml
56
+ coverage.xml
57
+ *.cover
58
+ *.py,cover
59
+ .hypothesis/
60
+ .pytest_cache/
61
+ cover/
62
+
63
+ # Translations
64
+ *.mo
65
+ *.pot
66
+
67
+ # Django stuff:
68
+ *.log
69
+ local_settings.py
70
+ db.sqlite3
71
+ db.sqlite3-journal
72
+
73
+ # Flask stuff:
74
+ instance/
75
+ .webassets-cache
76
+
77
+ # Scrapy stuff:
78
+ .scrapy
79
+
80
+ # Sphinx documentation
81
+ docs/_build/
82
+
83
+ # PyBuilder
84
+ .pybuilder/
85
+ target/
86
+
87
+ # Jupyter Notebook
88
+ .ipynb_checkpoints
89
+
90
+ # IPython
91
+ profile_default/
92
+ ipython_config.py
93
+
94
+ # pyenv
95
+ # For a library or package, you might want to ignore these files since the code is
96
+ # intended to run in multiple environments; otherwise, check them in:
97
+ # .python-version
98
+
99
+ # pipenv
100
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
101
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
102
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
103
+ # install all needed dependencies.
104
+ #Pipfile.lock
105
+
106
+ # UV
107
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
108
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
109
+ # commonly ignored for libraries.
110
+ #uv.lock
111
+
112
+ # poetry
113
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
114
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
115
+ # commonly ignored for libraries.
116
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
117
+ #poetry.lock
118
+
119
+ # pdm
120
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
121
+ #pdm.lock
122
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
123
+ # in version control.
124
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
125
+ .pdm.toml
126
+ .pdm-python
127
+ .pdm-build/
128
+
129
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
130
+ __pypackages__/
131
+
132
+ # Celery stuff
133
+ celerybeat-schedule
134
+ celerybeat.pid
135
+
136
+ # SageMath parsed files
137
+ *.sage.py
138
+
139
+ # Environments
140
+ .env
141
+ .venv
142
+ env/
143
+ venv/
144
+ ENV/
145
+ env.bak/
146
+ venv.bak/
147
+
148
+ # Spyder project settings
149
+ .spyderproject
150
+ .spyproject
151
+
152
+ # Rope project settings
153
+ .ropeproject
154
+
155
+ # mkdocs documentation
156
+ /site
157
+
158
+ # mypy
159
+ .mypy_cache/
160
+ .dmypy.json
161
+ dmypy.json
162
+
163
+ # Pyre type checker
164
+ .pyre/
165
+
166
+ # pytype static type analyzer
167
+ .pytype/
168
+
169
+ # Cython debug symbols
170
+ cython_debug/
171
+
172
+ # PyCharm
173
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
174
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
175
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
176
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
177
+ #.idea/
178
+
179
+ # PyPI configuration file
180
+ .pypirc
BioReason-main/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
BioReason-main/README.md ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <h1 align="center">
2
+ 🧬 BioReason<br>Incentivizing Multimodal Biological Reasoning<br>within a DNA-LLM Model
3
+ </h1>
4
+
5
+ <p align="center">
6
+ <a href="https://www.arxiv.org/abs/2505.23579" target="_blank"><img src="https://img.shields.io/badge/arXiv-2505.23579-FF6B6B?style=for-the-badge&logo=arxiv&logoColor=white" alt="arXiv"></a>
7
+ <a href="https://github.com/bowang-lab/BioReason"><img src="https://img.shields.io/badge/GitHub-Code-4A90E2?style=for-the-badge&logo=github&logoColor=white" alt="GitHub"></a>
8
+ <a href="https://bowang-lab.github.io/BioReason/"><img src="https://img.shields.io/badge/Website-Online-00B89E?style=for-the-badge&logo=internet-explorer&logoColor=white" alt="Website"></a>
9
+ <a href="https://huggingface.co/collections/wanglab/bioreason-683cd17172a037a31d208f70"><img src="https://img.shields.io/badge/HuggingFace-Dataset-FFBF00?style=for-the-badge&logo=huggingface&logoColor=white" alt="HuggingFace Dataset"></a>
10
+ </p>
11
+
12
+ <br>
13
+
14
+ ## Updates [Jun 10, 2025]
15
+ - We are integrating vLLM to improve the speed and efficiency of the GRPO pipeline. We expect this to be pushed by end of week.
16
+ - Checkpoints along with the custom DNA-LLM model class will be released on HuggingFace by end of week.
17
+ - More training results with GRPO will be shared soon.
18
+
19
+ <br>
20
+
21
+ ## Abstract
22
+
23
+ Unlocking deep, interpretable biological reasoning from complex genomic data is a major AI challenge hindering scientific discovery. Current DNA foundation models, despite strong sequence representation, struggle with multi-step reasoning and lack inherent transparent, biologically intuitive explanations. We introduce BioReason, a pioneering architecture that, for the first time, deeply integrates a DNA foundation model with a large language model (LLM). This novel connection enables the LLM to directly process and reason with genomic information as a fundamental input, fostering a new form of multimodal biological understanding. BioReason's sophisticated multi-step reasoning is developed through supervised fine-tuning and targeted reinforcement learning, guiding the system to generate logical, biologically coherent deductions. On biological reasoning benchmarks including KEGG-based disease pathway prediction—where accuracy improves from 88% to 97%—and variant effect prediction, BioReason demonstrates an average 15% performance gain over strong single-modality baselines.
24
+
25
+ <br>
26
+
27
+ ## Key Contributions
28
+
29
+ • **Novel multimodal architecture**: The first successful integration of a DNA foundation model with an LLM, establishing a new methodology for AI-driven biological studies.
30
+
31
+ • **Advanced reasoning methodology**: A systematic training approach combining supervised fine-tuning and reinforcement learning that incentivizes multi-step biological reasoning.
32
+
33
+ • **New biological reasoning benchmarks**: Development and curation of novel benchmarks for evaluating biological reasoning capabilities, including an annotated reasoning dataset for gene pathway and disease prediction from KEGG.
34
+
35
+ • **Empirical performance improvements**: Demonstration that BioReason outperforms both DNA foundation models and LLMs used independently or in simple combination, with average performance gains of 15%+ over baseline.
36
+
37
+ • **Interpretable reasoning traces**: A mechanism for generating step-by-step biological reasoning traces that provide interpretable predictions, enhancing scientific insight and hypothesis generation.
38
+
39
+ <br>
40
+
41
+ ## Datasets
42
+
43
+ The datasets used to train and evaluate BioReason can be found on our [HuggingFace collection](https://huggingface.co/collections/wanglab/bioreason-683cd17172a037a31d208f70) with detailed download and usage instructions.
44
+
45
+ <br>
46
+
47
+ ## Checkpoints
48
+
49
+ We will release the checkpoints soon!
50
+
51
+ <br>
52
+
53
+ ## Installation
54
+
55
+ ### Prerequisites
56
+ - Python 3.11+
57
+ - CUDA/GPU for best performance
58
+
59
+ ### Installation Steps
60
+ ```bash
61
+ # Clone the repository
62
+ git clone https://github.com/bowang-lab/BioReason.git
63
+ cd BioReason
64
+
65
+ # Install package
66
+ pip install -e .
67
+ ```
68
+
69
+ <br>
70
+
71
+ ## Results
72
+
73
+ ### KEGG-Derived Biological Reasoning Task
74
+ Performance comparison on 290 test datapoints for multi-step mechanistic reasoning:
75
+
76
+ | Model | Accuracy | F1-Score | Precision | Recall |
77
+ |-------|----------|----------|-----------|---------|
78
+ | [DNA] NT - 500M | 86.55 | 69.76 | 73.23 | 66.61 |
79
+ | [DNA] Evo2 - 1B | 88.28 | 72.43 | 75.23 | 69.83 |
80
+ | [LLM] Qwen3 - 1B | 85.17 | 65.71 | 71.39 | 64.19 |
81
+ | [LLM] Qwen3 - 4B | 93.48 | 85.44 | 88.31 | 86.72 |
82
+ | [DNA-LLM] NT + Qwen3 - 1B | 88.42 | 72.13 | 75.42 | 71.91 |
83
+ | [DNA-LLM] NT + Qwen3 - 1B (+RL) | 89.66 | 74.11 | 78.82 | 72.96 |
84
+ | [DNA-LLM] NT + Qwen3 - 4B | 96.90 | **89.03** | **90.99** | **89.38** |
85
+ | [DNA-LLM] Evo2 + Qwen3 - 1B | 90.42 | 75.62 | 77.42 | 73.91 |
86
+ | [DNA-LLM] Evo2 + Qwen3 - 4B | **97.24** | 86.30 | 86.75 | 87.25 |
87
+
88
+ ### Variant Effect Prediction Benchmarks
89
+ Performance on pathogenic/benign classification:
90
+
91
+ | Model | Variant Effect - Coding | | Variant Effect - Non-SNV | |
92
+ |-------|------------|----------|------------|----------|
93
+ | | Accuracy | F1-Score | Accuracy | F1-Score |
94
+ | [DNA] NT - 500M | 60.91 | 45.20 | 67.93 | 65.97 |
95
+ | [DNA] Evo2 - 1B | 70.07 | 49.19 | 76.17 | 66.51 |
96
+ | [LLM] Qwen3 - 1B | 46.55 | 34.82 | 70.67 | 76.21 |
97
+ | [LLM] Qwen3 - 4B | 48.99 | 39.58 | 61.86 | 67.60 |
98
+ | [DNA-LLM] NT + Qwen3 - 1B | 55.58 | 54.50 | 72.82 | 76.93 |
99
+ | [DNA-LLM] NT + Qwen3 - 4B | 60.94 | 55.66 | 65.59 | 73.00 |
100
+ | [DNA-LLM] Evo2 + Qwen3 - 1B | 72.83 | 68.90 | **88.20** | **89.91** |
101
+ | [DNA-LLM] Evo2 + Qwen3 - 4B | **80.21** | **80.00** | 83.85 | 85.02 |
102
+
103
+ <br>
104
+
105
+ ## Citation
106
+
107
+ If you find this work useful, please cite our paper:
108
+
109
+ ```bibtex
110
+ @misc{fallahpour2025bioreasonincentivizingmultimodalbiological,
111
+ title={BioReason: Incentivizing Multimodal Biological Reasoning within a DNA-LLM Model},
112
+ author={Adibvafa Fallahpour and Andrew Magnuson and Purav Gupta and Shihao Ma and Jack Naimer and Arnav Shah and Haonan Duan and Omar Ibrahim and Hani Goodarzi and Chris J. Maddison and Bo Wang},
113
+ year={2025},
114
+ eprint={2505.23579},
115
+ archivePrefix={arXiv},
116
+ primaryClass={cs.LG},
117
+ url={https://arxiv.org/abs/2505.23579},
118
+ }
119
+ ```
120
+
121
+ <br>
122
+
123
+ ## Authors
124
+
125
+ - **Adibvafa Fallahpour**¹²³⁵ * (adibvafa.fallahpour@mail.utoronto.ca)
126
+ - **Andrew Magnuson**¹² *
127
+ - **Purav Gupta**¹² *
128
+ - **Shihao Ma**¹²³
129
+ - **Jack Naimer**¹²³
130
+ - **Arnav Shah**¹²³
131
+ - **Haonan Duan**¹²
132
+ - **Omar Ibrahim**³
133
+ - **Hani Goodarzi**†⁴⁶
134
+ - **Chris J. Maddison**†¹²⁷
135
+ - **Bo Wang**†¹²³
136
+
137
+ ¹ University of Toronto ² Vector Institute ³ University Health Network (UHN) <br>
138
+ ⁴ Arc Institute ⁵ Cohere ⁶ University of California, San Francisco ⁷ Google DeepMind
139
+
140
+ <br>
141
+ * Equal contribution <br>
142
+ † Equal advising
143
+
144
+ ---
145
+
146
+ <p align="center">
147
+ Made with ❤️ at University of Toronto, Vector Institute, and University Health Network
148
+ </p>
BioReason-main/bioreason.egg-info/PKG-INFO ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: bioreason
3
+ Version: 0.1.0
4
+ Summary: Bio-related Reasoning with Language Models
5
+ License: UNKNOWN
6
+ Platform: UNKNOWN
7
+ Classifier: Programming Language :: Python :: 3
8
+ Classifier: Programming Language :: Python :: 3.11
9
+ Classifier: License :: OSI Approved :: MIT License
10
+ Classifier: Operating System :: OS Independent
11
+ Requires-Python: >=3.11
12
+ Description-Content-Type: text/markdown
13
+ License-File: LICENSE
14
+ Requires-Dist: torch
15
+ Requires-Dist: torchvision
16
+ Requires-Dist: transformers
17
+ Requires-Dist: accelerate
18
+ Requires-Dist: qwen-vl-utils
19
+ Requires-Dist: jupyter
20
+ Requires-Dist: datasets
21
+ Requires-Dist: peft
22
+ Requires-Dist: pytorch_lightning
23
+ Requires-Dist: wandb
24
+ Requires-Dist: trl[vllm]
25
+ Requires-Dist: bitsandbytes
26
+ Requires-Dist: deepspeed
27
+ Provides-Extra: dev
28
+ Requires-Dist: pytest; extra == "dev"
29
+ Requires-Dist: black; extra == "dev"
30
+ Requires-Dist: isort; extra == "dev"
31
+ Requires-Dist: mypy; extra == "dev"
32
+ Dynamic: license-file
33
+
34
+ <h1 align="center">
35
+ 🧬 BioReason<br>Incentivizing Multimodal Biological Reasoning<br>within a DNA-LLM Model
36
+ </h1>
37
+
38
+ <p align="center">
39
+ <a href="https://www.arxiv.org/abs/2505.23579" target="_blank"><img src="https://img.shields.io/badge/arXiv-2505.23579-FF6B6B?style=for-the-badge&logo=arxiv&logoColor=white" alt="arXiv"></a>
40
+ <a href="https://github.com/bowang-lab/BioReason"><img src="https://img.shields.io/badge/GitHub-Code-4A90E2?style=for-the-badge&logo=github&logoColor=white" alt="GitHub"></a>
41
+ <a href="https://bowang-lab.github.io/BioReason/"><img src="https://img.shields.io/badge/Website-Online-00B89E?style=for-the-badge&logo=internet-explorer&logoColor=white" alt="Website"></a>
42
+ <a href="https://huggingface.co/collections/wanglab/bioreason-683cd17172a037a31d208f70"><img src="https://img.shields.io/badge/HuggingFace-Dataset-FFBF00?style=for-the-badge&logo=huggingface&logoColor=white" alt="HuggingFace Dataset"></a>
43
+ </p>
44
+
45
+ <br>
46
+
47
+ ## Updates [Jun 10, 2025]
48
+ - We are integrating vLLM to improve the speed and efficiency of the GRPO pipeline. We expect this to be pushed by end of week.
49
+ - Checkpoints along with the custom DNA-LLM model class will be released on HuggingFace by end of week.
50
+ - More training results with GRPO will be shared soon.
51
+
52
+ <br>
53
+
54
+ ## Abstract
55
+
56
+ Unlocking deep, interpretable biological reasoning from complex genomic data is a major AI challenge hindering scientific discovery. Current DNA foundation models, despite strong sequence representation, struggle with multi-step reasoning and lack inherent transparent, biologically intuitive explanations. We introduce BioReason, a pioneering architecture that, for the first time, deeply integrates a DNA foundation model with a large language model (LLM). This novel connection enables the LLM to directly process and reason with genomic information as a fundamental input, fostering a new form of multimodal biological understanding. BioReason's sophisticated multi-step reasoning is developed through supervised fine-tuning and targeted reinforcement learning, guiding the system to generate logical, biologically coherent deductions. On biological reasoning benchmarks including KEGG-based disease pathway prediction—where accuracy improves from 88% to 97%—and variant effect prediction, BioReason demonstrates an average 15% performance gain over strong single-modality baselines.
57
+
58
+ <br>
59
+
60
+ ## Key Contributions
61
+
62
+ • **Novel multimodal architecture**: The first successful integration of a DNA foundation model with an LLM, establishing a new methodology for AI-driven biological studies.
63
+
64
+ • **Advanced reasoning methodology**: A systematic training approach combining supervised fine-tuning and reinforcement learning that incentivizes multi-step biological reasoning.
65
+
66
+ • **New biological reasoning benchmarks**: Development and curation of novel benchmarks for evaluating biological reasoning capabilities, including an annotated reasoning dataset for gene pathway and disease prediction from KEGG.
67
+
68
+ • **Empirical performance improvements**: Demonstration that BioReason outperforms both DNA foundation models and LLMs used independently or in simple combination, with average performance gains of 15%+ over baseline.
69
+
70
+ • **Interpretable reasoning traces**: A mechanism for generating step-by-step biological reasoning traces that provide interpretable predictions, enhancing scientific insight and hypothesis generation.
71
+
72
+ <br>
73
+
74
+ ## Datasets
75
+
76
+ The datasets used to train and evaluate BioReason can be found on our [HuggingFace collection](https://huggingface.co/collections/wanglab/bioreason-683cd17172a037a31d208f70) with detailed download and usage instructions.
77
+
78
+ <br>
79
+
80
+ ## Checkpoints
81
+
82
+ We will release the checkpoints soon!
83
+
84
+ <br>
85
+
86
+ ## Installation
87
+
88
+ ### Prerequisites
89
+ - Python 3.11+
90
+ - CUDA/GPU for best performance
91
+
92
+ ### Installation Steps
93
+ ```bash
94
+ # Clone the repository
95
+ git clone https://github.com/bowang-lab/BioReason.git
96
+ cd BioReason
97
+
98
+ # Install package
99
+ pip install -e .
100
+ ```
101
+
102
+ <br>
103
+
104
+ ## Results
105
+
106
+ ### KEGG-Derived Biological Reasoning Task
107
+ Performance comparison on 290 test datapoints for multi-step mechanistic reasoning:
108
+
109
+ | Model | Accuracy | F1-Score | Precision | Recall |
110
+ |-------|----------|----------|-----------|---------|
111
+ | [DNA] NT - 500M | 86.55 | 69.76 | 73.23 | 66.61 |
112
+ | [DNA] Evo2 - 1B | 88.28 | 72.43 | 75.23 | 69.83 |
113
+ | [LLM] Qwen3 - 1B | 85.17 | 65.71 | 71.39 | 64.19 |
114
+ | [LLM] Qwen3 - 4B | 93.48 | 85.44 | 88.31 | 86.72 |
115
+ | [DNA-LLM] NT + Qwen3 - 1B | 88.42 | 72.13 | 75.42 | 71.91 |
116
+ | [DNA-LLM] NT + Qwen3 - 1B (+RL) | 89.66 | 74.11 | 78.82 | 72.96 |
117
+ | [DNA-LLM] NT + Qwen3 - 4B | 96.90 | **89.03** | **90.99** | **89.38** |
118
+ | [DNA-LLM] Evo2 + Qwen3 - 1B | 90.42 | 75.62 | 77.42 | 73.91 |
119
+ | [DNA-LLM] Evo2 + Qwen3 - 4B | **97.24** | 86.30 | 86.75 | 87.25 |
120
+
121
+ ### Variant Effect Prediction Benchmarks
122
+ Performance on pathogenic/benign classification:
123
+
124
+ | Model | Variant Effect - Coding | | Variant Effect - Non-SNV | |
125
+ |-------|------------|----------|------------|----------|
126
+ | | Accuracy | F1-Score | Accuracy | F1-Score |
127
+ | [DNA] NT - 500M | 60.91 | 45.20 | 67.93 | 65.97 |
128
+ | [DNA] Evo2 - 1B | 70.07 | 49.19 | 76.17 | 66.51 |
129
+ | [LLM] Qwen3 - 1B | 46.55 | 34.82 | 70.67 | 76.21 |
130
+ | [LLM] Qwen3 - 4B | 48.99 | 39.58 | 61.86 | 67.60 |
131
+ | [DNA-LLM] NT + Qwen3 - 1B | 55.58 | 54.50 | 72.82 | 76.93 |
132
+ | [DNA-LLM] NT + Qwen3 - 4B | 60.94 | 55.66 | 65.59 | 73.00 |
133
+ | [DNA-LLM] Evo2 + Qwen3 - 1B | 72.83 | 68.90 | **88.20** | **89.91** |
134
+ | [DNA-LLM] Evo2 + Qwen3 - 4B | **80.21** | **80.00** | 83.85 | 85.02 |
135
+
136
+ <br>
137
+
138
+ ## Citation
139
+
140
+ If you find this work useful, please cite our paper:
141
+
142
+ ```bibtex
143
+ @misc{fallahpour2025bioreasonincentivizingmultimodalbiological,
144
+ title={BioReason: Incentivizing Multimodal Biological Reasoning within a DNA-LLM Model},
145
+ author={Adibvafa Fallahpour and Andrew Magnuson and Purav Gupta and Shihao Ma and Jack Naimer and Arnav Shah and Haonan Duan and Omar Ibrahim and Hani Goodarzi and Chris J. Maddison and Bo Wang},
146
+ year={2025},
147
+ eprint={2505.23579},
148
+ archivePrefix={arXiv},
149
+ primaryClass={cs.LG},
150
+ url={https://arxiv.org/abs/2505.23579},
151
+ }
152
+ ```
153
+
154
+ <br>
155
+
156
+ ## Authors
157
+
158
+ - **Adibvafa Fallahpour**¹²³⁵ * (adibvafa.fallahpour@mail.utoronto.ca)
159
+ - **Andrew Magnuson**¹² *
160
+ - **Purav Gupta**¹² *
161
+ - **Shihao Ma**¹²³
162
+ - **Jack Naimer**¹²³
163
+ - **Arnav Shah**¹²³
164
+ - **Haonan Duan**¹²
165
+ - **Omar Ibrahim**³
166
+ - **Hani Goodarzi**†⁴⁶
167
+ - **Chris J. Maddison**†¹²⁷
168
+ - **Bo Wang**†¹²³
169
+
170
+ ¹ University of Toronto ² Vector Institute ³ University Health Network (UHN) <br>
171
+ ⁴ Arc Institute ⁵ Cohere ⁶ University of California, San Francisco ⁷ Google DeepMind
172
+
173
+ <br>
174
+ * Equal contribution <br>
175
+ † Equal advising
176
+
177
+ ---
178
+
179
+ <p align="center">
180
+ Made with ❤️ at University of Toronto, Vector Institute, and University Health Network
181
+ </p>
BioReason-main/bioreason.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ bioreason/__init__.py
5
+ bioreason.egg-info/PKG-INFO
6
+ bioreason.egg-info/SOURCES.txt
7
+ bioreason.egg-info/dependency_links.txt
8
+ bioreason.egg-info/requires.txt
9
+ bioreason.egg-info/top_level.txt
BioReason-main/bioreason.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
BioReason-main/bioreason.egg-info/requires.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ transformers
4
+ accelerate
5
+ qwen-vl-utils
6
+ jupyter
7
+ datasets
8
+ peft
9
+ pytorch_lightning
10
+ wandb
11
+ trl[vllm]
12
+ bitsandbytes
13
+ deepspeed
14
+
15
+ [dev]
16
+ pytest
17
+ black
18
+ isort
19
+ mypy
BioReason-main/bioreason.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ bioreason
BioReason-main/bioreason/__init__.py ADDED
File without changes
BioReason-main/bioreason/dataset/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .kegg import KEGGDataset, split_kegg_dataset
2
+ from .utils import torch_to_hf_dataset, truncate_dna
3
+ from .variant_effect import get_format_variant_effect_function
4
+
5
+ __all__ = [
6
+ "KEGGDataset",
7
+ "split_kegg_dataset",
8
+ "torch_to_hf_dataset",
9
+ "truncate_dna",
10
+ "get_format_variant_effect_function",
11
+ ]
BioReason-main/bioreason/dataset/kegg.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+ import sys
5
+ import torch
6
+ from torch.utils.data import Dataset, DataLoader
7
+ from typing import Any, Dict, List, Tuple
8
+
9
+ from bioreason.dataset.utils import torch_to_hf_dataset
10
+ from bioreason.models.dl.processing_dl import DLProcessor
11
+ from trl.data_utils import maybe_apply_chat_template
12
+
13
+
14
+ class KEGGDataset(Dataset):
15
+ """Dataset for KEGG data."""
16
+
17
+ def __init__(self, data_dir: str):
18
+ """
19
+ Initialize the dataset by loading all JSON files from the given directory.
20
+
21
+ Args:
22
+ data_dir: Path to the directory containing JSON files
23
+ """
24
+ self.data_dir = data_dir
25
+ self.data = []
26
+
27
+ # Load all JSON files
28
+ json_files = sorted([f for f in os.listdir(data_dir) if f.endswith(".json")])
29
+
30
+ # Process each file
31
+ for filename in json_files:
32
+ file_path = os.path.join(data_dir, filename)
33
+ kegg_id = filename.split("_")[1]
34
+
35
+ with open(file_path, "r", encoding="utf-8") as f:
36
+ item = json.load(f)
37
+ item["kegg_id"] = kegg_id
38
+ processed_item = self._process_item(item)
39
+ self.data.append(processed_item)
40
+
41
+ def _process_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
42
+ """
43
+ Process a single data item to format fields as required.
44
+
45
+ Args:
46
+ item: Original data item from JSON
47
+
48
+ Returns:
49
+ Processed data item
50
+ """
51
+ # Extract question as is
52
+ question = item.get("question", "")
53
+
54
+ # Convert answer to lowercase and strip whitespace
55
+ answer = item.get("answer", "").lower().strip()
56
+
57
+ # Combine reasoning steps into a single paragraph with newlines
58
+ reasoning_steps = item.get("reasoning", {}).get("reasoning_steps", [])
59
+ reasoning = "\n".join(reasoning_steps)
60
+
61
+ # Convert sequences to uppercase and strip whitespace
62
+ reference_sequence = item.get("reference_sequence", "").upper().strip()
63
+ variant_sequence = item.get("variant_sequence", "").upper().strip()
64
+
65
+ return {
66
+ "question": question,
67
+ "answer": answer,
68
+ "reasoning": reasoning,
69
+ "reference_sequence": reference_sequence,
70
+ "variant_sequence": variant_sequence,
71
+ }
72
+
73
+ def __len__(self) -> int:
74
+ """Return the number of items in the dataset."""
75
+ return len(self.data)
76
+
77
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
78
+ """Return a specific item from the dataset."""
79
+ return self.data[idx]
80
+
81
+
82
+ def split_kegg_dataset(
83
+ dataset: KEGGDataset,
84
+ train_ratio: float = 0.8,
85
+ val_ratio: float = 0.1,
86
+ test_ratio: float = 0.1,
87
+ seed: int = 42,
88
+ ) -> Tuple[KEGGDataset, KEGGDataset, KEGGDataset]:
89
+ """
90
+ Split a KEGG dataset into train, validation, and test sets.
91
+
92
+ Args:
93
+ dataset: The dataset to split
94
+ train_ratio: Proportion of data for training
95
+ val_ratio: Proportion of data for validation
96
+ test_ratio: Proportion of data for testing
97
+ batch_size: Batch size for the dataloaders
98
+ seed: Random seed for reproducibility
99
+
100
+ Returns:
101
+ Tuple of (train_dataset, val_dataset, test_dataset)
102
+ """
103
+ # Calculate the size of each split
104
+ dataset_size = len(dataset)
105
+ train_size = int(train_ratio * dataset_size)
106
+ val_size = int(val_ratio * dataset_size)
107
+ test_size = dataset_size - train_size - val_size
108
+ assert train_ratio + val_ratio + test_ratio == 1.0, "Ratios must sum to 1"
109
+
110
+ # Set the random seed
111
+ torch.manual_seed(seed)
112
+ random.seed(seed)
113
+
114
+ # Split the dataset
115
+ train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
116
+ dataset, [train_size, val_size, test_size]
117
+ )
118
+
119
+ return train_dataset, val_dataset, test_dataset
120
+
121
+
122
+ def create_kegg_dataloader(
123
+ data_dir: str,
124
+ batch_size: int = 2,
125
+ shuffle: bool = True,
126
+ num_workers: int = 2,
127
+ pin_memory: bool = True,
128
+ ) -> DataLoader:
129
+ """
130
+ Create a DataLoader for the KEGG dataset.
131
+
132
+ Args:
133
+ data_dir: Path to the directory containing JSON files
134
+ batch_size: Batch size for the dataloader
135
+ shuffle: Whether to shuffle the data
136
+ num_workers: Number of worker processes for loading data
137
+ pin_memory: Whether to pin memory for faster data transfer
138
+
139
+ Returns:
140
+ DataLoader for the KEGG dataset
141
+ """
142
+ dataset = KEGGDataset(data_dir)
143
+ return DataLoader(
144
+ dataset,
145
+ batch_size=batch_size,
146
+ shuffle=shuffle,
147
+ num_workers=num_workers,
148
+ pin_memory=pin_memory,
149
+ )
150
+
151
+
152
+ def get_format_kegg_function(model_name: str) -> Any:
153
+ """
154
+ Get the appropriate format function for a given model name.
155
+ """
156
+ if model_name.lower() == "llm":
157
+ return format_kegg_for_llm
158
+ elif model_name.lower() == "dna-llm":
159
+ return format_kegg_for_dna_llm
160
+ else:
161
+ raise ValueError(f"Unsupported model name: {model_name}")
162
+
163
+
164
+ def format_kegg_for_dna_llm(example: Dict[str, Any]) -> Dict[str, Any]:
165
+ """
166
+ Format a KEGG example into the required chat format for DNA-LLM.
167
+ """
168
+ return {
169
+ "prompt": [
170
+ {
171
+ "role": "user",
172
+ "content": [
173
+ *({"type": "dna", "text": None} for _ in range(2)),
174
+ {"type": "text", "text": example["question"].strip()},
175
+ ],
176
+ },
177
+ {
178
+ "role": "assistant",
179
+ "reasoning_content": example["reasoning"].strip(),
180
+ "content": [
181
+ {"type": "text", "text": f"Answer: {example['answer'].strip()}"},
182
+ ],
183
+ },
184
+ ],
185
+ "dna_sequences": [
186
+ example["reference_sequence"],
187
+ example["variant_sequence"],
188
+ ],
189
+ "answer": example["answer"],
190
+ }
191
+
192
+
193
+ def format_kegg_for_llm(example: Dict[str, Any]) -> Dict[str, Any]:
194
+ """
195
+ Format a KEGG example into the required chat format for LLM.
196
+ """
197
+ question = f"Reference sequence: {example['reference_sequence']}\nVariant sequence: {example['variant_sequence']}\nQuestion: {example['question']}"
198
+ return {
199
+ "prompt": [
200
+ {
201
+ "role": "user",
202
+ "content": [
203
+ *({"type": "dna", "text": None} for _ in range(2)),
204
+ {"type": "text", "text": question.strip()},
205
+ ],
206
+ },
207
+ {
208
+ "role": "assistant",
209
+ "reasoning_content": example["reasoning"].strip(),
210
+ "content": [
211
+ {"type": "text", "text": f"Answer: {example['answer'].strip()}"},
212
+ ],
213
+ },
214
+ ],
215
+ "dna_sequences": [
216
+ "",
217
+ "",
218
+ ],
219
+ "answer": example["answer"],
220
+ }
221
+
222
+
223
+ def qwen_dna_collate_fn(
224
+ examples: List[Dict],
225
+ processor: DLProcessor,
226
+ max_length_text: int,
227
+ max_length_dna: int,
228
+ return_answer_in_batch: bool = False,
229
+ ) -> Dict:
230
+ """
231
+ Custom collate function for Qwen DNA models.
232
+
233
+ Creates a batch with proper labels for supervised fine-tuning where only
234
+ the assistant responses contribute to the loss calculation.
235
+ """
236
+ prompts_text = [
237
+ maybe_apply_chat_template(example, processor)["prompt"] for example in examples
238
+ ]
239
+ batch_dna_sequences = [example["dna_sequences"] for example in examples]
240
+
241
+ batch = processor(
242
+ text=prompts_text,
243
+ batch_dna_sequences=batch_dna_sequences,
244
+ return_tensors="pt",
245
+ padding=True,
246
+ padding_side="left",
247
+ add_special_tokens=False,
248
+ max_length_text=max_length_text,
249
+ max_length_dna=max_length_dna,
250
+ )
251
+
252
+ # Create labels tensor filled with -100 (ignored in loss calculation)
253
+ labels = torch.full_like(batch["input_ids"], -100)
254
+
255
+ # Get token IDs for special markers
256
+ assistant_start_marker = "<|im_start|>assistant\n"
257
+ im_end_marker = "<|im_end|>"
258
+
259
+ assistant_start_token_ids = processor.tokenizer.encode(
260
+ assistant_start_marker, add_special_tokens=False
261
+ )
262
+ im_end_token_ids = processor.tokenizer.encode(
263
+ im_end_marker, add_special_tokens=False
264
+ )
265
+
266
+ # Convert token arrays to tensors for faster comparison
267
+ assistant_marker_tensor = torch.tensor(
268
+ assistant_start_token_ids, device=batch["input_ids"].device
269
+ )
270
+ im_end_marker_tensor = torch.tensor(
271
+ im_end_token_ids, device=batch["input_ids"].device
272
+ )
273
+
274
+ # Get dimensions for easier reference
275
+ assistant_marker_len = len(assistant_start_token_ids)
276
+ im_end_marker_len = len(im_end_token_ids)
277
+
278
+ # For each sequence in the batch
279
+ for i in range(batch["input_ids"].shape[0]):
280
+ input_ids = batch["input_ids"][i]
281
+ seq_len = input_ids.size(0)
282
+
283
+ # Track assistant sections
284
+ assistant_sections = []
285
+
286
+ # Find all assistant start markers
287
+ start_positions = []
288
+ for pos in range(seq_len - assistant_marker_len + 1):
289
+ if torch.all(
290
+ input_ids[pos : pos + assistant_marker_len] == assistant_marker_tensor
291
+ ):
292
+ start_positions.append(
293
+ pos + assistant_marker_len
294
+ ) # Store position after marker
295
+
296
+ # Find all end markers
297
+ end_positions = []
298
+ for pos in range(seq_len - im_end_marker_len + 1):
299
+ if torch.all(
300
+ input_ids[pos : pos + im_end_marker_len] == im_end_marker_tensor
301
+ ):
302
+ end_positions.append(pos) # Store position at start of end marker
303
+
304
+ # Match start and end markers to create sections
305
+ for start_pos in start_positions:
306
+ # Find the next end marker after this start position
307
+ valid_ends = [pos for pos in end_positions if pos > start_pos]
308
+ if valid_ends:
309
+ end_pos = min(valid_ends) # Take the first end marker after start
310
+ # Only include content between markers (not the markers themselves)
311
+ if start_pos < end_pos:
312
+ assistant_sections.append((start_pos, end_pos))
313
+ else:
314
+ # If no end marker, assume the section runs to the end of the sequence
315
+ assistant_sections.append((start_pos, seq_len))
316
+
317
+ # Set labels for all identified assistant sections
318
+ for start_pos, end_pos in assistant_sections:
319
+ if start_pos < end_pos and start_pos < seq_len:
320
+ end_pos = min(end_pos, seq_len) # Safety check
321
+ labels[i, start_pos:end_pos] = input_ids[start_pos:end_pos]
322
+
323
+ # Also mask padding tokens
324
+ labels[batch["input_ids"] == processor.tokenizer.pad_token_id] = -100
325
+
326
+ # Add labels to batch
327
+ batch["labels"] = labels
328
+
329
+ # Add answer to batch
330
+ if return_answer_in_batch:
331
+ batch["answer"] = [example["answer"].strip() for example in examples]
332
+
333
+ return batch
334
+
335
+
336
+ def dna_collate_fn(
337
+ batch: List[Dict[str, Any]],
338
+ dna_tokenizer: Any,
339
+ label2id: Dict[str, int],
340
+ max_length: int = 2048,
341
+ ) -> Dict[str, Any]:
342
+ """
343
+ Custom collate function for DNA models.
344
+ """
345
+ ref_sequences = [item["reference_sequence"] for item in batch]
346
+ alt_sequences = [item["variant_sequence"] for item in batch]
347
+
348
+ # Tokenize DNA sequences separately
349
+ tokenized_ref = dna_tokenizer(
350
+ ref_sequences,
351
+ padding=True,
352
+ truncation=True,
353
+ max_length=max_length,
354
+ return_tensors="pt",
355
+ )
356
+
357
+ tokenized_alt = dna_tokenizer(
358
+ alt_sequences,
359
+ padding=True,
360
+ truncation=True,
361
+ max_length=max_length,
362
+ return_tensors="pt",
363
+ )
364
+
365
+ # Get labels
366
+ labels = []
367
+ for item in batch:
368
+ label = label2id[item["answer"]]
369
+ labels.append(label)
370
+
371
+ # Create labels tensor
372
+ labels_tensor = torch.tensor(labels, dtype=torch.long)
373
+
374
+ tokenized_batch = {
375
+ "ref_ids": tokenized_ref.input_ids,
376
+ "ref_attention_mask": tokenized_ref.attention_mask,
377
+ "alt_ids": tokenized_alt.input_ids,
378
+ "alt_attention_mask": tokenized_alt.attention_mask,
379
+ "labels": labels_tensor,
380
+ }
381
+
382
+ return tokenized_batch
BioReason-main/bioreason/dataset/utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import Dataset as HFDataset
2
+ from torch.utils.data import Dataset as TorchDataset
3
+ from typing import Dict, Any, Union, List
4
+
5
+
6
+ def truncate_dna(
7
+ example: Dict[str, Any], truncate_dna_per_side: int = 1024
8
+ ) -> Dict[str, Any]:
9
+ """
10
+ Truncate DNA sequences by removing a specified number of base pairs from both ends.
11
+ If the sequence is too short, it will return the middle portion.
12
+ """
13
+ for key in ["reference_sequence", "variant_sequence"]:
14
+ sequence = example[key]
15
+ seq_len = len(sequence)
16
+
17
+ if seq_len > 2 * truncate_dna_per_side + 8:
18
+ example[key] = sequence[truncate_dna_per_side:-truncate_dna_per_side]
19
+
20
+ return example
21
+
22
+
23
+ def torch_to_hf_dataset(torch_dataset: TorchDataset) -> HFDataset:
24
+ """
25
+ Convert a PyTorch Dataset to a Hugging Face Dataset.
26
+
27
+ This function takes a PyTorch Dataset and converts it to a Hugging Face Dataset
28
+ by extracting all items and organizing them into a dictionary structure that
29
+ can be used to create a Hugging Face Dataset.
30
+
31
+ Args:
32
+ torch_dataset: A PyTorch Dataset object to be converted
33
+
34
+ Returns:
35
+ A Hugging Face Dataset containing the same data as the input PyTorch Dataset
36
+ """
37
+ # Get first item to determine structure
38
+ if len(torch_dataset) == 0:
39
+ return HFDataset.from_dict({})
40
+
41
+ first_item = torch_dataset[0]
42
+
43
+ # Initialize dictionary based on first item's keys
44
+ data_dict = (
45
+ {k: [] for k in first_item.keys()}
46
+ if isinstance(first_item, dict)
47
+ else {"data": []}
48
+ )
49
+
50
+ # Populate dictionary
51
+ for i in range(len(torch_dataset)):
52
+ item = torch_dataset[i]
53
+ if isinstance(item, dict):
54
+ for k in data_dict:
55
+ data_dict[k].append(item[k])
56
+ else:
57
+ data_dict["data"].append(item)
58
+
59
+ return HFDataset.from_dict(data_dict)
BioReason-main/bioreason/dataset/variant_effect.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+ import sys
5
+ import torch
6
+ from torch.utils.data import Dataset, DataLoader
7
+ from typing import Any, Dict, List, Tuple
8
+
9
+ from bioreason.dataset.utils import torch_to_hf_dataset
10
+ from bioreason.models.dl.processing_dl import DLProcessor
11
+ from trl.data_utils import maybe_apply_chat_template
12
+
13
+
14
+ def get_format_variant_effect_function(model_name: str) -> Any:
15
+ """
16
+ Get the appropriate format function for a given model name.
17
+ """
18
+ if model_name.lower() == "llm":
19
+ return format_variant_effect_for_llm
20
+ elif model_name.lower() == "dna-llm":
21
+ return format_variant_effect_for_dna_llm
22
+ else:
23
+ raise ValueError(f"Unsupported model name: {model_name}")
24
+
25
+
26
+ def clean_variant_effect_example(example: Dict[str, Any]) -> Dict[str, Any]:
27
+ """
28
+ Clean a variant effect example.
29
+ """
30
+ example['answer'] = example['answer'].split(";")[0].strip().lower()
31
+ return example
32
+
33
+
34
+ def clean_variant_effect_non_snv_example(example: Dict[str, Any]) -> Dict[str, Any]:
35
+ """
36
+ Clean a variant effect non-SNV example.
37
+ """
38
+ example['answer'] = example['answer'].replace("[", "").replace("]", "").replace("'", "").replace("_", " ").strip()
39
+ return example
40
+
41
+
42
+ def format_variant_effect_for_dna_llm(example: Dict[str, Any]) -> Dict[str, Any]:
43
+ """
44
+ Format a VEP example into the required chat format for DNA-LLM.
45
+ """
46
+ return {
47
+ "prompt": [
48
+ {
49
+ "role": "user",
50
+ "content": [
51
+ *({"type": "dna", "text": None} for _ in range(2)),
52
+ {"type": "text", "text": example["question"].strip()},
53
+ ],
54
+ },
55
+ {
56
+ "role": "assistant",
57
+ "reasoning_content": f"Answer: {example['answer'].strip()}",
58
+ "content": [
59
+ {"type": "text", "text": f"Answer: {example['answer'].strip()}"},
60
+ ],
61
+ },
62
+ ],
63
+ "dna_sequences": [
64
+ example["reference_sequence"],
65
+ example["variant_sequence"],
66
+ ],
67
+ "answer": example["answer"].strip(),
68
+ }
69
+
70
+
71
+ def format_variant_effect_for_llm(example: Dict[str, Any]) -> Dict[str, Any]:
72
+ """
73
+ Format a VEP example into the required chat format for LLM.
74
+ """
75
+ question = f"Reference sequence: {example['reference_sequence']}\nVariant sequence: {example['variant_sequence']}\nQuestion: {example['question']}"
76
+ return {
77
+ "prompt": [
78
+ {
79
+ "role": "user",
80
+ "content": [
81
+ *({"type": "dna", "text": None} for _ in range(2)),
82
+ {"type": "text", "text": question.strip()},
83
+ ],
84
+ },
85
+ {
86
+ "role": "assistant",
87
+ "reasoning_content": f"Answer: {example['answer'].strip()}",
88
+ "content": [
89
+ {"type": "text", "text": f"Answer: {example['answer'].strip()}"},
90
+ ],
91
+ },
92
+ ],
93
+ "dna_sequences": [
94
+ "",
95
+ "",
96
+ ],
97
+ "answer": example["answer"].strip(),
98
+ }
BioReason-main/bioreason/dna_modules/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .dna_module import DNABaseModule
2
+ from .nucleotide_module import NucleotideDNAModule
3
+
4
+ __all__ = ["DNABaseModule", "NucleotideDNAModule"]
BioReason-main/bioreason/dna_modules/dna_module.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, Any, Union
3
+ import torch
4
+
5
+ class DNABaseModule(ABC):
6
+ def __init__(self):
7
+ super().__init__()
8
+
9
+ @abstractmethod
10
+ def get_dnallm_key(self):
11
+ pass
12
+
13
+ @abstractmethod
14
+ def get_model_class(self, model_id: str, model_init_kwargs: dict):
15
+ pass
16
+
17
+ def post_model_init(self, model, processing_class):
18
+ pass
19
+
20
+ def is_embeds_input(self):
21
+ return False
22
+
23
+ @abstractmethod
24
+ def get_processing_class(self):
25
+ pass
26
+
27
+ @abstractmethod
28
+ def get_dnallm_modules_keywords(self):
29
+ pass
30
+
31
+ @abstractmethod
32
+ def get_custom_multimodal_keywords(self):
33
+ pass
34
+
35
+ @abstractmethod
36
+ def get_non_generate_params(self):
37
+ pass
38
+
39
+ @abstractmethod
40
+ def get_custom_processing_keywords(self):
41
+ pass
42
+
43
+ @abstractmethod
44
+ def prepare_prompt(self, processing_class, inputs: dict[str, Union[torch.Tensor, Any]]):
45
+ pass
46
+
47
+ @abstractmethod
48
+ def prepare_model_inputs(self, processing_class, prompts_text, images, return_tensors, padding, padding_side, add_special_tokens):
49
+ pass
BioReason-main/bioreason/dna_modules/nucleotide_module.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import (
2
+ Qwen2_5_VLForConditionalGeneration,
3
+ Qwen2VLForConditionalGeneration,
4
+ AutoProcessor,
5
+ )
6
+ from typing import Dict, Any, Union, List, Optional, Callable, Type
7
+ from trl.data_utils import maybe_apply_chat_template
8
+ from trl import SFTTrainer
9
+ import torch
10
+
11
+ from bioreason.dna_modules.dna_module import DNABaseModule
12
+ from bioreason.models.dna_llm import DNALLMModel
13
+ from bioreason.models.dl.processing_dl import DLProcessor
14
+
15
+
16
+ class NucleotideDNAModule(DNABaseModule):
17
+ """
18
+ DNA module implementation for NucleotideTransformer-based models.
19
+
20
+ This module provides the interface between DNA-LLM models and the training
21
+ infrastructure, handling model loading, processing setup, and reward functions.
22
+ """
23
+
24
+ def __init__(self):
25
+ """Initialize the NucleotideDNAModule."""
26
+ super().__init__()
27
+
28
+ def get_dnallm_key(self) -> str:
29
+ """
30
+ Get the key identifier for this DNA-LLM implementation.
31
+
32
+ Returns:
33
+ String identifier for this module type
34
+ """
35
+ return "qwen"
36
+
37
+ def get_model_class(self, model_id: str, model_init_kwargs: Dict[str, Any]) -> Type:
38
+ """
39
+ Return the appropriate model class based on model ID.
40
+
41
+ Args:
42
+ model_id: Identifier for the model
43
+ model_init_kwargs: Initialization arguments for the model
44
+
45
+ Returns:
46
+ The model class to instantiate
47
+
48
+ Raises:
49
+ ValueError: If the model is not supported
50
+ """
51
+ if "DNALLM" in model_id:
52
+ model_cls = DNALLMModel
53
+ else:
54
+ raise ValueError(f"Unsupported model: {model_id}")
55
+ return model_cls
56
+
57
+ def post_model_init(self, model: Any, processing_class: Any) -> None:
58
+ """
59
+ Perform any post-initialization setup on the model.
60
+
61
+ Args:
62
+ model: The initialized model
63
+ processing_class: The processor for the model
64
+ """
65
+ # No post-init needed for this implementation
66
+ pass
67
+
68
+ def get_processing_class(self) -> Type:
69
+ """
70
+ Get the processing class to use with this DNA-LLM model.
71
+
72
+ Returns:
73
+ The processing class
74
+ """
75
+ return DLProcessor
76
+
77
+ def get_dnallm_modules_keywords(self) -> List[str]:
78
+ """
79
+ Get keywords to identify DNA-specific modules in the model.
80
+
81
+ Used to exclude DNA modules from LoRA adaptation during training.
82
+
83
+ Returns:
84
+ List of keywords that identify DNA modules
85
+ """
86
+ return ["dna"]
87
+
88
+ def get_custom_multimodal_keywords(self) -> List[str]:
89
+ """
90
+ Get keywords for multimodal inputs that should be passed to the model.
91
+
92
+ Returns:
93
+ List of input keywords for multimodal processing
94
+ """
95
+ return ["dna_tokenized", "batch_idx_map"]
96
+
97
+ def get_non_generate_params(self) -> List[str]:
98
+ """
99
+ Get parameter names that should be excluded from generation.
100
+
101
+ Returns:
102
+ List of parameter names to exclude from generation calls
103
+ """
104
+ return []
105
+
106
+ def get_custom_processing_keywords(self) -> List[tuple]:
107
+ """
108
+ Get custom processing keywords for the processor.
109
+
110
+ Returns:
111
+ List of (component, parameter) tuples for custom processing
112
+ """
113
+ return [("dna_tokenizer", "max_length")]
114
+
115
+ def prepare_prompt(
116
+ self, processing_class: Any, inputs: List[Dict[str, Union[torch.Tensor, Any]]]
117
+ ) -> List[str]:
118
+ """
119
+ Prepare prompts from input examples.
120
+
121
+ Args:
122
+ processing_class: The processor to use
123
+ inputs: List of input examples
124
+
125
+ Returns:
126
+ List of prepared prompts
127
+ """
128
+ prompts_text = [
129
+ maybe_apply_chat_template(example, processing_class)["prompt"]
130
+ for example in inputs
131
+ ]
132
+ return prompts_text
133
+
134
+ def prepare_model_inputs(
135
+ self,
136
+ processing_class: Any,
137
+ model: Any,
138
+ prompts_text: List[str],
139
+ batch_dna_sequences: List[List[str]],
140
+ return_tensors: str = "pt",
141
+ padding: bool = True,
142
+ padding_side: str = "left",
143
+ add_special_tokens: bool = False,
144
+ ) -> Dict[str, Any]:
145
+ """
146
+ Prepare inputs for the model.
147
+
148
+ Args:
149
+ processing_class: The processor to use
150
+ model: The model to prepare inputs for
151
+ prompts_text: List of text prompts
152
+ batch_dna_sequences: List of lists of DNA sequences
153
+ return_tensors: Return format for tensors
154
+ padding: Whether to pad inputs
155
+ padding_side: Side to pad on
156
+ add_special_tokens: Whether to add special tokens
157
+
158
+ Returns:
159
+ Processed inputs for the model
160
+ """
161
+ # Handle DataParallel wrapped models by accessing the module attribute if needed
162
+ max_length_text = model.max_length_text if not hasattr(model, 'module') else model.module.max_length_text
163
+ max_length_dna = model.max_length_dna if not hasattr(model, 'module') else model.module.max_length_dna
164
+
165
+ prompt_inputs = processing_class(
166
+ text=prompts_text,
167
+ batch_dna_sequences=batch_dna_sequences,
168
+ return_tensors=return_tensors,
169
+ padding=padding,
170
+ padding_side=padding_side,
171
+ add_special_tokens=add_special_tokens,
172
+ max_length_text=max_length_text,
173
+ max_length_dna=max_length_dna,
174
+ )
175
+
176
+ return prompt_inputs
177
+
178
+ def is_embeds_input(self) -> bool:
179
+ """
180
+ Whether the model uses embeddings as input (instead of token IDs).
181
+
182
+ Returns:
183
+ Boolean indicating if the model takes embedding inputs
184
+ """
185
+ return True
186
+
187
+ @staticmethod
188
+ def get_question_template() -> str:
189
+ """
190
+ Get the template for formatting questions.
191
+
192
+ Returns:
193
+ String template for questions
194
+ """
195
+ return "{Question}"
196
+
197
+ @staticmethod
198
+ def format_reward_rec(completions: List[Dict[str, Any]], **kwargs) -> List[float]:
199
+ """
200
+ Check if the Qwen model output matches a specific format.
201
+
202
+ Args:
203
+ completions: List of model completions
204
+ **kwargs: Additional arguments
205
+
206
+ Returns:
207
+ List of reward scores (1.0 for match, 0.0 for no match)
208
+ """
209
+ import re
210
+ import os
211
+ from datetime import datetime
212
+
213
+ # Pattern to match the expected output format
214
+ pattern = r"<think>.*?</think>\s*<answer>.*?\{.*\[\d+,\s*\d+,\s*\d+,\s*\d+\].*\}.*?</answer>"
215
+ completion_contents = [completion[0]["content"] for completion in completions]
216
+ matches = [
217
+ re.search(pattern, content, re.DOTALL) is not None
218
+ for content in completion_contents
219
+ ]
220
+
221
+ # Log format results if in debug mode
222
+ current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
223
+ if os.getenv("DEBUG_MODE") == "true":
224
+ log_path = os.getenv("LOG_PATH")
225
+ with open(
226
+ log_path.replace(".txt", "_format.txt"), "a", encoding="utf-8"
227
+ ) as f:
228
+ f.write(f"------------- {current_time} Format reward -------------\n")
229
+ for content, match in zip(completion_contents, matches):
230
+ f.write(f"Content: {content}\n")
231
+ f.write(f"Has format: {bool(match)}\n")
232
+
233
+ return [1.0 if match else 0.0 for match in matches]
234
+
235
+ @staticmethod
236
+ def select_reward_func(func: str, task_type: str) -> Callable:
237
+ """
238
+ Select the appropriate reward function based on function name and task type.
239
+
240
+ Args:
241
+ func: The type of reward function ('accuracy', 'format', etc.)
242
+ task_type: The type of task ('rec', etc.)
243
+
244
+ Returns:
245
+ The reward function to use
246
+
247
+ Raises:
248
+ ValueError: If the function or task type is not supported
249
+ """
250
+ if func == "accuracy":
251
+ match task_type:
252
+ case "rec":
253
+ return NucleotideDNAModule.iou_reward
254
+ case _:
255
+ raise ValueError(f"Unsupported reward function: {func}")
256
+ elif func == "format":
257
+ match task_type:
258
+ case "rec":
259
+ return NucleotideDNAModule.format_reward_rec
260
+ case _:
261
+ raise ValueError(f"Unsupported reward function: {func}")
262
+ else:
263
+ raise ValueError(f"Unsupported reward function: {func}")
BioReason-main/bioreason/models/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .dna_only import DNAClassifierModel
2
+ from .dna_llm import DNALLMModel
3
+ from .evo2_tokenizer import Evo2Tokenizer
4
+
5
+ __all__ = [
6
+ "DNAClassifierModel",
7
+ "DNALLMModel",
8
+ "Evo2Tokenizer",
9
+ ]
BioReason-main/bioreason/models/dl/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
BioReason-main/bioreason/models/dl/chat_template_dl.py ADDED
@@ -0,0 +1 @@
 
 
1
+ CHAT_TEMPLATE = "{%- set dna_count = namespace(value=0) %}{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and not(message.content is string and message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' }} {%- if message.content is string %}{{- message.content + '<|im_end|>' + '\\n' }}{%- else %}{%- for content in message.content %}{%- if content.type == 'dna' or 'dna' in content %}{%- set dna_count.value = dna_count.value + 1 %}{%- if add_dna_id %}DNA Sequence {{- dna_count.value }}: {%- endif %}<|dna_start|><|dna_pad|><|dna_end|>{%- elif 'text' in content %}{{- content.text }}{%- endif %}{%- endfor %}{{- '<|im_end|>' + '\\n' }}{%- endif %}{%- elif message.role == \"assistant\" %}\n {%- set content = message.content[0].text %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is defined and message.reasoning_content is not none %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '</think>' in message.content %}\n {%- set content = message.content[0].text.split('</think>')[-1].lstrip('\\n') %}\n {%- set reasoning_content = message.content[0].text.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}"
BioReason-main/bioreason/models/dl/configuration_dl.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class DLDNAConfig(PretrainedConfig):
4
+ model_type = "dl"
5
+ base_config_key = "dna_config"
6
+
7
+ def __init__(
8
+ self,
9
+ depth=32,
10
+ hidden_size=3584,
11
+ hidden_act="silu",
12
+ intermediate_size=3420,
13
+ num_heads=16,
14
+ in_channels=3,
15
+ patch_size=14,
16
+ spatial_merge_size=2,
17
+ temporal_patch_size=2,
18
+ tokens_per_second=4,
19
+ window_size=112,
20
+ out_hidden_size=3584,
21
+ fullatt_block_indexes=[7, 15, 23, 31],
22
+ **kwargs,
23
+ ):
24
+ super().__init__(**kwargs)
25
+
26
+ self.depth = depth
27
+ self.hidden_size = hidden_size
28
+ self.hidden_act = hidden_act
29
+ self.intermediate_size = intermediate_size
30
+ self.num_heads = num_heads
31
+ self.in_channels = in_channels
32
+ self.patch_size = patch_size
33
+ self.spatial_merge_size = spatial_merge_size
34
+ self.temporal_patch_size = temporal_patch_size
35
+ self.tokens_per_second = tokens_per_second
36
+ self.window_size = window_size
37
+ self.fullatt_block_indexes = fullatt_block_indexes
38
+ self.out_hidden_size = out_hidden_size
39
+
40
+ class DLConfig(PretrainedConfig):
41
+ r"""
42
+ This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a
43
+ Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
44
+ with the defaults will yield a similar configuration to that of
45
+ Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
46
+
47
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
48
+ documentation from [`PretrainedConfig`] for more information.
49
+
50
+
51
+ Args:
52
+ vocab_size (`int`, *optional*, defaults to 152064):
53
+ Vocabulary size of the Qwen2_5_VL model. Defines the number of different tokens that can be represented by the
54
+ `inputs_ids` passed when calling [`Qwen2_5_VLModel`]
55
+ hidden_size (`int`, *optional*, defaults to 8192):
56
+ Dimension of the hidden representations.
57
+ intermediate_size (`int`, *optional*, defaults to 29568):
58
+ Dimension of the MLP representations.
59
+ num_hidden_layers (`int`, *optional*, defaults to 80):
60
+ Number of hidden layers in the Transformer encoder.
61
+ num_attention_heads (`int`, *optional*, defaults to 64):
62
+ Number of attention heads for each attention layer in the Transformer encoder.
63
+ num_key_value_heads (`int`, *optional*, defaults to 8):
64
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
65
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
66
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
67
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
68
+ by meanpooling all the original heads within that group. For more details checkout [this
69
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
70
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
71
+ The non-linear activation function (function or string) in the decoder.
72
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
73
+ The maximum sequence length that this model might ever be used with.
74
+ initializer_range (`float`, *optional*, defaults to 0.02):
75
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
76
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
77
+ The epsilon used by the rms normalization layers.
78
+ use_cache (`bool`, *optional*, defaults to `True`):
79
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
80
+ relevant if `config.is_decoder=True`.
81
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
82
+ Whether the model's input and output word embeddings should be tied.
83
+ rope_theta (`float`, *optional*, defaults to 1000000.0):
84
+ The base period of the RoPE embeddings.
85
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
86
+ Whether to use sliding window attention.
87
+ sliding_window (`int`, *optional*, defaults to 4096):
88
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
89
+ max_window_layers (`int`, *optional*, defaults to 80):
90
+ The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
91
+ attention_dropout (`float`, *optional*, defaults to 0.0):
92
+ The dropout ratio for the attention probabilities.
93
+ vision_config (`Dict`, *optional*):
94
+ The config for the visual encoder initialization.
95
+ rope_scaling (`Dict`, *optional*):
96
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
97
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
98
+ accordingly.
99
+ Expected contents:
100
+ `rope_type` (`str`):
101
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
102
+ 'llama3'], with 'default' being the original RoPE implementation.
103
+ `factor` (`float`, *optional*):
104
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
105
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
106
+ original maximum pre-trained length.
107
+ `original_max_position_embeddings` (`int`, *optional*):
108
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
109
+ pretraining.
110
+ `attention_factor` (`float`, *optional*):
111
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
112
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
113
+ `factor` field to infer the suggested value.
114
+ `beta_fast` (`float`, *optional*):
115
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
116
+ ramp function. If unspecified, it defaults to 32.
117
+ `beta_slow` (`float`, *optional*):
118
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
119
+ ramp function. If unspecified, it defaults to 1.
120
+ `short_factor` (`List[float]`, *optional*):
121
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
122
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
123
+ size divided by the number of attention heads divided by 2
124
+ `long_factor` (`List[float]`, *optional*):
125
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
126
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
127
+ size divided by the number of attention heads divided by 2
128
+ `low_freq_factor` (`float`, *optional*):
129
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
130
+ `high_freq_factor` (`float`, *optional*):
131
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
132
+
133
+ ```python
134
+ >>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig
135
+
136
+ >>> # Initializing a Qwen2_5_VL style configuration
137
+ >>> configuration = Qwen2_5_VLConfig()
138
+
139
+ >>> # Initializing a model from the Qwen2-VL-7B style configuration
140
+ >>> model = Qwen2_5_VLForConditionalGeneration(configuration)
141
+
142
+ >>> # Accessing the model configuration
143
+ >>> configuration = model.config
144
+ ```"""
145
+
146
+ model_type = "dl"
147
+ sub_configs = {"dna_config": DLDNAConfig}
148
+ keys_to_ignore_at_inference = ["past_key_values"]
149
+ # Default tensor parallel plan for base model `Qwen2_5_VL`
150
+ base_model_tp_plan = {
151
+ "layers.*.self_attn.q_proj": "colwise",
152
+ "layers.*.self_attn.k_proj": "colwise",
153
+ "layers.*.self_attn.v_proj": "colwise",
154
+ "layers.*.self_attn.o_proj": "rowwise",
155
+ "layers.*.mlp.gate_proj": "colwise",
156
+ "layers.*.mlp.up_proj": "colwise",
157
+ "layers.*.mlp.down_proj": "rowwise",
158
+ }
159
+ base_model_pp_plan = {
160
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
161
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
162
+ "norm": (["hidden_states"], ["hidden_states"]),
163
+ }
164
+
165
+ def __init__(
166
+ self,
167
+ vocab_size=152064,
168
+ hidden_size=8192,
169
+ intermediate_size=29568,
170
+ num_hidden_layers=80,
171
+ num_attention_heads=64,
172
+ num_key_value_heads=8,
173
+ hidden_act="silu",
174
+ max_position_embeddings=32768,
175
+ initializer_range=0.02,
176
+ rms_norm_eps=1e-05,
177
+ use_cache=True,
178
+ tie_word_embeddings=False,
179
+ rope_theta=1000000.0,
180
+ use_sliding_window=False,
181
+ sliding_window=4096,
182
+ max_window_layers=80,
183
+ attention_dropout=0.0,
184
+ vision_config=None,
185
+ rope_scaling=None,
186
+ image_token_id=None,
187
+ **kwargs,
188
+ ):
189
+ if isinstance(vision_config, dict):
190
+ self.vision_config = self.sub_configs["vision_config"](**vision_config)
191
+ elif vision_config is None:
192
+ self.vision_config = self.sub_configs["vision_config"]()
193
+
194
+ self.vocab_size = vocab_size
195
+ self.max_position_embeddings = max_position_embeddings
196
+ self.hidden_size = hidden_size
197
+ self.intermediate_size = intermediate_size
198
+ self.num_hidden_layers = num_hidden_layers
199
+ self.num_attention_heads = num_attention_heads
200
+ self.use_sliding_window = use_sliding_window
201
+ self.sliding_window = sliding_window
202
+ self.max_window_layers = max_window_layers
203
+
204
+ # for backward compatibility
205
+ if num_key_value_heads is None:
206
+ num_key_value_heads = num_attention_heads
207
+
208
+ self.num_key_value_heads = num_key_value_heads
209
+ self.hidden_act = hidden_act
210
+ self.initializer_range = initializer_range
211
+ self.rms_norm_eps = rms_norm_eps
212
+ self.use_cache = use_cache
213
+ self.rope_theta = rope_theta
214
+ self.attention_dropout = attention_dropout
215
+ self.rope_scaling = rope_scaling
216
+
217
+ self.dna_token_id = image_token_id
218
+
219
+ # Validate the correctness of rotary position embeddings parameters
220
+ # BC: if there is a 'type' field, move it to 'rope_type'.
221
+ # and change type from 'mrope' to 'default' because `mrope` does default RoPE calculations
222
+ # one can set it to "linear"/"dynamic" etc. to have scaled RoPE
223
+ # TODO: @raushan update config in the hub
224
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
225
+ if self.rope_scaling["type"] == "mrope":
226
+ self.rope_scaling["type"] = "default"
227
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
228
+ rope_config_validation(self, ignore_keys={"mrope_section"})
229
+
230
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
231
+
232
+ __all__ = ["DLConfig"]
BioReason-main/bioreason/models/dl/processing_dl.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union, Dict, Any, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+
7
+ from transformers import AutoTokenizer
8
+ from transformers.processing_utils import (
9
+ CommonKwargs,
10
+ ProcessingKwargs,
11
+ ProcessorMixin,
12
+ Unpack,
13
+ )
14
+ from transformers.feature_extraction_utils import BatchFeature
15
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
16
+ from transformers.utils import logging
17
+
18
+ from bioreason.utils.dna_utils import DNAInput
19
+
20
+ class DLDNAKwargs(CommonKwargs):
21
+ """Keyword arguments specific to DNA processing"""
22
+ max_length_text: Optional[int]
23
+ max_length_dna: Optional[int]
24
+
25
+
26
+ class DLProcessorKwargs(ProcessingKwargs, total=False):
27
+ """Processing keyword arguments for the DL processor"""
28
+ dna_kwargs: DLDNAKwargs
29
+ _defaults = {
30
+ "text_kwargs": {
31
+ "padding": False,
32
+ },
33
+ }
34
+
35
+ class DLProcessor(ProcessorMixin):
36
+ r"""
37
+ Constructs a DL processor which wraps a NucleotideTransformer DNA processor and a Qwen2_5 tokenizer into a single processor.
38
+ This processor handles both text and DNA sequence processing to prepare inputs for the DNALLMModel.
39
+
40
+ Args:
41
+ tokenizer (PreTrainedTokenizerBase, *optional*):
42
+ The text tokenizer used for processing text inputs.
43
+ dna_tokenizer (PreTrainedTokenizerBase, *optional*):
44
+ The DNA tokenizer used for processing DNA sequences.
45
+ chat_template (`str`, *optional*):
46
+ A Jinja template for chat formatting. If None, will use the tokenizer's template.
47
+ """
48
+
49
+ attributes = ["tokenizer", "dna_tokenizer"]
50
+ valid_kwargs = ["model", "chat_template"]
51
+ tokenizer_class = (
52
+ "Qwen2Tokenizer", "Qwen2TokenizerFast",
53
+ "GPT2TokenizerFast",
54
+ )
55
+ dna_tokenizer_class = ("EsmTokenizer", "Evo2Tokenizer")
56
+
57
+ def __init__(
58
+ self, tokenizer=None, dna_tokenizer=None, chat_template=None, **kwargs
59
+ ):
60
+ """
61
+ Initialize the processor with text and DNA tokenizers.
62
+
63
+ Args:
64
+ tokenizer: Text tokenizer (usually from a language model)
65
+ dna_tokenizer: DNA tokenizer (usually from a DNA model)
66
+ chat_template: Template for formatting chat conversations
67
+ **kwargs: Additional arguments
68
+ """
69
+ self.tokenizer = tokenizer
70
+ self.dna_tokenizer = dna_tokenizer
71
+
72
+ self.dna_token = (
73
+ "<|dna_pad|>"
74
+ if not hasattr(self.tokenizer, "dna_token")
75
+ else self.tokenizer.dna_token
76
+ )
77
+
78
+ # Get chat template from tokenizer if not provided
79
+ if chat_template is None and hasattr(self.tokenizer, "chat_template"):
80
+ chat_template = self.tokenizer.chat_template
81
+ super().__init__(tokenizer, dna_tokenizer, chat_template=chat_template)
82
+
83
+ # The GRPO trainer might expect this to be set
84
+ if not hasattr(self.tokenizer, 'pad_token') or self.tokenizer.pad_token is None:
85
+ self.tokenizer.pad_token = self.tokenizer.eos_token
86
+
87
+ def tokenize_dna_sequences(
88
+ self,
89
+ batch_dna_sequences: List[List[str]],
90
+ max_length: int = 2048,
91
+ return_tensors: str = "pt",
92
+ device: str = "cuda",
93
+ ) -> Dict[str, Any]:
94
+ """
95
+ Tokenize a batch of DNA sequences.
96
+
97
+ Args:
98
+ batch_dna_sequences: List of lists of DNA sequences per batch item
99
+ max_length: Maximum allowed length for DNA sequences
100
+ return_tensors: Return format for tensors ("pt" for PyTorch)
101
+ device: Device to place tensors on
102
+
103
+ Returns:
104
+ Dict containing:
105
+ - dna_tokenized: The tokenized DNA sequences
106
+ - batch_idx_map: Mapping of which sequences belong to which batch item
107
+ """
108
+ # Create a mapping to track which sequences belong to which batch item
109
+ batch_idx_map = []
110
+ all_sequences = []
111
+
112
+ # Flatten all sequences with batch tracking
113
+ for batch_idx, dna_sequences in enumerate(batch_dna_sequences):
114
+ for seq in dna_sequences:
115
+ all_sequences.append(seq)
116
+ batch_idx_map.append(batch_idx)
117
+
118
+ # If no sequences in the entire batch, return empty dict
119
+ if not all_sequences:
120
+ return {"dna_tokenized": None, "batch_idx_map": []}
121
+
122
+ # Tokenize all sequences at once
123
+ dna_tokenized = self.dna_tokenizer(
124
+ all_sequences,
125
+ padding=True,
126
+ truncation=True,
127
+ max_length=max_length,
128
+ return_tensors=return_tensors,
129
+ return_attention_mask=True,
130
+ )
131
+
132
+ return {"dna_tokenized": dna_tokenized, "batch_idx_map": batch_idx_map}
133
+
134
+ def __call__(
135
+ self,
136
+ batch_dna_sequences: Optional[List[List[str]]] = None,
137
+ text: Optional[
138
+ Union[
139
+ TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]
140
+ ]
141
+ ] = None,
142
+ max_length_text: int = 512,
143
+ max_length_dna: int = 2048,
144
+ return_tensors: str = "pt",
145
+ device: str = "cuda",
146
+ **kwargs: Unpack[DLProcessorKwargs],
147
+ ) -> BatchFeature:
148
+ """
149
+ Process text and DNA sequences for model input.
150
+
151
+ Args:
152
+ batch_dna_sequences: List of lists of DNA sequences per batch item
153
+ text: Input text or list of texts
154
+ max_length_text: Maximum length for text sequences
155
+ max_length_dna: Maximum length for DNA sequences
156
+ return_tensors: Return format for tensors
157
+ device: Device to place tensors on
158
+ **kwargs: Additional processor keyword arguments
159
+
160
+ Returns:
161
+ BatchFeature with tokenized inputs for the model
162
+ """
163
+ output_kwargs = self._merge_kwargs(
164
+ DLProcessorKwargs,
165
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
166
+ **kwargs,
167
+ )
168
+
169
+ # Ensure text is a list
170
+ if not isinstance(text, list):
171
+ text = [text]
172
+
173
+ # flattened_dna_sequences = [dna_sequence for dna_sequences in batch_dna_sequences for dna_sequence in dna_sequences]
174
+ dna_inputs = {}
175
+ if batch_dna_sequences is not None:
176
+ # Tokenize DNA sequences
177
+ dna_processing_result = self.tokenize_dna_sequences(
178
+ batch_dna_sequences,
179
+ max_length=max_length_dna,
180
+ return_tensors=return_tensors,
181
+ device=device,
182
+ )
183
+
184
+ # Replace DNA tokens in text if needed
185
+ index = 0
186
+ for i in range(len(text)):
187
+ while self.dna_token in text[i]:
188
+ num_dna_tokens = (dna_processing_result['dna_tokenized']['input_ids'][index] != 1).sum().item()
189
+ text[i] = text[i].replace(
190
+ self.dna_token, "<|placeholder|>" * num_dna_tokens, 1
191
+ )
192
+ index += 1
193
+ text[i] = text[i].replace("<|placeholder|>", self.dna_token)
194
+
195
+
196
+
197
+ # Add batch info to the output
198
+ dna_inputs = {
199
+ # "batch_dna_sequences": batch_dna_sequences,
200
+ "dna_tokenized": dna_processing_result["dna_tokenized"],
201
+ "batch_idx_map": dna_processing_result["batch_idx_map"],
202
+ }
203
+
204
+ # Tokenize text
205
+ text_kwargs = output_kwargs.get("text_kwargs", {})
206
+
207
+ if 'padding' in text_kwargs:
208
+ del text_kwargs['padding']
209
+
210
+ # print("__call__ (processor):", text)
211
+ text_inputs = self.tokenizer(
212
+ text,
213
+ max_length=max_length_text + 2 * max_length_dna,
214
+ return_tensors=return_tensors,
215
+ padding=True,
216
+ truncation=True,
217
+ **text_kwargs,
218
+ )
219
+
220
+ # The BatchFeature should have all required fields for the model's forward pass
221
+ return BatchFeature(data={**text_inputs, **dna_inputs})
222
+
223
+ def batch_decode(self, *args, **kwargs) -> List[str]:
224
+ """
225
+ This method forwards all its arguments to the tokenizer's batch_decode.
226
+
227
+ Returns:
228
+ List of decoded strings
229
+ """
230
+ return self.tokenizer.batch_decode(*args, **kwargs)
231
+
232
+ def decode(self, *args, **kwargs) -> str:
233
+ """
234
+ This method forwards all its arguments to the tokenizer's decode.
235
+
236
+ Returns:
237
+ Decoded string
238
+ """
239
+ return self.tokenizer.decode(*args, **kwargs)
240
+
241
+ def post_process_dna_to_text(
242
+ self,
243
+ generated_outputs: torch.Tensor,
244
+ skip_special_tokens: bool = True,
245
+ **kwargs,
246
+ ) -> List[str]:
247
+ """
248
+ Post-process the model output to decode the text.
249
+
250
+ Args:
251
+ generated_outputs: The token IDs generated by the model
252
+ skip_special_tokens: Whether to skip special tokens in the output
253
+ **kwargs: Additional arguments for the decoder
254
+
255
+ Returns:
256
+ List of decoded strings
257
+ """
258
+ return self.tokenizer.batch_decode(
259
+ generated_outputs,
260
+ skip_special_tokens=skip_special_tokens,
261
+ **kwargs,
262
+ )
263
+
264
+ @property
265
+ def model_input_names(self) -> List[str]:
266
+ """
267
+ Get the input names expected by the model.
268
+
269
+ Returns:
270
+ List of input names
271
+ """
272
+ tokenizer_input_names = self.tokenizer.model_input_names
273
+ dna_input_names = ["dna_tokenized", "batch_idx_map"]
274
+
275
+ return list(dict.fromkeys(tokenizer_input_names + dna_input_names))
BioReason-main/bioreason/models/dna_llm.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from argparse import ArgumentParser
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import (
6
+ AutoTokenizer,
7
+ AutoModelForCausalLM,
8
+ AutoModelForMaskedLM,
9
+ )
10
+
11
+ from typing import Optional, List, Dict, Any, Union, Tuple
12
+
13
+ from bioreason.utils.dna_utils import DNAInput
14
+ from bioreason.models.dl.processing_dl import DLProcessor
15
+ from bioreason.models.dl.chat_template_dl import CHAT_TEMPLATE
16
+ from bioreason.models.evo2_tokenizer import Evo2Tokenizer
17
+
18
+ class DNALLMModel(nn.Module):
19
+ """
20
+ A combined model that processes both DNA sequences and text inputs.
21
+
22
+ The model uses a DNA encoder (like NucleotideTransformer) to extract features from DNA sequences
23
+ and a text model (LLM) to process text inputs and generate responses. The DNA features are
24
+ projected to the text model's embedding space and prepended to the text embeddings.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ text_model_name: str,
30
+ dna_model_name: str,
31
+ cache_dir: Optional[str] = None,
32
+ max_length_dna: int = 2048,
33
+ max_length_text: int = 512,
34
+ text_model_finetune: bool = True,
35
+ dna_model_finetune: bool = True,
36
+ dna_is_evo2: bool = False,
37
+ dna_embedding_layer: str = None
38
+ ):
39
+ """
40
+ Initialize the DNALLMModel.
41
+
42
+ Args:
43
+ text_model_name: Name of the text model to be used.
44
+ dna_model_name: Name of the DNA model to be used.
45
+ cache_dir: Directory to cache the models.
46
+ max_length_dna: Maximum length of DNA sequences. Defaults to 2048.
47
+ max_length_text: Maximum length of text sequences. Defaults to 512.
48
+ text_model_finetune: Whether to finetune the text model. Defaults to True.
49
+ dna_model_finetune: Whether to finetune the DNA model. Defaults to True.
50
+ dna_is_evo2: Whether the DNA model is Evo2. Defaults to False.
51
+ dna_embedding_layer: Name of the layer to use for the Evo2 model. Defaults to None.
52
+ """
53
+ super().__init__()
54
+
55
+ self.text_model_finetune = text_model_finetune
56
+ self.dna_model_finetune = dna_model_finetune
57
+ self.max_length_dna = max_length_dna
58
+ self.max_length_text = max_length_text
59
+ self.dna_is_evo2 = dna_is_evo2
60
+ self.dna_embedding_layer = dna_embedding_layer
61
+
62
+
63
+ # Load the text model and tokenizer
64
+ self.text_model = AutoModelForCausalLM.from_pretrained(
65
+ text_model_name, cache_dir=cache_dir, trust_remote_code=True
66
+ )
67
+ self.text_tokenizer = AutoTokenizer.from_pretrained(text_model_name, trust_remote_code=True)
68
+ self.text_config = self.text_model.config
69
+ self.text_tokenizer.chat_template = CHAT_TEMPLATE
70
+ self.text_tokenizer.pad_token = self.text_tokenizer.eos_token
71
+
72
+ new_tokens = ["<|dna_start|>", "<|dna_pad|>", "<|dna_end|>"]
73
+ self.text_tokenizer.add_special_tokens({"additional_special_tokens": new_tokens})
74
+ self.dna_token_id = self.text_tokenizer.convert_tokens_to_ids("<|dna_pad|>")
75
+
76
+
77
+ # Load the DNA model and tokenizer
78
+ if not self.dna_is_evo2:
79
+ self.dna_model = AutoModelForMaskedLM.from_pretrained(
80
+ dna_model_name, cache_dir=cache_dir, trust_remote_code=True
81
+ )
82
+ self.dna_tokenizer = AutoTokenizer.from_pretrained(dna_model_name, trust_remote_code=True)
83
+ self.dna_config = self.dna_model.config
84
+
85
+ else:
86
+ from evo2 import Evo2
87
+ self.dna_model = Evo2(dna_model_name)
88
+ self.dna_tokenizer = Evo2Tokenizer(self.dna_model.tokenizer)
89
+ self.dna_config = self.dna_model.model.config
90
+ self.dna_embedding_layer = self.dna_embedding_layer
91
+
92
+ # Get model dimensions
93
+ self.text_hidden_size = self.text_config.hidden_size
94
+ self.dna_hidden_size = self.dna_config.hidden_size
95
+
96
+ # Create projection layer to map DNA embeddings to text model's embedding space
97
+ self.dna_projection = nn.Linear(self.dna_hidden_size, self.text_hidden_size)
98
+
99
+ # Create processor for handling inputs
100
+ self.processor = DLProcessor(tokenizer=self.text_tokenizer, dna_tokenizer=self.dna_tokenizer)
101
+
102
+
103
+ def process_dna_embeddings(
104
+ self,
105
+ dna_tokenized: Dict[str, torch.Tensor],
106
+ batch_idx_map: List[int],
107
+ batch_size: int,
108
+ ) -> List[torch.Tensor]:
109
+ """
110
+ Process DNA sequences to obtain embeddings.
111
+
112
+ Args:
113
+ dna_tokenized: Tokenized DNA sequences
114
+ batch_idx_map: Mapping of each sequence to its batch item
115
+ batch_size: Number of items in the batch
116
+
117
+ Returns:
118
+ List of tensor embeddings for each batch item
119
+ """
120
+ # Process all sequences to get DNA representations
121
+ with torch.no_grad():
122
+ # Handle different model types based on dna_is_evo2 attribute
123
+ if self.dna_is_evo2 and self.dna_embedding_layer is not None: # Evo2 model
124
+ # Get embeddings from the specific layer in Evo2
125
+ hidden_states_list = []
126
+
127
+ for seq_idx in range(len(dna_tokenized["input_ids"])):
128
+ # Extract single sequence
129
+ input_ids = dna_tokenized["input_ids"][seq_idx:seq_idx+1]
130
+
131
+ # Call Evo2 with return_embeddings=True
132
+ _, embeddings = self.dna_model(
133
+ input_ids,
134
+ return_embeddings=True,
135
+ layer_names=[self.dna_embedding_layer]
136
+ )
137
+
138
+ # Get embeddings for the specified layer
139
+ seq_embeddings = embeddings[self.dna_embedding_layer].squeeze(0)
140
+ hidden_states_list.append(seq_embeddings)
141
+
142
+ # Stack to get same format as non-Evo2 output
143
+ if hidden_states_list:
144
+ hidden_states = torch.stack(hidden_states_list)
145
+ else:
146
+ return [torch.zeros((0, self.text_hidden_size)) for _ in range(batch_size)]
147
+
148
+ else: # Standard HuggingFace model
149
+ # Use existing code path for HF models
150
+ outputs = self.dna_model(
151
+ input_ids=dna_tokenized["input_ids"],
152
+ attention_mask=dna_tokenized["attention_mask"],
153
+ output_hidden_states=True,
154
+ )
155
+ # Get the last hidden state
156
+ hidden_states = outputs.hidden_states[-1] # shape: [n_seqs, seq_len, hidden_dim]
157
+
158
+ # Project all embeddings at once
159
+ hidden_states = hidden_states.to(device=self.dna_projection.weight.device, dtype=self.dna_projection.weight.dtype)
160
+ projected_states = self.dna_projection(hidden_states)
161
+
162
+ # Group embeddings by batch item
163
+ result = [[] for _ in range(batch_size)]
164
+
165
+ # For each sequence, get its embeddings and add to appropriate batch result
166
+ for seq_idx, batch_idx in enumerate(batch_idx_map):
167
+ # Get only the valid (non-padding) tokens
168
+ valid_length = dna_tokenized["attention_mask"][seq_idx].sum().item()
169
+ seq_embedding = projected_states[seq_idx, :valid_length]
170
+ result[batch_idx].append(seq_embedding)
171
+
172
+ # Concatenate embeddings for each batch item
173
+ for i in range(batch_size):
174
+ if result[i]:
175
+ result[i] = torch.cat(result[i], dim=0)
176
+ else:
177
+ result[i] = torch.zeros((0, self.text_hidden_size))
178
+
179
+ return result
180
+
181
+ def forward(
182
+ self,
183
+ input_ids: Optional[torch.Tensor] = None,
184
+ attention_mask: Optional[torch.Tensor] = None,
185
+ dna_tokenized: Optional[Dict[str, torch.Tensor]] = None,
186
+ batch_idx_map: Optional[List[int]] = None,
187
+ labels: Optional[torch.Tensor] = None,
188
+ **kwargs,
189
+ ) -> torch.Tensor:
190
+ """
191
+ Generate text based on DNA and text inputs.
192
+
193
+ Args:
194
+ input_ids: Input IDs (used if provided directly)
195
+ attention_mask: Attention mask (used if provided directly)
196
+ dna_tokenized: Tokenized DNA sequences (used if provided directly)
197
+ batch_idx_map: Batch mapping for DNA sequences (used if provided directly)
198
+ labels: Labels for supervised fine-tuning (used if provided directly)
199
+ **kwargs: Additional arguments for generation
200
+
201
+ Returns:
202
+ Outputs from the text model
203
+ """
204
+ # Ensure required inputs are available
205
+ if input_ids is None or attention_mask is None:
206
+ raise ValueError("Either 'inputs' or 'input_ids'/'attention_mask' must be provided")
207
+
208
+ batch_size = input_ids.shape[0]
209
+
210
+ # Get text embeddings from the model's embedding layer
211
+ text_inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
212
+
213
+ if dna_tokenized is not None and batch_idx_map:
214
+ batch_dna_embeds = self.process_dna_embeddings(dna_tokenized, batch_idx_map, batch_size)
215
+
216
+ mask = input_ids == self.dna_token_id
217
+
218
+ n_dna_tokens = mask.sum().item()
219
+ dna_embeds_flat = torch.cat(batch_dna_embeds, dim=0)
220
+ n_dna_features = dna_embeds_flat.shape[0]
221
+
222
+ if n_dna_features != n_dna_tokens:
223
+ raise ValueError(
224
+ f"DNA features and DNA tokens do not match: features {n_dna_features}, tokens: {n_dna_tokens}"
225
+ )
226
+
227
+ # Ensure DNA embeddings have the same dtype as the text embeddings
228
+ dna_embeds_flat = dna_embeds_flat.to(dtype=text_inputs_embeds.dtype)
229
+ text_inputs_embeds[mask] = dna_embeds_flat
230
+
231
+ # Handle labels if provided (for training)
232
+ if labels is not None:
233
+ # TODO: Implement this
234
+ pass
235
+
236
+ # Forward pass through the text model (loss is computed if labels is provided)
237
+ outputs = self.text_model(
238
+ inputs_embeds=text_inputs_embeds,
239
+ attention_mask=attention_mask,
240
+ labels=labels,
241
+ **kwargs,
242
+ )
243
+
244
+ return outputs
245
+
246
+ def generate(
247
+ self,
248
+ input_ids: Optional[torch.Tensor] = None,
249
+ attention_mask: Optional[torch.Tensor] = None,
250
+ dna_tokenized: Optional[Dict[str, torch.Tensor]] = None,
251
+ batch_idx_map: Optional[List[int]] = None,
252
+ **generation_kwargs,
253
+ ) -> Union[torch.Tensor, List[str]]:
254
+ """
255
+ Generate text based on DNA and text inputs.
256
+
257
+ Args:
258
+ inputs: The preprocessed inputs from the processor (preferred method)
259
+ batch_dna_sequences: List of lists of DNA sequences per batch item (legacy method)
260
+ input_texts: List of input texts (legacy method)
261
+ input_ids: Input IDs (used if provided directly)
262
+ attention_mask: Attention mask (used if provided directly)
263
+ dna_tokenized: Tokenized DNA sequences (used if provided directly)
264
+ batch_idx_map: Batch mapping for DNA sequences (used if provided directly)
265
+ **generation_kwargs: Additional arguments for generation
266
+
267
+ Returns:
268
+ Generated token IDs which can be decoded using the processor
269
+ """
270
+ # Ensure required inputs are available
271
+ if input_ids is None or attention_mask is None:
272
+ raise ValueError("Either 'inputs' or 'input_ids'/'attention_mask' must be provided")
273
+
274
+ batch_size = input_ids.shape[0]
275
+
276
+ # Get text embeddings from the model's embedding layer
277
+ text_inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
278
+
279
+ if dna_tokenized is not None and batch_idx_map:
280
+ batch_dna_embeds = self.process_dna_embeddings(dna_tokenized, batch_idx_map, batch_size)
281
+
282
+ mask = input_ids == self.dna_token_id
283
+
284
+ n_dna_tokens = mask.sum().item()
285
+ dna_embeds_flat = torch.cat(batch_dna_embeds, dim=0)
286
+ n_dna_features = dna_embeds_flat.shape[0]
287
+
288
+ if n_dna_features != n_dna_tokens:
289
+ raise ValueError(
290
+ f"DNA features and DNA tokens do not match: features {n_dna_features}, tokens: {n_dna_tokens}"
291
+ )
292
+
293
+ # Ensure DNA embeddings have the same dtype as the text embeddings
294
+ dna_embeds_flat = dna_embeds_flat.to(dtype=text_inputs_embeds.dtype)
295
+ text_inputs_embeds[mask] = dna_embeds_flat
296
+
297
+ # Generation parameters may need adjustment based on model type
298
+ with torch.no_grad():
299
+ outputs = self.text_model.generate(
300
+ inputs_embeds=text_inputs_embeds,
301
+ attention_mask=attention_mask,
302
+ use_cache=True,
303
+ **generation_kwargs,
304
+ )
305
+
306
+ return outputs
BioReason-main/bioreason/models/dna_only.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Dict
5
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
6
+
7
+
8
+ class SelfAttentionPooling(nn.Module):
9
+ def __init__(self, hidden_size, num_heads=8):
10
+ super().__init__()
11
+ # Use PyTorch's built-in multi-head attention
12
+ self.attention = nn.MultiheadAttention(
13
+ embed_dim=hidden_size,
14
+ num_heads=num_heads,
15
+ batch_first=True
16
+ )
17
+ # Learnable query vector
18
+ self.query = nn.Parameter(torch.randn(1, 1, hidden_size))
19
+
20
+ def forward(self, embeddings, attention_mask=None):
21
+ # Expand query to batch size
22
+ batch_size = embeddings.size(0)
23
+ query = self.query.expand(batch_size, -1, -1)
24
+
25
+ # Create key padding mask from attention mask if provided
26
+ key_padding_mask = None
27
+ if attention_mask is not None:
28
+ key_padding_mask = attention_mask == 0 # Convert to boolean mask where True means ignore
29
+
30
+ # Apply attention: query attends to embeddings
31
+ context, _ = self.attention(
32
+ query=query, # [batch_size, 1, hidden_size]
33
+ key=embeddings, # [batch_size, seq_len, hidden_size]
34
+ value=embeddings, # [batch_size, seq_len, hidden_size]
35
+ key_padding_mask=key_padding_mask
36
+ )
37
+
38
+ # Squeeze out the singleton dimension
39
+ return context.squeeze(1) # [batch_size, hidden_size]
40
+
41
+
42
+ class DNAClassifierModel(nn.Module):
43
+ """
44
+ A simple classifier that uses a DNA model with a classification head.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ dna_model_name: str,
50
+ cache_dir: str = None,
51
+ max_length_dna: int = 4096,
52
+ num_classes: int = 2, # Binary classification by default
53
+ dna_is_evo2: bool = False,
54
+ dna_embedding_layer: str = None,
55
+ train_just_classifier: bool = True
56
+ ):
57
+ """
58
+ Initialize the DNAClassifierModel.
59
+
60
+ Args:
61
+ dna_model_name (str): Name of the DNA model to use
62
+ cache_dir (str): Directory to cache models
63
+ max_length_dna (int): Maximum sequence length
64
+ num_classes (int): Number of output classes
65
+ dna_is_evo2: Whether the DNA model is Evo2. Defaults to False
66
+ dna_embedding_layer: Name of the layer to use for the Evo2 model. Defaults to None
67
+ train_just_classifier: Whether to train just the classifier. Defaults to True
68
+ """
69
+ super().__init__()
70
+
71
+ self.dna_model_name = dna_model_name
72
+ self.cache_dir = cache_dir
73
+ self.max_length_dna = max_length_dna
74
+ self.num_classes = num_classes
75
+ self.dna_is_evo2 = dna_is_evo2
76
+ self.dna_embedding_layer = dna_embedding_layer
77
+ self.train_just_classifier = train_just_classifier
78
+
79
+ # Load the DNA model and tokenizer
80
+ if not self.dna_is_evo2:
81
+ self.dna_model = AutoModelForMaskedLM.from_pretrained(
82
+ dna_model_name, cache_dir=cache_dir, trust_remote_code=True
83
+ )
84
+ self.dna_tokenizer = AutoTokenizer.from_pretrained(dna_model_name, trust_remote_code=True)
85
+ self.dna_config = self.dna_model.config
86
+
87
+ else:
88
+ from evo2 import Evo2
89
+ from bioreason.models.evo2_tokenizer import Evo2Tokenizer
90
+ self.dna_model = Evo2(dna_model_name)
91
+ self.dna_tokenizer = Evo2Tokenizer(self.dna_model.tokenizer)
92
+ self.dna_config = self.dna_model.model.config
93
+ self.dna_embedding_layer = self.dna_embedding_layer
94
+
95
+ # Get hidden size from model config
96
+ self.hidden_size = self.dna_config.hidden_size
97
+
98
+ # Add the self-attention pooling module
99
+ self.pooler = SelfAttentionPooling(self.hidden_size)
100
+
101
+ # Create classification head that takes concatenated embeddings from both sequences
102
+ self.classifier = nn.Sequential(
103
+ nn.Linear(self.hidden_size * 2, self.hidden_size),
104
+ nn.ReLU(),
105
+ nn.Dropout(0.1),
106
+ nn.Linear(self.hidden_size, num_classes),
107
+ )
108
+
109
+ self.max_length_dna = max_length_dna
110
+
111
+ def get_dna_embedding(self, input_ids: torch.Tensor, attention_mask: torch.Tensor):
112
+ """
113
+ Get DNA embedding for a single DNA sequence using self-attention pooling.
114
+
115
+ Args:
116
+ input_ids: DNA tokenized sequence
117
+ attention_mask: DNA tokenized sequence attention mask
118
+
119
+ Returns:
120
+ torch.Tensor: Tensor containing the self-attention pooled DNA embedding
121
+ """
122
+ # Add batch dimension if not present
123
+ if input_ids.dim() == 1:
124
+ input_ids = input_ids.unsqueeze(0) # [1, seq_len]
125
+
126
+ # Handle attention mask - create if not provided or add batch dimension
127
+ if attention_mask is None:
128
+ attention_mask = torch.ones_like(input_ids)
129
+ elif attention_mask.dim() == 1:
130
+ attention_mask = attention_mask.unsqueeze(0) # [1, seq_len]
131
+
132
+ # Get embeddings from DNA model
133
+ with torch.set_grad_enabled(not self.train_just_classifier): # Enable gradients for fine-tuning
134
+
135
+ if self.dna_is_evo2 and self.dna_embedding_layer is not None: # Evo2 model
136
+ # Get embeddings from the specific layer in Evo2
137
+ _, embeddings = self.dna_model(
138
+ input_ids,
139
+ return_embeddings=True,
140
+ layer_names=[self.dna_embedding_layer]
141
+ )
142
+
143
+ # Get embeddings for the specified layer
144
+ hidden_states = embeddings[self.dna_embedding_layer]
145
+
146
+ else:
147
+ # Get embeddings from the last hidden state
148
+ outputs = self.dna_model(
149
+ input_ids,
150
+ attention_mask=attention_mask,
151
+ output_hidden_states=True,
152
+ )
153
+
154
+ # Get the last hidden state
155
+ hidden_states = outputs.hidden_states[-1]
156
+
157
+ # Apply self-attention pooling to get a weighted representation
158
+ sequence_embedding = self.pooler(hidden_states, attention_mask)
159
+ return sequence_embedding.squeeze(0)
160
+
161
+ def forward(
162
+ self, ref_ids=None, alt_ids=None, ref_attention_mask=None, alt_attention_mask=None
163
+ ):
164
+ """
165
+ Forward pass of the model.
166
+
167
+ Args:
168
+ ref_ids: Reference sequence token IDsself.dna_model
169
+ alt_ids: Alternate sequence token IDsself.dna_model
170
+ ref_attention_mask: Reference sequence attention maskself.dna_model
171
+ alt_attention_mask: Alternate sequence attention maskself.dna_model
172
+
173
+ Returns:
174
+ torch.Tensor: Classification logits
175
+ """
176
+ batch_size = ref_ids.shape[0] if ref_ids is not None else alt_ids.shape[0]
177
+
178
+ if batch_size is None:
179
+ raise ValueError("Either token IDs must be provided")
180
+
181
+ ref_embeddings = []
182
+ alt_embeddings = []
183
+
184
+ # Process each example in the batch
185
+ for i in range(batch_size):
186
+
187
+ # Get sequence embeddings
188
+ ref_embed = self.get_dna_embedding(ref_ids[i], ref_attention_mask[i])
189
+ alt_embed = self.get_dna_embedding(alt_ids[i], alt_attention_mask[i])
190
+ ref_embeddings.append(ref_embed)
191
+ alt_embeddings.append(alt_embed)
192
+
193
+ # Stack embeddings
194
+ ref_embeddings = torch.stack(ref_embeddings)
195
+ alt_embeddings = torch.stack(alt_embeddings)
196
+
197
+ # Concatenate ref and alt embeddings
198
+ combined_embeddings = torch.cat([ref_embeddings, alt_embeddings], dim=1)
199
+
200
+ # Pass through classifier
201
+ logits = self.classifier(combined_embeddings)
202
+
203
+ return logits
BioReason-main/bioreason/models/evo2_tokenizer.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.tokenization_utils import PreTrainedTokenizer
2
+ from transformers.utils import logging
3
+ from transformers import AutoTokenizer
4
+ from transformers.tokenization_utils_base import BatchEncoding
5
+ import torch
6
+ import numpy as np
7
+ from typing import List, Dict, Optional, Union, Tuple
8
+
9
+ # Register the tokenizer with AutoTokenizer
10
+ from transformers.models.auto import AutoTokenizer
11
+ from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING
12
+ from transformers.models.auto.configuration_auto import CONFIG_MAPPING
13
+
14
+ logger = logging.get_logger(__name__)
15
+
16
+ class Evo2Tokenizer(PreTrainedTokenizer):
17
+ """
18
+ Tokenizer for Evo2 models - wraps the CharLevelTokenizer to be compatible with HuggingFace.
19
+ """
20
+ vocab_files_names = {} # No vocab files needed
21
+ model_input_names = ["input_ids", "attention_mask"]
22
+
23
+ def __init__(
24
+ self,
25
+ evo2_tokenizer,
26
+ bos_token="<s>",
27
+ eos_token="</s>",
28
+ pad_token="<pad>",
29
+ unk_token="<unk>",
30
+ **kwargs
31
+ ):
32
+ """
33
+ Initialize the Evo2Tokenizer.
34
+
35
+ Args:
36
+ evo2_tokenizer: The Evo2 CharLevelTokenizer to wrap
37
+ bos_token: Beginning of sequence token
38
+ eos_token: End of sequence token
39
+ pad_token: Padding token
40
+ unk_token: Unknown token
41
+ """
42
+ self.evo2_tokenizer = evo2_tokenizer
43
+
44
+ # Map special tokens to Evo2 tokenizer's special token IDs
45
+ self._pad_token = pad_token
46
+ self._eos_token = eos_token
47
+ self._bos_token = bos_token
48
+ self._unk_token = unk_token
49
+
50
+ # Initialize with special tokens
51
+ super().__init__(
52
+ bos_token=bos_token,
53
+ eos_token=eos_token,
54
+ pad_token=pad_token,
55
+ unk_token=unk_token,
56
+ **kwargs
57
+ )
58
+
59
+ # Set token IDs from Evo2 tokenizer
60
+ self.pad_token_id = self.evo2_tokenizer.pad_id
61
+ self.eos_token_id = self.evo2_tokenizer.eos_id
62
+
63
+ @property
64
+ def vocab_size(self) -> int:
65
+ """Return the vocab size of the tokenizer."""
66
+ return self.evo2_tokenizer.vocab_size
67
+
68
+ def get_vocab(self) -> Dict:
69
+ """Return vocab as a dictionary."""
70
+ # Evo2 CharLevelTokenizer doesn't have a traditional vocab dict
71
+ # Create a simple mapping of ASCII codes to tokens
72
+ return {chr(i): i for i in range(self.vocab_size)}
73
+
74
+ def _tokenize(self, text: str) -> List[int]:
75
+ """Tokenize a string using the Evo2 tokenizer."""
76
+ return [chr(int(token)) for token in self.evo2_tokenizer.tokenize(text)]
77
+
78
+ def _convert_token_to_id(self, token: str) -> int:
79
+ """Convert a token to an id using the Evo2 tokenizer."""
80
+ # Since tokens are just characters, convert to their ASCII value
81
+ return ord(token)
82
+
83
+ def _convert_id_to_token(self, index: int) -> str:
84
+ """Convert an id to a token using the Evo2 tokenizer."""
85
+ # Convert ASCII value back to character
86
+ return chr(index)
87
+
88
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
89
+ """Convert a sequence of tokens to a single string."""
90
+ return "".join(tokens)
91
+
92
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
93
+ """No vocabulary to save for Evo2Tokenizer, so just return an empty tuple."""
94
+ return ()
95
+
96
+ def __call__(
97
+ self,
98
+ text: Union[str, List[str]],
99
+ text_pair: Optional[Union[str, List[str]]] = None,
100
+ padding: Union[bool, str] = False,
101
+ truncation: Union[bool, str] = False,
102
+ max_length: Optional[int] = None,
103
+ return_tensors: Optional[str] = None,
104
+ return_token_type_ids: Optional[bool] = None,
105
+ return_attention_mask: Optional[bool] = True,
106
+ **kwargs
107
+ ) -> Dict[str, torch.Tensor]:
108
+ """
109
+ Main tokenization method that handles batching and converts to tensors.
110
+ """
111
+ # Handle single string vs list of strings
112
+ if isinstance(text, str):
113
+ text = [text]
114
+
115
+ # Tokenize all sequences - note: tokenizer only accepts strings, not lists
116
+ input_ids_list = []
117
+ for seq in text:
118
+ # Tokenize and convert numpy.uint8 to Python integers
119
+ tokens = [int(token) for token in self.evo2_tokenizer.tokenize(seq)]
120
+
121
+ # Truncate if needed
122
+ if truncation and max_length and len(tokens) > max_length:
123
+ tokens = tokens[:max_length]
124
+
125
+ input_ids_list.append(tokens)
126
+
127
+ # Apply padding if needed
128
+ if padding:
129
+ if False:#max_length:
130
+ max_len = max_length
131
+ else:
132
+ max_len = max(len(ids) for ids in input_ids_list)
133
+
134
+ # Create padded sequences and attention masks
135
+ padded_input_ids = []
136
+ attention_mask = []
137
+
138
+ for ids in input_ids_list:
139
+ # Apply left padding (pad on the left)
140
+ padding_length = max_len - len(ids)
141
+ padded_ids = [self.pad_token_id] * padding_length + ids
142
+ mask = [0] * padding_length + [1] * len(ids)
143
+
144
+ padded_input_ids.append(padded_ids)
145
+ attention_mask.append(mask)
146
+
147
+ input_ids_list = padded_input_ids
148
+ else:
149
+ # Create attention mask without padding
150
+ attention_mask = [[1] * len(ids) for ids in input_ids_list]
151
+
152
+ # Create result dictionary
153
+ result = {"input_ids": input_ids_list}
154
+ if return_attention_mask:
155
+ result["attention_mask"] = attention_mask
156
+
157
+ # Convert to tensors if requested
158
+ if return_tensors == "pt":
159
+ result = {k: torch.tensor(v) for k, v in result.items()}
160
+
161
+ # Return a BatchEncoding object rather than a plain dictionary
162
+ return BatchEncoding(
163
+ data=result,
164
+ tensor_type=return_tensors,
165
+ prepend_batch_axis=False, # Already handled in our tensor creation
166
+ encoding=None # No encoding info from Evo2's tokenizer
167
+ )
168
+
169
+ def batch_decode(
170
+ self,
171
+ sequences: Union[List[int], List[List[int]], torch.Tensor],
172
+ skip_special_tokens: bool = False,
173
+ **kwargs
174
+ ) -> List[str]:
175
+ """
176
+ Decode a batch of token ids to strings.
177
+ """
178
+ if isinstance(sequences, torch.Tensor):
179
+ sequences = sequences.tolist()
180
+
181
+ return self.evo2_tokenizer.detokenize_batch(sequences)
182
+
183
+ def decode(
184
+ self,
185
+ token_ids: Union[int, List[int], torch.Tensor],
186
+ skip_special_tokens: bool = False,
187
+ **kwargs
188
+ ) -> str:
189
+ """
190
+ Decode a single sequence of token ids to a string.
191
+ """
192
+ if isinstance(token_ids, torch.Tensor):
193
+ token_ids = token_ids.tolist()
194
+
195
+ # Single sequence
196
+ if not isinstance(token_ids, list) or not token_ids or not isinstance(token_ids[0], (list, torch.Tensor)):
197
+ return self.evo2_tokenizer.detokenize(token_ids)
198
+
199
+ # Batch with one item
200
+ return self.batch_decode(token_ids, skip_special_tokens, **kwargs)[0]
201
+
202
+
203
+ # Register the tokenizer - you'll need to do this when your script loads
204
+ # You might want to put this in your __init__.py file
205
+ def register_evo2_tokenizer():
206
+ """Register the Evo2Tokenizer with HuggingFace's AutoTokenizer."""
207
+
208
+ # This will register the tokenizer so AutoTokenizer.from_pretrained knows about it
209
+ AutoTokenizer.register("evo2", Evo2Tokenizer)
210
+
211
+ # If you have a config class, you would also register that
212
+ # from transformers.models.auto import AutoConfig
213
+ # AutoConfig.register("evo2", Evo2Config)
214
+
215
+ print("Evo2Tokenizer registered with AutoTokenizer")
216
+
217
+
218
+ if __name__ == "__main__":
219
+ register_evo2_tokenizer()
BioReason-main/bioreason/trainer/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .grpo_config import DNALLMGRPOConfig
2
+ from .grpo_trainer import DNALLMGRPOTrainer
3
+
4
+ __all__ = [
5
+ "DNALLMGRPOConfig",
6
+ "DNALLMGRPOTrainer",
7
+ ]
BioReason-main/bioreason/trainer/demo_grpo.py ADDED
@@ -0,0 +1,811 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import textwrap
17
+ import warnings
18
+ from collections import defaultdict
19
+ from typing import Any, Callable, Optional, Sized, Union
20
+ from unittest.mock import patch
21
+
22
+ import torch
23
+ import torch.utils.data
24
+ import transformers
25
+ from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
26
+ from accelerate.utils.other import is_compiled_module
27
+ from datasets import Dataset, IterableDataset
28
+ from packaging import version
29
+ from torch import nn
30
+ from torch.utils.data import Sampler
31
+ from transformers import (
32
+ AutoModelForCausalLM,
33
+ AutoModelForSequenceClassification,
34
+ AutoTokenizer,
35
+ GenerationConfig,
36
+ PreTrainedModel,
37
+ PreTrainedTokenizerBase,
38
+ Trainer,
39
+ TrainerCallback,
40
+ is_wandb_available,
41
+ )
42
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
43
+ from transformers.utils import is_peft_available
44
+
45
+ from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
46
+ from trl.import_utils import is_vllm_available
47
+ from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
48
+ from trl import SyncRefModelCallback
49
+ from trl import GRPOConfig
50
+ from trl.trainer.utils import generate_model_card, get_comet_experiment_url, pad, selective_log_softmax
51
+
52
+
53
+ if is_peft_available():
54
+ from peft import PeftConfig, get_peft_model
55
+
56
+ if is_vllm_available():
57
+ from vllm import LLM, SamplingParams
58
+
59
+ if is_wandb_available():
60
+ import wandb
61
+
62
+ # What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
63
+ # rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
64
+ RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
65
+
66
+
67
+ class RepeatRandomSampler(Sampler):
68
+ """
69
+ Sampler that repeats the indices of a dataset N times.
70
+
71
+ Args:
72
+ data_source (`Sized`):
73
+ Dataset to sample from.
74
+ repeat_count (`int`):
75
+ Number of times to repeat each index.
76
+ seed (`Optional[int]`):
77
+ Random seed for reproducibility (only affects this sampler).
78
+
79
+ Example:
80
+ ```python
81
+ >>> sampler = RepeatRandomSampler(["a", "b", "c", "d"], repeat_count=2)
82
+ >>> list(sampler)
83
+ [2, 2, 0, 0, 3, 3, 1, 1]
84
+ ```
85
+ """
86
+
87
+ def __init__(self, data_source: Sized, repeat_count: int, seed: Optional[int] = None):
88
+ self.data_source = data_source
89
+ self.repeat_count = repeat_count
90
+ self.num_samples = len(data_source)
91
+ self.seed = seed
92
+ self.generator = torch.Generator() # Create a local random generator
93
+ if seed is not None:
94
+ self.generator.manual_seed(seed)
95
+
96
+ def __iter__(self):
97
+ indexes = [
98
+ idx
99
+ for idx in torch.randperm(self.num_samples, generator=self.generator).tolist()
100
+ for _ in range(self.repeat_count)
101
+ ]
102
+ return iter(indexes)
103
+
104
+ def __len__(self):
105
+ return self.num_samples * self.repeat_count
106
+
107
+ # made this to test out the usual pipeline of GRPOTrainer data, and add my own debug messages
108
+ class FakeGRPOTrainer(Trainer):
109
+ """
110
+ Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
111
+ paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300).
112
+
113
+ Example:
114
+
115
+ ```python
116
+ from datasets import load_dataset
117
+ from trl import GRPOTrainer
118
+
119
+ dataset = load_dataset("trl-lib/tldr", split="train")
120
+
121
+ def reward_func(completions, **kwargs):
122
+ # Dummy reward function that rewards completions with more unique letters.
123
+ return [float(len(set(completion))) for completion in completions]
124
+
125
+ trainer = GRPOTrainer(
126
+ model="Qwen/Qwen2-0.5B-Instruct",
127
+ reward_funcs=reward_func,
128
+ train_dataset=dataset,
129
+ )
130
+
131
+ trainer.train()
132
+ ```
133
+
134
+ Args:
135
+ model (`Union[str, PreTrainedModel]`):
136
+ Model to be trained. Can be either:
137
+
138
+ - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
139
+ a path to a *directory* containing model weights saved using
140
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
141
+ loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
142
+ in `args.model_init_kwargs`.
143
+ - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
144
+ reward_funcs (`Union[RewardFunc, list[RewardFunc]]`):
145
+ Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
146
+ functions with the prompts and completions and sum the rewards. Can be either:
147
+
148
+ - A single reward function, such as:
149
+ - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a
150
+ path to a *directory* containing model weights saved using
151
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
152
+ using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
153
+ keyword arguments in `args.model_init_kwargs`.
154
+ - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
155
+ - A custom reward function: The function is provided with the prompts and the generated completions,
156
+ plus any additional columns in the dataset. It should return a list of rewards. For more details, see
157
+ [Using a custom reward function](#using-a-custom-reward-function).
158
+ - A list of reward functions, where each item can independently be any of the above types. Mixing different
159
+ types within the list (e.g., a string model ID and a custom reward function) is allowed.
160
+ args ([`GRPOConfig`], *optional*, defaults to `None`):
161
+ Configuration for this trainer. If `None`, a default configuration is used.
162
+ train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
163
+ Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is
164
+ ignored. The format of the samples can be either:
165
+
166
+ - [Standard](dataset_formats#standard): Each sample contains plain text.
167
+ - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
168
+ and content).
169
+ eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
170
+ Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
171
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
172
+ Processing class used to process the data. The padding side must be set to "left". If `None`, the
173
+ processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`].
174
+ reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`):
175
+ Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
176
+
177
+ - A single processing class: Used when `reward_funcs` contains only one reward function.
178
+ - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`.
179
+ If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is
180
+ `None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`].
181
+ For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]),
182
+ the corresponding entries in `reward_processing_classes` are ignored.
183
+ callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
184
+ List of callbacks to customize the training loop. Will add those to the list of default callbacks
185
+ detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
186
+
187
+ If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
188
+ method.
189
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
190
+ A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
191
+ model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
192
+ peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
193
+ PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
194
+ """
195
+
196
+ _tag_names = ["trl", "grpo"]
197
+
198
+ def __init__(
199
+ self,
200
+ model: Union[str, PreTrainedModel],
201
+ reward_funcs: Union[RewardFunc, list[RewardFunc]],
202
+ args: GRPOConfig = None,
203
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
204
+ eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
205
+ processing_class: Optional[PreTrainedTokenizerBase] = None,
206
+ reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
207
+ callbacks: Optional[list[TrainerCallback]] = None,
208
+ optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
209
+ peft_config: Optional["PeftConfig"] = None,
210
+ ):
211
+ # Args
212
+ if args is None:
213
+ model_name = model if isinstance(model, str) else model.config._name_or_path
214
+ model_name = model_name.split("/")[-1]
215
+ args = GRPOConfig(f"{model_name}-GRPO")
216
+
217
+ # Models
218
+ # Trained model
219
+ model_init_kwargs = args.model_init_kwargs or {}
220
+ if isinstance(model, str):
221
+ model_id = model
222
+ torch_dtype = model_init_kwargs.get("torch_dtype")
223
+ if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
224
+ pass # torch_dtype is already a torch.dtype or "auto" or None
225
+ elif isinstance(torch_dtype, str): # it's a str, but not "auto"
226
+ torch_dtype = getattr(torch, torch_dtype)
227
+ model_init_kwargs["torch_dtype"] = torch_dtype
228
+ else:
229
+ raise ValueError(
230
+ "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
231
+ f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
232
+ )
233
+ # Disable caching if gradient checkpointing is enabled (not supported)
234
+ model_init_kwargs["use_cache"] = (
235
+ False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
236
+ )
237
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
238
+ else:
239
+ model_id = model.config._name_or_path
240
+ if args.model_init_kwargs is not None:
241
+ raise ValueError(
242
+ "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
243
+ "This argument can only be used when the `model` argument is a string."
244
+ )
245
+
246
+ if peft_config is not None:
247
+ model = get_peft_model(model, peft_config)
248
+
249
+ # Reference model
250
+ if is_deepspeed_zero3_enabled():
251
+ self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
252
+ elif not is_peft_model(model):
253
+ # If PEFT configuration is not provided, create a reference model based on the initial model.
254
+ self.ref_model = create_reference_model(model)
255
+ else:
256
+ # If PEFT is used, the reference model is not needed since the adapter can be disabled
257
+ # to revert to the initial model.
258
+ self.ref_model = None
259
+
260
+ # Processing class
261
+ if processing_class is None:
262
+ processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
263
+
264
+ # Reward functions
265
+ if not isinstance(reward_funcs, list):
266
+ reward_funcs = [reward_funcs]
267
+ for i, reward_func in enumerate(reward_funcs):
268
+ if isinstance(reward_func, str):
269
+ reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
270
+ reward_func, num_labels=1, **model_init_kwargs
271
+ )
272
+ self.reward_funcs = reward_funcs
273
+
274
+ # Reward weights
275
+ if args.reward_weights is not None:
276
+ if len(args.reward_weights) != len(reward_funcs):
277
+ raise ValueError(
278
+ f"Number of reward weights ({len(args.reward_weights)}) must match number of reward "
279
+ f"functions ({len(reward_funcs)})"
280
+ )
281
+ self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
282
+ else:
283
+ self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)
284
+
285
+ # Reward processing class
286
+ if reward_processing_classes is None:
287
+ reward_processing_classes = [None] * len(reward_funcs)
288
+ elif not isinstance(reward_processing_classes, list):
289
+ reward_processing_classes = [reward_processing_classes]
290
+ else:
291
+ if len(reward_processing_classes) != len(reward_funcs):
292
+ raise ValueError("The number of reward processing classes must match the number of reward functions.")
293
+
294
+ for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
295
+ if isinstance(reward_func, PreTrainedModel):
296
+ if reward_processing_class is None:
297
+ reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
298
+ if reward_processing_class.pad_token_id is None:
299
+ reward_processing_class.pad_token = reward_processing_class.eos_token
300
+ # The reward model computes the reward for the latest non-padded token in the input sequence.
301
+ # So it's important to set the pad token ID to the padding token ID of the processing class.
302
+ reward_func.config.pad_token_id = reward_processing_class.pad_token_id
303
+ reward_processing_classes[i] = reward_processing_class
304
+ self.reward_processing_classes = reward_processing_classes
305
+
306
+ # Data collator
307
+ def data_collator(features): # No data collation is needed in GRPO
308
+ return features
309
+
310
+ # Training arguments
311
+ self.max_prompt_length = args.max_prompt_length
312
+ self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
313
+ self.num_generations = args.num_generations # = G in the GRPO paper
314
+ self.use_vllm = args.use_vllm
315
+
316
+ self.beta = args.beta
317
+
318
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
319
+ # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
320
+ # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
321
+ # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
322
+ # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
323
+ # This acts as a flag to indicate that the warning has already been issued.
324
+ model.warnings_issued["estimate_tokens"] = True
325
+
326
+ # Initialize the metrics
327
+ self._metrics = defaultdict(list)
328
+ self.log_completions = args.log_completions
329
+
330
+ super().__init__(
331
+ model=model,
332
+ args=args,
333
+ data_collator=data_collator,
334
+ train_dataset=train_dataset,
335
+ eval_dataset=eval_dataset,
336
+ processing_class=processing_class,
337
+ callbacks=callbacks,
338
+ optimizers=optimizers,
339
+ )
340
+
341
+ # Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations
342
+ num_processes = self.accelerator.num_processes
343
+ global_batch_size = args.per_device_train_batch_size * num_processes
344
+ possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
345
+ if self.num_generations not in possible_values:
346
+ raise ValueError(
347
+ f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
348
+ f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train "
349
+ f"batch size, the valid values for the number of generations are: {possible_values}."
350
+ )
351
+ if self.args.eval_strategy != "no":
352
+ global_batch_size = args.per_device_eval_batch_size * num_processes
353
+ possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
354
+ if self.num_generations not in possible_values:
355
+ raise ValueError(
356
+ f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
357
+ f"divisible by the number of generations per prompt ({self.num_generations}). Given the current "
358
+ f"eval batch size, the valid values for the number of generations are: {possible_values}."
359
+ )
360
+
361
+ # Ensure each process receives a unique seed to prevent duplicate completions when generating with
362
+ # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
363
+ # it's safer to set it in all cases.
364
+ set_seed(args.seed, device_specific=True)
365
+
366
+ if self.use_vllm:
367
+ if not is_vllm_available():
368
+ raise ImportError(
369
+ "vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
370
+ "`pip install vllm` to use it."
371
+ )
372
+
373
+ if self.accelerator.is_main_process:
374
+ vllm_device = self.args.vllm_device
375
+ if vllm_device == "auto":
376
+ if torch.cuda.device_count() == 1:
377
+ vllm_device = "cuda:0" # particular case when training with onyl 1 GPU: share it
378
+ else:
379
+ vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx
380
+ # Check that the requested device is available
381
+ if vllm_device.split(":")[0] == "cuda" and int(vllm_device.split(":")[1]) >= torch.cuda.device_count():
382
+ raise ValueError(
383
+ f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM "
384
+ "without restricting the number of GPUs for training. Set the `--num_processes` argument to a "
385
+ "value lower than the number of GPUs available on your machine—typically, reducing it by one "
386
+ f"is sufficient. In your case: `--num_processes {torch.cuda.device_count() - 1}`."
387
+ )
388
+ # Check that the requested device is not also used for training
389
+ if vllm_device in {f"cuda:{idx}" for idx in range(self.accelerator.num_processes)}:
390
+ warnings.warn(
391
+ f"The requested device {vllm_device} is also being used for training. For higher throughput "
392
+ "and to avoid out-of-memory errors, it is recommended to use a dedicated device for vLLM. "
393
+ "If this is intentional, you may ignore this warning but should adjust "
394
+ "`vllm_gpu_memory_utilization` accordingly."
395
+ )
396
+ # vLLM is not compatible with accelerate. So we need to patch it to make sure we can (1) place the vLLM
397
+ # model on the desired device (world_size_patch) and (2) avoid a test that is not designed for our
398
+ # setting (profiling_patch).
399
+ world_size_patch = patch("torch.distributed.get_world_size", return_value=1)
400
+ profiling_patch = patch(
401
+ "vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling", return_value=None
402
+ )
403
+ with world_size_patch, profiling_patch:
404
+ self.llm = LLM(
405
+ model=model.name_or_path,
406
+ device=vllm_device,
407
+ gpu_memory_utilization=self.args.vllm_gpu_memory_utilization,
408
+ dtype=self.args.vllm_dtype,
409
+ # Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
410
+ # directly reuse the KV cache if it shares the same prefix with one of the existing queries.
411
+ # This is particularly useful here because we generate completions from the same prompts.
412
+ enable_prefix_caching=True,
413
+ max_model_len=self.args.vllm_max_model_len,
414
+ )
415
+ self.sampling_params = SamplingParams(
416
+ temperature=args.temperature,
417
+ max_tokens=self.max_completion_length,
418
+ )
419
+
420
+ self._last_loaded_step = 0 # tag to avoid useless loading during grad accumulation
421
+
422
+ # When using vLLM, the main process is responsible for loading the model weights. This can cause process
423
+ # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
424
+ # synchronize all processes after vLLM has been fully initialized.
425
+ self.accelerator.wait_for_everyone()
426
+ else:
427
+ self.generation_config = GenerationConfig(
428
+ max_new_tokens=self.max_completion_length,
429
+ do_sample=True,
430
+ temperature=args.temperature,
431
+ pad_token_id=processing_class.pad_token_id,
432
+ )
433
+
434
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
435
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
436
+ # self.model_accepts_loss_kwargs to False to enable scaling.
437
+ self.model_accepts_loss_kwargs = False
438
+
439
+ # Add tags to the model
440
+ self.model.add_model_tags(self._tag_names)
441
+
442
+ if self.ref_model is not None:
443
+ if self.is_deepspeed_enabled:
444
+ self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
445
+ else:
446
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
447
+
448
+ if args.sync_ref_model:
449
+ self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))
450
+
451
+ for i, reward_func in enumerate(self.reward_funcs):
452
+ if isinstance(reward_func, PreTrainedModel):
453
+ self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
454
+
455
+ def _set_signature_columns_if_needed(self):
456
+ # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
457
+ # By default, this method sets `self._signature_columns` to the model's expected inputs.
458
+ # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
459
+ # Instead, we set them to the columns expected by the `training_step` method, hence the override.
460
+ if self._signature_columns is None:
461
+ self._signature_columns = ["prompt"]
462
+
463
+ def _get_train_sampler(self) -> Sampler:
464
+ # Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
465
+ # identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
466
+ # within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
467
+ # preventing discrepancies in group formation.
468
+ return RepeatRandomSampler(self.train_dataset, self.num_generations, seed=self.args.seed)
469
+
470
+ def _get_eval_sampler(self, eval_dataset) -> Sampler:
471
+ # Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
472
+ # identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
473
+ # within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
474
+ # preventing discrepancies in group formation.
475
+ return RepeatRandomSampler(eval_dataset, self.num_generations, seed=self.args.seed)
476
+
477
+ # Get the per-token log probabilities for the completions for the model and the reference model
478
+ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
479
+ # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
480
+ logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
481
+ logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
482
+
483
+ input_ids = input_ids[:, -logits_to_keep:]
484
+ # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
485
+ # See https://github.com/huggingface/trl/issues/2770
486
+ logits = logits[:, -logits_to_keep:]
487
+ return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
488
+
489
+ def _move_model_to_vllm(self):
490
+ with unwrap_model_for_generation(
491
+ self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
492
+ ) as unwrapped_model:
493
+ if is_compiled_module(unwrapped_model):
494
+ unwrapped_model = unwrapped_model._orig_mod
495
+ if is_peft_model(unwrapped_model):
496
+ unwrapped_model.merge_adapter()
497
+ state_dict = unwrapped_model.state_dict()
498
+ # Remove base_model and base_layer prefixes
499
+ state_dict = {
500
+ k.removeprefix("base_model.model.").replace(".base_layer", ""): v for k, v in state_dict.items()
501
+ }
502
+ # Remove values with adapter prefix (example: "_lora")
503
+ state_dict = {k: v for k, v in state_dict.items() if unwrapped_model.prefix not in k}
504
+ # When module to save, remove its prefix and discard the original module
505
+ state_dict = {
506
+ k.replace("modules_to_save.default.", ""): v
507
+ for k, v in state_dict.items()
508
+ if "original_module" not in k
509
+ }
510
+ else:
511
+ state_dict = unwrapped_model.state_dict()
512
+ if self.accelerator.is_main_process:
513
+ llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
514
+ llm_model.load_weights(state_dict.items())
515
+ # Unmerge the adapter to restore the model to its original state.
516
+ # This must be done after loading weights to ensure they correspond to the merged state.
517
+ if is_peft_model(unwrapped_model):
518
+ unwrapped_model.unmerge_adapter()
519
+
520
+ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
521
+ device = self.accelerator.device
522
+ prompts = [x["prompt"] for x in inputs]
523
+ prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
524
+ prompt_inputs = self.processing_class(
525
+ prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
526
+ )
527
+ prompt_inputs = super()._prepare_inputs(prompt_inputs)
528
+ prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
529
+
530
+ if self.max_prompt_length is not None:
531
+ prompt_ids = prompt_ids[:, -self.max_prompt_length :]
532
+ prompt_mask = prompt_mask[:, -self.max_prompt_length :]
533
+
534
+ # Generate completions using either vLLM or regular generation
535
+ if self.args.use_vllm:
536
+ # First, have main process load weights if needed
537
+ if self.state.global_step != self._last_loaded_step:
538
+ self._move_model_to_vllm()
539
+ self._last_loaded_step = self.state.global_step
540
+
541
+ # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
542
+ all_prompts_text = gather_object(prompts_text)
543
+ if self.accelerator.is_main_process:
544
+ outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False)
545
+ completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
546
+ else:
547
+ completion_ids = [None] * len(all_prompts_text)
548
+ # Broadcast the completions from the main process to all processes, ensuring each process receives its
549
+ # corresponding slice.
550
+ completion_ids = broadcast_object_list(completion_ids, from_process=0)
551
+ process_slice = slice(
552
+ self.accelerator.process_index * len(prompts),
553
+ (self.accelerator.process_index + 1) * len(prompts),
554
+ )
555
+ completion_ids = completion_ids[process_slice]
556
+
557
+ # Pad the completions, and concatenate them with the prompts
558
+ completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
559
+ completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
560
+ prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
561
+ else:
562
+ print("about to generate!!")
563
+ # Regular generation path
564
+ with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
565
+ prompt_completion_ids = unwrapped_model.generate(
566
+ prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config
567
+ )
568
+
569
+ print('prompts_ids', prompt_ids, 'attention_mask', prompt_mask)
570
+ print('prompt_completion_ids', prompt_completion_ids)
571
+ print('prompt len', prompt_ids.size(1))
572
+
573
+ # Compute prompt length and extract completion ids
574
+ prompt_length = prompt_ids.size(1)
575
+ prompt_ids = prompt_completion_ids[:, :prompt_length]
576
+ completion_ids = prompt_completion_ids[:, prompt_length:]
577
+
578
+ # Mask everything after the first EOS token
579
+ is_eos = completion_ids == self.processing_class.eos_token_id
580
+ eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
581
+ eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
582
+ sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
583
+ completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
584
+
585
+ # Concatenate prompt_mask with completion_mask for logit computation
586
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
587
+
588
+ logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
589
+
590
+ with torch.inference_mode():
591
+ if self.ref_model is not None:
592
+ ref_per_token_logps = self._get_per_token_logps(
593
+ self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep
594
+ )
595
+ else:
596
+ with self.accelerator.unwrap_model(self.model).disable_adapter():
597
+ ref_per_token_logps = self._get_per_token_logps(
598
+ self.model, prompt_completion_ids, attention_mask, logits_to_keep
599
+ )
600
+
601
+ # Decode the generated completions
602
+ completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
603
+ if is_conversational(inputs[0]):
604
+ completions = []
605
+ for prompt, completion in zip(prompts, completions_text):
606
+ bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
607
+ completions.append([{"role": "assistant", "content": bootstrap + completion}])
608
+ else:
609
+ completions = completions_text
610
+
611
+ rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
612
+ for i, (reward_func, reward_processing_class) in enumerate(
613
+ zip(self.reward_funcs, self.reward_processing_classes)
614
+ ):
615
+ if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
616
+ if is_conversational(inputs[0]):
617
+ messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
618
+ texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
619
+ else:
620
+ texts = [p + c for p, c in zip(prompts, completions)]
621
+ reward_inputs = reward_processing_class(
622
+ texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
623
+ )
624
+ reward_inputs = super()._prepare_inputs(reward_inputs)
625
+ with torch.inference_mode():
626
+ rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
627
+ else:
628
+ # Repeat all input columns (but "prompt" and "completion") to match the number of generations
629
+ keys = [key for key in inputs[0] if key not in ["prompt", "completion"]]
630
+ reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
631
+ output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
632
+ rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
633
+
634
+ # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
635
+ # completions may be distributed across processes
636
+ rewards_per_func = gather(rewards_per_func)
637
+
638
+ # Apply weights to each reward function's output and sum
639
+ rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1)
640
+
641
+ # Compute grouped-wise rewards
642
+ mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
643
+ std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
644
+
645
+ # Normalize the rewards to compute the advantages
646
+ mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
647
+ std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
648
+ advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
649
+
650
+ # Slice to keep only the local part of the data
651
+ process_slice = slice(
652
+ self.accelerator.process_index * len(prompts),
653
+ (self.accelerator.process_index + 1) * len(prompts),
654
+ )
655
+ advantages = advantages[process_slice]
656
+
657
+ # Log the metrics
658
+ reward_per_func = rewards_per_func.mean(0)
659
+ for i, reward_func in enumerate(self.reward_funcs):
660
+ if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
661
+ reward_func_name = reward_func.config._name_or_path.split("/")[-1]
662
+ else:
663
+ reward_func_name = reward_func.__name__
664
+ self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
665
+
666
+ self._metrics["reward"].append(rewards.mean().item())
667
+ self._metrics["reward_std"].append(std_grouped_rewards.mean().item())
668
+
669
+ if (
670
+ self.log_completions
671
+ and self.state.global_step % self.args.logging_steps == 0
672
+ and "wandb" in self.args.report_to
673
+ ):
674
+ import pandas as pd
675
+
676
+ # For logging
677
+ table = {
678
+ "step": [str(self.state.global_step)] * len(rewards),
679
+ "prompt": gather_object(prompts_text),
680
+ "completion": gather_object(completions_text),
681
+ "reward": rewards.tolist(),
682
+ }
683
+ df = pd.DataFrame(table)
684
+
685
+ if wandb.run is not None and self.accelerator.is_main_process:
686
+ wandb.log({"completions": wandb.Table(dataframe=df)})
687
+
688
+ return {
689
+ "prompt_ids": prompt_ids,
690
+ "prompt_mask": prompt_mask,
691
+ "completion_ids": completion_ids,
692
+ "completion_mask": completion_mask,
693
+ "ref_per_token_logps": ref_per_token_logps,
694
+ "advantages": advantages,
695
+ }
696
+
697
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
698
+ if return_outputs:
699
+ raise ValueError("The GRPOTrainer does not support returning outputs")
700
+ # Compute the per-token log probabilities for the model
701
+
702
+ prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
703
+ completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
704
+ input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
705
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
706
+ logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
707
+
708
+ per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
709
+
710
+ # Compute the KL divergence between the model and the reference model
711
+ ref_per_token_logps = inputs["ref_per_token_logps"]
712
+ per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
713
+
714
+ # x - x.detach() allows for preserving gradients from x
715
+ advantages = inputs["advantages"]
716
+ per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
717
+ per_token_loss = -(per_token_loss - self.beta * per_token_kl)
718
+ loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
719
+
720
+ # Log the metrics
721
+ completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
722
+ self._metrics["completion_length"].append(completion_length)
723
+
724
+ mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
725
+ self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
726
+
727
+ return loss
728
+
729
+ def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None):
730
+ inputs = self._prepare_inputs(inputs)
731
+ print("about to loss")
732
+ with torch.no_grad():
733
+ with self.compute_loss_context_manager():
734
+ loss = self.compute_loss(model, inputs)
735
+ loss = loss.mean().detach()
736
+ print("loss computed")
737
+ return loss, None, None
738
+
739
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
740
+ metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
741
+
742
+ # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
743
+ # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
744
+ if next(iter(logs.keys())).startswith("eval_"):
745
+ metrics = {f"eval_{key}": val for key, val in metrics.items()}
746
+
747
+ logs = {**logs, **metrics}
748
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
749
+ super().log(logs, start_time)
750
+ else: # transformers<=4.46
751
+ super().log(logs)
752
+ self._metrics.clear()
753
+
754
+ def create_model_card(
755
+ self,
756
+ model_name: Optional[str] = None,
757
+ dataset_name: Optional[str] = None,
758
+ tags: Union[str, list[str], None] = None,
759
+ ):
760
+ """
761
+ Creates a draft of a model card using the information available to the `Trainer`.
762
+
763
+ Args:
764
+ model_name (`str` or `None`, *optional*, defaults to `None`):
765
+ Name of the model.
766
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
767
+ Name of the dataset used for training.
768
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
769
+ Tags to be associated with the model card.
770
+ """
771
+ if not self.is_world_process_zero():
772
+ return
773
+
774
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
775
+ base_model = self.model.config._name_or_path
776
+ else:
777
+ base_model = None
778
+
779
+ tags = tags or []
780
+ if isinstance(tags, str):
781
+ tags = [tags]
782
+
783
+ if hasattr(self.model.config, "unsloth_version"):
784
+ tags.append("unsloth")
785
+
786
+ citation = textwrap.dedent(
787
+ """\
788
+ @article{zhihong2024deepseekmath,
789
+ title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
790
+ author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
791
+ year = 2024,
792
+ eprint = {arXiv:2402.03300},
793
+ }
794
+ """
795
+ )
796
+
797
+ model_card = generate_model_card(
798
+ base_model=base_model,
799
+ model_name=model_name,
800
+ hub_model_id=self.hub_model_id,
801
+ dataset_name=dataset_name,
802
+ tags=tags,
803
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
804
+ comet_url=get_comet_experiment_url(),
805
+ trainer_name="GRPO",
806
+ trainer_citation=citation,
807
+ paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models",
808
+ paper_id="2402.03300",
809
+ )
810
+
811
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
BioReason-main/bioreason/trainer/grpo_config.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass, field
16
+ from typing import Optional, Union
17
+
18
+ from transformers import TrainingArguments
19
+
20
+
21
+ @dataclass
22
+ class DNALLMGRPOConfig(TrainingArguments):
23
+ r"""
24
+ Configuration class for the [`GRPOTrainer`].
25
+
26
+ Only the parameters specific to GRPO training are listed here. For details on other parameters, refer to the
27
+ [`~transformers.TrainingArguments`] documentation.
28
+
29
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
30
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
31
+ command line.
32
+
33
+ Parameters:
34
+ > Parameters that control the model and reference model
35
+
36
+ model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
37
+ Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
38
+ argument of the [`GRPOTrainer`] is provided as a string.
39
+
40
+ > Parameters that control the data preprocessing
41
+
42
+ remove_unused_columns (`bool`, *optional*, defaults to `False`):
43
+ Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that
44
+ requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`.
45
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
46
+ Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
47
+ num_generations (`int` or `None`, *optional*, defaults to `8`):
48
+ Number of generations per prompt to sample. The global batch size (num_processes * per_device_batch_size)
49
+ must be divisible by this value.
50
+ max_completion_length (`int` or `None`, *optional*, defaults to `256`):
51
+ Maximum length of the generated completion.
52
+ ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
53
+ This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
54
+ improving generation speed. However, disabling this option allows training models that exceed the VRAM
55
+ capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible
56
+ with vLLM generation.
57
+
58
+ > Parameters that control generation
59
+
60
+ temperature (`float`, defaults to `0.9`):
61
+ Temperature for sampling. The higher the temperature, the more random the completions.
62
+ top_p (`float`, *optional*, defaults to `1.0`):
63
+ Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. Set to
64
+ `1.0` to consider all tokens.
65
+ top_k (`int` or `None`, *optional*, defaults to `50`):
66
+ Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, top-k-filtering is
67
+ disabled.
68
+ min_p (`float` or `None`, *optional*, defaults to `None`):
69
+ Minimum token probability, which will be scaled by the probability of the most likely token. It must be a
70
+ value between `0.0` and `1.0`. Typical values are in the `0.01-0.2` range.
71
+ repetition_penalty (`float`, *optional*, defaults to `1.0`):
72
+ Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far.
73
+ Values > `1.0` encourage the model to use new tokens, while values < `1.0` encourage the model to repeat
74
+ tokens.
75
+ cache_implementation (`str` or `None`, *optional*, defaults to `None`):
76
+ Implementation of the cache method for faster generation when use_vllm is set to False.
77
+
78
+ > Parameters that control generation acceleration powered by vLLM
79
+
80
+ use_vllm (`bool`, *optional*, defaults to `False`):
81
+ Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for
82
+ training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`).
83
+ vllm_device (`str`, *optional*, defaults to `"auto"`):
84
+ Device where vLLM generation will run, e.g. `"cuda:1"`. If set to `"auto"` (default), the system will
85
+ automatically select the next available GPU after the last one used for training. This assumes that
86
+ training has not already occupied all available GPUs. If only one device is available, the device will be
87
+ shared between both training and vLLM.
88
+ vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`):
89
+ Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the
90
+ device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus
91
+ improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors
92
+ during initialization.
93
+ vllm_dtype (`str`, *optional*, defaults to `"auto"`):
94
+ Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined
95
+ based on the model configuration. Find the supported values in the vLLM documentation.
96
+ vllm_max_model_len (`int` or `None`, *optional*, defaults to `None`):
97
+ If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced
98
+ `vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model
99
+ context size, which might be much larger than the KV cache, leading to inefficiencies.
100
+ vllm_enable_prefix_caching (`bool`, *optional*, defaults to `True`):
101
+ Whether to enable prefix caching in vLLM. If set to `True` (default), ensure that the model and the hardware
102
+ support this feature.
103
+ vllm_guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`):
104
+ Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled.
105
+
106
+ > Parameters that control the training
107
+
108
+ learning_rate (`float`, *optional*, defaults to `1e-6`):
109
+ Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
110
+ [`~transformers.TrainingArguments`].
111
+ beta (`float`, *optional*, defaults to `0.04`):
112
+ KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving training
113
+ speed, but may be numerically unstable for long training runs.
114
+ num_iterations (`int`, *optional*, defaults to `1`):
115
+ Number of iterations per batch (denoted as μ in the algorithm).
116
+ epsilon (`float`, *optional*, defaults to `0.2`):
117
+ Epsilon value for clipping.
118
+ epsilon_high (`float` or `None`, *optional*, defaults to `None`):
119
+ Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the lower-bound
120
+ specified in argument `epsilon`. Paper [DAPO](https://huggingface.co/papers/2503.14476) recommends `0.28`.
121
+ reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
122
+ Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
123
+ weighted equally with weight `1.0`.
124
+ sync_ref_model (`bool`, *optional*, defaults to `False`):
125
+ Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using
126
+ the `ref_model_mixup_alpha` parameter. This synchronization originites from the
127
+ [TR-DPO](https://huggingface.co/papers/2404.09656) paper.
128
+ ref_model_mixup_alpha (`float`, *optional*, defaults to `0.6`):
129
+ α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix
130
+ between the current policy and the previous reference policy during updates. The reference policy is
131
+ updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you
132
+ must set `sync_ref_model=True`.
133
+ ref_model_sync_steps (`int`, *optional*, defaults to `512`):
134
+ τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how
135
+ frequently the current policy is synchronized with the reference policy. To use this parameter, you must
136
+ set `sync_ref_model=True`.
137
+
138
+ > Parameters that control the logging
139
+
140
+ log_completions (`bool`, *optional*, defaults to `False`):
141
+ Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is
142
+ installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`.
143
+ """
144
+
145
+ # Parameters that control the model and reference model
146
+ model_init_kwargs: Optional[dict] = field(
147
+ default=None,
148
+ metadata={
149
+ "help": "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` "
150
+ "argument of the `GRPOTrainer` is provided as a string."
151
+ },
152
+ )
153
+
154
+ # Parameters that control the data preprocessing
155
+ # The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on
156
+ # additional columns to compute the reward
157
+ remove_unused_columns: Optional[bool] = field(
158
+ default=False,
159
+ metadata={
160
+ "help": "Whether to only keep the column 'prompt' in the dataset. If you use a custom reward function "
161
+ "that requires any column other than 'prompts' and 'completions', you should keep this to `False`."
162
+ },
163
+ )
164
+ max_prompt_length: Optional[int] = field(
165
+ default=512,
166
+ metadata={
167
+ "help": "Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left."
168
+ },
169
+ )
170
+ num_generations: Optional[int] = field(
171
+ default=8,
172
+ metadata={
173
+ "help": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) "
174
+ "must be divisible by this value."
175
+ },
176
+ )
177
+ max_completion_length: Optional[int] = field(
178
+ default=800,
179
+ metadata={"help": "Maximum length of the generated completion."},
180
+ )
181
+ ds3_gather_for_generation: bool = field(
182
+ default=True,
183
+ metadata={
184
+ "help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for "
185
+ "generation, improving generation speed. However, disabling this option allows training models that "
186
+ "exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation. Disabling this option "
187
+ "is not compatible with vLLM generation."
188
+ },
189
+ )
190
+
191
+ # Parameters that control generation
192
+ temperature: float = field(
193
+ default=0.6,
194
+ metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."},
195
+ )
196
+ top_p: float = field(
197
+ default=0.95,
198
+ metadata={
199
+ "help": "Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. "
200
+ "Set to 1.0 to consider all tokens."
201
+ },
202
+ )
203
+ top_k: Optional[int] = field(
204
+ default=20,
205
+ metadata={
206
+ "help": "Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, "
207
+ "top-k-filtering is disabled."
208
+ },
209
+ )
210
+ min_p: Optional[float] = field(
211
+ default=None,
212
+ metadata={
213
+ "help": "Minimum token probability, which will be scaled by the probability of the most likely token. It "
214
+ "must be a value between 0.0 and 1.0. Typical values are in the 0.01-0.2 range."
215
+ },
216
+ )
217
+ repetition_penalty: float = field(
218
+ default=1.0,
219
+ metadata={
220
+ "help": "Float that penalizes new tokens based on whether they appear in the prompt and the generated "
221
+ "text so far. Values > 1.0 encourage the model to use new tokens, while values < 1.0 encourage the model "
222
+ "to repeat tokens."
223
+ },
224
+ )
225
+ cache_implementation: Optional[str] = field(
226
+ default=None,
227
+ metadata={"help": "Implementation of the cache method for faster generation when use_vllm is set to False."},
228
+ )
229
+
230
+ # Parameters that control generation acceleration powered by vLLM
231
+ use_vllm: Optional[bool] = field(
232
+ default=False,
233
+ metadata={
234
+ "help": "Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept "
235
+ "unused for training, as vLLM will require one for generation. vLLM must be installed "
236
+ "(`pip install vllm`)."
237
+ },
238
+ )
239
+ vllm_device: Optional[str] = field(
240
+ default="auto",
241
+ metadata={
242
+ "help": "Device where vLLM generation will run, e.g. 'cuda:1'. If set to 'auto' (default), the system "
243
+ "will automatically select the next available GPU after the last one used for training. This assumes "
244
+ "that training has not already occupied all available GPUs."
245
+ },
246
+ )
247
+ vllm_gpu_memory_utilization: float = field(
248
+ default=0.9,
249
+ metadata={
250
+ "help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV "
251
+ "cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache "
252
+ "size and thus improve the model's throughput. However, if the value is too high, it may cause "
253
+ "out-of-memory (OOM) errors during initialization."
254
+ },
255
+ )
256
+ vllm_dtype: Optional[str] = field(
257
+ default="auto",
258
+ metadata={
259
+ "help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically "
260
+ "determined based on the model configuration. Find the supported values in the vLLM documentation."
261
+ },
262
+ )
263
+ vllm_max_model_len: Optional[int] = field(
264
+ default=None,
265
+ metadata={
266
+ "help": "If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced "
267
+ "`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model "
268
+ "context size, which might be much larger than the KV cache, leading to inefficiencies."
269
+ },
270
+ )
271
+ vllm_enable_prefix_caching: Optional[bool] = field(
272
+ default=True,
273
+ metadata={
274
+ "help": "Whether to enable prefix caching in vLLM. If set to `True` (default), ensure that the model and "
275
+ "the hardware support this feature."
276
+ },
277
+ )
278
+ vllm_guided_decoding_regex: Optional[str] = field(
279
+ default=None,
280
+ metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
281
+ )
282
+
283
+ # Parameters that control the training
284
+ learning_rate: float = field(
285
+ default=1e-6,
286
+ metadata={
287
+ "help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of "
288
+ "`transformers.TrainingArguments`."
289
+ },
290
+ )
291
+ beta: float = field(
292
+ default=0.04,
293
+ metadata={
294
+ "help": "KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving "
295
+ "training speed, but may be numerically unstable for long training runs."
296
+ },
297
+ )
298
+ num_iterations: int = field(
299
+ default=1,
300
+ metadata={"help": "Number of iterations per batch (denoted as μ in the algorithm)."},
301
+ )
302
+ epsilon: float = field(
303
+ default=0.2,
304
+ metadata={"help": "Epsilon value for clipping."},
305
+ )
306
+ epsilon_high: Optional[float] = field(
307
+ default=None,
308
+ metadata={
309
+ "help": "Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the "
310
+ "lower-bound specified in argument `epsilon`. Paper DAPO recommends `0.28`."
311
+ },
312
+ )
313
+ reward_weights: Optional[list[float]] = field(
314
+ default=None,
315
+ metadata={
316
+ "help": "Weights for each reward function. Must match the number of reward functions. If `None`, all "
317
+ "rewards are weighted equally with weight `1.0`."
318
+ },
319
+ )
320
+ sync_ref_model: bool = field(
321
+ default=False,
322
+ metadata={
323
+ "help": "Whether to synchronize the reference model with the active model every `ref_model_sync_steps` "
324
+ "steps, using the `ref_model_mixup_alpha` parameter."
325
+ },
326
+ )
327
+ ref_model_mixup_alpha: float = field(
328
+ default=0.6,
329
+ metadata={
330
+ "help": "α parameter from the TR-DPO paper, which controls the mix between the current policy and the "
331
+ "previous reference policy during updates. The reference policy is updated according to the equation: "
332
+ "`π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you must set `sync_ref_model=True`."
333
+ },
334
+ )
335
+ ref_model_sync_steps: int = field(
336
+ default=512,
337
+ metadata={
338
+ "help": "τ parameter from the TR-DPO paper, which determines how frequently the current policy is "
339
+ "synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`."
340
+ },
341
+ )
342
+
343
+ # Parameters that control the logging
344
+ log_completions: bool = field(
345
+ default=True,
346
+ metadata={
347
+ "help": "Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is "
348
+ "installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`."
349
+ },
350
+ )
351
+
352
+ report_to: Union[None, str, list[str]] = field(
353
+ default="wandb", metadata={"help": "The list of integrations to report the results and logs to."}
354
+ )
355
+
356
+ logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"})
357
+ logging_steps: float = field(
358
+ default=2,
359
+ metadata={
360
+ "help": (
361
+ "Log every X updates steps. Should be an integer or a float in range `[0,1)`. "
362
+ "If smaller than 1, will be interpreted as ratio of total training steps."
363
+ )
364
+ },
365
+ )
BioReason-main/bioreason/trainer/grpo_trainer.py ADDED
@@ -0,0 +1,905 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import time
17
+ import textwrap
18
+ import pandas as pd
19
+ from collections import defaultdict
20
+ from typing import Any, Callable, Optional, Union, Sized
21
+
22
+ import torch
23
+ import torch.utils.data
24
+ import transformers
25
+ from datasets import Dataset, IterableDataset
26
+ from packaging import version
27
+ from transformers import (
28
+ AriaForConditionalGeneration,
29
+ AriaProcessor,
30
+ AutoModelForCausalLM,
31
+ AutoModelForSequenceClassification,
32
+ AutoProcessor,
33
+ AutoTokenizer,
34
+ GenerationConfig,
35
+ PreTrainedModel,
36
+ PreTrainedTokenizerBase,
37
+ Qwen2VLForConditionalGeneration,
38
+ Qwen2_5_VLForConditionalGeneration,
39
+ Trainer,
40
+ TrainerCallback,
41
+ is_wandb_available,
42
+ )
43
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
44
+ from transformers.utils import is_peft_available
45
+
46
+ from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
47
+ from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
48
+ from trl.trainer.grpo_config import GRPOConfig
49
+ from trl.trainer.utils import generate_model_card, get_comet_experiment_url
50
+ # from trl import GRPOTrainer
51
+
52
+ from accelerate.utils import is_peft_model, set_seed, gather_object
53
+ import PIL.Image
54
+
55
+ import copy
56
+ from torch.utils.data import Sampler
57
+ import warnings
58
+
59
+ if is_peft_available():
60
+ from peft import PeftConfig, get_peft_model, prepare_model_for_kbit_training
61
+
62
+ if is_wandb_available():
63
+ import wandb
64
+
65
+ from bioreason.dna_modules.dna_module import DNABaseModule
66
+ from bioreason.trainer import DNALLMGRPOConfig
67
+ # What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
68
+ # rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
69
+ RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
70
+
71
+
72
+ class RepeatRandomSampler(Sampler):
73
+ """
74
+ Sampler that repeats the indices of a dataset in a structured manner.
75
+
76
+ Args:
77
+ data_source (`Sized`):
78
+ Dataset to sample from.
79
+ mini_repeat_count (`int`):
80
+ Number of times to repeat each index per batch.
81
+ batch_size (`int`, *optional*, defaults to `1`):
82
+ Number of unique indices per batch.
83
+ repeat_count (`int`, *optional*, defaults to `1`):
84
+ Number of times to repeat the full sampling process.
85
+ seed (`int` or `None`, *optional*, defaults to `None`):
86
+ Random seed for reproducibility.
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ data_source: Sized,
92
+ mini_repeat_count: int,
93
+ batch_size: int = 1,
94
+ repeat_count: int = 1,
95
+ seed: Optional[int] = None,
96
+ ):
97
+ self.data_source = data_source
98
+ self.mini_repeat_count = mini_repeat_count
99
+ self.batch_size = batch_size
100
+ self.repeat_count = repeat_count
101
+ self.num_samples = len(data_source)
102
+ self.seed = seed
103
+ self.generator = torch.Generator()
104
+ if seed is not None:
105
+ self.generator.manual_seed(seed)
106
+
107
+ def __iter__(self):
108
+ indexes = torch.randperm(self.num_samples, generator=self.generator).tolist()
109
+ indexes = [indexes[i : i + self.batch_size] for i in range(0, len(indexes), self.batch_size)]
110
+ indexes = [chunk for chunk in indexes if len(chunk) == self.batch_size]
111
+
112
+ for chunk in indexes:
113
+ for _ in range(self.repeat_count):
114
+ for index in chunk:
115
+ for _ in range(self.mini_repeat_count):
116
+ yield index
117
+
118
+ def __len__(self) -> int:
119
+ return self.num_samples * self.mini_repeat_count * self.repeat_count
120
+
121
+
122
+ class DNALLMGRPOTrainer(Trainer):
123
+ """
124
+ Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
125
+ paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300).
126
+
127
+ Example:
128
+
129
+ ```python
130
+ from datasets import load_dataset
131
+ from trl import GRPOTrainer
132
+
133
+ dataset = load_dataset("trl-lib/tldr", split="train")
134
+
135
+ trainer = GRPOTrainer(
136
+ model="Qwen/Qwen2-0.5B-Instruct",
137
+ reward_funcs="weqweasdas/RM-Gemma-2B",
138
+ train_dataset=dataset,
139
+ )
140
+
141
+ trainer.train()
142
+ ```
143
+
144
+ Args:
145
+ model (`Union[str, PreTrainedModel]`):
146
+ Model to be trained. Can be either:
147
+
148
+ - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
149
+ a path to a *directory* containing model weights saved using
150
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
151
+ loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
152
+ in `args.model_init_kwargs`.
153
+ - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
154
+ reward_funcs (`Union[RewardFunc, list[RewardFunc]]`):
155
+ Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
156
+ functions with the prompts and completions and sum the rewards. Can be either:
157
+
158
+ - A single reward function, such as:
159
+ - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a
160
+ path to a *directory* containing model weights saved using
161
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
162
+ using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
163
+ keyword arguments in `args.model_init_kwargs`.
164
+ - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
165
+ - A custom reward function: The function is provided with the prompts and the generated completions,
166
+ plus any additional columns in the dataset. It should return a list of rewards. For more details, see
167
+ [Using a custom reward function](#using-a-custom-reward-function).
168
+ - A list of reward functions, where each item can independently be any of the above types. Mixing different
169
+ types within the list (e.g., a string model ID and a custom reward function) is allowed.
170
+ args ([`GRPOConfig`], *optional*, defaults to `None`):
171
+ Configuration for this trainer. If `None`, a default configuration is used.
172
+ train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
173
+ Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is
174
+ ignored. The format of the samples can be either:
175
+
176
+ - [Standard](dataset_formats#standard): Each sample contains plain text.
177
+ - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
178
+ and content).
179
+ eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
180
+ Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
181
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
182
+ Processing class used to process the data. The padding side must be set to "left". If `None`, the
183
+ processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`].
184
+ reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`):
185
+ Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
186
+
187
+ - A single processing class: Used when `reward_funcs` contains only one reward function.
188
+ - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`.
189
+ If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is
190
+ `None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`].
191
+ For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]),
192
+ the corresponding entries in `reward_processing_classes` are ignored.
193
+ callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
194
+ List of callbacks to customize the training loop. Will add those to the list of default callbacks
195
+ detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
196
+
197
+ If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
198
+ method.
199
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
200
+ A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
201
+ model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
202
+ peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
203
+ PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
204
+ """
205
+
206
+ def __init__(
207
+ self,
208
+ model: Union[str, PreTrainedModel],
209
+ reward_funcs: Union[RewardFunc, list[RewardFunc]],
210
+ args: DNALLMGRPOConfig = None,
211
+ dna_module: DNABaseModule = None,
212
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
213
+ eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
214
+ processing_class: Optional[PreTrainedTokenizerBase] = None,
215
+ reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
216
+ callbacks: Optional[list[TrainerCallback]] = None,
217
+ optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
218
+ peft_config: Optional["PeftConfig"] = None,
219
+ freeze_dna_modules: Optional[bool] = False,
220
+ attn_implementation: str = "flash_attention_2",
221
+ torch_dtype: str = "bfloat16",
222
+ **kwargs,
223
+ ):
224
+ # Args
225
+ if args is None:
226
+ model_name = model if isinstance(model, str) else model.config._name_or_path
227
+ model_name = model_name.split("/")[-1]
228
+ args = GRPOConfig(f"{model_name}-GRPO")
229
+
230
+ self.dna_module = dna_module
231
+
232
+ # Models
233
+ # Trained model
234
+ model_init_kwargs = args.model_init_kwargs or {}
235
+ # FIXME
236
+ # Remember to modify it in the invernvl
237
+ model_init_kwargs["attn_implementation"] = attn_implementation
238
+ if model_init_kwargs.get("torch_dtype") is None:
239
+ model_init_kwargs["torch_dtype"] = torch_dtype
240
+
241
+ assert not isinstance(model, str), "model must NOT be a string in the current implementation"
242
+
243
+ torch_dtype = model_init_kwargs.get("torch_dtype")
244
+ if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
245
+ pass # torch_dtype is already a torch.dtype or "auto" or None
246
+ elif isinstance(torch_dtype, str): # it's a str, but not "auto"
247
+ torch_dtype = getattr(torch, torch_dtype)
248
+ else:
249
+ raise ValueError(
250
+ "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
251
+ f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
252
+ )
253
+ # Disable caching if gradient checkpointing is enabled (not supported)
254
+ model_init_kwargs["use_cache"] = (
255
+ False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
256
+ )
257
+
258
+ # LoRA
259
+ self.dna_modules_keywords = self.dna_module.get_dnallm_modules_keywords()
260
+ if peft_config is not None:
261
+ print("Applying LoRA...")
262
+ def find_all_linear_names(model, multimodal_keywords):
263
+ cls = torch.nn.Linear
264
+ lora_module_names = set()
265
+ for name, module in model.named_modules():
266
+ print('name:', name, 'module:', module)
267
+ # LoRA is not applied to the DNA modules
268
+ if any(mm_keyword in name for mm_keyword in multimodal_keywords):
269
+ continue
270
+ if isinstance(module, cls):
271
+ lora_module_names.add(name)
272
+ for m in lora_module_names: # needed for 16-bit
273
+ if "embed_tokens" in m:
274
+ lora_module_names.remove(m)
275
+ return list(lora_module_names)
276
+ target_modules = find_all_linear_names(model, self.dna_modules_keywords)
277
+ peft_config.target_modules = target_modules
278
+ model = prepare_model_for_kbit_training(model)
279
+ model = get_peft_model(model, peft_config)
280
+
281
+ # Freeze DNA modules
282
+ if freeze_dna_modules:
283
+ print("Freezing DNA modules...")
284
+ for p in model.dna_model.parameters():
285
+ p.requires_grad = False
286
+
287
+ # Make projection layer trainable
288
+ for p in model.dna_projection.parameters():
289
+ p.required_grad = True
290
+
291
+ # Compute the number of trainable parameters and print the parameter that is trainable
292
+ trainable_params = [p for p in model.parameters() if p.requires_grad]
293
+ total_params = sum(p.numel() for p in trainable_params)
294
+ # for n, p in model.named_parameters():
295
+ # if p.requires_grad:
296
+ # print(n, p.shape)
297
+ print(f"Total trainable parameters: {total_params}")
298
+
299
+ # Enable gradient checkpointing if requested
300
+ if args.gradient_checkpointing:
301
+ model = self._enable_gradient_checkpointing(model, args)
302
+
303
+ # Reference model
304
+ self.beta = args.beta
305
+ if self.beta == 0.0:
306
+ # If beta is 0.0, the reference model is not needed
307
+ self.ref_model = None
308
+ elif is_deepspeed_zero3_enabled():
309
+ self.ref_model = model_cls.from_pretrained(model_id, **model_init_kwargs)
310
+ elif is_peft_model(model):
311
+ # If PEFT is used, the reference model is not needed since the adapter can be disabled
312
+ # to revert to the initial model.
313
+ self.ref_model = None
314
+ else:
315
+ # If PEFT configuration is not provided, create a reference model based on the initial model.
316
+ self.ref_model = create_reference_model(model)
317
+
318
+ # Processing class
319
+ if processing_class is None:
320
+ processing_cls = self.dna_module.get_processing_class()
321
+
322
+ #if isinstance(model.text_model)
323
+ processing_class = processing_cls(tokenizer=model.text_tokenizer, dna_tokenizer=model.dna_tokenizer)
324
+ # print(model.tokenizer.chat_template)
325
+ for component, processing_keyword in self.dna_module.get_custom_processing_keywords():
326
+ if processing_keyword in kwargs:
327
+ # If we cannot find component in processing_class, return the processing_class itself
328
+ processing_component = getattr(processing_class, component, processing_class)
329
+ setattr(processing_component, processing_keyword, kwargs[processing_keyword])
330
+ if getattr(processing_class, "tokenizer", None) is not None:
331
+ pad_token_id = processing_class.tokenizer.pad_token_id
332
+ processing_class.pad_token_id = pad_token_id
333
+ processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
334
+ else:
335
+ assert isinstance(processing_class, PreTrainedTokenizerBase), "processing_class must be an instance of PreTrainedTokenizerBase if it has no tokenizer attribute"
336
+ pad_token_id = processing_class.pad_token_id
337
+
338
+ self.dna_module.post_model_init(model, processing_class)
339
+ self.dna_module.post_model_init(self.ref_model, processing_class)
340
+
341
+ # Reward functions
342
+ if not isinstance(reward_funcs, list):
343
+ reward_funcs = [reward_funcs]
344
+ for i, reward_func in enumerate(reward_funcs):
345
+ if isinstance(reward_func, str):
346
+ reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
347
+ reward_func, num_labels=1, **model_init_kwargs
348
+ )
349
+ self.reward_funcs = reward_funcs
350
+
351
+ # Reward processing class
352
+ if reward_processing_classes is None:
353
+ reward_processing_classes = [None] * len(reward_funcs)
354
+ elif not isinstance(reward_processing_classes, list):
355
+ reward_processing_classes = [reward_processing_classes]
356
+ else:
357
+ if len(reward_processing_classes) != len(reward_funcs):
358
+ raise ValueError("The number of reward processing classes must match the number of reward functions.")
359
+
360
+ for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
361
+ if isinstance(reward_func, PreTrainedModel):
362
+ if reward_processing_class is None:
363
+ reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
364
+ if reward_processing_class.pad_token_id is None:
365
+ reward_processing_class.pad_token = reward_processing_class.eos_token
366
+ # The reward model computes the reward for the latest non-padded token in the input sequence.
367
+ # So it's important to set the pad token ID to the padding token ID of the processing class.
368
+ reward_func.config.pad_token_id = reward_processing_class.pad_token_id
369
+ reward_processing_classes[i] = reward_processing_class
370
+ self.reward_processing_classes = reward_processing_classes
371
+
372
+ # Data collator
373
+ def data_collator(features): # No data collation is needed in GRPO
374
+ return features
375
+
376
+ # Training arguments
377
+ self.max_prompt_length = args.max_prompt_length
378
+ self.max_prompt_length = None
379
+ if args.max_prompt_length is not None:
380
+ warnings.warn("Setting max_prompt_length is currently not supported, it has been set to None")
381
+
382
+ self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
383
+ self.num_generations = args.num_generations # = G in the GRPO paper
384
+ self.generation_config = GenerationConfig(
385
+ max_new_tokens=self.max_completion_length,
386
+ do_sample=True,
387
+ temperature=0.6,
388
+ top_p=0.95,
389
+ top_k=20,
390
+ pad_token_id=pad_token_id,
391
+ )
392
+ if hasattr(self.dna_module, "get_eos_token_id"): # For InternVL
393
+ self.generation_config.eos_token_id = self.dna_module.get_eos_token_id(processing_class)
394
+ self.beta = args.beta
395
+ self.epsilon_low = args.epsilon
396
+ self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon
397
+
398
+ # Multi-step
399
+ self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
400
+ # Tracks the number of iterations (forward + backward passes), including those within a gradient accumulation cycle
401
+ self._step = 0
402
+ # Buffer the batch to reuse generated outputs across multiple updates
403
+ self._buffered_inputs = [None] * args.gradient_accumulation_steps
404
+
405
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
406
+ # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
407
+ # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
408
+ # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
409
+ # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
410
+ # This acts as a flag to indicate that the warning has already been issued.
411
+ model.warnings_issued["estimate_tokens"] = True
412
+
413
+ # Initialize the metrics
414
+ self._metrics = defaultdict(list)
415
+ self.log_completions = args.log_completions
416
+
417
+ super().__init__(
418
+ model=model,
419
+ args=args,
420
+ data_collator=data_collator,
421
+ train_dataset=train_dataset,
422
+ eval_dataset=eval_dataset,
423
+ processing_class=processing_class,
424
+ callbacks=callbacks,
425
+ optimizers=optimizers,
426
+ )
427
+
428
+ # Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations
429
+ num_processes = self.accelerator.num_processes
430
+ global_batch_size = args.per_device_train_batch_size * num_processes
431
+ possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
432
+ if self.num_generations not in possible_values:
433
+ raise ValueError(
434
+ f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
435
+ f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train "
436
+ f"batch size, the valid values for the number of generations are: {possible_values}."
437
+ )
438
+ if self.args.eval_strategy != "no":
439
+ global_batch_size = args.per_device_eval_batch_size * num_processes
440
+ possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
441
+ if self.num_generations not in possible_values:
442
+ raise ValueError(
443
+ f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
444
+ f"divisible by the number of generations per prompt ({self.num_generations}). Given the current "
445
+ f"eval batch size, the valid values for the number of generations are: {possible_values}."
446
+ )
447
+
448
+ # Ensure each process receives a unique seed to prevent duplicate completions when generating with
449
+ # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
450
+ # it's safer to set it in all cases.
451
+ set_seed(args.seed, device_specific=True)
452
+
453
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
454
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
455
+ # self.model_accepts_loss_kwargs to False to enable scaling.
456
+ self.model_accepts_loss_kwargs = False
457
+
458
+ if self.ref_model is not None:
459
+ # if self.is_deepspeed_enabled:
460
+ if is_deepspeed_zero3_enabled():
461
+ self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
462
+ else:
463
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
464
+
465
+ for i, reward_func in enumerate(self.reward_funcs):
466
+ if isinstance(reward_func, PreTrainedModel):
467
+ self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
468
+
469
+ def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: GRPOConfig) -> PreTrainedModel:
470
+ """Enables gradient checkpointing for the model."""
471
+ # Ensure use_cache is disabled
472
+ model.config.use_cache = False
473
+
474
+ # Enable gradient checkpointing on the base model for PEFT
475
+ if is_peft_model(model):
476
+ model.base_model.gradient_checkpointing_enable()
477
+ # Enable gradient checkpointing for non-PEFT models
478
+ else:
479
+ if getattr(model, "language_model", None) is not None:
480
+ # For InternVL; these operations are copied from the original training script of InternVL
481
+ model.language_model.config.use_cache = False
482
+ model.dna_model.gradient_checkpointing = True
483
+ model.dna_model.encoder.gradient_checkpointing = True
484
+ model.language_model._set_gradient_checkpointing()
485
+ # This line is necessary, otherwise the `model.gradient_checkpointing_enable()` will be executed during the training process, leading to an error since InternVL does not support this operation.
486
+ args.gradient_checkpointing = False
487
+ else:
488
+ model.gradient_checkpointing_enable()
489
+
490
+ gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
491
+ use_reentrant = (
492
+ "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
493
+ )
494
+
495
+ if use_reentrant:
496
+ model.enable_input_require_grads()
497
+
498
+ return model
499
+
500
+ def _set_signature_columns_if_needed(self):
501
+ # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
502
+ # By default, this method sets `self._signature_columns` to the model's expected inputs.
503
+ # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
504
+ # Instead, we set them to the columns expected by the `training_step` method, hence the override.
505
+ if self._signature_columns is None:
506
+ self._signature_columns = ["prompt"]
507
+
508
+
509
+ # Get the per-token log probabilities for the completions for the model and the reference model
510
+ def _get_per_token_logps(self, model, input_ids, attention_mask, **custom_multimodal_inputs):
511
+ logits = model(input_ids=input_ids, attention_mask=attention_mask, **custom_multimodal_inputs).logits # (B, L, V)
512
+ logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
513
+ input_ids = input_ids[:, 1:] # (B, L-1), exclude the first input ID since we don't have logits for it
514
+ # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
515
+ per_token_logps = []
516
+ for logits_row, input_ids_row in zip(logits, input_ids):
517
+ log_probs = logits_row.log_softmax(dim=-1)
518
+ token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
519
+ per_token_logps.append(token_log_prob)
520
+ return torch.stack(per_token_logps)
521
+
522
+
523
+ def _prepare_inputs(self, inputs):
524
+ # Simple pass-through, just like original
525
+ return inputs
526
+
527
+ def _get_key_from_inputs(self, x, key):
528
+ ele = x.get(key, None)
529
+ assert ele is not None, f"The key {key} is not found in the input"
530
+ if isinstance(ele, list):
531
+ return [e for e in ele]
532
+ else:
533
+ return [ele]
534
+
535
+ def _generate_and_score_completions(self, inputs: dict[str, Union[torch.Tensor, Any]], model) -> dict[str, Union[torch.Tensor, Any]]:
536
+ device = self.accelerator.device
537
+ prompts = [x["prompt"] for x in inputs]
538
+ prompts_text = self.dna_module.prepare_prompt(self.processing_class, inputs)
539
+ # Handle both pre-loaded images and image paths
540
+ batch_dna_sequences = []
541
+ print("_generate_and_score_completions (GRPO):")
542
+ for x in inputs:
543
+ #print('---')
544
+ #print(x)
545
+ if 'dna_sequences' in x:
546
+ dnas = self._get_key_from_inputs(x, "dna_sequences")
547
+
548
+ for dna in dnas:
549
+ # clean if desired
550
+ pass
551
+ batch_dna_sequences.append(dnas)
552
+ # NOTE: typically appends dna, so dna_sequences is all the dna in one list
553
+ # odd. trying this instead
554
+
555
+
556
+ prompt_inputs = self.dna_module.prepare_model_inputs(
557
+ self.processing_class,
558
+ model,
559
+ prompts_text,
560
+ batch_dna_sequences,
561
+ return_tensors="pt",
562
+ padding=True,
563
+ padding_side="left",
564
+ add_special_tokens=False,
565
+ )
566
+
567
+ prompt_inputs = super()._prepare_inputs(prompt_inputs)
568
+ prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
569
+
570
+ # max_prompt_length is not supported yet
571
+ # if self.max_prompt_length is not None:
572
+ # prompt_ids = prompt_ids[:, -self.max_prompt_length :]
573
+ # prompt_inputs["input_ids"] = prompt_ids
574
+ # prompt_mask = prompt_mask[:, -self.max_prompt_length :]
575
+ # prompt_inputs["attention_mask"] = prompt_mask
576
+
577
+ # Generate completions
578
+ start = time.time()
579
+ with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
580
+ kwargs = {k: v for k, v in prompt_inputs.items() if k not in self.dna_module.get_non_generate_params()}
581
+ generate_returned_result = unwrapped_model.generate(
582
+ **kwargs,
583
+ generation_config=self.generation_config
584
+ )
585
+ end = time.time()
586
+ print(f"Generation time: {end - start:.9f} seconds")
587
+ prompt_length = prompt_ids.size(1)
588
+ if not self.dna_module.is_embeds_input():
589
+ prompt_completion_ids = generate_returned_result
590
+ prompt_ids = prompt_completion_ids[:, :prompt_length]
591
+ completion_ids = prompt_completion_ids[:, prompt_length:]
592
+ else:
593
+ # In this case, the input of the LLM backbone is the embedding of the combination of the image and text prompt
594
+ # So the returned result of the `generate` method only contains the completion ids
595
+ completion_ids = generate_returned_result
596
+ prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
597
+
598
+ # Mask everything after the first EOS token
599
+ # print('completion:', completion_ids)
600
+ # print('generate_returned_result', generate_returned_result, generate_returned_result.shape)
601
+ # print('prompt_inputs["input_ids"]', prompt_inputs["input_ids"], prompt_inputs["input_ids"].shape)
602
+ # print('prompt_ids', prompt_ids, prompt_ids.shape)
603
+ # print('prompt_length', prompt_length)
604
+ # print('prompt_completion_ids', prompt_completion_ids, prompt_completion_ids.shape)
605
+ is_eos = completion_ids == self.processing_class.eos_token_id
606
+ eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
607
+ eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
608
+ sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
609
+ completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
610
+
611
+ # Concatenate prompt_mask with completion_mask for logit computation
612
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
613
+
614
+ # Get the multimodal inputs
615
+ multimodal_keywords = self.dna_module.get_custom_multimodal_keywords()
616
+ multimodal_inputs = {k: prompt_inputs[k] if k in prompt_inputs else None for k in multimodal_keywords}
617
+ with torch.no_grad():
618
+ # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip its
619
+ # computation here, and use per_token_logps.detach() instead.
620
+ if self.num_iterations > 1:
621
+ old_per_token_logps = self._get_per_token_logps(
622
+ model, prompt_completion_ids, attention_mask, **multimodal_inputs
623
+ )
624
+ old_per_token_logps = old_per_token_logps[:, prompt_length - 1:]
625
+ else:
626
+ old_per_token_logps = None
627
+
628
+ if self.beta == 0.0:
629
+ ref_per_token_logps = None
630
+ elif self.ref_model is not None:
631
+ ref_per_token_logps = self._get_per_token_logps(
632
+ self.ref_model, prompt_completion_ids, attention_mask, **multimodal_inputs
633
+ )
634
+ else:
635
+ with self.accelerator.unwrap_model(model).disable_adapter():
636
+ ref_per_token_logps = self._get_per_token_logps(
637
+ model, prompt_completion_ids, attention_mask, **multimodal_inputs
638
+ )
639
+ if ref_per_token_logps is not None:
640
+ ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1:]
641
+
642
+ # Decode the generated completions
643
+ completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
644
+ if is_conversational(inputs[0]):
645
+ completions = [[{"role": "assistant", "content": completion}] for completion in completions_text]
646
+ else:
647
+ completions = completions_text
648
+ # Compute the rewards
649
+ # No need to duplicate prompts as we're not generating multiple completions per prompt
650
+ print("Reward calculation...")
651
+ rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
652
+ for i, (reward_func, reward_processing_class) in enumerate(
653
+ zip(self.reward_funcs, self.reward_processing_classes)
654
+ ):
655
+ if isinstance(reward_func, PreTrainedModel):
656
+ if is_conversational(inputs[0]):
657
+ messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
658
+ texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
659
+ else:
660
+ texts = [p + c for p, c in zip(prompts, completions)]
661
+ reward_inputs = reward_processing_class(
662
+ texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
663
+ )
664
+ reward_inputs = super()._prepare_inputs(reward_inputs)
665
+ with torch.inference_mode():
666
+ rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
667
+ else:
668
+ # Repeat all input columns (but "prompt" and "completion") to match the number of generations
669
+ reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
670
+ for key in reward_kwargs:
671
+ for example in inputs:
672
+ # No need to duplicate prompts as we're not generating multiple completions per prompt
673
+ # reward_kwargs[key].extend([example[key]] * self.num_generations)
674
+ reward_kwargs[key].extend([example[key]])
675
+ output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
676
+ rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
677
+
678
+ # Gather rewards across processes
679
+ rewards_per_func = self.accelerator.gather(rewards_per_func)
680
+
681
+ # Sum the rewards from all reward functions
682
+ rewards = rewards_per_func.sum(dim=1)
683
+
684
+ # Compute grouped-wise rewards
685
+ # Each group consists of num_generations completions for the same prompt
686
+ mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
687
+ std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
688
+
689
+ # Normalize the rewards to compute the advantages
690
+ mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
691
+ std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
692
+ advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
693
+
694
+ # Get only the local slice of advantages
695
+ process_slice = slice(
696
+ self.accelerator.process_index * len(prompts),
697
+ (self.accelerator.process_index + 1) * len(prompts),
698
+ )
699
+ advantages = advantages[process_slice]
700
+
701
+ # Log the metrics
702
+ print("Logging metrics...")
703
+ completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
704
+ self._metrics["completion_length"].append(completion_length)
705
+
706
+ reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
707
+ for i, reward_func in enumerate(self.reward_funcs):
708
+ if isinstance(reward_func, PreTrainedModel):
709
+ reward_func_name = reward_func.config._name_or_path.split("/")[-1]
710
+ else:
711
+ reward_func_name = reward_func.__name__
712
+ self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
713
+
714
+ self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
715
+
716
+ self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())
717
+
718
+ print(self.log_completions, self.state.global_step, self.args.logging_steps, self.args.report_to)
719
+ if (
720
+ self.log_completions
721
+ and self.state.global_step % self.args.logging_steps == 0
722
+ and "wandb" in self.args.report_to
723
+ ):
724
+ timestamp = time.time()
725
+
726
+ # Get the length of one of the other arrays
727
+ num_items = len(gather_object(prompts_text))
728
+
729
+ table = {
730
+ "step": [f"{self.state.global_step}_{timestamp}"] * num_items, # Repeat to match length
731
+ "prompt": gather_object(prompts_text),
732
+ "completion": gather_object(completions_text),
733
+ "reward": rewards.tolist(),
734
+ }
735
+ df = pd.DataFrame(table)
736
+
737
+ if wandb.run is not None and self.accelerator.is_main_process:
738
+ wandb.log({f"completions_{self.state.global_step}_{timestamp}": wandb.Table(dataframe=df)})
739
+
740
+ return {
741
+ "prompt_ids": prompt_ids,
742
+ "prompt_mask": prompt_mask,
743
+ "completion_ids": completion_ids,
744
+ "completion_mask": completion_mask,
745
+ "old_per_token_logps": old_per_token_logps,
746
+ "ref_per_token_logps": ref_per_token_logps,
747
+ "advantages": advantages,
748
+ "multimodal_inputs": multimodal_inputs
749
+ }
750
+
751
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
752
+ if return_outputs:
753
+ raise ValueError("The GRPOTrainer does not support returning outputs")
754
+
755
+ # Check if we need to generate new completions or use buffered ones
756
+ print("index 1")
757
+ if self.state.global_step % self.num_iterations == 0:
758
+ inputs = self._generate_and_score_completions(inputs, model)
759
+ self._buffered_inputs[self._step % self.args.gradient_accumulation_steps] = inputs
760
+ else:
761
+ inputs = self._buffered_inputs[self._step % self.args.gradient_accumulation_steps]
762
+ self._step += 1
763
+
764
+ print("index 2")
765
+ # Get the prepared inputs
766
+ prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
767
+ completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
768
+ multimodal_inputs = inputs["multimodal_inputs"]
769
+
770
+ # Concatenate for full sequence
771
+ input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
772
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
773
+ print("index 3")
774
+ # Get the current policy's log probabilities
775
+
776
+ print("index 4")
777
+ per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, **multimodal_inputs)
778
+ # Get rid of the prompt (-1 because of the shift done in get_per_token_logps)
779
+ per_token_logps = per_token_logps[:, prompt_ids.size(1) - 1:]
780
+
781
+ # Get the advantages from inputs
782
+ advantages = inputs["advantages"]
783
+ print("index 5")
784
+ # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip its computation
785
+ # and use per_token_logps.detach() instead
786
+ old_per_token_logps = inputs["old_per_token_logps"] if self.num_iterations > 1 else per_token_logps.detach()
787
+
788
+ # Compute the policy ratio and clipped version
789
+ coef_1 = torch.exp(per_token_logps - old_per_token_logps)
790
+ coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
791
+ per_token_loss1 = coef_1 * advantages.unsqueeze(1)
792
+ per_token_loss2 = coef_2 * advantages.unsqueeze(1)
793
+ per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
794
+ print("index 6")
795
+ # Add KL penalty if beta > 0
796
+ if self.beta > 0:
797
+ ref_per_token_logps = inputs["ref_per_token_logps"]
798
+ per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
799
+ per_token_loss = per_token_loss + self.beta * per_token_kl
800
+
801
+ # Log KL divergence
802
+ mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
803
+ self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
804
+
805
+ # Compute final loss
806
+ print("Computing final loss...")
807
+ loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
808
+
809
+ # Log clip ratio
810
+ is_clipped = (per_token_loss1 < per_token_loss2).float()
811
+ clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum()
812
+ self._metrics["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item())
813
+
814
+ return loss
815
+
816
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
817
+ metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
818
+ logs = {**logs, **metrics}
819
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
820
+ super().log(logs, start_time)
821
+ else: # transformers<=4.46
822
+ super().log(logs)
823
+ self._metrics.clear()
824
+
825
+ def create_model_card(
826
+ self,
827
+ model_name: Optional[str] = None,
828
+ dataset_name: Optional[str] = None,
829
+ tags: Union[str, list[str], None] = None,
830
+ ):
831
+ """
832
+ Creates a draft of a model card using the information available to the `Trainer`.
833
+
834
+ Args:
835
+ model_name (`str` or `None`, *optional*, defaults to `None`):
836
+ Name of the model.
837
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
838
+ Name of the dataset used for training.
839
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
840
+ Tags to be associated with the model card.
841
+ """
842
+ if not self.is_world_process_zero():
843
+ return
844
+
845
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
846
+ base_model = self.model.config._name_or_path
847
+ else:
848
+ base_model = None
849
+
850
+ tags = tags or []
851
+ if isinstance(tags, str):
852
+ tags = [tags]
853
+
854
+ if hasattr(self.model.config, "unsloth_version"):
855
+ tags.append("unsloth")
856
+
857
+ citation = textwrap.dedent(
858
+ """\
859
+ @article{zhihong2024deepseekmath,
860
+ title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
861
+ author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
862
+ year = 2024,
863
+ eprint = {arXiv:2402.03300},
864
+ """
865
+ )
866
+
867
+ model_card = generate_model_card(
868
+ base_model=base_model,
869
+ model_name=model_name,
870
+ hub_model_id=self.hub_model_id,
871
+ dataset_name=dataset_name,
872
+ tags=tags,
873
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
874
+ comet_url=get_comet_experiment_url(),
875
+ trainer_name="GRPO",
876
+ trainer_citation=citation,
877
+ paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models",
878
+ paper_id="2402.03300",
879
+ )
880
+
881
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
882
+
883
+ def _get_train_sampler(self) -> Sampler:
884
+ """Returns a sampler that ensures proper data sampling for GRPO training."""
885
+ effective_batch_size = (
886
+ self.args.per_device_train_batch_size
887
+ * self.accelerator.num_processes
888
+ * self.args.gradient_accumulation_steps
889
+ )
890
+
891
+ return RepeatRandomSampler(
892
+ data_source=self.train_dataset,
893
+ mini_repeat_count=self.num_generations,
894
+ batch_size=effective_batch_size // self.num_generations,
895
+ repeat_count=self.num_iterations,
896
+ seed=self.args.seed,
897
+ )
898
+
899
+ def _get_eval_sampler(self, eval_dataset) -> Sampler:
900
+ """Returns a sampler for evaluation."""
901
+ return RepeatRandomSampler(
902
+ data_source=eval_dataset,
903
+ mini_repeat_count=self.num_generations,
904
+ seed=self.args.seed,
905
+ )
BioReason-main/bioreason/utils/__init__.py ADDED
File without changes