RFDiffusion-3 / decoders.py
dn6's picture
dn6 HF Staff
Upload folder using huggingface_hub
4900749 verified
# Copyright 2025 Dhruv Nair. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List, Optional
import numpy as np
import torch
from atomworks.io.utils.io_utils import to_cif_file
from biotite.structure import AtomArray, AtomArrayStack, stack
from diffusers.utils import logging
from diffusers.modular_pipelines import ModularPipeline, ModularPipelineBlocks, PipelineState
from diffusers.modular_pipelines.modular_pipeline_utils import InputParam, OutputParam
logger = logging.get_logger(__name__)
AA_NAMES = [
"ALA", "ARG", "ASN", "ASP", "CYS", "GLN", "GLU", "GLY", "HIS", "ILE",
"LEU", "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP", "TYR", "VAL",
"UNK",
]
def _build_atom_array(xyz: torch.Tensor, seq: Optional[torch.Tensor] = None) -> AtomArray:
"""
Build a biotite AtomArray from CA coordinates and optional sequence.
Args:
xyz: Coordinates for a single structure [L, 3].
seq: Sequence indices [L] (indexes into AA_NAMES).
"""
xyz_np = xyz.detach().cpu().float().numpy()
L = xyz_np.shape[0]
arr = AtomArray(L)
arr.coord = xyz_np
arr.atom_name = np.full(L, "CA")
arr.element = np.full(L, "C")
arr.chain_id = np.full(L, "A")
arr.res_id = np.arange(1, L + 1)
if seq is not None:
seq_np = seq.detach().cpu().numpy()
arr.res_name = np.array([
AA_NAMES[int(idx)] if int(idx) < len(AA_NAMES) else "UNK"
for idx in seq_np
])
else:
arr.res_name = np.full(L, "ALA")
return arr
def _build_atom_array_stack(
xyz: torch.Tensor,
seq: Optional[torch.Tensor] = None,
) -> AtomArrayStack:
"""
Build an AtomArrayStack from batched coordinates [B, L, 3].
Matches foundry ``build_stack_from_atom_array_and_batched_coords``.
"""
template = _build_atom_array(xyz[0], seq[0] if seq is not None else None)
B = xyz.shape[0]
arr_stack = stack([template for _ in range(B)])
arr_stack.coord = xyz.detach().cpu().float().numpy()
return arr_stack
def _build_trajectory_stack(
trajectory: List[torch.Tensor],
seq: Optional[torch.Tensor] = None,
) -> AtomArrayStack:
"""
Build an AtomArrayStack from a denoising trajectory.
Each entry is [B, L, 3]; takes the first batch element per step.
"""
coords = torch.stack([t[0] for t in trajectory]) # [N_steps, L, 3]
template = _build_atom_array(coords[0], seq[0] if seq is not None else None)
arr_stack = stack([template for _ in range(coords.shape[0])])
arr_stack.coord = coords.detach().cpu().float().numpy()
return arr_stack
@dataclass
class RFDiffusionPipelineOutput:
"""Output class for RFDiffusion pipeline."""
xyz: torch.Tensor
atom_array: Optional[AtomArray] = None
atom_array_stack: Optional[AtomArrayStack] = None
trajectory_stack: Optional[AtomArrayStack] = None
sequence_indices: Optional[torch.Tensor] = None
sequence_logits: Optional[torch.Tensor] = None
single: Optional[torch.Tensor] = None
pair: Optional[torch.Tensor] = None
pdb_string: Optional[str] = None
trajectory: Optional[List[torch.Tensor]] = None
class RFDiffusionDecodeStep(ModularPipelineBlocks):
"""
Decode step for RFDiffusion.
Converts denoised coordinates to final output format.
Supported ``output_type`` values:
- ``"tensor"`` — raw tensors only
- ``"pdb"`` — tensors + PDB format string
- ``"cif"`` — tensors + AtomArray via AtomWorks, writes ``.cif``
- ``"cif.gz"`` — same as ``"cif"`` but compressed
"""
model_name = "rfdiffusion"
@property
def description(self) -> str:
return (
"Decode step that converts denoised coordinates to final output, "
"supporting tensor, PDB, and CIF (via AtomWorks) formats."
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"output_type",
default="tensor",
type_hint=str,
description="Output format: 'tensor', 'pdb', 'cif', or 'cif.gz'",
),
InputParam(
"output_path",
type_hint=str,
description="Path to save output structure",
),
InputParam("xyz", required=True, type_hint=torch.Tensor, description="Denoised coordinates [B, L, 3]"),
InputParam("sequence_indices", type_hint=torch.Tensor, description="Predicted sequence [B, L]"),
InputParam("sequence_logits", type_hint=torch.Tensor, description="Sequence logits [B, L, n_aa]"),
InputParam("single", type_hint=torch.Tensor, description="Single representation"),
InputParam("pair", type_hint=torch.Tensor, description="Pair representation"),
InputParam("trajectory", type_hint=List[torch.Tensor], description="Denoising trajectory"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("output", type_hint=RFDiffusionPipelineOutput, description="Final pipeline output"),
]
@torch.no_grad()
def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
xyz = block_state.xyz
sequence_indices = block_state.sequence_indices
sequence_logits = block_state.sequence_logits
single = block_state.single
pair = block_state.pair
trajectory = block_state.trajectory
output_type = block_state.output_type or "tensor"
output_path = block_state.output_path
pdb_string = None
atom_array = None
atom_array_stack = None
trajectory_stack = None
# Build AtomArray for CIF output types
if output_type in ("cif", "cif.gz"):
atom_array = _build_atom_array(
xyz[0], sequence_indices[0] if sequence_indices is not None else None,
)
if xyz.shape[0] > 1:
atom_array_stack = _build_atom_array_stack(xyz, sequence_indices)
if trajectory:
trajectory_stack = _build_trajectory_stack(trajectory, sequence_indices)
# Build PDB string
if output_type == "pdb":
pdb_string = self._coords_to_pdb(
xyz.squeeze(0),
sequence_indices.squeeze(0) if sequence_indices is not None else None,
)
# Write to disk
if output_path is not None:
import os
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
if output_type in ("cif", "cif.gz"):
to_write = atom_array_stack if atom_array_stack is not None else atom_array
base = output_path.rsplit(".", 1)[0] if "." in output_path else output_path
to_cif_file(to_write, base, file_type=output_type, include_entity_poly=False)
if trajectory_stack is not None:
to_cif_file(trajectory_stack, base + "_trajectory", file_type=output_type, include_entity_poly=False)
elif output_type == "pdb" and pdb_string is not None:
with open(output_path, "w") as f:
f.write(pdb_string)
output = RFDiffusionPipelineOutput(
xyz=xyz,
atom_array=atom_array,
atom_array_stack=atom_array_stack,
trajectory_stack=trajectory_stack,
sequence_indices=sequence_indices,
sequence_logits=sequence_logits,
single=single,
pair=pair,
pdb_string=pdb_string,
trajectory=trajectory,
)
block_state.output = output
self.set_block_state(state, block_state)
return components, state
def _coords_to_pdb(
self,
xyz: torch.Tensor,
seq: Optional[torch.Tensor] = None,
) -> str:
"""Convert coordinates to PDB format string."""
xyz_np = xyz.cpu().numpy()
L = xyz_np.shape[0]
if seq is not None:
seq_np = seq.cpu().numpy()
else:
seq_np = None
lines = []
atom_idx = 1
for i in range(L):
if seq_np is not None:
aa_idx = int(seq_np[i])
aa_name = AA_NAMES[aa_idx] if aa_idx < len(AA_NAMES) else "UNK"
else:
aa_name = "ALA"
if xyz_np.ndim == 2:
x, y, z = xyz_np[i, :]
line = (
f"ATOM {atom_idx:5d} CA {aa_name:3s} A"
f"{i+1:4d} {x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00 C "
)
lines.append(line)
atom_idx += 1
else:
for j, atom_name in enumerate(["N", "CA", "C"]):
if j >= xyz_np.shape[1]:
break
x, y, z = xyz_np[i, j, :]
line = (
f"ATOM {atom_idx:5d} {atom_name:<3s} {aa_name:3s} A"
f"{i+1:4d} {x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00 "
f"{atom_name[0]:>2s} "
)
lines.append(line)
atom_idx += 1
lines.append("END")
return "\n".join(lines)