# Copyright 2025 Dhruv Nair. All rights reserved. # Licensed under the Apache License, Version 2.0 """ RF3 (RosettaFold3) Transformer model. A diffusers-compatible wrapper around the foundry RF3 model components. Reuses FeatureInitializer, Recycler, DiffusionModule, and DistogramHead from ``rf3.model.*`` directly, adding only the ModelMixin/ConfigMixin interface needed for diffusers ModularPipeline integration. RF3 is structurally similar to RFD3 but adds a trunk recycler (48 pairformer blocks + MSA + templates) for sequence-conditioned folding. """ from dataclasses import dataclass from typing import Optional import torch import torch.nn as nn from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin from rf3.model.RF3_structure import DiffusionModule, DistogramHead, Recycler from rf3.model.layers.pairformer_layers import FeatureInitializer @dataclass class RF3TransformerOutput: """Output class for RF3 transformer.""" xyz: torch.Tensor # [D, L, 3] distogram: Optional[torch.Tensor] = None # [I, I, bins] single: Optional[torch.Tensor] = None # [I, c_s] pair: Optional[torch.Tensor] = None # [I, I, c_z] trajectory_noisy: Optional[list] = None # list of [D, L, 3] trajectory_denoised: Optional[list] = None # list of [D, L, 3] class RF3TransformerModel(ModelMixin, ConfigMixin): """ Diffusers-compatible wrapper around the foundry RF3 model. Wraps FeatureInitializer, Recycler, DiffusionModule, and DistogramHead to provide a diffusers ModelMixin/ConfigMixin interface. State dict keys match the foundry checkpoint format via the ``feature_initializer.*``, ``recycler.*``, ``diffusion_module.*``, and ``distogram_head.*`` prefixes. """ config_name = "config.json" _supports_gradient_checkpointing = True @register_to_config def __init__( self, c_s: int = 384, c_z: int = 128, c_atom: int = 128, c_atompair: int = 16, c_s_inputs: int = 449, c_token: int = 768, sigma_data: float = 16.0, n_pairformer_blocks: int = 48, n_diffusion_blocks: int = 24, n_atom_encoder_blocks: int = 3, n_atom_decoder_blocks: int = 3, n_msa_blocks: int = 4, n_template_blocks: int = 2, n_head: int = 16, n_pairformer_head: int = 16, n_recycles: int = 10, distogram_bins: int = 65, p_drop: float = 0.25, ): super().__init__() # ── FeatureInitializer ────────────────────────────────────────── self.feature_initializer = FeatureInitializer( c_s=c_s, c_z=c_z, c_atom=c_atom, c_atompair=c_atompair, c_s_inputs=c_s_inputs, input_feature_embedder={ "features": ["restype", "profile", "deletion_mean"], "atom_attention_encoder": { "c_token": c_s, "c_atom_1d_features": 389, "c_tokenpair": c_z, "use_inv_dist_squared": True, "atom_1d_features": [ "ref_pos", "ref_charge", "ref_mask", "ref_element", "ref_atom_name_chars", ], "atom_transformer": { "n_queries": 32, "n_keys": 128, "diffusion_transformer": { "n_block": 3, "diffusion_transformer_block": { "n_head": 4, "no_residual_connection_between_attention_and_transition": True, "kq_norm": True, }, }, }, }, }, relative_position_encoding={"r_max": 32, "s_max": 2}, ) # ── Recycler (trunk) ─────────────────────────────────────────── self.recycler = Recycler( c_s=c_s, c_z=c_z, n_pairformer_blocks=n_pairformer_blocks, pairformer_block={ "p_drop": p_drop, "triangle_multiplication": {"d_hidden": 128}, "triangle_attention": {"n_head": 4, "d_hidden": 32}, "attention_pair_bias": {"n_head": n_head}, }, template_embedder={ "n_block": n_template_blocks, "raw_template_dim": 108, "c": 64, "p_drop": p_drop, }, msa_module={ "n_block": n_msa_blocks, "c_m": 64, "p_drop_msa": 0.15, "p_drop_pair": p_drop, "msa_subsample_embedder": { "num_sequences": 1024, "dim_raw_msa": 34, "c_s_inputs": c_s_inputs, "c_msa_embed": 64, }, "outer_product": { "c_msa_embed": 64, "c_outer_product": 32, "c_out": c_z, }, "msa_pair_weighted_averaging": { "n_heads": 8, "c_weighted_average": 32, "c_msa_embed": 64, "c_z": c_z, "separate_gate_for_every_channel": True, }, "msa_transition": {"n": 4, "c": 64}, "triangle_multiplication_outgoing": { "d_pair": c_z, "d_hidden": 128, "bias": True, }, "triangle_multiplication_incoming": { "d_pair": c_z, "d_hidden": 128, "bias": True, }, "triangle_attention_starting": { "d_pair": c_z, "n_head": 4, "d_hidden": 32, "p_drop": 0.0, }, "triangle_attention_ending": { "d_pair": c_z, "n_head": 4, "d_hidden": 32, "p_drop": 0.0, }, "pair_transition": {"n": 4, "c": c_z}, }, ) # ── DiffusionModule ──────────────────────────────────────────── self.diffusion_module = DiffusionModule( sigma_data=sigma_data, c_atom=c_atom, c_atompair=c_atompair, c_token=c_token, c_s=c_s, c_z=c_z, diffusion_conditioning={ "c_s_inputs": c_s_inputs, "c_t_embed": 256, "relative_position_encoding": {"r_max": 32, "s_max": 2}, }, atom_attention_encoder={ "c_tokenpair": c_z, "c_atom_1d_features": 389, "use_inv_dist_squared": True, "atom_1d_features": [ "ref_pos", "ref_charge", "ref_mask", "ref_element", "ref_atom_name_chars", ], "atom_transformer": { "n_queries": 32, "n_keys": 128, "diffusion_transformer": { "n_block": n_atom_encoder_blocks, "diffusion_transformer_block": { "n_head": 4, "no_residual_connection_between_attention_and_transition": True, "kq_norm": True, }, }, }, "broadcast_trunk_feats_on_1dim_old": False, "use_chiral_features": True, "no_grad_on_chiral_center": False, }, diffusion_transformer={ "n_block": n_diffusion_blocks, "diffusion_transformer_block": { "n_head": n_head, "no_residual_connection_between_attention_and_transition": True, "kq_norm": True, }, }, atom_attention_decoder={ "atom_transformer": { "n_queries": 32, "n_keys": 128, "diffusion_transformer": { "n_block": n_atom_decoder_blocks, "diffusion_transformer_block": { "n_head": 4, "no_residual_connection_between_attention_and_transition": True, "kq_norm": True, }, }, }, }, ) # ── DistogramHead ────────────────────────────────────────────── self.distogram_head = DistogramHead(c_z=c_z, bins=distogram_bins) self._n_recycles = n_recycles def forward( self, f: dict, n_recycles: Optional[int] = None, diffusion_batch_size: int = 1, coord_atom_lvl_to_be_noised: Optional[torch.Tensor] = None, ) -> RF3TransformerOutput: """ Forward pass: recycling trunk → diffusion sampling. Args: f: Feature dictionary (sequence, MSA, templates, atom features). n_recycles: Number of recycling iterations (default: config value). diffusion_batch_size: Number of diffusion samples. coord_atom_lvl_to_be_noised: Initial coordinates for partial diffusion. Returns: RF3TransformerOutput with predicted coordinates and distogram. """ n_recycles = n_recycles or self._n_recycles # Pre-recycle: initialize features initialized = self.feature_initializer(f) S_inputs_I = initialized["S_inputs_I"] S_I = initialized.get("S_init_I", initialized.get("S_I")) Z_II = initialized.get("Z_init_II", initialized.get("Z_II")) # Recycling trunk for i in range(n_recycles): ctx = torch.no_grad() if i < n_recycles - 1 else torch.enable_grad() with ctx: recycled = self.recycler( S_I=S_I, Z_II=Z_II, S_inputs_I=S_inputs_I, f=f, ) S_I = recycled["S_I"] Z_II = recycled["Z_II"] # Distogram prediction distogram = self.distogram_head(Z_II) return RF3TransformerOutput( xyz=torch.zeros(1), # placeholder — filled by sampler in denoise step distogram=distogram, single=S_I, pair=Z_II, )