Spaces:
Running
Running
update
Browse files- Dockerfile +1 -0
- app.py +11 -5
Dockerfile
CHANGED
|
@@ -17,6 +17,7 @@ ARG MAMBA_DOCKERFILE_ACTIVATE=1
|
|
| 17 |
ENV BASH_ENV=/usr/local/bin/_activate_current_env.sh
|
| 18 |
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/conda/lib
|
| 19 |
|
|
|
|
| 20 |
RUN pip install --no-cache-dir \
|
| 21 |
"stoic @ git+https://github.com/PickyBinders/stoic.git" \
|
| 22 |
gradio==6.9.0
|
|
|
|
| 17 |
ENV BASH_ENV=/usr/local/bin/_activate_current_env.sh
|
| 18 |
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/conda/lib
|
| 19 |
|
| 20 |
+
RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu
|
| 21 |
RUN pip install --no-cache-dir \
|
| 22 |
"stoic @ git+https://github.com/PickyBinders/stoic.git" \
|
| 23 |
gradio==6.9.0
|
app.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import time
|
| 2 |
|
| 3 |
import torch
|
|
@@ -6,11 +7,15 @@ from loguru import logger
|
|
| 6 |
|
| 7 |
from stoic.model import Stoic
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
logger.info("
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
def predict(sequences_text: str, top_n: int) -> tuple[str, str]:
|
|
@@ -20,6 +25,7 @@ def predict(sequences_text: str, top_n: int) -> tuple[str, str]:
|
|
| 20 |
if len(sequences) > 26:
|
| 21 |
raise gr.Error("Maximum 26 unique chains supported.")
|
| 22 |
|
|
|
|
| 23 |
start = time.time()
|
| 24 |
with torch.no_grad():
|
| 25 |
results = model.predict_stoichiometry(sequences, top_n=top_n)
|
|
|
|
| 1 |
+
import functools
|
| 2 |
import time
|
| 3 |
|
| 4 |
import torch
|
|
|
|
| 7 |
|
| 8 |
from stoic.model import Stoic
|
| 9 |
|
| 10 |
+
|
| 11 |
+
@functools.lru_cache(maxsize=1)
|
| 12 |
+
def get_model():
|
| 13 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 14 |
+
logger.info(f"Loading model on {device}")
|
| 15 |
+
model = Stoic.from_pretrained("PickyBinders/stoic")
|
| 16 |
+
model = model.to(device).eval()
|
| 17 |
+
logger.info("Model loaded")
|
| 18 |
+
return model
|
| 19 |
|
| 20 |
|
| 21 |
def predict(sequences_text: str, top_n: int) -> tuple[str, str]:
|
|
|
|
| 25 |
if len(sequences) > 26:
|
| 26 |
raise gr.Error("Maximum 26 unique chains supported.")
|
| 27 |
|
| 28 |
+
model = get_model()
|
| 29 |
start = time.time()
|
| 30 |
with torch.no_grad():
|
| 31 |
results = model.predict_stoichiometry(sequences, top_n=top_n)
|