mgtotaro commited on
Commit
08446bb
·
1 Parent(s): 8ecc9a8

E1 model addition

Browse files
Files changed (4) hide show
  1. app.py +46 -49
  2. data.py +41 -36
  3. model.py +65 -25
  4. requirements.txt +1 -0
app.py CHANGED
@@ -1,13 +1,13 @@
1
- from gradio import Blocks, Button, Checkbox, DataFrame, DownloadButton, Dropdown, Examples, Image, Markdown, Tab, Textbox
2
 
3
- from model import get_models
4
  from data import Data
5
 
6
  # Define scoring strategies
7
  SCORING = ["wt-marginals", "masked-marginals"]
8
 
9
  # Get available models
10
- MODELS = get_models()
11
 
12
  def app(*argv):
13
  """
@@ -17,12 +17,15 @@ def app(*argv):
17
  seq, trg, model_name, *_ = argv
18
  scoring = SCORING[scoring_strategy.value]
19
  # Calculate the data based on the input parameters
20
- data = Data(seq, trg, model_name, scoring).calculate()
21
-
22
- if isinstance(data.image(), str):
23
- out = Image(value=data.image(), type='filepath', visible=True), DataFrame(visible=False)
24
- else:
25
- out = Image(visible=False), DataFrame(value=data.image(), visible=True)
 
 
 
26
 
27
  return *out, DownloadButton(value=data.csv(), visible=True)
28
 
@@ -32,58 +35,52 @@ with Blocks() as esm_scan:
32
  # Define the interface components
33
  with Tab("App"):
34
  Markdown(open("header.md", "r", encoding="utf-8").read())
35
- seq = Textbox(
36
- lines=2,
37
- label="Sequence",
38
- placeholder="FASTA sequence here...",
39
- value=''
40
- )
41
- trg = Textbox(
42
- lines=1,
43
- label="Substitutions",
44
- placeholder="Substitutions here...",
45
- value=""
46
- )
47
  model_name = Dropdown(MODELS, label="Model", value="facebook/esm2_t30_150M_UR50D")
48
  scoring_strategy = Checkbox(value=True, label="Use higher accuracy scoring", interactive=True)
49
  btn = Button(value="Run", variant="primary")
50
  dlb = DownloadButton(label="Download raw data", visible=False)
51
  out = Image(visible=False)
52
  ouu = DataFrame(visible=False)
53
- btn.click(
54
- fn=app,
55
- inputs=[seq, trg, model_name],
56
- outputs=[out, ouu, dlb]
57
- )
58
  ex = Examples(
59
  examples=[
60
- [
61
- "MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ",
62
- "deep mutational scanning",
63
- "facebook/esm2_t6_8M_UR50D"
64
  ],
65
- [
66
- "MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ",
67
- "217 218 219",
68
- "facebook/esm2_t12_35M_UR50D"
69
  ],
70
- [
71
- "MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ",
72
- "R218K R218S R218N R218A R218V R218D",
73
- "facebook/esm2_t30_150M_UR50D",
74
  ],
75
- [
76
- "MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ",
77
- "MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMWGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ",
78
- "facebook/esm2_t33_650M_UR50D",
79
  ],
80
- ],
81
- inputs=[seq,
82
- trg,
83
- model_name],
84
- outputs=[out],
85
- fn=app,
86
- cache_examples=False
 
87
  )
88
  with Tab("Instructions"):
89
  Markdown(open("instructions.md", "r", encoding="utf-8").read())
 
1
+ from gradio import Blocks, Button, Checkbox, DataFrame, DownloadButton, Dropdown, Error, Examples, Image, Markdown, Tab, Textbox
2
 
3
+ from model import ModelFactory
4
  from data import Data
5
 
6
  # Define scoring strategies
7
  SCORING = ["wt-marginals", "masked-marginals"]
8
 
9
  # Get available models
10
+ MODELS = ModelFactory.models()
11
 
12
  def app(*argv):
13
  """
 
17
  seq, trg, model_name, *_ = argv
18
  scoring = SCORING[scoring_strategy.value]
19
  # Calculate the data based on the input parameters
20
+ try:
21
+ data = Data(seq, trg, model_name, scoring).calculate()
22
+ if isinstance(data.image(), str):
23
+ out = Image(value=data.image(), type='filepath', visible=True), DataFrame(visible=False)
24
+ else:
25
+ out = Image(visible=False), DataFrame(value=data.image(), visible=True)
26
+ except Exception as e:
27
+ out = Image(visible=False), DataFrame(visible=False)
28
+ raise Error(str(e))
29
 
30
  return *out, DownloadButton(value=data.csv(), visible=True)
31
 
 
35
  # Define the interface components
36
  with Tab("App"):
37
  Markdown(open("header.md", "r", encoding="utf-8").read())
38
+ seq = Textbox( lines=2
39
+ , label="Sequence"
40
+ , placeholder="FASTA sequence here..."
41
+ , value=''
42
+ )
43
+ trg = Textbox( lines=1
44
+ , label="Substitutions"
45
+ , placeholder="Substitutions here..."
46
+ , value=""
47
+ )
 
 
48
  model_name = Dropdown(MODELS, label="Model", value="facebook/esm2_t30_150M_UR50D")
49
  scoring_strategy = Checkbox(value=True, label="Use higher accuracy scoring", interactive=True)
50
  btn = Button(value="Run", variant="primary")
51
  dlb = DownloadButton(label="Download raw data", visible=False)
52
  out = Image(visible=False)
53
  ouu = DataFrame(visible=False)
54
+ btn.click( fn=app
55
+ , inputs=[seq, trg, model_name]
56
+ , outputs=[out, ouu, dlb]
57
+ )
 
58
  ex = Examples(
59
  examples=[
60
+ [ "MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ"
61
+ , "deep mutational scanning"
62
+ , "facebook/esm2_t6_8M_UR50D"
 
63
  ],
64
+ [ "MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ"
65
+ , "217 218 219"
66
+ , "facebook/esm2_t12_35M_UR50D"
 
67
  ],
68
+ [ "MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ"
69
+ , "R218K R218S R218N R218A R218V R218D"
70
+ , "facebook/esm2_t30_150M_UR50D"
 
71
  ],
72
+ [ "MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ"
73
+ , "MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMWGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ"
74
+ , "facebook/esm2_t33_650M_UR50D"
 
75
  ],
76
+ ]
77
+ , inputs=[ seq
78
+ , trg
79
+ , model_name
80
+ ]
81
+ , outputs=[out]
82
+ , fn=app
83
+ , cache_examples=False
84
  )
85
  with Tab("Instructions"):
86
  Markdown(open("instructions.md", "r", encoding="utf-8").read())
data.py CHANGED
@@ -4,18 +4,15 @@ import pandas as pd
4
  from re import match
5
  import seaborn as sns
6
 
7
- from model import Model
8
 
9
  class Data:
10
  """Container for input and output data"""
11
- # Initialise empty model as static class member for efficiency
12
- model = Model()
13
-
14
  def parse_seq(self, src: str):
15
  """Parse input sequence"""
16
  self.seq = src.strip().upper().replace('\n', '')
17
  if not all(x in self.model.alphabet for x in self.seq):
18
- raise RuntimeError("Unrecognised characters in sequence")
19
 
20
  def parse_sub(self, trg: str):
21
  """Parse input substitutions"""
@@ -36,34 +33,42 @@ class Data:
36
  if all(match(r'\d+', x) for x in self.trg):
37
  # If all strings are numbers, deep mutational scanning mode
38
  self.mode = 'DMS'
 
39
  for resi in map(int, self.trg):
40
  src = self.seq[resi-1]
41
  for trg in "ACDEFGHIKLMNPQRSTVWY".replace(src, ''):
42
  self.sub.append(f"{src}{resi}{trg}")
 
43
  self.resi.append(resi)
 
44
  elif all(match(r'[A-Z]\d+[A-Z]', x) for x in self.trg):
45
  # If all strings are of the form X#Y, single substitution mode
46
  self.mode = 'MUT'
47
  self.sub = self.trg
48
- self.resi = [int(x[1:-1]) for x in self.trg]
 
 
 
 
49
  for s, *resi, _ in self.trg:
50
  if self.seq[int(''.join(resi))-1] != s:
51
- raise RuntimeError(f"Unrecognised input substitution {self.seq[int(''.join(resi))]}{int(''.join(resi))} /= {s}{int(''.join(resi))}")
 
52
  else:
53
  self.mode = 'TMS'
 
54
  for resi, src in enumerate(self.seq, 1):
55
  for trg in "ACDEFGHIKLMNPQRSTVWY".replace(src, ''):
56
  self.sub.append(f"{src}{resi}{trg}")
 
57
  self.resi.append(resi)
58
 
59
  self.sub = pd.DataFrame(self.sub, columns=['0'])
60
 
61
  def __init__(self, src:str, trg:str, model_name:str='facebook/esm2_t33_650M_UR50D', scoring_strategy:str='masked-marginals', out_file='out'):
62
  "initialise data"
63
- # if model has changed, load new model
64
- if self.model.model_name != model_name:
65
- self.model_name = model_name
66
- self.model = Model(model_name)
67
  self.parse_seq(src)
68
  self.offset = 0
69
  self.parse_sub(trg)
@@ -101,8 +106,8 @@ class Data:
101
  .groupby(['resi'])
102
  .head(19)
103
  .drop(['resi'], axis=1))
104
- self.out = pd.concat([self.out.iloc[19*x:19*(x+1)].reset_index(drop=True) for x in range(self.out.shape[0]//19)]
105
- , axis=1).set_axis(range(self.out.shape[0]//19*2), axis='columns')
106
 
107
  def process_tms_mode(self):
108
  self.out = self.assign_resi_and_group()
@@ -124,8 +129,8 @@ class Data:
124
  .pipe(self.create_dataframe)
125
  .sort_values(['0'], ascending=[True])
126
  .drop(['resi', '0'], axis=1)
127
- .set_axis(['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L',
128
- 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y'])
129
  .astype(float)
130
  ) for x in range(self.out.shape[0]//19)]
131
  , axis=1)
@@ -152,16 +157,16 @@ class Data:
152
 
153
  def plot_single_heatmap(self):
154
  fig = plt.figure(figsize=(12, 6))
155
- sns.heatmap(self.out
156
- , cmap='RdBu'
157
- , cbar=False
158
- , square=True
159
- , xticklabels=1
160
- , yticklabels=1
161
- , center=0
162
- , annot=self.out.map(lambda x: ' ' if x != 0 else '·')
163
- , fmt='s'
164
- , annot_kws={'size': 'xx-large'})
165
  fig.tight_layout()
166
 
167
  def plot_multiple_heatmaps(self, ncols, nrows):
@@ -169,17 +174,17 @@ class Data:
169
  for i in range(nrows):
170
  tmp = self.out.iloc[:,i*ncols:(i+1)*ncols]
171
  label = tmp.map(lambda x: ' ' if x != 0 else '·')
172
- sns.heatmap(tmp
173
- , ax=ax[i]
174
- , cmap='RdBu'
175
- , cbar=False
176
- , square=True
177
- , xticklabels=1
178
- , yticklabels=1
179
- , center=0
180
- , annot=label
181
- , fmt='s'
182
- , annot_kws={'size': 'xx-large'})
183
  ax[i].set_yticklabels(ax[i].get_yticklabels(), rotation=0)
184
  ax[i].set_xticklabels(ax[i].get_xticklabels(), rotation=90)
185
  fig.tight_layout()
 
4
  from re import match
5
  import seaborn as sns
6
 
7
+ from model import ModelFactory
8
 
9
  class Data:
10
  """Container for input and output data"""
 
 
 
11
  def parse_seq(self, src: str):
12
  """Parse input sequence"""
13
  self.seq = src.strip().upper().replace('\n', '')
14
  if not all(x in self.model.alphabet for x in self.seq):
15
+ raise RuntimeError(f"Unsupported characters in sequence: {''.join(x for x in self.seq if x not in self.model.alphabet)}")
16
 
17
  def parse_sub(self, trg: str):
18
  """Parse input substitutions"""
 
33
  if all(match(r'\d+', x) for x in self.trg):
34
  # If all strings are numbers, deep mutational scanning mode
35
  self.mode = 'DMS'
36
+ trh = list()
37
  for resi in map(int, self.trg):
38
  src = self.seq[resi-1]
39
  for trg in "ACDEFGHIKLMNPQRSTVWY".replace(src, ''):
40
  self.sub.append(f"{src}{resi}{trg}")
41
+ trh.append(self.seq[:resi-1]+trg+self.seq[resi:])
42
  self.resi.append(resi)
43
+ self.trg = trh
44
  elif all(match(r'[A-Z]\d+[A-Z]', x) for x in self.trg):
45
  # If all strings are of the form X#Y, single substitution mode
46
  self.mode = 'MUT'
47
  self.sub = self.trg
48
+ trh = list()
49
+ for x in self.trg:
50
+ idx = int(x[1:-1])
51
+ self.resi.append(idx)
52
+ trh.append(self.seq[:idx-1]+x[-1]+self.seq[idx:])
53
  for s, *resi, _ in self.trg:
54
  if self.seq[int(''.join(resi))-1] != s:
55
+ raise RuntimeError(f"Unrecognised input substitution: {self.seq[int(''.join(resi))]}{int(''.join(resi))} /= {s}{int(''.join(resi))}")
56
+ self.trg = trh
57
  else:
58
  self.mode = 'TMS'
59
+ self.trg = list()
60
  for resi, src in enumerate(self.seq, 1):
61
  for trg in "ACDEFGHIKLMNPQRSTVWY".replace(src, ''):
62
  self.sub.append(f"{src}{resi}{trg}")
63
+ self.trg.append(self.seq[:resi-1]+trg+self.seq[resi:])
64
  self.resi.append(resi)
65
 
66
  self.sub = pd.DataFrame(self.sub, columns=['0'])
67
 
68
  def __init__(self, src:str, trg:str, model_name:str='facebook/esm2_t33_650M_UR50D', scoring_strategy:str='masked-marginals', out_file='out'):
69
  "initialise data"
70
+ self.model_name = model_name
71
+ self.model = ModelFactory(model_name)
 
 
72
  self.parse_seq(src)
73
  self.offset = 0
74
  self.parse_sub(trg)
 
106
  .groupby(['resi'])
107
  .head(19)
108
  .drop(['resi'], axis=1))
109
+ self.out = pd.concat([ self.out.iloc[19*x:19*(x+1)].reset_index(drop=True) for x in range(self.out.shape[0]//19)]
110
+ , axis=1).set_axis(range(self.out.shape[0]//19*2), axis='columns')
111
 
112
  def process_tms_mode(self):
113
  self.out = self.assign_resi_and_group()
 
129
  .pipe(self.create_dataframe)
130
  .sort_values(['0'], ascending=[True])
131
  .drop(['resi', '0'], axis=1)
132
+ .set_axis([ 'A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L'
133
+ , 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y'])
134
  .astype(float)
135
  ) for x in range(self.out.shape[0]//19)]
136
  , axis=1)
 
157
 
158
  def plot_single_heatmap(self):
159
  fig = plt.figure(figsize=(12, 6))
160
+ sns.heatmap( self.out
161
+ , cmap='RdBu'
162
+ , cbar=False
163
+ , square=True
164
+ , xticklabels=1
165
+ , yticklabels=1
166
+ , center=0
167
+ , annot=self.out.map(lambda x: ' ' if x != 0 else '·')
168
+ , fmt='s'
169
+ , annot_kws={'size': 'xx-large'})
170
  fig.tight_layout()
171
 
172
  def plot_multiple_heatmaps(self, ncols, nrows):
 
174
  for i in range(nrows):
175
  tmp = self.out.iloc[:,i*ncols:(i+1)*ncols]
176
  label = tmp.map(lambda x: ' ' if x != 0 else '·')
177
+ sns.heatmap( tmp
178
+ , ax=ax[i]
179
+ , cmap='RdBu'
180
+ , cbar=False
181
+ , square=True
182
+ , xticklabels=1
183
+ , yticklabels=1
184
+ , center=0
185
+ , annot=label
186
+ , fmt='s'
187
+ , annot_kws={'size': 'xx-large'})
188
  ax[i].set_yticklabels(ax[i].get_yticklabels(), rotation=0)
189
  ax[i].set_xticklabels(ax[i].get_xticklabels(), rotation=90)
190
  fig.tight_layout()
model.py CHANGED
@@ -1,31 +1,17 @@
1
  from huggingface_hub import HfApi
2
  import torch
3
  from tqdm import tqdm
 
4
  from transformers import AutoTokenizer, AutoModelForMaskedLM
5
  from transformers.tokenization_utils_base import BatchEncoding
6
  from transformers.modeling_outputs import MaskedLMOutput
7
-
8
- # Function to fetch suitable ESM models from HuggingFace Hub
9
- def get_models() -> list[None|str]:
10
- """Fetch suitable ESM models from HuggingFace Hub."""
11
- if not any(
12
- out := [
13
- m.modelId for m in HfApi().list_models(
14
- author="facebook",
15
- model_name="esm",
16
- task="fill-mask",
17
- sort="lastModified",
18
- direction=-1
19
- )
20
- ]
21
- ):
22
- raise RuntimeError("Error while retrieving models from HuggingFace Hub")
23
- return out
24
 
25
  # Class to wrap ESM models
26
- class Model:
27
  """Wrapper for ESM models."""
28
- def __init__(self, model_name: str = ""):
29
  """Load selected model and tokenizer."""
30
  self.model_name = model_name
31
  if model_name:
@@ -95,9 +81,63 @@ class Model:
95
 
96
  # Apply the label_row function to each row of the substitutions dataframe
97
  data.out[self.model_name] = data.sub.apply(
98
- lambda row: label_row(
99
- row['0'],
100
- token_probs,
101
- ),
102
- axis=1,
103
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from huggingface_hub import HfApi
2
  import torch
3
  from tqdm import tqdm
4
+ from typing import Any
5
  from transformers import AutoTokenizer, AutoModelForMaskedLM
6
  from transformers.tokenization_utils_base import BatchEncoding
7
  from transformers.modeling_outputs import MaskedLMOutput
8
+ from E1.modeling import E1ForMaskedLM
9
+ from E1.scorer import E1Scorer, EncoderScoreMethod
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  # Class to wrap ESM models
12
+ class ESMModel:
13
  """Wrapper for ESM models."""
14
+ def __init__(self, model_name:str):
15
  """Load selected model and tokenizer."""
16
  self.model_name = model_name
17
  if model_name:
 
81
 
82
  # Apply the label_row function to each row of the substitutions dataframe
83
  data.out[self.model_name] = data.sub.apply(
84
+ lambda row: label_row(
85
+ row['0']
86
+ , token_probs
87
+ )
88
+ , axis=1
89
+ )
90
+
91
+ # Class to wrap E1 models
92
+ class E1Model:
93
+ def __init__(self, model_name:str):
94
+ self.model_name = model_name
95
+ self.scoring_strategy = EncoderScoreMethod.MASKED_MARGINAL
96
+ if model_name:
97
+ self.model = E1ForMaskedLM.from_pretrained(model_name, dtype=torch.float)
98
+ if torch.cuda.is_available():
99
+ self.model = self.model.cuda()
100
+ self.device = torch.device("cuda")
101
+ else:
102
+ self.device = torch.device("cpu")
103
+ self.scorer = E1Scorer(self.model, method=self.scoring_strategy)
104
+ self.alphabet = self.scorer.vocab
105
+
106
+ def run_model(self, data):
107
+ if not data.scoring_strategy.startswith("masked-marginals"):
108
+ self.scorer = E1Scorer(self.model, method=EncoderScoreMethod.WILDTYPE_MARGINAL)
109
+ scores = self.scorer.score(parent_sequence=data.seq, sequences=data.trg)
110
+ data.out[self.model_name] = [s['score'] for s in scores]
111
+
112
+ class ModelFactory:
113
+ _models = {
114
+ **{ m.modelId:ESMModel for m in HfApi().list_models(
115
+ author="facebook"
116
+ , model_name="esm"
117
+ , filter="fill-mask"
118
+ , sort="lastModified"
119
+ , direction=-1
120
+ )
121
+ }
122
+ , **{ m.modelId:E1Model for m in HfApi().list_models(
123
+ author="Profluent-Bio"
124
+ , model_name="E1"
125
+ , sort="lastModified"
126
+ , direction=-1
127
+ )
128
+ }
129
+ }
130
+
131
+ @classmethod
132
+ def register(cls, model_name, model_cls):
133
+ cls._models[model_name] = model_cls
134
+
135
+ @classmethod
136
+ def models(cls):
137
+ return [m for m in cls._models.keys()]
138
+
139
+ def __new__(cls, model_name:str) -> Any:
140
+ return cls._models[model_name](model_name)
141
+
142
+ for m,c in ModelFactory._models.items():
143
+ ModelFactory.register(m, c)
requirements.txt CHANGED
@@ -3,3 +3,4 @@ pandas
3
  seaborn
4
  torch
5
  transformers
 
 
3
  seaborn
4
  torch
5
  transformers
6
+ E1@git+https://github.com/Profluent-AI/E1.git@main#egg=E1