File size: 3,541 Bytes
94ff1b9
 
 
 
 
c1e206f
5b8fa42
aced7ae
c49e7b8
aced7ae
94ff1b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26f1fe3
94ff1b9
aced7ae
94ff1b9
 
 
 
 
 
 
 
 
 
464a533
aced7ae
 
8705e46
46eb3c3
aced7ae
 
 
 
 
4f422de
aced7ae
 
 
 
5b8fa42
 
94ff1b9
aced7ae
 
 
46eb3c3
c49e7b8
46eb3c3
c49e7b8
46eb3c3
 
 
 
c49e7b8
46eb3c3
 
 
c49e7b8
 
 
5d306a8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine
import gradio as gr
from lightning.fabric import seed_everything
import time
import os
import spaces
import subprocess
import gzip
import gemmi



@spaces.GPU(duration=300)  
def test_rfd3_from_notebook():
    # Set seed for reproducibility
    seed_everything(0)

    # Configure RFD3 inference
    config = RFD3InferenceConfig(
        specification={
            'length': 40,  # Generate 80-residue proteins
        },
        diffusion_batch_size=2,  # Generate 2 structures per batch
    )

    # Initialize engine and run generation
    try: 
        model = RFD3InferenceEngine(**config)
        outputs = model.run(
            inputs=None,      # None for unconditional generation
            out_dir=None,     # None to return in memory (no file output)
            n_batches=1,      # Generate 1 batch
        )
        return_str = "RDF3 test passed! Generated structures:\n"

        for idx, data in outputs.items():
            return_str += f"Batch {idx}: {len(data)} structure(s)\n"
            for i, struct in enumerate(data):
                return_str += f"Structure {i+1}: {struct.atom_array.array_length()} Atoms\n"
                #return_str += struct.atom_array
        return return_str
    except Exception as e:
        return f"Error: {str(e)}"
    

# Initialize engine and run generation
@spaces.GPU(duration=300) 
def unconditional_generation(num_batches, num_designs_per_batch, length):

    config = RFD3InferenceConfig(
    specification={
        'length': length,
    },
    diffusion_batch_size=num_designs_per_batch,  # Generate 2 structures per batch
    )

    session_hash = gr.Request().session_hash
    time_stamp = time.strftime("%Y-%m-%d-%H-%M-%S")
    directory = f"./outputs/unconditional_generation/session_{session_hash}_{time_stamp}"
    os.makedirs(directory, exist_ok=False)

    try: 
        model = RFD3InferenceEngine(**config)
        outputs = model.run(
            inputs=None,      # None for unconditional generation
            out_dir=directory,     # None to return in memory (no file output)
            n_batches=num_batches,      # Generate 1 batch
        )

        results = []
        for batch in range(num_batches):
            for design in range(num_designs_per_batch):
                file_name = os.path.join(directory, f"_{batch}_model_{design}.cif.gz")
                results.append({"batch": batch, "design": design, "cif_path": file_name, "pdb_path": mcif_gz_to_pdb(file_name)})
            
        print(results)
        return directory, results
    
    except Exception as e:
        raise RuntimeError(f"Error during generation: {str(e)}")
    
def collect_outputs(gen_directory, num_batches, num_designs_per_batch):
    try:
        cmd = f"ls -R {gen_directory}"
        file_list = subprocess.check_output(cmd, shell=True).decode()
        return file_list
    except Exception as e:
        return f"Error: {str(e)}"
    

def mcif_gz_to_pdb(file_path: str) -> str:
    """
    Converts a .mcif.gz file to pdb and saves it to the same directory. Returns the path to the pdb file.
    
    Parameters:
    ----------
    file_path: str,
        Path to the .mcif.gz file.
    
    Returns
    -------
        str: path to the generated pdb file.
    """
    st = gemmi.read_structure(file_path)
    st.setup_entities()  # Recommended for consistent entity handling [web:18]
    pdb_path = file_path.replace(".cif.gz", ".pdb")
    st.write_minimal_pdb(pdb_path)
    return pdb_path