RosettaFold-3 / denoise.py
dn6's picture
dn6 HF Staff
Upload folder using huggingface_hub
a376829 verified
# Copyright 2025 Dhruv Nair. All rights reserved.
# Licensed under the Apache License, Version 2.0
"""
Denoising loop for RF3.
Same EDM stochastic sampling as RFD3, but conditioned on trunk
representations (single S_I, pair Z_II) from the recycling step.
"""
from typing import Callable, List
import torch
from diffusers.utils import logging
from diffusers.modular_pipelines import ModularPipeline, ModularPipelineBlocks, PipelineState
from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
logger = logging.get_logger(__name__)
class RF3DenoiseStep(ModularPipelineBlocks):
"""
Iterative denoising step for RF3.
Uses trunk representations from the recycling step as conditioning
for the diffusion module at each denoising step.
"""
model_name = "rf3"
@property
def description(self) -> str:
return "Iteratively denoise protein structure conditioned on sequence/MSA representations."
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("transformer", description="RF3 transformer (provides diffusion_module)"),
ComponentSpec("scheduler", description="RF3 EDM scheduler"),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("xyz", required=True, type_hint=torch.Tensor, description="Initial noised coords [D, L, 3]"),
InputParam("noise_schedule", required=True, type_hint=torch.Tensor),
InputParam("f", required=True, type_hint=dict, description="Feature dictionary"),
InputParam("single", type_hint=torch.Tensor, description="Trunk single repr [I, c_s]"),
InputParam("pair", type_hint=torch.Tensor, description="Trunk pair repr [I, I, c_z]"),
InputParam("s_inputs", type_hint=torch.Tensor, description="Input embeddings [I, c_s_inputs]"),
InputParam("callback", type_hint=Callable),
InputParam("callback_steps", default=1, type_hint=int),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("xyz", type_hint=torch.Tensor, description="Denoised coords [D, L, 3]"),
OutputParam("trajectory", type_hint=List[torch.Tensor]),
]
@torch.no_grad()
def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
xyz = block_state.xyz
noise_schedule = block_state.noise_schedule
f = block_state.f
single = block_state.single
pair = block_state.pair
s_inputs = block_state.s_inputs
callback = block_state.callback
callback_steps = block_state.callback_steps or 1
X_L = xyz.clone()
D = X_L.shape[0]
device = X_L.device
noise_schedule = noise_schedule.to(device)
trajectory = []
has_transformer = hasattr(components, "transformer") and components.transformer is not None
has_scheduler = hasattr(components, "scheduler") and components.scheduler is not None
for step_num in range(len(noise_schedule) - 1):
c_t_minus_1 = noise_schedule[step_num]
c_t = noise_schedule[step_num + 1]
# Noise injection
if has_scheduler:
X_noisy, t_hat = components.scheduler.add_noise(X_L, c_t_minus_1, c_t)
else:
X_noisy = X_L
t_hat = c_t_minus_1
# Model forward (diffusion module conditioned on trunk)
if has_transformer:
t_batch = (t_hat.to(device).expand(D) if isinstance(t_hat, torch.Tensor)
else torch.full((D,), t_hat, device=device))
outs = components.transformer.diffusion_module(
X_noisy_L=X_noisy,
t=t_batch,
f=f,
S_inputs_I=s_inputs,
S_trunk_I=single,
Z_trunk_II=pair,
)
X_denoised = outs if isinstance(outs, torch.Tensor) else outs.get("X_L", outs)
else:
X_denoised = X_noisy
# Euler step
if has_scheduler:
X_L = components.scheduler.step(
xyz_pred=X_denoised, xyz_noisy=X_noisy,
c_t_minus_1=c_t_minus_1, c_t=c_t,
)
else:
delta = (X_noisy - X_denoised) / (t_hat + 1e-8)
d_t = c_t - t_hat
X_L = X_noisy + d_t * delta
trajectory.append(X_denoised.clone())
if callback is not None and step_num % callback_steps == 0:
callback(step_num, c_t_minus_1, X_L)
block_state.xyz = X_L
block_state.trajectory = trajectory
self.set_block_state(state, block_state)
return components, state