Spaces:
Sleeping
Sleeping
vladimir.manuylov
commited on
Commit
·
f42fb15
1
Parent(s):
a26c5b0
fixed app for zero gpu
Browse files
app.py
CHANGED
|
@@ -10,13 +10,14 @@ from protobind_diff.esm_inference import get_esm_embedding
|
|
| 10 |
from protobind_diff.model import ModelGenerator
|
| 11 |
from protobind_diff.data_loader import InferenceDataset
|
| 12 |
from huggingface_hub import hf_hub_download
|
|
|
|
| 13 |
|
| 14 |
# Hugging Face Hub details
|
| 15 |
REPO_ID = "ai-gero/ProtoBind-Diff"
|
| 16 |
MODEL_FILENAME = "model.ckpt"
|
| 17 |
TOKENIZER_FILENAME = "tokenizer_smiles_diffusion.json"
|
| 18 |
|
| 19 |
-
|
| 20 |
def generate_smiles_for_sequence(protein_sequence: str, num_samples: int):
|
| 21 |
"""
|
| 22 |
The main prediction function that runs the full pipeline.
|
|
@@ -27,7 +28,9 @@ def generate_smiles_for_sequence(protein_sequence: str, num_samples: int):
|
|
| 27 |
if len(protein_sequence) < 10:
|
| 28 |
raise gr.Error("Protein sequence is too short.")
|
| 29 |
|
| 30 |
-
|
|
|
|
|
|
|
| 31 |
|
| 32 |
with torch.no_grad():
|
| 33 |
batch_converter = alphabet.get_batch_converter()
|
|
@@ -58,7 +61,7 @@ def generate_smiles_for_sequence(protein_sequence: str, num_samples: int):
|
|
| 58 |
# --- GRADIO APP DEFINITION ---
|
| 59 |
|
| 60 |
# Load models on app startup
|
| 61 |
-
device = "
|
| 62 |
esm_model, alphabet = esm.pretrained.load_model_and_alphabet('esm2_t33_650M_UR50D')
|
| 63 |
esm_model.eval()
|
| 64 |
esm_model = esm_model.to(device)
|
|
|
|
| 10 |
from protobind_diff.model import ModelGenerator
|
| 11 |
from protobind_diff.data_loader import InferenceDataset
|
| 12 |
from huggingface_hub import hf_hub_download
|
| 13 |
+
import spaces
|
| 14 |
|
| 15 |
# Hugging Face Hub details
|
| 16 |
REPO_ID = "ai-gero/ProtoBind-Diff"
|
| 17 |
MODEL_FILENAME = "model.ckpt"
|
| 18 |
TOKENIZER_FILENAME = "tokenizer_smiles_diffusion.json"
|
| 19 |
|
| 20 |
+
@spaces.GPU(duration=120)
|
| 21 |
def generate_smiles_for_sequence(protein_sequence: str, num_samples: int):
|
| 22 |
"""
|
| 23 |
The main prediction function that runs the full pipeline.
|
|
|
|
| 28 |
if len(protein_sequence) < 10:
|
| 29 |
raise gr.Error("Protein sequence is too short.")
|
| 30 |
|
| 31 |
+
device = "cuda"
|
| 32 |
+
esm_model.to(device)
|
| 33 |
+
protobind_model.to(device)
|
| 34 |
|
| 35 |
with torch.no_grad():
|
| 36 |
batch_converter = alphabet.get_batch_converter()
|
|
|
|
| 61 |
# --- GRADIO APP DEFINITION ---
|
| 62 |
|
| 63 |
# Load models on app startup
|
| 64 |
+
device = "cpu"
|
| 65 |
esm_model, alphabet = esm.pretrained.load_model_and_alphabet('esm2_t33_650M_UR50D')
|
| 66 |
esm_model.eval()
|
| 67 |
esm_model = esm_model.to(device)
|