dn6 HF Staff commited on
Commit
4900749
Β·
verified Β·
1 Parent(s): 0646043

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Protein Design with Diffusers
2
+
3
+ A [diffusers](https://github.com/huggingface/diffusers) `ModularPipeline` wrapper for protein design, combining structure generation ([RFdiffusion3](https://www.biorxiv.org/content/10.1101/2025.09.18.676967v2)) and sequence design ([ProteinMPNN](https://www.science.org/doi/10.1126/science.add2187) / [LigandMPNN](https://www.nature.com/articles/s41592-025-02626-1)) into composable, swappable pipeline blocks.
4
+
5
+ All three models β€” RFD3, ProteinMPNN, and LigandMPNN β€” rely on [Foundry](https://github.com/RosettaCommons/foundry) for their underlying implementations and [AtomWorks](https://github.com/RosettaCommons/atomworks) for structure I/O. This package adds only the thin wrappers needed for diffusers integration.
6
+
7
+ ## Getting Started
8
+
9
+ ### Installation
10
+
11
+ ```bash
12
+ # Install foundry (provides model implementations + AtomWorks)
13
+ pip install rc-foundry[all]
14
+
15
+ # Install diffusers with modular pipeline support
16
+ pip install diffusers
17
+ ```
18
+
19
+ ### Running with Diffusers
20
+
21
+ ```python
22
+ import torch
23
+ from diffusers import ModularPipeline
24
+
25
+ # Load the pipeline
26
+ pipe = ModularPipeline.from_pretrained(
27
+ "dn6/RFDiffusion-3",
28
+ trust_remote_code=True,
29
+ )
30
+ pipe.load_components(
31
+ device_map="cuda",
32
+ torch_dtype=torch.bfloat16,
33
+ trust_remote_code=True,
34
+ )
35
+
36
+ # Generate a 100-residue protein backbone
37
+ state = pipe(contigs="100")
38
+
39
+ # Access coordinates
40
+ print(state.output.xyz.shape) # [B, L, 3]
41
+ ```
42
+
43
+ ## Why Diffusers?
44
+
45
+ Wrapping RFdiffusion3 as a diffusers `ModularPipeline` gives you access to the full diffusers ecosystem out of the box:
46
+
47
+ ### CPU Offloading
48
+
49
+ Run large models on limited VRAM by offloading components to CPU when not in use:
50
+
51
+ ```python
52
+ pipe.enable_model_cpu_offload()
53
+ state = pipe(contigs="100")
54
+ ```
55
+
56
+ ### Hub Integration
57
+
58
+ All models are hosted on the Hugging Face Hub. Load by repo ID, share fine-tuned variants, and version your checkpoints:
59
+
60
+ ```python
61
+ pipe = ModularPipeline.from_pretrained("dn6/RFDiffusion-3", trust_remote_code=True)
62
+ ```
63
+
64
+ ### LoRA Fine-Tuning
65
+
66
+ Fine-tune RFdiffusion3 on custom datasets with LoRA β€” supported natively by `ModelMixin`:
67
+
68
+ ```python
69
+ from peft import LoraConfig
70
+
71
+ lora_config = LoraConfig(r=16, lora_alpha=16, target_modules=["to_q", "to_k", "to_v"])
72
+ pipe.transformer.add_adapter(lora_config)
73
+
74
+ # After training
75
+ pipe.transformer.save_pretrained("my-rfd3-lora")
76
+ ```
77
+
78
+ ### Composable Workflows
79
+
80
+ Inspect, swap, and extend pipeline blocks at runtime:
81
+
82
+ ```python
83
+ # Inspect the pipeline
84
+ print(pipe.blocks)
85
+
86
+ # Swap ProteinMPNN for LigandMPNN
87
+ mpnn = AutoModel.from_pretrained("dn6/RFDiffusion-3", subfolder="mpnn_ligand", trust_remote_code=True)
88
+ pipe.update_components(mpnn=mpnn)
89
+
90
+ # Add a custom post-processing block
91
+ from diffusers.modular_pipelines import ModularPipelineBlocks, PipelineState
92
+ from diffusers.modular_pipelines.modular_pipeline_utils import InputParam, OutputParam
93
+
94
+ class ScoreDesignStep(ModularPipelineBlocks):
95
+ @property
96
+ def inputs(self):
97
+ return [InputParam("xyz", required=True)]
98
+
99
+ @property
100
+ def intermediate_outputs(self):
101
+ return [OutputParam("radius_of_gyration")]
102
+
103
+ def __call__(self, components, state):
104
+ block_state = self.get_block_state(state)
105
+ xyz = block_state.xyz
106
+ centroid = xyz.mean(dim=-2, keepdim=True)
107
+ block_state.radius_of_gyration = ((xyz - centroid) ** 2).sum(-1).mean().sqrt()
108
+ self.set_block_state(state, block_state)
109
+ return components, state
110
+
111
+ # Insert after the decoder
112
+ pipe._blocks.sub_blocks.insert("score", ScoreDesignStep(), index=3)
113
+ ```
114
+
115
+ ## Output Types
116
+
117
+ All output types return the base tensors (`xyz`, `sequence_indices`, `sequence_logits`). The `output_type` parameter controls what additional format is produced:
118
+
119
+ | `output_type` | Additional output | Writes to disk |
120
+ |---|---|---|
121
+ | `"tensor"` | β€” | β€” |
122
+ | `"pdb"` | `pdb_string` | `.pdb` file |
123
+ | `"cif"` | `atom_array`, `atom_array_stack`, `trajectory_stack` | `.cif` via AtomWorks |
124
+ | `"cif.gz"` | Same as `"cif"` | `.cif.gz` compressed |
125
+
126
+ CIF outputs use [AtomWorks](https://github.com/RosettaCommons/atomworks) `to_cif_file` and return [biotite](https://www.biotite-python.org/) `AtomArray` / `AtomArrayStack` objects, matching the foundry output format.
127
+
128
+ ```python
129
+ # Save as compressed CIF (matches foundry output format)
130
+ state = pipe(contigs="100", output_type="cif.gz", output_path="design_0")
131
+
132
+ # AtomArray is available directly
133
+ atom_array = state.output.atom_array
134
+ print(atom_array) # biotite AtomArray with CA coords + residue names
135
+
136
+ # Denoising trajectory as AtomArrayStack (one model per step)
137
+ trajectory = state.output.trajectory_stack
138
+
139
+ # PDB string output
140
+ state = pipe(contigs="100", output_type="pdb", output_path="design_0.pdb")
141
+ print(state.output.pdb_string[:200])
142
+ ```
143
+
144
+ ## Models
145
+
146
+ By default, `load_components` loads the RFdiffusion3 transformer and scheduler. MPNN models are optional β€” load them separately when you need sequence design.
147
+
148
+ ### RFdiffusion3 (RFD3)
149
+
150
+ [RFdiffusion3](https://www.biorxiv.org/content/10.1101/2025.09.18.676967v2) is an all-atom generative model that designs protein structures via iterative denoising. Uses an EDM noise schedule with 200 steps. Loaded automatically by `load_components`.
151
+
152
+ | Component | Subfolder | Description |
153
+ |-----------|-----------|-------------|
154
+ | `transformer` | `transformer/` | `RFDiffusionTransformerModel` (168M params) |
155
+ | `scheduler` | `scheduler/` | `RFDiffusionScheduler` (EDM noise schedule + Euler stepping) |
156
+
157
+ ### ProteinMPNN / LigandMPNN
158
+
159
+ [ProteinMPNN](https://www.science.org/doi/10.1126/science.add2187) and [LigandMPNN](https://www.nature.com/articles/s41592-025-02626-1) are inverse-folding models that design amino acid sequences for a given protein backbone. These are **not** loaded by default β€” load them with `AutoModel` and register via `update_components`:
160
+
161
+ ```python
162
+ from diffusers import AutoModel
163
+
164
+ mpnn = AutoModel.from_pretrained("dn6/RFDiffusion-3", subfolder="mpnn", trust_remote_code=True)
165
+ pipe.update_components(mpnn=mpnn)
166
+ ```
167
+
168
+ Three variants are available:
169
+
170
+ | Subfolder | Variant | Params | Description |
171
+ |-----------|---------|--------|-------------|
172
+ | `mpnn/` | ProteinMPNN | 1.66M | Standard protein sequence design |
173
+ | `mpnn_ligand/` | LigandMPNN | 2.62M | Ligand-aware sequence design |
174
+ | `mpnn_soluble/` | SolubleMPNN | 1.66M | Optimized for soluble proteins |
175
+
176
+ ## Workflows
177
+
178
+ The active workflow is selected automatically based on which inputs you provide. Passing `temperature` triggers the MPNN sequence design step; passing `input_xyz` enables motif conditioning.
179
+
180
+ | Workflow | Trigger inputs | What runs |
181
+ |----------|---------------|-----------|
182
+ | `structure_only` | `contigs` | RFdiffusion3 |
183
+ | `structure_and_sequence` | `contigs`, `temperature` | RFdiffusion3 β†’ MPNN |
184
+ | `motif_structure_and_sequence` | `contigs`, `input_xyz`, `temperature` | Motif-conditioned RFdiffusion3 β†’ MPNN |
185
+
186
+ > Workflows that include MPNN require loading an MPNN variant first (see above).
187
+
188
+ You can also select a workflow explicitly:
189
+
190
+ ```python
191
+ workflow = pipe.get_workflow("structure_and_sequence")
192
+ ```
193
+
194
+ ### Structure Only
195
+
196
+ ```python
197
+ state = pipe(contigs="100")
198
+ print(state.output.xyz.shape) # [1, 100, 3]
199
+ ```
200
+
201
+ ### Structure + Sequence Design
202
+
203
+ ```python
204
+ from diffusers import AutoModel
205
+
206
+ # Load an MPNN variant and register it
207
+ mpnn = AutoModel.from_pretrained("dn6/RFDiffusion-3", subfolder="mpnn", trust_remote_code=True)
208
+ pipe.update_components(mpnn=mpnn)
209
+
210
+ # Passing temperature triggers the MPNN step
211
+ state = pipe(contigs="100", temperature=0.1)
212
+ print(state.mpnn_output.designed_sequence) # e.g. "MKVLSEG..."
213
+ ```
214
+
215
+ ### Motif-Conditioned Design
216
+
217
+ ```python
218
+ import torch
219
+
220
+ motif_coords = torch.randn(16, 3) # [N_motif, 3]
221
+ state = pipe(
222
+ contigs="A10-25/50",
223
+ input_xyz=motif_coords,
224
+ temperature=0.1,
225
+ )
226
+ ```
227
+
228
+ ## Full Design Pipeline
229
+
230
+ The three pipelines can be composed into a complete protein design workflow:
231
+
232
+ ```
233
+ RFD3 (design backbone) β†’ MPNN (design sequence) β†’ RF3 (validate fold)
234
+ ```
235
+
236
+ Each is a standalone `ModularPipeline` that can run independently. Here's the full end-to-end flow:
237
+
238
+ ```python
239
+ import torch
240
+ from diffusers import AutoModel, ModularPipeline
241
+
242
+ # 1. Design a backbone with RFdiffusion3
243
+ design_pipe = ModularPipeline.from_pretrained("dn6/RFDiffusion-3", trust_remote_code=True)
244
+ design_pipe.load_components(device_map="cuda", torch_dtype=torch.bfloat16, trust_remote_code=True)
245
+
246
+ mpnn = AutoModel.from_pretrained("dn6/RFDiffusion-3", subfolder="mpnn", trust_remote_code=True)
247
+ design_pipe.update_components(mpnn=mpnn)
248
+
249
+ state = design_pipe(contigs="100", temperature=0.1, output_type="cif.gz", output_path="design")
250
+ designed_sequence = state.mpnn_output.designed_sequence
251
+
252
+ # 2. Validate the design with RF3 (structure prediction)
253
+ fold_pipe = ModularPipeline.from_pretrained("dn6/RosettaFold-3", trust_remote_code=True)
254
+ fold_pipe.load_components(device_map="cuda", torch_dtype=torch.bfloat16, trust_remote_code=True)
255
+
256
+ state = fold_pipe(sequence=designed_sequence, output_type="cif.gz", output_path="prediction")
257
+ ```
258
+
259
+ > [RF3](https://www.biorxiv.org/content/10.1101/2025.08.14.670328) (RosettaFold3) is available as a separate pipeline at [`dn6/RosettaFold-3`](https://huggingface.co/dn6/RosettaFold-3).
260
+
261
+ ## Citation
262
+
263
+ If you use this code, please cite the relevant work:
264
+
265
+ ```bibtex
266
+ @article{butcher2025_rfdiffusion3,
267
+ author = {Butcher, Jasper and Krishna, Rohith and Mitra, Raktim and Brent, Rafael Isaac and Li, Yanjing and Corley, Nathaniel and Kim, Paul T and Funk, Jonathan and Mathis, Simon Valentin and Salike, Saman and Muraishi, Aiko and Eisenach, Helen and Thompson, Tuscan Rock and Chen, Jie and Politanska, Yuliya and Sehgal, Enisha and Coventry, Brian and Zhang, Odin and Qiang, Bo and Didi, Kieran and Kazman, Maxwell and DiMaio, Frank and Baker, David},
268
+ title = {De novo Design of All-atom Biomolecular Interactions with RFdiffusion3},
269
+ journal = {bioRxiv},
270
+ year = {2025},
271
+ doi = {10.1101/2025.09.18.676967},
272
+ }
273
+
274
+ @article{dauparas2022robust,
275
+ author = {Dauparas, Justas and Anishchenko, Ivan and Bennett, Nathaniel and Bai, Hua and Ragotte, Robert J and Milles, Lukas F and Wicky, Basile IM and Courbet, Alexis and de Haas, Rob J and Bethel, Neville and others},
276
+ title = {Robust deep learning--based protein sequence design using ProteinMPNN},
277
+ journal = {Science},
278
+ volume = {378},
279
+ number = {6615},
280
+ pages = {49--56},
281
+ year = {2022},
282
+ }
283
+
284
+ @article{dauparas2025atomic,
285
+ author = {Dauparas, Justas and Lee, Gyu Rie and Pecoraro, Robert and An, Linna and Anishchenko, Ivan and Glasscock, Cameron and Baker, David},
286
+ title = {Atomic context-conditioned protein sequence design using LigandMPNN},
287
+ journal = {Nature Methods},
288
+ year = {2025},
289
+ }
290
+ ```
__init__.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Dhruv Nair. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .transformer import RFDiffusionTransformerModel, RFDiffusionTransformerOutput
16
+ from .scheduler import RFDiffusionScheduler
17
+ from .modular_blocks import (
18
+ ALL_BLOCKS,
19
+ AUTO_BLOCKS,
20
+ UNCONDITIONAL_BLOCKS,
21
+ RFDiffusionAutoBeforeDenoiseStep,
22
+ RFDiffusionAutoBlocks,
23
+ RFDiffusionAutoDecodeStep,
24
+ RFDiffusionAutoDenoiseStep,
25
+ )
26
+ from .before_denoise import (
27
+ RFDiffusionInputStep,
28
+ RFDiffusionPrepareLatentsStep,
29
+ RFDiffusionSetTimestepsStep,
30
+ )
31
+ from .denoise import RFDiffusionDenoiseStep
32
+ from .decoders import RFDiffusionDecodeStep, RFDiffusionPipelineOutput
33
+ from .mpnn import MPNNModel, MPNNModelOutput
34
+ from .modular_blocks import MPNNAutoDesignStep, MPNNPipelineOutput, MPNNSequenceDesignStep
before_denoise.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Dhruv Nair. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, Tuple, Union
16
+
17
+ import torch
18
+
19
+ from diffusers.utils import logging
20
+ from diffusers.modular_pipelines import ModularPipeline, ModularPipelineBlocks, PipelineState
21
+ from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ def parse_contig_string(contig_str: str) -> Tuple[int, List[Tuple[int, int]]]:
28
+ """
29
+ Parse contig specification string.
30
+
31
+ Supports formats like:
32
+ - "100" -> 100 residues to design
33
+ - "50-100" -> random length between 50-100
34
+ - "A10-25/50" -> motif from chain A residues 10-25, plus 50 designed
35
+
36
+ Returns:
37
+ total_length: Total protein length
38
+ motif_ranges: List of (start, end) for motif residues (0-indexed)
39
+ """
40
+ parts = contig_str.split("/")
41
+ total_length = 0
42
+ motif_ranges = []
43
+
44
+ for part in parts:
45
+ part = part.strip()
46
+ if not part:
47
+ continue
48
+
49
+ if part[0].isalpha():
50
+ chain = part[0]
51
+ residue_spec = part[1:]
52
+ if "-" in residue_spec:
53
+ start, end = map(int, residue_spec.split("-"))
54
+ else:
55
+ start = end = int(residue_spec)
56
+ motif_len = end - start + 1
57
+ motif_ranges.append((total_length, total_length + motif_len))
58
+ total_length += motif_len
59
+ else:
60
+ if "-" in part:
61
+ min_len, max_len = map(int, part.split("-"))
62
+ add_len = (min_len + max_len) // 2
63
+ else:
64
+ add_len = int(part)
65
+ total_length += add_len
66
+
67
+ return total_length, motif_ranges
68
+
69
+
70
+ class RFDiffusionInputStep(ModularPipelineBlocks):
71
+ """
72
+ Input processing step for RFDiffusion.
73
+
74
+ Parses contigs to prepare features for structure generation.
75
+ """
76
+
77
+ model_name = "rfdiffusion"
78
+
79
+ @property
80
+ def description(self) -> str:
81
+ return (
82
+ "Input processing step that:\n"
83
+ " 1. Parses contig specification to determine protein length and design regions\n"
84
+ " 2. Generates masks for motif positions\n"
85
+ )
86
+
87
+ @property
88
+ def inputs(self) -> List[InputParam]:
89
+ return [
90
+ InputParam(
91
+ "contigs",
92
+ required=True,
93
+ type_hint=Union[str, List[str]],
94
+ description="Contig specification defining design regions (e.g., '100' or 'A10-25/50-100')",
95
+ ),
96
+ InputParam(
97
+ "input_xyz",
98
+ type_hint=torch.Tensor,
99
+ description="Input coordinates for motif residues [N_motif, 3]",
100
+ ),
101
+ ]
102
+
103
+ @property
104
+ def intermediate_outputs(self) -> List[OutputParam]:
105
+ return [
106
+ OutputParam(
107
+ "motif_mask",
108
+ type_hint=torch.Tensor,
109
+ description="Boolean mask for motif (fixed) positions",
110
+ ),
111
+ OutputParam(
112
+ "motif_xyz",
113
+ type_hint=torch.Tensor,
114
+ description="Coordinates for motif residues",
115
+ ),
116
+ OutputParam(
117
+ "L",
118
+ type_hint=int,
119
+ description="Total length of the protein being designed",
120
+ ),
121
+ OutputParam(
122
+ "batch_size",
123
+ type_hint=int,
124
+ description="Batch size (typically 1 for RFDiffusion)",
125
+ ),
126
+ OutputParam(
127
+ "dtype",
128
+ type_hint=torch.dtype,
129
+ description="Data type for tensors",
130
+ ),
131
+ ]
132
+
133
+ def check_inputs(self, components, block_state):
134
+ if block_state.contigs is None:
135
+ raise ValueError("`contigs` must be provided to specify protein design regions")
136
+
137
+ @torch.no_grad()
138
+ def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
139
+ block_state = self.get_block_state(state)
140
+ self.check_inputs(components, block_state)
141
+
142
+ contigs = block_state.contigs
143
+ input_xyz = block_state.input_xyz
144
+
145
+ if isinstance(contigs, list):
146
+ contig_str = "/".join(contigs)
147
+ else:
148
+ contig_str = contigs
149
+
150
+ L, motif_ranges = parse_contig_string(contig_str)
151
+
152
+ motif_mask = torch.zeros(L, dtype=torch.bool)
153
+ for start, end in motif_ranges:
154
+ motif_mask[start:end] = True
155
+
156
+ if input_xyz is not None:
157
+ motif_xyz = input_xyz
158
+ else:
159
+ motif_xyz = None
160
+
161
+ block_state.motif_mask = motif_mask
162
+ block_state.motif_xyz = motif_xyz
163
+ block_state.L = L
164
+ block_state.batch_size = 1
165
+ block_state.dtype = torch.float32
166
+
167
+ self.set_block_state(state, block_state)
168
+ return components, state
169
+
170
+
171
+ class RFDiffusionSetTimestepsStep(ModularPipelineBlocks):
172
+ """
173
+ Set up the EDM noise schedule for RFDiffusion3.
174
+ """
175
+
176
+ model_name = "rfdiffusion"
177
+
178
+ @property
179
+ def description(self) -> str:
180
+ return "Sets up the EDM noise schedule matching the original inference sampler."
181
+
182
+ @property
183
+ def expected_components(self) -> List[ComponentSpec]:
184
+ return [
185
+ ComponentSpec("scheduler", description="RFDiffusion3 EDM scheduler"),
186
+ ]
187
+
188
+ @property
189
+ def inputs(self) -> List[InputParam]:
190
+ return [
191
+ InputParam(
192
+ "num_inference_steps",
193
+ default=None,
194
+ type_hint=int,
195
+ description="Number of denoising steps (default: use scheduler config)",
196
+ ),
197
+ InputParam("L", required=True, type_hint=int, description="Protein length"),
198
+ ]
199
+
200
+ @property
201
+ def intermediate_outputs(self) -> List[OutputParam]:
202
+ return [
203
+ OutputParam(
204
+ "noise_schedule",
205
+ type_hint=torch.Tensor,
206
+ description="EDM noise schedule [num_timesteps] from high to low noise",
207
+ ),
208
+ OutputParam(
209
+ "num_inference_steps",
210
+ type_hint=int,
211
+ description="Number of inference steps",
212
+ ),
213
+ ]
214
+
215
+ @torch.no_grad()
216
+ def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
217
+ block_state = self.get_block_state(state)
218
+
219
+ if hasattr(components, "scheduler") and components.scheduler is not None:
220
+ noise_schedule = components.scheduler.get_noise_schedule()
221
+ else:
222
+ # Fallback: simple linear schedule
223
+ noise_schedule = torch.linspace(160.0 * 16.0, 4e-4 * 16.0, 200)
224
+
225
+ block_state.noise_schedule = noise_schedule
226
+ block_state.num_inference_steps = len(noise_schedule)
227
+
228
+ self.set_block_state(state, block_state)
229
+ return components, state
230
+
231
+
232
+ class RFDiffusionPrepareLatentsStep(ModularPipelineBlocks):
233
+ """
234
+ Prepare initial noised coordinates for RFDiffusion3.
235
+
236
+ Matches the original _get_initial_structure:
237
+ noise = c0 * randn(D, L, 3)
238
+ noise[..., is_motif, :] = 0
239
+ X_L = noise + coord_motif
240
+ """
241
+
242
+ model_name = "rfdiffusion"
243
+
244
+ @property
245
+ def description(self) -> str:
246
+ return (
247
+ "Prepares initial coordinates by sampling Gaussian noise scaled by "
248
+ "the first noise schedule value, matching the original sampler."
249
+ )
250
+
251
+ @property
252
+ def expected_components(self) -> List[ComponentSpec]:
253
+ return [
254
+ ComponentSpec("scheduler", description="RFDiffusion3 EDM scheduler"),
255
+ ComponentSpec("transformer", description="RFDiffusion transformer model"),
256
+ ]
257
+
258
+ @property
259
+ def inputs(self) -> List[InputParam]:
260
+ return [
261
+ InputParam("generator", type_hint=torch.Generator, description="Random generator for reproducibility"),
262
+ InputParam("diffusion_batch_size", default=1, type_hint=int, description="Number of samples to generate in parallel"),
263
+ InputParam("L", required=True, type_hint=int, description="Protein length"),
264
+ InputParam("motif_mask", required=True, type_hint=torch.Tensor),
265
+ InputParam("motif_xyz", type_hint=torch.Tensor),
266
+ InputParam("noise_schedule", required=True, type_hint=torch.Tensor),
267
+ InputParam("dtype", type_hint=torch.dtype),
268
+ ]
269
+
270
+ @property
271
+ def intermediate_outputs(self) -> List[OutputParam]:
272
+ return [
273
+ OutputParam("xyz", type_hint=torch.Tensor, description="Initial noised coordinates [D, L, 3]"),
274
+ ]
275
+
276
+ @torch.no_grad()
277
+ def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
278
+ block_state = self.get_block_state(state)
279
+
280
+ L = block_state.L
281
+ motif_mask = block_state.motif_mask
282
+ motif_xyz = block_state.motif_xyz
283
+ noise_schedule = block_state.noise_schedule
284
+ dtype = block_state.dtype or torch.float32
285
+ generator = block_state.generator
286
+ D = block_state.diffusion_batch_size or 1
287
+
288
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
289
+
290
+ # Initial noise scaled by first noise level (c0), matching original:
291
+ # noise = c0 * randn(D, L, 3)
292
+ c0 = noise_schedule[0]
293
+ noise = c0 * torch.randn((D, L, 3), dtype=dtype, device=device, generator=generator)
294
+
295
+ # Zero out noise for motif atoms
296
+ if motif_mask is not None:
297
+ noise[:, motif_mask] = 0.0
298
+
299
+ # Build initial coordinates: motif coords + noise
300
+ coord_motif = torch.zeros((D, L, 3), dtype=dtype, device=device)
301
+ if motif_xyz is not None and motif_mask is not None:
302
+ motif_indices = motif_mask.nonzero(as_tuple=True)[0]
303
+ for i, idx in enumerate(motif_indices):
304
+ if i < motif_xyz.shape[0]:
305
+ coord_motif[:, idx] = motif_xyz[i].to(dtype=dtype, device=device)
306
+
307
+ xyz = noise + coord_motif
308
+
309
+ block_state.xyz = xyz
310
+
311
+ self.set_block_state(state, block_state)
312
+ return components, state
decoders.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Dhruv Nair. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import List, Optional
17
+
18
+ import numpy as np
19
+ import torch
20
+ from atomworks.io.utils.io_utils import to_cif_file
21
+ from biotite.structure import AtomArray, AtomArrayStack, stack
22
+
23
+ from diffusers.utils import logging
24
+ from diffusers.modular_pipelines import ModularPipeline, ModularPipelineBlocks, PipelineState
25
+ from diffusers.modular_pipelines.modular_pipeline_utils import InputParam, OutputParam
26
+
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ AA_NAMES = [
32
+ "ALA", "ARG", "ASN", "ASP", "CYS", "GLN", "GLU", "GLY", "HIS", "ILE",
33
+ "LEU", "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP", "TYR", "VAL",
34
+ "UNK",
35
+ ]
36
+
37
+
38
+ def _build_atom_array(xyz: torch.Tensor, seq: Optional[torch.Tensor] = None) -> AtomArray:
39
+ """
40
+ Build a biotite AtomArray from CA coordinates and optional sequence.
41
+
42
+ Args:
43
+ xyz: Coordinates for a single structure [L, 3].
44
+ seq: Sequence indices [L] (indexes into AA_NAMES).
45
+ """
46
+ xyz_np = xyz.detach().cpu().float().numpy()
47
+ L = xyz_np.shape[0]
48
+
49
+ arr = AtomArray(L)
50
+ arr.coord = xyz_np
51
+ arr.atom_name = np.full(L, "CA")
52
+ arr.element = np.full(L, "C")
53
+ arr.chain_id = np.full(L, "A")
54
+ arr.res_id = np.arange(1, L + 1)
55
+
56
+ if seq is not None:
57
+ seq_np = seq.detach().cpu().numpy()
58
+ arr.res_name = np.array([
59
+ AA_NAMES[int(idx)] if int(idx) < len(AA_NAMES) else "UNK"
60
+ for idx in seq_np
61
+ ])
62
+ else:
63
+ arr.res_name = np.full(L, "ALA")
64
+
65
+ return arr
66
+
67
+
68
+ def _build_atom_array_stack(
69
+ xyz: torch.Tensor,
70
+ seq: Optional[torch.Tensor] = None,
71
+ ) -> AtomArrayStack:
72
+ """
73
+ Build an AtomArrayStack from batched coordinates [B, L, 3].
74
+
75
+ Matches foundry ``build_stack_from_atom_array_and_batched_coords``.
76
+ """
77
+ template = _build_atom_array(xyz[0], seq[0] if seq is not None else None)
78
+ B = xyz.shape[0]
79
+ arr_stack = stack([template for _ in range(B)])
80
+ arr_stack.coord = xyz.detach().cpu().float().numpy()
81
+ return arr_stack
82
+
83
+
84
+ def _build_trajectory_stack(
85
+ trajectory: List[torch.Tensor],
86
+ seq: Optional[torch.Tensor] = None,
87
+ ) -> AtomArrayStack:
88
+ """
89
+ Build an AtomArrayStack from a denoising trajectory.
90
+
91
+ Each entry is [B, L, 3]; takes the first batch element per step.
92
+ """
93
+ coords = torch.stack([t[0] for t in trajectory]) # [N_steps, L, 3]
94
+ template = _build_atom_array(coords[0], seq[0] if seq is not None else None)
95
+ arr_stack = stack([template for _ in range(coords.shape[0])])
96
+ arr_stack.coord = coords.detach().cpu().float().numpy()
97
+ return arr_stack
98
+
99
+
100
+ @dataclass
101
+ class RFDiffusionPipelineOutput:
102
+ """Output class for RFDiffusion pipeline."""
103
+
104
+ xyz: torch.Tensor
105
+ atom_array: Optional[AtomArray] = None
106
+ atom_array_stack: Optional[AtomArrayStack] = None
107
+ trajectory_stack: Optional[AtomArrayStack] = None
108
+ sequence_indices: Optional[torch.Tensor] = None
109
+ sequence_logits: Optional[torch.Tensor] = None
110
+ single: Optional[torch.Tensor] = None
111
+ pair: Optional[torch.Tensor] = None
112
+ pdb_string: Optional[str] = None
113
+ trajectory: Optional[List[torch.Tensor]] = None
114
+
115
+
116
+ class RFDiffusionDecodeStep(ModularPipelineBlocks):
117
+ """
118
+ Decode step for RFDiffusion.
119
+
120
+ Converts denoised coordinates to final output format.
121
+
122
+ Supported ``output_type`` values:
123
+
124
+ - ``"tensor"`` β€” raw tensors only
125
+ - ``"pdb"`` β€” tensors + PDB format string
126
+ - ``"cif"`` β€” tensors + AtomArray via AtomWorks, writes ``.cif``
127
+ - ``"cif.gz"`` β€” same as ``"cif"`` but compressed
128
+ """
129
+
130
+ model_name = "rfdiffusion"
131
+
132
+ @property
133
+ def description(self) -> str:
134
+ return (
135
+ "Decode step that converts denoised coordinates to final output, "
136
+ "supporting tensor, PDB, and CIF (via AtomWorks) formats."
137
+ )
138
+
139
+ @property
140
+ def inputs(self) -> List[InputParam]:
141
+ return [
142
+ InputParam(
143
+ "output_type",
144
+ default="tensor",
145
+ type_hint=str,
146
+ description="Output format: 'tensor', 'pdb', 'cif', or 'cif.gz'",
147
+ ),
148
+ InputParam(
149
+ "output_path",
150
+ type_hint=str,
151
+ description="Path to save output structure",
152
+ ),
153
+ InputParam("xyz", required=True, type_hint=torch.Tensor, description="Denoised coordinates [B, L, 3]"),
154
+ InputParam("sequence_indices", type_hint=torch.Tensor, description="Predicted sequence [B, L]"),
155
+ InputParam("sequence_logits", type_hint=torch.Tensor, description="Sequence logits [B, L, n_aa]"),
156
+ InputParam("single", type_hint=torch.Tensor, description="Single representation"),
157
+ InputParam("pair", type_hint=torch.Tensor, description="Pair representation"),
158
+ InputParam("trajectory", type_hint=List[torch.Tensor], description="Denoising trajectory"),
159
+ ]
160
+
161
+ @property
162
+ def intermediate_outputs(self) -> List[OutputParam]:
163
+ return [
164
+ OutputParam("output", type_hint=RFDiffusionPipelineOutput, description="Final pipeline output"),
165
+ ]
166
+
167
+ @torch.no_grad()
168
+ def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
169
+ block_state = self.get_block_state(state)
170
+
171
+ xyz = block_state.xyz
172
+ sequence_indices = block_state.sequence_indices
173
+ sequence_logits = block_state.sequence_logits
174
+ single = block_state.single
175
+ pair = block_state.pair
176
+ trajectory = block_state.trajectory
177
+ output_type = block_state.output_type or "tensor"
178
+ output_path = block_state.output_path
179
+
180
+ pdb_string = None
181
+ atom_array = None
182
+ atom_array_stack = None
183
+ trajectory_stack = None
184
+
185
+ # Build AtomArray for CIF output types
186
+ if output_type in ("cif", "cif.gz"):
187
+ atom_array = _build_atom_array(
188
+ xyz[0], sequence_indices[0] if sequence_indices is not None else None,
189
+ )
190
+ if xyz.shape[0] > 1:
191
+ atom_array_stack = _build_atom_array_stack(xyz, sequence_indices)
192
+ if trajectory:
193
+ trajectory_stack = _build_trajectory_stack(trajectory, sequence_indices)
194
+
195
+ # Build PDB string
196
+ if output_type == "pdb":
197
+ pdb_string = self._coords_to_pdb(
198
+ xyz.squeeze(0),
199
+ sequence_indices.squeeze(0) if sequence_indices is not None else None,
200
+ )
201
+
202
+ # Write to disk
203
+ if output_path is not None:
204
+ import os
205
+ os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
206
+
207
+ if output_type in ("cif", "cif.gz"):
208
+ to_write = atom_array_stack if atom_array_stack is not None else atom_array
209
+ base = output_path.rsplit(".", 1)[0] if "." in output_path else output_path
210
+ to_cif_file(to_write, base, file_type=output_type, include_entity_poly=False)
211
+ if trajectory_stack is not None:
212
+ to_cif_file(trajectory_stack, base + "_trajectory", file_type=output_type, include_entity_poly=False)
213
+ elif output_type == "pdb" and pdb_string is not None:
214
+ with open(output_path, "w") as f:
215
+ f.write(pdb_string)
216
+
217
+ output = RFDiffusionPipelineOutput(
218
+ xyz=xyz,
219
+ atom_array=atom_array,
220
+ atom_array_stack=atom_array_stack,
221
+ trajectory_stack=trajectory_stack,
222
+ sequence_indices=sequence_indices,
223
+ sequence_logits=sequence_logits,
224
+ single=single,
225
+ pair=pair,
226
+ pdb_string=pdb_string,
227
+ trajectory=trajectory,
228
+ )
229
+
230
+ block_state.output = output
231
+ self.set_block_state(state, block_state)
232
+ return components, state
233
+
234
+ def _coords_to_pdb(
235
+ self,
236
+ xyz: torch.Tensor,
237
+ seq: Optional[torch.Tensor] = None,
238
+ ) -> str:
239
+ """Convert coordinates to PDB format string."""
240
+ xyz_np = xyz.cpu().numpy()
241
+ L = xyz_np.shape[0]
242
+
243
+ if seq is not None:
244
+ seq_np = seq.cpu().numpy()
245
+ else:
246
+ seq_np = None
247
+
248
+ lines = []
249
+ atom_idx = 1
250
+
251
+ for i in range(L):
252
+ if seq_np is not None:
253
+ aa_idx = int(seq_np[i])
254
+ aa_name = AA_NAMES[aa_idx] if aa_idx < len(AA_NAMES) else "UNK"
255
+ else:
256
+ aa_name = "ALA"
257
+
258
+ if xyz_np.ndim == 2:
259
+ x, y, z = xyz_np[i, :]
260
+ line = (
261
+ f"ATOM {atom_idx:5d} CA {aa_name:3s} A"
262
+ f"{i+1:4d} {x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00 C "
263
+ )
264
+ lines.append(line)
265
+ atom_idx += 1
266
+ else:
267
+ for j, atom_name in enumerate(["N", "CA", "C"]):
268
+ if j >= xyz_np.shape[1]:
269
+ break
270
+ x, y, z = xyz_np[i, j, :]
271
+
272
+ line = (
273
+ f"ATOM {atom_idx:5d} {atom_name:<3s} {aa_name:3s} A"
274
+ f"{i+1:4d} {x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00 "
275
+ f"{atom_name[0]:>2s} "
276
+ )
277
+ lines.append(line)
278
+ atom_idx += 1
279
+
280
+ lines.append("END")
281
+ return "\n".join(lines)
denoise.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Dhruv Nair. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Denoising loop for RFDiffusion3.
17
+
18
+ Implements the iterative denoising procedure from the original
19
+ inference_sampler.py (SampleDiffusionWithMotif.sample_diffusion_like_af3).
20
+
21
+ The loop iterates over consecutive pairs (c_t_minus_1, c_t) in the noise schedule:
22
+ 1. Inject stochastic noise: t_hat = c_t_minus_1 * (gamma + 1), epsilon ~ N(0, noise_scale * sqrt(t_hat^2 - c_t_minus_1^2))
23
+ 2. Call model: X_denoised = model(X_noisy, t_hat)
24
+ 3. Euler update: X = X_noisy + step_scale * (c_t - t_hat) * (X_noisy - X_denoised) / t_hat
25
+ """
26
+
27
+ from typing import Callable, List
28
+
29
+ import torch
30
+
31
+ from diffusers.utils import logging
32
+ from diffusers.modular_pipelines import ModularPipeline, ModularPipelineBlocks, PipelineState
33
+ from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
34
+
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+
39
+ class RFDiffusionDenoiseStep(ModularPipelineBlocks):
40
+ """
41
+ Iterative denoising step for RFDiffusion3.
42
+
43
+ Implements the EDM stochastic sampling loop matching the original
44
+ SampleDiffusionWithMotif.sample_diffusion_like_af3.
45
+ """
46
+
47
+ model_name = "rfdiffusion"
48
+
49
+ @property
50
+ def description(self) -> str:
51
+ return (
52
+ "Iteratively denoise protein structure through reverse diffusion. "
53
+ "Uses EDM stochastic sampling with gamma noise injection and step scaling."
54
+ )
55
+
56
+ @property
57
+ def expected_components(self) -> List[ComponentSpec]:
58
+ return [
59
+ ComponentSpec("transformer", description="RFDiffusion transformer for structure prediction"),
60
+ ComponentSpec("scheduler", description="Scheduler for noise injection and stepping"),
61
+ ]
62
+
63
+ @property
64
+ def inputs(self) -> List[InputParam]:
65
+ return [
66
+ InputParam(
67
+ "n_recycle",
68
+ default=None,
69
+ type_hint=int,
70
+ description="Number of recycling iterations (None uses model default)",
71
+ ),
72
+ InputParam(
73
+ "callback",
74
+ type_hint=Callable,
75
+ description="Optional callback function called at each step",
76
+ ),
77
+ InputParam(
78
+ "callback_steps",
79
+ default=1,
80
+ type_hint=int,
81
+ description="Frequency of callback invocation",
82
+ ),
83
+ InputParam("xyz", required=True, type_hint=torch.Tensor, description="Initial noised coordinates [D, L, 3]"),
84
+ InputParam("noise_schedule", required=True, type_hint=torch.Tensor, description="EDM noise schedule"),
85
+ InputParam("motif_mask", required=True, type_hint=torch.Tensor, description="Mask for fixed motif positions"),
86
+ ]
87
+
88
+ @property
89
+ def intermediate_outputs(self) -> List[OutputParam]:
90
+ return [
91
+ OutputParam("xyz", type_hint=torch.Tensor, description="Denoised coordinates [D, L, 3]"),
92
+ OutputParam("single", type_hint=torch.Tensor, description="Single representation"),
93
+ OutputParam("pair", type_hint=torch.Tensor, description="Pair representation"),
94
+ OutputParam("sequence_logits", type_hint=torch.Tensor, description="Predicted sequence logits"),
95
+ OutputParam("sequence_indices", type_hint=torch.Tensor, description="Predicted sequence indices"),
96
+ OutputParam("trajectory", type_hint=List[torch.Tensor], description="Denoising trajectory"),
97
+ ]
98
+
99
+ @torch.no_grad()
100
+ def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
101
+ block_state = self.get_block_state(state)
102
+
103
+ xyz = block_state.xyz
104
+ noise_schedule = block_state.noise_schedule
105
+ motif_mask = block_state.motif_mask
106
+
107
+ n_recycle = block_state.n_recycle
108
+ callback = block_state.callback
109
+ callback_steps = block_state.callback_steps or 1
110
+
111
+ X_denoised_L_traj = []
112
+ X_L = xyz.clone()
113
+ D = X_L.shape[0]
114
+ device = X_L.device
115
+
116
+ # Ensure all tensors are on the same device as xyz
117
+ noise_schedule = noise_schedule.to(device)
118
+ if motif_mask is not None:
119
+ motif_mask = motif_mask.to(device)
120
+
121
+ single = None
122
+ pair = None
123
+ sequence_logits = None
124
+ sequence_indices = None
125
+
126
+ has_transformer = hasattr(components, "transformer") and components.transformer is not None
127
+ has_scheduler = hasattr(components, "scheduler") and components.scheduler is not None
128
+
129
+ # Iterate over consecutive pairs (c_t_minus_1, c_t) in the noise schedule
130
+ # noise_schedule goes from high noise to low noise
131
+ for step_num in range(len(noise_schedule) - 1):
132
+ c_t_minus_1 = noise_schedule[step_num]
133
+ c_t = noise_schedule[step_num + 1]
134
+
135
+ # Step 1: Inject stochastic noise (matching original sampler)
136
+ if has_scheduler:
137
+ X_noisy_L, t_hat = components.scheduler.add_noise(
138
+ X_L, c_t_minus_1, c_t, motif_mask=motif_mask
139
+ )
140
+ else:
141
+ X_noisy_L = X_L
142
+ t_hat = c_t_minus_1
143
+
144
+ # Step 2: Model forward pass
145
+ if has_transformer:
146
+ # t_hat is a scalar, tile to batch dimension
147
+ t_batch = (t_hat.to(device).expand(D) if isinstance(t_hat, torch.Tensor)
148
+ else torch.full((D,), t_hat, device=device))
149
+
150
+ output = components.transformer(
151
+ xyz_noisy=X_noisy_L,
152
+ t=t_batch,
153
+ motif_mask=motif_mask,
154
+ n_recycle=n_recycle,
155
+ )
156
+
157
+ X_denoised_L = output.xyz
158
+ single = output.single
159
+ pair = output.pair
160
+ sequence_logits = output.sequence_logits
161
+ sequence_indices = output.sequence_indices
162
+ else:
163
+ X_denoised_L = X_noisy_L
164
+
165
+ # Step 3: Euler update with step_scale (matching original sampler)
166
+ if has_scheduler:
167
+ X_L = components.scheduler.step(
168
+ xyz_pred=X_denoised_L,
169
+ xyz_noisy=X_noisy_L,
170
+ c_t_minus_1=c_t_minus_1,
171
+ c_t=c_t,
172
+ motif_mask=motif_mask,
173
+ )
174
+ else:
175
+ # Fallback simple Euler step
176
+ delta_L = (X_noisy_L - X_denoised_L) / (t_hat + 1e-8)
177
+ d_t = c_t - t_hat
178
+ X_L = X_noisy_L + d_t * delta_L
179
+
180
+ X_denoised_L_traj.append(X_denoised_L.clone())
181
+
182
+ if callback is not None and step_num % callback_steps == 0:
183
+ callback(step_num, c_t_minus_1, X_L)
184
+
185
+ block_state.xyz = X_L
186
+ block_state.single = single
187
+ block_state.pair = pair
188
+ block_state.sequence_logits = sequence_logits
189
+ block_state.sequence_indices = sequence_indices
190
+ block_state.trajectory = X_denoised_L_traj
191
+
192
+ self.set_block_state(state, block_state)
193
+ return components, state
modular_blocks.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Dhruv Nair. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import List, Optional
17
+
18
+ import torch
19
+
20
+ from diffusers.utils import logging
21
+ from diffusers.modular_pipelines import (
22
+ AutoPipelineBlocks,
23
+ ModularPipeline,
24
+ ModularPipelineBlocks,
25
+ PipelineState,
26
+ SequentialPipelineBlocks,
27
+ )
28
+ from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, InsertableDict, OutputParam
29
+
30
+ from .before_denoise import (
31
+ RFDiffusionInputStep,
32
+ RFDiffusionPrepareLatentsStep,
33
+ RFDiffusionSetTimestepsStep,
34
+ )
35
+ from .decoders import RFDiffusionDecodeStep
36
+ from .denoise import RFDiffusionDenoiseStep
37
+
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+
42
+ # ─── Amino acid mappings (used by MPNN blocks) ─────────────────────────
43
+
44
+ THREE_TO_ONE = {
45
+ "ALA": "A", "ARG": "R", "ASN": "N", "ASP": "D", "CYS": "C",
46
+ "GLN": "Q", "GLU": "E", "GLY": "G", "HIS": "H", "ILE": "I",
47
+ "LEU": "L", "LYS": "K", "MET": "M", "PHE": "F", "PRO": "P",
48
+ "SER": "S", "THR": "T", "TRP": "W", "TYR": "Y", "VAL": "V",
49
+ "UNK": "X",
50
+ }
51
+
52
+ AA_NAMES = list(THREE_TO_ONE.keys())
53
+
54
+
55
+ # ═══════════════════════════════════════════════════════════════════════════
56
+ # RFDiffusion blocks
57
+ # ═══════════════════════════════════════════════════════════════════════════
58
+
59
+
60
+ class RFDiffusionBeforeDenoiseStep(SequentialPipelineBlocks):
61
+ """Sequential block for pre-denoising preparation."""
62
+
63
+ block_classes = [
64
+ RFDiffusionInputStep,
65
+ RFDiffusionSetTimestepsStep,
66
+ RFDiffusionPrepareLatentsStep,
67
+ ]
68
+ block_names = ["input", "set_timesteps", "prepare_latents"]
69
+
70
+ @property
71
+ def description(self):
72
+ return (
73
+ "Before denoise step that prepares the inputs for the denoise step.\n"
74
+ "This is a sequential pipeline blocks:\n"
75
+ " - `RFDiffusionInputStep` processes contigs and prepares input features\n"
76
+ " - `RFDiffusionSetTimestepsStep` sets up the diffusion timesteps\n"
77
+ " - `RFDiffusionPrepareLatentsStep` initializes noised coordinates\n"
78
+ )
79
+
80
+
81
+ class RFDiffusionAutoBeforeDenoiseStep(AutoPipelineBlocks):
82
+ """Auto-select before denoise step based on task."""
83
+
84
+ block_classes = [RFDiffusionBeforeDenoiseStep]
85
+ block_names = ["unconditional"]
86
+ block_trigger_inputs = [None]
87
+
88
+ @property
89
+ def description(self):
90
+ return (
91
+ "Before denoise step that prepares the inputs for the denoise step.\n"
92
+ "This is an auto pipeline block for protein structure generation.\n"
93
+ " - `RFDiffusionBeforeDenoiseStep` (unconditional) is used.\n"
94
+ )
95
+
96
+
97
+ class RFDiffusionAutoDenoiseStep(AutoPipelineBlocks):
98
+ """Auto-select denoise step."""
99
+
100
+ block_classes = [RFDiffusionDenoiseStep]
101
+ block_names = ["denoise"]
102
+ block_trigger_inputs = [None]
103
+
104
+ @property
105
+ def description(self) -> str:
106
+ return (
107
+ "Denoise step that iteratively denoises the protein structure. "
108
+ "This is an auto pipeline block for protein structure generation. "
109
+ " - `RFDiffusionDenoiseStep` (denoise) for structure generation."
110
+ )
111
+
112
+
113
+ class RFDiffusionAutoDecodeStep(AutoPipelineBlocks):
114
+ """Auto-select decode step."""
115
+
116
+ block_classes = [RFDiffusionDecodeStep]
117
+ block_names = ["decode"]
118
+ block_trigger_inputs = [None]
119
+
120
+ @property
121
+ def description(self):
122
+ return "Decode step that converts denoised coordinates to PDB output.\n - `RFDiffusionDecodeStep`"
123
+
124
+
125
+ # ═══════════════════════════════════════════════════════════════════════════
126
+ # MPNN blocks (defined before RFDiffusionAutoBlocks which references them)
127
+ # ═══════════════════════════════════════════════════════════════════════════
128
+
129
+
130
+ @dataclass
131
+ class MPNNPipelineOutput:
132
+ """Output from ProteinMPNN / LigandMPNN sequence design."""
133
+
134
+ designed_sequence: str
135
+ sequence_indices: torch.Tensor # [B, L] token indices
136
+ sequence_logits: torch.Tensor # [B, L, n_vocab] logits
137
+ xyz: torch.Tensor # [B, L, 3] input structure (passed through)
138
+ pdb_string: Optional[str] = None # PDB with designed sequence
139
+ sequence_recovery: Optional[float] = None
140
+
141
+
142
+ class MPNNSequenceDesignStep(ModularPipelineBlocks):
143
+ """
144
+ Design sequences for a protein backbone using ProteinMPNN / LigandMPNN.
145
+
146
+ Takes ``xyz`` coordinates (typically from an upstream RFDiffusion denoise
147
+ step) and runs the ``MPNNModel`` to produce amino acid sequences for
148
+ the designable regions.
149
+
150
+ When no ``mpnn`` component is loaded, falls back to using the sequence
151
+ predictions from upstream RFDiffusion (or glycine everywhere).
152
+ """
153
+
154
+ model_name = "mpnn"
155
+
156
+ @property
157
+ def description(self) -> str:
158
+ return (
159
+ "Design amino acid sequences for protein backbones using "
160
+ "ProteinMPNN or LigandMPNN. Accepts structure coordinates "
161
+ "from an upstream diffusion step."
162
+ )
163
+
164
+ @property
165
+ def expected_components(self) -> List[ComponentSpec]:
166
+ return [
167
+ ComponentSpec("mpnn", description="MPNNModel (ProteinMPNN or LigandMPNN)"),
168
+ ]
169
+
170
+ @property
171
+ def inputs(self) -> List[InputParam]:
172
+ return [
173
+ InputParam(
174
+ "xyz", required=True, type_hint=torch.Tensor,
175
+ description="Protein backbone coordinates [B, L, 3] (CA atoms)",
176
+ ),
177
+ InputParam(
178
+ "motif_mask", type_hint=torch.Tensor,
179
+ description="Mask for fixed/motif positions [L]. True = fixed sequence.",
180
+ ),
181
+ InputParam(
182
+ "sequence_indices", type_hint=torch.Tensor,
183
+ description="Known sequence indices for motif positions [B, L] (from RFDiffusion)",
184
+ ),
185
+ InputParam(
186
+ "temperature", default=0.1, type_hint=float,
187
+ description="Sampling temperature (lower = more deterministic)",
188
+ ),
189
+ InputParam(
190
+ "num_designs", default=1, type_hint=int,
191
+ description="Number of sequence designs to generate per structure",
192
+ ),
193
+ InputParam(
194
+ "output_type", default="tensor", type_hint=str,
195
+ description="'tensor', 'pdb', or 'cif'",
196
+ ),
197
+ InputParam(
198
+ "output_path", type_hint=str,
199
+ description="Path to save designed PDB",
200
+ ),
201
+ ]
202
+
203
+ @property
204
+ def intermediate_outputs(self) -> List[OutputParam]:
205
+ return [
206
+ OutputParam(
207
+ "mpnn_output", type_hint=MPNNPipelineOutput,
208
+ description="MPNN sequence design output",
209
+ ),
210
+ ]
211
+
212
+ @torch.no_grad()
213
+ def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
214
+ block_state = self.get_block_state(state)
215
+
216
+ xyz = block_state.xyz
217
+ motif_mask = block_state.motif_mask
218
+ known_seq = block_state.sequence_indices
219
+ temperature = block_state.temperature or 0.1
220
+ output_type = block_state.output_type or "tensor"
221
+ output_path = block_state.output_path
222
+
223
+ B, L, _ = xyz.shape
224
+ device = xyz.device
225
+
226
+ has_mpnn = hasattr(components, "mpnn") and components.mpnn is not None
227
+
228
+ if has_mpnn:
229
+ sequence_logits, sequence_indices = self._run_mpnn(
230
+ components.mpnn, xyz, motif_mask, known_seq, temperature,
231
+ )
232
+ else:
233
+ if known_seq is not None:
234
+ sequence_indices = known_seq
235
+ else:
236
+ sequence_indices = torch.full((B, L), 7, dtype=torch.long, device=device) # GLY
237
+ sequence_logits = torch.zeros(B, L, len(AA_NAMES), device=device)
238
+ sequence_logits.scatter_(2, sequence_indices.unsqueeze(-1), 1.0)
239
+
240
+ seq_list = sequence_indices[0].cpu().tolist()
241
+ designed_sequence = "".join(
242
+ THREE_TO_ONE.get(AA_NAMES[min(idx, len(AA_NAMES) - 1)], "X")
243
+ for idx in seq_list
244
+ )
245
+
246
+ pdb_string = None
247
+ if output_type in ("pdb",):
248
+ pdb_string = self._coords_to_pdb(xyz[0], sequence_indices[0])
249
+ if output_path:
250
+ import os
251
+ os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
252
+ with open(output_path, "w") as f:
253
+ f.write(pdb_string)
254
+
255
+ output = MPNNPipelineOutput(
256
+ designed_sequence=designed_sequence,
257
+ sequence_indices=sequence_indices,
258
+ sequence_logits=sequence_logits,
259
+ xyz=xyz,
260
+ pdb_string=pdb_string,
261
+ )
262
+
263
+ block_state.mpnn_output = output
264
+ self.set_block_state(state, block_state)
265
+ return components, state
266
+
267
+ def _run_mpnn(self, mpnn, xyz, motif_mask, known_seq, temperature):
268
+ """Run the MPNNModel wrapper on backbone coordinates."""
269
+ B, L, _ = xyz.shape
270
+ device = xyz.device
271
+ dtype = xyz.dtype
272
+
273
+ ca = xyz
274
+ n_offset = torch.tensor([-1.458, 0.0, 0.0], device=device, dtype=dtype)
275
+ c_offset = torch.tensor([0.550, 1.424, 0.0], device=device, dtype=dtype)
276
+ o_offset = torch.tensor([0.550, 2.500, 0.0], device=device, dtype=dtype)
277
+
278
+ X = torch.stack([
279
+ ca + n_offset, ca, ca + c_offset, ca + o_offset,
280
+ ], dim=2)
281
+
282
+ if motif_mask is not None:
283
+ designed_mask = ~motif_mask.unsqueeze(0).expand(B, -1)
284
+ else:
285
+ designed_mask = None
286
+
287
+ output = mpnn(
288
+ X=X, S=known_seq, designed_residue_mask=designed_mask, temperature=temperature,
289
+ )
290
+
291
+ logits = output.sequence_logits
292
+ indices = output.sequence_indices
293
+
294
+ if motif_mask is not None and known_seq is not None:
295
+ indices[:, motif_mask] = known_seq[:, motif_mask]
296
+
297
+ return logits, indices
298
+
299
+ def _coords_to_pdb(self, xyz: torch.Tensor, seq: torch.Tensor) -> str:
300
+ xyz_np = xyz.cpu().numpy()
301
+ seq_np = seq.cpu().numpy()
302
+ L = xyz_np.shape[0]
303
+ lines = []
304
+ for i in range(L):
305
+ aa_idx = int(seq_np[i])
306
+ aa_name = AA_NAMES[min(aa_idx, len(AA_NAMES) - 1)]
307
+ x, y, z = xyz_np[i, :]
308
+ lines.append(
309
+ f"ATOM {i+1:5d} CA {aa_name:3s} A{i+1:4d} "
310
+ f"{x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00 C "
311
+ )
312
+ lines.append("END")
313
+ return "\n".join(lines)
314
+
315
+
316
+ class MPNNAutoDesignStep(AutoPipelineBlocks):
317
+ """Auto-select MPNN design step."""
318
+
319
+ block_classes = [MPNNSequenceDesignStep]
320
+ block_names = ["sequence_design"]
321
+ block_trigger_inputs = [None]
322
+
323
+ @property
324
+ def description(self) -> str:
325
+ return "Sequence design using ProteinMPNN or LigandMPNN."
326
+
327
+
328
+ # ═══════════════════════════════════════════════════════════════════════════
329
+ # Top-level pipeline blocks
330
+ # ═══════════════════════════════════════════════════════════════════════════
331
+
332
+
333
+ class RFDiffusionAutoBlocks(SequentialPipelineBlocks):
334
+ """
335
+ Full protein design pipeline: RFDiffusion3 + optional ProteinMPNN/LigandMPNN.
336
+
337
+ The active workflow is selected by trigger inputs:
338
+ - ``contigs`` only β†’ structure generation
339
+ - ``contigs`` + ``temperature`` β†’ structure + sequence design
340
+ - ``contigs`` + ``input_xyz`` + ``temperature`` β†’ motif-conditioned + sequence design
341
+
342
+ The MPNN step is skipped when ``temperature`` is not provided or when
343
+ no ``mpnn`` component is loaded.
344
+ """
345
+
346
+ block_classes = [
347
+ RFDiffusionAutoBeforeDenoiseStep,
348
+ RFDiffusionAutoDenoiseStep,
349
+ RFDiffusionAutoDecodeStep,
350
+ MPNNAutoDesignStep,
351
+ ]
352
+ block_names = [
353
+ "before_denoise",
354
+ "denoise",
355
+ "decoder",
356
+ "sequence_design",
357
+ ]
358
+
359
+ _workflow_map = {
360
+ "structure_only": {
361
+ "contigs": True,
362
+ },
363
+ "structure_and_sequence": {
364
+ "contigs": True,
365
+ "temperature": True,
366
+ },
367
+ "motif_structure_and_sequence": {
368
+ "contigs": True,
369
+ "input_xyz": True,
370
+ "temperature": True,
371
+ },
372
+ }
373
+
374
+ @property
375
+ def description(self):
376
+ return (
377
+ "Modular pipeline for protein design using RFDiffusion3.\n"
378
+ "Workflows:\n"
379
+ " - structure_only: backbone generation\n"
380
+ " - structure_and_sequence: backbone + MPNN sequence design\n"
381
+ " - motif_structure_and_sequence: motif-conditioned + MPNN\n"
382
+ )
383
+
384
+
385
+ # ═══════════════════════════════════════════════════════════════════════════
386
+ # Block registries
387
+ # ═══════════════════════════════════════════════════════════════════════════
388
+
389
+
390
+ UNCONDITIONAL_BLOCKS = InsertableDict(
391
+ [
392
+ ("input", RFDiffusionInputStep),
393
+ ("set_timesteps", RFDiffusionSetTimestepsStep),
394
+ ("prepare_latents", RFDiffusionPrepareLatentsStep),
395
+ ("denoise", RFDiffusionDenoiseStep),
396
+ ("decode", RFDiffusionDecodeStep),
397
+ ("sequence_design", MPNNSequenceDesignStep),
398
+ ]
399
+ )
400
+
401
+ AUTO_BLOCKS = InsertableDict(
402
+ [
403
+ ("before_denoise", RFDiffusionAutoBeforeDenoiseStep),
404
+ ("denoise", RFDiffusionAutoDenoiseStep),
405
+ ("decode", RFDiffusionAutoDecodeStep),
406
+ ("sequence_design", MPNNAutoDesignStep),
407
+ ]
408
+ )
409
+
410
+ ALL_BLOCKS = {
411
+ "unconditional": UNCONDITIONAL_BLOCKS,
412
+ "auto": AUTO_BLOCKS,
413
+ }
modular_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "RFDiffusionAutoBlocks",
3
+ "_diffusers_version": "0.37.0.dev0",
4
+ "auto_map": {
5
+ "ModularPipelineBlocks": "modular_blocks.RFDiffusionAutoBlocks"
6
+ }
7
+ }
modular_model_index.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_blocks_class_name": "RFDiffusionAutoBlocks",
3
+ "_class_name": "ModularPipeline",
4
+ "_diffusers_version": "0.37.0.dev0",
5
+ "default_num_residues": 100,
6
+ "default_num_timesteps": 200,
7
+ "sigma_data": 16.0,
8
+ "transformer": [
9
+ null,
10
+ null,
11
+ {
12
+ "pretrained_model_name_or_path": "dn6/RFDiffusion-3",
13
+ "subfolder": "transformer",
14
+ "type_hint": [
15
+ "diffusers",
16
+ "AutoModel"
17
+ ],
18
+ "revision": null,
19
+ "variant": null
20
+ }
21
+ ],
22
+ "scheduler": [
23
+ null,
24
+ null,
25
+ {
26
+ "pretrained_model_name_or_path": "dn6/RFDiffusion-3",
27
+ "subfolder": "scheduler",
28
+ "type_hint": [
29
+ "diffusers",
30
+ "AutoModel"
31
+ ],
32
+ "revision": null,
33
+ "variant": null,
34
+ "default_creation_method": "from_config"
35
+ }
36
+ ]
37
+ }
mpnn/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Dhruv Nair. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .model_mpnn import MPNNModel, MPNNModelOutput
mpnn/config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "MPNNModel",
3
+ "_diffusers_version": "0.37.0.dev0",
4
+ "auto_map": {
5
+ "AutoModel": "model_mpnn.MPNNModel"
6
+ },
7
+ "model_type": "protein_mpnn",
8
+ "hidden_dim": 128,
9
+ "num_encoder_layers": 3,
10
+ "num_decoder_layers": 3,
11
+ "num_neighbors": 48,
12
+ "dropout_rate": 0.1,
13
+ "num_positional_embeddings": 16,
14
+ "min_rbf_mean": 2.0,
15
+ "max_rbf_mean": 22.0,
16
+ "num_rbf": 16
17
+ }
mpnn/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f94bbbaa904c3554d3b10397dce3c90e01adb725fd80e3275ff48f11cc4745f
3
+ size 6653924
mpnn/model_mpnn.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Dhruv Nair. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ ProteinMPNN / LigandMPNN model wrapper.
17
+
18
+ A thin diffusers-compatible wrapper around the foundry MPNN model,
19
+ following the same pattern as the transformer and scheduler wrappers.
20
+ Reuses the foundry model implementation directly, adding only the
21
+ ModelMixin/ConfigMixin interface for diffusers integration.
22
+ """
23
+
24
+ from dataclasses import dataclass
25
+ from typing import Optional
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+
30
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
31
+ from diffusers.models.modeling_utils import ModelMixin
32
+
33
+ from mpnn.model.mpnn import LigandMPNN, ProteinMPNN
34
+
35
+
36
+ MODEL_CLASSES = {
37
+ "protein_mpnn": ProteinMPNN,
38
+ "ligand_mpnn": LigandMPNN,
39
+ }
40
+
41
+
42
+ @dataclass
43
+ class MPNNModelOutput:
44
+ """Output from the MPNN model wrapper."""
45
+
46
+ sequence_logits: torch.Tensor # [B, L, n_vocab]
47
+ sequence_indices: torch.Tensor # [B, L]
48
+ decoder_features: dict # full decoder output dict
49
+
50
+
51
+ class MPNNModel(ModelMixin, ConfigMixin):
52
+ """
53
+ Diffusers-compatible wrapper around the foundry ProteinMPNN / LigandMPNN.
54
+
55
+ Wraps `mpnn.model.mpnn.ProteinMPNN` (or `LigandMPNN`) to provide a
56
+ diffusers ModelMixin/ConfigMixin interface. All model logic is delegated
57
+ to the foundry implementation.
58
+
59
+ State dict keys match the foundry checkpoint format via the `model.*`
60
+ prefix (stripped on load).
61
+ """
62
+
63
+ config_name = "config.json"
64
+
65
+ @register_to_config
66
+ def __init__(
67
+ self,
68
+ model_type: str = "protein_mpnn",
69
+ hidden_dim: int = 128,
70
+ num_encoder_layers: int = 3,
71
+ num_decoder_layers: int = 3,
72
+ num_neighbors: int = 48,
73
+ dropout_rate: float = 0.1,
74
+ num_positional_embeddings: int = 16,
75
+ min_rbf_mean: float = 2.0,
76
+ max_rbf_mean: float = 22.0,
77
+ num_rbf: int = 16,
78
+ # LigandMPNN-specific
79
+ num_context_atoms: int = 25,
80
+ num_context_encoding_layers: int = 2,
81
+ ):
82
+ super().__init__()
83
+
84
+ model_cls = MODEL_CLASSES.get(model_type)
85
+ if model_cls is None:
86
+ raise ValueError(
87
+ f"Unknown model_type '{model_type}'. "
88
+ f"Choose from: {list(MODEL_CLASSES.keys())}"
89
+ )
90
+
91
+ common_kwargs = dict(
92
+ num_node_features=hidden_dim,
93
+ num_edge_features=hidden_dim,
94
+ hidden_dim=hidden_dim,
95
+ num_encoder_layers=num_encoder_layers,
96
+ num_decoder_layers=num_decoder_layers,
97
+ num_neighbors=num_neighbors,
98
+ dropout_rate=dropout_rate,
99
+ num_positional_embeddings=num_positional_embeddings,
100
+ min_rbf_mean=min_rbf_mean,
101
+ max_rbf_mean=max_rbf_mean,
102
+ num_rbf=num_rbf,
103
+ )
104
+
105
+ if model_type == "ligand_mpnn":
106
+ common_kwargs["num_context_atoms"] = num_context_atoms
107
+ common_kwargs["num_context_encoding_layers"] = num_context_encoding_layers
108
+
109
+ self.model = model_cls(**common_kwargs)
110
+
111
+ def forward(
112
+ self,
113
+ X: torch.Tensor,
114
+ S: Optional[torch.Tensor] = None,
115
+ residue_mask: Optional[torch.Tensor] = None,
116
+ designed_residue_mask: Optional[torch.Tensor] = None,
117
+ chain_labels: Optional[torch.Tensor] = None,
118
+ R_idx: Optional[torch.Tensor] = None,
119
+ temperature: float = 0.1,
120
+ **kwargs,
121
+ ) -> MPNNModelOutput:
122
+ """
123
+ Run ProteinMPNN / LigandMPNN sequence design.
124
+
125
+ Args:
126
+ X: Backbone atom coordinates [B, L, num_atoms, 3].
127
+ For ProteinMPNN: num_atoms=4 (N, CA, C, O).
128
+ S: Ground-truth sequence tokens [B, L] (optional, for teacher forcing).
129
+ residue_mask: Valid residue mask [B, L] (default: all valid).
130
+ designed_residue_mask: Which residues to design [B, L] (default: all).
131
+ chain_labels: Chain identifiers [B, L] (default: single chain).
132
+ R_idx: Residue indices [B, L] (default: 0..L-1).
133
+ temperature: Sampling temperature (default: 0.1).
134
+
135
+ Returns:
136
+ MPNNModelOutput with sequence logits and sampled indices.
137
+ """
138
+ B, L = X.shape[0], X.shape[1]
139
+ device = X.device
140
+
141
+ if S is None:
142
+ S = torch.zeros(B, L, dtype=torch.long, device=device)
143
+ if residue_mask is None:
144
+ residue_mask = torch.ones(B, L, dtype=torch.bool, device=device)
145
+ if designed_residue_mask is None:
146
+ designed_residue_mask = torch.ones(B, L, dtype=torch.bool, device=device)
147
+ if chain_labels is None:
148
+ chain_labels = torch.zeros(B, L, dtype=torch.long, device=device)
149
+ if R_idx is None:
150
+ R_idx = torch.arange(L, device=device).unsqueeze(0).expand(B, -1)
151
+
152
+ # Atom mask: mark all atoms as valid based on coordinate presence
153
+ X_m = (X.abs().sum(dim=-1) > 0).float() # [B, L, num_atoms]
154
+
155
+ network_input = {
156
+ "X": X,
157
+ "X_m": X_m,
158
+ "S": S,
159
+ "R_idx": R_idx,
160
+ "chain_labels": chain_labels,
161
+ "residue_mask": residue_mask,
162
+ "designed_residue_mask": designed_residue_mask,
163
+ "temperature": temperature,
164
+ **kwargs,
165
+ }
166
+
167
+ output = self.model(network_input)
168
+
169
+ logits = output["decoder_features"]["logits"] # [B, L, n_vocab]
170
+ S_sampled = output["decoder_features"].get(
171
+ "S_sampled", logits.argmax(dim=-1)
172
+ )
173
+
174
+ return MPNNModelOutput(
175
+ sequence_logits=logits,
176
+ sequence_indices=S_sampled,
177
+ decoder_features=output["decoder_features"],
178
+ )
mpnn_ligand/config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "MPNNModel",
3
+ "_diffusers_version": "0.37.0.dev0",
4
+ "auto_map": {
5
+ "AutoModel": "model_mpnn.MPNNModel"
6
+ },
7
+ "model_type": "ligand_mpnn",
8
+ "hidden_dim": 128,
9
+ "num_encoder_layers": 3,
10
+ "num_decoder_layers": 3,
11
+ "num_neighbors": 32,
12
+ "dropout_rate": 0.1,
13
+ "num_positional_embeddings": 16,
14
+ "min_rbf_mean": 2.0,
15
+ "max_rbf_mean": 22.0,
16
+ "num_rbf": 16,
17
+ "num_context_atoms": 25,
18
+ "num_context_encoding_layers": 2
19
+ }
mpnn_ligand/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:63d94e809089821f2f01c965b8f3a5e736d6b9b53c378fc27d42f47376e18964
3
+ size 10497908
mpnn_ligand/model_mpnn.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Dhruv Nair. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ ProteinMPNN / LigandMPNN model wrapper.
17
+
18
+ A thin diffusers-compatible wrapper around the foundry MPNN model,
19
+ following the same pattern as the transformer and scheduler wrappers.
20
+ Reuses the foundry model implementation directly, adding only the
21
+ ModelMixin/ConfigMixin interface for diffusers integration.
22
+ """
23
+
24
+ from dataclasses import dataclass
25
+ from typing import Optional
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+
30
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
31
+ from diffusers.models.modeling_utils import ModelMixin
32
+
33
+ from mpnn.model.mpnn import LigandMPNN, ProteinMPNN
34
+
35
+
36
+ MODEL_CLASSES = {
37
+ "protein_mpnn": ProteinMPNN,
38
+ "ligand_mpnn": LigandMPNN,
39
+ }
40
+
41
+
42
+ @dataclass
43
+ class MPNNModelOutput:
44
+ """Output from the MPNN model wrapper."""
45
+
46
+ sequence_logits: torch.Tensor # [B, L, n_vocab]
47
+ sequence_indices: torch.Tensor # [B, L]
48
+ decoder_features: dict # full decoder output dict
49
+
50
+
51
+ class MPNNModel(ModelMixin, ConfigMixin):
52
+ """
53
+ Diffusers-compatible wrapper around the foundry ProteinMPNN / LigandMPNN.
54
+
55
+ Wraps `mpnn.model.mpnn.ProteinMPNN` (or `LigandMPNN`) to provide a
56
+ diffusers ModelMixin/ConfigMixin interface. All model logic is delegated
57
+ to the foundry implementation.
58
+
59
+ State dict keys match the foundry checkpoint format via the `model.*`
60
+ prefix (stripped on load).
61
+ """
62
+
63
+ config_name = "config.json"
64
+
65
+ @register_to_config
66
+ def __init__(
67
+ self,
68
+ model_type: str = "protein_mpnn",
69
+ hidden_dim: int = 128,
70
+ num_encoder_layers: int = 3,
71
+ num_decoder_layers: int = 3,
72
+ num_neighbors: int = 48,
73
+ dropout_rate: float = 0.1,
74
+ num_positional_embeddings: int = 16,
75
+ min_rbf_mean: float = 2.0,
76
+ max_rbf_mean: float = 22.0,
77
+ num_rbf: int = 16,
78
+ # LigandMPNN-specific
79
+ num_context_atoms: int = 25,
80
+ num_context_encoding_layers: int = 2,
81
+ ):
82
+ super().__init__()
83
+
84
+ model_cls = MODEL_CLASSES.get(model_type)
85
+ if model_cls is None:
86
+ raise ValueError(
87
+ f"Unknown model_type '{model_type}'. "
88
+ f"Choose from: {list(MODEL_CLASSES.keys())}"
89
+ )
90
+
91
+ common_kwargs = dict(
92
+ num_node_features=hidden_dim,
93
+ num_edge_features=hidden_dim,
94
+ hidden_dim=hidden_dim,
95
+ num_encoder_layers=num_encoder_layers,
96
+ num_decoder_layers=num_decoder_layers,
97
+ num_neighbors=num_neighbors,
98
+ dropout_rate=dropout_rate,
99
+ num_positional_embeddings=num_positional_embeddings,
100
+ min_rbf_mean=min_rbf_mean,
101
+ max_rbf_mean=max_rbf_mean,
102
+ num_rbf=num_rbf,
103
+ )
104
+
105
+ if model_type == "ligand_mpnn":
106
+ common_kwargs["num_context_atoms"] = num_context_atoms
107
+ common_kwargs["num_context_encoding_layers"] = num_context_encoding_layers
108
+
109
+ self.model = model_cls(**common_kwargs)
110
+
111
+ def forward(
112
+ self,
113
+ X: torch.Tensor,
114
+ S: Optional[torch.Tensor] = None,
115
+ residue_mask: Optional[torch.Tensor] = None,
116
+ designed_residue_mask: Optional[torch.Tensor] = None,
117
+ chain_labels: Optional[torch.Tensor] = None,
118
+ R_idx: Optional[torch.Tensor] = None,
119
+ temperature: float = 0.1,
120
+ **kwargs,
121
+ ) -> MPNNModelOutput:
122
+ """
123
+ Run ProteinMPNN / LigandMPNN sequence design.
124
+
125
+ Args:
126
+ X: Backbone atom coordinates [B, L, num_atoms, 3].
127
+ For ProteinMPNN: num_atoms=4 (N, CA, C, O).
128
+ S: Ground-truth sequence tokens [B, L] (optional, for teacher forcing).
129
+ residue_mask: Valid residue mask [B, L] (default: all valid).
130
+ designed_residue_mask: Which residues to design [B, L] (default: all).
131
+ chain_labels: Chain identifiers [B, L] (default: single chain).
132
+ R_idx: Residue indices [B, L] (default: 0..L-1).
133
+ temperature: Sampling temperature (default: 0.1).
134
+
135
+ Returns:
136
+ MPNNModelOutput with sequence logits and sampled indices.
137
+ """
138
+ B, L = X.shape[0], X.shape[1]
139
+ device = X.device
140
+
141
+ if S is None:
142
+ S = torch.zeros(B, L, dtype=torch.long, device=device)
143
+ if residue_mask is None:
144
+ residue_mask = torch.ones(B, L, dtype=torch.bool, device=device)
145
+ if designed_residue_mask is None:
146
+ designed_residue_mask = torch.ones(B, L, dtype=torch.bool, device=device)
147
+ if chain_labels is None:
148
+ chain_labels = torch.zeros(B, L, dtype=torch.long, device=device)
149
+ if R_idx is None:
150
+ R_idx = torch.arange(L, device=device).unsqueeze(0).expand(B, -1)
151
+
152
+ # Atom mask: mark all atoms as valid based on coordinate presence
153
+ X_m = (X.abs().sum(dim=-1) > 0).float() # [B, L, num_atoms]
154
+
155
+ network_input = {
156
+ "X": X,
157
+ "X_m": X_m,
158
+ "S": S,
159
+ "R_idx": R_idx,
160
+ "chain_labels": chain_labels,
161
+ "residue_mask": residue_mask,
162
+ "designed_residue_mask": designed_residue_mask,
163
+ "temperature": temperature,
164
+ **kwargs,
165
+ }
166
+
167
+ output = self.model(network_input)
168
+
169
+ logits = output["decoder_features"]["logits"] # [B, L, n_vocab]
170
+ S_sampled = output["decoder_features"].get(
171
+ "S_sampled", logits.argmax(dim=-1)
172
+ )
173
+
174
+ return MPNNModelOutput(
175
+ sequence_logits=logits,
176
+ sequence_indices=S_sampled,
177
+ decoder_features=output["decoder_features"],
178
+ )
mpnn_soluble/config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "MPNNModel",
3
+ "_diffusers_version": "0.37.0.dev0",
4
+ "auto_map": {
5
+ "AutoModel": "model_mpnn.MPNNModel"
6
+ },
7
+ "model_type": "protein_mpnn",
8
+ "hidden_dim": 128,
9
+ "num_encoder_layers": 3,
10
+ "num_decoder_layers": 3,
11
+ "num_neighbors": 48,
12
+ "dropout_rate": 0.1,
13
+ "num_positional_embeddings": 16,
14
+ "min_rbf_mean": 2.0,
15
+ "max_rbf_mean": 22.0,
16
+ "num_rbf": 16
17
+ }
mpnn_soluble/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d7ce5491820a26c7cfae3ab46421c40d1af645accb6c10b4818243eea2f0165
3
+ size 6653924
mpnn_soluble/model_mpnn.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Dhruv Nair. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ ProteinMPNN / LigandMPNN model wrapper.
17
+
18
+ A thin diffusers-compatible wrapper around the foundry MPNN model,
19
+ following the same pattern as the transformer and scheduler wrappers.
20
+ Reuses the foundry model implementation directly, adding only the
21
+ ModelMixin/ConfigMixin interface for diffusers integration.
22
+ """
23
+
24
+ from dataclasses import dataclass
25
+ from typing import Optional
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+
30
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
31
+ from diffusers.models.modeling_utils import ModelMixin
32
+
33
+ from mpnn.model.mpnn import LigandMPNN, ProteinMPNN
34
+
35
+
36
+ MODEL_CLASSES = {
37
+ "protein_mpnn": ProteinMPNN,
38
+ "ligand_mpnn": LigandMPNN,
39
+ }
40
+
41
+
42
+ @dataclass
43
+ class MPNNModelOutput:
44
+ """Output from the MPNN model wrapper."""
45
+
46
+ sequence_logits: torch.Tensor # [B, L, n_vocab]
47
+ sequence_indices: torch.Tensor # [B, L]
48
+ decoder_features: dict # full decoder output dict
49
+
50
+
51
+ class MPNNModel(ModelMixin, ConfigMixin):
52
+ """
53
+ Diffusers-compatible wrapper around the foundry ProteinMPNN / LigandMPNN.
54
+
55
+ Wraps `mpnn.model.mpnn.ProteinMPNN` (or `LigandMPNN`) to provide a
56
+ diffusers ModelMixin/ConfigMixin interface. All model logic is delegated
57
+ to the foundry implementation.
58
+
59
+ State dict keys match the foundry checkpoint format via the `model.*`
60
+ prefix (stripped on load).
61
+ """
62
+
63
+ config_name = "config.json"
64
+
65
+ @register_to_config
66
+ def __init__(
67
+ self,
68
+ model_type: str = "protein_mpnn",
69
+ hidden_dim: int = 128,
70
+ num_encoder_layers: int = 3,
71
+ num_decoder_layers: int = 3,
72
+ num_neighbors: int = 48,
73
+ dropout_rate: float = 0.1,
74
+ num_positional_embeddings: int = 16,
75
+ min_rbf_mean: float = 2.0,
76
+ max_rbf_mean: float = 22.0,
77
+ num_rbf: int = 16,
78
+ # LigandMPNN-specific
79
+ num_context_atoms: int = 25,
80
+ num_context_encoding_layers: int = 2,
81
+ ):
82
+ super().__init__()
83
+
84
+ model_cls = MODEL_CLASSES.get(model_type)
85
+ if model_cls is None:
86
+ raise ValueError(
87
+ f"Unknown model_type '{model_type}'. "
88
+ f"Choose from: {list(MODEL_CLASSES.keys())}"
89
+ )
90
+
91
+ common_kwargs = dict(
92
+ num_node_features=hidden_dim,
93
+ num_edge_features=hidden_dim,
94
+ hidden_dim=hidden_dim,
95
+ num_encoder_layers=num_encoder_layers,
96
+ num_decoder_layers=num_decoder_layers,
97
+ num_neighbors=num_neighbors,
98
+ dropout_rate=dropout_rate,
99
+ num_positional_embeddings=num_positional_embeddings,
100
+ min_rbf_mean=min_rbf_mean,
101
+ max_rbf_mean=max_rbf_mean,
102
+ num_rbf=num_rbf,
103
+ )
104
+
105
+ if model_type == "ligand_mpnn":
106
+ common_kwargs["num_context_atoms"] = num_context_atoms
107
+ common_kwargs["num_context_encoding_layers"] = num_context_encoding_layers
108
+
109
+ self.model = model_cls(**common_kwargs)
110
+
111
+ def forward(
112
+ self,
113
+ X: torch.Tensor,
114
+ S: Optional[torch.Tensor] = None,
115
+ residue_mask: Optional[torch.Tensor] = None,
116
+ designed_residue_mask: Optional[torch.Tensor] = None,
117
+ chain_labels: Optional[torch.Tensor] = None,
118
+ R_idx: Optional[torch.Tensor] = None,
119
+ temperature: float = 0.1,
120
+ **kwargs,
121
+ ) -> MPNNModelOutput:
122
+ """
123
+ Run ProteinMPNN / LigandMPNN sequence design.
124
+
125
+ Args:
126
+ X: Backbone atom coordinates [B, L, num_atoms, 3].
127
+ For ProteinMPNN: num_atoms=4 (N, CA, C, O).
128
+ S: Ground-truth sequence tokens [B, L] (optional, for teacher forcing).
129
+ residue_mask: Valid residue mask [B, L] (default: all valid).
130
+ designed_residue_mask: Which residues to design [B, L] (default: all).
131
+ chain_labels: Chain identifiers [B, L] (default: single chain).
132
+ R_idx: Residue indices [B, L] (default: 0..L-1).
133
+ temperature: Sampling temperature (default: 0.1).
134
+
135
+ Returns:
136
+ MPNNModelOutput with sequence logits and sampled indices.
137
+ """
138
+ B, L = X.shape[0], X.shape[1]
139
+ device = X.device
140
+
141
+ if S is None:
142
+ S = torch.zeros(B, L, dtype=torch.long, device=device)
143
+ if residue_mask is None:
144
+ residue_mask = torch.ones(B, L, dtype=torch.bool, device=device)
145
+ if designed_residue_mask is None:
146
+ designed_residue_mask = torch.ones(B, L, dtype=torch.bool, device=device)
147
+ if chain_labels is None:
148
+ chain_labels = torch.zeros(B, L, dtype=torch.long, device=device)
149
+ if R_idx is None:
150
+ R_idx = torch.arange(L, device=device).unsqueeze(0).expand(B, -1)
151
+
152
+ # Atom mask: mark all atoms as valid based on coordinate presence
153
+ X_m = (X.abs().sum(dim=-1) > 0).float() # [B, L, num_atoms]
154
+
155
+ network_input = {
156
+ "X": X,
157
+ "X_m": X_m,
158
+ "S": S,
159
+ "R_idx": R_idx,
160
+ "chain_labels": chain_labels,
161
+ "residue_mask": residue_mask,
162
+ "designed_residue_mask": designed_residue_mask,
163
+ "temperature": temperature,
164
+ **kwargs,
165
+ }
166
+
167
+ output = self.model(network_input)
168
+
169
+ logits = output["decoder_features"]["logits"] # [B, L, n_vocab]
170
+ S_sampled = output["decoder_features"].get(
171
+ "S_sampled", logits.argmax(dim=-1)
172
+ )
173
+
174
+ return MPNNModelOutput(
175
+ sequence_logits=logits,
176
+ sequence_indices=S_sampled,
177
+ decoder_features=output["decoder_features"],
178
+ )
pyproject.toml ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "modular-rfdiffusion"
7
+ version = "0.1.0"
8
+ description = "Modular Diffusers Pipeline for RFDiffusion protein structure generation"
9
+ readme = "README.md"
10
+ license = {text = "Apache-2.0"}
11
+ requires-python = ">=3.8"
12
+ authors = [
13
+ {name = "Dhruv Nair"}
14
+ ]
15
+ keywords = ["diffusion", "protein", "rfdiffusion", "deep-learning", "pytorch"]
16
+ classifiers = [
17
+ "Development Status :: 3 - Alpha",
18
+ "Intended Audience :: Science/Research",
19
+ "License :: OSI Approved :: Apache Software License",
20
+ "Programming Language :: Python :: 3",
21
+ "Programming Language :: Python :: 3.8",
22
+ "Programming Language :: Python :: 3.9",
23
+ "Programming Language :: Python :: 3.10",
24
+ "Programming Language :: Python :: 3.11",
25
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
26
+ "Topic :: Scientific/Engineering :: Bio-Informatics",
27
+ ]
28
+
29
+ dependencies = [
30
+ "torch>=2.0.0",
31
+ "numpy>=1.21.0",
32
+ "diffusers>=0.25.0",
33
+ "huggingface-hub>=0.20.0",
34
+ "scipy>=1.7.0",
35
+ "hydra-core>=1.0.0",
36
+ "omegaconf>=2.0.0",
37
+ ]
38
+
39
+ [project.optional-dependencies]
40
+ dev = [
41
+ "pytest>=7.0.0",
42
+ "pytest-cov>=4.0.0",
43
+ "black>=23.0.0",
44
+ "ruff>=0.1.0",
45
+ "mypy>=1.0.0",
46
+ ]
47
+
48
+ [project.urls]
49
+ Homepage = "https://github.com/DN6/modular-diffusers"
50
+ Repository = "https://github.com/DN6/modular-diffusers"
51
+
52
+ [tool.setuptools.packages.find]
53
+ where = ["."]
54
+ include = ["*"]
55
+
56
+ [tool.black]
57
+ line-length = 119
58
+ target-version = ["py38", "py39", "py310", "py311"]
59
+
60
+ [tool.ruff]
61
+ line-length = 119
62
+ target-version = "py38"
63
+
64
+ [tool.ruff.lint]
65
+ select = ["E", "F", "W", "I", "N"]
66
+ ignore = ["E501"]
scheduler/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Dhruv Nair. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .model import RFDiffusionScheduler
scheduler/config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "RFDiffusionScheduler",
3
+ "_diffusers_version": "0.37.0.dev0",
4
+ "auto_map": {
5
+ "AutoModel": "model.RFDiffusionScheduler"
6
+ },
7
+ "num_timesteps": 200,
8
+ "sigma_data": 16.0,
9
+ "s_min": 4e-4,
10
+ "s_max": 160.0,
11
+ "p": 7.0,
12
+ "gamma_0": 0.6,
13
+ "gamma_min": 1.0,
14
+ "noise_scale": 1.003,
15
+ "step_scale": 1.5
16
+ }
scheduler/model.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Dhruv Nair. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ RFDiffusion3 Scheduler.
17
+
18
+ A thin diffusers-compatible wrapper around the foundry EDM noise schedule
19
+ and stochastic sampling logic from `rfd3.model.inference_sampler`.
20
+ """
21
+
22
+ from typing import Optional
23
+
24
+ import torch
25
+
26
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
27
+
28
+ # Reuse the original noise schedule and sampling config directly
29
+ from rfd3.model.inference_sampler import SampleDiffusionWithMotif
30
+
31
+
32
+ class RFDiffusionScheduler(ConfigMixin):
33
+ """
34
+ Diffusers-compatible scheduler wrapping the foundry EDM sampler.
35
+
36
+ Delegates noise schedule construction and sampling parameters to
37
+ `rfd3.model.inference_sampler.SampleDiffusionWithMotif`.
38
+ """
39
+
40
+ config_name = "config.json"
41
+
42
+ @register_to_config
43
+ def __init__(
44
+ self,
45
+ num_timesteps: int = 200,
46
+ sigma_data: float = 16.0,
47
+ s_min: float = 4e-4,
48
+ s_max: float = 160.0,
49
+ p: float = 7.0,
50
+ gamma_0: float = 0.6,
51
+ gamma_min: float = 1.0,
52
+ noise_scale: float = 1.003,
53
+ step_scale: float = 1.5,
54
+ ):
55
+ # Instantiate the foundry sampler with matching parameters
56
+ self._sampler = SampleDiffusionWithMotif(
57
+ num_timesteps=num_timesteps,
58
+ sigma_data=sigma_data,
59
+ s_min=s_min,
60
+ s_max=s_max,
61
+ p=p,
62
+ gamma_0=gamma_0,
63
+ gamma_min=gamma_min,
64
+ noise_scale=noise_scale,
65
+ step_scale=step_scale,
66
+ )
67
+
68
+ @property
69
+ def sampler(self) -> SampleDiffusionWithMotif:
70
+ return self._sampler
71
+
72
+ def get_noise_schedule(self, device: torch.device = None) -> torch.Tensor:
73
+ """
74
+ Construct the EDM noise schedule using the foundry implementation.
75
+
76
+ Returns:
77
+ torch.Tensor: Noise schedule [num_timesteps] from high to low noise.
78
+ """
79
+ return self._sampler._construct_inference_noise_schedule(
80
+ device=device or torch.device("cpu")
81
+ )
82
+
83
+ def get_initial_noise_level(self, device: torch.device = None) -> torch.Tensor:
84
+ """Get the first (largest) noise level from the schedule."""
85
+ return self.get_noise_schedule(device=device)[0]
86
+
87
+ def step(
88
+ self,
89
+ xyz_pred: torch.Tensor,
90
+ xyz_noisy: torch.Tensor,
91
+ c_t_minus_1: torch.Tensor,
92
+ c_t: torch.Tensor,
93
+ motif_mask: Optional[torch.Tensor] = None,
94
+ ) -> torch.Tensor:
95
+ """
96
+ Perform one Euler denoising step matching the foundry sampler.
97
+
98
+ The foundry ``sample_diffusion_like_af3`` does NOT clamp motif
99
+ coordinates after the Euler update β€” it relies on noise injection
100
+ having zeroed epsilon for motif atoms so the model's delta is ~0
101
+ there. We replicate that behaviour here.
102
+
103
+ Args:
104
+ xyz_pred: Model's denoised prediction X_denoised_L [B, L, 3]
105
+ xyz_noisy: Noise-injected coordinates X_noisy_L [B, L, 3]
106
+ c_t_minus_1: Previous noise level
107
+ c_t: Next (lower) noise level
108
+ motif_mask: Boolean mask for fixed positions (True = fixed) [L]
109
+ (unused β€” kept for API compatibility)
110
+
111
+ Returns:
112
+ Updated coordinates X_L [B, L, 3]
113
+ """
114
+ gamma = self._sampler.gamma_0 if c_t > self._sampler.gamma_min else 0.0
115
+ t_hat = c_t_minus_1 * (gamma + 1.0)
116
+
117
+ delta_L = (xyz_noisy - xyz_pred) / t_hat
118
+ d_t = c_t - t_hat
119
+ xyz_next = xyz_noisy + self._sampler.step_scale * d_t * delta_L
120
+
121
+ return xyz_next
122
+
123
+ def add_noise(
124
+ self,
125
+ xyz: torch.Tensor,
126
+ c_t_minus_1: torch.Tensor,
127
+ c_t: torch.Tensor,
128
+ motif_mask: Optional[torch.Tensor] = None,
129
+ ) -> tuple[torch.Tensor, torch.Tensor]:
130
+ """
131
+ Inject stochastic noise before the model call, matching the foundry sampler.
132
+
133
+ Args:
134
+ xyz: Current coordinates X_L [B, L, 3]
135
+ c_t_minus_1: Previous noise level
136
+ c_t: Current (next lower) noise level
137
+ motif_mask: Boolean mask for fixed positions (True = fixed) [L]
138
+
139
+ Returns:
140
+ Tuple of (noisy coordinates X_noisy_L, t_hat scalar)
141
+ """
142
+ gamma = self._sampler.gamma_0 if c_t > self._sampler.gamma_min else 0.0
143
+ t_hat = c_t_minus_1 * (gamma + 1.0)
144
+
145
+ noise_std = self._sampler.noise_scale * torch.sqrt(t_hat**2 - c_t_minus_1**2)
146
+ epsilon = noise_std * torch.randn_like(xyz)
147
+
148
+ if motif_mask is not None:
149
+ epsilon[:, motif_mask] = 0.0
150
+
151
+ xyz_noisy = xyz + epsilon
152
+ return xyz_noisy, t_hat
transformer/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Dhruv Nair. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .model_rfdiffusion import RFDiffusionTransformerModel, RFDiffusionTransformerOutput
transformer/config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "RFDiffusionTransformerModel",
3
+ "_diffusers_version": "0.37.0.dev0",
4
+ "auto_map": {
5
+ "AutoModel": "model_rfdiffusion.RFDiffusionTransformerModel"
6
+ },
7
+ "c_s": 384,
8
+ "c_z": 128,
9
+ "c_atom": 128,
10
+ "c_atompair": 16,
11
+ "c_token": 768,
12
+ "c_t_embed": 256,
13
+ "sigma_data": 16.0,
14
+ "n_pairformer_block": 2,
15
+ "n_diffusion_block": 18,
16
+ "n_atom_encoder_block": 3,
17
+ "n_atom_decoder_block": 3,
18
+ "n_head": 16,
19
+ "n_pairformer_head": 16,
20
+ "n_recycle": 2,
21
+ "p_drop": 0.0
22
+ }
transformer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1fe11cd4fe730654e0659d5218228a9d27bdcdac4c880faad993feaa532e99eb
3
+ size 672279561
transformer/model_rfdiffusion.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Dhruv Nair. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ RFDiffusion3 Transformer model.
17
+
18
+ A thin diffusers-compatible wrapper around the existing foundry RFD3
19
+ model components. Reuses all layer implementations from `rfd3.model.layers.*`
20
+ and `foundry.model.layers.*` directly, adding only the ModelMixin/ConfigMixin
21
+ interface needed for diffusers ModularPipeline integration.
22
+ """
23
+
24
+ from dataclasses import dataclass
25
+ from typing import Optional
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+
30
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
31
+ from diffusers.models.modeling_utils import ModelMixin
32
+
33
+ # Reuse all components from the foundry/rfd3 package directly
34
+ from rfd3.model.RFD3_diffusion_module import RFD3DiffusionModule
35
+ from rfd3.model.layers.encoders import TokenInitializer
36
+
37
+
38
+ @dataclass
39
+ class RFDiffusionTransformerOutput:
40
+ """Output class for RFDiffusion transformer."""
41
+
42
+ xyz: torch.Tensor
43
+ single: torch.Tensor
44
+ pair: torch.Tensor
45
+ sequence_logits: Optional[torch.Tensor] = None
46
+ sequence_indices: Optional[torch.Tensor] = None
47
+
48
+
49
+ class RFDiffusionTransformerModel(ModelMixin, ConfigMixin):
50
+ """
51
+ Diffusers-compatible wrapper around the foundry RFD3 model.
52
+
53
+ This wraps `rfd3.model.RFD3_diffusion_module.RFD3DiffusionModule` and
54
+ `rfd3.model.layers.encoders.TokenInitializer` to provide a diffusers
55
+ ModelMixin/ConfigMixin interface. All actual model logic is delegated
56
+ to the foundry implementation.
57
+
58
+ State dict keys match the foundry checkpoint format via the
59
+ `token_initializer.*` and `diffusion_module.*` prefixes.
60
+ """
61
+
62
+ config_name = "config.json"
63
+ _supports_gradient_checkpointing = True
64
+
65
+ @register_to_config
66
+ def __init__(
67
+ self,
68
+ c_s: int = 384,
69
+ c_z: int = 128,
70
+ c_atom: int = 128,
71
+ c_atompair: int = 16,
72
+ c_token: int = 768,
73
+ c_t_embed: int = 256,
74
+ sigma_data: float = 16.0,
75
+ n_pairformer_block: int = 2,
76
+ n_diffusion_block: int = 18,
77
+ n_atom_encoder_block: int = 3,
78
+ n_atom_decoder_block: int = 3,
79
+ n_head: int = 16,
80
+ n_pairformer_head: int = 16,
81
+ n_recycle: int = 2,
82
+ p_drop: float = 0.0,
83
+ ):
84
+ super().__init__()
85
+
86
+ # ── Shared sub-configs matching rfd3_net.yaml exactly ────────────
87
+ cross_attention_block = {
88
+ "n_head": 4,
89
+ "c_model": c_atom,
90
+ "dropout": 0.0,
91
+ "kq_norm": True,
92
+ }
93
+ downcast_cfg = {
94
+ "method": "cross_attention",
95
+ "cross_attention_block": cross_attention_block,
96
+ }
97
+ upcast_cfg = {
98
+ "method": "cross_attention",
99
+ "n_split": 3,
100
+ "cross_attention_block": cross_attention_block,
101
+ }
102
+
103
+ # ── TokenInitializer (foundry) ──────────────────────────────────
104
+ token_1d_features = {
105
+ "ref_motif_token_type": 3,
106
+ "restype": 32,
107
+ "ref_plddt": 1,
108
+ "is_non_loopy": 1,
109
+ }
110
+ atom_1d_features = {
111
+ "ref_atom_name_chars": 256,
112
+ "ref_element": 128,
113
+ "ref_charge": 1,
114
+ "ref_mask": 1,
115
+ "ref_is_motif_atom_with_fixed_coord": 1,
116
+ "ref_is_motif_atom_unindexed": 1,
117
+ "has_zero_occupancy": 1,
118
+ "ref_pos": 3,
119
+ "ref_atomwise_rasa": 3,
120
+ "active_donor": 1,
121
+ "active_acceptor": 1,
122
+ "is_atom_level_hotspot": 1,
123
+ }
124
+
125
+ self.token_initializer = TokenInitializer(
126
+ c_s=c_s,
127
+ c_z=c_z,
128
+ c_atom=c_atom,
129
+ c_atompair=c_atompair,
130
+ relative_position_encoding={"r_max": 32, "s_max": 2},
131
+ n_pairformer_blocks=n_pairformer_block,
132
+ pairformer_block={
133
+ "attention_pair_bias": {"n_head": n_pairformer_head, "kq_norm": True},
134
+ },
135
+ downcast=downcast_cfg,
136
+ token_1d_features=token_1d_features,
137
+ atom_1d_features=atom_1d_features,
138
+ atom_transformer={
139
+ "n_blocks": 0,
140
+ "atom_transformer_block": {
141
+ "n_head": 4,
142
+ "kq_norm": True,
143
+ "no_residual_connection_between_attention_and_transition": False,
144
+ "dropout": 0.0,
145
+ "n_attn_seq_neighbours": 4,
146
+ "n_attn_keys": 128,
147
+ },
148
+ },
149
+ )
150
+
151
+ # ── RFD3DiffusionModule (foundry) ───────────────────────────────
152
+ self.diffusion_module = RFD3DiffusionModule(
153
+ c_atom=c_atom,
154
+ c_atompair=c_atompair,
155
+ c_token=c_token,
156
+ c_s=c_s,
157
+ c_z=c_z,
158
+ c_t_embed=c_t_embed,
159
+ sigma_data=sigma_data,
160
+ f_pred="edm",
161
+ n_attn_seq_neighbours=2,
162
+ n_attn_keys=128,
163
+ n_recycle=n_recycle,
164
+ atom_attention_encoder={
165
+ "n_blocks": n_atom_encoder_block,
166
+ "atom_transformer_block": {
167
+ "n_head": 4,
168
+ "kq_norm": True,
169
+ "no_residual_connection_between_attention_and_transition": False,
170
+ "dropout": 0.0,
171
+ },
172
+ },
173
+ diffusion_token_encoder={
174
+ "sigma_data": sigma_data,
175
+ "n_pairformer_blocks": n_pairformer_block,
176
+ "pairformer_block": {
177
+ "attention_pair_bias": {"n_head": n_pairformer_head, "kq_norm": True},
178
+ },
179
+ "use_distogram": True,
180
+ "use_self": True,
181
+ "use_sinusoidal_distogram_embedder": False,
182
+ },
183
+ diffusion_transformer={
184
+ "n_block": n_diffusion_block,
185
+ "diffusion_transformer_block": {
186
+ "n_head": n_head,
187
+ "kq_norm": True,
188
+ "no_residual_connection_between_attention_and_transition": False,
189
+ "dropout": p_drop,
190
+ },
191
+ },
192
+ atom_attention_decoder={
193
+ "n_blocks": n_atom_decoder_block,
194
+ "upcast": upcast_cfg,
195
+ "downcast": downcast_cfg,
196
+ "atom_transformer_block": {
197
+ "n_head": 4,
198
+ "kq_norm": True,
199
+ "no_residual_connection_between_attention_and_transition": False,
200
+ "dropout": p_drop,
201
+ },
202
+ },
203
+ downcast=downcast_cfg,
204
+ )
205
+
206
+ @property
207
+ def sigma_data(self) -> float:
208
+ return self.diffusion_module.sigma_data
209
+
210
+ def forward(
211
+ self,
212
+ xyz_noisy: torch.Tensor,
213
+ t: torch.Tensor,
214
+ f: Optional[dict] = None,
215
+ motif_mask: Optional[torch.Tensor] = None,
216
+ n_recycle: Optional[int] = None,
217
+ **kwargs,
218
+ ) -> RFDiffusionTransformerOutput:
219
+ """
220
+ Forward pass delegated to the foundry RFD3DiffusionModule.
221
+
222
+ Args:
223
+ xyz_noisy: Noisy atom coordinates [B, L, 3]
224
+ t: Noise level / timestep [B]
225
+ f: Feature dictionary (as expected by foundry). If None, a minimal
226
+ feature dict is constructed from xyz_noisy and motif_mask.
227
+ motif_mask: Mask for fixed motif atoms [L] (used when f is None)
228
+ n_recycle: Number of recycling iterations
229
+
230
+ Returns:
231
+ RFDiffusionTransformerOutput with denoised coordinates and predictions
232
+ """
233
+ B, L, _ = xyz_noisy.shape
234
+
235
+ # If caller provides a full feature dict, use the native foundry path
236
+ if f is not None:
237
+ initializer_outputs = self.token_initializer(f)
238
+ outs = self.diffusion_module(
239
+ X_noisy_L=xyz_noisy,
240
+ t=t,
241
+ f=f,
242
+ n_recycle=n_recycle,
243
+ **initializer_outputs,
244
+ )
245
+ return RFDiffusionTransformerOutput(
246
+ xyz=outs["X_L"],
247
+ single=torch.zeros(1), # not directly exposed by foundry
248
+ pair=torch.zeros(1),
249
+ sequence_logits=outs.get("sequence_logits_I"),
250
+ sequence_indices=outs.get("sequence_indices_I"),
251
+ )
252
+
253
+ # Simplified path: construct minimal feature dict and call dm.forward()
254
+ # For unconditional generation, each residue has 1 atom (CA), so L = I
255
+ device = xyz_noisy.device
256
+ dtype = xyz_noisy.dtype
257
+
258
+ if motif_mask is None:
259
+ motif_mask = torch.zeros(L, dtype=torch.bool, device=device)
260
+ else:
261
+ motif_mask = motif_mask.to(device)
262
+
263
+ # Construct minimal feature dict with all keys required by foundry
264
+ f = {
265
+ "atom_to_token_map": torch.arange(L, device=device), # 1:1 atom-to-token
266
+ "unindexing_pair_mask": torch.zeros(L, L, dtype=torch.bool, device=device),
267
+ "is_ca": torch.ones(L, dtype=torch.bool, device=device),
268
+ "is_motif_atom_with_fixed_coord": motif_mask,
269
+ "is_motif_token_with_fully_fixed_coord": motif_mask,
270
+ }
271
+
272
+ # Zero-initialized TokenInitializer outputs (no conditioning features)
273
+ Q_L_init = torch.zeros(L, self.config.c_atom, device=device, dtype=dtype)
274
+ C_L = torch.zeros(L, self.config.c_atom, device=device, dtype=dtype)
275
+ P_LL = torch.zeros(L, L, self.config.c_atompair, device=device, dtype=dtype)
276
+ S_I = torch.zeros(L, self.config.c_s, device=device, dtype=dtype)
277
+ Z_II = torch.zeros(L, L, self.config.c_z, device=device, dtype=dtype)
278
+
279
+ outs = self.diffusion_module(
280
+ X_noisy_L=xyz_noisy,
281
+ t=t,
282
+ f=f,
283
+ Q_L_init=Q_L_init,
284
+ C_L=C_L,
285
+ P_LL=P_LL,
286
+ S_I=S_I,
287
+ Z_II=Z_II,
288
+ n_recycle=n_recycle,
289
+ )
290
+
291
+ return RFDiffusionTransformerOutput(
292
+ xyz=outs["X_L"],
293
+ single=torch.zeros(1),
294
+ pair=torch.zeros(1),
295
+ sequence_logits=outs.get("sequence_logits_I"),
296
+ sequence_indices=outs.get("sequence_indices_I"),
297
+ )