Instructions to use dn6/RFDiffusion-3 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use dn6/RFDiffusion-3 with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("dn6/RFDiffusion-3", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
Upload folder using huggingface_hub
Browse files- README.md +290 -0
- __init__.py +34 -0
- before_denoise.py +312 -0
- decoders.py +281 -0
- denoise.py +193 -0
- modular_blocks.py +413 -0
- modular_config.json +7 -0
- modular_model_index.json +37 -0
- mpnn/__init__.py +15 -0
- mpnn/config.json +17 -0
- mpnn/diffusion_pytorch_model.safetensors +3 -0
- mpnn/model_mpnn.py +178 -0
- mpnn_ligand/config.json +19 -0
- mpnn_ligand/diffusion_pytorch_model.safetensors +3 -0
- mpnn_ligand/model_mpnn.py +178 -0
- mpnn_soluble/config.json +17 -0
- mpnn_soluble/diffusion_pytorch_model.safetensors +3 -0
- mpnn_soluble/model_mpnn.py +178 -0
- pyproject.toml +66 -0
- scheduler/__init__.py +15 -0
- scheduler/config.json +16 -0
- scheduler/model.py +152 -0
- transformer/__init__.py +15 -0
- transformer/config.json +22 -0
- transformer/diffusion_pytorch_model.safetensors +3 -0
- transformer/model_rfdiffusion.py +297 -0
README.md
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Protein Design with Diffusers
|
| 2 |
+
|
| 3 |
+
A [diffusers](https://github.com/huggingface/diffusers) `ModularPipeline` wrapper for protein design, combining structure generation ([RFdiffusion3](https://www.biorxiv.org/content/10.1101/2025.09.18.676967v2)) and sequence design ([ProteinMPNN](https://www.science.org/doi/10.1126/science.add2187) / [LigandMPNN](https://www.nature.com/articles/s41592-025-02626-1)) into composable, swappable pipeline blocks.
|
| 4 |
+
|
| 5 |
+
All three models β RFD3, ProteinMPNN, and LigandMPNN β rely on [Foundry](https://github.com/RosettaCommons/foundry) for their underlying implementations and [AtomWorks](https://github.com/RosettaCommons/atomworks) for structure I/O. This package adds only the thin wrappers needed for diffusers integration.
|
| 6 |
+
|
| 7 |
+
## Getting Started
|
| 8 |
+
|
| 9 |
+
### Installation
|
| 10 |
+
|
| 11 |
+
```bash
|
| 12 |
+
# Install foundry (provides model implementations + AtomWorks)
|
| 13 |
+
pip install rc-foundry[all]
|
| 14 |
+
|
| 15 |
+
# Install diffusers with modular pipeline support
|
| 16 |
+
pip install diffusers
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
### Running with Diffusers
|
| 20 |
+
|
| 21 |
+
```python
|
| 22 |
+
import torch
|
| 23 |
+
from diffusers import ModularPipeline
|
| 24 |
+
|
| 25 |
+
# Load the pipeline
|
| 26 |
+
pipe = ModularPipeline.from_pretrained(
|
| 27 |
+
"dn6/RFDiffusion-3",
|
| 28 |
+
trust_remote_code=True,
|
| 29 |
+
)
|
| 30 |
+
pipe.load_components(
|
| 31 |
+
device_map="cuda",
|
| 32 |
+
torch_dtype=torch.bfloat16,
|
| 33 |
+
trust_remote_code=True,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# Generate a 100-residue protein backbone
|
| 37 |
+
state = pipe(contigs="100")
|
| 38 |
+
|
| 39 |
+
# Access coordinates
|
| 40 |
+
print(state.output.xyz.shape) # [B, L, 3]
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
## Why Diffusers?
|
| 44 |
+
|
| 45 |
+
Wrapping RFdiffusion3 as a diffusers `ModularPipeline` gives you access to the full diffusers ecosystem out of the box:
|
| 46 |
+
|
| 47 |
+
### CPU Offloading
|
| 48 |
+
|
| 49 |
+
Run large models on limited VRAM by offloading components to CPU when not in use:
|
| 50 |
+
|
| 51 |
+
```python
|
| 52 |
+
pipe.enable_model_cpu_offload()
|
| 53 |
+
state = pipe(contigs="100")
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
### Hub Integration
|
| 57 |
+
|
| 58 |
+
All models are hosted on the Hugging Face Hub. Load by repo ID, share fine-tuned variants, and version your checkpoints:
|
| 59 |
+
|
| 60 |
+
```python
|
| 61 |
+
pipe = ModularPipeline.from_pretrained("dn6/RFDiffusion-3", trust_remote_code=True)
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
### LoRA Fine-Tuning
|
| 65 |
+
|
| 66 |
+
Fine-tune RFdiffusion3 on custom datasets with LoRA β supported natively by `ModelMixin`:
|
| 67 |
+
|
| 68 |
+
```python
|
| 69 |
+
from peft import LoraConfig
|
| 70 |
+
|
| 71 |
+
lora_config = LoraConfig(r=16, lora_alpha=16, target_modules=["to_q", "to_k", "to_v"])
|
| 72 |
+
pipe.transformer.add_adapter(lora_config)
|
| 73 |
+
|
| 74 |
+
# After training
|
| 75 |
+
pipe.transformer.save_pretrained("my-rfd3-lora")
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
### Composable Workflows
|
| 79 |
+
|
| 80 |
+
Inspect, swap, and extend pipeline blocks at runtime:
|
| 81 |
+
|
| 82 |
+
```python
|
| 83 |
+
# Inspect the pipeline
|
| 84 |
+
print(pipe.blocks)
|
| 85 |
+
|
| 86 |
+
# Swap ProteinMPNN for LigandMPNN
|
| 87 |
+
mpnn = AutoModel.from_pretrained("dn6/RFDiffusion-3", subfolder="mpnn_ligand", trust_remote_code=True)
|
| 88 |
+
pipe.update_components(mpnn=mpnn)
|
| 89 |
+
|
| 90 |
+
# Add a custom post-processing block
|
| 91 |
+
from diffusers.modular_pipelines import ModularPipelineBlocks, PipelineState
|
| 92 |
+
from diffusers.modular_pipelines.modular_pipeline_utils import InputParam, OutputParam
|
| 93 |
+
|
| 94 |
+
class ScoreDesignStep(ModularPipelineBlocks):
|
| 95 |
+
@property
|
| 96 |
+
def inputs(self):
|
| 97 |
+
return [InputParam("xyz", required=True)]
|
| 98 |
+
|
| 99 |
+
@property
|
| 100 |
+
def intermediate_outputs(self):
|
| 101 |
+
return [OutputParam("radius_of_gyration")]
|
| 102 |
+
|
| 103 |
+
def __call__(self, components, state):
|
| 104 |
+
block_state = self.get_block_state(state)
|
| 105 |
+
xyz = block_state.xyz
|
| 106 |
+
centroid = xyz.mean(dim=-2, keepdim=True)
|
| 107 |
+
block_state.radius_of_gyration = ((xyz - centroid) ** 2).sum(-1).mean().sqrt()
|
| 108 |
+
self.set_block_state(state, block_state)
|
| 109 |
+
return components, state
|
| 110 |
+
|
| 111 |
+
# Insert after the decoder
|
| 112 |
+
pipe._blocks.sub_blocks.insert("score", ScoreDesignStep(), index=3)
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
## Output Types
|
| 116 |
+
|
| 117 |
+
All output types return the base tensors (`xyz`, `sequence_indices`, `sequence_logits`). The `output_type` parameter controls what additional format is produced:
|
| 118 |
+
|
| 119 |
+
| `output_type` | Additional output | Writes to disk |
|
| 120 |
+
|---|---|---|
|
| 121 |
+
| `"tensor"` | β | β |
|
| 122 |
+
| `"pdb"` | `pdb_string` | `.pdb` file |
|
| 123 |
+
| `"cif"` | `atom_array`, `atom_array_stack`, `trajectory_stack` | `.cif` via AtomWorks |
|
| 124 |
+
| `"cif.gz"` | Same as `"cif"` | `.cif.gz` compressed |
|
| 125 |
+
|
| 126 |
+
CIF outputs use [AtomWorks](https://github.com/RosettaCommons/atomworks) `to_cif_file` and return [biotite](https://www.biotite-python.org/) `AtomArray` / `AtomArrayStack` objects, matching the foundry output format.
|
| 127 |
+
|
| 128 |
+
```python
|
| 129 |
+
# Save as compressed CIF (matches foundry output format)
|
| 130 |
+
state = pipe(contigs="100", output_type="cif.gz", output_path="design_0")
|
| 131 |
+
|
| 132 |
+
# AtomArray is available directly
|
| 133 |
+
atom_array = state.output.atom_array
|
| 134 |
+
print(atom_array) # biotite AtomArray with CA coords + residue names
|
| 135 |
+
|
| 136 |
+
# Denoising trajectory as AtomArrayStack (one model per step)
|
| 137 |
+
trajectory = state.output.trajectory_stack
|
| 138 |
+
|
| 139 |
+
# PDB string output
|
| 140 |
+
state = pipe(contigs="100", output_type="pdb", output_path="design_0.pdb")
|
| 141 |
+
print(state.output.pdb_string[:200])
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
## Models
|
| 145 |
+
|
| 146 |
+
By default, `load_components` loads the RFdiffusion3 transformer and scheduler. MPNN models are optional β load them separately when you need sequence design.
|
| 147 |
+
|
| 148 |
+
### RFdiffusion3 (RFD3)
|
| 149 |
+
|
| 150 |
+
[RFdiffusion3](https://www.biorxiv.org/content/10.1101/2025.09.18.676967v2) is an all-atom generative model that designs protein structures via iterative denoising. Uses an EDM noise schedule with 200 steps. Loaded automatically by `load_components`.
|
| 151 |
+
|
| 152 |
+
| Component | Subfolder | Description |
|
| 153 |
+
|-----------|-----------|-------------|
|
| 154 |
+
| `transformer` | `transformer/` | `RFDiffusionTransformerModel` (168M params) |
|
| 155 |
+
| `scheduler` | `scheduler/` | `RFDiffusionScheduler` (EDM noise schedule + Euler stepping) |
|
| 156 |
+
|
| 157 |
+
### ProteinMPNN / LigandMPNN
|
| 158 |
+
|
| 159 |
+
[ProteinMPNN](https://www.science.org/doi/10.1126/science.add2187) and [LigandMPNN](https://www.nature.com/articles/s41592-025-02626-1) are inverse-folding models that design amino acid sequences for a given protein backbone. These are **not** loaded by default β load them with `AutoModel` and register via `update_components`:
|
| 160 |
+
|
| 161 |
+
```python
|
| 162 |
+
from diffusers import AutoModel
|
| 163 |
+
|
| 164 |
+
mpnn = AutoModel.from_pretrained("dn6/RFDiffusion-3", subfolder="mpnn", trust_remote_code=True)
|
| 165 |
+
pipe.update_components(mpnn=mpnn)
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
Three variants are available:
|
| 169 |
+
|
| 170 |
+
| Subfolder | Variant | Params | Description |
|
| 171 |
+
|-----------|---------|--------|-------------|
|
| 172 |
+
| `mpnn/` | ProteinMPNN | 1.66M | Standard protein sequence design |
|
| 173 |
+
| `mpnn_ligand/` | LigandMPNN | 2.62M | Ligand-aware sequence design |
|
| 174 |
+
| `mpnn_soluble/` | SolubleMPNN | 1.66M | Optimized for soluble proteins |
|
| 175 |
+
|
| 176 |
+
## Workflows
|
| 177 |
+
|
| 178 |
+
The active workflow is selected automatically based on which inputs you provide. Passing `temperature` triggers the MPNN sequence design step; passing `input_xyz` enables motif conditioning.
|
| 179 |
+
|
| 180 |
+
| Workflow | Trigger inputs | What runs |
|
| 181 |
+
|----------|---------------|-----------|
|
| 182 |
+
| `structure_only` | `contigs` | RFdiffusion3 |
|
| 183 |
+
| `structure_and_sequence` | `contigs`, `temperature` | RFdiffusion3 β MPNN |
|
| 184 |
+
| `motif_structure_and_sequence` | `contigs`, `input_xyz`, `temperature` | Motif-conditioned RFdiffusion3 β MPNN |
|
| 185 |
+
|
| 186 |
+
> Workflows that include MPNN require loading an MPNN variant first (see above).
|
| 187 |
+
|
| 188 |
+
You can also select a workflow explicitly:
|
| 189 |
+
|
| 190 |
+
```python
|
| 191 |
+
workflow = pipe.get_workflow("structure_and_sequence")
|
| 192 |
+
```
|
| 193 |
+
|
| 194 |
+
### Structure Only
|
| 195 |
+
|
| 196 |
+
```python
|
| 197 |
+
state = pipe(contigs="100")
|
| 198 |
+
print(state.output.xyz.shape) # [1, 100, 3]
|
| 199 |
+
```
|
| 200 |
+
|
| 201 |
+
### Structure + Sequence Design
|
| 202 |
+
|
| 203 |
+
```python
|
| 204 |
+
from diffusers import AutoModel
|
| 205 |
+
|
| 206 |
+
# Load an MPNN variant and register it
|
| 207 |
+
mpnn = AutoModel.from_pretrained("dn6/RFDiffusion-3", subfolder="mpnn", trust_remote_code=True)
|
| 208 |
+
pipe.update_components(mpnn=mpnn)
|
| 209 |
+
|
| 210 |
+
# Passing temperature triggers the MPNN step
|
| 211 |
+
state = pipe(contigs="100", temperature=0.1)
|
| 212 |
+
print(state.mpnn_output.designed_sequence) # e.g. "MKVLSEG..."
|
| 213 |
+
```
|
| 214 |
+
|
| 215 |
+
### Motif-Conditioned Design
|
| 216 |
+
|
| 217 |
+
```python
|
| 218 |
+
import torch
|
| 219 |
+
|
| 220 |
+
motif_coords = torch.randn(16, 3) # [N_motif, 3]
|
| 221 |
+
state = pipe(
|
| 222 |
+
contigs="A10-25/50",
|
| 223 |
+
input_xyz=motif_coords,
|
| 224 |
+
temperature=0.1,
|
| 225 |
+
)
|
| 226 |
+
```
|
| 227 |
+
|
| 228 |
+
## Full Design Pipeline
|
| 229 |
+
|
| 230 |
+
The three pipelines can be composed into a complete protein design workflow:
|
| 231 |
+
|
| 232 |
+
```
|
| 233 |
+
RFD3 (design backbone) β MPNN (design sequence) β RF3 (validate fold)
|
| 234 |
+
```
|
| 235 |
+
|
| 236 |
+
Each is a standalone `ModularPipeline` that can run independently. Here's the full end-to-end flow:
|
| 237 |
+
|
| 238 |
+
```python
|
| 239 |
+
import torch
|
| 240 |
+
from diffusers import AutoModel, ModularPipeline
|
| 241 |
+
|
| 242 |
+
# 1. Design a backbone with RFdiffusion3
|
| 243 |
+
design_pipe = ModularPipeline.from_pretrained("dn6/RFDiffusion-3", trust_remote_code=True)
|
| 244 |
+
design_pipe.load_components(device_map="cuda", torch_dtype=torch.bfloat16, trust_remote_code=True)
|
| 245 |
+
|
| 246 |
+
mpnn = AutoModel.from_pretrained("dn6/RFDiffusion-3", subfolder="mpnn", trust_remote_code=True)
|
| 247 |
+
design_pipe.update_components(mpnn=mpnn)
|
| 248 |
+
|
| 249 |
+
state = design_pipe(contigs="100", temperature=0.1, output_type="cif.gz", output_path="design")
|
| 250 |
+
designed_sequence = state.mpnn_output.designed_sequence
|
| 251 |
+
|
| 252 |
+
# 2. Validate the design with RF3 (structure prediction)
|
| 253 |
+
fold_pipe = ModularPipeline.from_pretrained("dn6/RosettaFold-3", trust_remote_code=True)
|
| 254 |
+
fold_pipe.load_components(device_map="cuda", torch_dtype=torch.bfloat16, trust_remote_code=True)
|
| 255 |
+
|
| 256 |
+
state = fold_pipe(sequence=designed_sequence, output_type="cif.gz", output_path="prediction")
|
| 257 |
+
```
|
| 258 |
+
|
| 259 |
+
> [RF3](https://www.biorxiv.org/content/10.1101/2025.08.14.670328) (RosettaFold3) is available as a separate pipeline at [`dn6/RosettaFold-3`](https://huggingface.co/dn6/RosettaFold-3).
|
| 260 |
+
|
| 261 |
+
## Citation
|
| 262 |
+
|
| 263 |
+
If you use this code, please cite the relevant work:
|
| 264 |
+
|
| 265 |
+
```bibtex
|
| 266 |
+
@article{butcher2025_rfdiffusion3,
|
| 267 |
+
author = {Butcher, Jasper and Krishna, Rohith and Mitra, Raktim and Brent, Rafael Isaac and Li, Yanjing and Corley, Nathaniel and Kim, Paul T and Funk, Jonathan and Mathis, Simon Valentin and Salike, Saman and Muraishi, Aiko and Eisenach, Helen and Thompson, Tuscan Rock and Chen, Jie and Politanska, Yuliya and Sehgal, Enisha and Coventry, Brian and Zhang, Odin and Qiang, Bo and Didi, Kieran and Kazman, Maxwell and DiMaio, Frank and Baker, David},
|
| 268 |
+
title = {De novo Design of All-atom Biomolecular Interactions with RFdiffusion3},
|
| 269 |
+
journal = {bioRxiv},
|
| 270 |
+
year = {2025},
|
| 271 |
+
doi = {10.1101/2025.09.18.676967},
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
@article{dauparas2022robust,
|
| 275 |
+
author = {Dauparas, Justas and Anishchenko, Ivan and Bennett, Nathaniel and Bai, Hua and Ragotte, Robert J and Milles, Lukas F and Wicky, Basile IM and Courbet, Alexis and de Haas, Rob J and Bethel, Neville and others},
|
| 276 |
+
title = {Robust deep learning--based protein sequence design using ProteinMPNN},
|
| 277 |
+
journal = {Science},
|
| 278 |
+
volume = {378},
|
| 279 |
+
number = {6615},
|
| 280 |
+
pages = {49--56},
|
| 281 |
+
year = {2022},
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
@article{dauparas2025atomic,
|
| 285 |
+
author = {Dauparas, Justas and Lee, Gyu Rie and Pecoraro, Robert and An, Linna and Anishchenko, Ivan and Glasscock, Cameron and Baker, David},
|
| 286 |
+
title = {Atomic context-conditioned protein sequence design using LigandMPNN},
|
| 287 |
+
journal = {Nature Methods},
|
| 288 |
+
year = {2025},
|
| 289 |
+
}
|
| 290 |
+
```
|
__init__.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Dhruv Nair. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .transformer import RFDiffusionTransformerModel, RFDiffusionTransformerOutput
|
| 16 |
+
from .scheduler import RFDiffusionScheduler
|
| 17 |
+
from .modular_blocks import (
|
| 18 |
+
ALL_BLOCKS,
|
| 19 |
+
AUTO_BLOCKS,
|
| 20 |
+
UNCONDITIONAL_BLOCKS,
|
| 21 |
+
RFDiffusionAutoBeforeDenoiseStep,
|
| 22 |
+
RFDiffusionAutoBlocks,
|
| 23 |
+
RFDiffusionAutoDecodeStep,
|
| 24 |
+
RFDiffusionAutoDenoiseStep,
|
| 25 |
+
)
|
| 26 |
+
from .before_denoise import (
|
| 27 |
+
RFDiffusionInputStep,
|
| 28 |
+
RFDiffusionPrepareLatentsStep,
|
| 29 |
+
RFDiffusionSetTimestepsStep,
|
| 30 |
+
)
|
| 31 |
+
from .denoise import RFDiffusionDenoiseStep
|
| 32 |
+
from .decoders import RFDiffusionDecodeStep, RFDiffusionPipelineOutput
|
| 33 |
+
from .mpnn import MPNNModel, MPNNModelOutput
|
| 34 |
+
from .modular_blocks import MPNNAutoDesignStep, MPNNPipelineOutput, MPNNSequenceDesignStep
|
before_denoise.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Dhruv Nair. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import List, Tuple, Union
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
from diffusers.utils import logging
|
| 20 |
+
from diffusers.modular_pipelines import ModularPipeline, ModularPipelineBlocks, PipelineState
|
| 21 |
+
from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.get_logger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def parse_contig_string(contig_str: str) -> Tuple[int, List[Tuple[int, int]]]:
|
| 28 |
+
"""
|
| 29 |
+
Parse contig specification string.
|
| 30 |
+
|
| 31 |
+
Supports formats like:
|
| 32 |
+
- "100" -> 100 residues to design
|
| 33 |
+
- "50-100" -> random length between 50-100
|
| 34 |
+
- "A10-25/50" -> motif from chain A residues 10-25, plus 50 designed
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
total_length: Total protein length
|
| 38 |
+
motif_ranges: List of (start, end) for motif residues (0-indexed)
|
| 39 |
+
"""
|
| 40 |
+
parts = contig_str.split("/")
|
| 41 |
+
total_length = 0
|
| 42 |
+
motif_ranges = []
|
| 43 |
+
|
| 44 |
+
for part in parts:
|
| 45 |
+
part = part.strip()
|
| 46 |
+
if not part:
|
| 47 |
+
continue
|
| 48 |
+
|
| 49 |
+
if part[0].isalpha():
|
| 50 |
+
chain = part[0]
|
| 51 |
+
residue_spec = part[1:]
|
| 52 |
+
if "-" in residue_spec:
|
| 53 |
+
start, end = map(int, residue_spec.split("-"))
|
| 54 |
+
else:
|
| 55 |
+
start = end = int(residue_spec)
|
| 56 |
+
motif_len = end - start + 1
|
| 57 |
+
motif_ranges.append((total_length, total_length + motif_len))
|
| 58 |
+
total_length += motif_len
|
| 59 |
+
else:
|
| 60 |
+
if "-" in part:
|
| 61 |
+
min_len, max_len = map(int, part.split("-"))
|
| 62 |
+
add_len = (min_len + max_len) // 2
|
| 63 |
+
else:
|
| 64 |
+
add_len = int(part)
|
| 65 |
+
total_length += add_len
|
| 66 |
+
|
| 67 |
+
return total_length, motif_ranges
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class RFDiffusionInputStep(ModularPipelineBlocks):
|
| 71 |
+
"""
|
| 72 |
+
Input processing step for RFDiffusion.
|
| 73 |
+
|
| 74 |
+
Parses contigs to prepare features for structure generation.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
model_name = "rfdiffusion"
|
| 78 |
+
|
| 79 |
+
@property
|
| 80 |
+
def description(self) -> str:
|
| 81 |
+
return (
|
| 82 |
+
"Input processing step that:\n"
|
| 83 |
+
" 1. Parses contig specification to determine protein length and design regions\n"
|
| 84 |
+
" 2. Generates masks for motif positions\n"
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
@property
|
| 88 |
+
def inputs(self) -> List[InputParam]:
|
| 89 |
+
return [
|
| 90 |
+
InputParam(
|
| 91 |
+
"contigs",
|
| 92 |
+
required=True,
|
| 93 |
+
type_hint=Union[str, List[str]],
|
| 94 |
+
description="Contig specification defining design regions (e.g., '100' or 'A10-25/50-100')",
|
| 95 |
+
),
|
| 96 |
+
InputParam(
|
| 97 |
+
"input_xyz",
|
| 98 |
+
type_hint=torch.Tensor,
|
| 99 |
+
description="Input coordinates for motif residues [N_motif, 3]",
|
| 100 |
+
),
|
| 101 |
+
]
|
| 102 |
+
|
| 103 |
+
@property
|
| 104 |
+
def intermediate_outputs(self) -> List[OutputParam]:
|
| 105 |
+
return [
|
| 106 |
+
OutputParam(
|
| 107 |
+
"motif_mask",
|
| 108 |
+
type_hint=torch.Tensor,
|
| 109 |
+
description="Boolean mask for motif (fixed) positions",
|
| 110 |
+
),
|
| 111 |
+
OutputParam(
|
| 112 |
+
"motif_xyz",
|
| 113 |
+
type_hint=torch.Tensor,
|
| 114 |
+
description="Coordinates for motif residues",
|
| 115 |
+
),
|
| 116 |
+
OutputParam(
|
| 117 |
+
"L",
|
| 118 |
+
type_hint=int,
|
| 119 |
+
description="Total length of the protein being designed",
|
| 120 |
+
),
|
| 121 |
+
OutputParam(
|
| 122 |
+
"batch_size",
|
| 123 |
+
type_hint=int,
|
| 124 |
+
description="Batch size (typically 1 for RFDiffusion)",
|
| 125 |
+
),
|
| 126 |
+
OutputParam(
|
| 127 |
+
"dtype",
|
| 128 |
+
type_hint=torch.dtype,
|
| 129 |
+
description="Data type for tensors",
|
| 130 |
+
),
|
| 131 |
+
]
|
| 132 |
+
|
| 133 |
+
def check_inputs(self, components, block_state):
|
| 134 |
+
if block_state.contigs is None:
|
| 135 |
+
raise ValueError("`contigs` must be provided to specify protein design regions")
|
| 136 |
+
|
| 137 |
+
@torch.no_grad()
|
| 138 |
+
def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
|
| 139 |
+
block_state = self.get_block_state(state)
|
| 140 |
+
self.check_inputs(components, block_state)
|
| 141 |
+
|
| 142 |
+
contigs = block_state.contigs
|
| 143 |
+
input_xyz = block_state.input_xyz
|
| 144 |
+
|
| 145 |
+
if isinstance(contigs, list):
|
| 146 |
+
contig_str = "/".join(contigs)
|
| 147 |
+
else:
|
| 148 |
+
contig_str = contigs
|
| 149 |
+
|
| 150 |
+
L, motif_ranges = parse_contig_string(contig_str)
|
| 151 |
+
|
| 152 |
+
motif_mask = torch.zeros(L, dtype=torch.bool)
|
| 153 |
+
for start, end in motif_ranges:
|
| 154 |
+
motif_mask[start:end] = True
|
| 155 |
+
|
| 156 |
+
if input_xyz is not None:
|
| 157 |
+
motif_xyz = input_xyz
|
| 158 |
+
else:
|
| 159 |
+
motif_xyz = None
|
| 160 |
+
|
| 161 |
+
block_state.motif_mask = motif_mask
|
| 162 |
+
block_state.motif_xyz = motif_xyz
|
| 163 |
+
block_state.L = L
|
| 164 |
+
block_state.batch_size = 1
|
| 165 |
+
block_state.dtype = torch.float32
|
| 166 |
+
|
| 167 |
+
self.set_block_state(state, block_state)
|
| 168 |
+
return components, state
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class RFDiffusionSetTimestepsStep(ModularPipelineBlocks):
|
| 172 |
+
"""
|
| 173 |
+
Set up the EDM noise schedule for RFDiffusion3.
|
| 174 |
+
"""
|
| 175 |
+
|
| 176 |
+
model_name = "rfdiffusion"
|
| 177 |
+
|
| 178 |
+
@property
|
| 179 |
+
def description(self) -> str:
|
| 180 |
+
return "Sets up the EDM noise schedule matching the original inference sampler."
|
| 181 |
+
|
| 182 |
+
@property
|
| 183 |
+
def expected_components(self) -> List[ComponentSpec]:
|
| 184 |
+
return [
|
| 185 |
+
ComponentSpec("scheduler", description="RFDiffusion3 EDM scheduler"),
|
| 186 |
+
]
|
| 187 |
+
|
| 188 |
+
@property
|
| 189 |
+
def inputs(self) -> List[InputParam]:
|
| 190 |
+
return [
|
| 191 |
+
InputParam(
|
| 192 |
+
"num_inference_steps",
|
| 193 |
+
default=None,
|
| 194 |
+
type_hint=int,
|
| 195 |
+
description="Number of denoising steps (default: use scheduler config)",
|
| 196 |
+
),
|
| 197 |
+
InputParam("L", required=True, type_hint=int, description="Protein length"),
|
| 198 |
+
]
|
| 199 |
+
|
| 200 |
+
@property
|
| 201 |
+
def intermediate_outputs(self) -> List[OutputParam]:
|
| 202 |
+
return [
|
| 203 |
+
OutputParam(
|
| 204 |
+
"noise_schedule",
|
| 205 |
+
type_hint=torch.Tensor,
|
| 206 |
+
description="EDM noise schedule [num_timesteps] from high to low noise",
|
| 207 |
+
),
|
| 208 |
+
OutputParam(
|
| 209 |
+
"num_inference_steps",
|
| 210 |
+
type_hint=int,
|
| 211 |
+
description="Number of inference steps",
|
| 212 |
+
),
|
| 213 |
+
]
|
| 214 |
+
|
| 215 |
+
@torch.no_grad()
|
| 216 |
+
def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
|
| 217 |
+
block_state = self.get_block_state(state)
|
| 218 |
+
|
| 219 |
+
if hasattr(components, "scheduler") and components.scheduler is not None:
|
| 220 |
+
noise_schedule = components.scheduler.get_noise_schedule()
|
| 221 |
+
else:
|
| 222 |
+
# Fallback: simple linear schedule
|
| 223 |
+
noise_schedule = torch.linspace(160.0 * 16.0, 4e-4 * 16.0, 200)
|
| 224 |
+
|
| 225 |
+
block_state.noise_schedule = noise_schedule
|
| 226 |
+
block_state.num_inference_steps = len(noise_schedule)
|
| 227 |
+
|
| 228 |
+
self.set_block_state(state, block_state)
|
| 229 |
+
return components, state
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class RFDiffusionPrepareLatentsStep(ModularPipelineBlocks):
|
| 233 |
+
"""
|
| 234 |
+
Prepare initial noised coordinates for RFDiffusion3.
|
| 235 |
+
|
| 236 |
+
Matches the original _get_initial_structure:
|
| 237 |
+
noise = c0 * randn(D, L, 3)
|
| 238 |
+
noise[..., is_motif, :] = 0
|
| 239 |
+
X_L = noise + coord_motif
|
| 240 |
+
"""
|
| 241 |
+
|
| 242 |
+
model_name = "rfdiffusion"
|
| 243 |
+
|
| 244 |
+
@property
|
| 245 |
+
def description(self) -> str:
|
| 246 |
+
return (
|
| 247 |
+
"Prepares initial coordinates by sampling Gaussian noise scaled by "
|
| 248 |
+
"the first noise schedule value, matching the original sampler."
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
@property
|
| 252 |
+
def expected_components(self) -> List[ComponentSpec]:
|
| 253 |
+
return [
|
| 254 |
+
ComponentSpec("scheduler", description="RFDiffusion3 EDM scheduler"),
|
| 255 |
+
ComponentSpec("transformer", description="RFDiffusion transformer model"),
|
| 256 |
+
]
|
| 257 |
+
|
| 258 |
+
@property
|
| 259 |
+
def inputs(self) -> List[InputParam]:
|
| 260 |
+
return [
|
| 261 |
+
InputParam("generator", type_hint=torch.Generator, description="Random generator for reproducibility"),
|
| 262 |
+
InputParam("diffusion_batch_size", default=1, type_hint=int, description="Number of samples to generate in parallel"),
|
| 263 |
+
InputParam("L", required=True, type_hint=int, description="Protein length"),
|
| 264 |
+
InputParam("motif_mask", required=True, type_hint=torch.Tensor),
|
| 265 |
+
InputParam("motif_xyz", type_hint=torch.Tensor),
|
| 266 |
+
InputParam("noise_schedule", required=True, type_hint=torch.Tensor),
|
| 267 |
+
InputParam("dtype", type_hint=torch.dtype),
|
| 268 |
+
]
|
| 269 |
+
|
| 270 |
+
@property
|
| 271 |
+
def intermediate_outputs(self) -> List[OutputParam]:
|
| 272 |
+
return [
|
| 273 |
+
OutputParam("xyz", type_hint=torch.Tensor, description="Initial noised coordinates [D, L, 3]"),
|
| 274 |
+
]
|
| 275 |
+
|
| 276 |
+
@torch.no_grad()
|
| 277 |
+
def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
|
| 278 |
+
block_state = self.get_block_state(state)
|
| 279 |
+
|
| 280 |
+
L = block_state.L
|
| 281 |
+
motif_mask = block_state.motif_mask
|
| 282 |
+
motif_xyz = block_state.motif_xyz
|
| 283 |
+
noise_schedule = block_state.noise_schedule
|
| 284 |
+
dtype = block_state.dtype or torch.float32
|
| 285 |
+
generator = block_state.generator
|
| 286 |
+
D = block_state.diffusion_batch_size or 1
|
| 287 |
+
|
| 288 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 289 |
+
|
| 290 |
+
# Initial noise scaled by first noise level (c0), matching original:
|
| 291 |
+
# noise = c0 * randn(D, L, 3)
|
| 292 |
+
c0 = noise_schedule[0]
|
| 293 |
+
noise = c0 * torch.randn((D, L, 3), dtype=dtype, device=device, generator=generator)
|
| 294 |
+
|
| 295 |
+
# Zero out noise for motif atoms
|
| 296 |
+
if motif_mask is not None:
|
| 297 |
+
noise[:, motif_mask] = 0.0
|
| 298 |
+
|
| 299 |
+
# Build initial coordinates: motif coords + noise
|
| 300 |
+
coord_motif = torch.zeros((D, L, 3), dtype=dtype, device=device)
|
| 301 |
+
if motif_xyz is not None and motif_mask is not None:
|
| 302 |
+
motif_indices = motif_mask.nonzero(as_tuple=True)[0]
|
| 303 |
+
for i, idx in enumerate(motif_indices):
|
| 304 |
+
if i < motif_xyz.shape[0]:
|
| 305 |
+
coord_motif[:, idx] = motif_xyz[i].to(dtype=dtype, device=device)
|
| 306 |
+
|
| 307 |
+
xyz = noise + coord_motif
|
| 308 |
+
|
| 309 |
+
block_state.xyz = xyz
|
| 310 |
+
|
| 311 |
+
self.set_block_state(state, block_state)
|
| 312 |
+
return components, state
|
decoders.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Dhruv Nair. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
from typing import List, Optional
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
from atomworks.io.utils.io_utils import to_cif_file
|
| 21 |
+
from biotite.structure import AtomArray, AtomArrayStack, stack
|
| 22 |
+
|
| 23 |
+
from diffusers.utils import logging
|
| 24 |
+
from diffusers.modular_pipelines import ModularPipeline, ModularPipelineBlocks, PipelineState
|
| 25 |
+
from diffusers.modular_pipelines.modular_pipeline_utils import InputParam, OutputParam
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
logger = logging.get_logger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
AA_NAMES = [
|
| 32 |
+
"ALA", "ARG", "ASN", "ASP", "CYS", "GLN", "GLU", "GLY", "HIS", "ILE",
|
| 33 |
+
"LEU", "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP", "TYR", "VAL",
|
| 34 |
+
"UNK",
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _build_atom_array(xyz: torch.Tensor, seq: Optional[torch.Tensor] = None) -> AtomArray:
|
| 39 |
+
"""
|
| 40 |
+
Build a biotite AtomArray from CA coordinates and optional sequence.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
xyz: Coordinates for a single structure [L, 3].
|
| 44 |
+
seq: Sequence indices [L] (indexes into AA_NAMES).
|
| 45 |
+
"""
|
| 46 |
+
xyz_np = xyz.detach().cpu().float().numpy()
|
| 47 |
+
L = xyz_np.shape[0]
|
| 48 |
+
|
| 49 |
+
arr = AtomArray(L)
|
| 50 |
+
arr.coord = xyz_np
|
| 51 |
+
arr.atom_name = np.full(L, "CA")
|
| 52 |
+
arr.element = np.full(L, "C")
|
| 53 |
+
arr.chain_id = np.full(L, "A")
|
| 54 |
+
arr.res_id = np.arange(1, L + 1)
|
| 55 |
+
|
| 56 |
+
if seq is not None:
|
| 57 |
+
seq_np = seq.detach().cpu().numpy()
|
| 58 |
+
arr.res_name = np.array([
|
| 59 |
+
AA_NAMES[int(idx)] if int(idx) < len(AA_NAMES) else "UNK"
|
| 60 |
+
for idx in seq_np
|
| 61 |
+
])
|
| 62 |
+
else:
|
| 63 |
+
arr.res_name = np.full(L, "ALA")
|
| 64 |
+
|
| 65 |
+
return arr
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _build_atom_array_stack(
|
| 69 |
+
xyz: torch.Tensor,
|
| 70 |
+
seq: Optional[torch.Tensor] = None,
|
| 71 |
+
) -> AtomArrayStack:
|
| 72 |
+
"""
|
| 73 |
+
Build an AtomArrayStack from batched coordinates [B, L, 3].
|
| 74 |
+
|
| 75 |
+
Matches foundry ``build_stack_from_atom_array_and_batched_coords``.
|
| 76 |
+
"""
|
| 77 |
+
template = _build_atom_array(xyz[0], seq[0] if seq is not None else None)
|
| 78 |
+
B = xyz.shape[0]
|
| 79 |
+
arr_stack = stack([template for _ in range(B)])
|
| 80 |
+
arr_stack.coord = xyz.detach().cpu().float().numpy()
|
| 81 |
+
return arr_stack
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _build_trajectory_stack(
|
| 85 |
+
trajectory: List[torch.Tensor],
|
| 86 |
+
seq: Optional[torch.Tensor] = None,
|
| 87 |
+
) -> AtomArrayStack:
|
| 88 |
+
"""
|
| 89 |
+
Build an AtomArrayStack from a denoising trajectory.
|
| 90 |
+
|
| 91 |
+
Each entry is [B, L, 3]; takes the first batch element per step.
|
| 92 |
+
"""
|
| 93 |
+
coords = torch.stack([t[0] for t in trajectory]) # [N_steps, L, 3]
|
| 94 |
+
template = _build_atom_array(coords[0], seq[0] if seq is not None else None)
|
| 95 |
+
arr_stack = stack([template for _ in range(coords.shape[0])])
|
| 96 |
+
arr_stack.coord = coords.detach().cpu().float().numpy()
|
| 97 |
+
return arr_stack
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
@dataclass
|
| 101 |
+
class RFDiffusionPipelineOutput:
|
| 102 |
+
"""Output class for RFDiffusion pipeline."""
|
| 103 |
+
|
| 104 |
+
xyz: torch.Tensor
|
| 105 |
+
atom_array: Optional[AtomArray] = None
|
| 106 |
+
atom_array_stack: Optional[AtomArrayStack] = None
|
| 107 |
+
trajectory_stack: Optional[AtomArrayStack] = None
|
| 108 |
+
sequence_indices: Optional[torch.Tensor] = None
|
| 109 |
+
sequence_logits: Optional[torch.Tensor] = None
|
| 110 |
+
single: Optional[torch.Tensor] = None
|
| 111 |
+
pair: Optional[torch.Tensor] = None
|
| 112 |
+
pdb_string: Optional[str] = None
|
| 113 |
+
trajectory: Optional[List[torch.Tensor]] = None
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class RFDiffusionDecodeStep(ModularPipelineBlocks):
|
| 117 |
+
"""
|
| 118 |
+
Decode step for RFDiffusion.
|
| 119 |
+
|
| 120 |
+
Converts denoised coordinates to final output format.
|
| 121 |
+
|
| 122 |
+
Supported ``output_type`` values:
|
| 123 |
+
|
| 124 |
+
- ``"tensor"`` β raw tensors only
|
| 125 |
+
- ``"pdb"`` β tensors + PDB format string
|
| 126 |
+
- ``"cif"`` β tensors + AtomArray via AtomWorks, writes ``.cif``
|
| 127 |
+
- ``"cif.gz"`` β same as ``"cif"`` but compressed
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
model_name = "rfdiffusion"
|
| 131 |
+
|
| 132 |
+
@property
|
| 133 |
+
def description(self) -> str:
|
| 134 |
+
return (
|
| 135 |
+
"Decode step that converts denoised coordinates to final output, "
|
| 136 |
+
"supporting tensor, PDB, and CIF (via AtomWorks) formats."
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
@property
|
| 140 |
+
def inputs(self) -> List[InputParam]:
|
| 141 |
+
return [
|
| 142 |
+
InputParam(
|
| 143 |
+
"output_type",
|
| 144 |
+
default="tensor",
|
| 145 |
+
type_hint=str,
|
| 146 |
+
description="Output format: 'tensor', 'pdb', 'cif', or 'cif.gz'",
|
| 147 |
+
),
|
| 148 |
+
InputParam(
|
| 149 |
+
"output_path",
|
| 150 |
+
type_hint=str,
|
| 151 |
+
description="Path to save output structure",
|
| 152 |
+
),
|
| 153 |
+
InputParam("xyz", required=True, type_hint=torch.Tensor, description="Denoised coordinates [B, L, 3]"),
|
| 154 |
+
InputParam("sequence_indices", type_hint=torch.Tensor, description="Predicted sequence [B, L]"),
|
| 155 |
+
InputParam("sequence_logits", type_hint=torch.Tensor, description="Sequence logits [B, L, n_aa]"),
|
| 156 |
+
InputParam("single", type_hint=torch.Tensor, description="Single representation"),
|
| 157 |
+
InputParam("pair", type_hint=torch.Tensor, description="Pair representation"),
|
| 158 |
+
InputParam("trajectory", type_hint=List[torch.Tensor], description="Denoising trajectory"),
|
| 159 |
+
]
|
| 160 |
+
|
| 161 |
+
@property
|
| 162 |
+
def intermediate_outputs(self) -> List[OutputParam]:
|
| 163 |
+
return [
|
| 164 |
+
OutputParam("output", type_hint=RFDiffusionPipelineOutput, description="Final pipeline output"),
|
| 165 |
+
]
|
| 166 |
+
|
| 167 |
+
@torch.no_grad()
|
| 168 |
+
def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
|
| 169 |
+
block_state = self.get_block_state(state)
|
| 170 |
+
|
| 171 |
+
xyz = block_state.xyz
|
| 172 |
+
sequence_indices = block_state.sequence_indices
|
| 173 |
+
sequence_logits = block_state.sequence_logits
|
| 174 |
+
single = block_state.single
|
| 175 |
+
pair = block_state.pair
|
| 176 |
+
trajectory = block_state.trajectory
|
| 177 |
+
output_type = block_state.output_type or "tensor"
|
| 178 |
+
output_path = block_state.output_path
|
| 179 |
+
|
| 180 |
+
pdb_string = None
|
| 181 |
+
atom_array = None
|
| 182 |
+
atom_array_stack = None
|
| 183 |
+
trajectory_stack = None
|
| 184 |
+
|
| 185 |
+
# Build AtomArray for CIF output types
|
| 186 |
+
if output_type in ("cif", "cif.gz"):
|
| 187 |
+
atom_array = _build_atom_array(
|
| 188 |
+
xyz[0], sequence_indices[0] if sequence_indices is not None else None,
|
| 189 |
+
)
|
| 190 |
+
if xyz.shape[0] > 1:
|
| 191 |
+
atom_array_stack = _build_atom_array_stack(xyz, sequence_indices)
|
| 192 |
+
if trajectory:
|
| 193 |
+
trajectory_stack = _build_trajectory_stack(trajectory, sequence_indices)
|
| 194 |
+
|
| 195 |
+
# Build PDB string
|
| 196 |
+
if output_type == "pdb":
|
| 197 |
+
pdb_string = self._coords_to_pdb(
|
| 198 |
+
xyz.squeeze(0),
|
| 199 |
+
sequence_indices.squeeze(0) if sequence_indices is not None else None,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# Write to disk
|
| 203 |
+
if output_path is not None:
|
| 204 |
+
import os
|
| 205 |
+
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
|
| 206 |
+
|
| 207 |
+
if output_type in ("cif", "cif.gz"):
|
| 208 |
+
to_write = atom_array_stack if atom_array_stack is not None else atom_array
|
| 209 |
+
base = output_path.rsplit(".", 1)[0] if "." in output_path else output_path
|
| 210 |
+
to_cif_file(to_write, base, file_type=output_type, include_entity_poly=False)
|
| 211 |
+
if trajectory_stack is not None:
|
| 212 |
+
to_cif_file(trajectory_stack, base + "_trajectory", file_type=output_type, include_entity_poly=False)
|
| 213 |
+
elif output_type == "pdb" and pdb_string is not None:
|
| 214 |
+
with open(output_path, "w") as f:
|
| 215 |
+
f.write(pdb_string)
|
| 216 |
+
|
| 217 |
+
output = RFDiffusionPipelineOutput(
|
| 218 |
+
xyz=xyz,
|
| 219 |
+
atom_array=atom_array,
|
| 220 |
+
atom_array_stack=atom_array_stack,
|
| 221 |
+
trajectory_stack=trajectory_stack,
|
| 222 |
+
sequence_indices=sequence_indices,
|
| 223 |
+
sequence_logits=sequence_logits,
|
| 224 |
+
single=single,
|
| 225 |
+
pair=pair,
|
| 226 |
+
pdb_string=pdb_string,
|
| 227 |
+
trajectory=trajectory,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
block_state.output = output
|
| 231 |
+
self.set_block_state(state, block_state)
|
| 232 |
+
return components, state
|
| 233 |
+
|
| 234 |
+
def _coords_to_pdb(
|
| 235 |
+
self,
|
| 236 |
+
xyz: torch.Tensor,
|
| 237 |
+
seq: Optional[torch.Tensor] = None,
|
| 238 |
+
) -> str:
|
| 239 |
+
"""Convert coordinates to PDB format string."""
|
| 240 |
+
xyz_np = xyz.cpu().numpy()
|
| 241 |
+
L = xyz_np.shape[0]
|
| 242 |
+
|
| 243 |
+
if seq is not None:
|
| 244 |
+
seq_np = seq.cpu().numpy()
|
| 245 |
+
else:
|
| 246 |
+
seq_np = None
|
| 247 |
+
|
| 248 |
+
lines = []
|
| 249 |
+
atom_idx = 1
|
| 250 |
+
|
| 251 |
+
for i in range(L):
|
| 252 |
+
if seq_np is not None:
|
| 253 |
+
aa_idx = int(seq_np[i])
|
| 254 |
+
aa_name = AA_NAMES[aa_idx] if aa_idx < len(AA_NAMES) else "UNK"
|
| 255 |
+
else:
|
| 256 |
+
aa_name = "ALA"
|
| 257 |
+
|
| 258 |
+
if xyz_np.ndim == 2:
|
| 259 |
+
x, y, z = xyz_np[i, :]
|
| 260 |
+
line = (
|
| 261 |
+
f"ATOM {atom_idx:5d} CA {aa_name:3s} A"
|
| 262 |
+
f"{i+1:4d} {x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00 C "
|
| 263 |
+
)
|
| 264 |
+
lines.append(line)
|
| 265 |
+
atom_idx += 1
|
| 266 |
+
else:
|
| 267 |
+
for j, atom_name in enumerate(["N", "CA", "C"]):
|
| 268 |
+
if j >= xyz_np.shape[1]:
|
| 269 |
+
break
|
| 270 |
+
x, y, z = xyz_np[i, j, :]
|
| 271 |
+
|
| 272 |
+
line = (
|
| 273 |
+
f"ATOM {atom_idx:5d} {atom_name:<3s} {aa_name:3s} A"
|
| 274 |
+
f"{i+1:4d} {x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00 "
|
| 275 |
+
f"{atom_name[0]:>2s} "
|
| 276 |
+
)
|
| 277 |
+
lines.append(line)
|
| 278 |
+
atom_idx += 1
|
| 279 |
+
|
| 280 |
+
lines.append("END")
|
| 281 |
+
return "\n".join(lines)
|
denoise.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Dhruv Nair. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Denoising loop for RFDiffusion3.
|
| 17 |
+
|
| 18 |
+
Implements the iterative denoising procedure from the original
|
| 19 |
+
inference_sampler.py (SampleDiffusionWithMotif.sample_diffusion_like_af3).
|
| 20 |
+
|
| 21 |
+
The loop iterates over consecutive pairs (c_t_minus_1, c_t) in the noise schedule:
|
| 22 |
+
1. Inject stochastic noise: t_hat = c_t_minus_1 * (gamma + 1), epsilon ~ N(0, noise_scale * sqrt(t_hat^2 - c_t_minus_1^2))
|
| 23 |
+
2. Call model: X_denoised = model(X_noisy, t_hat)
|
| 24 |
+
3. Euler update: X = X_noisy + step_scale * (c_t - t_hat) * (X_noisy - X_denoised) / t_hat
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
from typing import Callable, List
|
| 28 |
+
|
| 29 |
+
import torch
|
| 30 |
+
|
| 31 |
+
from diffusers.utils import logging
|
| 32 |
+
from diffusers.modular_pipelines import ModularPipeline, ModularPipelineBlocks, PipelineState
|
| 33 |
+
from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
logger = logging.get_logger(__name__)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class RFDiffusionDenoiseStep(ModularPipelineBlocks):
|
| 40 |
+
"""
|
| 41 |
+
Iterative denoising step for RFDiffusion3.
|
| 42 |
+
|
| 43 |
+
Implements the EDM stochastic sampling loop matching the original
|
| 44 |
+
SampleDiffusionWithMotif.sample_diffusion_like_af3.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
model_name = "rfdiffusion"
|
| 48 |
+
|
| 49 |
+
@property
|
| 50 |
+
def description(self) -> str:
|
| 51 |
+
return (
|
| 52 |
+
"Iteratively denoise protein structure through reverse diffusion. "
|
| 53 |
+
"Uses EDM stochastic sampling with gamma noise injection and step scaling."
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
@property
|
| 57 |
+
def expected_components(self) -> List[ComponentSpec]:
|
| 58 |
+
return [
|
| 59 |
+
ComponentSpec("transformer", description="RFDiffusion transformer for structure prediction"),
|
| 60 |
+
ComponentSpec("scheduler", description="Scheduler for noise injection and stepping"),
|
| 61 |
+
]
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def inputs(self) -> List[InputParam]:
|
| 65 |
+
return [
|
| 66 |
+
InputParam(
|
| 67 |
+
"n_recycle",
|
| 68 |
+
default=None,
|
| 69 |
+
type_hint=int,
|
| 70 |
+
description="Number of recycling iterations (None uses model default)",
|
| 71 |
+
),
|
| 72 |
+
InputParam(
|
| 73 |
+
"callback",
|
| 74 |
+
type_hint=Callable,
|
| 75 |
+
description="Optional callback function called at each step",
|
| 76 |
+
),
|
| 77 |
+
InputParam(
|
| 78 |
+
"callback_steps",
|
| 79 |
+
default=1,
|
| 80 |
+
type_hint=int,
|
| 81 |
+
description="Frequency of callback invocation",
|
| 82 |
+
),
|
| 83 |
+
InputParam("xyz", required=True, type_hint=torch.Tensor, description="Initial noised coordinates [D, L, 3]"),
|
| 84 |
+
InputParam("noise_schedule", required=True, type_hint=torch.Tensor, description="EDM noise schedule"),
|
| 85 |
+
InputParam("motif_mask", required=True, type_hint=torch.Tensor, description="Mask for fixed motif positions"),
|
| 86 |
+
]
|
| 87 |
+
|
| 88 |
+
@property
|
| 89 |
+
def intermediate_outputs(self) -> List[OutputParam]:
|
| 90 |
+
return [
|
| 91 |
+
OutputParam("xyz", type_hint=torch.Tensor, description="Denoised coordinates [D, L, 3]"),
|
| 92 |
+
OutputParam("single", type_hint=torch.Tensor, description="Single representation"),
|
| 93 |
+
OutputParam("pair", type_hint=torch.Tensor, description="Pair representation"),
|
| 94 |
+
OutputParam("sequence_logits", type_hint=torch.Tensor, description="Predicted sequence logits"),
|
| 95 |
+
OutputParam("sequence_indices", type_hint=torch.Tensor, description="Predicted sequence indices"),
|
| 96 |
+
OutputParam("trajectory", type_hint=List[torch.Tensor], description="Denoising trajectory"),
|
| 97 |
+
]
|
| 98 |
+
|
| 99 |
+
@torch.no_grad()
|
| 100 |
+
def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
|
| 101 |
+
block_state = self.get_block_state(state)
|
| 102 |
+
|
| 103 |
+
xyz = block_state.xyz
|
| 104 |
+
noise_schedule = block_state.noise_schedule
|
| 105 |
+
motif_mask = block_state.motif_mask
|
| 106 |
+
|
| 107 |
+
n_recycle = block_state.n_recycle
|
| 108 |
+
callback = block_state.callback
|
| 109 |
+
callback_steps = block_state.callback_steps or 1
|
| 110 |
+
|
| 111 |
+
X_denoised_L_traj = []
|
| 112 |
+
X_L = xyz.clone()
|
| 113 |
+
D = X_L.shape[0]
|
| 114 |
+
device = X_L.device
|
| 115 |
+
|
| 116 |
+
# Ensure all tensors are on the same device as xyz
|
| 117 |
+
noise_schedule = noise_schedule.to(device)
|
| 118 |
+
if motif_mask is not None:
|
| 119 |
+
motif_mask = motif_mask.to(device)
|
| 120 |
+
|
| 121 |
+
single = None
|
| 122 |
+
pair = None
|
| 123 |
+
sequence_logits = None
|
| 124 |
+
sequence_indices = None
|
| 125 |
+
|
| 126 |
+
has_transformer = hasattr(components, "transformer") and components.transformer is not None
|
| 127 |
+
has_scheduler = hasattr(components, "scheduler") and components.scheduler is not None
|
| 128 |
+
|
| 129 |
+
# Iterate over consecutive pairs (c_t_minus_1, c_t) in the noise schedule
|
| 130 |
+
# noise_schedule goes from high noise to low noise
|
| 131 |
+
for step_num in range(len(noise_schedule) - 1):
|
| 132 |
+
c_t_minus_1 = noise_schedule[step_num]
|
| 133 |
+
c_t = noise_schedule[step_num + 1]
|
| 134 |
+
|
| 135 |
+
# Step 1: Inject stochastic noise (matching original sampler)
|
| 136 |
+
if has_scheduler:
|
| 137 |
+
X_noisy_L, t_hat = components.scheduler.add_noise(
|
| 138 |
+
X_L, c_t_minus_1, c_t, motif_mask=motif_mask
|
| 139 |
+
)
|
| 140 |
+
else:
|
| 141 |
+
X_noisy_L = X_L
|
| 142 |
+
t_hat = c_t_minus_1
|
| 143 |
+
|
| 144 |
+
# Step 2: Model forward pass
|
| 145 |
+
if has_transformer:
|
| 146 |
+
# t_hat is a scalar, tile to batch dimension
|
| 147 |
+
t_batch = (t_hat.to(device).expand(D) if isinstance(t_hat, torch.Tensor)
|
| 148 |
+
else torch.full((D,), t_hat, device=device))
|
| 149 |
+
|
| 150 |
+
output = components.transformer(
|
| 151 |
+
xyz_noisy=X_noisy_L,
|
| 152 |
+
t=t_batch,
|
| 153 |
+
motif_mask=motif_mask,
|
| 154 |
+
n_recycle=n_recycle,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
X_denoised_L = output.xyz
|
| 158 |
+
single = output.single
|
| 159 |
+
pair = output.pair
|
| 160 |
+
sequence_logits = output.sequence_logits
|
| 161 |
+
sequence_indices = output.sequence_indices
|
| 162 |
+
else:
|
| 163 |
+
X_denoised_L = X_noisy_L
|
| 164 |
+
|
| 165 |
+
# Step 3: Euler update with step_scale (matching original sampler)
|
| 166 |
+
if has_scheduler:
|
| 167 |
+
X_L = components.scheduler.step(
|
| 168 |
+
xyz_pred=X_denoised_L,
|
| 169 |
+
xyz_noisy=X_noisy_L,
|
| 170 |
+
c_t_minus_1=c_t_minus_1,
|
| 171 |
+
c_t=c_t,
|
| 172 |
+
motif_mask=motif_mask,
|
| 173 |
+
)
|
| 174 |
+
else:
|
| 175 |
+
# Fallback simple Euler step
|
| 176 |
+
delta_L = (X_noisy_L - X_denoised_L) / (t_hat + 1e-8)
|
| 177 |
+
d_t = c_t - t_hat
|
| 178 |
+
X_L = X_noisy_L + d_t * delta_L
|
| 179 |
+
|
| 180 |
+
X_denoised_L_traj.append(X_denoised_L.clone())
|
| 181 |
+
|
| 182 |
+
if callback is not None and step_num % callback_steps == 0:
|
| 183 |
+
callback(step_num, c_t_minus_1, X_L)
|
| 184 |
+
|
| 185 |
+
block_state.xyz = X_L
|
| 186 |
+
block_state.single = single
|
| 187 |
+
block_state.pair = pair
|
| 188 |
+
block_state.sequence_logits = sequence_logits
|
| 189 |
+
block_state.sequence_indices = sequence_indices
|
| 190 |
+
block_state.trajectory = X_denoised_L_traj
|
| 191 |
+
|
| 192 |
+
self.set_block_state(state, block_state)
|
| 193 |
+
return components, state
|
modular_blocks.py
ADDED
|
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Dhruv Nair. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
from typing import List, Optional
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from diffusers.utils import logging
|
| 21 |
+
from diffusers.modular_pipelines import (
|
| 22 |
+
AutoPipelineBlocks,
|
| 23 |
+
ModularPipeline,
|
| 24 |
+
ModularPipelineBlocks,
|
| 25 |
+
PipelineState,
|
| 26 |
+
SequentialPipelineBlocks,
|
| 27 |
+
)
|
| 28 |
+
from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, InsertableDict, OutputParam
|
| 29 |
+
|
| 30 |
+
from .before_denoise import (
|
| 31 |
+
RFDiffusionInputStep,
|
| 32 |
+
RFDiffusionPrepareLatentsStep,
|
| 33 |
+
RFDiffusionSetTimestepsStep,
|
| 34 |
+
)
|
| 35 |
+
from .decoders import RFDiffusionDecodeStep
|
| 36 |
+
from .denoise import RFDiffusionDenoiseStep
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
logger = logging.get_logger(__name__)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# βββ Amino acid mappings (used by MPNN blocks) βββββββββββββββββββββββββ
|
| 43 |
+
|
| 44 |
+
THREE_TO_ONE = {
|
| 45 |
+
"ALA": "A", "ARG": "R", "ASN": "N", "ASP": "D", "CYS": "C",
|
| 46 |
+
"GLN": "Q", "GLU": "E", "GLY": "G", "HIS": "H", "ILE": "I",
|
| 47 |
+
"LEU": "L", "LYS": "K", "MET": "M", "PHE": "F", "PRO": "P",
|
| 48 |
+
"SER": "S", "THR": "T", "TRP": "W", "TYR": "Y", "VAL": "V",
|
| 49 |
+
"UNK": "X",
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
AA_NAMES = list(THREE_TO_ONE.keys())
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 56 |
+
# RFDiffusion blocks
|
| 57 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class RFDiffusionBeforeDenoiseStep(SequentialPipelineBlocks):
|
| 61 |
+
"""Sequential block for pre-denoising preparation."""
|
| 62 |
+
|
| 63 |
+
block_classes = [
|
| 64 |
+
RFDiffusionInputStep,
|
| 65 |
+
RFDiffusionSetTimestepsStep,
|
| 66 |
+
RFDiffusionPrepareLatentsStep,
|
| 67 |
+
]
|
| 68 |
+
block_names = ["input", "set_timesteps", "prepare_latents"]
|
| 69 |
+
|
| 70 |
+
@property
|
| 71 |
+
def description(self):
|
| 72 |
+
return (
|
| 73 |
+
"Before denoise step that prepares the inputs for the denoise step.\n"
|
| 74 |
+
"This is a sequential pipeline blocks:\n"
|
| 75 |
+
" - `RFDiffusionInputStep` processes contigs and prepares input features\n"
|
| 76 |
+
" - `RFDiffusionSetTimestepsStep` sets up the diffusion timesteps\n"
|
| 77 |
+
" - `RFDiffusionPrepareLatentsStep` initializes noised coordinates\n"
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class RFDiffusionAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
| 82 |
+
"""Auto-select before denoise step based on task."""
|
| 83 |
+
|
| 84 |
+
block_classes = [RFDiffusionBeforeDenoiseStep]
|
| 85 |
+
block_names = ["unconditional"]
|
| 86 |
+
block_trigger_inputs = [None]
|
| 87 |
+
|
| 88 |
+
@property
|
| 89 |
+
def description(self):
|
| 90 |
+
return (
|
| 91 |
+
"Before denoise step that prepares the inputs for the denoise step.\n"
|
| 92 |
+
"This is an auto pipeline block for protein structure generation.\n"
|
| 93 |
+
" - `RFDiffusionBeforeDenoiseStep` (unconditional) is used.\n"
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class RFDiffusionAutoDenoiseStep(AutoPipelineBlocks):
|
| 98 |
+
"""Auto-select denoise step."""
|
| 99 |
+
|
| 100 |
+
block_classes = [RFDiffusionDenoiseStep]
|
| 101 |
+
block_names = ["denoise"]
|
| 102 |
+
block_trigger_inputs = [None]
|
| 103 |
+
|
| 104 |
+
@property
|
| 105 |
+
def description(self) -> str:
|
| 106 |
+
return (
|
| 107 |
+
"Denoise step that iteratively denoises the protein structure. "
|
| 108 |
+
"This is an auto pipeline block for protein structure generation. "
|
| 109 |
+
" - `RFDiffusionDenoiseStep` (denoise) for structure generation."
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class RFDiffusionAutoDecodeStep(AutoPipelineBlocks):
|
| 114 |
+
"""Auto-select decode step."""
|
| 115 |
+
|
| 116 |
+
block_classes = [RFDiffusionDecodeStep]
|
| 117 |
+
block_names = ["decode"]
|
| 118 |
+
block_trigger_inputs = [None]
|
| 119 |
+
|
| 120 |
+
@property
|
| 121 |
+
def description(self):
|
| 122 |
+
return "Decode step that converts denoised coordinates to PDB output.\n - `RFDiffusionDecodeStep`"
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 126 |
+
# MPNN blocks (defined before RFDiffusionAutoBlocks which references them)
|
| 127 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
@dataclass
|
| 131 |
+
class MPNNPipelineOutput:
|
| 132 |
+
"""Output from ProteinMPNN / LigandMPNN sequence design."""
|
| 133 |
+
|
| 134 |
+
designed_sequence: str
|
| 135 |
+
sequence_indices: torch.Tensor # [B, L] token indices
|
| 136 |
+
sequence_logits: torch.Tensor # [B, L, n_vocab] logits
|
| 137 |
+
xyz: torch.Tensor # [B, L, 3] input structure (passed through)
|
| 138 |
+
pdb_string: Optional[str] = None # PDB with designed sequence
|
| 139 |
+
sequence_recovery: Optional[float] = None
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class MPNNSequenceDesignStep(ModularPipelineBlocks):
|
| 143 |
+
"""
|
| 144 |
+
Design sequences for a protein backbone using ProteinMPNN / LigandMPNN.
|
| 145 |
+
|
| 146 |
+
Takes ``xyz`` coordinates (typically from an upstream RFDiffusion denoise
|
| 147 |
+
step) and runs the ``MPNNModel`` to produce amino acid sequences for
|
| 148 |
+
the designable regions.
|
| 149 |
+
|
| 150 |
+
When no ``mpnn`` component is loaded, falls back to using the sequence
|
| 151 |
+
predictions from upstream RFDiffusion (or glycine everywhere).
|
| 152 |
+
"""
|
| 153 |
+
|
| 154 |
+
model_name = "mpnn"
|
| 155 |
+
|
| 156 |
+
@property
|
| 157 |
+
def description(self) -> str:
|
| 158 |
+
return (
|
| 159 |
+
"Design amino acid sequences for protein backbones using "
|
| 160 |
+
"ProteinMPNN or LigandMPNN. Accepts structure coordinates "
|
| 161 |
+
"from an upstream diffusion step."
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
@property
|
| 165 |
+
def expected_components(self) -> List[ComponentSpec]:
|
| 166 |
+
return [
|
| 167 |
+
ComponentSpec("mpnn", description="MPNNModel (ProteinMPNN or LigandMPNN)"),
|
| 168 |
+
]
|
| 169 |
+
|
| 170 |
+
@property
|
| 171 |
+
def inputs(self) -> List[InputParam]:
|
| 172 |
+
return [
|
| 173 |
+
InputParam(
|
| 174 |
+
"xyz", required=True, type_hint=torch.Tensor,
|
| 175 |
+
description="Protein backbone coordinates [B, L, 3] (CA atoms)",
|
| 176 |
+
),
|
| 177 |
+
InputParam(
|
| 178 |
+
"motif_mask", type_hint=torch.Tensor,
|
| 179 |
+
description="Mask for fixed/motif positions [L]. True = fixed sequence.",
|
| 180 |
+
),
|
| 181 |
+
InputParam(
|
| 182 |
+
"sequence_indices", type_hint=torch.Tensor,
|
| 183 |
+
description="Known sequence indices for motif positions [B, L] (from RFDiffusion)",
|
| 184 |
+
),
|
| 185 |
+
InputParam(
|
| 186 |
+
"temperature", default=0.1, type_hint=float,
|
| 187 |
+
description="Sampling temperature (lower = more deterministic)",
|
| 188 |
+
),
|
| 189 |
+
InputParam(
|
| 190 |
+
"num_designs", default=1, type_hint=int,
|
| 191 |
+
description="Number of sequence designs to generate per structure",
|
| 192 |
+
),
|
| 193 |
+
InputParam(
|
| 194 |
+
"output_type", default="tensor", type_hint=str,
|
| 195 |
+
description="'tensor', 'pdb', or 'cif'",
|
| 196 |
+
),
|
| 197 |
+
InputParam(
|
| 198 |
+
"output_path", type_hint=str,
|
| 199 |
+
description="Path to save designed PDB",
|
| 200 |
+
),
|
| 201 |
+
]
|
| 202 |
+
|
| 203 |
+
@property
|
| 204 |
+
def intermediate_outputs(self) -> List[OutputParam]:
|
| 205 |
+
return [
|
| 206 |
+
OutputParam(
|
| 207 |
+
"mpnn_output", type_hint=MPNNPipelineOutput,
|
| 208 |
+
description="MPNN sequence design output",
|
| 209 |
+
),
|
| 210 |
+
]
|
| 211 |
+
|
| 212 |
+
@torch.no_grad()
|
| 213 |
+
def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
|
| 214 |
+
block_state = self.get_block_state(state)
|
| 215 |
+
|
| 216 |
+
xyz = block_state.xyz
|
| 217 |
+
motif_mask = block_state.motif_mask
|
| 218 |
+
known_seq = block_state.sequence_indices
|
| 219 |
+
temperature = block_state.temperature or 0.1
|
| 220 |
+
output_type = block_state.output_type or "tensor"
|
| 221 |
+
output_path = block_state.output_path
|
| 222 |
+
|
| 223 |
+
B, L, _ = xyz.shape
|
| 224 |
+
device = xyz.device
|
| 225 |
+
|
| 226 |
+
has_mpnn = hasattr(components, "mpnn") and components.mpnn is not None
|
| 227 |
+
|
| 228 |
+
if has_mpnn:
|
| 229 |
+
sequence_logits, sequence_indices = self._run_mpnn(
|
| 230 |
+
components.mpnn, xyz, motif_mask, known_seq, temperature,
|
| 231 |
+
)
|
| 232 |
+
else:
|
| 233 |
+
if known_seq is not None:
|
| 234 |
+
sequence_indices = known_seq
|
| 235 |
+
else:
|
| 236 |
+
sequence_indices = torch.full((B, L), 7, dtype=torch.long, device=device) # GLY
|
| 237 |
+
sequence_logits = torch.zeros(B, L, len(AA_NAMES), device=device)
|
| 238 |
+
sequence_logits.scatter_(2, sequence_indices.unsqueeze(-1), 1.0)
|
| 239 |
+
|
| 240 |
+
seq_list = sequence_indices[0].cpu().tolist()
|
| 241 |
+
designed_sequence = "".join(
|
| 242 |
+
THREE_TO_ONE.get(AA_NAMES[min(idx, len(AA_NAMES) - 1)], "X")
|
| 243 |
+
for idx in seq_list
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
pdb_string = None
|
| 247 |
+
if output_type in ("pdb",):
|
| 248 |
+
pdb_string = self._coords_to_pdb(xyz[0], sequence_indices[0])
|
| 249 |
+
if output_path:
|
| 250 |
+
import os
|
| 251 |
+
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
|
| 252 |
+
with open(output_path, "w") as f:
|
| 253 |
+
f.write(pdb_string)
|
| 254 |
+
|
| 255 |
+
output = MPNNPipelineOutput(
|
| 256 |
+
designed_sequence=designed_sequence,
|
| 257 |
+
sequence_indices=sequence_indices,
|
| 258 |
+
sequence_logits=sequence_logits,
|
| 259 |
+
xyz=xyz,
|
| 260 |
+
pdb_string=pdb_string,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
block_state.mpnn_output = output
|
| 264 |
+
self.set_block_state(state, block_state)
|
| 265 |
+
return components, state
|
| 266 |
+
|
| 267 |
+
def _run_mpnn(self, mpnn, xyz, motif_mask, known_seq, temperature):
|
| 268 |
+
"""Run the MPNNModel wrapper on backbone coordinates."""
|
| 269 |
+
B, L, _ = xyz.shape
|
| 270 |
+
device = xyz.device
|
| 271 |
+
dtype = xyz.dtype
|
| 272 |
+
|
| 273 |
+
ca = xyz
|
| 274 |
+
n_offset = torch.tensor([-1.458, 0.0, 0.0], device=device, dtype=dtype)
|
| 275 |
+
c_offset = torch.tensor([0.550, 1.424, 0.0], device=device, dtype=dtype)
|
| 276 |
+
o_offset = torch.tensor([0.550, 2.500, 0.0], device=device, dtype=dtype)
|
| 277 |
+
|
| 278 |
+
X = torch.stack([
|
| 279 |
+
ca + n_offset, ca, ca + c_offset, ca + o_offset,
|
| 280 |
+
], dim=2)
|
| 281 |
+
|
| 282 |
+
if motif_mask is not None:
|
| 283 |
+
designed_mask = ~motif_mask.unsqueeze(0).expand(B, -1)
|
| 284 |
+
else:
|
| 285 |
+
designed_mask = None
|
| 286 |
+
|
| 287 |
+
output = mpnn(
|
| 288 |
+
X=X, S=known_seq, designed_residue_mask=designed_mask, temperature=temperature,
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
logits = output.sequence_logits
|
| 292 |
+
indices = output.sequence_indices
|
| 293 |
+
|
| 294 |
+
if motif_mask is not None and known_seq is not None:
|
| 295 |
+
indices[:, motif_mask] = known_seq[:, motif_mask]
|
| 296 |
+
|
| 297 |
+
return logits, indices
|
| 298 |
+
|
| 299 |
+
def _coords_to_pdb(self, xyz: torch.Tensor, seq: torch.Tensor) -> str:
|
| 300 |
+
xyz_np = xyz.cpu().numpy()
|
| 301 |
+
seq_np = seq.cpu().numpy()
|
| 302 |
+
L = xyz_np.shape[0]
|
| 303 |
+
lines = []
|
| 304 |
+
for i in range(L):
|
| 305 |
+
aa_idx = int(seq_np[i])
|
| 306 |
+
aa_name = AA_NAMES[min(aa_idx, len(AA_NAMES) - 1)]
|
| 307 |
+
x, y, z = xyz_np[i, :]
|
| 308 |
+
lines.append(
|
| 309 |
+
f"ATOM {i+1:5d} CA {aa_name:3s} A{i+1:4d} "
|
| 310 |
+
f"{x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00 C "
|
| 311 |
+
)
|
| 312 |
+
lines.append("END")
|
| 313 |
+
return "\n".join(lines)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
class MPNNAutoDesignStep(AutoPipelineBlocks):
|
| 317 |
+
"""Auto-select MPNN design step."""
|
| 318 |
+
|
| 319 |
+
block_classes = [MPNNSequenceDesignStep]
|
| 320 |
+
block_names = ["sequence_design"]
|
| 321 |
+
block_trigger_inputs = [None]
|
| 322 |
+
|
| 323 |
+
@property
|
| 324 |
+
def description(self) -> str:
|
| 325 |
+
return "Sequence design using ProteinMPNN or LigandMPNN."
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 329 |
+
# Top-level pipeline blocks
|
| 330 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
class RFDiffusionAutoBlocks(SequentialPipelineBlocks):
|
| 334 |
+
"""
|
| 335 |
+
Full protein design pipeline: RFDiffusion3 + optional ProteinMPNN/LigandMPNN.
|
| 336 |
+
|
| 337 |
+
The active workflow is selected by trigger inputs:
|
| 338 |
+
- ``contigs`` only β structure generation
|
| 339 |
+
- ``contigs`` + ``temperature`` β structure + sequence design
|
| 340 |
+
- ``contigs`` + ``input_xyz`` + ``temperature`` β motif-conditioned + sequence design
|
| 341 |
+
|
| 342 |
+
The MPNN step is skipped when ``temperature`` is not provided or when
|
| 343 |
+
no ``mpnn`` component is loaded.
|
| 344 |
+
"""
|
| 345 |
+
|
| 346 |
+
block_classes = [
|
| 347 |
+
RFDiffusionAutoBeforeDenoiseStep,
|
| 348 |
+
RFDiffusionAutoDenoiseStep,
|
| 349 |
+
RFDiffusionAutoDecodeStep,
|
| 350 |
+
MPNNAutoDesignStep,
|
| 351 |
+
]
|
| 352 |
+
block_names = [
|
| 353 |
+
"before_denoise",
|
| 354 |
+
"denoise",
|
| 355 |
+
"decoder",
|
| 356 |
+
"sequence_design",
|
| 357 |
+
]
|
| 358 |
+
|
| 359 |
+
_workflow_map = {
|
| 360 |
+
"structure_only": {
|
| 361 |
+
"contigs": True,
|
| 362 |
+
},
|
| 363 |
+
"structure_and_sequence": {
|
| 364 |
+
"contigs": True,
|
| 365 |
+
"temperature": True,
|
| 366 |
+
},
|
| 367 |
+
"motif_structure_and_sequence": {
|
| 368 |
+
"contigs": True,
|
| 369 |
+
"input_xyz": True,
|
| 370 |
+
"temperature": True,
|
| 371 |
+
},
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
@property
|
| 375 |
+
def description(self):
|
| 376 |
+
return (
|
| 377 |
+
"Modular pipeline for protein design using RFDiffusion3.\n"
|
| 378 |
+
"Workflows:\n"
|
| 379 |
+
" - structure_only: backbone generation\n"
|
| 380 |
+
" - structure_and_sequence: backbone + MPNN sequence design\n"
|
| 381 |
+
" - motif_structure_and_sequence: motif-conditioned + MPNN\n"
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 386 |
+
# Block registries
|
| 387 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
UNCONDITIONAL_BLOCKS = InsertableDict(
|
| 391 |
+
[
|
| 392 |
+
("input", RFDiffusionInputStep),
|
| 393 |
+
("set_timesteps", RFDiffusionSetTimestepsStep),
|
| 394 |
+
("prepare_latents", RFDiffusionPrepareLatentsStep),
|
| 395 |
+
("denoise", RFDiffusionDenoiseStep),
|
| 396 |
+
("decode", RFDiffusionDecodeStep),
|
| 397 |
+
("sequence_design", MPNNSequenceDesignStep),
|
| 398 |
+
]
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
AUTO_BLOCKS = InsertableDict(
|
| 402 |
+
[
|
| 403 |
+
("before_denoise", RFDiffusionAutoBeforeDenoiseStep),
|
| 404 |
+
("denoise", RFDiffusionAutoDenoiseStep),
|
| 405 |
+
("decode", RFDiffusionAutoDecodeStep),
|
| 406 |
+
("sequence_design", MPNNAutoDesignStep),
|
| 407 |
+
]
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
ALL_BLOCKS = {
|
| 411 |
+
"unconditional": UNCONDITIONAL_BLOCKS,
|
| 412 |
+
"auto": AUTO_BLOCKS,
|
| 413 |
+
}
|
modular_config.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "RFDiffusionAutoBlocks",
|
| 3 |
+
"_diffusers_version": "0.37.0.dev0",
|
| 4 |
+
"auto_map": {
|
| 5 |
+
"ModularPipelineBlocks": "modular_blocks.RFDiffusionAutoBlocks"
|
| 6 |
+
}
|
| 7 |
+
}
|
modular_model_index.json
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_blocks_class_name": "RFDiffusionAutoBlocks",
|
| 3 |
+
"_class_name": "ModularPipeline",
|
| 4 |
+
"_diffusers_version": "0.37.0.dev0",
|
| 5 |
+
"default_num_residues": 100,
|
| 6 |
+
"default_num_timesteps": 200,
|
| 7 |
+
"sigma_data": 16.0,
|
| 8 |
+
"transformer": [
|
| 9 |
+
null,
|
| 10 |
+
null,
|
| 11 |
+
{
|
| 12 |
+
"pretrained_model_name_or_path": "dn6/RFDiffusion-3",
|
| 13 |
+
"subfolder": "transformer",
|
| 14 |
+
"type_hint": [
|
| 15 |
+
"diffusers",
|
| 16 |
+
"AutoModel"
|
| 17 |
+
],
|
| 18 |
+
"revision": null,
|
| 19 |
+
"variant": null
|
| 20 |
+
}
|
| 21 |
+
],
|
| 22 |
+
"scheduler": [
|
| 23 |
+
null,
|
| 24 |
+
null,
|
| 25 |
+
{
|
| 26 |
+
"pretrained_model_name_or_path": "dn6/RFDiffusion-3",
|
| 27 |
+
"subfolder": "scheduler",
|
| 28 |
+
"type_hint": [
|
| 29 |
+
"diffusers",
|
| 30 |
+
"AutoModel"
|
| 31 |
+
],
|
| 32 |
+
"revision": null,
|
| 33 |
+
"variant": null,
|
| 34 |
+
"default_creation_method": "from_config"
|
| 35 |
+
}
|
| 36 |
+
]
|
| 37 |
+
}
|
mpnn/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Dhruv Nair. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .model_mpnn import MPNNModel, MPNNModelOutput
|
mpnn/config.json
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "MPNNModel",
|
| 3 |
+
"_diffusers_version": "0.37.0.dev0",
|
| 4 |
+
"auto_map": {
|
| 5 |
+
"AutoModel": "model_mpnn.MPNNModel"
|
| 6 |
+
},
|
| 7 |
+
"model_type": "protein_mpnn",
|
| 8 |
+
"hidden_dim": 128,
|
| 9 |
+
"num_encoder_layers": 3,
|
| 10 |
+
"num_decoder_layers": 3,
|
| 11 |
+
"num_neighbors": 48,
|
| 12 |
+
"dropout_rate": 0.1,
|
| 13 |
+
"num_positional_embeddings": 16,
|
| 14 |
+
"min_rbf_mean": 2.0,
|
| 15 |
+
"max_rbf_mean": 22.0,
|
| 16 |
+
"num_rbf": 16
|
| 17 |
+
}
|
mpnn/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6f94bbbaa904c3554d3b10397dce3c90e01adb725fd80e3275ff48f11cc4745f
|
| 3 |
+
size 6653924
|
mpnn/model_mpnn.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Dhruv Nair. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
ProteinMPNN / LigandMPNN model wrapper.
|
| 17 |
+
|
| 18 |
+
A thin diffusers-compatible wrapper around the foundry MPNN model,
|
| 19 |
+
following the same pattern as the transformer and scheduler wrappers.
|
| 20 |
+
Reuses the foundry model implementation directly, adding only the
|
| 21 |
+
ModelMixin/ConfigMixin interface for diffusers integration.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from dataclasses import dataclass
|
| 25 |
+
from typing import Optional
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
import torch.nn as nn
|
| 29 |
+
|
| 30 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 31 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 32 |
+
|
| 33 |
+
from mpnn.model.mpnn import LigandMPNN, ProteinMPNN
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
MODEL_CLASSES = {
|
| 37 |
+
"protein_mpnn": ProteinMPNN,
|
| 38 |
+
"ligand_mpnn": LigandMPNN,
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class MPNNModelOutput:
|
| 44 |
+
"""Output from the MPNN model wrapper."""
|
| 45 |
+
|
| 46 |
+
sequence_logits: torch.Tensor # [B, L, n_vocab]
|
| 47 |
+
sequence_indices: torch.Tensor # [B, L]
|
| 48 |
+
decoder_features: dict # full decoder output dict
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class MPNNModel(ModelMixin, ConfigMixin):
|
| 52 |
+
"""
|
| 53 |
+
Diffusers-compatible wrapper around the foundry ProteinMPNN / LigandMPNN.
|
| 54 |
+
|
| 55 |
+
Wraps `mpnn.model.mpnn.ProteinMPNN` (or `LigandMPNN`) to provide a
|
| 56 |
+
diffusers ModelMixin/ConfigMixin interface. All model logic is delegated
|
| 57 |
+
to the foundry implementation.
|
| 58 |
+
|
| 59 |
+
State dict keys match the foundry checkpoint format via the `model.*`
|
| 60 |
+
prefix (stripped on load).
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
config_name = "config.json"
|
| 64 |
+
|
| 65 |
+
@register_to_config
|
| 66 |
+
def __init__(
|
| 67 |
+
self,
|
| 68 |
+
model_type: str = "protein_mpnn",
|
| 69 |
+
hidden_dim: int = 128,
|
| 70 |
+
num_encoder_layers: int = 3,
|
| 71 |
+
num_decoder_layers: int = 3,
|
| 72 |
+
num_neighbors: int = 48,
|
| 73 |
+
dropout_rate: float = 0.1,
|
| 74 |
+
num_positional_embeddings: int = 16,
|
| 75 |
+
min_rbf_mean: float = 2.0,
|
| 76 |
+
max_rbf_mean: float = 22.0,
|
| 77 |
+
num_rbf: int = 16,
|
| 78 |
+
# LigandMPNN-specific
|
| 79 |
+
num_context_atoms: int = 25,
|
| 80 |
+
num_context_encoding_layers: int = 2,
|
| 81 |
+
):
|
| 82 |
+
super().__init__()
|
| 83 |
+
|
| 84 |
+
model_cls = MODEL_CLASSES.get(model_type)
|
| 85 |
+
if model_cls is None:
|
| 86 |
+
raise ValueError(
|
| 87 |
+
f"Unknown model_type '{model_type}'. "
|
| 88 |
+
f"Choose from: {list(MODEL_CLASSES.keys())}"
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
common_kwargs = dict(
|
| 92 |
+
num_node_features=hidden_dim,
|
| 93 |
+
num_edge_features=hidden_dim,
|
| 94 |
+
hidden_dim=hidden_dim,
|
| 95 |
+
num_encoder_layers=num_encoder_layers,
|
| 96 |
+
num_decoder_layers=num_decoder_layers,
|
| 97 |
+
num_neighbors=num_neighbors,
|
| 98 |
+
dropout_rate=dropout_rate,
|
| 99 |
+
num_positional_embeddings=num_positional_embeddings,
|
| 100 |
+
min_rbf_mean=min_rbf_mean,
|
| 101 |
+
max_rbf_mean=max_rbf_mean,
|
| 102 |
+
num_rbf=num_rbf,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
if model_type == "ligand_mpnn":
|
| 106 |
+
common_kwargs["num_context_atoms"] = num_context_atoms
|
| 107 |
+
common_kwargs["num_context_encoding_layers"] = num_context_encoding_layers
|
| 108 |
+
|
| 109 |
+
self.model = model_cls(**common_kwargs)
|
| 110 |
+
|
| 111 |
+
def forward(
|
| 112 |
+
self,
|
| 113 |
+
X: torch.Tensor,
|
| 114 |
+
S: Optional[torch.Tensor] = None,
|
| 115 |
+
residue_mask: Optional[torch.Tensor] = None,
|
| 116 |
+
designed_residue_mask: Optional[torch.Tensor] = None,
|
| 117 |
+
chain_labels: Optional[torch.Tensor] = None,
|
| 118 |
+
R_idx: Optional[torch.Tensor] = None,
|
| 119 |
+
temperature: float = 0.1,
|
| 120 |
+
**kwargs,
|
| 121 |
+
) -> MPNNModelOutput:
|
| 122 |
+
"""
|
| 123 |
+
Run ProteinMPNN / LigandMPNN sequence design.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
X: Backbone atom coordinates [B, L, num_atoms, 3].
|
| 127 |
+
For ProteinMPNN: num_atoms=4 (N, CA, C, O).
|
| 128 |
+
S: Ground-truth sequence tokens [B, L] (optional, for teacher forcing).
|
| 129 |
+
residue_mask: Valid residue mask [B, L] (default: all valid).
|
| 130 |
+
designed_residue_mask: Which residues to design [B, L] (default: all).
|
| 131 |
+
chain_labels: Chain identifiers [B, L] (default: single chain).
|
| 132 |
+
R_idx: Residue indices [B, L] (default: 0..L-1).
|
| 133 |
+
temperature: Sampling temperature (default: 0.1).
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
MPNNModelOutput with sequence logits and sampled indices.
|
| 137 |
+
"""
|
| 138 |
+
B, L = X.shape[0], X.shape[1]
|
| 139 |
+
device = X.device
|
| 140 |
+
|
| 141 |
+
if S is None:
|
| 142 |
+
S = torch.zeros(B, L, dtype=torch.long, device=device)
|
| 143 |
+
if residue_mask is None:
|
| 144 |
+
residue_mask = torch.ones(B, L, dtype=torch.bool, device=device)
|
| 145 |
+
if designed_residue_mask is None:
|
| 146 |
+
designed_residue_mask = torch.ones(B, L, dtype=torch.bool, device=device)
|
| 147 |
+
if chain_labels is None:
|
| 148 |
+
chain_labels = torch.zeros(B, L, dtype=torch.long, device=device)
|
| 149 |
+
if R_idx is None:
|
| 150 |
+
R_idx = torch.arange(L, device=device).unsqueeze(0).expand(B, -1)
|
| 151 |
+
|
| 152 |
+
# Atom mask: mark all atoms as valid based on coordinate presence
|
| 153 |
+
X_m = (X.abs().sum(dim=-1) > 0).float() # [B, L, num_atoms]
|
| 154 |
+
|
| 155 |
+
network_input = {
|
| 156 |
+
"X": X,
|
| 157 |
+
"X_m": X_m,
|
| 158 |
+
"S": S,
|
| 159 |
+
"R_idx": R_idx,
|
| 160 |
+
"chain_labels": chain_labels,
|
| 161 |
+
"residue_mask": residue_mask,
|
| 162 |
+
"designed_residue_mask": designed_residue_mask,
|
| 163 |
+
"temperature": temperature,
|
| 164 |
+
**kwargs,
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
output = self.model(network_input)
|
| 168 |
+
|
| 169 |
+
logits = output["decoder_features"]["logits"] # [B, L, n_vocab]
|
| 170 |
+
S_sampled = output["decoder_features"].get(
|
| 171 |
+
"S_sampled", logits.argmax(dim=-1)
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
return MPNNModelOutput(
|
| 175 |
+
sequence_logits=logits,
|
| 176 |
+
sequence_indices=S_sampled,
|
| 177 |
+
decoder_features=output["decoder_features"],
|
| 178 |
+
)
|
mpnn_ligand/config.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "MPNNModel",
|
| 3 |
+
"_diffusers_version": "0.37.0.dev0",
|
| 4 |
+
"auto_map": {
|
| 5 |
+
"AutoModel": "model_mpnn.MPNNModel"
|
| 6 |
+
},
|
| 7 |
+
"model_type": "ligand_mpnn",
|
| 8 |
+
"hidden_dim": 128,
|
| 9 |
+
"num_encoder_layers": 3,
|
| 10 |
+
"num_decoder_layers": 3,
|
| 11 |
+
"num_neighbors": 32,
|
| 12 |
+
"dropout_rate": 0.1,
|
| 13 |
+
"num_positional_embeddings": 16,
|
| 14 |
+
"min_rbf_mean": 2.0,
|
| 15 |
+
"max_rbf_mean": 22.0,
|
| 16 |
+
"num_rbf": 16,
|
| 17 |
+
"num_context_atoms": 25,
|
| 18 |
+
"num_context_encoding_layers": 2
|
| 19 |
+
}
|
mpnn_ligand/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:63d94e809089821f2f01c965b8f3a5e736d6b9b53c378fc27d42f47376e18964
|
| 3 |
+
size 10497908
|
mpnn_ligand/model_mpnn.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Dhruv Nair. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
ProteinMPNN / LigandMPNN model wrapper.
|
| 17 |
+
|
| 18 |
+
A thin diffusers-compatible wrapper around the foundry MPNN model,
|
| 19 |
+
following the same pattern as the transformer and scheduler wrappers.
|
| 20 |
+
Reuses the foundry model implementation directly, adding only the
|
| 21 |
+
ModelMixin/ConfigMixin interface for diffusers integration.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from dataclasses import dataclass
|
| 25 |
+
from typing import Optional
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
import torch.nn as nn
|
| 29 |
+
|
| 30 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 31 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 32 |
+
|
| 33 |
+
from mpnn.model.mpnn import LigandMPNN, ProteinMPNN
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
MODEL_CLASSES = {
|
| 37 |
+
"protein_mpnn": ProteinMPNN,
|
| 38 |
+
"ligand_mpnn": LigandMPNN,
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class MPNNModelOutput:
|
| 44 |
+
"""Output from the MPNN model wrapper."""
|
| 45 |
+
|
| 46 |
+
sequence_logits: torch.Tensor # [B, L, n_vocab]
|
| 47 |
+
sequence_indices: torch.Tensor # [B, L]
|
| 48 |
+
decoder_features: dict # full decoder output dict
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class MPNNModel(ModelMixin, ConfigMixin):
|
| 52 |
+
"""
|
| 53 |
+
Diffusers-compatible wrapper around the foundry ProteinMPNN / LigandMPNN.
|
| 54 |
+
|
| 55 |
+
Wraps `mpnn.model.mpnn.ProteinMPNN` (or `LigandMPNN`) to provide a
|
| 56 |
+
diffusers ModelMixin/ConfigMixin interface. All model logic is delegated
|
| 57 |
+
to the foundry implementation.
|
| 58 |
+
|
| 59 |
+
State dict keys match the foundry checkpoint format via the `model.*`
|
| 60 |
+
prefix (stripped on load).
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
config_name = "config.json"
|
| 64 |
+
|
| 65 |
+
@register_to_config
|
| 66 |
+
def __init__(
|
| 67 |
+
self,
|
| 68 |
+
model_type: str = "protein_mpnn",
|
| 69 |
+
hidden_dim: int = 128,
|
| 70 |
+
num_encoder_layers: int = 3,
|
| 71 |
+
num_decoder_layers: int = 3,
|
| 72 |
+
num_neighbors: int = 48,
|
| 73 |
+
dropout_rate: float = 0.1,
|
| 74 |
+
num_positional_embeddings: int = 16,
|
| 75 |
+
min_rbf_mean: float = 2.0,
|
| 76 |
+
max_rbf_mean: float = 22.0,
|
| 77 |
+
num_rbf: int = 16,
|
| 78 |
+
# LigandMPNN-specific
|
| 79 |
+
num_context_atoms: int = 25,
|
| 80 |
+
num_context_encoding_layers: int = 2,
|
| 81 |
+
):
|
| 82 |
+
super().__init__()
|
| 83 |
+
|
| 84 |
+
model_cls = MODEL_CLASSES.get(model_type)
|
| 85 |
+
if model_cls is None:
|
| 86 |
+
raise ValueError(
|
| 87 |
+
f"Unknown model_type '{model_type}'. "
|
| 88 |
+
f"Choose from: {list(MODEL_CLASSES.keys())}"
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
common_kwargs = dict(
|
| 92 |
+
num_node_features=hidden_dim,
|
| 93 |
+
num_edge_features=hidden_dim,
|
| 94 |
+
hidden_dim=hidden_dim,
|
| 95 |
+
num_encoder_layers=num_encoder_layers,
|
| 96 |
+
num_decoder_layers=num_decoder_layers,
|
| 97 |
+
num_neighbors=num_neighbors,
|
| 98 |
+
dropout_rate=dropout_rate,
|
| 99 |
+
num_positional_embeddings=num_positional_embeddings,
|
| 100 |
+
min_rbf_mean=min_rbf_mean,
|
| 101 |
+
max_rbf_mean=max_rbf_mean,
|
| 102 |
+
num_rbf=num_rbf,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
if model_type == "ligand_mpnn":
|
| 106 |
+
common_kwargs["num_context_atoms"] = num_context_atoms
|
| 107 |
+
common_kwargs["num_context_encoding_layers"] = num_context_encoding_layers
|
| 108 |
+
|
| 109 |
+
self.model = model_cls(**common_kwargs)
|
| 110 |
+
|
| 111 |
+
def forward(
|
| 112 |
+
self,
|
| 113 |
+
X: torch.Tensor,
|
| 114 |
+
S: Optional[torch.Tensor] = None,
|
| 115 |
+
residue_mask: Optional[torch.Tensor] = None,
|
| 116 |
+
designed_residue_mask: Optional[torch.Tensor] = None,
|
| 117 |
+
chain_labels: Optional[torch.Tensor] = None,
|
| 118 |
+
R_idx: Optional[torch.Tensor] = None,
|
| 119 |
+
temperature: float = 0.1,
|
| 120 |
+
**kwargs,
|
| 121 |
+
) -> MPNNModelOutput:
|
| 122 |
+
"""
|
| 123 |
+
Run ProteinMPNN / LigandMPNN sequence design.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
X: Backbone atom coordinates [B, L, num_atoms, 3].
|
| 127 |
+
For ProteinMPNN: num_atoms=4 (N, CA, C, O).
|
| 128 |
+
S: Ground-truth sequence tokens [B, L] (optional, for teacher forcing).
|
| 129 |
+
residue_mask: Valid residue mask [B, L] (default: all valid).
|
| 130 |
+
designed_residue_mask: Which residues to design [B, L] (default: all).
|
| 131 |
+
chain_labels: Chain identifiers [B, L] (default: single chain).
|
| 132 |
+
R_idx: Residue indices [B, L] (default: 0..L-1).
|
| 133 |
+
temperature: Sampling temperature (default: 0.1).
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
MPNNModelOutput with sequence logits and sampled indices.
|
| 137 |
+
"""
|
| 138 |
+
B, L = X.shape[0], X.shape[1]
|
| 139 |
+
device = X.device
|
| 140 |
+
|
| 141 |
+
if S is None:
|
| 142 |
+
S = torch.zeros(B, L, dtype=torch.long, device=device)
|
| 143 |
+
if residue_mask is None:
|
| 144 |
+
residue_mask = torch.ones(B, L, dtype=torch.bool, device=device)
|
| 145 |
+
if designed_residue_mask is None:
|
| 146 |
+
designed_residue_mask = torch.ones(B, L, dtype=torch.bool, device=device)
|
| 147 |
+
if chain_labels is None:
|
| 148 |
+
chain_labels = torch.zeros(B, L, dtype=torch.long, device=device)
|
| 149 |
+
if R_idx is None:
|
| 150 |
+
R_idx = torch.arange(L, device=device).unsqueeze(0).expand(B, -1)
|
| 151 |
+
|
| 152 |
+
# Atom mask: mark all atoms as valid based on coordinate presence
|
| 153 |
+
X_m = (X.abs().sum(dim=-1) > 0).float() # [B, L, num_atoms]
|
| 154 |
+
|
| 155 |
+
network_input = {
|
| 156 |
+
"X": X,
|
| 157 |
+
"X_m": X_m,
|
| 158 |
+
"S": S,
|
| 159 |
+
"R_idx": R_idx,
|
| 160 |
+
"chain_labels": chain_labels,
|
| 161 |
+
"residue_mask": residue_mask,
|
| 162 |
+
"designed_residue_mask": designed_residue_mask,
|
| 163 |
+
"temperature": temperature,
|
| 164 |
+
**kwargs,
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
output = self.model(network_input)
|
| 168 |
+
|
| 169 |
+
logits = output["decoder_features"]["logits"] # [B, L, n_vocab]
|
| 170 |
+
S_sampled = output["decoder_features"].get(
|
| 171 |
+
"S_sampled", logits.argmax(dim=-1)
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
return MPNNModelOutput(
|
| 175 |
+
sequence_logits=logits,
|
| 176 |
+
sequence_indices=S_sampled,
|
| 177 |
+
decoder_features=output["decoder_features"],
|
| 178 |
+
)
|
mpnn_soluble/config.json
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "MPNNModel",
|
| 3 |
+
"_diffusers_version": "0.37.0.dev0",
|
| 4 |
+
"auto_map": {
|
| 5 |
+
"AutoModel": "model_mpnn.MPNNModel"
|
| 6 |
+
},
|
| 7 |
+
"model_type": "protein_mpnn",
|
| 8 |
+
"hidden_dim": 128,
|
| 9 |
+
"num_encoder_layers": 3,
|
| 10 |
+
"num_decoder_layers": 3,
|
| 11 |
+
"num_neighbors": 48,
|
| 12 |
+
"dropout_rate": 0.1,
|
| 13 |
+
"num_positional_embeddings": 16,
|
| 14 |
+
"min_rbf_mean": 2.0,
|
| 15 |
+
"max_rbf_mean": 22.0,
|
| 16 |
+
"num_rbf": 16
|
| 17 |
+
}
|
mpnn_soluble/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1d7ce5491820a26c7cfae3ab46421c40d1af645accb6c10b4818243eea2f0165
|
| 3 |
+
size 6653924
|
mpnn_soluble/model_mpnn.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Dhruv Nair. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
ProteinMPNN / LigandMPNN model wrapper.
|
| 17 |
+
|
| 18 |
+
A thin diffusers-compatible wrapper around the foundry MPNN model,
|
| 19 |
+
following the same pattern as the transformer and scheduler wrappers.
|
| 20 |
+
Reuses the foundry model implementation directly, adding only the
|
| 21 |
+
ModelMixin/ConfigMixin interface for diffusers integration.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from dataclasses import dataclass
|
| 25 |
+
from typing import Optional
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
import torch.nn as nn
|
| 29 |
+
|
| 30 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 31 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 32 |
+
|
| 33 |
+
from mpnn.model.mpnn import LigandMPNN, ProteinMPNN
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
MODEL_CLASSES = {
|
| 37 |
+
"protein_mpnn": ProteinMPNN,
|
| 38 |
+
"ligand_mpnn": LigandMPNN,
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class MPNNModelOutput:
|
| 44 |
+
"""Output from the MPNN model wrapper."""
|
| 45 |
+
|
| 46 |
+
sequence_logits: torch.Tensor # [B, L, n_vocab]
|
| 47 |
+
sequence_indices: torch.Tensor # [B, L]
|
| 48 |
+
decoder_features: dict # full decoder output dict
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class MPNNModel(ModelMixin, ConfigMixin):
|
| 52 |
+
"""
|
| 53 |
+
Diffusers-compatible wrapper around the foundry ProteinMPNN / LigandMPNN.
|
| 54 |
+
|
| 55 |
+
Wraps `mpnn.model.mpnn.ProteinMPNN` (or `LigandMPNN`) to provide a
|
| 56 |
+
diffusers ModelMixin/ConfigMixin interface. All model logic is delegated
|
| 57 |
+
to the foundry implementation.
|
| 58 |
+
|
| 59 |
+
State dict keys match the foundry checkpoint format via the `model.*`
|
| 60 |
+
prefix (stripped on load).
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
config_name = "config.json"
|
| 64 |
+
|
| 65 |
+
@register_to_config
|
| 66 |
+
def __init__(
|
| 67 |
+
self,
|
| 68 |
+
model_type: str = "protein_mpnn",
|
| 69 |
+
hidden_dim: int = 128,
|
| 70 |
+
num_encoder_layers: int = 3,
|
| 71 |
+
num_decoder_layers: int = 3,
|
| 72 |
+
num_neighbors: int = 48,
|
| 73 |
+
dropout_rate: float = 0.1,
|
| 74 |
+
num_positional_embeddings: int = 16,
|
| 75 |
+
min_rbf_mean: float = 2.0,
|
| 76 |
+
max_rbf_mean: float = 22.0,
|
| 77 |
+
num_rbf: int = 16,
|
| 78 |
+
# LigandMPNN-specific
|
| 79 |
+
num_context_atoms: int = 25,
|
| 80 |
+
num_context_encoding_layers: int = 2,
|
| 81 |
+
):
|
| 82 |
+
super().__init__()
|
| 83 |
+
|
| 84 |
+
model_cls = MODEL_CLASSES.get(model_type)
|
| 85 |
+
if model_cls is None:
|
| 86 |
+
raise ValueError(
|
| 87 |
+
f"Unknown model_type '{model_type}'. "
|
| 88 |
+
f"Choose from: {list(MODEL_CLASSES.keys())}"
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
common_kwargs = dict(
|
| 92 |
+
num_node_features=hidden_dim,
|
| 93 |
+
num_edge_features=hidden_dim,
|
| 94 |
+
hidden_dim=hidden_dim,
|
| 95 |
+
num_encoder_layers=num_encoder_layers,
|
| 96 |
+
num_decoder_layers=num_decoder_layers,
|
| 97 |
+
num_neighbors=num_neighbors,
|
| 98 |
+
dropout_rate=dropout_rate,
|
| 99 |
+
num_positional_embeddings=num_positional_embeddings,
|
| 100 |
+
min_rbf_mean=min_rbf_mean,
|
| 101 |
+
max_rbf_mean=max_rbf_mean,
|
| 102 |
+
num_rbf=num_rbf,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
if model_type == "ligand_mpnn":
|
| 106 |
+
common_kwargs["num_context_atoms"] = num_context_atoms
|
| 107 |
+
common_kwargs["num_context_encoding_layers"] = num_context_encoding_layers
|
| 108 |
+
|
| 109 |
+
self.model = model_cls(**common_kwargs)
|
| 110 |
+
|
| 111 |
+
def forward(
|
| 112 |
+
self,
|
| 113 |
+
X: torch.Tensor,
|
| 114 |
+
S: Optional[torch.Tensor] = None,
|
| 115 |
+
residue_mask: Optional[torch.Tensor] = None,
|
| 116 |
+
designed_residue_mask: Optional[torch.Tensor] = None,
|
| 117 |
+
chain_labels: Optional[torch.Tensor] = None,
|
| 118 |
+
R_idx: Optional[torch.Tensor] = None,
|
| 119 |
+
temperature: float = 0.1,
|
| 120 |
+
**kwargs,
|
| 121 |
+
) -> MPNNModelOutput:
|
| 122 |
+
"""
|
| 123 |
+
Run ProteinMPNN / LigandMPNN sequence design.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
X: Backbone atom coordinates [B, L, num_atoms, 3].
|
| 127 |
+
For ProteinMPNN: num_atoms=4 (N, CA, C, O).
|
| 128 |
+
S: Ground-truth sequence tokens [B, L] (optional, for teacher forcing).
|
| 129 |
+
residue_mask: Valid residue mask [B, L] (default: all valid).
|
| 130 |
+
designed_residue_mask: Which residues to design [B, L] (default: all).
|
| 131 |
+
chain_labels: Chain identifiers [B, L] (default: single chain).
|
| 132 |
+
R_idx: Residue indices [B, L] (default: 0..L-1).
|
| 133 |
+
temperature: Sampling temperature (default: 0.1).
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
MPNNModelOutput with sequence logits and sampled indices.
|
| 137 |
+
"""
|
| 138 |
+
B, L = X.shape[0], X.shape[1]
|
| 139 |
+
device = X.device
|
| 140 |
+
|
| 141 |
+
if S is None:
|
| 142 |
+
S = torch.zeros(B, L, dtype=torch.long, device=device)
|
| 143 |
+
if residue_mask is None:
|
| 144 |
+
residue_mask = torch.ones(B, L, dtype=torch.bool, device=device)
|
| 145 |
+
if designed_residue_mask is None:
|
| 146 |
+
designed_residue_mask = torch.ones(B, L, dtype=torch.bool, device=device)
|
| 147 |
+
if chain_labels is None:
|
| 148 |
+
chain_labels = torch.zeros(B, L, dtype=torch.long, device=device)
|
| 149 |
+
if R_idx is None:
|
| 150 |
+
R_idx = torch.arange(L, device=device).unsqueeze(0).expand(B, -1)
|
| 151 |
+
|
| 152 |
+
# Atom mask: mark all atoms as valid based on coordinate presence
|
| 153 |
+
X_m = (X.abs().sum(dim=-1) > 0).float() # [B, L, num_atoms]
|
| 154 |
+
|
| 155 |
+
network_input = {
|
| 156 |
+
"X": X,
|
| 157 |
+
"X_m": X_m,
|
| 158 |
+
"S": S,
|
| 159 |
+
"R_idx": R_idx,
|
| 160 |
+
"chain_labels": chain_labels,
|
| 161 |
+
"residue_mask": residue_mask,
|
| 162 |
+
"designed_residue_mask": designed_residue_mask,
|
| 163 |
+
"temperature": temperature,
|
| 164 |
+
**kwargs,
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
output = self.model(network_input)
|
| 168 |
+
|
| 169 |
+
logits = output["decoder_features"]["logits"] # [B, L, n_vocab]
|
| 170 |
+
S_sampled = output["decoder_features"].get(
|
| 171 |
+
"S_sampled", logits.argmax(dim=-1)
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
return MPNNModelOutput(
|
| 175 |
+
sequence_logits=logits,
|
| 176 |
+
sequence_indices=S_sampled,
|
| 177 |
+
decoder_features=output["decoder_features"],
|
| 178 |
+
)
|
pyproject.toml
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=61.0", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "modular-rfdiffusion"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "Modular Diffusers Pipeline for RFDiffusion protein structure generation"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
license = {text = "Apache-2.0"}
|
| 11 |
+
requires-python = ">=3.8"
|
| 12 |
+
authors = [
|
| 13 |
+
{name = "Dhruv Nair"}
|
| 14 |
+
]
|
| 15 |
+
keywords = ["diffusion", "protein", "rfdiffusion", "deep-learning", "pytorch"]
|
| 16 |
+
classifiers = [
|
| 17 |
+
"Development Status :: 3 - Alpha",
|
| 18 |
+
"Intended Audience :: Science/Research",
|
| 19 |
+
"License :: OSI Approved :: Apache Software License",
|
| 20 |
+
"Programming Language :: Python :: 3",
|
| 21 |
+
"Programming Language :: Python :: 3.8",
|
| 22 |
+
"Programming Language :: Python :: 3.9",
|
| 23 |
+
"Programming Language :: Python :: 3.10",
|
| 24 |
+
"Programming Language :: Python :: 3.11",
|
| 25 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
| 26 |
+
"Topic :: Scientific/Engineering :: Bio-Informatics",
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
dependencies = [
|
| 30 |
+
"torch>=2.0.0",
|
| 31 |
+
"numpy>=1.21.0",
|
| 32 |
+
"diffusers>=0.25.0",
|
| 33 |
+
"huggingface-hub>=0.20.0",
|
| 34 |
+
"scipy>=1.7.0",
|
| 35 |
+
"hydra-core>=1.0.0",
|
| 36 |
+
"omegaconf>=2.0.0",
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
[project.optional-dependencies]
|
| 40 |
+
dev = [
|
| 41 |
+
"pytest>=7.0.0",
|
| 42 |
+
"pytest-cov>=4.0.0",
|
| 43 |
+
"black>=23.0.0",
|
| 44 |
+
"ruff>=0.1.0",
|
| 45 |
+
"mypy>=1.0.0",
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
[project.urls]
|
| 49 |
+
Homepage = "https://github.com/DN6/modular-diffusers"
|
| 50 |
+
Repository = "https://github.com/DN6/modular-diffusers"
|
| 51 |
+
|
| 52 |
+
[tool.setuptools.packages.find]
|
| 53 |
+
where = ["."]
|
| 54 |
+
include = ["*"]
|
| 55 |
+
|
| 56 |
+
[tool.black]
|
| 57 |
+
line-length = 119
|
| 58 |
+
target-version = ["py38", "py39", "py310", "py311"]
|
| 59 |
+
|
| 60 |
+
[tool.ruff]
|
| 61 |
+
line-length = 119
|
| 62 |
+
target-version = "py38"
|
| 63 |
+
|
| 64 |
+
[tool.ruff.lint]
|
| 65 |
+
select = ["E", "F", "W", "I", "N"]
|
| 66 |
+
ignore = ["E501"]
|
scheduler/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Dhruv Nair. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .model import RFDiffusionScheduler
|
scheduler/config.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "RFDiffusionScheduler",
|
| 3 |
+
"_diffusers_version": "0.37.0.dev0",
|
| 4 |
+
"auto_map": {
|
| 5 |
+
"AutoModel": "model.RFDiffusionScheduler"
|
| 6 |
+
},
|
| 7 |
+
"num_timesteps": 200,
|
| 8 |
+
"sigma_data": 16.0,
|
| 9 |
+
"s_min": 4e-4,
|
| 10 |
+
"s_max": 160.0,
|
| 11 |
+
"p": 7.0,
|
| 12 |
+
"gamma_0": 0.6,
|
| 13 |
+
"gamma_min": 1.0,
|
| 14 |
+
"noise_scale": 1.003,
|
| 15 |
+
"step_scale": 1.5
|
| 16 |
+
}
|
scheduler/model.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Dhruv Nair. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
RFDiffusion3 Scheduler.
|
| 17 |
+
|
| 18 |
+
A thin diffusers-compatible wrapper around the foundry EDM noise schedule
|
| 19 |
+
and stochastic sampling logic from `rfd3.model.inference_sampler`.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
from typing import Optional
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
|
| 26 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 27 |
+
|
| 28 |
+
# Reuse the original noise schedule and sampling config directly
|
| 29 |
+
from rfd3.model.inference_sampler import SampleDiffusionWithMotif
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class RFDiffusionScheduler(ConfigMixin):
|
| 33 |
+
"""
|
| 34 |
+
Diffusers-compatible scheduler wrapping the foundry EDM sampler.
|
| 35 |
+
|
| 36 |
+
Delegates noise schedule construction and sampling parameters to
|
| 37 |
+
`rfd3.model.inference_sampler.SampleDiffusionWithMotif`.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
config_name = "config.json"
|
| 41 |
+
|
| 42 |
+
@register_to_config
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
num_timesteps: int = 200,
|
| 46 |
+
sigma_data: float = 16.0,
|
| 47 |
+
s_min: float = 4e-4,
|
| 48 |
+
s_max: float = 160.0,
|
| 49 |
+
p: float = 7.0,
|
| 50 |
+
gamma_0: float = 0.6,
|
| 51 |
+
gamma_min: float = 1.0,
|
| 52 |
+
noise_scale: float = 1.003,
|
| 53 |
+
step_scale: float = 1.5,
|
| 54 |
+
):
|
| 55 |
+
# Instantiate the foundry sampler with matching parameters
|
| 56 |
+
self._sampler = SampleDiffusionWithMotif(
|
| 57 |
+
num_timesteps=num_timesteps,
|
| 58 |
+
sigma_data=sigma_data,
|
| 59 |
+
s_min=s_min,
|
| 60 |
+
s_max=s_max,
|
| 61 |
+
p=p,
|
| 62 |
+
gamma_0=gamma_0,
|
| 63 |
+
gamma_min=gamma_min,
|
| 64 |
+
noise_scale=noise_scale,
|
| 65 |
+
step_scale=step_scale,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
@property
|
| 69 |
+
def sampler(self) -> SampleDiffusionWithMotif:
|
| 70 |
+
return self._sampler
|
| 71 |
+
|
| 72 |
+
def get_noise_schedule(self, device: torch.device = None) -> torch.Tensor:
|
| 73 |
+
"""
|
| 74 |
+
Construct the EDM noise schedule using the foundry implementation.
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
torch.Tensor: Noise schedule [num_timesteps] from high to low noise.
|
| 78 |
+
"""
|
| 79 |
+
return self._sampler._construct_inference_noise_schedule(
|
| 80 |
+
device=device or torch.device("cpu")
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def get_initial_noise_level(self, device: torch.device = None) -> torch.Tensor:
|
| 84 |
+
"""Get the first (largest) noise level from the schedule."""
|
| 85 |
+
return self.get_noise_schedule(device=device)[0]
|
| 86 |
+
|
| 87 |
+
def step(
|
| 88 |
+
self,
|
| 89 |
+
xyz_pred: torch.Tensor,
|
| 90 |
+
xyz_noisy: torch.Tensor,
|
| 91 |
+
c_t_minus_1: torch.Tensor,
|
| 92 |
+
c_t: torch.Tensor,
|
| 93 |
+
motif_mask: Optional[torch.Tensor] = None,
|
| 94 |
+
) -> torch.Tensor:
|
| 95 |
+
"""
|
| 96 |
+
Perform one Euler denoising step matching the foundry sampler.
|
| 97 |
+
|
| 98 |
+
The foundry ``sample_diffusion_like_af3`` does NOT clamp motif
|
| 99 |
+
coordinates after the Euler update β it relies on noise injection
|
| 100 |
+
having zeroed epsilon for motif atoms so the model's delta is ~0
|
| 101 |
+
there. We replicate that behaviour here.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
xyz_pred: Model's denoised prediction X_denoised_L [B, L, 3]
|
| 105 |
+
xyz_noisy: Noise-injected coordinates X_noisy_L [B, L, 3]
|
| 106 |
+
c_t_minus_1: Previous noise level
|
| 107 |
+
c_t: Next (lower) noise level
|
| 108 |
+
motif_mask: Boolean mask for fixed positions (True = fixed) [L]
|
| 109 |
+
(unused β kept for API compatibility)
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
Updated coordinates X_L [B, L, 3]
|
| 113 |
+
"""
|
| 114 |
+
gamma = self._sampler.gamma_0 if c_t > self._sampler.gamma_min else 0.0
|
| 115 |
+
t_hat = c_t_minus_1 * (gamma + 1.0)
|
| 116 |
+
|
| 117 |
+
delta_L = (xyz_noisy - xyz_pred) / t_hat
|
| 118 |
+
d_t = c_t - t_hat
|
| 119 |
+
xyz_next = xyz_noisy + self._sampler.step_scale * d_t * delta_L
|
| 120 |
+
|
| 121 |
+
return xyz_next
|
| 122 |
+
|
| 123 |
+
def add_noise(
|
| 124 |
+
self,
|
| 125 |
+
xyz: torch.Tensor,
|
| 126 |
+
c_t_minus_1: torch.Tensor,
|
| 127 |
+
c_t: torch.Tensor,
|
| 128 |
+
motif_mask: Optional[torch.Tensor] = None,
|
| 129 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 130 |
+
"""
|
| 131 |
+
Inject stochastic noise before the model call, matching the foundry sampler.
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
xyz: Current coordinates X_L [B, L, 3]
|
| 135 |
+
c_t_minus_1: Previous noise level
|
| 136 |
+
c_t: Current (next lower) noise level
|
| 137 |
+
motif_mask: Boolean mask for fixed positions (True = fixed) [L]
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
Tuple of (noisy coordinates X_noisy_L, t_hat scalar)
|
| 141 |
+
"""
|
| 142 |
+
gamma = self._sampler.gamma_0 if c_t > self._sampler.gamma_min else 0.0
|
| 143 |
+
t_hat = c_t_minus_1 * (gamma + 1.0)
|
| 144 |
+
|
| 145 |
+
noise_std = self._sampler.noise_scale * torch.sqrt(t_hat**2 - c_t_minus_1**2)
|
| 146 |
+
epsilon = noise_std * torch.randn_like(xyz)
|
| 147 |
+
|
| 148 |
+
if motif_mask is not None:
|
| 149 |
+
epsilon[:, motif_mask] = 0.0
|
| 150 |
+
|
| 151 |
+
xyz_noisy = xyz + epsilon
|
| 152 |
+
return xyz_noisy, t_hat
|
transformer/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Dhruv Nair. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .model_rfdiffusion import RFDiffusionTransformerModel, RFDiffusionTransformerOutput
|
transformer/config.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "RFDiffusionTransformerModel",
|
| 3 |
+
"_diffusers_version": "0.37.0.dev0",
|
| 4 |
+
"auto_map": {
|
| 5 |
+
"AutoModel": "model_rfdiffusion.RFDiffusionTransformerModel"
|
| 6 |
+
},
|
| 7 |
+
"c_s": 384,
|
| 8 |
+
"c_z": 128,
|
| 9 |
+
"c_atom": 128,
|
| 10 |
+
"c_atompair": 16,
|
| 11 |
+
"c_token": 768,
|
| 12 |
+
"c_t_embed": 256,
|
| 13 |
+
"sigma_data": 16.0,
|
| 14 |
+
"n_pairformer_block": 2,
|
| 15 |
+
"n_diffusion_block": 18,
|
| 16 |
+
"n_atom_encoder_block": 3,
|
| 17 |
+
"n_atom_decoder_block": 3,
|
| 18 |
+
"n_head": 16,
|
| 19 |
+
"n_pairformer_head": 16,
|
| 20 |
+
"n_recycle": 2,
|
| 21 |
+
"p_drop": 0.0
|
| 22 |
+
}
|
transformer/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1fe11cd4fe730654e0659d5218228a9d27bdcdac4c880faad993feaa532e99eb
|
| 3 |
+
size 672279561
|
transformer/model_rfdiffusion.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Dhruv Nair. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
RFDiffusion3 Transformer model.
|
| 17 |
+
|
| 18 |
+
A thin diffusers-compatible wrapper around the existing foundry RFD3
|
| 19 |
+
model components. Reuses all layer implementations from `rfd3.model.layers.*`
|
| 20 |
+
and `foundry.model.layers.*` directly, adding only the ModelMixin/ConfigMixin
|
| 21 |
+
interface needed for diffusers ModularPipeline integration.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from dataclasses import dataclass
|
| 25 |
+
from typing import Optional
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
import torch.nn as nn
|
| 29 |
+
|
| 30 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 31 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 32 |
+
|
| 33 |
+
# Reuse all components from the foundry/rfd3 package directly
|
| 34 |
+
from rfd3.model.RFD3_diffusion_module import RFD3DiffusionModule
|
| 35 |
+
from rfd3.model.layers.encoders import TokenInitializer
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class RFDiffusionTransformerOutput:
|
| 40 |
+
"""Output class for RFDiffusion transformer."""
|
| 41 |
+
|
| 42 |
+
xyz: torch.Tensor
|
| 43 |
+
single: torch.Tensor
|
| 44 |
+
pair: torch.Tensor
|
| 45 |
+
sequence_logits: Optional[torch.Tensor] = None
|
| 46 |
+
sequence_indices: Optional[torch.Tensor] = None
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class RFDiffusionTransformerModel(ModelMixin, ConfigMixin):
|
| 50 |
+
"""
|
| 51 |
+
Diffusers-compatible wrapper around the foundry RFD3 model.
|
| 52 |
+
|
| 53 |
+
This wraps `rfd3.model.RFD3_diffusion_module.RFD3DiffusionModule` and
|
| 54 |
+
`rfd3.model.layers.encoders.TokenInitializer` to provide a diffusers
|
| 55 |
+
ModelMixin/ConfigMixin interface. All actual model logic is delegated
|
| 56 |
+
to the foundry implementation.
|
| 57 |
+
|
| 58 |
+
State dict keys match the foundry checkpoint format via the
|
| 59 |
+
`token_initializer.*` and `diffusion_module.*` prefixes.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
config_name = "config.json"
|
| 63 |
+
_supports_gradient_checkpointing = True
|
| 64 |
+
|
| 65 |
+
@register_to_config
|
| 66 |
+
def __init__(
|
| 67 |
+
self,
|
| 68 |
+
c_s: int = 384,
|
| 69 |
+
c_z: int = 128,
|
| 70 |
+
c_atom: int = 128,
|
| 71 |
+
c_atompair: int = 16,
|
| 72 |
+
c_token: int = 768,
|
| 73 |
+
c_t_embed: int = 256,
|
| 74 |
+
sigma_data: float = 16.0,
|
| 75 |
+
n_pairformer_block: int = 2,
|
| 76 |
+
n_diffusion_block: int = 18,
|
| 77 |
+
n_atom_encoder_block: int = 3,
|
| 78 |
+
n_atom_decoder_block: int = 3,
|
| 79 |
+
n_head: int = 16,
|
| 80 |
+
n_pairformer_head: int = 16,
|
| 81 |
+
n_recycle: int = 2,
|
| 82 |
+
p_drop: float = 0.0,
|
| 83 |
+
):
|
| 84 |
+
super().__init__()
|
| 85 |
+
|
| 86 |
+
# ββ Shared sub-configs matching rfd3_net.yaml exactly ββββββββββββ
|
| 87 |
+
cross_attention_block = {
|
| 88 |
+
"n_head": 4,
|
| 89 |
+
"c_model": c_atom,
|
| 90 |
+
"dropout": 0.0,
|
| 91 |
+
"kq_norm": True,
|
| 92 |
+
}
|
| 93 |
+
downcast_cfg = {
|
| 94 |
+
"method": "cross_attention",
|
| 95 |
+
"cross_attention_block": cross_attention_block,
|
| 96 |
+
}
|
| 97 |
+
upcast_cfg = {
|
| 98 |
+
"method": "cross_attention",
|
| 99 |
+
"n_split": 3,
|
| 100 |
+
"cross_attention_block": cross_attention_block,
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
# ββ TokenInitializer (foundry) ββββββββββββββββββββββββββββββββββ
|
| 104 |
+
token_1d_features = {
|
| 105 |
+
"ref_motif_token_type": 3,
|
| 106 |
+
"restype": 32,
|
| 107 |
+
"ref_plddt": 1,
|
| 108 |
+
"is_non_loopy": 1,
|
| 109 |
+
}
|
| 110 |
+
atom_1d_features = {
|
| 111 |
+
"ref_atom_name_chars": 256,
|
| 112 |
+
"ref_element": 128,
|
| 113 |
+
"ref_charge": 1,
|
| 114 |
+
"ref_mask": 1,
|
| 115 |
+
"ref_is_motif_atom_with_fixed_coord": 1,
|
| 116 |
+
"ref_is_motif_atom_unindexed": 1,
|
| 117 |
+
"has_zero_occupancy": 1,
|
| 118 |
+
"ref_pos": 3,
|
| 119 |
+
"ref_atomwise_rasa": 3,
|
| 120 |
+
"active_donor": 1,
|
| 121 |
+
"active_acceptor": 1,
|
| 122 |
+
"is_atom_level_hotspot": 1,
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
self.token_initializer = TokenInitializer(
|
| 126 |
+
c_s=c_s,
|
| 127 |
+
c_z=c_z,
|
| 128 |
+
c_atom=c_atom,
|
| 129 |
+
c_atompair=c_atompair,
|
| 130 |
+
relative_position_encoding={"r_max": 32, "s_max": 2},
|
| 131 |
+
n_pairformer_blocks=n_pairformer_block,
|
| 132 |
+
pairformer_block={
|
| 133 |
+
"attention_pair_bias": {"n_head": n_pairformer_head, "kq_norm": True},
|
| 134 |
+
},
|
| 135 |
+
downcast=downcast_cfg,
|
| 136 |
+
token_1d_features=token_1d_features,
|
| 137 |
+
atom_1d_features=atom_1d_features,
|
| 138 |
+
atom_transformer={
|
| 139 |
+
"n_blocks": 0,
|
| 140 |
+
"atom_transformer_block": {
|
| 141 |
+
"n_head": 4,
|
| 142 |
+
"kq_norm": True,
|
| 143 |
+
"no_residual_connection_between_attention_and_transition": False,
|
| 144 |
+
"dropout": 0.0,
|
| 145 |
+
"n_attn_seq_neighbours": 4,
|
| 146 |
+
"n_attn_keys": 128,
|
| 147 |
+
},
|
| 148 |
+
},
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# ββ RFD3DiffusionModule (foundry) βββββββββββββββββββββββββββββββ
|
| 152 |
+
self.diffusion_module = RFD3DiffusionModule(
|
| 153 |
+
c_atom=c_atom,
|
| 154 |
+
c_atompair=c_atompair,
|
| 155 |
+
c_token=c_token,
|
| 156 |
+
c_s=c_s,
|
| 157 |
+
c_z=c_z,
|
| 158 |
+
c_t_embed=c_t_embed,
|
| 159 |
+
sigma_data=sigma_data,
|
| 160 |
+
f_pred="edm",
|
| 161 |
+
n_attn_seq_neighbours=2,
|
| 162 |
+
n_attn_keys=128,
|
| 163 |
+
n_recycle=n_recycle,
|
| 164 |
+
atom_attention_encoder={
|
| 165 |
+
"n_blocks": n_atom_encoder_block,
|
| 166 |
+
"atom_transformer_block": {
|
| 167 |
+
"n_head": 4,
|
| 168 |
+
"kq_norm": True,
|
| 169 |
+
"no_residual_connection_between_attention_and_transition": False,
|
| 170 |
+
"dropout": 0.0,
|
| 171 |
+
},
|
| 172 |
+
},
|
| 173 |
+
diffusion_token_encoder={
|
| 174 |
+
"sigma_data": sigma_data,
|
| 175 |
+
"n_pairformer_blocks": n_pairformer_block,
|
| 176 |
+
"pairformer_block": {
|
| 177 |
+
"attention_pair_bias": {"n_head": n_pairformer_head, "kq_norm": True},
|
| 178 |
+
},
|
| 179 |
+
"use_distogram": True,
|
| 180 |
+
"use_self": True,
|
| 181 |
+
"use_sinusoidal_distogram_embedder": False,
|
| 182 |
+
},
|
| 183 |
+
diffusion_transformer={
|
| 184 |
+
"n_block": n_diffusion_block,
|
| 185 |
+
"diffusion_transformer_block": {
|
| 186 |
+
"n_head": n_head,
|
| 187 |
+
"kq_norm": True,
|
| 188 |
+
"no_residual_connection_between_attention_and_transition": False,
|
| 189 |
+
"dropout": p_drop,
|
| 190 |
+
},
|
| 191 |
+
},
|
| 192 |
+
atom_attention_decoder={
|
| 193 |
+
"n_blocks": n_atom_decoder_block,
|
| 194 |
+
"upcast": upcast_cfg,
|
| 195 |
+
"downcast": downcast_cfg,
|
| 196 |
+
"atom_transformer_block": {
|
| 197 |
+
"n_head": 4,
|
| 198 |
+
"kq_norm": True,
|
| 199 |
+
"no_residual_connection_between_attention_and_transition": False,
|
| 200 |
+
"dropout": p_drop,
|
| 201 |
+
},
|
| 202 |
+
},
|
| 203 |
+
downcast=downcast_cfg,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
@property
|
| 207 |
+
def sigma_data(self) -> float:
|
| 208 |
+
return self.diffusion_module.sigma_data
|
| 209 |
+
|
| 210 |
+
def forward(
|
| 211 |
+
self,
|
| 212 |
+
xyz_noisy: torch.Tensor,
|
| 213 |
+
t: torch.Tensor,
|
| 214 |
+
f: Optional[dict] = None,
|
| 215 |
+
motif_mask: Optional[torch.Tensor] = None,
|
| 216 |
+
n_recycle: Optional[int] = None,
|
| 217 |
+
**kwargs,
|
| 218 |
+
) -> RFDiffusionTransformerOutput:
|
| 219 |
+
"""
|
| 220 |
+
Forward pass delegated to the foundry RFD3DiffusionModule.
|
| 221 |
+
|
| 222 |
+
Args:
|
| 223 |
+
xyz_noisy: Noisy atom coordinates [B, L, 3]
|
| 224 |
+
t: Noise level / timestep [B]
|
| 225 |
+
f: Feature dictionary (as expected by foundry). If None, a minimal
|
| 226 |
+
feature dict is constructed from xyz_noisy and motif_mask.
|
| 227 |
+
motif_mask: Mask for fixed motif atoms [L] (used when f is None)
|
| 228 |
+
n_recycle: Number of recycling iterations
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
RFDiffusionTransformerOutput with denoised coordinates and predictions
|
| 232 |
+
"""
|
| 233 |
+
B, L, _ = xyz_noisy.shape
|
| 234 |
+
|
| 235 |
+
# If caller provides a full feature dict, use the native foundry path
|
| 236 |
+
if f is not None:
|
| 237 |
+
initializer_outputs = self.token_initializer(f)
|
| 238 |
+
outs = self.diffusion_module(
|
| 239 |
+
X_noisy_L=xyz_noisy,
|
| 240 |
+
t=t,
|
| 241 |
+
f=f,
|
| 242 |
+
n_recycle=n_recycle,
|
| 243 |
+
**initializer_outputs,
|
| 244 |
+
)
|
| 245 |
+
return RFDiffusionTransformerOutput(
|
| 246 |
+
xyz=outs["X_L"],
|
| 247 |
+
single=torch.zeros(1), # not directly exposed by foundry
|
| 248 |
+
pair=torch.zeros(1),
|
| 249 |
+
sequence_logits=outs.get("sequence_logits_I"),
|
| 250 |
+
sequence_indices=outs.get("sequence_indices_I"),
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
# Simplified path: construct minimal feature dict and call dm.forward()
|
| 254 |
+
# For unconditional generation, each residue has 1 atom (CA), so L = I
|
| 255 |
+
device = xyz_noisy.device
|
| 256 |
+
dtype = xyz_noisy.dtype
|
| 257 |
+
|
| 258 |
+
if motif_mask is None:
|
| 259 |
+
motif_mask = torch.zeros(L, dtype=torch.bool, device=device)
|
| 260 |
+
else:
|
| 261 |
+
motif_mask = motif_mask.to(device)
|
| 262 |
+
|
| 263 |
+
# Construct minimal feature dict with all keys required by foundry
|
| 264 |
+
f = {
|
| 265 |
+
"atom_to_token_map": torch.arange(L, device=device), # 1:1 atom-to-token
|
| 266 |
+
"unindexing_pair_mask": torch.zeros(L, L, dtype=torch.bool, device=device),
|
| 267 |
+
"is_ca": torch.ones(L, dtype=torch.bool, device=device),
|
| 268 |
+
"is_motif_atom_with_fixed_coord": motif_mask,
|
| 269 |
+
"is_motif_token_with_fully_fixed_coord": motif_mask,
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
# Zero-initialized TokenInitializer outputs (no conditioning features)
|
| 273 |
+
Q_L_init = torch.zeros(L, self.config.c_atom, device=device, dtype=dtype)
|
| 274 |
+
C_L = torch.zeros(L, self.config.c_atom, device=device, dtype=dtype)
|
| 275 |
+
P_LL = torch.zeros(L, L, self.config.c_atompair, device=device, dtype=dtype)
|
| 276 |
+
S_I = torch.zeros(L, self.config.c_s, device=device, dtype=dtype)
|
| 277 |
+
Z_II = torch.zeros(L, L, self.config.c_z, device=device, dtype=dtype)
|
| 278 |
+
|
| 279 |
+
outs = self.diffusion_module(
|
| 280 |
+
X_noisy_L=xyz_noisy,
|
| 281 |
+
t=t,
|
| 282 |
+
f=f,
|
| 283 |
+
Q_L_init=Q_L_init,
|
| 284 |
+
C_L=C_L,
|
| 285 |
+
P_LL=P_LL,
|
| 286 |
+
S_I=S_I,
|
| 287 |
+
Z_II=Z_II,
|
| 288 |
+
n_recycle=n_recycle,
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
return RFDiffusionTransformerOutput(
|
| 292 |
+
xyz=outs["X_L"],
|
| 293 |
+
single=torch.zeros(1),
|
| 294 |
+
pair=torch.zeros(1),
|
| 295 |
+
sequence_logits=outs.get("sequence_logits_I"),
|
| 296 |
+
sequence_indices=outs.get("sequence_indices_I"),
|
| 297 |
+
)
|