Instructions to use dn6/RosettaFold-3 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use dn6/RosettaFold-3 with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("dn6/RosettaFold-3", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| # Copyright 2025 Dhruv Nair. All rights reserved. | |
| # Licensed under the Apache License, Version 2.0 | |
| """ | |
| RF3 (RosettaFold3) Transformer model. | |
| A diffusers-compatible wrapper around the foundry RF3 model components. | |
| Reuses FeatureInitializer, Recycler, DiffusionModule, and DistogramHead | |
| from ``rf3.model.*`` directly, adding only the ModelMixin/ConfigMixin | |
| interface needed for diffusers ModularPipeline integration. | |
| RF3 is structurally similar to RFD3 but adds a trunk recycler (48 | |
| pairformer blocks + MSA + templates) for sequence-conditioned folding. | |
| """ | |
| from dataclasses import dataclass | |
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| from diffusers.configuration_utils import ConfigMixin, register_to_config | |
| from diffusers.models.modeling_utils import ModelMixin | |
| from rf3.model.RF3_structure import DiffusionModule, DistogramHead, Recycler | |
| from rf3.model.layers.pairformer_layers import FeatureInitializer | |
| class RF3TransformerOutput: | |
| """Output class for RF3 transformer.""" | |
| xyz: torch.Tensor # [D, L, 3] | |
| distogram: Optional[torch.Tensor] = None # [I, I, bins] | |
| single: Optional[torch.Tensor] = None # [I, c_s] | |
| pair: Optional[torch.Tensor] = None # [I, I, c_z] | |
| trajectory_noisy: Optional[list] = None # list of [D, L, 3] | |
| trajectory_denoised: Optional[list] = None # list of [D, L, 3] | |
| class RF3TransformerModel(ModelMixin, ConfigMixin): | |
| """ | |
| Diffusers-compatible wrapper around the foundry RF3 model. | |
| Wraps FeatureInitializer, Recycler, DiffusionModule, and DistogramHead | |
| to provide a diffusers ModelMixin/ConfigMixin interface. | |
| State dict keys match the foundry checkpoint format via the | |
| ``feature_initializer.*``, ``recycler.*``, ``diffusion_module.*``, | |
| and ``distogram_head.*`` prefixes. | |
| """ | |
| config_name = "config.json" | |
| _supports_gradient_checkpointing = True | |
| def __init__( | |
| self, | |
| c_s: int = 384, | |
| c_z: int = 128, | |
| c_atom: int = 128, | |
| c_atompair: int = 16, | |
| c_s_inputs: int = 449, | |
| c_token: int = 768, | |
| sigma_data: float = 16.0, | |
| n_pairformer_blocks: int = 48, | |
| n_diffusion_blocks: int = 24, | |
| n_atom_encoder_blocks: int = 3, | |
| n_atom_decoder_blocks: int = 3, | |
| n_msa_blocks: int = 4, | |
| n_template_blocks: int = 2, | |
| n_head: int = 16, | |
| n_pairformer_head: int = 16, | |
| n_recycles: int = 10, | |
| distogram_bins: int = 65, | |
| p_drop: float = 0.25, | |
| ): | |
| super().__init__() | |
| # ββ FeatureInitializer ββββββββββββββββββββββββββββββββββββββββββ | |
| self.feature_initializer = FeatureInitializer( | |
| c_s=c_s, | |
| c_z=c_z, | |
| c_atom=c_atom, | |
| c_atompair=c_atompair, | |
| c_s_inputs=c_s_inputs, | |
| input_feature_embedder={ | |
| "features": ["restype", "profile", "deletion_mean"], | |
| "atom_attention_encoder": { | |
| "c_token": c_s, | |
| "c_atom_1d_features": 389, | |
| "c_tokenpair": c_z, | |
| "use_inv_dist_squared": True, | |
| "atom_1d_features": [ | |
| "ref_pos", "ref_charge", "ref_mask", | |
| "ref_element", "ref_atom_name_chars", | |
| ], | |
| "atom_transformer": { | |
| "n_queries": 32, | |
| "n_keys": 128, | |
| "diffusion_transformer": { | |
| "n_block": 3, | |
| "diffusion_transformer_block": { | |
| "n_head": 4, | |
| "no_residual_connection_between_attention_and_transition": True, | |
| "kq_norm": True, | |
| }, | |
| }, | |
| }, | |
| }, | |
| }, | |
| relative_position_encoding={"r_max": 32, "s_max": 2}, | |
| ) | |
| # ββ Recycler (trunk) βββββββββββββββββββββββββββββββββββββββββββ | |
| self.recycler = Recycler( | |
| c_s=c_s, | |
| c_z=c_z, | |
| n_pairformer_blocks=n_pairformer_blocks, | |
| pairformer_block={ | |
| "p_drop": p_drop, | |
| "triangle_multiplication": {"d_hidden": 128}, | |
| "triangle_attention": {"n_head": 4, "d_hidden": 32}, | |
| "attention_pair_bias": {"n_head": n_head}, | |
| }, | |
| template_embedder={ | |
| "n_block": n_template_blocks, | |
| "raw_template_dim": 108, | |
| "c": 64, | |
| "p_drop": p_drop, | |
| }, | |
| msa_module={ | |
| "n_block": n_msa_blocks, | |
| "c_m": 64, | |
| "p_drop_msa": 0.15, | |
| "p_drop_pair": p_drop, | |
| "msa_subsample_embedder": { | |
| "num_sequences": 1024, | |
| "dim_raw_msa": 34, | |
| "c_s_inputs": c_s_inputs, | |
| "c_msa_embed": 64, | |
| }, | |
| "outer_product": { | |
| "c_msa_embed": 64, | |
| "c_outer_product": 32, | |
| "c_out": c_z, | |
| }, | |
| "msa_pair_weighted_averaging": { | |
| "n_heads": 8, | |
| "c_weighted_average": 32, | |
| "c_msa_embed": 64, | |
| "c_z": c_z, | |
| "separate_gate_for_every_channel": True, | |
| }, | |
| "msa_transition": {"n": 4, "c": 64}, | |
| "triangle_multiplication_outgoing": { | |
| "d_pair": c_z, "d_hidden": 128, "bias": True, | |
| }, | |
| "triangle_multiplication_incoming": { | |
| "d_pair": c_z, "d_hidden": 128, "bias": True, | |
| }, | |
| "triangle_attention_starting": { | |
| "d_pair": c_z, "n_head": 4, "d_hidden": 32, "p_drop": 0.0, | |
| }, | |
| "triangle_attention_ending": { | |
| "d_pair": c_z, "n_head": 4, "d_hidden": 32, "p_drop": 0.0, | |
| }, | |
| "pair_transition": {"n": 4, "c": c_z}, | |
| }, | |
| ) | |
| # ββ DiffusionModule ββββββββββββββββββββββββββββββββββββββββββββ | |
| self.diffusion_module = DiffusionModule( | |
| sigma_data=sigma_data, | |
| c_atom=c_atom, | |
| c_atompair=c_atompair, | |
| c_token=c_token, | |
| c_s=c_s, | |
| c_z=c_z, | |
| diffusion_conditioning={ | |
| "c_s_inputs": c_s_inputs, | |
| "c_t_embed": 256, | |
| "relative_position_encoding": {"r_max": 32, "s_max": 2}, | |
| }, | |
| atom_attention_encoder={ | |
| "c_tokenpair": c_z, | |
| "c_atom_1d_features": 389, | |
| "use_inv_dist_squared": True, | |
| "atom_1d_features": [ | |
| "ref_pos", "ref_charge", "ref_mask", | |
| "ref_element", "ref_atom_name_chars", | |
| ], | |
| "atom_transformer": { | |
| "n_queries": 32, | |
| "n_keys": 128, | |
| "diffusion_transformer": { | |
| "n_block": n_atom_encoder_blocks, | |
| "diffusion_transformer_block": { | |
| "n_head": 4, | |
| "no_residual_connection_between_attention_and_transition": True, | |
| "kq_norm": True, | |
| }, | |
| }, | |
| }, | |
| "broadcast_trunk_feats_on_1dim_old": False, | |
| "use_chiral_features": True, | |
| "no_grad_on_chiral_center": False, | |
| }, | |
| diffusion_transformer={ | |
| "n_block": n_diffusion_blocks, | |
| "diffusion_transformer_block": { | |
| "n_head": n_head, | |
| "no_residual_connection_between_attention_and_transition": True, | |
| "kq_norm": True, | |
| }, | |
| }, | |
| atom_attention_decoder={ | |
| "atom_transformer": { | |
| "n_queries": 32, | |
| "n_keys": 128, | |
| "diffusion_transformer": { | |
| "n_block": n_atom_decoder_blocks, | |
| "diffusion_transformer_block": { | |
| "n_head": 4, | |
| "no_residual_connection_between_attention_and_transition": True, | |
| "kq_norm": True, | |
| }, | |
| }, | |
| }, | |
| }, | |
| ) | |
| # ββ DistogramHead ββββββββββββββββββββββββββββββββββββββββββββββ | |
| self.distogram_head = DistogramHead(c_z=c_z, bins=distogram_bins) | |
| self._n_recycles = n_recycles | |
| def forward( | |
| self, | |
| f: dict, | |
| n_recycles: Optional[int] = None, | |
| diffusion_batch_size: int = 1, | |
| coord_atom_lvl_to_be_noised: Optional[torch.Tensor] = None, | |
| ) -> RF3TransformerOutput: | |
| """ | |
| Forward pass: recycling trunk β diffusion sampling. | |
| Args: | |
| f: Feature dictionary (sequence, MSA, templates, atom features). | |
| n_recycles: Number of recycling iterations (default: config value). | |
| diffusion_batch_size: Number of diffusion samples. | |
| coord_atom_lvl_to_be_noised: Initial coordinates for partial diffusion. | |
| Returns: | |
| RF3TransformerOutput with predicted coordinates and distogram. | |
| """ | |
| n_recycles = n_recycles or self._n_recycles | |
| # Pre-recycle: initialize features | |
| initialized = self.feature_initializer(f) | |
| S_inputs_I = initialized["S_inputs_I"] | |
| S_I = initialized.get("S_init_I", initialized.get("S_I")) | |
| Z_II = initialized.get("Z_init_II", initialized.get("Z_II")) | |
| # Recycling trunk | |
| for i in range(n_recycles): | |
| ctx = torch.no_grad() if i < n_recycles - 1 else torch.enable_grad() | |
| with ctx: | |
| recycled = self.recycler( | |
| S_I=S_I, | |
| Z_II=Z_II, | |
| S_inputs_I=S_inputs_I, | |
| f=f, | |
| ) | |
| S_I = recycled["S_I"] | |
| Z_II = recycled["Z_II"] | |
| # Distogram prediction | |
| distogram = self.distogram_head(Z_II) | |
| return RF3TransformerOutput( | |
| xyz=torch.zeros(1), # placeholder β filled by sampler in denoise step | |
| distogram=distogram, | |
| single=S_I, | |
| pair=Z_II, | |
| ) | |