yuccaaa commited on
Commit
4d12519
·
verified ·
1 Parent(s): a17e46e

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +15 -0
  2. all_checkpoints/stage2_08011616_2datasets_withoutpretrain/last.ckpt/checkpoint/mp_rank_00_model_states.pt +3 -0
  3. all_checkpoints/stage2_08011616_2datasets_withoutpretrain/wandb/run-20250801_162632-790wkw0g/files/output.log +0 -0
  4. all_checkpoints/stage2_08011616_2datasets_withoutpretrain/wandb/run-20250801_162632-790wkw0g/files/requirements.txt +225 -0
  5. all_checkpoints/stage2_08011616_2datasets_withoutpretrain/wandb/run-20250801_162632-790wkw0g/files/wandb-metadata.json +108 -0
  6. all_checkpoints/stage2_08011616_2datasets_withoutpretrain/wandb/run-20250801_162632-790wkw0g/files/wandb-summary.json +1 -0
  7. all_checkpoints/stage2_08011616_2datasets_withoutpretrain/wandb/run-20250801_162632-790wkw0g/logs/debug-internal.log +166 -0
  8. all_checkpoints/stage2_08011616_2datasets_withoutpretrain/wandb/run-20250801_162632-790wkw0g/logs/debug.log +24 -0
  9. all_checkpoints/stage2_08011616_2datasets_withoutpretrain/wandb/run-20250801_162632-790wkw0g/run-790wkw0g.wandb +3 -0
  10. data/.gitignore +2 -0
  11. data/OntoProteinDatasetV2/test.txt +3 -0
  12. data/OntoProteinDatasetV2/train.txt +3 -0
  13. data/OntoProteinDatasetV2/valid.txt +3 -0
  14. data/PDBDataset/abstract.json +3 -0
  15. data/PDBDataset/q_types.txt +30 -0
  16. data/PDBDataset/qa_all.json +3 -0
  17. data/PDBDataset/test.txt +0 -0
  18. data/PDBDataset/train.txt +3 -0
  19. data/PDBDataset/val.txt +0 -0
  20. data/SwissProtV3/test_set.jsonl +0 -0
  21. data/SwissProtV3/train_set.jsonl +3 -0
  22. data/SwissProtV3/valid_set.jsonl +0 -0
  23. data/protein-molecule/protein-text.zip +3 -0
  24. data/protein-text/eval_assist.zipg3ebgjl7.tmp +3 -0
  25. data/protein-text/eval_assist.ziphwjr8q2y.tmp +3 -0
  26. data/protein-text/eval_assist.zipzh1pdmj_.tmp +3 -0
  27. data_provider/__pycache__/bindingdb.cpython-310.pyc +0 -0
  28. data_provider/__pycache__/go.cpython-310.pyc +0 -0
  29. data_provider/__pycache__/metalIonbinding.cpython-310.pyc +0 -0
  30. data_provider/__pycache__/mutation.cpython-310.pyc +0 -0
  31. data_provider/__pycache__/production.cpython-310.pyc +0 -0
  32. data_provider/__pycache__/prot_qa_dm.cpython-310.pyc +0 -0
  33. data_provider/__pycache__/prot_qa_dm.cpython-311.pyc +0 -0
  34. data_provider/__pycache__/stage1_dm.cpython-310.pyc +0 -0
  35. data_provider/__pycache__/stage1_dm.cpython-311.pyc +0 -0
  36. data_provider/__pycache__/stage2_dm.cpython-310.pyc +0 -0
  37. data_provider/__pycache__/stage3_dm.cpython-310.pyc +0 -0
  38. data_provider/__pycache__/stage3_dm.cpython-311.pyc +0 -0
  39. data_provider/bindingdb.py +62 -0
  40. data_provider/gal_helpers.py +45 -0
  41. data_provider/go.py +237 -0
  42. data_provider/llm_tuning_dm.py +261 -0
  43. data_provider/llm_tuning_prot_qa_dm.py +164 -0
  44. data_provider/metalIonbinding.py +63 -0
  45. data_provider/mutation.py +119 -0
  46. data_provider/production.py +237 -0
  47. data_provider/prot_qa_dm.py +299 -0
  48. data_provider/proteinchat_dm.py +254 -0
  49. data_provider/stage1_dm.py +539 -0
  50. data_provider/stage2_dm.py +386 -0
.gitattributes CHANGED
@@ -39,3 +39,18 @@ all_checkpoints/stage2_07070337_2datasets_noconstruct/wandb/run-20250707_041231-
39
  all_checkpoints/stage2_07070513_2datasets_construct/wandb/run-20250707_052104-615z4bme/run-615z4bme.wandb filter=lfs diff=lfs merge=lfs -text
40
  all_checkpoints/stage2_07070513_2datasets_construct/wandb/run-20250707_053222-9cjzn0v3/run-9cjzn0v3.wandb filter=lfs diff=lfs merge=lfs -text
41
  all_checkpoints/stage2_07301646_2datasets_construct/wandb/run-20250730_175623-pbf2bxo6/run-pbf2bxo6.wandb filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  all_checkpoints/stage2_07070513_2datasets_construct/wandb/run-20250707_052104-615z4bme/run-615z4bme.wandb filter=lfs diff=lfs merge=lfs -text
40
  all_checkpoints/stage2_07070513_2datasets_construct/wandb/run-20250707_053222-9cjzn0v3/run-9cjzn0v3.wandb filter=lfs diff=lfs merge=lfs -text
41
  all_checkpoints/stage2_07301646_2datasets_construct/wandb/run-20250730_175623-pbf2bxo6/run-pbf2bxo6.wandb filter=lfs diff=lfs merge=lfs -text
42
+ data/OntoProteinDatasetV2/test.txt filter=lfs diff=lfs merge=lfs -text
43
+ all_checkpoints/stage2_08011616_2datasets_withoutpretrain/wandb/run-20250801_162632-790wkw0g/run-790wkw0g.wandb filter=lfs diff=lfs merge=lfs -text
44
+ data/OntoProteinDatasetV2/train.txt filter=lfs diff=lfs merge=lfs -text
45
+ data/OntoProteinDatasetV2/valid.txt filter=lfs diff=lfs merge=lfs -text
46
+ data/PDBDataset/abstract.json filter=lfs diff=lfs merge=lfs -text
47
+ data/PDBDataset/qa_all.json filter=lfs diff=lfs merge=lfs -text
48
+ data/PDBDataset/train.txt filter=lfs diff=lfs merge=lfs -text
49
+ data/SwissProtV3/train_set.jsonl filter=lfs diff=lfs merge=lfs -text
50
+ data/protein-text/eval_assist.ziphwjr8q2y.tmp filter=lfs diff=lfs merge=lfs -text
51
+ data/protein-text/eval_assist.zipg3ebgjl7.tmp filter=lfs diff=lfs merge=lfs -text
52
+ data/protein-text/eval_assist.zipzh1pdmj_.tmp filter=lfs diff=lfs merge=lfs -text
53
+ data_small/OntoProteinDatasetV2/train.txt filter=lfs diff=lfs merge=lfs -text
54
+ data_small/PDBDataset/abstract.json filter=lfs diff=lfs merge=lfs -text
55
+ data_small/PDBDataset/qa_all.json filter=lfs diff=lfs merge=lfs -text
56
+ data_small/SwissProtV3/train_set_.jsonl filter=lfs diff=lfs merge=lfs -text
all_checkpoints/stage2_08011616_2datasets_withoutpretrain/last.ckpt/checkpoint/mp_rank_00_model_states.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6efbc882065731ceb2c9886091da92484e90352ef036f1fff44e77700ff80f41
3
+ size 208795384
all_checkpoints/stage2_08011616_2datasets_withoutpretrain/wandb/run-20250801_162632-790wkw0g/files/output.log ADDED
The diff for this file is too large to render. See raw diff
 
all_checkpoints/stage2_08011616_2datasets_withoutpretrain/wandb/run-20250801_162632-790wkw0g/files/requirements.txt ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ opendatasets==0.1.22
2
+ salesforce-lavis==1.0.2
3
+ Pygments==2.19.1
4
+ nvidia-nccl-cu12==2.21.5
5
+ tornado==6.5.1
6
+ nvidia-cuda-runtime-cu12==12.4.127
7
+ requests==2.32.3
8
+ nvidia-cuda-cupti-cu12==12.4.127
9
+ decord==0.6.0
10
+ braceexpand==0.1.7
11
+ frozenlist==1.6.0
12
+ markdown-it-py==3.0.0
13
+ shellingham==1.5.4
14
+ absl-py==2.2.2
15
+ pycocoevalcap==1.2
16
+ contexttimer==0.3.3
17
+ bleach==6.2.0
18
+ jsonschema-specifications==2025.4.1
19
+ pycocotools==2.0.8
20
+ python-slugify==8.0.4
21
+ tqdm==4.67.1
22
+ numpy==2.2.6
23
+ urllib3==2.4.0
24
+ deepspeed==0.16.10+b666844f
25
+ watchdog==6.0.0
26
+ wrapt==1.17.2
27
+ setuptools==78.1.1
28
+ matplotlib==3.10.3
29
+ pydeck==0.9.1
30
+ aiosignal==1.3.2
31
+ gitdb==4.0.12
32
+ hjson==3.1.0
33
+ timm==0.4.12
34
+ blis==1.3.0
35
+ PyYAML==6.0.2
36
+ referencing==0.36.2
37
+ contourpy==1.3.2
38
+ kaggle==1.7.4.5
39
+ triton==3.2.0
40
+ catalogue==2.0.10
41
+ idna==3.10
42
+ torch==2.6.0
43
+ text-unidecode==1.3
44
+ altair==5.5.0
45
+ cloudpathlib==0.21.1
46
+ protobuf==6.31.0
47
+ nvidia-cusolver-cu12==11.6.1.9
48
+ pytz==2025.2
49
+ sympy==1.13.1
50
+ spacy==3.8.7
51
+ MarkupSafe==3.0.2
52
+ thinc==8.3.6
53
+ nvidia-cudnn-cu12==9.1.0.70
54
+ wasabi==1.1.3
55
+ aiohappyeyeballs==2.6.1
56
+ nvidia-nvtx-cu12==12.4.127
57
+ rich==14.0.0
58
+ ipython==8.36.0
59
+ yarl==1.20.0
60
+ torchmetrics==1.7.1
61
+ multidict==6.4.4
62
+ cfgv==3.4.0
63
+ smmap==5.0.2
64
+ srsly==2.5.1
65
+ scikit-image==0.25.2
66
+ matplotlib-inline==0.1.7
67
+ annotated-types==0.7.0
68
+ lazy_loader==0.4
69
+ tenacity==9.1.2
70
+ GitPython==3.1.44
71
+ language_data==1.3.0
72
+ pydantic_core==2.33.2
73
+ sentencepiece==0.2.0
74
+ platformdirs==4.3.8
75
+ distlib==0.3.9
76
+ nvidia-cusparselt-cu12==0.6.2
77
+ blinker==1.9.0
78
+ regex==2024.11.6
79
+ tifffile==2025.5.10
80
+ py-cpuinfo==9.0.0
81
+ attrs==25.3.0
82
+ mdurl==0.1.2
83
+ prompt_toolkit==3.0.51
84
+ packaging==24.2
85
+ async-timeout==5.0.1
86
+ six==1.17.0
87
+ executing==2.2.0
88
+ parso==0.8.4
89
+ omegaconf==2.3.0
90
+ wcwidth==0.2.13
91
+ murmurhash==1.0.13
92
+ stack-data==0.6.3
93
+ nvidia-cufft-cu12==11.2.1.3
94
+ virtualenv==20.31.2
95
+ langcodes==3.5.0
96
+ fonttools==4.58.0
97
+ opencv-python-headless==4.5.5.64
98
+ jedi==0.19.2
99
+ torchvision==0.21.0
100
+ plotly==6.1.1
101
+ nodeenv==1.9.1
102
+ smart-open==7.1.0
103
+ toml==0.10.2
104
+ pytorch-lightning==2.5.1.post0
105
+ typing_extensions==4.13.2
106
+ safetensors==0.5.3
107
+ psutil==7.0.0
108
+ pillow==11.2.1
109
+ python-dateutil==2.9.0.post0
110
+ ftfy==6.3.1
111
+ scipy==1.15.3
112
+ webdataset==0.2.111
113
+ charset-normalizer==3.4.2
114
+ nvidia-nvjitlink-cu12==12.4.127
115
+ kiwisolver==1.4.8
116
+ nvidia-ml-py==12.575.51
117
+ confection==0.1.5
118
+ nvidia-curand-cu12==10.3.5.147
119
+ pandas==2.2.3
120
+ nltk==3.9.1
121
+ webencodings==0.5.1
122
+ pyarrow==20.0.0
123
+ asttokens==3.0.0
124
+ exceptiongroup==1.3.0
125
+ pre_commit==4.2.0
126
+ ninja==1.11.1.4
127
+ spacy-loggers==1.0.5
128
+ msgpack==1.1.0
129
+ lightning-utilities==0.14.3
130
+ nvidia-cublas-cu12==12.4.5.8
131
+ tzdata==2025.2
132
+ cycler==0.12.1
133
+ hf-xet==1.1.2
134
+ antlr4-python3-runtime==4.9.3
135
+ iopath==0.1.10
136
+ pexpect==4.9.0
137
+ imageio==2.37.0
138
+ streamlit==1.45.1
139
+ python-magic==0.4.27
140
+ networkx==3.4.2
141
+ portalocker==3.1.1
142
+ nvidia-cusparse-cu12==12.3.1.170
143
+ propcache==0.3.1
144
+ ptyprocess==0.7.0
145
+ fairscale==0.4.4
146
+ rpds-py==0.25.1
147
+ certifi==2025.4.26
148
+ rouge_score==0.1.2
149
+ traitlets==5.14.3
150
+ identify==2.6.12
151
+ spacy-legacy==3.0.12
152
+ weasel==0.4.1
153
+ mpmath==1.3.0
154
+ cymem==2.0.11
155
+ typing-inspection==0.4.1
156
+ nvidia-cuda-nvrtc-cu12==12.4.127
157
+ marisa-trie==1.2.1
158
+ einops==0.8.1
159
+ nvidia-cufile-cu12==1.11.1.6
160
+ pydantic==2.11.5
161
+ cachetools==5.5.2
162
+ joblib==1.5.1
163
+ Jinja2==3.1.6
164
+ filelock==3.18.0
165
+ pyparsing==3.2.3
166
+ pure_eval==0.2.3
167
+ decorator==5.2.1
168
+ wheel==0.45.1
169
+ pycryptodome==3.23.0
170
+ cheroot==10.0.1
171
+ multiprocess==0.70.16
172
+ aiohttp==3.12.2
173
+ crcmod==1.7
174
+ fsspec==2025.3.0
175
+ jmespath==0.10.0
176
+ preshed==3.0.10
177
+ jaraco.functools==4.1.0
178
+ cryptography==45.0.3
179
+ sentry-sdk==2.29.1
180
+ tokenizers==0.21.1
181
+ opendelta==0.3.2
182
+ pycparser==2.22
183
+ narwhals==1.41.0
184
+ scikit-learn==1.6.1
185
+ dill==0.3.8
186
+ oss2==2.15.0
187
+ yacs==0.1.8
188
+ more-itertools==10.7.0
189
+ pip==25.1.1
190
+ threadpoolctl==3.6.0
191
+ flash-attn==2.7.1.post1
192
+ bigmodelvis==0.0.1
193
+ pathlib==1.0.1
194
+ delta-center-client==0.0.4
195
+ xxhash==3.5.0
196
+ wandb==0.19.11
197
+ setproctitle==1.3.6
198
+ aliyun-python-sdk-core==2.16.0
199
+ transformers==4.52.3
200
+ aliyun-python-sdk-kms==2.16.5
201
+ datasets==3.6.0
202
+ typer==0.16.0
203
+ docker-pycreds==0.4.0
204
+ click==8.2.1
205
+ huggingface-hub==0.32.1
206
+ web.py==0.62
207
+ cffi==1.17.1
208
+ opencv-python==4.11.0.86
209
+ jsonschema==4.24.0
210
+ typing_extensions==4.12.2
211
+ jaraco.functools==4.0.1
212
+ jaraco.text==3.12.1
213
+ jaraco.collections==5.1.0
214
+ inflect==7.3.1
215
+ more-itertools==10.3.0
216
+ packaging==24.2
217
+ importlib_metadata==8.0.0
218
+ backports.tarfile==1.2.0
219
+ typeguard==4.3.0
220
+ zipp==3.19.2
221
+ platformdirs==4.2.2
222
+ autocommand==2.2.2
223
+ jaraco.context==5.3.0
224
+ tomli==2.0.1
225
+ wheel==0.45.1
all_checkpoints/stage2_08011616_2datasets_withoutpretrain/wandb/run-20250801_162632-790wkw0g/files/wandb-metadata.json ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.10.134-008.16.kangaroo.al8.x86_64-x86_64-with-glibc2.35",
3
+ "python": "CPython 3.10.0",
4
+ "startedAt": "2025-08-01T08:26:32.241935Z",
5
+ "args": [
6
+ "--devices",
7
+ "0,1,2,3,4,5,6,7",
8
+ "--mode",
9
+ "train",
10
+ "--filename",
11
+ "stage2_08011616_2datasets_qweninstruct",
12
+ "--num_query_token",
13
+ "8",
14
+ "--save_every_n_epochs",
15
+ "2",
16
+ "--max_epochs",
17
+ "10",
18
+ "--batch_size",
19
+ "4",
20
+ "--precision",
21
+ "bf16-mixed",
22
+ "--num_workers",
23
+ "8",
24
+ "--plm_model",
25
+ "/nas/shared/kilab/wangyujia/ProtT3/plm_model/esm2-150m",
26
+ "--bert_name",
27
+ "/nas/shared/kilab/wangyujia/ProtT3/plm_model/microsoft",
28
+ "--llm_name",
29
+ "/nas/shared/kilab/hf-hub/models--Qwen--Qwen2.5-7B-Instruct/snapshots/a09a35458c702b33eeacc393d103063234e8bc28",
30
+ "--llm_tune",
31
+ "mid_lora",
32
+ "--stage1_path",
33
+ "/nas/shared/kilab/wangyujia/ProtT3/all_checkpoints/stage1_07041727_2dataset/epoch=29.ckpt/converted.ckpt",
34
+ "--use_wandb_logger",
35
+ "--dataset",
36
+ "swiss-prot"
37
+ ],
38
+ "program": "/nas/shared/kilab/wangyujia/ProtT3/stage2.py",
39
+ "codePath": "wangyujia/ProtT3/stage2.py",
40
+ "git": {
41
+ "remote": "https://github.com/PorUna-byte/PAR.git",
42
+ "commit": "b8caf406aa1699c788f0ca6e44a1769452c317db"
43
+ },
44
+ "root": "./all_checkpoints/stage2_08011616_2datasets_qweninstruct/",
45
+ "host": "dsw-265304-58fbcf9d9b-zvtdx",
46
+ "executable": "/root/miniconda3/envs/protT3/bin/python",
47
+ "codePathLocal": "stage2.py",
48
+ "cpu_count": 64,
49
+ "cpu_count_logical": 64,
50
+ "gpu": "NVIDIA A800-SXM4-80GB",
51
+ "gpu_count": 8,
52
+ "disk": {
53
+ "/": {
54
+ "total": "1623302262784",
55
+ "used": "1476915200"
56
+ }
57
+ },
58
+ "memory": {
59
+ "total": "549755813888"
60
+ },
61
+ "cpu": {
62
+ "count": 64,
63
+ "countLogical": 64
64
+ },
65
+ "gpu_nvidia": [
66
+ {
67
+ "name": "NVIDIA A800-SXM4-80GB",
68
+ "memoryTotal": "85198045184",
69
+ "architecture": "Ampere"
70
+ },
71
+ {
72
+ "name": "NVIDIA A800-SXM4-80GB",
73
+ "memoryTotal": "85198045184",
74
+ "architecture": "Ampere"
75
+ },
76
+ {
77
+ "name": "NVIDIA A800-SXM4-80GB",
78
+ "memoryTotal": "85198045184",
79
+ "architecture": "Ampere"
80
+ },
81
+ {
82
+ "name": "NVIDIA A800-SXM4-80GB",
83
+ "memoryTotal": "85198045184",
84
+ "architecture": "Ampere"
85
+ },
86
+ {
87
+ "name": "NVIDIA A800-SXM4-80GB",
88
+ "memoryTotal": "85198045184",
89
+ "architecture": "Ampere"
90
+ },
91
+ {
92
+ "name": "NVIDIA A800-SXM4-80GB",
93
+ "memoryTotal": "85198045184",
94
+ "architecture": "Ampere"
95
+ },
96
+ {
97
+ "name": "NVIDIA A800-SXM4-80GB",
98
+ "memoryTotal": "85198045184",
99
+ "architecture": "Ampere"
100
+ },
101
+ {
102
+ "name": "NVIDIA A800-SXM4-80GB",
103
+ "memoryTotal": "85198045184",
104
+ "architecture": "Ampere"
105
+ }
106
+ ],
107
+ "cudaVersion": "12.1"
108
+ }
all_checkpoints/stage2_08011616_2datasets_withoutpretrain/wandb/run-20250801_162632-790wkw0g/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"trainer/global_step":134559,"dataset0/rouge_2":37.98429870605469,"_step":2700,"dataset0/bleu2":37.878440856933594,"dataset0/meteor_score":53.153656005859375,"lr":1.2202456673549023e-05,"dataset0/rouge_1":45.05393981933594,"_wandb":{"runtime":61483},"_timestamp":1.7540982570645058e+09,"dataset0/rouge_l":42.82152557373047,"dataset0/bleu4":34.96389389038086,"_runtime":61467.758475346,"epoch":9,"loss":0.37739327549934387,"dataloader0/val loss/dataloader_idx_0":0.6249951124191284,"dataset0/acc":0}
all_checkpoints/stage2_08011616_2datasets_withoutpretrain/wandb/run-20250801_162632-790wkw0g/logs/debug-internal.log ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"time":"2025-08-01T16:26:32.258940024+08:00","level":"INFO","msg":"stream: starting","core version":"0.19.11","symlink path":"all_checkpoints/stage2_08011616_2datasets_qweninstruct/wandb/run-20250801_162632-790wkw0g/logs/debug-core.log"}
2
+ {"time":"2025-08-01T16:26:47.100144119+08:00","level":"INFO","msg":"created new stream","id":"790wkw0g"}
3
+ {"time":"2025-08-01T16:26:47.101616008+08:00","level":"INFO","msg":"stream: started","id":"790wkw0g"}
4
+ {"time":"2025-08-01T16:26:47.101680831+08:00","level":"INFO","msg":"sender: started","stream_id":"790wkw0g"}
5
+ {"time":"2025-08-01T16:26:47.101644151+08:00","level":"INFO","msg":"writer: Do: started","stream_id":"790wkw0g"}
6
+ {"time":"2025-08-01T16:26:47.101748992+08:00","level":"INFO","msg":"handler: started","stream_id":"790wkw0g"}
7
+ {"time":"2025-08-01T16:26:50.002107204+08:00","level":"INFO","msg":"Starting system monitor"}
8
+ {"time":"2025-08-01T16:31:09.07203296+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:36310->104.21.20.172:443: read: connection timed out"}
9
+ {"time":"2025-08-01T16:38:20.137992503+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
10
+ {"time":"2025-08-01T16:41:40.88100016+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:41026->104.21.20.172:443: read: connection timed out"}
11
+ {"time":"2025-08-01T16:43:35.141344346+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
12
+ {"time":"2025-08-01T16:44:07.544267418+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
13
+ {"time":"2025-08-01T16:44:41.670680082+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
14
+ {"time":"2025-08-01T16:45:20.186411207+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
15
+ {"time":"2025-08-01T16:45:57.904074825+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:57230->172.67.193.61:443: read: connection timed out"}
16
+ {"time":"2025-08-01T16:46:09.248932426+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
17
+ {"time":"2025-08-01T16:50:22.097023685+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:45466->104.21.20.172:443: read: connection timed out"}
18
+ {"time":"2025-08-01T17:44:15.888009977+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:47510->104.21.20.172:443: read: connection timed out"}
19
+ {"time":"2025-08-01T20:05:52.724089483+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:49218->104.21.20.172:443: read: connection reset by peer"}
20
+ {"time":"2025-08-01T20:13:58.928063914+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:46332->104.21.20.172:443: read: connection timed out"}
21
+ {"time":"2025-08-01T20:27:15.089020249+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:51288->104.21.20.172:443: read: connection timed out"}
22
+ {"time":"2025-08-01T20:33:59.569042011+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:38854->172.67.193.61:443: read: connection timed out"}
23
+ {"time":"2025-08-01T20:34:32.461391926+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": unexpected EOF"}
24
+ {"time":"2025-08-01T20:38:30.928056966+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:45810->104.21.20.172:443: read: connection timed out"}
25
+ {"time":"2025-08-01T20:43:03.824002354+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:51984->104.21.20.172:443: read: connection timed out"}
26
+ {"time":"2025-08-01T20:46:25.450089946+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:42776->172.67.193.61:443: read: connection reset by peer"}
27
+ {"time":"2025-08-01T20:49:05.052996031+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:54072->104.21.20.172:443: read: connection reset by peer"}
28
+ {"time":"2025-08-01T20:49:50.926181978+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
29
+ {"time":"2025-08-01T20:50:58.961126603+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:46418->104.21.20.172:443: read: connection reset by peer"}
30
+ {"time":"2025-08-01T20:54:00.208036179+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:47556->104.21.20.172:443: read: connection timed out"}
31
+ {"time":"2025-08-01T21:03:33.648034502+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:45288->104.21.20.172:443: read: connection timed out"}
32
+ {"time":"2025-08-01T21:06:14.416001149+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:39640->172.67.193.61:443: read: connection timed out"}
33
+ {"time":"2025-08-01T21:06:50.935397642+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
34
+ {"time":"2025-08-01T21:07:22.991031701+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
35
+ {"time":"2025-08-01T21:07:57.40524595+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
36
+ {"time":"2025-08-01T21:08:36.813512769+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
37
+ {"time":"2025-08-01T21:09:24.678708826+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
38
+ {"time":"2025-08-01T21:10:33.369865787+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
39
+ {"time":"2025-08-01T21:13:09.331256642+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:55554->104.21.20.172:443: read: connection reset by peer"}
40
+ {"time":"2025-08-01T21:18:12.321384192+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": unexpected EOF"}
41
+ {"time":"2025-08-01T21:21:14.379467148+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
42
+ {"time":"2025-08-01T21:24:32.868908249+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": context deadline exceeded"}
43
+ {"time":"2025-08-01T21:28:13.328011263+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:53294->104.21.20.172:443: read: connection timed out"}
44
+ {"time":"2025-08-01T21:30:01.911292424+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:45656->172.67.193.61:443: read: connection reset by peer"}
45
+ {"time":"2025-08-01T21:31:05.951614465+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
46
+ {"time":"2025-08-01T21:32:43.664038993+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:58998->172.67.193.61:443: read: connection timed out"}
47
+ {"time":"2025-08-01T21:33:35.953483769+08:00","level":"ERROR","msg":"sender: sendStopStatus: failed to get run stopped status: context deadline exceeded (Client.Timeout or context cancellation while reading body)"}
48
+ {"time":"2025-08-01T21:35:00.100184499+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:48498->104.21.20.172:443: read: connection reset by peer"}
49
+ {"time":"2025-08-01T21:37:35.955239872+08:00","level":"ERROR","msg":"sender: sendStopStatus: failed to get run stopped status: context deadline exceeded (Client.Timeout or context cancellation while reading body)"}
50
+ {"time":"2025-08-01T21:38:05.712022165+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:60866->104.21.20.172:443: read: connection timed out"}
51
+ {"time":"2025-08-01T21:38:20.956070921+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
52
+ {"time":"2025-08-01T21:41:06.960060464+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:55500->172.67.193.61:443: read: connection timed out"}
53
+ {"time":"2025-08-01T21:41:57.697000345+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:36556->172.67.193.61:443: read: connection reset by peer"}
54
+ {"time":"2025-08-01T21:45:01.968042256+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:53014->172.67.193.61:443: read: connection timed out"}
55
+ {"time":"2025-08-01T21:46:05.960063741+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
56
+ {"time":"2025-08-01T21:47:02.576818619+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:44354->104.21.20.172:443: read: connection reset by peer"}
57
+ {"time":"2025-08-01T21:54:27.728015687+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:59618->104.21.20.172:443: read: connection timed out"}
58
+ {"time":"2025-08-01T21:57:21.808006007+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:46320->104.21.20.172:443: read: connection timed out"}
59
+ {"time":"2025-08-01T22:00:19.984010382+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:53128->104.21.20.172:443: read: connection timed out"}
60
+ {"time":"2025-08-01T22:00:50.968803427+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
61
+ {"time":"2025-08-01T22:03:12.527983291+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:34872->104.21.20.172:443: read: connection timed out"}
62
+ {"time":"2025-08-01T22:04:48.400864668+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:43930->104.21.20.172:443: read: connection reset by peer"}
63
+ {"time":"2025-08-01T22:04:50.970855218+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
64
+ {"time":"2025-08-01T22:07:31.600992168+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:49046->104.21.20.172:443: read: connection timed out"}
65
+ {"time":"2025-08-01T22:09:05.97302798+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
66
+ {"time":"2025-08-01T22:09:46.459558104+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:47316->104.21.20.172:443: read: connection reset by peer"}
67
+ {"time":"2025-08-01T22:10:35.97491736+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
68
+ {"time":"2025-08-01T22:11:08.18377384+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
69
+ {"time":"2025-08-01T22:11:25.779600104+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:34352->172.67.193.61:443: read: connection reset by peer"}
70
+ {"time":"2025-08-01T22:11:42.262306077+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
71
+ {"time":"2025-08-01T22:12:50.976429474+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
72
+ {"time":"2025-08-01T22:13:23.229614619+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
73
+ {"time":"2025-08-01T22:13:57.892901227+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
74
+ {"time":"2025-08-01T22:15:04.820393306+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:60344->104.21.20.172:443: read: connection reset by peer"}
75
+ {"time":"2025-08-01T22:15:45.367773065+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:54996->172.67.193.61:443: read: connection reset by peer"}
76
+ {"time":"2025-08-01T22:17:57.605995597+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": unexpected EOF"}
77
+ {"time":"2025-08-01T22:21:13.87299516+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:49910->104.21.20.172:443: read: connection timed out"}
78
+ {"time":"2025-08-01T22:24:16.144029987+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:41940->172.67.193.61:443: read: connection timed out"}
79
+ {"time":"2025-08-01T22:27:28.144021751+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:46560->172.67.193.61:443: read: connection timed out"}
80
+ {"time":"2025-08-01T22:28:35.984483121+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
81
+ {"time":"2025-08-01T22:31:45.168984421+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:38656->172.67.193.61:443: read: connection timed out"}
82
+ {"time":"2025-08-01T22:35:18.671993015+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:49116->104.21.20.172:443: read: connection timed out"}
83
+ {"time":"2025-08-01T22:39:02.928024828+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:40350->172.67.193.61:443: read: connection timed out"}
84
+ {"time":"2025-08-01T22:39:35.380477212+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:35004->104.21.20.172:443: read: connection reset by peer"}
85
+ {"time":"2025-08-01T22:39:55.432951061+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": unexpected EOF"}
86
+ {"time":"2025-08-01T22:41:50.991964572+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
87
+ {"time":"2025-08-01T22:41:53.61821248+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:37852->104.21.20.172:443: read: connection reset by peer"}
88
+ {"time":"2025-08-01T22:43:35.932757702+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": http2: client conn is closed"}
89
+ {"time":"2025-08-01T22:44:04.110612747+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:53322->104.21.20.172:443: read: connection reset by peer"}
90
+ {"time":"2025-08-01T22:46:35.994608155+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
91
+ {"time":"2025-08-01T22:47:08.18720141+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
92
+ {"time":"2025-08-01T22:47:34.92805984+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:51190->172.67.193.61:443: read: connection timed out"}
93
+ {"time":"2025-08-01T22:47:42.321052323+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
94
+ {"time":"2025-08-01T22:50:05.996775834+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
95
+ {"time":"2025-08-01T22:50:30.544005531+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:49020->172.67.193.61:443: read: connection timed out"}
96
+ {"time":"2025-08-01T22:50:38.303517763+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
97
+ {"time":"2025-08-01T22:51:12.580566425+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
98
+ {"time":"2025-08-01T22:51:51.301848016+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
99
+ {"time":"2025-08-01T22:52:39.862980102+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
100
+ {"time":"2025-08-01T22:53:43.460661511+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
101
+ {"time":"2025-08-01T22:55:13.462771183+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
102
+ {"time":"2025-08-01T22:55:50.033040855+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:38370->172.67.193.61:443: read: connection timed out"}
103
+ {"time":"2025-08-01T22:56:24.977146749+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:39018->104.21.20.172:443: read: connection reset by peer"}
104
+ {"time":"2025-08-01T22:56:43.465087253+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
105
+ {"time":"2025-08-01T22:58:13.466056502+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
106
+ {"time":"2025-08-01T22:59:29.003634152+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": context deadline exceeded (Client.Timeout exceeded while awaiting headers)"}
107
+ {"time":"2025-08-01T22:59:35.997189916+08:00","level":"WARN","msg":"sender: taking a long time","seconds":600.000438436,"work":"WorkRecord(*service_go_proto.Request_StopStatus); Control(local:true mailbox_slot:\"zbp0fazrv773\" connection_id:\"127.0.0.1:32880\")"}
108
+ {"time":"2025-08-01T22:59:43.468121454+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
109
+ {"time":"2025-08-01T22:59:56.405426213+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": unexpected EOF"}
110
+ {"time":"2025-08-01T22:59:56.885071194+08:00","level":"WARN","msg":"runwork: taking a long time","seconds":600.000154693,"work":"WorkRecord(*service_go_proto.Record_OutputRaw); Control(connection_id:\"127.0.0.1:32880\")"}
111
+ {"time":"2025-08-01T23:00:05.004711621+08:00","level":"WARN","msg":"runwork: taking a long time","seconds":600.00033328,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"}
112
+ {"time":"2025-08-01T23:00:05.076924009+08:00","level":"WARN","msg":"runwork: taking a long time","seconds":600.000302659,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"}
113
+ {"time":"2025-08-01T23:01:13.46903123+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
114
+ {"time":"2025-08-01T23:02:43.472282526+08:00","level":"INFO","msg":"sender: succeeded after taking longer than expected","seconds":787.475510948,"work":"WorkRecord(*service_go_proto.Request_StopStatus); Control(local:true mailbox_slot:\"zbp0fazrv773\" connection_id:\"127.0.0.1:32880\")"}
115
+ {"time":"2025-08-01T23:02:43.472384259+08:00","level":"INFO","msg":"runwork: succeeded after taking longer than expected","seconds":758.39578933,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"}
116
+ {"time":"2025-08-01T23:02:43.472386124+08:00","level":"INFO","msg":"runwork: succeeded after taking longer than expected","seconds":758.468026974,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"}
117
+ {"time":"2025-08-01T23:02:43.47239327+08:00","level":"INFO","msg":"runwork: succeeded after taking longer than expected","seconds":766.587479442,"work":"WorkRecord(*service_go_proto.Record_OutputRaw); Control(connection_id:\"127.0.0.1:32880\")"}
118
+ {"time":"2025-08-01T23:03:03.192446429+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:44012->172.67.193.61:443: read: connection reset by peer"}
119
+ {"time":"2025-08-01T23:03:21.005846988+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
120
+ {"time":"2025-08-01T23:08:40.592045526+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:38310->104.21.20.172:443: read: connection timed out"}
121
+ {"time":"2025-08-01T23:11:46.960031857+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:55084->104.21.20.172:443: read: connection timed out"}
122
+ {"time":"2025-08-01T23:12:24.849142697+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:49670->104.21.20.172:443: read: connection reset by peer"}
123
+ {"time":"2025-08-01T23:15:08.177028412+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:52862->104.21.20.172:443: read: connection timed out"}
124
+ {"time":"2025-08-01T23:16:46.602539365+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:45510->104.21.20.172:443: read: connection reset by peer"}
125
+ {"time":"2025-08-01T23:18:23.28263426+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": read tcp 10.1.11.100:48380->172.67.193.61:443: read: connection reset by peer"}
126
+ {"time":"2025-08-01T23:19:27.841519011+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": unexpected EOF"}
127
+ {"time":"2025-08-01T23:23:01.165294384+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": context deadline exceeded"}
128
+ {"time":"2025-08-01T23:26:53.200081773+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:43360->104.21.20.172:443: read: connection timed out"}
129
+ {"time":"2025-08-01T23:34:57.039993804+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:52350->172.67.193.61:443: read: connection timed out"}
130
+ {"time":"2025-08-01T23:38:16.720010379+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:57172->104.21.20.172:443: read: connection timed out"}
131
+ {"time":"2025-08-01T23:41:21.040008763+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:36508->104.21.20.172:443: read: connection timed out"}
132
+ {"time":"2025-08-01T23:44:44.816004249+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:55942->172.67.193.61:443: read: connection timed out"}
133
+ {"time":"2025-08-01T23:48:52.62500069+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:47172->172.67.193.61:443: read: connection timed out"}
134
+ {"time":"2025-08-01T23:51:31.345006101+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:52938->104.21.20.172:443: read: connection timed out"}
135
+ {"time":"2025-08-01T23:54:39.760994652+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:57488->104.21.20.172:443: read: connection timed out"}
136
+ {"time":"2025-08-01T23:56:53.243171871+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": unexpected EOF"}
137
+ {"time":"2025-08-02T00:01:11.439949517+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:48808->172.67.193.61:443: read: connection timed out"}
138
+ {"time":"2025-08-02T00:03:57.840030605+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:58796->172.67.193.61:443: read: connection timed out"}
139
+ {"time":"2025-08-02T00:04:51.340371717+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": read tcp 10.1.11.100:40756->104.21.20.172:443: read: connection reset by peer"}
140
+ {"time":"2025-08-02T00:06:51.040127189+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
141
+ {"time":"2025-08-02T00:12:23.188896321+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": read tcp 10.1.11.100:52718->172.67.193.61:443: read: connection reset by peer"}
142
+ {"time":"2025-08-02T00:19:09.499959872+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": unexpected EOF"}
143
+ {"time":"2025-08-02T00:27:17.136029213+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:33250->104.21.20.172:443: read: connection timed out"}
144
+ {"time":"2025-08-02T00:31:24.717936901+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:42528->104.21.20.172:443: read: connection reset by peer"}
145
+ {"time":"2025-08-02T00:49:10.92803379+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:49640->172.67.193.61:443: read: connection timed out"}
146
+ {"time":"2025-08-02T00:53:16.688059564+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:41832->172.67.193.61:443: read: connection timed out"}
147
+ {"time":"2025-08-02T01:03:03.440032369+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:36398->104.21.20.172:443: read: connection timed out"}
148
+ {"time":"2025-08-02T01:15:25.202630664+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:35078->104.21.20.172:443: read: connection reset by peer"}
149
+ {"time":"2025-08-02T01:23:06.089931169+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
150
+ {"time":"2025-08-02T01:28:16.400050797+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:37336->172.67.193.61:443: read: connection timed out"}
151
+ {"time":"2025-08-02T02:22:06.219804028+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
152
+ {"time":"2025-08-02T03:05:52.144050294+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:51464->172.67.193.61:443: read: connection timed out"}
153
+ {"time":"2025-08-02T06:13:41.315904777+08:00","level":"INFO","msg":"api: retrying HTTP error","status":504,"url":"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream","body":"error code: 504"}
154
+ {"time":"2025-08-02T06:13:52.289991198+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
155
+ {"time":"2025-08-02T06:55:22.310503234+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
156
+ {"time":"2025-08-02T06:55:25.73261549+08:00","level":"INFO","msg":"api: retrying HTTP error","status":504,"url":"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream","body":"error code: 504"}
157
+ {"time":"2025-08-02T07:13:22.52185475+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": read tcp 10.1.11.100:41606->104.21.20.172:443: read: connection reset by peer"}
158
+ {"time":"2025-08-02T07:38:52.528290374+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": read tcp 10.1.11.100:38352->104.21.20.172:443: read: connection reset by peer"}
159
+ {"time":"2025-08-02T09:31:16.192630538+08:00","level":"INFO","msg":"stream: closing","id":"790wkw0g"}
160
+ {"time":"2025-08-02T09:31:16.192679202+08:00","level":"INFO","msg":"Stopping system monitor"}
161
+ {"time":"2025-08-02T09:31:16.194392232+08:00","level":"INFO","msg":"Stopped system monitor"}
162
+ {"time":"2025-08-02T09:31:21.149902109+08:00","level":"INFO","msg":"fileTransfer: Close: file transfer manager closed"}
163
+ {"time":"2025-08-02T09:31:22.465229216+08:00","level":"INFO","msg":"handler: closed","stream_id":"790wkw0g"}
164
+ {"time":"2025-08-02T09:31:22.465258111+08:00","level":"INFO","msg":"writer: Close: closed","stream_id":"790wkw0g"}
165
+ {"time":"2025-08-02T09:31:22.465263174+08:00","level":"INFO","msg":"sender: closed","stream_id":"790wkw0g"}
166
+ {"time":"2025-08-02T09:31:22.469479316+08:00","level":"INFO","msg":"stream: closed","id":"790wkw0g"}
all_checkpoints/stage2_08011616_2datasets_withoutpretrain/wandb/run-20250801_162632-790wkw0g/logs/debug.log ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-08-01 16:26:32,237 INFO MainThread:764673 [wandb_setup.py:_flush():70] Current SDK version is 0.19.11
2
+ 2025-08-01 16:26:32,237 INFO MainThread:764673 [wandb_setup.py:_flush():70] Configure stats pid to 764673
3
+ 2025-08-01 16:26:32,237 INFO MainThread:764673 [wandb_setup.py:_flush():70] Loading settings from /root/.config/wandb/settings
4
+ 2025-08-01 16:26:32,237 INFO MainThread:764673 [wandb_setup.py:_flush():70] Loading settings from /nas/shared/kilab/wangyujia/ProtT3/wandb/settings
5
+ 2025-08-01 16:26:32,237 INFO MainThread:764673 [wandb_setup.py:_flush():70] Loading settings from environment variables
6
+ 2025-08-01 16:26:32,237 INFO MainThread:764673 [wandb_init.py:setup_run_log_directory():724] Logging user logs to ./all_checkpoints/stage2_08011616_2datasets_qweninstruct/wandb/run-20250801_162632-790wkw0g/logs/debug.log
7
+ 2025-08-01 16:26:32,237 INFO MainThread:764673 [wandb_init.py:setup_run_log_directory():725] Logging internal logs to ./all_checkpoints/stage2_08011616_2datasets_qweninstruct/wandb/run-20250801_162632-790wkw0g/logs/debug-internal.log
8
+ 2025-08-01 16:26:32,237 INFO MainThread:764673 [wandb_init.py:init():852] calling init triggers
9
+ 2025-08-01 16:26:32,237 INFO MainThread:764673 [wandb_init.py:init():857] wandb.init called with sweep_config: {}
10
+ config: {'_wandb': {}}
11
+ 2025-08-01 16:26:32,237 INFO MainThread:764673 [wandb_init.py:init():893] starting backend
12
+ 2025-08-01 16:26:32,237 INFO MainThread:764673 [wandb_init.py:init():897] sending inform_init request
13
+ 2025-08-01 16:26:32,241 INFO MainThread:764673 [backend.py:_multiprocessing_setup():101] multiprocessing start_methods=fork,spawn,forkserver, using: spawn
14
+ 2025-08-01 16:26:32,241 INFO MainThread:764673 [wandb_init.py:init():907] backend started and connected
15
+ 2025-08-01 16:26:32,242 INFO MainThread:764673 [wandb_init.py:init():1005] updated telemetry
16
+ 2025-08-01 16:26:32,306 INFO MainThread:764673 [wandb_init.py:init():1029] communicating run to backend with 90.0 second timeout
17
+ 2025-08-01 16:26:49,938 INFO MainThread:764673 [wandb_init.py:init():1104] starting run threads in backend
18
+ 2025-08-01 16:26:50,132 INFO MainThread:764673 [wandb_run.py:_console_start():2573] atexit reg
19
+ 2025-08-01 16:26:50,132 INFO MainThread:764673 [wandb_run.py:_redirect():2421] redirect: wrap_raw
20
+ 2025-08-01 16:26:50,132 INFO MainThread:764673 [wandb_run.py:_redirect():2490] Wrapping output streams.
21
+ 2025-08-01 16:26:50,132 INFO MainThread:764673 [wandb_run.py:_redirect():2513] Redirects installed.
22
+ 2025-08-01 16:26:50,134 INFO MainThread:764673 [wandb_init.py:init():1150] run started, returning control to user process
23
+ 2025-08-01 16:26:58,724 INFO MainThread:764673 [wandb_run.py:_config_callback():1436] config_cb None None {'filename': 'stage2_08011616_2datasets_qweninstruct', 'seed': 42, 'mode': 'train', 'strategy': 'deepspeed', 'accelerator': 'gpu', 'devices': '0,1,2,3,4,5,6,7', 'precision': 'bf16-mixed', 'max_epochs': 10, 'accumulate_grad_batches': 1, 'check_val_every_n_epoch': 1, 'enable_flash': False, 'use_wandb_logger': True, 'mix_dataset': False, 'dataset': 'swiss-prot', 'save_every_n_epochs': 2, 'bert_name': '/nas/shared/kilab/wangyujia/ProtT3/plm_model/microsoft', 'cross_attention_freq': 2, 'num_query_token': 8, 'qformer_tune': 'train', 'llm_name': '/nas/shared/kilab/hf-hub/models--Qwen--Qwen2.5-7B-Instruct/snapshots/a09a35458c702b33eeacc393d103063234e8bc28', 'num_beams': 5, 'do_sample': False, 'max_inference_len': 128, 'min_inference_len': 1, 'llm_tune': 'mid_lora', 'peft_config': '', 'peft_dir': '', 'plm_model': '/nas/shared/kilab/wangyujia/ProtT3/plm_model/esm2-150m', 'plm_tune': 'freeze', 'lora_r': 8, 'lora_alpha': 16, 'lora_dropout': 0.1, 'enbale_gradient_checkpointing': False, 'weight_decay': 0.05, 'init_lr': 0.0001, 'min_lr': 1e-05, 'warmup_lr': 1e-06, 'warmup_steps': 1000, 'lr_decay_rate': 0.9, 'scheduler': 'linear_warmup_cosine_lr', 'stage1_path': '/nas/shared/kilab/wangyujia/ProtT3/all_checkpoints/stage1_07041727_2dataset/epoch=29.ckpt/converted.ckpt', 'stage2_path': '', 'init_checkpoint': '/nas/shared/kilab/wangyujia/ProtT3/all_checkpoints/stage2_07070513_2datasets_construct/epoch=09.ckpt/converted.ckpt', 'caption_eval_epoch': 5, 'num_workers': 8, 'batch_size': 4, 'inference_batch_size': 4, 'root': 'data', 'text_max_len': 2048, 'q_max_len': 29, 'a_max_len': 36, 'prot_max_len': 1024, 'prompt': 'The protein has the following properties:', 'filter_side_qa': False}
24
+ 2025-08-02 09:31:16,189 INFO MsgRouterThr:764673 [mailbox.py:close():129] [no run ID] Closing mailbox, abandoning 1 handles.
all_checkpoints/stage2_08011616_2datasets_withoutpretrain/wandb/run-20250801_162632-790wkw0g/run-790wkw0g.wandb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb5d0be9388288b74f1f5a6f3f8c39d9ee6aa7643fe0c3426a6f51210ad77dc1
3
+ size 84438700
data/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *
2
+ !.gitignore
data/OntoProteinDatasetV2/test.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:87c0edf3fd59defd24decb12c8af64d9bb5fa1ef727fce2bf5fc36c06fe4066f
3
+ size 12085289
data/OntoProteinDatasetV2/train.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:632a5830aea6b3029feaf2eeba5a1c2b63e2badc61b72f92734c517aef85e879
3
+ size 478799520
data/OntoProteinDatasetV2/valid.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:40cefeb9b08cff978fee2ce21f259974fa8881bbe6614f04e896ca7560d4e11f
3
+ size 11829916
data/PDBDataset/abstract.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59cf438e3325c38809ab41fc09fe00ec78fc4bd41d7073f0c22f664168ea5ca3
3
+ size 190905202
data/PDBDataset/q_types.txt ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ What is the nucleic acid polymer entity type for this protein? String structure/property
2
+ When is this protein first published? Number side information
3
+ How many polymer monomers does this protein have? Number structure/property
4
+ How many assemblies does this protein have? Number structure/property
5
+ How many heavy solvent atom coordinates records does this protein have? Number side information
6
+ Does this protein have cis-peptide linkages? String structure/property
7
+ Does this protein contain branched entities? String structure/property
8
+ How many entities does this protein have? Number structure/property
9
+ Does this protein contain solvent entities? String structure/property
10
+ What is the polymer entity type for this protein? String structure/property
11
+ How many nucleic acid polymer entities (DNA or RNA) does this protein have? Number structure/property
12
+ What is the polymer entity composition for this protein? String structure/property
13
+ Does this protein contain DNA polymer entities? String structure/property
14
+ Does this protein contain non-polymer entities? String structure/property
15
+ How many heavy atom coordinates records does this protein have? Number side information
16
+ How many intermolecular metalic bonds does this protein have? Number structure/property
17
+ Does this protein have hybrid nucleic acid polymer entities? String structure/property
18
+ What is the molecular mass (KDa) of polymer and non-polymer entities (exclusive of solvent) for this protein? Number structure/property
19
+ Is this protein determined by experimental or computational methods? String side information
20
+ What is the radiation wavelength in angstroms for this protein? Number structure/property
21
+ Does this protein have unmodeled polymer monomers? String side information
22
+ How many intermolecular covalent bonds does this protein have? Number structure/property
23
+ How many model structures deposited for this protein? Number side information
24
+ What are the software programs reported in connection with the production of this protein? String side information
25
+ How many hydrogen atom coordinates records does this protein have? Number side information
26
+ What experimental method(s) were used to determine the structure of this protein? String side information
27
+ Does this protein contain polymer entities? String structure/property
28
+ Does this protein contain RNA polymer entities? String structure/property
29
+ What are the bound nonpolymer components for this protein String structure/property
30
+ What are the terms characterizing the protein? String structure/property
data/PDBDataset/qa_all.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dfd9d775036717127b66aa9ba9c9d61f8fc324c4ea8fb548d6a1fd2ab4506194
3
+ size 523476104
data/PDBDataset/test.txt ADDED
The diff for this file is too large to render. See raw diff
 
data/PDBDataset/train.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:494c6c0680bdedd58dc0da2df11393b123ad5a4998c68d432da0b1db777c8870
3
+ size 34040909
data/PDBDataset/val.txt ADDED
The diff for this file is too large to render. See raw diff
 
data/SwissProtV3/test_set.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
data/SwissProtV3/train_set.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d92b283a50056b185af292082ea3c733d2c979e70ca233c19d4b06bf81381713
3
+ size 312044762
data/SwissProtV3/valid_set.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
data/protein-molecule/protein-text.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ae14ae1857962332f2b73a3f7ef5651541e7be7d886bb71031b4fc745626f3a
3
+ size 1566756112
data/protein-text/eval_assist.zipg3ebgjl7.tmp ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28e38b555c6d456da8e2b95dd2af19f884b2b86687f733d745b7972079dff708
3
+ size 113836360
data/protein-text/eval_assist.ziphwjr8q2y.tmp ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:58bd802e8fc350fd425e9a827245662790580004bc064aa84b265cad8f2e6d3f
3
+ size 9836668
data/protein-text/eval_assist.zipzh1pdmj_.tmp ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ddebab3f8ea6218c001053e6861781b90b2b70ef429bf9239761c2c0c7927e0
3
+ size 34493776
data_provider/__pycache__/bindingdb.cpython-310.pyc ADDED
Binary file (2.78 kB). View file
 
data_provider/__pycache__/go.cpython-310.pyc ADDED
Binary file (6.44 kB). View file
 
data_provider/__pycache__/metalIonbinding.cpython-310.pyc ADDED
Binary file (2.83 kB). View file
 
data_provider/__pycache__/mutation.cpython-310.pyc ADDED
Binary file (3.97 kB). View file
 
data_provider/__pycache__/production.cpython-310.pyc ADDED
Binary file (6.97 kB). View file
 
data_provider/__pycache__/prot_qa_dm.cpython-310.pyc ADDED
Binary file (7.92 kB). View file
 
data_provider/__pycache__/prot_qa_dm.cpython-311.pyc ADDED
Binary file (15.9 kB). View file
 
data_provider/__pycache__/stage1_dm.cpython-310.pyc ADDED
Binary file (14.9 kB). View file
 
data_provider/__pycache__/stage1_dm.cpython-311.pyc ADDED
Binary file (30.1 kB). View file
 
data_provider/__pycache__/stage2_dm.cpython-310.pyc ADDED
Binary file (12.4 kB). View file
 
data_provider/__pycache__/stage3_dm.cpython-310.pyc ADDED
Binary file (23.7 kB). View file
 
data_provider/__pycache__/stage3_dm.cpython-311.pyc ADDED
Binary file (12.9 kB). View file
 
data_provider/bindingdb.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pytorch_lightning import LightningDataModule
3
+ from torch.utils.data import DataLoader, ConcatDataset, Dataset
4
+ from data_provider.stage1_dm import SwissProtDataset, OntoProteinDataset
5
+ import pandas as pd
6
+
7
+ class BindingDB(Dataset):
8
+ def __init__(self, data_path, prompt='', return_prompt=False):
9
+ super(BindingDB, self).__init__()
10
+ self.data_path = data_path
11
+ self.user_prompt = prompt
12
+ self.return_prompt = return_prompt
13
+
14
+ self.data_list = self._load_and_preprocess(self.data_path)
15
+ self.text2id = self._build_text_vocab()
16
+
17
+ def _load_and_preprocess(self, data_path):
18
+ data_list = []
19
+ df = pd.read_csv(data_path)
20
+ for _, row in df.iterrows():
21
+ try:
22
+ ligand_smiles = str(row['ligand']).strip()
23
+ prot_seq = str(row['protein']).strip()
24
+ result = str(row['ic50']).strip()
25
+
26
+ text_seq = f"<answer>{result}</answer>\n"
27
+
28
+ prompt = f"""
29
+ 【Protein sequence (1-letter amino acid codes)】;{ligand_smiles}【Ligand structure (SMILES)】
30
+ Task: Evaluate the inhibitory effect of the ligand on the given protein.
31
+ Note: IC50 (half maximal inhibitory concentration) is the concentration of a substance required to inhibit 50% of the protein's activity. Lower IC50 values indicate stronger inhibition.
32
+ Based on the provided protein and ligand, predict the inhibitory strength by classifying the IC50 level:
33
+ """
34
+ if self.user_prompt:
35
+ prompt += self.user_prompt
36
+
37
+ # extra可以返回原始feather字符串,也可以返回feather_vals
38
+ # 或 feather_raw
39
+ data_list.append((prot_seq, text_seq, prompt))
40
+ except Exception as e:
41
+ print(f"警告: 跳过有问题的行: {row},原因: {e}")
42
+ return data_list
43
+
44
+ def _build_text_vocab(self):
45
+ text2id = {}
46
+ for _, text_seq, _ in self.data_list:
47
+ if text_seq not in text2id:
48
+ text2id[text_seq] = len(text2id)
49
+ return text2id
50
+
51
+ def shuffle(self):
52
+ random.shuffle(self.data_list)
53
+ return self
54
+
55
+ def __len__(self):
56
+ return len(self.data_list)
57
+
58
+ def __getitem__(self, index):
59
+ prot_seq, text_seq, prompt = self.data_list[index]
60
+ if self.return_prompt:
61
+ return prot_seq, prompt, text_seq,index
62
+ return prot_seq, text_seq, index
data_provider/gal_helpers.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ # we split individual characters inside special tokens like [START_DNA]
5
+ CUSTOM_SEQ_RE = re.compile(r"(\[START_(DNA|SMILES|I_SMILES|AMINO)])(.*?)(\[END_\2])")
6
+
7
+ # token added to implement a custom sequence tokenization. This token is added at
8
+ # corpus cleaning step and removed in pretokenization. The digits are added to increase the chance
9
+ # that they do not occur in the corpus. The digits are escaped so that the token does not appear
10
+ # literally in the source code in case we ever include it in the training data.
11
+ SPLIT_MARKER = f"SPL{1}T-TH{1}S-Pl3A5E"
12
+
13
+ def _insert_split_marker(m: re.Match):
14
+ """
15
+ Applies split marker based on a regex match of special tokens such as
16
+ [START_DNA].
17
+
18
+ Parameters
19
+ ----------
20
+ n : str
21
+ Input text to split
22
+
23
+ Returns
24
+ ----------
25
+ str - the text with the split token added
26
+ """
27
+ start_token, _, sequence, end_token = m.groups()
28
+ sequence = re.sub(r"(.)", fr"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL)
29
+ return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}"
30
+
31
+
32
+ def escape_custom_split_sequence(text):
33
+ """
34
+ Applies custom splitting to the text for GALILEO's tokenization
35
+
36
+ Parameters
37
+ ----------
38
+ text : str
39
+ Input text to split
40
+
41
+ Returns
42
+ ----------
43
+ str - the text with the split token added
44
+ """
45
+ return CUSTOM_SEQ_RE.sub(_insert_split_marker, text)
data_provider/go.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pytorch_lightning import LightningDataModule
3
+ from torch.utils.data import DataLoader, ConcatDataset, Dataset
4
+ from data_provider.stage1_dm import SwissProtDataset, OntoProteinDataset
5
+ import pandas as pd
6
+
7
+ class GO_BP(Dataset):
8
+ def __init__(self, data_path, prompt='', return_prompt=False):
9
+ super(GO_BP, self).__init__()
10
+ self.data_path = data_path
11
+ self.user_prompt = prompt
12
+ self.return_prompt = return_prompt
13
+
14
+ self.data_list = self._load_and_preprocess(self.data_path)
15
+ self.text2id = self._build_text_vocab()
16
+
17
+ def _load_and_preprocess(self, data_path):
18
+ data_list = []
19
+ df = pd.read_csv(data_path)
20
+ for _, row in df.iterrows():
21
+ try:
22
+ prot_seq = str(row['question']).strip()
23
+ result = str(row['answer']).strip()
24
+
25
+ text_seq = f"<answer>{result}</answer>\n"
26
+
27
+ prompt = f"""
28
+ 【Task】Predict the biological processes involving the given protein.
29
+ 【Background】Each process is represented by a GO-BP term (e.g., GO:0008150), describing a series of molecular events relevant to protein function.
30
+ 【Output Format】List the predicted GO-BP terms, separated by commas, and wrap them in <answer> </answer> tags.
31
+ Example: <answer>GO:0008150, GO:0009987, GO:0050896</answer>
32
+ """
33
+ if self.user_prompt:
34
+ prompt += self.user_prompt
35
+
36
+ # extra可以返回原始feather字符串,也可以返回feather_vals
37
+ # 或 feather_raw
38
+ data_list.append((prot_seq, text_seq, prompt))
39
+ except Exception as e:
40
+ print(f"警告: 跳过有问题的行: {row},原因: {e}")
41
+ return data_list
42
+
43
+ def _build_text_vocab(self):
44
+ text2id = {}
45
+ for _, text_seq, _ in self.data_list:
46
+ if text_seq not in text2id:
47
+ text2id[text_seq] = len(text2id)
48
+ return text2id
49
+
50
+ def shuffle(self):
51
+ random.shuffle(self.data_list)
52
+ return self
53
+
54
+ def __len__(self):
55
+ return len(self.data_list)
56
+
57
+ def __getitem__(self, index):
58
+ prot_seq, text_seq, prompt = self.data_list[index]
59
+ if self.return_prompt:
60
+ return prot_seq, prompt, text_seq,index
61
+ return prot_seq, text_seq, index
62
+
63
+
64
+ class GO_CC(Dataset):
65
+ def __init__(self, data_path, prompt='', return_prompt=False):
66
+ super(GO_CC, self).__init__()
67
+ self.data_path = data_path
68
+ self.user_prompt = prompt
69
+ self.return_prompt = return_prompt
70
+
71
+ self.data_list = self._load_and_preprocess(self.data_path)
72
+ self.text2id = self._build_text_vocab()
73
+
74
+ def _load_and_preprocess(self, data_path):
75
+ data_list = []
76
+ df = pd.read_csv(data_path)
77
+ for _, row in df.iterrows():
78
+ try:
79
+ prot_seq = str(row['question']).strip()
80
+ result = str(row['answer']).strip()
81
+
82
+ text_seq = f"<answer>{result}</answer>\n"
83
+
84
+ prompt = f"""
85
+ 【Task】Predict the cellular components associated with this protein.
86
+ 【Background】Each location is represented by a GO-CC term (e.g., GO:0005737), indicating where the protein functions within the cell.
87
+ 【Output Format】List the predicted GO-CC terms, separated by commas, and wrap them in <answer> </answer> tags.
88
+ Example: <answer>GO:0005737, GO:0005829, GO:0005886</answer>
89
+ """
90
+ if self.user_prompt:
91
+ prompt += self.user_prompt
92
+
93
+ # extra可以返回原始feather字符串,也可以返回feather_vals
94
+ # 或 feather_raw
95
+ data_list.append((prot_seq, text_seq, prompt))
96
+ except Exception as e:
97
+ print(f"警告: 跳过有问题的行: {row},原因: {e}")
98
+ return data_list
99
+
100
+ def _build_text_vocab(self):
101
+ text2id = {}
102
+ for _, text_seq, _ in self.data_list:
103
+ if text_seq not in text2id:
104
+ text2id[text_seq] = len(text2id)
105
+ return text2id
106
+
107
+ def shuffle(self):
108
+ random.shuffle(self.data_list)
109
+ return self
110
+
111
+ def __len__(self):
112
+ return len(self.data_list)
113
+
114
+ def __getitem__(self, index):
115
+ prot_seq, text_seq, prompt = self.data_list[index]
116
+ if self.return_prompt:
117
+ return prot_seq, prompt, text_seq,index
118
+ return prot_seq, text_seq, index
119
+
120
+
121
+ class GO_MF(Dataset):
122
+ def __init__(self, data_path, prompt='', return_prompt=False):
123
+ super(GO_MF, self).__init__()
124
+ self.data_path = data_path
125
+ self.user_prompt = prompt
126
+ self.return_prompt = return_prompt
127
+
128
+ self.data_list = self._load_and_preprocess(self.data_path)
129
+ self.text2id = self._build_text_vocab()
130
+
131
+ def _load_and_preprocess(self, data_path):
132
+ data_list = []
133
+ df = pd.read_csv(data_path)
134
+ for _, row in df.iterrows():
135
+ try:
136
+ prot_seq = str(row['question']).strip()
137
+ result = str(row['answer']).strip()
138
+
139
+ text_seq = f"<answer>{result}</answer>\n"
140
+
141
+ prompt = f"""
142
+ 【Task】Predict the molecular functions performed by this protein.
143
+ 【Background】Each function is represented by a GO-MF term (e.g., GO:0003677), describing specific biochemical activities of the protein.
144
+ 【Output Format】List the predicted GO-MF terms, separated by commas, and wrap them in <answer> </answer> tags.
145
+ Example: <answer>GO:0003677, GO:0005524, GO:0016787</answer>
146
+ """
147
+ if self.user_prompt:
148
+ prompt += self.user_prompt
149
+
150
+ # extra可以返回原始feather字符串,也可以返回feather_vals
151
+ # 或 feather_raw
152
+ data_list.append((prot_seq, text_seq, prompt))
153
+ except Exception as e:
154
+ print(f"警告: 跳过有问题的行: {row},原因: {e}")
155
+ return data_list
156
+
157
+ def _build_text_vocab(self):
158
+ text2id = {}
159
+ for _, text_seq, _ in self.data_list:
160
+ if text_seq not in text2id:
161
+ text2id[text_seq] = len(text2id)
162
+ return text2id
163
+
164
+ def shuffle(self):
165
+ random.shuffle(self.data_list)
166
+ return self
167
+
168
+ def __len__(self):
169
+ return len(self.data_list)
170
+
171
+ def __getitem__(self, index):
172
+ prot_seq, text_seq, prompt = self.data_list[index]
173
+ if self.return_prompt:
174
+ return prot_seq, prompt, text_seq,index
175
+ return prot_seq, text_seq, index
176
+
177
+
178
+ class EC(Dataset):
179
+ def __init__(self, data_path, prompt='', return_prompt=False):
180
+ super(EC, self).__init__()
181
+ self.data_path = data_path
182
+ self.user_prompt = prompt
183
+ self.return_prompt = return_prompt
184
+
185
+ self.data_list = self._load_and_preprocess(self.data_path)
186
+ self.text2id = self._build_text_vocab()
187
+
188
+ def _load_and_preprocess(self, data_path):
189
+ data_list = []
190
+ df = pd.read_csv(data_path)
191
+ for _, row in df.iterrows():
192
+ try:
193
+ name = str(row['name']).strip()
194
+ # 先按“-”分割,得到结构信息和 UniProt ID
195
+ structure_part, uniprot_id = name.split('-') # '3r7t_A', 'Q9PMG4'
196
+
197
+ # 再按“_”分割结构信息,得到 PDB ID 和链ID
198
+ pdb_id, chain_id = structure_part.split('_') # '3r7t', 'A'
199
+ prot_seq = str(row['aa_seq']).strip()
200
+ result = str(row['label']).strip()
201
+
202
+ text_seq = f"<answer>{result}</answer>\n"
203
+
204
+ prompt = f"""
205
+ The information provided above is protein information, one of the chains of its crystal structure {pdb_id}, named {chain_id}, and numbered {uniprot_id} in the Uniprot sequence database.
206
+ Based on this information, the possible enzyme activity is inferred and the corresponding EC number is predicted.
207
+ 【Output Format】List predicted EC numbers, separated by commas, wrapped in <answer> </answer> tags.
208
+ """
209
+ if self.user_prompt:
210
+ prompt += self.user_prompt
211
+
212
+ # extra可以返回原始feather字符串,也可以返回feather_vals
213
+ # 或 feather_raw
214
+ data_list.append((prot_seq, text_seq, prompt))
215
+ except Exception as e:
216
+ print(f"警告: 跳过有问题的行: {row},原因: {e}")
217
+ return data_list
218
+
219
+ def _build_text_vocab(self):
220
+ text2id = {}
221
+ for _, text_seq, _ in self.data_list:
222
+ if text_seq not in text2id:
223
+ text2id[text_seq] = len(text2id)
224
+ return text2id
225
+
226
+ def shuffle(self):
227
+ random.shuffle(self.data_list)
228
+ return self
229
+
230
+ def __len__(self):
231
+ return len(self.data_list)
232
+
233
+ def __getitem__(self, index):
234
+ prot_seq, text_seq, prompt = self.data_list[index]
235
+ if self.return_prompt:
236
+ return prot_seq, prompt, text_seq,index
237
+ return prot_seq, text_seq, index
data_provider/llm_tuning_dm.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+ from pytorch_lightning import LightningDataModule
4
+ from data_provider.gal_helpers import escape_custom_split_sequence
5
+ from data_provider.stage1_dm import SwissProtDataset, OntoProteinDataset
6
+ from torch.utils.data import DataLoader, ConcatDataset
7
+
8
+
9
+ class LLMTuningCollater:
10
+ def __init__(self, tokenizer, text_max_len, prot_max_len, use_gal):
11
+ self.text_max_len = text_max_len
12
+ self.prot_max_len = prot_max_len
13
+ self.tokenizer = tokenizer
14
+ self.use_gal = use_gal
15
+
16
+ def __call__(self, batch):
17
+ prot_seqs, prompt_seqs, text_seqs, _ = zip(*batch)
18
+ prot_seqs = [prompt.format(p) for prompt, p in zip(prompt_seqs, prot_seqs)]
19
+ if self.use_gal:
20
+ prot_seqs = [escape_custom_split_sequence(p) for p in prot_seqs]
21
+ ## deal with prompt
22
+ self.tokenizer.padding_side = 'left'
23
+ prot_batch = self.tokenizer(text=prot_seqs,
24
+ truncation=True,
25
+ padding='max_length',
26
+ add_special_tokens=True,
27
+ max_length=self.prot_max_len,
28
+ return_tensors='pt',
29
+ return_attention_mask=True)
30
+ self.tokenizer.padding_side = 'right'
31
+ text_batch = self.tokenizer(text=text_seqs,
32
+ truncation=True,
33
+ padding='max_length',
34
+ add_special_tokens=True,
35
+ max_length=self.text_max_len,
36
+ return_tensors='pt',
37
+ return_attention_mask=True)
38
+ return prot_batch, text_batch
39
+
40
+
41
+ class InferenceCollater:
42
+ def __init__(self, tokenizer, text_max_len, prot_max_len, use_gal):
43
+ self.text_max_len = text_max_len
44
+ self.prot_max_len = prot_max_len
45
+ self.tokenizer = tokenizer
46
+ self.use_gal = use_gal
47
+
48
+ def __call__(self, batch):
49
+ prot_seqs, prompt_seqs, text_seqs, indices = zip(*batch)
50
+ prot_seqs = [prompt.format(p) for prompt, p in zip(prompt_seqs, prot_seqs)]
51
+ if self.use_gal:
52
+ prot_seqs = [escape_custom_split_sequence(p) for p in prot_seqs]
53
+ ## deal with prompt
54
+ self.tokenizer.padding_side = 'left'
55
+ prot_batch = self.tokenizer(text=prot_seqs,
56
+ truncation=True,
57
+ padding='max_length',
58
+ add_special_tokens=True,
59
+ max_length=self.prot_max_len,
60
+ return_tensors='pt',
61
+ return_attention_mask=True)
62
+ target_dict = {'targets': text_seqs, 'indices': indices}
63
+ return prot_batch, target_dict
64
+
65
+
66
+
67
+ class LLMTuningDM(LightningDataModule):
68
+ def __init__(
69
+ self,
70
+ root: str = 'data/',
71
+ args=None,
72
+ ):
73
+ super().__init__()
74
+ self.batch_size = args.batch_size
75
+ self.inference_batch_size = args.inference_batch_size
76
+ self.num_workers = args.num_workers
77
+ self.prot_max_len = args.prot_max_len
78
+ self.text_max_len = args.text_max_len
79
+ if root.find('SwissProtV3') >= 0:
80
+ self.train_dataset = SwissProtDataset(root+'/train_set.json', prompt='[START_AMINO]{}[END_AMINO]. Swiss-Prot description: ', return_prompt=True)
81
+ self.val_dataset = SwissProtDataset(root+'/valid_set.json', prompt='[START_AMINO]{}[END_AMINO]. Swiss-Prot description: ', return_prompt=True)
82
+ self.test_dataset = SwissProtDataset(root+'/test_set.json', prompt='[START_AMINO]{}[END_AMINO]. Swiss-Prot description: ', return_prompt=True)
83
+ elif root.find('OntoProteinDatasetV2') >= 0:
84
+ self.train_dataset = OntoProteinDataset(root+'/train.txt', prompt='[START_AMINO]{}[END_AMINO]. Gene Ontology description: ', return_prompt=True)
85
+ self.val_dataset = OntoProteinDataset(root+'/valid.txt', prompt='[START_AMINO]{}[END_AMINO]. Gene Ontology description: ', return_prompt=True)
86
+ self.test_dataset = OntoProteinDataset(root+'/test.txt', prompt='[START_AMINO]{}[END_AMINO]. Gene Ontology description: ', return_prompt=True)
87
+ else:
88
+ raise NotImplementedError()
89
+ self.tokenizer = None
90
+ self.use_gal = args.llm_name.find('gal') >= 0
91
+
92
+ def init_tokenizer(self, tokenizer):
93
+ self.tokenizer = tokenizer
94
+
95
+ def train_dataloader(self):
96
+ loader = DataLoader(
97
+ self.train_dataset,
98
+ batch_size=self.batch_size,
99
+ shuffle=True,
100
+ num_workers=self.num_workers,
101
+ pin_memory=False,
102
+ drop_last=True,
103
+ persistent_workers=False,
104
+ collate_fn=LLMTuningCollater(self.tokenizer, self.text_max_len, self.prot_max_len, self.use_gal),
105
+ )
106
+ return loader
107
+
108
+ def val_dataloader(self):
109
+ val_loader = DataLoader(
110
+ self.val_dataset,
111
+ batch_size=self.batch_size,
112
+ shuffle=False,
113
+ num_workers=self.num_workers,
114
+ pin_memory=False,
115
+ drop_last=False,
116
+ persistent_workers=False,
117
+ collate_fn=LLMTuningCollater(self.tokenizer, self.text_max_len, self.prot_max_len, self.use_gal),
118
+ )
119
+ test_loader = DataLoader(
120
+ self.test_dataset,
121
+ batch_size=self.inference_batch_size,
122
+ shuffle=False,
123
+ num_workers=self.num_workers,
124
+ pin_memory=False,
125
+ drop_last=False,
126
+ persistent_workers=False,
127
+ collate_fn=InferenceCollater(self.tokenizer, self.text_max_len, self.prot_max_len, self.use_gal),
128
+ )
129
+ return [val_loader, test_loader]
130
+
131
+ def add_model_specific_args(parent_parser):
132
+ parser = parent_parser.add_argument_group("Data module")
133
+ parser.add_argument('--num_workers', type=int, default=2)
134
+ parser.add_argument('--batch_size', type=int, default=32)
135
+ parser.add_argument('--inference_batch_size', type=int, default=4)
136
+ parser.add_argument('--root', type=str, default='data/SwissProtV3')
137
+ parser.add_argument('--text_max_len', type=int, default=128)
138
+ parser.add_argument('--prot_max_len', type=int, default=1024)
139
+ parser.add_argument('--q_max_len', type=int, default=1064)
140
+ parser.add_argument('--a_max_len', type=int, default=36)
141
+ parser.add_argument('--prompt', type=str, default='[START_AMINO]{}[END_AMINO]. The protein has the following properties: ')
142
+ parser.add_argument('--filter_side_qa', action='store_true', default=False)
143
+ return parent_parser
144
+
145
+
146
+ class LLMTuningMixDM(LightningDataModule):
147
+ def __init__(
148
+ self,
149
+ root: str = 'data/',
150
+ args=None,
151
+ ):
152
+ super().__init__()
153
+ self.batch_size = args.batch_size
154
+ self.inference_batch_size = args.inference_batch_size
155
+ self.num_workers = args.num_workers
156
+ self.prot_max_len = args.prot_max_len
157
+ self.text_max_len = args.text_max_len
158
+
159
+ train_dataset1 = SwissProtDataset(root+'/SwissProtV3/train_set.json', prompt='[START_AMINO]{}[END_AMINO]. Swiss-Prot description: ', return_prompt=True)
160
+ train_dataset2 = OntoProteinDataset(root+'/OntoProteinDatasetV2/train.txt', prompt='[START_AMINO]{}[END_AMINO]. Gene Ontology description: ', return_prompt=True)
161
+ self.train_dataset = ConcatDataset([train_dataset1, train_dataset2])
162
+ self.swiss_val_dataset = SwissProtDataset(root+'/SwissProtV3/valid_set.json', prompt='[START_AMINO]{}[END_AMINO]. Swiss-Prot description: ', return_prompt=True)
163
+ self.onto_val_dataset = OntoProteinDataset(root+'/OntoProteinDatasetV2/valid.txt', prompt='[START_AMINO]{}[END_AMINO]. Gene Ontology description: ', return_prompt=True)
164
+ self.swiss_test_dataset = SwissProtDataset(root+'/SwissProtV3/test_set.json', prompt='[START_AMINO]{}[END_AMINO]. Swiss-Prot description: ', return_prompt=True)
165
+ self.onto_test_dataset = OntoProteinDataset(root+'/OntoProteinDatasetV2/test.txt', prompt='[START_AMINO]{}[END_AMINO]. Gene Ontology description: ', return_prompt=True)
166
+
167
+ self.tokenizer = None
168
+ self.use_gal = args.llm_name.find('gal') >= 0
169
+
170
+ def init_tokenizer(self, tokenizer):
171
+ self.tokenizer = tokenizer
172
+
173
+ def train_dataloader(self):
174
+ loader = DataLoader(
175
+ self.train_dataset,
176
+ batch_size=self.batch_size,
177
+ shuffle=True,
178
+ num_workers=self.num_workers,
179
+ pin_memory=False,
180
+ drop_last=True,
181
+ persistent_workers=False,
182
+ collate_fn=LLMTuningCollater(self.tokenizer, self.text_max_len, self.prot_max_len, self.use_gal),
183
+ )
184
+ return loader
185
+
186
+ def val_dataloader(self):
187
+ swiss_val_loader = DataLoader(
188
+ self.swiss_val_dataset,
189
+ batch_size=self.batch_size,
190
+ shuffle=False,
191
+ num_workers=self.num_workers,
192
+ pin_memory=False,
193
+ drop_last=False,
194
+ persistent_workers=False,
195
+ collate_fn=LLMTuningCollater(self.tokenizer, self.text_max_len, self.prot_max_len, self.use_gal),
196
+ )
197
+ swiss_test_loader = DataLoader(
198
+ self.swiss_test_dataset,
199
+ batch_size=self.inference_batch_size,
200
+ shuffle=False,
201
+ num_workers=self.num_workers,
202
+ pin_memory=False,
203
+ drop_last=False,
204
+ persistent_workers=False,
205
+ collate_fn=InferenceCollater(self.tokenizer, self.text_max_len, self.prot_max_len, self.use_gal),
206
+ )
207
+
208
+ onto_val_loader = DataLoader(
209
+ self.onto_val_dataset,
210
+ batch_size=self.batch_size,
211
+ shuffle=False,
212
+ num_workers=self.num_workers,
213
+ pin_memory=False,
214
+ drop_last=False,
215
+ persistent_workers=False,
216
+ collate_fn=LLMTuningCollater(self.tokenizer, self.text_max_len, self.prot_max_len, self.use_gal),
217
+ )
218
+ onto_test_loader = DataLoader(
219
+ self.onto_test_dataset,
220
+ batch_size=self.inference_batch_size,
221
+ shuffle=False,
222
+ num_workers=self.num_workers,
223
+ pin_memory=False,
224
+ drop_last=False,
225
+ persistent_workers=False,
226
+ collate_fn=InferenceCollater(self.tokenizer, self.text_max_len, self.prot_max_len, self.use_gal),
227
+ )
228
+ return [swiss_val_loader, swiss_test_loader, onto_val_loader, onto_test_loader]
229
+
230
+ def add_model_specific_args(parent_parser):
231
+ parser = parent_parser.add_argument_group("Data module")
232
+ parser.add_argument('--num_workers', type=int, default=2)
233
+ parser.add_argument('--batch_size', type=int, default=32)
234
+ parser.add_argument('--inference_batch_size', type=int, default=4)
235
+ parser.add_argument('--root', type=str, default='data/SwissProtV3')
236
+ parser.add_argument('--text_max_len', type=int, default=128)
237
+ parser.add_argument('--prot_max_len', type=int, default=1024)
238
+ parser.add_argument('--q_max_len', type=int, default=1064)
239
+ parser.add_argument('--a_max_len', type=int, default=36)
240
+ parser.add_argument('--prompt', type=str, default='[START_AMINO]{}[END_AMINO]. The protein has the following properties: ')
241
+ parser.add_argument('--filter_side_qa', action='store_true', default=False)
242
+ return parent_parser
243
+
244
+
245
+ if __name__ == '__main__':
246
+ dataset = SwissProtDataset('../data/SwissProtV3/train_set.json')
247
+ from transformers import AutoTokenizer
248
+ tokenizer = AutoTokenizer.from_pretrained('facebook/galactica-1.3b')
249
+ tokenizer.add_special_tokens({'pad_token': '<pad>'})
250
+ loader = DataLoader(
251
+ dataset,
252
+ batch_size=16,
253
+ shuffle=True,
254
+ num_workers=0,
255
+ pin_memory=False,
256
+ drop_last=True,
257
+ persistent_workers=False,
258
+ collate_fn=LLMTuningCollater(tokenizer, 128, 1024, True, '[START_AMINO]{}[END_AMINO].'),
259
+ )
260
+ for data in loader:
261
+ input()
data_provider/llm_tuning_prot_qa_dm.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+ from pytorch_lightning import LightningDataModule
4
+ from torch.utils.data import DataLoader
5
+ from data_provider.prot_qa_dm import PDBQADataset
6
+ from data_provider.gal_helpers import escape_custom_split_sequence
7
+
8
+
9
+ class LLMTuningProtQACollater(object):
10
+ def __init__(self, tokenizer, q_max_len, a_max_len, use_gal, prompt):
11
+ self.tokenizer = tokenizer
12
+ self.q_max_len = q_max_len
13
+ self.a_max_len = a_max_len
14
+ self.use_gal = use_gal
15
+ self.prompt = prompt
16
+ assert prompt.find('{}') >= 0
17
+
18
+ def __call__(self, batch):
19
+ prot_seqs, questions, answers, q_types = zip(*batch)
20
+ assert len(prot_seqs) == len(questions) == len(answers)
21
+ questions = [self.prompt.format(prot_seqs[i], questions[i]) for i in range(len(prot_seqs))]
22
+
23
+ if self.use_gal:
24
+ questions = [escape_custom_split_sequence(q) for q in questions]
25
+ answers = [a + '\n' for a in answers]
26
+ if False:
27
+ self.tokenizer.padding_side = 'left'
28
+ q_batch = self.tokenizer(questions,
29
+ truncation=True,
30
+ padding='max_length',
31
+ add_special_tokens=True,
32
+ max_length=self.q_max_len,
33
+ return_tensors='pt',
34
+ return_attention_mask=True,
35
+ return_token_type_ids=False)
36
+ self.tokenizer.padding_side = 'right'
37
+ a_batch = self.tokenizer(answers,
38
+ truncation=True,
39
+ padding='max_length',
40
+ add_special_tokens=True,
41
+ max_length=self.a_max_len,
42
+ return_tensors='pt',
43
+ return_attention_mask=True,
44
+ return_token_type_ids=False)
45
+ return q_batch, a_batch
46
+ else:
47
+ self.tokenizer.padding_side = 'right'
48
+ qa_pair = [[q, a] for q, a in zip(questions, answers)]
49
+ qa_batch = self.tokenizer(qa_pair,
50
+ truncation=True,
51
+ padding='max_length',
52
+ add_special_tokens=True,
53
+ max_length=self.q_max_len + self.a_max_len,
54
+ return_tensors='pt',
55
+ return_attention_mask=True,
56
+ return_token_type_ids=True)
57
+ return qa_batch
58
+
59
+
60
+ class InferenceCollater(object):
61
+ def __init__(self, tokenizer, q_max_len, a_max_len, use_gal, prompt):
62
+ self.tokenizer = tokenizer
63
+ self.q_max_len = q_max_len
64
+ self.a_max_len = a_max_len
65
+ self.use_gal = use_gal
66
+ self.prompt = prompt
67
+ assert prompt.find('{}') >= 0
68
+
69
+ def __call__(self, batch):
70
+ prot_seqs, questions, answers, q_types, indices = zip(*batch)
71
+ assert len(prot_seqs) == len(questions) == len(answers)
72
+ questions = [self.prompt.format(prot_seqs[i], questions[i]) for i in range(len(prot_seqs))]
73
+
74
+ if self.use_gal:
75
+ questions = [escape_custom_split_sequence(q) for q in questions]
76
+ answers = [a + '\n' for a in answers]
77
+ self.tokenizer.padding_side = 'left'
78
+ q_batch = self.tokenizer(questions,
79
+ truncation=True,
80
+ padding='max_length',
81
+ add_special_tokens=True,
82
+ max_length=self.q_max_len,
83
+ return_tensors='pt',
84
+ return_attention_mask=True,
85
+ return_token_type_ids=False)
86
+ target_dict = {'targets': answers, 'q_types': q_types, 'indices': indices}
87
+ return q_batch, target_dict
88
+
89
+
90
+ class LLMTuningProtQADM(LightningDataModule):
91
+ def __init__(
92
+ self,
93
+ root: str = 'data/',
94
+ args=None,
95
+ ):
96
+ super().__init__()
97
+ self.args = args
98
+ self.batch_size = args.batch_size
99
+ self.inference_batch_size = args.inference_batch_size
100
+ self.num_workers = args.num_workers
101
+ self.q_max_len = args.q_max_len
102
+ self.a_max_len = args.a_max_len
103
+ self.prompt = args.prompt
104
+
105
+ self.train_dataset = PDBQADataset(root, 'train.txt', "Question: {} Answer:", filter_side_qa=args.filter_side_qa)
106
+ self.val_dataset = PDBQADataset(root, 'val.txt', "Question: {} Answer:", filter_side_qa=args.filter_side_qa)
107
+ self.test_dataset = PDBQADataset(root, 'test.txt', "Question: {} Answer:", filter_side_qa=args.filter_side_qa)
108
+
109
+ self.tokenizer = None
110
+ self.use_gal = args.llm_name.find('gal') >= 0
111
+
112
+ def init_tokenizer(self, tokenizer):
113
+ self.tokenizer = tokenizer
114
+
115
+ def train_dataloader(self):
116
+ loader = DataLoader(
117
+ self.train_dataset,
118
+ batch_size=self.batch_size,
119
+ shuffle=True,
120
+ num_workers=self.num_workers,
121
+ pin_memory=False,
122
+ drop_last=True,
123
+ persistent_workers=False,
124
+ collate_fn=LLMTuningProtQACollater(self.tokenizer, self.q_max_len, self.a_max_len, self.use_gal, self.prompt),
125
+ )
126
+ return loader
127
+
128
+ def val_dataloader(self):
129
+ val_loader = DataLoader(
130
+ self.val_dataset,
131
+ batch_size=self.batch_size,
132
+ shuffle=False,
133
+ num_workers=self.num_workers,
134
+ pin_memory=False,
135
+ drop_last=False,
136
+ persistent_workers=False,
137
+ collate_fn=LLMTuningProtQACollater(self.tokenizer, self.q_max_len, self.a_max_len, self.use_gal, self.prompt),
138
+ )
139
+ test_loader = DataLoader(
140
+ self.test_dataset,
141
+ batch_size=self.inference_batch_size,
142
+ shuffle=False,
143
+ num_workers=self.num_workers,
144
+ pin_memory=False,
145
+ drop_last=False,
146
+ persistent_workers=False,
147
+ collate_fn=InferenceCollater(self.tokenizer, self.q_max_len, self.a_max_len, self.use_gal, self.prompt),
148
+ )
149
+ return [val_loader, test_loader]
150
+
151
+
152
+ def add_model_specific_args(parent_parser):
153
+ parser = parent_parser.add_argument_group("Data module")
154
+ parser.add_argument('--num_workers', type=int, default=2)
155
+ parser.add_argument('--batch_size', type=int, default=32)
156
+ parser.add_argument('--inference_batch_size', type=int, default=4)
157
+ parser.add_argument('--root', type=str, default='data/SwissProtV3')
158
+ parser.add_argument('--q_max_len', type=int, default=1064)
159
+ parser.add_argument('--a_max_len', type=int, default=36)
160
+ parser.add_argument('--prompt', type=str, default='[START_AMINO]{}[END_AMINO]. {}')
161
+ parser.add_argument('--filter_side_qa', action='store_true', default=False)
162
+ return parent_parser
163
+
164
+
data_provider/metalIonbinding.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pytorch_lightning import LightningDataModule
3
+ from torch.utils.data import DataLoader, ConcatDataset, Dataset
4
+ from data_provider.stage1_dm import SwissProtDataset, OntoProteinDataset
5
+ import pandas as pd
6
+
7
+ class MetallonBinding(Dataset):
8
+ def __init__(self, data_path, prompt='', return_prompt=False):
9
+ super(MetallonBinding, self).__init__()
10
+ self.data_path = data_path
11
+ self.user_prompt = prompt
12
+ self.return_prompt = return_prompt
13
+
14
+ self.data_list = self._load_and_preprocess(self.data_path)
15
+ self.text2id = self._build_text_vocab()
16
+
17
+ def _load_and_preprocess(self, data_path):
18
+ data_list = []
19
+ df = pd.read_csv(data_path)
20
+ for _, row in df.iterrows():
21
+ try:
22
+ name = str(row['name']).strip()
23
+ prot_seq = str(row['aa_seq']).strip()
24
+ result = str(int(row['label'])).strip()
25
+
26
+ text_seq = f"<answer>{result}</answer>\n"
27
+
28
+ prompt = f"""
29
+ Task: Determine whether this protein is a metalloprotein based on the provided sequence and protein name {name}.
30
+ Background: Metalloproteins are proteins that bind metal ions, often through specific amino acid residues such as histidine (H), cysteine (C), aspartate (D), or glutamate (E).
31
+ Question: Does this protein bind metal ions? Please choose one of the following options:
32
+ 0: Non-metalloprotein — This protein does **not** bind to any metal ions.
33
+ 1: Metalloprotein — This protein **binds** to one or more metal ions.
34
+ """
35
+ if self.user_prompt:
36
+ prompt += self.user_prompt
37
+
38
+ # extra可以返回原始feather字符串,也可以返回feather_vals
39
+ # 或 feather_raw
40
+ data_list.append((prot_seq, text_seq, prompt))
41
+ except Exception as e:
42
+ print(f"警告: 跳过有问题的行: {row},原因: {e}")
43
+ return data_list
44
+
45
+ def _build_text_vocab(self):
46
+ text2id = {}
47
+ for _, text_seq, _ in self.data_list:
48
+ if text_seq not in text2id:
49
+ text2id[text_seq] = len(text2id)
50
+ return text2id
51
+
52
+ def shuffle(self):
53
+ random.shuffle(self.data_list)
54
+ return self
55
+
56
+ def __len__(self):
57
+ return len(self.data_list)
58
+
59
+ def __getitem__(self, index):
60
+ prot_seq, text_seq, prompt = self.data_list[index]
61
+ if self.return_prompt:
62
+ return prot_seq, prompt, text_seq,index
63
+ return prot_seq, text_seq, index
data_provider/mutation.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pytorch_lightning import LightningDataModule
3
+ from torch.utils.data import DataLoader, ConcatDataset, Dataset
4
+ from data_provider.stage1_dm import SwissProtDataset, OntoProteinDataset
5
+ import pandas as pd
6
+
7
+ class TAPE_Stability(Dataset):
8
+ def __init__(self, data_path, prompt='', return_prompt=False):
9
+ super(TAPE_Stability, self).__init__()
10
+ self.data_path = data_path
11
+ self.user_prompt = prompt
12
+ self.return_prompt = return_prompt
13
+
14
+ self.data_list = self._load_and_preprocess(self.data_path)
15
+ self.text2id = self._build_text_vocab()
16
+
17
+ def _load_and_preprocess(self, data_path):
18
+ data_list = []
19
+ df = pd.read_csv(data_path)
20
+ for _, row in df.iterrows():
21
+ try:
22
+
23
+ prot_seq = str(row['aa_seq']).strip()
24
+ result = str(row['label']).strip()
25
+
26
+ text_seq = f"<answer>{result}</answer>\n"
27
+
28
+ prompt = """
29
+ 【Task】Predict the thermostability score of the given protein sequence, which reflects its ability to maintain proper folding above a concentration threshold.
30
+ 【Background】Protein stability is an important biophysical property indicating a protein’s resistance to denaturation or unfolding under thermal or chemical stress. In this task, each protein is evaluated by a numerical stability score, where higher values indicate greater ability to remain folded under extreme conditions. This score serves as a proxy for the protein’s intrinsic stability.
31
+ 【Question】What is the predicted stability score for this sequence?
32
+ 【Output Format】You must return only the score number, wrapped in <answer></answer> tags.
33
+ """
34
+ if self.user_prompt:
35
+ prompt += self.user_prompt
36
+
37
+ # extra可以返回原始feather字符串,也可以返回feather_vals
38
+ # 或 feather_raw
39
+ data_list.append((prot_seq, text_seq, prompt))
40
+ except Exception as e:
41
+ print(f"警告: 跳过有问题的行: {row},原因: {e}")
42
+ return data_list
43
+
44
+ def _build_text_vocab(self):
45
+ text2id = {}
46
+ for _, text_seq, _ in self.data_list:
47
+ if text_seq not in text2id:
48
+ text2id[text_seq] = len(text2id)
49
+ return text2id
50
+
51
+ def shuffle(self):
52
+ random.shuffle(self.data_list)
53
+ return self
54
+
55
+ def __len__(self):
56
+ return len(self.data_list)
57
+
58
+ def __getitem__(self, index):
59
+ prot_seq, text_seq, prompt = self.data_list[index]
60
+ if self.return_prompt:
61
+ return prot_seq, prompt, text_seq,index
62
+ return prot_seq, text_seq, index
63
+
64
+ class TAPE_Fluorescence(Dataset):
65
+ def __init__(self, data_path, prompt='', return_prompt=False):
66
+ super(TAPE_Fluorescence, self).__init__()
67
+ self.data_path = data_path
68
+ self.user_prompt = prompt
69
+ self.return_prompt = return_prompt
70
+
71
+ self.data_list = self._load_and_preprocess(self.data_path)
72
+ self.text2id = self._build_text_vocab()
73
+
74
+ def _load_and_preprocess(self, data_path):
75
+ data_list = []
76
+ df = pd.read_csv(data_path)
77
+ for _, row in df.iterrows():
78
+ try:
79
+
80
+ prot_seq = str(row['aa_seq']).strip()
81
+ result = str(row['label']).strip()
82
+
83
+ text_seq = f"<answer>{result}</answer>\n"
84
+
85
+ prompt = """
86
+ 【Task】Predict the log fluorescence intensity of the given protein sequence.
87
+ 【Output Format】You must return only the numerical value, wrapped in <answer></answer> tags.
88
+ """
89
+ # 【Background】Fluorescence intensity reflects how strongly a protein emits light when excited by a specific wavelength. It is commonly measured in protein variants such as GFP (Green Fluorescent Protein) mutants. The log-transformed fluorescence value quantifies the brightness on a logarithmic scale. Mutations in the sequence can increase or decrease fluorescence intensity.
90
+ # 【Question】What is the predicted log fluorescence intensity for this sequence?
91
+ if self.user_prompt:
92
+ prompt += self.user_prompt
93
+
94
+ # extra可以返回原始feather字符串,也可以返回feather_vals
95
+ # 或 feather_raw
96
+ data_list.append((prot_seq, text_seq, prompt))
97
+ except Exception as e:
98
+ print(f"警告: 跳过有问题的行: {row},原因: {e}")
99
+ return data_list
100
+
101
+ def _build_text_vocab(self):
102
+ text2id = {}
103
+ for _, text_seq, _ in self.data_list:
104
+ if text_seq not in text2id:
105
+ text2id[text_seq] = len(text2id)
106
+ return text2id
107
+
108
+ def shuffle(self):
109
+ random.shuffle(self.data_list)
110
+ return self
111
+
112
+ def __len__(self):
113
+ return len(self.data_list)
114
+
115
+ def __getitem__(self, index):
116
+ prot_seq, text_seq, prompt = self.data_list[index]
117
+ if self.return_prompt:
118
+ return prot_seq, prompt, text_seq,index
119
+ return prot_seq, text_seq, index
data_provider/production.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import json
3
+ from pytorch_lightning import LightningDataModule
4
+ from torch.utils.data import DataLoader, ConcatDataset, Dataset
5
+ from data_provider.stage1_dm import SwissProtDataset, OntoProteinDataset
6
+ import pandas as pd
7
+ class Antibiotic_Resistance(Dataset):
8
+ def __init__(self, data_path, prompt='', return_prompt=False):
9
+ super(Antibiotic_Resistance, self).__init__()
10
+ self.data_path = data_path
11
+ self.user_prompt = prompt
12
+ self.return_prompt = return_prompt
13
+
14
+ self.data_list = self._load_and_preprocess(self.data_path)
15
+ self.text2id = self._build_text_vocab()
16
+
17
+ def _load_and_preprocess(self, data_path):
18
+ data_list = []
19
+ df = pd.read_csv(data_path)
20
+ for _, row in df.iterrows():
21
+ try:
22
+ prot_seq = str(row['aa_seq']).strip()
23
+ result = str(row['label']).strip()
24
+ text_seq = f"<answer>{result}</answer>\n"
25
+
26
+ prompt = f"""
27
+ 【Task】Predict the antibiotic resistance class of the given protein.
28
+ 【Background】Antibiotic resistance refers to the ability of bacteria or other microbes to resist the effects of antibiotics that were once effective against them. Each protein is associated with resistance to exactly one of 19 antibiotic classes.
29
+
30
+ 【Prediction Goal】Based on the provided protein sequence, determine which single antibiotic class (from 1 to 19) this protein confers resistance to.
31
+
32
+ 【Output Format】Return only one predicted resistance class (a number from 1 to 19), wrapped in <answer> </answer> tags.
33
+ Example: <answer>7</answer>
34
+ """
35
+ if self.user_prompt:
36
+ prompt += self.user_prompt
37
+
38
+ # extra可以返回原始feather字符串,也可以返回feather_vals
39
+ # 或 feather_raw
40
+ data_list.append((prot_seq, text_seq, prompt))
41
+ except Exception as e:
42
+ print(f"警告: 跳过有问题的行: {row},原因: {e}")
43
+ return data_list
44
+
45
+ def _build_text_vocab(self):
46
+ text2id = {}
47
+ for _, text_seq, _ in self.data_list:
48
+ if text_seq not in text2id:
49
+ text2id[text_seq] = len(text2id)
50
+ return text2id
51
+
52
+ def shuffle(self):
53
+ random.shuffle(self.data_list)
54
+ return self
55
+
56
+ def __len__(self):
57
+ return len(self.data_list)
58
+
59
+ def __getitem__(self, index):
60
+ prot_seq, text_seq, prompt = self.data_list[index]
61
+ if self.return_prompt:
62
+ return prot_seq, prompt, text_seq,index
63
+ return prot_seq, text_seq, index
64
+
65
+ class Thermostability(Dataset):
66
+ def __init__(self, data_path, prompt='', return_prompt=False):
67
+ super(Thermostability, self).__init__()
68
+ self.data_path = data_path
69
+ self.user_prompt = prompt
70
+ self.return_prompt = return_prompt
71
+
72
+ self.data_list = self._load_and_preprocess(self.data_path)
73
+ self.text2id = self._build_text_vocab()
74
+
75
+ def _load_and_preprocess(self, data_path):
76
+ data_list = []
77
+ df = pd.read_csv(data_path)
78
+ for _, row in df.iterrows():
79
+ try:
80
+ #name = str(row['name']).strip()
81
+ prot_seq = str(row['aa_seq']).strip()
82
+ result = str(row['label']).strip()
83
+ text_seq = f"<answer>{result}</answer>\n"
84
+
85
+ prompt = f"""
86
+ 【Task】Predict the thermostability value of the given protein.
87
+ 【Background】Thermostability refers to the ability of a molecule to resist irreversible chemical or physical changes at high temperatures, such as decomposition or aggregation.
88
+ 【Output Format】Provide the predicted thermostability as a numeric value (e.g., melting temperature in °C). Wrap your answer in <answer></answer> tags.
89
+
90
+ """
91
+ if self.user_prompt:
92
+ prompt += self.user_prompt
93
+
94
+ # extra可以返回原始feather字符串,也可以返回feather_vals
95
+ # 或 feather_raw
96
+ data_list.append((prot_seq, text_seq, prompt))
97
+ except Exception as e:
98
+ print(f"警告: 跳过有问题的行: {row},原因: {e}")
99
+ return data_list
100
+
101
+ def _build_text_vocab(self):
102
+ text2id = {}
103
+ for _, text_seq, _ in self.data_list:
104
+ if text_seq not in text2id:
105
+ text2id[text_seq] = len(text2id)
106
+ return text2id
107
+
108
+ def shuffle(self):
109
+ random.shuffle(self.data_list)
110
+ return self
111
+
112
+ def __len__(self):
113
+ return len(self.data_list)
114
+
115
+ def __getitem__(self, index):
116
+ prot_seq, text_seq, prompt = self.data_list[index]
117
+ if self.return_prompt:
118
+ return prot_seq, prompt, text_seq,index
119
+ return prot_seq, text_seq, index
120
+
121
+
122
+
123
+ class Material(Dataset):
124
+ def __init__(self, data_path, prompt='', return_prompt=False):
125
+ super(Material, self).__init__()
126
+ self.data_path = data_path
127
+ self.user_prompt = prompt
128
+ self.return_prompt = return_prompt
129
+
130
+ self.data_list = self._load_and_preprocess(self.data_path)
131
+ self.text2id = self._build_text_vocab()
132
+
133
+ def _load_and_preprocess(self, data_path):
134
+ data_list = []
135
+ df = pd.read_csv(data_path)
136
+ for _, row in df.iterrows():
137
+ try:
138
+ #name = str(row['name']).strip()
139
+ prot_seq = str(row['aa_seq']).strip()
140
+ result = str(row['label']).strip()
141
+ text_seq = f"<answer>{result}</answer>\n"
142
+
143
+ prompt = f"""
144
+ 【Task】Determine whether the given material can be successfully produced.
145
+ 【Background】In materials science, certain chemical compounds or materials may or may not be synthesizable (i.e., producible) under realistic experimental conditions. This task requires classifying whether the input material composition and structure allow for successful production. This is a binary classification problem.
146
+ 【Question】Can this material be successfully produced?
147
+ 【Output Format】Respond with either "1" or "0", and wrap your answer in <answer></answer> tags.
148
+
149
+ """
150
+ if self.user_prompt:
151
+ prompt += self.user_prompt
152
+
153
+ # extra可以返回原始feather字符串,也可以返回feather_vals
154
+ # 或 feather_raw
155
+ data_list.append((prot_seq, text_seq, prompt))
156
+ except Exception as e:
157
+ print(f"警告: 跳过有问题的行: {row},原因: {e}")
158
+ return data_list
159
+
160
+ def _build_text_vocab(self):
161
+ text2id = {}
162
+ for _, text_seq, _ in self.data_list:
163
+ if text_seq not in text2id:
164
+ text2id[text_seq] = len(text2id)
165
+ return text2id
166
+
167
+ def shuffle(self):
168
+ random.shuffle(self.data_list)
169
+ return self
170
+
171
+ def __len__(self):
172
+ return len(self.data_list)
173
+
174
+ def __getitem__(self, index):
175
+ prot_seq, text_seq, prompt = self.data_list[index]
176
+ if self.return_prompt:
177
+ return prot_seq, prompt, text_seq,index
178
+ return prot_seq, text_seq, index
179
+
180
+
181
+
182
+ class Clone(Dataset):
183
+ def __init__(self, data_path, prompt='', return_prompt=False):
184
+ super(Clone, self).__init__()
185
+ self.data_path = data_path
186
+ self.user_prompt = prompt
187
+ self.return_prompt = return_prompt
188
+
189
+ self.data_list = self._load_and_preprocess(self.data_path)
190
+ self.text2id = self._build_text_vocab()
191
+
192
+ def _load_and_preprocess(self, data_path):
193
+ data_list = []
194
+ df = pd.read_csv(data_path)
195
+ for _, row in df.iterrows():
196
+ try:
197
+ #name = str(row['name']).strip()
198
+ prot_seq = str(row['aa_seq']).strip()
199
+ result = str(row['label']).strip()
200
+ text_seq = f"<answer>{result}</answer>\n"
201
+
202
+ prompt = f"""
203
+
204
+ 【Task】Determine whether the given protein sequence can be successfully cloned.
205
+ 【Background】In molecular biology, cloning refers to the process of creating copies of a DNA or protein sequence. Some sequences can be challenging to clone due to their length, GC-content, secondary structures, or toxicity to the host. This task requires predicting whether the given protein sequence is likely to be successfully cloned. This is a binary classification problem.
206
+ 【Question】Can this protein sequence be successfully cloned?
207
+ 【Output Format】Respond with either "1" or "0", and wrap your answer in <answer></answer> tags.
208
+ """
209
+ if self.user_prompt:
210
+ prompt += self.user_prompt
211
+
212
+ # extra可以返回原始feather字符串,也可以返回feather_vals
213
+ # 或 feather_raw
214
+ data_list.append((prot_seq, text_seq, prompt))
215
+ except Exception as e:
216
+ print(f"警告: 跳过有问题的行: {row},原因: {e}")
217
+ return data_list
218
+
219
+ def _build_text_vocab(self):
220
+ text2id = {}
221
+ for _, text_seq, _ in self.data_list:
222
+ if text_seq not in text2id:
223
+ text2id[text_seq] = len(text2id)
224
+ return text2id
225
+
226
+ def shuffle(self):
227
+ random.shuffle(self.data_list)
228
+ return self
229
+
230
+ def __len__(self):
231
+ return len(self.data_list)
232
+
233
+ def __getitem__(self, index):
234
+ prot_seq, text_seq, prompt = self.data_list[index]
235
+ if self.return_prompt:
236
+ return prot_seq, prompt, text_seq,index
237
+ return prot_seq, text_seq, index
data_provider/prot_qa_dm.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+ import json
4
+ from pytorch_lightning import LightningDataModule
5
+ # import torch_geometric
6
+ from torch.utils.data import DataLoader, Dataset
7
+ from pathlib import Path
8
+
9
+
10
+ class ProtQACollater(object):
11
+ def __init__(self, tokenizer, prot_tokenizer, q_max_len, a_max_len, prot_max_len):
12
+ self.tokenizer = tokenizer
13
+ self.prot_tokenizer = prot_tokenizer
14
+ self.q_max_len = q_max_len
15
+ self.a_max_len = a_max_len
16
+ self.prot_max_len = prot_max_len
17
+
18
+ def __call__(self, batch):
19
+ prot_seqs, questions, answers, _, _ = zip(*batch)
20
+ answers = [a + '\n' for a in answers]
21
+ prot_batch = self.prot_tokenizer(prot_seqs,
22
+ truncation=True,
23
+ padding='max_length',
24
+ max_length=self.prot_max_len,
25
+ return_tensors="pt",
26
+ return_attention_mask=True,
27
+ return_token_type_ids=False)
28
+ if False:
29
+ self.tokenizer.padding_side = 'left'
30
+ q_batch = self.tokenizer(questions,
31
+ truncation=True,
32
+ padding='max_length',
33
+ add_special_tokens=True,
34
+ max_length=self.q_max_len,
35
+ return_tensors='pt',
36
+ return_attention_mask=True,
37
+ return_token_type_ids=False)
38
+ self.tokenizer.padding_side = 'right'
39
+ a_batch = self.tokenizer(answers,
40
+ truncation=True,
41
+ padding='max_length',
42
+ add_special_tokens=True,
43
+ max_length=self.a_max_len,
44
+ return_tensors='pt',
45
+ return_attention_mask=True,
46
+ return_token_type_ids=False)
47
+ return prot_batch, q_batch, a_batch
48
+ else:
49
+ self.tokenizer.padding_side = 'right'
50
+ qa_pair = [[q, a] for q, a in zip(questions, answers)]
51
+ qa_batch = self.tokenizer(qa_pair,
52
+ truncation=True,
53
+ padding='max_length',
54
+ add_special_tokens=True,
55
+ max_length=self.q_max_len + self.a_max_len,
56
+ return_tensors='pt',
57
+ return_attention_mask=True,
58
+ return_token_type_ids=True)
59
+ return prot_batch, qa_batch
60
+
61
+
62
+ class InferenceCollater(object):
63
+ def __init__(self, tokenizer, prot_tokenizer, q_max_len, a_max_len, prot_max_len):
64
+ self.tokenizer = tokenizer
65
+ self.prot_tokenizer = prot_tokenizer
66
+ self.q_max_len = q_max_len
67
+ self.a_max_len = a_max_len
68
+ self.prot_max_len = prot_max_len
69
+
70
+ def __call__(self, batch):
71
+ prot_seqs, questions, answers, q_types, indices = zip(*batch)
72
+ answers = [a + '\n' for a in answers]
73
+ prot_batch = self.prot_tokenizer(prot_seqs,
74
+ truncation=True,
75
+ padding='max_length',
76
+ max_length=self.prot_max_len,
77
+ return_tensors="pt",
78
+ return_attention_mask=True,
79
+ return_token_type_ids=False)
80
+ self.tokenizer.padding_side = 'left'
81
+ q_batch = self.tokenizer(questions,
82
+ truncation=True,
83
+ padding='max_length',
84
+ add_special_tokens=True,
85
+ max_length=self.q_max_len,
86
+ return_tensors='pt',
87
+ return_attention_mask=True,
88
+ return_token_type_ids=False)
89
+ target_dict = {'targets': answers, 'q_types': q_types, 'indices': indices}
90
+ return prot_batch, q_batch, target_dict
91
+
92
+
93
+ class ProtQADM(LightningDataModule):
94
+ def __init__(
95
+ self,
96
+ root: str = 'data/',
97
+ args=None,
98
+ ):
99
+ super().__init__()
100
+ self.args = args
101
+ self.batch_size = args.batch_size
102
+ self.inference_batch_size = args.inference_batch_size
103
+ self.num_workers = args.num_workers
104
+ self.q_max_len = args.q_max_len
105
+ self.a_max_len = args.a_max_len
106
+ self.prot_max_len = args.prot_max_len
107
+ self.prompt = args.prompt
108
+
109
+ self.train_dataset = PDBQADataset(root, 'train.txt', prompt=self.prompt, filter_side_qa=args.filter_side_qa)
110
+ self.val_dataset = PDBQADataset(root, 'val.txt', prompt=self.prompt, filter_side_qa=args.filter_side_qa)
111
+ self.test_dataset = PDBQADataset(root, 'test.txt', prompt=self.prompt, filter_side_qa=args.filter_side_qa)
112
+
113
+ self.tokenizer = None
114
+ self.prot_tokenizer = None
115
+
116
+ def init_tokenizer(self, tokenizer, prot_tokenizer):
117
+ self.tokenizer = tokenizer
118
+ self.prot_tokenizer = prot_tokenizer
119
+
120
+ def train_dataloader(self):
121
+ loader = DataLoader(
122
+ self.train_dataset,
123
+ batch_size=self.batch_size,
124
+ shuffle=True,
125
+ num_workers=self.num_workers,
126
+ pin_memory=False,
127
+ drop_last=True,
128
+ persistent_workers=False,
129
+ collate_fn=ProtQACollater(self.tokenizer, self.prot_tokenizer, self.q_max_len, self.a_max_len, self.prot_max_len),
130
+ )
131
+ return loader
132
+
133
+ def val_dataloader(self):
134
+ val_loader = DataLoader(
135
+ self.val_dataset,
136
+ batch_size=self.batch_size,
137
+ shuffle=False,
138
+ num_workers=self.num_workers,
139
+ pin_memory=False,
140
+ drop_last=False,
141
+ persistent_workers=False,
142
+ collate_fn=ProtQACollater(self.tokenizer, self.prot_tokenizer, self.q_max_len, self.a_max_len, self.prot_max_len),
143
+ )
144
+ test_loader = DataLoader(
145
+ self.test_dataset,
146
+ batch_size=self.inference_batch_size,
147
+ shuffle=False,
148
+ num_workers=self.num_workers,
149
+ pin_memory=False,
150
+ drop_last=False,
151
+ persistent_workers=False,
152
+ collate_fn=InferenceCollater(self.tokenizer, self.prot_tokenizer, self.q_max_len, self.a_max_len, self.prot_max_len),
153
+ )
154
+ return [val_loader, test_loader]
155
+
156
+
157
+ def add_model_specific_args(parent_parser):
158
+ parser = parent_parser.add_argument_group("Data module")
159
+ parser.add_argument('--num_workers', type=int, default=2)
160
+ parser.add_argument('--batch_size', type=int, default=32)
161
+ parser.add_argument('--inference_batch_size', type=int, default=4)
162
+ parser.add_argument('--root', type=str, default='data/SwissProtV3')
163
+ parser.add_argument('--text_max_len', type=int, default=128)
164
+ parser.add_argument('--q_max_len', type=int, default=34)
165
+ parser.add_argument('--a_max_len', type=int, default=36)
166
+ parser.add_argument('--prot_max_len', type=int, default=1024)
167
+ parser.add_argument('--prompt', type=str, default='The protein has the following properties: ')
168
+ parser.add_argument('--filter_side_qa', action='store_true', default=False)
169
+ return parent_parser
170
+
171
+
172
+ class PDBQADataset(Dataset):
173
+ def __init__(self, root_path, subset, prompt="Question: {} Answer:", filter_side_qa=False):
174
+ super(PDBQADataset, self).__init__()
175
+ self.data_path = Path(root_path) / subset
176
+ self.qa_path = Path(root_path) / 'qa_all.json'
177
+ self.q_type_path = Path(root_path) / 'q_types.txt'
178
+ self.prompt = prompt
179
+
180
+ ## load dataset
181
+ with open(self.qa_path, 'r') as f:
182
+ qa_data = json.load(f)
183
+
184
+ with open(self.data_path, 'r') as f:
185
+ lines = f.readlines()
186
+ pdb2seq = [line.strip().split('\t') for line in lines]
187
+
188
+
189
+ ## load q types
190
+ with open(self.q_type_path, 'r') as f:
191
+ q_types = [line.strip().split('\t') for line in f.readlines()]
192
+ self.q_type_dict = {q: t for q, t in q_types}
193
+
194
+ ## process dataset
195
+ pdb_set = set(i[0] for i in pdb2seq)
196
+ ## filter qa data
197
+ qa_data = {k: v for k, v in qa_data.items() if k in pdb_set}
198
+ assert len(qa_data) == len(pdb_set), print(len(qa_data), len(pdb_set))
199
+
200
+ ## generate qa data
201
+ self.data_list = []
202
+ for pdb_id, seq in pdb2seq:
203
+ qa_list = qa_data[pdb_id]
204
+ for qa in qa_list:
205
+ q = qa['Q']
206
+ a = str(qa['A'])
207
+ if filter_side_qa:
208
+ q_type = self.q_type_dict[q]
209
+ if q_type.find('side information') >= 0:
210
+ continue
211
+ self.data_list.append((seq, q, a))
212
+
213
+ def __len__(self):
214
+ return len(self.data_list)
215
+
216
+ def __getitem__(self, index):
217
+ seq, q, a = self.data_list[index]
218
+ q_type = self.q_type_dict[q]
219
+ q = self.prompt.format(q)
220
+ return seq, q, a, q_type, index
221
+
222
+
223
+ if __name__ == '__main__':
224
+ import numpy as np
225
+ from collections import defaultdict, Counter
226
+ train_dataset = PDBQADataset('../data/PDBDataset', 'train.txt', filter_side_qa=True)
227
+ val_dataset = PDBQADataset('../data/PDBDataset', 'val.txt', filter_side_qa=True)
228
+ test_dataset = PDBQADataset('../data/PDBDataset', 'test.txt', filter_side_qa=True)
229
+ if True:
230
+ # print(len(train_dataset), len(val_dataset), len(test_dataset))
231
+ # train_protein_lens = np.asarray([len(p) for p in train_dataset.protein_list])
232
+ # val_protein_lens = np.asarray([len(p) for p in val_dataset.protein_list])
233
+ # test_protein_lens = np.asarray([len(p) for p in test_dataset.protein_list])
234
+
235
+ q_lens = []
236
+ a_lens = []
237
+ for seq, q, a in train_dataset.data_list:
238
+ q_lens.append(len(q.split()))
239
+ a_lens.append(len(a.split()))
240
+
241
+ print(np.asarray(q_lens).min(), np.asarray(q_lens).max(), np.asarray(q_lens).mean())
242
+ print(np.asarray(a_lens).min(), np.asarray(a_lens).max(), np.asarray(a_lens).mean())
243
+
244
+ q_lens = []
245
+ a_lens = []
246
+ for seq, q, a in val_dataset.data_list:
247
+ q_lens.append(len(q.split()))
248
+ a_lens.append(len(a.split()))
249
+
250
+ print(np.asarray(q_lens).min(), np.asarray(q_lens).max(), np.asarray(q_lens).mean())
251
+ print(np.asarray(a_lens).min(), np.asarray(a_lens).max(), np.asarray(a_lens).mean())
252
+
253
+ q_lens = []
254
+ a_lens = []
255
+ for seq, q, a in test_dataset.data_list:
256
+ q_lens.append(len(q.split()))
257
+ a_lens.append(len(a.split()))
258
+
259
+ print(np.asarray(q_lens).min(), np.asarray(q_lens).max(), np.asarray(q_lens).mean())
260
+ print(np.asarray(a_lens).min(), np.asarray(a_lens).max(), np.asarray(a_lens).mean())
261
+
262
+ elif False:
263
+ ## construct the guess for prediction by number
264
+ train_counter = defaultdict(Counter)
265
+ for _, q, a in train_dataset.data_list:
266
+ train_counter[q.lower()][a] += 1
267
+ ## get the most common answer
268
+ q2a = {}
269
+ for q, counter in train_counter.items():
270
+ q2a[q] = counter.most_common(1)[0][0]
271
+
272
+ ## test the guess
273
+ acc = 0
274
+ for _, q, a in test_dataset.data_list:
275
+ if q.lower() in q2a:
276
+ predict = q2a[q.lower()]
277
+ if predict.lower() == a.lower():
278
+ acc += 1
279
+ print(acc / len(test_dataset.data_list))
280
+ elif False:
281
+ from transformers import AutoTokenizer, EsmTokenizer
282
+ llm_tokenizer = AutoTokenizer.from_pretrained('facebook/galactica-1.3b', use_fast=False, padding_side='right')
283
+ plm_tokenizer = EsmTokenizer.from_pretrained('facebook/esm2_t30_150M_UR50D')
284
+ llm_tokenizer.add_special_tokens({'pad_token': '<pad>'})
285
+ loader = DataLoader(
286
+ train_dataset,
287
+ batch_size=32,
288
+ shuffle=True,
289
+ num_workers=4,
290
+ pin_memory=False,
291
+ drop_last=True,
292
+ persistent_workers=False,
293
+ collate_fn=ProtQACollater(llm_tokenizer, plm_tokenizer, 40, 40, 1024),
294
+ )
295
+ else:
296
+ print(len(train_dataset.data_list))
297
+ print(len(val_dataset.data_list))
298
+ print(len(test_dataset.data_list))
299
+
data_provider/proteinchat_dm.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+ import json
4
+ import random
5
+ import torch
6
+ import os
7
+ import numpy as np
8
+ from pytorch_lightning import LightningDataModule
9
+ from torch.utils.data import DataLoader, Dataset
10
+ from data_provider.gal_helpers import escape_custom_split_sequence
11
+ from pathlib import Path
12
+ from torch.utils.data.dataloader import default_collate
13
+
14
+
15
+ class ProteinChatCollater(object):
16
+ def __init__(self, tokenizer, q_max_len, a_max_len, use_gal):
17
+ self.tokenizer = tokenizer
18
+ self.q_max_len = q_max_len
19
+ self.a_max_len = a_max_len
20
+ self.use_gal = use_gal
21
+
22
+ def __call__(self, batch):
23
+ embeds, prot_seqs, questions, answers, q_types = zip(*batch)
24
+ max_embed_len = 896
25
+ ## concate
26
+ if False:
27
+ max_dim = max([e.shape[0] for e in embeds])
28
+
29
+ padded_embeds = []
30
+ for embed in embeds:
31
+ shape_dim0 = embed.shape[0]
32
+ pad1 = ((0, max_dim - shape_dim0), (0, 0), (0, 0))
33
+ padded_embeds.append(np.pad(embed, pad1, mode='constant'))
34
+ padded_embeds = default_collate(padded_embeds).squeeze(dim=2)[:,:1024,:]
35
+ else:
36
+ padded_embeds = torch.zeros(len(embeds), max_embed_len, 512)
37
+ for i in range(len(embeds)):
38
+ padded_embeds[i, :embeds[i].shape[0], :] = embeds[i][:max_embed_len, :]
39
+ padded_embeds = padded_embeds.detach()
40
+
41
+ assert len(prot_seqs) == len(questions) == len(answers)
42
+
43
+ if self.use_gal:
44
+ questions = [escape_custom_split_sequence(q) for q in questions]
45
+ answers = [a + '\n' for a in answers]
46
+ self.tokenizer.padding_side = 'left'
47
+ q_batch = self.tokenizer(questions,
48
+ truncation=True,
49
+ padding='max_length',
50
+ add_special_tokens=True,
51
+ max_length=self.q_max_len,
52
+ return_tensors='pt',
53
+ return_attention_mask=True,
54
+ return_token_type_ids=False)
55
+ self.tokenizer.padding_side = 'right'
56
+ a_batch = self.tokenizer(answers,
57
+ truncation=True,
58
+ padding='max_length',
59
+ add_special_tokens=True,
60
+ max_length=self.a_max_len,
61
+ return_tensors='pt',
62
+ return_attention_mask=True,
63
+ return_token_type_ids=False)
64
+ prot_mask = torch.ones(padded_embeds.shape[0], padded_embeds.shape[1], dtype=torch.bool)
65
+ return (padded_embeds, prot_mask), q_batch, a_batch
66
+
67
+
68
+ class InferenceCollater(object):
69
+ def __init__(self, tokenizer, q_max_len, a_max_len, use_gal):
70
+ self.tokenizer = tokenizer
71
+ self.q_max_len = q_max_len
72
+ self.a_max_len = a_max_len
73
+ self.use_gal = use_gal
74
+
75
+ def __call__(self, batch):
76
+ embeds, prot_seqs, questions, answers, q_types = zip(*batch)
77
+ max_embed_len = 896
78
+ ## concate
79
+ if False:
80
+ max_dim = max([e.shape[0] for e in embeds])
81
+
82
+ padded_embeds = []
83
+ for embed in embeds:
84
+ shape_dim0 = embed.shape[0]
85
+ pad1 = ((0, max_dim - shape_dim0), (0, 0), (0, 0))
86
+ padded_embeds.append(np.pad(embed, pad1, mode='constant'))
87
+ padded_embeds = default_collate(padded_embeds).squeeze(dim=2)[:,:1024,:]
88
+ else:
89
+ padded_embeds = torch.zeros(len(embeds), max_embed_len, 512)
90
+ for i in range(len(embeds)):
91
+ padded_embeds[i, :embeds[i].shape[0], :] = embeds[i][:max_embed_len, :]
92
+ padded_embeds = padded_embeds.detach()
93
+
94
+ assert len(prot_seqs) == len(questions) == len(answers)
95
+
96
+ if self.use_gal:
97
+ questions = [escape_custom_split_sequence(q) for q in questions]
98
+ answers = [a + '\n' for a in answers]
99
+ self.tokenizer.padding_side = 'left'
100
+ q_batch = self.tokenizer(questions,
101
+ truncation=True,
102
+ padding='max_length',
103
+ add_special_tokens=True,
104
+ max_length=self.q_max_len,
105
+ return_tensors='pt',
106
+ return_attention_mask=True,
107
+ return_token_type_ids=False)
108
+ prot_mask = torch.ones(padded_embeds.shape[0], padded_embeds.shape[1], dtype=torch.bool)
109
+ target_dict = {'answers': answers, "q_types": q_types}
110
+ return (padded_embeds, prot_mask), q_batch, target_dict
111
+
112
+
113
+ class ProteinChatDM(LightningDataModule):
114
+ def __init__(
115
+ self,
116
+ root: str = 'data/',
117
+ args=None,
118
+ ):
119
+ super().__init__()
120
+ self.args = args
121
+ self.batch_size = args.batch_size
122
+ self.inference_batch_size = args.inference_batch_size
123
+ self.num_workers = args.num_workers
124
+ self.q_max_len = args.q_max_len
125
+ self.a_max_len = args.a_max_len
126
+ self.prompt = args.prompt
127
+
128
+ self.train_dataset = ProteinChatDataset(root, 'train.txt', prompt="### Human: {}\n### Assistant: ", pt_file_path=args.pt_file_path)
129
+ self.val_dataset = ProteinChatDataset(root, 'val.txt', prompt="### Human: {}\n### Assistant: ", pt_file_path=args.pt_file_path)
130
+ self.test_dataset = ProteinChatDataset(root, 'test.txt', prompt="### Human: {}\n### Assistant: ", pt_file_path=args.pt_file_path)
131
+
132
+ self.tokenizer = None
133
+ self.use_gal = args.llm_name.find('gal') >= 0
134
+
135
+ def init_tokenizer(self, tokenizer):
136
+ self.tokenizer = tokenizer
137
+
138
+ def train_dataloader(self):
139
+ loader = DataLoader(
140
+ self.train_dataset,
141
+ batch_size=self.batch_size,
142
+ shuffle=True,
143
+ num_workers=self.num_workers,
144
+ pin_memory=False,
145
+ drop_last=True,
146
+ persistent_workers=False,
147
+ collate_fn=ProteinChatCollater(self.tokenizer, self.q_max_len, self.a_max_len, self.use_gal),
148
+ )
149
+ return loader
150
+
151
+ def val_dataloader(self):
152
+ val_loader = DataLoader(
153
+ self.val_dataset,
154
+ batch_size=self.batch_size,
155
+ shuffle=False,
156
+ num_workers=self.num_workers,
157
+ pin_memory=False,
158
+ drop_last=False,
159
+ persistent_workers=False,
160
+ collate_fn=ProteinChatCollater(self.tokenizer, self.q_max_len, self.a_max_len, self.use_gal),
161
+ )
162
+ test_loader = DataLoader(
163
+ self.test_dataset,
164
+ batch_size=self.inference_batch_size,
165
+ shuffle=False,
166
+ num_workers=self.num_workers,
167
+ pin_memory=False,
168
+ drop_last=False,
169
+ persistent_workers=False,
170
+ collate_fn=InferenceCollater(self.tokenizer, self.q_max_len, self.a_max_len, self.use_gal),
171
+ )
172
+ return [val_loader, test_loader]
173
+
174
+
175
+ def add_model_specific_args(parent_parser):
176
+ parser = parent_parser.add_argument_group("Data module")
177
+ parser.add_argument('--num_workers', type=int, default=2)
178
+ parser.add_argument('--batch_size', type=int, default=32)
179
+ parser.add_argument('--inference_batch_size', type=int, default=4)
180
+ parser.add_argument('--root', type=str, default='data/SwissProtV3')
181
+ parser.add_argument('--q_max_len', type=int, default=30)
182
+ parser.add_argument('--a_max_len', type=int, default=36)
183
+ parser.add_argument('--prompt', type=str, default='[START_AMINO]{}[END_AMINO]. Question: {} Answer:')
184
+ parser.add_argument('--pt_file_path', type=str, default='/home/XXXX-2/proteinchatdata/proteinchat')
185
+ return parent_parser
186
+
187
+
188
+
189
+ class ProteinChatDataset(Dataset):
190
+ def __init__(self, root_path, subset, pt_file_path, prompt="Question: {} Answer:"):
191
+ super(ProteinChatDataset, self).__init__()
192
+ self.data_path = Path(root_path) / subset
193
+ self.qa_path = Path(root_path) / 'qa_all.json'
194
+ self.q_type_path = Path(root_path) / 'q_types.txt'
195
+ self.prompt = prompt
196
+
197
+ ## load dataset
198
+ with open(self.qa_path, 'r') as f:
199
+ qa_data = json.load(f)
200
+
201
+ with open(self.data_path, 'r') as f:
202
+ lines = f.readlines()
203
+ pdb2seq = [line.strip().split('\t') for line in lines]
204
+
205
+ ## process dataset
206
+ pdb_set = set(i[0] for i in pdb2seq)
207
+ ## filter qa data
208
+ qa_data = {k: v for k, v in qa_data.items() if k in pdb_set}
209
+ assert len(qa_data) == len(pdb_set), print(len(qa_data), len(pdb_set))
210
+
211
+ pt_file = Path(pt_file_path).glob('*.pt')
212
+ pt_file_ids = {f.name.split('.pt')[0] for f in pt_file}
213
+ self.pt_file_path = pt_file_path
214
+
215
+ ## load q types
216
+ with open(self.q_type_path, 'r') as f:
217
+ q_types = [line.strip().split('\t') for line in f.readlines()]
218
+ self.q_type_dict = {q: t for q, t in q_types}
219
+
220
+ ## generate qa data
221
+ self.data_list = []
222
+ for pdb_id, seq in pdb2seq:
223
+ if pdb_id not in pt_file_ids:
224
+ continue
225
+ qa_list = qa_data[pdb_id]
226
+ for qa in qa_list:
227
+ q = qa['Q']
228
+ a = str(qa['A'])
229
+ self.data_list.append((pdb_id, seq, q, a))
230
+
231
+ def shuffle(self):
232
+ random.shuffle(self.data_list)
233
+ return self
234
+
235
+ def __len__(self):
236
+ return len(self.data_list)
237
+
238
+ def __getitem__(self, index):
239
+ pdb_id, seq, q, a = self.data_list[index]
240
+ q_type = self.q_type_dict[q]
241
+ path = os.path.join(self.pt_file_path, pdb_id + '.pt')
242
+ embed = torch.load(path, map_location=torch.device('cpu'))
243
+ embed = embed.squeeze(dim=1)
244
+ embed = embed.detach()
245
+ q = self.prompt.format(q)
246
+ return embed, seq, q, a, q_type
247
+
248
+
249
+ if __name__ == '__main__':
250
+ dataset = ProteinChatDataset('./data/PDBDataset', 'train.txt')
251
+ dataset.shuffle()
252
+ for i in range(1000):
253
+ print(dataset[i][0].shape)
254
+
data_provider/stage1_dm.py ADDED
@@ -0,0 +1,539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+ from pytorch_lightning import LightningDataModule
4
+ import json
5
+ from torch.utils.data import DataLoader, Dataset, ConcatDataset
6
+ import random
7
+ from pathlib import Path
8
+
9
+
10
+ def rand_seq_crop(seq, max_len):
11
+ if len(seq) <= max_len:
12
+ return seq
13
+ rand_pos = random.randint(0, len(seq)-1-max_len)
14
+ return seq[rand_pos:rand_pos+max_len]
15
+
16
+
17
+ class Stage1Collater(object):
18
+ def __init__(self, tokenizer, prot_tokenizer, text_max_len, prot_max_len, prot_aug='None'):
19
+ self.tokenizer = tokenizer
20
+ self.prot_tokenizer = prot_tokenizer
21
+ self.text_max_len = text_max_len
22
+ self.prot_max_len = prot_max_len
23
+ self.prot_aug = prot_aug
24
+
25
+ def __call__(self, batch):
26
+ prot_seqs, text_seqs, _ = zip(*batch)
27
+ if self.prot_aug == 'rand_crop':
28
+ prot_seqs = [rand_seq_crop(seq, self.prot_max_len-2) for seq in prot_seqs] # -2 for the two special tokens
29
+
30
+ text_tokens = self.tokenizer(text_seqs,
31
+ truncation=True,
32
+ padding='max_length',
33
+ add_special_tokens=True,
34
+ max_length=self.text_max_len,
35
+ return_tensors='pt',
36
+ return_attention_mask=True,
37
+ return_token_type_ids=False)
38
+ prot_tokens = self.prot_tokenizer(prot_seqs,
39
+ truncation=True,
40
+ padding='max_length',
41
+ max_length=self.prot_max_len,
42
+ return_tensors="pt",
43
+ return_attention_mask=True,
44
+ return_token_type_ids=False)
45
+ return prot_tokens, text_tokens
46
+
47
+
48
+ class Stage1DM(LightningDataModule):
49
+ def __init__(
50
+ self,
51
+ num_workers: int = 0,
52
+ batch_size: int = 256,
53
+ root: str = 'data/',
54
+ args=None,
55
+ ):
56
+ super().__init__()
57
+ self.batch_size = batch_size
58
+ self.match_batch_size = args.match_batch_size
59
+ self.num_workers = num_workers
60
+ self.text_max_len = args.text_max_len
61
+ self.prot_max_len = args.prot_max_len
62
+ if root.find('SwissProt') >= 0:
63
+ self.train_dataset = SwissProtDataset(root+'/train_set.jsonl')
64
+ self.val_dataset = SwissProtDataset(root+'/valid_set.jsonl')
65
+ self.test_dataset = SwissProtDataset(root+'/test_set.jsonl')
66
+ self.val_dataset_match = SwissProtDataset(root+'/valid_set.jsonl').shuffle()
67
+ self.test_dataset_match = SwissProtDataset(root+'/test_set.jsonl').shuffle()
68
+ elif root.find('PDBDataset') >= 0:
69
+ self.train_dataset = PDBAbstractDataset(root, 'train.txt')
70
+ self.val_dataset = PDBAbstractDataset(root, 'val.txt')
71
+ self.test_dataset = PDBAbstractDataset(root, 'test.txt')
72
+ self.val_dataset_match = PDBAbstractDataset(root, 'val.txt').shuffle()
73
+ self.test_dataset_match = PDBAbstractDataset(root, 'test.txt').shuffle()
74
+ elif root.find('OntoProtein') >= 0:
75
+ self.train_dataset = OntoProteinDataset(root+'/train.txt')
76
+ self.val_dataset = OntoProteinDataset(root+'/valid.txt')
77
+ self.test_dataset = OntoProteinDataset(root+'/test.txt')
78
+ self.val_dataset_match = OntoProteinDataset(root+'/valid.txt').shuffle()
79
+ self.test_dataset_match = OntoProteinDataset(root+'/test.txt').shuffle()
80
+ else:
81
+ raise NotImplementedError
82
+
83
+ self.tokenizer = None
84
+ self.prot_tokenizer = None
85
+ self.prot_aug = args.prot_aug
86
+
87
+ def init_tokenizer(self, tokenizer, prot_tokenizer):
88
+ self.tokenizer = tokenizer
89
+ self.prot_tokenizer = prot_tokenizer
90
+
91
+ def train_dataloader(self):
92
+ loader = DataLoader(
93
+ self.train_dataset,
94
+ batch_size=self.batch_size,
95
+ shuffle=True,
96
+ num_workers=self.num_workers,
97
+ pin_memory=False,
98
+ drop_last=True,
99
+ # persistent_workers=True,
100
+ collate_fn=Stage1Collater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len, self.prot_aug)
101
+ )
102
+ return loader
103
+
104
+ def val_dataloader(self):
105
+ loader = DataLoader(
106
+ self.val_dataset,
107
+ batch_size=self.batch_size,
108
+ shuffle=False,
109
+ num_workers=self.num_workers,
110
+ pin_memory=False,
111
+ drop_last=False,
112
+ # persistent_workers=True,
113
+ collate_fn=Stage1Collater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len, self.prot_aug)
114
+ )
115
+ return loader
116
+
117
+ def match_dataloader(self):
118
+ val_match_loader = DataLoader(self.val_dataset_match,
119
+ batch_size=self.match_batch_size,
120
+ shuffle=False,
121
+ num_workers=self.num_workers,
122
+ pin_memory=False,
123
+ drop_last=False,
124
+ # persistent_workers=True,
125
+ collate_fn=Stage1Collater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len, self.prot_aug))
126
+ test_match_loader = DataLoader(self.test_dataset_match,
127
+ batch_size=self.match_batch_size,
128
+ shuffle=False,
129
+ num_workers=self.num_workers,
130
+ pin_memory=False,
131
+ drop_last=False,
132
+ # persistent_workers=True,
133
+ collate_fn=Stage1Collater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len, self.prot_aug))
134
+ return val_match_loader, test_match_loader
135
+
136
+ def add_model_specific_args(parent_parser):
137
+ parser = parent_parser.add_argument_group("Data module")
138
+ parser.add_argument('--num_workers', type=int, default=4)
139
+ parser.add_argument('--batch_size', type=int, default=32)
140
+ parser.add_argument('--match_batch_size', type=int, default=64)
141
+ parser.add_argument('--root', type=str, default='data/SwissProtV3')
142
+ parser.add_argument('--text_max_len', type=int, default=128)
143
+ parser.add_argument('--prot_max_len', type=int, default=1024)
144
+ parser.add_argument('--prot_aug', type=str, default='None')
145
+ return parent_parser
146
+
147
+
148
+
149
+ class Stage1MixDM(LightningDataModule):
150
+ def __init__(
151
+ self,
152
+ num_workers: int = 0,
153
+ batch_size: int = 256,
154
+ root: str = 'data/',
155
+ args=None,
156
+ ):
157
+ super().__init__()
158
+ self.batch_size = batch_size
159
+ self.match_batch_size = args.match_batch_size
160
+ self.num_workers = num_workers
161
+ self.text_max_len = args.text_max_len
162
+ self.prot_max_len = args.prot_max_len
163
+ assert args.mix_dataset
164
+
165
+ train_dataset1 = SwissProtDataset(root+'/SwissProtV3/train_set.jsonl')
166
+ train_dataset2 = OntoProteinDataset(root+'/OntoProteinDatasetV2/train.txt')
167
+ # 新增 PDBAbstract 训练集(需指定 subset 为 'train.txt')
168
+ train_dataset3 = PDBAbstractDataset(root + '/PDBDataset/', subset='train.txt')
169
+
170
+ #self.train_dataset = ConcatDataset([train_dataset1, train_dataset2,train_dataset3], )
171
+ self.train_dataset = ConcatDataset([train_dataset1,train_dataset2], )
172
+
173
+ self.swiss_val_dataset = SwissProtDataset(root+'/SwissProtV3/valid_set.jsonl')
174
+ self.onto_val_dataset = OntoProteinDataset(root+'/OntoProteinDatasetV2/valid.txt')
175
+ self.pdb_val_dataset = PDBAbstractDataset(root + '/PDBDataset/', subset='val.txt')
176
+
177
+ self.swiss_test_dataset = SwissProtDataset(root+'/SwissProtV3/test_set.jsonl')
178
+ self.onto_test_dataset = OntoProteinDataset(root+'/OntoProteinDatasetV2/test.txt')
179
+ self.pdb_test_dataset = PDBAbstractDataset(root + '/PDBDataset/', subset='test.txt')
180
+
181
+ self.swiss_val_dataset_match = SwissProtDataset(root+'/SwissProtV3/valid_set.jsonl').shuffle()
182
+ self.onto_val_dataset_match = OntoProteinDataset(root+'/OntoProteinDatasetV2/valid.txt').shuffle()
183
+ self.pdb_val_dataset_match = PDBAbstractDataset(root + '/PDBDataset/', subset='val.txt').shuffle()
184
+
185
+ self.swiss_test_dataset_match = SwissProtDataset(root+'/SwissProtV3/test_set.jsonl').shuffle()
186
+ self.onto_test_dataset_match = OntoProteinDataset(root+'/OntoProteinDatasetV2/test.txt').shuffle()
187
+ self.pdb_test_dataset_match = PDBAbstractDataset(root + '/PDBDataset/', subset='test.txt').shuffle()
188
+
189
+ self.tokenizer = None
190
+ self.prot_tokenizer = None
191
+ self.prot_aug = args.prot_aug
192
+
193
+ def init_tokenizer(self, tokenizer, prot_tokenizer):
194
+ self.tokenizer = tokenizer
195
+ self.prot_tokenizer = prot_tokenizer
196
+
197
+ def train_dataloader(self):
198
+ loader = DataLoader(
199
+ self.train_dataset,
200
+ batch_size=self.batch_size,
201
+ shuffle=True,
202
+ num_workers=self.num_workers,
203
+ pin_memory=False,
204
+ drop_last=True,
205
+ collate_fn=Stage1Collater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len, self.prot_aug)
206
+ )
207
+ return loader
208
+
209
+ def val_dataloader(self):
210
+ loader1 = DataLoader(
211
+ self.swiss_val_dataset,
212
+ batch_size=self.batch_size,
213
+ shuffle=False,
214
+ num_workers=self.num_workers,
215
+ pin_memory=False,
216
+ drop_last=False,
217
+ collate_fn=Stage1Collater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len, self.prot_aug)
218
+ )
219
+ loader2 = DataLoader(
220
+ self.onto_val_dataset,
221
+ batch_size=self.batch_size,
222
+ shuffle=False,
223
+ num_workers=self.num_workers,
224
+ pin_memory=False,
225
+ drop_last=False,
226
+ collate_fn=Stage1Collater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len, self.prot_aug)
227
+ )
228
+ # 新增 PDB 验证加载器
229
+ loader3 = DataLoader(
230
+ self.pdb_val_dataset,
231
+ batch_size=self.batch_size,
232
+ shuffle=False,
233
+ num_workers=self.num_workers,
234
+ pin_memory=False,
235
+ drop_last=False,
236
+ collate_fn=Stage1Collater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len, self.prot_aug)
237
+ )
238
+
239
+ return [loader1, loader2, loader3]
240
+
241
+ def swiss_match_dataloader(self):
242
+ val_match_loader = DataLoader(self.swiss_val_dataset_match,
243
+ batch_size=self.match_batch_size,
244
+ shuffle=False,
245
+ num_workers=self.num_workers,
246
+ pin_memory=False,
247
+ drop_last=False,
248
+ collate_fn=Stage1Collater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len, self.prot_aug))
249
+ test_match_loader = DataLoader(self.swiss_test_dataset_match,
250
+ batch_size=self.match_batch_size,
251
+ shuffle=False,
252
+ num_workers=self.num_workers,
253
+ pin_memory=False,
254
+ drop_last=False,
255
+ collate_fn=Stage1Collater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len, self.prot_aug))
256
+ return val_match_loader, test_match_loader
257
+
258
+ def onto_match_dataloader(self):
259
+ val_match_loader = DataLoader(self.onto_val_dataset_match,
260
+ batch_size=self.match_batch_size,
261
+ shuffle=False,
262
+ num_workers=self.num_workers,
263
+ pin_memory=False,
264
+ drop_last=False,
265
+ collate_fn=Stage1Collater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len, self.prot_aug))
266
+ test_match_loader = DataLoader(self.onto_test_dataset_match,
267
+ batch_size=self.match_batch_size,
268
+ shuffle=False,
269
+ num_workers=self.num_workers,
270
+ pin_memory=False,
271
+ drop_last=False,
272
+ collate_fn=Stage1Collater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len, self.prot_aug))
273
+ return val_match_loader, test_match_loader
274
+
275
+ def pdb_match_dataloader(self):
276
+ val_match_loader = DataLoader(
277
+ self.pdb_val_dataset_match,
278
+ batch_size=self.match_batch_size,
279
+ shuffle=False,
280
+ num_workers=self.num_workers,
281
+ pin_memory=False,
282
+ drop_last=False,
283
+ collate_fn=Stage1Collater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len, self.prot_aug)
284
+ )
285
+ test_match_loader = DataLoader(
286
+ self.pdb_test_dataset_match,
287
+ batch_size=self.match_batch_size,
288
+ shuffle=False,
289
+ num_workers=self.num_workers,
290
+ pin_memory=False,
291
+ drop_last=False,
292
+ collate_fn=Stage1Collater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len, self.prot_aug)
293
+ )
294
+ return val_match_loader, test_match_loader
295
+
296
+ def add_model_specific_args(parent_parser):
297
+ parser = parent_parser.add_argument_group("Data module")
298
+ parser.add_argument('--num_workers', type=int, default=4)
299
+ parser.add_argument('--batch_size', type=int, default=32)
300
+ parser.add_argument('--match_batch_size', type=int, default=64)
301
+ parser.add_argument('--root', type=str, default='data')
302
+ parser.add_argument('--text_max_len', type=int, default=128)
303
+ parser.add_argument('--prot_max_len', type=int, default=1024)
304
+ parser.add_argument('--prot_aug', type=str, default='None')
305
+ return parent_parser
306
+
307
+
308
+ # class SwissProtDataset(Dataset):
309
+ # def __init__(self, data_path, prompt='Swiss-Prot description: ', return_prompt=False):
310
+ # super(SwissProtDataset, self).__init__()
311
+ # self.data_path = data_path
312
+
313
+ # ## load data
314
+ # with open(data_path, 'r') as f:
315
+ # lines = f.readlines()
316
+ # lines = [line.strip() for line in lines]
317
+ # self.data_list = [json.loads(line) for line in lines]
318
+
319
+ # ## preprocessing
320
+ # self.data_list = [(p, t.strip() + '\n') for p, t in self.data_list]
321
+
322
+ # self.text2id = {}
323
+ # for prot_seq, text_seq in self.data_list:
324
+ # if text_seq not in self.text2id:
325
+ # self.text2id[text_seq] = len(self.text2id)
326
+
327
+ # self.prompt = prompt
328
+ # self.return_prompt = return_prompt
329
+
330
+ # def shuffle(self):
331
+ # random.shuffle(self.data_list)
332
+ # return self
333
+
334
+ # def len(self,):
335
+ # return len(self)
336
+
337
+ # def get(self, idx):
338
+ # return self.__getitem__(idx)
339
+
340
+ # def __len__(self):
341
+ # return len(self.data_list)
342
+
343
+ # def __getitem__(self, index):
344
+ # prot_seq, text_seq = self.data_list[index]
345
+ # if self.return_prompt:
346
+ # return prot_seq, self.prompt, text_seq, index
347
+ # return prot_seq, text_seq, index
348
+
349
+ class SwissProtDataset(Dataset):
350
+ def __init__(self, data_path, prompt='Swiss-Prot description: ', return_prompt=False):
351
+ super(SwissProtDataset, self).__init__()
352
+ self.data_path = data_path
353
+ self.prompt = prompt
354
+ self.return_prompt = return_prompt
355
+
356
+ # 加载并预处理数据
357
+ self.data_list = self._load_and_preprocess(data_path)
358
+ self.text2id = self._build_text_vocab()
359
+
360
+ def _load_and_preprocess(self, data_path):
361
+ """加载JSONL文件并预处理"""
362
+ data_list = []
363
+ with open(data_path, 'r', encoding='utf-8') as f:
364
+ for line in f:
365
+ try:
366
+ item = json.loads(line.strip())
367
+ # 确保包含所需字段
368
+ if 'protein' in item and 'text' in item:
369
+ prot_seq = item['protein']
370
+ text_seq = item['text'].strip() + '\n' # 添加结尾换行符
371
+ data_list.append((prot_seq, text_seq))
372
+ else:
373
+ print(f"警告: 跳过缺少字段的行: {line[:50]}...")
374
+ except json.JSONDecodeError:
375
+ print(f"警告: 跳过无效JSON行: {line[:50]}...")
376
+ return data_list
377
+
378
+ def _build_text_vocab(self):
379
+ """构建文本到ID的映射"""
380
+ text2id = {}
381
+ for _, text_seq in self.data_list:
382
+ if text_seq not in text2id:
383
+ text2id[text_seq] = len(text2id)
384
+ return text2id
385
+
386
+ def shuffle(self):
387
+ """打乱数据集顺序"""
388
+ random.shuffle(self.data_list)
389
+ return self
390
+
391
+ def __len__(self):
392
+ return len(self.data_list)
393
+
394
+ def __getitem__(self, index):
395
+ prot_seq, text_seq = self.data_list[index]
396
+ if self.return_prompt:
397
+ return prot_seq, self.prompt, text_seq, index
398
+ return prot_seq, text_seq, index
399
+
400
+ # 添加一些实用方法
401
+ def get_protein_sequence(self, index):
402
+ """获取指定索引的蛋白质序列"""
403
+ return self.data_list[index][0]
404
+
405
+ def get_text_description(self, index):
406
+ """获取指定索引的文本描述"""
407
+ return self.data_list[index][1]
408
+
409
+ def get_text_id(self, text_seq):
410
+ """获取文本描述的ID"""
411
+ return self.text2id.get(text_seq, -1)
412
+
413
+
414
+ class PDBAbstractDataset(Dataset):
415
+ def __init__(self, root_path, subset, prompt='ABSTRACT: ', return_prompt=False):
416
+ super(PDBAbstractDataset, self).__init__()
417
+ self.data_path = Path(root_path) / subset
418
+ self.abstract_path = Path(root_path) / 'abstract.json'
419
+
420
+ ## load dataset
421
+ with open(self.abstract_path, 'r') as f:
422
+ abstract_data = json.load(f)
423
+ abstract_data_dict = {line['pdb_id']: line['caption'] for line in abstract_data}
424
+
425
+ with open(self.data_path, 'r') as f:
426
+ lines = f.readlines()
427
+ pdb2seq = [line.strip().split('\t') for line in lines]
428
+
429
+ ## process dataset
430
+ data_list = []
431
+ for pdb_id, seq in pdb2seq:
432
+ abstract = abstract_data_dict[pdb_id]
433
+ abstract = abstract.replace('\n', ' ').strip() + '\n'
434
+ data_list.append((seq, abstract))
435
+ self.data_list = data_list
436
+ self.prompt = prompt
437
+ self.return_prompt = return_prompt
438
+
439
+ def shuffle(self):
440
+ random.shuffle(self.data_list)
441
+ return self
442
+
443
+ def len(self,):
444
+ return len(self)
445
+
446
+ def get(self, idx):
447
+ return self.__getitem__(idx)
448
+
449
+ def __len__(self):
450
+ return len(self.data_list)
451
+
452
+ def __getitem__(self, index):
453
+ seq, abstract = self.data_list[index]
454
+ if self.return_prompt:
455
+ return seq, self.prompt, abstract, index
456
+ return seq, abstract,index
457
+
458
+
459
+ class OntoProteinDataset(Dataset):
460
+ def __init__(self, data_path, prompt='Gene Ontology description: ', return_prompt=False):
461
+ super(OntoProteinDataset, self).__init__()
462
+ self.data_path = data_path
463
+
464
+ ## load data
465
+ with open(data_path, 'r') as f:
466
+ lines = f.readlines()
467
+ self.data_list = [line.strip().split('\t') for line in lines]
468
+
469
+ ## preprocessing
470
+ ## fixme: I have disabled the signal word for this dataset. However, it was used in previous experiments.
471
+ if True:
472
+ self.data_list = [(p, t.strip() + '\n') for p, t in self.data_list]
473
+ else:
474
+ self.data_list = [(p, "KG: " + t.strip() + '\n') for p, t in self.data_list]
475
+ self.prompt = prompt
476
+ self.return_prompt = return_prompt
477
+
478
+ def shuffle(self):
479
+ random.shuffle(self.data_list)
480
+ return self
481
+
482
+ def __len__(self):
483
+ return len(self.data_list)
484
+
485
+ def __getitem__(self, index):
486
+ prot_seq, text_seq = self.data_list[index]
487
+ if self.return_prompt:
488
+ return prot_seq, self.prompt, text_seq, index
489
+ return prot_seq, text_seq,index
490
+
491
+
492
+ if __name__ == '__main__':
493
+ import numpy as np
494
+ ## get statistics for swiss prot dataset
495
+ if False:
496
+ swiss_train = SwissProtDataset('../data/SwissProtV3/train_set.jsonl')
497
+ swiss_valid = SwissProtDataset('../data/SwissProtV3/valid_set.jsonl')
498
+ swiss_test = SwissProtDataset('../data/SwissProtV3/test_set.jsonl')
499
+ print(len(swiss_train), len(swiss_valid), len(swiss_test))
500
+
501
+ ## get amino acid statistics
502
+ aa_lens = np.asarray([len(seq) for seq, _ in swiss_train.data_list])
503
+ print('Train dataset mean: ', np.mean(aa_lens), 'min: ', aa_lens.min(), 'max: ', aa_lens.max())
504
+ aa_lens = np.asarray([len(seq) for seq, _ in swiss_valid.data_list])
505
+ print('Valid dataset mean: ', np.mean(aa_lens), 'min: ', aa_lens.min(), 'max: ', aa_lens.max())
506
+ aa_lens = np.asarray([len(seq) for seq, _ in swiss_test.data_list])
507
+ print('Test dataset mean: ', np.mean(aa_lens), 'min: ', aa_lens.min(), 'max: ', aa_lens.max())
508
+
509
+ ## get text statistics
510
+ text_lens = np.asarray([len(seq.split()) for _, seq in swiss_train.data_list])
511
+ print('Train dataset mean: ', np.mean(text_lens), 'min: ', text_lens.min(), 'max: ', text_lens.max())
512
+ text_lens = np.asarray([len(seq.split()) for _, seq in swiss_valid.data_list])
513
+ print('Valid dataset mean: ', np.mean(text_lens), 'min: ', text_lens.min(), 'max: ', text_lens.max())
514
+ text_lens = np.asarray([len(seq.split()) for _, seq in swiss_test.data_list])
515
+ print('Test dataset mean: ', np.mean(text_lens), 'min: ', text_lens.min(), 'max: ', text_lens.max())
516
+ print('---------------------------')
517
+
518
+ ## get statistics for onto protein dataset
519
+ onto_train = OntoProteinDataset('../data/OntoProteinDatasetV2/train.txt')
520
+ onto_valid = OntoProteinDataset('../data/OntoProteinDatasetV2/valid.txt')
521
+ onto_test = OntoProteinDataset('../data/OntoProteinDatasetV2/test.txt')
522
+ print(len(onto_train), len(onto_valid), len(onto_test))
523
+
524
+ ## get amino acid statistics
525
+ aa_lens = np.asarray([len(seq) for seq, _ in onto_train.data_list])
526
+ print('Train dataset mean: ', np.mean(aa_lens), 'min: ', aa_lens.min(), 'max: ', aa_lens.max())
527
+ aa_lens = np.asarray([len(seq) for seq, _ in onto_valid.data_list])
528
+ print('Valid dataset mean: ', np.mean(aa_lens), 'min: ', aa_lens.min(), 'max: ', aa_lens.max())
529
+ aa_lens = np.asarray([len(seq) for seq, _ in onto_test.data_list])
530
+ print('Test dataset mean: ', np.mean(aa_lens), 'min: ', aa_lens.min(), 'max: ', aa_lens.max())
531
+
532
+ ## get text statistics
533
+ text_lens = np.asarray([len(seq.split()) for _, seq in onto_train.data_list])
534
+ print('Train dataset mean: ', np.mean(text_lens), 'min: ', text_lens.min(), 'max: ', text_lens.max())
535
+ text_lens = np.asarray([len(seq.split()) for _, seq in onto_valid.data_list])
536
+ print('Valid dataset mean: ', np.mean(text_lens), 'min: ', text_lens.min(), 'max: ', text_lens.max())
537
+ text_lens = np.asarray([len(seq.split()) for _, seq in onto_test.data_list])
538
+ print('Test dataset mean: ', np.mean(text_lens), 'min: ', text_lens.min(), 'max: ', text_lens.max())
539
+ print('---------------------------')
data_provider/stage2_dm.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+ import json
4
+ from pytorch_lightning import LightningDataModule
5
+ from torch.utils.data import DataLoader, ConcatDataset
6
+ from data_provider.stage1_dm import SwissProtDataset, OntoProteinDataset
7
+ from data_provider.stage3_dm import DeepLocBinaryDataset,AlpacaDataset,MolInstructionDataset,DeepLocMultiDataset,Deepsol,DeepsoluE,Protsolm,FLIP_GB1,FLIP_AAV
8
+ from data_provider.bindingdb import BindingDB
9
+ from data_provider.metalIonbinding import MetallonBinding
10
+ from data_provider.go import GO_BP,EC
11
+ from data_provider.production import Antibiotic_Resistance,Thermostability,Material,Clone
12
+ from data_provider.mutation import TAPE_Stability,TAPE_Fluorescence
13
+
14
+ class Stage2Collater(object):
15
+ def __init__(self, tokenizer, prot_tokenizer, text_max_len, prot_max_len):
16
+ self.tokenizer = tokenizer
17
+ self.prot_tokenizer = prot_tokenizer
18
+ self.text_max_len = text_max_len
19
+ self.prot_max_len = prot_max_len
20
+
21
+ def __call__(self, batch):
22
+ prot_seqs, prompt_seqs, text_seqs, _ = zip(*batch)
23
+ prot_tokens = self.prot_tokenizer(prot_seqs,
24
+ truncation=True,
25
+ padding='max_length',
26
+ max_length=self.prot_max_len,
27
+ return_tensors="pt",
28
+ return_attention_mask=True,
29
+ return_token_type_ids=False)
30
+ if False:
31
+ self.tokenizer.padding_side = 'left'
32
+ prompt_tokens = self.tokenizer(prompt_seqs,
33
+ truncation=True,
34
+ padding='longest',
35
+ add_special_tokens=True,
36
+ max_length=self.text_max_len,
37
+ return_tensors='pt',
38
+ return_attention_mask=True,
39
+ return_token_type_ids=False)
40
+ self.tokenizer.padding_side = 'right'
41
+ text_tokens = self.tokenizer(text_seqs,
42
+ truncation=True,
43
+ padding='max_length',
44
+ add_special_tokens=True,
45
+ max_length=self.text_max_len,
46
+ return_tensors='pt',
47
+ return_attention_mask=True,
48
+ return_token_type_ids=False)
49
+ else:
50
+ self.tokenizer.padding_side = 'left'
51
+ prompt_tokens = self.tokenizer(prompt_seqs,
52
+ truncation=True,
53
+ padding='longest',
54
+ add_special_tokens=True,
55
+ max_length=self.text_max_len,
56
+ return_tensors='pt',
57
+ return_attention_mask=True,
58
+ return_token_type_ids=False)
59
+ max_prompt_len = int(prompt_tokens.attention_mask.sum(dim=1).max())
60
+ input_pair = [[p, t] for p, t in zip(prompt_seqs, text_seqs)]
61
+ input_tokens = self.tokenizer(input_pair,
62
+ truncation=True,
63
+ padding='max_length',
64
+ add_special_tokens=True,
65
+ max_length=self.text_max_len + max_prompt_len,
66
+ return_tensors='pt',
67
+ return_attention_mask=True,
68
+ return_token_type_ids=True)
69
+ return prot_tokens, input_tokens
70
+
71
+
72
+ class InferenceCollater(object):
73
+ def __init__(self, tokenizer, prot_tokenizer, text_max_len, prot_max_len):
74
+ self.tokenizer = tokenizer
75
+ self.prot_tokenizer = prot_tokenizer
76
+ self.text_max_len = text_max_len
77
+ self.prot_max_len = prot_max_len
78
+
79
+ def __call__(self, batch):
80
+ prot_seqs, prompt_seqs, text_seqs, indices = zip(*batch)
81
+ # print("=========")
82
+ # print(prot_seqs)
83
+
84
+ self.tokenizer.padding_side = 'right'
85
+ prompt_tokens = self.tokenizer(prompt_seqs,
86
+ truncation=True,
87
+ padding='longest',
88
+ add_special_tokens=False,
89
+ max_length=self.text_max_len,
90
+ return_tensors='pt',
91
+ return_attention_mask=True,
92
+ return_token_type_ids=False)
93
+ prot_tokens = self.prot_tokenizer(prot_seqs,
94
+ truncation=True,
95
+ padding='max_length',
96
+ max_length=self.prot_max_len,
97
+ return_tensors="pt",
98
+ return_attention_mask=True,
99
+ return_token_type_ids=False)
100
+ target_dict = {'targets': text_seqs, 'indices': indices}
101
+ return prot_tokens, prompt_tokens, target_dict
102
+
103
+
104
+ class Stage2DM(LightningDataModule):
105
+ def __init__(
106
+ self,
107
+ root: str = 'data/',
108
+ args=None,
109
+ ):
110
+ super().__init__()
111
+ self.args = args
112
+ self.batch_size = args.batch_size
113
+ self.inference_batch_size = args.inference_batch_size
114
+ self.num_workers = args.num_workers
115
+ self.text_max_len = args.text_max_len
116
+ self.prot_max_len = args.prot_max_len
117
+ self.prompt = args.prompt
118
+
119
+ # self.train_dataset = AlpacaDataset('/nas/shared/kilab/wangyujia/pretrain_data/instruct/alpaca-gpt4-train.jsonl', prompt=self.prompt, return_prompt=True)
120
+ # self.val_dataset = AlpacaDataset('/nas/shared/kilab/wangyujia/pretrain_data/instruct/alpaca-gpt4-valid.jsonl', prompt=self.prompt, return_prompt=True)
121
+ # self.test_dataset = AlpacaDataset('/nas/shared/kilab/wangyujia/pretrain_data/instruct/alpaca-gpt4-test.jsonl', prompt=self.prompt, return_prompt=True)
122
+
123
+ # self.train_dataset = MolInstructionDataset('/oss/wangyujia/pretrain-bench/mol-instruction/train.jsonl', prompt='', return_prompt=True)
124
+ # self.val_dataset = MolInstructionDataset('/oss/wangyujia/pretrain-bench/mol-instruction/train.jsonl', prompt='', return_prompt=True)
125
+ # self.test_dataset = MolInstructionDataset('/oss/wangyujia/pretrain-bench/mol-instruction/train.jsonl', prompt='', return_prompt=True)
126
+
127
+ if self.args.dataset=='deeplocbinary':
128
+ self.train_dataset = DeepLocBinaryDataset('/oss/wangyujia/pretrain-bench/locate/deeplocbinary/train.csv', prompt=self.prompt, return_prompt=True)
129
+ self.val_dataset = DeepLocBinaryDataset('/oss/wangyujia/pretrain-bench/locate/deeplocbinary/valid.csv', prompt=self.prompt, return_prompt=True)
130
+ self.test_dataset = DeepLocBinaryDataset('/oss/wangyujia/pretrain-bench/locate/deeplocbinary/test.csv', prompt=self.prompt, return_prompt=True)
131
+
132
+ elif self.args.dataset=='deeplocmulti':
133
+ self.train_dataset = DeepLocMultiDataset('/oss/wangyujia/pretrain-bench/locate/deeplocmulti/train.csv',prompt=self.prompt, return_prompt=True)
134
+ self.val_dataset = DeepLocMultiDataset('/oss/wangyujia/pretrain-bench/locate/deeplocmulti/valid.csv', prompt=self.prompt, return_prompt=True)
135
+ self.test_dataset = DeepLocMultiDataset('/oss/wangyujia/pretrain-bench/locate/deeplocmulti/test.csv', prompt=self.prompt, return_prompt=True)
136
+
137
+ elif self.args.dataset=='deepsol':
138
+ self.train_dataset = Deepsol('/nas/shared/kilab/wangyujia/sft_data/deepsol/clean/train.csv',prompt=self.prompt, return_prompt=True)
139
+ self.val_dataset = Deepsol('/nas/shared/kilab/wangyujia/sft_data/deepsol/clean/valid.csv', prompt=self.prompt, return_prompt=True)
140
+ self.test_dataset = Deepsol('/nas/shared/kilab/wangyujia/sft_data/deepsol/clean/test.csv', prompt=self.prompt, return_prompt=True)
141
+ elif self.args.dataset=='deepsolue':
142
+ self.train_dataset = DeepsoluE('/nas/shared/kilab/wangyujia/sft_data/deepsoluE/clean/train.csv',prompt=self.prompt, return_prompt=True)
143
+ self.val_dataset = DeepsoluE('/nas/shared/kilab/wangyujia/sft_data/deepsoluE/clean/test.csv', prompt=self.prompt, return_prompt=True)
144
+ self.test_dataset = DeepsoluE('/nas/shared/kilab/wangyujia/sft_data/deepsoluE/clean/test.csv', prompt=self.prompt, return_prompt=True)
145
+ elif self.args.dataset=='protsolm':
146
+ self.train_dataset = Protsolm('/oss/wangyujia/pretrain-bench/solubility/protsolm/train.csv',prompt=self.prompt, return_prompt=True)
147
+ self.val_dataset = Protsolm('/oss/wangyujia/pretrain-bench/solubility/protsolm/valid.csv', prompt=self.prompt, return_prompt=True)
148
+ self.test_dataset = Protsolm('/oss/wangyujia/pretrain-bench/solubility/protsolm/test.csv', prompt=self.prompt, return_prompt=True)
149
+ elif self.args.dataset=='gb1':
150
+ self.train_dataset = FLIP_GB1('/nas/shared/kilab/wangyujia/sft_data/mutation/gb1/train.csv',prompt=self.prompt, return_prompt=True)
151
+ self.val_dataset = FLIP_GB1('/nas/shared/kilab/wangyujia/sft_data/mutation/gb1/valid.csv', prompt=self.prompt, return_prompt=True)
152
+ self.test_dataset = FLIP_GB1('/nas/shared/kilab/wangyujia/sft_data/mutation/gb1/test.csv', prompt=self.prompt, return_prompt=True)
153
+ elif self.args.dataset=='gb1_low':
154
+ self.train_dataset = FLIP_GB1('/nas/shared/kilab/wangyujia/sft_data/mutation/gb1/clean/train.csv',prompt=self.prompt, return_prompt=True)
155
+ self.val_dataset = FLIP_GB1('/nas/shared/kilab/wangyujia/sft_data/mutation/gb1/clean/valid.csv', prompt=self.prompt, return_prompt=True)
156
+ self.test_dataset = FLIP_GB1('/nas/shared/kilab/wangyujia/sft_data/mutation/gb1/clean/test.csv', prompt=self.prompt, return_prompt=True)
157
+
158
+ elif self.args.dataset=='aav':
159
+ self.train_dataset = FLIP_AAV('/nas/shared/kilab/wangyujia/sft_data/mutation/aav/train.csv',prompt=self.prompt, return_prompt=True)
160
+ self.val_dataset = FLIP_AAV('/nas/shared/kilab/wangyujia/sft_data/mutation/aav/valid.csv', prompt=self.prompt, return_prompt=True)
161
+ self.test_dataset = FLIP_AAV('/nas/shared/kilab/wangyujia/sft_data/mutation/aav/test.csv', prompt=self.prompt, return_prompt=True)
162
+
163
+ elif self.args.dataset=='bindingdb':
164
+ self.train_dataset = BindingDB('/nas/shared/kilab/wangyujia/sft_data/bindingdb/clean/train_small.csv',prompt=self.prompt, return_prompt=True)
165
+ self.val_dataset = BindingDB('/nas/shared/kilab/wangyujia/sft_data/bindingdb/clean/valid_small.csv', prompt=self.prompt, return_prompt=True)
166
+ self.test_dataset = BindingDB('/nas/shared/kilab/wangyujia/sft_data/bindingdb/clean/test_small.csv', prompt=self.prompt, return_prompt=True)
167
+
168
+ elif self.args.dataset=='metallonbinding':
169
+ self.train_dataset = MetallonBinding('/nas/shared/kilab/wangyujia/sft_data/MetalIonBinding/train.csv',prompt=self.prompt, return_prompt=True)
170
+ self.val_dataset = MetallonBinding('/nas/shared/kilab/wangyujia/sft_data/MetalIonBinding/valid.csv', prompt=self.prompt, return_prompt=True)
171
+ self.test_dataset = MetallonBinding('/nas/shared/kilab/wangyujia/sft_data/MetalIonBinding/test.csv', prompt=self.prompt, return_prompt=True)
172
+
173
+ elif self.args.dataset=='bp':
174
+ self.train_dataset = GO_BP('/nas/shared/kilab/wangyujia/sft_data/go/clean/BP_train.csv',prompt=self.prompt, return_prompt=True)
175
+ self.val_dataset = GO_BP('/nas/shared/kilab/wangyujia/sft_data/go/clean/BP_valid.csv', prompt=self.prompt, return_prompt=True)
176
+ self.test_dataset = GO_BP('/nas/shared/kilab/wangyujia/sft_data/go/clean/BP_test.csv', prompt=self.prompt, return_prompt=True)
177
+ elif self.args.dataset=='ec':
178
+ self.train_dataset = EC('/nas/shared/kilab/wangyujia/sft_data/EC/train.csv',prompt=self.prompt, return_prompt=True)
179
+ self.val_dataset = EC('/nas/shared/kilab/wangyujia/sft_data/EC/valid.csv', prompt=self.prompt, return_prompt=True)
180
+ self.test_dataset = EC('/nas/shared/kilab/wangyujia/sft_data/EC/test.csv', prompt=self.prompt, return_prompt=True)
181
+
182
+ elif self.args.dataset=='antibiotic':
183
+ self.train_dataset = Antibiotic_Resistance('/nas/shared/kilab/wangyujia/sft_data/production/Antibiotic_Resistance/train.csv',prompt=self.prompt, return_prompt=True)
184
+ self.val_dataset = Antibiotic_Resistance('/nas/shared/kilab/wangyujia/sft_data/production/Antibiotic_Resistance/test.csv', prompt=self.prompt, return_prompt=True)
185
+ self.test_dataset = Antibiotic_Resistance('/nas/shared/kilab/wangyujia/sft_data/production/Antibiotic_Resistance/test.csv', prompt=self.prompt, return_prompt=True)
186
+ elif self.args.dataset=='thermostability':
187
+ self.train_dataset = Thermostability('/nas/shared/kilab/wangyujia/sft_data/production/Thermostability/train.csv',prompt=self.prompt, return_prompt=True)
188
+ self.val_dataset = Thermostability('/nas/shared/kilab/wangyujia/sft_data/production/Thermostability/valid.csv', prompt=self.prompt, return_prompt=True)
189
+ self.test_dataset = Thermostability('/nas/shared/kilab/wangyujia/sft_data/production/Thermostability/test.csv', prompt=self.prompt, return_prompt=True)
190
+ elif self.args.dataset=='material':
191
+ self.train_dataset = Material('/oss/wangyujia/ProtT3/ProtT3/data/sft/dataset/material_production/train.csv',prompt=self.prompt, return_prompt=True)
192
+ self.val_dataset = Material('/oss/wangyujia/ProtT3/ProtT3/data/sft/dataset/material_production/val.csv', prompt=self.prompt, return_prompt=True)
193
+ self.test_dataset = Material('/oss/wangyujia/ProtT3/ProtT3/data/sft/dataset/material_production/test.csv', prompt=self.prompt, return_prompt=True)
194
+ #6
195
+ elif self.args.dataset=='clone':
196
+ self.train_dataset = Clone('/oss/wangyujia/ProtT3/ProtT3/data/sft/dataset/cloning_clf/train.csv',prompt=self.prompt, return_prompt=True)
197
+ self.val_dataset = Clone('/oss/wangyujia/ProtT3/ProtT3/data/sft/dataset/cloning_clf/val.csv', prompt=self.prompt, return_prompt=True)
198
+ self.test_dataset = Clone('/oss/wangyujia/ProtT3/ProtT3/data/sft/dataset/cloning_clf/test.csv', prompt=self.prompt, return_prompt=True)
199
+
200
+
201
+
202
+ elif self.args.dataset=='stability':
203
+ self.train_dataset = TAPE_Stability('/oss/wangyujia/ProtT3/ProtT3/data/sft/dataset/TAPE_Stability/train.csv',prompt=self.prompt, return_prompt=True)
204
+ self.val_dataset = TAPE_Stability('/oss/wangyujia/ProtT3/ProtT3/data/sft/dataset/TAPE_Stability/val.csv', prompt=self.prompt, return_prompt=True)
205
+ self.test_dataset = TAPE_Stability('/oss/wangyujia/ProtT3/ProtT3/data/sft/dataset/TAPE_Stability/test.csv', prompt=self.prompt, return_prompt=True)
206
+ elif self.args.dataset=='fluorescence':
207
+ self.train_dataset = TAPE_Fluorescence('/oss/wangyujia/ProtT3/ProtT3/data/sft/dataset/TAPE_Fluorescence/fluorescence_prediction_train.csv',prompt=self.prompt, return_prompt=True)
208
+ self.val_dataset = TAPE_Fluorescence('/oss/wangyujia/ProtT3/ProtT3/data/sft/dataset/TAPE_Fluorescence/fluorescence_prediction_valid.csv', prompt=self.prompt, return_prompt=True)
209
+ self.test_dataset = TAPE_Fluorescence('/oss/wangyujia/ProtT3/ProtT3/data/sft/dataset/TAPE_Fluorescence/test.csv', prompt=self.prompt, return_prompt=True)
210
+
211
+
212
+ elif self.args.dataset=='empty':
213
+ self.train_dataset = Empty( prompt=self.prompt, return_prompt=True)
214
+ self.val_dataset = Empty( prompt=self.prompt, return_prompt=True)
215
+ self.test_dataset = Empty(prompt=self.prompt, return_prompt=True)
216
+
217
+ elif self.args.dataset=='swiss-prot':
218
+ self.train_dataset = SwissProtDataset(root+'/SwissProtV3/train_set.jsonl', prompt='Swiss-Prot description: ', return_prompt=True)
219
+ self.val_dataset = SwissProtDataset(root+'/SwissProtV3/valid_set.jsonl', prompt='Swiss-Prot description: ', return_prompt=True)
220
+ self.test_dataset = SwissProtDataset(root+'/SwissProtV3/test_set.jsonl', prompt='Swiss-Prot description: ', return_prompt=True)
221
+
222
+ self.tokenizer = None
223
+ self.prot_tokenizer = None
224
+
225
+ def init_tokenizer(self, tokenizer, prot_tokenizer):
226
+ self.tokenizer = tokenizer
227
+ self.prot_tokenizer = prot_tokenizer
228
+
229
+ def train_dataloader(self):
230
+ loader = DataLoader(
231
+ self.train_dataset,
232
+ batch_size=self.batch_size,
233
+ shuffle=True,
234
+ num_workers=self.num_workers,
235
+ pin_memory=False,
236
+ drop_last=True,
237
+ persistent_workers=False,
238
+ collate_fn=InferenceCollater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len),
239
+ )
240
+ return loader
241
+
242
+ def val_dataloader(self):
243
+ val_loader = DataLoader(
244
+ self.val_dataset,
245
+ batch_size=self.batch_size,
246
+ shuffle=False,
247
+ num_workers=self.num_workers,
248
+ pin_memory=False,
249
+ drop_last=False,
250
+ persistent_workers=False,
251
+ collate_fn=InferenceCollater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len),
252
+ )
253
+ test_loader = DataLoader(
254
+ self.test_dataset,
255
+ batch_size=self.inference_batch_size,
256
+ shuffle=False,
257
+ num_workers=self.num_workers,
258
+ pin_memory=False,
259
+ drop_last=False,
260
+ persistent_workers=False,
261
+ collate_fn=InferenceCollater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len),
262
+ )
263
+ return [val_loader, test_loader]
264
+
265
+
266
+ def add_model_specific_args(parent_parser):
267
+ parser = parent_parser.add_argument_group("Data module")
268
+ parser.add_argument('--num_workers', type=int, default=2)
269
+ parser.add_argument('--batch_size', type=int, default=32)
270
+ parser.add_argument('--inference_batch_size', type=int, default=4)
271
+ parser.add_argument('--root', type=str, default='data')
272
+ parser.add_argument('--text_max_len', type=int, default=2048)
273
+ parser.add_argument('--q_max_len', type=int, default=29)
274
+ parser.add_argument('--a_max_len', type=int, default=36)
275
+ parser.add_argument('--prot_max_len', type=int, default=1024)
276
+ parser.add_argument('--prompt', type=str, default='The protein has the following properties:')
277
+ parser.add_argument('--filter_side_qa', action='store_true', default=False)
278
+ return parent_parser
279
+
280
+
281
+
282
+ class Stage2MixDM(LightningDataModule):
283
+ def __init__(
284
+ self,
285
+ root: str = 'data/',
286
+ args=None,
287
+ ):
288
+ super().__init__()
289
+ self.args = args
290
+ self.batch_size = args.batch_size
291
+ self.inference_batch_size = args.inference_batch_size
292
+ self.num_workers = args.num_workers
293
+ self.text_max_len = args.text_max_len
294
+ self.prot_max_len = args.prot_max_len
295
+ # self.prompt = args.prompt
296
+ assert args.mix_dataset
297
+
298
+ train_dataset1 = SwissProtDataset(root+'/SwissProtV3/train_set.jsonl', prompt='Swiss-Prot description: ', return_prompt=True)
299
+ train_dataset2 = OntoProteinDataset(root+'/OntoProteinDatasetV2/train.txt', prompt='Gene Ontology description: ', return_prompt=True)
300
+ self.train_dataset = ConcatDataset([train_dataset1,train_dataset2])
301
+ self.swiss_val_dataset = SwissProtDataset(root+'/SwissProtV3/valid_set.jsonl', prompt='Swiss-Prot description: ', return_prompt=True)
302
+ self.onto_val_dataset = OntoProteinDataset(root+'/OntoProteinDatasetV2/valid.txt', prompt='Gene Ontology description: ', return_prompt=True)
303
+ self.swiss_test_dataset = SwissProtDataset(root+'/SwissProtV3/test_set.jsonl', prompt='Swiss-Prot description: ', return_prompt=True)
304
+ self.onto_test_dataset = OntoProteinDataset(root+'/OntoProteinDatasetV2/test.txt', prompt='Gene Ontology description: ', return_prompt=True)
305
+
306
+
307
+
308
+ self.tokenizer = None
309
+ self.prot_tokenizer = None
310
+
311
+ def init_tokenizer(self, tokenizer, prot_tokenizer):
312
+ self.tokenizer = tokenizer
313
+ self.prot_tokenizer = prot_tokenizer
314
+
315
+ def train_dataloader(self):
316
+ loader = DataLoader(
317
+ self.train_dataset,
318
+ batch_size=self.batch_size,
319
+ shuffle=True,
320
+ num_workers=self.num_workers,
321
+ pin_memory=False,
322
+ drop_last=True,
323
+ persistent_workers=False,
324
+ collate_fn=InferenceCollater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len),
325
+ )
326
+ return loader
327
+
328
+ def val_dataloader(self):
329
+ swiss_val_loader = DataLoader(
330
+ self.swiss_val_dataset,
331
+ batch_size=self.batch_size,
332
+ shuffle=False,
333
+ num_workers=self.num_workers,
334
+ pin_memory=False,
335
+ drop_last=False,
336
+ persistent_workers=False,
337
+ collate_fn=InferenceCollater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len),
338
+ )
339
+ swiss_test_loader = DataLoader(
340
+ self.swiss_test_dataset,
341
+ batch_size=self.inference_batch_size,
342
+ shuffle=False,
343
+ num_workers=self.num_workers,
344
+ pin_memory=False,
345
+ drop_last=False,
346
+ persistent_workers=False,
347
+ collate_fn=InferenceCollater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len),
348
+ )
349
+
350
+ onto_val_loader = DataLoader(
351
+ self.onto_val_dataset,
352
+ batch_size=self.batch_size,
353
+ shuffle=False,
354
+ num_workers=self.num_workers,
355
+ pin_memory=False,
356
+ drop_last=False,
357
+ persistent_workers=False,
358
+ collate_fn=InferenceCollater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len),
359
+ )
360
+ onto_test_loader = DataLoader(
361
+ self.onto_test_dataset,
362
+ batch_size=self.inference_batch_size,
363
+ shuffle=False,
364
+ num_workers=self.num_workers,
365
+ pin_memory=False,
366
+ drop_last=False,
367
+ persistent_workers=False,
368
+ collate_fn=InferenceCollater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len),
369
+ )
370
+ return [swiss_val_loader, swiss_test_loader, onto_val_loader, onto_test_loader]
371
+
372
+
373
+ def add_model_specific_args(parent_parser):
374
+ parser = parent_parser.add_argument_group("Data module")
375
+ parser.add_argument('--num_workers', type=int, default=2)
376
+ parser.add_argument('--batch_size', type=int, default=32)
377
+ parser.add_argument('--inference_batch_size', type=int, default=4)
378
+ parser.add_argument('--root', type=str, default='data')
379
+ parser.add_argument('--text_max_len', type=int, default=1024)
380
+ parser.add_argument('--q_max_len', type=int, default=29)
381
+ parser.add_argument('--a_max_len', type=int, default=36)
382
+ parser.add_argument('--prot_max_len', type=int, default=1024)
383
+ # parser.add_argument('--prompt', type=str, default='The protein has the following properties: ')
384
+ return parent_parser
385
+
386
+