| |
| from pathlib import Path |
| from typing import Optional |
| from uuid import uuid4 |
| import hashlib |
| import json |
| import gradio as gr |
| import gemmi |
| from gradio_molecule3d import Molecule3D |
| from modal_app import app, chai1_inference, download_inference_dependencies, here |
| from numpy import load |
| from typing import List |
|
|
| theme = gr.themes.Default( |
| text_size="md", |
| radius_size="lg", |
| ) |
|
|
| |
| def select_best_model( |
| run_id: str, |
| number_of_scores: int=5, |
| scores_to_print: List[str]=None, |
| results_dir: str="results/score", |
| prefix: str="-scores.model_idx_", |
| ): |
| """ |
| Selects the best model based on the aggregate score among several simulation results. |
| |
| Args: |
| run_id (str): Unique identifier for the inference run. |
| number_of_scores (int, optional): Number of models to evaluate (number of score files to read). Default is 5. |
| scores_to_print (List[str], optional): List of score names to display for each model (e.g., ["aggregate_score", "ptm", "iptm"]). Default is ["aggregate_score", "ptm", "iptm"]. |
| results_dir (str, optional): Directory where the result files are located. Default is "results/score". |
| prefix (str, optional): Prefix used in the score file names. Default is "-scores.model_idx_". |
| |
| Returns: |
| Tuple[int, float]: |
| - best_model (int): Index of the best model (the one with the highest aggregate score and without inter-chain clashes). |
| - max_aggregate_score (float): Value of the highest aggregate score. |
| """ |
| print(f"🧬 Start reading scores for each inference...") |
| if scores_to_print is None: |
| scores_to_print = ["aggregate_score", "ptm", "iptm"] |
| max_aggregate_score = 0 |
| best_model = None |
| for model_index in range(number_of_scores): |
| print(f" 🧬 Reading scores for model {model_index}...") |
| data = load(f"{results_dir}/{run_id}{prefix}{model_index}.npz") |
| if data["has_inter_chain_clashes"][0] == False: |
| for item in scores_to_print: |
| print(f"{item}: {data[item][0]}") |
| else: |
| print(f" 🧬 Model {model_index} has inter-chain clashes, skipping scores.") |
| continue |
| if data["aggregate_score"][0] > max_aggregate_score: |
| max_aggregate_score = data["aggregate_score"][0] |
| best_model = int(model_index) |
| print( |
| f"🧬 Best model is {best_model} with an aggregate score of {max_aggregate_score}." |
| ) |
| return best_model, max_aggregate_score |
|
|
| |
| |
| def create_fasta_file(file_content: str, name: Optional[str] = None, seq_name: Optional[str] = None) -> str: |
| """Create a FASTA file from a biomolecule sequence string with a unique name. |
| |
| Args: |
| file_content (str): The content of the FASTA file required with optional line breaks |
| name (str, optional): FASTA file name ending with .fasta ideally. If not provided, a unique ID will be generated |
| seq_name (str, optional): The name/identifier for the sequence. Defaults to "protein" |
| |
| |
| Returns: |
| str: Name of the created FASTA file |
| """ |
| |
| if not file_content.strip(): |
| print("Fasta file content cannot be empty so the example fasta file will be used") |
| file_content = ">protein|name=example-protein\nAGSHSMRYFSTSVSRPGRGEPRFIAVGYVDDTQFVRFD" |
| |
| |
| lines = file_content.strip().split('\n') |
| |
| |
| if not lines[0].startswith('>'): |
| |
| if seq_name is None: |
| seq_name = "protein" |
| file_content = f">{seq_name}\n{file_content}" |
| |
| |
| fasta_content = file_content |
| |
| |
| unique_id = hashlib.sha256(uuid4().bytes).hexdigest()[:8] |
| if name: |
| file_name = name |
| else: |
| file_name = f"chai1_{unique_id}.fasta" |
| file_path = here / "inputs/fasta" / file_name |
| |
| |
| with open(file_path, "w") as f: |
| f.write(fasta_content) |
|
|
|
|
| |
| def create_json_config( |
| num_diffn_timesteps: int, |
| num_trunk_recycles: int, |
| seed: int, |
| options: list, |
| name: Optional[str] = None |
| ) -> str: |
| """Create a JSON configuration file from the Gradio interface inputs. |
| |
| Args: |
| num_diffn_timesteps (int): Number of diffusion timesteps from slider |
| num_trunk_recycles (int): Number of trunk recycles from slider |
| seed (int): Random seed from slider |
| options (list): List of selected options from checkbox group |
| name (str, optional): JSON config file name ending with .json ideally. If not provided, a unique ID will be generated |
| |
| Returns: |
| str: Name of the created JSON file |
| """ |
| |
| use_esm_embeddings = "ESM_embeddings" in options |
| use_msa_server = "MSA_server" in options |
| |
| |
| config = { |
| "num_trunk_recycles": num_trunk_recycles, |
| "num_diffn_timesteps": num_diffn_timesteps, |
| "seed": seed, |
| "use_esm_embeddings": use_esm_embeddings, |
| "use_msa_server": use_msa_server |
| } |
| |
| |
| unique_id = hashlib.sha256(uuid4().bytes).hexdigest()[:8] |
| if name: |
| file_name = name |
| else: |
| file_name = f"chai1_{unique_id}.json" |
| file_path = here / "inputs/config" / file_name |
| |
| |
| with open(file_path, "w") as f: |
| json.dump(config, f, indent=4) |
|
|
|
|
| |
| def compute_Chai1( |
| fasta_file_name: Optional[str] = "", |
| inference_config_file_name: Optional[str] = "", |
| ): |
| """Compute a Chai1 simulation. |
| |
| Args: |
| fasta_file_name (str, optional): FASTA file name to use for the Chai1 simulation. |
| If not provided, uses the default input file. |
| inference_config_file_name (str, optional): JSON configuration file name for inference. |
| If not provided, uses the default quick inference configuration. |
| |
| Returns: |
| pd.DataFrame: DataFrame containing model scores and CIF file paths |
| """ |
| import pandas as pd |
| with app.run(): |
| force_redownload = False |
| |
| print("🧬 checking inference dependencies") |
| download_inference_dependencies.remote(force=force_redownload) |
|
|
| |
| if not fasta_file_name: |
| fasta_file_name = here / "inputs/fasta" / "chai1_default_input.fasta" |
| print(f"🧬 running Chai inference on {fasta_file_name}") |
| fasta_file_name = here / "inputs/fasta" / fasta_file_name |
| print(fasta_file_name) |
| fasta_content = Path(fasta_file_name).read_text() |
|
|
| |
| if not inference_config_file_name: |
| inference_config_file_name = here / "inputs/config" / "chai1_quick_inference.json" |
| inference_config_file_name = here / "inputs/config" / inference_config_file_name |
| print(f"🧬 loading Chai inference config from {inference_config_file_name}") |
| inference_config = json.loads(Path(inference_config_file_name).read_text()) |
|
|
| |
| run_id = hashlib.sha256(uuid4().bytes).hexdigest()[:8] |
| print(f"🧬 running inference with {run_id=}") |
|
|
| results = chai1_inference.remote(fasta_content, inference_config, run_id) |
|
|
| |
| output_dir = Path("./results") |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| print(f"🧬 saving results to disk locally in {output_dir}") |
| |
| |
| model_data = [] |
| |
| for ii, (scores, cif) in enumerate(results): |
| score_file = Path(output_dir, "score") / f"{run_id}-scores.model_idx_{ii}.npz" |
| cif_file = Path(output_dir, "molecules") / f"{run_id}-preds.model_idx_{ii}.cif" |
| |
| score_file.write_bytes(scores) |
| cif_file.write_text(cif) |
| |
| |
| data = load(str(score_file)) |
| |
| if not data["has_inter_chain_clashes"][0]: |
| model_data.append({ |
| "Model Index": ii, |
| "Aggregate Score": float(data["aggregate_score"][0]), |
| "PTM": float(data["ptm"][0]), |
| "IPTM": float(data["iptm"][0]), |
| "CIF File": str(cif_file).split("/")[-1], |
| }) |
| |
| |
| results_df = pd.DataFrame(model_data).sort_values("Aggregate Score", ascending=False) |
| |
| return results_df |
|
|
|
|
| |
| def plot_protein(result_df) -> str: |
| """Plot the 3D structure of a biomolecule using the DataFrame from compute_Chai1. |
| |
| Args: |
| result_df (pd.DataFrame): DataFrame containing model information and scores |
| |
| Returns: |
| str: Path to the generated PDB file of the best model. |
| """ |
| if result_df.empty: |
| return "" |
| |
| |
| best_cif = str(Path("results/molecules") / result_df.iloc[0]["CIF File"]) |
| |
| |
| pdb_file = best_cif.replace('.cif', '.pdb') |
| |
| |
| if not Path(pdb_file).exists(): |
| st = gemmi.read_structure(best_cif) |
| st.write_minimal_pdb(pdb_file) |
| |
| return pdb_file |
|
|
|
|
| |
| def show_cif_file(cif_file): |
| """Plot a 3D structure from a CIF file with the Molecule3D library. |
| |
| Args: |
| cif_file: A biomolecule structure file in CIF format. This can be a file uploaded by the user. |
| If None, the function will return None. |
| |
| Returns: |
| str or None: PDB file name if successful, None if no file was provided |
| or if conversion failed. |
| """ |
| if not cif_file: |
| return None |
| |
| cif_path = Path(cif_file.name) |
| st = gemmi.read_structure(str(cif_path)) |
| pdb_file = cif_path.with_suffix('.pdb') |
| st.write_minimal_pdb(str(pdb_file)) |
| |
| return str(pdb_file) |
|
|
|
|
| |
| reps = [{"model": 0,"style": "cartoon","color": "hydrophobicity"}] |
|
|
| with gr.Blocks(theme=theme) as demo: |
| |
| gr.Markdown( |
| """ |
| # Protein Folding Simulation Interface |
| This interface provides the tools to fold FASTA chains based on Chai-1 model. Also, this is a MCP server to provide all the tools to automate the process of folding biomolecules with LLMs. |
| """) |
| |
| with gr.Tab("Introduction 🔭"): |
| |
| gr.Image("images/logo1.png", show_label=False, width=600, show_download_button=False, show_fullscreen_button=False, show_share_button=False) |
| |
| gr.Markdown( |
| """ |
| # Stakes |
| |
| The industry is undergoing a profound transformation due to the development of Large Language Models (LLMs) and the recent advancements that enable them to access external tools. |
| For years, companies have leveraged simulation tools to accelerate and reduce the costs of product development. |
| One of the primary challenges in the coming years will be to create agents capable of setting up, running, and processing simulations to further expedite innovation. |
| Engineers will focus on analysis rather than simulation setup, allowing them to concentrate on the most critical aspects of their work. |
| |
| # Objective |
| |
| This project represents a first step towards developing AI agents that can perform simulations using existing engineering softwares. |
| Key domains of application include: |
| - **CFD** (Computational Fluid Dynamics) simulations |
| - **Biology** (Protein Folding, Molecular Dynamics, etc.) |
| - **Neural network applications** |
| |
| While this project focuses on biomolecules folding, the principles employed can be extended to other domains. |
| Specifically, it uses [Chai-1](https://www.chaidiscovery.com/blog/introducing-chai-1), a multi-modal foundation model for molecular structure prediction that achieves state-of-the-art performance across various benchmarks. |
| Chai-1 enables unified prediction of proteins, small molecules, DNA, RNA, glycosylations, and more. |
| |
| Industrial computations frequently require substantial resources (large number of CPUs and GPUs) that are performed on High-Performance Computing (HPC) clusters. |
| To this end, [Modal Labs](https://modal.com/), a serverless platform that offers a straightforward method to run any application with the latest CPU and GPU hardware, will be used. |
| |
| MCP servers are an efficient solution to connect LLMs to real world engineering applications by providing access to a set of tools. |
| The purpose of this project is to enable users to run biomolecule folding simulations using the Chai-1 model through any LLM chat or with a Gradio interface. |
| |
| # Benefits |
| |
| 1. **Efficiency**: The MCP server's connected to high-performance computing capabilities ensure that simulations are run quickly and efficiently. |
| |
| 2. **Ease of Use**: Only provide necessary parameters to the user to simplify the process of setting up and running complex simulations. |
| |
| 3. **Integration**: The seamless integration between the LLM's chat interface and the MCP server allows for a streamlined workflow, from simulation setup to results analysis. |
| |
| The following video illustrates a practical use of the MCP server to run a biomolecule folding simulation using the Chai-1 model. |
| In this scenario, Copilot is used in Agent mode with Claude 3.5 Sonnet to leverage the tools provided by the MCP server. |
| |
| """ |
| ) |
| |
| gr.HTML( |
| """<style> |
| iframe { |
| display: block; |
| margin: 0 auto; |
| } |
| </style> |
| <iframe width="600" height="338" |
| src="https://www.youtube.com/embed/P9cAKxJ9Zh8" |
| frameborder="0" allowfullscreen></iframe>""", |
| label="MCP demonstration video" |
| ) |
| |
| gr.Markdown( |
| """ |
| # MCP tools |
| 1. `create_fasta_file`: Create a FASTA file from a biomolecule sequence string with a unique name. |
| 2. `create_json_config`: Create a JSON configuration file from the Gradio interface inputs. |
| 3. `compute_Chai1`: Compute a Chai-1 simulation on Modal labs server. Return a DataFrame with protein scores. |
| 4. `plot_protein`: Plot the 3D structure of a biomolecule using the DataFrame from `compute_Chai1` (Use for Gradio interface). |
| 5. `show_cif_file`: Plot a 3D structure from a CIF file with the Molecule3D library (Use for the Gradio interface). |
| """) |
| |
| with open("introduction_page.md", "r") as f: |
| intro_md = f.read() |
| gr.Markdown(intro_md) |
| |
| gr.Markdown( |
| """ |
| # Result example |
| The following image shows an example of a protein folding simulation using the Chai-1 model. |
| The simulation was run with the default configuration and the image is 3D view from the Gradio interface. |
| """) |
| |
| gr.Image("images/protein.png", show_label=True, width=400, label="Protein Folding example", show_download_button=False, show_fullscreen_button=False, show_share_button=False) |
| |
| gr.Markdown( |
| """ |
| # What's next? |
| 1. Expose additional tools to post-process the results of the simulations (ex: Plot images of the molecule structure from a file). |
| The current post-processing tools are suited for the Gradio interface. |
| 2. Continue the pipeline by adding software like [OpenMM](https://openmm.org/) or [Gromacs](https://www.gromacs.org/) for molecular dynamics simulations. |
| 3. Perform full simulation plans including loops over parameters fully automated by the LLM. |
| |
| # Contact |
| For any issues or questions, please contact the developer or refer to the documentation. |
| """) |
|
|
| |
| with gr.Tab("Configuration 📦"): |
| |
| gr.Markdown( |
| """ |
| ## Fasta file and configuration generator (optional) |
| """) |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| slider_nb = gr.Slider(1, 500, value=300, label="Number of diffusion time steps", info="Choose the number of diffusion time steps for the simulation", step=1, interactive=True, elem_id="num_iterations") |
| slider_trunk = gr.Slider(1, 5, value=3, label="Number of trunk recycles", info="Choose the number of iterations for the simulation", step=1, interactive=True, elem_id="trunk_number") |
| slider_seed = gr.Slider(1, 100, value=42, label="Seed", info="Choose the seed", step=1, interactive=True, elem_id="seed") |
| check_options = gr.CheckboxGroup(["ESM_embeddings", "MSA_server"], value=["ESM_embeddings",], label="Additional options", info="Options to use ESM embeddings and MSA server", elem_id="options") |
| config_name = gr.Textbox(placeholder="Enter a name for the json file (optional)", label="JSON file name") |
| button_json = gr.Button("Create Config file") |
| button_json.click(fn=create_json_config, inputs=[slider_nb, slider_trunk, slider_seed, check_options, config_name], outputs=[]) |
| |
| |
| with gr.Column(scale=1): |
| fasta_input = gr.Textbox(placeholder="Fasta format sequences", label="Fasta content", lines=10) |
| fasta_name = gr.Textbox(placeholder="Enter the name of the fasta file name (optional)", label="Fasta file name") |
| fasta_button = gr.Button("Create Fasta file") |
| fasta_button.click(fn=create_fasta_file, inputs=[fasta_input, fasta_name], outputs=[]) |
| |
| gr.Markdown( |
| """ |
| ## Example Fasta File |
| ``` |
| >protein|name=example-protein |
| AGSHSMRYFSTSVSRPGRGEPRFIAVGYVDDTQFVRFD |
| ``` |
| """) |
| |
| |
| with gr.Tab("Run folding simulation 🚀"): |
| with gr.Row(): |
| with gr.Column(scale=1): |
| inp2 = gr.FileExplorer(root_dir=here / "inputs/config", |
| value="chai1_default_inference.json", |
| label="Configuration file", |
| file_count='single') |
| |
| with gr.Column(scale=1): |
| inp1 = gr.FileExplorer(root_dir=here / "inputs/fasta", |
| value="chai1_default_input.fasta", |
| label="Input Fasta file", |
| file_count='single') |
| btn_refresh = gr.Button("Refresh available files") |
| |
| |
| def update_file_explorer(): |
| """Don't need to be used by LLMs, but useful for the interface to update the file explorer""" |
| return gr.FileExplorer(root_dir=here), gr.FileExplorer(root_dir=here) |
| def update_file_explorer_2(): |
| """Don't need to be used by LLMs, but useful for the interface to update the file explorer""" |
| return gr.FileExplorer(root_dir=here / "inputs/fasta"), gr.FileExplorer(root_dir=here / "inputs/config") |
| |
| btn_refresh.click(update_file_explorer, outputs=[inp1,inp2]).then(update_file_explorer_2, outputs=[inp1, inp2]) |
| |
| |
| out = gr.DataFrame( |
| headers=["Model Index", "Aggregate Score", "PTM", "IPTM", "CIF File"], |
| datatype=["number", "number", "number", "number", "str"], |
| label="Inference Results sorted by Aggregate Score", |
| visible=True, |
| ) |
| out2 = Molecule3D(label="Plot the 3D Molecule", reps=reps) |
| |
| btn = gr.Button("Run Simulation") |
| btn.click(fn=compute_Chai1, inputs=[inp1 , inp2], outputs=[out]).then( |
| fn=plot_protein, |
| inputs=out, |
| outputs=out2 |
| ) |
| |
| |
| with gr.Tab("Plot CIF file 💻"): |
| |
| gr.Markdown( |
| """ |
| ## Plot a 3D structure from a CIF file |
| """) |
| |
| cif_input = gr.File(label="Input CIF file", file_count='single') |
| cif_output = Molecule3D(label="Plot the 3D Molecule", reps=reps) |
| cif_input.change(fn=show_cif_file, inputs=cif_input, outputs=cif_output) |
|
|
| |
| if __name__ == "__main__": |
| demo.launch(mcp_server=True) |
|
|