dn6's picture
dn6 HF Staff
Upload folder using huggingface_hub
a376829 verified
# Copyright 2025 Dhruv Nair. All rights reserved.
# Licensed under the Apache License, Version 2.0
"""
RF3 (RosettaFold3) Transformer model.
A diffusers-compatible wrapper around the foundry RF3 model components.
Reuses FeatureInitializer, Recycler, DiffusionModule, and DistogramHead
from ``rf3.model.*`` directly, adding only the ModelMixin/ConfigMixin
interface needed for diffusers ModularPipeline integration.
RF3 is structurally similar to RFD3 but adds a trunk recycler (48
pairformer blocks + MSA + templates) for sequence-conditioned folding.
"""
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from rf3.model.RF3_structure import DiffusionModule, DistogramHead, Recycler
from rf3.model.layers.pairformer_layers import FeatureInitializer
@dataclass
class RF3TransformerOutput:
"""Output class for RF3 transformer."""
xyz: torch.Tensor # [D, L, 3]
distogram: Optional[torch.Tensor] = None # [I, I, bins]
single: Optional[torch.Tensor] = None # [I, c_s]
pair: Optional[torch.Tensor] = None # [I, I, c_z]
trajectory_noisy: Optional[list] = None # list of [D, L, 3]
trajectory_denoised: Optional[list] = None # list of [D, L, 3]
class RF3TransformerModel(ModelMixin, ConfigMixin):
"""
Diffusers-compatible wrapper around the foundry RF3 model.
Wraps FeatureInitializer, Recycler, DiffusionModule, and DistogramHead
to provide a diffusers ModelMixin/ConfigMixin interface.
State dict keys match the foundry checkpoint format via the
``feature_initializer.*``, ``recycler.*``, ``diffusion_module.*``,
and ``distogram_head.*`` prefixes.
"""
config_name = "config.json"
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
c_s: int = 384,
c_z: int = 128,
c_atom: int = 128,
c_atompair: int = 16,
c_s_inputs: int = 449,
c_token: int = 768,
sigma_data: float = 16.0,
n_pairformer_blocks: int = 48,
n_diffusion_blocks: int = 24,
n_atom_encoder_blocks: int = 3,
n_atom_decoder_blocks: int = 3,
n_msa_blocks: int = 4,
n_template_blocks: int = 2,
n_head: int = 16,
n_pairformer_head: int = 16,
n_recycles: int = 10,
distogram_bins: int = 65,
p_drop: float = 0.25,
):
super().__init__()
# ── FeatureInitializer ──────────────────────────────────────────
self.feature_initializer = FeatureInitializer(
c_s=c_s,
c_z=c_z,
c_atom=c_atom,
c_atompair=c_atompair,
c_s_inputs=c_s_inputs,
input_feature_embedder={
"features": ["restype", "profile", "deletion_mean"],
"atom_attention_encoder": {
"c_token": c_s,
"c_atom_1d_features": 389,
"c_tokenpair": c_z,
"use_inv_dist_squared": True,
"atom_1d_features": [
"ref_pos", "ref_charge", "ref_mask",
"ref_element", "ref_atom_name_chars",
],
"atom_transformer": {
"n_queries": 32,
"n_keys": 128,
"diffusion_transformer": {
"n_block": 3,
"diffusion_transformer_block": {
"n_head": 4,
"no_residual_connection_between_attention_and_transition": True,
"kq_norm": True,
},
},
},
},
},
relative_position_encoding={"r_max": 32, "s_max": 2},
)
# ── Recycler (trunk) ───────────────────────────────────────────
self.recycler = Recycler(
c_s=c_s,
c_z=c_z,
n_pairformer_blocks=n_pairformer_blocks,
pairformer_block={
"p_drop": p_drop,
"triangle_multiplication": {"d_hidden": 128},
"triangle_attention": {"n_head": 4, "d_hidden": 32},
"attention_pair_bias": {"n_head": n_head},
},
template_embedder={
"n_block": n_template_blocks,
"raw_template_dim": 108,
"c": 64,
"p_drop": p_drop,
},
msa_module={
"n_block": n_msa_blocks,
"c_m": 64,
"p_drop_msa": 0.15,
"p_drop_pair": p_drop,
"msa_subsample_embedder": {
"num_sequences": 1024,
"dim_raw_msa": 34,
"c_s_inputs": c_s_inputs,
"c_msa_embed": 64,
},
"outer_product": {
"c_msa_embed": 64,
"c_outer_product": 32,
"c_out": c_z,
},
"msa_pair_weighted_averaging": {
"n_heads": 8,
"c_weighted_average": 32,
"c_msa_embed": 64,
"c_z": c_z,
"separate_gate_for_every_channel": True,
},
"msa_transition": {"n": 4, "c": 64},
"triangle_multiplication_outgoing": {
"d_pair": c_z, "d_hidden": 128, "bias": True,
},
"triangle_multiplication_incoming": {
"d_pair": c_z, "d_hidden": 128, "bias": True,
},
"triangle_attention_starting": {
"d_pair": c_z, "n_head": 4, "d_hidden": 32, "p_drop": 0.0,
},
"triangle_attention_ending": {
"d_pair": c_z, "n_head": 4, "d_hidden": 32, "p_drop": 0.0,
},
"pair_transition": {"n": 4, "c": c_z},
},
)
# ── DiffusionModule ────────────────────────────────────────────
self.diffusion_module = DiffusionModule(
sigma_data=sigma_data,
c_atom=c_atom,
c_atompair=c_atompair,
c_token=c_token,
c_s=c_s,
c_z=c_z,
diffusion_conditioning={
"c_s_inputs": c_s_inputs,
"c_t_embed": 256,
"relative_position_encoding": {"r_max": 32, "s_max": 2},
},
atom_attention_encoder={
"c_tokenpair": c_z,
"c_atom_1d_features": 389,
"use_inv_dist_squared": True,
"atom_1d_features": [
"ref_pos", "ref_charge", "ref_mask",
"ref_element", "ref_atom_name_chars",
],
"atom_transformer": {
"n_queries": 32,
"n_keys": 128,
"diffusion_transformer": {
"n_block": n_atom_encoder_blocks,
"diffusion_transformer_block": {
"n_head": 4,
"no_residual_connection_between_attention_and_transition": True,
"kq_norm": True,
},
},
},
"broadcast_trunk_feats_on_1dim_old": False,
"use_chiral_features": True,
"no_grad_on_chiral_center": False,
},
diffusion_transformer={
"n_block": n_diffusion_blocks,
"diffusion_transformer_block": {
"n_head": n_head,
"no_residual_connection_between_attention_and_transition": True,
"kq_norm": True,
},
},
atom_attention_decoder={
"atom_transformer": {
"n_queries": 32,
"n_keys": 128,
"diffusion_transformer": {
"n_block": n_atom_decoder_blocks,
"diffusion_transformer_block": {
"n_head": 4,
"no_residual_connection_between_attention_and_transition": True,
"kq_norm": True,
},
},
},
},
)
# ── DistogramHead ──────────────────────────────────────────────
self.distogram_head = DistogramHead(c_z=c_z, bins=distogram_bins)
self._n_recycles = n_recycles
def forward(
self,
f: dict,
n_recycles: Optional[int] = None,
diffusion_batch_size: int = 1,
coord_atom_lvl_to_be_noised: Optional[torch.Tensor] = None,
) -> RF3TransformerOutput:
"""
Forward pass: recycling trunk β†’ diffusion sampling.
Args:
f: Feature dictionary (sequence, MSA, templates, atom features).
n_recycles: Number of recycling iterations (default: config value).
diffusion_batch_size: Number of diffusion samples.
coord_atom_lvl_to_be_noised: Initial coordinates for partial diffusion.
Returns:
RF3TransformerOutput with predicted coordinates and distogram.
"""
n_recycles = n_recycles or self._n_recycles
# Pre-recycle: initialize features
initialized = self.feature_initializer(f)
S_inputs_I = initialized["S_inputs_I"]
S_I = initialized.get("S_init_I", initialized.get("S_I"))
Z_II = initialized.get("Z_init_II", initialized.get("Z_II"))
# Recycling trunk
for i in range(n_recycles):
ctx = torch.no_grad() if i < n_recycles - 1 else torch.enable_grad()
with ctx:
recycled = self.recycler(
S_I=S_I,
Z_II=Z_II,
S_inputs_I=S_inputs_I,
f=f,
)
S_I = recycled["S_I"]
Z_II = recycled["Z_II"]
# Distogram prediction
distogram = self.distogram_head(Z_II)
return RF3TransformerOutput(
xyz=torch.zeros(1), # placeholder β€” filled by sampler in denoise step
distogram=distogram,
single=S_I,
pair=Z_II,
)