vladimir.manuylov commited on
Commit
a26c5b0
·
1 Parent(s): 8140c5e

fix app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -7
app.py CHANGED
@@ -1,8 +1,7 @@
1
  # app.py
2
  # --- IMPORTS ---
3
  import re
4
- from pathlib import Path
5
-
6
  import gradio as gr
7
  import torch
8
  from torch.utils.data import DataLoader
@@ -28,11 +27,15 @@ def generate_smiles_for_sequence(protein_sequence: str, num_samples: int):
28
  if len(protein_sequence) < 10:
29
  raise gr.Error("Protein sequence is too short.")
30
 
31
- embedding = get_esm_embedding(
32
- protein_sequence,
33
- 'esm2_t33_650M_UR50D',
34
- device
35
- ).to(dtype=torch.bfloat16)
 
 
 
 
36
  n_batches = num_samples // 10
37
  dataset = InferenceDataset(embedding, batch_size=10, n_batches=n_batches)
38
  loader = DataLoader(dataset, batch_size=None)
@@ -56,6 +59,10 @@ def generate_smiles_for_sequence(protein_sequence: str, num_samples: int):
56
 
57
  # Load models on app startup
58
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
59
  tokenizer_path = hf_hub_download(
60
  repo_id=REPO_ID,
61
  filename=TOKENIZER_FILENAME,
@@ -141,6 +148,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
141
 
142
  # Launch the app
143
  if __name__ == "__main__":
 
144
  demo.launch(share=True)
145
 
146
 
 
1
  # app.py
2
  # --- IMPORTS ---
3
  import re
4
+ import esm
 
5
  import gradio as gr
6
  import torch
7
  from torch.utils.data import DataLoader
 
27
  if len(protein_sequence) < 10:
28
  raise gr.Error("Protein sequence is too short.")
29
 
30
+ print(">> inference started, attempts:", num_samples, flush=True)
31
+
32
+ with torch.no_grad():
33
+ batch_converter = alphabet.get_batch_converter()
34
+ _, _, tokens = batch_converter([("protein", protein_sequence)])
35
+ tokens = tokens.to(device)
36
+ embedding = esm_model(tokens, repr_layers=[33])["representations"][33][:, 1:-1, :]
37
+ embedding = embedding.float() if device == "cpu" else embedding.bfloat16()
38
+
39
  n_batches = num_samples // 10
40
  dataset = InferenceDataset(embedding, batch_size=10, n_batches=n_batches)
41
  loader = DataLoader(dataset, batch_size=None)
 
59
 
60
  # Load models on app startup
61
  device = "cuda" if torch.cuda.is_available() else "cpu"
62
+ esm_model, alphabet = esm.pretrained.load_model_and_alphabet('esm2_t33_650M_UR50D')
63
+ esm_model.eval()
64
+ esm_model = esm_model.to(device)
65
+
66
  tokenizer_path = hf_hub_download(
67
  repo_id=REPO_ID,
68
  filename=TOKENIZER_FILENAME,
 
148
 
149
  # Launch the app
150
  if __name__ == "__main__":
151
+ demo.queue(max_size=10)
152
  demo.launch(share=True)
153
 
154