mgtotaro commited on
Commit
31be27c
·
1 Parent(s): 7e12824

progress tracker

Browse files
Files changed (3) hide show
  1. app.py +19 -25
  2. data.py +5 -6
  3. model.py +3 -6
app.py CHANGED
@@ -1,32 +1,27 @@
1
- from gradio import Blocks, Button, Checkbox, DownloadButton, Dropdown, Error, Examples, Image, HTML, Markdown, Tab, Textbox
2
 
3
  from model import ModelFactory
4
  from data import Data
5
 
6
- # Define scoring strategies
7
- SCORING = ["wt-marginals", "masked-marginals"]
8
-
9
  # Get available models
10
  MODELS = ModelFactory.models()
11
 
12
- def app(*argv):
13
  "Main application function"
14
- # Unpack the arguments
15
- seq, trg, model_name, *_ = argv
16
- scoring = SCORING[scoring_strategy.value]
17
 
18
  # Validate the input
19
  if 1 > len(seq):
20
  raise Error("Sequence cannot be empty")
21
- if 1 > len(trg):
22
  raise Error("Substitutions cannot be empty")
23
 
24
  # Calculate the data based on the input parameters
25
  try:
26
- data = Data(seq, trg, model_name, scoring).calculate()
27
  if isinstance(data.image, str):
28
  return ( Image(value=data.image, type='filepath', visible=True)
29
- , HTML(visible=False)
30
  , DownloadButton(value=data.csv, visible=True) )
31
  else:
32
  return ( Image(visible=False)
@@ -45,19 +40,20 @@ with Blocks() as demo:
45
  , label="Sequence"
46
  , placeholder="FASTA sequence here..."
47
  , value='' )
48
- trg = Textbox( lines=1
49
  , label="Substitutions"
50
  , placeholder="Substitutions here..."
51
  , value='' )
52
- model_name = Dropdown(MODELS, label="Model", value="facebook/esm2_t30_150M_UR50D")
53
- scoring_strategy = Checkbox(value=True, label="Use higher accuracy scoring", interactive=True)
54
- btn = Button(value="Run", variant="primary")
55
- dlb = DownloadButton(label="Download raw data", visible=False)
56
- out = Image(visible=False)
57
- ouu = HTML(visible=False)
58
- btn.click( fn=app
59
- , inputs=[seq, trg, model_name]
60
- , outputs=[out, ouu, dlb] )
 
61
  ex = Examples(
62
  examples=[
63
  [ "MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ"
@@ -73,10 +69,8 @@ with Blocks() as demo:
73
  , "MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMWGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ"
74
  , "facebook/esm2_t33_650M_UR50D" ],
75
  ]
76
- , inputs=[ seq
77
- , trg
78
- , model_name ]
79
- , outputs=[out]
80
  , fn=app
81
  , cache_examples=False )
82
  with Tab("Instructions"):
 
1
+ from gradio import Blocks, Button, Checkbox, DownloadButton, Dropdown, Error, Examples, Image, HTML, Markdown, Progress, Tab, Textbox
2
 
3
  from model import ModelFactory
4
  from data import Data
5
 
 
 
 
6
  # Get available models
7
  MODELS = ModelFactory.models()
8
 
9
+ def app(seq, sub, model_name, acc):
10
  "Main application function"
11
+ scoring = "masked-marginals" if acc else "wt-marginals"
 
 
12
 
13
  # Validate the input
14
  if 1 > len(seq):
15
  raise Error("Sequence cannot be empty")
16
+ if 1 > len(sub):
17
  raise Error("Substitutions cannot be empty")
18
 
19
  # Calculate the data based on the input parameters
20
  try:
21
+ data = Data(seq, sub, model_name, scoring).calculate(progress)
22
  if isinstance(data.image, str):
23
  return ( Image(value=data.image, type='filepath', visible=True)
24
+ , HTML()
25
  , DownloadButton(value=data.csv, visible=True) )
26
  else:
27
  return ( Image(visible=False)
 
40
  , label="Sequence"
41
  , placeholder="FASTA sequence here..."
42
  , value='' )
43
+ sub = Textbox( lines=1
44
  , label="Substitutions"
45
  , placeholder="Substitutions here..."
46
  , value='' )
47
+ model_name = Dropdown(MODELS, label="Model", value="facebook/esm2_t30_150M_UR50D")
48
+ acc_box = Checkbox(value=True, label="Use higher accuracy scoring", interactive=True)
49
+ run_btn = Button(value="Run", variant="primary")
50
+ dl_btn = DownloadButton(label="Download raw data", visible=False)
51
+ progress = Progress()
52
+ out_html = HTML()
53
+ out_img = Image(visible=False)
54
+ run_btn.click( fn=app
55
+ , inputs=[seq, sub, model_name, acc_box]
56
+ , outputs=[out_img, out_html, dl_btn] )
57
  ex = Examples(
58
  examples=[
59
  [ "MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ"
 
69
  , "MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMWGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ"
70
  , "facebook/esm2_t33_650M_UR50D" ],
71
  ]
72
+ , inputs=[seq, sub, model_name]
73
+ , outputs=[out_img]
 
 
74
  , fn=app
75
  , cache_examples=False )
76
  with Tab("Instructions"):
data.py CHANGED
@@ -72,7 +72,7 @@ class Data:
72
  self.parse_seq(src)
73
  self.parse_sub(trg)
74
  self.scoring_strategy = scoring_strategy
75
- self.token_probs = None
76
  self.out = pd.DataFrame(self.sub, columns=['0', self.model_name])
77
  self.out_img = f"{out_file}.png"
78
  self.out_csv = f"{out_file}.csv"
@@ -121,11 +121,9 @@ class Data:
121
  def concat_and_set_axis(self):
122
  return (pd.concat([(self.out.iloc[19*x:19*(x+1)]
123
  .pipe(self.create_dataframe).sort_values(['0'], ascending=[True])
124
- .drop(["resi", '0'], axis=1)
125
- .astype(float)
126
  .set_axis([ 'A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L'
127
- , 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y'])) for x in range(self.out.shape[0]//19)]
128
- , axis=1)
129
  .set_axis([f'{a}{i}' for i, a in enumerate(self.seq, 1)], axis="columns"))
130
 
131
  def create_dataframe(self, df):
@@ -181,8 +179,9 @@ class Data:
181
  ax[i].set_xticklabels(ax[i].get_xticklabels(), rotation=90)
182
  fig.tight_layout()
183
 
184
- def calculate(self):
185
  "run model and parse output"
 
186
  self.model.run_model(self)
187
  self.parse_output()
188
  return self
 
72
  self.parse_seq(src)
73
  self.parse_sub(trg)
74
  self.scoring_strategy = scoring_strategy
75
+ self.progress = None
76
  self.out = pd.DataFrame(self.sub, columns=['0', self.model_name])
77
  self.out_img = f"{out_file}.png"
78
  self.out_csv = f"{out_file}.csv"
 
121
  def concat_and_set_axis(self):
122
  return (pd.concat([(self.out.iloc[19*x:19*(x+1)]
123
  .pipe(self.create_dataframe).sort_values(['0'], ascending=[True])
124
+ .drop(["resi", '0'], axis=1).astype(float)
 
125
  .set_axis([ 'A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L'
126
+ , 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y'])) for x in range(self.out.shape[0]//19)], axis=1)
 
127
  .set_axis([f'{a}{i}' for i, a in enumerate(self.seq, 1)], axis="columns"))
128
 
129
  def create_dataframe(self, df):
 
179
  ax[i].set_xticklabels(ax[i].get_xticklabels(), rotation=90)
180
  fig.tight_layout()
181
 
182
+ def calculate(self, progress):
183
  "run model and parse output"
184
+ self.progress = progress
185
  self.model.run_model(self)
186
  self.parse_output()
187
  return self
model.py CHANGED
@@ -1,6 +1,5 @@
1
  from huggingface_hub import HfApi
2
  import torch
3
- from tqdm import tqdm
4
  from typing import Any
5
  from transformers import AutoTokenizer, AutoModelForMaskedLM
6
  from transformers.tokenization_utils_base import BatchEncoding
@@ -54,7 +53,7 @@ class ESMModel:
54
  if data.scoring_strategy.startswith("masked-marginals"):
55
  all_token_probs = []
56
  # For each token in the batch
57
- for i in tqdm(range(batch_tokens.size()[1])):
58
  # If the token is in the list of residues
59
  if i in data.resi:
60
  # Clone the batch tokens and mask the current token
@@ -73,9 +72,7 @@ class ESMModel:
73
 
74
  # Apply the label_row function to each row of the substitutions dataframe
75
  data.out[self.model_name] = data.sub.apply(
76
- lambda row: label_row(
77
- row['0']
78
- , token_probs )
79
  , axis=1 )
80
 
81
  class E1Model:
@@ -96,7 +93,7 @@ class E1Model:
96
  self.scorer = E1Scorer(self.model, EncoderScoreMethod.WILDTYPE_MARGINAL)
97
  batch_size = 60 ## chunking to avoid OOM
98
  out = []
99
- for chunk in tqdm([data.trg[i:i+batch_size] for i in range(0, len(data.trg), batch_size)]):
100
  scores = self.scorer.score(parent_sequence=data.seq, sequences=chunk)
101
  out.extend(s['score'] for s in scores)
102
  data.out[self.model_name] = out
 
1
  from huggingface_hub import HfApi
2
  import torch
 
3
  from typing import Any
4
  from transformers import AutoTokenizer, AutoModelForMaskedLM
5
  from transformers.tokenization_utils_base import BatchEncoding
 
53
  if data.scoring_strategy.startswith("masked-marginals"):
54
  all_token_probs = []
55
  # For each token in the batch
56
+ for i in data.progress.tqdm(range(batch_tokens.size()[1]), desc="Calculating"):
57
  # If the token is in the list of residues
58
  if i in data.resi:
59
  # Clone the batch tokens and mask the current token
 
72
 
73
  # Apply the label_row function to each row of the substitutions dataframe
74
  data.out[self.model_name] = data.sub.apply(
75
+ lambda row: label_row(row['0'], token_probs)
 
 
76
  , axis=1 )
77
 
78
  class E1Model:
 
93
  self.scorer = E1Scorer(self.model, EncoderScoreMethod.WILDTYPE_MARGINAL)
94
  batch_size = 60 ## chunking to avoid OOM
95
  out = []
96
+ for chunk in data.progress.tqdm([data.trg[i:i+batch_size] for i in range(0, len(data.trg), batch_size)], desc="Calculating"):
97
  scores = self.scorer.score(parent_sequence=data.seq, sequences=chunk)
98
  out.extend(s['score'] for s in scores)
99
  data.out[self.model_name] = out