Ninjani commited on
Commit
ec7ba9f
·
1 Parent(s): 9a13a4a
Files changed (2) hide show
  1. Dockerfile +33 -0
  2. app.py +80 -0
Dockerfile ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ARG BASE_IMAGE=mambaorg/micromamba
2
+ ARG BASE_TAG=1.5-jammy
3
+ ARG MAMBA_PYTHON_VERSION=3.10
4
+
5
+ FROM --platform=linux/amd64 ${BASE_IMAGE}:${BASE_TAG}
6
+
7
+ ARG MAMBA_PYTHON_VERSION
8
+ ENV DEBIAN_FRONTEND=noninteractive
9
+
10
+ WORKDIR /usr/src/app
11
+
12
+ RUN micromamba install -y -n base -c conda-forge \
13
+ python=${MAMBA_PYTHON_VERSION} \
14
+ && micromamba clean --all --yes
15
+
16
+ ARG MAMBA_DOCKERFILE_ACTIVATE=1
17
+ ENV BASH_ENV=/usr/local/bin/_activate_current_env.sh
18
+ ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/conda/lib
19
+
20
+ RUN pip install --no-cache-dir \
21
+ "stoic @ git+https://github.com/PickyBinders/stoic.git" \
22
+ gradio==6.9.0
23
+
24
+ EXPOSE 7860
25
+ ENV GRADIO_SERVER_NAME="0.0.0.0"
26
+
27
+ ADD . .
28
+
29
+ USER $MAMBA_USER
30
+
31
+ ENTRYPOINT ["/usr/local/bin/_entrypoint.sh"]
32
+
33
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import torch
4
+ import gradio as gr
5
+ from loguru import logger
6
+
7
+ from stoic.model import Stoic
8
+
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ logger.info(f"Loading model on {device}")
11
+ model = Stoic.from_pretrained("PickyBinders/stoic")
12
+ model = model.to(device).eval()
13
+ logger.info("Model loaded")
14
+
15
+
16
+ def predict(sequences_text: str, top_n: int) -> tuple[str, str]:
17
+ sequences = [s.strip() for s in sequences_text.strip().split("\n") if s.strip()]
18
+ if not sequences:
19
+ raise gr.Error("Please enter at least one protein sequence.")
20
+ if len(sequences) > 26:
21
+ raise gr.Error("Maximum 26 unique chains supported.")
22
+
23
+ start = time.time()
24
+ with torch.no_grad():
25
+ results = model.predict_stoichiometry(sequences, top_n=top_n)
26
+ elapsed = time.time() - start
27
+
28
+ chain_labels = [chr(ord("A") + i) for i in range(len(sequences))]
29
+
30
+ header = "| Rank | " + " | ".join(f"Chain {l}" for l in chain_labels) + " | Stoichiometry |"
31
+ separator = "|------|" + "|".join("-----" for _ in chain_labels) + "|---------------|"
32
+ rows = []
33
+ for rank, candidate in enumerate(results, 1):
34
+ copies = [candidate.get(seq, 0) for seq in sequences]
35
+ stoich = "".join(f"{l}<sub>{c}</sub>" for l, c in zip(chain_labels, copies))
36
+ row = f"| {rank} | " + " | ".join(str(c) for c in copies) + f" | {stoich} |"
37
+ rows.append(row)
38
+
39
+ table = "\n".join([header, separator] + rows)
40
+
41
+ legend_lines = ["\n\n**Sequences:**"]
42
+ for label, seq in zip(chain_labels, sequences):
43
+ preview = seq[:50] + "..." if len(seq) > 50 else seq
44
+ legend_lines.append(f"- **Chain {label}**: `{preview}`")
45
+
46
+ return table + "\n".join(legend_lines), f"{elapsed:.2f}s"
47
+
48
+
49
+ with gr.Blocks(title="Stoic - Protein Stoichiometry Prediction") as app:
50
+ gr.Markdown(
51
+ "# Stoic\n"
52
+ "**Fast and accurate protein stoichiometry prediction**\n\n"
53
+ "Enter one protein sequence per line (one per unique chain type). "
54
+ "Stoic predicts how many copies of each chain are present in the assembled complex."
55
+ )
56
+
57
+ with gr.Row():
58
+ with gr.Column():
59
+ sequences_input = gr.Textbox(
60
+ label="Protein Sequences (one per line)",
61
+ placeholder="MKTLLILTLFLAIAASSASA...\nMGSSHHHHHHSSGLVPR...",
62
+ lines=6,
63
+ )
64
+ top_n = gr.Slider(
65
+ minimum=1, maximum=10, value=3, step=1,
66
+ label="Number of candidates to return",
67
+ )
68
+ btn = gr.Button("Predict Stoichiometry", variant="primary")
69
+
70
+ with gr.Column():
71
+ results_output = gr.Markdown(value="Results will appear here.")
72
+ run_time = gr.Textbox(label="Runtime")
73
+
74
+ btn.click(
75
+ predict,
76
+ inputs=[sequences_input, top_n],
77
+ outputs=[results_output, run_time],
78
+ )
79
+
80
+ app.launch()