Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +15 -0
- all_checkpoints/stage2_08011616_2datasets_withoutpretrain/last.ckpt/checkpoint/mp_rank_00_model_states.pt +3 -0
- all_checkpoints/stage2_08011616_2datasets_withoutpretrain/wandb/run-20250801_162632-790wkw0g/files/output.log +0 -0
- all_checkpoints/stage2_08011616_2datasets_withoutpretrain/wandb/run-20250801_162632-790wkw0g/files/requirements.txt +225 -0
- all_checkpoints/stage2_08011616_2datasets_withoutpretrain/wandb/run-20250801_162632-790wkw0g/files/wandb-metadata.json +108 -0
- all_checkpoints/stage2_08011616_2datasets_withoutpretrain/wandb/run-20250801_162632-790wkw0g/files/wandb-summary.json +1 -0
- all_checkpoints/stage2_08011616_2datasets_withoutpretrain/wandb/run-20250801_162632-790wkw0g/logs/debug-internal.log +166 -0
- all_checkpoints/stage2_08011616_2datasets_withoutpretrain/wandb/run-20250801_162632-790wkw0g/logs/debug.log +24 -0
- all_checkpoints/stage2_08011616_2datasets_withoutpretrain/wandb/run-20250801_162632-790wkw0g/run-790wkw0g.wandb +3 -0
- data/.gitignore +2 -0
- data/OntoProteinDatasetV2/test.txt +3 -0
- data/OntoProteinDatasetV2/train.txt +3 -0
- data/OntoProteinDatasetV2/valid.txt +3 -0
- data/PDBDataset/abstract.json +3 -0
- data/PDBDataset/q_types.txt +30 -0
- data/PDBDataset/qa_all.json +3 -0
- data/PDBDataset/test.txt +0 -0
- data/PDBDataset/train.txt +3 -0
- data/PDBDataset/val.txt +0 -0
- data/SwissProtV3/test_set.jsonl +0 -0
- data/SwissProtV3/train_set.jsonl +3 -0
- data/SwissProtV3/valid_set.jsonl +0 -0
- data/protein-molecule/protein-text.zip +3 -0
- data/protein-text/eval_assist.zipg3ebgjl7.tmp +3 -0
- data/protein-text/eval_assist.ziphwjr8q2y.tmp +3 -0
- data/protein-text/eval_assist.zipzh1pdmj_.tmp +3 -0
- data_provider/__pycache__/bindingdb.cpython-310.pyc +0 -0
- data_provider/__pycache__/go.cpython-310.pyc +0 -0
- data_provider/__pycache__/metalIonbinding.cpython-310.pyc +0 -0
- data_provider/__pycache__/mutation.cpython-310.pyc +0 -0
- data_provider/__pycache__/production.cpython-310.pyc +0 -0
- data_provider/__pycache__/prot_qa_dm.cpython-310.pyc +0 -0
- data_provider/__pycache__/prot_qa_dm.cpython-311.pyc +0 -0
- data_provider/__pycache__/stage1_dm.cpython-310.pyc +0 -0
- data_provider/__pycache__/stage1_dm.cpython-311.pyc +0 -0
- data_provider/__pycache__/stage2_dm.cpython-310.pyc +0 -0
- data_provider/__pycache__/stage3_dm.cpython-310.pyc +0 -0
- data_provider/__pycache__/stage3_dm.cpython-311.pyc +0 -0
- data_provider/bindingdb.py +62 -0
- data_provider/gal_helpers.py +45 -0
- data_provider/go.py +237 -0
- data_provider/llm_tuning_dm.py +261 -0
- data_provider/llm_tuning_prot_qa_dm.py +164 -0
- data_provider/metalIonbinding.py +63 -0
- data_provider/mutation.py +119 -0
- data_provider/production.py +237 -0
- data_provider/prot_qa_dm.py +299 -0
- data_provider/proteinchat_dm.py +254 -0
- data_provider/stage1_dm.py +539 -0
- data_provider/stage2_dm.py +386 -0
.gitattributes
CHANGED
|
@@ -39,3 +39,18 @@ all_checkpoints/stage2_07070337_2datasets_noconstruct/wandb/run-20250707_041231-
|
|
| 39 |
all_checkpoints/stage2_07070513_2datasets_construct/wandb/run-20250707_052104-615z4bme/run-615z4bme.wandb filter=lfs diff=lfs merge=lfs -text
|
| 40 |
all_checkpoints/stage2_07070513_2datasets_construct/wandb/run-20250707_053222-9cjzn0v3/run-9cjzn0v3.wandb filter=lfs diff=lfs merge=lfs -text
|
| 41 |
all_checkpoints/stage2_07301646_2datasets_construct/wandb/run-20250730_175623-pbf2bxo6/run-pbf2bxo6.wandb filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
all_checkpoints/stage2_07070513_2datasets_construct/wandb/run-20250707_052104-615z4bme/run-615z4bme.wandb filter=lfs diff=lfs merge=lfs -text
|
| 40 |
all_checkpoints/stage2_07070513_2datasets_construct/wandb/run-20250707_053222-9cjzn0v3/run-9cjzn0v3.wandb filter=lfs diff=lfs merge=lfs -text
|
| 41 |
all_checkpoints/stage2_07301646_2datasets_construct/wandb/run-20250730_175623-pbf2bxo6/run-pbf2bxo6.wandb filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
data/OntoProteinDatasetV2/test.txt filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
all_checkpoints/stage2_08011616_2datasets_withoutpretrain/wandb/run-20250801_162632-790wkw0g/run-790wkw0g.wandb filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
data/OntoProteinDatasetV2/train.txt filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
data/OntoProteinDatasetV2/valid.txt filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
data/PDBDataset/abstract.json filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
data/PDBDataset/qa_all.json filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
data/PDBDataset/train.txt filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
data/SwissProtV3/train_set.jsonl filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
data/protein-text/eval_assist.ziphwjr8q2y.tmp filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
data/protein-text/eval_assist.zipg3ebgjl7.tmp filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
data/protein-text/eval_assist.zipzh1pdmj_.tmp filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
data_small/OntoProteinDatasetV2/train.txt filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
data_small/PDBDataset/abstract.json filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
data_small/PDBDataset/qa_all.json filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
data_small/SwissProtV3/train_set_.jsonl filter=lfs diff=lfs merge=lfs -text
|
all_checkpoints/stage2_08011616_2datasets_withoutpretrain/last.ckpt/checkpoint/mp_rank_00_model_states.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6efbc882065731ceb2c9886091da92484e90352ef036f1fff44e77700ff80f41
|
| 3 |
+
size 208795384
|
all_checkpoints/stage2_08011616_2datasets_withoutpretrain/wandb/run-20250801_162632-790wkw0g/files/output.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
all_checkpoints/stage2_08011616_2datasets_withoutpretrain/wandb/run-20250801_162632-790wkw0g/files/requirements.txt
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
opendatasets==0.1.22
|
| 2 |
+
salesforce-lavis==1.0.2
|
| 3 |
+
Pygments==2.19.1
|
| 4 |
+
nvidia-nccl-cu12==2.21.5
|
| 5 |
+
tornado==6.5.1
|
| 6 |
+
nvidia-cuda-runtime-cu12==12.4.127
|
| 7 |
+
requests==2.32.3
|
| 8 |
+
nvidia-cuda-cupti-cu12==12.4.127
|
| 9 |
+
decord==0.6.0
|
| 10 |
+
braceexpand==0.1.7
|
| 11 |
+
frozenlist==1.6.0
|
| 12 |
+
markdown-it-py==3.0.0
|
| 13 |
+
shellingham==1.5.4
|
| 14 |
+
absl-py==2.2.2
|
| 15 |
+
pycocoevalcap==1.2
|
| 16 |
+
contexttimer==0.3.3
|
| 17 |
+
bleach==6.2.0
|
| 18 |
+
jsonschema-specifications==2025.4.1
|
| 19 |
+
pycocotools==2.0.8
|
| 20 |
+
python-slugify==8.0.4
|
| 21 |
+
tqdm==4.67.1
|
| 22 |
+
numpy==2.2.6
|
| 23 |
+
urllib3==2.4.0
|
| 24 |
+
deepspeed==0.16.10+b666844f
|
| 25 |
+
watchdog==6.0.0
|
| 26 |
+
wrapt==1.17.2
|
| 27 |
+
setuptools==78.1.1
|
| 28 |
+
matplotlib==3.10.3
|
| 29 |
+
pydeck==0.9.1
|
| 30 |
+
aiosignal==1.3.2
|
| 31 |
+
gitdb==4.0.12
|
| 32 |
+
hjson==3.1.0
|
| 33 |
+
timm==0.4.12
|
| 34 |
+
blis==1.3.0
|
| 35 |
+
PyYAML==6.0.2
|
| 36 |
+
referencing==0.36.2
|
| 37 |
+
contourpy==1.3.2
|
| 38 |
+
kaggle==1.7.4.5
|
| 39 |
+
triton==3.2.0
|
| 40 |
+
catalogue==2.0.10
|
| 41 |
+
idna==3.10
|
| 42 |
+
torch==2.6.0
|
| 43 |
+
text-unidecode==1.3
|
| 44 |
+
altair==5.5.0
|
| 45 |
+
cloudpathlib==0.21.1
|
| 46 |
+
protobuf==6.31.0
|
| 47 |
+
nvidia-cusolver-cu12==11.6.1.9
|
| 48 |
+
pytz==2025.2
|
| 49 |
+
sympy==1.13.1
|
| 50 |
+
spacy==3.8.7
|
| 51 |
+
MarkupSafe==3.0.2
|
| 52 |
+
thinc==8.3.6
|
| 53 |
+
nvidia-cudnn-cu12==9.1.0.70
|
| 54 |
+
wasabi==1.1.3
|
| 55 |
+
aiohappyeyeballs==2.6.1
|
| 56 |
+
nvidia-nvtx-cu12==12.4.127
|
| 57 |
+
rich==14.0.0
|
| 58 |
+
ipython==8.36.0
|
| 59 |
+
yarl==1.20.0
|
| 60 |
+
torchmetrics==1.7.1
|
| 61 |
+
multidict==6.4.4
|
| 62 |
+
cfgv==3.4.0
|
| 63 |
+
smmap==5.0.2
|
| 64 |
+
srsly==2.5.1
|
| 65 |
+
scikit-image==0.25.2
|
| 66 |
+
matplotlib-inline==0.1.7
|
| 67 |
+
annotated-types==0.7.0
|
| 68 |
+
lazy_loader==0.4
|
| 69 |
+
tenacity==9.1.2
|
| 70 |
+
GitPython==3.1.44
|
| 71 |
+
language_data==1.3.0
|
| 72 |
+
pydantic_core==2.33.2
|
| 73 |
+
sentencepiece==0.2.0
|
| 74 |
+
platformdirs==4.3.8
|
| 75 |
+
distlib==0.3.9
|
| 76 |
+
nvidia-cusparselt-cu12==0.6.2
|
| 77 |
+
blinker==1.9.0
|
| 78 |
+
regex==2024.11.6
|
| 79 |
+
tifffile==2025.5.10
|
| 80 |
+
py-cpuinfo==9.0.0
|
| 81 |
+
attrs==25.3.0
|
| 82 |
+
mdurl==0.1.2
|
| 83 |
+
prompt_toolkit==3.0.51
|
| 84 |
+
packaging==24.2
|
| 85 |
+
async-timeout==5.0.1
|
| 86 |
+
six==1.17.0
|
| 87 |
+
executing==2.2.0
|
| 88 |
+
parso==0.8.4
|
| 89 |
+
omegaconf==2.3.0
|
| 90 |
+
wcwidth==0.2.13
|
| 91 |
+
murmurhash==1.0.13
|
| 92 |
+
stack-data==0.6.3
|
| 93 |
+
nvidia-cufft-cu12==11.2.1.3
|
| 94 |
+
virtualenv==20.31.2
|
| 95 |
+
langcodes==3.5.0
|
| 96 |
+
fonttools==4.58.0
|
| 97 |
+
opencv-python-headless==4.5.5.64
|
| 98 |
+
jedi==0.19.2
|
| 99 |
+
torchvision==0.21.0
|
| 100 |
+
plotly==6.1.1
|
| 101 |
+
nodeenv==1.9.1
|
| 102 |
+
smart-open==7.1.0
|
| 103 |
+
toml==0.10.2
|
| 104 |
+
pytorch-lightning==2.5.1.post0
|
| 105 |
+
typing_extensions==4.13.2
|
| 106 |
+
safetensors==0.5.3
|
| 107 |
+
psutil==7.0.0
|
| 108 |
+
pillow==11.2.1
|
| 109 |
+
python-dateutil==2.9.0.post0
|
| 110 |
+
ftfy==6.3.1
|
| 111 |
+
scipy==1.15.3
|
| 112 |
+
webdataset==0.2.111
|
| 113 |
+
charset-normalizer==3.4.2
|
| 114 |
+
nvidia-nvjitlink-cu12==12.4.127
|
| 115 |
+
kiwisolver==1.4.8
|
| 116 |
+
nvidia-ml-py==12.575.51
|
| 117 |
+
confection==0.1.5
|
| 118 |
+
nvidia-curand-cu12==10.3.5.147
|
| 119 |
+
pandas==2.2.3
|
| 120 |
+
nltk==3.9.1
|
| 121 |
+
webencodings==0.5.1
|
| 122 |
+
pyarrow==20.0.0
|
| 123 |
+
asttokens==3.0.0
|
| 124 |
+
exceptiongroup==1.3.0
|
| 125 |
+
pre_commit==4.2.0
|
| 126 |
+
ninja==1.11.1.4
|
| 127 |
+
spacy-loggers==1.0.5
|
| 128 |
+
msgpack==1.1.0
|
| 129 |
+
lightning-utilities==0.14.3
|
| 130 |
+
nvidia-cublas-cu12==12.4.5.8
|
| 131 |
+
tzdata==2025.2
|
| 132 |
+
cycler==0.12.1
|
| 133 |
+
hf-xet==1.1.2
|
| 134 |
+
antlr4-python3-runtime==4.9.3
|
| 135 |
+
iopath==0.1.10
|
| 136 |
+
pexpect==4.9.0
|
| 137 |
+
imageio==2.37.0
|
| 138 |
+
streamlit==1.45.1
|
| 139 |
+
python-magic==0.4.27
|
| 140 |
+
networkx==3.4.2
|
| 141 |
+
portalocker==3.1.1
|
| 142 |
+
nvidia-cusparse-cu12==12.3.1.170
|
| 143 |
+
propcache==0.3.1
|
| 144 |
+
ptyprocess==0.7.0
|
| 145 |
+
fairscale==0.4.4
|
| 146 |
+
rpds-py==0.25.1
|
| 147 |
+
certifi==2025.4.26
|
| 148 |
+
rouge_score==0.1.2
|
| 149 |
+
traitlets==5.14.3
|
| 150 |
+
identify==2.6.12
|
| 151 |
+
spacy-legacy==3.0.12
|
| 152 |
+
weasel==0.4.1
|
| 153 |
+
mpmath==1.3.0
|
| 154 |
+
cymem==2.0.11
|
| 155 |
+
typing-inspection==0.4.1
|
| 156 |
+
nvidia-cuda-nvrtc-cu12==12.4.127
|
| 157 |
+
marisa-trie==1.2.1
|
| 158 |
+
einops==0.8.1
|
| 159 |
+
nvidia-cufile-cu12==1.11.1.6
|
| 160 |
+
pydantic==2.11.5
|
| 161 |
+
cachetools==5.5.2
|
| 162 |
+
joblib==1.5.1
|
| 163 |
+
Jinja2==3.1.6
|
| 164 |
+
filelock==3.18.0
|
| 165 |
+
pyparsing==3.2.3
|
| 166 |
+
pure_eval==0.2.3
|
| 167 |
+
decorator==5.2.1
|
| 168 |
+
wheel==0.45.1
|
| 169 |
+
pycryptodome==3.23.0
|
| 170 |
+
cheroot==10.0.1
|
| 171 |
+
multiprocess==0.70.16
|
| 172 |
+
aiohttp==3.12.2
|
| 173 |
+
crcmod==1.7
|
| 174 |
+
fsspec==2025.3.0
|
| 175 |
+
jmespath==0.10.0
|
| 176 |
+
preshed==3.0.10
|
| 177 |
+
jaraco.functools==4.1.0
|
| 178 |
+
cryptography==45.0.3
|
| 179 |
+
sentry-sdk==2.29.1
|
| 180 |
+
tokenizers==0.21.1
|
| 181 |
+
opendelta==0.3.2
|
| 182 |
+
pycparser==2.22
|
| 183 |
+
narwhals==1.41.0
|
| 184 |
+
scikit-learn==1.6.1
|
| 185 |
+
dill==0.3.8
|
| 186 |
+
oss2==2.15.0
|
| 187 |
+
yacs==0.1.8
|
| 188 |
+
more-itertools==10.7.0
|
| 189 |
+
pip==25.1.1
|
| 190 |
+
threadpoolctl==3.6.0
|
| 191 |
+
flash-attn==2.7.1.post1
|
| 192 |
+
bigmodelvis==0.0.1
|
| 193 |
+
pathlib==1.0.1
|
| 194 |
+
delta-center-client==0.0.4
|
| 195 |
+
xxhash==3.5.0
|
| 196 |
+
wandb==0.19.11
|
| 197 |
+
setproctitle==1.3.6
|
| 198 |
+
aliyun-python-sdk-core==2.16.0
|
| 199 |
+
transformers==4.52.3
|
| 200 |
+
aliyun-python-sdk-kms==2.16.5
|
| 201 |
+
datasets==3.6.0
|
| 202 |
+
typer==0.16.0
|
| 203 |
+
docker-pycreds==0.4.0
|
| 204 |
+
click==8.2.1
|
| 205 |
+
huggingface-hub==0.32.1
|
| 206 |
+
web.py==0.62
|
| 207 |
+
cffi==1.17.1
|
| 208 |
+
opencv-python==4.11.0.86
|
| 209 |
+
jsonschema==4.24.0
|
| 210 |
+
typing_extensions==4.12.2
|
| 211 |
+
jaraco.functools==4.0.1
|
| 212 |
+
jaraco.text==3.12.1
|
| 213 |
+
jaraco.collections==5.1.0
|
| 214 |
+
inflect==7.3.1
|
| 215 |
+
more-itertools==10.3.0
|
| 216 |
+
packaging==24.2
|
| 217 |
+
importlib_metadata==8.0.0
|
| 218 |
+
backports.tarfile==1.2.0
|
| 219 |
+
typeguard==4.3.0
|
| 220 |
+
zipp==3.19.2
|
| 221 |
+
platformdirs==4.2.2
|
| 222 |
+
autocommand==2.2.2
|
| 223 |
+
jaraco.context==5.3.0
|
| 224 |
+
tomli==2.0.1
|
| 225 |
+
wheel==0.45.1
|
all_checkpoints/stage2_08011616_2datasets_withoutpretrain/wandb/run-20250801_162632-790wkw0g/files/wandb-metadata.json
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"os": "Linux-5.10.134-008.16.kangaroo.al8.x86_64-x86_64-with-glibc2.35",
|
| 3 |
+
"python": "CPython 3.10.0",
|
| 4 |
+
"startedAt": "2025-08-01T08:26:32.241935Z",
|
| 5 |
+
"args": [
|
| 6 |
+
"--devices",
|
| 7 |
+
"0,1,2,3,4,5,6,7",
|
| 8 |
+
"--mode",
|
| 9 |
+
"train",
|
| 10 |
+
"--filename",
|
| 11 |
+
"stage2_08011616_2datasets_qweninstruct",
|
| 12 |
+
"--num_query_token",
|
| 13 |
+
"8",
|
| 14 |
+
"--save_every_n_epochs",
|
| 15 |
+
"2",
|
| 16 |
+
"--max_epochs",
|
| 17 |
+
"10",
|
| 18 |
+
"--batch_size",
|
| 19 |
+
"4",
|
| 20 |
+
"--precision",
|
| 21 |
+
"bf16-mixed",
|
| 22 |
+
"--num_workers",
|
| 23 |
+
"8",
|
| 24 |
+
"--plm_model",
|
| 25 |
+
"/nas/shared/kilab/wangyujia/ProtT3/plm_model/esm2-150m",
|
| 26 |
+
"--bert_name",
|
| 27 |
+
"/nas/shared/kilab/wangyujia/ProtT3/plm_model/microsoft",
|
| 28 |
+
"--llm_name",
|
| 29 |
+
"/nas/shared/kilab/hf-hub/models--Qwen--Qwen2.5-7B-Instruct/snapshots/a09a35458c702b33eeacc393d103063234e8bc28",
|
| 30 |
+
"--llm_tune",
|
| 31 |
+
"mid_lora",
|
| 32 |
+
"--stage1_path",
|
| 33 |
+
"/nas/shared/kilab/wangyujia/ProtT3/all_checkpoints/stage1_07041727_2dataset/epoch=29.ckpt/converted.ckpt",
|
| 34 |
+
"--use_wandb_logger",
|
| 35 |
+
"--dataset",
|
| 36 |
+
"swiss-prot"
|
| 37 |
+
],
|
| 38 |
+
"program": "/nas/shared/kilab/wangyujia/ProtT3/stage2.py",
|
| 39 |
+
"codePath": "wangyujia/ProtT3/stage2.py",
|
| 40 |
+
"git": {
|
| 41 |
+
"remote": "https://github.com/PorUna-byte/PAR.git",
|
| 42 |
+
"commit": "b8caf406aa1699c788f0ca6e44a1769452c317db"
|
| 43 |
+
},
|
| 44 |
+
"root": "./all_checkpoints/stage2_08011616_2datasets_qweninstruct/",
|
| 45 |
+
"host": "dsw-265304-58fbcf9d9b-zvtdx",
|
| 46 |
+
"executable": "/root/miniconda3/envs/protT3/bin/python",
|
| 47 |
+
"codePathLocal": "stage2.py",
|
| 48 |
+
"cpu_count": 64,
|
| 49 |
+
"cpu_count_logical": 64,
|
| 50 |
+
"gpu": "NVIDIA A800-SXM4-80GB",
|
| 51 |
+
"gpu_count": 8,
|
| 52 |
+
"disk": {
|
| 53 |
+
"/": {
|
| 54 |
+
"total": "1623302262784",
|
| 55 |
+
"used": "1476915200"
|
| 56 |
+
}
|
| 57 |
+
},
|
| 58 |
+
"memory": {
|
| 59 |
+
"total": "549755813888"
|
| 60 |
+
},
|
| 61 |
+
"cpu": {
|
| 62 |
+
"count": 64,
|
| 63 |
+
"countLogical": 64
|
| 64 |
+
},
|
| 65 |
+
"gpu_nvidia": [
|
| 66 |
+
{
|
| 67 |
+
"name": "NVIDIA A800-SXM4-80GB",
|
| 68 |
+
"memoryTotal": "85198045184",
|
| 69 |
+
"architecture": "Ampere"
|
| 70 |
+
},
|
| 71 |
+
{
|
| 72 |
+
"name": "NVIDIA A800-SXM4-80GB",
|
| 73 |
+
"memoryTotal": "85198045184",
|
| 74 |
+
"architecture": "Ampere"
|
| 75 |
+
},
|
| 76 |
+
{
|
| 77 |
+
"name": "NVIDIA A800-SXM4-80GB",
|
| 78 |
+
"memoryTotal": "85198045184",
|
| 79 |
+
"architecture": "Ampere"
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"name": "NVIDIA A800-SXM4-80GB",
|
| 83 |
+
"memoryTotal": "85198045184",
|
| 84 |
+
"architecture": "Ampere"
|
| 85 |
+
},
|
| 86 |
+
{
|
| 87 |
+
"name": "NVIDIA A800-SXM4-80GB",
|
| 88 |
+
"memoryTotal": "85198045184",
|
| 89 |
+
"architecture": "Ampere"
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"name": "NVIDIA A800-SXM4-80GB",
|
| 93 |
+
"memoryTotal": "85198045184",
|
| 94 |
+
"architecture": "Ampere"
|
| 95 |
+
},
|
| 96 |
+
{
|
| 97 |
+
"name": "NVIDIA A800-SXM4-80GB",
|
| 98 |
+
"memoryTotal": "85198045184",
|
| 99 |
+
"architecture": "Ampere"
|
| 100 |
+
},
|
| 101 |
+
{
|
| 102 |
+
"name": "NVIDIA A800-SXM4-80GB",
|
| 103 |
+
"memoryTotal": "85198045184",
|
| 104 |
+
"architecture": "Ampere"
|
| 105 |
+
}
|
| 106 |
+
],
|
| 107 |
+
"cudaVersion": "12.1"
|
| 108 |
+
}
|
all_checkpoints/stage2_08011616_2datasets_withoutpretrain/wandb/run-20250801_162632-790wkw0g/files/wandb-summary.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"trainer/global_step":134559,"dataset0/rouge_2":37.98429870605469,"_step":2700,"dataset0/bleu2":37.878440856933594,"dataset0/meteor_score":53.153656005859375,"lr":1.2202456673549023e-05,"dataset0/rouge_1":45.05393981933594,"_wandb":{"runtime":61483},"_timestamp":1.7540982570645058e+09,"dataset0/rouge_l":42.82152557373047,"dataset0/bleu4":34.96389389038086,"_runtime":61467.758475346,"epoch":9,"loss":0.37739327549934387,"dataloader0/val loss/dataloader_idx_0":0.6249951124191284,"dataset0/acc":0}
|
all_checkpoints/stage2_08011616_2datasets_withoutpretrain/wandb/run-20250801_162632-790wkw0g/logs/debug-internal.log
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"time":"2025-08-01T16:26:32.258940024+08:00","level":"INFO","msg":"stream: starting","core version":"0.19.11","symlink path":"all_checkpoints/stage2_08011616_2datasets_qweninstruct/wandb/run-20250801_162632-790wkw0g/logs/debug-core.log"}
|
| 2 |
+
{"time":"2025-08-01T16:26:47.100144119+08:00","level":"INFO","msg":"created new stream","id":"790wkw0g"}
|
| 3 |
+
{"time":"2025-08-01T16:26:47.101616008+08:00","level":"INFO","msg":"stream: started","id":"790wkw0g"}
|
| 4 |
+
{"time":"2025-08-01T16:26:47.101680831+08:00","level":"INFO","msg":"sender: started","stream_id":"790wkw0g"}
|
| 5 |
+
{"time":"2025-08-01T16:26:47.101644151+08:00","level":"INFO","msg":"writer: Do: started","stream_id":"790wkw0g"}
|
| 6 |
+
{"time":"2025-08-01T16:26:47.101748992+08:00","level":"INFO","msg":"handler: started","stream_id":"790wkw0g"}
|
| 7 |
+
{"time":"2025-08-01T16:26:50.002107204+08:00","level":"INFO","msg":"Starting system monitor"}
|
| 8 |
+
{"time":"2025-08-01T16:31:09.07203296+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:36310->104.21.20.172:443: read: connection timed out"}
|
| 9 |
+
{"time":"2025-08-01T16:38:20.137992503+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
| 10 |
+
{"time":"2025-08-01T16:41:40.88100016+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:41026->104.21.20.172:443: read: connection timed out"}
|
| 11 |
+
{"time":"2025-08-01T16:43:35.141344346+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
| 12 |
+
{"time":"2025-08-01T16:44:07.544267418+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
| 13 |
+
{"time":"2025-08-01T16:44:41.670680082+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
|
| 14 |
+
{"time":"2025-08-01T16:45:20.186411207+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
|
| 15 |
+
{"time":"2025-08-01T16:45:57.904074825+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:57230->172.67.193.61:443: read: connection timed out"}
|
| 16 |
+
{"time":"2025-08-01T16:46:09.248932426+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
| 17 |
+
{"time":"2025-08-01T16:50:22.097023685+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:45466->104.21.20.172:443: read: connection timed out"}
|
| 18 |
+
{"time":"2025-08-01T17:44:15.888009977+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:47510->104.21.20.172:443: read: connection timed out"}
|
| 19 |
+
{"time":"2025-08-01T20:05:52.724089483+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:49218->104.21.20.172:443: read: connection reset by peer"}
|
| 20 |
+
{"time":"2025-08-01T20:13:58.928063914+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:46332->104.21.20.172:443: read: connection timed out"}
|
| 21 |
+
{"time":"2025-08-01T20:27:15.089020249+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:51288->104.21.20.172:443: read: connection timed out"}
|
| 22 |
+
{"time":"2025-08-01T20:33:59.569042011+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:38854->172.67.193.61:443: read: connection timed out"}
|
| 23 |
+
{"time":"2025-08-01T20:34:32.461391926+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": unexpected EOF"}
|
| 24 |
+
{"time":"2025-08-01T20:38:30.928056966+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:45810->104.21.20.172:443: read: connection timed out"}
|
| 25 |
+
{"time":"2025-08-01T20:43:03.824002354+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:51984->104.21.20.172:443: read: connection timed out"}
|
| 26 |
+
{"time":"2025-08-01T20:46:25.450089946+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:42776->172.67.193.61:443: read: connection reset by peer"}
|
| 27 |
+
{"time":"2025-08-01T20:49:05.052996031+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:54072->104.21.20.172:443: read: connection reset by peer"}
|
| 28 |
+
{"time":"2025-08-01T20:49:50.926181978+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
| 29 |
+
{"time":"2025-08-01T20:50:58.961126603+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:46418->104.21.20.172:443: read: connection reset by peer"}
|
| 30 |
+
{"time":"2025-08-01T20:54:00.208036179+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:47556->104.21.20.172:443: read: connection timed out"}
|
| 31 |
+
{"time":"2025-08-01T21:03:33.648034502+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:45288->104.21.20.172:443: read: connection timed out"}
|
| 32 |
+
{"time":"2025-08-01T21:06:14.416001149+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:39640->172.67.193.61:443: read: connection timed out"}
|
| 33 |
+
{"time":"2025-08-01T21:06:50.935397642+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
|
| 34 |
+
{"time":"2025-08-01T21:07:22.991031701+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
| 35 |
+
{"time":"2025-08-01T21:07:57.40524595+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
|
| 36 |
+
{"time":"2025-08-01T21:08:36.813512769+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
| 37 |
+
{"time":"2025-08-01T21:09:24.678708826+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
| 38 |
+
{"time":"2025-08-01T21:10:33.369865787+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
| 39 |
+
{"time":"2025-08-01T21:13:09.331256642+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:55554->104.21.20.172:443: read: connection reset by peer"}
|
| 40 |
+
{"time":"2025-08-01T21:18:12.321384192+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": unexpected EOF"}
|
| 41 |
+
{"time":"2025-08-01T21:21:14.379467148+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
| 42 |
+
{"time":"2025-08-01T21:24:32.868908249+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": context deadline exceeded"}
|
| 43 |
+
{"time":"2025-08-01T21:28:13.328011263+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:53294->104.21.20.172:443: read: connection timed out"}
|
| 44 |
+
{"time":"2025-08-01T21:30:01.911292424+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:45656->172.67.193.61:443: read: connection reset by peer"}
|
| 45 |
+
{"time":"2025-08-01T21:31:05.951614465+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
| 46 |
+
{"time":"2025-08-01T21:32:43.664038993+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:58998->172.67.193.61:443: read: connection timed out"}
|
| 47 |
+
{"time":"2025-08-01T21:33:35.953483769+08:00","level":"ERROR","msg":"sender: sendStopStatus: failed to get run stopped status: context deadline exceeded (Client.Timeout or context cancellation while reading body)"}
|
| 48 |
+
{"time":"2025-08-01T21:35:00.100184499+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:48498->104.21.20.172:443: read: connection reset by peer"}
|
| 49 |
+
{"time":"2025-08-01T21:37:35.955239872+08:00","level":"ERROR","msg":"sender: sendStopStatus: failed to get run stopped status: context deadline exceeded (Client.Timeout or context cancellation while reading body)"}
|
| 50 |
+
{"time":"2025-08-01T21:38:05.712022165+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:60866->104.21.20.172:443: read: connection timed out"}
|
| 51 |
+
{"time":"2025-08-01T21:38:20.956070921+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
|
| 52 |
+
{"time":"2025-08-01T21:41:06.960060464+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:55500->172.67.193.61:443: read: connection timed out"}
|
| 53 |
+
{"time":"2025-08-01T21:41:57.697000345+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:36556->172.67.193.61:443: read: connection reset by peer"}
|
| 54 |
+
{"time":"2025-08-01T21:45:01.968042256+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:53014->172.67.193.61:443: read: connection timed out"}
|
| 55 |
+
{"time":"2025-08-01T21:46:05.960063741+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
| 56 |
+
{"time":"2025-08-01T21:47:02.576818619+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:44354->104.21.20.172:443: read: connection reset by peer"}
|
| 57 |
+
{"time":"2025-08-01T21:54:27.728015687+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:59618->104.21.20.172:443: read: connection timed out"}
|
| 58 |
+
{"time":"2025-08-01T21:57:21.808006007+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:46320->104.21.20.172:443: read: connection timed out"}
|
| 59 |
+
{"time":"2025-08-01T22:00:19.984010382+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:53128->104.21.20.172:443: read: connection timed out"}
|
| 60 |
+
{"time":"2025-08-01T22:00:50.968803427+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
| 61 |
+
{"time":"2025-08-01T22:03:12.527983291+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:34872->104.21.20.172:443: read: connection timed out"}
|
| 62 |
+
{"time":"2025-08-01T22:04:48.400864668+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:43930->104.21.20.172:443: read: connection reset by peer"}
|
| 63 |
+
{"time":"2025-08-01T22:04:50.970855218+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
| 64 |
+
{"time":"2025-08-01T22:07:31.600992168+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:49046->104.21.20.172:443: read: connection timed out"}
|
| 65 |
+
{"time":"2025-08-01T22:09:05.97302798+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
|
| 66 |
+
{"time":"2025-08-01T22:09:46.459558104+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:47316->104.21.20.172:443: read: connection reset by peer"}
|
| 67 |
+
{"time":"2025-08-01T22:10:35.97491736+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
| 68 |
+
{"time":"2025-08-01T22:11:08.18377384+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
|
| 69 |
+
{"time":"2025-08-01T22:11:25.779600104+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:34352->172.67.193.61:443: read: connection reset by peer"}
|
| 70 |
+
{"time":"2025-08-01T22:11:42.262306077+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
|
| 71 |
+
{"time":"2025-08-01T22:12:50.976429474+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
| 72 |
+
{"time":"2025-08-01T22:13:23.229614619+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
|
| 73 |
+
{"time":"2025-08-01T22:13:57.892901227+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
| 74 |
+
{"time":"2025-08-01T22:15:04.820393306+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:60344->104.21.20.172:443: read: connection reset by peer"}
|
| 75 |
+
{"time":"2025-08-01T22:15:45.367773065+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:54996->172.67.193.61:443: read: connection reset by peer"}
|
| 76 |
+
{"time":"2025-08-01T22:17:57.605995597+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": unexpected EOF"}
|
| 77 |
+
{"time":"2025-08-01T22:21:13.87299516+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:49910->104.21.20.172:443: read: connection timed out"}
|
| 78 |
+
{"time":"2025-08-01T22:24:16.144029987+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:41940->172.67.193.61:443: read: connection timed out"}
|
| 79 |
+
{"time":"2025-08-01T22:27:28.144021751+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:46560->172.67.193.61:443: read: connection timed out"}
|
| 80 |
+
{"time":"2025-08-01T22:28:35.984483121+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
| 81 |
+
{"time":"2025-08-01T22:31:45.168984421+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:38656->172.67.193.61:443: read: connection timed out"}
|
| 82 |
+
{"time":"2025-08-01T22:35:18.671993015+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:49116->104.21.20.172:443: read: connection timed out"}
|
| 83 |
+
{"time":"2025-08-01T22:39:02.928024828+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:40350->172.67.193.61:443: read: connection timed out"}
|
| 84 |
+
{"time":"2025-08-01T22:39:35.380477212+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:35004->104.21.20.172:443: read: connection reset by peer"}
|
| 85 |
+
{"time":"2025-08-01T22:39:55.432951061+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": unexpected EOF"}
|
| 86 |
+
{"time":"2025-08-01T22:41:50.991964572+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
| 87 |
+
{"time":"2025-08-01T22:41:53.61821248+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:37852->104.21.20.172:443: read: connection reset by peer"}
|
| 88 |
+
{"time":"2025-08-01T22:43:35.932757702+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": http2: client conn is closed"}
|
| 89 |
+
{"time":"2025-08-01T22:44:04.110612747+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:53322->104.21.20.172:443: read: connection reset by peer"}
|
| 90 |
+
{"time":"2025-08-01T22:46:35.994608155+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
|
| 91 |
+
{"time":"2025-08-01T22:47:08.18720141+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
| 92 |
+
{"time":"2025-08-01T22:47:34.92805984+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:51190->172.67.193.61:443: read: connection timed out"}
|
| 93 |
+
{"time":"2025-08-01T22:47:42.321052323+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
|
| 94 |
+
{"time":"2025-08-01T22:50:05.996775834+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
|
| 95 |
+
{"time":"2025-08-01T22:50:30.544005531+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:49020->172.67.193.61:443: read: connection timed out"}
|
| 96 |
+
{"time":"2025-08-01T22:50:38.303517763+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
| 97 |
+
{"time":"2025-08-01T22:51:12.580566425+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
|
| 98 |
+
{"time":"2025-08-01T22:51:51.301848016+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
|
| 99 |
+
{"time":"2025-08-01T22:52:39.862980102+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
|
| 100 |
+
{"time":"2025-08-01T22:53:43.460661511+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
| 101 |
+
{"time":"2025-08-01T22:55:13.462771183+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
|
| 102 |
+
{"time":"2025-08-01T22:55:50.033040855+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:38370->172.67.193.61:443: read: connection timed out"}
|
| 103 |
+
{"time":"2025-08-01T22:56:24.977146749+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:39018->104.21.20.172:443: read: connection reset by peer"}
|
| 104 |
+
{"time":"2025-08-01T22:56:43.465087253+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
| 105 |
+
{"time":"2025-08-01T22:58:13.466056502+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
|
| 106 |
+
{"time":"2025-08-01T22:59:29.003634152+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": context deadline exceeded (Client.Timeout exceeded while awaiting headers)"}
|
| 107 |
+
{"time":"2025-08-01T22:59:35.997189916+08:00","level":"WARN","msg":"sender: taking a long time","seconds":600.000438436,"work":"WorkRecord(*service_go_proto.Request_StopStatus); Control(local:true mailbox_slot:\"zbp0fazrv773\" connection_id:\"127.0.0.1:32880\")"}
|
| 108 |
+
{"time":"2025-08-01T22:59:43.468121454+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
| 109 |
+
{"time":"2025-08-01T22:59:56.405426213+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": unexpected EOF"}
|
| 110 |
+
{"time":"2025-08-01T22:59:56.885071194+08:00","level":"WARN","msg":"runwork: taking a long time","seconds":600.000154693,"work":"WorkRecord(*service_go_proto.Record_OutputRaw); Control(connection_id:\"127.0.0.1:32880\")"}
|
| 111 |
+
{"time":"2025-08-01T23:00:05.004711621+08:00","level":"WARN","msg":"runwork: taking a long time","seconds":600.00033328,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"}
|
| 112 |
+
{"time":"2025-08-01T23:00:05.076924009+08:00","level":"WARN","msg":"runwork: taking a long time","seconds":600.000302659,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"}
|
| 113 |
+
{"time":"2025-08-01T23:01:13.46903123+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
|
| 114 |
+
{"time":"2025-08-01T23:02:43.472282526+08:00","level":"INFO","msg":"sender: succeeded after taking longer than expected","seconds":787.475510948,"work":"WorkRecord(*service_go_proto.Request_StopStatus); Control(local:true mailbox_slot:\"zbp0fazrv773\" connection_id:\"127.0.0.1:32880\")"}
|
| 115 |
+
{"time":"2025-08-01T23:02:43.472384259+08:00","level":"INFO","msg":"runwork: succeeded after taking longer than expected","seconds":758.39578933,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"}
|
| 116 |
+
{"time":"2025-08-01T23:02:43.472386124+08:00","level":"INFO","msg":"runwork: succeeded after taking longer than expected","seconds":758.468026974,"work":"WorkRecord(*service_go_proto.Record_Stats); Control(always_send:true)"}
|
| 117 |
+
{"time":"2025-08-01T23:02:43.47239327+08:00","level":"INFO","msg":"runwork: succeeded after taking longer than expected","seconds":766.587479442,"work":"WorkRecord(*service_go_proto.Record_OutputRaw); Control(connection_id:\"127.0.0.1:32880\")"}
|
| 118 |
+
{"time":"2025-08-01T23:03:03.192446429+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:44012->172.67.193.61:443: read: connection reset by peer"}
|
| 119 |
+
{"time":"2025-08-01T23:03:21.005846988+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
|
| 120 |
+
{"time":"2025-08-01T23:08:40.592045526+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:38310->104.21.20.172:443: read: connection timed out"}
|
| 121 |
+
{"time":"2025-08-01T23:11:46.960031857+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:55084->104.21.20.172:443: read: connection timed out"}
|
| 122 |
+
{"time":"2025-08-01T23:12:24.849142697+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:49670->104.21.20.172:443: read: connection reset by peer"}
|
| 123 |
+
{"time":"2025-08-01T23:15:08.177028412+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:52862->104.21.20.172:443: read: connection timed out"}
|
| 124 |
+
{"time":"2025-08-01T23:16:46.602539365+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:45510->104.21.20.172:443: read: connection reset by peer"}
|
| 125 |
+
{"time":"2025-08-01T23:18:23.28263426+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": read tcp 10.1.11.100:48380->172.67.193.61:443: read: connection reset by peer"}
|
| 126 |
+
{"time":"2025-08-01T23:19:27.841519011+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": unexpected EOF"}
|
| 127 |
+
{"time":"2025-08-01T23:23:01.165294384+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": context deadline exceeded"}
|
| 128 |
+
{"time":"2025-08-01T23:26:53.200081773+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:43360->104.21.20.172:443: read: connection timed out"}
|
| 129 |
+
{"time":"2025-08-01T23:34:57.039993804+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:52350->172.67.193.61:443: read: connection timed out"}
|
| 130 |
+
{"time":"2025-08-01T23:38:16.720010379+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:57172->104.21.20.172:443: read: connection timed out"}
|
| 131 |
+
{"time":"2025-08-01T23:41:21.040008763+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:36508->104.21.20.172:443: read: connection timed out"}
|
| 132 |
+
{"time":"2025-08-01T23:44:44.816004249+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:55942->172.67.193.61:443: read: connection timed out"}
|
| 133 |
+
{"time":"2025-08-01T23:48:52.62500069+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:47172->172.67.193.61:443: read: connection timed out"}
|
| 134 |
+
{"time":"2025-08-01T23:51:31.345006101+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:52938->104.21.20.172:443: read: connection timed out"}
|
| 135 |
+
{"time":"2025-08-01T23:54:39.760994652+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:57488->104.21.20.172:443: read: connection timed out"}
|
| 136 |
+
{"time":"2025-08-01T23:56:53.243171871+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": unexpected EOF"}
|
| 137 |
+
{"time":"2025-08-02T00:01:11.439949517+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:48808->172.67.193.61:443: read: connection timed out"}
|
| 138 |
+
{"time":"2025-08-02T00:03:57.840030605+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:58796->172.67.193.61:443: read: connection timed out"}
|
| 139 |
+
{"time":"2025-08-02T00:04:51.340371717+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": read tcp 10.1.11.100:40756->104.21.20.172:443: read: connection reset by peer"}
|
| 140 |
+
{"time":"2025-08-02T00:06:51.040127189+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
| 141 |
+
{"time":"2025-08-02T00:12:23.188896321+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": read tcp 10.1.11.100:52718->172.67.193.61:443: read: connection reset by peer"}
|
| 142 |
+
{"time":"2025-08-02T00:19:09.499959872+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": unexpected EOF"}
|
| 143 |
+
{"time":"2025-08-02T00:27:17.136029213+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:33250->104.21.20.172:443: read: connection timed out"}
|
| 144 |
+
{"time":"2025-08-02T00:31:24.717936901+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:42528->104.21.20.172:443: read: connection reset by peer"}
|
| 145 |
+
{"time":"2025-08-02T00:49:10.92803379+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:49640->172.67.193.61:443: read: connection timed out"}
|
| 146 |
+
{"time":"2025-08-02T00:53:16.688059564+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:41832->172.67.193.61:443: read: connection timed out"}
|
| 147 |
+
{"time":"2025-08-02T01:03:03.440032369+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:36398->104.21.20.172:443: read: connection timed out"}
|
| 148 |
+
{"time":"2025-08-02T01:15:25.202630664+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:35078->104.21.20.172:443: read: connection reset by peer"}
|
| 149 |
+
{"time":"2025-08-02T01:23:06.089931169+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
| 150 |
+
{"time":"2025-08-02T01:28:16.400050797+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:37336->172.67.193.61:443: read: connection timed out"}
|
| 151 |
+
{"time":"2025-08-02T02:22:06.219804028+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
| 152 |
+
{"time":"2025-08-02T03:05:52.144050294+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream\": read tcp 10.1.11.100:51464->172.67.193.61:443: read: connection timed out"}
|
| 153 |
+
{"time":"2025-08-02T06:13:41.315904777+08:00","level":"INFO","msg":"api: retrying HTTP error","status":504,"url":"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream","body":"error code: 504"}
|
| 154 |
+
{"time":"2025-08-02T06:13:52.289991198+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
|
| 155 |
+
{"time":"2025-08-02T06:55:22.310503234+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": context deadline exceeded"}
|
| 156 |
+
{"time":"2025-08-02T06:55:25.73261549+08:00","level":"INFO","msg":"api: retrying HTTP error","status":504,"url":"https://api.bandw.top/files/gia0603yucca/stage2_08011616_2datasets_qweninstruct/790wkw0g/file_stream","body":"error code: 504"}
|
| 157 |
+
{"time":"2025-08-02T07:13:22.52185475+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": read tcp 10.1.11.100:41606->104.21.20.172:443: read: connection reset by peer"}
|
| 158 |
+
{"time":"2025-08-02T07:38:52.528290374+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.bandw.top/graphql\": read tcp 10.1.11.100:38352->104.21.20.172:443: read: connection reset by peer"}
|
| 159 |
+
{"time":"2025-08-02T09:31:16.192630538+08:00","level":"INFO","msg":"stream: closing","id":"790wkw0g"}
|
| 160 |
+
{"time":"2025-08-02T09:31:16.192679202+08:00","level":"INFO","msg":"Stopping system monitor"}
|
| 161 |
+
{"time":"2025-08-02T09:31:16.194392232+08:00","level":"INFO","msg":"Stopped system monitor"}
|
| 162 |
+
{"time":"2025-08-02T09:31:21.149902109+08:00","level":"INFO","msg":"fileTransfer: Close: file transfer manager closed"}
|
| 163 |
+
{"time":"2025-08-02T09:31:22.465229216+08:00","level":"INFO","msg":"handler: closed","stream_id":"790wkw0g"}
|
| 164 |
+
{"time":"2025-08-02T09:31:22.465258111+08:00","level":"INFO","msg":"writer: Close: closed","stream_id":"790wkw0g"}
|
| 165 |
+
{"time":"2025-08-02T09:31:22.465263174+08:00","level":"INFO","msg":"sender: closed","stream_id":"790wkw0g"}
|
| 166 |
+
{"time":"2025-08-02T09:31:22.469479316+08:00","level":"INFO","msg":"stream: closed","id":"790wkw0g"}
|
all_checkpoints/stage2_08011616_2datasets_withoutpretrain/wandb/run-20250801_162632-790wkw0g/logs/debug.log
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2025-08-01 16:26:32,237 INFO MainThread:764673 [wandb_setup.py:_flush():70] Current SDK version is 0.19.11
|
| 2 |
+
2025-08-01 16:26:32,237 INFO MainThread:764673 [wandb_setup.py:_flush():70] Configure stats pid to 764673
|
| 3 |
+
2025-08-01 16:26:32,237 INFO MainThread:764673 [wandb_setup.py:_flush():70] Loading settings from /root/.config/wandb/settings
|
| 4 |
+
2025-08-01 16:26:32,237 INFO MainThread:764673 [wandb_setup.py:_flush():70] Loading settings from /nas/shared/kilab/wangyujia/ProtT3/wandb/settings
|
| 5 |
+
2025-08-01 16:26:32,237 INFO MainThread:764673 [wandb_setup.py:_flush():70] Loading settings from environment variables
|
| 6 |
+
2025-08-01 16:26:32,237 INFO MainThread:764673 [wandb_init.py:setup_run_log_directory():724] Logging user logs to ./all_checkpoints/stage2_08011616_2datasets_qweninstruct/wandb/run-20250801_162632-790wkw0g/logs/debug.log
|
| 7 |
+
2025-08-01 16:26:32,237 INFO MainThread:764673 [wandb_init.py:setup_run_log_directory():725] Logging internal logs to ./all_checkpoints/stage2_08011616_2datasets_qweninstruct/wandb/run-20250801_162632-790wkw0g/logs/debug-internal.log
|
| 8 |
+
2025-08-01 16:26:32,237 INFO MainThread:764673 [wandb_init.py:init():852] calling init triggers
|
| 9 |
+
2025-08-01 16:26:32,237 INFO MainThread:764673 [wandb_init.py:init():857] wandb.init called with sweep_config: {}
|
| 10 |
+
config: {'_wandb': {}}
|
| 11 |
+
2025-08-01 16:26:32,237 INFO MainThread:764673 [wandb_init.py:init():893] starting backend
|
| 12 |
+
2025-08-01 16:26:32,237 INFO MainThread:764673 [wandb_init.py:init():897] sending inform_init request
|
| 13 |
+
2025-08-01 16:26:32,241 INFO MainThread:764673 [backend.py:_multiprocessing_setup():101] multiprocessing start_methods=fork,spawn,forkserver, using: spawn
|
| 14 |
+
2025-08-01 16:26:32,241 INFO MainThread:764673 [wandb_init.py:init():907] backend started and connected
|
| 15 |
+
2025-08-01 16:26:32,242 INFO MainThread:764673 [wandb_init.py:init():1005] updated telemetry
|
| 16 |
+
2025-08-01 16:26:32,306 INFO MainThread:764673 [wandb_init.py:init():1029] communicating run to backend with 90.0 second timeout
|
| 17 |
+
2025-08-01 16:26:49,938 INFO MainThread:764673 [wandb_init.py:init():1104] starting run threads in backend
|
| 18 |
+
2025-08-01 16:26:50,132 INFO MainThread:764673 [wandb_run.py:_console_start():2573] atexit reg
|
| 19 |
+
2025-08-01 16:26:50,132 INFO MainThread:764673 [wandb_run.py:_redirect():2421] redirect: wrap_raw
|
| 20 |
+
2025-08-01 16:26:50,132 INFO MainThread:764673 [wandb_run.py:_redirect():2490] Wrapping output streams.
|
| 21 |
+
2025-08-01 16:26:50,132 INFO MainThread:764673 [wandb_run.py:_redirect():2513] Redirects installed.
|
| 22 |
+
2025-08-01 16:26:50,134 INFO MainThread:764673 [wandb_init.py:init():1150] run started, returning control to user process
|
| 23 |
+
2025-08-01 16:26:58,724 INFO MainThread:764673 [wandb_run.py:_config_callback():1436] config_cb None None {'filename': 'stage2_08011616_2datasets_qweninstruct', 'seed': 42, 'mode': 'train', 'strategy': 'deepspeed', 'accelerator': 'gpu', 'devices': '0,1,2,3,4,5,6,7', 'precision': 'bf16-mixed', 'max_epochs': 10, 'accumulate_grad_batches': 1, 'check_val_every_n_epoch': 1, 'enable_flash': False, 'use_wandb_logger': True, 'mix_dataset': False, 'dataset': 'swiss-prot', 'save_every_n_epochs': 2, 'bert_name': '/nas/shared/kilab/wangyujia/ProtT3/plm_model/microsoft', 'cross_attention_freq': 2, 'num_query_token': 8, 'qformer_tune': 'train', 'llm_name': '/nas/shared/kilab/hf-hub/models--Qwen--Qwen2.5-7B-Instruct/snapshots/a09a35458c702b33eeacc393d103063234e8bc28', 'num_beams': 5, 'do_sample': False, 'max_inference_len': 128, 'min_inference_len': 1, 'llm_tune': 'mid_lora', 'peft_config': '', 'peft_dir': '', 'plm_model': '/nas/shared/kilab/wangyujia/ProtT3/plm_model/esm2-150m', 'plm_tune': 'freeze', 'lora_r': 8, 'lora_alpha': 16, 'lora_dropout': 0.1, 'enbale_gradient_checkpointing': False, 'weight_decay': 0.05, 'init_lr': 0.0001, 'min_lr': 1e-05, 'warmup_lr': 1e-06, 'warmup_steps': 1000, 'lr_decay_rate': 0.9, 'scheduler': 'linear_warmup_cosine_lr', 'stage1_path': '/nas/shared/kilab/wangyujia/ProtT3/all_checkpoints/stage1_07041727_2dataset/epoch=29.ckpt/converted.ckpt', 'stage2_path': '', 'init_checkpoint': '/nas/shared/kilab/wangyujia/ProtT3/all_checkpoints/stage2_07070513_2datasets_construct/epoch=09.ckpt/converted.ckpt', 'caption_eval_epoch': 5, 'num_workers': 8, 'batch_size': 4, 'inference_batch_size': 4, 'root': 'data', 'text_max_len': 2048, 'q_max_len': 29, 'a_max_len': 36, 'prot_max_len': 1024, 'prompt': 'The protein has the following properties:', 'filter_side_qa': False}
|
| 24 |
+
2025-08-02 09:31:16,189 INFO MsgRouterThr:764673 [mailbox.py:close():129] [no run ID] Closing mailbox, abandoning 1 handles.
|
all_checkpoints/stage2_08011616_2datasets_withoutpretrain/wandb/run-20250801_162632-790wkw0g/run-790wkw0g.wandb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cb5d0be9388288b74f1f5a6f3f8c39d9ee6aa7643fe0c3426a6f51210ad77dc1
|
| 3 |
+
size 84438700
|
data/.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*
|
| 2 |
+
!.gitignore
|
data/OntoProteinDatasetV2/test.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:87c0edf3fd59defd24decb12c8af64d9bb5fa1ef727fce2bf5fc36c06fe4066f
|
| 3 |
+
size 12085289
|
data/OntoProteinDatasetV2/train.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:632a5830aea6b3029feaf2eeba5a1c2b63e2badc61b72f92734c517aef85e879
|
| 3 |
+
size 478799520
|
data/OntoProteinDatasetV2/valid.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:40cefeb9b08cff978fee2ce21f259974fa8881bbe6614f04e896ca7560d4e11f
|
| 3 |
+
size 11829916
|
data/PDBDataset/abstract.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:59cf438e3325c38809ab41fc09fe00ec78fc4bd41d7073f0c22f664168ea5ca3
|
| 3 |
+
size 190905202
|
data/PDBDataset/q_types.txt
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
What is the nucleic acid polymer entity type for this protein? String structure/property
|
| 2 |
+
When is this protein first published? Number side information
|
| 3 |
+
How many polymer monomers does this protein have? Number structure/property
|
| 4 |
+
How many assemblies does this protein have? Number structure/property
|
| 5 |
+
How many heavy solvent atom coordinates records does this protein have? Number side information
|
| 6 |
+
Does this protein have cis-peptide linkages? String structure/property
|
| 7 |
+
Does this protein contain branched entities? String structure/property
|
| 8 |
+
How many entities does this protein have? Number structure/property
|
| 9 |
+
Does this protein contain solvent entities? String structure/property
|
| 10 |
+
What is the polymer entity type for this protein? String structure/property
|
| 11 |
+
How many nucleic acid polymer entities (DNA or RNA) does this protein have? Number structure/property
|
| 12 |
+
What is the polymer entity composition for this protein? String structure/property
|
| 13 |
+
Does this protein contain DNA polymer entities? String structure/property
|
| 14 |
+
Does this protein contain non-polymer entities? String structure/property
|
| 15 |
+
How many heavy atom coordinates records does this protein have? Number side information
|
| 16 |
+
How many intermolecular metalic bonds does this protein have? Number structure/property
|
| 17 |
+
Does this protein have hybrid nucleic acid polymer entities? String structure/property
|
| 18 |
+
What is the molecular mass (KDa) of polymer and non-polymer entities (exclusive of solvent) for this protein? Number structure/property
|
| 19 |
+
Is this protein determined by experimental or computational methods? String side information
|
| 20 |
+
What is the radiation wavelength in angstroms for this protein? Number structure/property
|
| 21 |
+
Does this protein have unmodeled polymer monomers? String side information
|
| 22 |
+
How many intermolecular covalent bonds does this protein have? Number structure/property
|
| 23 |
+
How many model structures deposited for this protein? Number side information
|
| 24 |
+
What are the software programs reported in connection with the production of this protein? String side information
|
| 25 |
+
How many hydrogen atom coordinates records does this protein have? Number side information
|
| 26 |
+
What experimental method(s) were used to determine the structure of this protein? String side information
|
| 27 |
+
Does this protein contain polymer entities? String structure/property
|
| 28 |
+
Does this protein contain RNA polymer entities? String structure/property
|
| 29 |
+
What are the bound nonpolymer components for this protein String structure/property
|
| 30 |
+
What are the terms characterizing the protein? String structure/property
|
data/PDBDataset/qa_all.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dfd9d775036717127b66aa9ba9c9d61f8fc324c4ea8fb548d6a1fd2ab4506194
|
| 3 |
+
size 523476104
|
data/PDBDataset/test.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/PDBDataset/train.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:494c6c0680bdedd58dc0da2df11393b123ad5a4998c68d432da0b1db777c8870
|
| 3 |
+
size 34040909
|
data/PDBDataset/val.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/SwissProtV3/test_set.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/SwissProtV3/train_set.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d92b283a50056b185af292082ea3c733d2c979e70ca233c19d4b06bf81381713
|
| 3 |
+
size 312044762
|
data/SwissProtV3/valid_set.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/protein-molecule/protein-text.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5ae14ae1857962332f2b73a3f7ef5651541e7be7d886bb71031b4fc745626f3a
|
| 3 |
+
size 1566756112
|
data/protein-text/eval_assist.zipg3ebgjl7.tmp
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:28e38b555c6d456da8e2b95dd2af19f884b2b86687f733d745b7972079dff708
|
| 3 |
+
size 113836360
|
data/protein-text/eval_assist.ziphwjr8q2y.tmp
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:58bd802e8fc350fd425e9a827245662790580004bc064aa84b265cad8f2e6d3f
|
| 3 |
+
size 9836668
|
data/protein-text/eval_assist.zipzh1pdmj_.tmp
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3ddebab3f8ea6218c001053e6861781b90b2b70ef429bf9239761c2c0c7927e0
|
| 3 |
+
size 34493776
|
data_provider/__pycache__/bindingdb.cpython-310.pyc
ADDED
|
Binary file (2.78 kB). View file
|
|
|
data_provider/__pycache__/go.cpython-310.pyc
ADDED
|
Binary file (6.44 kB). View file
|
|
|
data_provider/__pycache__/metalIonbinding.cpython-310.pyc
ADDED
|
Binary file (2.83 kB). View file
|
|
|
data_provider/__pycache__/mutation.cpython-310.pyc
ADDED
|
Binary file (3.97 kB). View file
|
|
|
data_provider/__pycache__/production.cpython-310.pyc
ADDED
|
Binary file (6.97 kB). View file
|
|
|
data_provider/__pycache__/prot_qa_dm.cpython-310.pyc
ADDED
|
Binary file (7.92 kB). View file
|
|
|
data_provider/__pycache__/prot_qa_dm.cpython-311.pyc
ADDED
|
Binary file (15.9 kB). View file
|
|
|
data_provider/__pycache__/stage1_dm.cpython-310.pyc
ADDED
|
Binary file (14.9 kB). View file
|
|
|
data_provider/__pycache__/stage1_dm.cpython-311.pyc
ADDED
|
Binary file (30.1 kB). View file
|
|
|
data_provider/__pycache__/stage2_dm.cpython-310.pyc
ADDED
|
Binary file (12.4 kB). View file
|
|
|
data_provider/__pycache__/stage3_dm.cpython-310.pyc
ADDED
|
Binary file (23.7 kB). View file
|
|
|
data_provider/__pycache__/stage3_dm.cpython-311.pyc
ADDED
|
Binary file (12.9 kB). View file
|
|
|
data_provider/bindingdb.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pytorch_lightning import LightningDataModule
|
| 3 |
+
from torch.utils.data import DataLoader, ConcatDataset, Dataset
|
| 4 |
+
from data_provider.stage1_dm import SwissProtDataset, OntoProteinDataset
|
| 5 |
+
import pandas as pd
|
| 6 |
+
|
| 7 |
+
class BindingDB(Dataset):
|
| 8 |
+
def __init__(self, data_path, prompt='', return_prompt=False):
|
| 9 |
+
super(BindingDB, self).__init__()
|
| 10 |
+
self.data_path = data_path
|
| 11 |
+
self.user_prompt = prompt
|
| 12 |
+
self.return_prompt = return_prompt
|
| 13 |
+
|
| 14 |
+
self.data_list = self._load_and_preprocess(self.data_path)
|
| 15 |
+
self.text2id = self._build_text_vocab()
|
| 16 |
+
|
| 17 |
+
def _load_and_preprocess(self, data_path):
|
| 18 |
+
data_list = []
|
| 19 |
+
df = pd.read_csv(data_path)
|
| 20 |
+
for _, row in df.iterrows():
|
| 21 |
+
try:
|
| 22 |
+
ligand_smiles = str(row['ligand']).strip()
|
| 23 |
+
prot_seq = str(row['protein']).strip()
|
| 24 |
+
result = str(row['ic50']).strip()
|
| 25 |
+
|
| 26 |
+
text_seq = f"<answer>{result}</answer>\n"
|
| 27 |
+
|
| 28 |
+
prompt = f"""
|
| 29 |
+
【Protein sequence (1-letter amino acid codes)】;{ligand_smiles}【Ligand structure (SMILES)】
|
| 30 |
+
Task: Evaluate the inhibitory effect of the ligand on the given protein.
|
| 31 |
+
Note: IC50 (half maximal inhibitory concentration) is the concentration of a substance required to inhibit 50% of the protein's activity. Lower IC50 values indicate stronger inhibition.
|
| 32 |
+
Based on the provided protein and ligand, predict the inhibitory strength by classifying the IC50 level:
|
| 33 |
+
"""
|
| 34 |
+
if self.user_prompt:
|
| 35 |
+
prompt += self.user_prompt
|
| 36 |
+
|
| 37 |
+
# extra可以返回原始feather字符串,也可以返回feather_vals
|
| 38 |
+
# 或 feather_raw
|
| 39 |
+
data_list.append((prot_seq, text_seq, prompt))
|
| 40 |
+
except Exception as e:
|
| 41 |
+
print(f"警告: 跳过有问题的行: {row},原因: {e}")
|
| 42 |
+
return data_list
|
| 43 |
+
|
| 44 |
+
def _build_text_vocab(self):
|
| 45 |
+
text2id = {}
|
| 46 |
+
for _, text_seq, _ in self.data_list:
|
| 47 |
+
if text_seq not in text2id:
|
| 48 |
+
text2id[text_seq] = len(text2id)
|
| 49 |
+
return text2id
|
| 50 |
+
|
| 51 |
+
def shuffle(self):
|
| 52 |
+
random.shuffle(self.data_list)
|
| 53 |
+
return self
|
| 54 |
+
|
| 55 |
+
def __len__(self):
|
| 56 |
+
return len(self.data_list)
|
| 57 |
+
|
| 58 |
+
def __getitem__(self, index):
|
| 59 |
+
prot_seq, text_seq, prompt = self.data_list[index]
|
| 60 |
+
if self.return_prompt:
|
| 61 |
+
return prot_seq, prompt, text_seq,index
|
| 62 |
+
return prot_seq, text_seq, index
|
data_provider/gal_helpers.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
# we split individual characters inside special tokens like [START_DNA]
|
| 5 |
+
CUSTOM_SEQ_RE = re.compile(r"(\[START_(DNA|SMILES|I_SMILES|AMINO)])(.*?)(\[END_\2])")
|
| 6 |
+
|
| 7 |
+
# token added to implement a custom sequence tokenization. This token is added at
|
| 8 |
+
# corpus cleaning step and removed in pretokenization. The digits are added to increase the chance
|
| 9 |
+
# that they do not occur in the corpus. The digits are escaped so that the token does not appear
|
| 10 |
+
# literally in the source code in case we ever include it in the training data.
|
| 11 |
+
SPLIT_MARKER = f"SPL{1}T-TH{1}S-Pl3A5E"
|
| 12 |
+
|
| 13 |
+
def _insert_split_marker(m: re.Match):
|
| 14 |
+
"""
|
| 15 |
+
Applies split marker based on a regex match of special tokens such as
|
| 16 |
+
[START_DNA].
|
| 17 |
+
|
| 18 |
+
Parameters
|
| 19 |
+
----------
|
| 20 |
+
n : str
|
| 21 |
+
Input text to split
|
| 22 |
+
|
| 23 |
+
Returns
|
| 24 |
+
----------
|
| 25 |
+
str - the text with the split token added
|
| 26 |
+
"""
|
| 27 |
+
start_token, _, sequence, end_token = m.groups()
|
| 28 |
+
sequence = re.sub(r"(.)", fr"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL)
|
| 29 |
+
return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def escape_custom_split_sequence(text):
|
| 33 |
+
"""
|
| 34 |
+
Applies custom splitting to the text for GALILEO's tokenization
|
| 35 |
+
|
| 36 |
+
Parameters
|
| 37 |
+
----------
|
| 38 |
+
text : str
|
| 39 |
+
Input text to split
|
| 40 |
+
|
| 41 |
+
Returns
|
| 42 |
+
----------
|
| 43 |
+
str - the text with the split token added
|
| 44 |
+
"""
|
| 45 |
+
return CUSTOM_SEQ_RE.sub(_insert_split_marker, text)
|
data_provider/go.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pytorch_lightning import LightningDataModule
|
| 3 |
+
from torch.utils.data import DataLoader, ConcatDataset, Dataset
|
| 4 |
+
from data_provider.stage1_dm import SwissProtDataset, OntoProteinDataset
|
| 5 |
+
import pandas as pd
|
| 6 |
+
|
| 7 |
+
class GO_BP(Dataset):
|
| 8 |
+
def __init__(self, data_path, prompt='', return_prompt=False):
|
| 9 |
+
super(GO_BP, self).__init__()
|
| 10 |
+
self.data_path = data_path
|
| 11 |
+
self.user_prompt = prompt
|
| 12 |
+
self.return_prompt = return_prompt
|
| 13 |
+
|
| 14 |
+
self.data_list = self._load_and_preprocess(self.data_path)
|
| 15 |
+
self.text2id = self._build_text_vocab()
|
| 16 |
+
|
| 17 |
+
def _load_and_preprocess(self, data_path):
|
| 18 |
+
data_list = []
|
| 19 |
+
df = pd.read_csv(data_path)
|
| 20 |
+
for _, row in df.iterrows():
|
| 21 |
+
try:
|
| 22 |
+
prot_seq = str(row['question']).strip()
|
| 23 |
+
result = str(row['answer']).strip()
|
| 24 |
+
|
| 25 |
+
text_seq = f"<answer>{result}</answer>\n"
|
| 26 |
+
|
| 27 |
+
prompt = f"""
|
| 28 |
+
【Task】Predict the biological processes involving the given protein.
|
| 29 |
+
【Background】Each process is represented by a GO-BP term (e.g., GO:0008150), describing a series of molecular events relevant to protein function.
|
| 30 |
+
【Output Format】List the predicted GO-BP terms, separated by commas, and wrap them in <answer> </answer> tags.
|
| 31 |
+
Example: <answer>GO:0008150, GO:0009987, GO:0050896</answer>
|
| 32 |
+
"""
|
| 33 |
+
if self.user_prompt:
|
| 34 |
+
prompt += self.user_prompt
|
| 35 |
+
|
| 36 |
+
# extra可以返回原始feather字符串,也可以返回feather_vals
|
| 37 |
+
# 或 feather_raw
|
| 38 |
+
data_list.append((prot_seq, text_seq, prompt))
|
| 39 |
+
except Exception as e:
|
| 40 |
+
print(f"警告: 跳过有问题的行: {row},原因: {e}")
|
| 41 |
+
return data_list
|
| 42 |
+
|
| 43 |
+
def _build_text_vocab(self):
|
| 44 |
+
text2id = {}
|
| 45 |
+
for _, text_seq, _ in self.data_list:
|
| 46 |
+
if text_seq not in text2id:
|
| 47 |
+
text2id[text_seq] = len(text2id)
|
| 48 |
+
return text2id
|
| 49 |
+
|
| 50 |
+
def shuffle(self):
|
| 51 |
+
random.shuffle(self.data_list)
|
| 52 |
+
return self
|
| 53 |
+
|
| 54 |
+
def __len__(self):
|
| 55 |
+
return len(self.data_list)
|
| 56 |
+
|
| 57 |
+
def __getitem__(self, index):
|
| 58 |
+
prot_seq, text_seq, prompt = self.data_list[index]
|
| 59 |
+
if self.return_prompt:
|
| 60 |
+
return prot_seq, prompt, text_seq,index
|
| 61 |
+
return prot_seq, text_seq, index
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class GO_CC(Dataset):
|
| 65 |
+
def __init__(self, data_path, prompt='', return_prompt=False):
|
| 66 |
+
super(GO_CC, self).__init__()
|
| 67 |
+
self.data_path = data_path
|
| 68 |
+
self.user_prompt = prompt
|
| 69 |
+
self.return_prompt = return_prompt
|
| 70 |
+
|
| 71 |
+
self.data_list = self._load_and_preprocess(self.data_path)
|
| 72 |
+
self.text2id = self._build_text_vocab()
|
| 73 |
+
|
| 74 |
+
def _load_and_preprocess(self, data_path):
|
| 75 |
+
data_list = []
|
| 76 |
+
df = pd.read_csv(data_path)
|
| 77 |
+
for _, row in df.iterrows():
|
| 78 |
+
try:
|
| 79 |
+
prot_seq = str(row['question']).strip()
|
| 80 |
+
result = str(row['answer']).strip()
|
| 81 |
+
|
| 82 |
+
text_seq = f"<answer>{result}</answer>\n"
|
| 83 |
+
|
| 84 |
+
prompt = f"""
|
| 85 |
+
【Task】Predict the cellular components associated with this protein.
|
| 86 |
+
【Background】Each location is represented by a GO-CC term (e.g., GO:0005737), indicating where the protein functions within the cell.
|
| 87 |
+
【Output Format】List the predicted GO-CC terms, separated by commas, and wrap them in <answer> </answer> tags.
|
| 88 |
+
Example: <answer>GO:0005737, GO:0005829, GO:0005886</answer>
|
| 89 |
+
"""
|
| 90 |
+
if self.user_prompt:
|
| 91 |
+
prompt += self.user_prompt
|
| 92 |
+
|
| 93 |
+
# extra可以返回原始feather字符串,也可以返回feather_vals
|
| 94 |
+
# 或 feather_raw
|
| 95 |
+
data_list.append((prot_seq, text_seq, prompt))
|
| 96 |
+
except Exception as e:
|
| 97 |
+
print(f"警告: 跳过有问题的行: {row},原因: {e}")
|
| 98 |
+
return data_list
|
| 99 |
+
|
| 100 |
+
def _build_text_vocab(self):
|
| 101 |
+
text2id = {}
|
| 102 |
+
for _, text_seq, _ in self.data_list:
|
| 103 |
+
if text_seq not in text2id:
|
| 104 |
+
text2id[text_seq] = len(text2id)
|
| 105 |
+
return text2id
|
| 106 |
+
|
| 107 |
+
def shuffle(self):
|
| 108 |
+
random.shuffle(self.data_list)
|
| 109 |
+
return self
|
| 110 |
+
|
| 111 |
+
def __len__(self):
|
| 112 |
+
return len(self.data_list)
|
| 113 |
+
|
| 114 |
+
def __getitem__(self, index):
|
| 115 |
+
prot_seq, text_seq, prompt = self.data_list[index]
|
| 116 |
+
if self.return_prompt:
|
| 117 |
+
return prot_seq, prompt, text_seq,index
|
| 118 |
+
return prot_seq, text_seq, index
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class GO_MF(Dataset):
|
| 122 |
+
def __init__(self, data_path, prompt='', return_prompt=False):
|
| 123 |
+
super(GO_MF, self).__init__()
|
| 124 |
+
self.data_path = data_path
|
| 125 |
+
self.user_prompt = prompt
|
| 126 |
+
self.return_prompt = return_prompt
|
| 127 |
+
|
| 128 |
+
self.data_list = self._load_and_preprocess(self.data_path)
|
| 129 |
+
self.text2id = self._build_text_vocab()
|
| 130 |
+
|
| 131 |
+
def _load_and_preprocess(self, data_path):
|
| 132 |
+
data_list = []
|
| 133 |
+
df = pd.read_csv(data_path)
|
| 134 |
+
for _, row in df.iterrows():
|
| 135 |
+
try:
|
| 136 |
+
prot_seq = str(row['question']).strip()
|
| 137 |
+
result = str(row['answer']).strip()
|
| 138 |
+
|
| 139 |
+
text_seq = f"<answer>{result}</answer>\n"
|
| 140 |
+
|
| 141 |
+
prompt = f"""
|
| 142 |
+
【Task】Predict the molecular functions performed by this protein.
|
| 143 |
+
【Background】Each function is represented by a GO-MF term (e.g., GO:0003677), describing specific biochemical activities of the protein.
|
| 144 |
+
【Output Format】List the predicted GO-MF terms, separated by commas, and wrap them in <answer> </answer> tags.
|
| 145 |
+
Example: <answer>GO:0003677, GO:0005524, GO:0016787</answer>
|
| 146 |
+
"""
|
| 147 |
+
if self.user_prompt:
|
| 148 |
+
prompt += self.user_prompt
|
| 149 |
+
|
| 150 |
+
# extra可以返回原始feather字符串,也可以返回feather_vals
|
| 151 |
+
# 或 feather_raw
|
| 152 |
+
data_list.append((prot_seq, text_seq, prompt))
|
| 153 |
+
except Exception as e:
|
| 154 |
+
print(f"警告: 跳过有问题的行: {row},原因: {e}")
|
| 155 |
+
return data_list
|
| 156 |
+
|
| 157 |
+
def _build_text_vocab(self):
|
| 158 |
+
text2id = {}
|
| 159 |
+
for _, text_seq, _ in self.data_list:
|
| 160 |
+
if text_seq not in text2id:
|
| 161 |
+
text2id[text_seq] = len(text2id)
|
| 162 |
+
return text2id
|
| 163 |
+
|
| 164 |
+
def shuffle(self):
|
| 165 |
+
random.shuffle(self.data_list)
|
| 166 |
+
return self
|
| 167 |
+
|
| 168 |
+
def __len__(self):
|
| 169 |
+
return len(self.data_list)
|
| 170 |
+
|
| 171 |
+
def __getitem__(self, index):
|
| 172 |
+
prot_seq, text_seq, prompt = self.data_list[index]
|
| 173 |
+
if self.return_prompt:
|
| 174 |
+
return prot_seq, prompt, text_seq,index
|
| 175 |
+
return prot_seq, text_seq, index
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class EC(Dataset):
|
| 179 |
+
def __init__(self, data_path, prompt='', return_prompt=False):
|
| 180 |
+
super(EC, self).__init__()
|
| 181 |
+
self.data_path = data_path
|
| 182 |
+
self.user_prompt = prompt
|
| 183 |
+
self.return_prompt = return_prompt
|
| 184 |
+
|
| 185 |
+
self.data_list = self._load_and_preprocess(self.data_path)
|
| 186 |
+
self.text2id = self._build_text_vocab()
|
| 187 |
+
|
| 188 |
+
def _load_and_preprocess(self, data_path):
|
| 189 |
+
data_list = []
|
| 190 |
+
df = pd.read_csv(data_path)
|
| 191 |
+
for _, row in df.iterrows():
|
| 192 |
+
try:
|
| 193 |
+
name = str(row['name']).strip()
|
| 194 |
+
# 先按“-”分割,得到结构信息和 UniProt ID
|
| 195 |
+
structure_part, uniprot_id = name.split('-') # '3r7t_A', 'Q9PMG4'
|
| 196 |
+
|
| 197 |
+
# 再按“_”分割结构信息,得到 PDB ID 和链ID
|
| 198 |
+
pdb_id, chain_id = structure_part.split('_') # '3r7t', 'A'
|
| 199 |
+
prot_seq = str(row['aa_seq']).strip()
|
| 200 |
+
result = str(row['label']).strip()
|
| 201 |
+
|
| 202 |
+
text_seq = f"<answer>{result}</answer>\n"
|
| 203 |
+
|
| 204 |
+
prompt = f"""
|
| 205 |
+
The information provided above is protein information, one of the chains of its crystal structure {pdb_id}, named {chain_id}, and numbered {uniprot_id} in the Uniprot sequence database.
|
| 206 |
+
Based on this information, the possible enzyme activity is inferred and the corresponding EC number is predicted.
|
| 207 |
+
【Output Format】List predicted EC numbers, separated by commas, wrapped in <answer> </answer> tags.
|
| 208 |
+
"""
|
| 209 |
+
if self.user_prompt:
|
| 210 |
+
prompt += self.user_prompt
|
| 211 |
+
|
| 212 |
+
# extra可以返回原始feather字符串,也可以返回feather_vals
|
| 213 |
+
# 或 feather_raw
|
| 214 |
+
data_list.append((prot_seq, text_seq, prompt))
|
| 215 |
+
except Exception as e:
|
| 216 |
+
print(f"警告: 跳过有问题的行: {row},原因: {e}")
|
| 217 |
+
return data_list
|
| 218 |
+
|
| 219 |
+
def _build_text_vocab(self):
|
| 220 |
+
text2id = {}
|
| 221 |
+
for _, text_seq, _ in self.data_list:
|
| 222 |
+
if text_seq not in text2id:
|
| 223 |
+
text2id[text_seq] = len(text2id)
|
| 224 |
+
return text2id
|
| 225 |
+
|
| 226 |
+
def shuffle(self):
|
| 227 |
+
random.shuffle(self.data_list)
|
| 228 |
+
return self
|
| 229 |
+
|
| 230 |
+
def __len__(self):
|
| 231 |
+
return len(self.data_list)
|
| 232 |
+
|
| 233 |
+
def __getitem__(self, index):
|
| 234 |
+
prot_seq, text_seq, prompt = self.data_list[index]
|
| 235 |
+
if self.return_prompt:
|
| 236 |
+
return prot_seq, prompt, text_seq,index
|
| 237 |
+
return prot_seq, text_seq, index
|
data_provider/llm_tuning_dm.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# Licensed under the MIT License.
|
| 3 |
+
from pytorch_lightning import LightningDataModule
|
| 4 |
+
from data_provider.gal_helpers import escape_custom_split_sequence
|
| 5 |
+
from data_provider.stage1_dm import SwissProtDataset, OntoProteinDataset
|
| 6 |
+
from torch.utils.data import DataLoader, ConcatDataset
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class LLMTuningCollater:
|
| 10 |
+
def __init__(self, tokenizer, text_max_len, prot_max_len, use_gal):
|
| 11 |
+
self.text_max_len = text_max_len
|
| 12 |
+
self.prot_max_len = prot_max_len
|
| 13 |
+
self.tokenizer = tokenizer
|
| 14 |
+
self.use_gal = use_gal
|
| 15 |
+
|
| 16 |
+
def __call__(self, batch):
|
| 17 |
+
prot_seqs, prompt_seqs, text_seqs, _ = zip(*batch)
|
| 18 |
+
prot_seqs = [prompt.format(p) for prompt, p in zip(prompt_seqs, prot_seqs)]
|
| 19 |
+
if self.use_gal:
|
| 20 |
+
prot_seqs = [escape_custom_split_sequence(p) for p in prot_seqs]
|
| 21 |
+
## deal with prompt
|
| 22 |
+
self.tokenizer.padding_side = 'left'
|
| 23 |
+
prot_batch = self.tokenizer(text=prot_seqs,
|
| 24 |
+
truncation=True,
|
| 25 |
+
padding='max_length',
|
| 26 |
+
add_special_tokens=True,
|
| 27 |
+
max_length=self.prot_max_len,
|
| 28 |
+
return_tensors='pt',
|
| 29 |
+
return_attention_mask=True)
|
| 30 |
+
self.tokenizer.padding_side = 'right'
|
| 31 |
+
text_batch = self.tokenizer(text=text_seqs,
|
| 32 |
+
truncation=True,
|
| 33 |
+
padding='max_length',
|
| 34 |
+
add_special_tokens=True,
|
| 35 |
+
max_length=self.text_max_len,
|
| 36 |
+
return_tensors='pt',
|
| 37 |
+
return_attention_mask=True)
|
| 38 |
+
return prot_batch, text_batch
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class InferenceCollater:
|
| 42 |
+
def __init__(self, tokenizer, text_max_len, prot_max_len, use_gal):
|
| 43 |
+
self.text_max_len = text_max_len
|
| 44 |
+
self.prot_max_len = prot_max_len
|
| 45 |
+
self.tokenizer = tokenizer
|
| 46 |
+
self.use_gal = use_gal
|
| 47 |
+
|
| 48 |
+
def __call__(self, batch):
|
| 49 |
+
prot_seqs, prompt_seqs, text_seqs, indices = zip(*batch)
|
| 50 |
+
prot_seqs = [prompt.format(p) for prompt, p in zip(prompt_seqs, prot_seqs)]
|
| 51 |
+
if self.use_gal:
|
| 52 |
+
prot_seqs = [escape_custom_split_sequence(p) for p in prot_seqs]
|
| 53 |
+
## deal with prompt
|
| 54 |
+
self.tokenizer.padding_side = 'left'
|
| 55 |
+
prot_batch = self.tokenizer(text=prot_seqs,
|
| 56 |
+
truncation=True,
|
| 57 |
+
padding='max_length',
|
| 58 |
+
add_special_tokens=True,
|
| 59 |
+
max_length=self.prot_max_len,
|
| 60 |
+
return_tensors='pt',
|
| 61 |
+
return_attention_mask=True)
|
| 62 |
+
target_dict = {'targets': text_seqs, 'indices': indices}
|
| 63 |
+
return prot_batch, target_dict
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class LLMTuningDM(LightningDataModule):
|
| 68 |
+
def __init__(
|
| 69 |
+
self,
|
| 70 |
+
root: str = 'data/',
|
| 71 |
+
args=None,
|
| 72 |
+
):
|
| 73 |
+
super().__init__()
|
| 74 |
+
self.batch_size = args.batch_size
|
| 75 |
+
self.inference_batch_size = args.inference_batch_size
|
| 76 |
+
self.num_workers = args.num_workers
|
| 77 |
+
self.prot_max_len = args.prot_max_len
|
| 78 |
+
self.text_max_len = args.text_max_len
|
| 79 |
+
if root.find('SwissProtV3') >= 0:
|
| 80 |
+
self.train_dataset = SwissProtDataset(root+'/train_set.json', prompt='[START_AMINO]{}[END_AMINO]. Swiss-Prot description: ', return_prompt=True)
|
| 81 |
+
self.val_dataset = SwissProtDataset(root+'/valid_set.json', prompt='[START_AMINO]{}[END_AMINO]. Swiss-Prot description: ', return_prompt=True)
|
| 82 |
+
self.test_dataset = SwissProtDataset(root+'/test_set.json', prompt='[START_AMINO]{}[END_AMINO]. Swiss-Prot description: ', return_prompt=True)
|
| 83 |
+
elif root.find('OntoProteinDatasetV2') >= 0:
|
| 84 |
+
self.train_dataset = OntoProteinDataset(root+'/train.txt', prompt='[START_AMINO]{}[END_AMINO]. Gene Ontology description: ', return_prompt=True)
|
| 85 |
+
self.val_dataset = OntoProteinDataset(root+'/valid.txt', prompt='[START_AMINO]{}[END_AMINO]. Gene Ontology description: ', return_prompt=True)
|
| 86 |
+
self.test_dataset = OntoProteinDataset(root+'/test.txt', prompt='[START_AMINO]{}[END_AMINO]. Gene Ontology description: ', return_prompt=True)
|
| 87 |
+
else:
|
| 88 |
+
raise NotImplementedError()
|
| 89 |
+
self.tokenizer = None
|
| 90 |
+
self.use_gal = args.llm_name.find('gal') >= 0
|
| 91 |
+
|
| 92 |
+
def init_tokenizer(self, tokenizer):
|
| 93 |
+
self.tokenizer = tokenizer
|
| 94 |
+
|
| 95 |
+
def train_dataloader(self):
|
| 96 |
+
loader = DataLoader(
|
| 97 |
+
self.train_dataset,
|
| 98 |
+
batch_size=self.batch_size,
|
| 99 |
+
shuffle=True,
|
| 100 |
+
num_workers=self.num_workers,
|
| 101 |
+
pin_memory=False,
|
| 102 |
+
drop_last=True,
|
| 103 |
+
persistent_workers=False,
|
| 104 |
+
collate_fn=LLMTuningCollater(self.tokenizer, self.text_max_len, self.prot_max_len, self.use_gal),
|
| 105 |
+
)
|
| 106 |
+
return loader
|
| 107 |
+
|
| 108 |
+
def val_dataloader(self):
|
| 109 |
+
val_loader = DataLoader(
|
| 110 |
+
self.val_dataset,
|
| 111 |
+
batch_size=self.batch_size,
|
| 112 |
+
shuffle=False,
|
| 113 |
+
num_workers=self.num_workers,
|
| 114 |
+
pin_memory=False,
|
| 115 |
+
drop_last=False,
|
| 116 |
+
persistent_workers=False,
|
| 117 |
+
collate_fn=LLMTuningCollater(self.tokenizer, self.text_max_len, self.prot_max_len, self.use_gal),
|
| 118 |
+
)
|
| 119 |
+
test_loader = DataLoader(
|
| 120 |
+
self.test_dataset,
|
| 121 |
+
batch_size=self.inference_batch_size,
|
| 122 |
+
shuffle=False,
|
| 123 |
+
num_workers=self.num_workers,
|
| 124 |
+
pin_memory=False,
|
| 125 |
+
drop_last=False,
|
| 126 |
+
persistent_workers=False,
|
| 127 |
+
collate_fn=InferenceCollater(self.tokenizer, self.text_max_len, self.prot_max_len, self.use_gal),
|
| 128 |
+
)
|
| 129 |
+
return [val_loader, test_loader]
|
| 130 |
+
|
| 131 |
+
def add_model_specific_args(parent_parser):
|
| 132 |
+
parser = parent_parser.add_argument_group("Data module")
|
| 133 |
+
parser.add_argument('--num_workers', type=int, default=2)
|
| 134 |
+
parser.add_argument('--batch_size', type=int, default=32)
|
| 135 |
+
parser.add_argument('--inference_batch_size', type=int, default=4)
|
| 136 |
+
parser.add_argument('--root', type=str, default='data/SwissProtV3')
|
| 137 |
+
parser.add_argument('--text_max_len', type=int, default=128)
|
| 138 |
+
parser.add_argument('--prot_max_len', type=int, default=1024)
|
| 139 |
+
parser.add_argument('--q_max_len', type=int, default=1064)
|
| 140 |
+
parser.add_argument('--a_max_len', type=int, default=36)
|
| 141 |
+
parser.add_argument('--prompt', type=str, default='[START_AMINO]{}[END_AMINO]. The protein has the following properties: ')
|
| 142 |
+
parser.add_argument('--filter_side_qa', action='store_true', default=False)
|
| 143 |
+
return parent_parser
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class LLMTuningMixDM(LightningDataModule):
|
| 147 |
+
def __init__(
|
| 148 |
+
self,
|
| 149 |
+
root: str = 'data/',
|
| 150 |
+
args=None,
|
| 151 |
+
):
|
| 152 |
+
super().__init__()
|
| 153 |
+
self.batch_size = args.batch_size
|
| 154 |
+
self.inference_batch_size = args.inference_batch_size
|
| 155 |
+
self.num_workers = args.num_workers
|
| 156 |
+
self.prot_max_len = args.prot_max_len
|
| 157 |
+
self.text_max_len = args.text_max_len
|
| 158 |
+
|
| 159 |
+
train_dataset1 = SwissProtDataset(root+'/SwissProtV3/train_set.json', prompt='[START_AMINO]{}[END_AMINO]. Swiss-Prot description: ', return_prompt=True)
|
| 160 |
+
train_dataset2 = OntoProteinDataset(root+'/OntoProteinDatasetV2/train.txt', prompt='[START_AMINO]{}[END_AMINO]. Gene Ontology description: ', return_prompt=True)
|
| 161 |
+
self.train_dataset = ConcatDataset([train_dataset1, train_dataset2])
|
| 162 |
+
self.swiss_val_dataset = SwissProtDataset(root+'/SwissProtV3/valid_set.json', prompt='[START_AMINO]{}[END_AMINO]. Swiss-Prot description: ', return_prompt=True)
|
| 163 |
+
self.onto_val_dataset = OntoProteinDataset(root+'/OntoProteinDatasetV2/valid.txt', prompt='[START_AMINO]{}[END_AMINO]. Gene Ontology description: ', return_prompt=True)
|
| 164 |
+
self.swiss_test_dataset = SwissProtDataset(root+'/SwissProtV3/test_set.json', prompt='[START_AMINO]{}[END_AMINO]. Swiss-Prot description: ', return_prompt=True)
|
| 165 |
+
self.onto_test_dataset = OntoProteinDataset(root+'/OntoProteinDatasetV2/test.txt', prompt='[START_AMINO]{}[END_AMINO]. Gene Ontology description: ', return_prompt=True)
|
| 166 |
+
|
| 167 |
+
self.tokenizer = None
|
| 168 |
+
self.use_gal = args.llm_name.find('gal') >= 0
|
| 169 |
+
|
| 170 |
+
def init_tokenizer(self, tokenizer):
|
| 171 |
+
self.tokenizer = tokenizer
|
| 172 |
+
|
| 173 |
+
def train_dataloader(self):
|
| 174 |
+
loader = DataLoader(
|
| 175 |
+
self.train_dataset,
|
| 176 |
+
batch_size=self.batch_size,
|
| 177 |
+
shuffle=True,
|
| 178 |
+
num_workers=self.num_workers,
|
| 179 |
+
pin_memory=False,
|
| 180 |
+
drop_last=True,
|
| 181 |
+
persistent_workers=False,
|
| 182 |
+
collate_fn=LLMTuningCollater(self.tokenizer, self.text_max_len, self.prot_max_len, self.use_gal),
|
| 183 |
+
)
|
| 184 |
+
return loader
|
| 185 |
+
|
| 186 |
+
def val_dataloader(self):
|
| 187 |
+
swiss_val_loader = DataLoader(
|
| 188 |
+
self.swiss_val_dataset,
|
| 189 |
+
batch_size=self.batch_size,
|
| 190 |
+
shuffle=False,
|
| 191 |
+
num_workers=self.num_workers,
|
| 192 |
+
pin_memory=False,
|
| 193 |
+
drop_last=False,
|
| 194 |
+
persistent_workers=False,
|
| 195 |
+
collate_fn=LLMTuningCollater(self.tokenizer, self.text_max_len, self.prot_max_len, self.use_gal),
|
| 196 |
+
)
|
| 197 |
+
swiss_test_loader = DataLoader(
|
| 198 |
+
self.swiss_test_dataset,
|
| 199 |
+
batch_size=self.inference_batch_size,
|
| 200 |
+
shuffle=False,
|
| 201 |
+
num_workers=self.num_workers,
|
| 202 |
+
pin_memory=False,
|
| 203 |
+
drop_last=False,
|
| 204 |
+
persistent_workers=False,
|
| 205 |
+
collate_fn=InferenceCollater(self.tokenizer, self.text_max_len, self.prot_max_len, self.use_gal),
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
onto_val_loader = DataLoader(
|
| 209 |
+
self.onto_val_dataset,
|
| 210 |
+
batch_size=self.batch_size,
|
| 211 |
+
shuffle=False,
|
| 212 |
+
num_workers=self.num_workers,
|
| 213 |
+
pin_memory=False,
|
| 214 |
+
drop_last=False,
|
| 215 |
+
persistent_workers=False,
|
| 216 |
+
collate_fn=LLMTuningCollater(self.tokenizer, self.text_max_len, self.prot_max_len, self.use_gal),
|
| 217 |
+
)
|
| 218 |
+
onto_test_loader = DataLoader(
|
| 219 |
+
self.onto_test_dataset,
|
| 220 |
+
batch_size=self.inference_batch_size,
|
| 221 |
+
shuffle=False,
|
| 222 |
+
num_workers=self.num_workers,
|
| 223 |
+
pin_memory=False,
|
| 224 |
+
drop_last=False,
|
| 225 |
+
persistent_workers=False,
|
| 226 |
+
collate_fn=InferenceCollater(self.tokenizer, self.text_max_len, self.prot_max_len, self.use_gal),
|
| 227 |
+
)
|
| 228 |
+
return [swiss_val_loader, swiss_test_loader, onto_val_loader, onto_test_loader]
|
| 229 |
+
|
| 230 |
+
def add_model_specific_args(parent_parser):
|
| 231 |
+
parser = parent_parser.add_argument_group("Data module")
|
| 232 |
+
parser.add_argument('--num_workers', type=int, default=2)
|
| 233 |
+
parser.add_argument('--batch_size', type=int, default=32)
|
| 234 |
+
parser.add_argument('--inference_batch_size', type=int, default=4)
|
| 235 |
+
parser.add_argument('--root', type=str, default='data/SwissProtV3')
|
| 236 |
+
parser.add_argument('--text_max_len', type=int, default=128)
|
| 237 |
+
parser.add_argument('--prot_max_len', type=int, default=1024)
|
| 238 |
+
parser.add_argument('--q_max_len', type=int, default=1064)
|
| 239 |
+
parser.add_argument('--a_max_len', type=int, default=36)
|
| 240 |
+
parser.add_argument('--prompt', type=str, default='[START_AMINO]{}[END_AMINO]. The protein has the following properties: ')
|
| 241 |
+
parser.add_argument('--filter_side_qa', action='store_true', default=False)
|
| 242 |
+
return parent_parser
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
if __name__ == '__main__':
|
| 246 |
+
dataset = SwissProtDataset('../data/SwissProtV3/train_set.json')
|
| 247 |
+
from transformers import AutoTokenizer
|
| 248 |
+
tokenizer = AutoTokenizer.from_pretrained('facebook/galactica-1.3b')
|
| 249 |
+
tokenizer.add_special_tokens({'pad_token': '<pad>'})
|
| 250 |
+
loader = DataLoader(
|
| 251 |
+
dataset,
|
| 252 |
+
batch_size=16,
|
| 253 |
+
shuffle=True,
|
| 254 |
+
num_workers=0,
|
| 255 |
+
pin_memory=False,
|
| 256 |
+
drop_last=True,
|
| 257 |
+
persistent_workers=False,
|
| 258 |
+
collate_fn=LLMTuningCollater(tokenizer, 128, 1024, True, '[START_AMINO]{}[END_AMINO].'),
|
| 259 |
+
)
|
| 260 |
+
for data in loader:
|
| 261 |
+
input()
|
data_provider/llm_tuning_prot_qa_dm.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# Licensed under the MIT License.
|
| 3 |
+
from pytorch_lightning import LightningDataModule
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
from data_provider.prot_qa_dm import PDBQADataset
|
| 6 |
+
from data_provider.gal_helpers import escape_custom_split_sequence
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class LLMTuningProtQACollater(object):
|
| 10 |
+
def __init__(self, tokenizer, q_max_len, a_max_len, use_gal, prompt):
|
| 11 |
+
self.tokenizer = tokenizer
|
| 12 |
+
self.q_max_len = q_max_len
|
| 13 |
+
self.a_max_len = a_max_len
|
| 14 |
+
self.use_gal = use_gal
|
| 15 |
+
self.prompt = prompt
|
| 16 |
+
assert prompt.find('{}') >= 0
|
| 17 |
+
|
| 18 |
+
def __call__(self, batch):
|
| 19 |
+
prot_seqs, questions, answers, q_types = zip(*batch)
|
| 20 |
+
assert len(prot_seqs) == len(questions) == len(answers)
|
| 21 |
+
questions = [self.prompt.format(prot_seqs[i], questions[i]) for i in range(len(prot_seqs))]
|
| 22 |
+
|
| 23 |
+
if self.use_gal:
|
| 24 |
+
questions = [escape_custom_split_sequence(q) for q in questions]
|
| 25 |
+
answers = [a + '\n' for a in answers]
|
| 26 |
+
if False:
|
| 27 |
+
self.tokenizer.padding_side = 'left'
|
| 28 |
+
q_batch = self.tokenizer(questions,
|
| 29 |
+
truncation=True,
|
| 30 |
+
padding='max_length',
|
| 31 |
+
add_special_tokens=True,
|
| 32 |
+
max_length=self.q_max_len,
|
| 33 |
+
return_tensors='pt',
|
| 34 |
+
return_attention_mask=True,
|
| 35 |
+
return_token_type_ids=False)
|
| 36 |
+
self.tokenizer.padding_side = 'right'
|
| 37 |
+
a_batch = self.tokenizer(answers,
|
| 38 |
+
truncation=True,
|
| 39 |
+
padding='max_length',
|
| 40 |
+
add_special_tokens=True,
|
| 41 |
+
max_length=self.a_max_len,
|
| 42 |
+
return_tensors='pt',
|
| 43 |
+
return_attention_mask=True,
|
| 44 |
+
return_token_type_ids=False)
|
| 45 |
+
return q_batch, a_batch
|
| 46 |
+
else:
|
| 47 |
+
self.tokenizer.padding_side = 'right'
|
| 48 |
+
qa_pair = [[q, a] for q, a in zip(questions, answers)]
|
| 49 |
+
qa_batch = self.tokenizer(qa_pair,
|
| 50 |
+
truncation=True,
|
| 51 |
+
padding='max_length',
|
| 52 |
+
add_special_tokens=True,
|
| 53 |
+
max_length=self.q_max_len + self.a_max_len,
|
| 54 |
+
return_tensors='pt',
|
| 55 |
+
return_attention_mask=True,
|
| 56 |
+
return_token_type_ids=True)
|
| 57 |
+
return qa_batch
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class InferenceCollater(object):
|
| 61 |
+
def __init__(self, tokenizer, q_max_len, a_max_len, use_gal, prompt):
|
| 62 |
+
self.tokenizer = tokenizer
|
| 63 |
+
self.q_max_len = q_max_len
|
| 64 |
+
self.a_max_len = a_max_len
|
| 65 |
+
self.use_gal = use_gal
|
| 66 |
+
self.prompt = prompt
|
| 67 |
+
assert prompt.find('{}') >= 0
|
| 68 |
+
|
| 69 |
+
def __call__(self, batch):
|
| 70 |
+
prot_seqs, questions, answers, q_types, indices = zip(*batch)
|
| 71 |
+
assert len(prot_seqs) == len(questions) == len(answers)
|
| 72 |
+
questions = [self.prompt.format(prot_seqs[i], questions[i]) for i in range(len(prot_seqs))]
|
| 73 |
+
|
| 74 |
+
if self.use_gal:
|
| 75 |
+
questions = [escape_custom_split_sequence(q) for q in questions]
|
| 76 |
+
answers = [a + '\n' for a in answers]
|
| 77 |
+
self.tokenizer.padding_side = 'left'
|
| 78 |
+
q_batch = self.tokenizer(questions,
|
| 79 |
+
truncation=True,
|
| 80 |
+
padding='max_length',
|
| 81 |
+
add_special_tokens=True,
|
| 82 |
+
max_length=self.q_max_len,
|
| 83 |
+
return_tensors='pt',
|
| 84 |
+
return_attention_mask=True,
|
| 85 |
+
return_token_type_ids=False)
|
| 86 |
+
target_dict = {'targets': answers, 'q_types': q_types, 'indices': indices}
|
| 87 |
+
return q_batch, target_dict
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class LLMTuningProtQADM(LightningDataModule):
|
| 91 |
+
def __init__(
|
| 92 |
+
self,
|
| 93 |
+
root: str = 'data/',
|
| 94 |
+
args=None,
|
| 95 |
+
):
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.args = args
|
| 98 |
+
self.batch_size = args.batch_size
|
| 99 |
+
self.inference_batch_size = args.inference_batch_size
|
| 100 |
+
self.num_workers = args.num_workers
|
| 101 |
+
self.q_max_len = args.q_max_len
|
| 102 |
+
self.a_max_len = args.a_max_len
|
| 103 |
+
self.prompt = args.prompt
|
| 104 |
+
|
| 105 |
+
self.train_dataset = PDBQADataset(root, 'train.txt', "Question: {} Answer:", filter_side_qa=args.filter_side_qa)
|
| 106 |
+
self.val_dataset = PDBQADataset(root, 'val.txt', "Question: {} Answer:", filter_side_qa=args.filter_side_qa)
|
| 107 |
+
self.test_dataset = PDBQADataset(root, 'test.txt', "Question: {} Answer:", filter_side_qa=args.filter_side_qa)
|
| 108 |
+
|
| 109 |
+
self.tokenizer = None
|
| 110 |
+
self.use_gal = args.llm_name.find('gal') >= 0
|
| 111 |
+
|
| 112 |
+
def init_tokenizer(self, tokenizer):
|
| 113 |
+
self.tokenizer = tokenizer
|
| 114 |
+
|
| 115 |
+
def train_dataloader(self):
|
| 116 |
+
loader = DataLoader(
|
| 117 |
+
self.train_dataset,
|
| 118 |
+
batch_size=self.batch_size,
|
| 119 |
+
shuffle=True,
|
| 120 |
+
num_workers=self.num_workers,
|
| 121 |
+
pin_memory=False,
|
| 122 |
+
drop_last=True,
|
| 123 |
+
persistent_workers=False,
|
| 124 |
+
collate_fn=LLMTuningProtQACollater(self.tokenizer, self.q_max_len, self.a_max_len, self.use_gal, self.prompt),
|
| 125 |
+
)
|
| 126 |
+
return loader
|
| 127 |
+
|
| 128 |
+
def val_dataloader(self):
|
| 129 |
+
val_loader = DataLoader(
|
| 130 |
+
self.val_dataset,
|
| 131 |
+
batch_size=self.batch_size,
|
| 132 |
+
shuffle=False,
|
| 133 |
+
num_workers=self.num_workers,
|
| 134 |
+
pin_memory=False,
|
| 135 |
+
drop_last=False,
|
| 136 |
+
persistent_workers=False,
|
| 137 |
+
collate_fn=LLMTuningProtQACollater(self.tokenizer, self.q_max_len, self.a_max_len, self.use_gal, self.prompt),
|
| 138 |
+
)
|
| 139 |
+
test_loader = DataLoader(
|
| 140 |
+
self.test_dataset,
|
| 141 |
+
batch_size=self.inference_batch_size,
|
| 142 |
+
shuffle=False,
|
| 143 |
+
num_workers=self.num_workers,
|
| 144 |
+
pin_memory=False,
|
| 145 |
+
drop_last=False,
|
| 146 |
+
persistent_workers=False,
|
| 147 |
+
collate_fn=InferenceCollater(self.tokenizer, self.q_max_len, self.a_max_len, self.use_gal, self.prompt),
|
| 148 |
+
)
|
| 149 |
+
return [val_loader, test_loader]
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def add_model_specific_args(parent_parser):
|
| 153 |
+
parser = parent_parser.add_argument_group("Data module")
|
| 154 |
+
parser.add_argument('--num_workers', type=int, default=2)
|
| 155 |
+
parser.add_argument('--batch_size', type=int, default=32)
|
| 156 |
+
parser.add_argument('--inference_batch_size', type=int, default=4)
|
| 157 |
+
parser.add_argument('--root', type=str, default='data/SwissProtV3')
|
| 158 |
+
parser.add_argument('--q_max_len', type=int, default=1064)
|
| 159 |
+
parser.add_argument('--a_max_len', type=int, default=36)
|
| 160 |
+
parser.add_argument('--prompt', type=str, default='[START_AMINO]{}[END_AMINO]. {}')
|
| 161 |
+
parser.add_argument('--filter_side_qa', action='store_true', default=False)
|
| 162 |
+
return parent_parser
|
| 163 |
+
|
| 164 |
+
|
data_provider/metalIonbinding.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pytorch_lightning import LightningDataModule
|
| 3 |
+
from torch.utils.data import DataLoader, ConcatDataset, Dataset
|
| 4 |
+
from data_provider.stage1_dm import SwissProtDataset, OntoProteinDataset
|
| 5 |
+
import pandas as pd
|
| 6 |
+
|
| 7 |
+
class MetallonBinding(Dataset):
|
| 8 |
+
def __init__(self, data_path, prompt='', return_prompt=False):
|
| 9 |
+
super(MetallonBinding, self).__init__()
|
| 10 |
+
self.data_path = data_path
|
| 11 |
+
self.user_prompt = prompt
|
| 12 |
+
self.return_prompt = return_prompt
|
| 13 |
+
|
| 14 |
+
self.data_list = self._load_and_preprocess(self.data_path)
|
| 15 |
+
self.text2id = self._build_text_vocab()
|
| 16 |
+
|
| 17 |
+
def _load_and_preprocess(self, data_path):
|
| 18 |
+
data_list = []
|
| 19 |
+
df = pd.read_csv(data_path)
|
| 20 |
+
for _, row in df.iterrows():
|
| 21 |
+
try:
|
| 22 |
+
name = str(row['name']).strip()
|
| 23 |
+
prot_seq = str(row['aa_seq']).strip()
|
| 24 |
+
result = str(int(row['label'])).strip()
|
| 25 |
+
|
| 26 |
+
text_seq = f"<answer>{result}</answer>\n"
|
| 27 |
+
|
| 28 |
+
prompt = f"""
|
| 29 |
+
Task: Determine whether this protein is a metalloprotein based on the provided sequence and protein name {name}.
|
| 30 |
+
Background: Metalloproteins are proteins that bind metal ions, often through specific amino acid residues such as histidine (H), cysteine (C), aspartate (D), or glutamate (E).
|
| 31 |
+
Question: Does this protein bind metal ions? Please choose one of the following options:
|
| 32 |
+
0: Non-metalloprotein — This protein does **not** bind to any metal ions.
|
| 33 |
+
1: Metalloprotein — This protein **binds** to one or more metal ions.
|
| 34 |
+
"""
|
| 35 |
+
if self.user_prompt:
|
| 36 |
+
prompt += self.user_prompt
|
| 37 |
+
|
| 38 |
+
# extra可以返回原始feather字符串,也可以返回feather_vals
|
| 39 |
+
# 或 feather_raw
|
| 40 |
+
data_list.append((prot_seq, text_seq, prompt))
|
| 41 |
+
except Exception as e:
|
| 42 |
+
print(f"警告: 跳过有问题的行: {row},原因: {e}")
|
| 43 |
+
return data_list
|
| 44 |
+
|
| 45 |
+
def _build_text_vocab(self):
|
| 46 |
+
text2id = {}
|
| 47 |
+
for _, text_seq, _ in self.data_list:
|
| 48 |
+
if text_seq not in text2id:
|
| 49 |
+
text2id[text_seq] = len(text2id)
|
| 50 |
+
return text2id
|
| 51 |
+
|
| 52 |
+
def shuffle(self):
|
| 53 |
+
random.shuffle(self.data_list)
|
| 54 |
+
return self
|
| 55 |
+
|
| 56 |
+
def __len__(self):
|
| 57 |
+
return len(self.data_list)
|
| 58 |
+
|
| 59 |
+
def __getitem__(self, index):
|
| 60 |
+
prot_seq, text_seq, prompt = self.data_list[index]
|
| 61 |
+
if self.return_prompt:
|
| 62 |
+
return prot_seq, prompt, text_seq,index
|
| 63 |
+
return prot_seq, text_seq, index
|
data_provider/mutation.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pytorch_lightning import LightningDataModule
|
| 3 |
+
from torch.utils.data import DataLoader, ConcatDataset, Dataset
|
| 4 |
+
from data_provider.stage1_dm import SwissProtDataset, OntoProteinDataset
|
| 5 |
+
import pandas as pd
|
| 6 |
+
|
| 7 |
+
class TAPE_Stability(Dataset):
|
| 8 |
+
def __init__(self, data_path, prompt='', return_prompt=False):
|
| 9 |
+
super(TAPE_Stability, self).__init__()
|
| 10 |
+
self.data_path = data_path
|
| 11 |
+
self.user_prompt = prompt
|
| 12 |
+
self.return_prompt = return_prompt
|
| 13 |
+
|
| 14 |
+
self.data_list = self._load_and_preprocess(self.data_path)
|
| 15 |
+
self.text2id = self._build_text_vocab()
|
| 16 |
+
|
| 17 |
+
def _load_and_preprocess(self, data_path):
|
| 18 |
+
data_list = []
|
| 19 |
+
df = pd.read_csv(data_path)
|
| 20 |
+
for _, row in df.iterrows():
|
| 21 |
+
try:
|
| 22 |
+
|
| 23 |
+
prot_seq = str(row['aa_seq']).strip()
|
| 24 |
+
result = str(row['label']).strip()
|
| 25 |
+
|
| 26 |
+
text_seq = f"<answer>{result}</answer>\n"
|
| 27 |
+
|
| 28 |
+
prompt = """
|
| 29 |
+
【Task】Predict the thermostability score of the given protein sequence, which reflects its ability to maintain proper folding above a concentration threshold.
|
| 30 |
+
【Background】Protein stability is an important biophysical property indicating a protein’s resistance to denaturation or unfolding under thermal or chemical stress. In this task, each protein is evaluated by a numerical stability score, where higher values indicate greater ability to remain folded under extreme conditions. This score serves as a proxy for the protein’s intrinsic stability.
|
| 31 |
+
【Question】What is the predicted stability score for this sequence?
|
| 32 |
+
【Output Format】You must return only the score number, wrapped in <answer></answer> tags.
|
| 33 |
+
"""
|
| 34 |
+
if self.user_prompt:
|
| 35 |
+
prompt += self.user_prompt
|
| 36 |
+
|
| 37 |
+
# extra可以返回原始feather字符串,也可以返回feather_vals
|
| 38 |
+
# 或 feather_raw
|
| 39 |
+
data_list.append((prot_seq, text_seq, prompt))
|
| 40 |
+
except Exception as e:
|
| 41 |
+
print(f"警告: 跳过有问题的行: {row},原因: {e}")
|
| 42 |
+
return data_list
|
| 43 |
+
|
| 44 |
+
def _build_text_vocab(self):
|
| 45 |
+
text2id = {}
|
| 46 |
+
for _, text_seq, _ in self.data_list:
|
| 47 |
+
if text_seq not in text2id:
|
| 48 |
+
text2id[text_seq] = len(text2id)
|
| 49 |
+
return text2id
|
| 50 |
+
|
| 51 |
+
def shuffle(self):
|
| 52 |
+
random.shuffle(self.data_list)
|
| 53 |
+
return self
|
| 54 |
+
|
| 55 |
+
def __len__(self):
|
| 56 |
+
return len(self.data_list)
|
| 57 |
+
|
| 58 |
+
def __getitem__(self, index):
|
| 59 |
+
prot_seq, text_seq, prompt = self.data_list[index]
|
| 60 |
+
if self.return_prompt:
|
| 61 |
+
return prot_seq, prompt, text_seq,index
|
| 62 |
+
return prot_seq, text_seq, index
|
| 63 |
+
|
| 64 |
+
class TAPE_Fluorescence(Dataset):
|
| 65 |
+
def __init__(self, data_path, prompt='', return_prompt=False):
|
| 66 |
+
super(TAPE_Fluorescence, self).__init__()
|
| 67 |
+
self.data_path = data_path
|
| 68 |
+
self.user_prompt = prompt
|
| 69 |
+
self.return_prompt = return_prompt
|
| 70 |
+
|
| 71 |
+
self.data_list = self._load_and_preprocess(self.data_path)
|
| 72 |
+
self.text2id = self._build_text_vocab()
|
| 73 |
+
|
| 74 |
+
def _load_and_preprocess(self, data_path):
|
| 75 |
+
data_list = []
|
| 76 |
+
df = pd.read_csv(data_path)
|
| 77 |
+
for _, row in df.iterrows():
|
| 78 |
+
try:
|
| 79 |
+
|
| 80 |
+
prot_seq = str(row['aa_seq']).strip()
|
| 81 |
+
result = str(row['label']).strip()
|
| 82 |
+
|
| 83 |
+
text_seq = f"<answer>{result}</answer>\n"
|
| 84 |
+
|
| 85 |
+
prompt = """
|
| 86 |
+
【Task】Predict the log fluorescence intensity of the given protein sequence.
|
| 87 |
+
【Output Format】You must return only the numerical value, wrapped in <answer></answer> tags.
|
| 88 |
+
"""
|
| 89 |
+
# 【Background】Fluorescence intensity reflects how strongly a protein emits light when excited by a specific wavelength. It is commonly measured in protein variants such as GFP (Green Fluorescent Protein) mutants. The log-transformed fluorescence value quantifies the brightness on a logarithmic scale. Mutations in the sequence can increase or decrease fluorescence intensity.
|
| 90 |
+
# 【Question】What is the predicted log fluorescence intensity for this sequence?
|
| 91 |
+
if self.user_prompt:
|
| 92 |
+
prompt += self.user_prompt
|
| 93 |
+
|
| 94 |
+
# extra可以返回原始feather字符串,也可以返回feather_vals
|
| 95 |
+
# 或 feather_raw
|
| 96 |
+
data_list.append((prot_seq, text_seq, prompt))
|
| 97 |
+
except Exception as e:
|
| 98 |
+
print(f"警告: 跳过有问题的行: {row},原因: {e}")
|
| 99 |
+
return data_list
|
| 100 |
+
|
| 101 |
+
def _build_text_vocab(self):
|
| 102 |
+
text2id = {}
|
| 103 |
+
for _, text_seq, _ in self.data_list:
|
| 104 |
+
if text_seq not in text2id:
|
| 105 |
+
text2id[text_seq] = len(text2id)
|
| 106 |
+
return text2id
|
| 107 |
+
|
| 108 |
+
def shuffle(self):
|
| 109 |
+
random.shuffle(self.data_list)
|
| 110 |
+
return self
|
| 111 |
+
|
| 112 |
+
def __len__(self):
|
| 113 |
+
return len(self.data_list)
|
| 114 |
+
|
| 115 |
+
def __getitem__(self, index):
|
| 116 |
+
prot_seq, text_seq, prompt = self.data_list[index]
|
| 117 |
+
if self.return_prompt:
|
| 118 |
+
return prot_seq, prompt, text_seq,index
|
| 119 |
+
return prot_seq, text_seq, index
|
data_provider/production.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import json
|
| 3 |
+
from pytorch_lightning import LightningDataModule
|
| 4 |
+
from torch.utils.data import DataLoader, ConcatDataset, Dataset
|
| 5 |
+
from data_provider.stage1_dm import SwissProtDataset, OntoProteinDataset
|
| 6 |
+
import pandas as pd
|
| 7 |
+
class Antibiotic_Resistance(Dataset):
|
| 8 |
+
def __init__(self, data_path, prompt='', return_prompt=False):
|
| 9 |
+
super(Antibiotic_Resistance, self).__init__()
|
| 10 |
+
self.data_path = data_path
|
| 11 |
+
self.user_prompt = prompt
|
| 12 |
+
self.return_prompt = return_prompt
|
| 13 |
+
|
| 14 |
+
self.data_list = self._load_and_preprocess(self.data_path)
|
| 15 |
+
self.text2id = self._build_text_vocab()
|
| 16 |
+
|
| 17 |
+
def _load_and_preprocess(self, data_path):
|
| 18 |
+
data_list = []
|
| 19 |
+
df = pd.read_csv(data_path)
|
| 20 |
+
for _, row in df.iterrows():
|
| 21 |
+
try:
|
| 22 |
+
prot_seq = str(row['aa_seq']).strip()
|
| 23 |
+
result = str(row['label']).strip()
|
| 24 |
+
text_seq = f"<answer>{result}</answer>\n"
|
| 25 |
+
|
| 26 |
+
prompt = f"""
|
| 27 |
+
【Task】Predict the antibiotic resistance class of the given protein.
|
| 28 |
+
【Background】Antibiotic resistance refers to the ability of bacteria or other microbes to resist the effects of antibiotics that were once effective against them. Each protein is associated with resistance to exactly one of 19 antibiotic classes.
|
| 29 |
+
|
| 30 |
+
【Prediction Goal】Based on the provided protein sequence, determine which single antibiotic class (from 1 to 19) this protein confers resistance to.
|
| 31 |
+
|
| 32 |
+
【Output Format】Return only one predicted resistance class (a number from 1 to 19), wrapped in <answer> </answer> tags.
|
| 33 |
+
Example: <answer>7</answer>
|
| 34 |
+
"""
|
| 35 |
+
if self.user_prompt:
|
| 36 |
+
prompt += self.user_prompt
|
| 37 |
+
|
| 38 |
+
# extra可以返回原始feather字符串,也可以返回feather_vals
|
| 39 |
+
# 或 feather_raw
|
| 40 |
+
data_list.append((prot_seq, text_seq, prompt))
|
| 41 |
+
except Exception as e:
|
| 42 |
+
print(f"警告: 跳过有问题的行: {row},原因: {e}")
|
| 43 |
+
return data_list
|
| 44 |
+
|
| 45 |
+
def _build_text_vocab(self):
|
| 46 |
+
text2id = {}
|
| 47 |
+
for _, text_seq, _ in self.data_list:
|
| 48 |
+
if text_seq not in text2id:
|
| 49 |
+
text2id[text_seq] = len(text2id)
|
| 50 |
+
return text2id
|
| 51 |
+
|
| 52 |
+
def shuffle(self):
|
| 53 |
+
random.shuffle(self.data_list)
|
| 54 |
+
return self
|
| 55 |
+
|
| 56 |
+
def __len__(self):
|
| 57 |
+
return len(self.data_list)
|
| 58 |
+
|
| 59 |
+
def __getitem__(self, index):
|
| 60 |
+
prot_seq, text_seq, prompt = self.data_list[index]
|
| 61 |
+
if self.return_prompt:
|
| 62 |
+
return prot_seq, prompt, text_seq,index
|
| 63 |
+
return prot_seq, text_seq, index
|
| 64 |
+
|
| 65 |
+
class Thermostability(Dataset):
|
| 66 |
+
def __init__(self, data_path, prompt='', return_prompt=False):
|
| 67 |
+
super(Thermostability, self).__init__()
|
| 68 |
+
self.data_path = data_path
|
| 69 |
+
self.user_prompt = prompt
|
| 70 |
+
self.return_prompt = return_prompt
|
| 71 |
+
|
| 72 |
+
self.data_list = self._load_and_preprocess(self.data_path)
|
| 73 |
+
self.text2id = self._build_text_vocab()
|
| 74 |
+
|
| 75 |
+
def _load_and_preprocess(self, data_path):
|
| 76 |
+
data_list = []
|
| 77 |
+
df = pd.read_csv(data_path)
|
| 78 |
+
for _, row in df.iterrows():
|
| 79 |
+
try:
|
| 80 |
+
#name = str(row['name']).strip()
|
| 81 |
+
prot_seq = str(row['aa_seq']).strip()
|
| 82 |
+
result = str(row['label']).strip()
|
| 83 |
+
text_seq = f"<answer>{result}</answer>\n"
|
| 84 |
+
|
| 85 |
+
prompt = f"""
|
| 86 |
+
【Task】Predict the thermostability value of the given protein.
|
| 87 |
+
【Background】Thermostability refers to the ability of a molecule to resist irreversible chemical or physical changes at high temperatures, such as decomposition or aggregation.
|
| 88 |
+
【Output Format】Provide the predicted thermostability as a numeric value (e.g., melting temperature in °C). Wrap your answer in <answer></answer> tags.
|
| 89 |
+
|
| 90 |
+
"""
|
| 91 |
+
if self.user_prompt:
|
| 92 |
+
prompt += self.user_prompt
|
| 93 |
+
|
| 94 |
+
# extra可以返回原始feather字符串,也可以返回feather_vals
|
| 95 |
+
# 或 feather_raw
|
| 96 |
+
data_list.append((prot_seq, text_seq, prompt))
|
| 97 |
+
except Exception as e:
|
| 98 |
+
print(f"警告: 跳过有问题的行: {row},原因: {e}")
|
| 99 |
+
return data_list
|
| 100 |
+
|
| 101 |
+
def _build_text_vocab(self):
|
| 102 |
+
text2id = {}
|
| 103 |
+
for _, text_seq, _ in self.data_list:
|
| 104 |
+
if text_seq not in text2id:
|
| 105 |
+
text2id[text_seq] = len(text2id)
|
| 106 |
+
return text2id
|
| 107 |
+
|
| 108 |
+
def shuffle(self):
|
| 109 |
+
random.shuffle(self.data_list)
|
| 110 |
+
return self
|
| 111 |
+
|
| 112 |
+
def __len__(self):
|
| 113 |
+
return len(self.data_list)
|
| 114 |
+
|
| 115 |
+
def __getitem__(self, index):
|
| 116 |
+
prot_seq, text_seq, prompt = self.data_list[index]
|
| 117 |
+
if self.return_prompt:
|
| 118 |
+
return prot_seq, prompt, text_seq,index
|
| 119 |
+
return prot_seq, text_seq, index
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class Material(Dataset):
|
| 124 |
+
def __init__(self, data_path, prompt='', return_prompt=False):
|
| 125 |
+
super(Material, self).__init__()
|
| 126 |
+
self.data_path = data_path
|
| 127 |
+
self.user_prompt = prompt
|
| 128 |
+
self.return_prompt = return_prompt
|
| 129 |
+
|
| 130 |
+
self.data_list = self._load_and_preprocess(self.data_path)
|
| 131 |
+
self.text2id = self._build_text_vocab()
|
| 132 |
+
|
| 133 |
+
def _load_and_preprocess(self, data_path):
|
| 134 |
+
data_list = []
|
| 135 |
+
df = pd.read_csv(data_path)
|
| 136 |
+
for _, row in df.iterrows():
|
| 137 |
+
try:
|
| 138 |
+
#name = str(row['name']).strip()
|
| 139 |
+
prot_seq = str(row['aa_seq']).strip()
|
| 140 |
+
result = str(row['label']).strip()
|
| 141 |
+
text_seq = f"<answer>{result}</answer>\n"
|
| 142 |
+
|
| 143 |
+
prompt = f"""
|
| 144 |
+
【Task】Determine whether the given material can be successfully produced.
|
| 145 |
+
【Background】In materials science, certain chemical compounds or materials may or may not be synthesizable (i.e., producible) under realistic experimental conditions. This task requires classifying whether the input material composition and structure allow for successful production. This is a binary classification problem.
|
| 146 |
+
【Question】Can this material be successfully produced?
|
| 147 |
+
【Output Format】Respond with either "1" or "0", and wrap your answer in <answer></answer> tags.
|
| 148 |
+
|
| 149 |
+
"""
|
| 150 |
+
if self.user_prompt:
|
| 151 |
+
prompt += self.user_prompt
|
| 152 |
+
|
| 153 |
+
# extra可以返回原始feather字符串,也可以返回feather_vals
|
| 154 |
+
# 或 feather_raw
|
| 155 |
+
data_list.append((prot_seq, text_seq, prompt))
|
| 156 |
+
except Exception as e:
|
| 157 |
+
print(f"警告: 跳过有问题的行: {row},原因: {e}")
|
| 158 |
+
return data_list
|
| 159 |
+
|
| 160 |
+
def _build_text_vocab(self):
|
| 161 |
+
text2id = {}
|
| 162 |
+
for _, text_seq, _ in self.data_list:
|
| 163 |
+
if text_seq not in text2id:
|
| 164 |
+
text2id[text_seq] = len(text2id)
|
| 165 |
+
return text2id
|
| 166 |
+
|
| 167 |
+
def shuffle(self):
|
| 168 |
+
random.shuffle(self.data_list)
|
| 169 |
+
return self
|
| 170 |
+
|
| 171 |
+
def __len__(self):
|
| 172 |
+
return len(self.data_list)
|
| 173 |
+
|
| 174 |
+
def __getitem__(self, index):
|
| 175 |
+
prot_seq, text_seq, prompt = self.data_list[index]
|
| 176 |
+
if self.return_prompt:
|
| 177 |
+
return prot_seq, prompt, text_seq,index
|
| 178 |
+
return prot_seq, text_seq, index
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class Clone(Dataset):
|
| 183 |
+
def __init__(self, data_path, prompt='', return_prompt=False):
|
| 184 |
+
super(Clone, self).__init__()
|
| 185 |
+
self.data_path = data_path
|
| 186 |
+
self.user_prompt = prompt
|
| 187 |
+
self.return_prompt = return_prompt
|
| 188 |
+
|
| 189 |
+
self.data_list = self._load_and_preprocess(self.data_path)
|
| 190 |
+
self.text2id = self._build_text_vocab()
|
| 191 |
+
|
| 192 |
+
def _load_and_preprocess(self, data_path):
|
| 193 |
+
data_list = []
|
| 194 |
+
df = pd.read_csv(data_path)
|
| 195 |
+
for _, row in df.iterrows():
|
| 196 |
+
try:
|
| 197 |
+
#name = str(row['name']).strip()
|
| 198 |
+
prot_seq = str(row['aa_seq']).strip()
|
| 199 |
+
result = str(row['label']).strip()
|
| 200 |
+
text_seq = f"<answer>{result}</answer>\n"
|
| 201 |
+
|
| 202 |
+
prompt = f"""
|
| 203 |
+
|
| 204 |
+
【Task】Determine whether the given protein sequence can be successfully cloned.
|
| 205 |
+
【Background】In molecular biology, cloning refers to the process of creating copies of a DNA or protein sequence. Some sequences can be challenging to clone due to their length, GC-content, secondary structures, or toxicity to the host. This task requires predicting whether the given protein sequence is likely to be successfully cloned. This is a binary classification problem.
|
| 206 |
+
【Question】Can this protein sequence be successfully cloned?
|
| 207 |
+
【Output Format】Respond with either "1" or "0", and wrap your answer in <answer></answer> tags.
|
| 208 |
+
"""
|
| 209 |
+
if self.user_prompt:
|
| 210 |
+
prompt += self.user_prompt
|
| 211 |
+
|
| 212 |
+
# extra可以返回原始feather字符串,也可以返回feather_vals
|
| 213 |
+
# 或 feather_raw
|
| 214 |
+
data_list.append((prot_seq, text_seq, prompt))
|
| 215 |
+
except Exception as e:
|
| 216 |
+
print(f"警告: 跳过有问题的行: {row},原因: {e}")
|
| 217 |
+
return data_list
|
| 218 |
+
|
| 219 |
+
def _build_text_vocab(self):
|
| 220 |
+
text2id = {}
|
| 221 |
+
for _, text_seq, _ in self.data_list:
|
| 222 |
+
if text_seq not in text2id:
|
| 223 |
+
text2id[text_seq] = len(text2id)
|
| 224 |
+
return text2id
|
| 225 |
+
|
| 226 |
+
def shuffle(self):
|
| 227 |
+
random.shuffle(self.data_list)
|
| 228 |
+
return self
|
| 229 |
+
|
| 230 |
+
def __len__(self):
|
| 231 |
+
return len(self.data_list)
|
| 232 |
+
|
| 233 |
+
def __getitem__(self, index):
|
| 234 |
+
prot_seq, text_seq, prompt = self.data_list[index]
|
| 235 |
+
if self.return_prompt:
|
| 236 |
+
return prot_seq, prompt, text_seq,index
|
| 237 |
+
return prot_seq, text_seq, index
|
data_provider/prot_qa_dm.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# Licensed under the MIT License.
|
| 3 |
+
import json
|
| 4 |
+
from pytorch_lightning import LightningDataModule
|
| 5 |
+
# import torch_geometric
|
| 6 |
+
from torch.utils.data import DataLoader, Dataset
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class ProtQACollater(object):
|
| 11 |
+
def __init__(self, tokenizer, prot_tokenizer, q_max_len, a_max_len, prot_max_len):
|
| 12 |
+
self.tokenizer = tokenizer
|
| 13 |
+
self.prot_tokenizer = prot_tokenizer
|
| 14 |
+
self.q_max_len = q_max_len
|
| 15 |
+
self.a_max_len = a_max_len
|
| 16 |
+
self.prot_max_len = prot_max_len
|
| 17 |
+
|
| 18 |
+
def __call__(self, batch):
|
| 19 |
+
prot_seqs, questions, answers, _, _ = zip(*batch)
|
| 20 |
+
answers = [a + '\n' for a in answers]
|
| 21 |
+
prot_batch = self.prot_tokenizer(prot_seqs,
|
| 22 |
+
truncation=True,
|
| 23 |
+
padding='max_length',
|
| 24 |
+
max_length=self.prot_max_len,
|
| 25 |
+
return_tensors="pt",
|
| 26 |
+
return_attention_mask=True,
|
| 27 |
+
return_token_type_ids=False)
|
| 28 |
+
if False:
|
| 29 |
+
self.tokenizer.padding_side = 'left'
|
| 30 |
+
q_batch = self.tokenizer(questions,
|
| 31 |
+
truncation=True,
|
| 32 |
+
padding='max_length',
|
| 33 |
+
add_special_tokens=True,
|
| 34 |
+
max_length=self.q_max_len,
|
| 35 |
+
return_tensors='pt',
|
| 36 |
+
return_attention_mask=True,
|
| 37 |
+
return_token_type_ids=False)
|
| 38 |
+
self.tokenizer.padding_side = 'right'
|
| 39 |
+
a_batch = self.tokenizer(answers,
|
| 40 |
+
truncation=True,
|
| 41 |
+
padding='max_length',
|
| 42 |
+
add_special_tokens=True,
|
| 43 |
+
max_length=self.a_max_len,
|
| 44 |
+
return_tensors='pt',
|
| 45 |
+
return_attention_mask=True,
|
| 46 |
+
return_token_type_ids=False)
|
| 47 |
+
return prot_batch, q_batch, a_batch
|
| 48 |
+
else:
|
| 49 |
+
self.tokenizer.padding_side = 'right'
|
| 50 |
+
qa_pair = [[q, a] for q, a in zip(questions, answers)]
|
| 51 |
+
qa_batch = self.tokenizer(qa_pair,
|
| 52 |
+
truncation=True,
|
| 53 |
+
padding='max_length',
|
| 54 |
+
add_special_tokens=True,
|
| 55 |
+
max_length=self.q_max_len + self.a_max_len,
|
| 56 |
+
return_tensors='pt',
|
| 57 |
+
return_attention_mask=True,
|
| 58 |
+
return_token_type_ids=True)
|
| 59 |
+
return prot_batch, qa_batch
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class InferenceCollater(object):
|
| 63 |
+
def __init__(self, tokenizer, prot_tokenizer, q_max_len, a_max_len, prot_max_len):
|
| 64 |
+
self.tokenizer = tokenizer
|
| 65 |
+
self.prot_tokenizer = prot_tokenizer
|
| 66 |
+
self.q_max_len = q_max_len
|
| 67 |
+
self.a_max_len = a_max_len
|
| 68 |
+
self.prot_max_len = prot_max_len
|
| 69 |
+
|
| 70 |
+
def __call__(self, batch):
|
| 71 |
+
prot_seqs, questions, answers, q_types, indices = zip(*batch)
|
| 72 |
+
answers = [a + '\n' for a in answers]
|
| 73 |
+
prot_batch = self.prot_tokenizer(prot_seqs,
|
| 74 |
+
truncation=True,
|
| 75 |
+
padding='max_length',
|
| 76 |
+
max_length=self.prot_max_len,
|
| 77 |
+
return_tensors="pt",
|
| 78 |
+
return_attention_mask=True,
|
| 79 |
+
return_token_type_ids=False)
|
| 80 |
+
self.tokenizer.padding_side = 'left'
|
| 81 |
+
q_batch = self.tokenizer(questions,
|
| 82 |
+
truncation=True,
|
| 83 |
+
padding='max_length',
|
| 84 |
+
add_special_tokens=True,
|
| 85 |
+
max_length=self.q_max_len,
|
| 86 |
+
return_tensors='pt',
|
| 87 |
+
return_attention_mask=True,
|
| 88 |
+
return_token_type_ids=False)
|
| 89 |
+
target_dict = {'targets': answers, 'q_types': q_types, 'indices': indices}
|
| 90 |
+
return prot_batch, q_batch, target_dict
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class ProtQADM(LightningDataModule):
|
| 94 |
+
def __init__(
|
| 95 |
+
self,
|
| 96 |
+
root: str = 'data/',
|
| 97 |
+
args=None,
|
| 98 |
+
):
|
| 99 |
+
super().__init__()
|
| 100 |
+
self.args = args
|
| 101 |
+
self.batch_size = args.batch_size
|
| 102 |
+
self.inference_batch_size = args.inference_batch_size
|
| 103 |
+
self.num_workers = args.num_workers
|
| 104 |
+
self.q_max_len = args.q_max_len
|
| 105 |
+
self.a_max_len = args.a_max_len
|
| 106 |
+
self.prot_max_len = args.prot_max_len
|
| 107 |
+
self.prompt = args.prompt
|
| 108 |
+
|
| 109 |
+
self.train_dataset = PDBQADataset(root, 'train.txt', prompt=self.prompt, filter_side_qa=args.filter_side_qa)
|
| 110 |
+
self.val_dataset = PDBQADataset(root, 'val.txt', prompt=self.prompt, filter_side_qa=args.filter_side_qa)
|
| 111 |
+
self.test_dataset = PDBQADataset(root, 'test.txt', prompt=self.prompt, filter_side_qa=args.filter_side_qa)
|
| 112 |
+
|
| 113 |
+
self.tokenizer = None
|
| 114 |
+
self.prot_tokenizer = None
|
| 115 |
+
|
| 116 |
+
def init_tokenizer(self, tokenizer, prot_tokenizer):
|
| 117 |
+
self.tokenizer = tokenizer
|
| 118 |
+
self.prot_tokenizer = prot_tokenizer
|
| 119 |
+
|
| 120 |
+
def train_dataloader(self):
|
| 121 |
+
loader = DataLoader(
|
| 122 |
+
self.train_dataset,
|
| 123 |
+
batch_size=self.batch_size,
|
| 124 |
+
shuffle=True,
|
| 125 |
+
num_workers=self.num_workers,
|
| 126 |
+
pin_memory=False,
|
| 127 |
+
drop_last=True,
|
| 128 |
+
persistent_workers=False,
|
| 129 |
+
collate_fn=ProtQACollater(self.tokenizer, self.prot_tokenizer, self.q_max_len, self.a_max_len, self.prot_max_len),
|
| 130 |
+
)
|
| 131 |
+
return loader
|
| 132 |
+
|
| 133 |
+
def val_dataloader(self):
|
| 134 |
+
val_loader = DataLoader(
|
| 135 |
+
self.val_dataset,
|
| 136 |
+
batch_size=self.batch_size,
|
| 137 |
+
shuffle=False,
|
| 138 |
+
num_workers=self.num_workers,
|
| 139 |
+
pin_memory=False,
|
| 140 |
+
drop_last=False,
|
| 141 |
+
persistent_workers=False,
|
| 142 |
+
collate_fn=ProtQACollater(self.tokenizer, self.prot_tokenizer, self.q_max_len, self.a_max_len, self.prot_max_len),
|
| 143 |
+
)
|
| 144 |
+
test_loader = DataLoader(
|
| 145 |
+
self.test_dataset,
|
| 146 |
+
batch_size=self.inference_batch_size,
|
| 147 |
+
shuffle=False,
|
| 148 |
+
num_workers=self.num_workers,
|
| 149 |
+
pin_memory=False,
|
| 150 |
+
drop_last=False,
|
| 151 |
+
persistent_workers=False,
|
| 152 |
+
collate_fn=InferenceCollater(self.tokenizer, self.prot_tokenizer, self.q_max_len, self.a_max_len, self.prot_max_len),
|
| 153 |
+
)
|
| 154 |
+
return [val_loader, test_loader]
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def add_model_specific_args(parent_parser):
|
| 158 |
+
parser = parent_parser.add_argument_group("Data module")
|
| 159 |
+
parser.add_argument('--num_workers', type=int, default=2)
|
| 160 |
+
parser.add_argument('--batch_size', type=int, default=32)
|
| 161 |
+
parser.add_argument('--inference_batch_size', type=int, default=4)
|
| 162 |
+
parser.add_argument('--root', type=str, default='data/SwissProtV3')
|
| 163 |
+
parser.add_argument('--text_max_len', type=int, default=128)
|
| 164 |
+
parser.add_argument('--q_max_len', type=int, default=34)
|
| 165 |
+
parser.add_argument('--a_max_len', type=int, default=36)
|
| 166 |
+
parser.add_argument('--prot_max_len', type=int, default=1024)
|
| 167 |
+
parser.add_argument('--prompt', type=str, default='The protein has the following properties: ')
|
| 168 |
+
parser.add_argument('--filter_side_qa', action='store_true', default=False)
|
| 169 |
+
return parent_parser
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class PDBQADataset(Dataset):
|
| 173 |
+
def __init__(self, root_path, subset, prompt="Question: {} Answer:", filter_side_qa=False):
|
| 174 |
+
super(PDBQADataset, self).__init__()
|
| 175 |
+
self.data_path = Path(root_path) / subset
|
| 176 |
+
self.qa_path = Path(root_path) / 'qa_all.json'
|
| 177 |
+
self.q_type_path = Path(root_path) / 'q_types.txt'
|
| 178 |
+
self.prompt = prompt
|
| 179 |
+
|
| 180 |
+
## load dataset
|
| 181 |
+
with open(self.qa_path, 'r') as f:
|
| 182 |
+
qa_data = json.load(f)
|
| 183 |
+
|
| 184 |
+
with open(self.data_path, 'r') as f:
|
| 185 |
+
lines = f.readlines()
|
| 186 |
+
pdb2seq = [line.strip().split('\t') for line in lines]
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
## load q types
|
| 190 |
+
with open(self.q_type_path, 'r') as f:
|
| 191 |
+
q_types = [line.strip().split('\t') for line in f.readlines()]
|
| 192 |
+
self.q_type_dict = {q: t for q, t in q_types}
|
| 193 |
+
|
| 194 |
+
## process dataset
|
| 195 |
+
pdb_set = set(i[0] for i in pdb2seq)
|
| 196 |
+
## filter qa data
|
| 197 |
+
qa_data = {k: v for k, v in qa_data.items() if k in pdb_set}
|
| 198 |
+
assert len(qa_data) == len(pdb_set), print(len(qa_data), len(pdb_set))
|
| 199 |
+
|
| 200 |
+
## generate qa data
|
| 201 |
+
self.data_list = []
|
| 202 |
+
for pdb_id, seq in pdb2seq:
|
| 203 |
+
qa_list = qa_data[pdb_id]
|
| 204 |
+
for qa in qa_list:
|
| 205 |
+
q = qa['Q']
|
| 206 |
+
a = str(qa['A'])
|
| 207 |
+
if filter_side_qa:
|
| 208 |
+
q_type = self.q_type_dict[q]
|
| 209 |
+
if q_type.find('side information') >= 0:
|
| 210 |
+
continue
|
| 211 |
+
self.data_list.append((seq, q, a))
|
| 212 |
+
|
| 213 |
+
def __len__(self):
|
| 214 |
+
return len(self.data_list)
|
| 215 |
+
|
| 216 |
+
def __getitem__(self, index):
|
| 217 |
+
seq, q, a = self.data_list[index]
|
| 218 |
+
q_type = self.q_type_dict[q]
|
| 219 |
+
q = self.prompt.format(q)
|
| 220 |
+
return seq, q, a, q_type, index
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
if __name__ == '__main__':
|
| 224 |
+
import numpy as np
|
| 225 |
+
from collections import defaultdict, Counter
|
| 226 |
+
train_dataset = PDBQADataset('../data/PDBDataset', 'train.txt', filter_side_qa=True)
|
| 227 |
+
val_dataset = PDBQADataset('../data/PDBDataset', 'val.txt', filter_side_qa=True)
|
| 228 |
+
test_dataset = PDBQADataset('../data/PDBDataset', 'test.txt', filter_side_qa=True)
|
| 229 |
+
if True:
|
| 230 |
+
# print(len(train_dataset), len(val_dataset), len(test_dataset))
|
| 231 |
+
# train_protein_lens = np.asarray([len(p) for p in train_dataset.protein_list])
|
| 232 |
+
# val_protein_lens = np.asarray([len(p) for p in val_dataset.protein_list])
|
| 233 |
+
# test_protein_lens = np.asarray([len(p) for p in test_dataset.protein_list])
|
| 234 |
+
|
| 235 |
+
q_lens = []
|
| 236 |
+
a_lens = []
|
| 237 |
+
for seq, q, a in train_dataset.data_list:
|
| 238 |
+
q_lens.append(len(q.split()))
|
| 239 |
+
a_lens.append(len(a.split()))
|
| 240 |
+
|
| 241 |
+
print(np.asarray(q_lens).min(), np.asarray(q_lens).max(), np.asarray(q_lens).mean())
|
| 242 |
+
print(np.asarray(a_lens).min(), np.asarray(a_lens).max(), np.asarray(a_lens).mean())
|
| 243 |
+
|
| 244 |
+
q_lens = []
|
| 245 |
+
a_lens = []
|
| 246 |
+
for seq, q, a in val_dataset.data_list:
|
| 247 |
+
q_lens.append(len(q.split()))
|
| 248 |
+
a_lens.append(len(a.split()))
|
| 249 |
+
|
| 250 |
+
print(np.asarray(q_lens).min(), np.asarray(q_lens).max(), np.asarray(q_lens).mean())
|
| 251 |
+
print(np.asarray(a_lens).min(), np.asarray(a_lens).max(), np.asarray(a_lens).mean())
|
| 252 |
+
|
| 253 |
+
q_lens = []
|
| 254 |
+
a_lens = []
|
| 255 |
+
for seq, q, a in test_dataset.data_list:
|
| 256 |
+
q_lens.append(len(q.split()))
|
| 257 |
+
a_lens.append(len(a.split()))
|
| 258 |
+
|
| 259 |
+
print(np.asarray(q_lens).min(), np.asarray(q_lens).max(), np.asarray(q_lens).mean())
|
| 260 |
+
print(np.asarray(a_lens).min(), np.asarray(a_lens).max(), np.asarray(a_lens).mean())
|
| 261 |
+
|
| 262 |
+
elif False:
|
| 263 |
+
## construct the guess for prediction by number
|
| 264 |
+
train_counter = defaultdict(Counter)
|
| 265 |
+
for _, q, a in train_dataset.data_list:
|
| 266 |
+
train_counter[q.lower()][a] += 1
|
| 267 |
+
## get the most common answer
|
| 268 |
+
q2a = {}
|
| 269 |
+
for q, counter in train_counter.items():
|
| 270 |
+
q2a[q] = counter.most_common(1)[0][0]
|
| 271 |
+
|
| 272 |
+
## test the guess
|
| 273 |
+
acc = 0
|
| 274 |
+
for _, q, a in test_dataset.data_list:
|
| 275 |
+
if q.lower() in q2a:
|
| 276 |
+
predict = q2a[q.lower()]
|
| 277 |
+
if predict.lower() == a.lower():
|
| 278 |
+
acc += 1
|
| 279 |
+
print(acc / len(test_dataset.data_list))
|
| 280 |
+
elif False:
|
| 281 |
+
from transformers import AutoTokenizer, EsmTokenizer
|
| 282 |
+
llm_tokenizer = AutoTokenizer.from_pretrained('facebook/galactica-1.3b', use_fast=False, padding_side='right')
|
| 283 |
+
plm_tokenizer = EsmTokenizer.from_pretrained('facebook/esm2_t30_150M_UR50D')
|
| 284 |
+
llm_tokenizer.add_special_tokens({'pad_token': '<pad>'})
|
| 285 |
+
loader = DataLoader(
|
| 286 |
+
train_dataset,
|
| 287 |
+
batch_size=32,
|
| 288 |
+
shuffle=True,
|
| 289 |
+
num_workers=4,
|
| 290 |
+
pin_memory=False,
|
| 291 |
+
drop_last=True,
|
| 292 |
+
persistent_workers=False,
|
| 293 |
+
collate_fn=ProtQACollater(llm_tokenizer, plm_tokenizer, 40, 40, 1024),
|
| 294 |
+
)
|
| 295 |
+
else:
|
| 296 |
+
print(len(train_dataset.data_list))
|
| 297 |
+
print(len(val_dataset.data_list))
|
| 298 |
+
print(len(test_dataset.data_list))
|
| 299 |
+
|
data_provider/proteinchat_dm.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# Licensed under the MIT License.
|
| 3 |
+
import json
|
| 4 |
+
import random
|
| 5 |
+
import torch
|
| 6 |
+
import os
|
| 7 |
+
import numpy as np
|
| 8 |
+
from pytorch_lightning import LightningDataModule
|
| 9 |
+
from torch.utils.data import DataLoader, Dataset
|
| 10 |
+
from data_provider.gal_helpers import escape_custom_split_sequence
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from torch.utils.data.dataloader import default_collate
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ProteinChatCollater(object):
|
| 16 |
+
def __init__(self, tokenizer, q_max_len, a_max_len, use_gal):
|
| 17 |
+
self.tokenizer = tokenizer
|
| 18 |
+
self.q_max_len = q_max_len
|
| 19 |
+
self.a_max_len = a_max_len
|
| 20 |
+
self.use_gal = use_gal
|
| 21 |
+
|
| 22 |
+
def __call__(self, batch):
|
| 23 |
+
embeds, prot_seqs, questions, answers, q_types = zip(*batch)
|
| 24 |
+
max_embed_len = 896
|
| 25 |
+
## concate
|
| 26 |
+
if False:
|
| 27 |
+
max_dim = max([e.shape[0] for e in embeds])
|
| 28 |
+
|
| 29 |
+
padded_embeds = []
|
| 30 |
+
for embed in embeds:
|
| 31 |
+
shape_dim0 = embed.shape[0]
|
| 32 |
+
pad1 = ((0, max_dim - shape_dim0), (0, 0), (0, 0))
|
| 33 |
+
padded_embeds.append(np.pad(embed, pad1, mode='constant'))
|
| 34 |
+
padded_embeds = default_collate(padded_embeds).squeeze(dim=2)[:,:1024,:]
|
| 35 |
+
else:
|
| 36 |
+
padded_embeds = torch.zeros(len(embeds), max_embed_len, 512)
|
| 37 |
+
for i in range(len(embeds)):
|
| 38 |
+
padded_embeds[i, :embeds[i].shape[0], :] = embeds[i][:max_embed_len, :]
|
| 39 |
+
padded_embeds = padded_embeds.detach()
|
| 40 |
+
|
| 41 |
+
assert len(prot_seqs) == len(questions) == len(answers)
|
| 42 |
+
|
| 43 |
+
if self.use_gal:
|
| 44 |
+
questions = [escape_custom_split_sequence(q) for q in questions]
|
| 45 |
+
answers = [a + '\n' for a in answers]
|
| 46 |
+
self.tokenizer.padding_side = 'left'
|
| 47 |
+
q_batch = self.tokenizer(questions,
|
| 48 |
+
truncation=True,
|
| 49 |
+
padding='max_length',
|
| 50 |
+
add_special_tokens=True,
|
| 51 |
+
max_length=self.q_max_len,
|
| 52 |
+
return_tensors='pt',
|
| 53 |
+
return_attention_mask=True,
|
| 54 |
+
return_token_type_ids=False)
|
| 55 |
+
self.tokenizer.padding_side = 'right'
|
| 56 |
+
a_batch = self.tokenizer(answers,
|
| 57 |
+
truncation=True,
|
| 58 |
+
padding='max_length',
|
| 59 |
+
add_special_tokens=True,
|
| 60 |
+
max_length=self.a_max_len,
|
| 61 |
+
return_tensors='pt',
|
| 62 |
+
return_attention_mask=True,
|
| 63 |
+
return_token_type_ids=False)
|
| 64 |
+
prot_mask = torch.ones(padded_embeds.shape[0], padded_embeds.shape[1], dtype=torch.bool)
|
| 65 |
+
return (padded_embeds, prot_mask), q_batch, a_batch
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class InferenceCollater(object):
|
| 69 |
+
def __init__(self, tokenizer, q_max_len, a_max_len, use_gal):
|
| 70 |
+
self.tokenizer = tokenizer
|
| 71 |
+
self.q_max_len = q_max_len
|
| 72 |
+
self.a_max_len = a_max_len
|
| 73 |
+
self.use_gal = use_gal
|
| 74 |
+
|
| 75 |
+
def __call__(self, batch):
|
| 76 |
+
embeds, prot_seqs, questions, answers, q_types = zip(*batch)
|
| 77 |
+
max_embed_len = 896
|
| 78 |
+
## concate
|
| 79 |
+
if False:
|
| 80 |
+
max_dim = max([e.shape[0] for e in embeds])
|
| 81 |
+
|
| 82 |
+
padded_embeds = []
|
| 83 |
+
for embed in embeds:
|
| 84 |
+
shape_dim0 = embed.shape[0]
|
| 85 |
+
pad1 = ((0, max_dim - shape_dim0), (0, 0), (0, 0))
|
| 86 |
+
padded_embeds.append(np.pad(embed, pad1, mode='constant'))
|
| 87 |
+
padded_embeds = default_collate(padded_embeds).squeeze(dim=2)[:,:1024,:]
|
| 88 |
+
else:
|
| 89 |
+
padded_embeds = torch.zeros(len(embeds), max_embed_len, 512)
|
| 90 |
+
for i in range(len(embeds)):
|
| 91 |
+
padded_embeds[i, :embeds[i].shape[0], :] = embeds[i][:max_embed_len, :]
|
| 92 |
+
padded_embeds = padded_embeds.detach()
|
| 93 |
+
|
| 94 |
+
assert len(prot_seqs) == len(questions) == len(answers)
|
| 95 |
+
|
| 96 |
+
if self.use_gal:
|
| 97 |
+
questions = [escape_custom_split_sequence(q) for q in questions]
|
| 98 |
+
answers = [a + '\n' for a in answers]
|
| 99 |
+
self.tokenizer.padding_side = 'left'
|
| 100 |
+
q_batch = self.tokenizer(questions,
|
| 101 |
+
truncation=True,
|
| 102 |
+
padding='max_length',
|
| 103 |
+
add_special_tokens=True,
|
| 104 |
+
max_length=self.q_max_len,
|
| 105 |
+
return_tensors='pt',
|
| 106 |
+
return_attention_mask=True,
|
| 107 |
+
return_token_type_ids=False)
|
| 108 |
+
prot_mask = torch.ones(padded_embeds.shape[0], padded_embeds.shape[1], dtype=torch.bool)
|
| 109 |
+
target_dict = {'answers': answers, "q_types": q_types}
|
| 110 |
+
return (padded_embeds, prot_mask), q_batch, target_dict
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class ProteinChatDM(LightningDataModule):
|
| 114 |
+
def __init__(
|
| 115 |
+
self,
|
| 116 |
+
root: str = 'data/',
|
| 117 |
+
args=None,
|
| 118 |
+
):
|
| 119 |
+
super().__init__()
|
| 120 |
+
self.args = args
|
| 121 |
+
self.batch_size = args.batch_size
|
| 122 |
+
self.inference_batch_size = args.inference_batch_size
|
| 123 |
+
self.num_workers = args.num_workers
|
| 124 |
+
self.q_max_len = args.q_max_len
|
| 125 |
+
self.a_max_len = args.a_max_len
|
| 126 |
+
self.prompt = args.prompt
|
| 127 |
+
|
| 128 |
+
self.train_dataset = ProteinChatDataset(root, 'train.txt', prompt="### Human: {}\n### Assistant: ", pt_file_path=args.pt_file_path)
|
| 129 |
+
self.val_dataset = ProteinChatDataset(root, 'val.txt', prompt="### Human: {}\n### Assistant: ", pt_file_path=args.pt_file_path)
|
| 130 |
+
self.test_dataset = ProteinChatDataset(root, 'test.txt', prompt="### Human: {}\n### Assistant: ", pt_file_path=args.pt_file_path)
|
| 131 |
+
|
| 132 |
+
self.tokenizer = None
|
| 133 |
+
self.use_gal = args.llm_name.find('gal') >= 0
|
| 134 |
+
|
| 135 |
+
def init_tokenizer(self, tokenizer):
|
| 136 |
+
self.tokenizer = tokenizer
|
| 137 |
+
|
| 138 |
+
def train_dataloader(self):
|
| 139 |
+
loader = DataLoader(
|
| 140 |
+
self.train_dataset,
|
| 141 |
+
batch_size=self.batch_size,
|
| 142 |
+
shuffle=True,
|
| 143 |
+
num_workers=self.num_workers,
|
| 144 |
+
pin_memory=False,
|
| 145 |
+
drop_last=True,
|
| 146 |
+
persistent_workers=False,
|
| 147 |
+
collate_fn=ProteinChatCollater(self.tokenizer, self.q_max_len, self.a_max_len, self.use_gal),
|
| 148 |
+
)
|
| 149 |
+
return loader
|
| 150 |
+
|
| 151 |
+
def val_dataloader(self):
|
| 152 |
+
val_loader = DataLoader(
|
| 153 |
+
self.val_dataset,
|
| 154 |
+
batch_size=self.batch_size,
|
| 155 |
+
shuffle=False,
|
| 156 |
+
num_workers=self.num_workers,
|
| 157 |
+
pin_memory=False,
|
| 158 |
+
drop_last=False,
|
| 159 |
+
persistent_workers=False,
|
| 160 |
+
collate_fn=ProteinChatCollater(self.tokenizer, self.q_max_len, self.a_max_len, self.use_gal),
|
| 161 |
+
)
|
| 162 |
+
test_loader = DataLoader(
|
| 163 |
+
self.test_dataset,
|
| 164 |
+
batch_size=self.inference_batch_size,
|
| 165 |
+
shuffle=False,
|
| 166 |
+
num_workers=self.num_workers,
|
| 167 |
+
pin_memory=False,
|
| 168 |
+
drop_last=False,
|
| 169 |
+
persistent_workers=False,
|
| 170 |
+
collate_fn=InferenceCollater(self.tokenizer, self.q_max_len, self.a_max_len, self.use_gal),
|
| 171 |
+
)
|
| 172 |
+
return [val_loader, test_loader]
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def add_model_specific_args(parent_parser):
|
| 176 |
+
parser = parent_parser.add_argument_group("Data module")
|
| 177 |
+
parser.add_argument('--num_workers', type=int, default=2)
|
| 178 |
+
parser.add_argument('--batch_size', type=int, default=32)
|
| 179 |
+
parser.add_argument('--inference_batch_size', type=int, default=4)
|
| 180 |
+
parser.add_argument('--root', type=str, default='data/SwissProtV3')
|
| 181 |
+
parser.add_argument('--q_max_len', type=int, default=30)
|
| 182 |
+
parser.add_argument('--a_max_len', type=int, default=36)
|
| 183 |
+
parser.add_argument('--prompt', type=str, default='[START_AMINO]{}[END_AMINO]. Question: {} Answer:')
|
| 184 |
+
parser.add_argument('--pt_file_path', type=str, default='/home/XXXX-2/proteinchatdata/proteinchat')
|
| 185 |
+
return parent_parser
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class ProteinChatDataset(Dataset):
|
| 190 |
+
def __init__(self, root_path, subset, pt_file_path, prompt="Question: {} Answer:"):
|
| 191 |
+
super(ProteinChatDataset, self).__init__()
|
| 192 |
+
self.data_path = Path(root_path) / subset
|
| 193 |
+
self.qa_path = Path(root_path) / 'qa_all.json'
|
| 194 |
+
self.q_type_path = Path(root_path) / 'q_types.txt'
|
| 195 |
+
self.prompt = prompt
|
| 196 |
+
|
| 197 |
+
## load dataset
|
| 198 |
+
with open(self.qa_path, 'r') as f:
|
| 199 |
+
qa_data = json.load(f)
|
| 200 |
+
|
| 201 |
+
with open(self.data_path, 'r') as f:
|
| 202 |
+
lines = f.readlines()
|
| 203 |
+
pdb2seq = [line.strip().split('\t') for line in lines]
|
| 204 |
+
|
| 205 |
+
## process dataset
|
| 206 |
+
pdb_set = set(i[0] for i in pdb2seq)
|
| 207 |
+
## filter qa data
|
| 208 |
+
qa_data = {k: v for k, v in qa_data.items() if k in pdb_set}
|
| 209 |
+
assert len(qa_data) == len(pdb_set), print(len(qa_data), len(pdb_set))
|
| 210 |
+
|
| 211 |
+
pt_file = Path(pt_file_path).glob('*.pt')
|
| 212 |
+
pt_file_ids = {f.name.split('.pt')[0] for f in pt_file}
|
| 213 |
+
self.pt_file_path = pt_file_path
|
| 214 |
+
|
| 215 |
+
## load q types
|
| 216 |
+
with open(self.q_type_path, 'r') as f:
|
| 217 |
+
q_types = [line.strip().split('\t') for line in f.readlines()]
|
| 218 |
+
self.q_type_dict = {q: t for q, t in q_types}
|
| 219 |
+
|
| 220 |
+
## generate qa data
|
| 221 |
+
self.data_list = []
|
| 222 |
+
for pdb_id, seq in pdb2seq:
|
| 223 |
+
if pdb_id not in pt_file_ids:
|
| 224 |
+
continue
|
| 225 |
+
qa_list = qa_data[pdb_id]
|
| 226 |
+
for qa in qa_list:
|
| 227 |
+
q = qa['Q']
|
| 228 |
+
a = str(qa['A'])
|
| 229 |
+
self.data_list.append((pdb_id, seq, q, a))
|
| 230 |
+
|
| 231 |
+
def shuffle(self):
|
| 232 |
+
random.shuffle(self.data_list)
|
| 233 |
+
return self
|
| 234 |
+
|
| 235 |
+
def __len__(self):
|
| 236 |
+
return len(self.data_list)
|
| 237 |
+
|
| 238 |
+
def __getitem__(self, index):
|
| 239 |
+
pdb_id, seq, q, a = self.data_list[index]
|
| 240 |
+
q_type = self.q_type_dict[q]
|
| 241 |
+
path = os.path.join(self.pt_file_path, pdb_id + '.pt')
|
| 242 |
+
embed = torch.load(path, map_location=torch.device('cpu'))
|
| 243 |
+
embed = embed.squeeze(dim=1)
|
| 244 |
+
embed = embed.detach()
|
| 245 |
+
q = self.prompt.format(q)
|
| 246 |
+
return embed, seq, q, a, q_type
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
if __name__ == '__main__':
|
| 250 |
+
dataset = ProteinChatDataset('./data/PDBDataset', 'train.txt')
|
| 251 |
+
dataset.shuffle()
|
| 252 |
+
for i in range(1000):
|
| 253 |
+
print(dataset[i][0].shape)
|
| 254 |
+
|
data_provider/stage1_dm.py
ADDED
|
@@ -0,0 +1,539 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# Licensed under the MIT License.
|
| 3 |
+
from pytorch_lightning import LightningDataModule
|
| 4 |
+
import json
|
| 5 |
+
from torch.utils.data import DataLoader, Dataset, ConcatDataset
|
| 6 |
+
import random
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def rand_seq_crop(seq, max_len):
|
| 11 |
+
if len(seq) <= max_len:
|
| 12 |
+
return seq
|
| 13 |
+
rand_pos = random.randint(0, len(seq)-1-max_len)
|
| 14 |
+
return seq[rand_pos:rand_pos+max_len]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Stage1Collater(object):
|
| 18 |
+
def __init__(self, tokenizer, prot_tokenizer, text_max_len, prot_max_len, prot_aug='None'):
|
| 19 |
+
self.tokenizer = tokenizer
|
| 20 |
+
self.prot_tokenizer = prot_tokenizer
|
| 21 |
+
self.text_max_len = text_max_len
|
| 22 |
+
self.prot_max_len = prot_max_len
|
| 23 |
+
self.prot_aug = prot_aug
|
| 24 |
+
|
| 25 |
+
def __call__(self, batch):
|
| 26 |
+
prot_seqs, text_seqs, _ = zip(*batch)
|
| 27 |
+
if self.prot_aug == 'rand_crop':
|
| 28 |
+
prot_seqs = [rand_seq_crop(seq, self.prot_max_len-2) for seq in prot_seqs] # -2 for the two special tokens
|
| 29 |
+
|
| 30 |
+
text_tokens = self.tokenizer(text_seqs,
|
| 31 |
+
truncation=True,
|
| 32 |
+
padding='max_length',
|
| 33 |
+
add_special_tokens=True,
|
| 34 |
+
max_length=self.text_max_len,
|
| 35 |
+
return_tensors='pt',
|
| 36 |
+
return_attention_mask=True,
|
| 37 |
+
return_token_type_ids=False)
|
| 38 |
+
prot_tokens = self.prot_tokenizer(prot_seqs,
|
| 39 |
+
truncation=True,
|
| 40 |
+
padding='max_length',
|
| 41 |
+
max_length=self.prot_max_len,
|
| 42 |
+
return_tensors="pt",
|
| 43 |
+
return_attention_mask=True,
|
| 44 |
+
return_token_type_ids=False)
|
| 45 |
+
return prot_tokens, text_tokens
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class Stage1DM(LightningDataModule):
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
num_workers: int = 0,
|
| 52 |
+
batch_size: int = 256,
|
| 53 |
+
root: str = 'data/',
|
| 54 |
+
args=None,
|
| 55 |
+
):
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.batch_size = batch_size
|
| 58 |
+
self.match_batch_size = args.match_batch_size
|
| 59 |
+
self.num_workers = num_workers
|
| 60 |
+
self.text_max_len = args.text_max_len
|
| 61 |
+
self.prot_max_len = args.prot_max_len
|
| 62 |
+
if root.find('SwissProt') >= 0:
|
| 63 |
+
self.train_dataset = SwissProtDataset(root+'/train_set.jsonl')
|
| 64 |
+
self.val_dataset = SwissProtDataset(root+'/valid_set.jsonl')
|
| 65 |
+
self.test_dataset = SwissProtDataset(root+'/test_set.jsonl')
|
| 66 |
+
self.val_dataset_match = SwissProtDataset(root+'/valid_set.jsonl').shuffle()
|
| 67 |
+
self.test_dataset_match = SwissProtDataset(root+'/test_set.jsonl').shuffle()
|
| 68 |
+
elif root.find('PDBDataset') >= 0:
|
| 69 |
+
self.train_dataset = PDBAbstractDataset(root, 'train.txt')
|
| 70 |
+
self.val_dataset = PDBAbstractDataset(root, 'val.txt')
|
| 71 |
+
self.test_dataset = PDBAbstractDataset(root, 'test.txt')
|
| 72 |
+
self.val_dataset_match = PDBAbstractDataset(root, 'val.txt').shuffle()
|
| 73 |
+
self.test_dataset_match = PDBAbstractDataset(root, 'test.txt').shuffle()
|
| 74 |
+
elif root.find('OntoProtein') >= 0:
|
| 75 |
+
self.train_dataset = OntoProteinDataset(root+'/train.txt')
|
| 76 |
+
self.val_dataset = OntoProteinDataset(root+'/valid.txt')
|
| 77 |
+
self.test_dataset = OntoProteinDataset(root+'/test.txt')
|
| 78 |
+
self.val_dataset_match = OntoProteinDataset(root+'/valid.txt').shuffle()
|
| 79 |
+
self.test_dataset_match = OntoProteinDataset(root+'/test.txt').shuffle()
|
| 80 |
+
else:
|
| 81 |
+
raise NotImplementedError
|
| 82 |
+
|
| 83 |
+
self.tokenizer = None
|
| 84 |
+
self.prot_tokenizer = None
|
| 85 |
+
self.prot_aug = args.prot_aug
|
| 86 |
+
|
| 87 |
+
def init_tokenizer(self, tokenizer, prot_tokenizer):
|
| 88 |
+
self.tokenizer = tokenizer
|
| 89 |
+
self.prot_tokenizer = prot_tokenizer
|
| 90 |
+
|
| 91 |
+
def train_dataloader(self):
|
| 92 |
+
loader = DataLoader(
|
| 93 |
+
self.train_dataset,
|
| 94 |
+
batch_size=self.batch_size,
|
| 95 |
+
shuffle=True,
|
| 96 |
+
num_workers=self.num_workers,
|
| 97 |
+
pin_memory=False,
|
| 98 |
+
drop_last=True,
|
| 99 |
+
# persistent_workers=True,
|
| 100 |
+
collate_fn=Stage1Collater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len, self.prot_aug)
|
| 101 |
+
)
|
| 102 |
+
return loader
|
| 103 |
+
|
| 104 |
+
def val_dataloader(self):
|
| 105 |
+
loader = DataLoader(
|
| 106 |
+
self.val_dataset,
|
| 107 |
+
batch_size=self.batch_size,
|
| 108 |
+
shuffle=False,
|
| 109 |
+
num_workers=self.num_workers,
|
| 110 |
+
pin_memory=False,
|
| 111 |
+
drop_last=False,
|
| 112 |
+
# persistent_workers=True,
|
| 113 |
+
collate_fn=Stage1Collater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len, self.prot_aug)
|
| 114 |
+
)
|
| 115 |
+
return loader
|
| 116 |
+
|
| 117 |
+
def match_dataloader(self):
|
| 118 |
+
val_match_loader = DataLoader(self.val_dataset_match,
|
| 119 |
+
batch_size=self.match_batch_size,
|
| 120 |
+
shuffle=False,
|
| 121 |
+
num_workers=self.num_workers,
|
| 122 |
+
pin_memory=False,
|
| 123 |
+
drop_last=False,
|
| 124 |
+
# persistent_workers=True,
|
| 125 |
+
collate_fn=Stage1Collater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len, self.prot_aug))
|
| 126 |
+
test_match_loader = DataLoader(self.test_dataset_match,
|
| 127 |
+
batch_size=self.match_batch_size,
|
| 128 |
+
shuffle=False,
|
| 129 |
+
num_workers=self.num_workers,
|
| 130 |
+
pin_memory=False,
|
| 131 |
+
drop_last=False,
|
| 132 |
+
# persistent_workers=True,
|
| 133 |
+
collate_fn=Stage1Collater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len, self.prot_aug))
|
| 134 |
+
return val_match_loader, test_match_loader
|
| 135 |
+
|
| 136 |
+
def add_model_specific_args(parent_parser):
|
| 137 |
+
parser = parent_parser.add_argument_group("Data module")
|
| 138 |
+
parser.add_argument('--num_workers', type=int, default=4)
|
| 139 |
+
parser.add_argument('--batch_size', type=int, default=32)
|
| 140 |
+
parser.add_argument('--match_batch_size', type=int, default=64)
|
| 141 |
+
parser.add_argument('--root', type=str, default='data/SwissProtV3')
|
| 142 |
+
parser.add_argument('--text_max_len', type=int, default=128)
|
| 143 |
+
parser.add_argument('--prot_max_len', type=int, default=1024)
|
| 144 |
+
parser.add_argument('--prot_aug', type=str, default='None')
|
| 145 |
+
return parent_parser
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class Stage1MixDM(LightningDataModule):
|
| 150 |
+
def __init__(
|
| 151 |
+
self,
|
| 152 |
+
num_workers: int = 0,
|
| 153 |
+
batch_size: int = 256,
|
| 154 |
+
root: str = 'data/',
|
| 155 |
+
args=None,
|
| 156 |
+
):
|
| 157 |
+
super().__init__()
|
| 158 |
+
self.batch_size = batch_size
|
| 159 |
+
self.match_batch_size = args.match_batch_size
|
| 160 |
+
self.num_workers = num_workers
|
| 161 |
+
self.text_max_len = args.text_max_len
|
| 162 |
+
self.prot_max_len = args.prot_max_len
|
| 163 |
+
assert args.mix_dataset
|
| 164 |
+
|
| 165 |
+
train_dataset1 = SwissProtDataset(root+'/SwissProtV3/train_set.jsonl')
|
| 166 |
+
train_dataset2 = OntoProteinDataset(root+'/OntoProteinDatasetV2/train.txt')
|
| 167 |
+
# 新增 PDBAbstract 训练集(需指定 subset 为 'train.txt')
|
| 168 |
+
train_dataset3 = PDBAbstractDataset(root + '/PDBDataset/', subset='train.txt')
|
| 169 |
+
|
| 170 |
+
#self.train_dataset = ConcatDataset([train_dataset1, train_dataset2,train_dataset3], )
|
| 171 |
+
self.train_dataset = ConcatDataset([train_dataset1,train_dataset2], )
|
| 172 |
+
|
| 173 |
+
self.swiss_val_dataset = SwissProtDataset(root+'/SwissProtV3/valid_set.jsonl')
|
| 174 |
+
self.onto_val_dataset = OntoProteinDataset(root+'/OntoProteinDatasetV2/valid.txt')
|
| 175 |
+
self.pdb_val_dataset = PDBAbstractDataset(root + '/PDBDataset/', subset='val.txt')
|
| 176 |
+
|
| 177 |
+
self.swiss_test_dataset = SwissProtDataset(root+'/SwissProtV3/test_set.jsonl')
|
| 178 |
+
self.onto_test_dataset = OntoProteinDataset(root+'/OntoProteinDatasetV2/test.txt')
|
| 179 |
+
self.pdb_test_dataset = PDBAbstractDataset(root + '/PDBDataset/', subset='test.txt')
|
| 180 |
+
|
| 181 |
+
self.swiss_val_dataset_match = SwissProtDataset(root+'/SwissProtV3/valid_set.jsonl').shuffle()
|
| 182 |
+
self.onto_val_dataset_match = OntoProteinDataset(root+'/OntoProteinDatasetV2/valid.txt').shuffle()
|
| 183 |
+
self.pdb_val_dataset_match = PDBAbstractDataset(root + '/PDBDataset/', subset='val.txt').shuffle()
|
| 184 |
+
|
| 185 |
+
self.swiss_test_dataset_match = SwissProtDataset(root+'/SwissProtV3/test_set.jsonl').shuffle()
|
| 186 |
+
self.onto_test_dataset_match = OntoProteinDataset(root+'/OntoProteinDatasetV2/test.txt').shuffle()
|
| 187 |
+
self.pdb_test_dataset_match = PDBAbstractDataset(root + '/PDBDataset/', subset='test.txt').shuffle()
|
| 188 |
+
|
| 189 |
+
self.tokenizer = None
|
| 190 |
+
self.prot_tokenizer = None
|
| 191 |
+
self.prot_aug = args.prot_aug
|
| 192 |
+
|
| 193 |
+
def init_tokenizer(self, tokenizer, prot_tokenizer):
|
| 194 |
+
self.tokenizer = tokenizer
|
| 195 |
+
self.prot_tokenizer = prot_tokenizer
|
| 196 |
+
|
| 197 |
+
def train_dataloader(self):
|
| 198 |
+
loader = DataLoader(
|
| 199 |
+
self.train_dataset,
|
| 200 |
+
batch_size=self.batch_size,
|
| 201 |
+
shuffle=True,
|
| 202 |
+
num_workers=self.num_workers,
|
| 203 |
+
pin_memory=False,
|
| 204 |
+
drop_last=True,
|
| 205 |
+
collate_fn=Stage1Collater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len, self.prot_aug)
|
| 206 |
+
)
|
| 207 |
+
return loader
|
| 208 |
+
|
| 209 |
+
def val_dataloader(self):
|
| 210 |
+
loader1 = DataLoader(
|
| 211 |
+
self.swiss_val_dataset,
|
| 212 |
+
batch_size=self.batch_size,
|
| 213 |
+
shuffle=False,
|
| 214 |
+
num_workers=self.num_workers,
|
| 215 |
+
pin_memory=False,
|
| 216 |
+
drop_last=False,
|
| 217 |
+
collate_fn=Stage1Collater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len, self.prot_aug)
|
| 218 |
+
)
|
| 219 |
+
loader2 = DataLoader(
|
| 220 |
+
self.onto_val_dataset,
|
| 221 |
+
batch_size=self.batch_size,
|
| 222 |
+
shuffle=False,
|
| 223 |
+
num_workers=self.num_workers,
|
| 224 |
+
pin_memory=False,
|
| 225 |
+
drop_last=False,
|
| 226 |
+
collate_fn=Stage1Collater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len, self.prot_aug)
|
| 227 |
+
)
|
| 228 |
+
# 新增 PDB 验证加载器
|
| 229 |
+
loader3 = DataLoader(
|
| 230 |
+
self.pdb_val_dataset,
|
| 231 |
+
batch_size=self.batch_size,
|
| 232 |
+
shuffle=False,
|
| 233 |
+
num_workers=self.num_workers,
|
| 234 |
+
pin_memory=False,
|
| 235 |
+
drop_last=False,
|
| 236 |
+
collate_fn=Stage1Collater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len, self.prot_aug)
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
return [loader1, loader2, loader3]
|
| 240 |
+
|
| 241 |
+
def swiss_match_dataloader(self):
|
| 242 |
+
val_match_loader = DataLoader(self.swiss_val_dataset_match,
|
| 243 |
+
batch_size=self.match_batch_size,
|
| 244 |
+
shuffle=False,
|
| 245 |
+
num_workers=self.num_workers,
|
| 246 |
+
pin_memory=False,
|
| 247 |
+
drop_last=False,
|
| 248 |
+
collate_fn=Stage1Collater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len, self.prot_aug))
|
| 249 |
+
test_match_loader = DataLoader(self.swiss_test_dataset_match,
|
| 250 |
+
batch_size=self.match_batch_size,
|
| 251 |
+
shuffle=False,
|
| 252 |
+
num_workers=self.num_workers,
|
| 253 |
+
pin_memory=False,
|
| 254 |
+
drop_last=False,
|
| 255 |
+
collate_fn=Stage1Collater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len, self.prot_aug))
|
| 256 |
+
return val_match_loader, test_match_loader
|
| 257 |
+
|
| 258 |
+
def onto_match_dataloader(self):
|
| 259 |
+
val_match_loader = DataLoader(self.onto_val_dataset_match,
|
| 260 |
+
batch_size=self.match_batch_size,
|
| 261 |
+
shuffle=False,
|
| 262 |
+
num_workers=self.num_workers,
|
| 263 |
+
pin_memory=False,
|
| 264 |
+
drop_last=False,
|
| 265 |
+
collate_fn=Stage1Collater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len, self.prot_aug))
|
| 266 |
+
test_match_loader = DataLoader(self.onto_test_dataset_match,
|
| 267 |
+
batch_size=self.match_batch_size,
|
| 268 |
+
shuffle=False,
|
| 269 |
+
num_workers=self.num_workers,
|
| 270 |
+
pin_memory=False,
|
| 271 |
+
drop_last=False,
|
| 272 |
+
collate_fn=Stage1Collater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len, self.prot_aug))
|
| 273 |
+
return val_match_loader, test_match_loader
|
| 274 |
+
|
| 275 |
+
def pdb_match_dataloader(self):
|
| 276 |
+
val_match_loader = DataLoader(
|
| 277 |
+
self.pdb_val_dataset_match,
|
| 278 |
+
batch_size=self.match_batch_size,
|
| 279 |
+
shuffle=False,
|
| 280 |
+
num_workers=self.num_workers,
|
| 281 |
+
pin_memory=False,
|
| 282 |
+
drop_last=False,
|
| 283 |
+
collate_fn=Stage1Collater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len, self.prot_aug)
|
| 284 |
+
)
|
| 285 |
+
test_match_loader = DataLoader(
|
| 286 |
+
self.pdb_test_dataset_match,
|
| 287 |
+
batch_size=self.match_batch_size,
|
| 288 |
+
shuffle=False,
|
| 289 |
+
num_workers=self.num_workers,
|
| 290 |
+
pin_memory=False,
|
| 291 |
+
drop_last=False,
|
| 292 |
+
collate_fn=Stage1Collater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len, self.prot_aug)
|
| 293 |
+
)
|
| 294 |
+
return val_match_loader, test_match_loader
|
| 295 |
+
|
| 296 |
+
def add_model_specific_args(parent_parser):
|
| 297 |
+
parser = parent_parser.add_argument_group("Data module")
|
| 298 |
+
parser.add_argument('--num_workers', type=int, default=4)
|
| 299 |
+
parser.add_argument('--batch_size', type=int, default=32)
|
| 300 |
+
parser.add_argument('--match_batch_size', type=int, default=64)
|
| 301 |
+
parser.add_argument('--root', type=str, default='data')
|
| 302 |
+
parser.add_argument('--text_max_len', type=int, default=128)
|
| 303 |
+
parser.add_argument('--prot_max_len', type=int, default=1024)
|
| 304 |
+
parser.add_argument('--prot_aug', type=str, default='None')
|
| 305 |
+
return parent_parser
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
# class SwissProtDataset(Dataset):
|
| 309 |
+
# def __init__(self, data_path, prompt='Swiss-Prot description: ', return_prompt=False):
|
| 310 |
+
# super(SwissProtDataset, self).__init__()
|
| 311 |
+
# self.data_path = data_path
|
| 312 |
+
|
| 313 |
+
# ## load data
|
| 314 |
+
# with open(data_path, 'r') as f:
|
| 315 |
+
# lines = f.readlines()
|
| 316 |
+
# lines = [line.strip() for line in lines]
|
| 317 |
+
# self.data_list = [json.loads(line) for line in lines]
|
| 318 |
+
|
| 319 |
+
# ## preprocessing
|
| 320 |
+
# self.data_list = [(p, t.strip() + '\n') for p, t in self.data_list]
|
| 321 |
+
|
| 322 |
+
# self.text2id = {}
|
| 323 |
+
# for prot_seq, text_seq in self.data_list:
|
| 324 |
+
# if text_seq not in self.text2id:
|
| 325 |
+
# self.text2id[text_seq] = len(self.text2id)
|
| 326 |
+
|
| 327 |
+
# self.prompt = prompt
|
| 328 |
+
# self.return_prompt = return_prompt
|
| 329 |
+
|
| 330 |
+
# def shuffle(self):
|
| 331 |
+
# random.shuffle(self.data_list)
|
| 332 |
+
# return self
|
| 333 |
+
|
| 334 |
+
# def len(self,):
|
| 335 |
+
# return len(self)
|
| 336 |
+
|
| 337 |
+
# def get(self, idx):
|
| 338 |
+
# return self.__getitem__(idx)
|
| 339 |
+
|
| 340 |
+
# def __len__(self):
|
| 341 |
+
# return len(self.data_list)
|
| 342 |
+
|
| 343 |
+
# def __getitem__(self, index):
|
| 344 |
+
# prot_seq, text_seq = self.data_list[index]
|
| 345 |
+
# if self.return_prompt:
|
| 346 |
+
# return prot_seq, self.prompt, text_seq, index
|
| 347 |
+
# return prot_seq, text_seq, index
|
| 348 |
+
|
| 349 |
+
class SwissProtDataset(Dataset):
|
| 350 |
+
def __init__(self, data_path, prompt='Swiss-Prot description: ', return_prompt=False):
|
| 351 |
+
super(SwissProtDataset, self).__init__()
|
| 352 |
+
self.data_path = data_path
|
| 353 |
+
self.prompt = prompt
|
| 354 |
+
self.return_prompt = return_prompt
|
| 355 |
+
|
| 356 |
+
# 加载并预处理数据
|
| 357 |
+
self.data_list = self._load_and_preprocess(data_path)
|
| 358 |
+
self.text2id = self._build_text_vocab()
|
| 359 |
+
|
| 360 |
+
def _load_and_preprocess(self, data_path):
|
| 361 |
+
"""加载JSONL文件并预处理"""
|
| 362 |
+
data_list = []
|
| 363 |
+
with open(data_path, 'r', encoding='utf-8') as f:
|
| 364 |
+
for line in f:
|
| 365 |
+
try:
|
| 366 |
+
item = json.loads(line.strip())
|
| 367 |
+
# 确保包含所需字段
|
| 368 |
+
if 'protein' in item and 'text' in item:
|
| 369 |
+
prot_seq = item['protein']
|
| 370 |
+
text_seq = item['text'].strip() + '\n' # 添加结尾换行符
|
| 371 |
+
data_list.append((prot_seq, text_seq))
|
| 372 |
+
else:
|
| 373 |
+
print(f"警告: 跳过缺少字段的行: {line[:50]}...")
|
| 374 |
+
except json.JSONDecodeError:
|
| 375 |
+
print(f"警告: 跳过无效JSON行: {line[:50]}...")
|
| 376 |
+
return data_list
|
| 377 |
+
|
| 378 |
+
def _build_text_vocab(self):
|
| 379 |
+
"""构建文本到ID的映射"""
|
| 380 |
+
text2id = {}
|
| 381 |
+
for _, text_seq in self.data_list:
|
| 382 |
+
if text_seq not in text2id:
|
| 383 |
+
text2id[text_seq] = len(text2id)
|
| 384 |
+
return text2id
|
| 385 |
+
|
| 386 |
+
def shuffle(self):
|
| 387 |
+
"""打乱数据集顺序"""
|
| 388 |
+
random.shuffle(self.data_list)
|
| 389 |
+
return self
|
| 390 |
+
|
| 391 |
+
def __len__(self):
|
| 392 |
+
return len(self.data_list)
|
| 393 |
+
|
| 394 |
+
def __getitem__(self, index):
|
| 395 |
+
prot_seq, text_seq = self.data_list[index]
|
| 396 |
+
if self.return_prompt:
|
| 397 |
+
return prot_seq, self.prompt, text_seq, index
|
| 398 |
+
return prot_seq, text_seq, index
|
| 399 |
+
|
| 400 |
+
# 添加一些实用方法
|
| 401 |
+
def get_protein_sequence(self, index):
|
| 402 |
+
"""获取指定索引的蛋白质序列"""
|
| 403 |
+
return self.data_list[index][0]
|
| 404 |
+
|
| 405 |
+
def get_text_description(self, index):
|
| 406 |
+
"""获取指定索引的文本描述"""
|
| 407 |
+
return self.data_list[index][1]
|
| 408 |
+
|
| 409 |
+
def get_text_id(self, text_seq):
|
| 410 |
+
"""获取文本描述的ID"""
|
| 411 |
+
return self.text2id.get(text_seq, -1)
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
class PDBAbstractDataset(Dataset):
|
| 415 |
+
def __init__(self, root_path, subset, prompt='ABSTRACT: ', return_prompt=False):
|
| 416 |
+
super(PDBAbstractDataset, self).__init__()
|
| 417 |
+
self.data_path = Path(root_path) / subset
|
| 418 |
+
self.abstract_path = Path(root_path) / 'abstract.json'
|
| 419 |
+
|
| 420 |
+
## load dataset
|
| 421 |
+
with open(self.abstract_path, 'r') as f:
|
| 422 |
+
abstract_data = json.load(f)
|
| 423 |
+
abstract_data_dict = {line['pdb_id']: line['caption'] for line in abstract_data}
|
| 424 |
+
|
| 425 |
+
with open(self.data_path, 'r') as f:
|
| 426 |
+
lines = f.readlines()
|
| 427 |
+
pdb2seq = [line.strip().split('\t') for line in lines]
|
| 428 |
+
|
| 429 |
+
## process dataset
|
| 430 |
+
data_list = []
|
| 431 |
+
for pdb_id, seq in pdb2seq:
|
| 432 |
+
abstract = abstract_data_dict[pdb_id]
|
| 433 |
+
abstract = abstract.replace('\n', ' ').strip() + '\n'
|
| 434 |
+
data_list.append((seq, abstract))
|
| 435 |
+
self.data_list = data_list
|
| 436 |
+
self.prompt = prompt
|
| 437 |
+
self.return_prompt = return_prompt
|
| 438 |
+
|
| 439 |
+
def shuffle(self):
|
| 440 |
+
random.shuffle(self.data_list)
|
| 441 |
+
return self
|
| 442 |
+
|
| 443 |
+
def len(self,):
|
| 444 |
+
return len(self)
|
| 445 |
+
|
| 446 |
+
def get(self, idx):
|
| 447 |
+
return self.__getitem__(idx)
|
| 448 |
+
|
| 449 |
+
def __len__(self):
|
| 450 |
+
return len(self.data_list)
|
| 451 |
+
|
| 452 |
+
def __getitem__(self, index):
|
| 453 |
+
seq, abstract = self.data_list[index]
|
| 454 |
+
if self.return_prompt:
|
| 455 |
+
return seq, self.prompt, abstract, index
|
| 456 |
+
return seq, abstract,index
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
class OntoProteinDataset(Dataset):
|
| 460 |
+
def __init__(self, data_path, prompt='Gene Ontology description: ', return_prompt=False):
|
| 461 |
+
super(OntoProteinDataset, self).__init__()
|
| 462 |
+
self.data_path = data_path
|
| 463 |
+
|
| 464 |
+
## load data
|
| 465 |
+
with open(data_path, 'r') as f:
|
| 466 |
+
lines = f.readlines()
|
| 467 |
+
self.data_list = [line.strip().split('\t') for line in lines]
|
| 468 |
+
|
| 469 |
+
## preprocessing
|
| 470 |
+
## fixme: I have disabled the signal word for this dataset. However, it was used in previous experiments.
|
| 471 |
+
if True:
|
| 472 |
+
self.data_list = [(p, t.strip() + '\n') for p, t in self.data_list]
|
| 473 |
+
else:
|
| 474 |
+
self.data_list = [(p, "KG: " + t.strip() + '\n') for p, t in self.data_list]
|
| 475 |
+
self.prompt = prompt
|
| 476 |
+
self.return_prompt = return_prompt
|
| 477 |
+
|
| 478 |
+
def shuffle(self):
|
| 479 |
+
random.shuffle(self.data_list)
|
| 480 |
+
return self
|
| 481 |
+
|
| 482 |
+
def __len__(self):
|
| 483 |
+
return len(self.data_list)
|
| 484 |
+
|
| 485 |
+
def __getitem__(self, index):
|
| 486 |
+
prot_seq, text_seq = self.data_list[index]
|
| 487 |
+
if self.return_prompt:
|
| 488 |
+
return prot_seq, self.prompt, text_seq, index
|
| 489 |
+
return prot_seq, text_seq,index
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
if __name__ == '__main__':
|
| 493 |
+
import numpy as np
|
| 494 |
+
## get statistics for swiss prot dataset
|
| 495 |
+
if False:
|
| 496 |
+
swiss_train = SwissProtDataset('../data/SwissProtV3/train_set.jsonl')
|
| 497 |
+
swiss_valid = SwissProtDataset('../data/SwissProtV3/valid_set.jsonl')
|
| 498 |
+
swiss_test = SwissProtDataset('../data/SwissProtV3/test_set.jsonl')
|
| 499 |
+
print(len(swiss_train), len(swiss_valid), len(swiss_test))
|
| 500 |
+
|
| 501 |
+
## get amino acid statistics
|
| 502 |
+
aa_lens = np.asarray([len(seq) for seq, _ in swiss_train.data_list])
|
| 503 |
+
print('Train dataset mean: ', np.mean(aa_lens), 'min: ', aa_lens.min(), 'max: ', aa_lens.max())
|
| 504 |
+
aa_lens = np.asarray([len(seq) for seq, _ in swiss_valid.data_list])
|
| 505 |
+
print('Valid dataset mean: ', np.mean(aa_lens), 'min: ', aa_lens.min(), 'max: ', aa_lens.max())
|
| 506 |
+
aa_lens = np.asarray([len(seq) for seq, _ in swiss_test.data_list])
|
| 507 |
+
print('Test dataset mean: ', np.mean(aa_lens), 'min: ', aa_lens.min(), 'max: ', aa_lens.max())
|
| 508 |
+
|
| 509 |
+
## get text statistics
|
| 510 |
+
text_lens = np.asarray([len(seq.split()) for _, seq in swiss_train.data_list])
|
| 511 |
+
print('Train dataset mean: ', np.mean(text_lens), 'min: ', text_lens.min(), 'max: ', text_lens.max())
|
| 512 |
+
text_lens = np.asarray([len(seq.split()) for _, seq in swiss_valid.data_list])
|
| 513 |
+
print('Valid dataset mean: ', np.mean(text_lens), 'min: ', text_lens.min(), 'max: ', text_lens.max())
|
| 514 |
+
text_lens = np.asarray([len(seq.split()) for _, seq in swiss_test.data_list])
|
| 515 |
+
print('Test dataset mean: ', np.mean(text_lens), 'min: ', text_lens.min(), 'max: ', text_lens.max())
|
| 516 |
+
print('---------------------------')
|
| 517 |
+
|
| 518 |
+
## get statistics for onto protein dataset
|
| 519 |
+
onto_train = OntoProteinDataset('../data/OntoProteinDatasetV2/train.txt')
|
| 520 |
+
onto_valid = OntoProteinDataset('../data/OntoProteinDatasetV2/valid.txt')
|
| 521 |
+
onto_test = OntoProteinDataset('../data/OntoProteinDatasetV2/test.txt')
|
| 522 |
+
print(len(onto_train), len(onto_valid), len(onto_test))
|
| 523 |
+
|
| 524 |
+
## get amino acid statistics
|
| 525 |
+
aa_lens = np.asarray([len(seq) for seq, _ in onto_train.data_list])
|
| 526 |
+
print('Train dataset mean: ', np.mean(aa_lens), 'min: ', aa_lens.min(), 'max: ', aa_lens.max())
|
| 527 |
+
aa_lens = np.asarray([len(seq) for seq, _ in onto_valid.data_list])
|
| 528 |
+
print('Valid dataset mean: ', np.mean(aa_lens), 'min: ', aa_lens.min(), 'max: ', aa_lens.max())
|
| 529 |
+
aa_lens = np.asarray([len(seq) for seq, _ in onto_test.data_list])
|
| 530 |
+
print('Test dataset mean: ', np.mean(aa_lens), 'min: ', aa_lens.min(), 'max: ', aa_lens.max())
|
| 531 |
+
|
| 532 |
+
## get text statistics
|
| 533 |
+
text_lens = np.asarray([len(seq.split()) for _, seq in onto_train.data_list])
|
| 534 |
+
print('Train dataset mean: ', np.mean(text_lens), 'min: ', text_lens.min(), 'max: ', text_lens.max())
|
| 535 |
+
text_lens = np.asarray([len(seq.split()) for _, seq in onto_valid.data_list])
|
| 536 |
+
print('Valid dataset mean: ', np.mean(text_lens), 'min: ', text_lens.min(), 'max: ', text_lens.max())
|
| 537 |
+
text_lens = np.asarray([len(seq.split()) for _, seq in onto_test.data_list])
|
| 538 |
+
print('Test dataset mean: ', np.mean(text_lens), 'min: ', text_lens.min(), 'max: ', text_lens.max())
|
| 539 |
+
print('---------------------------')
|
data_provider/stage2_dm.py
ADDED
|
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# Licensed under the MIT License.
|
| 3 |
+
import json
|
| 4 |
+
from pytorch_lightning import LightningDataModule
|
| 5 |
+
from torch.utils.data import DataLoader, ConcatDataset
|
| 6 |
+
from data_provider.stage1_dm import SwissProtDataset, OntoProteinDataset
|
| 7 |
+
from data_provider.stage3_dm import DeepLocBinaryDataset,AlpacaDataset,MolInstructionDataset,DeepLocMultiDataset,Deepsol,DeepsoluE,Protsolm,FLIP_GB1,FLIP_AAV
|
| 8 |
+
from data_provider.bindingdb import BindingDB
|
| 9 |
+
from data_provider.metalIonbinding import MetallonBinding
|
| 10 |
+
from data_provider.go import GO_BP,EC
|
| 11 |
+
from data_provider.production import Antibiotic_Resistance,Thermostability,Material,Clone
|
| 12 |
+
from data_provider.mutation import TAPE_Stability,TAPE_Fluorescence
|
| 13 |
+
|
| 14 |
+
class Stage2Collater(object):
|
| 15 |
+
def __init__(self, tokenizer, prot_tokenizer, text_max_len, prot_max_len):
|
| 16 |
+
self.tokenizer = tokenizer
|
| 17 |
+
self.prot_tokenizer = prot_tokenizer
|
| 18 |
+
self.text_max_len = text_max_len
|
| 19 |
+
self.prot_max_len = prot_max_len
|
| 20 |
+
|
| 21 |
+
def __call__(self, batch):
|
| 22 |
+
prot_seqs, prompt_seqs, text_seqs, _ = zip(*batch)
|
| 23 |
+
prot_tokens = self.prot_tokenizer(prot_seqs,
|
| 24 |
+
truncation=True,
|
| 25 |
+
padding='max_length',
|
| 26 |
+
max_length=self.prot_max_len,
|
| 27 |
+
return_tensors="pt",
|
| 28 |
+
return_attention_mask=True,
|
| 29 |
+
return_token_type_ids=False)
|
| 30 |
+
if False:
|
| 31 |
+
self.tokenizer.padding_side = 'left'
|
| 32 |
+
prompt_tokens = self.tokenizer(prompt_seqs,
|
| 33 |
+
truncation=True,
|
| 34 |
+
padding='longest',
|
| 35 |
+
add_special_tokens=True,
|
| 36 |
+
max_length=self.text_max_len,
|
| 37 |
+
return_tensors='pt',
|
| 38 |
+
return_attention_mask=True,
|
| 39 |
+
return_token_type_ids=False)
|
| 40 |
+
self.tokenizer.padding_side = 'right'
|
| 41 |
+
text_tokens = self.tokenizer(text_seqs,
|
| 42 |
+
truncation=True,
|
| 43 |
+
padding='max_length',
|
| 44 |
+
add_special_tokens=True,
|
| 45 |
+
max_length=self.text_max_len,
|
| 46 |
+
return_tensors='pt',
|
| 47 |
+
return_attention_mask=True,
|
| 48 |
+
return_token_type_ids=False)
|
| 49 |
+
else:
|
| 50 |
+
self.tokenizer.padding_side = 'left'
|
| 51 |
+
prompt_tokens = self.tokenizer(prompt_seqs,
|
| 52 |
+
truncation=True,
|
| 53 |
+
padding='longest',
|
| 54 |
+
add_special_tokens=True,
|
| 55 |
+
max_length=self.text_max_len,
|
| 56 |
+
return_tensors='pt',
|
| 57 |
+
return_attention_mask=True,
|
| 58 |
+
return_token_type_ids=False)
|
| 59 |
+
max_prompt_len = int(prompt_tokens.attention_mask.sum(dim=1).max())
|
| 60 |
+
input_pair = [[p, t] for p, t in zip(prompt_seqs, text_seqs)]
|
| 61 |
+
input_tokens = self.tokenizer(input_pair,
|
| 62 |
+
truncation=True,
|
| 63 |
+
padding='max_length',
|
| 64 |
+
add_special_tokens=True,
|
| 65 |
+
max_length=self.text_max_len + max_prompt_len,
|
| 66 |
+
return_tensors='pt',
|
| 67 |
+
return_attention_mask=True,
|
| 68 |
+
return_token_type_ids=True)
|
| 69 |
+
return prot_tokens, input_tokens
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class InferenceCollater(object):
|
| 73 |
+
def __init__(self, tokenizer, prot_tokenizer, text_max_len, prot_max_len):
|
| 74 |
+
self.tokenizer = tokenizer
|
| 75 |
+
self.prot_tokenizer = prot_tokenizer
|
| 76 |
+
self.text_max_len = text_max_len
|
| 77 |
+
self.prot_max_len = prot_max_len
|
| 78 |
+
|
| 79 |
+
def __call__(self, batch):
|
| 80 |
+
prot_seqs, prompt_seqs, text_seqs, indices = zip(*batch)
|
| 81 |
+
# print("=========")
|
| 82 |
+
# print(prot_seqs)
|
| 83 |
+
|
| 84 |
+
self.tokenizer.padding_side = 'right'
|
| 85 |
+
prompt_tokens = self.tokenizer(prompt_seqs,
|
| 86 |
+
truncation=True,
|
| 87 |
+
padding='longest',
|
| 88 |
+
add_special_tokens=False,
|
| 89 |
+
max_length=self.text_max_len,
|
| 90 |
+
return_tensors='pt',
|
| 91 |
+
return_attention_mask=True,
|
| 92 |
+
return_token_type_ids=False)
|
| 93 |
+
prot_tokens = self.prot_tokenizer(prot_seqs,
|
| 94 |
+
truncation=True,
|
| 95 |
+
padding='max_length',
|
| 96 |
+
max_length=self.prot_max_len,
|
| 97 |
+
return_tensors="pt",
|
| 98 |
+
return_attention_mask=True,
|
| 99 |
+
return_token_type_ids=False)
|
| 100 |
+
target_dict = {'targets': text_seqs, 'indices': indices}
|
| 101 |
+
return prot_tokens, prompt_tokens, target_dict
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class Stage2DM(LightningDataModule):
|
| 105 |
+
def __init__(
|
| 106 |
+
self,
|
| 107 |
+
root: str = 'data/',
|
| 108 |
+
args=None,
|
| 109 |
+
):
|
| 110 |
+
super().__init__()
|
| 111 |
+
self.args = args
|
| 112 |
+
self.batch_size = args.batch_size
|
| 113 |
+
self.inference_batch_size = args.inference_batch_size
|
| 114 |
+
self.num_workers = args.num_workers
|
| 115 |
+
self.text_max_len = args.text_max_len
|
| 116 |
+
self.prot_max_len = args.prot_max_len
|
| 117 |
+
self.prompt = args.prompt
|
| 118 |
+
|
| 119 |
+
# self.train_dataset = AlpacaDataset('/nas/shared/kilab/wangyujia/pretrain_data/instruct/alpaca-gpt4-train.jsonl', prompt=self.prompt, return_prompt=True)
|
| 120 |
+
# self.val_dataset = AlpacaDataset('/nas/shared/kilab/wangyujia/pretrain_data/instruct/alpaca-gpt4-valid.jsonl', prompt=self.prompt, return_prompt=True)
|
| 121 |
+
# self.test_dataset = AlpacaDataset('/nas/shared/kilab/wangyujia/pretrain_data/instruct/alpaca-gpt4-test.jsonl', prompt=self.prompt, return_prompt=True)
|
| 122 |
+
|
| 123 |
+
# self.train_dataset = MolInstructionDataset('/oss/wangyujia/pretrain-bench/mol-instruction/train.jsonl', prompt='', return_prompt=True)
|
| 124 |
+
# self.val_dataset = MolInstructionDataset('/oss/wangyujia/pretrain-bench/mol-instruction/train.jsonl', prompt='', return_prompt=True)
|
| 125 |
+
# self.test_dataset = MolInstructionDataset('/oss/wangyujia/pretrain-bench/mol-instruction/train.jsonl', prompt='', return_prompt=True)
|
| 126 |
+
|
| 127 |
+
if self.args.dataset=='deeplocbinary':
|
| 128 |
+
self.train_dataset = DeepLocBinaryDataset('/oss/wangyujia/pretrain-bench/locate/deeplocbinary/train.csv', prompt=self.prompt, return_prompt=True)
|
| 129 |
+
self.val_dataset = DeepLocBinaryDataset('/oss/wangyujia/pretrain-bench/locate/deeplocbinary/valid.csv', prompt=self.prompt, return_prompt=True)
|
| 130 |
+
self.test_dataset = DeepLocBinaryDataset('/oss/wangyujia/pretrain-bench/locate/deeplocbinary/test.csv', prompt=self.prompt, return_prompt=True)
|
| 131 |
+
|
| 132 |
+
elif self.args.dataset=='deeplocmulti':
|
| 133 |
+
self.train_dataset = DeepLocMultiDataset('/oss/wangyujia/pretrain-bench/locate/deeplocmulti/train.csv',prompt=self.prompt, return_prompt=True)
|
| 134 |
+
self.val_dataset = DeepLocMultiDataset('/oss/wangyujia/pretrain-bench/locate/deeplocmulti/valid.csv', prompt=self.prompt, return_prompt=True)
|
| 135 |
+
self.test_dataset = DeepLocMultiDataset('/oss/wangyujia/pretrain-bench/locate/deeplocmulti/test.csv', prompt=self.prompt, return_prompt=True)
|
| 136 |
+
|
| 137 |
+
elif self.args.dataset=='deepsol':
|
| 138 |
+
self.train_dataset = Deepsol('/nas/shared/kilab/wangyujia/sft_data/deepsol/clean/train.csv',prompt=self.prompt, return_prompt=True)
|
| 139 |
+
self.val_dataset = Deepsol('/nas/shared/kilab/wangyujia/sft_data/deepsol/clean/valid.csv', prompt=self.prompt, return_prompt=True)
|
| 140 |
+
self.test_dataset = Deepsol('/nas/shared/kilab/wangyujia/sft_data/deepsol/clean/test.csv', prompt=self.prompt, return_prompt=True)
|
| 141 |
+
elif self.args.dataset=='deepsolue':
|
| 142 |
+
self.train_dataset = DeepsoluE('/nas/shared/kilab/wangyujia/sft_data/deepsoluE/clean/train.csv',prompt=self.prompt, return_prompt=True)
|
| 143 |
+
self.val_dataset = DeepsoluE('/nas/shared/kilab/wangyujia/sft_data/deepsoluE/clean/test.csv', prompt=self.prompt, return_prompt=True)
|
| 144 |
+
self.test_dataset = DeepsoluE('/nas/shared/kilab/wangyujia/sft_data/deepsoluE/clean/test.csv', prompt=self.prompt, return_prompt=True)
|
| 145 |
+
elif self.args.dataset=='protsolm':
|
| 146 |
+
self.train_dataset = Protsolm('/oss/wangyujia/pretrain-bench/solubility/protsolm/train.csv',prompt=self.prompt, return_prompt=True)
|
| 147 |
+
self.val_dataset = Protsolm('/oss/wangyujia/pretrain-bench/solubility/protsolm/valid.csv', prompt=self.prompt, return_prompt=True)
|
| 148 |
+
self.test_dataset = Protsolm('/oss/wangyujia/pretrain-bench/solubility/protsolm/test.csv', prompt=self.prompt, return_prompt=True)
|
| 149 |
+
elif self.args.dataset=='gb1':
|
| 150 |
+
self.train_dataset = FLIP_GB1('/nas/shared/kilab/wangyujia/sft_data/mutation/gb1/train.csv',prompt=self.prompt, return_prompt=True)
|
| 151 |
+
self.val_dataset = FLIP_GB1('/nas/shared/kilab/wangyujia/sft_data/mutation/gb1/valid.csv', prompt=self.prompt, return_prompt=True)
|
| 152 |
+
self.test_dataset = FLIP_GB1('/nas/shared/kilab/wangyujia/sft_data/mutation/gb1/test.csv', prompt=self.prompt, return_prompt=True)
|
| 153 |
+
elif self.args.dataset=='gb1_low':
|
| 154 |
+
self.train_dataset = FLIP_GB1('/nas/shared/kilab/wangyujia/sft_data/mutation/gb1/clean/train.csv',prompt=self.prompt, return_prompt=True)
|
| 155 |
+
self.val_dataset = FLIP_GB1('/nas/shared/kilab/wangyujia/sft_data/mutation/gb1/clean/valid.csv', prompt=self.prompt, return_prompt=True)
|
| 156 |
+
self.test_dataset = FLIP_GB1('/nas/shared/kilab/wangyujia/sft_data/mutation/gb1/clean/test.csv', prompt=self.prompt, return_prompt=True)
|
| 157 |
+
|
| 158 |
+
elif self.args.dataset=='aav':
|
| 159 |
+
self.train_dataset = FLIP_AAV('/nas/shared/kilab/wangyujia/sft_data/mutation/aav/train.csv',prompt=self.prompt, return_prompt=True)
|
| 160 |
+
self.val_dataset = FLIP_AAV('/nas/shared/kilab/wangyujia/sft_data/mutation/aav/valid.csv', prompt=self.prompt, return_prompt=True)
|
| 161 |
+
self.test_dataset = FLIP_AAV('/nas/shared/kilab/wangyujia/sft_data/mutation/aav/test.csv', prompt=self.prompt, return_prompt=True)
|
| 162 |
+
|
| 163 |
+
elif self.args.dataset=='bindingdb':
|
| 164 |
+
self.train_dataset = BindingDB('/nas/shared/kilab/wangyujia/sft_data/bindingdb/clean/train_small.csv',prompt=self.prompt, return_prompt=True)
|
| 165 |
+
self.val_dataset = BindingDB('/nas/shared/kilab/wangyujia/sft_data/bindingdb/clean/valid_small.csv', prompt=self.prompt, return_prompt=True)
|
| 166 |
+
self.test_dataset = BindingDB('/nas/shared/kilab/wangyujia/sft_data/bindingdb/clean/test_small.csv', prompt=self.prompt, return_prompt=True)
|
| 167 |
+
|
| 168 |
+
elif self.args.dataset=='metallonbinding':
|
| 169 |
+
self.train_dataset = MetallonBinding('/nas/shared/kilab/wangyujia/sft_data/MetalIonBinding/train.csv',prompt=self.prompt, return_prompt=True)
|
| 170 |
+
self.val_dataset = MetallonBinding('/nas/shared/kilab/wangyujia/sft_data/MetalIonBinding/valid.csv', prompt=self.prompt, return_prompt=True)
|
| 171 |
+
self.test_dataset = MetallonBinding('/nas/shared/kilab/wangyujia/sft_data/MetalIonBinding/test.csv', prompt=self.prompt, return_prompt=True)
|
| 172 |
+
|
| 173 |
+
elif self.args.dataset=='bp':
|
| 174 |
+
self.train_dataset = GO_BP('/nas/shared/kilab/wangyujia/sft_data/go/clean/BP_train.csv',prompt=self.prompt, return_prompt=True)
|
| 175 |
+
self.val_dataset = GO_BP('/nas/shared/kilab/wangyujia/sft_data/go/clean/BP_valid.csv', prompt=self.prompt, return_prompt=True)
|
| 176 |
+
self.test_dataset = GO_BP('/nas/shared/kilab/wangyujia/sft_data/go/clean/BP_test.csv', prompt=self.prompt, return_prompt=True)
|
| 177 |
+
elif self.args.dataset=='ec':
|
| 178 |
+
self.train_dataset = EC('/nas/shared/kilab/wangyujia/sft_data/EC/train.csv',prompt=self.prompt, return_prompt=True)
|
| 179 |
+
self.val_dataset = EC('/nas/shared/kilab/wangyujia/sft_data/EC/valid.csv', prompt=self.prompt, return_prompt=True)
|
| 180 |
+
self.test_dataset = EC('/nas/shared/kilab/wangyujia/sft_data/EC/test.csv', prompt=self.prompt, return_prompt=True)
|
| 181 |
+
|
| 182 |
+
elif self.args.dataset=='antibiotic':
|
| 183 |
+
self.train_dataset = Antibiotic_Resistance('/nas/shared/kilab/wangyujia/sft_data/production/Antibiotic_Resistance/train.csv',prompt=self.prompt, return_prompt=True)
|
| 184 |
+
self.val_dataset = Antibiotic_Resistance('/nas/shared/kilab/wangyujia/sft_data/production/Antibiotic_Resistance/test.csv', prompt=self.prompt, return_prompt=True)
|
| 185 |
+
self.test_dataset = Antibiotic_Resistance('/nas/shared/kilab/wangyujia/sft_data/production/Antibiotic_Resistance/test.csv', prompt=self.prompt, return_prompt=True)
|
| 186 |
+
elif self.args.dataset=='thermostability':
|
| 187 |
+
self.train_dataset = Thermostability('/nas/shared/kilab/wangyujia/sft_data/production/Thermostability/train.csv',prompt=self.prompt, return_prompt=True)
|
| 188 |
+
self.val_dataset = Thermostability('/nas/shared/kilab/wangyujia/sft_data/production/Thermostability/valid.csv', prompt=self.prompt, return_prompt=True)
|
| 189 |
+
self.test_dataset = Thermostability('/nas/shared/kilab/wangyujia/sft_data/production/Thermostability/test.csv', prompt=self.prompt, return_prompt=True)
|
| 190 |
+
elif self.args.dataset=='material':
|
| 191 |
+
self.train_dataset = Material('/oss/wangyujia/ProtT3/ProtT3/data/sft/dataset/material_production/train.csv',prompt=self.prompt, return_prompt=True)
|
| 192 |
+
self.val_dataset = Material('/oss/wangyujia/ProtT3/ProtT3/data/sft/dataset/material_production/val.csv', prompt=self.prompt, return_prompt=True)
|
| 193 |
+
self.test_dataset = Material('/oss/wangyujia/ProtT3/ProtT3/data/sft/dataset/material_production/test.csv', prompt=self.prompt, return_prompt=True)
|
| 194 |
+
#6
|
| 195 |
+
elif self.args.dataset=='clone':
|
| 196 |
+
self.train_dataset = Clone('/oss/wangyujia/ProtT3/ProtT3/data/sft/dataset/cloning_clf/train.csv',prompt=self.prompt, return_prompt=True)
|
| 197 |
+
self.val_dataset = Clone('/oss/wangyujia/ProtT3/ProtT3/data/sft/dataset/cloning_clf/val.csv', prompt=self.prompt, return_prompt=True)
|
| 198 |
+
self.test_dataset = Clone('/oss/wangyujia/ProtT3/ProtT3/data/sft/dataset/cloning_clf/test.csv', prompt=self.prompt, return_prompt=True)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
elif self.args.dataset=='stability':
|
| 203 |
+
self.train_dataset = TAPE_Stability('/oss/wangyujia/ProtT3/ProtT3/data/sft/dataset/TAPE_Stability/train.csv',prompt=self.prompt, return_prompt=True)
|
| 204 |
+
self.val_dataset = TAPE_Stability('/oss/wangyujia/ProtT3/ProtT3/data/sft/dataset/TAPE_Stability/val.csv', prompt=self.prompt, return_prompt=True)
|
| 205 |
+
self.test_dataset = TAPE_Stability('/oss/wangyujia/ProtT3/ProtT3/data/sft/dataset/TAPE_Stability/test.csv', prompt=self.prompt, return_prompt=True)
|
| 206 |
+
elif self.args.dataset=='fluorescence':
|
| 207 |
+
self.train_dataset = TAPE_Fluorescence('/oss/wangyujia/ProtT3/ProtT3/data/sft/dataset/TAPE_Fluorescence/fluorescence_prediction_train.csv',prompt=self.prompt, return_prompt=True)
|
| 208 |
+
self.val_dataset = TAPE_Fluorescence('/oss/wangyujia/ProtT3/ProtT3/data/sft/dataset/TAPE_Fluorescence/fluorescence_prediction_valid.csv', prompt=self.prompt, return_prompt=True)
|
| 209 |
+
self.test_dataset = TAPE_Fluorescence('/oss/wangyujia/ProtT3/ProtT3/data/sft/dataset/TAPE_Fluorescence/test.csv', prompt=self.prompt, return_prompt=True)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
elif self.args.dataset=='empty':
|
| 213 |
+
self.train_dataset = Empty( prompt=self.prompt, return_prompt=True)
|
| 214 |
+
self.val_dataset = Empty( prompt=self.prompt, return_prompt=True)
|
| 215 |
+
self.test_dataset = Empty(prompt=self.prompt, return_prompt=True)
|
| 216 |
+
|
| 217 |
+
elif self.args.dataset=='swiss-prot':
|
| 218 |
+
self.train_dataset = SwissProtDataset(root+'/SwissProtV3/train_set.jsonl', prompt='Swiss-Prot description: ', return_prompt=True)
|
| 219 |
+
self.val_dataset = SwissProtDataset(root+'/SwissProtV3/valid_set.jsonl', prompt='Swiss-Prot description: ', return_prompt=True)
|
| 220 |
+
self.test_dataset = SwissProtDataset(root+'/SwissProtV3/test_set.jsonl', prompt='Swiss-Prot description: ', return_prompt=True)
|
| 221 |
+
|
| 222 |
+
self.tokenizer = None
|
| 223 |
+
self.prot_tokenizer = None
|
| 224 |
+
|
| 225 |
+
def init_tokenizer(self, tokenizer, prot_tokenizer):
|
| 226 |
+
self.tokenizer = tokenizer
|
| 227 |
+
self.prot_tokenizer = prot_tokenizer
|
| 228 |
+
|
| 229 |
+
def train_dataloader(self):
|
| 230 |
+
loader = DataLoader(
|
| 231 |
+
self.train_dataset,
|
| 232 |
+
batch_size=self.batch_size,
|
| 233 |
+
shuffle=True,
|
| 234 |
+
num_workers=self.num_workers,
|
| 235 |
+
pin_memory=False,
|
| 236 |
+
drop_last=True,
|
| 237 |
+
persistent_workers=False,
|
| 238 |
+
collate_fn=InferenceCollater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len),
|
| 239 |
+
)
|
| 240 |
+
return loader
|
| 241 |
+
|
| 242 |
+
def val_dataloader(self):
|
| 243 |
+
val_loader = DataLoader(
|
| 244 |
+
self.val_dataset,
|
| 245 |
+
batch_size=self.batch_size,
|
| 246 |
+
shuffle=False,
|
| 247 |
+
num_workers=self.num_workers,
|
| 248 |
+
pin_memory=False,
|
| 249 |
+
drop_last=False,
|
| 250 |
+
persistent_workers=False,
|
| 251 |
+
collate_fn=InferenceCollater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len),
|
| 252 |
+
)
|
| 253 |
+
test_loader = DataLoader(
|
| 254 |
+
self.test_dataset,
|
| 255 |
+
batch_size=self.inference_batch_size,
|
| 256 |
+
shuffle=False,
|
| 257 |
+
num_workers=self.num_workers,
|
| 258 |
+
pin_memory=False,
|
| 259 |
+
drop_last=False,
|
| 260 |
+
persistent_workers=False,
|
| 261 |
+
collate_fn=InferenceCollater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len),
|
| 262 |
+
)
|
| 263 |
+
return [val_loader, test_loader]
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def add_model_specific_args(parent_parser):
|
| 267 |
+
parser = parent_parser.add_argument_group("Data module")
|
| 268 |
+
parser.add_argument('--num_workers', type=int, default=2)
|
| 269 |
+
parser.add_argument('--batch_size', type=int, default=32)
|
| 270 |
+
parser.add_argument('--inference_batch_size', type=int, default=4)
|
| 271 |
+
parser.add_argument('--root', type=str, default='data')
|
| 272 |
+
parser.add_argument('--text_max_len', type=int, default=2048)
|
| 273 |
+
parser.add_argument('--q_max_len', type=int, default=29)
|
| 274 |
+
parser.add_argument('--a_max_len', type=int, default=36)
|
| 275 |
+
parser.add_argument('--prot_max_len', type=int, default=1024)
|
| 276 |
+
parser.add_argument('--prompt', type=str, default='The protein has the following properties:')
|
| 277 |
+
parser.add_argument('--filter_side_qa', action='store_true', default=False)
|
| 278 |
+
return parent_parser
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
class Stage2MixDM(LightningDataModule):
|
| 283 |
+
def __init__(
|
| 284 |
+
self,
|
| 285 |
+
root: str = 'data/',
|
| 286 |
+
args=None,
|
| 287 |
+
):
|
| 288 |
+
super().__init__()
|
| 289 |
+
self.args = args
|
| 290 |
+
self.batch_size = args.batch_size
|
| 291 |
+
self.inference_batch_size = args.inference_batch_size
|
| 292 |
+
self.num_workers = args.num_workers
|
| 293 |
+
self.text_max_len = args.text_max_len
|
| 294 |
+
self.prot_max_len = args.prot_max_len
|
| 295 |
+
# self.prompt = args.prompt
|
| 296 |
+
assert args.mix_dataset
|
| 297 |
+
|
| 298 |
+
train_dataset1 = SwissProtDataset(root+'/SwissProtV3/train_set.jsonl', prompt='Swiss-Prot description: ', return_prompt=True)
|
| 299 |
+
train_dataset2 = OntoProteinDataset(root+'/OntoProteinDatasetV2/train.txt', prompt='Gene Ontology description: ', return_prompt=True)
|
| 300 |
+
self.train_dataset = ConcatDataset([train_dataset1,train_dataset2])
|
| 301 |
+
self.swiss_val_dataset = SwissProtDataset(root+'/SwissProtV3/valid_set.jsonl', prompt='Swiss-Prot description: ', return_prompt=True)
|
| 302 |
+
self.onto_val_dataset = OntoProteinDataset(root+'/OntoProteinDatasetV2/valid.txt', prompt='Gene Ontology description: ', return_prompt=True)
|
| 303 |
+
self.swiss_test_dataset = SwissProtDataset(root+'/SwissProtV3/test_set.jsonl', prompt='Swiss-Prot description: ', return_prompt=True)
|
| 304 |
+
self.onto_test_dataset = OntoProteinDataset(root+'/OntoProteinDatasetV2/test.txt', prompt='Gene Ontology description: ', return_prompt=True)
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
self.tokenizer = None
|
| 309 |
+
self.prot_tokenizer = None
|
| 310 |
+
|
| 311 |
+
def init_tokenizer(self, tokenizer, prot_tokenizer):
|
| 312 |
+
self.tokenizer = tokenizer
|
| 313 |
+
self.prot_tokenizer = prot_tokenizer
|
| 314 |
+
|
| 315 |
+
def train_dataloader(self):
|
| 316 |
+
loader = DataLoader(
|
| 317 |
+
self.train_dataset,
|
| 318 |
+
batch_size=self.batch_size,
|
| 319 |
+
shuffle=True,
|
| 320 |
+
num_workers=self.num_workers,
|
| 321 |
+
pin_memory=False,
|
| 322 |
+
drop_last=True,
|
| 323 |
+
persistent_workers=False,
|
| 324 |
+
collate_fn=InferenceCollater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len),
|
| 325 |
+
)
|
| 326 |
+
return loader
|
| 327 |
+
|
| 328 |
+
def val_dataloader(self):
|
| 329 |
+
swiss_val_loader = DataLoader(
|
| 330 |
+
self.swiss_val_dataset,
|
| 331 |
+
batch_size=self.batch_size,
|
| 332 |
+
shuffle=False,
|
| 333 |
+
num_workers=self.num_workers,
|
| 334 |
+
pin_memory=False,
|
| 335 |
+
drop_last=False,
|
| 336 |
+
persistent_workers=False,
|
| 337 |
+
collate_fn=InferenceCollater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len),
|
| 338 |
+
)
|
| 339 |
+
swiss_test_loader = DataLoader(
|
| 340 |
+
self.swiss_test_dataset,
|
| 341 |
+
batch_size=self.inference_batch_size,
|
| 342 |
+
shuffle=False,
|
| 343 |
+
num_workers=self.num_workers,
|
| 344 |
+
pin_memory=False,
|
| 345 |
+
drop_last=False,
|
| 346 |
+
persistent_workers=False,
|
| 347 |
+
collate_fn=InferenceCollater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len),
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
onto_val_loader = DataLoader(
|
| 351 |
+
self.onto_val_dataset,
|
| 352 |
+
batch_size=self.batch_size,
|
| 353 |
+
shuffle=False,
|
| 354 |
+
num_workers=self.num_workers,
|
| 355 |
+
pin_memory=False,
|
| 356 |
+
drop_last=False,
|
| 357 |
+
persistent_workers=False,
|
| 358 |
+
collate_fn=InferenceCollater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len),
|
| 359 |
+
)
|
| 360 |
+
onto_test_loader = DataLoader(
|
| 361 |
+
self.onto_test_dataset,
|
| 362 |
+
batch_size=self.inference_batch_size,
|
| 363 |
+
shuffle=False,
|
| 364 |
+
num_workers=self.num_workers,
|
| 365 |
+
pin_memory=False,
|
| 366 |
+
drop_last=False,
|
| 367 |
+
persistent_workers=False,
|
| 368 |
+
collate_fn=InferenceCollater(self.tokenizer, self.prot_tokenizer, self.text_max_len, self.prot_max_len),
|
| 369 |
+
)
|
| 370 |
+
return [swiss_val_loader, swiss_test_loader, onto_val_loader, onto_test_loader]
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def add_model_specific_args(parent_parser):
|
| 374 |
+
parser = parent_parser.add_argument_group("Data module")
|
| 375 |
+
parser.add_argument('--num_workers', type=int, default=2)
|
| 376 |
+
parser.add_argument('--batch_size', type=int, default=32)
|
| 377 |
+
parser.add_argument('--inference_batch_size', type=int, default=4)
|
| 378 |
+
parser.add_argument('--root', type=str, default='data')
|
| 379 |
+
parser.add_argument('--text_max_len', type=int, default=1024)
|
| 380 |
+
parser.add_argument('--q_max_len', type=int, default=29)
|
| 381 |
+
parser.add_argument('--a_max_len', type=int, default=36)
|
| 382 |
+
parser.add_argument('--prot_max_len', type=int, default=1024)
|
| 383 |
+
# parser.add_argument('--prompt', type=str, default='The protein has the following properties: ')
|
| 384 |
+
return parent_parser
|
| 385 |
+
|
| 386 |
+
|