| |
| |
|
|
| """ |
| Denoising loop for RF3. |
| |
| Same EDM stochastic sampling as RFD3, but conditioned on trunk |
| representations (single S_I, pair Z_II) from the recycling step. |
| """ |
|
|
| from typing import Callable, List |
|
|
| import torch |
|
|
| from diffusers.utils import logging |
| from diffusers.modular_pipelines import ModularPipeline, ModularPipelineBlocks, PipelineState |
| from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, OutputParam |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class RF3DenoiseStep(ModularPipelineBlocks): |
| """ |
| Iterative denoising step for RF3. |
| |
| Uses trunk representations from the recycling step as conditioning |
| for the diffusion module at each denoising step. |
| """ |
|
|
| model_name = "rf3" |
|
|
| @property |
| def description(self) -> str: |
| return "Iteratively denoise protein structure conditioned on sequence/MSA representations." |
|
|
| @property |
| def expected_components(self) -> List[ComponentSpec]: |
| return [ |
| ComponentSpec("transformer", description="RF3 transformer (provides diffusion_module)"), |
| ComponentSpec("scheduler", description="RF3 EDM scheduler"), |
| ] |
|
|
| @property |
| def inputs(self) -> List[InputParam]: |
| return [ |
| InputParam("xyz", required=True, type_hint=torch.Tensor, description="Initial noised coords [D, L, 3]"), |
| InputParam("noise_schedule", required=True, type_hint=torch.Tensor), |
| InputParam("f", required=True, type_hint=dict, description="Feature dictionary"), |
| InputParam("single", type_hint=torch.Tensor, description="Trunk single repr [I, c_s]"), |
| InputParam("pair", type_hint=torch.Tensor, description="Trunk pair repr [I, I, c_z]"), |
| InputParam("s_inputs", type_hint=torch.Tensor, description="Input embeddings [I, c_s_inputs]"), |
| InputParam("callback", type_hint=Callable), |
| InputParam("callback_steps", default=1, type_hint=int), |
| ] |
|
|
| @property |
| def intermediate_outputs(self) -> List[OutputParam]: |
| return [ |
| OutputParam("xyz", type_hint=torch.Tensor, description="Denoised coords [D, L, 3]"), |
| OutputParam("trajectory", type_hint=List[torch.Tensor]), |
| ] |
|
|
| @torch.no_grad() |
| def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: |
| block_state = self.get_block_state(state) |
|
|
| xyz = block_state.xyz |
| noise_schedule = block_state.noise_schedule |
| f = block_state.f |
| single = block_state.single |
| pair = block_state.pair |
| s_inputs = block_state.s_inputs |
| callback = block_state.callback |
| callback_steps = block_state.callback_steps or 1 |
|
|
| X_L = xyz.clone() |
| D = X_L.shape[0] |
| device = X_L.device |
| noise_schedule = noise_schedule.to(device) |
|
|
| trajectory = [] |
|
|
| has_transformer = hasattr(components, "transformer") and components.transformer is not None |
| has_scheduler = hasattr(components, "scheduler") and components.scheduler is not None |
|
|
| for step_num in range(len(noise_schedule) - 1): |
| c_t_minus_1 = noise_schedule[step_num] |
| c_t = noise_schedule[step_num + 1] |
|
|
| |
| if has_scheduler: |
| X_noisy, t_hat = components.scheduler.add_noise(X_L, c_t_minus_1, c_t) |
| else: |
| X_noisy = X_L |
| t_hat = c_t_minus_1 |
|
|
| |
| if has_transformer: |
| t_batch = (t_hat.to(device).expand(D) if isinstance(t_hat, torch.Tensor) |
| else torch.full((D,), t_hat, device=device)) |
|
|
| outs = components.transformer.diffusion_module( |
| X_noisy_L=X_noisy, |
| t=t_batch, |
| f=f, |
| S_inputs_I=s_inputs, |
| S_trunk_I=single, |
| Z_trunk_II=pair, |
| ) |
| X_denoised = outs if isinstance(outs, torch.Tensor) else outs.get("X_L", outs) |
| else: |
| X_denoised = X_noisy |
|
|
| |
| if has_scheduler: |
| X_L = components.scheduler.step( |
| xyz_pred=X_denoised, xyz_noisy=X_noisy, |
| c_t_minus_1=c_t_minus_1, c_t=c_t, |
| ) |
| else: |
| delta = (X_noisy - X_denoised) / (t_hat + 1e-8) |
| d_t = c_t - t_hat |
| X_L = X_noisy + d_t * delta |
|
|
| trajectory.append(X_denoised.clone()) |
|
|
| if callback is not None and step_num % callback_steps == 0: |
| callback(step_num, c_t_minus_1, X_L) |
|
|
| block_state.xyz = X_L |
| block_state.trajectory = trajectory |
|
|
| self.set_block_state(state, block_state) |
| return components, state |
|
|