dn6 HF Staff commited on
Commit
a376829
Β·
verified Β·
1 Parent(s): 9f83200

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Protein Structure Prediction with Diffusers
2
+
3
+ A [diffusers](https://github.com/huggingface/diffusers) `ModularPipeline` wrapper for [RosettaFold3](https://doi.org/10.1101/2025.08.14.670328) (RF3) β€” a diffusion-based protein structure prediction model that predicts 3D atomic coordinates from amino acid sequences.
4
+
5
+ RF3 relies on [Foundry](https://github.com/RosettaCommons/foundry) for its underlying implementation and [AtomWorks](https://github.com/RosettaCommons/atomworks) for structure I/O. This package adds only the thin wrappers needed for diffusers integration.
6
+
7
+ ## Getting Started
8
+
9
+ ### Installation
10
+
11
+ ```bash
12
+ pip install rc-foundry[all]
13
+ pip install diffusers
14
+ ```
15
+
16
+ ### Running with Diffusers
17
+
18
+ ```python
19
+ import torch
20
+ from diffusers import ModularPipeline
21
+
22
+ pipe = ModularPipeline.from_pretrained("dn6/RosettaFold-3", trust_remote_code=True)
23
+ pipe.load_components(device_map="cuda", torch_dtype=torch.bfloat16, trust_remote_code=True)
24
+
25
+ state = pipe(sequence="MKVLSEGDPWRK...")
26
+ print(state.output.xyz.shape) # [D, L, 3]
27
+ ```
28
+
29
+ ## Workflows
30
+
31
+ | Workflow | Trigger inputs | What runs |
32
+ |----------|---------------|-----------|
33
+ | `fold` | `sequence` | Full structure prediction (recycling trunk + diffusion) |
34
+
35
+ ### Fold a Sequence
36
+
37
+ ```python
38
+ state = pipe(sequence="MKVLSEGDPWRK...", output_type="cif.gz", output_path="prediction")
39
+ print(state.output.atom_array)
40
+ ```
41
+
42
+ ### Full Design Pipeline
43
+
44
+ RF3 is typically used as a validation step after backbone design with [RFdiffusion3](https://huggingface.co/dn6/RFDiffusion-3):
45
+
46
+ ```
47
+ RFD3 (design backbone) β†’ MPNN (design sequence) β†’ RF3 (validate fold)
48
+ ```
49
+
50
+ ```python
51
+ import torch
52
+ from diffusers import AutoModel, ModularPipeline
53
+
54
+ # 1. Design a backbone + sequence
55
+ design_pipe = ModularPipeline.from_pretrained("dn6/RFDiffusion-3", trust_remote_code=True)
56
+ design_pipe.load_components(device_map="cuda", torch_dtype=torch.bfloat16, trust_remote_code=True)
57
+
58
+ mpnn = AutoModel.from_pretrained("dn6/RFDiffusion-3", subfolder="mpnn", trust_remote_code=True)
59
+ design_pipe.update_components(mpnn=mpnn)
60
+
61
+ state = design_pipe(contigs="100", temperature=0.1)
62
+ designed_sequence = state.mpnn_output.designed_sequence
63
+
64
+ # 2. Validate the fold
65
+ fold_pipe = ModularPipeline.from_pretrained("dn6/RosettaFold-3", trust_remote_code=True)
66
+ fold_pipe.load_components(device_map="cuda", torch_dtype=torch.bfloat16, trust_remote_code=True)
67
+
68
+ state = fold_pipe(sequence=designed_sequence, output_type="cif.gz", output_path="prediction")
69
+ ```
70
+
71
+ ## Customizing Workflows
72
+
73
+ ```python
74
+ # Inspect the pipeline structure
75
+ print(pipe.blocks)
76
+
77
+ # Add a custom block
78
+ from diffusers.modular_pipelines import ModularPipelineBlocks, PipelineState
79
+ from diffusers.modular_pipelines.modular_pipeline_utils import InputParam, OutputParam
80
+
81
+ class ComputeRadiusOfGyration(ModularPipelineBlocks):
82
+ @property
83
+ def inputs(self):
84
+ return [InputParam("xyz", required=True)]
85
+
86
+ @property
87
+ def intermediate_outputs(self):
88
+ return [OutputParam("radius_of_gyration")]
89
+
90
+ def __call__(self, components, state):
91
+ block_state = self.get_block_state(state)
92
+ xyz = block_state.xyz
93
+ centroid = xyz.mean(dim=-2, keepdim=True)
94
+ block_state.radius_of_gyration = ((xyz - centroid) ** 2).sum(-1).mean().sqrt()
95
+ self.set_block_state(state, block_state)
96
+ return components, state
97
+
98
+ pipe._blocks.sub_blocks.insert("rog", ComputeRadiusOfGyration(), index=3)
99
+ ```
100
+
101
+ ## Output Types
102
+
103
+ | `output_type` | Additional output | Writes to disk |
104
+ |---|---|---|
105
+ | `"tensor"` | β€” | β€” |
106
+ | `"pdb"` | `pdb_string` | `.pdb` file |
107
+ | `"cif"` | `atom_array`, `atom_array_stack`, `trajectory_stack` | `.cif` via AtomWorks |
108
+ | `"cif.gz"` | Same as `"cif"` | `.cif.gz` compressed |
109
+
110
+ ```python
111
+ # CIF output with AtomArray
112
+ state = pipe(sequence="MKVLSEG...", output_type="cif.gz", output_path="fold_0")
113
+ atom_array = state.output.atom_array
114
+
115
+ # Denoising trajectory
116
+ trajectory = state.output.trajectory_stack
117
+
118
+ # PDB output
119
+ state = pipe(sequence="MKVLSEG...", output_type="pdb", output_path="fold_0.pdb")
120
+ ```
121
+
122
+ ## Model Architecture
123
+
124
+ RF3 is a diffusion model with the same EDM noise schedule as RFdiffusion3 (200 steps), but conditioned on sequence/MSA/template representations from a large recycling trunk:
125
+
126
+ | Component | Subfolder | Description |
127
+ |-----------|-----------|-------------|
128
+ | `transformer` | `transformer/` | `RF3TransformerModel` (366M params) β€” FeatureInitializer + Recycler (48 pairformer blocks) + DiffusionModule (24 transformer blocks) + DistogramHead |
129
+ | `scheduler` | `scheduler/` | `RF3Scheduler` (EDM schedule, gamma_0=0.8) |
130
+
131
+ ## Citation
132
+
133
+ ```bibtex
134
+ @article{corley2025accelerating,
135
+ author = {Corley, Nathaniel and Mathis, Simon and Krishna, Rohith and Bauer, Magnus S and Thompson, Tuscan R and Ahern, Woody and Kazman, Maxwell W and Brent, Rafael I and Didi, Kieran and Kubaney, Andrew and others},
136
+ title = {Accelerating biomolecular modeling with AtomWorks and RF3},
137
+ journal = {bioRxiv},
138
+ year = {2025},
139
+ }
140
+ ```
__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Dhruv Nair. All rights reserved.
2
+ # Licensed under the Apache License, Version 2.0
3
+
4
+ from .transformer import RF3TransformerModel, RF3TransformerOutput
5
+ from .scheduler import RF3Scheduler
6
+ from .modular_blocks import (
7
+ RF3AutoBeforeDenoiseStep,
8
+ RF3AutoBlocks,
9
+ RF3AutoDecodeStep,
10
+ RF3AutoDenoiseStep,
11
+ )
12
+ from .before_denoise import (
13
+ RF3InputStep,
14
+ RF3PrepareLatentsStep,
15
+ RF3RecyclingStep,
16
+ RF3SetTimestepsStep,
17
+ )
18
+ from .denoise import RF3DenoiseStep
19
+ from .decoders import RF3DecodeStep, RF3PipelineOutput
before_denoise.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Dhruv Nair. All rights reserved.
2
+ # Licensed under the Apache License, Version 2.0
3
+
4
+ """
5
+ Pre-denoising steps for RF3: input processing, timestep setup, recycling trunk, latent preparation.
6
+ """
7
+
8
+ from typing import List
9
+
10
+ import torch
11
+
12
+ from diffusers.utils import logging
13
+ from diffusers.modular_pipelines import ModularPipeline, ModularPipelineBlocks, PipelineState
14
+ from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
15
+
16
+
17
+ logger = logging.get_logger(__name__)
18
+
19
+
20
+ class RF3InputStep(ModularPipelineBlocks):
21
+ """Parse sequence input and prepare feature dict for RF3."""
22
+
23
+ model_name = "rf3"
24
+
25
+ @property
26
+ def description(self) -> str:
27
+ return "Parse sequence and optional MSA/template inputs for structure prediction."
28
+
29
+ @property
30
+ def inputs(self) -> List[InputParam]:
31
+ return [
32
+ InputParam("sequence", required=True, type_hint=str, description="Amino acid sequence (one-letter codes)"),
33
+ InputParam("f", type_hint=dict, description="Pre-built feature dict (overrides sequence)"),
34
+ ]
35
+
36
+ @property
37
+ def intermediate_outputs(self) -> List[OutputParam]:
38
+ return [
39
+ OutputParam("f", type_hint=dict, description="Feature dictionary for RF3"),
40
+ OutputParam("L", type_hint=int, description="Sequence length (num atoms)"),
41
+ OutputParam("I", type_hint=int, description="Num tokens"),
42
+ ]
43
+
44
+ @torch.no_grad()
45
+ def __call__(self, components, state):
46
+ block_state = self.get_block_state(state)
47
+
48
+ f = block_state.f
49
+ sequence = block_state.sequence
50
+
51
+ if f is None:
52
+ # Build minimal feature dict from sequence
53
+ L = len(sequence)
54
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
+
56
+ # Map sequence to restype indices
57
+ AA_ORDER = "ARNDCQEGHILKMFPSTWYV"
58
+ restype = torch.zeros(L, 32, device=device)
59
+ for i, aa in enumerate(sequence):
60
+ idx = AA_ORDER.find(aa)
61
+ if idx >= 0:
62
+ restype[i, idx] = 1.0
63
+ else:
64
+ restype[i, 20] = 1.0 # unknown
65
+
66
+ f = {
67
+ "restype": restype,
68
+ "atom_to_token_map": torch.arange(L, device=device),
69
+ "is_ca": torch.ones(L, dtype=torch.bool, device=device),
70
+ "ref_pos": torch.zeros(L, 3, device=device),
71
+ "ref_charge": torch.zeros(L, device=device),
72
+ "ref_mask": torch.ones(L, device=device),
73
+ "ref_element": torch.zeros(L, 128, device=device),
74
+ "ref_atom_name_chars": torch.zeros(L, 4, 64, device=device),
75
+ }
76
+ else:
77
+ L = f.get("ref_element", f.get("restype")).shape[0]
78
+
79
+ block_state.f = f
80
+ block_state.L = L
81
+ block_state.I = L # token count = atom count for CA-only
82
+
83
+ self.set_block_state(state, block_state)
84
+ return components, state
85
+
86
+
87
+ class RF3SetTimestepsStep(ModularPipelineBlocks):
88
+ """Set up EDM noise schedule for RF3."""
89
+
90
+ model_name = "rf3"
91
+
92
+ @property
93
+ def description(self) -> str:
94
+ return "Construct EDM noise schedule for RF3 diffusion sampling."
95
+
96
+ @property
97
+ def expected_components(self) -> List[ComponentSpec]:
98
+ return [ComponentSpec("scheduler", description="RF3 EDM scheduler")]
99
+
100
+ @property
101
+ def inputs(self) -> List[InputParam]:
102
+ return [
103
+ InputParam("num_inference_steps", default=None, type_hint=int),
104
+ InputParam("L", required=True, type_hint=int),
105
+ ]
106
+
107
+ @property
108
+ def intermediate_outputs(self) -> List[OutputParam]:
109
+ return [
110
+ OutputParam("noise_schedule", type_hint=torch.Tensor),
111
+ OutputParam("num_inference_steps", type_hint=int),
112
+ ]
113
+
114
+ @torch.no_grad()
115
+ def __call__(self, components, state):
116
+ block_state = self.get_block_state(state)
117
+
118
+ if hasattr(components, "scheduler") and components.scheduler is not None:
119
+ noise_schedule = components.scheduler.get_noise_schedule()
120
+ else:
121
+ noise_schedule = torch.linspace(160.0 * 16.0, 4e-4 * 16.0, 200)
122
+
123
+ block_state.noise_schedule = noise_schedule
124
+ block_state.num_inference_steps = len(noise_schedule)
125
+
126
+ self.set_block_state(state, block_state)
127
+ return components, state
128
+
129
+
130
+ class RF3RecyclingStep(ModularPipelineBlocks):
131
+ """Run the recycling trunk (pairformer + MSA + templates)."""
132
+
133
+ model_name = "rf3"
134
+
135
+ @property
136
+ def description(self) -> str:
137
+ return "Run RF3 recycling trunk to produce single/pair representations."
138
+
139
+ @property
140
+ def expected_components(self) -> List[ComponentSpec]:
141
+ return [ComponentSpec("transformer", description="RF3 transformer model")]
142
+
143
+ @property
144
+ def inputs(self) -> List[InputParam]:
145
+ return [
146
+ InputParam("f", required=True, type_hint=dict),
147
+ InputParam("n_recycles", default=None, type_hint=int),
148
+ ]
149
+
150
+ @property
151
+ def intermediate_outputs(self) -> List[OutputParam]:
152
+ return [
153
+ OutputParam("single", type_hint=torch.Tensor, description="Single representation [I, c_s]"),
154
+ OutputParam("pair", type_hint=torch.Tensor, description="Pair representation [I, I, c_z]"),
155
+ OutputParam("s_inputs", type_hint=torch.Tensor, description="Input embeddings [I, c_s_inputs]"),
156
+ OutputParam("distogram", type_hint=torch.Tensor, description="Distogram prediction [I, I, bins]"),
157
+ ]
158
+
159
+ @torch.no_grad()
160
+ def __call__(self, components, state):
161
+ block_state = self.get_block_state(state)
162
+
163
+ f = block_state.f
164
+ n_recycles = block_state.n_recycles
165
+
166
+ if hasattr(components, "transformer") and components.transformer is not None:
167
+ output = components.transformer(f=f, n_recycles=n_recycles)
168
+ block_state.single = output.single
169
+ block_state.pair = output.pair
170
+ block_state.distogram = output.distogram
171
+ block_state.s_inputs = None # populated inside forward
172
+ else:
173
+ # Placeholder when no model loaded
174
+ block_state.single = None
175
+ block_state.pair = None
176
+ block_state.distogram = None
177
+ block_state.s_inputs = None
178
+
179
+ self.set_block_state(state, block_state)
180
+ return components, state
181
+
182
+
183
+ class RF3PrepareLatentsStep(ModularPipelineBlocks):
184
+ """Prepare initial noised coordinates for diffusion sampling."""
185
+
186
+ model_name = "rf3"
187
+
188
+ @property
189
+ def description(self) -> str:
190
+ return "Sample initial Gaussian noise scaled by the first noise schedule value."
191
+
192
+ @property
193
+ def inputs(self) -> List[InputParam]:
194
+ return [
195
+ InputParam("generator", type_hint=torch.Generator),
196
+ InputParam("diffusion_batch_size", default=5, type_hint=int),
197
+ InputParam("L", required=True, type_hint=int),
198
+ InputParam("noise_schedule", required=True, type_hint=torch.Tensor),
199
+ ]
200
+
201
+ @property
202
+ def intermediate_outputs(self) -> List[OutputParam]:
203
+ return [
204
+ OutputParam("xyz", type_hint=torch.Tensor, description="Initial noised coords [D, L, 3]"),
205
+ ]
206
+
207
+ @torch.no_grad()
208
+ def __call__(self, components, state):
209
+ block_state = self.get_block_state(state)
210
+
211
+ L = block_state.L
212
+ noise_schedule = block_state.noise_schedule
213
+ D = block_state.diffusion_batch_size or 5
214
+ generator = block_state.generator
215
+
216
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
217
+ c0 = noise_schedule[0]
218
+ xyz = c0 * torch.randn((D, L, 3), device=device, generator=generator)
219
+
220
+ block_state.xyz = xyz
221
+
222
+ self.set_block_state(state, block_state)
223
+ return components, state
decoders.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Dhruv Nair. All rights reserved.
2
+ # Licensed under the Apache License, Version 2.0
3
+
4
+ """
5
+ Decode step for RF3 β€” converts denoised coordinates to output structures.
6
+ Supports tensor, PDB, and CIF (via AtomWorks) output formats.
7
+ """
8
+
9
+ from dataclasses import dataclass
10
+ from typing import List, Optional
11
+
12
+ import numpy as np
13
+ import torch
14
+ from atomworks.io.utils.io_utils import to_cif_file
15
+ from biotite.structure import AtomArray, AtomArrayStack, stack
16
+
17
+ from diffusers.utils import logging
18
+ from diffusers.modular_pipelines import ModularPipeline, ModularPipelineBlocks, PipelineState
19
+ from diffusers.modular_pipelines.modular_pipeline_utils import InputParam, OutputParam
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+ AA_ORDER = "ARNDCQEGHILKMFPSTWYV"
25
+ AA_NAMES_3 = [
26
+ "ALA", "ARG", "ASN", "ASP", "CYS", "GLN", "GLU", "GLY", "HIS", "ILE",
27
+ "LEU", "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP", "TYR", "VAL", "UNK",
28
+ ]
29
+
30
+
31
+ def _build_atom_array(xyz: torch.Tensor, sequence: Optional[str] = None) -> AtomArray:
32
+ xyz_np = xyz.detach().cpu().float().numpy()
33
+ L = xyz_np.shape[0]
34
+ arr = AtomArray(L)
35
+ arr.coord = xyz_np
36
+ arr.atom_name = np.full(L, "CA")
37
+ arr.element = np.full(L, "C")
38
+ arr.chain_id = np.full(L, "A")
39
+ arr.res_id = np.arange(1, L + 1)
40
+ if sequence:
41
+ arr.res_name = np.array([
42
+ AA_NAMES_3[AA_ORDER.find(aa)] if aa in AA_ORDER else "UNK"
43
+ for aa in sequence
44
+ ])
45
+ else:
46
+ arr.res_name = np.full(L, "ALA")
47
+ return arr
48
+
49
+
50
+ def _build_atom_array_stack(xyz: torch.Tensor, sequence: Optional[str] = None) -> AtomArrayStack:
51
+ template = _build_atom_array(xyz[0], sequence)
52
+ arr_stack = stack([template for _ in range(xyz.shape[0])])
53
+ arr_stack.coord = xyz.detach().cpu().float().numpy()
54
+ return arr_stack
55
+
56
+
57
+ @dataclass
58
+ class RF3PipelineOutput:
59
+ """Output class for RF3 pipeline."""
60
+
61
+ xyz: torch.Tensor
62
+ atom_array: Optional[AtomArray] = None
63
+ atom_array_stack: Optional[AtomArrayStack] = None
64
+ trajectory_stack: Optional[AtomArrayStack] = None
65
+ distogram: Optional[torch.Tensor] = None
66
+ sequence: Optional[str] = None
67
+ pdb_string: Optional[str] = None
68
+ trajectory: Optional[List[torch.Tensor]] = None
69
+
70
+
71
+ class RF3DecodeStep(ModularPipelineBlocks):
72
+ """
73
+ Decode step for RF3.
74
+
75
+ Supported ``output_type`` values: ``"tensor"``, ``"pdb"``, ``"cif"``, ``"cif.gz"``.
76
+ """
77
+
78
+ model_name = "rf3"
79
+
80
+ @property
81
+ def description(self) -> str:
82
+ return "Convert predicted coordinates to output format (tensor/PDB/CIF)."
83
+
84
+ @property
85
+ def inputs(self) -> List[InputParam]:
86
+ return [
87
+ InputParam("output_type", default="tensor", type_hint=str),
88
+ InputParam("output_path", type_hint=str),
89
+ InputParam("xyz", required=True, type_hint=torch.Tensor),
90
+ InputParam("sequence", type_hint=str),
91
+ InputParam("distogram", type_hint=torch.Tensor),
92
+ InputParam("trajectory", type_hint=List[torch.Tensor]),
93
+ ]
94
+
95
+ @property
96
+ def intermediate_outputs(self) -> List[OutputParam]:
97
+ return [
98
+ OutputParam("output", type_hint=RF3PipelineOutput),
99
+ ]
100
+
101
+ @torch.no_grad()
102
+ def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
103
+ block_state = self.get_block_state(state)
104
+
105
+ xyz = block_state.xyz
106
+ sequence = block_state.sequence
107
+ distogram = block_state.distogram
108
+ trajectory = block_state.trajectory
109
+ output_type = block_state.output_type or "tensor"
110
+ output_path = block_state.output_path
111
+
112
+ pdb_string = None
113
+ atom_array = None
114
+ atom_array_stack = None
115
+ trajectory_stack = None
116
+
117
+ if output_type in ("cif", "cif.gz"):
118
+ atom_array = _build_atom_array(xyz[0], sequence)
119
+ if xyz.shape[0] > 1:
120
+ atom_array_stack = _build_atom_array_stack(xyz, sequence)
121
+ if trajectory:
122
+ traj_coords = torch.stack([t[0] for t in trajectory])
123
+ template = _build_atom_array(traj_coords[0], sequence)
124
+ trajectory_stack = stack([template for _ in range(traj_coords.shape[0])])
125
+ trajectory_stack.coord = traj_coords.detach().cpu().float().numpy()
126
+
127
+ if output_type == "pdb":
128
+ pdb_string = self._coords_to_pdb(xyz[0], sequence)
129
+
130
+ if output_path is not None:
131
+ import os
132
+ os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
133
+ if output_type in ("cif", "cif.gz"):
134
+ to_write = atom_array_stack if atom_array_stack is not None else atom_array
135
+ base = output_path.rsplit(".", 1)[0] if "." in output_path else output_path
136
+ to_cif_file(to_write, base, file_type=output_type, include_entity_poly=False)
137
+ elif output_type == "pdb" and pdb_string:
138
+ with open(output_path, "w") as f:
139
+ f.write(pdb_string)
140
+
141
+ output = RF3PipelineOutput(
142
+ xyz=xyz,
143
+ atom_array=atom_array,
144
+ atom_array_stack=atom_array_stack,
145
+ trajectory_stack=trajectory_stack,
146
+ distogram=distogram,
147
+ sequence=sequence,
148
+ pdb_string=pdb_string,
149
+ trajectory=trajectory,
150
+ )
151
+
152
+ block_state.output = output
153
+ self.set_block_state(state, block_state)
154
+ return components, state
155
+
156
+ def _coords_to_pdb(self, xyz: torch.Tensor, sequence: Optional[str] = None) -> str:
157
+ xyz_np = xyz.cpu().numpy()
158
+ L = xyz_np.shape[0]
159
+ lines = []
160
+ for i in range(L):
161
+ aa = sequence[i] if sequence and i < len(sequence) else "A"
162
+ aa3 = AA_NAMES_3[AA_ORDER.find(aa)] if aa in AA_ORDER else "UNK"
163
+ x, y, z = xyz_np[i, :]
164
+ lines.append(
165
+ f"ATOM {i+1:5d} CA {aa3:3s} A{i+1:4d} "
166
+ f"{x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00 C "
167
+ )
168
+ lines.append("END")
169
+ return "\n".join(lines)
denoise.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Dhruv Nair. All rights reserved.
2
+ # Licensed under the Apache License, Version 2.0
3
+
4
+ """
5
+ Denoising loop for RF3.
6
+
7
+ Same EDM stochastic sampling as RFD3, but conditioned on trunk
8
+ representations (single S_I, pair Z_II) from the recycling step.
9
+ """
10
+
11
+ from typing import Callable, List
12
+
13
+ import torch
14
+
15
+ from diffusers.utils import logging
16
+ from diffusers.modular_pipelines import ModularPipeline, ModularPipelineBlocks, PipelineState
17
+ from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
18
+
19
+
20
+ logger = logging.get_logger(__name__)
21
+
22
+
23
+ class RF3DenoiseStep(ModularPipelineBlocks):
24
+ """
25
+ Iterative denoising step for RF3.
26
+
27
+ Uses trunk representations from the recycling step as conditioning
28
+ for the diffusion module at each denoising step.
29
+ """
30
+
31
+ model_name = "rf3"
32
+
33
+ @property
34
+ def description(self) -> str:
35
+ return "Iteratively denoise protein structure conditioned on sequence/MSA representations."
36
+
37
+ @property
38
+ def expected_components(self) -> List[ComponentSpec]:
39
+ return [
40
+ ComponentSpec("transformer", description="RF3 transformer (provides diffusion_module)"),
41
+ ComponentSpec("scheduler", description="RF3 EDM scheduler"),
42
+ ]
43
+
44
+ @property
45
+ def inputs(self) -> List[InputParam]:
46
+ return [
47
+ InputParam("xyz", required=True, type_hint=torch.Tensor, description="Initial noised coords [D, L, 3]"),
48
+ InputParam("noise_schedule", required=True, type_hint=torch.Tensor),
49
+ InputParam("f", required=True, type_hint=dict, description="Feature dictionary"),
50
+ InputParam("single", type_hint=torch.Tensor, description="Trunk single repr [I, c_s]"),
51
+ InputParam("pair", type_hint=torch.Tensor, description="Trunk pair repr [I, I, c_z]"),
52
+ InputParam("s_inputs", type_hint=torch.Tensor, description="Input embeddings [I, c_s_inputs]"),
53
+ InputParam("callback", type_hint=Callable),
54
+ InputParam("callback_steps", default=1, type_hint=int),
55
+ ]
56
+
57
+ @property
58
+ def intermediate_outputs(self) -> List[OutputParam]:
59
+ return [
60
+ OutputParam("xyz", type_hint=torch.Tensor, description="Denoised coords [D, L, 3]"),
61
+ OutputParam("trajectory", type_hint=List[torch.Tensor]),
62
+ ]
63
+
64
+ @torch.no_grad()
65
+ def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
66
+ block_state = self.get_block_state(state)
67
+
68
+ xyz = block_state.xyz
69
+ noise_schedule = block_state.noise_schedule
70
+ f = block_state.f
71
+ single = block_state.single
72
+ pair = block_state.pair
73
+ s_inputs = block_state.s_inputs
74
+ callback = block_state.callback
75
+ callback_steps = block_state.callback_steps or 1
76
+
77
+ X_L = xyz.clone()
78
+ D = X_L.shape[0]
79
+ device = X_L.device
80
+ noise_schedule = noise_schedule.to(device)
81
+
82
+ trajectory = []
83
+
84
+ has_transformer = hasattr(components, "transformer") and components.transformer is not None
85
+ has_scheduler = hasattr(components, "scheduler") and components.scheduler is not None
86
+
87
+ for step_num in range(len(noise_schedule) - 1):
88
+ c_t_minus_1 = noise_schedule[step_num]
89
+ c_t = noise_schedule[step_num + 1]
90
+
91
+ # Noise injection
92
+ if has_scheduler:
93
+ X_noisy, t_hat = components.scheduler.add_noise(X_L, c_t_minus_1, c_t)
94
+ else:
95
+ X_noisy = X_L
96
+ t_hat = c_t_minus_1
97
+
98
+ # Model forward (diffusion module conditioned on trunk)
99
+ if has_transformer:
100
+ t_batch = (t_hat.to(device).expand(D) if isinstance(t_hat, torch.Tensor)
101
+ else torch.full((D,), t_hat, device=device))
102
+
103
+ outs = components.transformer.diffusion_module(
104
+ X_noisy_L=X_noisy,
105
+ t=t_batch,
106
+ f=f,
107
+ S_inputs_I=s_inputs,
108
+ S_trunk_I=single,
109
+ Z_trunk_II=pair,
110
+ )
111
+ X_denoised = outs if isinstance(outs, torch.Tensor) else outs.get("X_L", outs)
112
+ else:
113
+ X_denoised = X_noisy
114
+
115
+ # Euler step
116
+ if has_scheduler:
117
+ X_L = components.scheduler.step(
118
+ xyz_pred=X_denoised, xyz_noisy=X_noisy,
119
+ c_t_minus_1=c_t_minus_1, c_t=c_t,
120
+ )
121
+ else:
122
+ delta = (X_noisy - X_denoised) / (t_hat + 1e-8)
123
+ d_t = c_t - t_hat
124
+ X_L = X_noisy + d_t * delta
125
+
126
+ trajectory.append(X_denoised.clone())
127
+
128
+ if callback is not None and step_num % callback_steps == 0:
129
+ callback(step_num, c_t_minus_1, X_L)
130
+
131
+ block_state.xyz = X_L
132
+ block_state.trajectory = trajectory
133
+
134
+ self.set_block_state(state, block_state)
135
+ return components, state
modular_blocks.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Dhruv Nair. All rights reserved.
2
+ # Licensed under the Apache License, Version 2.0
3
+
4
+ from diffusers.utils import logging
5
+ from diffusers.modular_pipelines import AutoPipelineBlocks, SequentialPipelineBlocks
6
+
7
+ from .before_denoise import (
8
+ RF3InputStep,
9
+ RF3PrepareLatentsStep,
10
+ RF3RecyclingStep,
11
+ RF3SetTimestepsStep,
12
+ )
13
+ from .decoders import RF3DecodeStep
14
+ from .denoise import RF3DenoiseStep
15
+
16
+
17
+ logger = logging.get_logger(__name__)
18
+
19
+
20
+ class RF3BeforeDenoiseStep(SequentialPipelineBlocks):
21
+ """Sequential block for pre-denoising: input β†’ timesteps β†’ recycling β†’ latents."""
22
+
23
+ block_classes = [
24
+ RF3InputStep,
25
+ RF3SetTimestepsStep,
26
+ RF3RecyclingStep,
27
+ RF3PrepareLatentsStep,
28
+ ]
29
+ block_names = ["input", "set_timesteps", "recycling", "prepare_latents"]
30
+
31
+ @property
32
+ def description(self):
33
+ return (
34
+ "Before denoise step:\n"
35
+ " - `RF3InputStep` parses sequence and builds feature dict\n"
36
+ " - `RF3SetTimestepsStep` constructs EDM noise schedule\n"
37
+ " - `RF3RecyclingStep` runs trunk recycler (pairformer + MSA + templates)\n"
38
+ " - `RF3PrepareLatentsStep` samples initial noised coordinates\n"
39
+ )
40
+
41
+
42
+ class RF3AutoBeforeDenoiseStep(AutoPipelineBlocks):
43
+ block_classes = [RF3BeforeDenoiseStep]
44
+ block_names = ["fold"]
45
+ block_trigger_inputs = [None]
46
+
47
+ @property
48
+ def description(self):
49
+ return "Before denoise step for RF3 structure prediction."
50
+
51
+
52
+ class RF3AutoDenoiseStep(AutoPipelineBlocks):
53
+ block_classes = [RF3DenoiseStep]
54
+ block_names = ["denoise"]
55
+ block_trigger_inputs = [None]
56
+
57
+ @property
58
+ def description(self) -> str:
59
+ return "Denoise step for RF3 structure prediction."
60
+
61
+
62
+ class RF3AutoDecodeStep(AutoPipelineBlocks):
63
+ block_classes = [RF3DecodeStep]
64
+ block_names = ["decode"]
65
+ block_trigger_inputs = [None]
66
+
67
+ @property
68
+ def description(self):
69
+ return "Decode step for RF3 β€” coordinates to tensor/PDB/CIF."
70
+
71
+
72
+ class RF3AutoBlocks(SequentialPipelineBlocks):
73
+ """Full RF3 structure prediction pipeline."""
74
+
75
+ block_classes = [
76
+ RF3AutoBeforeDenoiseStep,
77
+ RF3AutoDenoiseStep,
78
+ RF3AutoDecodeStep,
79
+ ]
80
+ block_names = [
81
+ "before_denoise",
82
+ "denoise",
83
+ "decoder",
84
+ ]
85
+
86
+ @property
87
+ def description(self):
88
+ return (
89
+ "Modular pipeline for protein structure prediction using RF3.\n"
90
+ "Provide `sequence` to predict a protein's 3D structure."
91
+ )
modular_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "RF3AutoBlocks",
3
+ "_diffusers_version": "0.37.0.dev0",
4
+ "auto_map": {
5
+ "ModularPipelineBlocks": "modular_blocks.RF3AutoBlocks"
6
+ }
7
+ }
modular_model_index.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_blocks_class_name": "RF3AutoBlocks",
3
+ "_class_name": "ModularPipeline",
4
+ "_diffusers_version": "0.37.0.dev0",
5
+ "transformer": [
6
+ null,
7
+ null,
8
+ {
9
+ "pretrained_model_name_or_path": "dn6/RosettaFold-3",
10
+ "subfolder": "transformer",
11
+ "type_hint": [
12
+ "diffusers",
13
+ "AutoModel"
14
+ ],
15
+ "revision": null,
16
+ "variant": null
17
+ }
18
+ ],
19
+ "scheduler": [
20
+ null,
21
+ null,
22
+ {
23
+ "pretrained_model_name_or_path": "dn6/RosettaFold-3",
24
+ "subfolder": "scheduler",
25
+ "type_hint": [
26
+ "diffusers",
27
+ "AutoModel"
28
+ ],
29
+ "revision": null,
30
+ "variant": null,
31
+ "default_creation_method": "from_config"
32
+ }
33
+ ]
34
+ }
scheduler/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright 2025 Dhruv Nair. All rights reserved.
2
+ # Licensed under the Apache License, Version 2.0
3
+
4
+ from .model import RF3Scheduler
scheduler/config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "RF3Scheduler",
3
+ "_diffusers_version": "0.37.0.dev0",
4
+ "auto_map": {
5
+ "AutoModel": "model.RF3Scheduler"
6
+ },
7
+ "num_timesteps": 200,
8
+ "sigma_data": 16.0,
9
+ "s_min": 4e-4,
10
+ "s_max": 160.0,
11
+ "p": 7.0,
12
+ "gamma_0": 0.8,
13
+ "gamma_min": 1.0,
14
+ "noise_scale": 1.003,
15
+ "step_scale": 1.5
16
+ }
scheduler/model.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Dhruv Nair. All rights reserved.
2
+ # Licensed under the Apache License, Version 2.0
3
+
4
+ """
5
+ RF3 Scheduler.
6
+
7
+ A diffusers-compatible wrapper around the foundry EDM noise schedule
8
+ for RF3. Same schedule formula as RFD3 but with gamma_0=0.8 (vs 0.6).
9
+ """
10
+
11
+ from typing import Optional
12
+
13
+ import torch
14
+
15
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
16
+
17
+ from rf3.diffusion_samplers.inference_sampler import SampleDiffusion
18
+
19
+
20
+ class RF3Scheduler(ConfigMixin):
21
+ """
22
+ Diffusers-compatible scheduler wrapping the foundry RF3 EDM sampler.
23
+ """
24
+
25
+ config_name = "config.json"
26
+
27
+ @register_to_config
28
+ def __init__(
29
+ self,
30
+ num_timesteps: int = 200,
31
+ sigma_data: float = 16.0,
32
+ s_min: float = 4e-4,
33
+ s_max: float = 160.0,
34
+ p: float = 7.0,
35
+ gamma_0: float = 0.8,
36
+ gamma_min: float = 1.0,
37
+ noise_scale: float = 1.003,
38
+ step_scale: float = 1.5,
39
+ ):
40
+ self._sampler = SampleDiffusion(
41
+ num_timesteps=num_timesteps,
42
+ min_t=0,
43
+ max_t=1,
44
+ sigma_data=sigma_data,
45
+ s_min=s_min,
46
+ s_max=s_max,
47
+ p=p,
48
+ gamma_0=gamma_0,
49
+ gamma_min=gamma_min,
50
+ noise_scale=noise_scale,
51
+ step_scale=step_scale,
52
+ solver="af3",
53
+ )
54
+
55
+ @property
56
+ def sampler(self) -> SampleDiffusion:
57
+ return self._sampler
58
+
59
+ def get_noise_schedule(self, device: torch.device = None) -> torch.Tensor:
60
+ """Construct the EDM noise schedule."""
61
+ return self._sampler._construct_inference_noise_schedule(
62
+ device=device or torch.device("cpu")
63
+ )
64
+
65
+ def add_noise(
66
+ self,
67
+ xyz: torch.Tensor,
68
+ c_t_minus_1: torch.Tensor,
69
+ c_t: torch.Tensor,
70
+ ) -> tuple[torch.Tensor, torch.Tensor]:
71
+ """Inject stochastic noise before the model call."""
72
+ gamma = self._sampler.gamma_0 if c_t > self._sampler.gamma_min else 0.0
73
+ t_hat = c_t_minus_1 * (gamma + 1.0)
74
+ noise_std = self._sampler.noise_scale * torch.sqrt(t_hat**2 - c_t_minus_1**2)
75
+ epsilon = noise_std * torch.randn_like(xyz)
76
+ return xyz + epsilon, t_hat
77
+
78
+ def step(
79
+ self,
80
+ xyz_pred: torch.Tensor,
81
+ xyz_noisy: torch.Tensor,
82
+ c_t_minus_1: torch.Tensor,
83
+ c_t: torch.Tensor,
84
+ ) -> torch.Tensor:
85
+ """Perform one Euler denoising step."""
86
+ gamma = self._sampler.gamma_0 if c_t > self._sampler.gamma_min else 0.0
87
+ t_hat = c_t_minus_1 * (gamma + 1.0)
88
+ delta = (xyz_noisy - xyz_pred) / t_hat
89
+ d_t = c_t - t_hat
90
+ return xyz_noisy + self._sampler.step_scale * d_t * delta
transformer/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright 2025 Dhruv Nair. All rights reserved.
2
+ # Licensed under the Apache License, Version 2.0
3
+
4
+ from .model import RF3TransformerModel, RF3TransformerOutput
transformer/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "RF3TransformerModel",
3
+ "_diffusers_version": "0.37.0.dev0",
4
+ "auto_map": {
5
+ "AutoModel": "model.RF3TransformerModel"
6
+ },
7
+ "c_s": 384,
8
+ "c_z": 128,
9
+ "c_atom": 128,
10
+ "c_atompair": 16,
11
+ "c_s_inputs": 449,
12
+ "c_token": 768,
13
+ "sigma_data": 16.0,
14
+ "n_pairformer_blocks": 48,
15
+ "n_diffusion_blocks": 24,
16
+ "n_atom_encoder_blocks": 3,
17
+ "n_atom_decoder_blocks": 3,
18
+ "n_msa_blocks": 4,
19
+ "n_template_blocks": 2,
20
+ "n_head": 16,
21
+ "n_pairformer_head": 16,
22
+ "n_recycles": 10,
23
+ "distogram_bins": 65,
24
+ "p_drop": 0.25
25
+ }
transformer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d30d1003e326ef59e8154bc3cf3af928562715fa5410acb63de88c6486ae275
3
+ size 1466334428
transformer/model.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Dhruv Nair. All rights reserved.
2
+ # Licensed under the Apache License, Version 2.0
3
+
4
+ """
5
+ RF3 (RosettaFold3) Transformer model.
6
+
7
+ A diffusers-compatible wrapper around the foundry RF3 model components.
8
+ Reuses FeatureInitializer, Recycler, DiffusionModule, and DistogramHead
9
+ from ``rf3.model.*`` directly, adding only the ModelMixin/ConfigMixin
10
+ interface needed for diffusers ModularPipeline integration.
11
+
12
+ RF3 is structurally similar to RFD3 but adds a trunk recycler (48
13
+ pairformer blocks + MSA + templates) for sequence-conditioned folding.
14
+ """
15
+
16
+ from dataclasses import dataclass
17
+ from typing import Optional
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.models.modeling_utils import ModelMixin
24
+
25
+ from rf3.model.RF3_structure import DiffusionModule, DistogramHead, Recycler
26
+ from rf3.model.layers.pairformer_layers import FeatureInitializer
27
+
28
+
29
+ @dataclass
30
+ class RF3TransformerOutput:
31
+ """Output class for RF3 transformer."""
32
+
33
+ xyz: torch.Tensor # [D, L, 3]
34
+ distogram: Optional[torch.Tensor] = None # [I, I, bins]
35
+ single: Optional[torch.Tensor] = None # [I, c_s]
36
+ pair: Optional[torch.Tensor] = None # [I, I, c_z]
37
+ trajectory_noisy: Optional[list] = None # list of [D, L, 3]
38
+ trajectory_denoised: Optional[list] = None # list of [D, L, 3]
39
+
40
+
41
+ class RF3TransformerModel(ModelMixin, ConfigMixin):
42
+ """
43
+ Diffusers-compatible wrapper around the foundry RF3 model.
44
+
45
+ Wraps FeatureInitializer, Recycler, DiffusionModule, and DistogramHead
46
+ to provide a diffusers ModelMixin/ConfigMixin interface.
47
+
48
+ State dict keys match the foundry checkpoint format via the
49
+ ``feature_initializer.*``, ``recycler.*``, ``diffusion_module.*``,
50
+ and ``distogram_head.*`` prefixes.
51
+ """
52
+
53
+ config_name = "config.json"
54
+ _supports_gradient_checkpointing = True
55
+
56
+ @register_to_config
57
+ def __init__(
58
+ self,
59
+ c_s: int = 384,
60
+ c_z: int = 128,
61
+ c_atom: int = 128,
62
+ c_atompair: int = 16,
63
+ c_s_inputs: int = 449,
64
+ c_token: int = 768,
65
+ sigma_data: float = 16.0,
66
+ n_pairformer_blocks: int = 48,
67
+ n_diffusion_blocks: int = 24,
68
+ n_atom_encoder_blocks: int = 3,
69
+ n_atom_decoder_blocks: int = 3,
70
+ n_msa_blocks: int = 4,
71
+ n_template_blocks: int = 2,
72
+ n_head: int = 16,
73
+ n_pairformer_head: int = 16,
74
+ n_recycles: int = 10,
75
+ distogram_bins: int = 65,
76
+ p_drop: float = 0.25,
77
+ ):
78
+ super().__init__()
79
+
80
+ # ── FeatureInitializer ──────────────────────────────────────────
81
+ self.feature_initializer = FeatureInitializer(
82
+ c_s=c_s,
83
+ c_z=c_z,
84
+ c_atom=c_atom,
85
+ c_atompair=c_atompair,
86
+ c_s_inputs=c_s_inputs,
87
+ input_feature_embedder={
88
+ "features": ["restype", "profile", "deletion_mean"],
89
+ "atom_attention_encoder": {
90
+ "c_token": c_s,
91
+ "c_atom_1d_features": 389,
92
+ "c_tokenpair": c_z,
93
+ "use_inv_dist_squared": True,
94
+ "atom_1d_features": [
95
+ "ref_pos", "ref_charge", "ref_mask",
96
+ "ref_element", "ref_atom_name_chars",
97
+ ],
98
+ "atom_transformer": {
99
+ "n_queries": 32,
100
+ "n_keys": 128,
101
+ "diffusion_transformer": {
102
+ "n_block": 3,
103
+ "diffusion_transformer_block": {
104
+ "n_head": 4,
105
+ "no_residual_connection_between_attention_and_transition": True,
106
+ "kq_norm": True,
107
+ },
108
+ },
109
+ },
110
+ },
111
+ },
112
+ relative_position_encoding={"r_max": 32, "s_max": 2},
113
+ )
114
+
115
+ # ── Recycler (trunk) ───────────────────────────────────────────
116
+ self.recycler = Recycler(
117
+ c_s=c_s,
118
+ c_z=c_z,
119
+ n_pairformer_blocks=n_pairformer_blocks,
120
+ pairformer_block={
121
+ "p_drop": p_drop,
122
+ "triangle_multiplication": {"d_hidden": 128},
123
+ "triangle_attention": {"n_head": 4, "d_hidden": 32},
124
+ "attention_pair_bias": {"n_head": n_head},
125
+ },
126
+ template_embedder={
127
+ "n_block": n_template_blocks,
128
+ "raw_template_dim": 108,
129
+ "c": 64,
130
+ "p_drop": p_drop,
131
+ },
132
+ msa_module={
133
+ "n_block": n_msa_blocks,
134
+ "c_m": 64,
135
+ "p_drop_msa": 0.15,
136
+ "p_drop_pair": p_drop,
137
+ "msa_subsample_embedder": {
138
+ "num_sequences": 1024,
139
+ "dim_raw_msa": 34,
140
+ "c_s_inputs": c_s_inputs,
141
+ "c_msa_embed": 64,
142
+ },
143
+ "outer_product": {
144
+ "c_msa_embed": 64,
145
+ "c_outer_product": 32,
146
+ "c_out": c_z,
147
+ },
148
+ "msa_pair_weighted_averaging": {
149
+ "n_heads": 8,
150
+ "c_weighted_average": 32,
151
+ "c_msa_embed": 64,
152
+ "c_z": c_z,
153
+ "separate_gate_for_every_channel": True,
154
+ },
155
+ "msa_transition": {"n": 4, "c": 64},
156
+ "triangle_multiplication_outgoing": {
157
+ "d_pair": c_z, "d_hidden": 128, "bias": True,
158
+ },
159
+ "triangle_multiplication_incoming": {
160
+ "d_pair": c_z, "d_hidden": 128, "bias": True,
161
+ },
162
+ "triangle_attention_starting": {
163
+ "d_pair": c_z, "n_head": 4, "d_hidden": 32, "p_drop": 0.0,
164
+ },
165
+ "triangle_attention_ending": {
166
+ "d_pair": c_z, "n_head": 4, "d_hidden": 32, "p_drop": 0.0,
167
+ },
168
+ "pair_transition": {"n": 4, "c": c_z},
169
+ },
170
+ )
171
+
172
+ # ── DiffusionModule ────────────────────────────────────────────
173
+ self.diffusion_module = DiffusionModule(
174
+ sigma_data=sigma_data,
175
+ c_atom=c_atom,
176
+ c_atompair=c_atompair,
177
+ c_token=c_token,
178
+ c_s=c_s,
179
+ c_z=c_z,
180
+ diffusion_conditioning={
181
+ "c_s_inputs": c_s_inputs,
182
+ "c_t_embed": 256,
183
+ "relative_position_encoding": {"r_max": 32, "s_max": 2},
184
+ },
185
+ atom_attention_encoder={
186
+ "c_tokenpair": c_z,
187
+ "c_atom_1d_features": 389,
188
+ "use_inv_dist_squared": True,
189
+ "atom_1d_features": [
190
+ "ref_pos", "ref_charge", "ref_mask",
191
+ "ref_element", "ref_atom_name_chars",
192
+ ],
193
+ "atom_transformer": {
194
+ "n_queries": 32,
195
+ "n_keys": 128,
196
+ "diffusion_transformer": {
197
+ "n_block": n_atom_encoder_blocks,
198
+ "diffusion_transformer_block": {
199
+ "n_head": 4,
200
+ "no_residual_connection_between_attention_and_transition": True,
201
+ "kq_norm": True,
202
+ },
203
+ },
204
+ },
205
+ "broadcast_trunk_feats_on_1dim_old": False,
206
+ "use_chiral_features": True,
207
+ "no_grad_on_chiral_center": False,
208
+ },
209
+ diffusion_transformer={
210
+ "n_block": n_diffusion_blocks,
211
+ "diffusion_transformer_block": {
212
+ "n_head": n_head,
213
+ "no_residual_connection_between_attention_and_transition": True,
214
+ "kq_norm": True,
215
+ },
216
+ },
217
+ atom_attention_decoder={
218
+ "atom_transformer": {
219
+ "n_queries": 32,
220
+ "n_keys": 128,
221
+ "diffusion_transformer": {
222
+ "n_block": n_atom_decoder_blocks,
223
+ "diffusion_transformer_block": {
224
+ "n_head": 4,
225
+ "no_residual_connection_between_attention_and_transition": True,
226
+ "kq_norm": True,
227
+ },
228
+ },
229
+ },
230
+ },
231
+ )
232
+
233
+ # ── DistogramHead ──────────────────────────────────────────────
234
+ self.distogram_head = DistogramHead(c_z=c_z, bins=distogram_bins)
235
+
236
+ self._n_recycles = n_recycles
237
+
238
+ def forward(
239
+ self,
240
+ f: dict,
241
+ n_recycles: Optional[int] = None,
242
+ diffusion_batch_size: int = 1,
243
+ coord_atom_lvl_to_be_noised: Optional[torch.Tensor] = None,
244
+ ) -> RF3TransformerOutput:
245
+ """
246
+ Forward pass: recycling trunk β†’ diffusion sampling.
247
+
248
+ Args:
249
+ f: Feature dictionary (sequence, MSA, templates, atom features).
250
+ n_recycles: Number of recycling iterations (default: config value).
251
+ diffusion_batch_size: Number of diffusion samples.
252
+ coord_atom_lvl_to_be_noised: Initial coordinates for partial diffusion.
253
+
254
+ Returns:
255
+ RF3TransformerOutput with predicted coordinates and distogram.
256
+ """
257
+ n_recycles = n_recycles or self._n_recycles
258
+
259
+ # Pre-recycle: initialize features
260
+ initialized = self.feature_initializer(f)
261
+ S_inputs_I = initialized["S_inputs_I"]
262
+ S_I = initialized.get("S_init_I", initialized.get("S_I"))
263
+ Z_II = initialized.get("Z_init_II", initialized.get("Z_II"))
264
+
265
+ # Recycling trunk
266
+ for i in range(n_recycles):
267
+ ctx = torch.no_grad() if i < n_recycles - 1 else torch.enable_grad()
268
+ with ctx:
269
+ recycled = self.recycler(
270
+ S_I=S_I,
271
+ Z_II=Z_II,
272
+ S_inputs_I=S_inputs_I,
273
+ f=f,
274
+ )
275
+ S_I = recycled["S_I"]
276
+ Z_II = recycled["Z_II"]
277
+
278
+ # Distogram prediction
279
+ distogram = self.distogram_head(Z_II)
280
+
281
+ return RF3TransformerOutput(
282
+ xyz=torch.zeros(1), # placeholder β€” filled by sampler in denoise step
283
+ distogram=distogram,
284
+ single=S_I,
285
+ pair=Z_II,
286
+ )