Spaces:
Running
Running
E1 model addition
Browse files
app.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
| 1 |
-
from gradio import Blocks, Button, Checkbox, DataFrame, DownloadButton, Dropdown, Examples, Image, Markdown, Tab, Textbox
|
| 2 |
|
| 3 |
-
from model import
|
| 4 |
from data import Data
|
| 5 |
|
| 6 |
# Define scoring strategies
|
| 7 |
SCORING = ["wt-marginals", "masked-marginals"]
|
| 8 |
|
| 9 |
# Get available models
|
| 10 |
-
MODELS =
|
| 11 |
|
| 12 |
def app(*argv):
|
| 13 |
"""
|
|
@@ -17,12 +17,15 @@ def app(*argv):
|
|
| 17 |
seq, trg, model_name, *_ = argv
|
| 18 |
scoring = SCORING[scoring_strategy.value]
|
| 19 |
# Calculate the data based on the input parameters
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
return *out, DownloadButton(value=data.csv(), visible=True)
|
| 28 |
|
|
@@ -32,58 +35,52 @@ with Blocks() as esm_scan:
|
|
| 32 |
# Define the interface components
|
| 33 |
with Tab("App"):
|
| 34 |
Markdown(open("header.md", "r", encoding="utf-8").read())
|
| 35 |
-
seq = Textbox(
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
value=""
|
| 46 |
-
)
|
| 47 |
model_name = Dropdown(MODELS, label="Model", value="facebook/esm2_t30_150M_UR50D")
|
| 48 |
scoring_strategy = Checkbox(value=True, label="Use higher accuracy scoring", interactive=True)
|
| 49 |
btn = Button(value="Run", variant="primary")
|
| 50 |
dlb = DownloadButton(label="Download raw data", visible=False)
|
| 51 |
out = Image(visible=False)
|
| 52 |
ouu = DataFrame(visible=False)
|
| 53 |
-
btn.click(
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
)
|
| 58 |
ex = Examples(
|
| 59 |
examples=[
|
| 60 |
-
[
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
"facebook/esm2_t6_8M_UR50D"
|
| 64 |
],
|
| 65 |
-
[
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
"facebook/esm2_t12_35M_UR50D"
|
| 69 |
],
|
| 70 |
-
[
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
"facebook/esm2_t30_150M_UR50D",
|
| 74 |
],
|
| 75 |
-
[
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
"facebook/esm2_t33_650M_UR50D",
|
| 79 |
],
|
| 80 |
-
]
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
|
|
|
| 87 |
)
|
| 88 |
with Tab("Instructions"):
|
| 89 |
Markdown(open("instructions.md", "r", encoding="utf-8").read())
|
|
|
|
| 1 |
+
from gradio import Blocks, Button, Checkbox, DataFrame, DownloadButton, Dropdown, Error, Examples, Image, Markdown, Tab, Textbox
|
| 2 |
|
| 3 |
+
from model import ModelFactory
|
| 4 |
from data import Data
|
| 5 |
|
| 6 |
# Define scoring strategies
|
| 7 |
SCORING = ["wt-marginals", "masked-marginals"]
|
| 8 |
|
| 9 |
# Get available models
|
| 10 |
+
MODELS = ModelFactory.models()
|
| 11 |
|
| 12 |
def app(*argv):
|
| 13 |
"""
|
|
|
|
| 17 |
seq, trg, model_name, *_ = argv
|
| 18 |
scoring = SCORING[scoring_strategy.value]
|
| 19 |
# Calculate the data based on the input parameters
|
| 20 |
+
try:
|
| 21 |
+
data = Data(seq, trg, model_name, scoring).calculate()
|
| 22 |
+
if isinstance(data.image(), str):
|
| 23 |
+
out = Image(value=data.image(), type='filepath', visible=True), DataFrame(visible=False)
|
| 24 |
+
else:
|
| 25 |
+
out = Image(visible=False), DataFrame(value=data.image(), visible=True)
|
| 26 |
+
except Exception as e:
|
| 27 |
+
out = Image(visible=False), DataFrame(visible=False)
|
| 28 |
+
raise Error(str(e))
|
| 29 |
|
| 30 |
return *out, DownloadButton(value=data.csv(), visible=True)
|
| 31 |
|
|
|
|
| 35 |
# Define the interface components
|
| 36 |
with Tab("App"):
|
| 37 |
Markdown(open("header.md", "r", encoding="utf-8").read())
|
| 38 |
+
seq = Textbox( lines=2
|
| 39 |
+
, label="Sequence"
|
| 40 |
+
, placeholder="FASTA sequence here..."
|
| 41 |
+
, value=''
|
| 42 |
+
)
|
| 43 |
+
trg = Textbox( lines=1
|
| 44 |
+
, label="Substitutions"
|
| 45 |
+
, placeholder="Substitutions here..."
|
| 46 |
+
, value=""
|
| 47 |
+
)
|
|
|
|
|
|
|
| 48 |
model_name = Dropdown(MODELS, label="Model", value="facebook/esm2_t30_150M_UR50D")
|
| 49 |
scoring_strategy = Checkbox(value=True, label="Use higher accuracy scoring", interactive=True)
|
| 50 |
btn = Button(value="Run", variant="primary")
|
| 51 |
dlb = DownloadButton(label="Download raw data", visible=False)
|
| 52 |
out = Image(visible=False)
|
| 53 |
ouu = DataFrame(visible=False)
|
| 54 |
+
btn.click( fn=app
|
| 55 |
+
, inputs=[seq, trg, model_name]
|
| 56 |
+
, outputs=[out, ouu, dlb]
|
| 57 |
+
)
|
|
|
|
| 58 |
ex = Examples(
|
| 59 |
examples=[
|
| 60 |
+
[ "MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ"
|
| 61 |
+
, "deep mutational scanning"
|
| 62 |
+
, "facebook/esm2_t6_8M_UR50D"
|
|
|
|
| 63 |
],
|
| 64 |
+
[ "MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ"
|
| 65 |
+
, "217 218 219"
|
| 66 |
+
, "facebook/esm2_t12_35M_UR50D"
|
|
|
|
| 67 |
],
|
| 68 |
+
[ "MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ"
|
| 69 |
+
, "R218K R218S R218N R218A R218V R218D"
|
| 70 |
+
, "facebook/esm2_t30_150M_UR50D"
|
|
|
|
| 71 |
],
|
| 72 |
+
[ "MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ"
|
| 73 |
+
, "MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMWGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ"
|
| 74 |
+
, "facebook/esm2_t33_650M_UR50D"
|
|
|
|
| 75 |
],
|
| 76 |
+
]
|
| 77 |
+
, inputs=[ seq
|
| 78 |
+
, trg
|
| 79 |
+
, model_name
|
| 80 |
+
]
|
| 81 |
+
, outputs=[out]
|
| 82 |
+
, fn=app
|
| 83 |
+
, cache_examples=False
|
| 84 |
)
|
| 85 |
with Tab("Instructions"):
|
| 86 |
Markdown(open("instructions.md", "r", encoding="utf-8").read())
|
data.py
CHANGED
|
@@ -4,18 +4,15 @@ import pandas as pd
|
|
| 4 |
from re import match
|
| 5 |
import seaborn as sns
|
| 6 |
|
| 7 |
-
from model import
|
| 8 |
|
| 9 |
class Data:
|
| 10 |
"""Container for input and output data"""
|
| 11 |
-
# Initialise empty model as static class member for efficiency
|
| 12 |
-
model = Model()
|
| 13 |
-
|
| 14 |
def parse_seq(self, src: str):
|
| 15 |
"""Parse input sequence"""
|
| 16 |
self.seq = src.strip().upper().replace('\n', '')
|
| 17 |
if not all(x in self.model.alphabet for x in self.seq):
|
| 18 |
-
raise RuntimeError("
|
| 19 |
|
| 20 |
def parse_sub(self, trg: str):
|
| 21 |
"""Parse input substitutions"""
|
|
@@ -36,34 +33,42 @@ class Data:
|
|
| 36 |
if all(match(r'\d+', x) for x in self.trg):
|
| 37 |
# If all strings are numbers, deep mutational scanning mode
|
| 38 |
self.mode = 'DMS'
|
|
|
|
| 39 |
for resi in map(int, self.trg):
|
| 40 |
src = self.seq[resi-1]
|
| 41 |
for trg in "ACDEFGHIKLMNPQRSTVWY".replace(src, ''):
|
| 42 |
self.sub.append(f"{src}{resi}{trg}")
|
|
|
|
| 43 |
self.resi.append(resi)
|
|
|
|
| 44 |
elif all(match(r'[A-Z]\d+[A-Z]', x) for x in self.trg):
|
| 45 |
# If all strings are of the form X#Y, single substitution mode
|
| 46 |
self.mode = 'MUT'
|
| 47 |
self.sub = self.trg
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
for s, *resi, _ in self.trg:
|
| 50 |
if self.seq[int(''.join(resi))-1] != s:
|
| 51 |
-
raise RuntimeError(f"Unrecognised input substitution {self.seq[int(''.join(resi))]}{int(''.join(resi))} /= {s}{int(''.join(resi))}")
|
|
|
|
| 52 |
else:
|
| 53 |
self.mode = 'TMS'
|
|
|
|
| 54 |
for resi, src in enumerate(self.seq, 1):
|
| 55 |
for trg in "ACDEFGHIKLMNPQRSTVWY".replace(src, ''):
|
| 56 |
self.sub.append(f"{src}{resi}{trg}")
|
|
|
|
| 57 |
self.resi.append(resi)
|
| 58 |
|
| 59 |
self.sub = pd.DataFrame(self.sub, columns=['0'])
|
| 60 |
|
| 61 |
def __init__(self, src:str, trg:str, model_name:str='facebook/esm2_t33_650M_UR50D', scoring_strategy:str='masked-marginals', out_file='out'):
|
| 62 |
"initialise data"
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
self.model_name = model_name
|
| 66 |
-
self.model = Model(model_name)
|
| 67 |
self.parse_seq(src)
|
| 68 |
self.offset = 0
|
| 69 |
self.parse_sub(trg)
|
|
@@ -101,8 +106,8 @@ class Data:
|
|
| 101 |
.groupby(['resi'])
|
| 102 |
.head(19)
|
| 103 |
.drop(['resi'], axis=1))
|
| 104 |
-
self.out = pd.concat([self.out.iloc[19*x:19*(x+1)].reset_index(drop=True) for x in range(self.out.shape[0]//19)]
|
| 105 |
-
|
| 106 |
|
| 107 |
def process_tms_mode(self):
|
| 108 |
self.out = self.assign_resi_and_group()
|
|
@@ -124,8 +129,8 @@ class Data:
|
|
| 124 |
.pipe(self.create_dataframe)
|
| 125 |
.sort_values(['0'], ascending=[True])
|
| 126 |
.drop(['resi', '0'], axis=1)
|
| 127 |
-
.set_axis(['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L'
|
| 128 |
-
|
| 129 |
.astype(float)
|
| 130 |
) for x in range(self.out.shape[0]//19)]
|
| 131 |
, axis=1)
|
|
@@ -152,16 +157,16 @@ class Data:
|
|
| 152 |
|
| 153 |
def plot_single_heatmap(self):
|
| 154 |
fig = plt.figure(figsize=(12, 6))
|
| 155 |
-
sns.heatmap(self.out
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
fig.tight_layout()
|
| 166 |
|
| 167 |
def plot_multiple_heatmaps(self, ncols, nrows):
|
|
@@ -169,17 +174,17 @@ class Data:
|
|
| 169 |
for i in range(nrows):
|
| 170 |
tmp = self.out.iloc[:,i*ncols:(i+1)*ncols]
|
| 171 |
label = tmp.map(lambda x: ' ' if x != 0 else '·')
|
| 172 |
-
sns.heatmap(tmp
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
ax[i].set_yticklabels(ax[i].get_yticklabels(), rotation=0)
|
| 184 |
ax[i].set_xticklabels(ax[i].get_xticklabels(), rotation=90)
|
| 185 |
fig.tight_layout()
|
|
|
|
| 4 |
from re import match
|
| 5 |
import seaborn as sns
|
| 6 |
|
| 7 |
+
from model import ModelFactory
|
| 8 |
|
| 9 |
class Data:
|
| 10 |
"""Container for input and output data"""
|
|
|
|
|
|
|
|
|
|
| 11 |
def parse_seq(self, src: str):
|
| 12 |
"""Parse input sequence"""
|
| 13 |
self.seq = src.strip().upper().replace('\n', '')
|
| 14 |
if not all(x in self.model.alphabet for x in self.seq):
|
| 15 |
+
raise RuntimeError(f"Unsupported characters in sequence: {''.join(x for x in self.seq if x not in self.model.alphabet)}")
|
| 16 |
|
| 17 |
def parse_sub(self, trg: str):
|
| 18 |
"""Parse input substitutions"""
|
|
|
|
| 33 |
if all(match(r'\d+', x) for x in self.trg):
|
| 34 |
# If all strings are numbers, deep mutational scanning mode
|
| 35 |
self.mode = 'DMS'
|
| 36 |
+
trh = list()
|
| 37 |
for resi in map(int, self.trg):
|
| 38 |
src = self.seq[resi-1]
|
| 39 |
for trg in "ACDEFGHIKLMNPQRSTVWY".replace(src, ''):
|
| 40 |
self.sub.append(f"{src}{resi}{trg}")
|
| 41 |
+
trh.append(self.seq[:resi-1]+trg+self.seq[resi:])
|
| 42 |
self.resi.append(resi)
|
| 43 |
+
self.trg = trh
|
| 44 |
elif all(match(r'[A-Z]\d+[A-Z]', x) for x in self.trg):
|
| 45 |
# If all strings are of the form X#Y, single substitution mode
|
| 46 |
self.mode = 'MUT'
|
| 47 |
self.sub = self.trg
|
| 48 |
+
trh = list()
|
| 49 |
+
for x in self.trg:
|
| 50 |
+
idx = int(x[1:-1])
|
| 51 |
+
self.resi.append(idx)
|
| 52 |
+
trh.append(self.seq[:idx-1]+x[-1]+self.seq[idx:])
|
| 53 |
for s, *resi, _ in self.trg:
|
| 54 |
if self.seq[int(''.join(resi))-1] != s:
|
| 55 |
+
raise RuntimeError(f"Unrecognised input substitution: {self.seq[int(''.join(resi))]}{int(''.join(resi))} /= {s}{int(''.join(resi))}")
|
| 56 |
+
self.trg = trh
|
| 57 |
else:
|
| 58 |
self.mode = 'TMS'
|
| 59 |
+
self.trg = list()
|
| 60 |
for resi, src in enumerate(self.seq, 1):
|
| 61 |
for trg in "ACDEFGHIKLMNPQRSTVWY".replace(src, ''):
|
| 62 |
self.sub.append(f"{src}{resi}{trg}")
|
| 63 |
+
self.trg.append(self.seq[:resi-1]+trg+self.seq[resi:])
|
| 64 |
self.resi.append(resi)
|
| 65 |
|
| 66 |
self.sub = pd.DataFrame(self.sub, columns=['0'])
|
| 67 |
|
| 68 |
def __init__(self, src:str, trg:str, model_name:str='facebook/esm2_t33_650M_UR50D', scoring_strategy:str='masked-marginals', out_file='out'):
|
| 69 |
"initialise data"
|
| 70 |
+
self.model_name = model_name
|
| 71 |
+
self.model = ModelFactory(model_name)
|
|
|
|
|
|
|
| 72 |
self.parse_seq(src)
|
| 73 |
self.offset = 0
|
| 74 |
self.parse_sub(trg)
|
|
|
|
| 106 |
.groupby(['resi'])
|
| 107 |
.head(19)
|
| 108 |
.drop(['resi'], axis=1))
|
| 109 |
+
self.out = pd.concat([ self.out.iloc[19*x:19*(x+1)].reset_index(drop=True) for x in range(self.out.shape[0]//19)]
|
| 110 |
+
, axis=1).set_axis(range(self.out.shape[0]//19*2), axis='columns')
|
| 111 |
|
| 112 |
def process_tms_mode(self):
|
| 113 |
self.out = self.assign_resi_and_group()
|
|
|
|
| 129 |
.pipe(self.create_dataframe)
|
| 130 |
.sort_values(['0'], ascending=[True])
|
| 131 |
.drop(['resi', '0'], axis=1)
|
| 132 |
+
.set_axis([ 'A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L'
|
| 133 |
+
, 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y'])
|
| 134 |
.astype(float)
|
| 135 |
) for x in range(self.out.shape[0]//19)]
|
| 136 |
, axis=1)
|
|
|
|
| 157 |
|
| 158 |
def plot_single_heatmap(self):
|
| 159 |
fig = plt.figure(figsize=(12, 6))
|
| 160 |
+
sns.heatmap( self.out
|
| 161 |
+
, cmap='RdBu'
|
| 162 |
+
, cbar=False
|
| 163 |
+
, square=True
|
| 164 |
+
, xticklabels=1
|
| 165 |
+
, yticklabels=1
|
| 166 |
+
, center=0
|
| 167 |
+
, annot=self.out.map(lambda x: ' ' if x != 0 else '·')
|
| 168 |
+
, fmt='s'
|
| 169 |
+
, annot_kws={'size': 'xx-large'})
|
| 170 |
fig.tight_layout()
|
| 171 |
|
| 172 |
def plot_multiple_heatmaps(self, ncols, nrows):
|
|
|
|
| 174 |
for i in range(nrows):
|
| 175 |
tmp = self.out.iloc[:,i*ncols:(i+1)*ncols]
|
| 176 |
label = tmp.map(lambda x: ' ' if x != 0 else '·')
|
| 177 |
+
sns.heatmap( tmp
|
| 178 |
+
, ax=ax[i]
|
| 179 |
+
, cmap='RdBu'
|
| 180 |
+
, cbar=False
|
| 181 |
+
, square=True
|
| 182 |
+
, xticklabels=1
|
| 183 |
+
, yticklabels=1
|
| 184 |
+
, center=0
|
| 185 |
+
, annot=label
|
| 186 |
+
, fmt='s'
|
| 187 |
+
, annot_kws={'size': 'xx-large'})
|
| 188 |
ax[i].set_yticklabels(ax[i].get_yticklabels(), rotation=0)
|
| 189 |
ax[i].set_xticklabels(ax[i].get_xticklabels(), rotation=90)
|
| 190 |
fig.tight_layout()
|
model.py
CHANGED
|
@@ -1,31 +1,17 @@
|
|
| 1 |
from huggingface_hub import HfApi
|
| 2 |
import torch
|
| 3 |
from tqdm import tqdm
|
|
|
|
| 4 |
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
| 5 |
from transformers.tokenization_utils_base import BatchEncoding
|
| 6 |
from transformers.modeling_outputs import MaskedLMOutput
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
def get_models() -> list[None|str]:
|
| 10 |
-
"""Fetch suitable ESM models from HuggingFace Hub."""
|
| 11 |
-
if not any(
|
| 12 |
-
out := [
|
| 13 |
-
m.modelId for m in HfApi().list_models(
|
| 14 |
-
author="facebook",
|
| 15 |
-
model_name="esm",
|
| 16 |
-
task="fill-mask",
|
| 17 |
-
sort="lastModified",
|
| 18 |
-
direction=-1
|
| 19 |
-
)
|
| 20 |
-
]
|
| 21 |
-
):
|
| 22 |
-
raise RuntimeError("Error while retrieving models from HuggingFace Hub")
|
| 23 |
-
return out
|
| 24 |
|
| 25 |
# Class to wrap ESM models
|
| 26 |
-
class
|
| 27 |
"""Wrapper for ESM models."""
|
| 28 |
-
def __init__(self, model_name:
|
| 29 |
"""Load selected model and tokenizer."""
|
| 30 |
self.model_name = model_name
|
| 31 |
if model_name:
|
|
@@ -95,9 +81,63 @@ class Model:
|
|
| 95 |
|
| 96 |
# Apply the label_row function to each row of the substitutions dataframe
|
| 97 |
data.out[self.model_name] = data.sub.apply(
|
| 98 |
-
lambda row: label_row(
|
| 99 |
-
row['0']
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from huggingface_hub import HfApi
|
| 2 |
import torch
|
| 3 |
from tqdm import tqdm
|
| 4 |
+
from typing import Any
|
| 5 |
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
| 6 |
from transformers.tokenization_utils_base import BatchEncoding
|
| 7 |
from transformers.modeling_outputs import MaskedLMOutput
|
| 8 |
+
from E1.modeling import E1ForMaskedLM
|
| 9 |
+
from E1.scorer import E1Scorer, EncoderScoreMethod
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
# Class to wrap ESM models
|
| 12 |
+
class ESMModel:
|
| 13 |
"""Wrapper for ESM models."""
|
| 14 |
+
def __init__(self, model_name:str):
|
| 15 |
"""Load selected model and tokenizer."""
|
| 16 |
self.model_name = model_name
|
| 17 |
if model_name:
|
|
|
|
| 81 |
|
| 82 |
# Apply the label_row function to each row of the substitutions dataframe
|
| 83 |
data.out[self.model_name] = data.sub.apply(
|
| 84 |
+
lambda row: label_row(
|
| 85 |
+
row['0']
|
| 86 |
+
, token_probs
|
| 87 |
+
)
|
| 88 |
+
, axis=1
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# Class to wrap E1 models
|
| 92 |
+
class E1Model:
|
| 93 |
+
def __init__(self, model_name:str):
|
| 94 |
+
self.model_name = model_name
|
| 95 |
+
self.scoring_strategy = EncoderScoreMethod.MASKED_MARGINAL
|
| 96 |
+
if model_name:
|
| 97 |
+
self.model = E1ForMaskedLM.from_pretrained(model_name, dtype=torch.float)
|
| 98 |
+
if torch.cuda.is_available():
|
| 99 |
+
self.model = self.model.cuda()
|
| 100 |
+
self.device = torch.device("cuda")
|
| 101 |
+
else:
|
| 102 |
+
self.device = torch.device("cpu")
|
| 103 |
+
self.scorer = E1Scorer(self.model, method=self.scoring_strategy)
|
| 104 |
+
self.alphabet = self.scorer.vocab
|
| 105 |
+
|
| 106 |
+
def run_model(self, data):
|
| 107 |
+
if not data.scoring_strategy.startswith("masked-marginals"):
|
| 108 |
+
self.scorer = E1Scorer(self.model, method=EncoderScoreMethod.WILDTYPE_MARGINAL)
|
| 109 |
+
scores = self.scorer.score(parent_sequence=data.seq, sequences=data.trg)
|
| 110 |
+
data.out[self.model_name] = [s['score'] for s in scores]
|
| 111 |
+
|
| 112 |
+
class ModelFactory:
|
| 113 |
+
_models = {
|
| 114 |
+
**{ m.modelId:ESMModel for m in HfApi().list_models(
|
| 115 |
+
author="facebook"
|
| 116 |
+
, model_name="esm"
|
| 117 |
+
, filter="fill-mask"
|
| 118 |
+
, sort="lastModified"
|
| 119 |
+
, direction=-1
|
| 120 |
+
)
|
| 121 |
+
}
|
| 122 |
+
, **{ m.modelId:E1Model for m in HfApi().list_models(
|
| 123 |
+
author="Profluent-Bio"
|
| 124 |
+
, model_name="E1"
|
| 125 |
+
, sort="lastModified"
|
| 126 |
+
, direction=-1
|
| 127 |
+
)
|
| 128 |
+
}
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
@classmethod
|
| 132 |
+
def register(cls, model_name, model_cls):
|
| 133 |
+
cls._models[model_name] = model_cls
|
| 134 |
+
|
| 135 |
+
@classmethod
|
| 136 |
+
def models(cls):
|
| 137 |
+
return [m for m in cls._models.keys()]
|
| 138 |
+
|
| 139 |
+
def __new__(cls, model_name:str) -> Any:
|
| 140 |
+
return cls._models[model_name](model_name)
|
| 141 |
+
|
| 142 |
+
for m,c in ModelFactory._models.items():
|
| 143 |
+
ModelFactory.register(m, c)
|
requirements.txt
CHANGED
|
@@ -3,3 +3,4 @@ pandas
|
|
| 3 |
seaborn
|
| 4 |
torch
|
| 5 |
transformers
|
|
|
|
|
|
| 3 |
seaborn
|
| 4 |
torch
|
| 5 |
transformers
|
| 6 |
+
E1@git+https://github.com/Profluent-AI/E1.git@main#egg=E1
|