File size: 10,781 Bytes
e96c79e
20c69f6
c456acd
 
 
 
20c69f6
 
e96c79e
0b631cd
e96c79e
 
 
 
 
 
c456acd
0b631cd
 
 
e96c79e
 
 
 
 
 
 
0b631cd
b32e543
0b631cd
c456acd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b631cd
e96c79e
20c69f6
e96c79e
 
 
 
 
 
 
 
 
 
c456acd
e96c79e
 
 
 
c78551f
b32e543
e96c79e
c456acd
20c69f6
e96c79e
 
 
 
 
 
 
 
 
 
 
c456acd
20c69f6
e96c79e
 
 
 
c456acd
20c69f6
e96c79e
 
 
 
 
 
 
 
 
0b631cd
c456acd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20c69f6
c78551f
 
c456acd
e96c79e
c78551f
c456acd
 
 
 
 
c78551f
 
e96c79e
0b631cd
20c69f6
c456acd
0b631cd
e96c79e
 
 
20c69f6
 
 
0b631cd
 
20c69f6
0b631cd
 
e96c79e
20c69f6
0b631cd
20c69f6
0b631cd
 
 
 
 
 
 
 
 
 
 
 
20c69f6
0b631cd
20c69f6
0b631cd
 
 
 
20c69f6
0b631cd
a857335
0b631cd
20c69f6
c456acd
0b631cd
c78551f
20c69f6
 
0b631cd
20c69f6
0b631cd
 
 
 
c456acd
0b631cd
c456acd
 
 
 
20c69f6
0b631cd
c456acd
 
 
 
 
 
 
 
 
 
 
c34378f
 
c456acd
c34378f
0b631cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c456acd
0b631cd
c456acd
 
c34378f
 
0b631cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20c69f6
c456acd
c34378f
20c69f6
c456acd
20c69f6
0b631cd
 
c456acd
c34378f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
#!/usr/bin/env python
"""
Gradio interface for PepTron protein ensemble generation (+ 3Dmol.js viewer).

This version treats multi-MODEL PDBs as trajectory frames via viewer.addModelsAsFrames(),
which is the recommended way to visualize an ensemble in 3Dmol.js. [web:42]
"""

import gradio as gr
from gradio import InputHTMLAttributes
import os
import subprocess
import tempfile
import logging
import shutil
import glob
import hashlib

os.environ["OMP_NUM_THREADS"] = "8"

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# --- Configuration ---
INFERENCE_STEPS = 10
MAX_BATCH_SIZE = 1
NUM_WORKERS = 8
CHECKPOINT_DIR = "/data/checkpoints/peptron-checkpoint"
REPO_DIR = "/app/PepTron"
RESULTS_DIR = "/data/results"
OUTPUTS_DIR = "outputs"  # -> /app/outputs in HF Spaces container


def cleanup_outputs():
    os.makedirs(OUTPUTS_DIR, exist_ok=True)
    for p in glob.glob(os.path.join(OUTPUTS_DIR, "*")):
        try:
            if os.path.isdir(p):
                shutil.rmtree(p)
            else:
                os.remove(p)
        except Exception as e:
            logger.warning(f"Could not delete {p}: {e}")


# Run once at process start (covers Space restart / rebuild cases)
cleanup_outputs()


# Bash script template
BASH_SCRIPT_CONTENT = """#!/bin/bash
set -e
CKPT_PATH="$1"
RESULTS_PATH="$2"
CSV_FILE="$3"
NUM_SAMPLES="$4"
STEPS="$5"
BATCH_SIZE="$6"
WORKERS="$7"


export NCCL_TIMEOUT=3600
export TORCH_NCCL_ENABLE_MONITORING=0
export TORCHDYNAMO_SUPPRESS_ERRORS=1
export CUDA_LAUNCH_BLOCKING=1
# PYTHONPATH=. tells Python to look for modules in the current directory (PepTron)
export PYTHONPATH=/app/PepTron


echo ">>> Inference..."
python -m peptron.infer \\
    --config.inference.num_nodes 1 \\
    --config.inference.checkpoint_path "$CKPT_PATH" \\
    --config.inference.chains_path "$CSV_FILE" \\
    --config.inference.results_path "$RESULTS_PATH" \\
    --config.inference.num_gpus 1 \\
    --config.inference.max_batch_size "$BATCH_SIZE" \\
    --config.inference.num_workers "$WORKERS" \\
    --config.inference.samples "$NUM_SAMPLES" \\
    --config.inference.steps "$STEPS"


echo ">>> Convert..."
python -m peptron.pt_to_structure -i "$RESULTS_PATH" \\
    -o "$RESULTS_PATH/ensembles" \\
    -p $(($(nproc) / 2))


echo ">>> Filter..."
mkdir -p "$RESULTS_PATH/physical_ensembles"
for trajectory_file in "$RESULTS_PATH/ensembles/"*.pdb; do
    [ -e "$trajectory_file" ] || continue
    base_name=$(basename "$trajectory_file" .pdb)
    output_file="$RESULTS_PATH/physical_ensembles/${base_name}_filtered.pdb"
    python -m peptron.utils.filter_unphysical_traj --trajectory "$trajectory_file" --outfile "$output_file"
done
"""


def _read_text(path: str) -> str:
    with open(path, "r", encoding="utf-8", errors="ignore") as f:
        return f.read()


def pdb_to_3dmol_iframe(pdb_path: str, height: int = 520, autoplay: bool = False) -> str:
    """
    Returns an <iframe srcdoc="..."> containing a 3Dmol.js viewer that renders
    the provided PDB. Multi-model PDBs are loaded as frames via addModelsAsFrames(). [web:42]

    Note: Gradio doesn't have a native PDB viewer; HF recommends HTML+iframe+3Dmol.js. [web:5]
    """
    pdb = _read_text(pdb_path)

    html_doc = f"""
<!DOCTYPE html>
<html>
<head>
  <meta http-equiv="content-type" content="text/html; charset=UTF-8" />
  <script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
  <style>
    html, body {{ margin:0; padding:0; }}
    .toolbar {{
      font-family: ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial;
      font-size: 12px;
      padding: 8px 10px;
      border-bottom: 1px solid #eee;
      display: flex;
      gap: 10px;
      align-items: center;
    }}
    .mol-container {{ width: 100%; height: {height}px; position: relative; }}
    input[type="range"] {{ width: 240px; }}
  </style>
</head>
<body>
  <div class="toolbar">
    <button id="btn_play">Play</button>
    <button id="btn_pause">Pause</button>
    <span>Frame:</span>
    <input id="frame" type="range" min="0" max="0" value="0" step="1" />
    <span id="frame_label">0</span>
  </div>
  <div id="container" class="mol-container"></div>

  <script>
    const pdb = `{pdb}`;
    const element = document.getElementById("container");
    const viewer = $3Dmol.createViewer(element, {{ backgroundColor: "white" }});

    // Multi-model PDB -> frames in one model (ensemble/trajectory). [web:42]
    const model = viewer.addModelsAsFrames(pdb, "pdb");

    viewer.setStyle({{}}, {{ cartoon: {{ color: "steelblue" }} }});
    viewer.zoomTo();
    viewer.render();

    // Frame controls
    const slider = document.getElementById("frame");
    const label = document.getElementById("frame_label");

    const nframes = (model && model.getNumFrames) ? model.getNumFrames() : 1;
    slider.max = Math.max(0, nframes - 1);
    slider.value = 0;
    label.textContent = "0";

    slider.addEventListener("input", (e) => {{
      const f = parseInt(e.target.value);
      label.textContent = String(f);
      if (model && model.setFrame) model.setFrame(f);
      viewer.render();
    }});

    document.getElementById("btn_play").addEventListener("click", () => {{
      viewer.animate({{loop: "forward"}});  // animate frames forward [web:42]
    }});

    document.getElementById("btn_pause").addEventListener("click", () => {{
      viewer.stopAnimate();
    }});

    if ({str(autoplay).lower()}) {{
      viewer.animate({{loop: "forward"}});
    }}
  </script>
</body>
</html>
"""

    # Escape double quotes for srcdoc="...".
    srcdoc = html_doc.replace('"', "&quot;")
    return f'<iframe style="width:100%; height:{height+46}px; border:0;" srcdoc="{srcdoc}"></iframe>'


def peptron_inference(sequence, num_samples):
    # 1. Validate Environment
    if not os.path.exists(CHECKPOINT_DIR) or not os.listdir(CHECKPOINT_DIR):
        return f"❌ Error: Checkpoint not found in {CHECKPOINT_DIR}.", None, None

    if not os.path.isdir(REPO_DIR):
        return (
            f"❌ Error: '{REPO_DIR}' directory not found. Please upload the PepTron folder to the Space.",
            None,
            None,
        )

    # 2. Validate Input
    clean_seq = sequence.strip().upper().replace(" ", "").replace("\n", "")
    valid_chars = set("ACDEFGHIKLMNPQRSTVWY")
    if not clean_seq or set(clean_seq) - valid_chars:
        return "❌ Error: Invalid amino acid sequence.", None, None

    try:
        with tempfile.TemporaryDirectory() as tmpdir:
            input_csv = os.path.join(tmpdir, "input.csv")
            out_dir = os.path.join(tmpdir, "out")
            script_path = os.path.join(tmpdir, "run.sh")
            os.makedirs(out_dir, exist_ok=True)

            with open(input_csv, "w") as f:
                f.write("name,seqres\ntarget," + clean_seq + "\n")

            with open(script_path, "w") as f:
                f.write(BASH_SCRIPT_CONTENT)
            os.chmod(script_path, 0o755)

            # 3. Run Inference
            cmd = [
                "bash",
                script_path,
                CHECKPOINT_DIR,
                out_dir,
                input_csv,
                str(num_samples),
                str(INFERENCE_STEPS),
                str(MAX_BATCH_SIZE),
                str(NUM_WORKERS),
            ]

            logger.info("Running inference...")

            res = subprocess.run(
                cmd,
                capture_output=True,
                text=True,
                timeout=3600,
                cwd=os.path.abspath(REPO_DIR),
                env=os.environ.copy(),
            )

            if res.returncode != 0:
                return f"❌ Failed:\n{res.stderr}", None, None

            # 4. Process Outputs
            phys_dir = os.path.join(out_dir, "physical_ensembles")
            raw_dir = os.path.join(out_dir, "ensembles")

            found = glob.glob(os.path.join(phys_dir, "*.pdb"))
            if not found:
                found = glob.glob(os.path.join(raw_dir, "*.pdb"))

            if not found:
                return "❌ No PDB generated", None, None

            seq_hash = hashlib.sha256(clean_seq.encode("utf-8")).hexdigest()
            output_filename = f"ensemble_{seq_hash[:16]}_L{len(clean_seq)}.pdb"
            final_out = os.path.join(OUTPUTS_DIR, output_filename)
            os.makedirs(OUTPUTS_DIR, exist_ok=True)
            shutil.copy(found[0], final_out)

            # Optional: keep a copy in persistent results dir (if mounted/writable)
            try:
                os.makedirs(RESULTS_DIR, exist_ok=True)
                shutil.copy(found[0], f"{RESULTS_DIR}/{output_filename}")
            except Exception as e:
                logger.warning(f"Could not copy to RESULTS_DIR: {e}")

            # 5. Viewer HTML (multi-model PDB handled as frames). [web:42]
            viewer_html = pdb_to_3dmol_iframe(final_out, height=520, autoplay=False)

            return f"βœ… Done! ({num_samples} samples)", viewer_html, final_out

    except Exception as e:
        return f"❌ System Error: {str(e)}", None, None


# ~100 aa example sequence (synthetic; uses only valid AA letters)
EXAMPLE_SEQ_100ISH = (
    "MSTNPKPQRKTKRNTNRRPQDVKPGGKKQTKK"
    "GDSAENLQKLRDNLVQRLKNNGVSVEKVTKELG"
    "ADKVEEMLAKLGADVVVVES"
)


def load_example():
    return EXAMPLE_SEQ_100ISH


CSS = """
#sequence_box textarea {
  font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas,
               "Liberation Mono", "DejaVu Sans Mono", "Courier New", monospace !important;
  letter-spacing: 0.04em;
  font-variant-ligatures: none;
}
"""


with gr.Blocks(theme=gr.themes.Soft(), css=CSS) as demo:
    gr.Markdown("# 🧬 PepTron - Multidomain protein ensemble generation")

    with gr.Row():
        with gr.Column():
            seq = gr.Textbox(
                label="Sequence",
                lines=4,
                placeholder="Sequence...",
                elem_id="sequence_box",
                html_attributes=InputHTMLAttributes(
                    spellcheck=False,
                    autocorrect="off",
                    autocapitalize="off",
                    autocomplete="off",
                ),
            )
            with gr.Row():
                example_btn = gr.Button("Load example")
                btn = gr.Button("Generate", variant="primary")
            samps = gr.Slider(1, 50, value=5, label="Samples", step=1)

        with gr.Column():
            stat = gr.Textbox(label="Status")
            mol = gr.HTML(label="3D view")
            out = gr.File(label="PDB")

    example_btn.click(load_example, inputs=None, outputs=seq)
    btn.click(peptron_inference, [seq, samps], [stat, mol, out])

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)