# Copyright 2025 Dhruv Nair. All rights reserved. # Licensed under the Apache License, Version 2.0 """ Pre-denoising steps for RF3: input processing, timestep setup, recycling trunk, latent preparation. """ from typing import List import torch from diffusers.utils import logging from diffusers.modular_pipelines import ModularPipeline, ModularPipelineBlocks, PipelineState from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, OutputParam logger = logging.get_logger(__name__) class RF3InputStep(ModularPipelineBlocks): """Parse sequence input and prepare feature dict for RF3.""" model_name = "rf3" @property def description(self) -> str: return "Parse sequence and optional MSA/template inputs for structure prediction." @property def inputs(self) -> List[InputParam]: return [ InputParam("sequence", required=True, type_hint=str, description="Amino acid sequence (one-letter codes)"), InputParam("f", type_hint=dict, description="Pre-built feature dict (overrides sequence)"), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam("f", type_hint=dict, description="Feature dictionary for RF3"), OutputParam("L", type_hint=int, description="Sequence length (num atoms)"), OutputParam("I", type_hint=int, description="Num tokens"), ] @torch.no_grad() def __call__(self, components, state): block_state = self.get_block_state(state) f = block_state.f sequence = block_state.sequence if f is None: # Build minimal feature dict from sequence L = len(sequence) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Map sequence to restype indices AA_ORDER = "ARNDCQEGHILKMFPSTWYV" restype = torch.zeros(L, 32, device=device) for i, aa in enumerate(sequence): idx = AA_ORDER.find(aa) if idx >= 0: restype[i, idx] = 1.0 else: restype[i, 20] = 1.0 # unknown f = { "restype": restype, "atom_to_token_map": torch.arange(L, device=device), "is_ca": torch.ones(L, dtype=torch.bool, device=device), "ref_pos": torch.zeros(L, 3, device=device), "ref_charge": torch.zeros(L, device=device), "ref_mask": torch.ones(L, device=device), "ref_element": torch.zeros(L, 128, device=device), "ref_atom_name_chars": torch.zeros(L, 4, 64, device=device), } else: L = f.get("ref_element", f.get("restype")).shape[0] block_state.f = f block_state.L = L block_state.I = L # token count = atom count for CA-only self.set_block_state(state, block_state) return components, state class RF3SetTimestepsStep(ModularPipelineBlocks): """Set up EDM noise schedule for RF3.""" model_name = "rf3" @property def description(self) -> str: return "Construct EDM noise schedule for RF3 diffusion sampling." @property def expected_components(self) -> List[ComponentSpec]: return [ComponentSpec("scheduler", description="RF3 EDM scheduler")] @property def inputs(self) -> List[InputParam]: return [ InputParam("num_inference_steps", default=None, type_hint=int), InputParam("L", required=True, type_hint=int), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam("noise_schedule", type_hint=torch.Tensor), OutputParam("num_inference_steps", type_hint=int), ] @torch.no_grad() def __call__(self, components, state): block_state = self.get_block_state(state) if hasattr(components, "scheduler") and components.scheduler is not None: noise_schedule = components.scheduler.get_noise_schedule() else: noise_schedule = torch.linspace(160.0 * 16.0, 4e-4 * 16.0, 200) block_state.noise_schedule = noise_schedule block_state.num_inference_steps = len(noise_schedule) self.set_block_state(state, block_state) return components, state class RF3RecyclingStep(ModularPipelineBlocks): """Run the recycling trunk (pairformer + MSA + templates).""" model_name = "rf3" @property def description(self) -> str: return "Run RF3 recycling trunk to produce single/pair representations." @property def expected_components(self) -> List[ComponentSpec]: return [ComponentSpec("transformer", description="RF3 transformer model")] @property def inputs(self) -> List[InputParam]: return [ InputParam("f", required=True, type_hint=dict), InputParam("n_recycles", default=None, type_hint=int), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam("single", type_hint=torch.Tensor, description="Single representation [I, c_s]"), OutputParam("pair", type_hint=torch.Tensor, description="Pair representation [I, I, c_z]"), OutputParam("s_inputs", type_hint=torch.Tensor, description="Input embeddings [I, c_s_inputs]"), OutputParam("distogram", type_hint=torch.Tensor, description="Distogram prediction [I, I, bins]"), ] @torch.no_grad() def __call__(self, components, state): block_state = self.get_block_state(state) f = block_state.f n_recycles = block_state.n_recycles if hasattr(components, "transformer") and components.transformer is not None: output = components.transformer(f=f, n_recycles=n_recycles) block_state.single = output.single block_state.pair = output.pair block_state.distogram = output.distogram block_state.s_inputs = None # populated inside forward else: # Placeholder when no model loaded block_state.single = None block_state.pair = None block_state.distogram = None block_state.s_inputs = None self.set_block_state(state, block_state) return components, state class RF3PrepareLatentsStep(ModularPipelineBlocks): """Prepare initial noised coordinates for diffusion sampling.""" model_name = "rf3" @property def description(self) -> str: return "Sample initial Gaussian noise scaled by the first noise schedule value." @property def inputs(self) -> List[InputParam]: return [ InputParam("generator", type_hint=torch.Generator), InputParam("diffusion_batch_size", default=5, type_hint=int), InputParam("L", required=True, type_hint=int), InputParam("noise_schedule", required=True, type_hint=torch.Tensor), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam("xyz", type_hint=torch.Tensor, description="Initial noised coords [D, L, 3]"), ] @torch.no_grad() def __call__(self, components, state): block_state = self.get_block_state(state) L = block_state.L noise_schedule = block_state.noise_schedule D = block_state.diffusion_batch_size or 5 generator = block_state.generator device = torch.device("cuda" if torch.cuda.is_available() else "cpu") c0 = noise_schedule[0] xyz = c0 * torch.randn((D, L, 3), device=device, generator=generator) block_state.xyz = xyz self.set_block_state(state, block_state) return components, state