BoltzGen_Demo / app.py
almo03's picture
correct log path
829bdf3 verified
import gradio as gr
import os
import subprocess
from pathlib import Path
import shutil
import spaces
# Function to check that the files are in the same directory
def check_files_in_same_directory(file1, file2):
if file1 is None or file2 is None:
raise ValueError("Both the YAML config and target PDB/CIF files must be provided.")
if Path(file1.name).parent != Path(file2.name).parent:
raise ValueError("YAML config and target PDB/CIF must be in the same directory.")
def move_file_to_output(file, output_dir):
if file is None:
return None
src_path = Path(file.name) # path to the temp-uploaded file
dest_path = Path(output_dir) / src_path.name
dest_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(src_path, dest_path) # preserves metadata
return dest_path
def move_files_to_output(yaml_file, target_file, output_dir):
yaml_path_dest = move_file_to_output(yaml_file, output_dir)
target_path_dest = move_file_to_output(target_file, output_dir)
check_files_in_same_directory(yaml_path_dest, target_path_dest)
return yaml_path_dest, target_path_dest
def download_results(results_dir, zip_name):
"""
Creates a ZIP of `results_dir` and returns (zip_path or None, message).
"""
results_dir = Path(results_dir)
if not results_dir.exists() or not results_dir.is_dir():
return None, "No results available for download."
# Ensure zip_name is a Path and has .zip extension
zip_path = Path(zip_name)
if zip_path.suffix != ".zip":
zip_path = zip_path.with_suffix(".zip")
# Make parent directory if needed
if zip_path.parent != Path(""):
zip_path.parent.mkdir(parents=True, exist_ok=True)
# shutil.make_archive wants base_name without extension
base_name = str(zip_path.with_suffix(""))
shutil.make_archive(
base_name=base_name,
format="zip",
root_dir=str(results_dir),
)
zip_path = Path(base_name + ".zip")
return str(zip_path), "Download ready!"
def find_log_file(output_dir):
# log files start with "boltzgen" and end with ".log"
log_files = list(Path(output_dir).glob("boltzgen*.log"))
if log_files:
# return the most recently modified log file
return max(log_files, key=lambda f: f.stat().st_mtime)
return None
@spaces.GPU(duration=300)
def run_boltzgen(yaml_file, target_file, protocol, num_designs, budget):
output_dir = "/tmp/output"
os.makedirs(output_dir, exist_ok=True)
# Move config and target files to the output directory
yaml_path_dest, target_path_dest = move_files_to_output(yaml_file, target_file, output_dir)
cmd = [
"boltzgen", "run", str(yaml_path_dest),
"--output", output_dir,
"--protocol", protocol,
"--num_designs", str(num_designs),
"--budget", str(budget),
"--cache", "/tmp/cache"
]
try:
result = subprocess.run(cmd, capture_output=True, text=True, cwd=output_dir, timeout=300)
if result.returncode == 0:
final_designs_dir = Path(output_dir) / "final_ranked_designs"
if final_designs_dir.exists() and final_designs_dir.is_dir():
return f"Success!\nDesigns saved to {final_designs_dir}/"
else:
return "No designs passed the filtering criteria."
return f"Error:\n{result.stderr}"
except subprocess.TimeoutExpired:
return "Timeout (300s limit)"
except Exception as e:
return f"Failed: {str(e)}"
# Interface
with gr.Blocks() as demo:
# Markdown with image at the beginning
gr.Markdown(""" <h1 style="font-size: 48px;">BoltzGen Design Space</h1>""")
gr.Markdown(
"""
<span style="font-size:20px">
**Design novel protein, peptide, and nanobody binders** against any biomolecular target (proteins, DNA, small molecules) using the BoltzGen diffusion model.
**BoltzGen** unifies structure prediction and binder design in a single all-atom generative model with flexible design constraints.
Developed by Hannes Stärk and team at MIT: [description](https://jclinic.mit.edu/boltzgen/) [git](https://github.com/HannesStark/boltzgen)
</span>
""")
gr.Markdown(
"""
### Instructions
1. Prepare a YAML config file specifying your design task. See [BoltzGen Docs](https://github.com/HannesStark/boltzgen/tree/main/example) for examples.
2. Upload the target structure file (PDB or CIF) corresponding to your design as needed (by the YAML config).
3. Select the design protocol (chose "protein-anything" for general results).
4. Specify the number of designs to generate. !Warning: high numbers may take longer to compute than your session allows.
5. Specify your budget (# of designs wanted in the final diversified set). Note that the budget must be less than or equal to the number of designs.
6. Click "Run BoltzGen" and wait for the results!
7. Download your results using the buttons at the bottom:
- Final Ranked Designs: structures, metrics and report summary of the best designs after filtering and ranking (interaction score and diversity).
- Intermediate Designs: all generated trajectories before filtering/ranking.
"""
)
with gr.Row():
yaml_file = gr.File(
label="Design YAML Config",
file_types=[".yaml", ".yml"],
file_count="single"
)
target_file = gr.File(
label="Target PDB/CIF",
file_types=[".pdb", ".cif"],
file_count="single"
)
with gr.Row():
num_designs = gr.Number(
value=10,
label="Number of Designs",
precision=0,
minimum=1,
maximum=1000
)
budget = gr.Number(
value=5,
label="Budget (in tokens)",
precision=0,
minimum=1,
maximum=1000
)
protocol = gr.Dropdown(
choices=["protein-anything", "peptide-anything", "nanobody-anything"],
value="protein-anything",
label="Protocol"
)
run_button = gr.Button("Run BoltzGen", variant="primary")
output_text = gr.Textbox(label="Run Status", lines=2)
run_button.click(
run_boltzgen,
inputs=[yaml_file, target_file, protocol, num_designs, budget],
outputs=[output_text]
)
gr.Markdown("## Download the results")
with gr.Row():
download_final_ranked_designs_button = gr.Button("Download Final Results")
download_intermediate_results_button = gr.Button("Download Intermediate Designs")
with gr.Row():
download_final_ranked_designs_button.click(
download_results,
inputs=[
gr.State("/tmp/output/final_ranked_designs"),
gr.State("final_ranked_designs.zip")
],
outputs=[
gr.File(label="Download Final Ranked Designs")
]
)
download_intermediate_results_button.click(
download_results,
inputs=[
gr.State("/tmp/output"),
gr.State("intermediate_results.zip")
],
outputs=[
gr.File(label="Download Intermediate Designs")
]
)
# Show the log output in a textbox for debugging
with gr.Row():
log_output = gr.Textbox(label="BoltzGen Log Output", lines=10)
run_button.click(
lambda: find_log_file("/tmp").read_text() if find_log_file("/tmp") else "No log file found.",
inputs=[],
outputs=log_output
)
demo.queue()
demo.launch()