Ninjani commited on
Commit
1ca89d7
·
1 Parent(s): ec7ba9f
Files changed (2) hide show
  1. Dockerfile +1 -0
  2. app.py +11 -5
Dockerfile CHANGED
@@ -17,6 +17,7 @@ ARG MAMBA_DOCKERFILE_ACTIVATE=1
17
  ENV BASH_ENV=/usr/local/bin/_activate_current_env.sh
18
  ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/conda/lib
19
 
 
20
  RUN pip install --no-cache-dir \
21
  "stoic @ git+https://github.com/PickyBinders/stoic.git" \
22
  gradio==6.9.0
 
17
  ENV BASH_ENV=/usr/local/bin/_activate_current_env.sh
18
  ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/conda/lib
19
 
20
+ RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu
21
  RUN pip install --no-cache-dir \
22
  "stoic @ git+https://github.com/PickyBinders/stoic.git" \
23
  gradio==6.9.0
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import time
2
 
3
  import torch
@@ -6,11 +7,15 @@ from loguru import logger
6
 
7
  from stoic.model import Stoic
8
 
9
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
- logger.info(f"Loading model on {device}")
11
- model = Stoic.from_pretrained("PickyBinders/stoic")
12
- model = model.to(device).eval()
13
- logger.info("Model loaded")
 
 
 
 
14
 
15
 
16
  def predict(sequences_text: str, top_n: int) -> tuple[str, str]:
@@ -20,6 +25,7 @@ def predict(sequences_text: str, top_n: int) -> tuple[str, str]:
20
  if len(sequences) > 26:
21
  raise gr.Error("Maximum 26 unique chains supported.")
22
 
 
23
  start = time.time()
24
  with torch.no_grad():
25
  results = model.predict_stoichiometry(sequences, top_n=top_n)
 
1
+ import functools
2
  import time
3
 
4
  import torch
 
7
 
8
  from stoic.model import Stoic
9
 
10
+
11
+ @functools.lru_cache(maxsize=1)
12
+ def get_model():
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ logger.info(f"Loading model on {device}")
15
+ model = Stoic.from_pretrained("PickyBinders/stoic")
16
+ model = model.to(device).eval()
17
+ logger.info("Model loaded")
18
+ return model
19
 
20
 
21
  def predict(sequences_text: str, top_n: int) -> tuple[str, str]:
 
25
  if len(sequences) > 26:
26
  raise gr.Error("Maximum 26 unique chains supported.")
27
 
28
+ model = get_model()
29
  start = time.time()
30
  with torch.no_grad():
31
  results = model.predict_stoichiometry(sequences, top_n=top_n)