RosettaFold3 / utils /pipelines.py
gabboud's picture
include zip_file download logic on fold_all_jobs, remove gr.update returns
0cd364e
import gradio as gr
from lightning.fabric import seed_everything
import time
import os
import spaces
import subprocess
import gzip
from utils.handle_files import *
import sys
import shutil
from time import perf_counter
import random
def get_duration(job_files, support_files, num_predictions, early_stopping, diffusion_steps, max_duration):
return max_duration
@spaces.GPU(duration=get_duration)
def fold_all_jobs(job_files, support_files, num_predictions, early_stopping, diffusion_steps, max_duration):
"""
Folds protein structures using RosettaFold3 given user-uploaded job files (.json, .pdb, .cif) and support files (.pdb, .cif).
Job files are fed to the model as inputs, while support files are needed when conditioning a job like when templating.
Parameters:
----------
job_files: list of str, or gr.File object with file_count="multiple"
List of paths to the user-uploaded job files, which can be .json, .pdb, or .cif files.
support_files: list of str, or gr.File object with file_count="multiple"
List of paths to the user-uploaded support files, which can be .pdb or .cif files.
num_predictions: int
The number of structure predictions to generate for each input job.
early_stopping: float
The pLDDT threshold for early stopping. If set to 0, early stopping is disabled.
diffusion_steps: int
The number of diffusion steps to use during generation. Higher values may improve structure quality but will increase runtime.
max_duration: int
The maximum duration (in seconds) for the folding process. This is used to set the timeout for the Spaces GPU allocation on ZeroGPU.
Returns:
-------
textbox_update: gr.update,
A gr.update object to update the textbox with the status of the generation. propagates subprocess errors to the textbox if the generation fails.
zip_path: str or None
If generation is successful, returns the path to the generated zip file containing the results, which can be passed to gr.File for download. None if generation fails.
"""
if job_files is None:
return "Please ensure you have uploaded at least one job file: .json, .pdb, or .cif", None
elif support_files is None:
status_update = f"Creating {num_predictions} structure predictions per job.\n jobs uploaded {[os.path.basename(f.name) for f in job_files]}\nno support files uploaded."
else:
status_update = f"Creating {num_predictions} structure predictions per job.\n jobs uploaded {[os.path.basename(f.name) for f in job_files]}\nsupport files {[os.path.basename(f.name) for f in support_files]}."
session_hash = random.getrandbits(128)
time_stamp = time.strftime("%Y-%m-%d-%H-%M-%S")
directory = os.path.join(os.getcwd(), f"outputs/structure_prediction/session_{session_hash}_{time_stamp}")
os.makedirs(directory, exist_ok=False)
try:
shared_dir = os.path.join(os.getcwd(), "uploads", f"{time_stamp}_{session_hash}")
new_job_paths, new_support_paths = move_all_files_to_shared_directory(job_files, support_files, shared_dir) # this is needed because gr.File(file_count="multiple") stores each file in a separate temp directory on the server, and we need to move them to the same dir.
input_file_str = "["+",".join(new_job_paths)+"]" #way to process multiple rf3 jobs in one go using hydra's list syntax.
command = f"rf3 fold inputs={input_file_str} out_dir={directory} diffusion_batch_size={num_predictions} num_steps={diffusion_steps} early_stopping_plddt_threshold={early_stopping} ckpt_path=/home/user/.foundry/checkpoints/rf3_foundry_01_24_latest_remapped.ckpt"
#as of 28th of february 2026 it seems that the default for downloading the weights of RF3 using "foundry install" and the default checkpoint running "rf3 fold" do not use the same checkpoint file, which is why I hard-coded the path to the checkpoint here.
#they might notice this and fix it in the future, and this workaround would then produce an error and would need to be removed.
print(f"Running command: {command}")
status_update += f"\nRunning command: {command}."
start = perf_counter()
res = subprocess.run(command, shell=True, check=True, text=True, capture_output=True, cwd=shared_dir)
print("Command took", perf_counter() - start, "seconds to run.")
status_update += f"\nGeneration successful! Command took {perf_counter() - start:.2f} seconds to run."
zip_path = download_results_as_zip(directory)
return status_update, zip_path
except subprocess.CalledProcessError as e:
print("subprocess threw an error", e.stderr)
return f"Generation failed:\n{e.stderr}", None