vladimir.manuylov commited on
Commit
f42fb15
·
1 Parent(s): a26c5b0

fixed app for zero gpu

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -10,13 +10,14 @@ from protobind_diff.esm_inference import get_esm_embedding
10
  from protobind_diff.model import ModelGenerator
11
  from protobind_diff.data_loader import InferenceDataset
12
  from huggingface_hub import hf_hub_download
 
13
 
14
  # Hugging Face Hub details
15
  REPO_ID = "ai-gero/ProtoBind-Diff"
16
  MODEL_FILENAME = "model.ckpt"
17
  TOKENIZER_FILENAME = "tokenizer_smiles_diffusion.json"
18
 
19
-
20
  def generate_smiles_for_sequence(protein_sequence: str, num_samples: int):
21
  """
22
  The main prediction function that runs the full pipeline.
@@ -27,7 +28,9 @@ def generate_smiles_for_sequence(protein_sequence: str, num_samples: int):
27
  if len(protein_sequence) < 10:
28
  raise gr.Error("Protein sequence is too short.")
29
 
30
- print(">> inference started, attempts:", num_samples, flush=True)
 
 
31
 
32
  with torch.no_grad():
33
  batch_converter = alphabet.get_batch_converter()
@@ -58,7 +61,7 @@ def generate_smiles_for_sequence(protein_sequence: str, num_samples: int):
58
  # --- GRADIO APP DEFINITION ---
59
 
60
  # Load models on app startup
61
- device = "cuda" if torch.cuda.is_available() else "cpu"
62
  esm_model, alphabet = esm.pretrained.load_model_and_alphabet('esm2_t33_650M_UR50D')
63
  esm_model.eval()
64
  esm_model = esm_model.to(device)
 
10
  from protobind_diff.model import ModelGenerator
11
  from protobind_diff.data_loader import InferenceDataset
12
  from huggingface_hub import hf_hub_download
13
+ import spaces
14
 
15
  # Hugging Face Hub details
16
  REPO_ID = "ai-gero/ProtoBind-Diff"
17
  MODEL_FILENAME = "model.ckpt"
18
  TOKENIZER_FILENAME = "tokenizer_smiles_diffusion.json"
19
 
20
+ @spaces.GPU(duration=120)
21
  def generate_smiles_for_sequence(protein_sequence: str, num_samples: int):
22
  """
23
  The main prediction function that runs the full pipeline.
 
28
  if len(protein_sequence) < 10:
29
  raise gr.Error("Protein sequence is too short.")
30
 
31
+ device = "cuda"
32
+ esm_model.to(device)
33
+ protobind_model.to(device)
34
 
35
  with torch.no_grad():
36
  batch_converter = alphabet.get_batch_converter()
 
61
  # --- GRADIO APP DEFINITION ---
62
 
63
  # Load models on app startup
64
+ device = "cpu"
65
  esm_model, alphabet = esm.pretrained.load_model_and_alphabet('esm2_t33_650M_UR50D')
66
  esm_model.eval()
67
  esm_model = esm_model.to(device)