File size: 7,870 Bytes
a376829
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
# Copyright 2025 Dhruv Nair. All rights reserved.
# Licensed under the Apache License, Version 2.0

"""
Pre-denoising steps for RF3: input processing, timestep setup, recycling trunk, latent preparation.
"""

from typing import List

import torch

from diffusers.utils import logging
from diffusers.modular_pipelines import ModularPipeline, ModularPipelineBlocks, PipelineState
from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, OutputParam


logger = logging.get_logger(__name__)


class RF3InputStep(ModularPipelineBlocks):
    """Parse sequence input and prepare feature dict for RF3."""

    model_name = "rf3"

    @property
    def description(self) -> str:
        return "Parse sequence and optional MSA/template inputs for structure prediction."

    @property
    def inputs(self) -> List[InputParam]:
        return [
            InputParam("sequence", required=True, type_hint=str, description="Amino acid sequence (one-letter codes)"),
            InputParam("f", type_hint=dict, description="Pre-built feature dict (overrides sequence)"),
        ]

    @property
    def intermediate_outputs(self) -> List[OutputParam]:
        return [
            OutputParam("f", type_hint=dict, description="Feature dictionary for RF3"),
            OutputParam("L", type_hint=int, description="Sequence length (num atoms)"),
            OutputParam("I", type_hint=int, description="Num tokens"),
        ]

    @torch.no_grad()
    def __call__(self, components, state):
        block_state = self.get_block_state(state)

        f = block_state.f
        sequence = block_state.sequence

        if f is None:
            # Build minimal feature dict from sequence
            L = len(sequence)
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

            # Map sequence to restype indices
            AA_ORDER = "ARNDCQEGHILKMFPSTWYV"
            restype = torch.zeros(L, 32, device=device)
            for i, aa in enumerate(sequence):
                idx = AA_ORDER.find(aa)
                if idx >= 0:
                    restype[i, idx] = 1.0
                else:
                    restype[i, 20] = 1.0  # unknown

            f = {
                "restype": restype,
                "atom_to_token_map": torch.arange(L, device=device),
                "is_ca": torch.ones(L, dtype=torch.bool, device=device),
                "ref_pos": torch.zeros(L, 3, device=device),
                "ref_charge": torch.zeros(L, device=device),
                "ref_mask": torch.ones(L, device=device),
                "ref_element": torch.zeros(L, 128, device=device),
                "ref_atom_name_chars": torch.zeros(L, 4, 64, device=device),
            }
        else:
            L = f.get("ref_element", f.get("restype")).shape[0]

        block_state.f = f
        block_state.L = L
        block_state.I = L  # token count = atom count for CA-only

        self.set_block_state(state, block_state)
        return components, state


class RF3SetTimestepsStep(ModularPipelineBlocks):
    """Set up EDM noise schedule for RF3."""

    model_name = "rf3"

    @property
    def description(self) -> str:
        return "Construct EDM noise schedule for RF3 diffusion sampling."

    @property
    def expected_components(self) -> List[ComponentSpec]:
        return [ComponentSpec("scheduler", description="RF3 EDM scheduler")]

    @property
    def inputs(self) -> List[InputParam]:
        return [
            InputParam("num_inference_steps", default=None, type_hint=int),
            InputParam("L", required=True, type_hint=int),
        ]

    @property
    def intermediate_outputs(self) -> List[OutputParam]:
        return [
            OutputParam("noise_schedule", type_hint=torch.Tensor),
            OutputParam("num_inference_steps", type_hint=int),
        ]

    @torch.no_grad()
    def __call__(self, components, state):
        block_state = self.get_block_state(state)

        if hasattr(components, "scheduler") and components.scheduler is not None:
            noise_schedule = components.scheduler.get_noise_schedule()
        else:
            noise_schedule = torch.linspace(160.0 * 16.0, 4e-4 * 16.0, 200)

        block_state.noise_schedule = noise_schedule
        block_state.num_inference_steps = len(noise_schedule)

        self.set_block_state(state, block_state)
        return components, state


class RF3RecyclingStep(ModularPipelineBlocks):
    """Run the recycling trunk (pairformer + MSA + templates)."""

    model_name = "rf3"

    @property
    def description(self) -> str:
        return "Run RF3 recycling trunk to produce single/pair representations."

    @property
    def expected_components(self) -> List[ComponentSpec]:
        return [ComponentSpec("transformer", description="RF3 transformer model")]

    @property
    def inputs(self) -> List[InputParam]:
        return [
            InputParam("f", required=True, type_hint=dict),
            InputParam("n_recycles", default=None, type_hint=int),
        ]

    @property
    def intermediate_outputs(self) -> List[OutputParam]:
        return [
            OutputParam("single", type_hint=torch.Tensor, description="Single representation [I, c_s]"),
            OutputParam("pair", type_hint=torch.Tensor, description="Pair representation [I, I, c_z]"),
            OutputParam("s_inputs", type_hint=torch.Tensor, description="Input embeddings [I, c_s_inputs]"),
            OutputParam("distogram", type_hint=torch.Tensor, description="Distogram prediction [I, I, bins]"),
        ]

    @torch.no_grad()
    def __call__(self, components, state):
        block_state = self.get_block_state(state)

        f = block_state.f
        n_recycles = block_state.n_recycles

        if hasattr(components, "transformer") and components.transformer is not None:
            output = components.transformer(f=f, n_recycles=n_recycles)
            block_state.single = output.single
            block_state.pair = output.pair
            block_state.distogram = output.distogram
            block_state.s_inputs = None  # populated inside forward
        else:
            # Placeholder when no model loaded
            block_state.single = None
            block_state.pair = None
            block_state.distogram = None
            block_state.s_inputs = None

        self.set_block_state(state, block_state)
        return components, state


class RF3PrepareLatentsStep(ModularPipelineBlocks):
    """Prepare initial noised coordinates for diffusion sampling."""

    model_name = "rf3"

    @property
    def description(self) -> str:
        return "Sample initial Gaussian noise scaled by the first noise schedule value."

    @property
    def inputs(self) -> List[InputParam]:
        return [
            InputParam("generator", type_hint=torch.Generator),
            InputParam("diffusion_batch_size", default=5, type_hint=int),
            InputParam("L", required=True, type_hint=int),
            InputParam("noise_schedule", required=True, type_hint=torch.Tensor),
        ]

    @property
    def intermediate_outputs(self) -> List[OutputParam]:
        return [
            OutputParam("xyz", type_hint=torch.Tensor, description="Initial noised coords [D, L, 3]"),
        ]

    @torch.no_grad()
    def __call__(self, components, state):
        block_state = self.get_block_state(state)

        L = block_state.L
        noise_schedule = block_state.noise_schedule
        D = block_state.diffusion_batch_size or 5
        generator = block_state.generator

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        c0 = noise_schedule[0]
        xyz = c0 * torch.randn((D, L, 3), device=device, generator=generator)

        block_state.xyz = xyz

        self.set_block_state(state, block_state)
        return components, state