Spaces:
Runtime error
Runtime error
| import argparse | |
| import ray | |
| import time | |
| from diffab.tools.relax.openmm_relaxer import run_openmm | |
| from diffab.tools.relax.pyrosetta_relaxer import run_pyrosetta, run_pyrosetta_fixbb | |
| from diffab.tools.relax.base import TaskScanner | |
| def run_openmm_remote(task): | |
| return run_openmm(task) | |
| def run_pyrosetta_remote(task): | |
| return run_pyrosetta(task) | |
| def run_pyrosetta_fixbb_remote(task): | |
| return run_pyrosetta_fixbb(task) | |
| def pipeline_openmm_pyrosetta(task): | |
| funcs = [ | |
| run_openmm_remote, | |
| run_pyrosetta_remote, | |
| ] | |
| for fn in funcs: | |
| task = fn.remote(task) | |
| return ray.get(task) | |
| def pipeline_pyrosetta(task): | |
| funcs = [ | |
| run_pyrosetta_remote, | |
| ] | |
| for fn in funcs: | |
| task = fn.remote(task) | |
| return ray.get(task) | |
| def pipeline_pyrosetta_fixbb(task): | |
| funcs = [ | |
| run_pyrosetta_fixbb_remote, | |
| ] | |
| for fn in funcs: | |
| task = fn.remote(task) | |
| return ray.get(task) | |
| pipeline_dict = { | |
| 'openmm_pyrosetta': pipeline_openmm_pyrosetta, | |
| 'pyrosetta': pipeline_pyrosetta, | |
| 'pyrosetta_fixbb': pipeline_pyrosetta_fixbb, | |
| } | |
| def main(): | |
| ray.init() | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--root', type=str, default='./results') | |
| parser.add_argument('--pipeline', type=lambda s: pipeline_dict[s], default=pipeline_openmm_pyrosetta) | |
| args = parser.parse_args() | |
| final_pfx = 'fixbb' if args.pipeline == pipeline_pyrosetta_fixbb else 'rosetta' | |
| scanner = TaskScanner(args.root, final_postfix=final_pfx) | |
| while True: | |
| tasks = scanner.scan() | |
| futures = [args.pipeline.remote(t) for t in tasks] | |
| if len(futures) > 0: | |
| print(f'Submitted {len(futures)} tasks.') | |
| while len(futures) > 0: | |
| done_ids, futures = ray.wait(futures, num_returns=1) | |
| for done_id in done_ids: | |
| done_task = ray.get(done_id) | |
| print(f'Remaining {len(futures)}. Finished {done_task.current_path}') | |
| time.sleep(1.0) | |
| if __name__ == '__main__': | |
| main() | |