| |
| from pathlib import Path |
| from typing import Optional |
| from uuid import uuid4 |
| import hashlib |
| import json |
| import gradio as gr |
| import gemmi |
| from gradio_molecule3d import Molecule3D |
| from modal_app import app, chai1_inference, download_inference_dependencies, here |
|
|
| |
|
|
| |
| def compute_Chai1( |
| force_redownload: bool = False, |
| fasta_file: Optional[str] = None, |
| inference_config_file: Optional[str] = None, |
| output_dir: Optional[str] = None, |
| run_id: Optional[str] = None, |
| ): |
| """Compute a Chai1 simulation. |
| |
| Args: |
| x (float | int): The number to square. |
| |
| Returns: |
| float: The square of the input number. |
| """ |
| with app.run(): |
| |
| print("🧬 checking inference dependencies") |
| download_inference_dependencies.remote(force=force_redownload) |
|
|
| if fasta_file is None: |
| fasta_file = here / "inputs" / "chai1_default_input.fasta" |
| print(f"🧬 running Chai inference on {fasta_file}") |
| fasta_content = Path(fasta_file).read_text() |
|
|
| if inference_config_file is None: |
| inference_config_file = here / "inputs" / "chai1_quick_inference.json" |
| print(f"🧬 loading Chai inference config from {inference_config_file}") |
| inference_config = json.loads(Path(inference_config_file).read_text()) |
|
|
| if run_id is None: |
| run_id = hashlib.sha256(uuid4().bytes).hexdigest()[:8] |
| print(f"🧬 running inference with {run_id=}") |
|
|
| results = chai1_inference.remote(fasta_content, inference_config, run_id) |
|
|
| if output_dir is None: |
| output_dir = Path("./results") |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| print(f"🧬 saving results to disk locally in {output_dir}") |
| for ii, (scores, cif) in enumerate(results): |
| (Path(output_dir) / f"{run_id}-scores.model_idx_{ii}.npz").write_bytes(scores) |
| (Path(output_dir) / f"{run_id}-preds.model_idx_{ii}.cif").write_text(cif) |
| |
| print(Path(output_dir) / f"{run_id}-preds.model_idx_{ii}.cif") |
| cif_name = output_dir+"/"+str(run_id)+"-preds.model_idx_"+str(ii)+".cif" |
| print(cif_name) |
| pdb_name = cif_name.split('.cif')[0] + '.pdb' |
| st = gemmi.read_structure(cif_name) |
| st.write_minimal_pdb(pdb_name) |
| |
| return pdb_name |
|
|
|
|
| |
|
|
| reps = [{"model": 0,"style": "cartoon","color": "whiteCarbon"}] |
|
|
| with gr.Blocks() as demo: |
| inp = gr.Textbox(placeholder="Sequence file", label="Input Fasta file") |
| btn = gr.Button("Run") |
| out = Molecule3D(label="Molecule3D", reps=reps) |
| btn.click(fn=compute_Chai1, inputs=[inp], outputs=[out]) |
|
|
| |
| if __name__ == "__main__": |
| demo.launch(mcp_server=True) |
|
|