RosettaFold-3 / decoders.py
dn6's picture
dn6 HF Staff
Upload folder using huggingface_hub
a376829 verified
# Copyright 2025 Dhruv Nair. All rights reserved.
# Licensed under the Apache License, Version 2.0
"""
Decode step for RF3 — converts denoised coordinates to output structures.
Supports tensor, PDB, and CIF (via AtomWorks) output formats.
"""
from dataclasses import dataclass
from typing import List, Optional
import numpy as np
import torch
from atomworks.io.utils.io_utils import to_cif_file
from biotite.structure import AtomArray, AtomArrayStack, stack
from diffusers.utils import logging
from diffusers.modular_pipelines import ModularPipeline, ModularPipelineBlocks, PipelineState
from diffusers.modular_pipelines.modular_pipeline_utils import InputParam, OutputParam
logger = logging.get_logger(__name__)
AA_ORDER = "ARNDCQEGHILKMFPSTWYV"
AA_NAMES_3 = [
"ALA", "ARG", "ASN", "ASP", "CYS", "GLN", "GLU", "GLY", "HIS", "ILE",
"LEU", "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP", "TYR", "VAL", "UNK",
]
def _build_atom_array(xyz: torch.Tensor, sequence: Optional[str] = None) -> AtomArray:
xyz_np = xyz.detach().cpu().float().numpy()
L = xyz_np.shape[0]
arr = AtomArray(L)
arr.coord = xyz_np
arr.atom_name = np.full(L, "CA")
arr.element = np.full(L, "C")
arr.chain_id = np.full(L, "A")
arr.res_id = np.arange(1, L + 1)
if sequence:
arr.res_name = np.array([
AA_NAMES_3[AA_ORDER.find(aa)] if aa in AA_ORDER else "UNK"
for aa in sequence
])
else:
arr.res_name = np.full(L, "ALA")
return arr
def _build_atom_array_stack(xyz: torch.Tensor, sequence: Optional[str] = None) -> AtomArrayStack:
template = _build_atom_array(xyz[0], sequence)
arr_stack = stack([template for _ in range(xyz.shape[0])])
arr_stack.coord = xyz.detach().cpu().float().numpy()
return arr_stack
@dataclass
class RF3PipelineOutput:
"""Output class for RF3 pipeline."""
xyz: torch.Tensor
atom_array: Optional[AtomArray] = None
atom_array_stack: Optional[AtomArrayStack] = None
trajectory_stack: Optional[AtomArrayStack] = None
distogram: Optional[torch.Tensor] = None
sequence: Optional[str] = None
pdb_string: Optional[str] = None
trajectory: Optional[List[torch.Tensor]] = None
class RF3DecodeStep(ModularPipelineBlocks):
"""
Decode step for RF3.
Supported ``output_type`` values: ``"tensor"``, ``"pdb"``, ``"cif"``, ``"cif.gz"``.
"""
model_name = "rf3"
@property
def description(self) -> str:
return "Convert predicted coordinates to output format (tensor/PDB/CIF)."
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("output_type", default="tensor", type_hint=str),
InputParam("output_path", type_hint=str),
InputParam("xyz", required=True, type_hint=torch.Tensor),
InputParam("sequence", type_hint=str),
InputParam("distogram", type_hint=torch.Tensor),
InputParam("trajectory", type_hint=List[torch.Tensor]),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("output", type_hint=RF3PipelineOutput),
]
@torch.no_grad()
def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
xyz = block_state.xyz
sequence = block_state.sequence
distogram = block_state.distogram
trajectory = block_state.trajectory
output_type = block_state.output_type or "tensor"
output_path = block_state.output_path
pdb_string = None
atom_array = None
atom_array_stack = None
trajectory_stack = None
if output_type in ("cif", "cif.gz"):
atom_array = _build_atom_array(xyz[0], sequence)
if xyz.shape[0] > 1:
atom_array_stack = _build_atom_array_stack(xyz, sequence)
if trajectory:
traj_coords = torch.stack([t[0] for t in trajectory])
template = _build_atom_array(traj_coords[0], sequence)
trajectory_stack = stack([template for _ in range(traj_coords.shape[0])])
trajectory_stack.coord = traj_coords.detach().cpu().float().numpy()
if output_type == "pdb":
pdb_string = self._coords_to_pdb(xyz[0], sequence)
if output_path is not None:
import os
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
if output_type in ("cif", "cif.gz"):
to_write = atom_array_stack if atom_array_stack is not None else atom_array
base = output_path.rsplit(".", 1)[0] if "." in output_path else output_path
to_cif_file(to_write, base, file_type=output_type, include_entity_poly=False)
elif output_type == "pdb" and pdb_string:
with open(output_path, "w") as f:
f.write(pdb_string)
output = RF3PipelineOutput(
xyz=xyz,
atom_array=atom_array,
atom_array_stack=atom_array_stack,
trajectory_stack=trajectory_stack,
distogram=distogram,
sequence=sequence,
pdb_string=pdb_string,
trajectory=trajectory,
)
block_state.output = output
self.set_block_state(state, block_state)
return components, state
def _coords_to_pdb(self, xyz: torch.Tensor, sequence: Optional[str] = None) -> str:
xyz_np = xyz.cpu().numpy()
L = xyz_np.shape[0]
lines = []
for i in range(L):
aa = sequence[i] if sequence and i < len(sequence) else "A"
aa3 = AA_NAMES_3[AA_ORDER.find(aa)] if aa in AA_ORDER else "UNK"
x, y, z = xyz_np[i, :]
lines.append(
f"ATOM {i+1:5d} CA {aa3:3s} A{i+1:4d} "
f"{x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00 C "
)
lines.append("END")
return "\n".join(lines)