File size: 4,981 Bytes
a376829 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | # Copyright 2025 Dhruv Nair. All rights reserved.
# Licensed under the Apache License, Version 2.0
"""
Denoising loop for RF3.
Same EDM stochastic sampling as RFD3, but conditioned on trunk
representations (single S_I, pair Z_II) from the recycling step.
"""
from typing import Callable, List
import torch
from diffusers.utils import logging
from diffusers.modular_pipelines import ModularPipeline, ModularPipelineBlocks, PipelineState
from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
logger = logging.get_logger(__name__)
class RF3DenoiseStep(ModularPipelineBlocks):
"""
Iterative denoising step for RF3.
Uses trunk representations from the recycling step as conditioning
for the diffusion module at each denoising step.
"""
model_name = "rf3"
@property
def description(self) -> str:
return "Iteratively denoise protein structure conditioned on sequence/MSA representations."
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("transformer", description="RF3 transformer (provides diffusion_module)"),
ComponentSpec("scheduler", description="RF3 EDM scheduler"),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("xyz", required=True, type_hint=torch.Tensor, description="Initial noised coords [D, L, 3]"),
InputParam("noise_schedule", required=True, type_hint=torch.Tensor),
InputParam("f", required=True, type_hint=dict, description="Feature dictionary"),
InputParam("single", type_hint=torch.Tensor, description="Trunk single repr [I, c_s]"),
InputParam("pair", type_hint=torch.Tensor, description="Trunk pair repr [I, I, c_z]"),
InputParam("s_inputs", type_hint=torch.Tensor, description="Input embeddings [I, c_s_inputs]"),
InputParam("callback", type_hint=Callable),
InputParam("callback_steps", default=1, type_hint=int),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("xyz", type_hint=torch.Tensor, description="Denoised coords [D, L, 3]"),
OutputParam("trajectory", type_hint=List[torch.Tensor]),
]
@torch.no_grad()
def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
xyz = block_state.xyz
noise_schedule = block_state.noise_schedule
f = block_state.f
single = block_state.single
pair = block_state.pair
s_inputs = block_state.s_inputs
callback = block_state.callback
callback_steps = block_state.callback_steps or 1
X_L = xyz.clone()
D = X_L.shape[0]
device = X_L.device
noise_schedule = noise_schedule.to(device)
trajectory = []
has_transformer = hasattr(components, "transformer") and components.transformer is not None
has_scheduler = hasattr(components, "scheduler") and components.scheduler is not None
for step_num in range(len(noise_schedule) - 1):
c_t_minus_1 = noise_schedule[step_num]
c_t = noise_schedule[step_num + 1]
# Noise injection
if has_scheduler:
X_noisy, t_hat = components.scheduler.add_noise(X_L, c_t_minus_1, c_t)
else:
X_noisy = X_L
t_hat = c_t_minus_1
# Model forward (diffusion module conditioned on trunk)
if has_transformer:
t_batch = (t_hat.to(device).expand(D) if isinstance(t_hat, torch.Tensor)
else torch.full((D,), t_hat, device=device))
outs = components.transformer.diffusion_module(
X_noisy_L=X_noisy,
t=t_batch,
f=f,
S_inputs_I=s_inputs,
S_trunk_I=single,
Z_trunk_II=pair,
)
X_denoised = outs if isinstance(outs, torch.Tensor) else outs.get("X_L", outs)
else:
X_denoised = X_noisy
# Euler step
if has_scheduler:
X_L = components.scheduler.step(
xyz_pred=X_denoised, xyz_noisy=X_noisy,
c_t_minus_1=c_t_minus_1, c_t=c_t,
)
else:
delta = (X_noisy - X_denoised) / (t_hat + 1e-8)
d_t = c_t - t_hat
X_L = X_noisy + d_t * delta
trajectory.append(X_denoised.clone())
if callback is not None and step_num % callback_steps == 0:
callback(step_num, c_t_minus_1, X_L)
block_state.xyz = X_L
block_state.trajectory = trajectory
self.set_block_state(state, block_state)
return components, state
|