Spaces:
Sleeping
Sleeping
Implement get_attention for tape BERT
Browse files- poetry.lock +118 -1
- protention/attention.py +30 -11
- pyproject.toml +1 -0
- tests/test_attention.py +10 -1
poetry.lock
CHANGED
|
@@ -171,6 +171,38 @@ category = "main"
|
|
| 171 |
optional = false
|
| 172 |
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
|
| 173 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
[[package]]
|
| 175 |
name = "cachetools"
|
| 176 |
version = "5.3.0"
|
|
@@ -572,6 +604,14 @@ MarkupSafe = ">=2.0"
|
|
| 572 |
[package.extras]
|
| 573 |
i18n = ["Babel (>=2.7)"]
|
| 574 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 575 |
[[package]]
|
| 576 |
name = "jsonpointer"
|
| 577 |
version = "2.3"
|
|
@@ -749,6 +789,14 @@ category = "main"
|
|
| 749 |
optional = false
|
| 750 |
python-versions = "*"
|
| 751 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 752 |
[[package]]
|
| 753 |
name = "markdown-it-py"
|
| 754 |
version = "2.2.0"
|
|
@@ -1474,6 +1522,36 @@ pygments = ">=2.13.0,<3.0.0"
|
|
| 1474 |
[package.extras]
|
| 1475 |
jupyter = ["ipywidgets (>=7.5.1,<9)"]
|
| 1476 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1477 |
[[package]]
|
| 1478 |
name = "semver"
|
| 1479 |
version = "2.13.0"
|
|
@@ -1613,6 +1691,37 @@ python-versions = ">=3.8"
|
|
| 1613 |
[package.dependencies]
|
| 1614 |
mpmath = ">=0.19"
|
| 1615 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1616 |
[[package]]
|
| 1617 |
name = "terminado"
|
| 1618 |
version = "0.17.1"
|
|
@@ -1983,7 +2092,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "flake8 (<5)", "pytest-co
|
|
| 1983 |
[metadata]
|
| 1984 |
lock-version = "1.1"
|
| 1985 |
python-versions = "^3.10"
|
| 1986 |
-
content-hash = "
|
| 1987 |
|
| 1988 |
[metadata.files]
|
| 1989 |
altair = []
|
|
@@ -2027,6 +2136,8 @@ beautifulsoup4 = []
|
|
| 2027 |
biopython = []
|
| 2028 |
bleach = []
|
| 2029 |
blinker = []
|
|
|
|
|
|
|
| 2030 |
cachetools = []
|
| 2031 |
certifi = []
|
| 2032 |
cffi = []
|
|
@@ -2071,6 +2182,7 @@ ipywidgets = []
|
|
| 2071 |
isoduration = []
|
| 2072 |
jedi = []
|
| 2073 |
jinja2 = []
|
|
|
|
| 2074 |
jsonpointer = []
|
| 2075 |
jsonschema = []
|
| 2076 |
jupyter-client = []
|
|
@@ -2082,6 +2194,7 @@ jupyter-server-terminals = []
|
|
| 2082 |
jupyterlab-pygments = []
|
| 2083 |
jupyterlab-widgets = []
|
| 2084 |
lit = []
|
|
|
|
| 2085 |
markdown-it-py = []
|
| 2086 |
markupsafe = []
|
| 2087 |
matplotlib-inline = []
|
|
@@ -2205,6 +2318,8 @@ requests = []
|
|
| 2205 |
rfc3339-validator = []
|
| 2206 |
rfc3986-validator = []
|
| 2207 |
rich = []
|
|
|
|
|
|
|
| 2208 |
semver = []
|
| 2209 |
send2trash = [
|
| 2210 |
{file = "Send2Trash-1.8.0-py3-none-any.whl", hash = "sha256:f20eaadfdb517eaca5ce077640cb261c7d2698385a6a0f072a4a5447fd49fa08"},
|
|
@@ -2225,6 +2340,8 @@ stack-data = []
|
|
| 2225 |
stmol = []
|
| 2226 |
streamlit = []
|
| 2227 |
sympy = []
|
|
|
|
|
|
|
| 2228 |
terminado = []
|
| 2229 |
tinycss2 = []
|
| 2230 |
tokenizers = []
|
|
|
|
| 171 |
optional = false
|
| 172 |
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
|
| 173 |
|
| 174 |
+
[[package]]
|
| 175 |
+
name = "boto3"
|
| 176 |
+
version = "1.26.95"
|
| 177 |
+
description = "The AWS SDK for Python"
|
| 178 |
+
category = "main"
|
| 179 |
+
optional = false
|
| 180 |
+
python-versions = ">= 3.7"
|
| 181 |
+
|
| 182 |
+
[package.dependencies]
|
| 183 |
+
botocore = ">=1.29.95,<1.30.0"
|
| 184 |
+
jmespath = ">=0.7.1,<2.0.0"
|
| 185 |
+
s3transfer = ">=0.6.0,<0.7.0"
|
| 186 |
+
|
| 187 |
+
[package.extras]
|
| 188 |
+
crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
|
| 189 |
+
|
| 190 |
+
[[package]]
|
| 191 |
+
name = "botocore"
|
| 192 |
+
version = "1.29.95"
|
| 193 |
+
description = "Low-level, data-driven core of boto 3."
|
| 194 |
+
category = "main"
|
| 195 |
+
optional = false
|
| 196 |
+
python-versions = ">= 3.7"
|
| 197 |
+
|
| 198 |
+
[package.dependencies]
|
| 199 |
+
jmespath = ">=0.7.1,<2.0.0"
|
| 200 |
+
python-dateutil = ">=2.1,<3.0.0"
|
| 201 |
+
urllib3 = ">=1.25.4,<1.27"
|
| 202 |
+
|
| 203 |
+
[package.extras]
|
| 204 |
+
crt = ["awscrt (==0.16.9)"]
|
| 205 |
+
|
| 206 |
[[package]]
|
| 207 |
name = "cachetools"
|
| 208 |
version = "5.3.0"
|
|
|
|
| 604 |
[package.extras]
|
| 605 |
i18n = ["Babel (>=2.7)"]
|
| 606 |
|
| 607 |
+
[[package]]
|
| 608 |
+
name = "jmespath"
|
| 609 |
+
version = "1.0.1"
|
| 610 |
+
description = "JSON Matching Expressions"
|
| 611 |
+
category = "main"
|
| 612 |
+
optional = false
|
| 613 |
+
python-versions = ">=3.7"
|
| 614 |
+
|
| 615 |
[[package]]
|
| 616 |
name = "jsonpointer"
|
| 617 |
version = "2.3"
|
|
|
|
| 789 |
optional = false
|
| 790 |
python-versions = "*"
|
| 791 |
|
| 792 |
+
[[package]]
|
| 793 |
+
name = "lmdb"
|
| 794 |
+
version = "1.4.0"
|
| 795 |
+
description = "Universal Python binding for the LMDB 'Lightning' Database"
|
| 796 |
+
category = "main"
|
| 797 |
+
optional = false
|
| 798 |
+
python-versions = "*"
|
| 799 |
+
|
| 800 |
[[package]]
|
| 801 |
name = "markdown-it-py"
|
| 802 |
version = "2.2.0"
|
|
|
|
| 1522 |
[package.extras]
|
| 1523 |
jupyter = ["ipywidgets (>=7.5.1,<9)"]
|
| 1524 |
|
| 1525 |
+
[[package]]
|
| 1526 |
+
name = "s3transfer"
|
| 1527 |
+
version = "0.6.0"
|
| 1528 |
+
description = "An Amazon S3 Transfer Manager"
|
| 1529 |
+
category = "main"
|
| 1530 |
+
optional = false
|
| 1531 |
+
python-versions = ">= 3.7"
|
| 1532 |
+
|
| 1533 |
+
[package.dependencies]
|
| 1534 |
+
botocore = ">=1.12.36,<2.0a.0"
|
| 1535 |
+
|
| 1536 |
+
[package.extras]
|
| 1537 |
+
crt = ["botocore[crt] (>=1.20.29,<2.0a.0)"]
|
| 1538 |
+
|
| 1539 |
+
[[package]]
|
| 1540 |
+
name = "scipy"
|
| 1541 |
+
version = "1.9.3"
|
| 1542 |
+
description = "Fundamental algorithms for scientific computing in Python"
|
| 1543 |
+
category = "main"
|
| 1544 |
+
optional = false
|
| 1545 |
+
python-versions = ">=3.8"
|
| 1546 |
+
|
| 1547 |
+
[package.dependencies]
|
| 1548 |
+
numpy = ">=1.18.5,<1.26.0"
|
| 1549 |
+
|
| 1550 |
+
[package.extras]
|
| 1551 |
+
test = ["pytest", "pytest-cov", "pytest-xdist", "asv", "mpmath", "gmpy2", "threadpoolctl", "scikit-umfpack"]
|
| 1552 |
+
doc = ["sphinx (!=4.1.0)", "pydata-sphinx-theme (==0.9.0)", "sphinx-panels (>=0.5.2)", "matplotlib (>2)", "numpydoc", "sphinx-tabs"]
|
| 1553 |
+
dev = ["mypy", "typing-extensions", "pycodestyle", "flake8"]
|
| 1554 |
+
|
| 1555 |
[[package]]
|
| 1556 |
name = "semver"
|
| 1557 |
version = "2.13.0"
|
|
|
|
| 1691 |
[package.dependencies]
|
| 1692 |
mpmath = ">=0.19"
|
| 1693 |
|
| 1694 |
+
[[package]]
|
| 1695 |
+
name = "tape-proteins"
|
| 1696 |
+
version = "0.5"
|
| 1697 |
+
description = "Repostory of Protein Benchmarking and Modeling"
|
| 1698 |
+
category = "main"
|
| 1699 |
+
optional = false
|
| 1700 |
+
python-versions = "*"
|
| 1701 |
+
|
| 1702 |
+
[package.dependencies]
|
| 1703 |
+
biopython = "*"
|
| 1704 |
+
boto3 = "*"
|
| 1705 |
+
filelock = "*"
|
| 1706 |
+
lmdb = "*"
|
| 1707 |
+
requests = "*"
|
| 1708 |
+
scipy = "*"
|
| 1709 |
+
tensorboardX = "*"
|
| 1710 |
+
tqdm = "*"
|
| 1711 |
+
|
| 1712 |
+
[[package]]
|
| 1713 |
+
name = "tensorboardx"
|
| 1714 |
+
version = "2.6"
|
| 1715 |
+
description = "TensorBoardX lets you watch Tensors Flow without Tensorflow"
|
| 1716 |
+
category = "main"
|
| 1717 |
+
optional = false
|
| 1718 |
+
python-versions = "*"
|
| 1719 |
+
|
| 1720 |
+
[package.dependencies]
|
| 1721 |
+
numpy = "*"
|
| 1722 |
+
packaging = "*"
|
| 1723 |
+
protobuf = ">=3.8.0,<4"
|
| 1724 |
+
|
| 1725 |
[[package]]
|
| 1726 |
name = "terminado"
|
| 1727 |
version = "0.17.1"
|
|
|
|
| 2092 |
[metadata]
|
| 2093 |
lock-version = "1.1"
|
| 2094 |
python-versions = "^3.10"
|
| 2095 |
+
content-hash = "ad6054ae4a119d961e9941f135489d1b89310303aefc27d3132fbd1ed1c35a0f"
|
| 2096 |
|
| 2097 |
[metadata.files]
|
| 2098 |
altair = []
|
|
|
|
| 2136 |
biopython = []
|
| 2137 |
bleach = []
|
| 2138 |
blinker = []
|
| 2139 |
+
boto3 = []
|
| 2140 |
+
botocore = []
|
| 2141 |
cachetools = []
|
| 2142 |
certifi = []
|
| 2143 |
cffi = []
|
|
|
|
| 2182 |
isoduration = []
|
| 2183 |
jedi = []
|
| 2184 |
jinja2 = []
|
| 2185 |
+
jmespath = []
|
| 2186 |
jsonpointer = []
|
| 2187 |
jsonschema = []
|
| 2188 |
jupyter-client = []
|
|
|
|
| 2194 |
jupyterlab-pygments = []
|
| 2195 |
jupyterlab-widgets = []
|
| 2196 |
lit = []
|
| 2197 |
+
lmdb = []
|
| 2198 |
markdown-it-py = []
|
| 2199 |
markupsafe = []
|
| 2200 |
matplotlib-inline = []
|
|
|
|
| 2318 |
rfc3339-validator = []
|
| 2319 |
rfc3986-validator = []
|
| 2320 |
rich = []
|
| 2321 |
+
s3transfer = []
|
| 2322 |
+
scipy = []
|
| 2323 |
semver = []
|
| 2324 |
send2trash = [
|
| 2325 |
{file = "Send2Trash-1.8.0-py3-none-any.whl", hash = "sha256:f20eaadfdb517eaca5ce077640cb261c7d2698385a6a0f072a4a5447fd49fa08"},
|
|
|
|
| 2340 |
stmol = []
|
| 2341 |
streamlit = []
|
| 2342 |
sympy = []
|
| 2343 |
+
tape-proteins = []
|
| 2344 |
+
tensorboardx = []
|
| 2345 |
terminado = []
|
| 2346 |
tinycss2 = []
|
| 2347 |
tokenizers = []
|
protention/attention.py
CHANGED
|
@@ -1,11 +1,16 @@
|
|
|
|
|
| 1 |
from io import StringIO
|
| 2 |
from urllib import request
|
| 3 |
|
| 4 |
import torch
|
| 5 |
from Bio.PDB import PDBParser, Polypeptide, Structure
|
|
|
|
| 6 |
from transformers import T5EncoderModel, T5Tokenizer
|
| 7 |
|
| 8 |
|
|
|
|
|
|
|
|
|
|
| 9 |
def get_structure(pdb_code: str) -> Structure:
|
| 10 |
"""
|
| 11 |
Get structure from PDB
|
|
@@ -46,9 +51,14 @@ def get_protT5() -> tuple[T5Tokenizer, T5EncoderModel]:
|
|
| 46 |
|
| 47 |
return tokenizer, model
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
def get_attention(
|
| 51 |
-
pdb_code: str,
|
| 52 |
):
|
| 53 |
"""
|
| 54 |
Get attention from T5
|
|
@@ -57,13 +67,22 @@ def get_attention(
|
|
| 57 |
structure = get_structure(pdb_code)
|
| 58 |
# Get list of sequences
|
| 59 |
sequences = get_sequences(structure)
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
from io import StringIO
|
| 3 |
from urllib import request
|
| 4 |
|
| 5 |
import torch
|
| 6 |
from Bio.PDB import PDBParser, Polypeptide, Structure
|
| 7 |
+
from tape import ProteinBertModel, TAPETokenizer
|
| 8 |
from transformers import T5EncoderModel, T5Tokenizer
|
| 9 |
|
| 10 |
|
| 11 |
+
class Model(str, Enum):
|
| 12 |
+
tape_bert = "bert-base"
|
| 13 |
+
|
| 14 |
def get_structure(pdb_code: str) -> Structure:
|
| 15 |
"""
|
| 16 |
Get structure from PDB
|
|
|
|
| 51 |
|
| 52 |
return tokenizer, model
|
| 53 |
|
| 54 |
+
def get_tape_bert() -> tuple[TAPETokenizer, ProteinBertModel]:
|
| 55 |
+
tokenizer = TAPETokenizer()
|
| 56 |
+
model = ProteinBertModel.from_pretrained('bert-base', output_attentions=True)
|
| 57 |
+
return tokenizer, model
|
| 58 |
+
|
| 59 |
|
| 60 |
def get_attention(
|
| 61 |
+
pdb_code: str, model: Model = Model.tape_bert
|
| 62 |
):
|
| 63 |
"""
|
| 64 |
Get attention from T5
|
|
|
|
| 67 |
structure = get_structure(pdb_code)
|
| 68 |
# Get list of sequences
|
| 69 |
sequences = get_sequences(structure)
|
| 70 |
+
# TODO handle multiple sequences
|
| 71 |
+
sequence = sequences[0]
|
| 72 |
+
|
| 73 |
+
match model:
|
| 74 |
+
case model.tape_bert:
|
| 75 |
+
tokenizer, model = get_tape_bert()
|
| 76 |
+
token_idxs = tokenizer.encode(sequence).tolist()
|
| 77 |
+
inputs = torch.tensor(token_idxs).unsqueeze(0)
|
| 78 |
+
with torch.no_grad():
|
| 79 |
+
attns = model(inputs)[-1]
|
| 80 |
+
# Remove attention from <CLS> (first) and <SEP> (last) token
|
| 81 |
+
attns = [attn[:, :, 1:-1, 1:-1] for attn in attns]
|
| 82 |
+
attns = torch.stack([attn.squeeze(0) for attn in attns])
|
| 83 |
+
case model.prot_T5:
|
| 84 |
+
# Space separate sequences
|
| 85 |
+
sequences = [" ".join(sequence) for sequence in sequences]
|
| 86 |
+
tokenizer, model = get_protT5()
|
| 87 |
+
|
| 88 |
+
return attns
|
pyproject.toml
CHANGED
|
@@ -12,6 +12,7 @@ biopython = "^1.81"
|
|
| 12 |
transformers = "^4.27.1"
|
| 13 |
torch = "^2.0.0"
|
| 14 |
sentencepiece = "^0.1.97"
|
|
|
|
| 15 |
|
| 16 |
[tool.poetry.dev-dependencies]
|
| 17 |
pytest = "^7.2.2"
|
|
|
|
| 12 |
transformers = "^4.27.1"
|
| 13 |
torch = "^2.0.0"
|
| 14 |
sentencepiece = "^0.1.97"
|
| 15 |
+
tape-proteins = "^0.5"
|
| 16 |
|
| 17 |
[tool.poetry.dev-dependencies]
|
| 18 |
pytest = "^7.2.2"
|
tests/test_attention.py
CHANGED
|
@@ -1,7 +1,9 @@
|
|
|
|
|
| 1 |
from Bio.PDB.Structure import Structure
|
| 2 |
from transformers import T5EncoderModel, T5Tokenizer
|
| 3 |
|
| 4 |
-
from protention.attention import
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
def test_get_structure():
|
|
@@ -33,3 +35,10 @@ def test_get_protT5():
|
|
| 33 |
|
| 34 |
assert isinstance(tokenizer, T5Tokenizer)
|
| 35 |
assert isinstance(model, T5EncoderModel)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
from Bio.PDB.Structure import Structure
|
| 3 |
from transformers import T5EncoderModel, T5Tokenizer
|
| 4 |
|
| 5 |
+
from protention.attention import (Model, get_attention, get_protT5,
|
| 6 |
+
get_sequences, get_structure)
|
| 7 |
|
| 8 |
|
| 9 |
def test_get_structure():
|
|
|
|
| 35 |
|
| 36 |
assert isinstance(tokenizer, T5Tokenizer)
|
| 37 |
assert isinstance(model, T5EncoderModel)
|
| 38 |
+
|
| 39 |
+
def test_get_attention_tape():
|
| 40 |
+
|
| 41 |
+
result = get_attention("1AKE", model=Model.tape_bert)
|
| 42 |
+
|
| 43 |
+
assert result is not None
|
| 44 |
+
assert result.shape == torch.Size([12,12,456,456])
|