File size: 4,981 Bytes
a376829
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# Copyright 2025 Dhruv Nair. All rights reserved.
# Licensed under the Apache License, Version 2.0

"""
Denoising loop for RF3.

Same EDM stochastic sampling as RFD3, but conditioned on trunk
representations (single S_I, pair Z_II) from the recycling step.
"""

from typing import Callable, List

import torch

from diffusers.utils import logging
from diffusers.modular_pipelines import ModularPipeline, ModularPipelineBlocks, PipelineState
from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, OutputParam


logger = logging.get_logger(__name__)


class RF3DenoiseStep(ModularPipelineBlocks):
    """
    Iterative denoising step for RF3.

    Uses trunk representations from the recycling step as conditioning
    for the diffusion module at each denoising step.
    """

    model_name = "rf3"

    @property
    def description(self) -> str:
        return "Iteratively denoise protein structure conditioned on sequence/MSA representations."

    @property
    def expected_components(self) -> List[ComponentSpec]:
        return [
            ComponentSpec("transformer", description="RF3 transformer (provides diffusion_module)"),
            ComponentSpec("scheduler", description="RF3 EDM scheduler"),
        ]

    @property
    def inputs(self) -> List[InputParam]:
        return [
            InputParam("xyz", required=True, type_hint=torch.Tensor, description="Initial noised coords [D, L, 3]"),
            InputParam("noise_schedule", required=True, type_hint=torch.Tensor),
            InputParam("f", required=True, type_hint=dict, description="Feature dictionary"),
            InputParam("single", type_hint=torch.Tensor, description="Trunk single repr [I, c_s]"),
            InputParam("pair", type_hint=torch.Tensor, description="Trunk pair repr [I, I, c_z]"),
            InputParam("s_inputs", type_hint=torch.Tensor, description="Input embeddings [I, c_s_inputs]"),
            InputParam("callback", type_hint=Callable),
            InputParam("callback_steps", default=1, type_hint=int),
        ]

    @property
    def intermediate_outputs(self) -> List[OutputParam]:
        return [
            OutputParam("xyz", type_hint=torch.Tensor, description="Denoised coords [D, L, 3]"),
            OutputParam("trajectory", type_hint=List[torch.Tensor]),
        ]

    @torch.no_grad()
    def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
        block_state = self.get_block_state(state)

        xyz = block_state.xyz
        noise_schedule = block_state.noise_schedule
        f = block_state.f
        single = block_state.single
        pair = block_state.pair
        s_inputs = block_state.s_inputs
        callback = block_state.callback
        callback_steps = block_state.callback_steps or 1

        X_L = xyz.clone()
        D = X_L.shape[0]
        device = X_L.device
        noise_schedule = noise_schedule.to(device)

        trajectory = []

        has_transformer = hasattr(components, "transformer") and components.transformer is not None
        has_scheduler = hasattr(components, "scheduler") and components.scheduler is not None

        for step_num in range(len(noise_schedule) - 1):
            c_t_minus_1 = noise_schedule[step_num]
            c_t = noise_schedule[step_num + 1]

            # Noise injection
            if has_scheduler:
                X_noisy, t_hat = components.scheduler.add_noise(X_L, c_t_minus_1, c_t)
            else:
                X_noisy = X_L
                t_hat = c_t_minus_1

            # Model forward (diffusion module conditioned on trunk)
            if has_transformer:
                t_batch = (t_hat.to(device).expand(D) if isinstance(t_hat, torch.Tensor)
                          else torch.full((D,), t_hat, device=device))

                outs = components.transformer.diffusion_module(
                    X_noisy_L=X_noisy,
                    t=t_batch,
                    f=f,
                    S_inputs_I=s_inputs,
                    S_trunk_I=single,
                    Z_trunk_II=pair,
                )
                X_denoised = outs if isinstance(outs, torch.Tensor) else outs.get("X_L", outs)
            else:
                X_denoised = X_noisy

            # Euler step
            if has_scheduler:
                X_L = components.scheduler.step(
                    xyz_pred=X_denoised, xyz_noisy=X_noisy,
                    c_t_minus_1=c_t_minus_1, c_t=c_t,
                )
            else:
                delta = (X_noisy - X_denoised) / (t_hat + 1e-8)
                d_t = c_t - t_hat
                X_L = X_noisy + d_t * delta

            trajectory.append(X_denoised.clone())

            if callback is not None and step_num % callback_steps == 0:
                callback(step_num, c_t_minus_1, X_L)

        block_state.xyz = X_L
        block_state.trajectory = trajectory

        self.set_block_state(state, block_state)
        return components, state