LigandMPNN / space_utils /pipelines.py
gabboud's picture
integrate download_zip logics inside run_generation_folder
116a83c
import random
import subprocess
import time
import os
import gradio as gr
import shutil
import json
import spaces
from space_utils.handle_files import display_fasta, download_results_as_zip
#def run_generation_single_pdb(pdb_file, num_batches, num_designs_per_batch, chains_to_design, temperature, extra_args):
#
# random_hash = random.getrandbits(128)
# time_stamp = time.strftime("%Y-%m-%d-%H-%M-%S")
# out_dir = f"./output/ligandmpnn_results_{time_stamp}_{random_hash}"
# os.makedirs(out_dir, exist_ok=True)
#
# try:
# command= f"python run.py --model_type 'ligand_mpnn' --checkpoint_ligand_mpnn './model_params/ligandmpnn_v_32_030_25.pt' \
# --pdb_path {pdb_file} --number_of_batches {num_batches} --batch_size {num_designs_per_batch} \
# --temperature {temperature} --out_folder '{out_dir}'"
# if extra_args:
# command += f" {extra_args}"
# if chains_to_design:
# command += f" --chains_to_design {chains_to_design}"
# res = subprocess.run(command, shell=True, check=True, text=True, capture_output=True)
#
# return gr.update(value="Generation complete!"), gr.update(value=out_dir)
# except subprocess.CalledProcessError as e:
# return gr.update(value=f"Generation failed: \n{e.stderr}"), gr.update(value=None)
def get_duration(pdb_folder, num_batches, num_designs_per_batch, chains_to_design, temperature, extra_args, max_duration):
return max_duration
@spaces.GPU(duration=get_duration)
def run_generation_folder(pdb_folder, num_batches, num_designs_per_batch, chains_to_design, temperature, extra_args, max_duration):
if pdb_folder is None:
return "Please upload a folder of PDB files to run generation.", None, None
status_string = f"Running generation with batch size {num_designs_per_batch} and number of batches {num_batches} for {len(pdb_folder)} pdb files..."
random_hash = random.getrandbits(128)
time_stamp = time.strftime("%Y-%m-%d-%H-%M-%S")
in_dir = f"./input/ligandmpnn_inputs_{time_stamp}_{random_hash}"
out_dir = f"./output/ligandmpnn_results_{time_stamp}_{random_hash}"
os.makedirs(in_dir, exist_ok=True)
os.makedirs(out_dir, exist_ok=True)
#basically to run a batch one should pass to the CLI a JSON file where main keys are names of pdb files to process
#the json file and all pdb files should be in the same input directory for it to work, hence the copying I do here
print(pdb_folder)
print("Scanning uploaded folder for pdb files...")
config_dict = {}
for file_path in pdb_folder:
if file_path.endswith(".pdb"):
print("Found pdb file:", file_path)
shutil.copy2(file_path, in_dir)
config_dict[os.path.join(in_dir, os.path.basename(file_path))]= ""
json_path = os.path.join(in_dir, "config.json")
with open(json_path, "w") as f:
json.dump(config_dict, f)
with open(json_path, "r") as f:
print("Config JSON content:\n", json.load(f))
print("Files in input directory:", os.listdir(in_dir))
try:
print(out_dir)
command= f"python run.py --model_type 'ligand_mpnn' --checkpoint_ligand_mpnn './model_params/ligandmpnn_v_32_030_25.pt' \
--pdb_path_multi {json_path} --number_of_batches {num_batches} --batch_size {num_designs_per_batch} \
--temperature {temperature} --out_folder '{out_dir}'"
if extra_args:
command += f" {extra_args}"
if chains_to_design:
command += f" --chains_to_design {chains_to_design}"
status_string += "\nRunning command:\n" + command
res = subprocess.run(command, shell=True, check=True, text=True, capture_output=True)
status_string += "\nGeneration complete!"
zip_path = download_results_as_zip(out_dir)
return status_string, out_dir, zip_path
except subprocess.CalledProcessError as e:
return status_string+ f"\nGeneration failed: \n{e.stderr}", None, None