yuccaaa commited on
Commit
349aa7a
·
verified ·
1 Parent(s): acbfbc3

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. BIO/sft/qwen-production-08022302/v0-20250802-230250/checkpoint-1029-merged/added_tokens.json +24 -0
  2. BioReason-main/data/README.md +35 -0
  3. BioReason-main/data/VEP.ipynb +0 -0
  4. BioReason-main/grpo_trainer_lora_model/adapter_config.json +37 -0
  5. BioReason-main/grpo_trainer_lora_model/ds_config_stage2.json +41 -0
  6. BioReason_new/bioreason/dataset/__pycache__/protein.cpython-310.pyc +0 -0
  7. BioReason_new/bioreason/dataset/__pycache__/protein.cpython-311.pyc +0 -0
  8. BioReason_new/bioreason/dataset/__pycache__/utils.cpython-310.pyc +0 -0
  9. BioReason_new/bioreason/dataset/__pycache__/utils.cpython-311.pyc +0 -0
  10. BioReason_new/bioreason/dataset/protein.py +421 -0
  11. BioReason_new/bioreason/dataset/utils.py +135 -0
  12. BioReason_new/bioreason/models/__pycache__/protein_llm.cpython-310.pyc +0 -0
  13. BioReason_new/bioreason/models/__pycache__/protein_llm.cpython-311.pyc +0 -0
  14. BioReason_new/bioreason/models/pl/__pycache__/processing_pl.cpython-310.pyc +0 -0
  15. BioReason_new/bioreason/models/pl/__pycache__/processing_pl.cpython-311.pyc +0 -0
  16. BioReason_new/bioreason/models/pl/processing_pl.py +279 -0
  17. BioReason_new/bioreason/models/protein_llm.py +1093 -0
  18. BioReason_new/bioreason/protein_modules/_init_.py +7 -0
  19. BioReason_new/bioreason/protein_modules/protein_base_module.py +49 -0
  20. BioReason_new/bioreason/protein_modules/protein_module.py +257 -0
  21. BioReason_new/bioreason/trainer/__pycache__/contrast_trainer_new.cpython-310.pyc +0 -0
  22. BioReason_new/bioreason/trainer/__pycache__/contrast_trainer_new.cpython-311.pyc +0 -0
  23. BioReason_new/bioreason/trainer/_init_.py +11 -0
  24. BioReason_new/bioreason/trainer/contrast_trainer.py +372 -0
  25. BioReason_new/bioreason/trainer/contrast_trainer_new.py +659 -0
  26. BioReason_new/bioreason/trainer/grpo_config.py +338 -0
  27. BioReason_new/bioreason/trainer/grpo_trainer.py +719 -0
  28. BioReason_new/bioreason/utils/__pycache__/protein_utils.cpython-310.pyc +0 -0
  29. BioReason_new/bioreason/utils/__pycache__/protein_utils.cpython-311.pyc +0 -0
  30. BioReason_new/bioreason/utils/protein_utils.py +229 -0
  31. BioReason_new/readme.md +8 -0
  32. BioReason_new/reason.py +520 -0
  33. BioReason_new/run.sh +107 -0
  34. BioReason_new/run_contrast.sh +31 -0
  35. BioReason_new/train_contrastive.py +552 -0
  36. BioReason_new/train_protein_qwen.py +839 -0
  37. BioReason_new/wandb/debug-internal.log +28 -0
  38. BioReason_new/wandb/debug.log +23 -0
  39. BioReason_new/wandb/run-20250811_215805-k21eogb7/files/config.yaml +159 -0
  40. BioReason_new/wandb/run-20250811_215805-k21eogb7/files/output.log +21 -0
  41. BioReason_new/wandb/run-20250811_215805-k21eogb7/files/requirements.txt +233 -0
  42. BioReason_new/wandb/run-20250811_215805-k21eogb7/files/wandb-metadata.json +57 -0
  43. BioReason_new/wandb/run-20250811_215805-k21eogb7/files/wandb-summary.json +1 -0
  44. BioReason_new/wandb/run-20250811_215805-k21eogb7/logs/debug-internal.log +15 -0
  45. BioReason_new/wandb/run-20250811_215805-k21eogb7/logs/debug.log +22 -0
  46. BioReason_new/wandb/run-20250811_215805-k21eogb7/run-k21eogb7.wandb +0 -0
  47. BioReason_new/wandb/run-20250811_220309-2qgjwsxa/files/config.yaml +195 -0
  48. BioReason_new/wandb/run-20250811_220309-2qgjwsxa/files/output.log +30 -0
  49. BioReason_new/wandb/run-20250811_220309-2qgjwsxa/files/requirements.txt +233 -0
  50. BioReason_new/wandb/run-20250811_220309-2qgjwsxa/files/wandb-metadata.json +113 -0
BIO/sft/qwen-production-08022302/v0-20250802-230250/checkpoint-1029-merged/added_tokens.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</tool_call>": 151658,
3
+ "<tool_call>": 151657,
4
+ "<|box_end|>": 151649,
5
+ "<|box_start|>": 151648,
6
+ "<|endoftext|>": 151643,
7
+ "<|file_sep|>": 151664,
8
+ "<|fim_middle|>": 151660,
9
+ "<|fim_pad|>": 151662,
10
+ "<|fim_prefix|>": 151659,
11
+ "<|fim_suffix|>": 151661,
12
+ "<|im_end|>": 151645,
13
+ "<|im_start|>": 151644,
14
+ "<|image_pad|>": 151655,
15
+ "<|object_ref_end|>": 151647,
16
+ "<|object_ref_start|>": 151646,
17
+ "<|quad_end|>": 151651,
18
+ "<|quad_start|>": 151650,
19
+ "<|repo_name|>": 151663,
20
+ "<|video_pad|>": 151656,
21
+ "<|vision_end|>": 151653,
22
+ "<|vision_pad|>": 151654,
23
+ "<|vision_start|>": 151652
24
+ }
BioReason-main/data/README.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BioReasoning Data Curation
2
+
3
+ Jupyter notebooks for processing genetic variant data and creating ML datasets for biological reasoning tasks.
4
+
5
+ ## Notebooks
6
+
7
+ **Core Analysis**
8
+ - `BioReasoning_DataCuration_KEGG.ipynb` - KEGG pathway analysis with Claude API
9
+ - `Clinvar_Coding.ipynb` - ClinVar variant processing and gene mapping
10
+ - `Clinvar_SNV_Non_SNV.ipynb` - SNV/structural variant datasets with VEP annotations
11
+
12
+ **KEGG Pipeline**
13
+ - `KEGG_Data_1.ipynb` - KEGG network data processing and variant identification
14
+ - `KEGG_Data_2.ipynb` - Variant parsing and sequence generation
15
+ - `KEGG_Data_3.ipynb` - Final ML dataset creation with Q&A pairs
16
+
17
+ **Variant Prediction**
18
+ - `VEP.ipynb` - Variant effect prediction datasets (ClinVar, OMIM, eQTL)
19
+
20
+ ## Setup
21
+
22
+ ```bash
23
+ brew install brewsci/bio/edirect # For ClinVar (macOS)
24
+ export ANTHROPIC_API_KEY="your-key" # For KEGG analysis
25
+ ```
26
+
27
+ ## Usage
28
+
29
+ Each notebook has a configuration section - update paths/keys as needed, then run sequentially.
30
+
31
+ **Key Outputs:**
32
+ - KEGG biological reasoning datasets
33
+ - ClinVar variant-disease associations
34
+ - VEP prediction task datasets
35
+ - Genomic sequences with variant context
BioReason-main/data/VEP.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
BioReason-main/grpo_trainer_lora_model/adapter_config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "unsloth/qwen2.5-1.5b-instruct-unsloth-bnb-4bit",
5
+ "bias": "none",
6
+ "eva_config": null,
7
+ "exclude_modules": null,
8
+ "fan_in_fan_out": false,
9
+ "inference_mode": false,
10
+ "init_lora_weights": true,
11
+ "layer_replication": null,
12
+ "layers_pattern": null,
13
+ "layers_to_transform": null,
14
+ "loftq_config": {},
15
+ "lora_alpha": 64,
16
+ "lora_bias": false,
17
+ "lora_dropout": 0,
18
+ "megatron_config": null,
19
+ "megatron_core": "megatron.core",
20
+ "modules_to_save": null,
21
+ "peft_type": "LORA",
22
+ "r": 64,
23
+ "rank_pattern": {},
24
+ "revision": null,
25
+ "target_modules": [
26
+ "o_proj",
27
+ "gate_proj",
28
+ "v_proj",
29
+ "up_proj",
30
+ "q_proj",
31
+ "down_proj",
32
+ "k_proj"
33
+ ],
34
+ "task_type": "CAUSAL_LM",
35
+ "use_dora": false,
36
+ "use_rslora": false
37
+ }
BioReason-main/grpo_trainer_lora_model/ds_config_stage2.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bf16": {
3
+ "enabled": true
4
+ },
5
+ "optimizer": {
6
+ "type": "AdamW",
7
+ "params": {
8
+ "lr": "auto",
9
+ "betas": "auto",
10
+ "eps": "auto",
11
+ "weight_decay": "auto"
12
+ }
13
+ },
14
+ "scheduler": {
15
+ "type": "WarmupLR",
16
+ "params": {
17
+ "warmup_min_lr": "auto",
18
+ "warmup_max_lr": "auto",
19
+ "warmup_num_steps": "auto"
20
+ }
21
+ },
22
+ "zero_optimization": {
23
+ "stage": 2,
24
+ "offload_optimizer": {
25
+ "device": "cpu",
26
+ "pin_memory": true
27
+ },
28
+ "contiguous_gradients": true,
29
+ "overlap_comm": true,
30
+ "allgather_partitions": true,
31
+ "allgather_bucket_size": 5e8,
32
+ "reduce_scatter": true,
33
+ "reduce_bucket_size": 5e8
34
+ },
35
+ "gradient_accumulation_steps": "auto",
36
+ "gradient_clipping": "auto",
37
+ "steps_per_print": 2000,
38
+ "train_batch_size": "auto",
39
+ "train_micro_batch_size_per_gpu": "auto",
40
+ "wall_clock_breakdown": false
41
+ }
BioReason_new/bioreason/dataset/__pycache__/protein.cpython-310.pyc ADDED
Binary file (9.29 kB). View file
 
BioReason_new/bioreason/dataset/__pycache__/protein.cpython-311.pyc ADDED
Binary file (17 kB). View file
 
BioReason_new/bioreason/dataset/__pycache__/utils.cpython-310.pyc ADDED
Binary file (3.34 kB). View file
 
BioReason_new/bioreason/dataset/__pycache__/utils.cpython-311.pyc ADDED
Binary file (5.46 kB). View file
 
BioReason_new/bioreason/dataset/protein.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+ import sys
5
+ import torch
6
+ from torch.utils.data import Dataset, DataLoader
7
+ from typing import Any, Dict, List, Tuple
8
+
9
+ from bioreason.dataset.utils import torch_to_hf_dataset
10
+ from bioreason.models.pl.processing_pl import ProteinLLMProcessor
11
+ from trl.data_utils import maybe_apply_chat_template
12
+
13
+
14
+ class ProteinDataset(Dataset):
15
+ """Dataset for protein-text paired data."""
16
+
17
+ def __init__(self, data_dir: str):
18
+ """
19
+ Initialize the dataset by loading all JSON files from the given directory.
20
+
21
+ Args:
22
+ data_dir: Path to the directory containing JSON files
23
+ """
24
+ self.data_dir = data_dir
25
+ self.data = []
26
+
27
+ # Load all JSON files
28
+ json_files = sorted([f for f in os.listdir(data_dir) if f.endswith(".json")])
29
+
30
+ # Process each file
31
+ for filename in json_files:
32
+ file_path = os.path.join(data_dir, filename)
33
+
34
+ with open(file_path, "r", encoding="utf-8") as f:
35
+ items = json.load(f)
36
+ if isinstance(items, list):
37
+ for item in items:
38
+ processed_item = self._process_item(item)
39
+ self.data.append(processed_item)
40
+ else:
41
+ processed_item = self._process_item(items)
42
+ self.data.append(processed_item)
43
+
44
+ def _process_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
45
+ """
46
+ Process a single data item to format fields as required.
47
+
48
+ Args:
49
+ item: Original data item from JSON
50
+
51
+ Returns:
52
+ Processed data item
53
+ """
54
+ # Extract question as is
55
+ question = item.get("question", "")
56
+
57
+ # Convert answer to lowercase and strip whitespace
58
+ answer = item.get("answer", "").lower().strip()
59
+
60
+ # Combine reasoning steps into a single paragraph with newlines
61
+ reasoning_steps = item.get("reasoning", {}).get("reasoning_steps", [])
62
+ if isinstance(reasoning_steps, list):
63
+ reasoning = "\n".join(reasoning_steps)
64
+ else:
65
+ reasoning = str(reasoning_steps)
66
+
67
+ # Process protein sequence - remove any whitespace and convert to uppercase
68
+ protein_sequence = item.get("protein_sequence", "").replace(" ", "").upper().strip()
69
+
70
+ # Handle protein description/function
71
+ protein_description = item.get("protein_description", item.get("function", "")).strip()
72
+
73
+ return {
74
+ "question": question,
75
+ "answer": answer,
76
+ "reasoning": reasoning,
77
+ "protein_sequence": protein_sequence,
78
+ "protein_description": protein_description,
79
+ }
80
+
81
+ def __len__(self) -> int:
82
+ """Return the number of items in the dataset."""
83
+ return len(self.data)
84
+
85
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
86
+ """Return a specific item from the dataset."""
87
+ return self.data[idx]
88
+
89
+
90
+ def split_protein_dataset(
91
+ dataset: ProteinDataset,
92
+ train_ratio: float = 0.8,
93
+ val_ratio: float = 0.1,
94
+ test_ratio: float = 0.1,
95
+ seed: int = 42,
96
+ ) -> Tuple[ProteinDataset, ProteinDataset, ProteinDataset]:
97
+ """
98
+ Split a protein dataset into train, validation, and test sets.
99
+
100
+ Args:
101
+ dataset: The dataset to split
102
+ train_ratio: Proportion of data for training
103
+ val_ratio: Proportion of data for validation
104
+ test_ratio: Proportion of data for testing
105
+ seed: Random seed for reproducibility
106
+
107
+ Returns:
108
+ Tuple of (train_dataset, val_dataset, test_dataset)
109
+ """
110
+ # Calculate the size of each split
111
+ dataset_size = len(dataset)
112
+ train_size = int(train_ratio * dataset_size)
113
+ val_size = int(val_ratio * dataset_size)
114
+ test_size = dataset_size - train_size - val_size
115
+ assert train_ratio + val_ratio + test_ratio == 1.0, "Ratios must sum to 1"
116
+
117
+ # Set the random seed
118
+ torch.manual_seed(seed)
119
+ random.seed(seed)
120
+
121
+ # Split the dataset
122
+ train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
123
+ dataset, [train_size, val_size, test_size]
124
+ )
125
+
126
+ return train_dataset, val_dataset, test_dataset
127
+
128
+
129
+ def create_protein_dataloader(
130
+ data_dir: str,
131
+ batch_size: int = 2,
132
+ shuffle: bool = True,
133
+ num_workers: int = 2,
134
+ pin_memory: bool = True,
135
+ ) -> DataLoader:
136
+ """
137
+ Create a DataLoader for the protein dataset.
138
+
139
+ Args:
140
+ data_dir: Path to the directory containing JSON files
141
+ batch_size: Batch size for the dataloader
142
+ shuffle: Whether to shuffle the data
143
+ num_workers: Number of worker processes for loading data
144
+ pin_memory: Whether to pin memory for faster data transfer
145
+
146
+ Returns:
147
+ DataLoader for the protein dataset
148
+ """
149
+ dataset = ProteinDataset(data_dir)
150
+ return DataLoader(
151
+ dataset,
152
+ batch_size=batch_size,
153
+ shuffle=shuffle,
154
+ num_workers=num_workers,
155
+ pin_memory=pin_memory,
156
+ )
157
+
158
+
159
+ def get_format_protein_function(model_name: str) -> Any:
160
+ """
161
+ Get the appropriate format function for a given model name.
162
+ """
163
+ if model_name.lower() == "llm":
164
+ return format_protein_for_llm
165
+ elif model_name.lower() == "protein-llm":
166
+ return format_protein_for_protein_llm
167
+ else:
168
+ raise ValueError(f"Unsupported model name: {model_name}")
169
+
170
+
171
+ def format_protein_for_protein_llm(example: Dict[str, Any]) -> Dict[str, Any]:
172
+ """
173
+ Format a protein example into the required chat format for Protein-LLM.
174
+ """
175
+ return {
176
+ "prompt": [
177
+ {
178
+ "role": "user",
179
+ "content": [
180
+ {"type": "protein", "text": None},
181
+ {"type": "text", "text": example["question"].strip()},
182
+ ],
183
+ },
184
+ {
185
+ "role": "assistant",
186
+ "reasoning_content": example["reasoning"].strip(),
187
+ "content": [
188
+ {"type": "text", "text": f"Answer: {example['answer'].strip()}"},
189
+ ],
190
+ },
191
+ ],
192
+ "protein_sequences": [
193
+ example["protein_sequence"],
194
+ ],
195
+ "answer": example["answer"],
196
+ }
197
+
198
+
199
+ def format_protein_for_llm(example: Dict[str, Any]) -> Dict[str, Any]:
200
+ """
201
+ Format a protein example into the required chat format for LLM.
202
+ """
203
+ question = f"Protein sequence: {example['protein_sequence']}\nQuestion: {example['question']}"
204
+ return {
205
+ "prompt": [
206
+ {
207
+ "role": "user",
208
+ "content": [
209
+ {"type": "protein", "text": None},
210
+ {"type": "text", "text": question.strip()},
211
+ ],
212
+ },
213
+ {
214
+ "role": "assistant",
215
+ "reasoning_content": example["reasoning"].strip(),
216
+ "content": [
217
+ {"type": "text", "text": f"Answer: {example['answer'].strip()}"},
218
+ ],
219
+ },
220
+ ],
221
+ "protein_sequences": [
222
+ "",
223
+ ],
224
+ "answer": example["answer"],
225
+ }
226
+
227
+
228
+ def format_protein_contrastive(example: Dict[str, Any]) -> Dict[str, Any]:
229
+ """
230
+ Format a protein example for contrastive learning.
231
+ """
232
+ # return {
233
+ # "protein": example["protein"],
234
+ # "text": example["text"],
235
+ # }
236
+ protein_seq = example.get("protein_sequence") or example.get("protein") or ""
237
+ text_desc = (example.get("protein_description") or
238
+ example.get("text") or
239
+ example.get("description") or
240
+ example.get("function") or "")
241
+
242
+ return {
243
+ "protein": protein_seq,
244
+ "text": text_desc,
245
+ "protein_sequence": protein_seq, # 保持向后兼容
246
+ "text_description": text_desc, # 保持向后兼容
247
+ }
248
+
249
+ def protein_text_collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, List[str]]:
250
+ """
251
+ 修复后的 collate function for protein-text contrastive learning.
252
+ """
253
+ protein_sequences = []
254
+ text_sequences = []
255
+
256
+ for item in batch:
257
+ # 尝试多个可能的字段名
258
+ protein_seq = (item.get("protein_sequence") or
259
+ item.get("protein") or "")
260
+ text_seq = (item.get("text_description") or
261
+ item.get("text") or
262
+ item.get("description") or "")
263
+
264
+ protein_sequences.append(protein_seq)
265
+ text_sequences.append(text_seq)
266
+
267
+ return {
268
+ "protein_sequences": protein_sequences,
269
+ "text_sequences": text_sequences,
270
+ }
271
+
272
+
273
+ def protein_llm_collate_fn(
274
+ examples: List[Dict],
275
+ processor: ProteinLLMProcessor,
276
+ max_length_text: int,
277
+ max_length_protein: int,
278
+ return_answer_in_batch: bool = False,
279
+ ) -> Dict:
280
+ """
281
+ Custom collate function for Protein-LLM models.
282
+
283
+ Creates a batch with proper labels for supervised fine-tuning where only
284
+ the assistant responses contribute to the loss calculation.
285
+ """
286
+ prompts_text = [
287
+ maybe_apply_chat_template(example, processor)["prompt"] for example in examples
288
+ ]
289
+ batch_protein_sequences = [example["protein_sequences"] for example in examples]
290
+
291
+ batch = processor(
292
+ text=prompts_text,
293
+ batch_protein_sequences=batch_protein_sequences,
294
+ return_tensors="pt",
295
+ padding=True,
296
+ padding_side="left",
297
+ add_special_tokens=False,
298
+ max_length_text=max_length_text,
299
+ max_length_protein=max_length_protein,
300
+ )
301
+
302
+ # Create labels tensor filled with -100 (ignored in loss calculation)
303
+ labels = torch.full_like(batch["input_ids"], -100)
304
+
305
+ # Get token IDs for special markers
306
+ assistant_start_marker = "<|im_start|>assistant\n"
307
+ im_end_marker = "<|im_end|>"
308
+
309
+ assistant_start_token_ids = processor.tokenizer.encode(
310
+ assistant_start_marker, add_special_tokens=False
311
+ )
312
+ im_end_token_ids = processor.tokenizer.encode(
313
+ im_end_marker, add_special_tokens=False
314
+ )
315
+
316
+ # Convert token arrays to tensors for faster comparison
317
+ assistant_marker_tensor = torch.tensor(
318
+ assistant_start_token_ids, device=batch["input_ids"].device
319
+ )
320
+ im_end_marker_tensor = torch.tensor(
321
+ im_end_token_ids, device=batch["input_ids"].device
322
+ )
323
+
324
+ # Get dimensions for easier reference
325
+ assistant_marker_len = len(assistant_start_token_ids)
326
+ im_end_marker_len = len(im_end_token_ids)
327
+
328
+ # For each sequence in the batch
329
+ for i in range(batch["input_ids"].shape[0]):
330
+ input_ids = batch["input_ids"][i]
331
+ seq_len = input_ids.size(0)
332
+
333
+ # Track assistant sections
334
+ assistant_sections = []
335
+
336
+ # Find all assistant start markers
337
+ start_positions = []
338
+ for pos in range(seq_len - assistant_marker_len + 1):
339
+ if torch.all(
340
+ input_ids[pos : pos + assistant_marker_len] == assistant_marker_tensor
341
+ ):
342
+ start_positions.append(
343
+ pos + assistant_marker_len
344
+ ) # Store position after marker
345
+
346
+ # Find all end markers
347
+ end_positions = []
348
+ for pos in range(seq_len - im_end_marker_len + 1):
349
+ if torch.all(
350
+ input_ids[pos : pos + im_end_marker_len] == im_end_marker_tensor
351
+ ):
352
+ end_positions.append(pos) # Store position at start of end marker
353
+
354
+ # Match start and end markers to create sections
355
+ for start_pos in start_positions:
356
+ # Find the next end marker after this start position
357
+ valid_ends = [pos for pos in end_positions if pos > start_pos]
358
+ if valid_ends:
359
+ end_pos = min(valid_ends) # Take the first end marker after start
360
+ # Only include content between markers (not the markers themselves)
361
+ if start_pos < end_pos:
362
+ assistant_sections.append((start_pos, end_pos))
363
+ else:
364
+ # If no end marker, assume the section runs to the end of the sequence
365
+ assistant_sections.append((start_pos, seq_len))
366
+
367
+ # Set labels for all identified assistant sections
368
+ for start_pos, end_pos in assistant_sections:
369
+ if start_pos < end_pos and start_pos < seq_len:
370
+ end_pos = min(end_pos, seq_len) # Safety check
371
+ labels[i, start_pos:end_pos] = input_ids[start_pos:end_pos]
372
+
373
+ # Also mask padding tokens
374
+ labels[batch["input_ids"] == processor.tokenizer.pad_token_id] = -100
375
+
376
+ # Add labels to batch
377
+ batch["labels"] = labels
378
+
379
+ # Add answer to batch
380
+ if return_answer_in_batch:
381
+ batch["answer"] = [example["answer"].strip() for example in examples]
382
+
383
+ return batch
384
+
385
+
386
+ def protein_collate_fn(
387
+ batch: List[Dict[str, Any]],
388
+ protein_tokenizer: Any,
389
+ label2id: Dict[str, int],
390
+ max_length: int = 1024,
391
+ ) -> Dict[str, Any]:
392
+ """
393
+ Custom collate function for protein models.
394
+ """
395
+ protein_sequences = [item["protein_sequence"] for item in batch]
396
+
397
+ # Tokenize protein sequences
398
+ tokenized_protein = protein_tokenizer(
399
+ protein_sequences,
400
+ padding=True,
401
+ truncation=True,
402
+ max_length=max_length,
403
+ return_tensors="pt",
404
+ )
405
+
406
+ # Get labels
407
+ labels = []
408
+ for item in batch:
409
+ label = label2id[item["answer"]]
410
+ labels.append(label)
411
+
412
+ # Create labels tensor
413
+ labels_tensor = torch.tensor(labels, dtype=torch.long)
414
+
415
+ tokenized_batch = {
416
+ "protein_ids": tokenized_protein.input_ids,
417
+ "protein_attention_mask": tokenized_protein.attention_mask,
418
+ "labels": labels_tensor,
419
+ }
420
+
421
+ return tokenized_batch
BioReason_new/bioreason/dataset/utils.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import Dataset as HfDataset
2
+ import torch
3
+ from torch.utils.data import Dataset
4
+ from typing import Dict, Any
5
+
6
+
7
+ def torch_to_hf_dataset(torch_dataset: Dataset) -> HfDataset:
8
+ """
9
+ Convert a PyTorch Dataset to a Hugging Face Dataset.
10
+
11
+ Args:
12
+ torch_dataset: PyTorch Dataset to convert
13
+
14
+ Returns:
15
+ HfDataset: Converted Hugging Face Dataset
16
+ """
17
+ # Extract all data from PyTorch dataset
18
+ data = []
19
+ for i in range(len(torch_dataset)):
20
+ data.append(torch_dataset[i])
21
+
22
+ return HfDataset.from_list(data)
23
+
24
+
25
+ def truncate_protein(example: Dict[str, Any], truncate_protein_per_side: int = 1024) -> Dict[str, Any]:
26
+ """
27
+ Truncate protein sequences to a maximum length.
28
+
29
+ Args:
30
+ example: Dataset example containing protein sequences
31
+ truncate_protein_per_side: Maximum length to keep from each side
32
+
33
+ Returns:
34
+ Dict[str, Any]: Modified example with truncated protein sequences
35
+ """
36
+ if "protein_sequence" in example:
37
+ protein_seq = example["protein_sequence"]
38
+ if len(protein_seq) > 2 * truncate_protein_per_side:
39
+ # Keep the first and last parts of the sequence
40
+ truncated_seq = protein_seq[:truncate_protein_per_side] + protein_seq[-truncate_protein_per_side:]
41
+ example["protein_sequence"] = truncated_seq
42
+
43
+ if "protein_sequences" in example:
44
+ truncated_sequences = []
45
+ for seq in example["protein_sequences"]:
46
+ if len(seq) > 2 * truncate_protein_per_side:
47
+ truncated_seq = seq[:truncate_protein_per_side] + seq[-truncate_protein_per_side:]
48
+ truncated_sequences.append(truncated_seq)
49
+ else:
50
+ truncated_sequences.append(seq)
51
+ example["protein_sequences"] = truncated_sequences
52
+
53
+ return example
54
+
55
+
56
+ def clean_protein_sequence(sequence: str) -> str:
57
+ """
58
+ Clean protein sequence by removing invalid characters and normalizing.
59
+
60
+ Args:
61
+ sequence: Raw protein sequence
62
+
63
+ Returns:
64
+ str: Cleaned protein sequence
65
+ """
66
+ # Standard amino acid codes
67
+ valid_amino_acids = set("ACDEFGHIKLMNPQRSTVWY")
68
+
69
+ # Remove whitespace and convert to uppercase
70
+ sequence = sequence.upper().replace(" ", "").replace("\n", "").replace("\t", "")
71
+
72
+ # Keep only valid amino acid characters
73
+ cleaned_sequence = "".join([char for char in sequence if char in valid_amino_acids])
74
+
75
+ return cleaned_sequence
76
+
77
+
78
+ def validate_protein_example(example: Dict[str, Any]) -> bool:
79
+ """
80
+ Validate that a protein example has required fields and valid data.
81
+
82
+ Args:
83
+ example: Dataset example to validate
84
+
85
+ Returns:
86
+ bool: True if example is valid, False otherwise
87
+ """
88
+ # Check for required fields
89
+ required_fields = ["protein_sequence"]
90
+ for field in required_fields:
91
+ if field not in example or not example[field]:
92
+ return False
93
+
94
+ # Check protein sequence validity
95
+ protein_seq = example["protein_sequence"]
96
+ if not isinstance(protein_seq, str) or len(protein_seq.strip()) == 0:
97
+ return False
98
+
99
+ # Check for minimum sequence length (e.g., at least 10 amino acids)
100
+ cleaned_seq = clean_protein_sequence(protein_seq)
101
+ if len(cleaned_seq) < 10:
102
+ return False
103
+
104
+ return True
105
+
106
+
107
+ def format_protein_qa_example(example: Dict[str, Any]) -> Dict[str, Any]:
108
+ """
109
+ Format a protein example for question-answering tasks.
110
+
111
+ Args:
112
+ example: Raw protein example
113
+
114
+ Returns:
115
+ Dict[str, Any]: Formatted example
116
+ """
117
+ # Clean protein sequence
118
+ if "protein_sequence" in example:
119
+ example["protein_sequence"] = clean_protein_sequence(example["protein_sequence"])
120
+
121
+ # Ensure answer is properly formatted
122
+ if "answer" in example:
123
+ answer = example["answer"]
124
+ if isinstance(answer, str):
125
+ example["answer"] = answer.strip().lower()
126
+ else:
127
+ example["answer"] = str(answer).strip().lower()
128
+
129
+ # Format question if needed
130
+ if "question" in example:
131
+ question = example["question"]
132
+ if not question.endswith("?"):
133
+ example["question"] = question.strip() + "?"
134
+
135
+ return example
BioReason_new/bioreason/models/__pycache__/protein_llm.cpython-310.pyc ADDED
Binary file (10.1 kB). View file
 
BioReason_new/bioreason/models/__pycache__/protein_llm.cpython-311.pyc ADDED
Binary file (20.9 kB). View file
 
BioReason_new/bioreason/models/pl/__pycache__/processing_pl.cpython-310.pyc ADDED
Binary file (8.08 kB). View file
 
BioReason_new/bioreason/models/pl/__pycache__/processing_pl.cpython-311.pyc ADDED
Binary file (12.3 kB). View file
 
BioReason_new/bioreason/models/pl/processing_pl.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union, Dict, Any, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+
7
+ from transformers import AutoTokenizer
8
+ from transformers.processing_utils import (
9
+ CommonKwargs,
10
+ ProcessingKwargs,
11
+ ProcessorMixin,
12
+ Unpack,
13
+ )
14
+ from transformers.feature_extraction_utils import BatchFeature
15
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
16
+ from transformers.utils import logging
17
+
18
+ from bioreason.utils.protein_utils import ProteinInput
19
+
20
+ class ProteinLLMKwargs(CommonKwargs):
21
+ """Keyword arguments specific to protein processing"""
22
+ max_length_text: Optional[int]
23
+ max_length_protein: Optional[int]
24
+
25
+
26
+ class ProteinLLMProcessorKwargs(ProcessingKwargs, total=False):
27
+ """Processing keyword arguments for the ProteinLLM processor"""
28
+ protein_kwargs: ProteinLLMKwargs
29
+ _defaults = {
30
+ "text_kwargs": {
31
+ "padding": False,
32
+ },
33
+ }
34
+
35
+ class ProteinLLMProcessor(ProcessorMixin):
36
+ r"""
37
+ Constructs a ProteinLLM processor which wraps a ESM2 protein processor and a Qwen tokenizer into a single processor.
38
+ This processor handles both text and protein sequence processing to prepare inputs for the ProteinLLMModel.
39
+
40
+ Args:
41
+ tokenizer (PreTrainedTokenizerBase, *optional*):
42
+ The text tokenizer used for processing text inputs.
43
+ protein_tokenizer (PreTrainedTokenizerBase, *optional*):
44
+ The protein tokenizer used for processing protein sequences.
45
+ chat_template (`str`, *optional*):
46
+ A Jinja template for chat formatting. If None, will use the tokenizer's template.
47
+ """
48
+
49
+ attributes = ["tokenizer", "protein_tokenizer"]
50
+ valid_kwargs = ["model", "chat_template"]
51
+ tokenizer_class = (
52
+ "Qwen2Tokenizer", "Qwen2TokenizerFast",
53
+ "GPT2TokenizerFast",
54
+ )
55
+ protein_tokenizer_class = ("EsmTokenizer",)
56
+
57
+ def __init__(
58
+ self, tokenizer=None, protein_tokenizer=None, chat_template=None, **kwargs
59
+ ):
60
+ """
61
+ Initialize the processor with text and protein tokenizers.
62
+
63
+ Args:
64
+ tokenizer: Text tokenizer (usually from a language model)
65
+ protein_tokenizer: Protein tokenizer (usually from ESM2)
66
+ chat_template: Template for formatting chat conversations
67
+ **kwargs: Additional arguments
68
+ """
69
+ self.tokenizer = tokenizer
70
+ self.protein_tokenizer = protein_tokenizer
71
+
72
+ self.protein_token = (
73
+ "<|protein_pad|>"
74
+ if not hasattr(self.tokenizer, "protein_token")
75
+ else self.tokenizer.protein_token
76
+ )
77
+
78
+ # Get chat template from tokenizer if not provided
79
+ if chat_template is None and hasattr(self.tokenizer, "chat_template"):
80
+ chat_template = self.tokenizer.chat_template
81
+ super().__init__(tokenizer, protein_tokenizer, chat_template=chat_template)
82
+
83
+ # The GRPO trainer might expect this to be set
84
+ if not hasattr(self.tokenizer, 'pad_token') or self.tokenizer.pad_token is None:
85
+ self.tokenizer.pad_token = self.tokenizer.eos_token
86
+
87
+ def tokenize_protein_sequences(
88
+ self,
89
+ batch_protein_sequences: List[List[str]],
90
+ max_length: int = 1024,
91
+ return_tensors: str = "pt",
92
+ device: str = "cuda",
93
+ ) -> Dict[str, Any]:
94
+ """
95
+ Tokenize a batch of protein sequences.
96
+
97
+ Args:
98
+ batch_protein_sequences: List of lists of protein sequences per batch item
99
+ max_length: Maximum allowed length for protein sequences
100
+ return_tensors: Return format for tensors ("pt" for PyTorch)
101
+ device: Device to place tensors on
102
+
103
+ Returns:
104
+ Dict containing:
105
+ - protein_tokenized: The tokenized protein sequences
106
+ - batch_idx_map: Mapping of which sequences belong to which batch item
107
+ """
108
+ # Create a mapping to track which sequences belong to which batch item
109
+ batch_idx_map = []
110
+ all_sequences = []
111
+
112
+ # Flatten all sequences with batch tracking
113
+ for batch_idx, protein_sequences in enumerate(batch_protein_sequences):
114
+ for seq in protein_sequences:
115
+ all_sequences.append(seq)
116
+ batch_idx_map.append(batch_idx)
117
+
118
+ # If no sequences in the entire batch, return empty dict
119
+ if not all_sequences:
120
+ return {"protein_tokenized": None, "batch_idx_map": []}
121
+
122
+ # Tokenize all sequences at once
123
+ protein_tokenized = self.protein_tokenizer(
124
+ all_sequences,
125
+ padding=True,
126
+ truncation=True,
127
+ max_length=max_length,
128
+ return_tensors=return_tensors,
129
+ return_attention_mask=True,
130
+ )
131
+
132
+ # Move tensors to the specified device
133
+ if return_tensors == "pt" and device is not None:
134
+ protein_tokenized = {k: v.to(device) if isinstance(v, torch.Tensor) else v
135
+ for k, v in protein_tokenized.items()}
136
+
137
+ return {"protein_tokenized": protein_tokenized, "batch_idx_map": batch_idx_map}
138
+
139
+ def __call__(
140
+ self,
141
+ batch_protein_sequences: Optional[List[List[str]]] = None,
142
+ text: Optional[
143
+ Union[
144
+ TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]
145
+ ]
146
+ ] = None,
147
+ max_length_text: int = 512,
148
+ max_length_protein: int = 1024,
149
+ return_tensors: str = "pt",
150
+ device: str = "cuda",
151
+ **kwargs: Unpack[ProteinLLMProcessorKwargs],
152
+ ) -> BatchFeature:
153
+ """
154
+ Process text and protein sequences for model input.
155
+
156
+ Args:
157
+ batch_protein_sequences: List of lists of protein sequences per batch item
158
+ text: Input text or list of texts
159
+ max_length_text: Maximum length for text sequences
160
+ max_length_protein: Maximum length for protein sequences
161
+ return_tensors: Return format for tensors
162
+ device: Device to place tensors on
163
+ **kwargs: Additional processor keyword arguments
164
+
165
+ Returns:
166
+ BatchFeature with tokenized inputs for the model
167
+ """
168
+ output_kwargs = self._merge_kwargs(
169
+ ProteinLLMProcessorKwargs,
170
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
171
+ **kwargs,
172
+ )
173
+
174
+ # Ensure text is a list
175
+ if not isinstance(text, list):
176
+ text = [text]
177
+
178
+ protein_inputs = {}
179
+ if batch_protein_sequences is not None:
180
+ # Tokenize protein sequences
181
+ protein_processing_result = self.tokenize_protein_sequences(
182
+ batch_protein_sequences,
183
+ max_length=max_length_protein,
184
+ return_tensors=return_tensors,
185
+ device=device,
186
+ )
187
+
188
+ # Replace protein tokens in text if needed
189
+ index = 0
190
+ for i in range(len(text)):
191
+ while self.protein_token in text[i]:
192
+ num_protein_tokens = (protein_processing_result['protein_tokenized']['input_ids'][index] != self.protein_tokenizer.pad_token_id).sum().item()
193
+ text[i] = text[i].replace(
194
+ self.protein_token, "<|placeholder|>" * num_protein_tokens, 1
195
+ )
196
+ index += 1
197
+ text[i] = text[i].replace("<|placeholder|>", self.protein_token)
198
+
199
+ # Add batch info to the output
200
+ protein_inputs = {
201
+ "protein_tokenized": protein_processing_result["protein_tokenized"],
202
+ "batch_idx_map": protein_processing_result["batch_idx_map"],
203
+ }
204
+
205
+ # Tokenize text
206
+ text_kwargs = output_kwargs.get("text_kwargs", {})
207
+
208
+ if 'padding' in text_kwargs:
209
+ del text_kwargs['padding']
210
+
211
+ text_inputs = self.tokenizer(
212
+ text,
213
+ max_length=max_length_text + 2 * max_length_protein,
214
+ return_tensors=return_tensors,
215
+ padding=True,
216
+ truncation=True,
217
+ **text_kwargs,
218
+ )
219
+ # Move text tensors to device if specified
220
+ if return_tensors == "pt" and device is not None:
221
+ text_inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v
222
+ for k, v in text_inputs.items()}
223
+
224
+ # The BatchFeature should have all required fields for the model's forward pass
225
+ return BatchFeature(data={**text_inputs, **protein_inputs})
226
+
227
+ def batch_decode(self, *args, **kwargs) -> List[str]:
228
+ """
229
+ This method forwards all its arguments to the tokenizer's batch_decode.
230
+
231
+ Returns:
232
+ List of decoded strings
233
+ """
234
+ return self.tokenizer.batch_decode(*args, **kwargs)
235
+
236
+ def decode(self, *args, **kwargs) -> str:
237
+ """
238
+ This method forwards all its arguments to the tokenizer's decode.
239
+
240
+ Returns:
241
+ Decoded string
242
+ """
243
+ return self.tokenizer.decode(*args, **kwargs)
244
+
245
+ def post_process_protein_to_text(
246
+ self,
247
+ generated_outputs: torch.Tensor,
248
+ skip_special_tokens: bool = True,
249
+ **kwargs,
250
+ ) -> List[str]:
251
+ """
252
+ Post-process the model output to decode the text.
253
+
254
+ Args:
255
+ generated_outputs: The token IDs generated by the model
256
+ skip_special_tokens: Whether to skip special tokens in the output
257
+ **kwargs: Additional arguments for the decoder
258
+
259
+ Returns:
260
+ List of decoded strings
261
+ """
262
+ return self.tokenizer.batch_decode(
263
+ generated_outputs,
264
+ skip_special_tokens=skip_special_tokens,
265
+ **kwargs,
266
+ )
267
+
268
+ @property
269
+ def model_input_names(self) -> List[str]:
270
+ """
271
+ Get the input names expected by the model.
272
+
273
+ Returns:
274
+ List of input names
275
+ """
276
+ tokenizer_input_names = self.tokenizer.model_input_names
277
+ protein_input_names = ["protein_tokenized", "batch_idx_map"]
278
+
279
+ return list(dict.fromkeys(tokenizer_input_names + protein_input_names))
BioReason_new/bioreason/models/protein_llm.py ADDED
@@ -0,0 +1,1093 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import torch
2
+ # import torch.nn as nn
3
+ # from typing import Optional, List, Dict, Any, Union, Tuple
4
+ # from transformers import (
5
+ # AutoTokenizer,
6
+ # AutoModelForCausalLM,
7
+ # EsmModel,
8
+ # EsmTokenizer,
9
+ # BertModel,
10
+ # BertTokenizer,
11
+ # )
12
+
13
+ # from bioreason.models.pl.processing_pl import ProteinLLMProcessor
14
+ # #from bioreason.models.dl.chat_template_dl import CHAT_TEMPLATE
15
+
16
+
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from typing import Optional, List, Dict, Any, Union, Tuple
21
+ from transformers import (
22
+ AutoTokenizer,
23
+ AutoModelForCausalLM,
24
+ EsmModel,
25
+ EsmTokenizer,
26
+ BertModel,
27
+ BertTokenizer,
28
+ BertConfig,
29
+ )
30
+
31
+ from bioreason.models.pl.processing_pl import ProteinLLMProcessor
32
+ #from bioreason.models.dl.chat_template_dl import CHAT_TEMPLATE
33
+
34
+
35
+ class QFormerProjector(nn.Module):
36
+ """
37
+ QFormer-based projector that maps protein embeddings to text space.
38
+ Uses cross-attention mechanism for better alignment.
39
+ """
40
+
41
+ import torch
42
+ import torch.nn as nn
43
+ from typing import Optional, List, Dict, Any, Union, Tuple
44
+ from transformers import (
45
+ AutoTokenizer,
46
+ AutoModelForCausalLM,
47
+ EsmModel,
48
+ EsmTokenizer,
49
+ BertModel,
50
+ BertTokenizer,
51
+ )
52
+
53
+ from bioreason.models.pl.processing_pl import ProteinLLMProcessor
54
+ #from bioreason.models.dl.chat_template_dl import CHAT_TEMPLATE
55
+
56
+
57
+ class QFormerProjector(nn.Module):
58
+ """
59
+ QFormer-based projector that maps protein embeddings to text space.
60
+ Uses cross-attention mechanism for better alignment.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ protein_hidden_size: int,
66
+ text_hidden_size: int,
67
+ qformer_model_name: str = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext",
68
+ num_query_tokens: int = 32,
69
+ cross_attention_layers: int = 6,
70
+ max_protein_length: int = 400, # Conservative limit: 32 + 400 = 432 < 512
71
+ ):
72
+ super().__init__()
73
+
74
+ self.protein_hidden_size = protein_hidden_size
75
+ self.text_hidden_size = text_hidden_size
76
+ self.num_query_tokens = num_query_tokens
77
+ self.max_protein_length = max_protein_length # 32 + 400 = 432 < 512 (BERT limit)
78
+
79
+ # Load QFormer (BERT-based) - keep original config to avoid size mismatch
80
+ self.qformer = BertModel.from_pretrained(qformer_model_name)
81
+ self.qformer_hidden_size = self.qformer.config.hidden_size
82
+
83
+ # Learnable query tokens
84
+ self.query_tokens = nn.Parameter(
85
+ torch.zeros(1, num_query_tokens, self.qformer_hidden_size)
86
+ )
87
+ self.query_tokens.data.normal_(mean=0.0, std=0.02)
88
+
89
+ # Project protein features to QFormer dimension
90
+ self.protein_projection = nn.Linear(protein_hidden_size, self.qformer_hidden_size)
91
+
92
+ # Final projection to text space
93
+ self.text_projection = nn.Linear(self.qformer_hidden_size, text_hidden_size)
94
+
95
+ # Layer norm for stability
96
+ self.layer_norm = nn.LayerNorm(text_hidden_size)
97
+
98
+ def forward(
99
+ self,
100
+ protein_embeddings: torch.Tensor, # [batch_size, seq_len, protein_hidden_size]
101
+ protein_attention_mask: torch.Tensor = None, # [batch_size, seq_len]
102
+ ) -> torch.Tensor:
103
+ """
104
+ Forward pass through QFormer projector.
105
+
106
+ Args:
107
+ protein_embeddings: Protein embeddings from ESM2
108
+ protein_attention_mask: Attention mask for protein sequences
109
+
110
+ Returns:
111
+ Projected embeddings in text space [batch_size, num_query_tokens, text_hidden_size]
112
+ """
113
+ batch_size, seq_len, _ = protein_embeddings.size()
114
+
115
+ # Truncate protein sequence if necessary
116
+ if seq_len > self.max_protein_length:
117
+ protein_embeddings = protein_embeddings[:, :self.max_protein_length, :]
118
+ if protein_attention_mask is not None:
119
+ protein_attention_mask = protein_attention_mask[:, :self.max_protein_length]
120
+ seq_len = self.max_protein_length
121
+
122
+ # Project protein embeddings to QFormer dimension
123
+ protein_embeds = self.protein_projection(protein_embeddings) # [B, L, H_qformer]
124
+
125
+ # Expand query tokens for batch
126
+ query_tokens = self.query_tokens.expand(batch_size, -1, -1) # [B, num_queries, H_qformer]
127
+
128
+ # Create attention masks
129
+ query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=query_tokens.device)
130
+
131
+ if protein_attention_mask is None:
132
+ protein_attention_mask = torch.ones(
133
+ protein_embeds.size()[:-1], dtype=torch.long, device=protein_embeds.device
134
+ )
135
+
136
+ attention_mask = torch.cat([query_atts, protein_attention_mask], dim=1)
137
+
138
+ # Ensure total length doesn't exceed model limit
139
+ total_length = attention_mask.size(1)
140
+ max_length = self.qformer.config.max_position_embeddings
141
+
142
+ if total_length > max_length:
143
+ # Truncate protein sequence further if needed
144
+ excess = total_length - max_length
145
+ if excess > 0 and protein_embeds.size(1) > excess:
146
+ protein_embeds = protein_embeds[:, :-excess, :]
147
+ protein_attention_mask = protein_attention_mask[:, :-excess]
148
+ attention_mask = torch.cat([query_atts, protein_attention_mask], dim=1)
149
+ else:
150
+ raise ValueError(f"Cannot fit sequence into model max length {max_length}")
151
+
152
+ # Combine embeddings
153
+ inputs_embeds = torch.cat([query_tokens, protein_embeds], dim=1)
154
+
155
+ # Pass through QFormer without explicit position_ids (let model auto-generate)
156
+ outputs = self.qformer(
157
+ inputs_embeds=inputs_embeds,
158
+ attention_mask=attention_mask,
159
+ return_dict=True,
160
+ )
161
+
162
+ # Extract query outputs (first num_query_tokens)
163
+ query_output = outputs.last_hidden_state[:, :self.num_query_tokens, :]
164
+
165
+ # Project to text space
166
+ text_embeds = self.text_projection(query_output)
167
+ text_embeds = self.layer_norm(text_embeds)
168
+
169
+ return text_embeds
170
+
171
+
172
+ class ProteinLLMModel(nn.Module):
173
+ """
174
+ A combined model that processes both protein sequences and text inputs.
175
+ Uses ESM2 for protein encoding, QFormer for projection, and Qwen for text generation.
176
+ """
177
+
178
+ def __init__(
179
+ self,
180
+ text_model_name: str,
181
+ protein_model_name: str,
182
+ qformer_model_name: str = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext",
183
+ cache_dir: Optional[str] = None,
184
+ max_length_protein: int = 1024,
185
+ max_length_text: int = 512,
186
+ text_model_finetune: bool = True,
187
+ protein_model_finetune: bool = True,
188
+ num_query_tokens: int = 32,
189
+ cross_attention_layers: int = 6,
190
+ ):
191
+ """
192
+ Initialize the ProteinLLMModel.
193
+
194
+ Args:
195
+ text_model_name: Name of the text model (Qwen)
196
+ protein_model_name: Name of the protein model (ESM2)
197
+ qformer_model_name: Name of the QFormer model
198
+ cache_dir: Directory to cache the models
199
+ max_length_protein: Maximum length of protein sequences
200
+ max_length_text: Maximum length of text sequences
201
+ text_model_finetune: Whether to finetune the text model
202
+ protein_model_finetune: Whether to finetune the protein model
203
+ num_query_tokens: Number of learnable query tokens
204
+ cross_attention_layers: Number of cross-attention layers in QFormer
205
+ """
206
+ super().__init__()
207
+
208
+ self.text_model_finetune = text_model_finetune
209
+ self.protein_model_finetune = protein_model_finetune
210
+ self.max_length_protein = max_length_protein
211
+ self.max_length_text = max_length_text
212
+ self.num_query_tokens = num_query_tokens
213
+
214
+ # Load the text model and tokenizer (Qwen)
215
+ self.text_model = AutoModelForCausalLM.from_pretrained(
216
+ text_model_name, cache_dir=cache_dir, trust_remote_code=True
217
+ )
218
+ self.text_tokenizer = AutoTokenizer.from_pretrained(text_model_name, trust_remote_code=True)
219
+ self.text_config = self.text_model.config
220
+ #self.text_tokenizer.chat_template = CHAT_TEMPLATE
221
+ self.text_tokenizer.pad_token = self.text_tokenizer.eos_token
222
+
223
+ # Add special tokens for protein
224
+ new_tokens = ["<|protein_start|>", "<|protein_pad|>", "<|protein_end|>"]
225
+ self.text_tokenizer.add_special_tokens({"additional_special_tokens": new_tokens})
226
+ self.protein_token_id = self.text_tokenizer.convert_tokens_to_ids("<|protein_pad|>")
227
+
228
+ # Load the protein model and tokenizer (ESM2)
229
+ self.protein_model = EsmModel.from_pretrained(
230
+ protein_model_name, cache_dir=cache_dir, trust_remote_code=True
231
+ )
232
+ self.protein_tokenizer = EsmTokenizer.from_pretrained(protein_model_name, trust_remote_code=True)
233
+ self.protein_config = self.protein_model.config
234
+
235
+ # Get model dimensions
236
+ self.text_hidden_size = self.text_config.hidden_size
237
+ self.protein_hidden_size = self.protein_config.hidden_size
238
+
239
+ # Create QFormer projector
240
+ self.protein_projection = QFormerProjector(
241
+ protein_hidden_size=self.protein_hidden_size,
242
+ text_hidden_size=self.text_hidden_size,
243
+ qformer_model_name=qformer_model_name,
244
+ num_query_tokens=num_query_tokens,
245
+ cross_attention_layers=cross_attention_layers,
246
+ )
247
+
248
+ # Create processor for handling inputs
249
+ self.processor = ProteinLLMProcessor(
250
+ tokenizer=self.text_tokenizer,
251
+ protein_tokenizer=self.protein_tokenizer
252
+ )
253
+
254
+ def process_protein_embeddings(
255
+ self,
256
+ protein_tokenized: Dict[str, torch.Tensor],
257
+ batch_idx_map: List[int],
258
+ batch_size: int,
259
+ ) -> List[torch.Tensor]:
260
+ """
261
+ Process protein sequences to obtain embeddings.
262
+
263
+ Args:
264
+ protein_tokenized: Tokenized protein sequences
265
+ batch_idx_map: Mapping of each sequence to its batch item
266
+ batch_size: Number of items in the batch
267
+
268
+ Returns:
269
+ List of tensor embeddings for each batch item
270
+ """
271
+ # Process all sequences to get protein representations
272
+ with torch.no_grad():
273
+ outputs = self.protein_model(
274
+ input_ids=protein_tokenized["input_ids"],
275
+ attention_mask=protein_tokenized["attention_mask"],
276
+ )
277
+ # Get the last hidden state
278
+ hidden_states = outputs.last_hidden_state # shape: [n_seqs, seq_len, hidden_dim]
279
+
280
+ # Apply QFormer projection
281
+ hidden_states = hidden_states.to(
282
+ device=self.protein_projection.query_tokens.device,
283
+ dtype=self.protein_projection.query_tokens.dtype
284
+ )
285
+
286
+ # Project all embeddings at once
287
+ projected_states_list = []
288
+ for seq_idx in range(hidden_states.size(0)):
289
+ seq_embedding = hidden_states[seq_idx:seq_idx+1] # [1, seq_len, hidden_dim]
290
+ seq_attention_mask = protein_tokenized["attention_mask"][seq_idx:seq_idx+1]
291
+
292
+ projected_embedding = self.protein_projection(
293
+ seq_embedding, seq_attention_mask
294
+ ) # [1, num_query_tokens, text_hidden_size]
295
+ projected_states_list.append(projected_embedding.squeeze(0)) # [num_query_tokens, text_hidden_size]
296
+
297
+ # Group embeddings by batch item
298
+ result = [[] for _ in range(batch_size)]
299
+
300
+ # For each sequence, get its embeddings and add to appropriate batch result
301
+ for seq_idx, batch_idx in enumerate(batch_idx_map):
302
+ result[batch_idx].append(projected_states_list[seq_idx])
303
+
304
+ # Concatenate embeddings for each batch item
305
+ for i in range(batch_size):
306
+ if result[i]:
307
+ result[i] = torch.cat(result[i], dim=0)
308
+ else:
309
+ result[i] = torch.zeros((0, self.text_hidden_size), device=self.device)
310
+
311
+ return result
312
+
313
+ def forward(
314
+ self,
315
+ input_ids: Optional[torch.Tensor] = None,
316
+ attention_mask: Optional[torch.Tensor] = None,
317
+ protein_tokenized: Optional[Dict[str, torch.Tensor]] = None,
318
+ batch_idx_map: Optional[List[int]] = None,
319
+ labels: Optional[torch.Tensor] = None,
320
+ **kwargs,
321
+ ) -> torch.Tensor:
322
+ """
323
+ Forward pass through the model.
324
+ """
325
+ if input_ids is None or attention_mask is None:
326
+ raise ValueError("Either 'inputs' or 'input_ids'/'attention_mask' must be provided")
327
+
328
+ batch_size = input_ids.shape[0]
329
+
330
+ # Get text embeddings from the model's embedding layer
331
+ text_inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
332
+
333
+ if protein_tokenized is not None and batch_idx_map:
334
+ batch_protein_embeds = self.process_protein_embeddings(
335
+ protein_tokenized, batch_idx_map, batch_size
336
+ )
337
+
338
+ mask = input_ids == self.protein_token_id
339
+
340
+ n_protein_tokens = mask.sum().item()
341
+ protein_embeds_flat = torch.cat(batch_protein_embeds, dim=0)
342
+ n_protein_features = protein_embeds_flat.shape[0]
343
+
344
+ if n_protein_features != n_protein_tokens:
345
+ raise ValueError(
346
+ f"Protein features and protein tokens do not match: features {n_protein_features}, tokens: {n_protein_tokens}"
347
+ )
348
+
349
+ # Ensure protein embeddings have the same dtype as the text embeddings
350
+ protein_embeds_flat = protein_embeds_flat.to(dtype=text_inputs_embeds.dtype)
351
+ text_inputs_embeds[mask] = protein_embeds_flat
352
+
353
+ # Forward pass through the text model
354
+ outputs = self.text_model(
355
+ inputs_embeds=text_inputs_embeds,
356
+ attention_mask=attention_mask,
357
+ labels=labels,
358
+ **kwargs,
359
+ )
360
+
361
+ return outputs
362
+
363
+ def generate(
364
+ self,
365
+ input_ids: Optional[torch.Tensor] = None,
366
+ attention_mask: Optional[torch.Tensor] = None,
367
+ protein_tokenized: Optional[Dict[str, torch.Tensor]] = None,
368
+ batch_idx_map: Optional[List[int]] = None,
369
+ **generation_kwargs,
370
+ ) -> Union[torch.Tensor, List[str]]:
371
+ """
372
+ Generate text based on protein and text inputs.
373
+ """
374
+ if input_ids is None or attention_mask is None:
375
+ raise ValueError("Either 'inputs' or 'input_ids'/'attention_mask' must be provided")
376
+
377
+ batch_size = input_ids.shape[0]
378
+
379
+ # Get text embeddings from the model's embedding layer
380
+ text_inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
381
+
382
+ if protein_tokenized is not None and batch_idx_map:
383
+ batch_protein_embeds = self.process_protein_embeddings(
384
+ protein_tokenized, batch_idx_map, batch_size
385
+ )
386
+
387
+ mask = input_ids == self.protein_token_id
388
+
389
+ n_protein_tokens = mask.sum().item()
390
+ protein_embeds_flat = torch.cat(batch_protein_embeds, dim=0)
391
+ n_protein_features = protein_embeds_flat.shape[0]
392
+
393
+ if n_protein_features != n_protein_tokens:
394
+ raise ValueError(
395
+ f"Protein features and protein tokens do not match: features {n_protein_features}, tokens: {n_protein_tokens}"
396
+ )
397
+
398
+ # Ensure protein embeddings have the same dtype as the text embeddings
399
+ protein_embeds_flat = protein_embeds_flat.to(dtype=text_inputs_embeds.dtype)
400
+ text_inputs_embeds[mask] = protein_embeds_flat
401
+
402
+ # Generation
403
+ with torch.no_grad():
404
+ outputs = self.text_model.generate(
405
+ inputs_embeds=text_inputs_embeds,
406
+ attention_mask=attention_mask,
407
+ use_cache=True,
408
+ **generation_kwargs,
409
+ )
410
+
411
+ return outputs
412
+
413
+ @property
414
+ def device(self):
415
+ """Get the device of the model."""
416
+ return next(self.parameters()).device
417
+
418
+ def to_device(self, tensor_dict):
419
+ """Move tensor dictionary to model device."""
420
+ return {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
421
+ for k, v in tensor_dict.items()}
422
+
423
+ def get_protein_embeddings(self, protein_sequences: List[str]) -> torch.Tensor:
424
+ """
425
+ Get raw protein embeddings before projection.
426
+
427
+ Args:
428
+ protein_sequences: List of protein sequences
429
+
430
+ Returns:
431
+ Raw protein embeddings [batch_size, seq_len, protein_hidden_size]
432
+ """
433
+ # Tokenize protein sequences
434
+ protein_inputs = self.protein_tokenizer(
435
+ protein_sequences,
436
+ padding=True,
437
+ truncation=True,
438
+ max_length=self.max_length_protein,
439
+ return_tensors="pt",
440
+ )
441
+
442
+ # Move to correct device
443
+ protein_inputs = self.to_device(protein_inputs)
444
+
445
+ # Get protein embeddings
446
+ with torch.no_grad():
447
+ protein_outputs = self.protein_model(**protein_inputs)
448
+ protein_embeddings = protein_outputs.last_hidden_state
449
+
450
+ return protein_embeddings
451
+
452
+ def get_protein_features(
453
+ self,
454
+ protein_sequences: List[str],
455
+ return_tensors: str = "pt",
456
+ ) -> torch.Tensor:
457
+ """
458
+ Extract protein features for contrastive learning.
459
+
460
+ Args:
461
+ protein_sequences: List of protein sequences
462
+ return_tensors: Return format for tensors
463
+
464
+ Returns:
465
+ Protein features [batch_size, num_query_tokens, text_hidden_size]
466
+ """
467
+ # Tokenize protein sequences
468
+ protein_inputs = self.protein_tokenizer(
469
+ protein_sequences,
470
+ padding=True,
471
+ truncation=True,
472
+ max_length=self.max_length_protein,
473
+ return_tensors=return_tensors,
474
+ )
475
+
476
+ protein_inputs = self.to_device(protein_inputs)
477
+
478
+ # Get protein embeddings
479
+ with torch.no_grad():
480
+ protein_outputs = self.protein_model(**protein_inputs)
481
+ protein_embeddings = protein_outputs.last_hidden_state
482
+
483
+ # Project through QFormer - Fixed: only pass two required arguments
484
+ protein_features = self.protein_projection(
485
+ protein_embeddings, protein_inputs["attention_mask"]
486
+ )
487
+
488
+ # Global average pooling over query tokens
489
+ # protein_features = protein_features.mean(dim=1) # [batch_size, text_hidden_size]
490
+
491
+ return protein_features
492
+
493
+ def get_text_features(
494
+ self,
495
+ text_sequences: List[str],
496
+ return_tensors: str = "pt",
497
+ ) -> torch.Tensor:
498
+ """
499
+ Extract text features for contrastive learning.
500
+
501
+ Args:
502
+ text_sequences: List of text descriptions
503
+ return_tensors: Return format for tensors
504
+
505
+ Returns:
506
+ Text features [batch_size, text_hidden_size]
507
+ """
508
+ # Tokenize text sequences
509
+ text_inputs = self.text_tokenizer(
510
+ text_sequences,
511
+ padding=True,
512
+ truncation=True,
513
+ max_length=self.max_length_text,
514
+ return_tensors=return_tensors,
515
+ )
516
+
517
+ text_inputs = self.to_device(text_inputs)
518
+
519
+ # Get text embeddings from the embedding layer
520
+ with torch.no_grad():
521
+ text_embeddings = self.text_model.get_input_embeddings()(text_inputs["input_ids"])
522
+
523
+ # Apply attention mask and average pooling
524
+ attention_mask = text_inputs["attention_mask"].unsqueeze(-1)
525
+ masked_embeddings = text_embeddings * attention_mask
526
+ text_features = masked_embeddings.sum(dim=1) / attention_mask.sum(dim=1)
527
+
528
+ return text_features
529
+
530
+ # class QFormerProjector(nn.Module):
531
+ # """
532
+ # QFormer-based projector that maps protein embeddings to text space.
533
+ # Uses cross-attention mechanism for better alignment.
534
+ # """
535
+
536
+ # # def __init__(
537
+ # # self,
538
+ # # protein_hidden_size: int,
539
+ # # text_hidden_size: int,
540
+ # # qformer_model_name: str = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext",
541
+ # # num_query_tokens: int = 32,
542
+ # # cross_attention_layers: int = 6,
543
+ # # ):
544
+ # # super().__init__()
545
+
546
+ # # self.protein_hidden_size = protein_hidden_size
547
+ # # self.text_hidden_size = text_hidden_size
548
+ # # self.num_query_tokens = num_query_tokens
549
+
550
+ # # # Load QFormer (BERT-based)
551
+ # # self.qformer = BertModel.from_pretrained(qformer_model_name)
552
+ # # self.qformer_hidden_size = self.qformer.config.hidden_size
553
+
554
+ # # # Learnable query tokens
555
+ # # self.query_tokens = nn.Parameter(
556
+ # # torch.zeros(1, num_query_tokens, self.qformer_hidden_size)
557
+ # # )
558
+ # # self.query_tokens.data.normal_(mean=0.0, std=0.02)
559
+
560
+ # # # Project protein features to QFormer dimension
561
+ # # self.protein_projection = nn.Linear(protein_hidden_size, self.qformer_hidden_size)
562
+
563
+ # # # Final projection to text space
564
+ # # self.text_projection = nn.Linear(self.qformer_hidden_size, text_hidden_size)
565
+
566
+ # # # Layer norm for stability
567
+ # # self.layer_norm = nn.LayerNorm(text_hidden_size)
568
+
569
+ # # def forward(
570
+ # # self,
571
+ # # protein_embeddings: torch.Tensor, # [batch_size, seq_len, protein_hidden_size]
572
+ # # protein_attention_mask: torch.Tensor = None, # [batch_size, seq_len]
573
+ # # ) -> torch.Tensor:
574
+ # # """
575
+ # # Forward pass through QFormer projector.
576
+
577
+ # # Args:
578
+ # # protein_embeddings: Protein embeddings from ESM2
579
+ # # protein_attention_mask: Attention mask for protein sequences
580
+
581
+ # # Returns:
582
+ # # Projected embeddings in text space [batch_size, num_query_tokens, text_hidden_size]
583
+ # # """
584
+ # # batch_size = protein_embeddings.size(0)
585
+
586
+ # # # Project protein embeddings to QFormer dimension
587
+ # # protein_embeds = self.protein_projection(protein_embeddings) # [B, L, H_qformer]
588
+
589
+ # # # Expand query tokens for batch
590
+ # # query_tokens = self.query_tokens.expand(batch_size, -1, -1) # [B, num_queries, H_qformer]
591
+
592
+ # # # Concatenate query tokens and protein embeddings
593
+ # # query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=query_tokens.device)
594
+
595
+ # # if protein_attention_mask is None:
596
+ # # protein_attention_mask = torch.ones(
597
+ # # protein_embeds.size()[:-1], dtype=torch.long, device=protein_embeds.device
598
+ # # )
599
+
600
+ # # attention_mask = torch.cat([query_atts, protein_attention_mask], dim=1)
601
+
602
+ # # # Create position ids
603
+ # # position_ids = torch.arange(
604
+ # # attention_mask.size(1), dtype=torch.long, device=attention_mask.device
605
+ # # ).unsqueeze(0).expand(batch_size, -1)
606
+
607
+ # # # Combine embeddings
608
+ # # inputs_embeds = torch.cat([query_tokens, protein_embeds], dim=1)
609
+
610
+ # # # Pass through QFormer
611
+ # # outputs = self.qformer(
612
+ # # inputs_embeds=inputs_embeds,
613
+ # # attention_mask=attention_mask,
614
+ # # position_ids=position_ids,
615
+ # # return_dict=True,
616
+ # # )
617
+
618
+ # # # Extract query outputs (first num_query_tokens)
619
+ # # query_output = outputs.last_hidden_state[:, :self.num_query_tokens, :]
620
+
621
+ # # # Project to text space
622
+ # # text_embeds = self.text_projection(query_output)
623
+ # # text_embeds = self.layer_norm(text_embeds)
624
+
625
+ # # return text_embeds
626
+ # def __init__(
627
+ # self,
628
+ # protein_hidden_size: int,
629
+ # text_hidden_size: int,
630
+ # qformer_model_name: str = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext",
631
+ # num_query_tokens: int = 32,
632
+ # cross_attention_layers: int = 6,
633
+ # max_protein_length: int = 480, # 新增:限制蛋白质序列长度
634
+ # ):
635
+ # super().__init__()
636
+
637
+ # self.protein_hidden_size = protein_hidden_size
638
+ # self.text_hidden_size = text_hidden_size
639
+ # self.num_query_tokens = num_query_tokens
640
+ # self.max_protein_length = max_protein_length # 32 + 480 = 512
641
+
642
+ # # Load QFormer (BERT-based) with longer sequence support
643
+ # from transformers import BertConfig
644
+ # config = BertConfig.from_pretrained(qformer_model_name)
645
+
646
+ # # 方案1:扩展模型的最大位置编码(如果原模型支持)
647
+ # config.max_position_embeddings = max(1024, num_query_tokens + max_protein_length)
648
+
649
+ # self.qformer = BertModel.from_pretrained(qformer_model_name, config=config)
650
+ # self.qformer_hidden_size = self.qformer.config.hidden_size
651
+
652
+ # # Learnable query tokens
653
+ # self.query_tokens = nn.Parameter(
654
+ # torch.zeros(1, num_query_tokens, self.qformer_hidden_size)
655
+ # )
656
+ # self.query_tokens.data.normal_(mean=0.0, std=0.02)
657
+
658
+ # # Project protein features to QFormer dimension
659
+ # self.protein_projection = nn.Linear(protein_hidden_size, self.qformer_hidden_size)
660
+
661
+ # # Final projection to text space
662
+ # self.text_projection = nn.Linear(self.qformer_hidden_size, text_hidden_size)
663
+
664
+ # # Layer norm for stability
665
+ # self.layer_norm = nn.LayerNorm(text_hidden_size)
666
+
667
+ # def forward(
668
+ # self,
669
+ # protein_embeddings: torch.Tensor, # [batch_size, seq_len, protein_hidden_size]
670
+ # protein_attention_mask: torch.Tensor = None, # [batch_size, seq_len]
671
+ # ) -> torch.Tensor:
672
+ # """
673
+ # Forward pass through QFormer projector.
674
+
675
+ # Args:
676
+ # protein_embeddings: Protein embeddings from ESM2
677
+ # protein_attention_mask: Attention mask for protein sequences
678
+
679
+ # Returns:
680
+ # Projected embeddings in text space [batch_size, num_query_tokens, text_hidden_size]
681
+ # """
682
+ # batch_size, seq_len, _ = protein_embeddings.size()
683
+
684
+ # # 方案2:截断蛋白质序列
685
+ # if seq_len > self.max_protein_length:
686
+ # protein_embeddings = protein_embeddings[:, :self.max_protein_length, :]
687
+ # if protein_attention_mask is not None:
688
+ # protein_attention_mask = protein_attention_mask[:, :self.max_protein_length]
689
+ # seq_len = self.max_protein_length
690
+
691
+ # # Project protein embeddings to QFormer dimension
692
+ # protein_embeds = self.protein_projection(protein_embeddings) # [B, L, H_qformer]
693
+
694
+ # # Expand query tokens for batch
695
+ # query_tokens = self.query_tokens.expand(batch_size, -1, -1) # [B, num_queries, H_qformer]
696
+
697
+ # # Create attention masks
698
+ # query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=query_tokens.device)
699
+
700
+ # if protein_attention_mask is None:
701
+ # protein_attention_mask = torch.ones(
702
+ # protein_embeds.size()[:-1], dtype=torch.long, device=protein_embeds.device
703
+ # )
704
+
705
+ # attention_mask = torch.cat([query_atts, protein_attention_mask], dim=1)
706
+
707
+ # # 确保总长度不超过模型限制
708
+ # total_length = attention_mask.size(1)
709
+ # max_length = self.qformer.config.max_position_embeddings
710
+
711
+ # if total_length > max_length:
712
+ # raise ValueError(f"Total sequence length {total_length} exceeds model max length {max_length}")
713
+
714
+ # # Combine embeddings
715
+ # inputs_embeds = torch.cat([query_tokens, protein_embeds], dim=1)
716
+
717
+ # # 方案3:不使用position_ids,让模型自动生成
718
+ # # Pass through QFormer without explicit position_ids
719
+ # outputs = self.qformer(
720
+ # inputs_embeds=inputs_embeds,
721
+ # attention_mask=attention_mask,
722
+ # return_dict=True,
723
+ # )
724
+
725
+ # # Extract query outputs (first num_query_tokens)
726
+ # query_output = outputs.last_hidden_state[:, :self.num_query_tokens, :]
727
+
728
+ # # Project to text space
729
+ # text_embeds = self.text_projection(query_output)
730
+ # text_embeds = self.layer_norm(text_embeds)
731
+
732
+ # return text_embeds
733
+
734
+
735
+ # class ProteinLLMModel(nn.Module):
736
+ # """
737
+ # A combined model that processes both protein sequences and text inputs.
738
+ # Uses ESM2 for protein encoding, QFormer for projection, and Qwen for text generation.
739
+ # """
740
+
741
+ # def __init__(
742
+ # self,
743
+ # text_model_name: str,
744
+ # protein_model_name: str,
745
+ # qformer_model_name: str = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext",
746
+ # cache_dir: Optional[str] = None,
747
+ # max_length_protein: int = 1024,
748
+ # max_length_text: int = 512,
749
+ # text_model_finetune: bool = True,
750
+ # protein_model_finetune: bool = True,
751
+ # num_query_tokens: int = 32,
752
+ # cross_attention_layers: int = 6,
753
+ # ):
754
+ # """
755
+ # Initialize the ProteinLLMModel.
756
+
757
+ # Args:
758
+ # text_model_name: Name of the text model (Qwen)
759
+ # protein_model_name: Name of the protein model (ESM2)
760
+ # qformer_model_name: Name of the QFormer model
761
+ # cache_dir: Directory to cache the models
762
+ # max_length_protein: Maximum length of protein sequences
763
+ # max_length_text: Maximum length of text sequences
764
+ # text_model_finetune: Whether to finetune the text model
765
+ # protein_model_finetune: Whether to finetune the protein model
766
+ # num_query_tokens: Number of learnable query tokens
767
+ # cross_attention_layers: Number of cross-attention layers in QFormer
768
+ # """
769
+ # super().__init__()
770
+
771
+ # self.text_model_finetune = text_model_finetune
772
+ # self.protein_model_finetune = protein_model_finetune
773
+ # self.max_length_protein = max_length_protein
774
+ # self.max_length_text = max_length_text
775
+ # self.num_query_tokens = num_query_tokens
776
+
777
+ # # Load the text model and tokenizer (Qwen)
778
+ # self.text_model = AutoModelForCausalLM.from_pretrained(
779
+ # text_model_name, cache_dir=cache_dir, trust_remote_code=True
780
+ # )
781
+ # self.text_tokenizer = AutoTokenizer.from_pretrained(text_model_name, trust_remote_code=True)
782
+ # self.text_config = self.text_model.config
783
+ # #self.text_tokenizer.chat_template = CHAT_TEMPLATE
784
+ # self.text_tokenizer.pad_token = self.text_tokenizer.eos_token
785
+
786
+ # # Add special tokens for protein
787
+ # new_tokens = ["<|protein_start|>", "<|protein_pad|>", "<|protein_end|>"]
788
+ # self.text_tokenizer.add_special_tokens({"additional_special_tokens": new_tokens})
789
+ # self.protein_token_id = self.text_tokenizer.convert_tokens_to_ids("<|protein_pad|>")
790
+
791
+ # # Load the protein model and tokenizer (ESM2)
792
+ # self.protein_model = EsmModel.from_pretrained(
793
+ # protein_model_name, cache_dir=cache_dir, trust_remote_code=True
794
+ # )
795
+ # self.protein_tokenizer = EsmTokenizer.from_pretrained(protein_model_name, trust_remote_code=True)
796
+ # self.protein_config = self.protein_model.config
797
+
798
+ # # Get model dimensions
799
+ # self.text_hidden_size = self.text_config.hidden_size
800
+ # self.protein_hidden_size = self.protein_config.hidden_size
801
+
802
+ # # Create QFormer projector
803
+ # self.protein_projection = QFormerProjector(
804
+ # protein_hidden_size=self.protein_hidden_size,
805
+ # text_hidden_size=self.text_hidden_size,
806
+ # qformer_model_name=qformer_model_name,
807
+ # num_query_tokens=num_query_tokens,
808
+ # cross_attention_layers=cross_attention_layers,
809
+ # )
810
+
811
+ # # Create processor for handling inputs
812
+ # self.processor = ProteinLLMProcessor(
813
+ # tokenizer=self.text_tokenizer,
814
+ # protein_tokenizer=self.protein_tokenizer
815
+ # )
816
+
817
+ # def process_protein_embeddings(
818
+ # self,
819
+ # protein_tokenized: Dict[str, torch.Tensor],
820
+ # batch_idx_map: List[int],
821
+ # batch_size: int,
822
+ # ) -> List[torch.Tensor]:
823
+ # """
824
+ # Process protein sequences to obtain embeddings.
825
+
826
+ # Args:
827
+ # protein_tokenized: Tokenized protein sequences
828
+ # batch_idx_map: Mapping of each sequence to its batch item
829
+ # batch_size: Number of items in the batch
830
+
831
+ # Returns:
832
+ # List of tensor embeddings for each batch item
833
+ # """
834
+ # # Process all sequences to get protein representations
835
+ # with torch.no_grad():
836
+ # outputs = self.protein_model(
837
+ # input_ids=protein_tokenized["input_ids"],
838
+ # attention_mask=protein_tokenized["attention_mask"],
839
+ # )
840
+ # # Get the last hidden state
841
+ # hidden_states = outputs.last_hidden_state # shape: [n_seqs, seq_len, hidden_dim]
842
+
843
+ # # Apply QFormer projection
844
+ # hidden_states = hidden_states.to(
845
+ # device=self.protein_projection.query_tokens.device,
846
+ # dtype=self.protein_projection.query_tokens.dtype
847
+ # )
848
+
849
+ # # Project all embeddings at once
850
+ # projected_states_list = []
851
+ # for seq_idx in range(hidden_states.size(0)):
852
+ # seq_embedding = hidden_states[seq_idx:seq_idx+1] # [1, seq_len, hidden_dim]
853
+ # seq_attention_mask = protein_tokenized["attention_mask"][seq_idx:seq_idx+1]
854
+
855
+ # projected_embedding = self.protein_projection(
856
+ # seq_embedding, seq_attention_mask
857
+ # ) # [1, num_query_tokens, text_hidden_size]
858
+ # projected_states_list.append(projected_embedding.squeeze(0)) # [num_query_tokens, text_hidden_size]
859
+
860
+ # # Group embeddings by batch item
861
+ # result = [[] for _ in range(batch_size)]
862
+
863
+ # # For each sequence, get its embeddings and add to appropriate batch result
864
+ # for seq_idx, batch_idx in enumerate(batch_idx_map):
865
+ # result[batch_idx].append(projected_states_list[seq_idx])
866
+
867
+ # # Concatenate embeddings for each batch item
868
+ # for i in range(batch_size):
869
+ # if result[i]:
870
+ # result[i] = torch.cat(result[i], dim=0)
871
+ # else:
872
+ # result[i] = torch.zeros((0, self.text_hidden_size))
873
+
874
+ # return result
875
+
876
+ # def forward(
877
+ # self,
878
+ # input_ids: Optional[torch.Tensor] = None,
879
+ # attention_mask: Optional[torch.Tensor] = None,
880
+ # protein_tokenized: Optional[Dict[str, torch.Tensor]] = None,
881
+ # batch_idx_map: Optional[List[int]] = None,
882
+ # labels: Optional[torch.Tensor] = None,
883
+ # **kwargs,
884
+ # ) -> torch.Tensor:
885
+ # """
886
+ # Forward pass through the model.
887
+ # """
888
+ # if input_ids is None or attention_mask is None:
889
+ # raise ValueError("Either 'inputs' or 'input_ids'/'attention_mask' must be provided")
890
+
891
+ # batch_size = input_ids.shape[0]
892
+
893
+ # # Get text embeddings from the model's embedding layer
894
+ # text_inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
895
+
896
+ # if protein_tokenized is not None and batch_idx_map:
897
+ # batch_protein_embeds = self.process_protein_embeddings(
898
+ # protein_tokenized, batch_idx_map, batch_size
899
+ # )
900
+
901
+ # mask = input_ids == self.protein_token_id
902
+
903
+ # n_protein_tokens = mask.sum().item()
904
+ # protein_embeds_flat = torch.cat(batch_protein_embeds, dim=0)
905
+ # n_protein_features = protein_embeds_flat.shape[0]
906
+
907
+ # if n_protein_features != n_protein_tokens:
908
+ # raise ValueError(
909
+ # f"Protein features and protein tokens do not match: features {n_protein_features}, tokens: {n_protein_tokens}"
910
+ # )
911
+
912
+ # # Ensure protein embeddings have the same dtype as the text embeddings
913
+ # protein_embeds_flat = protein_embeds_flat.to(dtype=text_inputs_embeds.dtype)
914
+ # text_inputs_embeds[mask] = protein_embeds_flat
915
+
916
+ # # Forward pass through the text model
917
+ # outputs = self.text_model(
918
+ # inputs_embeds=text_inputs_embeds,
919
+ # attention_mask=attention_mask,
920
+ # labels=labels,
921
+ # **kwargs,
922
+ # )
923
+
924
+ # return outputs
925
+
926
+ # def generate(
927
+ # self,
928
+ # input_ids: Optional[torch.Tensor] = None,
929
+ # attention_mask: Optional[torch.Tensor] = None,
930
+ # protein_tokenized: Optional[Dict[str, torch.Tensor]] = None,
931
+ # batch_idx_map: Optional[List[int]] = None,
932
+ # **generation_kwargs,
933
+ # ) -> Union[torch.Tensor, List[str]]:
934
+ # """
935
+ # Generate text based on protein and text inputs.
936
+ # """
937
+ # if input_ids is None or attention_mask is None:
938
+ # raise ValueError("Either 'inputs' or 'input_ids'/'attention_mask' must be provided")
939
+
940
+ # batch_size = input_ids.shape[0]
941
+
942
+ # # Get text embeddings from the model's embedding layer
943
+ # text_inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
944
+
945
+ # if protein_tokenized is not None and batch_idx_map:
946
+ # batch_protein_embeds = self.process_protein_embeddings(
947
+ # protein_tokenized, batch_idx_map, batch_size
948
+ # )
949
+
950
+ # mask = input_ids == self.protein_token_id
951
+
952
+ # n_protein_tokens = mask.sum().item()
953
+ # protein_embeds_flat = torch.cat(batch_protein_embeds, dim=0)
954
+ # n_protein_features = protein_embeds_flat.shape[0]
955
+
956
+ # if n_protein_features != n_protein_tokens:
957
+ # raise ValueError(
958
+ # f"Protein features and protein tokens do not match: features {n_protein_features}, tokens: {n_protein_tokens}"
959
+ # )
960
+
961
+ # # Ensure protein embeddings have the same dtype as the text embeddings
962
+ # protein_embeds_flat = protein_embeds_flat.to(dtype=text_inputs_embeds.dtype)
963
+ # text_inputs_embeds[mask] = protein_embeds_flat
964
+
965
+ # # Generation
966
+ # with torch.no_grad():
967
+ # outputs = self.text_model.generate(
968
+ # inputs_embeds=text_inputs_embeds,
969
+ # attention_mask=attention_mask,
970
+ # use_cache=True,
971
+ # **generation_kwargs,
972
+ # )
973
+
974
+ # return outputs
975
+
976
+
977
+ # @property
978
+ # def device(self):
979
+ # """Get the device of the model."""
980
+ # return next(self.parameters()).device
981
+
982
+ # def to_device(self, tensor_dict):
983
+ # """Move tensor dictionary to model device."""
984
+ # return {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
985
+ # for k, v in tensor_dict.items()}
986
+
987
+
988
+ # def get_protein_embeddings(self, protein_sequences: List[str]) -> torch.Tensor:
989
+ # """
990
+ # Get raw protein embeddings before projection.
991
+
992
+ # Args:
993
+ # protein_sequences: List of protein sequences
994
+
995
+ # Returns:
996
+ # Raw protein embeddings [batch_size, seq_len, protein_hidden_size]
997
+ # """
998
+ # # Tokenize protein sequences
999
+ # protein_inputs = self.protein_tokenizer(
1000
+ # protein_sequences,
1001
+ # padding=True,
1002
+ # truncation=True,
1003
+ # max_length=self.max_length_protein,
1004
+ # return_tensors="pt",
1005
+ # )
1006
+
1007
+ # # 移动到正确设备
1008
+ # protein_inputs = self.to_device(protein_inputs)
1009
+
1010
+ # # Get protein embeddings
1011
+ # with torch.no_grad():
1012
+ # protein_outputs = self.protein_model(**protein_inputs)
1013
+ # protein_embeddings = protein_outputs.last_hidden_state
1014
+
1015
+ # return protein_embeddings
1016
+
1017
+ # def get_protein_features(
1018
+ # self,
1019
+ # protein_sequences: List[str],
1020
+ # return_tensors: str = "pt",
1021
+ # ) -> torch.Tensor:
1022
+ # """
1023
+ # Extract protein features for contrastive learning.
1024
+
1025
+ # Args:
1026
+ # protein_sequences: List of protein sequences
1027
+ # return_tensors: Return format for tensors
1028
+
1029
+ # Returns:
1030
+ # Protein features [batch_size, num_query_tokens, text_hidden_size]
1031
+ # """
1032
+ # # Tokenize protein sequences
1033
+ # protein_inputs = self.protein_tokenizer(
1034
+ # protein_sequences,
1035
+ # padding=True,
1036
+ # truncation=True,
1037
+ # max_length=self.max_length_protein,
1038
+ # return_tensors=return_tensors,
1039
+ # )
1040
+
1041
+ # protein_inputs = self.to_device(protein_inputs)
1042
+
1043
+ # # Get protein embeddings
1044
+ # with torch.no_grad():
1045
+ # protein_outputs = self.protein_model(**protein_inputs)
1046
+ # protein_embeddings = protein_outputs.last_hidden_state
1047
+
1048
+ # # Project through QFormer
1049
+ # protein_features = self.protein_projection(
1050
+ # protein_embeddings, protein_inputs["attention_mask"],token_type_ids=None
1051
+ # )
1052
+
1053
+ # # Global average pooling over query tokens
1054
+ # protein_features = protein_features.mean(dim=1) # [batch_size, text_hidden_size]
1055
+
1056
+ # return protein_features
1057
+
1058
+ # def get_text_features(
1059
+ # self,
1060
+ # text_sequences: List[str],
1061
+ # return_tensors: str = "pt",
1062
+ # ) -> torch.Tensor:
1063
+ # """
1064
+ # Extract text features for contrastive learning.
1065
+
1066
+ # Args:
1067
+ # text_sequences: List of text descriptions
1068
+ # return_tensors: Return format for tensors
1069
+
1070
+ # Returns:
1071
+ # Text features [batch_size, text_hidden_size]
1072
+ # """
1073
+ # # Tokenize text sequences
1074
+ # text_inputs = self.text_tokenizer(
1075
+ # text_sequences,
1076
+ # padding=True,
1077
+ # truncation=True,
1078
+ # max_length=self.max_length_text,
1079
+ # return_tensors=return_tensors,
1080
+ # )
1081
+
1082
+ # text_inputs = self.to_device(text_inputs)
1083
+
1084
+ # # Get text embeddings from the embedding layer
1085
+ # with torch.no_grad():
1086
+ # text_embeddings = self.text_model.get_input_embeddings()(text_inputs["input_ids"])
1087
+
1088
+ # # Apply attention mask and average pooling
1089
+ # attention_mask = text_inputs["attention_mask"].unsqueeze(-1)
1090
+ # masked_embeddings = text_embeddings * attention_mask
1091
+ # text_features = masked_embeddings.sum(dim=1) / attention_mask.sum(dim=1)
1092
+
1093
+ # return text_features
BioReason_new/bioreason/protein_modules/_init_.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .protein_base_module import ProteinBaseModule
2
+ from .protein_module import ESM2ProteinModule
3
+
4
+ __all__ = [
5
+ "ProteinBaseModule",
6
+ "ESM2ProteinModule",
7
+ ]
BioReason_new/bioreason/protein_modules/protein_base_module.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, Any, Union
3
+ import torch
4
+
5
+ class ProteinBaseModule(ABC):
6
+ def __init__(self):
7
+ super().__init__()
8
+
9
+ @abstractmethod
10
+ def get_proteinllm_key(self):
11
+ pass
12
+
13
+ @abstractmethod
14
+ def get_model_class(self, model_id: str, model_init_kwargs: dict):
15
+ pass
16
+
17
+ def post_model_init(self, model, processing_class):
18
+ pass
19
+
20
+ def is_embeds_input(self):
21
+ return False
22
+
23
+ @abstractmethod
24
+ def get_processing_class(self):
25
+ pass
26
+
27
+ @abstractmethod
28
+ def get_proteinllm_modules_keywords(self):
29
+ pass
30
+
31
+ @abstractmethod
32
+ def get_custom_multimodal_keywords(self):
33
+ pass
34
+
35
+ @abstractmethod
36
+ def get_non_generate_params(self):
37
+ pass
38
+
39
+ @abstractmethod
40
+ def get_custom_processing_keywords(self):
41
+ pass
42
+
43
+ @abstractmethod
44
+ def prepare_prompt(self, processing_class, inputs: dict[str, Union[torch.Tensor, Any]]):
45
+ pass
46
+
47
+ @abstractmethod
48
+ def prepare_model_inputs(self, processing_class, prompts_text, proteins, return_tensors, padding, padding_side, add_special_tokens):
49
+ pass
BioReason_new/bioreason/protein_modules/protein_module.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any, Union, List, Optional, Callable, Type
2
+ from trl.data_utils import maybe_apply_chat_template
3
+ import torch
4
+
5
+ from bioreason.protein_modules.protein_module import ProteinBaseModule
6
+ from bioreason.models.protein_llm import ProteinLLMModel
7
+ from bioreason.models.dl.processing_dl import ProteinLLMProcessor
8
+
9
+
10
+ class ESM2ProteinModule(ProteinBaseModule):
11
+ """
12
+ Protein module implementation for ESM2-based models with QFormer projection.
13
+
14
+ This module provides the interface between Protein-LLM models and the training
15
+ infrastructure, handling model loading, processing setup, and reward functions.
16
+ """
17
+
18
+ def __init__(self):
19
+ """Initialize the ESM2ProteinModule."""
20
+ super().__init__()
21
+
22
+ def get_proteinllm_key(self) -> str:
23
+ """
24
+ Get the key identifier for this Protein-LLM implementation.
25
+
26
+ Returns:
27
+ String identifier for this module type
28
+ """
29
+ return "qwen"
30
+
31
+ def get_model_class(self, model_id: str, model_init_kwargs: Dict[str, Any]) -> Type:
32
+ """
33
+ Return the appropriate model class based on model ID.
34
+
35
+ Args:
36
+ model_id: Identifier for the model
37
+ model_init_kwargs: Initialization arguments for the model
38
+
39
+ Returns:
40
+ The model class to instantiate
41
+
42
+ Raises:
43
+ ValueError: If the model is not supported
44
+ """
45
+ if "ProteinLLM" in model_id:
46
+ model_cls = ProteinLLMModel
47
+ else:
48
+ raise ValueError(f"Unsupported model: {model_id}")
49
+ return model_cls
50
+
51
+ def post_model_init(self, model: Any, processing_class: Any) -> None:
52
+ """
53
+ Perform any post-initialization setup on the model.
54
+
55
+ Args:
56
+ model: The initialized model
57
+ processing_class: The processor for the model
58
+ """
59
+ # No post-init needed for this implementation
60
+ pass
61
+
62
+ def get_processing_class(self) -> Type:
63
+ """
64
+ Get the processing class to use with this Protein-LLM model.
65
+
66
+ Returns:
67
+ The processing class
68
+ """
69
+ return ProteinLLMProcessor
70
+
71
+ def get_proteinllm_modules_keywords(self) -> List[str]:
72
+ """
73
+ Get keywords to identify protein-specific modules in the model.
74
+
75
+ Used to exclude protein modules from LoRA adaptation during training.
76
+
77
+ Returns:
78
+ List of keywords that identify protein modules
79
+ """
80
+ return ["protein", "qformer", "projection"]
81
+
82
+ def get_custom_multimodal_keywords(self) -> List[str]:
83
+ """
84
+ Get keywords for multimodal inputs that should be passed to the model.
85
+
86
+ Returns:
87
+ List of input keywords for multimodal processing
88
+ """
89
+ return ["protein_tokenized", "batch_idx_map"]
90
+
91
+ def get_non_generate_params(self) -> List[str]:
92
+ """
93
+ Get parameter names that should be excluded from generation.
94
+
95
+ Returns:
96
+ List of parameter names to exclude from generation calls
97
+ """
98
+ return []
99
+
100
+ def get_custom_processing_keywords(self) -> List[tuple]:
101
+ """
102
+ Get custom processing keywords for the processor.
103
+
104
+ Returns:
105
+ List of (component, parameter) tuples for custom processing
106
+ """
107
+ return [("protein_tokenizer", "max_length")]
108
+
109
+ def prepare_prompt(
110
+ self, processing_class: Any, inputs: List[Dict[str, Union[torch.Tensor, Any]]]
111
+ ) -> List[str]:
112
+ """
113
+ Prepare prompts from input examples.
114
+
115
+ Args:
116
+ processing_class: The processor to use
117
+ inputs: List of input examples
118
+
119
+ Returns:
120
+ List of prepared prompts
121
+ """
122
+ prompts_text = [
123
+ maybe_apply_chat_template(example, processing_class)["prompt"]
124
+ for example in inputs
125
+ ]
126
+ return prompts_text
127
+
128
+ def prepare_model_inputs(
129
+ self,
130
+ processing_class: Any,
131
+ model: Any,
132
+ prompts_text: List[str],
133
+ batch_protein_sequences: List[List[str]],
134
+ return_tensors: str = "pt",
135
+ padding: bool = True,
136
+ padding_side: str = "left",
137
+ add_special_tokens: bool = False,
138
+ ) -> Dict[str, Any]:
139
+ """
140
+ Prepare inputs for the model.
141
+
142
+ Args:
143
+ processing_class: The processor to use
144
+ model: The model to prepare inputs for
145
+ prompts_text: List of text prompts
146
+ batch_protein_sequences: List of lists of protein sequences
147
+ return_tensors: Return format for tensors
148
+ padding: Whether to pad inputs
149
+ padding_side: Side to pad on
150
+ add_special_tokens: Whether to add special tokens
151
+
152
+ Returns:
153
+ Processed inputs for the model
154
+ """
155
+ # Handle DataParallel wrapped models by accessing the module attribute if needed
156
+ max_length_text = model.max_length_text if not hasattr(model, 'module') else model.module.max_length_text
157
+ max_length_protein = model.max_length_protein if not hasattr(model, 'module') else model.module.max_length_protein
158
+
159
+ prompt_inputs = processing_class(
160
+ text=prompts_text,
161
+ batch_protein_sequences=batch_protein_sequences,
162
+ return_tensors=return_tensors,
163
+ padding=padding,
164
+ padding_side=padding_side,
165
+ add_special_tokens=add_special_tokens,
166
+ max_length_text=max_length_text,
167
+ max_length_protein=max_length_protein,
168
+ )
169
+
170
+ return prompt_inputs
171
+
172
+ def is_embeds_input(self) -> bool:
173
+ """
174
+ Whether the model uses embeddings as input (instead of token IDs).
175
+
176
+ Returns:
177
+ Boolean indicating if the model takes embedding inputs
178
+ """
179
+ return True
180
+
181
+ @staticmethod
182
+ def get_question_template() -> str:
183
+ """
184
+ Get the template for formatting questions.
185
+
186
+ Returns:
187
+ String template for questions
188
+ """
189
+ return "{Question}"
190
+
191
+ @staticmethod
192
+ def format_reward_rec(completions: List[Dict[str, Any]], **kwargs) -> List[float]:
193
+ """
194
+ Check if the model output matches a specific format.
195
+
196
+ Args:
197
+ completions: List of model completions
198
+ **kwargs: Additional arguments
199
+
200
+ Returns:
201
+ List of reward scores (1.0 for match, 0.0 for no match)
202
+ """
203
+ import re
204
+ import os
205
+ from datetime import datetime
206
+
207
+ # Pattern to match the expected output format
208
+ pattern = r"<think>.*?</think>\s*<answer>.*?\{.*\[\d+,\s*\d+,\s*\d+,\s*\d+\].*\}.*?</answer>"
209
+ completion_contents = [completion[0]["content"] for completion in completions]
210
+ matches = [
211
+ re.search(pattern, content, re.DOTALL) is not None
212
+ for content in completion_contents
213
+ ]
214
+
215
+ # Log format results if in debug mode
216
+ current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
217
+ if os.getenv("DEBUG_MODE") == "true":
218
+ log_path = os.getenv("LOG_PATH")
219
+ with open(
220
+ log_path.replace(".txt", "_format.txt"), "a", encoding="utf-8"
221
+ ) as f:
222
+ f.write(f"------------- {current_time} Format reward -------------\n")
223
+ for content, match in zip(completion_contents, matches):
224
+ f.write(f"Content: {content}\n")
225
+ f.write(f"Has format: {bool(match)}\n")
226
+
227
+ return [1.0 if match else 0.0 for match in matches]
228
+
229
+ @staticmethod
230
+ def select_reward_func(func: str, task_type: str) -> Callable:
231
+ """
232
+ Select the appropriate reward function based on function name and task type.
233
+
234
+ Args:
235
+ func: The type of reward function ('accuracy', 'format', etc.)
236
+ task_type: The type of task ('rec', etc.)
237
+
238
+ Returns:
239
+ The reward function to use
240
+
241
+ Raises:
242
+ ValueError: If the function or task type is not supported
243
+ """
244
+ if func == "accuracy":
245
+ match task_type:
246
+ case "rec":
247
+ return ESM2ProteinModule.iou_reward
248
+ case _:
249
+ raise ValueError(f"Unsupported reward function: {func}")
250
+ elif func == "format":
251
+ match task_type:
252
+ case "rec":
253
+ return ESM2ProteinModule.format_reward_rec
254
+ case _:
255
+ raise ValueError(f"Unsupported reward function: {func}")
256
+ else:
257
+ raise ValueError(f"Unsupported reward function: {func}")
BioReason_new/bioreason/trainer/__pycache__/contrast_trainer_new.cpython-310.pyc ADDED
Binary file (13.7 kB). View file
 
BioReason_new/bioreason/trainer/__pycache__/contrast_trainer_new.cpython-311.pyc ADDED
Binary file (25.4 kB). View file
 
BioReason_new/bioreason/trainer/_init_.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .grpo_config import DNALLMGRPOConfig
2
+ from .grpo_trainer import DNALLMGRPOTrainer
3
+ from .protein_grpo_config import ProteinLLMGRPOConfig
4
+ from .protein_grpo_trainer import ProteinLLMGRPOTrainer
5
+
6
+ __all__ = [
7
+ "DNALLMGRPOConfig",
8
+ "DNALLMGRPOTrainer",
9
+ "ProteinLLMGRPOConfig",
10
+ "ProteinLLMGRPOTrainer",
11
+ ]
BioReason_new/bioreason/trainer/contrast_trainer.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.utils.data import DataLoader
6
+ from typing import Dict, List, Optional, Any, Union
7
+ from transformers import Trainer, TrainingArguments
8
+ from dataclasses import dataclass, field
9
+ import wandb
10
+ from datasets import Dataset
11
+
12
+ from bioreason.models.protein_llm import ProteinLLMModel
13
+
14
+ def pl_concat_all_gather(tensor):
15
+ """
16
+ Gather tensors from all processes in distributed training.
17
+ Falls back to returning the original tensor if not in distributed mode.
18
+ """
19
+ if not dist.is_available() or not dist.is_initialized():
20
+ return tensor
21
+
22
+ world_size = dist.get_world_size()
23
+ if world_size == 1:
24
+ return tensor
25
+
26
+ gathered_tensors = [torch.zeros_like(tensor) for _ in range(world_size)]
27
+ dist.all_gather(gathered_tensors, tensor)
28
+ return torch.cat(gathered_tensors, dim=0)
29
+
30
+ @dataclass
31
+ class ContrastiveTrainingArguments(TrainingArguments):
32
+ """
33
+ Arguments for contrastive learning training.
34
+ """
35
+ temperature: float = field(
36
+ default=0.07,
37
+ metadata={"help": "Temperature parameter for contrastive loss"}
38
+ )
39
+ freeze_protein_model: bool = field(
40
+ default=True,
41
+ metadata={"help": "Whether to freeze the protein model during training"}
42
+ )
43
+ freeze_text_model: bool = field(
44
+ default=True,
45
+ metadata={"help": "Whether to freeze the text model during training"}
46
+ )
47
+ protein_weight: float = field(
48
+ default=1.0,
49
+ metadata={"help": "Weight for protein features in contrastive loss"}
50
+ )
51
+ text_weight: float = field(
52
+ default=1.0,
53
+ metadata={"help": "Weight for text features in contrastive loss"}
54
+ )
55
+ max_length_protein: int = field(
56
+ default=1024,
57
+ metadata={"help": "Maximum length for protein sequences"}
58
+ )
59
+ max_length_text: int = field(
60
+ default=512,
61
+ metadata={"help": "Maximum length for text sequences"}
62
+ )
63
+
64
+
65
+
66
+ class ContrastiveLoss(nn.Module):
67
+ """
68
+ Contrastive loss for protein-text alignment.
69
+ """
70
+
71
+ def __init__(self, temperature: float = 0.07):
72
+ super().__init__()
73
+ self.temperature = temperature
74
+
75
+ def forward(
76
+ self,
77
+ protein_features: torch.Tensor,
78
+ text_features: torch.Tensor
79
+ ) -> torch.Tensor:
80
+ """
81
+ Compute contrastive loss between protein and text features.
82
+
83
+ Args:
84
+ protein_features: [batch_size, hidden_size]
85
+ text_features: [batch_size, hidden_size]
86
+
87
+ Returns:
88
+ Contrastive loss value
89
+ """
90
+ # Normalize features
91
+ protein_features = F.normalize(protein_features, dim=-1)
92
+ text_features = F.normalize(text_features, dim=-1)
93
+
94
+ # Compute similarity matrix
95
+ similarity_matrix = torch.matmul(protein_features, text_features.T) / self.temperature
96
+
97
+ # Create labels for positive pairs (diagonal elements)
98
+ batch_size = protein_features.size(0)
99
+ labels = torch.arange(batch_size, device=protein_features.device)
100
+
101
+ # Compute cross-entropy loss in both directions
102
+ loss_p2t = F.cross_entropy(similarity_matrix, labels)
103
+ loss_t2p = F.cross_entropy(similarity_matrix.T, labels)
104
+
105
+ # Average the losses
106
+ total_loss = (loss_p2t + loss_t2p) / 2
107
+
108
+ return total_loss
109
+
110
+
111
+ class ContrastiveTrainer(Trainer):
112
+ """
113
+ Trainer for contrastive learning between proteins and text.
114
+ """
115
+
116
+ def __init__(
117
+ self,
118
+ model: ProteinLLMModel,
119
+ args: ContrastiveTrainingArguments,
120
+ train_dataset: Optional[Dataset] = None,
121
+ eval_dataset: Optional[Dataset] = None,
122
+ data_collator: Optional[callable] = None,
123
+ **kwargs
124
+ ):
125
+ self.contrastive_loss = ContrastiveLoss(temperature=args.temperature)
126
+ self.freeze_protein_model = args.freeze_protein_model
127
+ self.freeze_text_model = args.freeze_text_model
128
+ self.protein_weight = args.protein_weight
129
+ self.text_weight = args.text_weight
130
+
131
+ # Freeze models if specified
132
+ if self.freeze_protein_model:
133
+ for param in model.protein_model.parameters():
134
+ param.requires_grad = False
135
+
136
+ if self.freeze_text_model:
137
+ for param in model.text_model.parameters():
138
+ param.requires_grad = False
139
+
140
+ # Ensure projection layers are trainable
141
+ for param in model.protein_projection.parameters():
142
+ param.requires_grad = True
143
+
144
+ super().__init__(
145
+ model=model,
146
+ args=args,
147
+ train_dataset=train_dataset,
148
+ eval_dataset=eval_dataset,
149
+ data_collator=data_collator,
150
+ **kwargs
151
+ )
152
+
153
+ def compute_loss(self, model, inputs, return_outputs=False):
154
+ """
155
+ Compute contrastive loss.
156
+
157
+ Args:
158
+ model: The ProteinLLMModel
159
+ inputs: Dictionary containing protein_sequences and text_sequences
160
+ return_outputs: Whether to return model outputs
161
+
162
+ Returns:
163
+ Contrastive loss
164
+ """
165
+ protein_sequences = inputs["protein_sequences"]
166
+ text_sequences = inputs["text_sequences"]
167
+
168
+ # Get protein features
169
+ protein_features = model.get_protein_features(protein_sequences)
170
+
171
+ # Get text features
172
+ text_features = model.get_text_features(text_sequences)
173
+
174
+ # Compute contrastive loss
175
+ loss = self.contrastive_loss(protein_features, text_features)
176
+
177
+ # Log metrics
178
+ with torch.no_grad():
179
+ # Compute similarity scores for monitoring
180
+ protein_features_norm = F.normalize(protein_features, dim=-1)
181
+ text_features_norm = F.normalize(text_features, dim=-1)
182
+ similarity_matrix = torch.matmul(protein_features_norm, text_features_norm.T)
183
+
184
+ # Diagonal elements are positive pairs
185
+ positive_similarities = torch.diag(similarity_matrix)
186
+ negative_similarities = similarity_matrix[~torch.eye(similarity_matrix.size(0), dtype=bool)]
187
+
188
+ self.log({
189
+ "contrastive_loss": loss.item(),
190
+ "positive_similarity_mean": positive_similarities.mean().item(),
191
+ "negative_similarity_mean": negative_similarities.mean().item(),
192
+ "positive_similarity_std": positive_similarities.std().item(),
193
+ "similarity_gap": (positive_similarities.mean() - negative_similarities.mean()).item(),
194
+ })
195
+
196
+ if return_outputs:
197
+ outputs = {
198
+ "protein_features": protein_features,
199
+ "text_features": text_features,
200
+ "similarity_matrix": similarity_matrix,
201
+ }
202
+ return (loss, outputs)
203
+
204
+ return loss
205
+
206
+ def evaluation_loop(self, dataloader, description, prediction_loss_only=None, ignore_keys=None, metric_key_prefix="eval"):
207
+ """
208
+ Custom evaluation loop for contrastive learning.
209
+ """
210
+ model = self._wrap_model(self.model, training=False, dataloader=dataloader)
211
+ model.eval()
212
+
213
+ total_loss = 0.0
214
+ total_samples = 0
215
+ all_protein_features = []
216
+ all_text_features = []
217
+
218
+ for step, inputs in enumerate(dataloader):
219
+ with torch.no_grad():
220
+ loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
221
+
222
+ total_loss += loss.item()
223
+ total_samples += len(inputs["protein_sequences"])
224
+
225
+ all_protein_features.append(outputs["protein_features"].cpu())
226
+ all_text_features.append(outputs["text_features"].cpu())
227
+
228
+ # Compute overall metrics
229
+ avg_loss = total_loss / len(dataloader)
230
+
231
+ # Concatenate all features
232
+ all_protein_features = torch.cat(all_protein_features, dim=0)
233
+ all_text_features = torch.cat(all_text_features, dim=0)
234
+
235
+ # Compute retrieval metrics
236
+ retrieval_metrics = self.compute_retrieval_metrics(all_protein_features, all_text_features)
237
+
238
+ metrics = {
239
+ f"{metric_key_prefix}_loss": avg_loss,
240
+ **{f"{metric_key_prefix}_{k}": v for k, v in retrieval_metrics.items()}
241
+ }
242
+
243
+ return metrics
244
+
245
+ def compute_retrieval_metrics(self, protein_features: torch.Tensor, text_features: torch.Tensor) -> Dict[str, float]:
246
+ """
247
+ Compute retrieval metrics (Recall@K).
248
+
249
+ Args:
250
+ protein_features: [num_samples, hidden_size]
251
+ text_features: [num_samples, hidden_size]
252
+
253
+ Returns:
254
+ Dictionary of retrieval metrics
255
+ """
256
+ # Normalize features
257
+ protein_features = F.normalize(protein_features, dim=-1)
258
+ text_features = F.normalize(text_features, dim=-1)
259
+
260
+ # Compute similarity matrix
261
+ similarity_matrix = torch.matmul(protein_features, text_features.T)
262
+
263
+ # Protein-to-text retrieval
264
+ p2t_ranks = []
265
+ for i in range(similarity_matrix.size(0)):
266
+ similarities = similarity_matrix[i]
267
+ rank = (similarities >= similarities[i]).sum().item()
268
+ p2t_ranks.append(rank)
269
+
270
+ # Text-to-protein retrieval
271
+ t2p_ranks = []
272
+ for i in range(similarity_matrix.size(1)):
273
+ similarities = similarity_matrix[:, i]
274
+ rank = (similarities >= similarities[i]).sum().item()
275
+ t2p_ranks.append(rank)
276
+
277
+ # Compute Recall@K
278
+ metrics = {}
279
+ for k in [1, 5, 10]:
280
+ p2t_recall_k = sum(1 for rank in p2t_ranks if rank <= k) / len(p2t_ranks)
281
+ t2p_recall_k = sum(1 for rank in t2p_ranks if rank <= k) / len(t2p_ranks)
282
+
283
+ metrics[f"p2t_recall_at_{k}"] = p2t_recall_k
284
+ metrics[f"t2p_recall_at_{k}"] = t2p_recall_k
285
+ metrics[f"avg_recall_at_{k}"] = (p2t_recall_k + t2p_recall_k) / 2
286
+
287
+ # Mean rank
288
+ metrics["p2t_mean_rank"] = sum(p2t_ranks) / len(p2t_ranks)
289
+ metrics["t2p_mean_rank"] = sum(t2p_ranks) / len(t2p_ranks)
290
+
291
+ return metrics
292
+
293
+
294
+ def protein_text_collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, List[str]]:
295
+ """
296
+ Collate function for protein-text contrastive learning.
297
+
298
+ Args:
299
+ batch: List of samples, each containing "protein_sequence" and "text_description"
300
+
301
+ Returns:
302
+ Dictionary with lists of protein sequences and text descriptions
303
+ """
304
+ protein_sequences = [item["protein_sequence"] for item in batch]
305
+ text_sequences = [item["text_description"] for item in batch]
306
+
307
+ return {
308
+ "protein_sequences": protein_sequences,
309
+ "text_sequences": text_sequences,
310
+ }
311
+
312
+
313
+ # Example usage
314
+ def train_contrastive_model(
315
+ model: ProteinLLMModel,
316
+ train_dataset: Dataset,
317
+ eval_dataset: Optional[Dataset] = None,
318
+ output_dir: str = "./contrastive_outputs",
319
+ num_epochs: int = 10,
320
+ batch_size: int = 32,
321
+ learning_rate: float = 1e-4,
322
+ temperature: float = 0.07,
323
+ **kwargs
324
+ ):
325
+ """
326
+ Train the model with contrastive learning.
327
+
328
+ Args:
329
+ model: ProteinLLMModel to train
330
+ train_dataset: Training dataset with protein_sequence and text_description
331
+ eval_dataset: Optional evaluation dataset
332
+ output_dir: Directory to save outputs
333
+ num_epochs: Number of training epochs
334
+ batch_size: Training batch size
335
+ learning_rate: Learning rate
336
+ temperature: Temperature for contrastive loss
337
+ **kwargs: Additional training arguments
338
+ """
339
+ training_args = ContrastiveTrainingArguments(
340
+ output_dir=output_dir,
341
+ num_train_epochs=num_epochs,
342
+ per_device_train_batch_size=batch_size,
343
+ per_device_eval_batch_size=batch_size,
344
+ learning_rate=learning_rate,
345
+ temperature=temperature,
346
+ logging_steps=10,
347
+ evaluation_strategy="steps" if eval_dataset else "no",
348
+ eval_steps=100 if eval_dataset else None,
349
+ save_steps=500,
350
+ save_total_limit=3,
351
+ load_best_model_at_end=True if eval_dataset else False,
352
+ metric_for_best_model="eval_avg_recall_at_1" if eval_dataset else None,
353
+ greater_is_better=True,
354
+ report_to=["wandb"] if wandb.run else [],
355
+ **kwargs
356
+ )
357
+
358
+ trainer = ContrastiveTrainer(
359
+ model=model,
360
+ args=training_args,
361
+ train_dataset=train_dataset,
362
+ eval_dataset=eval_dataset,
363
+ data_collator=protein_text_collate_fn,
364
+ )
365
+
366
+ # Train the model
367
+ trainer.train()
368
+
369
+ # Save the final model
370
+ trainer.save_model()
371
+
372
+ return trainer
BioReason_new/bioreason/trainer/contrast_trainer_new.py ADDED
@@ -0,0 +1,659 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.distributed as dist
6
+ from torch.utils.data import DataLoader
7
+ from typing import Dict, List, Optional, Any, Union
8
+ from dataclasses import dataclass, field
9
+ from transformers import Trainer, TrainingArguments
10
+ import wandb
11
+ from datasets import Dataset
12
+
13
+ from bioreason.models.protein_llm import ProteinLLMModel
14
+
15
+
16
+ def pl_concat_all_gather(tensor):
17
+ """
18
+ Gather tensors from all processes in distributed training.
19
+ Falls back to returning the original tensor if not in distributed mode.
20
+ """
21
+ if not dist.is_available() or not dist.is_initialized():
22
+ return tensor
23
+
24
+ world_size = dist.get_world_size()
25
+ if world_size == 1:
26
+ return tensor
27
+
28
+ gathered_tensors = [torch.zeros_like(tensor) for _ in range(world_size)]
29
+ dist.all_gather(gathered_tensors, tensor)
30
+ return torch.cat(gathered_tensors, dim=0)
31
+
32
+
33
+ @dataclass
34
+ class ContrastiveTrainingArguments(TrainingArguments):
35
+ """
36
+ Arguments for contrastive learning training.
37
+ """
38
+ print(TrainingArguments.__module__)
39
+ print("----------")
40
+ temperature: float = field(
41
+ default=0.07,
42
+ metadata={"help": "Temperature parameter for contrastive loss"}
43
+ )
44
+ freeze_protein_model: bool = field(
45
+ default=True,
46
+ metadata={"help": "Whether to freeze the protein model during training"}
47
+ )
48
+ freeze_text_model: bool = field(
49
+ default=True,
50
+ metadata={"help": "Whether to freeze the text model during training"}
51
+ )
52
+ protein_weight: float = field(
53
+ default=1.0,
54
+ metadata={"help": "Weight for protein features in contrastive loss"}
55
+ )
56
+ text_weight: float = field(
57
+ default=1.0,
58
+ metadata={"help": "Weight for text features in contrastive loss"}
59
+ )
60
+ max_length_protein: int = field(
61
+ default=1024,
62
+ metadata={"help": "Maximum length for protein sequences"}
63
+ )
64
+ max_length_text: int = field(
65
+ default=512,
66
+ metadata={"help": "Maximum length for text sequences"}
67
+ )
68
+ enable_ptm: bool = field(
69
+ default=True,
70
+ metadata={"help": "Enable protein-text matching task"}
71
+ )
72
+ ptm_weight: float = field(
73
+ default=1.0,
74
+ metadata={"help": "Weight for protein-text matching loss"}
75
+ )
76
+
77
+
78
+ class EnhancedContrastiveLoss(nn.Module):
79
+ """
80
+ Enhanced contrastive loss for protein-text alignment with multi-query support.
81
+ Based on BLIP2 QFormer implementation.
82
+ """
83
+
84
+ def __init__(self, temperature: float = 0.07, enable_ptm: bool = True):
85
+ super().__init__()
86
+ self.temperature = temperature
87
+ self.enable_ptm = enable_ptm
88
+ if enable_ptm:
89
+ self.ptm_head = nn.Linear(768, 2) # Assuming hidden size of 768
90
+
91
+ def contrast_global(self, features_protein, features_text, features_protein_all, features_text_all, return_sim=False):
92
+ """
93
+ Compute global contrastive loss across all processes.
94
+
95
+ Args:
96
+ features_protein: [B, num_queries, D] - local protein features
97
+ features_text: [B, D] - local text features
98
+ features_protein_all: [B * num_gpus, num_queries, D] - all protein features
99
+ features_text_all: [B * num_gpus, D] - all text features
100
+ return_sim: whether to return similarity matrices
101
+ """
102
+ bs = features_protein.size(0)
103
+ device = features_protein.device
104
+
105
+ # Protein-to-text similarity
106
+ # shape: [B, 1, num_queries, D] @ [B * num_gpus, D, 1] -> [B, B * num_gpus, num_queries]
107
+ sim_p2t = (features_protein.unsqueeze(1) @ features_text_all.unsqueeze(-1)).squeeze()
108
+ sim_p2t, _ = sim_p2t.max(-1) # Take max over query tokens: [B, B * num_gpus]
109
+
110
+ logits_per_protein = sim_p2t / self.temperature
111
+
112
+ # Text-to-protein similarity
113
+ # shape: [B, 1, 1, D] @ [B*num_gpus, D, num_queries] -> [B, B*num_gpus, num_queries]
114
+ sim_t2p = (features_text.unsqueeze(1).unsqueeze(1) @ features_protein_all.permute(0, 2, 1)).squeeze()
115
+ sim_t2p, _ = sim_t2p.max(-1) # Take max over query tokens: [B, B * num_gpus]
116
+ logits_per_text = sim_t2p / self.temperature
117
+
118
+ # Create labels for current rank
119
+ if dist.is_available() and dist.is_initialized():
120
+ rank = dist.get_rank()
121
+ labels = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(device)
122
+ else:
123
+ labels = torch.arange(bs, dtype=torch.long, device=device)
124
+
125
+ # Compute contrastive losses
126
+ loss_protein = F.cross_entropy(logits_per_protein, labels)
127
+ loss_text = F.cross_entropy(logits_per_text, labels)
128
+ loss = (loss_protein + loss_text) / 2
129
+
130
+ if return_sim:
131
+ return logits_per_protein, logits_per_text, loss
132
+ else:
133
+ return loss
134
+
135
+ def compute_ptm_loss(self, protein_embeds, protein_mask, text_ids, text_mask,
136
+ query_tokens, tokenizer, qformer, sim_p2t, sim_t2p):
137
+ """
138
+ Compute protein-text matching loss.
139
+ 修改以匹配标准 BertModel 的 API
140
+ """
141
+ batch_size = protein_embeds.size(0)
142
+ device = protein_embeds.device
143
+
144
+ # Get world features for negative sampling
145
+ protein_embeds_world = pl_concat_all_gather(protein_embeds)
146
+ protein_mask_world = pl_concat_all_gather(protein_mask)
147
+ text_ids_world = pl_concat_all_gather(text_ids)
148
+ text_mask_world = pl_concat_all_gather(text_mask)
149
+
150
+ with torch.no_grad():
151
+ if dist.is_available() and dist.is_initialized():
152
+ rank = dist.get_rank()
153
+ else:
154
+ rank = 0
155
+
156
+ # Compute weights for negative sampling
157
+ weights_t2p = F.softmax(sim_t2p, dim=1) + 1e-4
158
+ weights_t2p[:, rank * batch_size : rank * batch_size + batch_size].fill_diagonal_(0)
159
+
160
+ weights_p2t = F.softmax(sim_p2t, dim=1) + 1e-4
161
+ weights_p2t[:, rank * batch_size : rank * batch_size + batch_size].fill_diagonal_(0)
162
+
163
+ # Select negative proteins for each text
164
+ protein_embeds_neg = []
165
+ protein_mask_neg = []
166
+ for b in range(batch_size):
167
+ neg_idx = torch.multinomial(weights_t2p[b], 1).item()
168
+ protein_embeds_neg.append(protein_embeds_world[neg_idx])
169
+ protein_mask_neg.append(protein_mask_world[neg_idx])
170
+
171
+ protein_embeds_neg = torch.stack(protein_embeds_neg, dim=0)
172
+ protein_mask_neg = torch.stack(protein_mask_neg, dim=0)
173
+
174
+ # Select negative texts for each protein
175
+ text_ids_neg = []
176
+ text_mask_neg = []
177
+ for b in range(batch_size):
178
+ neg_idx = torch.multinomial(weights_p2t[b], 1).item()
179
+ text_ids_neg.append(text_ids_world[neg_idx])
180
+ text_mask_neg.append(text_mask_world[neg_idx])
181
+
182
+ text_ids_neg = torch.stack(text_ids_neg, dim=0)
183
+ text_mask_neg = torch.stack(text_mask_neg, dim=0)
184
+
185
+ # Prepare inputs for PTM
186
+ text_ids_all = torch.cat([text_ids, text_ids, text_ids_neg], dim=0) # pos, pos, neg
187
+ text_mask_all = torch.cat([text_mask, text_mask, text_mask_neg], dim=0)
188
+
189
+ # 获取 text embeddings
190
+ text_embeds_all = qformer.embeddings.word_embeddings(text_ids_all)
191
+
192
+ # Expand query tokens for all samples
193
+ query_tokens_ptm = query_tokens.expand(text_ids_all.shape[0], -1, -1)
194
+ query_mask_ptm = torch.ones(query_tokens_ptm.size()[:-1], dtype=torch.long, device=device)
195
+
196
+ # 方法1:只使用 query tokens 和 text,不直接编码 protein
197
+ # 这更符合你当前的 QFormer 架构
198
+ inputs_embeds = torch.cat([query_tokens_ptm, text_embeds_all], dim=1)
199
+ attention_mask_all = torch.cat([query_mask_ptm, text_mask_all], dim=1)
200
+
201
+ # 确保序列长度不超过限制
202
+ max_length = qformer.config.max_position_embeddings
203
+ if attention_mask_all.size(1) > max_length:
204
+ # 截断 text 部分
205
+ excess = attention_mask_all.size(1) - max_length
206
+ if excess > 0:
207
+ inputs_embeds = inputs_embeds[:, :-excess, :]
208
+ attention_mask_all = attention_mask_all[:, :-excess]
209
+
210
+ # Forward through QFormer - 使用标准 BERT API
211
+ output_ptm = qformer(
212
+ inputs_embeds=inputs_embeds,
213
+ attention_mask=attention_mask_all,
214
+ return_dict=True,
215
+ )
216
+
217
+ # Extract query embeddings
218
+ pl_embeddings = output_ptm.last_hidden_state[:, :query_tokens_ptm.size(1), :]
219
+
220
+ # 确保 ptm_head 存在并且维度正确
221
+ if not hasattr(self, 'ptm_head'):
222
+ hidden_size = pl_embeddings.size(-1)
223
+ self.ptm_head = nn.Linear(hidden_size, 2).to(device)
224
+
225
+ pl_output = self.ptm_head(pl_embeddings)
226
+ logits = pl_output.mean(dim=1) # [batch_size * 3, 2]
227
+
228
+ # Create labels: positive pairs get label 1, negative pairs get label 0
229
+ ptm_labels = torch.cat([
230
+ torch.ones(batch_size, dtype=torch.long), # text-protein positive
231
+ torch.zeros(batch_size, dtype=torch.long), # text-protein_neg negative
232
+ torch.zeros(batch_size, dtype=torch.long) # text_neg-protein negative
233
+ ], dim=0).to(device)
234
+
235
+ loss_ptm = F.cross_entropy(logits, ptm_labels)
236
+ return loss_ptm
237
+
238
+
239
+ class ContrastiveTrainer(Trainer):
240
+ """
241
+ Enhanced trainer for contrastive learning between proteins and text.
242
+ """
243
+
244
+ def __init__(
245
+ self,
246
+ model: ProteinLLMModel,
247
+ args: ContrastiveTrainingArguments,
248
+ train_dataset: Optional[Dataset] = None,
249
+ eval_dataset: Optional[Dataset] = None,
250
+ data_collator: Optional[callable] = None,
251
+ **kwargs
252
+ ):
253
+ self.contrastive_loss = EnhancedContrastiveLoss(
254
+ temperature=args.temperature,
255
+ enable_ptm=args.enable_ptm
256
+ )
257
+ self.freeze_protein_model = args.freeze_protein_model
258
+ self.freeze_text_model = args.freeze_text_model
259
+ self.protein_weight = args.protein_weight
260
+ self.text_weight = args.text_weight
261
+ self.enable_ptm = args.enable_ptm
262
+ self.ptm_weight = args.ptm_weight
263
+
264
+ # Freeze models if specified
265
+ if self.freeze_protein_model:
266
+ for param in model.protein_model.parameters():
267
+ param.requires_grad = False
268
+
269
+ if self.freeze_text_model:
270
+ for param in model.text_model.parameters():
271
+ param.requires_grad = False
272
+
273
+ # Ensure projection layers are trainable
274
+ for param in model.protein_projection.parameters():
275
+ param.requires_grad = True
276
+
277
+ super().__init__(
278
+ model=model,
279
+ args=args,
280
+ train_dataset=train_dataset,
281
+ eval_dataset=eval_dataset,
282
+ data_collator=data_collator,
283
+ **kwargs
284
+ )
285
+
286
+ def compute_loss(self, model, inputs, return_outputs=False,**kwargs):
287
+ """
288
+ Compute enhanced contrastive loss with optional PTM.
289
+ """
290
+ # 检查模型是否被DataParallel包装
291
+ if hasattr(model, 'module'):
292
+ # 如果是DataParallel,使用.module访问原始模型
293
+ model = model.module
294
+ else:
295
+ # 否则直接使用模型
296
+ model = model
297
+
298
+ num_items_in_batch = kwargs.get('num_items_in_batch', None)
299
+
300
+ protein_sequences = inputs["protein_sequences"]
301
+ text_sequences = inputs["text_sequences"]
302
+
303
+ # Get device from model
304
+ device = next(model.parameters()).device
305
+
306
+ # Get protein embeddings (before projection)
307
+ protein_embeds = model.get_protein_embeddings(protein_sequences) # [B, seq_len, hidden]
308
+
309
+ # Get protein features through projection (query tokens)
310
+ protein_features = model.get_protein_features(protein_sequences) # [B, num_queries, embed_dim]
311
+
312
+ # Get text features
313
+ text_features = model.get_text_features(text_sequences) # [B, embed_dim]
314
+
315
+ # Normalize features
316
+ protein_features = F.normalize(protein_features, p=2, dim=-1)
317
+ text_features = F.normalize(text_features, p=2, dim=-1)
318
+
319
+ # Gather features from all processes for global contrastive learning
320
+ protein_features_all = pl_concat_all_gather(protein_features)
321
+ text_features_all = pl_concat_all_gather(text_features)
322
+
323
+ # Compute contrastive loss
324
+ sim_p2t, sim_t2p, loss_contrastive = self.contrastive_loss.contrast_global(
325
+ protein_features, text_features, protein_features_all, text_features_all, return_sim=True
326
+ )
327
+
328
+ total_loss = loss_contrastive
329
+
330
+ # Compute PTM loss if enabled
331
+ loss_ptm = 0
332
+ if self.enable_ptm:
333
+ # Tokenize text for PTM
334
+ text_tokenized = model.text_tokenizer(
335
+ text_sequences,
336
+ padding=True,
337
+ truncation=True,
338
+ return_tensors="pt",
339
+ max_length=self.args.max_length_text
340
+ ).to(model.device)
341
+
342
+ # Get protein attention mask
343
+ protein_tokenized = model.protein_tokenizer(
344
+ protein_sequences,
345
+ padding=True,
346
+ truncation=True,
347
+ return_tensors="pt",
348
+ max_length=self.args.max_length_protein
349
+ ).to(model.device)
350
+
351
+ loss_ptm = self.contrastive_loss.compute_ptm_loss(
352
+ protein_embeds=protein_embeds,
353
+ protein_mask=protein_tokenized.attention_mask,
354
+ text_ids=text_tokenized.input_ids,
355
+ text_mask=text_tokenized.attention_mask,
356
+ query_tokens=model.protein_projection.query_tokens,
357
+ tokenizer=model.text_tokenizer,
358
+ qformer=model.protein_projection.qformer,
359
+ sim_p2t=sim_p2t,
360
+ sim_t2p=sim_t2p
361
+ )
362
+
363
+ total_loss = total_loss + self.ptm_weight * loss_ptm
364
+
365
+ # Log detailed metrics
366
+ with torch.no_grad():
367
+ # Compute similarity statistics
368
+ similarity_matrix = torch.matmul(protein_features.flatten(0, 1), text_features.T)
369
+ positive_similarities = torch.diag(similarity_matrix[:protein_features.size(0)])
370
+ negative_similarities = similarity_matrix[~torch.eye(similarity_matrix.size(0), dtype=bool)]
371
+
372
+ log_dict = {
373
+ "contrastive_loss": loss_contrastive.item(),
374
+ "total_loss": total_loss.item(),
375
+ "positive_similarity_mean": positive_similarities.mean().item(),
376
+ "negative_similarity_mean": negative_similarities.mean().item(),
377
+ "positive_similarity_std": positive_similarities.std().item(),
378
+ "similarity_gap": (positive_similarities.mean() - negative_similarities.mean()).item(),
379
+ }
380
+
381
+ if self.enable_ptm:
382
+ log_dict["ptm_loss"] = loss_ptm.item() if isinstance(loss_ptm, torch.Tensor) else loss_ptm
383
+
384
+ self.log(log_dict)
385
+
386
+ if return_outputs:
387
+ outputs = {
388
+ "protein_features": protein_features,
389
+ "text_features": text_features,
390
+ "similarity_matrix_p2t": sim_p2t,
391
+ "similarity_matrix_t2p": sim_t2p,
392
+ "loss_contrastive": loss_contrastive,
393
+ "loss_ptm": loss_ptm,
394
+ }
395
+ return (total_loss, outputs)
396
+
397
+ return total_loss
398
+
399
+ def evaluation_loop(self, dataloader, description, prediction_loss_only=None, ignore_keys=None, metric_key_prefix="eval"):
400
+ """
401
+ Enhanced evaluation loop for contrastive learning with PTM.
402
+ """
403
+ model = self._wrap_model(self.model, training=False, dataloader=dataloader)
404
+ model.eval()
405
+
406
+ total_loss = 0.0
407
+ total_contrastive_loss = 0.0
408
+ total_ptm_loss = 0.0
409
+ total_samples = 0
410
+ all_protein_features = []
411
+ all_text_features = []
412
+
413
+ for step, inputs in enumerate(dataloader):
414
+ with torch.no_grad():
415
+ loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
416
+
417
+ total_loss += loss.item()
418
+ total_contrastive_loss += outputs["loss_contrastive"].item()
419
+ if self.enable_ptm:
420
+ ptm_loss_val = outputs["loss_ptm"].item() if isinstance(outputs["loss_ptm"], torch.Tensor) else outputs["loss_ptm"]
421
+ total_ptm_loss += ptm_loss_val
422
+
423
+ total_samples += len(inputs["protein_sequences"])
424
+
425
+ # Collect features for retrieval metrics
426
+ all_protein_features.append(outputs["protein_features"].cpu())
427
+ all_text_features.append(outputs["text_features"].cpu())
428
+
429
+ # Compute average losses
430
+ avg_loss = total_loss / len(dataloader)
431
+ avg_contrastive_loss = total_contrastive_loss / len(dataloader)
432
+ avg_ptm_loss = total_ptm_loss / len(dataloader) if self.enable_ptm else 0
433
+
434
+ # Concatenate all features for retrieval metrics
435
+ all_protein_features = torch.cat(all_protein_features, dim=0)
436
+ all_text_features = torch.cat(all_text_features, dim=0)
437
+
438
+ # Compute retrieval metrics
439
+ retrieval_metrics = self.compute_retrieval_metrics(all_protein_features, all_text_features)
440
+
441
+ metrics = {
442
+ f"{metric_key_prefix}_loss": avg_loss,
443
+ f"{metric_key_prefix}_contrastive_loss": avg_contrastive_loss,
444
+ **{f"{metric_key_prefix}_{k}": v for k, v in retrieval_metrics.items()}
445
+ }
446
+
447
+ if self.enable_ptm:
448
+ metrics[f"{metric_key_prefix}_ptm_loss"] = avg_ptm_loss
449
+
450
+ return metrics
451
+
452
+ def compute_retrieval_metrics(self, protein_features: torch.Tensor, text_features: torch.Tensor) -> Dict[str, float]:
453
+ """
454
+ Compute retrieval metrics for multi-query protein features.
455
+ """
456
+ # Handle multi-query protein features by taking mean or max
457
+ if protein_features.dim() == 3: # [batch, num_queries, embed_dim]
458
+ protein_features_pooled = protein_features.mean(dim=1) # Pool query tokens
459
+ else:
460
+ protein_features_pooled = protein_features
461
+
462
+ # Normalize features
463
+ protein_features_pooled = F.normalize(protein_features_pooled, dim=-1)
464
+ text_features = F.normalize(text_features, dim=-1)
465
+
466
+ # Compute similarity matrix
467
+ similarity_matrix = torch.matmul(protein_features_pooled, text_features.T)
468
+
469
+ # Protein-to-text retrieval
470
+ p2t_ranks = []
471
+ for i in range(similarity_matrix.size(0)):
472
+ similarities = similarity_matrix[i]
473
+ rank = (similarities >= similarities[i]).sum().item()
474
+ p2t_ranks.append(rank)
475
+
476
+ # Text-to-protein retrieval
477
+ t2p_ranks = []
478
+ for i in range(similarity_matrix.size(1)):
479
+ similarities = similarity_matrix[:, i]
480
+ rank = (similarities >= similarities[i]).sum().item()
481
+ t2p_ranks.append(rank)
482
+
483
+ # Compute Recall@K
484
+ metrics = {}
485
+ for k in [1, 5, 10]:
486
+ p2t_recall_k = sum(1 for rank in p2t_ranks if rank <= k) / len(p2t_ranks)
487
+ t2p_recall_k = sum(1 for rank in t2p_ranks if rank <= k) / len(t2p_ranks)
488
+
489
+ metrics[f"p2t_recall_at_{k}"] = p2t_recall_k
490
+ metrics[f"t2p_recall_at_{k}"] = t2p_recall_k
491
+ metrics[f"avg_recall_at_{k}"] = (p2t_recall_k + t2p_recall_k) / 2
492
+
493
+ # Mean rank
494
+ metrics["p2t_mean_rank"] = sum(p2t_ranks) / len(p2t_ranks)
495
+ metrics["t2p_mean_rank"] = sum(t2p_ranks) / len(t2p_ranks)
496
+
497
+ return metrics
498
+
499
+
500
+ def protein_text_collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, List[str]]:
501
+ """
502
+ Collate function for protein-text contrastive learning.
503
+ """
504
+ protein_sequences = [item["protein_sequence"] for item in batch]
505
+ text_sequences = [item["text_description"] for item in batch]
506
+
507
+ return {
508
+ "protein_sequences": protein_sequences,
509
+ "text_sequences": text_sequences,
510
+ }
511
+
512
+
513
+ # Example usage function remains the same but with enhanced arguments
514
+ # def train_contrastive_model(
515
+ # model: ProteinLLMModel,
516
+ # train_dataset: Dataset,
517
+ # eval_dataset: Optional[Dataset] = None,
518
+ # output_dir: str = "./contrastive_outputs",
519
+ # num_epochs: int = 10,
520
+ # batch_size: int = 32,
521
+ # learning_rate: float = 1e-4,
522
+ # temperature: float = 0.07,
523
+ # enable_ptm: bool = True,
524
+ # ptm_weight: float = 1.0,
525
+ # **kwargs
526
+ # ):
527
+ # """
528
+ # Train the model with enhanced contrastive learning.
529
+ # """
530
+ # training_args = ContrastiveTrainingArguments(
531
+ # output_dir=output_dir,
532
+ # num_train_epochs=num_epochs,
533
+ # per_device_train_batch_size=batch_size,
534
+ # per_device_eval_batch_size=batch_size,
535
+ # learning_rate=learning_rate,
536
+ # temperature=temperature,
537
+ # enable_ptm=enable_ptm,
538
+ # ptm_weight=ptm_weight,
539
+ # logging_steps=10,
540
+ # evaluation_strategy="steps" if eval_dataset else "no",
541
+ # eval_steps=100 if eval_dataset else None,
542
+ # save_steps=500,
543
+ # save_total_limit=3,
544
+ # load_best_model_at_end=True if eval_dataset else False,
545
+ # metric_for_best_model="eval_avg_recall_at_1" if eval_dataset else None,
546
+ # greater_is_better=True,
547
+ # report_to=["wandb"] if wandb.run else [],
548
+ # **kwargs
549
+ # )
550
+
551
+ # trainer = ContrastiveTrainer(
552
+ # model=model,
553
+ # args=training_args,
554
+ # train_dataset=train_dataset,
555
+ # eval_dataset=eval_dataset,
556
+ # data_collator=protein_text_collate_fn,
557
+ # )
558
+
559
+ # # Train the model
560
+ # trainer.train()
561
+
562
+ # # Save the final model
563
+ # trainer.save_model()
564
+
565
+ # return trainer
566
+
567
+
568
+
569
+
570
+ # def compute_ptm_loss(self, protein_embeds, protein_mask, text_ids, text_mask,
571
+ # query_tokens, tokenizer, qformer, sim_p2t, sim_t2p):
572
+ # """
573
+ # Compute protein-text matching loss.
574
+ # """
575
+ # batch_size = protein_embeds.size(0)
576
+ # device = protein_embeds.device
577
+
578
+ # # Get world features for negative sampling
579
+ # protein_embeds_world = pl_concat_all_gather(protein_embeds)
580
+ # protein_mask_world = pl_concat_all_gather(protein_mask)
581
+ # text_ids_world = pl_concat_all_gather(text_ids)
582
+ # text_mask_world = pl_concat_all_gather(text_mask)
583
+
584
+ # with torch.no_grad():
585
+ # if dist.is_available() and dist.is_initialized():
586
+ # rank = dist.get_rank()
587
+ # else:
588
+ # rank = 0
589
+
590
+ # # Compute weights for negative sampling
591
+ # weights_t2p = F.softmax(sim_t2p, dim=1) + 1e-4
592
+ # weights_t2p[:, rank * batch_size : rank * batch_size + batch_size].fill_diagonal_(0)
593
+
594
+ # weights_p2t = F.softmax(sim_p2t, dim=1) + 1e-4
595
+ # weights_p2t[:, rank * batch_size : rank * batch_size + batch_size].fill_diagonal_(0)
596
+
597
+ # # Select negative proteins for each text
598
+ # protein_embeds_neg = []
599
+ # protein_mask_neg = []
600
+ # for b in range(batch_size):
601
+ # neg_idx = torch.multinomial(weights_t2p[b], 1).item()
602
+ # protein_embeds_neg.append(protein_embeds_world[neg_idx])
603
+ # protein_mask_neg.append(protein_mask_world[neg_idx])
604
+
605
+ # protein_embeds_neg = torch.stack(protein_embeds_neg, dim=0)
606
+ # protein_mask_neg = torch.stack(protein_mask_neg, dim=0)
607
+
608
+ # # Select negative texts for each protein
609
+ # text_ids_neg = []
610
+ # text_mask_neg = []
611
+ # for b in range(batch_size):
612
+ # neg_idx = torch.multinomial(weights_p2t[b], 1).item()
613
+ # text_ids_neg.append(text_ids_world[neg_idx])
614
+ # text_mask_neg.append(text_mask_world[neg_idx])
615
+
616
+ # text_ids_neg = torch.stack(text_ids_neg, dim=0)
617
+ # text_mask_neg = torch.stack(text_mask_neg, dim=0)
618
+
619
+ # # Prepare inputs for PTM
620
+ # text_ids_all = torch.cat([text_ids, text_ids, text_ids_neg], dim=0) # pos, pos, neg
621
+ # text_mask_all = torch.cat([text_mask, text_mask, text_mask_neg], dim=0)
622
+
623
+ # query_tokens_ptm = query_tokens.expand(text_ids_all.shape[0], -1, -1)
624
+ # query_mask_ptm = torch.ones(query_tokens_ptm.size()[:-1], dtype=torch.long, device=device)
625
+ # attention_mask_all = torch.cat([query_mask_ptm, text_mask_all], dim=1)
626
+
627
+ # protein_embeds_all = torch.cat([protein_embeds, protein_embeds_neg, protein_embeds], dim=0) # pos, neg, pos
628
+ # protein_mask_all = torch.cat([protein_mask, protein_mask_neg, protein_mask], dim=0)
629
+
630
+ # # Combine embeddings
631
+ # inputs_embeds = torch.cat([query_tokens_ptm,
632
+ # qformer.embeddings.word_embeddings(text_ids_all)], dim=1)
633
+
634
+ # # Create position ids
635
+ # position_ids = torch.arange(
636
+ # attention_mask_all.size(1), dtype=torch.long, device=device
637
+ # ).unsqueeze(0).expand(text_ids_all.size(0), -1)
638
+
639
+ # # Forward through QFormer for PTM - using BERT's forward method directly
640
+ # output_ptm = qformer(
641
+ # inputs_embeds=inputs_embeds,
642
+ # attention_mask=attention_mask_all,
643
+ # position_ids=position_ids,
644
+ # encoder_hidden_states=protein_embeds_all,
645
+ # encoder_attention_mask=protein_mask_all,
646
+ # return_dict=True,
647
+ # )
648
+
649
+ # pl_embeddings = output_ptm.last_hidden_state[:, :query_tokens_ptm.size(1), :]
650
+ # pl_output = self.ptm_head(pl_embeddings)
651
+ # logits = pl_output.mean(dim=1)
652
+
653
+ # ptm_labels = torch.cat([
654
+ # torch.ones(batch_size, dtype=torch.long),
655
+ # torch.zeros(2 * batch_size, dtype=torch.long)
656
+ # ], dim=0).to(device)
657
+
658
+ # loss_ptm = F.cross_entropy(logits, ptm_labels)
659
+ # return loss_ptm
BioReason_new/bioreason/trainer/grpo_config.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass, field
16
+ from typing import Optional, Union
17
+
18
+ from transformers import TrainingArguments
19
+
20
+
21
+ @dataclass
22
+ class ProteinLLMGRPOConfig(TrainingArguments):
23
+ r"""
24
+ Configuration class for the [`ProteinLLMGRPOTrainer`].
25
+
26
+ Only the parameters specific to GRPO training are listed here. For details on other parameters, refer to the
27
+ [`~transformers.TrainingArguments`] documentation.
28
+
29
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
30
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
31
+ command line.
32
+
33
+ Parameters:
34
+ > Parameters that control the model and reference model
35
+
36
+ model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
37
+ Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
38
+ argument of the [`ProteinLLMGRPOTrainer`] is provided as a string.
39
+
40
+ > Parameters that control the data preprocessing
41
+
42
+ remove_unused_columns (`bool`, *optional*, defaults to `False`):
43
+ Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that
44
+ requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`.
45
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
46
+ Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
47
+ num_generations (`int` or `None`, *optional*, defaults to `8`):
48
+ Number of generations per prompt to sample. The global batch size (num_processes * per_device_batch_size)
49
+ must be divisible by this value.
50
+ max_completion_length (`int` or `None`, *optional*, defaults to `256`):
51
+ Maximum length of the generated completion.
52
+ ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
53
+ This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
54
+ improving generation speed. However, disabling this option allows training models that exceed the VRAM
55
+ capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible
56
+ with vLLM generation.
57
+
58
+ > Parameters that control generation
59
+
60
+ temperature (`float`, defaults to `0.9`):
61
+ Temperature for sampling. The higher the temperature, the more random the completions.
62
+ top_p (`float`, *optional*, defaults to `1.0`):
63
+ Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. Set to
64
+ `1.0` to consider all tokens.
65
+ top_k (`int` or `None`, *optional*, defaults to `50`):
66
+ Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, top-k-filtering is
67
+ disabled.
68
+ min_p (`float` or `None`, *optional*, defaults to `None`):
69
+ Minimum token probability, which will be scaled by the probability of the most likely token. It must be a
70
+ value between `0.0` and `1.0`. Typical values are in the `0.01-0.2` range.
71
+ repetition_penalty (`float`, *optional*, defaults to `1.0`):
72
+ Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far.
73
+ Values > `1.0` encourage the model to use new tokens, while values < `1.0` encourage the model to repeat
74
+ tokens.
75
+ cache_implementation (`str` or `None`, *optional*, defaults to `None`):
76
+ Implementation of the cache method for faster generation when use_vllm is set to False.
77
+
78
+ > Parameters that control the training
79
+
80
+ learning_rate (`float`, *optional*, defaults to `1e-6`):
81
+ Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
82
+ [`~transformers.TrainingArguments`].
83
+ beta (`float`, *optional*, defaults to `0.04`):
84
+ KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving training
85
+ speed, but may be numerically unstable for long training runs.
86
+ num_iterations (`int`, *optional*, defaults to `1`):
87
+ Number of iterations per batch (denoted as μ in the algorithm).
88
+ epsilon (`float`, *optional*, defaults to `0.2`):
89
+ Epsilon value for clipping.
90
+ epsilon_high (`float` or `None`, *optional*, defaults to `None`):
91
+ Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the lower-bound
92
+ specified in argument `epsilon`. Paper [DAPO](https://huggingface.co/papers/2503.14476) recommends `0.28`.
93
+ reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
94
+ Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
95
+ weighted equally with weight `1.0`.
96
+ sync_ref_model (`bool`, *optional*, defaults to `False`):
97
+ Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using
98
+ the `ref_model_mixup_alpha` parameter. This synchronization originites from the
99
+ [TR-DPO](https://huggingface.co/papers/2404.09656) paper.
100
+ ref_model_mixup_alpha (`float`, *optional*, defaults to `0.6`):
101
+ α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix
102
+ between the current policy and the previous reference policy during updates. The reference policy is
103
+ updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you
104
+ must set `sync_ref_model=True`.
105
+ ref_model_sync_steps (`int`, *optional*, defaults to `512`):
106
+ τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how
107
+ frequently the current policy is synchronized with the reference policy. To use this parameter, you must
108
+ set `sync_ref_model=True`.
109
+
110
+ > Parameters that control the logging
111
+
112
+ log_completions (`bool`, *optional*, defaults to `False`):
113
+ Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is
114
+ installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`.
115
+ """
116
+
117
+ # Parameters that control the model and reference model
118
+ model_init_kwargs: Optional[dict] = field(
119
+ default=None,
120
+ metadata={
121
+ "help": "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` "
122
+ "argument of the `ProteinLLMGRPOTrainer` is provided as a string."
123
+ },
124
+ )
125
+
126
+ # Parameters that control the data preprocessing
127
+ # The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on
128
+ # additional columns to compute the reward
129
+ remove_unused_columns: Optional[bool] = field(
130
+ default=False,
131
+ metadata={
132
+ "help": "Whether to only keep the column 'prompt' in the dataset. If you use a custom reward function "
133
+ "that requires any column other than 'prompts' and 'completions', you should keep this to `False`."
134
+ },
135
+ )
136
+ max_prompt_length: Optional[int] = field(
137
+ default=512,
138
+ metadata={
139
+ "help": "Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left."
140
+ },
141
+ )
142
+ num_generations: Optional[int] = field(
143
+ default=8,
144
+ metadata={
145
+ "help": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) "
146
+ "must be divisible by this value."
147
+ },
148
+ )
149
+ max_completion_length: Optional[int] = field(
150
+ default=800,
151
+ metadata={"help": "Maximum length of the generated completion."},
152
+ )
153
+ ds3_gather_for_generation: bool = field(
154
+ default=True,
155
+ metadata={
156
+ "help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for "
157
+ "generation, improving generation speed. However, disabling this option allows training models that "
158
+ "exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation. Disabling this option "
159
+ "is not compatible with vLLM generation."
160
+ },
161
+ )
162
+
163
+ # Parameters that control generation
164
+ temperature: float = field(
165
+ default=0.6,
166
+ metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."},
167
+ )
168
+ top_p: float = field(
169
+ default=0.95,
170
+ metadata={
171
+ "help": "Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. "
172
+ "Set to 1.0 to consider all tokens."
173
+ },
174
+ )
175
+ top_k: Optional[int] = field(
176
+ default=20,
177
+ metadata={
178
+ "help": "Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, "
179
+ "top-k-filtering is disabled."
180
+ },
181
+ )
182
+ min_p: Optional[float] = field(
183
+ default=None,
184
+ metadata={
185
+ "help": "Minimum token probability, which will be scaled by the probability of the most likely token. It "
186
+ "must be a value between 0.0 and 1.0. Typical values are in the 0.01-0.2 range."
187
+ },
188
+ )
189
+ repetition_penalty: float = field(
190
+ default=1.0,
191
+ metadata={
192
+ "help": "Float that penalizes new tokens based on whether they appear in the prompt and the generated "
193
+ "text so far. Values > 1.0 encourage the model to use new tokens, while values < 1.0 encourage the model "
194
+ "to repeat tokens."
195
+ },
196
+ )
197
+ cache_implementation: Optional[str] = field(
198
+ default=None,
199
+ metadata={"help": "Implementation of the cache method for faster generation when use_vllm is set to False."},
200
+ )
201
+
202
+ # Parameters that control the training
203
+ learning_rate: float = field(
204
+ default=1e-6,
205
+ metadata={
206
+ "help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of "
207
+ "`transformers.TrainingArguments`."
208
+ },
209
+ )
210
+ beta: float = field(
211
+ default=0.04,
212
+ metadata={
213
+ "help": "KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving "
214
+ "training speed, but may be numerically unstable for long training runs."
215
+ },
216
+ )
217
+ num_iterations: int = field(
218
+ default=1,
219
+ metadata={"help": "Number of iterations per batch (denoted as μ in the algorithm)."},
220
+ )
221
+ epsilon: float = field(
222
+ default=0.2,
223
+ metadata={"help": "Epsilon value for clipping."},
224
+ )
225
+ epsilon_high: Optional[float] = field(
226
+ default=None,
227
+ metadata={
228
+ "help": "Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the "
229
+ "lower-bound specified in argument `epsilon`. Paper DAPO recommends `0.28`."
230
+ },
231
+ )
232
+ reward_weights: Optional[list[float]] = field(
233
+ default=None,
234
+ metadata={
235
+ "help": "Weights for each reward function. Must match the number of reward functions. If `None`, all "
236
+ "rewards are weighted equally with weight `1.0`."
237
+ },
238
+ )
239
+ sync_ref_model: bool = field(
240
+ default=False,
241
+ metadata={
242
+ "help": "Whether to synchronize the reference model with the active model every `ref_model_sync_steps` "
243
+ "steps, using the `ref_model_mixup_alpha` parameter."
244
+ },
245
+ )
246
+ ref_model_mixup_alpha: float = field(
247
+ default=0.6,
248
+ metadata={
249
+ "help": "α parameter from the TR-DPO paper, which controls the mix between the current policy and the "
250
+ "previous reference policy during updates. The reference policy is updated according to the equation: "
251
+ "`π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you must set `sync_ref_model=True`."
252
+ },
253
+ )
254
+ ref_model_sync_steps: int = field(
255
+ default=512,
256
+ metadata={
257
+ "help": "τ parameter from the TR-DPO paper, which determines how frequently the current policy is "
258
+ "synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`."
259
+ },
260
+ )
261
+
262
+ # Parameters that control the logging
263
+ log_completions: bool = field(
264
+ default=True,
265
+ metadata={
266
+ "help": "Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is "
267
+ "installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`."
268
+ },
269
+ )
270
+
271
+ report_to: Union[None, str, list[str]] = field(
272
+ default="wandb", metadata={"help": "The list of integrations to report the results and logs to."}
273
+ )
274
+
275
+ logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"})
276
+ logging_steps: float = field(
277
+ default=2,
278
+ metadata={
279
+ "help": (
280
+ "Log every X updates steps. Should be an integer or a float in range `[0,1)`. "
281
+ "If smaller than 1, will be interpreted as ratio of total training steps."
282
+ )
283
+ },
284
+ )
285
+
286
+
287
+ # Parameters that control generation acceleration powered by vLLM
288
+ use_vllm: Optional[bool] = field(
289
+ default=False,
290
+ metadata={
291
+ "help": "Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept "
292
+ "unused for training, as vLLM will require one for generation. vLLM must be installed "
293
+ "(`pip install vllm`)."
294
+ },
295
+ )
296
+ vllm_device: Optional[str] = field(
297
+ default="auto",
298
+ metadata={
299
+ "help": "Device where vLLM generation will run, e.g. 'cuda:1'. If set to 'auto' (default), the system "
300
+ "will automatically select the next available GPU after the last one used for training. This assumes "
301
+ "that training has not already occupied all available GPUs."
302
+ },
303
+ )
304
+ vllm_gpu_memory_utilization: float = field(
305
+ default=0.9,
306
+ metadata={
307
+ "help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV "
308
+ "cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache "
309
+ "size and thus improve the model's throughput. However, if the value is too high, it may cause "
310
+ "out-of-memory (OOM) errors during initialization."
311
+ },
312
+ )
313
+ vllm_dtype: Optional[str] = field(
314
+ default="auto",
315
+ metadata={
316
+ "help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically "
317
+ "determined based on the model configuration. Find the supported values in the vLLM documentation."
318
+ },
319
+ )
320
+ vllm_max_model_len: Optional[int] = field(
321
+ default=None,
322
+ metadata={
323
+ "help": "If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced "
324
+ "`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model "
325
+ "context size, which might be much larger than the KV cache, leading to inefficiencies."
326
+ },
327
+ )
328
+ vllm_enable_prefix_caching: Optional[bool] = field(
329
+ default=True,
330
+ metadata={
331
+ "help": "Whether to enable prefix caching in vLLM. If set to `True` (default), ensure that the model and "
332
+ "the hardware support this feature."
333
+ },
334
+ )
335
+ vllm_guided_decoding_regex: Optional[str] = field(
336
+ default=None,
337
+ metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
338
+ )
BioReason_new/bioreason/trainer/grpo_trainer.py ADDED
@@ -0,0 +1,719 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import textwrap
4
+ import pandas as pd
5
+ from collections import defaultdict
6
+ from typing import Any, Callable, Optional, Union, Sized
7
+
8
+ import torch
9
+ import torch.utils.data
10
+ import transformers
11
+ from datasets import Dataset, IterableDataset
12
+ from packaging import version
13
+ from transformers import (
14
+ AutoModelForCausalLM,
15
+ AutoModelForSequenceClassification,
16
+ AutoTokenizer,
17
+ GenerationConfig,
18
+ PreTrainedModel,
19
+ PreTrainedTokenizerBase,
20
+ Trainer,
21
+ TrainerCallback,
22
+ is_wandb_available,
23
+ )
24
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
25
+ from transformers.utils import is_peft_available
26
+
27
+ from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
28
+ from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
29
+ from trl.trainer.grpo_config import GRPOConfig
30
+ from trl.trainer.utils import generate_model_card, get_comet_experiment_url
31
+
32
+ from accelerate.utils import is_peft_model, set_seed, gather_object
33
+ import copy
34
+ from torch.utils.data import Sampler
35
+ import warnings
36
+
37
+ if is_peft_available():
38
+ from peft import PeftConfig, get_peft_model, prepare_model_for_kbit_training
39
+
40
+ if is_wandb_available():
41
+ import wandb
42
+
43
+ from bioreason.protein_modules.protein_base_module import ProteinBaseModule
44
+ from bioreason.trainer.protein_grpo_config import ProteinLLMGRPOConfig
45
+
46
+ # What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
47
+ # rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
48
+ RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
49
+
50
+
51
+ class RepeatRandomSampler(Sampler):
52
+ """
53
+ Sampler that repeats the indices of a dataset in a structured manner.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ data_source: Sized,
59
+ mini_repeat_count: int,
60
+ batch_size: int = 1,
61
+ repeat_count: int = 1,
62
+ seed: Optional[int] = None,
63
+ ):
64
+ self.data_source = data_source
65
+ self.mini_repeat_count = mini_repeat_count
66
+ self.batch_size = batch_size
67
+ self.repeat_count = repeat_count
68
+ self.num_samples = len(data_source)
69
+ self.seed = seed
70
+ self.generator = torch.Generator()
71
+ if seed is not None:
72
+ self.generator.manual_seed(seed)
73
+
74
+ def __iter__(self):
75
+ indexes = torch.randperm(self.num_samples, generator=self.generator).tolist()
76
+ indexes = [indexes[i : i + self.batch_size] for i in range(0, len(indexes), self.batch_size)]
77
+ indexes = [chunk for chunk in indexes if len(chunk) == self.batch_size]
78
+
79
+ for chunk in indexes:
80
+ for _ in range(self.repeat_count):
81
+ for index in chunk:
82
+ for _ in range(self.mini_repeat_count):
83
+ yield index
84
+
85
+ def __len__(self) -> int:
86
+ return self.num_samples * self.mini_repeat_count * self.repeat_count
87
+
88
+
89
+ class ProteinLLMGRPOTrainer(Trainer):
90
+ """
91
+ Trainer for the Group Relative Policy Optimization (GRPO) method adapted for Protein-LLM models.
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ model: Union[str, PreTrainedModel],
97
+ reward_funcs: Union[RewardFunc, list[RewardFunc]],
98
+ args: ProteinLLMGRPOConfig = None,
99
+ protein_module: ProteinBaseModule = None,
100
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
101
+ eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
102
+ processing_class: Optional[PreTrainedTokenizerBase] = None,
103
+ reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
104
+ callbacks: Optional[list[TrainerCallback]] = None,
105
+ optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
106
+ peft_config: Optional["PeftConfig"] = None,
107
+ freeze_protein_modules: Optional[bool] = False,
108
+ attn_implementation: str = "flash_attention_2",
109
+ torch_dtype: str = "bfloat16",
110
+ **kwargs,
111
+ ):
112
+ # Args
113
+ if args is None:
114
+ model_name = model if isinstance(model, str) else model.config._name_or_path
115
+ model_name = model_name.split("/")[-1]
116
+ args = GRPOConfig(f"{model_name}-GRPO")
117
+
118
+ self.protein_module = protein_module
119
+
120
+ # Models
121
+ model_init_kwargs = args.model_init_kwargs or {}
122
+ model_init_kwargs["attn_implementation"] = attn_implementation
123
+ if model_init_kwargs.get("torch_dtype") is None:
124
+ model_init_kwargs["torch_dtype"] = torch_dtype
125
+
126
+ assert not isinstance(model, str), "model must NOT be a string in the current implementation"
127
+
128
+ torch_dtype = model_init_kwargs.get("torch_dtype")
129
+ if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
130
+ pass
131
+ elif isinstance(torch_dtype, str):
132
+ torch_dtype = getattr(torch, torch_dtype)
133
+ else:
134
+ raise ValueError(
135
+ "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
136
+ f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
137
+ )
138
+
139
+ model_init_kwargs["use_cache"] = (
140
+ False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
141
+ )
142
+
143
+ # LoRA
144
+ self.protein_modules_keywords = self.protein_module.get_proteinllm_modules_keywords()
145
+ if peft_config is not None:
146
+ print("Applying LoRA...")
147
+ def find_all_linear_names(model, multimodal_keywords):
148
+ cls = torch.nn.Linear
149
+ lora_module_names = set()
150
+ for name, module in model.named_modules():
151
+ print('name:', name, 'module:', module)
152
+ # LoRA is not applied to the protein modules
153
+ if any(mm_keyword in name for mm_keyword in multimodal_keywords):
154
+ continue
155
+ if isinstance(module, cls):
156
+ lora_module_names.add(name)
157
+ for m in lora_module_names:
158
+ if "embed_tokens" in m:
159
+ lora_module_names.remove(m)
160
+ return list(lora_module_names)
161
+ target_modules = find_all_linear_names(model, self.protein_modules_keywords)
162
+ peft_config.target_modules = target_modules
163
+ model = prepare_model_for_kbit_training(model)
164
+ model = get_peft_model(model, peft_config)
165
+
166
+ # Freeze protein modules
167
+ if freeze_protein_modules:
168
+ print("Freezing protein modules...")
169
+ for p in model.protein_model.parameters():
170
+ p.requires_grad = False
171
+
172
+ # Make projection layer trainable
173
+ for p in model.protein_projection.parameters():
174
+ p.requires_grad = True
175
+
176
+ # Compute the number of trainable parameters
177
+ trainable_params = [p for p in model.parameters() if p.requires_grad]
178
+ total_params = sum(p.numel() for p in trainable_params)
179
+ print(f"Total trainable parameters: {total_params}")
180
+
181
+ # Enable gradient checkpointing if requested
182
+ if args.gradient_checkpointing:
183
+ model = self._enable_gradient_checkpointing(model, args)
184
+
185
+ # Reference model
186
+ self.beta = args.beta
187
+ if self.beta == 0.0:
188
+ self.ref_model = None
189
+ elif is_deepspeed_zero3_enabled():
190
+ self.ref_model = model_cls.from_pretrained(model_id, **model_init_kwargs)
191
+ elif is_peft_model(model):
192
+ self.ref_model = None
193
+ else:
194
+ self.ref_model = create_reference_model(model)
195
+
196
+ # Processing class
197
+ if processing_class is None:
198
+ processing_cls = self.protein_module.get_processing_class()
199
+ processing_class = processing_cls(
200
+ tokenizer=model.text_tokenizer,
201
+ protein_tokenizer=model.protein_tokenizer
202
+ )
203
+
204
+ for component, processing_keyword in self.protein_module.get_custom_processing_keywords():
205
+ if processing_keyword in kwargs:
206
+ processing_component = getattr(processing_class, component, processing_class)
207
+ setattr(processing_component, processing_keyword, kwargs[processing_keyword])
208
+
209
+ if getattr(processing_class, "tokenizer", None) is not None:
210
+ pad_token_id = processing_class.tokenizer.pad_token_id
211
+ processing_class.pad_token_id = pad_token_id
212
+ processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
213
+ else:
214
+ assert isinstance(processing_class, PreTrainedTokenizerBase), "processing_class must be an instance of PreTrainedTokenizerBase if it has no tokenizer attribute"
215
+ pad_token_id = processing_class.pad_token_id
216
+
217
+ self.protein_module.post_model_init(model, processing_class)
218
+ if self.ref_model is not None:
219
+ self.protein_module.post_model_init(self.ref_model, processing_class)
220
+
221
+ # Reward functions
222
+ if not isinstance(reward_funcs, list):
223
+ reward_funcs = [reward_funcs]
224
+ for i, reward_func in enumerate(reward_funcs):
225
+ if isinstance(reward_func, str):
226
+ reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
227
+ reward_func, num_labels=1, **model_init_kwargs
228
+ )
229
+ self.reward_funcs = reward_funcs
230
+
231
+ # Reward processing class
232
+ if reward_processing_classes is None:
233
+ reward_processing_classes = [None] * len(reward_funcs)
234
+ elif not isinstance(reward_processing_classes, list):
235
+ reward_processing_classes = [reward_processing_classes]
236
+ else:
237
+ if len(reward_processing_classes) != len(reward_funcs):
238
+ raise ValueError("The number of reward processing classes must match the number of reward functions.")
239
+
240
+ for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
241
+ if isinstance(reward_func, PreTrainedModel):
242
+ if reward_processing_class is None:
243
+ reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
244
+ if reward_processing_class.pad_token_id is None:
245
+ reward_processing_class.pad_token = reward_processing_class.eos_token
246
+ reward_func.config.pad_token_id = reward_processing_class.pad_token_id
247
+ reward_processing_classes[i] = reward_processing_class
248
+ self.reward_processing_classes = reward_processing_classes
249
+
250
+ # Data collator
251
+ def data_collator(features):
252
+ return features
253
+
254
+ # Training arguments
255
+ self.max_prompt_length = args.max_prompt_length
256
+ self.max_prompt_length = None
257
+ if args.max_prompt_length is not None:
258
+ warnings.warn("Setting max_prompt_length is currently not supported, it has been set to None")
259
+
260
+ self.max_completion_length = args.max_completion_length
261
+ self.num_generations = args.num_generations
262
+ self.generation_config = GenerationConfig(
263
+ max_new_tokens=self.max_completion_length,
264
+ do_sample=True,
265
+ temperature=0.6,
266
+ top_p=0.95,
267
+ top_k=20,
268
+ pad_token_id=pad_token_id,
269
+ )
270
+
271
+ if hasattr(self.protein_module, "get_eos_token_id"):
272
+ self.generation_config.eos_token_id = self.protein_module.get_eos_token_id(processing_class)
273
+
274
+ self.beta = args.beta
275
+ self.epsilon_low = args.epsilon
276
+ self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon
277
+
278
+ # Multi-step
279
+ self.num_iterations = args.num_iterations
280
+ self._step = 0
281
+ self._buffered_inputs = [None] * args.gradient_accumulation_steps
282
+
283
+ # Suppress warnings
284
+ model.warnings_issued["estimate_tokens"] = True
285
+
286
+ # Initialize the metrics
287
+ self._metrics = defaultdict(list)
288
+ self.log_completions = args.log_completions
289
+
290
+ super().__init__(
291
+ model=model,
292
+ args=args,
293
+ data_collator=data_collator,
294
+ train_dataset=train_dataset,
295
+ eval_dataset=eval_dataset,
296
+ processing_class=processing_class,
297
+ callbacks=callbacks,
298
+ optimizers=optimizers,
299
+ )
300
+
301
+ # Check batch size compatibility
302
+ num_processes = self.accelerator.num_processes
303
+ global_batch_size = args.per_device_train_batch_size * num_processes
304
+ possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
305
+ if self.num_generations not in possible_values:
306
+ raise ValueError(
307
+ f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
308
+ f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train "
309
+ f"batch size, the valid values for the number of generations are: {possible_values}."
310
+ )
311
+
312
+ if self.args.eval_strategy != "no":
313
+ global_batch_size = args.per_device_eval_batch_size * num_processes
314
+ possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
315
+ if self.num_generations not in possible_values:
316
+ raise ValueError(
317
+ f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
318
+ f"divisible by the number of generations per prompt ({self.num_generations}). Given the current "
319
+ f"eval batch size, the valid values for the number of generations are: {possible_values}."
320
+ )
321
+
322
+ # Set seed for reproducibility
323
+ set_seed(args.seed, device_specific=True)
324
+
325
+ # Gradient accumulation setup
326
+ self.model_accepts_loss_kwargs = False
327
+
328
+ if self.ref_model is not None:
329
+ if is_deepspeed_zero3_enabled():
330
+ self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
331
+ else:
332
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
333
+
334
+ for i, reward_func in enumerate(self.reward_funcs):
335
+ if isinstance(reward_func, PreTrainedModel):
336
+ self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
337
+
338
+ def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: ProteinLLMGRPOConfig) -> PreTrainedModel:
339
+ """Enables gradient checkpointing for the model."""
340
+ model.config.use_cache = False
341
+
342
+ if is_peft_model(model):
343
+ model.base_model.gradient_checkpointing_enable()
344
+ else:
345
+ if getattr(model, "language_model", None) is not None:
346
+ model.language_model.config.use_cache = False
347
+ model.protein_model.gradient_checkpointing = True
348
+ model.language_model._set_gradient_checkpointing()
349
+ args.gradient_checkpointing = False
350
+ else:
351
+ model.gradient_checkpointing_enable()
352
+
353
+ gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
354
+ use_reentrant = (
355
+ "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
356
+ )
357
+
358
+ if use_reentrant:
359
+ model.enable_input_require_grads()
360
+
361
+ return model
362
+
363
+ def _set_signature_columns_if_needed(self):
364
+ if self._signature_columns is None:
365
+ self._signature_columns = ["prompt"]
366
+
367
+ def _get_per_token_logps(self, model, input_ids, attention_mask, **custom_multimodal_inputs):
368
+ logits = model(input_ids=input_ids, attention_mask=attention_mask, **custom_multimodal_inputs).logits
369
+ logits = logits[:, :-1, :]
370
+ input_ids = input_ids[:, 1:]
371
+
372
+ # Compute the log probabilities for the input tokens
373
+ per_token_logps = []
374
+ for logits_row, input_ids_row in zip(logits, input_ids):
375
+ log_probs = logits_row.log_softmax(dim=-1)
376
+ token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
377
+ per_token_logps.append(token_log_prob)
378
+ return torch.stack(per_token_logps)
379
+
380
+ def _prepare_inputs(self, inputs):
381
+ return inputs
382
+
383
+ def _get_key_from_inputs(self, x, key):
384
+ ele = x.get(key, None)
385
+ assert ele is not None, f"The key {key} is not found in the input"
386
+ if isinstance(ele, list):
387
+ return [e for e in ele]
388
+ else:
389
+ return [ele]
390
+
391
+ def _generate_and_score_completions(self, inputs: dict[str, Union[torch.Tensor, Any]], model) -> dict[str, Union[torch.Tensor, Any]]:
392
+ device = self.accelerator.device
393
+ prompts = [x["prompt"] for x in inputs]
394
+ prompts_text = self.protein_module.prepare_prompt(self.processing_class, inputs)
395
+
396
+ # Handle protein sequences
397
+ batch_protein_sequences = []
398
+ print("_generate_and_score_completions (GRPO):")
399
+ for x in inputs:
400
+ if 'protein_sequences' in x:
401
+ proteins = self._get_key_from_inputs(x, "protein_sequences")
402
+ else:
403
+ proteins = []
404
+ batch_protein_sequences.append(proteins)
405
+
406
+ prompt_inputs = self.protein_module.prepare_model_inputs(
407
+ self.processing_class,
408
+ model,
409
+ prompts_text,
410
+ batch_protein_sequences,
411
+ return_tensors="pt",
412
+ padding=True,
413
+ padding_side="left",
414
+ add_special_tokens=False,
415
+ )
416
+
417
+ prompt_inputs = super()._prepare_inputs(prompt_inputs)
418
+ prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
419
+
420
+ # Generate completions
421
+ start = time.time()
422
+ with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
423
+ kwargs = {k: v for k, v in prompt_inputs.items() if k not in self.protein_module.get_non_generate_params()}
424
+ generate_returned_result = unwrapped_model.generate(
425
+ **kwargs,
426
+ generation_config=self.generation_config
427
+ )
428
+ end = time.time()
429
+ print(f"Generation time: {end - start:.9f} seconds")
430
+ prompt_length = prompt_ids.size(1)
431
+
432
+ if not self.protein_module.is_embeds_input():
433
+ prompt_completion_ids = generate_returned_result
434
+ prompt_ids = prompt_completion_ids[:, :prompt_length]
435
+ completion_ids = prompt_completion_ids[:, prompt_length:]
436
+ else:
437
+ completion_ids = generate_returned_result
438
+ prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
439
+
440
+ # Mask everything after the first EOS token
441
+ is_eos = completion_ids == self.processing_class.eos_token_id
442
+ eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
443
+ eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
444
+ sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
445
+ completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
446
+
447
+ # Concatenate prompt_mask with completion_mask for logit computation
448
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
449
+
450
+ # Get the multimodal inputs
451
+ multimodal_keywords = self.protein_module.get_custom_multimodal_keywords()
452
+ multimodal_inputs = {k: prompt_inputs[k] if k in prompt_inputs else None for k in multimodal_keywords}
453
+
454
+ with torch.no_grad():
455
+ if self.num_iterations > 1:
456
+ old_per_token_logps = self._get_per_token_logps(
457
+ model, prompt_completion_ids, attention_mask, **multimodal_inputs
458
+ )
459
+ old_per_token_logps = old_per_token_logps[:, prompt_length - 1:]
460
+ else:
461
+ old_per_token_logps = None
462
+
463
+ if self.beta == 0.0:
464
+ ref_per_token_logps = None
465
+ elif self.ref_model is not None:
466
+ ref_per_token_logps = self._get_per_token_logps(
467
+ self.ref_model, prompt_completion_ids, attention_mask, **multimodal_inputs
468
+ )
469
+ else:
470
+ with self.accelerator.unwrap_model(model).disable_adapter():
471
+ ref_per_token_logps = self._get_per_token_logps(
472
+ model, prompt_completion_ids, attention_mask, **multimodal_inputs
473
+ )
474
+
475
+ if ref_per_token_logps is not None:
476
+ ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1:]
477
+
478
+ # Decode the generated completions
479
+ completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
480
+ if is_conversational(inputs[0]):
481
+ completions = [[{"role": "assistant", "content": completion}] for completion in completions_text]
482
+ else:
483
+ completions = completions_text
484
+
485
+ # Compute the rewards
486
+ print("Reward calculation...")
487
+ rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
488
+ for i, (reward_func, reward_processing_class) in enumerate(
489
+ zip(self.reward_funcs, self.reward_processing_classes)
490
+ ):
491
+ if isinstance(reward_func, PreTrainedModel):
492
+ if is_conversational(inputs[0]):
493
+ messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
494
+ texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
495
+ else:
496
+ texts = [p + c for p, c in zip(prompts, completions)]
497
+ reward_inputs = reward_processing_class(
498
+ texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
499
+ )
500
+ reward_inputs = super()._prepare_inputs(reward_inputs)
501
+ with torch.inference_mode():
502
+ rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0]
503
+ else:
504
+ reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
505
+ for key in reward_kwargs:
506
+ for example in inputs:
507
+ reward_kwargs[key].extend([example[key]])
508
+ output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
509
+ rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
510
+
511
+ # Gather rewards across processes
512
+ rewards_per_func = self.accelerator.gather(rewards_per_func)
513
+
514
+ # Sum the rewards from all reward functions
515
+ rewards = rewards_per_func.sum(dim=1)
516
+
517
+ # Compute grouped-wise rewards
518
+ mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
519
+ std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
520
+
521
+ # Normalize the rewards to compute the advantages
522
+ mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
523
+ std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
524
+ advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
525
+
526
+ # Get only the local slice of advantages
527
+ process_slice = slice(
528
+ self.accelerator.process_index * len(prompts),
529
+ (self.accelerator.process_index + 1) * len(prompts),
530
+ )
531
+ advantages = advantages[process_slice]
532
+
533
+ # Log the metrics
534
+ print("Logging metrics...")
535
+ completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
536
+ self._metrics["completion_length"].append(completion_length)
537
+
538
+ reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
539
+ for i, reward_func in enumerate(self.reward_funcs):
540
+ if isinstance(reward_func, PreTrainedModel):
541
+ reward_func_name = reward_func.config._name_or_path.split("/")[-1]
542
+ else:
543
+ reward_func_name = reward_func.__name__
544
+ self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
545
+
546
+ self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
547
+ self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())
548
+
549
+ if (
550
+ self.log_completions
551
+ and self.state.global_step % self.args.logging_steps == 0
552
+ and "wandb" in self.args.report_to
553
+ ):
554
+ timestamp = time.time()
555
+ num_items = len(gather_object(prompts_text))
556
+
557
+ table = {
558
+ "step": [f"{self.state.global_step}_{timestamp}"] * num_items,
559
+ "prompt": gather_object(prompts_text),
560
+ "completion": gather_object(completions_text),
561
+ "reward": rewards.tolist(),
562
+ }
563
+ df = pd.DataFrame(table)
564
+
565
+ if wandb.run is not None and self.accelerator.is_main_process:
566
+ wandb.log({f"completions_{self.state.global_step}_{timestamp}": wandb.Table(dataframe=df)})
567
+
568
+ return {
569
+ "prompt_ids": prompt_ids,
570
+ "prompt_mask": prompt_mask,
571
+ "completion_ids": completion_ids,
572
+ "completion_mask": completion_mask,
573
+ "old_per_token_logps": old_per_token_logps,
574
+ "ref_per_token_logps": ref_per_token_logps,
575
+ "advantages": advantages,
576
+ "multimodal_inputs": multimodal_inputs
577
+ }
578
+
579
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
580
+ if return_outputs:
581
+ raise ValueError("The ProteinLLMGRPOTrainer does not support returning outputs")
582
+
583
+ # Check if we need to generate new completions or use buffered ones
584
+ if self.state.global_step % self.num_iterations == 0:
585
+ inputs = self._generate_and_score_completions(inputs, model)
586
+ self._buffered_inputs[self._step % self.args.gradient_accumulation_steps] = inputs
587
+ else:
588
+ inputs = self._buffered_inputs[self._step % self.args.gradient_accumulation_steps]
589
+ self._step += 1
590
+
591
+ # Get the prepared inputs
592
+ prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
593
+ completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
594
+ multimodal_inputs = inputs["multimodal_inputs"]
595
+
596
+ # Concatenate for full sequence
597
+ input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
598
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
599
+
600
+ # Get the current policy's log probabilities
601
+ per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, **multimodal_inputs)
602
+ per_token_logps = per_token_logps[:, prompt_ids.size(1) - 1:]
603
+
604
+ # Get the advantages from inputs
605
+ advantages = inputs["advantages"]
606
+
607
+ # When using num_iterations == 1, old_per_token_logps == per_token_logps
608
+ old_per_token_logps = inputs["old_per_token_logps"] if self.num_iterations > 1 else per_token_logps.detach()
609
+
610
+ # Compute the policy ratio and clipped version
611
+ coef_1 = torch.exp(per_token_logps - old_per_token_logps)
612
+ coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
613
+ per_token_loss1 = coef_1 * advantages.unsqueeze(1)
614
+ per_token_loss2 = coef_2 * advantages.unsqueeze(1)
615
+ per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
616
+
617
+ # Add KL penalty if beta > 0
618
+ if self.beta > 0:
619
+ ref_per_token_logps = inputs["ref_per_token_logps"]
620
+ per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
621
+ per_token_loss = per_token_loss + self.beta * per_token_kl
622
+
623
+ # Log KL divergence
624
+ mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
625
+ self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
626
+
627
+ # Compute final loss
628
+ loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
629
+
630
+ # Log clip ratio
631
+ is_clipped = (per_token_loss1 < per_token_loss2).float()
632
+ clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum()
633
+ self._metrics["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item())
634
+
635
+ return loss
636
+
637
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
638
+ metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()}
639
+ logs = {**logs, **metrics}
640
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
641
+ super().log(logs, start_time)
642
+ else:
643
+ super().log(logs)
644
+ self._metrics.clear()
645
+
646
+ def create_model_card(
647
+ self,
648
+ model_name: Optional[str] = None,
649
+ dataset_name: Optional[str] = None,
650
+ tags: Union[str, list[str], None] = None,
651
+ ):
652
+ """
653
+ Creates a draft of a model card using the information available to the `Trainer`.
654
+ """
655
+ if not self.is_world_process_zero():
656
+ return
657
+
658
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
659
+ base_model = self.model.config._name_or_path
660
+ else:
661
+ base_model = None
662
+
663
+ tags = tags or []
664
+ if isinstance(tags, str):
665
+ tags = [tags]
666
+
667
+ if hasattr(self.model.config, "unsloth_version"):
668
+ tags.append("unsloth")
669
+
670
+ citation = textwrap.dedent(
671
+ """\
672
+ @article{zhihong2024deepseekmath,
673
+ title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
674
+ author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
675
+ year = 2024,
676
+ eprint = {arXiv:2402.03300},
677
+ }
678
+ """
679
+ )
680
+
681
+ model_card = generate_model_card(
682
+ base_model=base_model,
683
+ model_name=model_name,
684
+ hub_model_id=self.hub_model_id,
685
+ dataset_name=dataset_name,
686
+ tags=tags,
687
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
688
+ comet_url=get_comet_experiment_url(),
689
+ trainer_name="GRPO",
690
+ trainer_citation=citation,
691
+ paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models",
692
+ paper_id="2402.03300",
693
+ )
694
+
695
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
696
+
697
+ def _get_train_sampler(self) -> Sampler:
698
+ """Returns a sampler that ensures proper data sampling for GRPO training."""
699
+ effective_batch_size = (
700
+ self.args.per_device_train_batch_size
701
+ * self.accelerator.num_processes
702
+ * self.args.gradient_accumulation_steps
703
+ )
704
+
705
+ return RepeatRandomSampler(
706
+ data_source=self.train_dataset,
707
+ mini_repeat_count=self.num_generations,
708
+ batch_size=effective_batch_size // self.num_generations,
709
+ repeat_count=self.num_iterations,
710
+ seed=self.args.seed,
711
+ )
712
+
713
+ def _get_eval_sampler(self, eval_dataset) -> Sampler:
714
+ """Returns a sampler for evaluation."""
715
+ return RepeatRandomSampler(
716
+ data_source=eval_dataset,
717
+ mini_repeat_count=self.num_generations,
718
+ seed=self.args.seed,
719
+ )
BioReason_new/bioreason/utils/__pycache__/protein_utils.cpython-310.pyc ADDED
Binary file (485 Bytes). View file
 
BioReason_new/bioreason/utils/__pycache__/protein_utils.cpython-311.pyc ADDED
Binary file (724 Bytes). View file
 
BioReason_new/bioreason/utils/protein_utils.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from typing import TYPE_CHECKING, Callable, Optional, Union
2
+
3
+ # import numpy as np
4
+
5
+ # from transformers.utils import is_torch_available
6
+
7
+ # if is_torch_available():
8
+ # import torch
9
+
10
+ # ProteinInput = Union[
11
+ # str, list[int], np.ndarray, "torch.Tensor", list[str], list[list[int]], list[np.ndarray], list["torch.Tensor"]
12
+ # ] # noqa
13
+
14
+
15
+ # def clean_protein_sequence(sequence: str) -> str:
16
+ # """
17
+ # Clean protein sequence by removing invalid characters and normalizing.
18
+
19
+ # Args:
20
+ # sequence: Raw protein sequence string
21
+
22
+ # Returns:
23
+ # Cleaned protein sequence
24
+ # """
25
+ # # Remove whitespace and convert to uppercase
26
+ # sequence = sequence.replace(" ", "").replace("\n", "").upper()
27
+
28
+ # # Keep only valid amino acid characters
29
+ # valid_aa = set("ACDEFGHIKLMNPQRSTVWY")
30
+ # cleaned_sequence = "".join(char for char in sequence if char in valid_aa)
31
+
32
+ # return cleaned_sequence
33
+
34
+
35
+ # def truncate_protein_sequence(sequence: str, max_length: int = 1024) -> str:
36
+ # """
37
+ # Truncate protein sequence to maximum length.
38
+
39
+ # Args:
40
+ # sequence: Protein sequence string
41
+ # max_length: Maximum allowed length
42
+
43
+ # Returns:
44
+ # Truncated protein sequence
45
+ # """
46
+ # if len(sequence) <= max_length:
47
+ # return sequence
48
+
49
+ # # Truncate from both ends to keep the middle part (often most important)
50
+ # if max_length >= 100:
51
+ # start_keep = max_length // 3
52
+ # end_keep = max_length - start_keep
53
+ # return sequence[:start_keep] + sequence[-end_keep:]
54
+ # else:
55
+ # # If very short max_length, just truncate from end
56
+ # return sequence[:max_length]
57
+
58
+
59
+ # def validate_protein_sequence(sequence: str) -> bool:
60
+ # """
61
+ # Validate if a sequence contains only valid amino acid characters.
62
+
63
+ # Args:
64
+ # sequence: Protein sequence string
65
+
66
+ # Returns:
67
+ # True if valid, False otherwise
68
+ # """
69
+ # valid_aa = set("ACDEFGHIKLMNPQRSTVWY")
70
+ # return all(char in valid_aa for char in sequence.upper())
71
+
72
+
73
+ # def get_sequence_stats(sequence: str) -> dict:
74
+ # """
75
+ # Get basic statistics about a protein sequence.
76
+
77
+ # Args:
78
+ # sequence: Protein sequence string
79
+
80
+ # Returns:
81
+ # Dictionary with sequence statistics
82
+ # """
83
+ # sequence = sequence.upper()
84
+ # length = len(sequence)
85
+
86
+ # if length == 0:
87
+ # return {"length": 0, "composition": {}, "molecular_weight": 0.0}
88
+
89
+ # # Amino acid composition
90
+ # composition = {}
91
+ # for aa in "ACDEFGHIKLMNPQRSTVWY":
92
+ # count = sequence.count(aa)
93
+ # composition[aa] = {
94
+ # "count": count,
95
+ # "frequency": count / length if length > 0 else 0
96
+ # }
97
+
98
+ # # Approximate molecular weight (Da)
99
+ # aa_weights = {
100
+ # 'A': 89.1, 'C': 121.0, 'D': 133.1, 'E': 147.1, 'F': 165.2,
101
+ # 'G': 75.1, 'H': 155.2, 'I': 131.2, 'K': 146.2, 'L': 131.2,
102
+ # 'M': 149.2, 'N': 132.1, 'P': 115.1, 'Q': 146.2, 'R': 174.2,
103
+ # 'S': 105.1, 'T': 119.1, 'V': 117.1, 'W': 204.2, 'Y': 181.2
104
+ # }
105
+
106
+ # molecular_weight = sum(aa_weights.get(aa, 0) for aa in sequence)
107
+ # # Subtract water molecules for peptide bonds
108
+ # molecular_weight -= (length - 1) * 18.015 if length > 1 else 0
109
+
110
+ # return {
111
+ # "length": length,
112
+ # "composition": composition,
113
+ # "molecular_weight": molecular_weight
114
+ # }
115
+
116
+
117
+ # def format_protein_for_display(sequence: str, line_length: int = 80) -> str:
118
+ # """
119
+ # Format protein sequence for display with line breaks.
120
+
121
+ # Args:
122
+ # sequence: Protein sequence string
123
+ # line_length: Number of characters per line
124
+
125
+ # Returns:
126
+ # Formatted sequence string
127
+ # """
128
+ # if not sequence:
129
+ # return ""
130
+
131
+ # lines = []
132
+ # for i in range(0, len(sequence), line_length):
133
+ # line = sequence[i:i + line_length]
134
+ # # Add position numbers
135
+ # pos_start = i + 1
136
+ # pos_end = min(i + line_length, len(sequence))
137
+ # lines.append(f"{pos_start:>8} {line} {pos_end}")
138
+
139
+ # return "\n".join(lines)
140
+
141
+
142
+ # def compare_protein_sequences(seq1: str, seq2: str) -> dict:
143
+ # """
144
+ # Compare two protein sequences and return similarity metrics.
145
+
146
+ # Args:
147
+ # seq1: First protein sequence
148
+ # seq2: Second protein sequence
149
+
150
+ # Returns:
151
+ # Dictionary with comparison metrics
152
+ # """
153
+ # seq1 = seq1.upper().replace(" ", "")
154
+ # seq2 = seq2.upper().replace(" ", "")
155
+
156
+ # if not seq1 or not seq2:
157
+ # return {"identity": 0.0, "similarity": 0.0, "gaps": 0}
158
+
159
+ # # Simple identity calculation (without proper alignment)
160
+ # min_len = min(len(seq1), len(seq2))
161
+ # max_len = max(len(seq1), len(seq2))
162
+
163
+ # identical = 0
164
+ # for i in range(min_len):
165
+ # if seq1[i] == seq2[i]:
166
+ # identical += 1
167
+
168
+ # identity = identical / max_len if max_len > 0 else 0.0
169
+ # gaps = abs(len(seq1) - len(seq2))
170
+
171
+ # # Simple similarity (identical positions / shorter sequence length)
172
+ # similarity = identical / min_len if min_len > 0 else 0.0
173
+
174
+ # return {
175
+ # "identity": identity,
176
+ # "similarity": similarity,
177
+ # "gaps": gaps,
178
+ # "identical_positions": identical,
179
+ # "seq1_length": len(seq1),
180
+ # "seq2_length": len(seq2)
181
+ # }
182
+
183
+
184
+ # def extract_protein_domains(sequence: str, domain_patterns: dict = None) -> list:
185
+ # """
186
+ # Extract potential protein domains based on simple patterns.
187
+
188
+ # Args:
189
+ # sequence: Protein sequence string
190
+ # domain_patterns: Dictionary of domain name to regex pattern
191
+
192
+ # Returns:
193
+ # List of detected domains
194
+ # """
195
+ # import re
196
+
197
+ # if domain_patterns is None:
198
+ # # Simple example patterns (in real use, you'd use proper domain databases)
199
+ # domain_patterns = {
200
+ # "signal_peptide": r"^M[A-Z]{10,30}[RK]", # Very simple signal peptide pattern
201
+ # "transmembrane": r"[AILMFWYV]{15,25}", # Hydrophobic stretch
202
+ # "nuclear_localization": r"[KR]{2,}[A-Z]{10,20}[KR]{2,}", # Basic NLS pattern
203
+ # }
204
+
205
+ # domains = []
206
+ # for domain_name, pattern in domain_patterns.items():
207
+ # matches = list(re.finditer(pattern, sequence))
208
+ # for match in matches:
209
+ # domains.append({
210
+ # "domain": domain_name,
211
+ # "start": match.start() + 1, # 1-based indexing
212
+ # "end": match.end(),
213
+ # "sequence": match.group()
214
+ # })
215
+
216
+ # return domains
217
+
218
+ from typing import TYPE_CHECKING, Callable, Optional, Union
219
+
220
+ import numpy as np
221
+
222
+ from transformers.utils import is_torch_available
223
+
224
+ if is_torch_available():
225
+ import torch
226
+
227
+ ProteinInput = Union[
228
+ str, list[int], np.ndarray, "torch.Tensor", list[str], list[list[int]], list[np.ndarray], list["torch.Tensor"]
229
+ ] # noqa
BioReason_new/readme.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # 1. 对比学习预训练
2
+ python train_contrastive.py --use_wandb --freeze_protein_model --freeze_text_model
3
+
4
+ # 2. 监督微调
5
+ python train_protein_qwen.py --model_type protein-llm --text_model_finetune True
6
+
7
+ # 3. GRPO训练
8
+ python protein_reason.py --sft_checkpoint ./checkpoints/best_model
BioReason_new/reason.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import pathlib
4
+ from argparse import ArgumentParser
5
+ from typing import List, Dict, Optional
6
+ from dataclasses import dataclass, field
7
+
8
+ import torch
9
+ from torch import nn
10
+ import torch.nn.functional as F
11
+ from torch.optim import AdamW
12
+ from torch.utils.data import DataLoader, Dataset
13
+ from transformers import get_cosine_schedule_with_warmup, AutoTokenizer
14
+
15
+ from transformers import (
16
+ AutoTokenizer,
17
+ AutoModelForCausalLM,
18
+ AutoModelForMaskedLM,
19
+ AutoProcessor,
20
+ )
21
+
22
+ from datasets import load_dataset, DatasetDict
23
+
24
+ from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
25
+ from transformers import BitsAndBytesConfig
26
+
27
+ import pytorch_lightning as pl
28
+ from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
29
+ from pytorch_lightning.loggers import WandbLogger
30
+
31
+ from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
32
+
33
+ # Import BLIP2 modules
34
+ from model.blip2_stage2 import Blip2Stage2
35
+ from blip2_dna_module import Blip2DNAModule
36
+ from blip2_grpo_trainer import Blip2GRPOTrainer
37
+ from bioreason.trainer import DNALLMGRPOConfig
38
+
39
+ # Custom TrainerCallback to override the saving mechanism
40
+ from transformers import TrainerCallback, TrainerState, TrainerControl
41
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
42
+
43
+ class SaveWithPyTorchCallback(TrainerCallback):
44
+ """Custom callback to save models with PyTorch's native save mechanism instead of safetensors"""
45
+ def on_save(self, args, state, control, **kwargs):
46
+ # Get the checkpoint folder
47
+ checkpoint_folder = os.path.join(
48
+ args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
49
+ )
50
+ os.makedirs(checkpoint_folder, exist_ok=True)
51
+
52
+ # Save with PyTorch instead of safetensors
53
+ checkpoint_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
54
+ model = kwargs.get("model")
55
+
56
+ # Get model unwrapped from accelerator etc.
57
+ unwrapped_model = model.module if hasattr(model, "module") else model
58
+
59
+ # Save using PyTorch directly
60
+ torch.save(unwrapped_model.state_dict(), checkpoint_path)
61
+
62
+ # For BLIP2, save the config from the LLM component
63
+ if hasattr(unwrapped_model, "blip2") and hasattr(unwrapped_model.blip2, "llm_model"):
64
+ if hasattr(unwrapped_model.blip2.llm_model, "config"):
65
+ unwrapped_model.blip2.llm_model.config.save_pretrained(checkpoint_folder)
66
+ elif hasattr(unwrapped_model.blip2.llm_model, "base_model") and hasattr(unwrapped_model.blip2.llm_model.base_model, "config"):
67
+ unwrapped_model.blip2.llm_model.base_model.config.save_pretrained(checkpoint_folder)
68
+
69
+ # Print info about what's being saved
70
+ print(f"Saved model checkpoint to {checkpoint_folder}")
71
+ lora_params = [k for k in unwrapped_model.state_dict().keys() if "lora" in k]
72
+ print(f"Checkpoint contains {len(lora_params)} LoRA parameters")
73
+
74
+ # Signal that we've saved
75
+ control.should_save = False
76
+ return control
77
+
78
+ def extract_xml_answer(text: str) -> str:
79
+ answer = text.split("</think>")[-1]
80
+ return answer.strip()
81
+
82
+ def extract_hash_answer(text: str) -> str | None:
83
+ if "####" not in text:
84
+ return None
85
+ return text.split("####")[1].strip()
86
+
87
+ def get_kegg_questions() -> Dataset:
88
+ data = load_dataset('wanglab/kegg', 'default') # type: ignore
89
+ example_dna_sequences = ["ATCTACATGCAT", "CAGCAGCTACAG", "CATCACATCGACATCGAC"]
90
+ num_dna_sequences = 2
91
+
92
+ data = data.map(lambda x: { # type: ignore
93
+ 'prompt': [
94
+ {
95
+ 'role': 'user',
96
+ 'content': [
97
+ *({'type': 'dna', 'text': None} for _ in range(num_dna_sequences)),
98
+ {'type': 'text', 'text': x['question']},
99
+ ],
100
+ },
101
+ ],
102
+ 'dna_sequences': [x['reference_sequence'], x['variant_sequence']],
103
+ 'answer': x['answer'],
104
+ }) # type: ignore
105
+
106
+ return data
107
+
108
+ def get_gsm8k_questions(question_prompt: str) -> Dataset:
109
+ data = load_dataset('openai/gsm8k', 'main') # type: ignore
110
+
111
+ example_dna_sequences = ["ATCTACATGCAT", "CAGCAGCTACAG", "CATCACATCGACATCGAC"]
112
+ data = data.map(lambda x: { # type: ignore
113
+ 'prompt': [
114
+ {
115
+ 'role': 'user',
116
+ 'content': [
117
+ *({'type': 'dna', 'text': None} for _ in range(len(example_dna_sequences))),
118
+ {'type': 'text', 'text': 'Give me a short introduction to large language model.'}
119
+ ]
120
+ },
121
+ ],
122
+ 'dna_sequences': [dna for dna in example_dna_sequences],
123
+ 'answer': extract_hash_answer(x['answer']),
124
+ }) # type: ignore
125
+
126
+ return data # type: ignore
127
+
128
+ # Reward functions
129
+ def format_correct_reward_func(completions, **kwargs) -> list[float]:
130
+ """
131
+ 奖励函数:检查格式是否正确
132
+ 要求��包含 <think>...</think> 和 <answer>...</answer> 标签
133
+ """
134
+ responses = [completion[0]["content"] for completion in completions]
135
+ rewards = []
136
+
137
+ for response in responses:
138
+ score = 0.0
139
+
140
+ # 检查是否有think标签
141
+ if "<think>" in response and "</think>" in response:
142
+ score += 0.5
143
+
144
+ # 检查是否有answer标签
145
+ if "<answer>" in response and "</answer>" in response:
146
+ score += 0.5
147
+
148
+ # 检查标签的顺序是否正确
149
+ think_start = response.find("<think>")
150
+ think_end = response.find("</think>")
151
+ answer_start = response.find("<answer>")
152
+ answer_end = response.find("</answer>")
153
+
154
+ if (think_start != -1 and think_end != -1 and
155
+ answer_start != -1 and answer_end != -1 and
156
+ think_start < think_end < answer_start < answer_end):
157
+ score += 0.5 # 格式完全正确的额外奖励
158
+
159
+ rewards.append(score)
160
+
161
+ return rewards
162
+
163
+ def accuracy_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
164
+ """
165
+ 奖励函数:检查答案准确率
166
+ """
167
+ responses = [completion[0]['content'] for completion in completions]
168
+ rewards = []
169
+
170
+ for i, response in enumerate(responses):
171
+ # 提取answer标签中的内容
172
+ answer_match = re.search(r"<answer>(.*?)</answer>", response, re.DOTALL)
173
+ if answer_match:
174
+ extracted_answer = answer_match.group(1).strip()
175
+ else:
176
+ extracted_answer = response.strip()
177
+
178
+ # 获取正确答案
179
+ if isinstance(answer, list) and len(answer) > i:
180
+ correct_answer = str(answer[i]).strip()
181
+ elif isinstance(answer, list) and len(answer) > 0:
182
+ correct_answer = str(answer[0]).strip()
183
+ else:
184
+ correct_answer = str(answer).strip()
185
+
186
+ # 计算准确率奖励
187
+ if correct_answer.lower() in extracted_answer.lower():
188
+ rewards.append(1.0) # 完全匹配
189
+ elif any(word in extracted_answer.lower() for word in correct_answer.lower().split()):
190
+ rewards.append(0.5) # 部分匹配
191
+ else:
192
+ rewards.append(0.0) # 不匹配
193
+
194
+ return rewards
195
+
196
+ def repetition_penalty_reward_func(completions, **kwargs) -> list[float]:
197
+ """
198
+ 奖励函数:检查重复率(越低越好)
199
+ 计算文本中重复词汇的比例,重复率越低奖励越高
200
+ """
201
+ responses = [completion[0]["content"] for completion in completions]
202
+ rewards = []
203
+
204
+ for response in responses:
205
+ # 提取answer部分的文本
206
+ answer_match = re.search(r"<answer>(.*?)</answer>", response, re.DOTALL)
207
+ if answer_match:
208
+ text_to_analyze = answer_match.group(1).strip()
209
+ else:
210
+ text_to_analyze = response.strip()
211
+
212
+ # 分词并计算重复率
213
+ words = text_to_analyze.lower().split()
214
+
215
+ if len(words) == 0:
216
+ rewards.append(0.0)
217
+ continue
218
+
219
+ # 计算词汇重复率
220
+ unique_words = set(words)
221
+ repetition_rate = 1.0 - (len(unique_words) / len(words))
222
+
223
+ # 计算句子重复率
224
+ sentences = [s.strip() for s in text_to_analyze.split('.') if s.strip()]
225
+ if len(sentences) > 1:
226
+ unique_sentences = set(sentences)
227
+ sentence_repetition_rate = 1.0 - (len(unique_sentences) / len(sentences))
228
+ else:
229
+ sentence_repetition_rate = 0.0
230
+
231
+ # 综合重复率
232
+ overall_repetition = (repetition_rate + sentence_repetition_rate) / 2
233
+
234
+ # 重复率越低,奖励越高
235
+ reward = max(0.0, 1.0 - overall_repetition * 2) # 乘以2让惩罚更明显
236
+ rewards.append(reward)
237
+
238
+ return rewards
239
+
240
+ def combined_reward_func(prompts, completions, answer,
241
+ format_weight=0.3, accuracy_weight=0.5, repetition_weight=0.2,
242
+ **kwargs) -> list[float]:
243
+ """
244
+ 组合奖励函数:格式+准确率+重复率的加权组合
245
+ """
246
+ format_rewards = format_correct_reward_func(completions, **kwargs)
247
+ accuracy_rewards = accuracy_reward_func(prompts, completions, answer, **kwargs)
248
+ repetition_rewards = repetition_penalty_reward_func(completions, **kwargs)
249
+
250
+ # 确保权重总和为1
251
+ total_weight = format_weight + accuracy_weight + repetition_weight
252
+ if total_weight != 1.0:
253
+ format_weight /= total_weight
254
+ accuracy_weight /= total_weight
255
+ repetition_weight /= total_weight
256
+ print(f"Normalized weights - Format: {format_weight:.3f}, Accuracy: {accuracy_weight:.3f}, Repetition: {repetition_weight:.3f}")
257
+
258
+ combined_rewards = []
259
+ for f_reward, a_reward, r_reward in zip(format_rewards, accuracy_rewards, repetition_rewards):
260
+ combined = (format_weight * f_reward +
261
+ accuracy_weight * a_reward +
262
+ repetition_weight * r_reward)
263
+ combined_rewards.append(combined)
264
+
265
+ return combined_rewards
266
+
267
+ # 保留一些原有的奖励函数作为备选
268
+ def less_than_4_reward_func(completions, **kwargs) -> list[float]:
269
+ responses = [completion[0]['content'] for completion in completions]
270
+ extracted_responses = [extract_xml_answer(r) for r in responses]
271
+ return [0.5 if len(r.split(' ')) <= 4 else 0.0 for r in extracted_responses]
272
+
273
+ def strict_format_reward_func(completions, **kwargs) -> list[float]:
274
+ """Reward function that checks if the completion has a specific format."""
275
+ pattern = r"^<think>\n.*?\n</think>\n.*?\n$"
276
+ responses = [completion[0]["content"] for completion in completions]
277
+ matches = [re.match(pattern, r) for r in responses]
278
+ return [0.5 if match else 0.0 for match in matches]
279
+
280
+ def xmlcount_reward_func(completions, **kwargs) -> list[float]:
281
+ contents = [completion[0]["content"] for completion in completions]
282
+ return [count_xml(c) for c in contents]
283
+
284
+ def count_xml(text) -> float:
285
+ count = 0.0
286
+ if text.count("<think>\n") == 1:
287
+ count += 0.125
288
+ if text.count("\n</think>\n") == 1:
289
+ count += 0.125
290
+ return count
291
+
292
+ @dataclass
293
+ class Blip2ModelConfig(ModelConfig):
294
+ # BLIP2 specific configuration
295
+ model_name_or_path: str = field(default="blip2-model", metadata={"help": "Model checkpoint for weights initialization."})
296
+
297
+ # BLIP2 Architecture parameters
298
+ bert_name: str = field(default="/path/to/bert", metadata={"help": "BERT model for Q-former"})
299
+ num_query_token: int = field(default=32, metadata={"help": "Number of query tokens"})
300
+ cross_attention_freq: int = field(default=2, metadata={"help": "Cross attention frequency"})
301
+ plm_model: str = field(default="facebook/esm2_t30_150M_UR50D", metadata={"help": "Protein language model"})
302
+ plm_tune: str = field(default="freeze", metadata={"help": "PLM tuning strategy"})
303
+ llm_name: str = field(default="facebook/galactica-1.3b", metadata={"help": "Language model name"})
304
+ llm_tune: str = field(default="lora", metadata={"help": "LLM tuning strategy"})
305
+ qformer_tune: str = field(default="train", metadata={"help": "Q-former tuning strategy"})
306
+ peft_dir: str = field(default="", metadata={"help": "PEFT directory"})
307
+
308
+ # LoRA parameters
309
+ lora_r: int = field(default=8, metadata={"help": "LoRA rank"})
310
+ lora_alpha: int = field(default=16, metadata={"help": "LoRA alpha"})
311
+ lora_dropout: float = field(default=0.1, metadata={"help": "LoRA dropout"})
312
+
313
+ # Training parameters
314
+ enbale_gradient_checkpointing: bool = field(default=False, metadata={"help": "Enable gradient checkpointing"})
315
+ enable_flash: bool = field(default=False, metadata={"help": "Enable flash attention"})
316
+
317
+ # Other parameters
318
+ cache_dir: str = field(default=None, metadata={"help": "Path to model cache directory."})
319
+ sft_checkpoint: str = field(default=None, metadata={"help": "Path to the checkpoint for SFT."})
320
+ freeze_dna_modules: bool = field(default=False, metadata={"help": "Freeze DNA/protein modules"})
321
+
322
+ @dataclass
323
+ class GRPOScriptArguments(ScriptArguments):
324
+ """
325
+ Script arguments for the GRPO training script with BLIP2.
326
+ """
327
+ dataset_name: str = field(default="wanglab/kegg", metadata={"help": "Dataset name with default."})
328
+ data_file_paths: str = field(
329
+ default=None,
330
+ metadata={"help": "Paths to data files, separated by ':'"},
331
+ )
332
+ arrow_cache_dir: str = field(
333
+ default=None,
334
+ metadata={"help": "Path to arrow cache directory"},
335
+ )
336
+ val_split_ratio: float = field(
337
+ default=0.0,
338
+ metadata={"help": "Ratio of validation split, default 0.0"},
339
+ )
340
+ reward_funcs: list[str] = field(
341
+ # 选项1:使用组合奖励函数(推荐)
342
+ default_factory=lambda: ["combined"],
343
+
344
+ # 选项2:使用分离的三个奖励函数
345
+ # default_factory=lambda: ["format_correct", "accuracy", "repetition_penalty"],
346
+
347
+ # 选项3:自定义组合
348
+ # default_factory=lambda: ["format_correct", "accuracy", "repetition_penalty", "xmlcount"],
349
+
350
+ metadata={"help": "List of reward functions. Available: 'combined', 'format_correct', 'accuracy', 'repetition_penalty', 'xmlcount', 'strict_format', 'less_than_4'"},
351
+ )
352
+
353
+ # 奖励函数权重配置
354
+ format_weight: float = field(
355
+ default=0.3,
356
+ metadata={"help": "Weight for format correctness reward (used in combined reward)"}
357
+ )
358
+ accuracy_weight: float = field(
359
+ default=0.5,
360
+ metadata={"help": "Weight for accuracy reward (used in combined reward)"}
361
+ )
362
+ repetition_weight: float = field(
363
+ default=0.2,
364
+ metadata={"help": "Weight for repetition penalty reward (used in combined reward)"}
365
+ )
366
+
367
+ reward_funcs_registry = {
368
+ # 新的三合一奖励函数
369
+ "combined": combined_reward_func, # 格式+准确率+重复率组合
370
+
371
+ # 分离的奖励函数
372
+ "format_correct": format_correct_reward_func, # 格式正确性
373
+ "accuracy": accuracy_reward_func, # 准确率
374
+ "repetition_penalty": repetition_penalty_reward_func, # 重复率惩罚
375
+
376
+ # 原有的奖励函数(保留作为备选)
377
+ "xmlcount": xmlcount_reward_func,
378
+ "strict_format": strict_format_reward_func,
379
+ "less_than_4": less_than_4_reward_func,
380
+ }
381
+
382
+ def get_vlm_module(model_name_or_path):
383
+ # Always use BLIP2 module for this implementation
384
+ return Blip2DNAModule
385
+
386
+ def create_blip2_args_from_config(model_args):
387
+ """Create BLIP2 args from model config"""
388
+ # Convert model config to the format expected by BLIP2
389
+ blip2_args = {
390
+ 'bert_name': model_args.bert_name,
391
+ 'num_query_token': model_args.num_query_token,
392
+ 'cross_attention_freq': model_args.cross_attention_freq,
393
+ 'plm_model': model_args.plm_model,
394
+ 'plm_tune': model_args.plm_tune,
395
+ 'llm_name': model_args.llm_name,
396
+ 'llm_tune': model_args.llm_tune,
397
+ 'qformer_tune': model_args.qformer_tune,
398
+ 'peft_dir': model_args.peft_dir,
399
+ 'lora_r': model_args.lora_r,
400
+ 'lora_alpha': model_args.lora_alpha,
401
+ 'lora_dropout': model_args.lora_dropout,
402
+ 'enbale_gradient_checkpointing': model_args.enbale_gradient_checkpointing,
403
+ 'enable_flash': model_args.enable_flash,
404
+ }
405
+ return blip2_args
406
+
407
+ def _prep_for_training(model, training_args):
408
+ """
409
+ Prepare BLIP2 model for training with LoRA.
410
+ """
411
+ # The BLIP2 model should handle its own LoRA setup
412
+ # This is mainly for any additional preparation needed
413
+
414
+ target_modules = ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"]
415
+
416
+ lora_config = LoraConfig(
417
+ r=training_args.lora_r,
418
+ lora_alpha=training_args.lora_alpha,
419
+ lora_dropout=training_args.lora_dropout,
420
+ target_modules=target_modules,
421
+ init_lora_weights="gaussian",
422
+ bias="none",
423
+ task_type="CAUSAL_LM",
424
+ )
425
+
426
+ return lora_config
427
+
428
+ def main(script_args, training_args, model_args):
429
+ print(training_args.output_dir)
430
+ torch.cuda.empty_cache()
431
+ torch.set_float32_matmul_precision("medium")
432
+
433
+ # Create BLIP2 model
434
+ blip2_args = create_blip2_args_from_config(model_args)
435
+ model = Blip2Stage2(blip2_args)
436
+
437
+ # Load checkpoint if specified
438
+ if model_args.sft_checkpoint is not None:
439
+ print(f"Loading SFT checkpoint from {model_args.sft_checkpoint}")
440
+
441
+ if os.path.isdir(model_args.sft_checkpoint):
442
+ # Load Lightning checkpoint
443
+ checkpoint = torch.load(os.path.join(model_args.sft_checkpoint, "last.ckpt"), map_location='cpu')
444
+ model.load_state_dict(checkpoint['state_dict'], strict=False)
445
+ print("Loaded Lightning checkpoint")
446
+ else:
447
+ # Load PyTorch state dict
448
+ checkpoint = torch.load(model_args.sft_checkpoint, map_location='cpu')
449
+
450
+ if "state_dict" in checkpoint:
451
+ state_dict = checkpoint["state_dict"]
452
+ else:
453
+ state_dict = checkpoint
454
+
455
+ # Remove module prefix if present
456
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
457
+
458
+ result = model.load_state_dict(state_dict, strict=False)
459
+ print(f"Loaded checkpoint with {len(result.missing_keys)} missing keys and {len(result.unexpected_keys)} unexpected keys")
460
+
461
+ # Get reward functions with weights
462
+ reward_funcs = []
463
+ for func_name in script_args.reward_funcs:
464
+ if func_name == "combined":
465
+ # 为组合奖励函数传递权重参数
466
+ def weighted_combined_reward(prompts, completions, answer, **kwargs):
467
+ return combined_reward_func(
468
+ prompts, completions, answer,
469
+ format_weight=script_args.format_weight,
470
+ accuracy_weight=script_args.accuracy_weight,
471
+ repetition_weight=script_args.repetition_weight,
472
+ **kwargs
473
+ )
474
+ reward_funcs.append(weighted_combined_reward)
475
+ else:
476
+ reward_funcs.append(reward_funcs_registry[func_name])
477
+
478
+ print("reward_funcs:", [func.__name__ if hasattr(func, '__name__') else 'weighted_combined_reward' for func in reward_funcs])
479
+ print(f"Reward weights - Format: {script_args.format_weight}, Accuracy: {script_args.accuracy_weight}, Repetition: {script_args.repetition_weight}")
480
+
481
+ vlm_module_cls = get_vlm_module(model_args.model_name_or_path)
482
+ print("using vlm module:", vlm_module_cls.__name__)
483
+ question_prompt = vlm_module_cls.get_question_template()
484
+
485
+ # Load dataset
486
+ dataset = get_kegg_questions()
487
+ print(dataset)
488
+
489
+ # Custom callback to handle saving with PyTorch's native mechanism
490
+ custom_save_callback = SaveWithPyTorchCallback()
491
+
492
+ # Initialize the BLIP2 GRPO trainer
493
+ trainer = Blip2GRPOTrainer(
494
+ model=model,
495
+ reward_funcs=reward_funcs,
496
+ args=training_args,
497
+ dna_module=vlm_module_cls(),
498
+ train_dataset=dataset['train'],
499
+ eval_dataset=dataset['val'] if training_args.eval_strategy != "no" else None,
500
+ peft_config=get_peft_config(model_args),
501
+ attn_implementation=getattr(model_args, 'attn_implementation', 'flash_attention_2'),
502
+ torch_dtype=getattr(model_args, 'torch_dtype', 'bfloat16'),
503
+ callbacks=[custom_save_callback],
504
+ )
505
+
506
+ # Set the trainer to save in PyTorch format instead of safetensors
507
+ training_args.save_safetensors = False
508
+
509
+ # Train the model
510
+ trainer.train()
511
+
512
+ if __name__ == "__main__":
513
+ print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}")
514
+ parser = TrlParser((GRPOScriptArguments, DNALLMGRPOConfig, Blip2ModelConfig))
515
+ script_args, training_args, model_args = parser.parse_args_and_config()
516
+
517
+ # Ensure we use PyTorch's save mechanism instead of safetensors
518
+ training_args.save_safetensors = False
519
+
520
+ main(script_args, training_args, model_args)
BioReason_new/run.sh ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Example training scripts for Protein-LLM project
4
+
5
+ # =============================================================================
6
+ # 1. Contrastive Pre-training (Stage 1)
7
+ # Train QFormer projection layer for protein-text alignment
8
+ # =============================================================================
9
+ echo "Starting contrastive pre-training..."
10
+
11
+ python train_contrastive.py \
12
+ --text_model_name "Qwen/Qwen3-1.7B" \
13
+ --protein_model_name "facebook/esm2_t6_8M_UR50D" \
14
+ --qformer_model_name "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext" \
15
+ --dataset_name "wanglab/protein_descriptions" \
16
+ --output_dir "./contrastive_outputs" \
17
+ --num_epochs 10 \
18
+ --batch_size 32 \
19
+ --learning_rate 1e-4 \
20
+ --temperature 0.07 \
21
+ --freeze_protein_model \
22
+ --freeze_text_model \
23
+ --max_length_protein 1024 \
24
+ --max_length_text 512 \
25
+ --eval_dataset \
26
+ --use_wandb \
27
+ --wandb_project "protein-llm-contrastive" \
28
+ --logging_steps 100 \
29
+ --eval_steps 500 \
30
+ --save_steps 1000
31
+
32
+ echo "Contrastive pre-training completed!"
33
+
34
+ # =============================================================================
35
+ # 2. Supervised Fine-tuning (Stage 2)
36
+ # Fine-tune the entire model on protein function prediction tasks
37
+ # =============================================================================
38
+ echo "Starting supervised fine-tuning..."
39
+
40
+ python train_protein_qwen.py \
41
+ --model_type "protein-llm" \
42
+ --text_model_name "Qwen/Qwen3-1.7B" \
43
+ --protein_model_name "facebook/esm2_t6_8M_UR50D" \
44
+ --qformer_model_name "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext" \
45
+ --dataset_type "protein_function" \
46
+ --protein_function_data_dir_huggingface "wanglab/protein_function" \
47
+ --text_model_finetune True \
48
+ --protein_model_finetune False \
49
+ --num_query_tokens 32 \
50
+ --seed 23 \
51
+ --batch_size 4 \
52
+ --max_epochs 5 \
53
+ --learning_rate 5e-5 \
54
+ --weight_decay 0.01 \
55
+ --gradient_accumulation_steps 8 \
56
+ --max_length_protein 1024 \
57
+ --max_length_text 1024 \
58
+ --lora_rank 32 \
59
+ --lora_alpha 64 \
60
+ --lora_dropout 0.05 \
61
+ --num_gpus 1 \
62
+ --strategy "ddp" \
63
+ --wandb_project "esm2-qwen3-1.7b-finetune" \
64
+ --checkpoint_dir "./checkpoints" \
65
+ --log_dir "./logs" \
66
+ --cache_dir "/model-weights"
67
+
68
+ echo "Supervised fine-tuning completed!"
69
+
70
+ # =============================================================================
71
+ # 3. GRPO Training (Stage 3)
72
+ # Reinforcement learning with Group Relative Policy Optimization
73
+ # =============================================================================
74
+ echo "Starting GRPO training..."
75
+
76
+ python protein_reason.py \
77
+ --output_dir "./grpo_outputs" \
78
+ --model_name_or_path "Qwen/Qwen3-0.6B" \
79
+ --protein_model_name_or_path "facebook/esm2_t6_8M_UR50D" \
80
+ --qformer_model_name_or_path "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext" \
81
+ --dataset_name "wanglab/protein_function" \
82
+ --sft_checkpoint "./checkpoints/best_model" \
83
+ --per_device_train_batch_size 4 \
84
+ --gradient_accumulation_steps 4 \
85
+ --num_train_epochs 3 \
86
+ --learning_rate 1e-6 \
87
+ --beta 0.04 \
88
+ --temperature 0.6 \
89
+ --top_p 0.95 \
90
+ --top_k 20 \
91
+ --max_completion_length 800 \
92
+ --num_generations 8 \
93
+ --reward_funcs "xmlcount" "soft_format" "strict_format" "correctness" \
94
+ --lora_r 32 \
95
+ --lora_alpha 64 \
96
+ --lora_dropout 0.05 \
97
+ --freeze_protein_modules \
98
+ --logging_steps 2 \
99
+ --eval_strategy "steps" \
100
+ --eval_steps 100 \
101
+ --save_steps 200 \
102
+ --report_to "wandb" \
103
+ --log_completions
104
+
105
+ echo "GRPO training completed!"
106
+
107
+ echo "All training stages completed successfully!"
BioReason_new/run_contrast.sh ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ echo "Starting contrastive pre-training..."
2
+ export WANDB_BASE_URL=https://api.bandw.top
3
+
4
+ # 指定要使用的 GPU 卡(例如使用 0,1,2,3 四张卡)
5
+ export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
6
+ python train_contrastive.py \
7
+ --text_model_name "/oss/wangyujia/BIO/construction_finetuning/alpaca/v1-20250609-141541/checkpoint-50-merged" \
8
+ --protein_model_name "/nas/shared/kilab/wangyujia/ProtT3/plm_model/esm2-150m" \
9
+ --qformer_model_name "/nas/shared/kilab/wangyujia/ProtT3/plm_model/microsoft" \
10
+ --num_query_tokens 8 \
11
+ --train_dataset "/nas/shared/kilab/wangyujia/ProtT3/data/SwissProtV3/train_set.jsonl" \
12
+ --valid_dataset "/nas/shared/kilab/wangyujia/ProtT3/data/SwissProtV3/valid_set.jsonl" \
13
+ --output_dir "./contrastive_outputs" \
14
+ --num_epochs 10 \
15
+ --batch_size 8 \
16
+ --learning_rate 1e-4 \
17
+ --temperature 0.07 \
18
+ --freeze_protein_model \
19
+ --freeze_text_model \
20
+ --enable_ptm \
21
+ --max_length_protein 1024 \
22
+ --max_length_text 512 \
23
+ --num_workers 8 \
24
+ --eval_dataset \
25
+ --use_wandb \
26
+ --wandb_project "protein-llm-contrastive" \
27
+ --logging_steps 100 \
28
+ --eval_steps 500 \
29
+ --save_steps 1000
30
+
31
+ echo "Contrastive pre-training completed!"
BioReason_new/train_contrastive.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import os
2
+ # import time
3
+ # from argparse import ArgumentParser
4
+ # from functools import partial
5
+
6
+ # import torch
7
+ # import wandb
8
+ # from datasets import load_dataset
9
+ # from torch.utils.data import DataLoader
10
+
11
+ # import pytorch_lightning as pl
12
+ # from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
13
+ # from pytorch_lightning.loggers import WandbLogger
14
+ # from pytorch_lightning.strategies import DeepSpeedStrategy
15
+
16
+ # from bioreason.models.protein_llm import ProteinLLMModel
17
+ # from bioreason.models.contrast_trainer import (
18
+ # ContrastiveTrainer,
19
+ # ContrastiveTrainingArguments,
20
+ # protein_text_collate_fn,
21
+ # )
22
+ # from bioreason.dataset.protein import format_protein_contrastive
23
+
24
+
25
+ # def main(args):
26
+ # """
27
+ # Main function for contrastive pre-training of Protein-LLM.
28
+
29
+ # This script trains the QFormer projection layer to align protein and text representations
30
+ # using contrastive learning before fine-tuning on downstream tasks.
31
+ # """
32
+
33
+ # # Set random seed
34
+ # pl.seed_everything(args.seed)
35
+ # torch.cuda.empty_cache()
36
+ # torch.set_float32_matmul_precision("medium")
37
+
38
+ # # Initialize wandb
39
+ # if args.use_wandb:
40
+ # wandb.init(
41
+ # project=args.wandb_project,
42
+ # entity=args.wandb_entity,
43
+ # name=f"contrastive-{args.text_model_name.split('/')[-1]}-{time.strftime('%Y%m%d-%H%M%S')}",
44
+ # config=vars(args)
45
+ # )
46
+
47
+ # print("Loading model...")
48
+ # # Load the Protein-LLM model
49
+ # model = ProteinLLMModel(
50
+ # text_model_name=args.text_model_name,
51
+ # protein_model_name=args.protein_model_name,
52
+ # qformer_model_name=args.qformer_model_name,
53
+ # cache_dir=args.cache_dir,
54
+ # max_length_protein=args.max_length_protein,
55
+ # max_length_text=args.max_length_text,
56
+ # text_model_finetune=False, # Don't fine-tune during contrastive learning
57
+ # protein_model_finetune=False, # Don't fine-tune during contrastive learning
58
+ # num_query_tokens=args.num_query_tokens,
59
+ # )
60
+
61
+ # print("Loading datasets...")
62
+ # # Load datasets for contrastive learning
63
+ # train_dataset = load_dataset(args.dataset_name, split="train")
64
+ # eval_dataset = load_dataset(args.dataset_name, split="validation") if args.eval_dataset else None
65
+
66
+ # # Format datasets for contrastive learning
67
+ # train_dataset = train_dataset.map(format_protein_contrastive)
68
+ # if eval_dataset:
69
+ # eval_dataset = eval_dataset.map(format_protein_contrastive)
70
+
71
+ # # Filter out examples without protein sequences or descriptions
72
+ # train_dataset = train_dataset.filter(
73
+ # lambda x: x["protein_sequence"] and x["text_description"]
74
+ # and len(x["protein_sequence"].strip()) > 0
75
+ # and len(x["text_description"].strip()) > 0
76
+ # )
77
+
78
+ # if eval_dataset:
79
+ # eval_dataset = eval_dataset.filter(
80
+ # lambda x: x["protein_sequence"] and x["text_description"]
81
+ # and len(x["protein_sequence"].strip()) > 0
82
+ # and len(x["text_description"].strip()) > 0
83
+ # )
84
+
85
+ # print(f"Training dataset size: {len(train_dataset)}")
86
+ # if eval_dataset:
87
+ # print(f"Eval dataset size: {len(eval_dataset)}")
88
+
89
+ # # Setup training arguments for contrastive learning
90
+ # training_args = ContrastiveTrainingArguments(
91
+ # output_dir=args.output_dir,
92
+ # num_train_epochs=args.num_epochs,
93
+ # per_device_train_batch_size=args.batch_size,
94
+ # per_device_eval_batch_size=args.batch_size,
95
+ # learning_rate=args.learning_rate,
96
+ # weight_decay=args.weight_decay,
97
+ # temperature=args.temperature,
98
+ # freeze_protein_model=args.freeze_protein_model,
99
+ # freeze_text_model=args.freeze_text_model,
100
+ # protein_weight=args.protein_weight,
101
+ # text_weight=args.text_weight,
102
+ # max_length_protein=args.max_length_protein,
103
+ # max_length_text=args.max_length_text,
104
+ # logging_steps=args.logging_steps,
105
+ # evaluation_strategy="steps" if eval_dataset else "no",
106
+ # eval_steps=args.eval_steps if eval_dataset else None,
107
+ # save_steps=args.save_steps,
108
+ # save_total_limit=args.save_total_limit,
109
+ # load_best_model_at_end=True if eval_dataset else False,
110
+ # metric_for_best_model="eval_avg_recall_at_1" if eval_dataset else None,
111
+ # greater_is_better=True,
112
+ # report_to=["wandb"] if args.use_wandb else [],
113
+ # warmup_steps=args.warmup_steps,
114
+ # gradient_accumulation_steps=args.gradient_accumulation_steps,
115
+ # fp16=args.fp16,
116
+ # bf16=args.bf16,
117
+ # dataloader_num_workers=args.num_workers,
118
+ # remove_unused_columns=False,
119
+ # seed=args.seed,
120
+ # )
121
+
122
+ # print("Initializing trainer...")
123
+ # # Initialize the contrastive trainer
124
+ # trainer = ContrastiveTrainer(
125
+ # model=model,
126
+ # args=training_args,
127
+ # train_dataset=train_dataset,
128
+ # eval_dataset=eval_dataset,
129
+ # data_collator=protein_text_collate_fn,
130
+ # )
131
+
132
+ # print("Starting contrastive training...")
133
+ # # Train the model
134
+ # trainer.train()
135
+
136
+ # print("Saving final model...")
137
+ # # Save the final model
138
+ # trainer.save_model()
139
+
140
+ # # Save only the projection layer weights for later use
141
+ # projection_path = os.path.join(args.output_dir, "protein_projection.pt")
142
+ # torch.save(model.protein_projection.state_dict(), projection_path)
143
+ # print(f"Saved protein projection weights to: {projection_path}")
144
+
145
+ # # Final evaluation
146
+ # if eval_dataset:
147
+ # print("Running final evaluation...")
148
+ # eval_results = trainer.evaluate()
149
+ # print(f"Final evaluation results: {eval_results}")
150
+
151
+ # if args.use_wandb:
152
+ # wandb.log({"final_eval": eval_results})
153
+
154
+ # print("Contrastive training completed!")
155
+
156
+ # if args.use_wandb:
157
+ # wandb.finish()
158
+
159
+ # return trainer
160
+
161
+
162
+ # if __name__ == "__main__":
163
+ # parser = ArgumentParser(description="Contrastive pre-training for Protein-LLM")
164
+
165
+ # # Model configuration
166
+ # parser.add_argument("--text_model_name", type=str, default="Qwen/Qwen3-1.7B",
167
+ # help="Name or path to the text model")
168
+ # parser.add_argument("--protein_model_name", type=str, default="facebook/esm2_t6_8M_UR50D",
169
+ # help="Name or path to the protein model")
170
+ # parser.add_argument("--qformer_model_name", type=str,
171
+ # default="microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext",
172
+ # help="Name or path to the QFormer model")
173
+ # parser.add_argument("--cache_dir", type=str, default="/model-weights",
174
+ # help="Directory to cache downloaded models")
175
+ # parser.add_argument("--num_query_tokens", type=int, default=32,
176
+ # help="Number of query tokens in QFormer")
177
+
178
+ # # Dataset configuration
179
+ # parser.add_argument("--dataset_name", type=str, default="wanglab/protein_descriptions",
180
+ # help="Name of the dataset for contrastive learning")
181
+ # parser.add_argument("--eval_dataset", action="store_true",
182
+ # help="Whether to use evaluation dataset")
183
+
184
+ # # Training configuration
185
+ # parser.add_argument("--output_dir", type=str, default="./contrastive_outputs",
186
+ # help="Output directory for model and logs")
187
+ # parser.add_argument("--num_epochs", type=int, default=10,
188
+ # help="Number of training epochs")
189
+ # parser.add_argument("--batch_size", type=int, default=32,
190
+ # help="Batch size per device")
191
+ # parser.add_argument("--learning_rate", type=float, default=1e-4,
192
+ # help="Learning rate")
193
+ # parser.add_argument("--weight_decay", type=float, default=0.01,
194
+ # help="Weight decay")
195
+ # parser.add_argument("--warmup_steps", type=int, default=1000,
196
+ # help="Number of warmup steps")
197
+ # parser.add_argument("--gradient_accumulation_steps", type=int, default=1,
198
+ # help="Gradient accumulation steps")
199
+
200
+ # # Contrastive learning specific
201
+ # parser.add_argument("--temperature", type=float, default=0.07,
202
+ # help="Temperature for contrastive loss")
203
+ # parser.add_argument("--freeze_protein_model", action="store_true", default=True,
204
+ # help="Freeze protein model during training")
205
+ # parser.add_argument("--freeze_text_model", action="store_true", default=True,
206
+ # help="Freeze text model during training")
207
+ # parser.add_argument("--protein_weight", type=float, default=1.0,
208
+ # help="Weight for protein features in contrastive loss")
209
+ # parser.add_argument("--text_weight", type=float, default=1.0,
210
+ # help="Weight for text features in contrastive loss")
211
+
212
+ # # Data configuration
213
+ # parser.add_argument("--max_length_protein", type=int, default=1024,
214
+ # help="Maximum length for protein sequences")
215
+ # parser.add_argument("--max_length_text", type=int, default=512,
216
+ # help="Maximum length for text sequences")
217
+ # parser.add_argument("--num_workers", type=int, default=4,
218
+ # help="Number of data loading workers")
219
+
220
+ # # Logging and evaluation
221
+ # parser.add_argument("--logging_steps", type=int, default=100,
222
+ # help="Number of steps between logging")
223
+ # parser.add_argument("--eval_steps", type=int, default=500,
224
+ # help="Number of steps between evaluations")
225
+ # parser.add_argument("--save_steps", type=int, default=1000,
226
+ # help="Number of steps between saving checkpoints")
227
+ # parser.add_argument("--save_total_limit", type=int, default=3,
228
+ # help="Maximum number of checkpoints to keep")
229
+
230
+ # # Hardware configuration
231
+ # parser.add_argument("--fp16", action="store_true",
232
+ # help="Use FP16 precision")
233
+ # parser.add_argument("--bf16", action="store_true",
234
+ # help="Use BF16 precision")
235
+ # parser.add_argument("--seed", type=int, default=42,
236
+ # help="Random seed")
237
+
238
+ # # Wandb logging
239
+ # parser.add_argument("--use_wandb", action="store_true",
240
+ # help="Use Weights & Biases for logging")
241
+ # parser.add_argument("--wandb_project", type=str, default="protein-llm-contrastive",
242
+ # help="Wandb project name")
243
+ # parser.add_argument("--wandb_entity", type=str, default=None,
244
+ # help="Wandb entity name")
245
+
246
+ # args = parser.parse_args()
247
+
248
+ # # Create output directory
249
+ # os.makedirs(args.output_dir, exist_ok=True)
250
+
251
+ # # Run contrastive training
252
+ # trainer = main(args)
253
+
254
+
255
+ import os
256
+ import time
257
+ from argparse import ArgumentParser
258
+ from functools import partial
259
+
260
+ import torch
261
+ import wandb
262
+ from datasets import load_dataset
263
+ from torch.utils.data import DataLoader
264
+
265
+ import pytorch_lightning as pl
266
+ from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
267
+ from pytorch_lightning.loggers import WandbLogger
268
+ from pytorch_lightning.strategies import DeepSpeedStrategy
269
+
270
+ from bioreason.models.protein_llm import ProteinLLMModel
271
+ from bioreason.trainer.contrast_trainer_new import (
272
+ ContrastiveTrainer,
273
+ ContrastiveTrainingArguments,
274
+ protein_text_collate_fn,
275
+ )
276
+ from bioreason.dataset.protein import format_protein_contrastive
277
+
278
+
279
+ def main(args):
280
+ """
281
+ Main function for enhanced contrastive pre-training of Protein-LLM.
282
+
283
+ This script trains the QFormer projection layer to align protein and text representations
284
+ using enhanced contrastive learning with optional protein-text matching before fine-tuning.
285
+ """
286
+
287
+ # Set random seed
288
+ pl.seed_everything(args.seed)
289
+ torch.cuda.empty_cache()
290
+ torch.set_float32_matmul_precision("medium")
291
+
292
+ # Initialize wandb
293
+ if args.use_wandb:
294
+ wandb.init(
295
+ project=args.wandb_project,
296
+ entity=args.wandb_entity,
297
+ name=f"enhanced-contrastive-{args.text_model_name.split('/')[-1]}-{time.strftime('%Y%m%d-%H%M%S')}",
298
+ config=vars(args)
299
+ )
300
+
301
+ print("Loading model...")
302
+ # Load the Protein-LLM model
303
+ model = ProteinLLMModel(
304
+ text_model_name=args.text_model_name,
305
+ protein_model_name=args.protein_model_name,
306
+ qformer_model_name=args.qformer_model_name,
307
+ cache_dir=args.cache_dir,
308
+ max_length_protein=args.max_length_protein,
309
+ max_length_text=args.max_length_text,
310
+ text_model_finetune=False, # Don't fine-tune during contrastive learning
311
+ protein_model_finetune=False, # Don't fine-tune during contrastive learning
312
+ num_query_tokens=args.num_query_tokens,
313
+ )
314
+
315
+ print("Loading datasets...")
316
+ # Load datasets for contrastive learning
317
+ train_dataset = load_dataset("json", data_files=args.train_dataset, split="train")
318
+ eval_dataset = load_dataset("json", data_files=args.valid_dataset, split="train") if args.eval_dataset else None
319
+
320
+ # Format datasets for contrastive learning
321
+ train_dataset = train_dataset.map(format_protein_contrastive)
322
+ if eval_dataset:
323
+ eval_dataset = eval_dataset.map(format_protein_contrastive)
324
+
325
+ # Filter out examples without protein sequences or descriptions
326
+ train_dataset = train_dataset.filter(
327
+ lambda x: x["protein"] and x["text"]
328
+ and len(x["protein"].strip()) > 0
329
+ and len(x["text"].strip()) > 0
330
+ )
331
+
332
+ if eval_dataset:
333
+ eval_dataset = eval_dataset.filter(
334
+ lambda x: x["protein"] and x["text"]
335
+ and len(x["protein"].strip()) > 0
336
+ and len(x["text"].strip()) > 0
337
+ )
338
+
339
+ print(f"Training dataset size: {len(train_dataset)}")
340
+ if eval_dataset:
341
+ print(f"Eval dataset size: {len(eval_dataset)}")
342
+
343
+ # Setup enhanced training arguments for contrastive learning
344
+
345
+ training_args = ContrastiveTrainingArguments(
346
+ output_dir=args.output_dir,
347
+ num_train_epochs=args.num_epochs,
348
+ per_device_train_batch_size=args.batch_size,
349
+ per_device_eval_batch_size=args.batch_size,
350
+ learning_rate=args.learning_rate,
351
+ weight_decay=args.weight_decay,
352
+ temperature=args.temperature,
353
+ freeze_protein_model=args.freeze_protein_model,
354
+ freeze_text_model=args.freeze_text_model,
355
+ protein_weight=args.protein_weight,
356
+ text_weight=args.text_weight,
357
+ enable_ptm=args.enable_ptm,
358
+ ptm_weight=args.ptm_weight,
359
+ max_length_protein=args.max_length_protein,
360
+ max_length_text=args.max_length_text,
361
+ logging_steps=args.logging_steps,
362
+ eval_strategy="steps" if eval_dataset else "no",
363
+ eval_steps=args.eval_steps if eval_dataset else None,
364
+ save_steps=args.save_steps,
365
+ save_total_limit=args.save_total_limit,
366
+ load_best_model_at_end=True if eval_dataset else False,
367
+ metric_for_best_model="eval_avg_recall_at_1" if eval_dataset else None,
368
+ greater_is_better=True,
369
+ report_to=["wandb"] if args.use_wandb else [],
370
+ warmup_steps=args.warmup_steps,
371
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
372
+ fp16=args.fp16,
373
+ bf16=args.bf16,
374
+ dataloader_num_workers=args.num_workers,
375
+ remove_unused_columns=False,
376
+ seed=args.seed,
377
+ # Distributed training settings
378
+ ddp_find_unused_parameters=False,
379
+ dataloader_pin_memory=True,
380
+ )
381
+
382
+ print("Initializing enhanced trainer...")
383
+ # Initialize the enhanced contrastive trainer
384
+ trainer = ContrastiveTrainer(
385
+ model=model,
386
+ args=training_args,
387
+ train_dataset=train_dataset,
388
+ eval_dataset=eval_dataset,
389
+ data_collator=protein_text_collate_fn,
390
+ )
391
+
392
+ print("Starting enhanced contrastive training...")
393
+ print(f"- Contrastive learning enabled")
394
+ print(f"- Protein-text matching: {'enabled' if args.enable_ptm else 'disabled'}")
395
+ print(f"- Temperature: {args.temperature}")
396
+ print(f"- PTM weight: {args.ptm_weight}")
397
+
398
+ # Train the model
399
+ trainer.train()
400
+
401
+ print("Saving final model...")
402
+ # Save the final model
403
+ trainer.save_model()
404
+
405
+ # Save projection layer weights and PTM head if enabled
406
+ projection_path = os.path.join(args.output_dir, "protein_projection.pt")
407
+ torch.save(model.protein_projection.state_dict(), projection_path)
408
+ print(f"Saved protein projection weights to: {projection_path}")
409
+
410
+ if args.enable_ptm and hasattr(trainer.contrastive_loss, 'ptm_head'):
411
+ ptm_head_path = os.path.join(args.output_dir, "ptm_head.pt")
412
+ torch.save(trainer.contrastive_loss.ptm_head.state_dict(), ptm_head_path)
413
+ print(f"Saved PTM head weights to: {ptm_head_path}")
414
+
415
+ # Final evaluation
416
+ if eval_dataset:
417
+ print("Running final evaluation...")
418
+ eval_results = trainer.evaluate()
419
+ print(f"Final evaluation results: {eval_results}")
420
+
421
+ if args.use_wandb:
422
+ wandb.log({"final_eval": eval_results})
423
+
424
+ print("Enhanced contrastive training completed!")
425
+
426
+ if args.use_wandb:
427
+ wandb.finish()
428
+
429
+ return trainer
430
+
431
+
432
+ if __name__ == "__main__":
433
+ parser = ArgumentParser(description="Enhanced contrastive pre-training for Protein-LLM")
434
+
435
+ # Model configuration
436
+ parser.add_argument("--text_model_name", type=str, default="Qwen/Qwen3-1.7B",
437
+ help="Name or path to the text model")
438
+ parser.add_argument("--protein_model_name", type=str, default="facebook/esm2_t6_8M_UR50D",
439
+ help="Name or path to the protein model")
440
+ parser.add_argument("--qformer_model_name", type=str,
441
+ default="microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext",
442
+ help="Name or path to the QFormer model")
443
+ parser.add_argument("--cache_dir", type=str, default="/model-weights",
444
+ help="Directory to cache downloaded models")
445
+ parser.add_argument("--num_query_tokens", type=int, default=32,
446
+ help="Number of query tokens in QFormer")
447
+
448
+ # Dataset configuration
449
+ parser.add_argument("--train_dataset", type=str, default="wanglab/protein_descriptions",
450
+ help="Name of the dataset for contrastive learning")
451
+ parser.add_argument("--valid_dataset", type=str, default="wanglab/protein_descriptions",
452
+ help="Name of the dataset for contrastive learning")
453
+ parser.add_argument("--eval_dataset", action="store_true",
454
+ help="Whether to use evaluation dataset")
455
+
456
+ # Training configuration
457
+ parser.add_argument("--output_dir", type=str, default="./enhanced_contrastive_outputs",
458
+ help="Output directory for model and logs")
459
+ parser.add_argument("--num_epochs", type=int, default=10,
460
+ help="Number of training epochs")
461
+ parser.add_argument("--batch_size", type=int, default=32,
462
+ help="Batch size per device")
463
+ parser.add_argument("--learning_rate", type=float, default=1e-4,
464
+ help="Learning rate")
465
+ parser.add_argument("--weight_decay", type=float, default=0.01,
466
+ help="Weight decay")
467
+ parser.add_argument("--warmup_steps", type=int, default=1000,
468
+ help="Number of warmup steps")
469
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1,
470
+ help="Gradient accumulation steps")
471
+
472
+ # Enhanced contrastive learning specific
473
+ parser.add_argument("--temperature", type=float, default=0.07,
474
+ help="Temperature for contrastive loss")
475
+ parser.add_argument("--freeze_protein_model", action="store_true", default=True,
476
+ help="Freeze protein model during training")
477
+ parser.add_argument("--freeze_text_model", action="store_true", default=True,
478
+ help="Freeze text model during training")
479
+ parser.add_argument("--protein_weight", type=float, default=1.0,
480
+ help="Weight for protein features in contrastive loss")
481
+ parser.add_argument("--text_weight", type=float, default=1.0,
482
+ help="Weight for text features in contrastive loss")
483
+
484
+ # Protein-Text Matching (PTM) configuration
485
+ parser.add_argument("--enable_ptm", action="store_true", default=True,
486
+ help="Enable protein-text matching task")
487
+ parser.add_argument("--ptm_weight", type=float, default=1.0,
488
+ help="Weight for protein-text matching loss")
489
+
490
+ # Data configuration
491
+ parser.add_argument("--max_length_protein", type=int, default=1024,
492
+ help="Maximum length for protein sequences")
493
+ parser.add_argument("--max_length_text", type=int, default=512,
494
+ help="Maximum length for text sequences")
495
+ parser.add_argument("--num_workers", type=int, default=4,
496
+ help="Number of data loading workers")
497
+
498
+ # Logging and evaluation
499
+ parser.add_argument("--logging_steps", type=int, default=100,
500
+ help="Number of steps between logging")
501
+ parser.add_argument("--eval_steps", type=int, default=500,
502
+ help="Number of steps between evaluations")
503
+ parser.add_argument("--save_steps", type=int, default=1000,
504
+ help="Number of steps between saving checkpoints")
505
+ parser.add_argument("--save_total_limit", type=int, default=3,
506
+ help="Maximum number of checkpoints to keep")
507
+
508
+ # Hardware configuration
509
+ parser.add_argument("--fp16", action="store_true",
510
+ help="Use FP16 precision")
511
+ parser.add_argument("--bf16", action="store_true",
512
+ help="Use BF16 precision")
513
+ parser.add_argument("--seed", type=int, default=42,
514
+ help="Random seed")
515
+
516
+ # Wandb logging
517
+ parser.add_argument("--use_wandb", action="store_true",
518
+ help="Use Weights & Biases for logging")
519
+ parser.add_argument("--wandb_project", type=str, default="protein-llm-enhanced-contrastive",
520
+ help="Wandb project name")
521
+ parser.add_argument("--wandb_entity", type=str, default=None,
522
+ help="Wandb entity name")
523
+
524
+ args = parser.parse_args()
525
+
526
+ # Validate arguments
527
+ if args.enable_ptm and not hasattr(args, 'ptm_weight'):
528
+ args.ptm_weight = 1.0
529
+
530
+ # Create output directory
531
+ os.makedirs(args.output_dir, exist_ok=True)
532
+
533
+ # Print configuration
534
+ print("=" * 50)
535
+ print("Enhanced Contrastive Training Configuration:")
536
+ print("=" * 50)
537
+ print(f"Text model: {args.text_model_name}")
538
+ print(f"Protein model: {args.protein_model_name}")
539
+ print(f"QFormer model: {args.qformer_model_name}")
540
+ print(f"Dataset: {args.train_dataset}")
541
+ print(f"Output directory: {args.output_dir}")
542
+ print(f"Batch size: {args.batch_size}")
543
+ print(f"Learning rate: {args.learning_rate}")
544
+ print(f"Temperature: {args.temperature}")
545
+ print(f"Enable PTM: {args.enable_ptm}")
546
+ print(f"PTM weight: {args.ptm_weight}")
547
+ print(f"Number of epochs: {args.num_epochs}")
548
+ print(f"Query tokens: {args.num_query_tokens}")
549
+ print("=" * 50)
550
+
551
+ # Run enhanced contrastive training
552
+ trainer = main(args)
BioReason_new/train_protein_qwen.py ADDED
@@ -0,0 +1,839 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import gc
3
+ import io
4
+ import multiprocessing
5
+ import os
6
+ import time
7
+ import traceback
8
+ from argparse import ArgumentParser
9
+ from functools import partial
10
+ from typing import *
11
+
12
+ import pandas as pd
13
+ import torch
14
+ import wandb
15
+ from datasets import DatasetDict, concatenate_datasets, load_dataset
16
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
17
+ from torch.optim import AdamW
18
+ from torch.utils.data import DataLoader
19
+ from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
20
+ from transformers.tokenization_utils_base import BatchEncoding
21
+
22
+ import pytorch_lightning as pl
23
+ from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
24
+ from pytorch_lightning.loggers import WandbLogger
25
+ from pytorch_lightning.strategies import DeepSpeedStrategy
26
+
27
+ from bioreason.dataset.protein import get_format_protein_function, protein_llm_collate_fn
28
+ from bioreason.dataset.utils import truncate_protein
29
+ from bioreason.models.dl.processing_dl import ProteinLLMProcessor
30
+ from bioreason.models.protein_llm import ProteinLLMModel
31
+
32
+ # Set start method to 'spawn' for CUDA compatibility with multiprocessing
33
+ torch.multiprocessing.set_sharing_strategy("file_system")
34
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
35
+
36
+
37
+ class ProteinLLMFineTuner(pl.LightningModule):
38
+ """
39
+ PyTorch Lightning module for fine-tuning Protein-LLM models.
40
+ """
41
+
42
+ def __init__(self, hparams):
43
+ """
44
+ Initialize the ProteinLLMFineTuner.
45
+
46
+ Args:
47
+ hparams: Hyperparameters for the model and training
48
+ """
49
+ super().__init__()
50
+ self.save_hyperparameters(hparams)
51
+
52
+ self.text_model_name = self.hparams.text_model_name
53
+ self.protein_model_name = self.hparams.protein_model_name
54
+ self.qformer_model_name = self.hparams.qformer_model_name
55
+ self.cache_dir = self.hparams.cache_dir
56
+ self.learning_rate = self.hparams.learning_rate
57
+ self.weight_decay = self.hparams.weight_decay
58
+ self.text_model_finetune = self.hparams.text_model_finetune
59
+ self.protein_model_finetune = self.hparams.protein_model_finetune
60
+ self.lora_rank = self.hparams.lora_rank
61
+ self.lora_alpha = self.hparams.lora_alpha
62
+ self.lora_dropout = self.hparams.lora_dropout
63
+ self.max_length_protein = self.hparams.max_length_protein
64
+ self.max_length_text = self.hparams.max_length_text
65
+ self.num_query_tokens = self.hparams.num_query_tokens
66
+ self.return_answer_in_batch = self.hparams.return_answer_in_batch
67
+ self.merge_val_test_set = self.hparams.merge_val_test_set
68
+
69
+ # Store dataset configuration
70
+ self.dataset_type = self.hparams.dataset_type
71
+
72
+ # Load model
73
+ self.model = ProteinLLMModel(
74
+ text_model_name=self.text_model_name,
75
+ protein_model_name=self.protein_model_name,
76
+ qformer_model_name=self.qformer_model_name,
77
+ cache_dir=self.cache_dir,
78
+ max_length_protein=self.max_length_protein,
79
+ max_length_text=self.max_length_text,
80
+ text_model_finetune=self.text_model_finetune,
81
+ protein_model_finetune=self.protein_model_finetune,
82
+ num_query_tokens=self.num_query_tokens,
83
+ )
84
+
85
+ self.text_model = self.model.text_model
86
+ self.protein_model = self.model.protein_model
87
+ self.protein_projection = self.model.protein_projection
88
+
89
+ # Load tokenizer for target text
90
+ self.tokenizer = self.model.text_tokenizer
91
+
92
+ # Prepare model for training
93
+ self.lora_config = self._prep_for_training()
94
+
95
+ def _get_target_modules(self):
96
+ # Apply LoRA to all linear layers in the text model
97
+ target_modules = []
98
+
99
+ # Get all unique linear layer names
100
+ seen_names = set()
101
+ for name, module in self.text_model.named_modules():
102
+ if isinstance(module, torch.nn.Linear):
103
+ names = name.split(".")
104
+ target_name = names[-1] # Use the last part of the name
105
+
106
+ # Skip output head but include all other linear layers
107
+ if target_name != "lm_head" and target_name not in seen_names:
108
+ target_modules.append(target_name)
109
+ seen_names.add(target_name)
110
+
111
+ # Add attention-specific layers
112
+ attention_patterns = [
113
+ "q_proj",
114
+ "k_proj",
115
+ "v_proj",
116
+ "out_proj",
117
+ "query",
118
+ "key",
119
+ "value",
120
+ ]
121
+ for pattern in attention_patterns:
122
+ if pattern not in seen_names:
123
+ target_modules.append(pattern)
124
+
125
+ # Return all unique layer names to apply LoRA to all layers
126
+ return list(target_modules)
127
+
128
+ def _prep_for_training(self) -> LoraConfig:
129
+ """
130
+ Load and configure the ProteinLLMModel.
131
+ """
132
+
133
+ # Freeze protein encoder parameters
134
+ if self.protein_model_finetune:
135
+ pass
136
+ else:
137
+ for param in self.protein_model.parameters():
138
+ param.requires_grad = False
139
+
140
+ if self.text_model_finetune:
141
+ target_modules = self._get_target_modules()
142
+
143
+ lora_config = LoraConfig(
144
+ r=self.lora_rank,
145
+ lora_alpha=self.lora_alpha,
146
+ lora_dropout=self.lora_dropout,
147
+ target_modules=target_modules,
148
+ init_lora_weights="gaussian",
149
+ bias="none",
150
+ task_type="CAUSAL_LM",
151
+ )
152
+
153
+ # Prepare text model for training
154
+ self.text_model = prepare_model_for_kbit_training(self.text_model)
155
+ self.text_model = get_peft_model(self.text_model, lora_config)
156
+ else:
157
+ # Freeze text model parameters
158
+ for param in self.text_model.parameters():
159
+ param.requires_grad = False
160
+ lora_config = None
161
+
162
+ # Make projection layer trainable
163
+ for param in self.protein_projection.parameters():
164
+ param.requires_grad = True
165
+
166
+ return lora_config
167
+
168
+ def _step(self, batch: Dict, batch_idx: int, prefix: str) -> torch.Tensor:
169
+ """
170
+ Performs a single step for training, validation, or testing.
171
+
172
+ Args:
173
+ batch: Dictionary containing the batch data
174
+ batch_idx: Integer indicating the batch index
175
+ prefix: String indicating the step type ('train', 'val', or 'test')
176
+
177
+ Returns:
178
+ torch.Tensor: The computed loss for this batch
179
+ """
180
+ if prefix == "test":
181
+ return {"loss": torch.tensor(0.0, device=self.device)}
182
+
183
+ # Get batch data from the collate function
184
+ input_ids = batch["input_ids"].to(self.device)
185
+ attention_mask = batch["attention_mask"].to(self.device)
186
+ labels = batch["labels"].to(self.device) if "labels" in batch else None
187
+ protein_tokenized = batch.get("protein_tokenized")
188
+ if protein_tokenized is not None:
189
+ protein_tokenized = protein_tokenized.to(self.device)
190
+ batch_idx_map = batch.get("batch_idx_map")
191
+
192
+ # Forward pass through the model
193
+ outputs = self.model(
194
+ input_ids=input_ids,
195
+ attention_mask=attention_mask,
196
+ protein_tokenized=protein_tokenized,
197
+ batch_idx_map=batch_idx_map,
198
+ labels=labels,
199
+ )
200
+
201
+ # Get the loss from model outputs
202
+ loss = outputs.loss
203
+
204
+ # Occasionally show generations for debugging purposes - ONLY during training/validation
205
+ if (prefix == "train" and (self.global_step % 3000 == 0)) or (prefix == "val" and (batch_idx % 300 == 0)):
206
+ try:
207
+ # Select first example from batch for demonstration
208
+ example_idx = 0
209
+
210
+ print(
211
+ f"\n=== Sample Generation (step {self.global_step} / {self.trainer.estimated_stepping_batches}) ==="
212
+ )
213
+
214
+ # Get the tokens that define the assistant pattern
215
+ assistant_start_marker = "<|im_start|>assistant\n"
216
+ assistant_marker_tokens = self.tokenizer.encode(assistant_start_marker, add_special_tokens=False)
217
+ marker_tensor = torch.tensor(assistant_marker_tokens, device=input_ids.device)
218
+ marker_len = len(assistant_marker_tokens)
219
+
220
+ # Find non-padding tokens in input
221
+ non_pad = (input_ids[example_idx] != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0]
222
+ if len(non_pad) > 0:
223
+ start_idx = non_pad[0].item() # First non-padding token
224
+ else:
225
+ start_idx = 0
226
+
227
+ # For each position, check if the next marker_len tokens match the pattern
228
+ matches = []
229
+ for pos in range(start_idx, input_ids.size(1) - marker_len + 1):
230
+ if torch.all(input_ids[example_idx, pos : pos + marker_len] == marker_tensor):
231
+ matches.append(pos)
232
+ break # Stop at first match
233
+
234
+ assistant_pos = matches[0] if matches else None
235
+
236
+ if assistant_pos is not None:
237
+ # Get input up to and including the assistant marker
238
+ gen_input_ids = input_ids[
239
+ example_idx : example_idx + 1, start_idx : assistant_pos + marker_len
240
+ ]
241
+ gen_attention_mask = attention_mask[
242
+ example_idx : example_idx + 1, start_idx : assistant_pos + marker_len
243
+ ]
244
+
245
+ # Extract protein data for this example
246
+ example_protein_data = None
247
+ example_batch_map = None
248
+
249
+ if protein_tokenized is not None and batch_idx_map is not None:
250
+ # Find protein sequences for this example
251
+ example_indices = [i for i, idx in enumerate(batch_idx_map) if idx == example_idx]
252
+
253
+ if len(example_indices) > 0:
254
+ # Extract just this example's protein data
255
+ example_protein_data = BatchEncoding(
256
+ {
257
+ "input_ids": protein_tokenized.input_ids[example_indices].to(self.device),
258
+ "attention_mask": protein_tokenized.attention_mask[example_indices].to(self.device),
259
+ }
260
+ )
261
+
262
+ # For generation we need all sequences mapped to index 0
263
+ example_batch_map = [0] * len(example_indices)
264
+
265
+ # Generate text
266
+ with torch.no_grad():
267
+ generated = self.model.generate(
268
+ input_ids=gen_input_ids,
269
+ attention_mask=gen_attention_mask,
270
+ protein_tokenized=example_protein_data,
271
+ batch_idx_map=example_batch_map,
272
+ max_new_tokens=800,
273
+ temperature=0.6,
274
+ top_p=0.95,
275
+ top_k=20,
276
+ do_sample=True,
277
+ )
278
+
279
+ # Decode and display
280
+ user_input = self.tokenizer.decode(gen_input_ids[0], skip_special_tokens=False).strip()
281
+ generation = self.tokenizer.decode(generated[0], skip_special_tokens=False).strip()
282
+
283
+ # Free memory early
284
+ del generated, gen_input_ids, gen_attention_mask, example_protein_data, example_batch_map
285
+ gc.collect()
286
+
287
+ print(f"=====[Sample {prefix} {batch_idx}]=====")
288
+ print(f"=====[User input]=====\n{user_input}")
289
+ print(f"=====[Complete generation]=====\n{generation}")
290
+
291
+ # Get ground truth if available
292
+ ground_truth = ""
293
+ if labels is not None:
294
+ # Find all positions where we have valid labels (not -100)
295
+ valid_label_pos = (labels[example_idx] != -100).nonzero(as_tuple=True)[0]
296
+
297
+ if len(valid_label_pos) > 0:
298
+ # Check if valid labels start after assistant marker
299
+ if valid_label_pos[0] >= assistant_pos + marker_len:
300
+ ground_truth = self.tokenizer.decode(
301
+ input_ids[example_idx, valid_label_pos], skip_special_tokens=False
302
+ ).strip()
303
+ print(f"=====[Ground truth]=====\n{ground_truth}")
304
+
305
+ # Log to wandb
306
+ timestamp = time.time()
307
+ step_id = f"gen_{self.global_step}-{timestamp}"
308
+ wandb_logger = self.logger.experiment
309
+ wandb_logger.log(
310
+ {
311
+ step_id: wandb.Table(
312
+ columns=["timestamp", "prefix", "batch_idx", "user_input", "generation", "ground_truth"],
313
+ data=[[timestamp, prefix, batch_idx, user_input, generation, ground_truth]],
314
+ )
315
+ }
316
+ )
317
+
318
+ # Clean up memory
319
+ del user_input, generation, ground_truth
320
+ torch.cuda.empty_cache()
321
+ gc.collect()
322
+
323
+ else:
324
+ print("No assistant marker found in the input sequence")
325
+
326
+ except Exception as e:
327
+ print(f"Error during sample generation: {str(e)}")
328
+ traceback.print_exc()
329
+
330
+ # Get current learning rate (skip during test as scheduler might not be available)
331
+ if prefix != "test":
332
+ current_lr = self.lr_schedulers().get_last_lr()[0]
333
+ else:
334
+ current_lr = 0
335
+
336
+ # Logging metrics
337
+ self.log(
338
+ f"{prefix}_loss",
339
+ loss,
340
+ on_step=True,
341
+ on_epoch=False,
342
+ prog_bar=True,
343
+ logger=True,
344
+ )
345
+ self.log(
346
+ f"{prefix}_loss_epoch",
347
+ loss,
348
+ on_step=False,
349
+ on_epoch=True,
350
+ prog_bar=True,
351
+ logger=True,
352
+ sync_dist=True,
353
+ )
354
+
355
+ # Only log learning rate during training/validation
356
+ if prefix != "test":
357
+ self.log(
358
+ "lr",
359
+ current_lr,
360
+ on_step=True,
361
+ on_epoch=True,
362
+ prog_bar=True,
363
+ logger=True,
364
+ sync_dist=True,
365
+ )
366
+
367
+ return loss
368
+
369
+ def training_step(self, batch: Dict, batch_idx: int) -> torch.Tensor:
370
+ """Perform a single training step."""
371
+ return self._step(batch, batch_idx, prefix="train")
372
+
373
+ def validation_step(self, batch: Dict, batch_idx: int) -> torch.Tensor:
374
+ """Perform a single validation step."""
375
+ return self._step(batch, batch_idx, prefix="val")
376
+
377
+ def test_step(self, batch: Dict, batch_idx: int) -> torch.Tensor:
378
+ """Perform a single test step."""
379
+ return self._step(batch, batch_idx, prefix="test")
380
+
381
+ def configure_optimizers(self):
382
+ """
383
+ Configure optimizers and learning rate schedulers.
384
+
385
+ Returns:
386
+ Tuple[List, List]: A tuple containing a list of optimizers and schedulers
387
+ """
388
+ optimizer = AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
389
+
390
+ total_steps = self.trainer.estimated_stepping_batches
391
+ warmup_steps = int(0.1 * total_steps)
392
+
393
+ scheduler = get_cosine_schedule_with_warmup(
394
+ optimizer,
395
+ num_warmup_steps=warmup_steps,
396
+ num_training_steps=total_steps,
397
+ )
398
+
399
+ return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
400
+
401
+ def train_dataloader(self) -> DataLoader:
402
+ """Create and return the training DataLoader."""
403
+ # Load dataset based on type specified in hyperparameters
404
+
405
+ if self.hparams.dataset_type == "protein_function":
406
+ # Use Hugging Face dataset if provided
407
+ dataset = load_dataset(self.hparams.protein_function_data_dir_huggingface)
408
+ dataset = dataset.map(get_format_protein_function(self.hparams.model_type))
409
+
410
+ labels = []
411
+ for split, data in dataset.items():
412
+ labels.extend(data["answer"])
413
+ self.labels = sorted(list(set(labels)))
414
+
415
+ train_dataset = dataset["train"]
416
+
417
+ if self.hparams.truncate_protein_per_side:
418
+ train_dataset = train_dataset.map(
419
+ truncate_protein, fn_kwargs={"truncate_protein_per_side": self.hparams.truncate_protein_per_side}
420
+ )
421
+
422
+ processor = ProteinLLMProcessor(
423
+ tokenizer=self.model.text_tokenizer,
424
+ protein_tokenizer=self.model.protein_tokenizer,
425
+ )
426
+
427
+ # Create partial function with all required arguments except the batch
428
+ collate_fn = partial(
429
+ protein_llm_collate_fn,
430
+ processor=processor,
431
+ max_length_text=self.max_length_text,
432
+ max_length_protein=self.max_length_protein,
433
+ return_answer_in_batch=self.return_answer_in_batch,
434
+ )
435
+
436
+ else:
437
+ raise ValueError(f"Unknown dataset type: {self.hparams.dataset_type}")
438
+
439
+ return DataLoader(
440
+ train_dataset,
441
+ batch_size=self.hparams.batch_size,
442
+ shuffle=True,
443
+ collate_fn=collate_fn,
444
+ num_workers=self.hparams.num_workers,
445
+ persistent_workers=False,
446
+ pin_memory=False,
447
+ )
448
+
449
+ def val_dataloader(self) -> DataLoader:
450
+ """Create and return the validation DataLoader."""
451
+
452
+ if self.hparams.dataset_type == "protein_function":
453
+ # Use Hugging Face dataset
454
+ dataset = load_dataset(self.hparams.protein_function_data_dir_huggingface)
455
+ dataset = dataset.map(get_format_protein_function(self.hparams.model_type))
456
+
457
+ if self.hparams.merge_val_test_set:
458
+ val_dataset = concatenate_datasets([dataset['test'], dataset['val']])
459
+ else:
460
+ val_dataset = dataset["val"]
461
+
462
+ labels = []
463
+ for split, data in dataset.items():
464
+ labels.extend(data["answer"])
465
+ self.labels = sorted(list(set(labels)))
466
+
467
+ if self.hparams.truncate_protein_per_side:
468
+ val_dataset = val_dataset.map(
469
+ truncate_protein, fn_kwargs={"truncate_protein_per_side": self.hparams.truncate_protein_per_side}
470
+ )
471
+
472
+ processor = ProteinLLMProcessor(
473
+ tokenizer=self.model.text_tokenizer,
474
+ protein_tokenizer=self.model.protein_tokenizer,
475
+ )
476
+
477
+ # Create partial function with all required arguments except the batch
478
+ collate_fn = partial(
479
+ protein_llm_collate_fn,
480
+ processor=processor,
481
+ max_length_text=self.max_length_text,
482
+ max_length_protein=self.max_length_protein,
483
+ return_answer_in_batch=self.return_answer_in_batch,
484
+ )
485
+
486
+ else:
487
+ raise ValueError(f"Unknown dataset type: {self.hparams.dataset_type}")
488
+
489
+ return DataLoader(
490
+ val_dataset,
491
+ batch_size=self.hparams.batch_size,
492
+ shuffle=False,
493
+ collate_fn=collate_fn,
494
+ num_workers=self.hparams.num_workers,
495
+ persistent_workers=False,
496
+ pin_memory=False,
497
+ )
498
+
499
+ def test_dataloader(self) -> DataLoader:
500
+ """Create and return the test DataLoader."""
501
+ return self.val_dataloader()
502
+
503
+ # For protein function datasets, use the resulting generations in W&B
504
+ def on_test_epoch_end(self):
505
+ """
506
+ Called at the end of test epoch to generate text for all test examples
507
+ and calculate accuracy based on whether the label appears in the generated response.
508
+ """
509
+ # Get wandb logger
510
+ wandb_logger = self.logger.experiment
511
+ wandb_logger.log({"test_progress": 0.0, "status": "starting test generation"})
512
+
513
+ # Set model to eval mode
514
+ self.model.eval()
515
+
516
+ # Get test dataloader
517
+ test_dataloader = self.test_dataloader()
518
+ total_batches = len(test_dataloader)
519
+
520
+ # Get negative and positive labels
521
+ neg_label = self.labels[0] if len(self.labels) > 0 else "negative"
522
+ pos_label = self.labels[1] if len(self.labels) > 1 else "positive"
523
+
524
+ # Log label information
525
+ wandb_logger.log({
526
+ "positive_label": pos_label,
527
+ "negative_label": neg_label
528
+ })
529
+ print(f"Using labels - Positive: '{pos_label}', Negative: '{neg_label}'")
530
+
531
+ # Initialize counters and storage for generations
532
+ total_examples = 0
533
+ correct_predictions = 0
534
+ processed_batches = 0
535
+ generations = []
536
+
537
+ # Process each batch in the test dataloader
538
+ for batch_idx, batch in enumerate(test_dataloader):
539
+ # Log batch start to wandb
540
+ wandb_logger.log({
541
+ "test_progress": batch_idx / total_batches,
542
+ "status": f"processing batch {batch_idx}/{total_batches}"
543
+ })
544
+
545
+ # Get batch data
546
+ input_ids = batch["input_ids"].to(self.device)
547
+ attention_mask = batch["attention_mask"].to(self.device)
548
+ answer = batch["answer"]
549
+ protein_tokenized = batch.get("protein_tokenized")
550
+ if protein_tokenized is not None:
551
+ protein_tokenized = protein_tokenized.to(self.device)
552
+ batch_idx_map = batch.get("batch_idx_map")
553
+
554
+ # Get assistant marker position
555
+ assistant_start_marker = "<|im_start|>assistant\n"
556
+ assistant_marker_tokens = self.tokenizer.encode(assistant_start_marker, add_special_tokens=False)
557
+ marker_tensor = torch.tensor(assistant_marker_tokens, device=input_ids.device)
558
+ marker_len = len(assistant_marker_tokens)
559
+
560
+ # Process examples in the batch
561
+ examples_in_batch = 0
562
+ for example_idx in range(input_ids.size(0)):
563
+ # Find non-padding tokens
564
+ non_pad = (input_ids[example_idx] != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0]
565
+ start_idx = non_pad[0].item() if len(non_pad) > 0 else 0
566
+
567
+ # Find assistant marker position
568
+ assistant_pos = None
569
+ for pos in range(start_idx, input_ids.size(1) - marker_len + 1):
570
+ if torch.all(input_ids[example_idx, pos:pos + marker_len] == marker_tensor):
571
+ assistant_pos = pos
572
+ break
573
+
574
+ if assistant_pos is not None:
575
+ # Prepare input for generation
576
+ gen_input_ids = input_ids[example_idx:example_idx + 1, start_idx:assistant_pos + marker_len]
577
+ gen_attention_mask = attention_mask[example_idx:example_idx + 1, start_idx:assistant_pos + marker_len]
578
+
579
+ # Extract protein data for this example
580
+ example_protein_data = None
581
+ example_batch_map = None
582
+
583
+ if protein_tokenized is not None and batch_idx_map is not None:
584
+ example_indices = [i for i, idx in enumerate(batch_idx_map) if idx == example_idx]
585
+
586
+ if example_indices:
587
+ example_protein_data = BatchEncoding({
588
+ "input_ids": protein_tokenized.input_ids[example_indices].to(self.device),
589
+ "attention_mask": protein_tokenized.attention_mask[example_indices].to(self.device),
590
+ })
591
+ example_batch_map = [0] * len(example_indices)
592
+
593
+ # Generate text
594
+ with torch.no_grad():
595
+ generated = self.model.generate(
596
+ input_ids=gen_input_ids,
597
+ attention_mask=gen_attention_mask,
598
+ protein_tokenized=example_protein_data,
599
+ batch_idx_map=example_batch_map,
600
+ max_new_tokens=800,
601
+ temperature=0.6,
602
+ top_p=0.95,
603
+ top_k=20,
604
+ do_sample=True,
605
+ )
606
+
607
+ # Decode user input and generated text
608
+ user_input = self.tokenizer.decode(gen_input_ids[0], skip_special_tokens=False).strip()
609
+ generation = self.tokenizer.decode(generated[0], skip_special_tokens=False).strip()
610
+
611
+ # Get ground truth and clean it if needed
612
+ ground_truth = answer[example_idx]
613
+ if ";" in ground_truth:
614
+ ground_truth = ground_truth.split(";")[0]
615
+
616
+ # Check if the generated text contains the ground truth
617
+ generation_contains_ground_truth = ground_truth.lower() in generation.lower()
618
+
619
+ # Update metrics
620
+ total_examples += 1
621
+ examples_in_batch += 1
622
+
623
+ if generation_contains_ground_truth:
624
+ correct_predictions += 1
625
+
626
+ # Store generation data
627
+ generations.append({
628
+ "batch_idx": batch_idx,
629
+ "example_idx": example_idx,
630
+ "user_input": user_input,
631
+ "generation": generation,
632
+ "ground_truth": ground_truth,
633
+ "contains_ground_truth": generation_contains_ground_truth,
634
+ })
635
+
636
+ # Clean up memory
637
+ torch.cuda.empty_cache()
638
+ gc.collect()
639
+
640
+ # Log batch completion to wandb
641
+ processed_batches += 1
642
+ current_accuracy = correct_predictions / max(total_examples, 1)
643
+
644
+ wandb_logger.log({
645
+ "batches_processed": processed_batches,
646
+ "examples_processed": total_examples,
647
+ "examples_in_last_batch": examples_in_batch,
648
+ "current_accuracy": current_accuracy,
649
+ "progress_percentage": (batch_idx + 1) / total_batches * 100
650
+ })
651
+
652
+ # Calculate final metrics
653
+ accuracy = correct_predictions / max(total_examples, 1)
654
+
655
+ # Log final metrics to wandb
656
+ wandb_logger.log({
657
+ "test_accuracy": accuracy,
658
+ "correct_predictions": correct_predictions,
659
+ "total_examples_processed": total_examples,
660
+ "test_status": "completed"
661
+ })
662
+
663
+ # Create a table with all the generations
664
+ if generations:
665
+ columns = [
666
+ "batch_idx",
667
+ "example_idx",
668
+ "user_input",
669
+ "generation",
670
+ "ground_truth",
671
+ "contains_ground_truth"
672
+ ]
673
+ data = []
674
+ for g in generations:
675
+ row = [g.get(c, "") for c in columns]
676
+ data.append(row)
677
+
678
+ wandb_logger.log({
679
+ f"test_generations_{time.strftime('%Y%m%d-%H%M%S')}:": wandb.Table(columns=columns, data=data)
680
+ })
681
+
682
+ # Save generations to a CSV file
683
+ model_name = self.hparams.text_model_name.split('/')[-1]
684
+ if self.hparams.ckpt_path:
685
+ csv_path = os.path.join(self.hparams.ckpt_path, f"{time.strftime('%Y%m%d-%H%M%S')}-test_generations_{model_name}.csv")
686
+ else:
687
+ csv_path = os.path.join(self.hparams.checkpoint_dir, f"{time.strftime('%Y%m%d-%H%M%S')}-test_generations_{model_name}.csv")
688
+
689
+ try:
690
+ with open(csv_path, 'w', newline='', encoding='utf-8') as f:
691
+ if generations:
692
+ writer = csv.DictWriter(f, fieldnames=generations[0].keys())
693
+ writer.writeheader()
694
+ for g in generations:
695
+ writer.writerow(g)
696
+
697
+ wandb_logger.log({"csv_saved": True, "csv_path": csv_path})
698
+ except Exception as e:
699
+ wandb_logger.log({"csv_saved": False, "csv_path": csv_path, "error": str(e)})
700
+
701
+ # Log a summary of the metrics
702
+ summary = (
703
+ f"Test Results Summary:\n"
704
+ f"Total examples: {total_examples}\n"
705
+ f"Accuracy: {accuracy:.4f}\n"
706
+ f"Correct: {correct_predictions}\n"
707
+ )
708
+ print(summary)
709
+ wandb_logger.log({"test_summary": summary})
710
+
711
+ # Force garbage collection
712
+ torch.cuda.empty_cache()
713
+ gc.collect()
714
+
715
+ return {
716
+ "test_accuracy": accuracy,
717
+ }
718
+
719
+
720
+ def main(args: ArgumentParser):
721
+ """
722
+ Main function to run the Protein-Text fine-tuning process.
723
+
724
+ Args:
725
+ args (ArgumentParser): Parsed command-line arguments
726
+ """
727
+ # Set random seed and environment variables
728
+ pl.seed_everything(args.seed)
729
+ torch.cuda.empty_cache()
730
+ torch.set_float32_matmul_precision("medium")
731
+
732
+ # Setup directories
733
+ run_name = f"{args.wandb_project}-{args.dataset_type}-{args.text_model_name.split('/')[-1]}"
734
+ args.checkpoint_dir = f"{args.checkpoint_dir}/{run_name}-{time.strftime('%Y%m%d-%H%M%S')}"
735
+
736
+ # Initialize model
737
+ model = ProteinLLMFineTuner(args)
738
+
739
+ # Setup callbacks
740
+ callbacks = [
741
+ ModelCheckpoint(
742
+ dirpath=args.checkpoint_dir,
743
+ filename=f"{run_name}-" + "{epoch:02d}-{val_loss_epoch:.4f}",
744
+ save_top_k=2,
745
+ monitor="val_loss_epoch",
746
+ mode="min",
747
+ save_last=True,
748
+ ),
749
+ LearningRateMonitor(logging_interval="step"),
750
+ ]
751
+
752
+ # Setup logger
753
+ is_resuming = args.ckpt_path is not None
754
+ logger = WandbLogger(
755
+ project=args.wandb_project,
756
+ entity=args.wandb_entity,
757
+ save_dir=args.log_dir,
758
+ name=run_name,
759
+ resume="allow" if is_resuming else None,
760
+ )
761
+
762
+ # Initialize the PyTorch Lightning Trainer
763
+ trainer = pl.Trainer(
764
+ max_epochs=args.max_epochs,
765
+ accelerator="gpu",
766
+ devices=args.num_gpus,
767
+ strategy=(
768
+ "ddp"
769
+ if args.strategy == "ddp"
770
+ else DeepSpeedStrategy(stage=2, offload_optimizer=False, allgather_bucket_size=5e8, reduce_bucket_size=5e8)
771
+ ),
772
+ precision="bf16-mixed",
773
+ callbacks=callbacks,
774
+ logger=logger,
775
+ deterministic=False,
776
+ enable_checkpointing=True,
777
+ enable_progress_bar=True,
778
+ enable_model_summary=True,
779
+ log_every_n_steps=5,
780
+ accumulate_grad_batches=args.gradient_accumulation_steps,
781
+ gradient_clip_val=1.0,
782
+ val_check_interval=1 / 3,
783
+ )
784
+
785
+ # Start the training process
786
+ trainer.fit(model, ckpt_path=args.ckpt_path)
787
+ trainer.test(model, ckpt_path=args.ckpt_path if args.ckpt_path else "best")
788
+
789
+ if __name__ == "__main__":
790
+ parser = ArgumentParser()
791
+
792
+ # Model configuration
793
+ parser.add_argument("--model_type", type=str, choices=["llm", "protein-llm"], default="protein-llm")
794
+ parser.add_argument("--text_model_name", type=str, default="Qwen/Qwen3-1.7B")
795
+ parser.add_argument("--protein_model_name", type=str, default="facebook/esm2_t6_8M_UR50D")
796
+ parser.add_argument("--qformer_model_name", type=str, default="microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext")
797
+ parser.add_argument("--text_model_finetune", type=bool, default=True)
798
+ parser.add_argument("--protein_model_finetune", type=bool, default=False)
799
+ parser.add_argument("--num_query_tokens", type=int, default=32)
800
+
801
+ # Training parameters
802
+ parser.add_argument("--seed", type=int, default=23)
803
+ parser.add_argument("--batch_size", type=int, default=1)
804
+ parser.add_argument("--max_epochs", type=int, default=5)
805
+ parser.add_argument("--learning_rate", type=float, default=5e-5)
806
+ parser.add_argument("--weight_decay", type=float, default=0.01)
807
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=8)
808
+ parser.add_argument("--max_length_protein", type=int, default=1024)
809
+ parser.add_argument("--max_length_text", type=int, default=1024)
810
+ parser.add_argument("--truncate_protein_per_side", type=int, default=1024)
811
+ parser.add_argument("--return_answer_in_batch", type=bool, default=False)
812
+
813
+ # LoRA parameters
814
+ parser.add_argument("--lora_rank", type=int, default=32)
815
+ parser.add_argument("--lora_alpha", type=int, default=64)
816
+ parser.add_argument("--lora_dropout", type=float, default=0.05)
817
+
818
+ # Infrastructure and paths
819
+ parser.add_argument("--checkpoint_dir", type=str, default="checkpoints")
820
+ parser.add_argument("--log_dir", type=str, default="logs")
821
+ parser.add_argument("--cache_dir", type=str, default="/model-weights")
822
+ parser.add_argument("--ckpt_path", type=str, default=None)
823
+ parser.add_argument("--num_workers", type=int, default=4)
824
+ parser.add_argument("--num_gpus", type=int, default=1)
825
+ parser.add_argument("--strategy", type=str, default="ddp")
826
+
827
+ # Dataset configuration
828
+ parser.add_argument("--dataset_type", type=str, choices=["protein_function"], default="protein_function")
829
+ parser.add_argument("--use_protein_llm_collate_fn", type=bool, default=True)
830
+ parser.add_argument("--protein_function_data_dir_huggingface", type=str, default="wanglab/protein_function")
831
+ parser.add_argument("--merge_val_test_set", type=bool, default=False)
832
+
833
+ # Logging and monitoring
834
+ parser.add_argument("--wandb_project", type=str, default="esm2-qwen3-1.7b-finetune")
835
+ parser.add_argument("--wandb_entity", type=str)
836
+
837
+ args = parser.parse_args()
838
+
839
+ main(args)
BioReason_new/wandb/debug-internal.log ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"time":"2025-08-13T15:33:18.937087932+08:00","level":"INFO","msg":"stream: starting","core version":"0.21.1"}
2
+ {"time":"2025-08-13T15:33:22.125246419+08:00","level":"INFO","msg":"stream: created new stream","id":"ig4rhoqf"}
3
+ {"time":"2025-08-13T15:33:22.126020019+08:00","level":"INFO","msg":"stream: started","id":"ig4rhoqf"}
4
+ {"time":"2025-08-13T15:33:22.12605541+08:00","level":"INFO","msg":"writer: started","stream_id":"ig4rhoqf"}
5
+ {"time":"2025-08-13T15:33:22.126066944+08:00","level":"INFO","msg":"handler: started","stream_id":"ig4rhoqf"}
6
+ {"time":"2025-08-13T15:33:22.126093203+08:00","level":"INFO","msg":"sender: started","stream_id":"ig4rhoqf"}
7
+ {"time":"2025-08-13T15:33:29.266636932+08:00","level":"ERROR","msg":"request failed","error":"Put \"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/ig4rhoqf/wandb-metadata.json?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250813%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250813T073326Z&X-Goog-Expires=86399&X-Goog-Signature=1c9fbc6cbd51ae71160b439aa8a692f95ee16291fd45e2c33912f90a5ae3a56a09f33e745a416b04da501c1cf40827caa3399f6cde0752ce3f571c12353aa477d27f258d11b74762a029b2fada9797f27a85c7fe03e3bb1ec12fb98789a6b389cf000b2742ca658544ea41f22c9631272e67f2ea275c4a5b7b9c516fed6ed22ccce6b26b684c341e337963a684be9b248bd64b665fb73437d28ce7365767fe6721770c9f8f75e150767d9e4154a983b7f6aafcc31bbc050a26df2e9a3b99e065c111d7a92525568b04cf8d8a7490e6b8f77aed12c26aa378d0708385b0f40b22af3988b5c6cfb25918b1c9471be072bdfdc7a986ca0774f732d27361d23f3448&X-Goog-SignedHeaders=host&X-User=gia0603yucca\": read tcp 10.1.1.90:52222->142.250.73.123:443: read: connection reset by peer","method":"PUT","url":"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/ig4rhoqf/wandb-metadata.json?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250813%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250813T073326Z&X-Goog-Expires=86399&X-Goog-Signature=1c9fbc6cbd51ae71160b439aa8a692f95ee16291fd45e2c33912f90a5ae3a56a09f33e745a416b04da501c1cf40827caa3399f6cde0752ce3f571c12353aa477d27f258d11b74762a029b2fada9797f27a85c7fe03e3bb1ec12fb98789a6b389cf000b2742ca658544ea41f22c9631272e67f2ea275c4a5b7b9c516fed6ed22ccce6b26b684c341e337963a684be9b248bd64b665fb73437d28ce7365767fe6721770c9f8f75e150767d9e4154a983b7f6aafcc31bbc050a26df2e9a3b99e065c111d7a92525568b04cf8d8a7490e6b8f77aed12c26aa378d0708385b0f40b22af3988b5c6cfb25918b1c9471be072bdfdc7a986ca0774f732d27361d23f3448&X-Goog-SignedHeaders=host&X-User=gia0603yucca"}
8
+ {"time":"2025-08-13T15:33:31.205628552+08:00","level":"ERROR","msg":"request failed","error":"Put \"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/ig4rhoqf/requirements.txt?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250813%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250813T073326Z&X-Goog-Expires=86399&X-Goog-Signature=27041959870a9e2ba8f974ca28bb5723ad374fc1917b982522ea3f555344b25fed8f73d5b443663eb1d7d10499d748efa90c8f6d91802daabf9e8c31258a7371e93aca1694d3e8bcb35a72d7c9ca243dc164f5dfae6ec3009c25cc78ea8cf37629f017b7538998f44f7c65ccfb675e343601421475a4490c754c5ee0370d3d5dfa928faddfe9a90621302ce69efd3d26c51f49c23f92148b018281ccd02f22c42e73e318594ea9c2ff9b25ad13163b60c37f1ededb4dc3a50712bcabdffed71883e2e5b5f04c40f13612d9ff1b6762b9f79d6e19873e959fbf9495e4827e901e8dbd25d7b6f291841915bc428212ad0a142cc3a9b7d04dbba9b49873573481fa&X-Goog-SignedHeaders=host&X-User=gia0603yucca\": read tcp 10.1.1.90:57666->142.250.73.91:443: read: connection reset by peer","method":"PUT","url":"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/ig4rhoqf/requirements.txt?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250813%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250813T073326Z&X-Goog-Expires=86399&X-Goog-Signature=27041959870a9e2ba8f974ca28bb5723ad374fc1917b982522ea3f555344b25fed8f73d5b443663eb1d7d10499d748efa90c8f6d91802daabf9e8c31258a7371e93aca1694d3e8bcb35a72d7c9ca243dc164f5dfae6ec3009c25cc78ea8cf37629f017b7538998f44f7c65ccfb675e343601421475a4490c754c5ee0370d3d5dfa928faddfe9a90621302ce69efd3d26c51f49c23f92148b018281ccd02f22c42e73e318594ea9c2ff9b25ad13163b60c37f1ededb4dc3a50712bcabdffed71883e2e5b5f04c40f13612d9ff1b6762b9f79d6e19873e959fbf9495e4827e901e8dbd25d7b6f291841915bc428212ad0a142cc3a9b7d04dbba9b49873573481fa&X-Goog-SignedHeaders=host&X-User=gia0603yucca"}
9
+ {"time":"2025-08-13T15:33:31.959506763+08:00","level":"ERROR","msg":"request failed","error":"Put \"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/ig4rhoqf/wandb-metadata.json?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250813%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250813T073326Z&X-Goog-Expires=86399&X-Goog-Signature=1c9fbc6cbd51ae71160b439aa8a692f95ee16291fd45e2c33912f90a5ae3a56a09f33e745a416b04da501c1cf40827caa3399f6cde0752ce3f571c12353aa477d27f258d11b74762a029b2fada9797f27a85c7fe03e3bb1ec12fb98789a6b389cf000b2742ca658544ea41f22c9631272e67f2ea275c4a5b7b9c516fed6ed22ccce6b26b684c341e337963a684be9b248bd64b665fb73437d28ce7365767fe6721770c9f8f75e150767d9e4154a983b7f6aafcc31bbc050a26df2e9a3b99e065c111d7a92525568b04cf8d8a7490e6b8f77aed12c26aa378d0708385b0f40b22af3988b5c6cfb25918b1c9471be072bdfdc7a986ca0774f732d27361d23f3448&X-Goog-SignedHeaders=host&X-User=gia0603yucca\": read tcp 10.1.1.90:52232->142.250.73.123:443: read: connection reset by peer","method":"PUT","url":"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/ig4rhoqf/wandb-metadata.json?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250813%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250813T073326Z&X-Goog-Expires=86399&X-Goog-Signature=1c9fbc6cbd51ae71160b439aa8a692f95ee16291fd45e2c33912f90a5ae3a56a09f33e745a416b04da501c1cf40827caa3399f6cde0752ce3f571c12353aa477d27f258d11b74762a029b2fada9797f27a85c7fe03e3bb1ec12fb98789a6b389cf000b2742ca658544ea41f22c9631272e67f2ea275c4a5b7b9c516fed6ed22ccce6b26b684c341e337963a684be9b248bd64b665fb73437d28ce7365767fe6721770c9f8f75e150767d9e4154a983b7f6aafcc31bbc050a26df2e9a3b99e065c111d7a92525568b04cf8d8a7490e6b8f77aed12c26aa378d0708385b0f40b22af3988b5c6cfb25918b1c9471be072bdfdc7a986ca0774f732d27361d23f3448&X-Goog-SignedHeaders=host&X-User=gia0603yucca"}
10
+ {"time":"2025-08-13T15:33:33.871387522+08:00","level":"ERROR","msg":"request failed","error":"Put \"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/ig4rhoqf/requirements.txt?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250813%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250813T073326Z&X-Goog-Expires=86399&X-Goog-Signature=27041959870a9e2ba8f974ca28bb5723ad374fc1917b982522ea3f555344b25fed8f73d5b443663eb1d7d10499d748efa90c8f6d91802daabf9e8c31258a7371e93aca1694d3e8bcb35a72d7c9ca243dc164f5dfae6ec3009c25cc78ea8cf37629f017b7538998f44f7c65ccfb675e343601421475a4490c754c5ee0370d3d5dfa928faddfe9a90621302ce69efd3d26c51f49c23f92148b018281ccd02f22c42e73e318594ea9c2ff9b25ad13163b60c37f1ededb4dc3a50712bcabdffed71883e2e5b5f04c40f13612d9ff1b6762b9f79d6e19873e959fbf9495e4827e901e8dbd25d7b6f291841915bc428212ad0a142cc3a9b7d04dbba9b49873573481fa&X-Goog-SignedHeaders=host&X-User=gia0603yucca\": read tcp 10.1.1.90:52238->142.250.73.123:443: read: connection reset by peer","method":"PUT","url":"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/ig4rhoqf/requirements.txt?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250813%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250813T073326Z&X-Goog-Expires=86399&X-Goog-Signature=27041959870a9e2ba8f974ca28bb5723ad374fc1917b982522ea3f555344b25fed8f73d5b443663eb1d7d10499d748efa90c8f6d91802daabf9e8c31258a7371e93aca1694d3e8bcb35a72d7c9ca243dc164f5dfae6ec3009c25cc78ea8cf37629f017b7538998f44f7c65ccfb675e343601421475a4490c754c5ee0370d3d5dfa928faddfe9a90621302ce69efd3d26c51f49c23f92148b018281ccd02f22c42e73e318594ea9c2ff9b25ad13163b60c37f1ededb4dc3a50712bcabdffed71883e2e5b5f04c40f13612d9ff1b6762b9f79d6e19873e959fbf9495e4827e901e8dbd25d7b6f291841915bc428212ad0a142cc3a9b7d04dbba9b49873573481fa&X-Goog-SignedHeaders=host&X-User=gia0603yucca"}
11
+ {"time":"2025-08-13T15:33:36.261477525+08:00","level":"ERROR","msg":"request failed","error":"Put \"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/ig4rhoqf/wandb-metadata.json?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250813%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250813T073326Z&X-Goog-Expires=86399&X-Goog-Signature=1c9fbc6cbd51ae71160b439aa8a692f95ee16291fd45e2c33912f90a5ae3a56a09f33e745a416b04da501c1cf40827caa3399f6cde0752ce3f571c12353aa477d27f258d11b74762a029b2fada9797f27a85c7fe03e3bb1ec12fb98789a6b389cf000b2742ca658544ea41f22c9631272e67f2ea275c4a5b7b9c516fed6ed22ccce6b26b684c341e337963a684be9b248bd64b665fb73437d28ce7365767fe6721770c9f8f75e150767d9e4154a983b7f6aafcc31bbc050a26df2e9a3b99e065c111d7a92525568b04cf8d8a7490e6b8f77aed12c26aa378d0708385b0f40b22af3988b5c6cfb25918b1c9471be072bdfdc7a986ca0774f732d27361d23f3448&X-Goog-SignedHeaders=host&X-User=gia0603yucca\": read tcp 10.1.1.90:57678->142.250.73.91:443: read: connection reset by peer","method":"PUT","url":"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/ig4rhoqf/wandb-metadata.json?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250813%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250813T073326Z&X-Goog-Expires=86399&X-Goog-Signature=1c9fbc6cbd51ae71160b439aa8a692f95ee16291fd45e2c33912f90a5ae3a56a09f33e745a416b04da501c1cf40827caa3399f6cde0752ce3f571c12353aa477d27f258d11b74762a029b2fada9797f27a85c7fe03e3bb1ec12fb98789a6b389cf000b2742ca658544ea41f22c9631272e67f2ea275c4a5b7b9c516fed6ed22ccce6b26b684c341e337963a684be9b248bd64b665fb73437d28ce7365767fe6721770c9f8f75e150767d9e4154a983b7f6aafcc31bbc050a26df2e9a3b99e065c111d7a92525568b04cf8d8a7490e6b8f77aed12c26aa378d0708385b0f40b22af3988b5c6cfb25918b1c9471be072bdfdc7a986ca0774f732d27361d23f3448&X-Goog-SignedHeaders=host&X-User=gia0603yucca"}
12
+ {"time":"2025-08-13T15:33:41.039143093+08:00","level":"INFO","msg":"stream: closing","id":"ig4rhoqf"}
13
+ {"time":"2025-08-13T15:38:00.686011134+08:00","level":"ERROR","msg":"request failed","error":"Put \"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/ig4rhoqf/requirements.txt?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250813%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250813T073326Z&X-Goog-Expires=86399&X-Goog-Signature=27041959870a9e2ba8f974ca28bb5723ad374fc1917b982522ea3f555344b25fed8f73d5b443663eb1d7d10499d748efa90c8f6d91802daabf9e8c31258a7371e93aca1694d3e8bcb35a72d7c9ca243dc164f5dfae6ec3009c25cc78ea8cf37629f017b7538998f44f7c65ccfb675e343601421475a4490c754c5ee0370d3d5dfa928faddfe9a90621302ce69efd3d26c51f49c23f92148b018281ccd02f22c42e73e318594ea9c2ff9b25ad13163b60c37f1ededb4dc3a50712bcabdffed71883e2e5b5f04c40f13612d9ff1b6762b9f79d6e19873e959fbf9495e4827e901e8dbd25d7b6f291841915bc428212ad0a142cc3a9b7d04dbba9b49873573481fa&X-Goog-SignedHeaders=host&X-User=gia0603yucca\": read tcp 10.1.1.90:44768->142.250.217.91:443: read: connection reset by peer","method":"PUT","url":"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/ig4rhoqf/requirements.txt?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250813%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250813T073326Z&X-Goog-Expires=86399&X-Goog-Signature=27041959870a9e2ba8f974ca28bb5723ad374fc1917b982522ea3f555344b25fed8f73d5b443663eb1d7d10499d748efa90c8f6d91802daabf9e8c31258a7371e93aca1694d3e8bcb35a72d7c9ca243dc164f5dfae6ec3009c25cc78ea8cf37629f017b7538998f44f7c65ccfb675e343601421475a4490c754c5ee0370d3d5dfa928faddfe9a90621302ce69efd3d26c51f49c23f92148b018281ccd02f22c42e73e318594ea9c2ff9b25ad13163b60c37f1ededb4dc3a50712bcabdffed71883e2e5b5f04c40f13612d9ff1b6762b9f79d6e19873e959fbf9495e4827e901e8dbd25d7b6f291841915bc428212ad0a142cc3a9b7d04dbba9b49873573481fa&X-Goog-SignedHeaders=host&X-User=gia0603yucca"}
14
+ {"time":"2025-08-13T15:38:04.732136338+08:00","level":"ERROR","msg":"request failed","error":"Put \"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/ig4rhoqf/output.log?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250813%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250813T073342Z&X-Goog-Expires=86399&X-Goog-Signature=1d3cd94b4887f4c0adaa644487d39281fc878818fa20fa0c39b307b4d29a0774e1077acff64b5047465d575c9417f96645240e67e6dab6d1cc25f617e622bcf0f8715629cc8509684074bcd0688360af452c0f89dc80a4e9025500a6e0d16485ec7617abfd0ff10c8ae62cf6012c0f43aad8d30bb7c9de79f87c56f73ccabc572e83285debd09c51e5fcc95e1a7fe1cc011606044e056dc5f2d6756e365041194a3e52cc500d0e71499406481acf9cc760c1da018d07a35e07830dd04208781140f1fc38c86f5aa24bf65e1a58d35406ae84e7026fe6cf59353cf16ca53b51a813beb22417486e4fabc77de1aaef6ee20b755ce21b765571f6bfbf8473942f9d&X-Goog-SignedHeaders=host&X-User=gia0603yucca\": read tcp 10.1.1.90:44778->142.250.217.91:443: read: connection reset by peer","method":"PUT","url":"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/ig4rhoqf/output.log?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250813%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250813T073342Z&X-Goog-Expires=86399&X-Goog-Signature=1d3cd94b4887f4c0adaa644487d39281fc878818fa20fa0c39b307b4d29a0774e1077acff64b5047465d575c9417f96645240e67e6dab6d1cc25f617e622bcf0f8715629cc8509684074bcd0688360af452c0f89dc80a4e9025500a6e0d16485ec7617abfd0ff10c8ae62cf6012c0f43aad8d30bb7c9de79f87c56f73ccabc572e83285debd09c51e5fcc95e1a7fe1cc011606044e056dc5f2d6756e365041194a3e52cc500d0e71499406481acf9cc760c1da018d07a35e07830dd04208781140f1fc38c86f5aa24bf65e1a58d35406ae84e7026fe6cf59353cf16ca53b51a813beb22417486e4fabc77de1aaef6ee20b755ce21b765571f6bfbf8473942f9d&X-Goog-SignedHeaders=host&X-User=gia0603yucca"}
15
+ {"time":"2025-08-13T15:38:04.732752873+08:00","level":"ERROR","msg":"request failed","error":"Put \"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/ig4rhoqf/wandb-summary.json?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250813%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250813T073342Z&X-Goog-Expires=86399&X-Goog-Signature=43378ba7835541b25e7884ddd1295c99cdb59af218b575e11ebf49219cd7b70414883b2eaabf6ce5bf29b6dcf3aa4854cc71ac3f00aceb098b2bd9f008cbf3c3e776aeb0eec09eafbd1633dbdf5cd30ca267e7fc78a69846132727fb83834b13e3aaf2e17cf6af4974dc4f96f982c762fbd056b369608aece84f3c97eed6c02b7648d28a1ab3f3566026c830a6eb4d57a55676a251768961021c9050a7b6974e74c7efd13ca9324ac1a6768bca2daa0daf12b3c8ebc3ad02e5d5c935024048da77f0b6047842ebceaf42ec881dd23660f5f4caad61c04ea8736f4b8abd5c81197ebd10853a45e4944b85d318500c385596344922b43e114698bc05474aad582c&X-Goog-SignedHeaders=host&X-User=gia0603yucca\": read tcp 10.1.1.90:43548->142.250.73.123:443: read: connection reset by peer","method":"PUT","url":"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/ig4rhoqf/wandb-summary.json?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250813%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250813T073342Z&X-Goog-Expires=86399&X-Goog-Signature=43378ba7835541b25e7884ddd1295c99cdb59af218b575e11ebf49219cd7b70414883b2eaabf6ce5bf29b6dcf3aa4854cc71ac3f00aceb098b2bd9f008cbf3c3e776aeb0eec09eafbd1633dbdf5cd30ca267e7fc78a69846132727fb83834b13e3aaf2e17cf6af4974dc4f96f982c762fbd056b369608aece84f3c97eed6c02b7648d28a1ab3f3566026c830a6eb4d57a55676a251768961021c9050a7b6974e74c7efd13ca9324ac1a6768bca2daa0daf12b3c8ebc3ad02e5d5c935024048da77f0b6047842ebceaf42ec881dd23660f5f4caad61c04ea8736f4b8abd5c81197ebd10853a45e4944b85d318500c385596344922b43e114698bc05474aad582c&X-Goog-SignedHeaders=host&X-User=gia0603yucca"}
16
+ {"time":"2025-08-13T15:38:04.830072606+08:00","level":"ERROR","msg":"request failed","error":"Put \"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/ig4rhoqf/config.yaml?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250813%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250813T073342Z&X-Goog-Expires=86399&X-Goog-Signature=c0195bb166fc1589772f7ac8000fd444b8e06323ba6372e465965830037d0e3abe33eb458ac66e538461c44cd0aad54fae2cf3ed0f374b29b4cc12eefa1bd24dbcc6dd7346e5d2fe8238aca5f9f454f181d5839a406a34dc8e519e37ef8c4fc8de26b075401ceec8431d3f6c56345d424b5cace6503127369284bda44efb6fe17a73b78fc162b5452942692cd51024f02807e743229ac8d7ddee751592f125d012c3e4de0790abc7b061c8d8d6916a275902955479dcddb6342607f857379f4552e01426a19bf54b2231de5fb22f340330a26d23e6cae7d90831b2172a8cfdd004a5de1aca9d47db2eded16ebcbeb22d617adfb1a0324e51847993c2374681dd&X-Goog-SignedHeaders=host&X-User=gia0603yucca\": read tcp 10.1.1.90:43546->142.250.73.123:443: read: connection reset by peer","method":"PUT","url":"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/ig4rhoqf/config.yaml?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250813%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250813T073342Z&X-Goog-Expires=86399&X-Goog-Signature=c0195bb166fc1589772f7ac8000fd444b8e06323ba6372e465965830037d0e3abe33eb458ac66e538461c44cd0aad54fae2cf3ed0f374b29b4cc12eefa1bd24dbcc6dd7346e5d2fe8238aca5f9f454f181d5839a406a34dc8e519e37ef8c4fc8de26b075401ceec8431d3f6c56345d424b5cace6503127369284bda44efb6fe17a73b78fc162b5452942692cd51024f02807e743229ac8d7ddee751592f125d012c3e4de0790abc7b061c8d8d6916a275902955479dcddb6342607f857379f4552e01426a19bf54b2231de5fb22f340330a26d23e6cae7d90831b2172a8cfdd004a5de1aca9d47db2eded16ebcbeb22d617adfb1a0324e51847993c2374681dd&X-Goog-SignedHeaders=host&X-User=gia0603yucca"}
17
+ {"time":"2025-08-13T15:38:08.914871723+08:00","level":"ERROR","msg":"request failed","error":"Put \"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/ig4rhoqf/wandb-metadata.json?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250813%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250813T073326Z&X-Goog-Expires=86399&X-Goog-Signature=1c9fbc6cbd51ae71160b439aa8a692f95ee16291fd45e2c33912f90a5ae3a56a09f33e745a416b04da501c1cf40827caa3399f6cde0752ce3f571c12353aa477d27f258d11b74762a029b2fada9797f27a85c7fe03e3bb1ec12fb98789a6b389cf000b2742ca658544ea41f22c9631272e67f2ea275c4a5b7b9c516fed6ed22ccce6b26b684c341e337963a684be9b248bd64b665fb73437d28ce7365767fe6721770c9f8f75e150767d9e4154a983b7f6aafcc31bbc050a26df2e9a3b99e065c111d7a92525568b04cf8d8a7490e6b8f77aed12c26aa378d0708385b0f40b22af3988b5c6cfb25918b1c9471be072bdfdc7a986ca0774f732d27361d23f3448&X-Goog-SignedHeaders=host&X-User=gia0603yucca\": read tcp 10.1.1.90:40682->142.250.73.91:443: read: connection reset by peer","method":"PUT","url":"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/ig4rhoqf/wandb-metadata.json?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250813%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250813T073326Z&X-Goog-Expires=86399&X-Goog-Signature=1c9fbc6cbd51ae71160b439aa8a692f95ee16291fd45e2c33912f90a5ae3a56a09f33e745a416b04da501c1cf40827caa3399f6cde0752ce3f571c12353aa477d27f258d11b74762a029b2fada9797f27a85c7fe03e3bb1ec12fb98789a6b389cf000b2742ca658544ea41f22c9631272e67f2ea275c4a5b7b9c516fed6ed22ccce6b26b684c341e337963a684be9b248bd64b665fb73437d28ce7365767fe6721770c9f8f75e150767d9e4154a983b7f6aafcc31bbc050a26df2e9a3b99e065c111d7a92525568b04cf8d8a7490e6b8f77aed12c26aa378d0708385b0f40b22af3988b5c6cfb25918b1c9471be072bdfdc7a986ca0774f732d27361d23f3448&X-Goog-SignedHeaders=host&X-User=gia0603yucca"}
18
+ {"time":"2025-08-13T15:38:09.512888171+08:00","level":"ERROR","msg":"request failed","error":"Put \"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/ig4rhoqf/requirements.txt?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250813%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250813T073326Z&X-Goog-Expires=86399&X-Goog-Signature=27041959870a9e2ba8f974ca28bb5723ad374fc1917b982522ea3f555344b25fed8f73d5b443663eb1d7d10499d748efa90c8f6d91802daabf9e8c31258a7371e93aca1694d3e8bcb35a72d7c9ca243dc164f5dfae6ec3009c25cc78ea8cf37629f017b7538998f44f7c65ccfb675e343601421475a4490c754c5ee0370d3d5dfa928faddfe9a90621302ce69efd3d26c51f49c23f92148b018281ccd02f22c42e73e318594ea9c2ff9b25ad13163b60c37f1ededb4dc3a50712bcabdffed71883e2e5b5f04c40f13612d9ff1b6762b9f79d6e19873e959fbf9495e4827e901e8dbd25d7b6f291841915bc428212ad0a142cc3a9b7d04dbba9b49873573481fa&X-Goog-SignedHeaders=host&X-User=gia0603yucca\": read tcp 10.1.1.90:53738->142.250.73.123:443: read: connection reset by peer","method":"PUT","url":"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/ig4rhoqf/requirements.txt?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250813%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250813T073326Z&X-Goog-Expires=86399&X-Goog-Signature=27041959870a9e2ba8f974ca28bb5723ad374fc1917b982522ea3f555344b25fed8f73d5b443663eb1d7d10499d748efa90c8f6d91802daabf9e8c31258a7371e93aca1694d3e8bcb35a72d7c9ca243dc164f5dfae6ec3009c25cc78ea8cf37629f017b7538998f44f7c65ccfb675e343601421475a4490c754c5ee0370d3d5dfa928faddfe9a90621302ce69efd3d26c51f49c23f92148b018281ccd02f22c42e73e318594ea9c2ff9b25ad13163b60c37f1ededb4dc3a50712bcabdffed71883e2e5b5f04c40f13612d9ff1b6762b9f79d6e19873e959fbf9495e4827e901e8dbd25d7b6f291841915bc428212ad0a142cc3a9b7d04dbba9b49873573481fa&X-Goog-SignedHeaders=host&X-User=gia0603yucca"}
19
+ {"time":"2025-08-13T15:38:25.342508911+08:00","level":"ERROR","msg":"request failed","error":"Put \"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/ig4rhoqf/wandb-metadata.json?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250813%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250813T073326Z&X-Goog-Expires=86399&X-Goog-Signature=1c9fbc6cbd51ae71160b439aa8a692f95ee16291fd45e2c33912f90a5ae3a56a09f33e745a416b04da501c1cf40827caa3399f6cde0752ce3f571c12353aa477d27f258d11b74762a029b2fada9797f27a85c7fe03e3bb1ec12fb98789a6b389cf000b2742ca658544ea41f22c9631272e67f2ea275c4a5b7b9c516fed6ed22ccce6b26b684c341e337963a684be9b248bd64b665fb73437d28ce7365767fe6721770c9f8f75e150767d9e4154a983b7f6aafcc31bbc050a26df2e9a3b99e065c111d7a92525568b04cf8d8a7490e6b8f77aed12c26aa378d0708385b0f40b22af3988b5c6cfb25918b1c9471be072bdfdc7a986ca0774f732d27361d23f3448&X-Goog-SignedHeaders=host&X-User=gia0603yucca\": read tcp 10.1.1.90:49916->142.250.73.91:443: read: connection reset by peer","method":"PUT","url":"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/ig4rhoqf/wandb-metadata.json?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250813%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250813T073326Z&X-Goog-Expires=86399&X-Goog-Signature=1c9fbc6cbd51ae71160b439aa8a692f95ee16291fd45e2c33912f90a5ae3a56a09f33e745a416b04da501c1cf40827caa3399f6cde0752ce3f571c12353aa477d27f258d11b74762a029b2fada9797f27a85c7fe03e3bb1ec12fb98789a6b389cf000b2742ca658544ea41f22c9631272e67f2ea275c4a5b7b9c516fed6ed22ccce6b26b684c341e337963a684be9b248bd64b665fb73437d28ce7365767fe6721770c9f8f75e150767d9e4154a983b7f6aafcc31bbc050a26df2e9a3b99e065c111d7a92525568b04cf8d8a7490e6b8f77aed12c26aa378d0708385b0f40b22af3988b5c6cfb25918b1c9471be072bdfdc7a986ca0774f732d27361d23f3448&X-Goog-SignedHeaders=host&X-User=gia0603yucca"}
20
+ {"time":"2025-08-13T15:38:59.014612034+08:00","level":"ERROR","msg":"request failed","error":"Put \"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/ig4rhoqf/wandb-metadata.json?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250813%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250813T073326Z&X-Goog-Expires=86399&X-Goog-Signature=1c9fbc6cbd51ae71160b439aa8a692f95ee16291fd45e2c33912f90a5ae3a56a09f33e745a416b04da501c1cf40827caa3399f6cde0752ce3f571c12353aa477d27f258d11b74762a029b2fada9797f27a85c7fe03e3bb1ec12fb98789a6b389cf000b2742ca658544ea41f22c9631272e67f2ea275c4a5b7b9c516fed6ed22ccce6b26b684c341e337963a684be9b248bd64b665fb73437d28ce7365767fe6721770c9f8f75e150767d9e4154a983b7f6aafcc31bbc050a26df2e9a3b99e065c111d7a92525568b04cf8d8a7490e6b8f77aed12c26aa378d0708385b0f40b22af3988b5c6cfb25918b1c9471be072bdfdc7a986ca0774f732d27361d23f3448&X-Goog-SignedHeaders=host&X-User=gia0603yucca\": read tcp 10.1.1.90:44582->142.250.73.123:443: read: connection reset by peer","method":"PUT","url":"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/ig4rhoqf/wandb-metadata.json?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250813%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250813T073326Z&X-Goog-Expires=86399&X-Goog-Signature=1c9fbc6cbd51ae71160b439aa8a692f95ee16291fd45e2c33912f90a5ae3a56a09f33e745a416b04da501c1cf40827caa3399f6cde0752ce3f571c12353aa477d27f258d11b74762a029b2fada9797f27a85c7fe03e3bb1ec12fb98789a6b389cf000b2742ca658544ea41f22c9631272e67f2ea275c4a5b7b9c516fed6ed22ccce6b26b684c341e337963a684be9b248bd64b665fb73437d28ce7365767fe6721770c9f8f75e150767d9e4154a983b7f6aafcc31bbc050a26df2e9a3b99e065c111d7a92525568b04cf8d8a7490e6b8f77aed12c26aa378d0708385b0f40b22af3988b5c6cfb25918b1c9471be072bdfdc7a986ca0774f732d27361d23f3448&X-Goog-SignedHeaders=host&X-User=gia0603yucca"}
21
+ {"time":"2025-08-13T15:40:15.762761262+08:00","level":"ERROR","msg":"request failed","error":"Put \"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/ig4rhoqf/output.log?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250813%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250813T073342Z&X-Goog-Expires=86399&X-Goog-Signature=1d3cd94b4887f4c0adaa644487d39281fc878818fa20fa0c39b307b4d29a0774e1077acff64b5047465d575c9417f96645240e67e6dab6d1cc25f617e622bcf0f8715629cc8509684074bcd0688360af452c0f89dc80a4e9025500a6e0d16485ec7617abfd0ff10c8ae62cf6012c0f43aad8d30bb7c9de79f87c56f73ccabc572e83285debd09c51e5fcc95e1a7fe1cc011606044e056dc5f2d6756e365041194a3e52cc500d0e71499406481acf9cc760c1da018d07a35e07830dd04208781140f1fc38c86f5aa24bf65e1a58d35406ae84e7026fe6cf59353cf16ca53b51a813beb22417486e4fabc77de1aaef6ee20b755ce21b765571f6bfbf8473942f9d&X-Goog-SignedHeaders=host&X-User=gia0603yucca\": read tcp 10.1.1.90:46682->142.250.73.91:443: read: connection reset by peer","method":"PUT","url":"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/ig4rhoqf/output.log?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250813%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250813T073342Z&X-Goog-Expires=86399&X-Goog-Signature=1d3cd94b4887f4c0adaa644487d39281fc878818fa20fa0c39b307b4d29a0774e1077acff64b5047465d575c9417f96645240e67e6dab6d1cc25f617e622bcf0f8715629cc8509684074bcd0688360af452c0f89dc80a4e9025500a6e0d16485ec7617abfd0ff10c8ae62cf6012c0f43aad8d30bb7c9de79f87c56f73ccabc572e83285debd09c51e5fcc95e1a7fe1cc011606044e056dc5f2d6756e365041194a3e52cc500d0e71499406481acf9cc760c1da018d07a35e07830dd04208781140f1fc38c86f5aa24bf65e1a58d35406ae84e7026fe6cf59353cf16ca53b51a813beb22417486e4fabc77de1aaef6ee20b755ce21b765571f6bfbf8473942f9d&X-Goog-SignedHeaders=host&X-User=gia0603yucca"}
22
+ {"time":"2025-08-13T15:40:15.7960472+08:00","level":"ERROR","msg":"request failed","error":"Put \"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/ig4rhoqf/wandb-summary.json?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250813%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250813T073342Z&X-Goog-Expires=86399&X-Goog-Signature=43378ba7835541b25e7884ddd1295c99cdb59af218b575e11ebf49219cd7b70414883b2eaabf6ce5bf29b6dcf3aa4854cc71ac3f00aceb098b2bd9f008cbf3c3e776aeb0eec09eafbd1633dbdf5cd30ca267e7fc78a69846132727fb83834b13e3aaf2e17cf6af4974dc4f96f982c762fbd056b369608aece84f3c97eed6c02b7648d28a1ab3f3566026c830a6eb4d57a55676a251768961021c9050a7b6974e74c7efd13ca9324ac1a6768bca2daa0daf12b3c8ebc3ad02e5d5c935024048da77f0b6047842ebceaf42ec881dd23660f5f4caad61c04ea8736f4b8abd5c81197ebd10853a45e4944b85d318500c385596344922b43e114698bc05474aad582c&X-Goog-SignedHeaders=host&X-User=gia0603yucca\": read tcp 10.1.1.90:46696->142.250.73.91:443: read: connection reset by peer","method":"PUT","url":"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/ig4rhoqf/wandb-summary.json?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250813%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250813T073342Z&X-Goog-Expires=86399&X-Goog-Signature=43378ba7835541b25e7884ddd1295c99cdb59af218b575e11ebf49219cd7b70414883b2eaabf6ce5bf29b6dcf3aa4854cc71ac3f00aceb098b2bd9f008cbf3c3e776aeb0eec09eafbd1633dbdf5cd30ca267e7fc78a69846132727fb83834b13e3aaf2e17cf6af4974dc4f96f982c762fbd056b369608aece84f3c97eed6c02b7648d28a1ab3f3566026c830a6eb4d57a55676a251768961021c9050a7b6974e74c7efd13ca9324ac1a6768bca2daa0daf12b3c8ebc3ad02e5d5c935024048da77f0b6047842ebceaf42ec881dd23660f5f4caad61c04ea8736f4b8abd5c81197ebd10853a45e4944b85d318500c385596344922b43e114698bc05474aad582c&X-Goog-SignedHeaders=host&X-User=gia0603yucca"}
23
+ {"time":"2025-08-13T15:40:15.806561829+08:00","level":"ERROR","msg":"request failed","error":"Put \"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/ig4rhoqf/config.yaml?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250813%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250813T073342Z&X-Goog-Expires=86399&X-Goog-Signature=c0195bb166fc1589772f7ac8000fd444b8e06323ba6372e465965830037d0e3abe33eb458ac66e538461c44cd0aad54fae2cf3ed0f374b29b4cc12eefa1bd24dbcc6dd7346e5d2fe8238aca5f9f454f181d5839a406a34dc8e519e37ef8c4fc8de26b075401ceec8431d3f6c56345d424b5cace6503127369284bda44efb6fe17a73b78fc162b5452942692cd51024f02807e743229ac8d7ddee751592f125d012c3e4de0790abc7b061c8d8d6916a275902955479dcddb6342607f857379f4552e01426a19bf54b2231de5fb22f340330a26d23e6cae7d90831b2172a8cfdd004a5de1aca9d47db2eded16ebcbeb22d617adfb1a0324e51847993c2374681dd&X-Goog-SignedHeaders=host&X-User=gia0603yucca\": read tcp 10.1.1.90:46686->142.250.73.91:443: read: connection reset by peer","method":"PUT","url":"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/ig4rhoqf/config.yaml?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250813%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250813T073342Z&X-Goog-Expires=86399&X-Goog-Signature=c0195bb166fc1589772f7ac8000fd444b8e06323ba6372e465965830037d0e3abe33eb458ac66e538461c44cd0aad54fae2cf3ed0f374b29b4cc12eefa1bd24dbcc6dd7346e5d2fe8238aca5f9f454f181d5839a406a34dc8e519e37ef8c4fc8de26b075401ceec8431d3f6c56345d424b5cace6503127369284bda44efb6fe17a73b78fc162b5452942692cd51024f02807e743229ac8d7ddee751592f125d012c3e4de0790abc7b061c8d8d6916a275902955479dcddb6342607f857379f4552e01426a19bf54b2231de5fb22f340330a26d23e6cae7d90831b2172a8cfdd004a5de1aca9d47db2eded16ebcbeb22d617adfb1a0324e51847993c2374681dd&X-Goog-SignedHeaders=host&X-User=gia0603yucca"}
24
+ {"time":"2025-08-13T15:40:20.882189315+08:00","level":"ERROR","msg":"request failed","error":"Put \"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/ig4rhoqf/config.yaml?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250813%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250813T073342Z&X-Goog-Expires=86399&X-Goog-Signature=c0195bb166fc1589772f7ac8000fd444b8e06323ba6372e465965830037d0e3abe33eb458ac66e538461c44cd0aad54fae2cf3ed0f374b29b4cc12eefa1bd24dbcc6dd7346e5d2fe8238aca5f9f454f181d5839a406a34dc8e519e37ef8c4fc8de26b075401ceec8431d3f6c56345d424b5cace6503127369284bda44efb6fe17a73b78fc162b5452942692cd51024f02807e743229ac8d7ddee751592f125d012c3e4de0790abc7b061c8d8d6916a275902955479dcddb6342607f857379f4552e01426a19bf54b2231de5fb22f340330a26d23e6cae7d90831b2172a8cfdd004a5de1aca9d47db2eded16ebcbeb22d617adfb1a0324e51847993c2374681dd&X-Goog-SignedHeaders=host&X-User=gia0603yucca\": read tcp 10.1.1.90:55506->142.250.217.91:443: read: connection reset by peer","method":"PUT","url":"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/ig4rhoqf/config.yaml?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250813%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250813T073342Z&X-Goog-Expires=86399&X-Goog-Signature=c0195bb166fc1589772f7ac8000fd444b8e06323ba6372e465965830037d0e3abe33eb458ac66e538461c44cd0aad54fae2cf3ed0f374b29b4cc12eefa1bd24dbcc6dd7346e5d2fe8238aca5f9f454f181d5839a406a34dc8e519e37ef8c4fc8de26b075401ceec8431d3f6c56345d424b5cace6503127369284bda44efb6fe17a73b78fc162b5452942692cd51024f02807e743229ac8d7ddee751592f125d012c3e4de0790abc7b061c8d8d6916a275902955479dcddb6342607f857379f4552e01426a19bf54b2231de5fb22f340330a26d23e6cae7d90831b2172a8cfdd004a5de1aca9d47db2eded16ebcbeb22d617adfb1a0324e51847993c2374681dd&X-Goog-SignedHeaders=host&X-User=gia0603yucca"}
25
+ {"time":"2025-08-13T15:42:32.350288869+08:00","level":"INFO","msg":"fileTransfer: Close: file transfer manager closed"}
26
+ {"time":"2025-08-13T15:42:37.496789088+08:00","level":"INFO","msg":"handler: closed","stream_id":"ig4rhoqf"}
27
+ {"time":"2025-08-13T15:42:37.502696559+08:00","level":"INFO","msg":"sender: closed","stream_id":"ig4rhoqf"}
28
+ {"time":"2025-08-13T15:42:37.502726375+08:00","level":"INFO","msg":"stream: closed","id":"ig4rhoqf"}
BioReason_new/wandb/debug.log ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-08-13 15:33:18,716 INFO MainThread:13510 [wandb_setup.py:_flush():80] Current SDK version is 0.21.1
2
+ 2025-08-13 15:33:18,716 INFO MainThread:13510 [wandb_setup.py:_flush():80] Configure stats pid to 13510
3
+ 2025-08-13 15:33:18,716 INFO MainThread:13510 [wandb_setup.py:_flush():80] Loading settings from /root/.config/wandb/settings
4
+ 2025-08-13 15:33:18,716 INFO MainThread:13510 [wandb_setup.py:_flush():80] Loading settings from /nas/shared/kilab/wangyujia/BioReason_new/wandb/settings
5
+ 2025-08-13 15:33:18,716 INFO MainThread:13510 [wandb_setup.py:_flush():80] Loading settings from environment variables
6
+ 2025-08-13 15:33:18,716 INFO MainThread:13510 [wandb_init.py:setup_run_log_directory():703] Logging user logs to /nas/shared/kilab/wangyujia/BioReason_new/wandb/run-20250813_153318-ig4rhoqf/logs/debug.log
7
+ 2025-08-13 15:33:18,716 INFO MainThread:13510 [wandb_init.py:setup_run_log_directory():704] Logging internal logs to /nas/shared/kilab/wangyujia/BioReason_new/wandb/run-20250813_153318-ig4rhoqf/logs/debug-internal.log
8
+ 2025-08-13 15:33:18,716 INFO MainThread:13510 [wandb_init.py:init():830] calling init triggers
9
+ 2025-08-13 15:33:18,716 INFO MainThread:13510 [wandb_init.py:init():835] wandb.init called with sweep_config: {}
10
+ config: {'text_model_name': '/oss/wangyujia/BIO/construction_finetuning/alpaca/v1-20250609-141541/checkpoint-50-merged', 'protein_model_name': '/nas/shared/kilab/wangyujia/ProtT3/plm_model/esm2-150m', 'qformer_model_name': '/nas/shared/kilab/wangyujia/ProtT3/plm_model/microsoft', 'cache_dir': '/model-weights', 'num_query_tokens': 8, 'train_dataset': '/nas/shared/kilab/wangyujia/ProtT3/data/SwissProtV3/train_set.jsonl', 'valid_dataset': '/nas/shared/kilab/wangyujia/ProtT3/data/SwissProtV3/valid_set.jsonl', 'eval_dataset': True, 'output_dir': './contrastive_outputs', 'num_epochs': 10, 'batch_size': 8, 'learning_rate': 0.0001, 'weight_decay': 0.01, 'warmup_steps': 1000, 'gradient_accumulation_steps': 1, 'temperature': 0.07, 'freeze_protein_model': True, 'freeze_text_model': True, 'protein_weight': 1.0, 'text_weight': 1.0, 'enable_ptm': True, 'ptm_weight': 1.0, 'max_length_protein': 1024, 'max_length_text': 512, 'num_workers': 8, 'logging_steps': 100, 'eval_steps': 500, 'save_steps': 1000, 'save_total_limit': 3, 'fp16': False, 'bf16': False, 'seed': 42, 'use_wandb': True, 'wandb_project': 'protein-llm-contrastive', 'wandb_entity': None, '_wandb': {}}
11
+ 2025-08-13 15:33:18,716 INFO MainThread:13510 [wandb_init.py:init():871] starting backend
12
+ 2025-08-13 15:33:18,925 INFO MainThread:13510 [wandb_init.py:init():874] sending inform_init request
13
+ 2025-08-13 15:33:18,931 INFO MainThread:13510 [wandb_init.py:init():882] backend started and connected
14
+ 2025-08-13 15:33:18,933 INFO MainThread:13510 [wandb_init.py:init():953] updated telemetry
15
+ 2025-08-13 15:33:18,961 INFO MainThread:13510 [wandb_init.py:init():977] communicating run to backend with 90.0 second timeout
16
+ 2025-08-13 15:33:25,902 INFO MainThread:13510 [wandb_init.py:init():1029] starting run threads in backend
17
+ 2025-08-13 15:33:26,011 INFO MainThread:13510 [wandb_run.py:_console_start():2494] atexit reg
18
+ 2025-08-13 15:33:26,011 INFO MainThread:13510 [wandb_run.py:_redirect():2342] redirect: wrap_raw
19
+ 2025-08-13 15:33:26,011 INFO MainThread:13510 [wandb_run.py:_redirect():2411] Wrapping output streams.
20
+ 2025-08-13 15:33:26,011 INFO MainThread:13510 [wandb_run.py:_redirect():2434] Redirects installed.
21
+ 2025-08-13 15:33:26,014 INFO MainThread:13510 [wandb_init.py:init():1075] run started, returning control to user process
22
+ 2025-08-13 15:33:36,032 INFO MainThread:13510 [wandb_run.py:_config_callback():1380] config_cb None None {'output_dir': './contrastive_outputs', 'overwrite_output_dir': False, 'do_train': False, 'do_eval': True, 'do_predict': False, 'eval_strategy': 'steps', 'prediction_loss_only': False, 'per_device_train_batch_size': 8, 'per_device_eval_batch_size': 8, 'per_gpu_train_batch_size': None, 'per_gpu_eval_batch_size': None, 'gradient_accumulation_steps': 1, 'eval_accumulation_steps': None, 'eval_delay': 0, 'torch_empty_cache_steps': None, 'learning_rate': 0.0001, 'weight_decay': 0.01, 'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_epsilon': 1e-08, 'max_grad_norm': 1.0, 'num_train_epochs': 10, 'max_steps': -1, 'lr_scheduler_type': 'linear', 'lr_scheduler_kwargs': {}, 'warmup_ratio': 0.0, 'warmup_steps': 1000, 'log_level': 'passive', 'log_level_replica': 'warning', 'log_on_each_node': True, 'logging_dir': './contrastive_outputs/runs/Aug13_15-33-30_dsw-265304-585cc9d768-ckfd7', 'logging_strategy': 'steps', 'logging_first_step': False, 'logging_steps': 100, 'logging_nan_inf_filter': True, 'save_strategy': 'steps', 'save_steps': 1000, 'save_total_limit': 3, 'save_safetensors': True, 'save_on_each_node': False, 'save_only_model': False, 'restore_callback_states_from_checkpoint': False, 'no_cuda': False, 'use_cpu': False, 'use_mps_device': False, 'seed': 42, 'data_seed': None, 'jit_mode_eval': False, 'use_ipex': False, 'bf16': False, 'fp16': False, 'fp16_opt_level': 'O1', 'half_precision_backend': 'auto', 'bf16_full_eval': False, 'fp16_full_eval': False, 'tf32': None, 'local_rank': 0, 'ddp_backend': None, 'tpu_num_cores': None, 'tpu_metrics_debug': False, 'debug': [], 'dataloader_drop_last': False, 'eval_steps': 500, 'dataloader_num_workers': 8, 'dataloader_prefetch_factor': None, 'past_index': -1, 'run_name': None, 'disable_tqdm': False, 'remove_unused_columns': False, 'label_names': None, 'load_best_model_at_end': True, 'metric_for_best_model': 'eval_avg_recall_at_1', 'greater_is_better': True, 'ignore_data_skip': False, 'fsdp': [], 'fsdp_min_num_params': 0, 'fsdp_config': {'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}, 'fsdp_transformer_layer_cls_to_wrap': None, 'accelerator_config': {'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None}, 'deepspeed': None, 'label_smoothing_factor': 0.0, 'optim': 'adamw_torch', 'optim_args': None, 'adafactor': False, 'group_by_length': False, 'length_column_name': 'length', 'report_to': ['wandb'], 'ddp_find_unused_parameters': False, 'ddp_bucket_cap_mb': None, 'ddp_broadcast_buffers': None, 'dataloader_pin_memory': True, 'dataloader_persistent_workers': False, 'skip_memory_metrics': True, 'use_legacy_prediction_loop': False, 'push_to_hub': False, 'resume_from_checkpoint': None, 'hub_model_id': None, 'hub_strategy': 'every_save', 'hub_token': '<HUB_TOKEN>', 'hub_private_repo': None, 'hub_always_push': False, 'hub_revision': None, 'gradient_checkpointing': False, 'gradient_checkpointing_kwargs': None, 'include_inputs_for_metrics': False, 'include_for_metrics': [], 'eval_do_concat_batches': True, 'fp16_backend': 'auto', 'push_to_hub_model_id': None, 'push_to_hub_organization': None, 'push_to_hub_token': '<PUSH_TO_HUB_TOKEN>', 'mp_parameters': '', 'auto_find_batch_size': False, 'full_determinism': False, 'torchdynamo': None, 'ray_scope': 'last', 'ddp_timeout': 1800, 'torch_compile': False, 'torch_compile_backend': None, 'torch_compile_mode': None, 'include_tokens_per_second': False, 'include_num_input_tokens_seen': False, 'neftune_noise_alpha': None, 'optim_target_modules': None, 'batch_eval_metrics': False, 'eval_on_start': False, 'use_liger_kernel': False, 'liger_kernel_config': None, 'eval_use_gather_object': False, 'average_tokens_across_devices': False, 'temperature': 0.07, 'freeze_protein_model': True, 'freeze_text_model': True, 'protein_weight': 1.0, 'text_weight': 1.0, 'max_length_protein': 1024, 'max_length_text': 512, 'enable_ptm': True, 'ptm_weight': 1.0}
23
+ 2025-08-13 15:33:41,038 INFO MsgRouterThr:13510 [mailbox.py:close():129] [no run ID] Closing mailbox, abandoning 2 handles.
BioReason_new/wandb/run-20250811_215805-k21eogb7/files/config.yaml ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _wandb:
2
+ value:
3
+ cli_version: 0.21.1
4
+ e:
5
+ ryu7ghs3jgr22n9qfads9dl5ddd2hfe3:
6
+ args:
7
+ - --text_model_name
8
+ - /oss/wangyujia/BIO/construction_finetuning/alpaca/v1-20250609-141541/checkpoint-50-merged
9
+ - --protein_model_name
10
+ - /nas/shared/kilab/wangyujia/ProtT3/plm_model/esm2-150m
11
+ - --qformer_model_name
12
+ - /nas/shared/kilab/wangyujia/ProtT3/plm_model/microsoft
13
+ - --num_query_tokens
14
+ - "8"
15
+ - --dataset_name
16
+ - /nas/shared/kilab/wangyujia/ProtT3/data/SwissProtV3/train_set.jsonl
17
+ - --output_dir
18
+ - ./contrastive_outputs
19
+ - --num_epochs
20
+ - "10"
21
+ - --batch_size
22
+ - "32"
23
+ - --learning_rate
24
+ - "1e-4"
25
+ - --temperature
26
+ - "0.07"
27
+ - --freeze_protein_model
28
+ - --freeze_text_model
29
+ - --enable_ptm
30
+ - --max_length_protein
31
+ - "1024"
32
+ - --max_length_text
33
+ - "512"
34
+ - --num_workers
35
+ - "8"
36
+ - --eval_dataset
37
+ - --use_wandb
38
+ - --wandb_project
39
+ - protein-llm-contrastive
40
+ - --logging_steps
41
+ - "100"
42
+ - --eval_steps
43
+ - "500"
44
+ - --save_steps
45
+ - "1000"
46
+ codePath: wangyujia/BioReason_new/train_contrastive.py
47
+ codePathLocal: train_contrastive.py
48
+ executable: /root/miniconda3/envs/bioreason/bin/python
49
+ git:
50
+ commit: b8caf406aa1699c788f0ca6e44a1769452c317db
51
+ remote: https://github.com/PorUna-byte/PAR.git
52
+ host: dsw-265304-f8bc5ff76-4mdt5
53
+ os: Linux-5.10.134-008.18.kangaroo.al8.x86_64-x86_64-with-glibc2.35
54
+ program: /nas/shared/kilab/wangyujia/BioReason_new/train_contrastive.py
55
+ python: CPython 3.11.0
56
+ root: /nas/shared/kilab/wangyujia/BioReason_new
57
+ startedAt: "2025-08-11T13:58:05.851565Z"
58
+ writerId: ryu7ghs3jgr22n9qfads9dl5ddd2hfe3
59
+ m: []
60
+ python_version: 3.11.0
61
+ t:
62
+ "1":
63
+ - 1
64
+ - 9
65
+ - 11
66
+ - 41
67
+ - 49
68
+ - 51
69
+ - 71
70
+ - 84
71
+ - 98
72
+ - 103
73
+ "2":
74
+ - 1
75
+ - 9
76
+ - 11
77
+ - 41
78
+ - 49
79
+ - 51
80
+ - 71
81
+ - 84
82
+ - 98
83
+ - 103
84
+ "3":
85
+ - 13
86
+ - 16
87
+ "4": 3.11.0
88
+ "5": 0.21.1
89
+ "6": 4.55.0
90
+ "12": 0.21.1
91
+ "13": linux-x86_64
92
+ batch_size:
93
+ value: 32
94
+ bf16:
95
+ value: false
96
+ cache_dir:
97
+ value: /model-weights
98
+ dataset_name:
99
+ value: /nas/shared/kilab/wangyujia/ProtT3/data/SwissProtV3/train_set.jsonl
100
+ enable_ptm:
101
+ value: true
102
+ eval_dataset:
103
+ value: true
104
+ eval_steps:
105
+ value: 500
106
+ fp16:
107
+ value: false
108
+ freeze_protein_model:
109
+ value: true
110
+ freeze_text_model:
111
+ value: true
112
+ gradient_accumulation_steps:
113
+ value: 1
114
+ learning_rate:
115
+ value: 0.0001
116
+ logging_steps:
117
+ value: 100
118
+ max_length_protein:
119
+ value: 1024
120
+ max_length_text:
121
+ value: 512
122
+ num_epochs:
123
+ value: 10
124
+ num_query_tokens:
125
+ value: 8
126
+ num_workers:
127
+ value: 8
128
+ output_dir:
129
+ value: ./contrastive_outputs
130
+ protein_model_name:
131
+ value: /nas/shared/kilab/wangyujia/ProtT3/plm_model/esm2-150m
132
+ protein_weight:
133
+ value: 1
134
+ ptm_weight:
135
+ value: 1
136
+ qformer_model_name:
137
+ value: /nas/shared/kilab/wangyujia/ProtT3/plm_model/microsoft
138
+ save_steps:
139
+ value: 1000
140
+ save_total_limit:
141
+ value: 3
142
+ seed:
143
+ value: 42
144
+ temperature:
145
+ value: 0.07
146
+ text_model_name:
147
+ value: /oss/wangyujia/BIO/construction_finetuning/alpaca/v1-20250609-141541/checkpoint-50-merged
148
+ text_weight:
149
+ value: 1
150
+ use_wandb:
151
+ value: true
152
+ wandb_entity:
153
+ value: null
154
+ wandb_project:
155
+ value: protein-llm-contrastive
156
+ warmup_steps:
157
+ value: 1000
158
+ weight_decay:
159
+ value: 0.01
BioReason_new/wandb/run-20250811_215805-k21eogb7/files/output.log ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Loading model...
2
+ Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 2.30it/s]
3
+ Some weights of EsmModel were not initialized from the model checkpoint at /nas/shared/kilab/wangyujia/ProtT3/plm_model/esm2-150m and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
4
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
5
+ Loading datasets...
6
+ Traceback (most recent call last):
7
+ File "/nas/shared/kilab/wangyujia/BioReason_new/train_contrastive.py", line 549, in <module>
8
+ trainer = main(args)
9
+ ^^^^^^^^^^
10
+ File "/nas/shared/kilab/wangyujia/BioReason_new/train_contrastive.py", line 317, in main
11
+ train_dataset = load_dataset(args.dataset_name, split="train")
12
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
13
+ File "/root/miniconda3/envs/bioreason/lib/python3.11/site-packages/datasets/load.py", line 1392, in load_dataset
14
+ builder_instance = load_dataset_builder(
15
+ ^^^^^^^^^^^^^^^^^^^^^
16
+ File "/root/miniconda3/envs/bioreason/lib/python3.11/site-packages/datasets/load.py", line 1132, in load_dataset_builder
17
+ dataset_module = dataset_module_factory(
18
+ ^^^^^^^^^^^^^^^^^^^^^^^
19
+ File "/root/miniconda3/envs/bioreason/lib/python3.11/site-packages/datasets/load.py", line 1033, in dataset_module_factory
20
+ raise FileNotFoundError(f"Couldn't find any data file at {relative_to_absolute_path(path)}.")
21
+ FileNotFoundError: Couldn't find any data file at /nas/shared/kilab/wangyujia/ProtT3/data/SwissProtV3/train_set.jsonl.
BioReason_new/wandb/run-20250811_215805-k21eogb7/files/requirements.txt ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ nvidia-nccl-cu12==2.26.2
2
+ cbor2==5.6.5
3
+ jupyter_server==2.16.0
4
+ nvidia-curand-cu12==10.3.7.77
5
+ bleach==6.2.0
6
+ py-cpuinfo==9.0.0
7
+ llvmlite==0.44.0
8
+ fsspec==2025.3.0
9
+ uvloop==0.21.0
10
+ rfc3986-validator==0.1.1
11
+ smmap==5.0.2
12
+ pip==25.1
13
+ compressed-tensors==0.10.2
14
+ ipython_pygments_lexers==1.1.1
15
+ fastapi-cli==0.0.8
16
+ filelock==3.18.0
17
+ msgspec==0.19.0
18
+ hjson==3.1.0
19
+ markdown-it-py==3.0.0
20
+ pyzmq==27.0.1
21
+ interegular==0.3.3
22
+ widgetsnbextension==4.0.14
23
+ vllm==0.10.0
24
+ ipykernel==6.30.1
25
+ pydantic==2.11.7
26
+ click==8.2.1
27
+ torchvision==0.22.1
28
+ fastapi-cloud-cli==0.1.5
29
+ httpcore==1.0.9
30
+ nvidia-cuda-nvrtc-cu12==12.6.77
31
+ mdurl==0.1.2
32
+ rich-toolkit==0.15.0
33
+ Pygments==2.19.2
34
+ pure_eval==0.2.3
35
+ types-python-dateutil==2.9.0.20250809
36
+ referencing==0.36.2
37
+ jupyterlab_widgets==3.0.15
38
+ typing-inspection==0.4.1
39
+ stack-data==0.6.3
40
+ jupyter_client==8.6.3
41
+ regex==2025.7.33
42
+ platformdirs==4.3.8
43
+ arrow==1.3.0
44
+ aiosignal==1.4.0
45
+ python-dateutil==2.9.0.post0
46
+ numpy==2.2.6
47
+ jupyter-lsp==2.2.6
48
+ transformers==4.55.0
49
+ mpmath==1.3.0
50
+ six==1.17.0
51
+ python-json-logger==3.3.0
52
+ distro==1.9.0
53
+ partial-json-parser==0.2.1.1.post6
54
+ bitsandbytes==0.46.1
55
+ nvidia-cusparselt-cu12==0.6.3
56
+ pandocfilters==1.5.1
57
+ pexpect==4.9.0
58
+ pydantic-extra-types==2.10.5
59
+ Jinja2==3.1.6
60
+ sentencepiece==0.2.0
61
+ uvicorn==0.35.0
62
+ babel==2.17.0
63
+ trl==0.21.0
64
+ urllib3==2.5.0
65
+ prometheus_client==0.22.1
66
+ watchfiles==1.1.0
67
+ prometheus-fastapi-instrumentator==7.1.0
68
+ jsonschema-specifications==2025.4.1
69
+ diskcache==5.6.3
70
+ webcolors==24.11.1
71
+ peft==0.17.0
72
+ jiter==0.10.0
73
+ triton==3.3.1
74
+ gitdb==4.0.12
75
+ gguf==0.17.1
76
+ safetensors==0.6.2
77
+ cloudpickle==3.1.1
78
+ multiprocess==0.70.16
79
+ aiohttp==3.12.15
80
+ tornado==6.5.2
81
+ nvidia-nvtx-cu12==12.6.77
82
+ nbclient==0.10.2
83
+ nbconvert==7.16.6
84
+ psutil==7.0.0
85
+ llguidance==0.7.30
86
+ ray==2.48.0
87
+ wcwidth==0.2.13
88
+ rignore==0.6.4
89
+ nvidia-cudnn-cu12==9.5.1.17
90
+ soupsieve==2.7
91
+ wandb==0.21.1
92
+ overrides==7.7.0
93
+ opencv-python-headless==4.12.0.88
94
+ pycparser==2.22
95
+ scipy==1.16.1
96
+ terminado==0.18.1
97
+ typer==0.16.0
98
+ parso==0.8.4
99
+ lark==1.2.2
100
+ msgpack==1.1.1
101
+ websockets==15.0.1
102
+ idna==3.10
103
+ fastrlock==0.8.3
104
+ jedi==0.19.2
105
+ accelerate==1.10.0
106
+ jupyter==1.1.1
107
+ beautifulsoup4==4.13.4
108
+ h11==0.16.0
109
+ MarkupSafe==3.0.2
110
+ python-dotenv==1.1.1
111
+ aiohappyeyeballs==2.6.1
112
+ rich==14.1.0
113
+ nbformat==5.10.4
114
+ traitlets==5.14.3
115
+ decorator==5.2.1
116
+ soxr==0.5.0.post1
117
+ propcache==0.3.2
118
+ ninja==1.11.1.4
119
+ cffi==1.17.1
120
+ cupy-cuda12x==13.5.1
121
+ pandas==2.3.1
122
+ deepspeed==0.17.4
123
+ setuptools==78.1.1
124
+ websocket-client==1.8.0
125
+ qwen-vl-utils==0.0.11
126
+ webencodings==0.5.1
127
+ httptools==0.6.4
128
+ jupyterlab==4.4.5
129
+ ptyprocess==0.7.0
130
+ shellingham==1.5.4
131
+ attrs==25.3.0
132
+ fqdn==1.5.1
133
+ huggingface-hub==0.34.4
134
+ tokenizers==0.21.4
135
+ asttokens==3.0.0
136
+ jupyter_server_terminals==0.5.3
137
+ av==15.0.0
138
+ nvidia-cuda-cupti-cu12==12.6.80
139
+ typing_extensions==4.14.1
140
+ hf-xet==1.1.7
141
+ jupyter_core==5.8.1
142
+ starlette==0.47.2
143
+ fastjsonschema==2.21.1
144
+ fastapi==0.116.1
145
+ lightning-utilities==0.15.2
146
+ jupyter-console==6.6.3
147
+ pybase64==1.4.2
148
+ jupyter-events==0.12.0
149
+ requests==2.32.4
150
+ numba==0.61.2
151
+ networkx==3.5
152
+ nvidia-cusparse-cu12==12.5.4.2
153
+ jsonpointer==3.0.0
154
+ pyarrow==21.0.0
155
+ dnspython==2.7.0
156
+ torchaudio==2.7.1
157
+ ipython==9.4.0
158
+ isoduration==20.11.0
159
+ bioreason==0.1.0
160
+ matplotlib-inline==0.1.7
161
+ packaging==25.0
162
+ xxhash==3.5.0
163
+ depyf==0.19.0
164
+ sentry-sdk==2.34.1
165
+ prompt_toolkit==3.0.51
166
+ nvidia-cublas-cu12==12.6.4.1
167
+ rfc3339-validator==0.1.4
168
+ nvidia-cufft-cu12==11.3.0.4
169
+ email_validator==2.2.0
170
+ pycountry==24.6.1
171
+ argon2-cffi==25.1.0
172
+ nvidia-cufile-cu12==1.11.1.6
173
+ frozenlist==1.7.0
174
+ json5==0.12.0
175
+ tinycss2==1.4.0
176
+ defusedxml==0.7.1
177
+ lm-format-enforcer==0.10.12
178
+ Send2Trash==1.8.3
179
+ anyio==4.10.0
180
+ rfc3987-syntax==1.1.0
181
+ pydantic_core==2.33.2
182
+ debugpy==1.8.16
183
+ async-lru==2.0.5
184
+ nvidia-cuda-runtime-cu12==12.6.77
185
+ tiktoken==0.11.0
186
+ comm==0.2.3
187
+ PyYAML==6.0.2
188
+ blake3==1.0.5
189
+ nvidia-cusolver-cu12==11.7.1.2
190
+ torch==2.7.1
191
+ torchmetrics==1.8.1
192
+ yarl==1.20.1
193
+ dill==0.3.8
194
+ wheel==0.45.1
195
+ cachetools==6.1.0
196
+ multidict==6.6.3
197
+ pytz==2025.2
198
+ pillow==11.3.0
199
+ annotated-types==0.7.0
200
+ astor==0.8.1
201
+ nest-asyncio==1.6.0
202
+ httpx==0.28.1
203
+ argon2-cffi-bindings==25.1.0
204
+ notebook_shim==0.2.4
205
+ jsonschema==4.25.0
206
+ python-multipart==0.0.20
207
+ charset-normalizer==3.4.3
208
+ tqdm==4.67.1
209
+ xformers==0.0.31
210
+ tzdata==2025.2
211
+ einops==0.8.1
212
+ mistral_common==1.8.3
213
+ jupyterlab_server==2.27.3
214
+ sympy==1.14.0
215
+ datasets==4.0.0
216
+ GitPython==3.1.45
217
+ mistune==3.1.3
218
+ ipywidgets==8.1.7
219
+ nvidia-ml-py==13.580.65
220
+ uri-template==1.3.0
221
+ notebook==7.4.5
222
+ certifi==2025.8.3
223
+ nvidia-nvjitlink-cu12==12.6.85
224
+ openai==1.90.0
225
+ xgrammar==0.1.21
226
+ executing==2.2.0
227
+ soundfile==0.13.1
228
+ jupyterlab_pygments==0.3.0
229
+ outlines_core==0.2.10
230
+ sniffio==1.3.1
231
+ pytorch-lightning==2.5.2
232
+ rpds-py==0.27.0
233
+ protobuf==6.31.1
BioReason_new/wandb/run-20250811_215805-k21eogb7/files/wandb-metadata.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.10.134-008.18.kangaroo.al8.x86_64-x86_64-with-glibc2.35",
3
+ "python": "CPython 3.11.0",
4
+ "startedAt": "2025-08-11T13:58:05.851565Z",
5
+ "args": [
6
+ "--text_model_name",
7
+ "/oss/wangyujia/BIO/construction_finetuning/alpaca/v1-20250609-141541/checkpoint-50-merged",
8
+ "--protein_model_name",
9
+ "/nas/shared/kilab/wangyujia/ProtT3/plm_model/esm2-150m",
10
+ "--qformer_model_name",
11
+ "/nas/shared/kilab/wangyujia/ProtT3/plm_model/microsoft",
12
+ "--num_query_tokens",
13
+ "8",
14
+ "--dataset_name",
15
+ "/nas/shared/kilab/wangyujia/ProtT3/data/SwissProtV3/train_set.jsonl",
16
+ "--output_dir",
17
+ "./contrastive_outputs",
18
+ "--num_epochs",
19
+ "10",
20
+ "--batch_size",
21
+ "32",
22
+ "--learning_rate",
23
+ "1e-4",
24
+ "--temperature",
25
+ "0.07",
26
+ "--freeze_protein_model",
27
+ "--freeze_text_model",
28
+ "--enable_ptm",
29
+ "--max_length_protein",
30
+ "1024",
31
+ "--max_length_text",
32
+ "512",
33
+ "--num_workers",
34
+ "8",
35
+ "--eval_dataset",
36
+ "--use_wandb",
37
+ "--wandb_project",
38
+ "protein-llm-contrastive",
39
+ "--logging_steps",
40
+ "100",
41
+ "--eval_steps",
42
+ "500",
43
+ "--save_steps",
44
+ "1000"
45
+ ],
46
+ "program": "/nas/shared/kilab/wangyujia/BioReason_new/train_contrastive.py",
47
+ "codePath": "wangyujia/BioReason_new/train_contrastive.py",
48
+ "codePathLocal": "train_contrastive.py",
49
+ "git": {
50
+ "remote": "https://github.com/PorUna-byte/PAR.git",
51
+ "commit": "b8caf406aa1699c788f0ca6e44a1769452c317db"
52
+ },
53
+ "root": "/nas/shared/kilab/wangyujia/BioReason_new",
54
+ "host": "dsw-265304-f8bc5ff76-4mdt5",
55
+ "executable": "/root/miniconda3/envs/bioreason/bin/python",
56
+ "writerId": "ryu7ghs3jgr22n9qfads9dl5ddd2hfe3"
57
+ }
BioReason_new/wandb/run-20250811_215805-k21eogb7/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"_runtime":3,"_wandb":{"runtime":3}}
BioReason_new/wandb/run-20250811_215805-k21eogb7/logs/debug-internal.log ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"time":"2025-08-11T21:58:06.266313778+08:00","level":"INFO","msg":"stream: starting","core version":"0.21.1"}
2
+ {"time":"2025-08-11T21:58:11.67768592+08:00","level":"INFO","msg":"stream: created new stream","id":"k21eogb7"}
3
+ {"time":"2025-08-11T21:58:11.67881+08:00","level":"INFO","msg":"stream: started","id":"k21eogb7"}
4
+ {"time":"2025-08-11T21:58:11.678831626+08:00","level":"INFO","msg":"writer: started","stream_id":"k21eogb7"}
5
+ {"time":"2025-08-11T21:58:11.678836167+08:00","level":"INFO","msg":"sender: started","stream_id":"k21eogb7"}
6
+ {"time":"2025-08-11T21:58:11.678866578+08:00","level":"INFO","msg":"handler: started","stream_id":"k21eogb7"}
7
+ {"time":"2025-08-11T21:58:17.422177077+08:00","level":"INFO","msg":"stream: closing","id":"k21eogb7"}
8
+ {"time":"2025-08-11T21:58:17.894015835+08:00","level":"ERROR","msg":"request failed","error":"Put \"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/k21eogb7/wandb-metadata.json?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250811%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250811T135814Z&X-Goog-Expires=86399&X-Goog-Signature=ac3cb37c5d4b2fc5862a0022639cf77d48b1d5bedf77614e65d21af8685129bd5888f0c2a5a385e3890e6bfe21ebfae49170fa950c85170f44ff15f9e914184a3755ccca89484e99e1c879cf0fe2606c7e0e29a42d8203a5625e24b124c0c3b9149353e88a5877f1c69381bbacde2febade041b46fa50ba30d73fd36d4edd17f8dbbc3e6d5589bf6a016fb4d5a7b8c061530322df6e2f1d2ce72f94f80d90e592b7ff11dee46dda00343f8c5cf527bfb758d250201795f13740771c5083590019038f3b4af4250b46cce19aa2c7d05c26aa69e1e6091738740ba50a56247d537ee611862178b0ecae0c476bb79d1e77eb7d472406c0ec7371a76080c28bdf798&X-Goog-SignedHeaders=host&X-User=gia0603yucca\": read tcp 10.1.8.118:37628->142.250.73.123:443: read: connection reset by peer","method":"PUT","url":"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/k21eogb7/wandb-metadata.json?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250811%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250811T135814Z&X-Goog-Expires=86399&X-Goog-Signature=ac3cb37c5d4b2fc5862a0022639cf77d48b1d5bedf77614e65d21af8685129bd5888f0c2a5a385e3890e6bfe21ebfae49170fa950c85170f44ff15f9e914184a3755ccca89484e99e1c879cf0fe2606c7e0e29a42d8203a5625e24b124c0c3b9149353e88a5877f1c69381bbacde2febade041b46fa50ba30d73fd36d4edd17f8dbbc3e6d5589bf6a016fb4d5a7b8c061530322df6e2f1d2ce72f94f80d90e592b7ff11dee46dda00343f8c5cf527bfb758d250201795f13740771c5083590019038f3b4af4250b46cce19aa2c7d05c26aa69e1e6091738740ba50a56247d537ee611862178b0ecae0c476bb79d1e77eb7d472406c0ec7371a76080c28bdf798&X-Goog-SignedHeaders=host&X-User=gia0603yucca"}
9
+ {"time":"2025-08-11T21:58:17.907861315+08:00","level":"ERROR","msg":"request failed","error":"Put \"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/k21eogb7/requirements.txt?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250811%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250811T135814Z&X-Goog-Expires=86399&X-Goog-Signature=ba6b1293deff25490307723c6c45ae5a2004c6679aca654218f25fa93a552e701ef0117ff99cef40a8a866d8bfe1a7a30662ae892d1831dd9af3e37dbabb7395c8e38d0701570166647ba2ab06044cc1206ce74aac46ce4e3c16fe6ede3e068e29538f5ce545aa06ccc122ff37f2e0f230c5e9bb4c5f9e8fb37cb62a240476326613b447cf162d0a360167f2a2000f4ce82833edc6ed6fbea897f4f74b4210dcc3be2c47f92f8af94d520341231457a0a5796e6169943c0a3ac34cc0c5fc29141610ab6bb31c50000377c7bd4e21f805feb05e31212a54ca8a02be4a129bcd603dc585268f4e8ab0c5cadb4100d577cc3d6e9f38b36192b2f7222e2fe99f654a&X-Goog-SignedHeaders=host&X-User=gia0603yucca\": read tcp 10.1.8.118:37620->142.250.73.123:443: read: connection reset by peer","method":"PUT","url":"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/k21eogb7/requirements.txt?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250811%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250811T135814Z&X-Goog-Expires=86399&X-Goog-Signature=ba6b1293deff25490307723c6c45ae5a2004c6679aca654218f25fa93a552e701ef0117ff99cef40a8a866d8bfe1a7a30662ae892d1831dd9af3e37dbabb7395c8e38d0701570166647ba2ab06044cc1206ce74aac46ce4e3c16fe6ede3e068e29538f5ce545aa06ccc122ff37f2e0f230c5e9bb4c5f9e8fb37cb62a240476326613b447cf162d0a360167f2a2000f4ce82833edc6ed6fbea897f4f74b4210dcc3be2c47f92f8af94d520341231457a0a5796e6169943c0a3ac34cc0c5fc29141610ab6bb31c50000377c7bd4e21f805feb05e31212a54ca8a02be4a129bcd603dc585268f4e8ab0c5cadb4100d577cc3d6e9f38b36192b2f7222e2fe99f654a&X-Goog-SignedHeaders=host&X-User=gia0603yucca"}
10
+ {"time":"2025-08-11T21:58:20.579200782+08:00","level":"ERROR","msg":"request failed","error":"Put \"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/k21eogb7/wandb-metadata.json?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250811%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250811T135814Z&X-Goog-Expires=86399&X-Goog-Signature=ac3cb37c5d4b2fc5862a0022639cf77d48b1d5bedf77614e65d21af8685129bd5888f0c2a5a385e3890e6bfe21ebfae49170fa950c85170f44ff15f9e914184a3755ccca89484e99e1c879cf0fe2606c7e0e29a42d8203a5625e24b124c0c3b9149353e88a5877f1c69381bbacde2febade041b46fa50ba30d73fd36d4edd17f8dbbc3e6d5589bf6a016fb4d5a7b8c061530322df6e2f1d2ce72f94f80d90e592b7ff11dee46dda00343f8c5cf527bfb758d250201795f13740771c5083590019038f3b4af4250b46cce19aa2c7d05c26aa69e1e6091738740ba50a56247d537ee611862178b0ecae0c476bb79d1e77eb7d472406c0ec7371a76080c28bdf798&X-Goog-SignedHeaders=host&X-User=gia0603yucca\": read tcp 10.1.8.118:51916->142.250.73.91:443: read: connection reset by peer","method":"PUT","url":"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/k21eogb7/wandb-metadata.json?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250811%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250811T135814Z&X-Goog-Expires=86399&X-Goog-Signature=ac3cb37c5d4b2fc5862a0022639cf77d48b1d5bedf77614e65d21af8685129bd5888f0c2a5a385e3890e6bfe21ebfae49170fa950c85170f44ff15f9e914184a3755ccca89484e99e1c879cf0fe2606c7e0e29a42d8203a5625e24b124c0c3b9149353e88a5877f1c69381bbacde2febade041b46fa50ba30d73fd36d4edd17f8dbbc3e6d5589bf6a016fb4d5a7b8c061530322df6e2f1d2ce72f94f80d90e592b7ff11dee46dda00343f8c5cf527bfb758d250201795f13740771c5083590019038f3b4af4250b46cce19aa2c7d05c26aa69e1e6091738740ba50a56247d537ee611862178b0ecae0c476bb79d1e77eb7d472406c0ec7371a76080c28bdf798&X-Goog-SignedHeaders=host&X-User=gia0603yucca"}
11
+ {"time":"2025-08-11T21:58:20.615403659+08:00","level":"ERROR","msg":"request failed","error":"Put \"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/k21eogb7/requirements.txt?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250811%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250811T135814Z&X-Goog-Expires=86399&X-Goog-Signature=ba6b1293deff25490307723c6c45ae5a2004c6679aca654218f25fa93a552e701ef0117ff99cef40a8a866d8bfe1a7a30662ae892d1831dd9af3e37dbabb7395c8e38d0701570166647ba2ab06044cc1206ce74aac46ce4e3c16fe6ede3e068e29538f5ce545aa06ccc122ff37f2e0f230c5e9bb4c5f9e8fb37cb62a240476326613b447cf162d0a360167f2a2000f4ce82833edc6ed6fbea897f4f74b4210dcc3be2c47f92f8af94d520341231457a0a5796e6169943c0a3ac34cc0c5fc29141610ab6bb31c50000377c7bd4e21f805feb05e31212a54ca8a02be4a129bcd603dc585268f4e8ab0c5cadb4100d577cc3d6e9f38b36192b2f7222e2fe99f654a&X-Goog-SignedHeaders=host&X-User=gia0603yucca\": read tcp 10.1.8.118:51920->142.250.73.91:443: read: connection reset by peer","method":"PUT","url":"https://storage.googleapis.com/wandb-production.appspot.com/gia0603yucca/protein-llm-contrastive/k21eogb7/requirements.txt?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20250811%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250811T135814Z&X-Goog-Expires=86399&X-Goog-Signature=ba6b1293deff25490307723c6c45ae5a2004c6679aca654218f25fa93a552e701ef0117ff99cef40a8a866d8bfe1a7a30662ae892d1831dd9af3e37dbabb7395c8e38d0701570166647ba2ab06044cc1206ce74aac46ce4e3c16fe6ede3e068e29538f5ce545aa06ccc122ff37f2e0f230c5e9bb4c5f9e8fb37cb62a240476326613b447cf162d0a360167f2a2000f4ce82833edc6ed6fbea897f4f74b4210dcc3be2c47f92f8af94d520341231457a0a5796e6169943c0a3ac34cc0c5fc29141610ab6bb31c50000377c7bd4e21f805feb05e31212a54ca8a02be4a129bcd603dc585268f4e8ab0c5cadb4100d577cc3d6e9f38b36192b2f7222e2fe99f654a&X-Goog-SignedHeaders=host&X-User=gia0603yucca"}
12
+ {"time":"2025-08-11T21:58:35.386126561+08:00","level":"INFO","msg":"fileTransfer: Close: file transfer manager closed"}
13
+ {"time":"2025-08-11T21:59:02.436399111+08:00","level":"INFO","msg":"handler: closed","stream_id":"k21eogb7"}
14
+ {"time":"2025-08-11T21:59:02.440170612+08:00","level":"INFO","msg":"sender: closed","stream_id":"k21eogb7"}
15
+ {"time":"2025-08-11T21:59:02.440183082+08:00","level":"INFO","msg":"stream: closed","id":"k21eogb7"}
BioReason_new/wandb/run-20250811_215805-k21eogb7/logs/debug.log ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-08-11 21:58:06,022 INFO MainThread:79345 [wandb_setup.py:_flush():80] Current SDK version is 0.21.1
2
+ 2025-08-11 21:58:06,022 INFO MainThread:79345 [wandb_setup.py:_flush():80] Configure stats pid to 79345
3
+ 2025-08-11 21:58:06,022 INFO MainThread:79345 [wandb_setup.py:_flush():80] Loading settings from /root/.config/wandb/settings
4
+ 2025-08-11 21:58:06,022 INFO MainThread:79345 [wandb_setup.py:_flush():80] Loading settings from /nas/shared/kilab/wangyujia/BioReason_new/wandb/settings
5
+ 2025-08-11 21:58:06,022 INFO MainThread:79345 [wandb_setup.py:_flush():80] Loading settings from environment variables
6
+ 2025-08-11 21:58:06,022 INFO MainThread:79345 [wandb_init.py:setup_run_log_directory():703] Logging user logs to /nas/shared/kilab/wangyujia/BioReason_new/wandb/run-20250811_215805-k21eogb7/logs/debug.log
7
+ 2025-08-11 21:58:06,022 INFO MainThread:79345 [wandb_init.py:setup_run_log_directory():704] Logging internal logs to /nas/shared/kilab/wangyujia/BioReason_new/wandb/run-20250811_215805-k21eogb7/logs/debug-internal.log
8
+ 2025-08-11 21:58:06,022 INFO MainThread:79345 [wandb_init.py:init():830] calling init triggers
9
+ 2025-08-11 21:58:06,022 INFO MainThread:79345 [wandb_init.py:init():835] wandb.init called with sweep_config: {}
10
+ config: {'text_model_name': '/oss/wangyujia/BIO/construction_finetuning/alpaca/v1-20250609-141541/checkpoint-50-merged', 'protein_model_name': '/nas/shared/kilab/wangyujia/ProtT3/plm_model/esm2-150m', 'qformer_model_name': '/nas/shared/kilab/wangyujia/ProtT3/plm_model/microsoft', 'cache_dir': '/model-weights', 'num_query_tokens': 8, 'dataset_name': '/nas/shared/kilab/wangyujia/ProtT3/data/SwissProtV3/train_set.jsonl', 'eval_dataset': True, 'output_dir': './contrastive_outputs', 'num_epochs': 10, 'batch_size': 32, 'learning_rate': 0.0001, 'weight_decay': 0.01, 'warmup_steps': 1000, 'gradient_accumulation_steps': 1, 'temperature': 0.07, 'freeze_protein_model': True, 'freeze_text_model': True, 'protein_weight': 1.0, 'text_weight': 1.0, 'enable_ptm': True, 'ptm_weight': 1.0, 'max_length_protein': 1024, 'max_length_text': 512, 'num_workers': 8, 'logging_steps': 100, 'eval_steps': 500, 'save_steps': 1000, 'save_total_limit': 3, 'fp16': False, 'bf16': False, 'seed': 42, 'use_wandb': True, 'wandb_project': 'protein-llm-contrastive', 'wandb_entity': None, '_wandb': {}}
11
+ 2025-08-11 21:58:06,022 INFO MainThread:79345 [wandb_init.py:init():871] starting backend
12
+ 2025-08-11 21:58:06,233 INFO MainThread:79345 [wandb_init.py:init():874] sending inform_init request
13
+ 2025-08-11 21:58:06,259 INFO MainThread:79345 [wandb_init.py:init():882] backend started and connected
14
+ 2025-08-11 21:58:06,263 INFO MainThread:79345 [wandb_init.py:init():953] updated telemetry
15
+ 2025-08-11 21:58:06,327 INFO MainThread:79345 [wandb_init.py:init():977] communicating run to backend with 90.0 second timeout
16
+ 2025-08-11 21:58:13,781 INFO MainThread:79345 [wandb_init.py:init():1029] starting run threads in backend
17
+ 2025-08-11 21:58:13,892 INFO MainThread:79345 [wandb_run.py:_console_start():2494] atexit reg
18
+ 2025-08-11 21:58:13,892 INFO MainThread:79345 [wandb_run.py:_redirect():2342] redirect: wrap_raw
19
+ 2025-08-11 21:58:13,892 INFO MainThread:79345 [wandb_run.py:_redirect():2411] Wrapping output streams.
20
+ 2025-08-11 21:58:13,892 INFO MainThread:79345 [wandb_run.py:_redirect():2434] Redirects installed.
21
+ 2025-08-11 21:58:13,895 INFO MainThread:79345 [wandb_init.py:init():1075] run started, returning control to user process
22
+ 2025-08-11 21:58:17,421 INFO MsgRouterThr:79345 [mailbox.py:close():129] [no run ID] Closing mailbox, abandoning 2 handles.
BioReason_new/wandb/run-20250811_215805-k21eogb7/run-k21eogb7.wandb ADDED
Binary file (7.1 kB). View file
 
BioReason_new/wandb/run-20250811_220309-2qgjwsxa/files/config.yaml ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _wandb:
2
+ value:
3
+ cli_version: 0.21.1
4
+ e:
5
+ r185oiuz6xjarzg7yyfap3b9flv6ll88:
6
+ args:
7
+ - --text_model_name
8
+ - /oss/wangyujia/BIO/construction_finetuning/alpaca/v1-20250609-141541/checkpoint-50-merged
9
+ - --protein_model_name
10
+ - /nas/shared/kilab/wangyujia/ProtT3/plm_model/esm2-150m
11
+ - --qformer_model_name
12
+ - /nas/shared/kilab/wangyujia/ProtT3/plm_model/microsoft
13
+ - --num_query_tokens
14
+ - "8"
15
+ - --dataset_name
16
+ - /nas/shared/kilab/wangyujia/ProtT3/data/SwissProtV3/train_set.jsonl
17
+ - --output_dir
18
+ - ./contrastive_outputs
19
+ - --num_epochs
20
+ - "10"
21
+ - --batch_size
22
+ - "32"
23
+ - --learning_rate
24
+ - "1e-4"
25
+ - --temperature
26
+ - "0.07"
27
+ - --freeze_protein_model
28
+ - --freeze_text_model
29
+ - --enable_ptm
30
+ - --max_length_protein
31
+ - "1024"
32
+ - --max_length_text
33
+ - "512"
34
+ - --num_workers
35
+ - "8"
36
+ - --eval_dataset
37
+ - --use_wandb
38
+ - --wandb_project
39
+ - protein-llm-contrastive
40
+ - --logging_steps
41
+ - "100"
42
+ - --eval_steps
43
+ - "500"
44
+ - --save_steps
45
+ - "1000"
46
+ codePath: wangyujia/BioReason_new/train_contrastive.py
47
+ codePathLocal: train_contrastive.py
48
+ cpu_count: 64
49
+ cpu_count_logical: 64
50
+ cudaVersion: "12.1"
51
+ disk:
52
+ /:
53
+ total: "1623302262784"
54
+ used: "28193923072"
55
+ executable: /root/miniconda3/envs/bioreason/bin/python
56
+ git:
57
+ commit: b8caf406aa1699c788f0ca6e44a1769452c317db
58
+ remote: https://github.com/PorUna-byte/PAR.git
59
+ gpu: NVIDIA A800-SXM4-80GB
60
+ gpu_count: 8
61
+ gpu_nvidia:
62
+ - architecture: Ampere
63
+ name: NVIDIA A800-SXM4-80GB
64
+ uuid: GPU-71607f78-ad31-1ea4-19c1-908e3e31aaf1
65
+ - architecture: Ampere
66
+ name: NVIDIA A800-SXM4-80GB
67
+ uuid: GPU-92b7dbbd-7ef5-3c5f-ce1c-1d179d7fa587
68
+ - architecture: Ampere
69
+ name: NVIDIA A800-SXM4-80GB
70
+ uuid: GPU-bbc35439-ad79-578b-381b-aba6f0cc0168
71
+ - architecture: Ampere
72
+ name: NVIDIA A800-SXM4-80GB
73
+ uuid: GPU-e492e147-ca2e-76f2-85da-4e08e4deeb14
74
+ - architecture: Ampere
75
+ name: NVIDIA A800-SXM4-80GB
76
+ uuid: GPU-8c4f8e67-4b52-5107-3095-0f007e6378ac
77
+ - architecture: Ampere
78
+ name: NVIDIA A800-SXM4-80GB
79
+ uuid: GPU-7063f0b9-4ca2-6a72-522d-1262899ac5ad
80
+ - architecture: Ampere
81
+ name: NVIDIA A800-SXM4-80GB
82
+ uuid: GPU-3b6e9a37-bcf3-387c-7874-4f8de4abd115
83
+ - architecture: Ampere
84
+ name: NVIDIA A800-SXM4-80GB
85
+ uuid: GPU-92456839-e814-7be9-6817-f3e8da8aa80c
86
+ host: dsw-265304-f8bc5ff76-4mdt5
87
+ memory:
88
+ total: "549755813888"
89
+ os: Linux-5.10.134-008.18.kangaroo.al8.x86_64-x86_64-with-glibc2.35
90
+ program: /nas/shared/kilab/wangyujia/BioReason_new/train_contrastive.py
91
+ python: CPython 3.11.0
92
+ root: /nas/shared/kilab/wangyujia/BioReason_new
93
+ startedAt: "2025-08-11T14:03:09.687288Z"
94
+ writerId: r185oiuz6xjarzg7yyfap3b9flv6ll88
95
+ m: []
96
+ python_version: 3.11.0
97
+ t:
98
+ "1":
99
+ - 1
100
+ - 9
101
+ - 11
102
+ - 41
103
+ - 49
104
+ - 51
105
+ - 71
106
+ - 84
107
+ - 98
108
+ - 103
109
+ "2":
110
+ - 1
111
+ - 9
112
+ - 11
113
+ - 41
114
+ - 49
115
+ - 51
116
+ - 71
117
+ - 84
118
+ - 98
119
+ - 103
120
+ "3":
121
+ - 13
122
+ - 16
123
+ "4": 3.11.0
124
+ "5": 0.21.1
125
+ "6": 4.55.0
126
+ "12": 0.21.1
127
+ "13": linux-x86_64
128
+ batch_size:
129
+ value: 32
130
+ bf16:
131
+ value: false
132
+ cache_dir:
133
+ value: /model-weights
134
+ dataset_name:
135
+ value: /nas/shared/kilab/wangyujia/ProtT3/data/SwissProtV3/train_set.jsonl
136
+ enable_ptm:
137
+ value: true
138
+ eval_dataset:
139
+ value: true
140
+ eval_steps:
141
+ value: 500
142
+ fp16:
143
+ value: false
144
+ freeze_protein_model:
145
+ value: true
146
+ freeze_text_model:
147
+ value: true
148
+ gradient_accumulation_steps:
149
+ value: 1
150
+ learning_rate:
151
+ value: 0.0001
152
+ logging_steps:
153
+ value: 100
154
+ max_length_protein:
155
+ value: 1024
156
+ max_length_text:
157
+ value: 512
158
+ num_epochs:
159
+ value: 10
160
+ num_query_tokens:
161
+ value: 8
162
+ num_workers:
163
+ value: 8
164
+ output_dir:
165
+ value: ./contrastive_outputs
166
+ protein_model_name:
167
+ value: /nas/shared/kilab/wangyujia/ProtT3/plm_model/esm2-150m
168
+ protein_weight:
169
+ value: 1
170
+ ptm_weight:
171
+ value: 1
172
+ qformer_model_name:
173
+ value: /nas/shared/kilab/wangyujia/ProtT3/plm_model/microsoft
174
+ save_steps:
175
+ value: 1000
176
+ save_total_limit:
177
+ value: 3
178
+ seed:
179
+ value: 42
180
+ temperature:
181
+ value: 0.07
182
+ text_model_name:
183
+ value: /oss/wangyujia/BIO/construction_finetuning/alpaca/v1-20250609-141541/checkpoint-50-merged
184
+ text_weight:
185
+ value: 1
186
+ use_wandb:
187
+ value: true
188
+ wandb_entity:
189
+ value: null
190
+ wandb_project:
191
+ value: protein-llm-contrastive
192
+ warmup_steps:
193
+ value: 1000
194
+ weight_decay:
195
+ value: 0.01
BioReason_new/wandb/run-20250811_220309-2qgjwsxa/files/output.log ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Loading model...
2
+ Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 2.49it/s]
3
+ Some weights of EsmModel were not initialized from the model checkpoint at /nas/shared/kilab/wangyujia/ProtT3/plm_model/esm2-150m and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
4
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
5
+ Loading datasets...
6
+ Traceback (most recent call last):
7
+ File "/nas/shared/kilab/wangyujia/BioReason_new/train_contrastive.py", line 549, in <module>
8
+ trainer = main(args)
9
+ ^^^^^^^^^^
10
+ File "/nas/shared/kilab/wangyujia/BioReason_new/train_contrastive.py", line 317, in main
11
+ train_dataset = load_dataset("json", args.dataset_name, split="train")
12
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
13
+ File "/root/miniconda3/envs/bioreason/lib/python3.11/site-packages/datasets/load.py", line 1392, in load_dataset
14
+ builder_instance = load_dataset_builder(
15
+ ^^^^^^^^^^^^^^^^^^^^^
16
+ File "/root/miniconda3/envs/bioreason/lib/python3.11/site-packages/datasets/load.py", line 1166, in load_dataset_builder
17
+ builder_instance: DatasetBuilder = builder_cls(
18
+ ^^^^^^^^^^^^
19
+ File "/root/miniconda3/envs/bioreason/lib/python3.11/site-packages/datasets/builder.py", line 343, in __init__
20
+ self.config, self.config_id = self._create_builder_config(
21
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
22
+ File "/root/miniconda3/envs/bioreason/lib/python3.11/site-packages/datasets/builder.py", line 552, in _create_builder_config
23
+ builder_config = self.BUILDER_CONFIG_CLASS(**config_kwargs)
24
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
25
+ File "<string>", line 16, in __init__
26
+ File "/root/miniconda3/envs/bioreason/lib/python3.11/site-packages/datasets/packaged_modules/json/json.py", line 55, in __post_init__
27
+ super().__post_init__()
28
+ File "/root/miniconda3/envs/bioreason/lib/python3.11/site-packages/datasets/builder.py", line 126, in __post_init__
29
+ raise InvalidConfigName(
30
+ datasets.builder.InvalidConfigName: Bad characters from black list '<>:/\|?*' found in '/nas/shared/kilab/wangyujia/ProtT3/data/SwissProtV3/train_set.jsonl'. They could create issues when creating a directory for this config on Windows filesystem.
BioReason_new/wandb/run-20250811_220309-2qgjwsxa/files/requirements.txt ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ nvidia-nccl-cu12==2.26.2
2
+ cbor2==5.6.5
3
+ jupyter_server==2.16.0
4
+ nvidia-curand-cu12==10.3.7.77
5
+ bleach==6.2.0
6
+ py-cpuinfo==9.0.0
7
+ llvmlite==0.44.0
8
+ fsspec==2025.3.0
9
+ uvloop==0.21.0
10
+ rfc3986-validator==0.1.1
11
+ smmap==5.0.2
12
+ pip==25.1
13
+ compressed-tensors==0.10.2
14
+ ipython_pygments_lexers==1.1.1
15
+ fastapi-cli==0.0.8
16
+ filelock==3.18.0
17
+ msgspec==0.19.0
18
+ hjson==3.1.0
19
+ markdown-it-py==3.0.0
20
+ pyzmq==27.0.1
21
+ interegular==0.3.3
22
+ widgetsnbextension==4.0.14
23
+ vllm==0.10.0
24
+ ipykernel==6.30.1
25
+ pydantic==2.11.7
26
+ click==8.2.1
27
+ torchvision==0.22.1
28
+ fastapi-cloud-cli==0.1.5
29
+ httpcore==1.0.9
30
+ nvidia-cuda-nvrtc-cu12==12.6.77
31
+ mdurl==0.1.2
32
+ rich-toolkit==0.15.0
33
+ Pygments==2.19.2
34
+ pure_eval==0.2.3
35
+ types-python-dateutil==2.9.0.20250809
36
+ referencing==0.36.2
37
+ jupyterlab_widgets==3.0.15
38
+ typing-inspection==0.4.1
39
+ stack-data==0.6.3
40
+ jupyter_client==8.6.3
41
+ regex==2025.7.33
42
+ platformdirs==4.3.8
43
+ arrow==1.3.0
44
+ aiosignal==1.4.0
45
+ python-dateutil==2.9.0.post0
46
+ numpy==2.2.6
47
+ jupyter-lsp==2.2.6
48
+ transformers==4.55.0
49
+ mpmath==1.3.0
50
+ six==1.17.0
51
+ python-json-logger==3.3.0
52
+ distro==1.9.0
53
+ partial-json-parser==0.2.1.1.post6
54
+ bitsandbytes==0.46.1
55
+ nvidia-cusparselt-cu12==0.6.3
56
+ pandocfilters==1.5.1
57
+ pexpect==4.9.0
58
+ pydantic-extra-types==2.10.5
59
+ Jinja2==3.1.6
60
+ sentencepiece==0.2.0
61
+ uvicorn==0.35.0
62
+ babel==2.17.0
63
+ trl==0.21.0
64
+ urllib3==2.5.0
65
+ prometheus_client==0.22.1
66
+ watchfiles==1.1.0
67
+ prometheus-fastapi-instrumentator==7.1.0
68
+ jsonschema-specifications==2025.4.1
69
+ diskcache==5.6.3
70
+ webcolors==24.11.1
71
+ peft==0.17.0
72
+ jiter==0.10.0
73
+ triton==3.3.1
74
+ gitdb==4.0.12
75
+ gguf==0.17.1
76
+ safetensors==0.6.2
77
+ cloudpickle==3.1.1
78
+ multiprocess==0.70.16
79
+ aiohttp==3.12.15
80
+ tornado==6.5.2
81
+ nvidia-nvtx-cu12==12.6.77
82
+ nbclient==0.10.2
83
+ nbconvert==7.16.6
84
+ psutil==7.0.0
85
+ llguidance==0.7.30
86
+ ray==2.48.0
87
+ wcwidth==0.2.13
88
+ rignore==0.6.4
89
+ nvidia-cudnn-cu12==9.5.1.17
90
+ soupsieve==2.7
91
+ wandb==0.21.1
92
+ overrides==7.7.0
93
+ opencv-python-headless==4.12.0.88
94
+ pycparser==2.22
95
+ scipy==1.16.1
96
+ terminado==0.18.1
97
+ typer==0.16.0
98
+ parso==0.8.4
99
+ lark==1.2.2
100
+ msgpack==1.1.1
101
+ websockets==15.0.1
102
+ idna==3.10
103
+ fastrlock==0.8.3
104
+ jedi==0.19.2
105
+ accelerate==1.10.0
106
+ jupyter==1.1.1
107
+ beautifulsoup4==4.13.4
108
+ h11==0.16.0
109
+ MarkupSafe==3.0.2
110
+ python-dotenv==1.1.1
111
+ aiohappyeyeballs==2.6.1
112
+ rich==14.1.0
113
+ nbformat==5.10.4
114
+ traitlets==5.14.3
115
+ decorator==5.2.1
116
+ soxr==0.5.0.post1
117
+ propcache==0.3.2
118
+ ninja==1.11.1.4
119
+ cffi==1.17.1
120
+ cupy-cuda12x==13.5.1
121
+ pandas==2.3.1
122
+ deepspeed==0.17.4
123
+ setuptools==78.1.1
124
+ websocket-client==1.8.0
125
+ qwen-vl-utils==0.0.11
126
+ webencodings==0.5.1
127
+ httptools==0.6.4
128
+ jupyterlab==4.4.5
129
+ ptyprocess==0.7.0
130
+ shellingham==1.5.4
131
+ attrs==25.3.0
132
+ fqdn==1.5.1
133
+ huggingface-hub==0.34.4
134
+ tokenizers==0.21.4
135
+ asttokens==3.0.0
136
+ jupyter_server_terminals==0.5.3
137
+ av==15.0.0
138
+ nvidia-cuda-cupti-cu12==12.6.80
139
+ typing_extensions==4.14.1
140
+ hf-xet==1.1.7
141
+ jupyter_core==5.8.1
142
+ starlette==0.47.2
143
+ fastjsonschema==2.21.1
144
+ fastapi==0.116.1
145
+ lightning-utilities==0.15.2
146
+ jupyter-console==6.6.3
147
+ pybase64==1.4.2
148
+ jupyter-events==0.12.0
149
+ requests==2.32.4
150
+ numba==0.61.2
151
+ networkx==3.5
152
+ nvidia-cusparse-cu12==12.5.4.2
153
+ jsonpointer==3.0.0
154
+ pyarrow==21.0.0
155
+ dnspython==2.7.0
156
+ torchaudio==2.7.1
157
+ ipython==9.4.0
158
+ isoduration==20.11.0
159
+ bioreason==0.1.0
160
+ matplotlib-inline==0.1.7
161
+ packaging==25.0
162
+ xxhash==3.5.0
163
+ depyf==0.19.0
164
+ sentry-sdk==2.34.1
165
+ prompt_toolkit==3.0.51
166
+ nvidia-cublas-cu12==12.6.4.1
167
+ rfc3339-validator==0.1.4
168
+ nvidia-cufft-cu12==11.3.0.4
169
+ email_validator==2.2.0
170
+ pycountry==24.6.1
171
+ argon2-cffi==25.1.0
172
+ nvidia-cufile-cu12==1.11.1.6
173
+ frozenlist==1.7.0
174
+ json5==0.12.0
175
+ tinycss2==1.4.0
176
+ defusedxml==0.7.1
177
+ lm-format-enforcer==0.10.12
178
+ Send2Trash==1.8.3
179
+ anyio==4.10.0
180
+ rfc3987-syntax==1.1.0
181
+ pydantic_core==2.33.2
182
+ debugpy==1.8.16
183
+ async-lru==2.0.5
184
+ nvidia-cuda-runtime-cu12==12.6.77
185
+ tiktoken==0.11.0
186
+ comm==0.2.3
187
+ PyYAML==6.0.2
188
+ blake3==1.0.5
189
+ nvidia-cusolver-cu12==11.7.1.2
190
+ torch==2.7.1
191
+ torchmetrics==1.8.1
192
+ yarl==1.20.1
193
+ dill==0.3.8
194
+ wheel==0.45.1
195
+ cachetools==6.1.0
196
+ multidict==6.6.3
197
+ pytz==2025.2
198
+ pillow==11.3.0
199
+ annotated-types==0.7.0
200
+ astor==0.8.1
201
+ nest-asyncio==1.6.0
202
+ httpx==0.28.1
203
+ argon2-cffi-bindings==25.1.0
204
+ notebook_shim==0.2.4
205
+ jsonschema==4.25.0
206
+ python-multipart==0.0.20
207
+ charset-normalizer==3.4.3
208
+ tqdm==4.67.1
209
+ xformers==0.0.31
210
+ tzdata==2025.2
211
+ einops==0.8.1
212
+ mistral_common==1.8.3
213
+ jupyterlab_server==2.27.3
214
+ sympy==1.14.0
215
+ datasets==4.0.0
216
+ GitPython==3.1.45
217
+ mistune==3.1.3
218
+ ipywidgets==8.1.7
219
+ nvidia-ml-py==13.580.65
220
+ uri-template==1.3.0
221
+ notebook==7.4.5
222
+ certifi==2025.8.3
223
+ nvidia-nvjitlink-cu12==12.6.85
224
+ openai==1.90.0
225
+ xgrammar==0.1.21
226
+ executing==2.2.0
227
+ soundfile==0.13.1
228
+ jupyterlab_pygments==0.3.0
229
+ outlines_core==0.2.10
230
+ sniffio==1.3.1
231
+ pytorch-lightning==2.5.2
232
+ rpds-py==0.27.0
233
+ protobuf==6.31.1
BioReason_new/wandb/run-20250811_220309-2qgjwsxa/files/wandb-metadata.json ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.10.134-008.18.kangaroo.al8.x86_64-x86_64-with-glibc2.35",
3
+ "python": "CPython 3.11.0",
4
+ "startedAt": "2025-08-11T14:03:09.687288Z",
5
+ "args": [
6
+ "--text_model_name",
7
+ "/oss/wangyujia/BIO/construction_finetuning/alpaca/v1-20250609-141541/checkpoint-50-merged",
8
+ "--protein_model_name",
9
+ "/nas/shared/kilab/wangyujia/ProtT3/plm_model/esm2-150m",
10
+ "--qformer_model_name",
11
+ "/nas/shared/kilab/wangyujia/ProtT3/plm_model/microsoft",
12
+ "--num_query_tokens",
13
+ "8",
14
+ "--dataset_name",
15
+ "/nas/shared/kilab/wangyujia/ProtT3/data/SwissProtV3/train_set.jsonl",
16
+ "--output_dir",
17
+ "./contrastive_outputs",
18
+ "--num_epochs",
19
+ "10",
20
+ "--batch_size",
21
+ "32",
22
+ "--learning_rate",
23
+ "1e-4",
24
+ "--temperature",
25
+ "0.07",
26
+ "--freeze_protein_model",
27
+ "--freeze_text_model",
28
+ "--enable_ptm",
29
+ "--max_length_protein",
30
+ "1024",
31
+ "--max_length_text",
32
+ "512",
33
+ "--num_workers",
34
+ "8",
35
+ "--eval_dataset",
36
+ "--use_wandb",
37
+ "--wandb_project",
38
+ "protein-llm-contrastive",
39
+ "--logging_steps",
40
+ "100",
41
+ "--eval_steps",
42
+ "500",
43
+ "--save_steps",
44
+ "1000"
45
+ ],
46
+ "program": "/nas/shared/kilab/wangyujia/BioReason_new/train_contrastive.py",
47
+ "codePath": "wangyujia/BioReason_new/train_contrastive.py",
48
+ "codePathLocal": "train_contrastive.py",
49
+ "git": {
50
+ "remote": "https://github.com/PorUna-byte/PAR.git",
51
+ "commit": "b8caf406aa1699c788f0ca6e44a1769452c317db"
52
+ },
53
+ "root": "/nas/shared/kilab/wangyujia/BioReason_new",
54
+ "host": "dsw-265304-f8bc5ff76-4mdt5",
55
+ "executable": "/root/miniconda3/envs/bioreason/bin/python",
56
+ "cpu_count": 64,
57
+ "cpu_count_logical": 64,
58
+ "gpu": "NVIDIA A800-SXM4-80GB",
59
+ "gpu_count": 8,
60
+ "disk": {
61
+ "/": {
62
+ "total": "1623302262784",
63
+ "used": "28193923072"
64
+ }
65
+ },
66
+ "memory": {
67
+ "total": "549755813888"
68
+ },
69
+ "gpu_nvidia": [
70
+ {
71
+ "name": "NVIDIA A800-SXM4-80GB",
72
+ "architecture": "Ampere",
73
+ "uuid": "GPU-71607f78-ad31-1ea4-19c1-908e3e31aaf1"
74
+ },
75
+ {
76
+ "name": "NVIDIA A800-SXM4-80GB",
77
+ "architecture": "Ampere",
78
+ "uuid": "GPU-92b7dbbd-7ef5-3c5f-ce1c-1d179d7fa587"
79
+ },
80
+ {
81
+ "name": "NVIDIA A800-SXM4-80GB",
82
+ "architecture": "Ampere",
83
+ "uuid": "GPU-bbc35439-ad79-578b-381b-aba6f0cc0168"
84
+ },
85
+ {
86
+ "name": "NVIDIA A800-SXM4-80GB",
87
+ "architecture": "Ampere",
88
+ "uuid": "GPU-e492e147-ca2e-76f2-85da-4e08e4deeb14"
89
+ },
90
+ {
91
+ "name": "NVIDIA A800-SXM4-80GB",
92
+ "architecture": "Ampere",
93
+ "uuid": "GPU-8c4f8e67-4b52-5107-3095-0f007e6378ac"
94
+ },
95
+ {
96
+ "name": "NVIDIA A800-SXM4-80GB",
97
+ "architecture": "Ampere",
98
+ "uuid": "GPU-7063f0b9-4ca2-6a72-522d-1262899ac5ad"
99
+ },
100
+ {
101
+ "name": "NVIDIA A800-SXM4-80GB",
102
+ "architecture": "Ampere",
103
+ "uuid": "GPU-3b6e9a37-bcf3-387c-7874-4f8de4abd115"
104
+ },
105
+ {
106
+ "name": "NVIDIA A800-SXM4-80GB",
107
+ "architecture": "Ampere",
108
+ "uuid": "GPU-92456839-e814-7be9-6817-f3e8da8aa80c"
109
+ }
110
+ ],
111
+ "cudaVersion": "12.1",
112
+ "writerId": "r185oiuz6xjarzg7yyfap3b9flv6ll88"
113
+ }