Instructions to use dn6/RosettaFold-3 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use dn6/RosettaFold-3 with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("dn6/RosettaFold-3", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| # Copyright 2025 Dhruv Nair. All rights reserved. | |
| # Licensed under the Apache License, Version 2.0 | |
| """ | |
| Decode step for RF3 — converts denoised coordinates to output structures. | |
| Supports tensor, PDB, and CIF (via AtomWorks) output formats. | |
| """ | |
| from dataclasses import dataclass | |
| from typing import List, Optional | |
| import numpy as np | |
| import torch | |
| from atomworks.io.utils.io_utils import to_cif_file | |
| from biotite.structure import AtomArray, AtomArrayStack, stack | |
| from diffusers.utils import logging | |
| from diffusers.modular_pipelines import ModularPipeline, ModularPipelineBlocks, PipelineState | |
| from diffusers.modular_pipelines.modular_pipeline_utils import InputParam, OutputParam | |
| logger = logging.get_logger(__name__) | |
| AA_ORDER = "ARNDCQEGHILKMFPSTWYV" | |
| AA_NAMES_3 = [ | |
| "ALA", "ARG", "ASN", "ASP", "CYS", "GLN", "GLU", "GLY", "HIS", "ILE", | |
| "LEU", "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP", "TYR", "VAL", "UNK", | |
| ] | |
| def _build_atom_array(xyz: torch.Tensor, sequence: Optional[str] = None) -> AtomArray: | |
| xyz_np = xyz.detach().cpu().float().numpy() | |
| L = xyz_np.shape[0] | |
| arr = AtomArray(L) | |
| arr.coord = xyz_np | |
| arr.atom_name = np.full(L, "CA") | |
| arr.element = np.full(L, "C") | |
| arr.chain_id = np.full(L, "A") | |
| arr.res_id = np.arange(1, L + 1) | |
| if sequence: | |
| arr.res_name = np.array([ | |
| AA_NAMES_3[AA_ORDER.find(aa)] if aa in AA_ORDER else "UNK" | |
| for aa in sequence | |
| ]) | |
| else: | |
| arr.res_name = np.full(L, "ALA") | |
| return arr | |
| def _build_atom_array_stack(xyz: torch.Tensor, sequence: Optional[str] = None) -> AtomArrayStack: | |
| template = _build_atom_array(xyz[0], sequence) | |
| arr_stack = stack([template for _ in range(xyz.shape[0])]) | |
| arr_stack.coord = xyz.detach().cpu().float().numpy() | |
| return arr_stack | |
| class RF3PipelineOutput: | |
| """Output class for RF3 pipeline.""" | |
| xyz: torch.Tensor | |
| atom_array: Optional[AtomArray] = None | |
| atom_array_stack: Optional[AtomArrayStack] = None | |
| trajectory_stack: Optional[AtomArrayStack] = None | |
| distogram: Optional[torch.Tensor] = None | |
| sequence: Optional[str] = None | |
| pdb_string: Optional[str] = None | |
| trajectory: Optional[List[torch.Tensor]] = None | |
| class RF3DecodeStep(ModularPipelineBlocks): | |
| """ | |
| Decode step for RF3. | |
| Supported ``output_type`` values: ``"tensor"``, ``"pdb"``, ``"cif"``, ``"cif.gz"``. | |
| """ | |
| model_name = "rf3" | |
| def description(self) -> str: | |
| return "Convert predicted coordinates to output format (tensor/PDB/CIF)." | |
| def inputs(self) -> List[InputParam]: | |
| return [ | |
| InputParam("output_type", default="tensor", type_hint=str), | |
| InputParam("output_path", type_hint=str), | |
| InputParam("xyz", required=True, type_hint=torch.Tensor), | |
| InputParam("sequence", type_hint=str), | |
| InputParam("distogram", type_hint=torch.Tensor), | |
| InputParam("trajectory", type_hint=List[torch.Tensor]), | |
| ] | |
| def intermediate_outputs(self) -> List[OutputParam]: | |
| return [ | |
| OutputParam("output", type_hint=RF3PipelineOutput), | |
| ] | |
| def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: | |
| block_state = self.get_block_state(state) | |
| xyz = block_state.xyz | |
| sequence = block_state.sequence | |
| distogram = block_state.distogram | |
| trajectory = block_state.trajectory | |
| output_type = block_state.output_type or "tensor" | |
| output_path = block_state.output_path | |
| pdb_string = None | |
| atom_array = None | |
| atom_array_stack = None | |
| trajectory_stack = None | |
| if output_type in ("cif", "cif.gz"): | |
| atom_array = _build_atom_array(xyz[0], sequence) | |
| if xyz.shape[0] > 1: | |
| atom_array_stack = _build_atom_array_stack(xyz, sequence) | |
| if trajectory: | |
| traj_coords = torch.stack([t[0] for t in trajectory]) | |
| template = _build_atom_array(traj_coords[0], sequence) | |
| trajectory_stack = stack([template for _ in range(traj_coords.shape[0])]) | |
| trajectory_stack.coord = traj_coords.detach().cpu().float().numpy() | |
| if output_type == "pdb": | |
| pdb_string = self._coords_to_pdb(xyz[0], sequence) | |
| if output_path is not None: | |
| import os | |
| os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) | |
| if output_type in ("cif", "cif.gz"): | |
| to_write = atom_array_stack if atom_array_stack is not None else atom_array | |
| base = output_path.rsplit(".", 1)[0] if "." in output_path else output_path | |
| to_cif_file(to_write, base, file_type=output_type, include_entity_poly=False) | |
| elif output_type == "pdb" and pdb_string: | |
| with open(output_path, "w") as f: | |
| f.write(pdb_string) | |
| output = RF3PipelineOutput( | |
| xyz=xyz, | |
| atom_array=atom_array, | |
| atom_array_stack=atom_array_stack, | |
| trajectory_stack=trajectory_stack, | |
| distogram=distogram, | |
| sequence=sequence, | |
| pdb_string=pdb_string, | |
| trajectory=trajectory, | |
| ) | |
| block_state.output = output | |
| self.set_block_state(state, block_state) | |
| return components, state | |
| def _coords_to_pdb(self, xyz: torch.Tensor, sequence: Optional[str] = None) -> str: | |
| xyz_np = xyz.cpu().numpy() | |
| L = xyz_np.shape[0] | |
| lines = [] | |
| for i in range(L): | |
| aa = sequence[i] if sequence and i < len(sequence) else "A" | |
| aa3 = AA_NAMES_3[AA_ORDER.find(aa)] if aa in AA_ORDER else "UNK" | |
| x, y, z = xyz_np[i, :] | |
| lines.append( | |
| f"ATOM {i+1:5d} CA {aa3:3s} A{i+1:4d} " | |
| f"{x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00 C " | |
| ) | |
| lines.append("END") | |
| return "\n".join(lines) | |