Spaces:
Running
Running
progress tracker
Browse files
app.py
CHANGED
|
@@ -1,32 +1,27 @@
|
|
| 1 |
-
from gradio import Blocks, Button, Checkbox, DownloadButton, Dropdown, Error, Examples, Image, HTML, Markdown, Tab, Textbox
|
| 2 |
|
| 3 |
from model import ModelFactory
|
| 4 |
from data import Data
|
| 5 |
|
| 6 |
-
# Define scoring strategies
|
| 7 |
-
SCORING = ["wt-marginals", "masked-marginals"]
|
| 8 |
-
|
| 9 |
# Get available models
|
| 10 |
MODELS = ModelFactory.models()
|
| 11 |
|
| 12 |
-
def app(
|
| 13 |
"Main application function"
|
| 14 |
-
|
| 15 |
-
seq, trg, model_name, *_ = argv
|
| 16 |
-
scoring = SCORING[scoring_strategy.value]
|
| 17 |
|
| 18 |
# Validate the input
|
| 19 |
if 1 > len(seq):
|
| 20 |
raise Error("Sequence cannot be empty")
|
| 21 |
-
if 1 > len(
|
| 22 |
raise Error("Substitutions cannot be empty")
|
| 23 |
|
| 24 |
# Calculate the data based on the input parameters
|
| 25 |
try:
|
| 26 |
-
data = Data(seq,
|
| 27 |
if isinstance(data.image, str):
|
| 28 |
return ( Image(value=data.image, type='filepath', visible=True)
|
| 29 |
-
, HTML(
|
| 30 |
, DownloadButton(value=data.csv, visible=True) )
|
| 31 |
else:
|
| 32 |
return ( Image(visible=False)
|
|
@@ -45,19 +40,20 @@ with Blocks() as demo:
|
|
| 45 |
, label="Sequence"
|
| 46 |
, placeholder="FASTA sequence here..."
|
| 47 |
, value='' )
|
| 48 |
-
|
| 49 |
, label="Substitutions"
|
| 50 |
, placeholder="Substitutions here..."
|
| 51 |
, value='' )
|
| 52 |
-
model_name
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
|
|
|
| 61 |
ex = Examples(
|
| 62 |
examples=[
|
| 63 |
[ "MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ"
|
|
@@ -73,10 +69,8 @@ with Blocks() as demo:
|
|
| 73 |
, "MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMWGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ"
|
| 74 |
, "facebook/esm2_t33_650M_UR50D" ],
|
| 75 |
]
|
| 76 |
-
, inputs=[
|
| 77 |
-
|
| 78 |
-
, model_name ]
|
| 79 |
-
, outputs=[out]
|
| 80 |
, fn=app
|
| 81 |
, cache_examples=False )
|
| 82 |
with Tab("Instructions"):
|
|
|
|
| 1 |
+
from gradio import Blocks, Button, Checkbox, DownloadButton, Dropdown, Error, Examples, Image, HTML, Markdown, Progress, Tab, Textbox
|
| 2 |
|
| 3 |
from model import ModelFactory
|
| 4 |
from data import Data
|
| 5 |
|
|
|
|
|
|
|
|
|
|
| 6 |
# Get available models
|
| 7 |
MODELS = ModelFactory.models()
|
| 8 |
|
| 9 |
+
def app(seq, sub, model_name, acc):
|
| 10 |
"Main application function"
|
| 11 |
+
scoring = "masked-marginals" if acc else "wt-marginals"
|
|
|
|
|
|
|
| 12 |
|
| 13 |
# Validate the input
|
| 14 |
if 1 > len(seq):
|
| 15 |
raise Error("Sequence cannot be empty")
|
| 16 |
+
if 1 > len(sub):
|
| 17 |
raise Error("Substitutions cannot be empty")
|
| 18 |
|
| 19 |
# Calculate the data based on the input parameters
|
| 20 |
try:
|
| 21 |
+
data = Data(seq, sub, model_name, scoring).calculate(progress)
|
| 22 |
if isinstance(data.image, str):
|
| 23 |
return ( Image(value=data.image, type='filepath', visible=True)
|
| 24 |
+
, HTML()
|
| 25 |
, DownloadButton(value=data.csv, visible=True) )
|
| 26 |
else:
|
| 27 |
return ( Image(visible=False)
|
|
|
|
| 40 |
, label="Sequence"
|
| 41 |
, placeholder="FASTA sequence here..."
|
| 42 |
, value='' )
|
| 43 |
+
sub = Textbox( lines=1
|
| 44 |
, label="Substitutions"
|
| 45 |
, placeholder="Substitutions here..."
|
| 46 |
, value='' )
|
| 47 |
+
model_name = Dropdown(MODELS, label="Model", value="facebook/esm2_t30_150M_UR50D")
|
| 48 |
+
acc_box = Checkbox(value=True, label="Use higher accuracy scoring", interactive=True)
|
| 49 |
+
run_btn = Button(value="Run", variant="primary")
|
| 50 |
+
dl_btn = DownloadButton(label="Download raw data", visible=False)
|
| 51 |
+
progress = Progress()
|
| 52 |
+
out_html = HTML()
|
| 53 |
+
out_img = Image(visible=False)
|
| 54 |
+
run_btn.click( fn=app
|
| 55 |
+
, inputs=[seq, sub, model_name, acc_box]
|
| 56 |
+
, outputs=[out_img, out_html, dl_btn] )
|
| 57 |
ex = Examples(
|
| 58 |
examples=[
|
| 59 |
[ "MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ"
|
|
|
|
| 69 |
, "MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMWGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ"
|
| 70 |
, "facebook/esm2_t33_650M_UR50D" ],
|
| 71 |
]
|
| 72 |
+
, inputs=[seq, sub, model_name]
|
| 73 |
+
, outputs=[out_img]
|
|
|
|
|
|
|
| 74 |
, fn=app
|
| 75 |
, cache_examples=False )
|
| 76 |
with Tab("Instructions"):
|
data.py
CHANGED
|
@@ -72,7 +72,7 @@ class Data:
|
|
| 72 |
self.parse_seq(src)
|
| 73 |
self.parse_sub(trg)
|
| 74 |
self.scoring_strategy = scoring_strategy
|
| 75 |
-
self.
|
| 76 |
self.out = pd.DataFrame(self.sub, columns=['0', self.model_name])
|
| 77 |
self.out_img = f"{out_file}.png"
|
| 78 |
self.out_csv = f"{out_file}.csv"
|
|
@@ -121,11 +121,9 @@ class Data:
|
|
| 121 |
def concat_and_set_axis(self):
|
| 122 |
return (pd.concat([(self.out.iloc[19*x:19*(x+1)]
|
| 123 |
.pipe(self.create_dataframe).sort_values(['0'], ascending=[True])
|
| 124 |
-
.drop(["resi", '0'], axis=1)
|
| 125 |
-
.astype(float)
|
| 126 |
.set_axis([ 'A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L'
|
| 127 |
-
, 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y'])) for x in range(self.out.shape[0]//19)]
|
| 128 |
-
, axis=1)
|
| 129 |
.set_axis([f'{a}{i}' for i, a in enumerate(self.seq, 1)], axis="columns"))
|
| 130 |
|
| 131 |
def create_dataframe(self, df):
|
|
@@ -181,8 +179,9 @@ class Data:
|
|
| 181 |
ax[i].set_xticklabels(ax[i].get_xticklabels(), rotation=90)
|
| 182 |
fig.tight_layout()
|
| 183 |
|
| 184 |
-
def calculate(self):
|
| 185 |
"run model and parse output"
|
|
|
|
| 186 |
self.model.run_model(self)
|
| 187 |
self.parse_output()
|
| 188 |
return self
|
|
|
|
| 72 |
self.parse_seq(src)
|
| 73 |
self.parse_sub(trg)
|
| 74 |
self.scoring_strategy = scoring_strategy
|
| 75 |
+
self.progress = None
|
| 76 |
self.out = pd.DataFrame(self.sub, columns=['0', self.model_name])
|
| 77 |
self.out_img = f"{out_file}.png"
|
| 78 |
self.out_csv = f"{out_file}.csv"
|
|
|
|
| 121 |
def concat_and_set_axis(self):
|
| 122 |
return (pd.concat([(self.out.iloc[19*x:19*(x+1)]
|
| 123 |
.pipe(self.create_dataframe).sort_values(['0'], ascending=[True])
|
| 124 |
+
.drop(["resi", '0'], axis=1).astype(float)
|
|
|
|
| 125 |
.set_axis([ 'A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L'
|
| 126 |
+
, 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y'])) for x in range(self.out.shape[0]//19)], axis=1)
|
|
|
|
| 127 |
.set_axis([f'{a}{i}' for i, a in enumerate(self.seq, 1)], axis="columns"))
|
| 128 |
|
| 129 |
def create_dataframe(self, df):
|
|
|
|
| 179 |
ax[i].set_xticklabels(ax[i].get_xticklabels(), rotation=90)
|
| 180 |
fig.tight_layout()
|
| 181 |
|
| 182 |
+
def calculate(self, progress):
|
| 183 |
"run model and parse output"
|
| 184 |
+
self.progress = progress
|
| 185 |
self.model.run_model(self)
|
| 186 |
self.parse_output()
|
| 187 |
return self
|
model.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
from huggingface_hub import HfApi
|
| 2 |
import torch
|
| 3 |
-
from tqdm import tqdm
|
| 4 |
from typing import Any
|
| 5 |
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
| 6 |
from transformers.tokenization_utils_base import BatchEncoding
|
|
@@ -54,7 +53,7 @@ class ESMModel:
|
|
| 54 |
if data.scoring_strategy.startswith("masked-marginals"):
|
| 55 |
all_token_probs = []
|
| 56 |
# For each token in the batch
|
| 57 |
-
for i in tqdm(range(batch_tokens.size()[1])):
|
| 58 |
# If the token is in the list of residues
|
| 59 |
if i in data.resi:
|
| 60 |
# Clone the batch tokens and mask the current token
|
|
@@ -73,9 +72,7 @@ class ESMModel:
|
|
| 73 |
|
| 74 |
# Apply the label_row function to each row of the substitutions dataframe
|
| 75 |
data.out[self.model_name] = data.sub.apply(
|
| 76 |
-
lambda row: label_row(
|
| 77 |
-
row['0']
|
| 78 |
-
, token_probs )
|
| 79 |
, axis=1 )
|
| 80 |
|
| 81 |
class E1Model:
|
|
@@ -96,7 +93,7 @@ class E1Model:
|
|
| 96 |
self.scorer = E1Scorer(self.model, EncoderScoreMethod.WILDTYPE_MARGINAL)
|
| 97 |
batch_size = 60 ## chunking to avoid OOM
|
| 98 |
out = []
|
| 99 |
-
for chunk in tqdm([data.trg[i:i+batch_size] for i in range(0, len(data.trg), batch_size)]):
|
| 100 |
scores = self.scorer.score(parent_sequence=data.seq, sequences=chunk)
|
| 101 |
out.extend(s['score'] for s in scores)
|
| 102 |
data.out[self.model_name] = out
|
|
|
|
| 1 |
from huggingface_hub import HfApi
|
| 2 |
import torch
|
|
|
|
| 3 |
from typing import Any
|
| 4 |
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
| 5 |
from transformers.tokenization_utils_base import BatchEncoding
|
|
|
|
| 53 |
if data.scoring_strategy.startswith("masked-marginals"):
|
| 54 |
all_token_probs = []
|
| 55 |
# For each token in the batch
|
| 56 |
+
for i in data.progress.tqdm(range(batch_tokens.size()[1]), desc="Calculating"):
|
| 57 |
# If the token is in the list of residues
|
| 58 |
if i in data.resi:
|
| 59 |
# Clone the batch tokens and mask the current token
|
|
|
|
| 72 |
|
| 73 |
# Apply the label_row function to each row of the substitutions dataframe
|
| 74 |
data.out[self.model_name] = data.sub.apply(
|
| 75 |
+
lambda row: label_row(row['0'], token_probs)
|
|
|
|
|
|
|
| 76 |
, axis=1 )
|
| 77 |
|
| 78 |
class E1Model:
|
|
|
|
| 93 |
self.scorer = E1Scorer(self.model, EncoderScoreMethod.WILDTYPE_MARGINAL)
|
| 94 |
batch_size = 60 ## chunking to avoid OOM
|
| 95 |
out = []
|
| 96 |
+
for chunk in data.progress.tqdm([data.trg[i:i+batch_size] for i in range(0, len(data.trg), batch_size)], desc="Calculating"):
|
| 97 |
scores = self.scorer.score(parent_sequence=data.seq, sequences=chunk)
|
| 98 |
out.extend(s['score'] for s in scores)
|
| 99 |
data.out[self.model_name] = out
|