RFDiffusion-3 / modular_blocks.py
dn6's picture
dn6 HF Staff
Upload folder using huggingface_hub
4900749 verified
# Copyright 2025 Dhruv Nair. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List, Optional
import torch
from diffusers.utils import logging
from diffusers.modular_pipelines import (
AutoPipelineBlocks,
ModularPipeline,
ModularPipelineBlocks,
PipelineState,
SequentialPipelineBlocks,
)
from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, InsertableDict, OutputParam
from .before_denoise import (
RFDiffusionInputStep,
RFDiffusionPrepareLatentsStep,
RFDiffusionSetTimestepsStep,
)
from .decoders import RFDiffusionDecodeStep
from .denoise import RFDiffusionDenoiseStep
logger = logging.get_logger(__name__)
# ─── Amino acid mappings (used by MPNN blocks) ─────────────────────────
THREE_TO_ONE = {
"ALA": "A", "ARG": "R", "ASN": "N", "ASP": "D", "CYS": "C",
"GLN": "Q", "GLU": "E", "GLY": "G", "HIS": "H", "ILE": "I",
"LEU": "L", "LYS": "K", "MET": "M", "PHE": "F", "PRO": "P",
"SER": "S", "THR": "T", "TRP": "W", "TYR": "Y", "VAL": "V",
"UNK": "X",
}
AA_NAMES = list(THREE_TO_ONE.keys())
# ═══════════════════════════════════════════════════════════════════════════
# RFDiffusion blocks
# ═══════════════════════════════════════════════════════════════════════════
class RFDiffusionBeforeDenoiseStep(SequentialPipelineBlocks):
"""Sequential block for pre-denoising preparation."""
block_classes = [
RFDiffusionInputStep,
RFDiffusionSetTimestepsStep,
RFDiffusionPrepareLatentsStep,
]
block_names = ["input", "set_timesteps", "prepare_latents"]
@property
def description(self):
return (
"Before denoise step that prepares the inputs for the denoise step.\n"
"This is a sequential pipeline blocks:\n"
" - `RFDiffusionInputStep` processes contigs and prepares input features\n"
" - `RFDiffusionSetTimestepsStep` sets up the diffusion timesteps\n"
" - `RFDiffusionPrepareLatentsStep` initializes noised coordinates\n"
)
class RFDiffusionAutoBeforeDenoiseStep(AutoPipelineBlocks):
"""Auto-select before denoise step based on task."""
block_classes = [RFDiffusionBeforeDenoiseStep]
block_names = ["unconditional"]
block_trigger_inputs = [None]
@property
def description(self):
return (
"Before denoise step that prepares the inputs for the denoise step.\n"
"This is an auto pipeline block for protein structure generation.\n"
" - `RFDiffusionBeforeDenoiseStep` (unconditional) is used.\n"
)
class RFDiffusionAutoDenoiseStep(AutoPipelineBlocks):
"""Auto-select denoise step."""
block_classes = [RFDiffusionDenoiseStep]
block_names = ["denoise"]
block_trigger_inputs = [None]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoises the protein structure. "
"This is an auto pipeline block for protein structure generation. "
" - `RFDiffusionDenoiseStep` (denoise) for structure generation."
)
class RFDiffusionAutoDecodeStep(AutoPipelineBlocks):
"""Auto-select decode step."""
block_classes = [RFDiffusionDecodeStep]
block_names = ["decode"]
block_trigger_inputs = [None]
@property
def description(self):
return "Decode step that converts denoised coordinates to PDB output.\n - `RFDiffusionDecodeStep`"
# ═══════════════════════════════════════════════════════════════════════════
# MPNN blocks (defined before RFDiffusionAutoBlocks which references them)
# ═══════════════════════════════════════════════════════════════════════════
@dataclass
class MPNNPipelineOutput:
"""Output from ProteinMPNN / LigandMPNN sequence design."""
designed_sequence: str
sequence_indices: torch.Tensor # [B, L] token indices
sequence_logits: torch.Tensor # [B, L, n_vocab] logits
xyz: torch.Tensor # [B, L, 3] input structure (passed through)
pdb_string: Optional[str] = None # PDB with designed sequence
sequence_recovery: Optional[float] = None
class MPNNSequenceDesignStep(ModularPipelineBlocks):
"""
Design sequences for a protein backbone using ProteinMPNN / LigandMPNN.
Takes ``xyz`` coordinates (typically from an upstream RFDiffusion denoise
step) and runs the ``MPNNModel`` to produce amino acid sequences for
the designable regions.
When no ``mpnn`` component is loaded, falls back to using the sequence
predictions from upstream RFDiffusion (or glycine everywhere).
"""
model_name = "mpnn"
@property
def description(self) -> str:
return (
"Design amino acid sequences for protein backbones using "
"ProteinMPNN or LigandMPNN. Accepts structure coordinates "
"from an upstream diffusion step."
)
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("mpnn", description="MPNNModel (ProteinMPNN or LigandMPNN)"),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"xyz", required=True, type_hint=torch.Tensor,
description="Protein backbone coordinates [B, L, 3] (CA atoms)",
),
InputParam(
"motif_mask", type_hint=torch.Tensor,
description="Mask for fixed/motif positions [L]. True = fixed sequence.",
),
InputParam(
"sequence_indices", type_hint=torch.Tensor,
description="Known sequence indices for motif positions [B, L] (from RFDiffusion)",
),
InputParam(
"temperature", default=0.1, type_hint=float,
description="Sampling temperature (lower = more deterministic)",
),
InputParam(
"num_designs", default=1, type_hint=int,
description="Number of sequence designs to generate per structure",
),
InputParam(
"output_type", default="tensor", type_hint=str,
description="'tensor', 'pdb', or 'cif'",
),
InputParam(
"output_path", type_hint=str,
description="Path to save designed PDB",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"mpnn_output", type_hint=MPNNPipelineOutput,
description="MPNN sequence design output",
),
]
@torch.no_grad()
def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
xyz = block_state.xyz
motif_mask = block_state.motif_mask
known_seq = block_state.sequence_indices
temperature = block_state.temperature or 0.1
output_type = block_state.output_type or "tensor"
output_path = block_state.output_path
B, L, _ = xyz.shape
device = xyz.device
has_mpnn = hasattr(components, "mpnn") and components.mpnn is not None
if has_mpnn:
sequence_logits, sequence_indices = self._run_mpnn(
components.mpnn, xyz, motif_mask, known_seq, temperature,
)
else:
if known_seq is not None:
sequence_indices = known_seq
else:
sequence_indices = torch.full((B, L), 7, dtype=torch.long, device=device) # GLY
sequence_logits = torch.zeros(B, L, len(AA_NAMES), device=device)
sequence_logits.scatter_(2, sequence_indices.unsqueeze(-1), 1.0)
seq_list = sequence_indices[0].cpu().tolist()
designed_sequence = "".join(
THREE_TO_ONE.get(AA_NAMES[min(idx, len(AA_NAMES) - 1)], "X")
for idx in seq_list
)
pdb_string = None
if output_type in ("pdb",):
pdb_string = self._coords_to_pdb(xyz[0], sequence_indices[0])
if output_path:
import os
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
with open(output_path, "w") as f:
f.write(pdb_string)
output = MPNNPipelineOutput(
designed_sequence=designed_sequence,
sequence_indices=sequence_indices,
sequence_logits=sequence_logits,
xyz=xyz,
pdb_string=pdb_string,
)
block_state.mpnn_output = output
self.set_block_state(state, block_state)
return components, state
def _run_mpnn(self, mpnn, xyz, motif_mask, known_seq, temperature):
"""Run the MPNNModel wrapper on backbone coordinates."""
B, L, _ = xyz.shape
device = xyz.device
dtype = xyz.dtype
ca = xyz
n_offset = torch.tensor([-1.458, 0.0, 0.0], device=device, dtype=dtype)
c_offset = torch.tensor([0.550, 1.424, 0.0], device=device, dtype=dtype)
o_offset = torch.tensor([0.550, 2.500, 0.0], device=device, dtype=dtype)
X = torch.stack([
ca + n_offset, ca, ca + c_offset, ca + o_offset,
], dim=2)
if motif_mask is not None:
designed_mask = ~motif_mask.unsqueeze(0).expand(B, -1)
else:
designed_mask = None
output = mpnn(
X=X, S=known_seq, designed_residue_mask=designed_mask, temperature=temperature,
)
logits = output.sequence_logits
indices = output.sequence_indices
if motif_mask is not None and known_seq is not None:
indices[:, motif_mask] = known_seq[:, motif_mask]
return logits, indices
def _coords_to_pdb(self, xyz: torch.Tensor, seq: torch.Tensor) -> str:
xyz_np = xyz.cpu().numpy()
seq_np = seq.cpu().numpy()
L = xyz_np.shape[0]
lines = []
for i in range(L):
aa_idx = int(seq_np[i])
aa_name = AA_NAMES[min(aa_idx, len(AA_NAMES) - 1)]
x, y, z = xyz_np[i, :]
lines.append(
f"ATOM {i+1:5d} CA {aa_name:3s} A{i+1:4d} "
f"{x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00 C "
)
lines.append("END")
return "\n".join(lines)
class MPNNAutoDesignStep(AutoPipelineBlocks):
"""Auto-select MPNN design step."""
block_classes = [MPNNSequenceDesignStep]
block_names = ["sequence_design"]
block_trigger_inputs = [None]
@property
def description(self) -> str:
return "Sequence design using ProteinMPNN or LigandMPNN."
# ═══════════════════════════════════════════════════════════════════════════
# Top-level pipeline blocks
# ═══════════════════════════════════════════════════════════════════════════
class RFDiffusionAutoBlocks(SequentialPipelineBlocks):
"""
Full protein design pipeline: RFDiffusion3 + optional ProteinMPNN/LigandMPNN.
The active workflow is selected by trigger inputs:
- ``contigs`` only β†’ structure generation
- ``contigs`` + ``temperature`` β†’ structure + sequence design
- ``contigs`` + ``input_xyz`` + ``temperature`` β†’ motif-conditioned + sequence design
The MPNN step is skipped when ``temperature`` is not provided or when
no ``mpnn`` component is loaded.
"""
block_classes = [
RFDiffusionAutoBeforeDenoiseStep,
RFDiffusionAutoDenoiseStep,
RFDiffusionAutoDecodeStep,
MPNNAutoDesignStep,
]
block_names = [
"before_denoise",
"denoise",
"decoder",
"sequence_design",
]
_workflow_map = {
"structure_only": {
"contigs": True,
},
"structure_and_sequence": {
"contigs": True,
"temperature": True,
},
"motif_structure_and_sequence": {
"contigs": True,
"input_xyz": True,
"temperature": True,
},
}
@property
def description(self):
return (
"Modular pipeline for protein design using RFDiffusion3.\n"
"Workflows:\n"
" - structure_only: backbone generation\n"
" - structure_and_sequence: backbone + MPNN sequence design\n"
" - motif_structure_and_sequence: motif-conditioned + MPNN\n"
)
# ═══════════════════════════════════════════════════════════════════════════
# Block registries
# ═══════════════════════════════════════════════════════════════════════════
UNCONDITIONAL_BLOCKS = InsertableDict(
[
("input", RFDiffusionInputStep),
("set_timesteps", RFDiffusionSetTimestepsStep),
("prepare_latents", RFDiffusionPrepareLatentsStep),
("denoise", RFDiffusionDenoiseStep),
("decode", RFDiffusionDecodeStep),
("sequence_design", MPNNSequenceDesignStep),
]
)
AUTO_BLOCKS = InsertableDict(
[
("before_denoise", RFDiffusionAutoBeforeDenoiseStep),
("denoise", RFDiffusionAutoDenoiseStep),
("decode", RFDiffusionAutoDecodeStep),
("sequence_design", MPNNAutoDesignStep),
]
)
ALL_BLOCKS = {
"unconditional": UNCONDITIONAL_BLOCKS,
"auto": AUTO_BLOCKS,
}