MCP_Chai1_Modal / app.py
PhDFlo's picture
Modif func fasta and example
0a38dda
raw
history blame
7.51 kB
# Import librairies
from pathlib import Path
from typing import Optional
from uuid import uuid4
import hashlib
import json
import gradio as gr
import gemmi
from gradio_molecule3d import Molecule3D
from modal_app import app, chai1_inference, download_inference_dependencies, here
# Definition of the tools for the MCP server
# Function to return a fasta file
def create_fasta_file(sequence: str, name: Optional[str] = None) -> str:
"""Create a FASTA file from a protein sequence string with a unique name.
Args:
sequence (str): The protein sequence string with optional line breaks
name (str, optional): The name/identifier for the sequence. Defaults to "PROTEIN"
Returns:
str: Name of the created FASTA file
"""
# Remove any trailing/leading whitespace but preserve line breaks
lines = sequence.strip().split('\n')
# Check if the first line is a FASTA header
if not lines[0].startswith('>'):
# If no header provided, add one
if name is None:
name = "PROTEIN"
sequence = f">{name}\n{sequence}"
# Create FASTA content (preserving line breaks)
fasta_content = sequence
# Generate a unique file name
unique_id = hashlib.sha256(uuid4().bytes).hexdigest()[:8]
file_name = f"chai1_{unique_id}_input.fasta"
file_path = here / "inputs" / file_name
# Write the FASTA file
with open(file_path, "w") as f:
f.write(fasta_content)
return file_name
# Function to create a custom JSON config file
def create_custom_config(
num_trunk_recycles: int = 3,
num_diffn_timesteps: int = 200,
seed: int = 42,
use_esm_embeddings: bool = True,
use_msa_server: bool = True,
output_file: Optional[str] = None
) -> str:
"""Create a custom JSON configuration file for Chai1 inference.
Args:
num_trunk_recycles (int, optional): Number of trunk recycles. Defaults to 3.
num_diffn_timesteps (int, optional): Number of diffusion timesteps. Defaults to 200.
seed (int, optional): Random seed. Defaults to 42.
use_esm_embeddings (bool, optional): Whether to use ESM embeddings. Defaults to True.
use_msa_server (bool, optional): Whether to use MSA server. Defaults to True.
output_file (str, optional): Path to save the config file. If None, saves to default location.
Returns:
str: Path to the created config file
"""
config = {
"num_trunk_recycles": num_trunk_recycles,
"num_diffn_timesteps": num_diffn_timesteps,
"seed": seed,
"use_esm_embeddings": use_esm_embeddings,
"use_msa_server": use_msa_server
}
if output_file is None:
output_file = here / "inputs" / "chai1_custom_inference.json"
with open(output_file, "w") as f:
json.dump(config, f, indent=4)
return str(output_file)
# Function to compute Chai1 inference
def compute_Chai1(
fasta_file: Optional[str] = "",
inference_config_file: Optional[str] = "",
):
"""Compute a Chai1 simulation.
Args:
x (float | int): The number to square.
Returns:
float: The square of the input number.
"""
with app.run():
force_redownload = False
print("🧬 checking inference dependencies")
download_inference_dependencies.remote(force=force_redownload)
# Define fasta file
if not fasta_file:
fasta_file = here / "inputs" / "chai1_default_input.fasta"
print(f"🧬 running Chai inference on {fasta_file}")
fasta_file = here / "inputs" / fasta_file
print(fasta_file)
fasta_content = Path(fasta_file).read_text()
# Define inference config file
if not inference_config_file:
inference_config_file = here / "inputs" / "chai1_quick_inference.json"
print(f"🧬 loading Chai inference config from {inference_config_file}")
inference_config = json.loads(Path(inference_config_file).read_text())
# Generate a unique run ID
run_id = hashlib.sha256(uuid4().bytes).hexdigest()[:8] # short id
print(f"🧬 running inference with {run_id=}")
results = chai1_inference.remote(fasta_content, inference_config, run_id)
# Define output directory
output_dir = Path("./results")
output_dir.mkdir(parents=True, exist_ok=True)
print(f"🧬 saving results to disk locally in {output_dir}")
for ii, (scores, cif) in enumerate(results):
(Path(output_dir) / f"{run_id}-scores.model_idx_{ii}.npz").write_bytes(scores)
(Path(output_dir) / f"{run_id}-preds.model_idx_{ii}.cif").write_text(cif)
# Take the last cif file and convert it to pdb
cif_name = str(output_dir)+"/"+str(run_id)+"-preds.model_idx_"+str(ii)+".cif"
pdb_name = cif_name.split('.cif')[0] + '.pdb'
st = gemmi.read_structure(cif_name)
st.write_minimal_pdb(pdb_name)
return pdb_name
# Create the Gradio interface
reps = [{"model": 0,"style": "cartoon","color": "hydrophobicity"}]
with gr.Blocks() as demo:
gr.Markdown(
"""
# Chai1 Simulation Interface
This interface allows you to run Chai1 simulations on a given Fasta sequence file.
""")
with gr.Tab("Configuration 📦"):
text_input = gr.Textbox(placeholder="Fasta format sequences", label="Fasta content", lines=10)
text_output = gr.Textbox(placeholder="Fasta file name", label="Fasta file name")
text_button = gr.Button("Create Fasta file")
text_button.click(fn=create_fasta_file, inputs=[text_input], outputs=[text_output])
gr.Markdown(
"""
You can input a Fasta file containing the sequence of the molecule you want to simulate.
The output will be a 3D representation of the molecule based on the Chai1 model.
## Instructions
1. Upload a Fasta sequence file containing the molecule sequence.
2. Click the "Run" button to start the simulation.
3. The output will be a 3D visualization of the molecule.
## Example Input
You can use the default Fasta file provided in the inputs directory, or upload your own.
## Output
The output will be a 3D representation of the molecule, which you can interact with.
## Note
Make sure to have the necessary dependencies installed and the Chai1 model available in the specified directory.
## Disclaimer
This interface is for educational and research purposes only. The results may vary based on the input sequence and the Chai1 model's capabilities.
## Contact
For any issues or questions, please contact the developer or refer to the documentation.
## Example Fasta File
```
>protein|name=example-protein
AGSHSMRYFSTSVSRPGRGEPRFIAVGYVDDTQFVRFD
""")
with gr.Tab("Run folding simulation 🚀"):
inp1 = gr.Textbox(placeholder="Fasta Sequence file", label="Input Fasta file")
inp2 = gr.Textbox(placeholder="Config file", label="JSON Config file")
btn = gr.Button("Run")
out = Molecule3D(label="Molecule3D", reps=reps)
btn.click(fn=compute_Chai1, inputs=[inp1 , inp2], outputs=[out])
# Launch both the Gradio web interface and the MCP server
if __name__ == "__main__":
demo.launch(mcp_server=True)