File size: 5,111 Bytes
94ff1b9
 
 
 
 
c1e206f
5b8fa42
aced7ae
a7294d8
c1a83a3
37b5de1
2535c77
44157d7
94ff1b9
 
 
50f1940
 
 
19021e5
50f1940
7f14dfe
 
 
 
 
 
 
37b5de1
 
7f14dfe
 
 
d848ff0
 
7f14dfe
 
 
 
 
 
 
e165b97
 
4a67952
e165b97
 
 
 
 
44157d7
7f14dfe
44157d7
7f14dfe
 
 
37b5de1
 
bcc516a
 
37b5de1
 
 
 
 
 
 
b212ae4
 
e165b97
2535c77
b041365
 
 
7f14dfe
 
 
 
833a14a
a76c148
7f14dfe
b8f3849
 
7f14dfe
 
 
4a67952
 
 
 
7f14dfe
c1a83a3
4a67952
aa9564f
 
50f1940
 
 
 
 
 
 
aa9564f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine
import gradio as gr
from lightning.fabric import seed_everything
import time
import os
import spaces
import subprocess
import gzip
from utils.handle_files import *
import sys
import shutil
from time import perf_counter
import random


    
def get_duration(input_file, pdb_file, num_batches, num_designs_per_batch, extra_args, max_duration):
    return max_duration

@spaces.GPU(duration=get_duration) 
def generation_with_input_config(input_file, pdb_file, num_batches, num_designs_per_batch, extra_args, max_duration):
    """
    Runs an unconditional generation with the specified input config file. Saves the generated structures to a timestamped directory in the outputs folder and returns the path to the directory along with a list of the generated structures' file paths.

    Parameters:
    ----------
    input_file: gr.File,
        gr.File object containing the uploaded config file (yaml or json). input_file.name is the path to the uploaded file on the server.
    pdb_file: gr.File,
        gr.File object containing the uploaded pdb file for conditioning the generation.
    
    Returns:
    -------
        textbox_update: gr.update,
            A gr.update object to update the textbox with the status of the generation. propagates subprocess errors to the textbox if the generation fails.
        directory: str,
            The path to the directory where the generated structures are saved.
        results: list of dicts,
            A list of the generated structures' file paths, where each dict contains batch number "batch", design number "design", path to cif file "cif_path", and path to pdb file "pdb_path".

    """


    if input_file is None:        
        return "Please ensure you have uploaded a configuration file: .yaml or .json", None, None
    elif pdb_file is None:
        status_update = f"Running generation for {num_batches} batches of {num_designs_per_batch}\n job configuration uploaded from file {os.path.basename(input_file)}\n no scaffold/target provided"
    else:
        status_update = f"Running generation for {num_batches} batches of {num_designs_per_batch}\n job configuration uploaded from file {os.path.basename(input_file)}\n scaffold/target provided from file {os.path.basename(pdb_file)}"

    session_hash = random.getrandbits(128)
    time_stamp = time.strftime("%Y-%m-%d-%H-%M-%S")
    directory = f"./outputs/generation_with_input_config/session_{session_hash}_{time_stamp}"
    os.makedirs(directory, exist_ok=False)

    try: 
        if pdb_file is not None:
            # I need to do this because uploading files to a HF space stores each file in a separate temp directory so I need to copy them again to the same place.
            shared_dir = os.path.join("uploads", f"{time_stamp}_{session_hash}")
            os.makedirs(shared_dir)
            copied_config_file = os.path.join(shared_dir, os.path.basename(input_file))
            shutil.copy2(input_file, copied_config_file)
            copied_pdb_file = os.path.join(shared_dir, os.path.basename(pdb_file))
            shutil.copy2(pdb_file, copied_pdb_file)
            command = f"rfd3 design inputs={copied_config_file} out_dir={directory} n_batches={num_batches} diffusion_batch_size={num_designs_per_batch}"
        else:
            command = f"rfd3 design inputs={input_file} out_dir={directory} n_batches={num_batches} diffusion_batch_size={num_designs_per_batch}"
        if extra_args:
            command += f" {extra_args}"
        status_update += f"\nRunning command: {command}."
        start = perf_counter()
        res = subprocess.run(command, shell=True, check=True, text=True, capture_output=True)
        status_update += f"\nGeneration successful! Command took {perf_counter() - start:.2f} seconds to run."


        results = []
        for file_name in os.listdir(directory):
            if file_name.endswith(".cif.gz"):
                name = os.path.basename(file_name).split(".")[0]  #filename without extension
                terms = name.split("_")
                model_index = terms.index("model")
                batch = int(terms[model_index - 1])
                design = int(terms[model_index + 1])
                cif_path = os.path.join(directory, file_name)
                pdb_path = mcif_gz_to_pdb(cif_path)
                results.append({"batch": batch, "design": design, "cif_path": cif_path, "pdb_path": pdb_path})
        
        zip_path = download_results_as_zip(directory)

        return status_update, results, zip_path
    
    except subprocess.CalledProcessError as e:
        return f"Generation failed:\n{e.stderr}", None, None


#def generation_with_input_config_factory(max_duration):
#
#    @spaces.GPU(duration=max_duration)
#    def generation_with_correct_time_limit(input_file, pdb_file, num_batches, num_designs_per_batch, extra_args):
#        return generation_with_input_config_impl(input_file, pdb_file, num_batches, num_designs_per_batch, extra_args)
#
#    return generation_with_correct_time_limit